use crate::hip::{Stream, bindings::hipStream_t};
use crate::miopen::error::{Error, Result};
use crate::miopen::ffi;
use std::ptr;
pub struct Handle {
handle: ffi::miopenHandle_t,
}
impl Handle {
pub fn new() -> Result<Self> {
let mut handle = ptr::null_mut();
let status = unsafe { ffi::miopenCreate(&mut handle) };
if status != ffi::miopenStatus_t_miopenStatusSuccess {
return Err(Error::new(status));
}
Ok(Self { handle })
}
pub fn with_stream(stream: &Stream) -> Result<Self> {
let mut handle = ptr::null_mut();
let status = unsafe {
ffi::miopenCreateWithStream(
&mut handle,
stream.as_raw() as crate::miopen::bindings::miopenAcceleratorQueue_t,
)
};
if status != ffi::miopenStatus_t_miopenStatusSuccess {
return Err(Error::new(status));
}
Ok(Self { handle })
}
pub fn set_stream(&self, stream: &Stream) -> Result<()> {
let status = unsafe {
ffi::miopenSetStream(
self.handle,
stream.as_raw() as crate::miopen::bindings::miopenAcceleratorQueue_t,
)
};
if status != ffi::miopenStatus_t_miopenStatusSuccess {
return Err(Error::new(status));
}
Ok(())
}
pub fn get_stream(&self) -> Result<Stream> {
let mut stream_id = ptr::null_mut();
let status = unsafe { ffi::miopenGetStream(self.handle, &mut stream_id) };
if status != ffi::miopenStatus_t_miopenStatusSuccess {
return Err(Error::new(status));
}
Ok(Stream::from_raw(stream_id as hipStream_t))
}
pub fn enable_profiling(&self, enable: bool) -> Result<()> {
let status = unsafe { ffi::miopenEnableProfiling(self.handle, enable) };
if status != ffi::miopenStatus_t_miopenStatusSuccess {
return Err(Error::new(status));
}
Ok(())
}
pub fn get_kernel_time(&self) -> Result<f32> {
let mut time = 0.0;
let status = unsafe { ffi::miopenGetKernelTime(self.handle, &mut time) };
if status != ffi::miopenStatus_t_miopenStatusSuccess {
return Err(Error::new(status));
}
Ok(time)
}
pub unsafe fn set_allocator(
&self,
allocator: ffi::miopenAllocatorFunction,
deallocator: ffi::miopenDeallocatorFunction,
context: *mut ::std::os::raw::c_void,
) -> Result<()> {
let status =
unsafe { ffi::miopenSetAllocator(self.handle, allocator, deallocator, context) };
if status != ffi::miopenStatus_t_miopenStatusSuccess {
return Err(Error::new(status));
}
Ok(())
}
pub fn as_raw(&self) -> ffi::miopenHandle_t {
self.handle
}
}
impl Drop for Handle {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe {
let _ = ffi::miopenDestroy(self.handle);
};
self.handle = ptr::null_mut();
}
}
}