use super::{DriverError, sys};
use std::ffi::{CString, c_void};
#[inline]
fn check(r: sys::hipError_t) -> Result<(), DriverError> {
if r == sys::hipError_t::hipSuccess {
Ok(())
} else {
Err(DriverError::Hip(r))
}
}
pub fn set_device(ordinal: i32) -> Result<(), DriverError> {
unsafe { check(sys::hipSetDevice(ordinal)) }
}
pub fn device_name(ordinal: i32) -> Result<String, DriverError> {
let mut buf = [0i8; 256];
unsafe { check(sys::hipDeviceGetName(buf.as_mut_ptr(), buf.len() as i32, ordinal))? };
let cstr = unsafe { std::ffi::CStr::from_ptr(buf.as_ptr()) };
Ok(cstr.to_string_lossy().into_owned())
}
pub fn stream_create() -> Result<sys::hipStream_t, DriverError> {
let mut s: sys::hipStream_t = std::ptr::null_mut();
unsafe { check(sys::hipStreamCreate(&mut s))? };
Ok(s)
}
pub fn stream_destroy(s: sys::hipStream_t) -> Result<(), DriverError> {
unsafe { check(sys::hipStreamDestroy(s)) }
}
pub fn stream_synchronize(s: sys::hipStream_t) -> Result<(), DriverError> {
unsafe { check(sys::hipStreamSynchronize(s)) }
}
pub fn malloc(n_bytes: usize) -> Result<u64, DriverError> {
let mut p: *mut c_void = std::ptr::null_mut();
unsafe { check(sys::hipMalloc(&mut p, n_bytes))? };
Ok(p as u64)
}
pub fn malloc_async(n_bytes: usize, s: sys::hipStream_t) -> Result<u64, DriverError> {
let mut p: *mut c_void = std::ptr::null_mut();
unsafe { check(sys::hipMallocAsync(&mut p, n_bytes, s))? };
Ok(p as u64)
}
pub fn free(ptr: u64) -> Result<(), DriverError> {
unsafe { check(sys::hipFree(ptr as *mut c_void)) }
}
pub fn free_async(ptr: u64, s:sys::hipStream_t) -> Result<(), DriverError> {
unsafe { check(sys::hipFreeAsync(ptr as *mut c_void, s)) }
}
pub unsafe fn memcpy_htod_async(
dst: u64,
src: &[u8],
stream: sys::hipStream_t,
) -> Result<(), DriverError> {
unsafe {
check(sys::hipMemcpyHtoDAsync(
dst as sys::hipDeviceptr_t,
src.as_ptr() as *const c_void,
src.len(),
stream,
))
}
}
pub unsafe fn memcpy_dtoh_async(
dst: &mut [u8],
src: u64,
stream: sys::hipStream_t,
) -> Result<(), DriverError> {
unsafe {
check(sys::hipMemcpyDtoHAsync(
dst.as_mut_ptr() as *mut c_void,
src as sys::hipDeviceptr_t,
dst.len(),
stream,
))
}
}
pub fn module_load_data(image: &[u8]) -> Result<sys::hipModule_t, DriverError> {
let mut m: sys::hipModule_t = std::ptr::null_mut();
unsafe { check(sys::hipModuleLoadData(&mut m, image.as_ptr() as *const c_void))? };
Ok(m)
}
pub fn module_unload(m: sys::hipModule_t) -> Result<(), DriverError> {
unsafe { check(sys::hipModuleUnload(m)) }
}
pub fn module_get_function(
m: sys::hipModule_t,
name: &str,
) -> Result<sys::hipFunction_t, DriverError> {
let c = CString::new(name).map_err(|_| DriverError::InvalidName)?;
let mut f: sys::hipFunction_t = std::ptr::null_mut();
unsafe { check(sys::hipModuleGetFunction(&mut f, m, c.as_ptr()))? };
Ok(f)
}
pub unsafe fn module_launch_kernel(
f: sys::hipFunction_t,
grid: (u32, u32, u32),
block: (u32, u32, u32),
shared_mem: u32,
stream: sys::hipStream_t,
params: &mut [*mut c_void],
) -> Result<(), DriverError> {
unsafe {
check(sys::hipModuleLaunchKernel(
f,
grid.0,
grid.1,
grid.2,
block.0,
block.1,
block.2,
shared_mem,
stream,
params.as_mut_ptr(),
std::ptr::null_mut(),
))
}
}