use crate::hip;
use crate::hip::error::{Error, Result};
use crate::hip::event::Event;
use crate::hip::ffi;
use std::{panic, ptr};
use super::memory::SynchronizeCopies;
#[derive(Clone, Debug)]
pub struct Stream {
pub(crate) stream: hip::ffi::hipStream_t,
}
impl Stream {
pub(crate) fn new() -> Result<Self> {
let mut stream = ptr::null_mut();
let error = unsafe { ffi::hipStreamCreate(&mut stream) };
if error != ffi::hipError_t_hipSuccess {
return Err(Error::new(error));
}
Ok(Self { stream })
}
pub(crate) fn with_flags(flags: u32) -> Result<Self> {
let mut stream = ptr::null_mut();
let error = unsafe { ffi::hipStreamCreateWithFlags(&mut stream, flags) };
if error != ffi::hipError_t_hipSuccess {
return Err(Error::new(error));
}
Ok(Self { stream })
}
pub(crate) fn with_priority(flags: u32, priority: i32) -> Result<Self> {
let mut stream = ptr::null_mut();
let error = unsafe { ffi::hipStreamCreateWithPriority(&mut stream, flags, priority) };
if error != ffi::hipError_t_hipSuccess {
return Err(Error::new(error));
}
Ok(Self { stream })
}
pub fn synchronize(&self) -> Result<()> {
let error = unsafe { ffi::hipStreamSynchronize(self.stream) };
if error != ffi::hipError_t_hipSuccess {
return Err(Error::new(error));
}
Ok(())
}
pub fn synchronize_memory<T: SynchronizeCopies>(&self, copies: T) -> Result<T::Output> {
Self::synchronize(&self)?;
Ok(unsafe { copies.finalize() })
}
pub fn query(&self) -> Result<()> {
let error = unsafe { ffi::hipStreamQuery(self.stream) };
if error == ffi::hipError_t_hipSuccess {
Ok(())
} else if error == ffi::hipError_t_hipErrorNotReady {
Err(Error::new(error))
} else {
Err(Error::new(error))
}
}
pub fn wait_event(&self, event: &Event, flags: u32) -> Result<()> {
let error = unsafe { ffi::hipStreamWaitEvent(self.stream, event.as_raw(), flags) };
if error != ffi::hipError_t_hipSuccess {
return Err(Error::new(error));
}
Ok(())
}
pub fn add_callback<F>(&self, callback: F) -> Result<()>
where
F: FnOnce() + Send + 'static,
{
type Callback = dyn FnOnce() + Send + 'static;
let boxed: Box<Option<Box<Callback>>> = Box::new(Some(Box::new(callback)));
let ptr = Box::into_raw(boxed) as *mut std::ffi::c_void;
unsafe extern "C" fn helper_callback(
_stream: ffi::hipStream_t,
_status: ffi::hipError_t,
user_data: *mut std::ffi::c_void,
) {
let callback_box = unsafe { Box::from_raw(user_data as *mut Option<Box<Callback>>) };
if let Some(callback) = *callback_box {
let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| callback()));
}
}
let error =
unsafe { ffi::hipStreamAddCallback(self.stream, Some(helper_callback), ptr, 0) };
if error != ffi::hipError_t_hipSuccess {
unsafe { drop(Box::from_raw(ptr)) }
return Err(Error::new(error));
}
Ok(())
}
pub fn as_raw(&self) -> ffi::hipStream_t {
self.stream
}
pub fn priority_range() -> Result<(i32, i32)> {
let mut least_priority = 0;
let mut greatest_priority = 0;
let error = unsafe {
ffi::hipDeviceGetStreamPriorityRange(&mut least_priority, &mut greatest_priority)
};
if error != ffi::hipError_t_hipSuccess {
return Err(Error::new(error));
}
Ok((least_priority, greatest_priority))
}
pub fn get_priority(&self) -> Result<i32> {
let mut priority = 0;
let error = unsafe { ffi::hipStreamGetPriority(self.stream, &mut priority) };
if error != ffi::hipError_t_hipSuccess {
return Err(Error::new(error));
}
Ok(priority)
}
pub fn get_flags(&self) -> Result<u32> {
let mut flags = 0;
let error = unsafe { ffi::hipStreamGetFlags(self.stream, &mut flags) };
if error != ffi::hipError_t_hipSuccess {
return Err(Error::new(error));
}
Ok(flags)
}
pub fn get_device(&self) -> Result<i32> {
let mut device = 0;
let error = unsafe { ffi::hipStreamGetDevice(self.stream, &mut device) };
if error != ffi::hipError_t_hipSuccess {
return Err(Error::new(error));
}
Ok(device)
}
pub fn from_raw(stream: ffi::hipStream_t) -> Self {
Self { stream }
}
}
impl Drop for Stream {
fn drop(&mut self) {
if !self.stream.is_null() {
unsafe {
let _ = ffi::hipStreamDestroy(self.stream);
};
self.stream = ptr::null_mut();
}
}
}
pub mod stream_flags {
pub const DEFAULT: u32 = 0;
pub const NON_BLOCKING: u32 = 1;
}