use crate::error::{CudaResult, DropResult, ToResult};
use crate::event::Event;
use crate::function::{BlockSize, Function, GridSize};
use cuda_driver_sys::{cudaError_enum, CUstream};
use std::ffi::c_void;
use std::mem;
use std::panic;
use std::ptr;
bitflags! {
pub struct StreamFlags: u32 {
const DEFAULT = 0x00;
const NON_BLOCKING = 0x01;
}
}
bitflags! {
pub struct StreamWaitEventFlags: u32 {
const DEFAULT = 0x0;
}
}
#[derive(Debug)]
pub struct Stream {
inner: CUstream,
}
impl Stream {
pub fn new(flags: StreamFlags, priority: Option<i32>) -> CudaResult<Self> {
unsafe {
let mut stream = Stream {
inner: ptr::null_mut(),
};
cuda_driver_sys::cuStreamCreateWithPriority(
&mut stream.inner as *mut CUstream,
flags.bits(),
priority.unwrap_or(0),
)
.to_result()?;
Ok(stream)
}
}
pub fn get_flags(&self) -> CudaResult<StreamFlags> {
unsafe {
let mut bits = 0u32;
cuda_driver_sys::cuStreamGetFlags(self.inner, &mut bits as *mut u32).to_result()?;
Ok(StreamFlags::from_bits_truncate(bits))
}
}
pub fn get_priority(&self) -> CudaResult<i32> {
unsafe {
let mut priority = 0i32;
cuda_driver_sys::cuStreamGetPriority(self.inner, &mut priority as *mut i32)
.to_result()?;
Ok(priority)
}
}
pub fn add_callback<T>(&self, callback: Box<T>) -> CudaResult<()>
where
T: FnOnce(CudaResult<()>) + Send,
{
unsafe {
cuda_driver_sys::cuStreamAddCallback(
self.inner,
Some(callback_wrapper::<T>),
Box::into_raw(callback) as *mut c_void,
0,
)
.to_result()
}
}
pub fn synchronize(&self) -> CudaResult<()> {
unsafe { cuda_driver_sys::cuStreamSynchronize(self.inner).to_result() }
}
pub fn wait_event(&self, event: Event, flags: StreamWaitEventFlags) -> CudaResult<()> {
unsafe {
cuda_driver_sys::cuStreamWaitEvent(self.inner, event.as_inner(), flags.bits())
.to_result()
}
}
#[doc(hidden)]
pub unsafe fn launch<G, B>(
&self,
func: &Function,
grid_size: G,
block_size: B,
shared_mem_bytes: u32,
args: &[*mut c_void],
) -> CudaResult<()>
where
G: Into<GridSize>,
B: Into<BlockSize>,
{
let grid_size: GridSize = grid_size.into();
let block_size: BlockSize = block_size.into();
cuda_driver_sys::cuLaunchKernel(
func.to_inner(),
grid_size.x,
grid_size.y,
grid_size.z,
block_size.x,
block_size.y,
block_size.z,
shared_mem_bytes,
self.inner,
args.as_ptr() as *mut _,
ptr::null_mut(),
)
.to_result()
}
pub(crate) fn as_inner(&self) -> CUstream {
self.inner
}
pub fn drop(mut stream: Stream) -> DropResult<Stream> {
if stream.inner.is_null() {
return Ok(());
}
unsafe {
let inner = mem::replace(&mut stream.inner, ptr::null_mut());
match cuda_driver_sys::cuStreamDestroy_v2(inner).to_result() {
Ok(()) => {
mem::forget(stream);
Ok(())
}
Err(e) => Err((e, Stream { inner })),
}
}
}
}
impl Drop for Stream {
fn drop(&mut self) {
if self.inner.is_null() {
return;
}
unsafe {
let inner = mem::replace(&mut self.inner, ptr::null_mut());
cuda_driver_sys::cuStreamDestroy_v2(inner)
.to_result()
.expect("Failed to destroy CUDA stream.");
}
}
}
unsafe extern "C" fn callback_wrapper<T>(
_stream: CUstream,
status: cudaError_enum,
callback: *mut c_void,
) where
T: FnOnce(CudaResult<()>) + Send,
{
let _ = panic::catch_unwind(|| {
let callback: Box<T> = Box::from_raw(callback as *mut T);
callback(status.to_result());
});
}