use std::collections::HashMap;
use std::ffi::CString;
use std::ptr;
use super::context::{get_driver, CudaContext};
use super::sys::{CUfunction, CUmodule, CudaDriver};
use crate::GpuError;
pub struct CudaModule {
module: CUmodule,
functions: HashMap<String, CUfunction>,
}
unsafe impl Send for CudaModule {}
unsafe impl Sync for CudaModule {}
impl CudaModule {
pub fn from_ptx(_ctx: &CudaContext, ptx: &str) -> Result<Self, GpuError> {
let driver = get_driver()?;
let ptx_cstring = CString::new(ptx)
.map_err(|_| GpuError::ModuleLoad("PTX contains null bytes".to_string()))?;
let mut module: CUmodule = ptr::null_mut();
let result =
unsafe { (driver.cuModuleLoadData)(&mut module, ptx_cstring.as_ptr() as *const _) };
if let Err(e) = CudaDriver::check(result) {
let ptx_path = "/tmp/failing_ptx.txt";
if let Ok(()) = std::fs::write(ptx_path, ptx) {
eprintln!("[PTX-DEBUG] Failing PTX dumped to {}", ptx_path);
}
let kernel_name = ptx
.lines()
.find(|l| l.contains(".entry"))
.map(|l| l.trim())
.unwrap_or("<unknown>");
eprintln!("[PTX-DEBUG] Failed kernel: {}", kernel_name);
eprintln!("[PTX-DEBUG] PTX length: {} bytes", ptx.len());
return Err(GpuError::ModuleLoad(e.to_string()));
}
Ok(Self {
module,
functions: HashMap::new(),
})
}
pub fn get_function(&mut self, name: &str) -> Result<CUfunction, GpuError> {
if let Some(&func) = self.functions.get(name) {
return Ok(func);
}
let driver = get_driver()?;
let name_cstring =
CString::new(name).map_err(|_| GpuError::FunctionNotFound(name.to_string()))?;
let mut func: CUfunction = ptr::null_mut();
let result =
unsafe { (driver.cuModuleGetFunction)(&mut func, self.module, name_cstring.as_ptr()) };
CudaDriver::check(result).map_err(|_| GpuError::FunctionNotFound(name.to_string()))?;
self.functions.insert(name.to_string(), func);
Ok(func)
}
#[must_use]
pub fn raw(&self) -> CUmodule {
self.module
}
pub fn has_function(&mut self, name: &str) -> bool {
self.get_function(name).is_ok()
}
#[must_use]
pub fn cached_functions(&self) -> Vec<&str> {
self.functions.keys().map(String::as_str).collect()
}
}
impl Drop for CudaModule {
fn drop(&mut self) {
if let Ok(driver) = get_driver() {
unsafe {
let _ = (driver.cuModuleUnload)(self.module);
}
}
}
}
#[cfg(test)]
mod tests {
#[test]
#[cfg(not(feature = "cuda"))]
fn test_module_requires_cuda_feature() {
assert!(true);
}
}