use std::{iter, marker::PhantomData, ptr, sync::Arc};
use num_enum::{IntoPrimitive, TryFromPrimitive};
use singe_core::impl_enum_conversion;
use singe_cuda_sys::runtime;
use crate::{
context::Context,
device::Device,
error::{Error, Result},
event::Event,
graph::{Graph, GraphDependency, GraphEdgeData, GraphNode},
try_cuda,
};
bitflags::bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct StreamFlags: u32 {
const DEFAULT = runtime::cudaStreamDefault;
const NON_BLOCKING = runtime::cudaStreamNonBlocking;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
#[repr(u32)]
pub enum StreamCaptureStatus {
None = runtime::cudaStreamCaptureStatus::CU_STREAM_CAPTURE_STATUS_NONE as _,
Active = runtime::cudaStreamCaptureStatus::CU_STREAM_CAPTURE_STATUS_ACTIVE as _,
Invalidated = runtime::cudaStreamCaptureStatus::CU_STREAM_CAPTURE_STATUS_INVALIDATED as _,
}
impl_enum_conversion!(u32, runtime::cudaStreamCaptureStatus, StreamCaptureStatus);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
#[repr(u32)]
pub enum StreamCaptureMode {
Global = runtime::cudaStreamCaptureMode::CU_STREAM_CAPTURE_MODE_GLOBAL as _,
ThreadLocal = runtime::cudaStreamCaptureMode::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL as _,
Relaxed = runtime::cudaStreamCaptureMode::CU_STREAM_CAPTURE_MODE_RELAXED as _,
}
impl_enum_conversion!(u32, runtime::cudaStreamCaptureMode, StreamCaptureMode);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
#[repr(u32)]
pub enum StreamCaptureDependencyUpdate {
Add = runtime::cudaStreamUpdateCaptureDependenciesFlags::cudaStreamAddCaptureDependencies as _,
Set = runtime::cudaStreamUpdateCaptureDependenciesFlags::cudaStreamSetCaptureDependencies as _,
}
impl_enum_conversion!(
u32,
runtime::cudaStreamUpdateCaptureDependenciesFlags,
StreamCaptureDependencyUpdate,
);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct StreamCaptureInfo {
pub status: StreamCaptureStatus,
pub id: u64,
}
type RustStreamCallbackDyn = Box<dyn FnOnce(Result<()>) + Send + 'static>;
type BoxedCallbackPtr = *mut RustStreamCallbackDyn;
#[derive(Debug)]
pub struct Stream {
handle: runtime::cudaStream_t,
ctx: Arc<Context>,
}
#[derive(Debug)]
pub struct StreamScope<'scope, 'env> {
stream: &'scope Stream,
_env: PhantomData<&'env mut &'env ()>,
}
#[derive(Debug, Clone)]
pub struct BorrowedStream {
handle: runtime::cudaStream_t,
ctx: Arc<Context>,
}
#[derive(Debug, Clone)]
pub enum StreamBinding {
Default(Arc<Context>),
Borrowed(BorrowedStream),
}
impl Stream {
pub const fn from_raw(handle: runtime::cudaStream_t, ctx: Arc<Context>) -> Self {
Self { handle, ctx }
}
pub fn scope<'env, F, R>(&self, f: F) -> Result<R>
where
F: for<'scope> FnOnce(&'scope StreamScope<'scope, 'env>) -> Result<R>,
{
let scope = StreamScope {
stream: self,
_env: PhantomData,
};
let result = f(&scope);
let sync_result = self.synchronize();
match (result, sync_result) {
(Ok(value), Ok(())) => Ok(value),
(Ok(_), Err(err)) => Err(err),
(Err(err), Ok(())) | (Err(err), Err(_)) => Err(err),
}
}
pub fn synchronize(&self) -> Result<()> {
self.ctx.bind()?;
unsafe { try_cuda!(runtime::cudaStreamSynchronize(self.as_raw())) }
}
pub fn query(&self) -> Result<bool> {
let error = unsafe { runtime::cudaStreamQuery(self.as_raw()) };
match error {
runtime::cudaError_t::CUDA_SUCCESS => Ok(true),
runtime::cudaError_t::CUDA_ERROR_NOT_READY => Ok(false),
_ => Err(error.into()),
}
}
pub fn wait_event(&self, event: &Event) -> Result<()> {
self.ctx.bind()?;
unsafe {
try_cuda!(runtime::cudaStreamWaitEvent(
self.as_raw(),
event.as_raw(),
0,
))
}
}
pub fn begin_capture(&self, mode: StreamCaptureMode) -> Result<()> {
self.ctx.bind()?;
unsafe {
try_cuda!(runtime::cudaStreamBeginCapture(self.as_raw(), mode.into()))?;
}
Ok(())
}
pub fn begin_capture_to_graph(
&self,
graph: &Graph,
dependencies: &[GraphNode],
mode: StreamCaptureMode,
) -> Result<()> {
self.begin_capture_to_graph_with_data(graph, dependencies, &[], mode)
}
pub fn begin_capture_to_graph_with_data(
&self,
graph: &Graph,
dependencies: &[GraphNode],
edge_data: &[GraphEdgeData],
mode: StreamCaptureMode,
) -> Result<()> {
if !edge_data.is_empty() && edge_data.len() != dependencies.len() {
return Err(Error::GraphDependencyMismatch);
}
let dependencies: Vec<_> = dependencies
.iter()
.zip(
edge_data
.iter()
.copied()
.chain(iter::repeat(GraphEdgeData::default())),
)
.map(|(&node, data)| GraphDependency { node, data })
.collect();
self.begin_capture_to_graph_with_dependencies(graph, &dependencies, mode)
}
pub fn begin_capture_to_graph_with_dependencies(
&self,
graph: &Graph,
dependencies: &[GraphDependency],
mode: StreamCaptureMode,
) -> Result<()> {
self.ctx.bind()?;
let dependencies_raw: Vec<_> = dependencies
.iter()
.map(|dependency| unsafe { dependency.node.as_raw() })
.collect();
let edge_data_raw: Vec<_> = dependencies
.iter()
.map(|dependency| dependency.data.into())
.collect();
unsafe {
try_cuda!(runtime::cudaStreamBeginCaptureToGraph(
self.as_raw(),
graph.as_raw(),
dependencies_raw.as_ptr(),
if edge_data_raw.is_empty() {
ptr::null()
} else {
edge_data_raw.as_ptr()
},
dependencies_raw.len() as _,
mode.into(),
))?;
}
Ok(())
}
pub fn end_capture(&self) -> Result<Graph> {
self.ctx.bind()?;
let mut handle = ptr::null_mut();
unsafe {
try_cuda!(runtime::cudaStreamEndCapture(
self.as_raw(),
&raw mut handle
))?;
Ok(Graph::from_raw(handle))
}
}
pub fn capture_status(&self) -> Result<StreamCaptureStatus> {
self.ctx.bind()?;
let mut status = runtime::cudaStreamCaptureStatus::CU_STREAM_CAPTURE_STATUS_NONE;
unsafe {
try_cuda!(runtime::cudaStreamIsCapturing(
self.as_raw(),
&raw mut status
))?;
}
Ok(status.into())
}
pub fn capture_info(&self) -> Result<StreamCaptureInfo> {
self.ctx.bind()?;
let mut status = runtime::cudaStreamCaptureStatus::CU_STREAM_CAPTURE_STATUS_NONE;
let mut id = 0;
unsafe {
try_cuda!(runtime::cudaStreamGetCaptureInfo(
self.as_raw(),
&raw mut status,
&raw mut id,
ptr::null_mut(),
ptr::null_mut(),
ptr::null_mut(),
ptr::null_mut(),
))?;
}
Ok(StreamCaptureInfo {
status: status.into(),
id,
})
}
pub fn update_capture_dependencies(&self, dependencies: &[GraphNode]) -> Result<()> {
self.update_capture_dependencies_with_mode(
dependencies,
&[],
StreamCaptureDependencyUpdate::Add,
)
}
pub fn update_capture_dependencies_with_data(
&self,
dependencies: &[GraphNode],
edge_data: &[GraphEdgeData],
) -> Result<()> {
self.update_capture_dependencies_with_mode(
dependencies,
edge_data,
StreamCaptureDependencyUpdate::Add,
)
}
pub fn update_capture_dependencies_with_mode(
&self,
dependencies: &[GraphNode],
edge_data: &[GraphEdgeData],
mode: StreamCaptureDependencyUpdate,
) -> Result<()> {
if !edge_data.is_empty() && edge_data.len() != dependencies.len() {
return Err(Error::GraphDependencyMismatch);
}
let dependencies: Vec<_> = dependencies
.iter()
.zip(
edge_data
.iter()
.copied()
.chain(iter::repeat(GraphEdgeData::default())),
)
.map(|(&node, data)| GraphDependency { node, data })
.collect();
self.update_capture_dependencies_with_dependencies(&dependencies, mode)
}
pub fn update_capture_dependencies_with_dependencies(
&self,
dependencies: &[GraphDependency],
mode: StreamCaptureDependencyUpdate,
) -> Result<()> {
self.ctx.bind()?;
let mut dependencies_raw: Vec<_> = dependencies
.iter()
.map(|dependency| unsafe { dependency.node.as_raw() })
.collect();
let edge_data_raw: Vec<_> = dependencies
.iter()
.map(|dependency| dependency.data.into())
.collect();
unsafe {
try_cuda!(runtime::cudaStreamUpdateCaptureDependencies(
self.as_raw(),
dependencies_raw.as_mut_ptr(),
if edge_data_raw.is_empty() {
ptr::null()
} else {
edge_data_raw.as_ptr()
},
dependencies_raw.len() as _,
mode.into(),
))?;
}
Ok(())
}
pub fn add_callback<F>(&self, callback: F) -> Result<()>
where
F: FnOnce(Result<()>) + Send + 'static,
{
self.ctx.bind()?;
let boxed_dyn_callback: RustStreamCallbackDyn = Box::new(callback);
let boxed_wrapper: Box<RustStreamCallbackDyn> = Box::new(boxed_dyn_callback);
let user_data_ptr: BoxedCallbackPtr = Box::into_raw(boxed_wrapper);
let final_user_data = user_data_ptr.cast();
let flags = 0u32;
unsafe {
let status = runtime::cudaStreamAddCallback(
self.as_raw(),
Some(stream_callback_trampoline),
final_user_data, flags,
);
if status != runtime::cudaError_t::CUDA_SUCCESS {
let _leaked_box = Box::from_raw(user_data_ptr);
try_cuda!(status)?;
}
}
Ok(())
}
pub fn flags(&self) -> Result<StreamFlags> {
self.ctx.bind()?;
let mut flags_raw = 0u32;
unsafe {
try_cuda!(runtime::cudaStreamGetFlags(
self.as_raw(),
&raw mut flags_raw
))?;
}
Ok(StreamFlags::from_bits_retain(flags_raw))
}
pub fn priority(&self) -> Result<i32> {
self.ctx.bind()?;
let mut priority = 0i32;
unsafe {
try_cuda!(runtime::cudaStreamGetPriority(
self.as_raw(),
&raw mut priority
))?;
}
Ok(priority)
}
pub fn id(&self) -> Result<u64> {
self.ctx.bind()?;
let mut id = 0u64;
unsafe {
try_cuda!(runtime::cudaStreamGetId(self.as_raw(), &raw mut id))?;
}
Ok(id)
}
pub fn device(&self) -> Result<Device> {
self.ctx.bind()?;
let mut device = 0i32;
unsafe {
try_cuda!(runtime::cudaStreamGetDevice(self.as_raw(), &raw mut device))?;
}
Ok(Device::new(device))
}
pub fn context(&self) -> &Context {
&self.ctx
}
pub const unsafe fn as_raw(&self) -> runtime::cudaStream_t {
self.handle
}
}
impl<'scope, 'env> StreamScope<'scope, 'env> {
pub const fn stream(&self) -> &'scope Stream {
self.stream
}
pub fn synchronize(&self) -> Result<()> {
self.stream.synchronize()
}
}
impl BorrowedStream {
pub const fn from_raw(handle: runtime::cudaStream_t, ctx: Arc<Context>) -> Self {
Self { handle, ctx }
}
pub fn context(&self) -> &Context {
&self.ctx
}
pub const fn as_raw(&self) -> runtime::cudaStream_t {
self.handle
}
}
impl StreamBinding {
pub fn context(&self) -> &Context {
match self {
Self::Default(ctx) => ctx.as_ref(),
Self::Borrowed(stream) => stream.context(),
}
}
pub fn is_default(&self) -> bool {
matches!(self, Self::Default(..))
}
pub const fn as_raw(&self) -> runtime::cudaStream_t {
match self {
Self::Default(_) => ptr::null_mut(),
Self::Borrowed(stream) => stream.as_raw(),
}
}
}
unsafe impl Send for Stream {}
unsafe impl Sync for Stream {}
impl Drop for Stream {
fn drop(&mut self) {
if let Err(err) = self.ctx.bind() {
#[cfg(debug_assertions)]
eprintln!("failed to bind context before destroying stream: {err}");
}
unsafe {
if let Err(err) = try_cuda!(runtime::cudaStreamSynchronize(self.handle)) {
#[cfg(debug_assertions)]
eprintln!("failed to synchronize stream before destroy: {err}");
}
if let Err(err) = try_cuda!(runtime::cudaStreamDestroy(self.handle)) {
#[cfg(debug_assertions)]
eprintln!("failed to destroy CUDA stream: {err}");
}
}
}
}
extern "C" fn stream_callback_trampoline(
_stream: runtime::cudaStream_t,
status: runtime::cudaError_t,
user_data: *mut std::ffi::c_void,
) {
if user_data.is_null() {
return;
}
let user_data_ptr = user_data as BoxedCallbackPtr;
let boxed_callback: Box<RustStreamCallbackDyn> = unsafe { Box::from_raw(user_data_ptr) };
let callback: RustStreamCallbackDyn = *boxed_callback;
let result = if status == runtime::cudaError_t::CUDA_SUCCESS {
Ok(())
} else {
Err(status.into())
};
callback(result);
}
impl Context {
pub fn create_stream(self: &Arc<Self>) -> Result<Stream> {
self.create_stream_with_flags(StreamFlags::DEFAULT)
}
pub fn create_stream_with_flags(self: &Arc<Self>, flags: StreamFlags) -> Result<Stream> {
self.bind()?;
let mut handle = ptr::null_mut();
unsafe {
try_cuda!(runtime::cudaStreamCreateWithFlags(
&raw mut handle,
flags.bits(),
))?;
}
if handle.is_null() {
return Err(Error::NullHandle);
}
Ok(Stream::from_raw(handle, Arc::clone(self)))
}
pub fn create_stream_with_priority(
self: &Arc<Self>,
flags: StreamFlags,
priority: i32,
) -> Result<Stream> {
self.bind()?;
let mut handle = ptr::null_mut();
unsafe {
try_cuda!(runtime::cudaStreamCreateWithPriority(
&raw mut handle,
flags.bits(),
priority,
))?;
}
if handle.is_null() {
return Err(Error::NullHandle);
}
Ok(Stream::from_raw(handle, Arc::clone(self)))
}
}
pub fn exchange_capture_mode(mode: StreamCaptureMode) -> Result<StreamCaptureMode> {
let mut mode_raw: runtime::cudaStreamCaptureMode = mode.into();
unsafe {
try_cuda!(runtime::cudaThreadExchangeStreamCaptureMode(
&raw mut mode_raw
))?;
}
Ok(mode_raw.into())
}
#[cfg(all(test, feature = "testing"))]
mod tests {
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering},
};
use super::*;
use crate::testing;
#[test]
fn it_works() -> Result<()> {
let _lock = testing::device_lock(0)?;
let ctx = match Context::create() {
Ok(ctx) => ctx,
Err(error) if testing::is_stub_library(&error) => return Ok(()),
Err(error) => return Err(error),
};
let stream1 = ctx.create_stream()?;
let _stream2 = ctx.create_stream_with_flags(StreamFlags::NON_BLOCKING)?;
let stream1_called = Arc::new(AtomicBool::new(false));
stream1.add_callback(Box::new({
let stream1_called = Arc::clone(&stream1_called);
move |_status| {
stream1_called.store(true, Ordering::SeqCst);
}
}))?;
let is_done = stream1.query()?;
assert!(!is_done);
stream1.synchronize()?;
let is_done_after = stream1.query()?;
assert!(is_done_after);
assert!(stream1_called.load(Ordering::SeqCst));
Ok(())
}
}