use std::sync::Arc;
use baracuda_cuda_sys::runtime::{cudaStream_t, runtime, types::cudaStreamFlags};
use crate::device::Device;
use crate::error::{check, Result};
#[derive(Clone)]
pub struct Stream {
inner: Arc<StreamInner>,
}
struct StreamInner {
handle: cudaStream_t,
device: Device,
}
unsafe impl Send for StreamInner {}
unsafe impl Sync for StreamInner {}
impl core::fmt::Debug for StreamInner {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Stream")
.field("handle", &self.handle)
.field("device", &self.device)
.finish()
}
}
impl core::fmt::Debug for Stream {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.inner.fmt(f)
}
}
impl Stream {
pub fn new() -> Result<Self> {
Self::with_flags(cudaStreamFlags::DEFAULT)
}
pub fn non_blocking() -> Result<Self> {
Self::with_flags(cudaStreamFlags::NON_BLOCKING)
}
pub unsafe fn from_raw(handle: cudaStream_t) -> Self {
let device = Device::current().unwrap_or(Device::from_ordinal(0));
Self {
inner: Arc::new(StreamInner { handle, device }),
}
}
pub fn with_flags(flags: u32) -> Result<Self> {
let r = runtime()?;
let cu = r.cuda_stream_create_with_flags()?;
let mut stream: cudaStream_t = core::ptr::null_mut();
check(unsafe { cu(&mut stream, flags) })?;
let device = Device::current()?;
Ok(Self {
inner: Arc::new(StreamInner {
handle: stream,
device,
}),
})
}
pub fn synchronize(&self) -> Result<()> {
let r = runtime()?;
let cu = r.cuda_stream_synchronize()?;
check(unsafe { cu(self.inner.handle) })
}
pub fn is_complete(&self) -> Result<bool> {
use baracuda_cuda_sys::runtime::cudaError_t;
let r = runtime()?;
let cu = r.cuda_stream_query()?;
match unsafe { cu(self.inner.handle) } {
cudaError_t::Success => Ok(true),
cudaError_t::NotReady => Ok(false),
other => Err(crate::error::Error::Status { status: other }),
}
}
#[inline]
pub fn device(&self) -> Device {
self.inner.device
}
#[inline]
pub fn as_raw(&self) -> cudaStream_t {
self.inner.handle
}
pub fn with_priority(flags: u32, priority: i32) -> Result<Self> {
let r = runtime()?;
let cu = r.cuda_stream_create_with_priority()?;
let mut stream: cudaStream_t = core::ptr::null_mut();
check(unsafe { cu(&mut stream, flags, priority) })?;
let device = Device::current()?;
Ok(Self {
inner: Arc::new(StreamInner {
handle: stream,
device,
}),
})
}
pub fn priority(&self) -> Result<i32> {
let r = runtime()?;
let cu = r.cuda_stream_get_priority()?;
let mut p: core::ffi::c_int = 0;
check(unsafe { cu(self.inner.handle, &mut p) })?;
Ok(p)
}
pub fn flags(&self) -> Result<u32> {
let r = runtime()?;
let cu = r.cuda_stream_get_flags()?;
let mut f: core::ffi::c_uint = 0;
check(unsafe { cu(self.inner.handle, &mut f) })?;
Ok(f)
}
pub fn wait_event(&self, event: &crate::Event, flags: u32) -> Result<()> {
let r = runtime()?;
let cu = r.cuda_stream_wait_event()?;
check(unsafe { cu(self.inner.handle, event.as_raw(), flags) })
}
}
pub fn stream_priority_range() -> Result<(i32, i32)> {
let r = runtime()?;
let cu = r.cuda_device_get_stream_priority_range()?;
let mut low: core::ffi::c_int = 0;
let mut high: core::ffi::c_int = 0;
check(unsafe { cu(&mut low, &mut high) })?;
Ok((low, high))
}
impl Stream {
pub fn launch_host_func<F>(&self, f: F) -> Result<()>
where
F: FnOnce() + Send + 'static,
{
use core::ffi::c_void;
let boxed: Box<Box<dyn FnOnce() + Send>> = Box::new(Box::new(f));
let raw = Box::into_raw(boxed) as *mut c_void;
unsafe extern "C" fn trampoline(user_data: *mut c_void) {
let f: Box<Box<dyn FnOnce() + Send>> =
unsafe { Box::from_raw(user_data as *mut Box<dyn FnOnce() + Send>) };
(*f)();
}
let r = runtime()?;
let cu = r.cuda_launch_host_func()?;
let rc = unsafe { cu(self.inner.handle, Some(trampoline), raw) };
if rc != baracuda_cuda_sys::runtime::cudaError_t::Success {
drop(unsafe { Box::from_raw(raw as *mut Box<dyn FnOnce() + Send>) });
return Err(crate::error::Error::Status { status: rc });
}
Ok(())
}
pub unsafe fn write_value_32(
&self,
addr: *mut core::ffi::c_void,
value: u32,
flags: u32,
) -> Result<()> { unsafe {
let r = runtime()?;
let cu = r.cuda_stream_write_value_32()?;
check(cu(self.inner.handle, addr, value, flags))
}}
pub unsafe fn write_value_64(
&self,
addr: *mut core::ffi::c_void,
value: u64,
flags: u32,
) -> Result<()> { unsafe {
let r = runtime()?;
let cu = r.cuda_stream_write_value_64()?;
check(cu(self.inner.handle, addr, value, flags))
}}
pub unsafe fn wait_value_32(
&self,
addr: *mut core::ffi::c_void,
value: u32,
flags: u32,
) -> Result<()> { unsafe {
let r = runtime()?;
let cu = r.cuda_stream_wait_value_32()?;
check(cu(self.inner.handle, addr, value, flags))
}}
pub unsafe fn wait_value_64(
&self,
addr: *mut core::ffi::c_void,
value: u64,
flags: u32,
) -> Result<()> { unsafe {
let r = runtime()?;
let cu = r.cuda_stream_wait_value_64()?;
check(cu(self.inner.handle, addr, value, flags))
}}
pub unsafe fn attach_mem_async(
&self,
dev_ptr: *mut core::ffi::c_void,
length: usize,
flags: u32,
) -> Result<()> { unsafe {
let r = runtime()?;
let cu = r.cuda_stream_attach_mem_async()?;
check(cu(self.inner.handle, dev_ptr, length, flags))
}}
pub fn copy_attributes_from(&self, src: &Stream) -> Result<()> {
let r = runtime()?;
let cu = r.cuda_stream_copy_attributes()?;
check(unsafe { cu(self.inner.handle, src.inner.handle) })
}
pub unsafe fn batch_mem_op(
&self,
params: &mut [baracuda_cuda_sys::types::CUstreamBatchMemOpParams],
flags: u32,
) -> Result<()> { unsafe {
let r = runtime()?;
let cu = r.cuda_stream_batch_mem_op()?;
check(cu(
self.inner.handle,
params.len() as core::ffi::c_uint,
params.as_mut_ptr(),
flags,
))
}}
}
impl Drop for StreamInner {
fn drop(&mut self) {
if let Ok(r) = runtime() {
if let Ok(cu) = r.cuda_stream_destroy() {
let _ = unsafe { cu(self.handle) };
}
}
}
}