use super::sys::{self};
use core::{
ffi::{c_char, c_int, CStr},
mem::MaybeUninit,
};
use std::{ffi::CString, vec::Vec};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct NvrtcError(pub sys::nvrtcResult);
impl sys::nvrtcResult {
pub fn result(self) -> Result<(), NvrtcError> {
match self {
sys::nvrtcResult::NVRTC_SUCCESS => Ok(()),
_ => Err(NvrtcError(self)),
}
}
}
#[cfg(feature = "std")]
impl std::fmt::Display for NvrtcError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
#[cfg(feature = "std")]
impl std::error::Error for NvrtcError {}
pub fn create_program(src: &CStr, name: Option<&CStr>) -> Result<sys::nvrtcProgram, NvrtcError> {
let mut prog = MaybeUninit::uninit();
unsafe {
sys::nvrtcCreateProgram(
prog.as_mut_ptr(),
src.as_ptr(),
name.map(|n| n.as_ptr()).unwrap_or(std::ptr::null()),
0,
std::ptr::null(),
std::ptr::null(),
)
.result()?;
Ok(prog.assume_init())
}
}
pub unsafe fn compile_program<O: Clone + Into<Vec<u8>>>(
prog: sys::nvrtcProgram,
options: &[O],
) -> Result<(), NvrtcError> {
let c_strings: Vec<CString> = options
.iter()
.cloned()
.map(|o| CString::new(o).unwrap())
.collect();
let c_strs: Vec<&CStr> = c_strings.iter().map(CString::as_c_str).collect();
let opts: Vec<*const c_char> = c_strs.iter().cloned().map(CStr::as_ptr).collect();
sys::nvrtcCompileProgram(prog, opts.len() as c_int, opts.as_ptr()).result()
}
pub unsafe fn destroy_program(prog: sys::nvrtcProgram) -> Result<(), NvrtcError> {
sys::nvrtcDestroyProgram(&prog as *const _ as *mut _).result()
}
#[allow(clippy::slow_vector_initialization)]
pub unsafe fn get_ptx(prog: sys::nvrtcProgram) -> Result<Vec<c_char>, NvrtcError> {
let mut size: usize = 0;
sys::nvrtcGetPTXSize(prog, &mut size as *mut _).result()?;
let mut ptx_src: Vec<c_char> = Vec::with_capacity(size);
ptx_src.resize(size, 0);
sys::nvrtcGetPTX(prog, ptx_src.as_mut_ptr()).result()?;
Ok(ptx_src)
}
#[allow(clippy::slow_vector_initialization)]
pub unsafe fn get_program_log(prog: sys::nvrtcProgram) -> Result<Vec<c_char>, NvrtcError> {
let mut size: usize = 0;
sys::nvrtcGetProgramLogSize(prog, &mut size as *mut _).result()?;
let mut log_src: Vec<c_char> = Vec::with_capacity(size);
log_src.resize(size, 0);
sys::nvrtcGetProgramLog(prog, log_src.as_mut_ptr()).result()?;
Ok(log_src)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compile_program_no_opts() {
let prog = create_program(c"extern \"C\" __global__ void kernel() { }", None).unwrap();
unsafe { compile_program::<&str>(prog, &[]) }.unwrap();
unsafe { destroy_program(prog) }.unwrap();
}
#[test]
fn test_compile_program_1_opt() {
let prog = create_program(c"extern \"C\" __global__ void kernel() { }", None).unwrap();
unsafe { compile_program(prog, &["--ftz=true"]) }.unwrap();
unsafe { destroy_program(prog) }.unwrap();
}
#[test]
fn test_compile_program_2_opt() {
let prog = create_program(c"extern \"C\" __global__ void kernel() { }", None).unwrap();
unsafe { compile_program(prog, &["--ftz=true", "--fmad=true"]) }.unwrap();
unsafe { destroy_program(prog) }.unwrap();
}
#[test]
fn test_compile_bad_program() {
let prog = create_program(c"extern \"C\" __global__ void kernel(", None).unwrap();
assert_eq!(
unsafe { compile_program::<&str>(prog, &[]) }.unwrap_err(),
NvrtcError(sys::nvrtcResult::NVRTC_ERROR_COMPILATION)
);
}
#[test]
fn test_get_ptx() {
const SRC: &CStr =
c"extern \"C\" __global__ void sin_kernel(float *out, const float *inp, int numel) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < numel) {
out[i] = sin(inp[i]);
}
}";
let prog = create_program(SRC, None).unwrap();
unsafe { compile_program::<&str>(prog, &[]) }.unwrap();
let ptx = unsafe { get_ptx(prog) }.unwrap();
assert!(!ptx.is_empty());
let log = unsafe { get_program_log(prog) }.unwrap();
assert!(!log.is_empty());
unsafe { destroy_program(prog) }.unwrap();
}
}