use std::collections::HashMap;
use std::ffi::CString;
use std::ptr;
use std::ffi::c_void;
use std::os::raw::c_uint;
use super::context::{get_driver, CudaContext};
use super::sys::{
CUfunction, CUmodule, CudaDriver, CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
CU_JIT_TARGET,
};
use crate::GpuError;
const CU_JIT_INFO_LOG_BUFFER: c_uint = 3;
const CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES: c_uint = 4;
pub struct CudaModule {
module: CUmodule,
functions: HashMap<String, CUfunction>,
}
unsafe impl Send for CudaModule {}
unsafe impl Sync for CudaModule {}
impl CudaModule {
pub fn from_ptx(ctx: &CudaContext, ptx: &str) -> Result<Self, GpuError> {
let driver = get_driver()?;
ctx.make_current()?;
let (major, minor) = ctx.compute_capability()?;
let jit_target: c_uint = (major * 10 + minor) as c_uint;
let ptx_cstring = CString::new(ptx)
.map_err(|_| GpuError::ModuleLoad("PTX contains null bytes".to_string()))?;
let mut info_log = vec![0u8; 4096];
let mut error_log = vec![0u8; 4096];
let info_log_size: usize = info_log.len();
let error_log_size: usize = error_log.len();
let mut options: [c_uint; 5] = [
CU_JIT_TARGET,
CU_JIT_INFO_LOG_BUFFER,
CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES,
CU_JIT_ERROR_LOG_BUFFER,
CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
];
let mut option_values: [*mut c_void; 5] = [
jit_target as *mut c_void,
info_log.as_mut_ptr() as *mut c_void,
info_log_size as *mut c_void,
error_log.as_mut_ptr() as *mut c_void,
error_log_size as *mut c_void,
];
let mut module: CUmodule = ptr::null_mut();
let result = unsafe {
(driver.cuModuleLoadDataEx)(
&mut module,
ptx_cstring.as_ptr() as *const _,
5,
options.as_mut_ptr(),
option_values.as_mut_ptr(),
)
};
if CudaDriver::check(result).is_ok() {
return Ok(Self { module, functions: HashMap::new() });
}
let kernel_name =
ptx.lines().find(|l| l.contains(".entry")).map(|l| l.trim()).unwrap_or("<unknown>");
let jit_info = String::from_utf8_lossy(&info_log).trim_end_matches('\0').to_string();
let jit_err = String::from_utf8_lossy(&error_log).trim_end_matches('\0').to_string();
eprintln!(
"[PTX-JIT] Try 1 failed: {kernel_name}, target: sm_{major}{minor}, \
PTX: {} bytes, result: {result}",
ptx.len()
);
if !jit_info.is_empty() {
eprintln!("[PTX-JIT] Info log: {jit_info}");
}
if !jit_err.is_empty() {
eprintln!("[PTX-JIT] Error log: {jit_err}");
}
let dump_path = format!(
"/tmp/failed-ptx-sm_{major}{minor}-{}.ptx",
kernel_name.replace(|c: char| !c.is_alphanumeric() && c != '_', "_")
);
if let Ok(()) = std::fs::write(&dump_path, ptx) {
eprintln!("[PTX-JIT] PTX dumped to {dump_path}");
}
eprintln!("[PTX-JIT] Retrying with cuModuleLoadData (no explicit target)...");
let mut module2: CUmodule = ptr::null_mut();
let result2 =
unsafe { (driver.cuModuleLoadData)(&mut module2, ptx_cstring.as_ptr() as *const _) };
if CudaDriver::check(result2).is_ok() {
eprintln!("[PTX-JIT] Fallback succeeded for {kernel_name}");
return Ok(Self { module: module2, functions: HashMap::new() });
}
eprintln!("[PTX-JIT] Both attempts failed for {kernel_name}");
Err(GpuError::ModuleLoad(format!(
"CUDA module loading failed: try1={result} try2={result2} (JIT target: sm_{major}{minor})"
)))
}
pub fn get_function(&mut self, name: &str) -> Result<CUfunction, GpuError> {
if let Some(&func) = self.functions.get(name) {
return Ok(func);
}
let driver = get_driver()?;
let name_cstring =
CString::new(name).map_err(|_| GpuError::FunctionNotFound(name.to_string()))?;
let mut func: CUfunction = ptr::null_mut();
let result =
unsafe { (driver.cuModuleGetFunction)(&mut func, self.module, name_cstring.as_ptr()) };
CudaDriver::check(result).map_err(|_| GpuError::FunctionNotFound(name.to_string()))?;
self.functions.insert(name.to_string(), func);
Ok(func)
}
#[must_use]
pub fn raw(&self) -> CUmodule {
self.module
}
pub fn has_function(&mut self, name: &str) -> bool {
self.get_function(name).is_ok()
}
#[must_use]
pub fn cached_functions(&self) -> Vec<&str> {
self.functions.keys().map(String::as_str).collect()
}
}
impl Drop for CudaModule {
fn drop(&mut self) {
if let Ok(driver) = get_driver() {
unsafe {
let _ = (driver.cuModuleUnload)(self.module);
}
}
}
}
#[cfg(test)]
mod tests {
#[test]
#[cfg(not(feature = "cuda"))]
fn test_module_requires_cuda_feature() {
assert!(true);
}
}