use std::{marker::PhantomData, sync::Arc};
pub mod result;
pub mod sys;
#[derive(Debug, thiserror::Error)]
pub enum DriverError {
#[error("HIP error: {0:?}")]
Hip(sys::hipError_t),
#[error("kernel name contained a null byte")]
InvalidName,
}
pub struct HipContext {
ordinal: i32,
default_stream: Arc<HipStream>,
gfx_arch: String,
}
impl HipContext {
pub fn new(ordinal: i32, gfx_arch: impl Into<String>) -> Result<Arc<Self>, DriverError> {
result::set_device(ordinal)?;
let stream = Arc::new(HipStream::new_on_device(ordinal)?);
Ok(Arc::new(Self {
ordinal,
default_stream: stream,
gfx_arch: gfx_arch.into(),
}))
}
pub fn ordinal(&self) -> i32 {
self.ordinal
}
pub fn default_stream(&self) -> Arc<HipStream> {
self.default_stream.clone()
}
pub fn gfx_arch(&self) -> &str {
&self.gfx_arch
}
pub fn name(&self) -> Result<String, DriverError> {
result::device_name(self.ordinal)
}
pub fn alloc<T>(self: &Arc<Self>, len: usize) -> Result<HipSlice<T>, DriverError> {
HipSlice::alloc(self.default_stream.clone(), len)
}
}
pub struct HipStream {
raw: sys::hipStream_t,
#[allow(dead_code)]
device_ordinal: i32,
}
impl HipStream {
fn new_on_device(ordinal: i32) -> Result<Self, DriverError> {
result::set_device(ordinal)?;
let raw = result::stream_create()?;
Ok(Self {
raw,
device_ordinal: ordinal,
})
}
pub fn synchronize(&self) -> Result<(), DriverError> {
result::stream_synchronize(self.raw)
}
pub fn hip_stream(&self) -> sys::hipStream_t {
self.raw
}
}
impl Drop for HipStream {
fn drop(&mut self) {
let _ = result::stream_destroy(self.raw);
}
}
unsafe impl Send for HipStream {}
unsafe impl Sync for HipStream {}
pub struct HipSlice<T> {
ptr: u64,
len: usize,
#[allow(dead_code)]
stream: Arc<HipStream>,
_marker: PhantomData<*const T>,
}
impl<T> HipSlice<T> {
fn alloc(stream: Arc<HipStream>, len: usize) -> Result<Self, DriverError> {
let bytes = len.checked_mul(std::mem::size_of::<T>()).expect("size overflow");
let ptr = if bytes == 0 { 0 } else { result::malloc(bytes)? };
Ok(Self {
ptr,
len,
stream,
_marker: PhantomData,
})
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn device_ptr(&self) -> u64 {
self.ptr
}
}
impl<T> Drop for HipSlice<T> {
fn drop(&mut self) {
if self.ptr != 0 {
let _ = result::free(self.ptr);
}
}
}
unsafe impl<T: Send> Send for HipSlice<T> {}
unsafe impl<T: Sync> Sync for HipSlice<T> {}
pub trait DevicePtr {
fn device_ptr(&self) -> u64;
}
impl<T> DevicePtr for HipSlice<T> {
fn device_ptr(&self) -> u64 {
self.ptr
}
}
pub struct HipModule {
raw: sys::hipModule_t,
}
impl HipModule {
pub fn from_hsaco(bytes: &[u8]) -> Result<Arc<Self>, DriverError> {
let raw = result::module_load_data(bytes)?;
Ok(Arc::new(Self { raw }))
}
pub fn get_function(self: &Arc<Self>, name: &str) -> Result<HipFunction, DriverError> {
let raw = result::module_get_function(self.raw, name)?;
Ok(HipFunction {
raw,
_module: self.clone(),
})
}
}
impl Drop for HipModule {
fn drop(&mut self) {
let _ = result::module_unload(self.raw);
}
}
unsafe impl Send for HipModule {}
unsafe impl Sync for HipModule {}
pub struct HipFunction {
raw: sys::hipFunction_t,
_module: Arc<HipModule>,
}
impl HipFunction {
pub unsafe fn launch(
&self,
grid: (u32, u32, u32),
block: (u32, u32, u32),
shared_mem_bytes: u32,
stream: &HipStream,
params: &mut [*mut std::ffi::c_void],
) -> Result<(), DriverError> {
unsafe {
result::module_launch_kernel(
self.raw,
grid,
block,
shared_mem_bytes,
stream.hip_stream(),
params,
)
}
}
}
unsafe impl Send for HipFunction {}
unsafe impl Sync for HipFunction {}