use std::collections::HashMap;
use std::ffi::CString;
use std::ptr;
use std::ffi::c_void;
use std::os::raw::c_uint;
use super::context::{get_driver, CudaContext};
use super::ptx_cache::{load_cached_cubin, ptx_cache_dir, ptx_cache_key, save_cached_cubin};
use super::ptx_patch::patch_backward_branches_sm121;
use super::sys::{
CUfunction, CUmodule, CudaDriver, CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
CU_JIT_INPUT_PTX, CU_JIT_TARGET,
};
use crate::GpuError;
const CU_JIT_INFO_LOG_BUFFER: c_uint = 3;
const CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES: c_uint = 4;
fn query_driver_version(driver: &CudaDriver) -> i32 {
let mut version: i32 = 0;
let result = unsafe { (driver.cuDriverGetVersion)(&mut version) };
if CudaDriver::check(result).is_ok() {
version
} else {
0
}
}
fn compile_ptx_to_cubin(
driver: &CudaDriver,
ptx: &str,
jit_target: c_uint,
) -> Result<Vec<u8>, GpuError> {
use super::sys::CUlinkState;
let mut link_state: CUlinkState = ptr::null_mut();
let mut opt_keys: [c_uint; 1] = [CU_JIT_TARGET];
let mut opt_vals: [*mut c_void; 1] = [jit_target as *mut c_void];
let result = unsafe {
(driver.cuLinkCreate)(
1,
opt_keys.as_mut_ptr(),
opt_vals.as_mut_ptr(),
&mut link_state,
)
};
CudaDriver::check(result)
.map_err(|e| GpuError::ModuleLoad(format!("cuLinkCreate failed: {e}")))?;
let ptx_cstring = CString::new(ptx.as_bytes().to_vec())
.map_err(|_| GpuError::ModuleLoad("PTX contains null bytes".to_string()))?;
let ptx_name = CString::new("kernel.ptx").expect("static string has no null bytes");
let result = unsafe {
(driver.cuLinkAddData)(
link_state,
CU_JIT_INPUT_PTX,
ptx_cstring.as_ptr() as *mut c_void,
ptx_cstring.as_bytes_with_nul().len(),
ptx_name.as_ptr(),
0,
ptr::null_mut(),
ptr::null_mut(),
)
};
if CudaDriver::check(result).is_err() {
unsafe { (driver.cuLinkDestroy)(link_state) };
return Err(GpuError::ModuleLoad(format!(
"cuLinkAddData failed: result={result}"
)));
}
let mut cubin_ptr: *mut c_void = ptr::null_mut();
let mut cubin_size: usize = 0;
let result = unsafe { (driver.cuLinkComplete)(link_state, &mut cubin_ptr, &mut cubin_size) };
if CudaDriver::check(result).is_err() {
unsafe { (driver.cuLinkDestroy)(link_state) };
return Err(GpuError::ModuleLoad(format!(
"cuLinkComplete failed: result={result}"
)));
}
let cubin = if !cubin_ptr.is_null() && cubin_size > 0 {
let slice = unsafe { std::slice::from_raw_parts(cubin_ptr as *const u8, cubin_size) };
slice.to_vec()
} else {
unsafe { (driver.cuLinkDestroy)(link_state) };
return Err(GpuError::ModuleLoad(
"cuLinkComplete returned null cubin".to_string(),
));
};
unsafe { (driver.cuLinkDestroy)(link_state) };
Ok(cubin)
}
pub struct CudaModule {
module: CUmodule,
functions: HashMap<String, CUfunction>,
}
unsafe impl Send for CudaModule {}
unsafe impl Sync for CudaModule {}
impl CudaModule {
pub fn from_ptx_direct(ctx: &CudaContext, ptx: &str) -> Result<Self, GpuError> {
let driver = get_driver()?;
ctx.make_current()?;
let (major, _) = ctx.compute_capability()?;
let ptx_patched = if major >= 12 {
patch_backward_branches_sm121(ptx)
} else {
None
};
let ptx = ptx_patched.as_deref().unwrap_or(ptx);
let ptx_cstring = CString::new(ptx.as_bytes().to_vec())
.map_err(|_| GpuError::ModuleLoad("PTX contains null bytes".to_string()))?;
let mut module: CUmodule = ptr::null_mut();
let result =
unsafe { (driver.cuModuleLoadData)(&mut module, ptx_cstring.as_ptr() as *const _) };
CudaDriver::check(result).map_err(|e| {
let kernel_name = ptx
.lines()
.find(|l| l.contains(".entry"))
.unwrap_or("unknown");
let dump_path = format!("/tmp/ptx-fail-{}.ptx", std::process::id());
let _ = std::fs::write(&dump_path, ptx);
eprintln!(
"[PTX-FAIL] Invalid PTX dumped to {dump_path} ({} bytes)",
ptx.len()
);
GpuError::ModuleLoad(format!(
"cuModuleLoadData failed: result={result} (kernel: {kernel_name}), error: {e}"
))
})?;
Ok(Self {
module,
functions: HashMap::new(),
})
}
}
impl CudaModule {
pub fn from_ptx(ctx: &CudaContext, ptx: &str) -> Result<Self, GpuError> {
let driver = get_driver()?;
ctx.make_current()?;
let (major, minor) = ctx.compute_capability()?;
let jit_target: c_uint = (major * 10 + minor) as c_uint;
let ptx_patched = if major >= 12 {
patch_backward_branches_sm121(ptx)
} else {
None
};
let ptx: &str = ptx_patched.as_deref().unwrap_or(ptx);
let driver_version = query_driver_version(driver);
let cache_key = ptx_cache_key(ptx, jit_target, driver_version);
if let Some(cubin) = load_cached_cubin(&cache_key) {
let mut module: CUmodule = ptr::null_mut();
let result =
unsafe { (driver.cuModuleLoadData)(&mut module, cubin.as_ptr() as *const c_void) };
if CudaDriver::check(result).is_ok() {
return Ok(Self {
module,
functions: HashMap::new(),
});
}
if let Some(dir) = ptx_cache_dir() {
let _ = std::fs::remove_file(dir.join(format!("{cache_key}.cubin")));
}
eprintln!(
"[PTX-CACHE] Cache hit but cuModuleLoadData failed (result={result}), \
falling through to JIT compilation"
);
}
let cubin_result = compile_ptx_to_cubin(driver, ptx, jit_target);
if let Err(ref e) = cubin_result {
eprintln!("[PTX-CACHE] Linker compilation failed: {e}, falling through to legacy JIT");
}
if let Ok(cubin) = &cubin_result {
save_cached_cubin(&cache_key, cubin);
let mut module: CUmodule = ptr::null_mut();
let result =
unsafe { (driver.cuModuleLoadData)(&mut module, cubin.as_ptr() as *const c_void) };
if CudaDriver::check(result).is_ok() {
return Ok(Self {
module,
functions: HashMap::new(),
});
}
eprintln!(
"[PTX-CACHE] Linker produced cubin but cuModuleLoadData failed (result={result}), \
falling through to legacy JIT"
);
}
Self::from_ptx_legacy(driver, ptx, jit_target, major, minor)
}
fn from_ptx_legacy(
driver: &CudaDriver,
ptx: &str,
jit_target: c_uint,
major: i32,
minor: i32,
) -> Result<Self, GpuError> {
let ptx_cstring = CString::new(ptx.as_bytes().to_vec())
.map_err(|_| GpuError::ModuleLoad("PTX contains null bytes".to_string()))?;
let mut info_log = vec![0u8; 4096];
let mut error_log = vec![0u8; 4096];
let info_log_size: usize = info_log.len();
let error_log_size: usize = error_log.len();
let mut options: [c_uint; 5] = [
CU_JIT_TARGET,
CU_JIT_INFO_LOG_BUFFER,
CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES,
CU_JIT_ERROR_LOG_BUFFER,
CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
];
let mut option_values: [*mut c_void; 5] = [
jit_target as *mut c_void,
info_log.as_mut_ptr() as *mut c_void,
info_log_size as *mut c_void,
error_log.as_mut_ptr() as *mut c_void,
error_log_size as *mut c_void,
];
let mut module: CUmodule = ptr::null_mut();
let result = unsafe {
(driver.cuModuleLoadDataEx)(
&mut module,
ptx_cstring.as_ptr() as *const _,
5,
options.as_mut_ptr(),
option_values.as_mut_ptr(),
)
};
if CudaDriver::check(result).is_ok() {
return Ok(Self {
module,
functions: HashMap::new(),
});
}
let kernel_name = ptx
.lines()
.find(|l| l.contains(".entry"))
.map(|l| l.trim())
.unwrap_or("<unknown>");
let jit_info = String::from_utf8_lossy(&info_log)
.trim_end_matches('\0')
.to_string();
let jit_err = String::from_utf8_lossy(&error_log)
.trim_end_matches('\0')
.to_string();
eprintln!(
"[PTX-JIT] Try 1 failed: {kernel_name}, target: sm_{major}{minor}, \
PTX: {} bytes, result: {result}",
ptx.len()
);
if !jit_info.is_empty() {
eprintln!("[PTX-JIT] Info log: {jit_info}");
}
if !jit_err.is_empty() {
eprintln!("[PTX-JIT] Error log: {jit_err}");
}
let dump_path = format!(
"/tmp/failed-ptx-sm_{major}{minor}-{}.ptx",
kernel_name.replace(|c: char| !c.is_alphanumeric() && c != '_', "_")
);
if let Ok(()) = std::fs::write(&dump_path, ptx) {
eprintln!("[PTX-JIT] PTX dumped to {dump_path}");
}
eprintln!("[PTX-JIT] Retrying with cuModuleLoadData (no explicit target)...");
let mut module2: CUmodule = ptr::null_mut();
let result2 =
unsafe { (driver.cuModuleLoadData)(&mut module2, ptx_cstring.as_ptr() as *const _) };
if CudaDriver::check(result2).is_ok() {
eprintln!("[PTX-JIT] Fallback succeeded for {kernel_name}");
return Ok(Self {
module: module2,
functions: HashMap::new(),
});
}
eprintln!("[PTX-JIT] Both attempts failed for {kernel_name}");
Err(GpuError::ModuleLoad(format!(
"CUDA module loading failed: try1={result} try2={result2} (JIT target: sm_{major}{minor})"
)))
}
pub fn get_function(&mut self, name: &str) -> Result<CUfunction, GpuError> {
if let Some(&func) = self.functions.get(name) {
return Ok(func);
}
let driver = get_driver()?;
let name_cstring =
CString::new(name).map_err(|_| GpuError::FunctionNotFound(name.to_string()))?;
let mut func: CUfunction = ptr::null_mut();
let result =
unsafe { (driver.cuModuleGetFunction)(&mut func, self.module, name_cstring.as_ptr()) };
CudaDriver::check(result).map_err(|_| GpuError::FunctionNotFound(name.to_string()))?;
self.functions.insert(name.to_string(), func);
Ok(func)
}
#[must_use]
pub fn raw(&self) -> CUmodule {
self.module
}
pub fn has_function(&mut self, name: &str) -> bool {
self.get_function(name).is_ok()
}
#[must_use]
pub fn cached_functions(&self) -> Vec<&str> {
self.functions.keys().map(String::as_str).collect()
}
}
impl Drop for CudaModule {
fn drop(&mut self) {
if let Ok(driver) = get_driver() {
unsafe {
let _ = (driver.cuModuleUnload)(self.module);
}
}
}
}
#[cfg(test)]
mod tests {
#[test]
#[cfg(not(feature = "cuda"))]
fn test_module_requires_cuda_feature() {
assert!(true);
}
}