use crate::hip::error::{Error, Result};
use crate::hip::ffi;
use crate::hip::kernel::Function;
use std::ffi::{CString, c_void};
use std::fs;
use std::path::Path;
use std::ptr;
pub struct Module {
module: ffi::hipModule_t,
}
impl Module {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let path_str = path.as_ref().to_string_lossy();
let path_cstr = CString::new(path_str.as_bytes()).unwrap();
let mut module = ptr::null_mut();
let error = unsafe { ffi::hipModuleLoad(&mut module, path_cstr.as_ptr()) };
if error != ffi::hipError_t_hipSuccess {
return Err(Error::new(error));
}
Ok(Self { module })
}
pub fn load_data(data: impl AsRef<[u8]>) -> Result<Self> {
let mut module = ptr::null_mut();
let error =
unsafe { ffi::hipModuleLoadData(&mut module, data.as_ref().as_ptr() as *const c_void) };
if error != ffi::hipError_t_hipSuccess {
return Err(Error::new(error));
}
Ok(Self { module })
}
pub unsafe fn load_with_options(
data: impl AsRef<[u8]>,
num_options: u32,
options: *mut ffi::hipJitOption,
option_values: *mut *mut c_void,
) -> Result<Self> {
let mut module = ptr::null_mut();
let error = unsafe {
ffi::hipModuleLoadDataEx(
&mut module,
data.as_ref().as_ptr() as *const c_void,
num_options,
options,
option_values,
)
};
if error != ffi::hipError_t_hipSuccess {
return Err(Error::new(error));
}
Ok(Self { module })
}
pub fn get_function(&self, name: &str) -> Result<Function> {
unsafe { Function::new(self.module, name) }
}
pub fn get_global<T>(&self, name: &str) -> Result<*mut T> {
let name_cstr = CString::new(name).unwrap();
let mut dev_ptr = ptr::null_mut();
let mut size = 0usize;
let error = unsafe {
ffi::hipModuleGetGlobal(&mut dev_ptr, &mut size, self.module, name_cstr.as_ptr())
};
if error != ffi::hipError_t_hipSuccess {
return Err(Error::new(error));
}
if size < std::mem::size_of::<T>() {
return Err(Error::new(ffi::hipError_t_hipErrorInvalidValue));
}
Ok(dev_ptr as *mut T)
}
pub fn as_raw(&self) -> ffi::hipModule_t {
self.module
}
}
impl Drop for Module {
fn drop(&mut self) {
if !self.module.is_null() {
unsafe {
let _ = ffi::hipModuleUnload(self.module);
}
self.module = ptr::null_mut();
}
}
}
pub fn load_module<P: AsRef<Path>>(path: P) -> Result<Module> {
Module::load(path)
}
pub fn load_module_data(data: &str) -> Result<Module> {
Module::load_data(data)
}
pub fn compile_and_load(source: &str, options: &[String]) -> Result<Module> {
use std::env::temp_dir;
use std::process::Command;
let temp_src_path = temp_dir().join("temp_kernel.cpp");
let temp_bin_path = temp_dir().join("temp_kernel.hsaco");
fs::write(&temp_src_path, source)
.map_err(|_| Error::new(ffi::hipError_t_hipErrorInvalidValue))?;
let mut cmd = Command::new("hipcc");
cmd.arg("--genco");
for opt in options {
cmd.arg(opt);
}
cmd.arg("-o").arg(&temp_bin_path).arg(&temp_src_path);
let status = cmd
.status()
.map_err(|_| Error::new(ffi::hipError_t_hipErrorInvalidValue))?;
if !status.success() {
return Err(Error::new(ffi::hipError_t_hipErrorInvalidValue));
}
Module::load(temp_bin_path)
}