use std::ffi::CString;
use cudarc::nvrtc::{Ptx, sys};
pub fn compile_cubin(src: &str, arch: &str) -> Result<Ptx, String> {
compile_cubin_with_extra_opts(src, arch, &[])
}
#[expect(unsafe_code, reason = "NVRTC C API requires raw program handles")]
pub fn compile_cubin_with_extra_opts(
src: &str,
arch: &str,
extra_opts: &[&str],
) -> Result<Ptx, String> {
let c_src = CString::new(src).map_err(|e| format!("kernel source contains NUL: {e}"))?;
let mut option_strings: Vec<CString> = Vec::with_capacity(1 + extra_opts.len());
option_strings.push(
CString::new(format!("--gpu-architecture={arch}"))
.map_err(|e| format!("arch option contains NUL: {e}"))?,
);
for opt in extra_opts {
option_strings
.push(CString::new(*opt).map_err(|e| format!("option {opt:?} contains NUL: {e}"))?);
}
let option_ptrs: Vec<*const ::core::ffi::c_char> =
option_strings.iter().map(|s| s.as_ptr()).collect();
unsafe {
let mut prog: sys::nvrtcProgram = std::ptr::null_mut();
let r = sys::nvrtcCreateProgram(
std::ptr::addr_of_mut!(prog),
c_src.as_ptr(),
std::ptr::null(),
0,
std::ptr::null(),
std::ptr::null(),
);
if r != sys::nvrtcResult::NVRTC_SUCCESS {
return Err(format!("nvrtcCreateProgram failed: {r:?}"));
}
let compile_result = compile_and_fetch(prog, arch, &option_ptrs);
let _ = sys::nvrtcDestroyProgram(std::ptr::addr_of_mut!(prog));
compile_result
}
}
#[expect(unsafe_code, reason = "NVRTC C API requires raw program handles")]
unsafe fn compile_and_fetch(
prog: sys::nvrtcProgram,
arch: &str,
option_ptrs: &[*const ::core::ffi::c_char],
) -> Result<Ptx, String> {
let option_count = ::core::ffi::c_int::try_from(option_ptrs.len())
.map_err(|_| format!("too many NVRTC options: {}", option_ptrs.len()))?;
let r = unsafe { sys::nvrtcCompileProgram(prog, option_count, option_ptrs.as_ptr()) };
if r != sys::nvrtcResult::NVRTC_SUCCESS {
let log = unsafe { program_log(prog) }.unwrap_or_default();
return Err(format!("nvrtcCompileProgram failed ({r:?}): {log}"));
}
let mut size: usize = 0;
let r = unsafe { sys::nvrtcGetCUBINSize(prog, std::ptr::addr_of_mut!(size)) };
if r != sys::nvrtcResult::NVRTC_SUCCESS {
return Err(format!("nvrtcGetCUBINSize failed: {r:?}"));
}
if size == 0 {
return Err(format!(
"nvrtcGetCUBIN returned empty — arch {arch:?} must be a real sm_XX target, not virtual compute_XX"
));
}
let mut buf: Vec<u8> = vec![0u8; size];
let r = unsafe { sys::nvrtcGetCUBIN(prog, buf.as_mut_ptr().cast::<::core::ffi::c_char>()) };
if r != sys::nvrtcResult::NVRTC_SUCCESS {
return Err(format!("nvrtcGetCUBIN failed: {r:?}"));
}
Ok(Ptx::from_binary(buf))
}
#[expect(unsafe_code, reason = "NVRTC C API requires raw program handles")]
unsafe fn program_log(prog: sys::nvrtcProgram) -> Option<String> {
let mut size: usize = 0;
if unsafe { sys::nvrtcGetProgramLogSize(prog, std::ptr::addr_of_mut!(size)) }
!= sys::nvrtcResult::NVRTC_SUCCESS
{
return None;
}
if size <= 1 {
return Some(String::new());
}
let mut buf: Vec<u8> = vec![0u8; size];
if unsafe { sys::nvrtcGetProgramLog(prog, buf.as_mut_ptr().cast::<::core::ffi::c_char>()) }
!= sys::nvrtcResult::NVRTC_SUCCESS
{
return None;
}
if let Some(&0) = buf.last() {
buf.pop();
}
String::from_utf8(buf).ok()
}