use oxicuda_driver::ffi::CUstream;
use oxicuda_driver::loader::try_driver;
use crate::error::{CudaRtError, CudaRtResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub struct StreamFlags(pub u32);
impl StreamFlags {
pub const DEFAULT: Self = Self(0x0);
pub const NON_BLOCKING: Self = Self(0x1);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct CudaStream(CUstream);
impl CudaStream {
pub const DEFAULT: Self = Self(CUstream(std::ptr::null_mut()));
pub const PER_THREAD: Self = Self(CUstream(2 as *mut std::ffi::c_void));
#[must_use]
pub const unsafe fn from_raw(raw: CUstream) -> Self {
Self(raw)
}
#[must_use]
pub fn raw(self) -> CUstream {
self.0
}
#[must_use]
pub fn is_default(self) -> bool {
self.0.is_null()
}
}
pub fn stream_create() -> CudaRtResult<CudaStream> {
stream_create_with_flags(StreamFlags::DEFAULT)
}
pub fn stream_create_with_flags(flags: StreamFlags) -> CudaRtResult<CudaStream> {
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let mut stream = CUstream::default();
let rc = unsafe { (api.cu_stream_create)(&raw mut stream, flags.0) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidResourceHandle));
}
Ok(CudaStream(stream))
}
pub fn stream_create_with_priority(flags: StreamFlags, priority: i32) -> CudaRtResult<CudaStream> {
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let mut stream = CUstream::default();
let rc = unsafe { (api.cu_stream_create_with_priority)(&raw mut stream, flags.0, priority) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidResourceHandle));
}
Ok(CudaStream(stream))
}
pub fn stream_destroy(stream: CudaStream) -> CudaRtResult<()> {
if stream.is_default() {
return Ok(()); }
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let rc = unsafe { (api.cu_stream_destroy_v2)(stream.raw()) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidResourceHandle));
}
Ok(())
}
pub fn stream_synchronize(stream: CudaStream) -> CudaRtResult<()> {
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let rc = unsafe { (api.cu_stream_synchronize)(stream.raw()) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::Unknown));
}
Ok(())
}
pub fn stream_query(stream: CudaStream) -> CudaRtResult<bool> {
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let rc = unsafe { (api.cu_stream_query)(stream.raw()) };
match rc {
0 => Ok(true), 600 => Ok(false), other => Err(CudaRtError::from_code(other).unwrap_or(CudaRtError::Unknown)),
}
}
pub fn stream_wait_event(
stream: CudaStream,
event: crate::event::CudaEvent,
flags: u32,
) -> CudaRtResult<()> {
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let rc = unsafe { (api.cu_stream_wait_event)(stream.raw(), event.raw(), flags) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidResourceHandle));
}
Ok(())
}
pub fn stream_get_priority(stream: CudaStream) -> CudaRtResult<i32> {
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let mut priority: std::ffi::c_int = 0;
let rc = unsafe { (api.cu_stream_get_priority)(stream.raw(), &raw mut priority) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidResourceHandle));
}
Ok(priority)
}
pub fn stream_get_flags(stream: CudaStream) -> CudaRtResult<StreamFlags> {
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let mut flags: u32 = 0;
let rc = unsafe { (api.cu_stream_get_flags)(stream.raw(), &raw mut flags) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidResourceHandle));
}
Ok(StreamFlags(flags))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_stream_is_null() {
assert!(CudaStream::DEFAULT.is_default());
assert!(!CudaStream::PER_THREAD.is_default());
}
#[test]
fn stream_flags_values() {
assert_eq!(StreamFlags::DEFAULT.0, 0);
assert_eq!(StreamFlags::NON_BLOCKING.0, 1);
}
#[test]
fn stream_destroy_default_is_noop() {
let result = stream_destroy(CudaStream::DEFAULT);
let _ = result;
}
#[test]
fn stream_create_without_gpu_returns_error() {
let result = stream_create();
assert!(result.is_ok() || result.is_err());
}
}