use cpp::cpp;
use crate::device::DeviceId;
use crate::ffi::device::Device;
use crate::ffi::ptr::DevicePtr;
use crate::ffi::result;
type Result<T> = std::result::Result<T, crate::error::Error>;
pub struct Stream {
internal: DevicePtr,
device: DeviceId,
}
unsafe impl Send for Stream {}
unsafe impl Sync for Stream {}
impl Stream {
pub fn null() -> Self {
Self {
internal: unsafe { DevicePtr::null() },
device: Device::get_or_panic(),
}
}
pub fn new() -> Result<Self> {
let device = Device::get()?;
let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
let ptr_ptr = std::ptr::addr_of_mut!(ptr);
let ret = cpp!(unsafe [
ptr_ptr as "void**"
] -> i32 as "std::int32_t" {
return cudaStreamCreate((cudaStream_t*) ptr_ptr);
});
result!(
ret,
Stream {
internal: DevicePtr::from_addr(ptr),
device,
}
)
}
pub fn synchronize(&self) -> Result<()> {
Device::set(self.device)?;
let ptr = self.internal.as_ptr();
let ret = cpp!(unsafe [
ptr as "void*"
] -> i32 as "std::int32_t" {
return cudaStreamSynchronize((cudaStream_t) ptr);
});
result!(ret)
}
pub fn add_callback(&self, f: impl FnOnce() + Send) -> Result<()> {
Device::set(self.device)?;
let ptr = self.internal.as_ptr();
let f_boxed = Box::new(f) as Box<dyn FnOnce()>;
let f_boxed2 = Box::new(f_boxed);
let f_boxed2_ptr = Box::into_raw(f_boxed2);
let user_data = f_boxed2_ptr as *mut std::ffi::c_void;
let ret = cpp!(unsafe [
ptr as "void*",
user_data as "void*"
] -> i32 as "std::int32_t" {
return cudaStreamAddCallback(
(cudaStream_t) ptr,
cuda_ffi_Callback,
user_data,
0
);
});
result!(ret)
}
#[inline(always)]
pub fn as_internal(&self) -> &DevicePtr {
&self.internal
}
#[inline(always)]
pub fn as_mut_internal(&mut self) -> &mut DevicePtr {
&mut self.internal
}
#[inline(always)]
pub fn device(&self) -> DeviceId {
self.device
}
pub unsafe fn destroy(&mut self) {
if self.internal.is_null() {
return;
}
Device::set_or_panic(self.device);
let mut internal = unsafe { self.internal.take() };
let ptr = internal.as_mut_ptr();
let _ret = cpp!(unsafe [
ptr as "void*"
] -> i32 as "std::int32_t" {
return cudaStreamSynchronize((cudaStream_t) ptr);
});
let _ret = cpp!(unsafe [
ptr as "void*"
] -> i32 as "std::int32_t" {
return cudaStreamDestroy((cudaStream_t) ptr);
});
}
}
impl Drop for Stream {
#[inline]
fn drop(&mut self) {
unsafe {
self.destroy();
}
}
}
cpp! {{
static void cuda_ffi_Callback(
__attribute__((unused)) cudaStream_t stream,
cudaError_t status,
void* user_data
) {
rust!(cuda_ffi_Callback_internal [
status : i32 as "std::int32_t",
user_data : *mut std::ffi::c_void as "void*"
] {
unsafe {
let user_data = std::mem::transmute(user_data);
let function = Box::<Box<dyn FnOnce()>>::from_raw(user_data);
function()
}
});
}
}}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
assert!(Stream::new().is_ok());
}
#[test]
fn test_synchronize() {
let stream = Stream::new().unwrap();
assert!(stream.synchronize().is_ok());
}
#[test]
fn test_synchronize_null_stream() {
let stream = Stream::null();
assert!(stream.synchronize().is_ok());
}
}