singe_cuda/stream.rs
1#[allow(unused_imports)]
2use crate::error::Status;
3
4use std::{iter, marker::PhantomData, mem::ManuallyDrop, panic, ptr, sync::Arc};
5
6use num_enum::{IntoPrimitive, TryFromPrimitive};
7use singe_core::impl_enum_conversion;
8use singe_cuda_sys::runtime;
9
10use crate::{
11 context::Context,
12 device::Device,
13 error::{Error, Result},
14 event::Event,
15 graph::{
16 ExecutableGraph, Graph, GraphDependency, GraphEdgeData, GraphInstantiateFlags, GraphNode,
17 },
18 try_ffi,
19};
20
21bitflags::bitflags! {
22 /// Flags for CUDA stream creation ([`Context::create_stream_with_flags`]).
23 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24 pub struct StreamFlags: u32 {
25 const DEFAULT = runtime::cudaStreamDefault;
26 const NON_BLOCKING = runtime::cudaStreamNonBlocking;
27 }
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
31#[repr(u32)]
32#[non_exhaustive]
33pub enum StreamCaptureStatus {
34 None = runtime::cudaStreamCaptureStatus::CU_STREAM_CAPTURE_STATUS_NONE as _,
35 Active = runtime::cudaStreamCaptureStatus::CU_STREAM_CAPTURE_STATUS_ACTIVE as _,
36 Invalidated = runtime::cudaStreamCaptureStatus::CU_STREAM_CAPTURE_STATUS_INVALIDATED as _,
37}
38
39impl_enum_conversion!(u32, runtime::cudaStreamCaptureStatus, StreamCaptureStatus);
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
42#[repr(u32)]
43#[non_exhaustive]
44pub enum StreamCaptureMode {
45 Global = runtime::cudaStreamCaptureMode::CU_STREAM_CAPTURE_MODE_GLOBAL as _,
46 ThreadLocal = runtime::cudaStreamCaptureMode::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL as _,
47 Relaxed = runtime::cudaStreamCaptureMode::CU_STREAM_CAPTURE_MODE_RELAXED as _,
48}
49
50impl_enum_conversion!(u32, runtime::cudaStreamCaptureMode, StreamCaptureMode);
51
52/// Flags for [`Stream::update_capture_dependencies_with_dependencies`]
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
54#[repr(u32)]
55#[non_exhaustive]
56pub enum StreamCaptureDependencyUpdate {
57 /// Add new nodes to the dependency set.
58 Add = runtime::cudaStreamUpdateCaptureDependenciesFlags::cudaStreamAddCaptureDependencies as _,
59 /// Replace the dependency set with the new nodes.
60 Set = runtime::cudaStreamUpdateCaptureDependenciesFlags::cudaStreamSetCaptureDependencies as _,
61}
62
63impl_enum_conversion!(
64 u32,
65 runtime::cudaStreamUpdateCaptureDependenciesFlags,
66 StreamCaptureDependencyUpdate,
67);
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
70pub struct StreamCaptureInfo {
71 pub status: StreamCaptureStatus,
72 pub id: u64,
73}
74
75// Type alias for the trait object Box itself (inner box).
76type RustStreamCallbackDyn = Box<dyn FnOnce(Result<()>) + Send + 'static>;
77
78// Type alias for the pointer type stored in the outer box.
79type BoxedCallbackPtr = *mut RustStreamCallbackDyn;
80
81type RustHostFunctionDyn = Box<dyn FnOnce() + Send + 'static>;
82type BoxedHostFunctionPtr = *mut RustHostFunctionDyn;
83
84#[derive(Debug, Clone)]
85pub struct Stream {
86 inner: Arc<StreamInner>,
87}
88
89#[derive(Debug)]
90struct StreamInner {
91 handle: runtime::cudaStream_t,
92 ctx: Arc<Context>,
93 // TODO: Store device ID? Could be useful for multi-GPU.
94 // device_id: DeviceId,
95}
96
97impl PartialEq for Stream {
98 fn eq(&self, other: &Self) -> bool {
99 self.as_raw() == other.as_raw() && Arc::ptr_eq(&self.inner.ctx, &other.inner.ctx)
100 }
101}
102
103impl Eq for Stream {}
104
105#[derive(Debug)]
106pub struct StreamScope<'scope, 'env> {
107 stream: &'scope Stream,
108 _env: PhantomData<&'env mut &'env ()>,
109}
110
111#[derive(Debug)]
112pub struct StreamCaptureScope<'scope> {
113 stream: &'scope Stream,
114 _not_send: PhantomData<*const ()>,
115}
116
117/// Operation that may be recorded into a CUDA graph capture scope.
118///
119/// # Safety
120///
121/// Implementors must only enqueue CUDA work that is valid during stream
122/// capture. Every pointer, handle, and side effect captured into the resulting
123/// graph must have its replay safety contract represented by the operation's
124/// type and constructor.
125pub unsafe trait GraphRecordable {
126 type Output;
127
128 fn record(self, scope: &StreamCaptureScope<'_>) -> Result<Self::Output>;
129}
130
131struct ActiveStreamCapture<'stream> {
132 stream: &'stream Stream,
133 finished: bool,
134}
135
136impl ActiveStreamCapture<'_> {
137 fn finish(mut self) -> Result<Graph> {
138 self.finished = true;
139 self.stream.end_capture()
140 }
141
142 fn discard(mut self) {
143 self.finished = true;
144 drop(self.stream.end_capture());
145 }
146}
147
148impl Drop for ActiveStreamCapture<'_> {
149 fn drop(&mut self) {
150 if !self.finished {
151 drop(self.stream.end_capture());
152 }
153 }
154}
155
156#[derive(Debug, Clone)]
157pub struct BorrowedStream {
158 handle: runtime::cudaStream_t,
159 ctx: Arc<Context>,
160}
161
162#[derive(Debug, Clone)]
163#[non_exhaustive]
164pub enum StreamBinding {
165 Default(Arc<Context>),
166 Borrowed(BorrowedStream),
167}
168
169impl Stream {
170 /// Wraps an existing CUDA stream handle and takes ownership of it.
171 ///
172 /// Dropping the returned stream may block while synchronizing the stream
173 /// before destruction. Use [`Stream::shutdown`] to surface synchronization
174 /// or destruction errors explicitly.
175 ///
176 /// # Safety
177 ///
178 /// `handle` must be a valid CUDA stream owned by `ctx`, and ownership of
179 /// the handle is transferred to the returned [`Stream`]. The handle must
180 /// not be destroyed elsewhere after calling this function.
181 pub unsafe fn from_raw(handle: runtime::cudaStream_t, ctx: Arc<Context>) -> Result<Self> {
182 if handle.is_null() {
183 return Err(Error::NullHandle);
184 }
185
186 Ok(Self {
187 inner: Arc::new(StreamInner { handle, ctx }),
188 })
189 }
190
191 pub fn to_borrowed(&self) -> BorrowedStream {
192 unsafe { BorrowedStream::from_raw(self.as_raw(), Arc::clone(&self.inner.ctx)) }
193 }
194
195 /// Runs `f` with a stream scope and synchronizes this stream before returning.
196 ///
197 /// Use this for scoped asynchronous operations that borrow host or device
198 /// memory until stream completion. For CUDA graph capture, use
199 /// [`Stream::capture`] or [`Stream::capture_executable`].
200 pub fn sync_scope<'env, F, R>(&self, f: F) -> Result<R>
201 where
202 F: for<'scope> FnOnce(&'scope StreamScope<'scope, 'env>) -> Result<R>,
203 {
204 let scope = StreamScope {
205 stream: self,
206 _env: PhantomData,
207 };
208 let result = f(&scope);
209 let sync_result = self.synchronize();
210
211 match (result, sync_result) {
212 (Ok(value), Ok(())) => Ok(value),
213 (Ok(_), Err(err)) => Err(err),
214 (Err(err), Ok(())) | (Err(err), Err(_)) => Err(err),
215 }
216 }
217
218 /// Blocks until stream has completed all operations.
219 /// If [`ContextFlags::SCHEDULE_BLOCKING_SYNC`](crate::context::ContextFlags::SCHEDULE_BLOCKING_SYNC) was set for this device, the host thread will block until the stream is finished with all of its tasks.
220 ///
221 /// Uses standard `default stream` semantics.
222 ///
223 /// # Errors
224 ///
225 /// Returns an error if stream synchronization fails or if a previous asynchronous launch
226 /// reported an error. CUDA may also return initialization-related errors such as
227 /// [`crate::error::Status::NotInitialized`], [`crate::error::Status::CallRequiresNewerDriver`], or
228 /// [`crate::error::Status::NoDevice`] if this call initializes internal runtime state. Callbacks must not
229 /// call CUDA functions; see [`Stream::add_callback`].
230 pub fn synchronize(&self) -> Result<()> {
231 self.inner.ctx.bind()?;
232 unsafe { try_ffi!(runtime::cudaStreamSynchronize(self.as_raw())) }
233 }
234
235 /// Synchronizes this stream, destroys it, and returns any CUDA error.
236 ///
237 /// This is the explicit version of the cleanup normally performed by
238 /// [`Drop`]. It may block while waiting for stream work and callbacks to
239 /// complete. If synchronization fails, destruction is still attempted and
240 /// the synchronization error is returned. If synchronization succeeds but
241 /// destruction fails, the destruction error is returned.
242 pub fn shutdown(self) -> Result<()> {
243 let inner = Arc::try_unwrap(self.inner).map_err(|_| Error::InvalidValue)?;
244 let inner = ManuallyDrop::new(inner);
245 Self::destroy_handle(inner.ctx.as_ref(), inner.handle)
246 }
247
248 /// Returns `true` if all operations in stream have completed, or `false` if not.
249 ///
250 /// For the purposes of Unified Memory, a return value of `true` is equivalent to having called [`Stream::synchronize`].
251 ///
252 /// Uses standard `default stream` semantics.
253 ///
254 /// # Errors
255 ///
256 /// Returns an error if querying the stream fails or if a previous asynchronous launch reported
257 /// an error. CUDA may also return initialization-related errors such as
258 /// [`crate::error::Status::NotInitialized`], [`crate::error::Status::CallRequiresNewerDriver`], or
259 /// [`crate::error::Status::NoDevice`] if this call initializes internal runtime state. Callbacks must not
260 /// call CUDA functions; see [`Stream::add_callback`].
261 pub fn query(&self) -> Result<bool> {
262 let error = unsafe { runtime::cudaStreamQuery(self.as_raw()) };
263 match error {
264 runtime::cudaError_t::CUDA_SUCCESS => Ok(true),
265 runtime::cudaError_t::CUDA_ERROR_NOT_READY => Ok(false),
266 _ => Err(error.into()),
267 }
268 }
269
270 /// Makes all future work submitted to stream wait for all work captured in event.
271 /// See [`sys::cudaEventRecord`](singe_cuda_sys::runtime::cudaEventRecord) for details on what is captured by an event.
272 /// Synchronization is performed efficiently on the device when applicable.
273 /// `event` may be from a different device than `stream`.
274 ///
275 /// Uses standard `default stream` semantics.
276 ///
277 /// # Errors
278 ///
279 /// Returns an error if the stream cannot wait on the event or if a previous asynchronous launch
280 /// reported an error. CUDA may also return initialization-related errors such as
281 /// [`crate::error::Status::NotInitialized`], [`crate::error::Status::CallRequiresNewerDriver`], or
282 /// [`crate::error::Status::NoDevice`] if this call initializes internal runtime state. Callbacks must not
283 /// call CUDA functions; see [`Stream::add_callback`].
284 pub fn wait_event(&self, event: &Event) -> Result<()> {
285 self.wait_event_with_flags(event, 0)
286 }
287
288 /// Makes all future work submitted to stream wait for all work captured in event.
289 /// See [`sys::cudaEventRecord`](singe_cuda_sys::runtime::cudaEventRecord) for details on what is captured by an event.
290 /// `flags` controls how strictly the wait is enforced.
291 /// Synchronization is performed efficiently on the device when applicable.
292 /// `event` may be from a different device than `stream`.
293 ///
294 /// Uses standard `default stream` semantics.
295 ///
296 /// # Errors
297 ///
298 /// Returns an error if the stream cannot wait on the event or if a previous asynchronous launch
299 /// reported an error. CUDA may also return initialization-related errors such as
300 /// [`crate::error::Status::NotInitialized`], [`crate::error::Status::CallRequiresNewerDriver`], or
301 /// [`crate::error::Status::NoDevice`] if this call initializes internal runtime state. Callbacks must not
302 /// call CUDA functions; see [`Stream::add_callback`].
303 pub fn wait_event_with_flags(&self, event: &Event, flags: u32) -> Result<()> {
304 self.inner.ctx.bind()?;
305 unsafe {
306 try_ffi!(runtime::cudaStreamWaitEvent(
307 self.as_raw(),
308 event.as_raw(),
309 flags,
310 ))
311 }
312 }
313
314 /// Begin graph capture on stream.
315 /// When a stream is in capture mode, operations pushed into the stream are captured
316 /// into a graph instead of executed. [`Stream::end_capture`] returns the graph.
317 /// Capture may not be initiated on the legacy default stream.
318 /// Capture must be ended on the same stream in which it was initiated, and it may only be initiated if the stream is not already in capture mode.
319 /// The capture mode may be queried via [`Stream::capture_status`].
320 /// A unique id representing the capture sequence may be queried via [`Stream::capture_info`].
321 ///
322 /// If mode is not [`StreamCaptureMode::Relaxed`], [`Stream::end_capture`] must be called on this stream from the same thread.
323 ///
324 /// Kernels captured using this API must not use texture and surface references.
325 /// Reading or writing through any texture or surface reference is undefined behavior.
326 /// This restriction does not apply to texture and surface objects.
327 ///
328 /// # Errors
329 ///
330 /// Returns an error if the context cannot be bound, capture cannot begin on
331 /// this stream, the capture mode is invalid for the current thread state,
332 /// or a previous asynchronous launch reports an error.
333 pub fn begin_capture(&self, mode: StreamCaptureMode) -> Result<()> {
334 self.inner.ctx.bind()?;
335 unsafe {
336 try_ffi!(runtime::cudaStreamBeginCapture(self.as_raw(), mode.into()))?;
337 }
338 Ok(())
339 }
340
341 /// Begins stream capture into an existing graph.
342 ///
343 /// # Safety
344 ///
345 /// This low-level API captures into `graph`'s existing CUDA handle. Calling
346 /// [`Stream::end_capture`] after this may return that same raw handle; the
347 /// caller must not wrap it as a second owned [`Graph`]. Prefer
348 /// [`Stream::capture`] unless manually managing capture into an existing
349 /// graph is required.
350 pub unsafe fn begin_capture_to_graph(
351 &self,
352 graph: &Graph,
353 dependencies: &[GraphNode],
354 mode: StreamCaptureMode,
355 ) -> Result<()> {
356 unsafe { self.begin_capture_to_graph_with_data(graph, dependencies, &[], mode) }
357 }
358
359 /// Begins stream capture into an existing graph with annotated dependency edges.
360 ///
361 /// # Safety
362 ///
363 /// This has the same ownership restrictions as
364 /// [`Stream::begin_capture_to_graph`].
365 pub unsafe fn begin_capture_to_graph_with_data(
366 &self,
367 graph: &Graph,
368 dependencies: &[GraphNode],
369 edge_data: &[GraphEdgeData],
370 mode: StreamCaptureMode,
371 ) -> Result<()> {
372 if !edge_data.is_empty() && edge_data.len() != dependencies.len() {
373 return Err(Error::GraphDependencyMismatch);
374 }
375
376 let dependencies: Vec<_> = dependencies
377 .iter()
378 .zip(
379 edge_data
380 .iter()
381 .copied()
382 .chain(iter::repeat(GraphEdgeData::default())),
383 )
384 .map(|(node, data)| GraphDependency {
385 node: node.clone(),
386 data,
387 })
388 .collect();
389
390 unsafe { self.begin_capture_to_graph_with_dependencies(graph, &dependencies, mode) }
391 }
392
393 /// Begin graph capture on stream.
394 /// When a stream is in capture mode, operations pushed into the stream are captured
395 /// into a graph instead of executed. [`Stream::end_capture`] returns the graph.
396 ///
397 /// Capture may not be initiated on the legacy default stream.
398 /// Capture must be ended on the same stream in which it was initiated, and it may only be initiated if the stream is not already in capture mode.
399 /// The capture mode may be queried via [`Stream::capture_status`].
400 /// A unique id representing the capture sequence may be queried via [`Stream::capture_info`].
401 ///
402 /// If mode is not [`StreamCaptureMode::Relaxed`], [`Stream::end_capture`] must be called on this stream from the same thread.
403 ///
404 /// Kernels captured using this API must not use texture and surface references.
405 /// Reading or writing through any texture or surface reference is undefined behavior.
406 /// This restriction does not apply to texture and surface objects.
407 ///
408 /// # Errors
409 ///
410 /// Returns an error if the context cannot be bound, capture cannot begin on
411 /// this stream, the graph dependencies are invalid, the capture mode is
412 /// invalid for the current thread state, or a previous asynchronous launch
413 /// reports an error.
414 /// # Safety
415 ///
416 /// This captures into `graph`'s existing CUDA handle. Calling
417 /// [`Stream::end_capture`] after this may return that same raw handle; the
418 /// caller must not wrap it as a second owned [`Graph`].
419 pub unsafe fn begin_capture_to_graph_with_dependencies(
420 &self,
421 graph: &Graph,
422 dependencies: &[GraphDependency],
423 mode: StreamCaptureMode,
424 ) -> Result<()> {
425 self.check_graph_context(graph)?;
426 self.check_capture_dependency_contexts(dependencies)?;
427 self.check_capture_graph_dependencies(graph, dependencies)?;
428 self.inner.ctx.bind()?;
429
430 let dependencies_raw: Vec<_> = dependencies
431 .iter()
432 .map(|dependency| dependency.node.as_raw())
433 .collect();
434 let edge_data_raw: Vec<_> = dependencies
435 .iter()
436 .map(|dependency| dependency.data.into())
437 .collect();
438 unsafe {
439 try_ffi!(runtime::cudaStreamBeginCaptureToGraph(
440 self.as_raw(),
441 graph.as_raw(),
442 dependencies_raw.as_ptr(),
443 if edge_data_raw.is_empty() {
444 ptr::null()
445 } else {
446 edge_data_raw.as_ptr()
447 },
448 dependencies_raw.len() as _,
449 mode.into(),
450 ))?;
451 }
452 Ok(())
453 }
454
455 /// Ends capture on this stream, returning the captured graph.
456 /// Capture must have been initiated on stream via a call to [`Stream::begin_capture`].
457 /// If capture was invalidated due to a violation of the rules of stream capture, an error is returned.
458 ///
459 /// If the mode argument to [`Stream::begin_capture`] was not [`StreamCaptureMode::Relaxed`], this call must be from the same thread as [`Stream::begin_capture`].
460 ///
461 /// # Errors
462 ///
463 /// Returns an error if the context cannot be bound, capture is not active
464 /// on this stream, the capture has been invalidated, or a previous
465 /// asynchronous launch reports an error.
466 pub fn end_capture(&self) -> Result<Graph> {
467 self.inner.ctx.bind()?;
468 let mut handle = ptr::null_mut();
469 unsafe {
470 try_ffi!(runtime::cudaStreamEndCapture(
471 self.as_raw(),
472 &raw mut handle
473 ))?;
474 Graph::from_raw_in_context(handle, Arc::clone(&self.inner.ctx))
475 }
476 }
477
478 /// Captures stream work recorded by `f` into a CUDA graph.
479 ///
480 /// This is the scoped form of [`Stream::begin_capture`] and
481 /// [`Stream::end_capture`]. The capture is always ended before this method
482 /// returns or resumes a panic. If `f` returns an error, this method attempts
483 /// to end capture to restore stream usability, destroys any graph returned
484 /// by CUDA, and returns the closure error.
485 ///
486 /// The scope is intentionally `!Send`, so it cannot be moved to another
487 /// thread while capture is active. Future graph-safe recording helpers can
488 /// be added to [`StreamCaptureScope`] without changing this API shape.
489 ///
490 /// # Errors
491 ///
492 /// Returns an error if capture cannot begin, if `f` returns an error, or if
493 /// capture cannot be ended successfully.
494 pub fn capture<F>(&self, mode: StreamCaptureMode, f: F) -> Result<Graph>
495 where
496 F: FnOnce(&StreamCaptureScope<'_>) -> Result<()>,
497 {
498 self.begin_capture(mode)?;
499 let capture = ActiveStreamCapture {
500 stream: self,
501 finished: false,
502 };
503
504 let scope = StreamCaptureScope {
505 stream: self,
506 _not_send: PhantomData,
507 };
508
509 let capture_result = panic::catch_unwind(panic::AssertUnwindSafe(|| f(&scope)));
510 match capture_result {
511 Ok(Ok(())) => capture.finish(),
512 Ok(Err(err)) => {
513 capture.discard();
514 Err(err)
515 }
516 Err(payload) => {
517 drop(capture);
518 panic::resume_unwind(payload);
519 }
520 }
521 }
522
523 pub fn capture_executable<F>(&self, mode: StreamCaptureMode, f: F) -> Result<ExecutableGraph>
524 where
525 F: FnOnce(&StreamCaptureScope<'_>) -> Result<()>,
526 {
527 self.capture_executable_with_flags(mode, GraphInstantiateFlags::empty(), f)
528 }
529
530 pub fn capture_executable_with_flags<F>(
531 &self,
532 mode: StreamCaptureMode,
533 flags: GraphInstantiateFlags,
534 f: F,
535 ) -> Result<ExecutableGraph>
536 where
537 F: FnOnce(&StreamCaptureScope<'_>) -> Result<()>,
538 {
539 let graph = self.capture(mode, f)?;
540 graph.instantiate_with_flags(flags)
541 }
542
543 /// Returns the capture status of this stream.
544 /// After a successful call, the status is one of the following:
545 ///
546 /// * [`StreamCaptureStatus::None`]: The stream is not capturing.
547 /// * [`StreamCaptureStatus::Active`]: The stream is capturing.
548 /// * [`StreamCaptureStatus::Invalidated`]: The stream was capturing but an error has invalidated the capture sequence.
549 /// The capture sequence must be terminated with
550 /// [`Stream::end_capture`] on the stream where it was initiated to continue using the stream.
551 ///
552 /// If this is called on the legacy default stream while a blocking stream on the same device is capturing, it returns [`crate::error::Status::StreamCaptureImplicit`].
553 /// The blocking stream capture is not invalidated.
554 ///
555 /// When a blocking stream is capturing, the legacy stream is in an unusable state until the blocking stream capture is terminated.
556 /// The legacy stream is not supported for stream capture, but attempted use would have an implicit dependency on the capturing stream(s).
557 ///
558 /// # Errors
559 ///
560 /// Returns an error if the context cannot be bound, CUDA cannot query the
561 /// capture status, or a previous asynchronous launch reports an error.
562 pub fn capture_status(&self) -> Result<StreamCaptureStatus> {
563 self.inner.ctx.bind()?;
564 let mut status = runtime::cudaStreamCaptureStatus::CU_STREAM_CAPTURE_STATUS_NONE;
565 unsafe {
566 try_ffi!(runtime::cudaStreamIsCapturing(
567 self.as_raw(),
568 &raw mut status
569 ))?;
570 }
571 Ok(status.into())
572 }
573
574 /// Query stream state related to stream capture.
575 ///
576 /// If called on the legacy default stream while a stream not created with [`StreamFlags::NON_BLOCKING`] is capturing, returns [`crate::error::Status::StreamCaptureImplicit`].
577 ///
578 /// Valid data (other than capture status) is returned only if both of the following are true:
579 ///
580 /// * the call succeeds
581 /// * the returned capture status is [`StreamCaptureStatus::Active`]
582 ///
583 /// If there is non-zero edge data for one or more current stream dependencies and the query cannot return that data, the call returns [`crate::error::Status::LossyQuery`].
584 ///
585 /// # Errors
586 ///
587 /// Returns an error if the context cannot be bound, CUDA cannot query the
588 /// capture info, the query would lose non-zero edge data, or a previous
589 /// asynchronous launch reports an error.
590 pub fn capture_info(&self) -> Result<StreamCaptureInfo> {
591 self.inner.ctx.bind()?;
592 let mut status = runtime::cudaStreamCaptureStatus::CU_STREAM_CAPTURE_STATUS_NONE;
593 let mut id = 0;
594 unsafe {
595 try_ffi!(runtime::cudaStreamGetCaptureInfo(
596 self.as_raw(),
597 &raw mut status,
598 &raw mut id,
599 ptr::null_mut(),
600 ptr::null_mut(),
601 ptr::null_mut(),
602 ptr::null_mut(),
603 ))?;
604 }
605 Ok(StreamCaptureInfo {
606 status: status.into(),
607 id,
608 })
609 }
610
611 pub fn update_capture_dependencies(&self, dependencies: &[GraphNode]) -> Result<()> {
612 self.update_capture_dependencies_with_mode(
613 dependencies,
614 &[],
615 StreamCaptureDependencyUpdate::Add,
616 )
617 }
618
619 pub fn update_capture_dependencies_with_data(
620 &self,
621 dependencies: &[GraphNode],
622 edge_data: &[GraphEdgeData],
623 ) -> Result<()> {
624 self.update_capture_dependencies_with_mode(
625 dependencies,
626 edge_data,
627 StreamCaptureDependencyUpdate::Add,
628 )
629 }
630
631 pub fn update_capture_dependencies_with_mode(
632 &self,
633 dependencies: &[GraphNode],
634 edge_data: &[GraphEdgeData],
635 mode: StreamCaptureDependencyUpdate,
636 ) -> Result<()> {
637 if !edge_data.is_empty() && edge_data.len() != dependencies.len() {
638 return Err(Error::GraphDependencyMismatch);
639 }
640
641 let dependencies: Vec<_> = dependencies
642 .iter()
643 .zip(
644 edge_data
645 .iter()
646 .copied()
647 .chain(iter::repeat(GraphEdgeData::default())),
648 )
649 .map(|(node, data)| GraphDependency {
650 node: node.clone(),
651 data,
652 })
653 .collect();
654
655 self.update_capture_dependencies_with_dependencies(&dependencies, mode)
656 }
657
658 /// Modifies the dependency set of a capturing stream.
659 /// The dependency set is the set of nodes that the next captured node in the stream will depend on.
660 ///
661 /// Valid flags are [`StreamCaptureDependencyUpdate::Add`] and [`StreamCaptureDependencyUpdate::Set`].
662 /// These control whether the supplied set is added to the existing set or replaces it.
663 /// A flags value of 0 defaults to [`StreamCaptureDependencyUpdate::Add`].
664 ///
665 /// Nodes that are removed from the dependency set by this call do not result in [`crate::error::Status::StreamCaptureUnjoined`] if they are unreachable from the stream at [`Stream::end_capture`].
666 ///
667 /// Returns [`crate::error::Status::IllegalState`] if the stream is not capturing.
668 ///
669 /// # Errors
670 ///
671 /// Returns an error if the context cannot be bound, the stream is not
672 /// capturing, the supplied dependencies are invalid, or a previous
673 /// asynchronous launch reports an error.
674 pub fn update_capture_dependencies_with_dependencies(
675 &self,
676 dependencies: &[GraphDependency],
677 mode: StreamCaptureDependencyUpdate,
678 ) -> Result<()> {
679 self.check_capture_dependency_contexts(dependencies)?;
680 self.check_active_capture_graph_dependencies(dependencies)?;
681 self.inner.ctx.bind()?;
682
683 let mut dependencies_raw: Vec<_> = dependencies
684 .iter()
685 .map(|dependency| dependency.node.as_raw())
686 .collect();
687 let edge_data_raw: Vec<_> = dependencies
688 .iter()
689 .map(|dependency| dependency.data.into())
690 .collect();
691 unsafe {
692 try_ffi!(runtime::cudaStreamUpdateCaptureDependencies(
693 self.as_raw(),
694 dependencies_raw.as_mut_ptr(),
695 if edge_data_raw.is_empty() {
696 ptr::null()
697 } else {
698 edge_data_raw.as_ptr()
699 },
700 dependencies_raw.len() as _,
701 mode.into(),
702 ))?;
703 }
704 Ok(())
705 }
706
707 /// This callback API is slated for eventual deprecation and removal.
708 /// If you do not require the callback to execute after a device error, consider using [`sys::cudaLaunchHostFunc`](singe_cuda_sys::runtime::cudaLaunchHostFunc).
709 /// Additionally, this callback mechanism is not supported with [`Stream::begin_capture`] and [`Stream::end_capture`], unlike [`sys::cudaLaunchHostFunc`](singe_cuda_sys::runtime::cudaLaunchHostFunc).
710 ///
711 /// Adds a callback to be called on the host after all currently enqueued items in the stream have completed.
712 /// For each [`Stream::add_callback`] call, a callback is executed exactly once.
713 /// The callback blocks later work in the stream until it is finished.
714 ///
715 /// The callback may be passed a successful status or an error code.
716 /// In the event of a device error, all subsequently executed callbacks receive an appropriate [`Status`].
717 ///
718 /// Callbacks must not call CUDA functions.
719 /// Attempting to do so may result in [`crate::error::Status::NotPermitted`].
720 /// Callbacks must not perform any synchronization that may depend on outstanding device work or other callbacks that are not mandated to run earlier.
721 /// Callbacks without a mandated order (in independent streams) execute in undefined order and may be serialized.
722 ///
723 /// For the purposes of Unified Memory, callback execution makes a number of guarantees:
724 ///
725 /// * The callback stream is considered idle for the duration of the callback.
726 /// Thus, for example, a callback may always use memory
727 /// attached to the callback stream.
728 /// * The start of execution of a callback has the same effect as synchronizing an event recorded in the same stream immediately
729 /// before the callback.
730 /// It thus synchronizes streams which have been "joined" before the callback.
731 /// * Adding device work to any stream does not have the effect of making the stream active until all preceding callbacks have executed.
732 /// Thus, for example, a callback might use global attached memory even if work has been added to another stream, if it has been
733 /// properly ordered with an event.
734 /// * Completion of a callback does not cause a stream to become active except as described above.
735 /// The callback stream will remain
736 /// idle if no device work follows the callback, and will remain idle across consecutive callbacks without device work in between.
737 /// Thus, for example, stream synchronization can be done by signaling from a callback at the end of the stream.
738 ///
739 /// # Errors
740 ///
741 /// Returns an error if the context cannot be bound, CUDA rejects the
742 /// callback registration, a previous asynchronous launch reports an error,
743 /// or CUDA reports runtime initialization diagnostics such as
744 /// [`crate::error::Status::NotInitialized`], [`crate::error::Status::CallRequiresNewerDriver`],
745 /// or [`crate::error::Status::NoDevice`].
746 pub fn add_callback<F>(&self, callback: F) -> Result<()>
747 where
748 F: FnOnce(Result<()>) + Send + 'static,
749 {
750 self.inner.ctx.bind()?;
751
752 let boxed_dyn_callback: RustStreamCallbackDyn = Box::new(callback);
753 let boxed_wrapper: Box<RustStreamCallbackDyn> = Box::new(boxed_dyn_callback);
754 let user_data_ptr: BoxedCallbackPtr = Box::into_raw(boxed_wrapper);
755 let final_user_data = user_data_ptr.cast();
756
757 let flags = 0u32;
758
759 unsafe {
760 let status = runtime::cudaStreamAddCallback(
761 self.as_raw(),
762 Some(stream_callback_trampoline),
763 final_user_data, // Pass the thin pointer
764 flags,
765 );
766
767 // If adding the callback fails, manually reconstruct and drop the *outer* Box to prevent leaking both boxes.
768 if status != runtime::cudaError_t::CUDA_SUCCESS {
769 // Reconstruct the outer box (Box<Box<dyn Trait>>)
770 let _leaked_box = Box::from_raw(user_data_ptr);
771 // Drop the reconstructed outer box.
772 try_ffi!(status)?;
773 }
774 }
775
776 Ok(())
777 }
778
779 /// Enqueues a host function to run after all currently enqueued work in this stream completes.
780 ///
781 /// Unlike [`Stream::add_callback`], CUDA does not call this function if the CUDA context is already in an error state.
782 /// This API is supported during stream capture by CUDA, but the host function still must not call CUDA
783 /// APIs or perform synchronization that depends on outstanding device work.
784 ///
785 /// # Errors
786 ///
787 /// Returns an error if the context cannot be bound, CUDA rejects the host
788 /// function registration, a previous asynchronous launch reports an error,
789 /// or CUDA reports runtime initialization diagnostics.
790 pub fn launch_host_func<F>(&self, function: F) -> Result<()>
791 where
792 F: FnOnce() + Send + 'static,
793 {
794 self.inner.ctx.bind()?;
795
796 let boxed_dyn_function: RustHostFunctionDyn = Box::new(function);
797 let boxed_wrapper: Box<RustHostFunctionDyn> = Box::new(boxed_dyn_function);
798 let user_data_ptr: BoxedHostFunctionPtr = Box::into_raw(boxed_wrapper);
799 let final_user_data = user_data_ptr.cast();
800
801 unsafe {
802 let status = runtime::cudaLaunchHostFunc(
803 self.as_raw(),
804 Some(stream_host_function_trampoline),
805 final_user_data,
806 );
807
808 if status != runtime::cudaError_t::CUDA_SUCCESS {
809 let _leaked_box = Box::from_raw(user_data_ptr);
810 try_ffi!(status)?;
811 }
812 }
813
814 Ok(())
815 }
816
817 /// Query the flags of a stream.
818 /// Returns the stream flags.
819 /// See [`Context::create_stream_with_flags`] for a list of valid flags.
820 ///
821 /// Uses standard `default stream` semantics.
822 ///
823 /// # Errors
824 ///
825 /// Returns an error if CUDA cannot query the flags or if a previous asynchronous launch
826 /// reported an error. CUDA may also return initialization-related errors such as
827 /// [`crate::error::Status::NotInitialized`], [`crate::error::Status::CallRequiresNewerDriver`], or
828 /// [`crate::error::Status::NoDevice`] if this call initializes internal runtime state. Callbacks must not
829 /// call CUDA functions; see [`Stream::add_callback`].
830 pub fn flags(&self) -> Result<StreamFlags> {
831 self.inner.ctx.bind()?;
832 let mut flags_raw = 0u32;
833 unsafe {
834 try_ffi!(runtime::cudaStreamGetFlags(
835 self.as_raw(),
836 &raw mut flags_raw
837 ))?;
838 }
839 Ok(StreamFlags::from_bits_retain(flags_raw))
840 }
841
842 /// Query the priority of a stream.
843 /// Returns the stream priority.
844 /// If the stream was created with a priority outside the meaningful numerical range returned by [`Device::stream_priority_range`], this returns the clamped priority.
845 /// See [`Context::create_stream_with_priority`] for details about priority clamping.
846 ///
847 /// # Errors
848 ///
849 /// Returns an error if the context cannot be bound, CUDA cannot query the
850 /// priority, a previous asynchronous launch reports an error, or CUDA
851 /// reports runtime initialization diagnostics.
852 pub fn priority(&self) -> Result<i32> {
853 self.inner.ctx.bind()?;
854 let mut priority = 0i32;
855 unsafe {
856 try_ffi!(runtime::cudaStreamGetPriority(
857 self.as_raw(),
858 &raw mut priority
859 ))?;
860 }
861 Ok(priority)
862 }
863
864 /// Returns a stream identifier that remains unique for the life of the program.
865 ///
866 /// The stream handle may refer to any of the following:
867 ///
868 /// * a stream created via any of the CUDA runtime APIs such as [`sys::cudaStreamCreate`](singe_cuda_sys::runtime::cudaStreamCreate), [`Context::create_stream_with_flags`] and [`Context::create_stream_with_priority`], or their driver API equivalents such as [`sys::cuStreamCreate`](singe_cuda_sys::driver::cuStreamCreate) or [`sys::cuStreamCreateWithPriority`](singe_cuda_sys::driver::cuStreamCreateWithPriority).
869 /// Passing an invalid handle results in undefined behavior.
870 /// * the special legacy default stream and per-thread default stream.
871 /// The driver API equivalents of these are also accepted.
872 ///
873 /// # Errors
874 ///
875 /// Returns an error if the context cannot be bound, CUDA cannot query the
876 /// stream identifier, a previous asynchronous launch reports an error, or
877 /// CUDA reports runtime initialization diagnostics.
878 pub fn id(&self) -> Result<u64> {
879 self.inner.ctx.bind()?;
880 let mut id = 0u64;
881 unsafe {
882 try_ffi!(runtime::cudaStreamGetId(self.as_raw(), &raw mut id))?;
883 }
884 Ok(id)
885 }
886
887 /// Returns the device of the stream.
888 ///
889 /// # Errors
890 ///
891 /// Returns an error if the context cannot be bound, CUDA cannot query the
892 /// stream device, a previous asynchronous launch reports an error, or CUDA
893 /// reports runtime initialization diagnostics.
894 pub fn device(&self) -> Result<Device> {
895 self.inner.ctx.bind()?;
896 let mut device = 0i32;
897 unsafe {
898 try_ffi!(runtime::cudaStreamGetDevice(self.as_raw(), &raw mut device))?;
899 }
900 Ok(Device::new(device))
901 }
902
903 pub fn context(&self) -> &Context {
904 &self.inner.ctx
905 }
906
907 pub fn as_raw(&self) -> runtime::cudaStream_t {
908 self.inner.handle
909 }
910
911 /// Consumes the stream and returns the raw CUDA stream handle without
912 /// destroying it.
913 ///
914 /// The caller becomes responsible for eventually destroying the returned
915 /// handle with CUDA.
916 pub fn into_raw(self) -> runtime::cudaStream_t {
917 let inner = Arc::try_unwrap(self.inner)
918 .expect("cannot transfer raw stream handle while cloned stream handles exist");
919 let inner = ManuallyDrop::new(inner);
920 inner.handle
921 }
922
923 fn destroy_handle(ctx: &Context, handle: runtime::cudaStream_t) -> Result<()> {
924 let bind_result = ctx.bind();
925 let sync_result =
926 bind_result.and_then(|()| unsafe { try_ffi!(runtime::cudaStreamSynchronize(handle)) });
927 let destroy_result = unsafe { try_ffi!(runtime::cudaStreamDestroy(handle)) };
928
929 match (sync_result, destroy_result) {
930 (Ok(()), Ok(())) => Ok(()),
931 (Err(err), _) | (Ok(()), Err(err)) => Err(err),
932 }
933 }
934
935 // pub fn is_null(&self) -> bool {
936 // self.inner.handle.is_null()
937 // }
938
939 // --- Methods related to Memory Management ---
940 // Add methods like malloc_async, free_async, attach_mem_async if needed
941 // These would likely take wrappers around device memory pointers.
942
943 fn check_graph_context(&self, graph: &Graph) -> Result<()> {
944 if matches!(graph.context(), Some(ctx) if ctx != self.inner.ctx.as_ref()) {
945 return Err(Error::GraphContextMismatch);
946 }
947 Ok(())
948 }
949
950 fn check_capture_dependency_contexts(&self, dependencies: &[GraphDependency]) -> Result<()> {
951 for dependency in dependencies {
952 if matches!(dependency.node.context(), Some(ctx) if ctx != self.inner.ctx.as_ref()) {
953 return Err(Error::GraphContextMismatch);
954 }
955 }
956 Ok(())
957 }
958
959 fn check_capture_graph_dependencies(
960 &self,
961 graph: &Graph,
962 dependencies: &[GraphDependency],
963 ) -> Result<()> {
964 for dependency in dependencies {
965 graph.check_node(&dependency.node)?;
966 }
967 Ok(())
968 }
969
970 fn check_active_capture_graph_dependencies(
971 &self,
972 dependencies: &[GraphDependency],
973 ) -> Result<()> {
974 if dependencies.is_empty() {
975 return Ok(());
976 }
977
978 self.inner.ctx.bind()?;
979 let mut status = runtime::cudaStreamCaptureStatus::CU_STREAM_CAPTURE_STATUS_NONE;
980 let mut graph = ptr::null_mut();
981 unsafe {
982 try_ffi!(runtime::cudaStreamGetCaptureInfo(
983 self.as_raw(),
984 &raw mut status,
985 ptr::null_mut(),
986 &raw mut graph,
987 ptr::null_mut(),
988 ptr::null_mut(),
989 ptr::null_mut(),
990 ))?;
991 }
992 if StreamCaptureStatus::from(status) != StreamCaptureStatus::Active {
993 return Ok(());
994 }
995 if graph.is_null() {
996 return Err(Error::NullHandle);
997 }
998
999 for dependency in dependencies {
1000 if !matches!(dependency.node.graph_raw(), Some(node_graph) if node_graph == graph) {
1001 return Err(Error::GraphNodeMismatch);
1002 }
1003 }
1004 Ok(())
1005 }
1006
1007 pub(crate) fn ensure_not_capturing_for_future(&self) -> Result<()> {
1008 match self.capture_status()? {
1009 StreamCaptureStatus::None => Ok(()),
1010 StreamCaptureStatus::Active => Err(Status::StreamCaptureUnsupported.into()),
1011 StreamCaptureStatus::Invalidated => Err(Status::StreamCaptureInvalidated.into()),
1012 }
1013 }
1014}
1015
1016impl<'scope, 'env> StreamScope<'scope, 'env> {
1017 pub const fn stream(&self) -> &'scope Stream {
1018 self.stream
1019 }
1020
1021 pub fn synchronize(&self) -> Result<()> {
1022 self.stream.synchronize()
1023 }
1024}
1025
1026impl<'scope> StreamCaptureScope<'scope> {
1027 pub const fn stream(&self) -> &'scope Stream {
1028 self.stream
1029 }
1030
1031 /// Records a graph-safe operation into this active stream capture.
1032 ///
1033 /// Only operations implementing [`GraphRecordable`] can be submitted
1034 /// through this method. Allocation/free and other capture-unsafe CUDA calls
1035 /// should stay outside this trait unless their replay ownership and address
1036 /// stability are explicitly modeled.
1037 pub fn record<O>(&self, operation: O) -> Result<O::Output>
1038 where
1039 O: GraphRecordable,
1040 {
1041 operation.record(self)
1042 }
1043}
1044
1045impl BorrowedStream {
1046 /// Wraps an existing CUDA stream handle without taking ownership.
1047 ///
1048 /// # Safety
1049 ///
1050 /// `handle` must be a valid CUDA stream associated with `ctx`, and it must
1051 /// remain valid for every operation using the returned borrowed stream.
1052 pub const unsafe fn from_raw(handle: runtime::cudaStream_t, ctx: Arc<Context>) -> Self {
1053 Self { handle, ctx }
1054 }
1055
1056 pub fn synchronize(&self) -> Result<()> {
1057 self.ctx.bind()?;
1058 unsafe { try_ffi!(runtime::cudaStreamSynchronize(self.as_raw())) }
1059 }
1060
1061 pub fn context(&self) -> &Context {
1062 &self.ctx
1063 }
1064
1065 pub const fn as_raw(&self) -> runtime::cudaStream_t {
1066 self.handle
1067 }
1068}
1069
1070impl StreamBinding {
1071 pub fn context(&self) -> &Context {
1072 match self {
1073 Self::Default(ctx) => ctx.as_ref(),
1074 Self::Borrowed(stream) => stream.context(),
1075 }
1076 }
1077
1078 pub fn is_default(&self) -> bool {
1079 matches!(self, Self::Default(..))
1080 }
1081
1082 pub fn as_raw(&self) -> runtime::cudaStream_t {
1083 match self {
1084 Self::Default(_) => ptr::null_mut(),
1085 Self::Borrowed(stream) => stream.as_raw(),
1086 }
1087 }
1088}
1089
1090// CUDA streams are ordering handles.
1091// Operations take &self and are serialized by CUDA stream semantics rather than mutable Rust state.
1092unsafe impl Send for StreamInner {}
1093unsafe impl Sync for StreamInner {}
1094unsafe impl Send for Stream {}
1095unsafe impl Sync for Stream {}
1096
1097impl Drop for StreamInner {
1098 fn drop(&mut self) {
1099 if let Err(err) = Stream::destroy_handle(self.ctx.as_ref(), self.handle) {
1100 #[cfg(debug_assertions)]
1101 eprintln!("failed to synchronize or destroy CUDA stream: {err}");
1102 }
1103 }
1104}
1105
1106// Trampoline function to bridge C FFI callback to Rust closure
1107extern "C" fn stream_callback_trampoline(
1108 _stream: runtime::cudaStream_t,
1109 status: runtime::cudaError_t,
1110 user_data: *mut std::ffi::c_void,
1111) {
1112 if user_data.is_null() {
1113 return;
1114 }
1115
1116 let user_data_ptr = user_data as BoxedCallbackPtr;
1117
1118 let boxed_callback: Box<RustStreamCallbackDyn> = unsafe { Box::from_raw(user_data_ptr) };
1119 let callback: RustStreamCallbackDyn = *boxed_callback;
1120
1121 let result = if status == runtime::cudaError_t::CUDA_SUCCESS {
1122 Ok(())
1123 } else {
1124 Err(status.into())
1125 };
1126
1127 callback(result);
1128}
1129
1130extern "C" fn stream_host_function_trampoline(user_data: *mut std::ffi::c_void) {
1131 if user_data.is_null() {
1132 return;
1133 }
1134
1135 let user_data_ptr = user_data as BoxedHostFunctionPtr;
1136 let boxed_function: Box<RustHostFunctionDyn> = unsafe { Box::from_raw(user_data_ptr) };
1137 let function: RustHostFunctionDyn = *boxed_function;
1138 function();
1139}
1140
1141impl Context {
1142 pub fn create_stream(self: &Arc<Self>) -> Result<Stream> {
1143 self.create_stream_with_flags(StreamFlags::DEFAULT)
1144 }
1145
1146 /// Creates a new asynchronous stream on the context that is current to the calling host thread.
1147 /// If no context is current to the calling host thread, then the primary context for a device is selected, made current to the calling thread, and initialized before creating a stream on it.
1148 /// The flags argument determines the behaviors of the stream.
1149 /// Valid values are provided by [`StreamFlags`]:
1150 ///
1151 /// * [`StreamFlags::DEFAULT`]: default stream creation behavior.
1152 /// * [`StreamFlags::NON_BLOCKING`]: allows the created stream to run concurrently with the legacy default stream without implicit synchronization.
1153 ///
1154 /// # Errors
1155 ///
1156 /// Returns an error if CUDA cannot create the stream, if it does not return a valid stream
1157 /// handle, or if a previous asynchronous launch reported an error. CUDA may also return
1158 /// initialization-related errors such as [`crate::error::Status::NotInitialized`],
1159 /// [`crate::error::Status::CallRequiresNewerDriver`], or [`crate::error::Status::NoDevice`] if this call initializes
1160 /// internal runtime state. Callbacks must not call CUDA functions; see
1161 /// [`Stream::add_callback`].
1162 pub fn create_stream_with_flags(self: &Arc<Self>, flags: StreamFlags) -> Result<Stream> {
1163 self.bind()?;
1164 let mut handle = ptr::null_mut();
1165 unsafe {
1166 try_ffi!(runtime::cudaStreamCreateWithFlags(
1167 &raw mut handle,
1168 flags.bits(),
1169 ))?;
1170 }
1171 if handle.is_null() {
1172 return Err(Error::NullHandle);
1173 }
1174 // let mut device_id = 0;
1175 // unsafe { check(cudaStreamGetDevice(stream, &mut device_id))?; }
1176 unsafe { Stream::from_raw(handle, Arc::clone(self)) }
1177 }
1178
1179 /// Creates a stream with the specified priority.
1180 /// The stream is created on this context.
1181 /// This affects the scheduling priority of work in the stream.
1182 /// Priorities provide a hint to preferentially run work with higher priority when possible, but do not preempt already-running work or provide any other functional guarantee on execution order.
1183 ///
1184 /// `priority` follows a convention where lower numbers represent higher priorities.
1185 /// `0` represents default priority.
1186 /// The range of meaningful numerical priorities can be queried using [`Device::stream_priority_range`].
1187 /// If the specified priority is outside the numerical range returned by [`Device::stream_priority_range`], it will automatically be clamped to the lowest or the highest number in the range.
1188 ///
1189 /// * Stream priorities are supported only on GPUs with compute capability 3.5 or higher.
1190 /// * In the current implementation, only compute kernels launched in priority streams are affected by the stream's priority.
1191 /// Stream
1192 /// priorities have no effect on host-to-device and device-to-host memory operations.
1193 ///
1194 /// # Errors
1195 ///
1196 /// Returns an error if the context cannot be bound, CUDA cannot create the
1197 /// stream, CUDA returns a null stream handle, a previous asynchronous launch
1198 /// reports an error, or CUDA reports runtime initialization diagnostics.
1199 pub fn create_stream_with_priority(
1200 self: &Arc<Self>,
1201 flags: StreamFlags,
1202 priority: i32,
1203 ) -> Result<Stream> {
1204 self.bind()?;
1205 let mut handle = ptr::null_mut();
1206 unsafe {
1207 try_ffi!(runtime::cudaStreamCreateWithPriority(
1208 &raw mut handle,
1209 flags.bits(),
1210 priority,
1211 ))?;
1212 }
1213 if handle.is_null() {
1214 return Err(Error::NullHandle);
1215 }
1216 unsafe { Stream::from_raw(handle, Arc::clone(self)) }
1217 }
1218}
1219
1220/// Sets the calling thread's stream capture interaction mode, returning the previous mode for the thread.
1221/// To facilitate deterministic behavior across function or module boundaries, use this in a push-pop fashion.
1222///
1223/// During stream capture (see [`Stream::begin_capture`]), some actions, such as a call to [`DeviceMemory::alloc`](crate::memory::DeviceMemory::alloc), may be unsafe.
1224/// In the case of [`DeviceMemory::alloc`](crate::memory::DeviceMemory::alloc), the operation is not enqueued asynchronously to a stream, and is not observed by stream capture.
1225/// If the sequence of operations captured via [`Stream::begin_capture`] depended on the allocation being replayed whenever the graph is launched, the captured graph would be invalid.
1226///
1227/// Therefore, stream capture places restrictions on CUDA calls that can be made within or concurrently to a [`Stream::begin_capture`]-[`Stream::end_capture`] sequence.
1228/// Control this behavior with this function and flags to [`Stream::begin_capture`].
1229///
1230/// A thread's mode is one of the following:
1231///
1232/// * [`StreamCaptureMode::Global`]: default mode.
1233/// If the local thread has an ongoing capture sequence that was not initiated with [`StreamCaptureMode::Relaxed`] at [`Stream::begin_capture`], or if any other thread has a concurrent capture sequence initiated with [`StreamCaptureMode::Global`], this thread is prohibited from potentially unsafe CUDA calls.
1234/// * [`StreamCaptureMode::ThreadLocal`]: If the local thread has an ongoing capture sequence not initiated with [`StreamCaptureMode::Relaxed`], it is prohibited from potentially unsafe CUDA calls.
1235/// Concurrent capture sequences in other threads are ignored.
1236/// * [`StreamCaptureMode::Relaxed`]: The local thread is not prohibited from potentially unsafe CUDA calls.
1237/// The thread is still prohibited from CUDA calls
1238/// which necessarily conflict with stream capture, for example, attempting [`Event::query`] on an event that was last recorded inside a capture sequence.
1239///
1240/// # Errors
1241///
1242/// Returns an error if CUDA rejects the capture-mode exchange or if a previous asynchronous launch
1243/// reported an error.
1244pub fn exchange_capture_mode(mode: StreamCaptureMode) -> Result<StreamCaptureMode> {
1245 let mut mode_raw: runtime::cudaStreamCaptureMode = mode.into();
1246 unsafe {
1247 try_ffi!(runtime::cudaThreadExchangeStreamCaptureMode(
1248 &raw mut mode_raw
1249 ))?;
1250 }
1251 Ok(mode_raw.into())
1252}
1253
1254#[cfg(all(test, feature = "testing"))]
1255mod tests {
1256 use std::{
1257 sync::{
1258 Arc,
1259 atomic::{AtomicBool, Ordering},
1260 },
1261 thread,
1262 };
1263
1264 use super::*;
1265 use crate::{event::EventRecordFlags, memory::DeviceMemory, testing};
1266
1267 #[test]
1268 fn it_works() -> Result<()> {
1269 let (_lock, ctx) = testing::bootstrap()?;
1270 let stream1 = ctx.create_stream()?;
1271 let _stream2 = ctx.create_stream_with_flags(StreamFlags::NON_BLOCKING)?;
1272
1273 let stream1_called = Arc::new(AtomicBool::new(false));
1274 stream1.add_callback(Box::new({
1275 let stream1_called = Arc::clone(&stream1_called);
1276 move |_status| {
1277 stream1_called.store(true, Ordering::SeqCst);
1278 }
1279 }))?;
1280
1281 let is_done = stream1.query()?;
1282 assert!(!is_done);
1283
1284 stream1.synchronize()?;
1285
1286 let is_done_after = stream1.query()?;
1287 assert!(is_done_after);
1288
1289 assert!(stream1_called.load(Ordering::SeqCst));
1290
1291 Ok(())
1292 }
1293
1294 #[test]
1295 fn event_query_uses_event_context() -> Result<()> {
1296 let (_lock, ctx) = testing::bootstrap()?;
1297 let stream = ctx.create_stream()?;
1298 let event = ctx.create_event()?;
1299
1300 event.record(&stream, EventRecordFlags::DEFAULT)?;
1301 stream.synchronize()?;
1302
1303 assert!(event.query()?);
1304 Ok(())
1305 }
1306
1307 #[test]
1308 fn shutdown_synchronizes_and_destroys_stream() -> Result<()> {
1309 let (_lock, ctx) = testing::bootstrap()?;
1310 let stream = ctx.create_stream()?;
1311
1312 let called = Arc::new(AtomicBool::new(false));
1313 stream.add_callback(Box::new({
1314 let called = Arc::clone(&called);
1315 move |_status| {
1316 called.store(true, Ordering::SeqCst);
1317 }
1318 }))?;
1319
1320 stream.shutdown()?;
1321 assert!(called.load(Ordering::SeqCst));
1322 Ok(())
1323 }
1324
1325 #[test]
1326 fn launch_host_func_runs_after_stream_work() -> Result<()> {
1327 let (_lock, ctx) = testing::bootstrap()?;
1328 let stream = ctx.create_stream()?;
1329
1330 let called = Arc::new(AtomicBool::new(false));
1331 stream.launch_host_func({
1332 let called = Arc::clone(&called);
1333 move || {
1334 called.store(true, Ordering::SeqCst);
1335 }
1336 })?;
1337 stream.synchronize()?;
1338
1339 assert!(called.load(Ordering::SeqCst));
1340 Ok(())
1341 }
1342
1343 #[test]
1344 fn scoped_capture_returns_context_associated_graph() -> Result<()> {
1345 let (_lock, ctx) = testing::bootstrap()?;
1346 let stream = ctx.create_stream()?;
1347
1348 let graph = stream.capture(StreamCaptureMode::Relaxed, |scope| {
1349 assert_eq!(scope.stream().context(), ctx.as_ref());
1350 Ok(())
1351 })?;
1352
1353 assert_eq!(graph.context(), Some(ctx.as_ref()));
1354 Ok(())
1355 }
1356
1357 #[test]
1358 fn capture_to_graph_rejects_graph_from_different_context() -> Result<()> {
1359 let (_lock, ctx) = testing::bootstrap()?;
1360 let other_ctx = Context::create()?;
1361
1362 let stream = ctx.create_stream()?;
1363 let graph = other_ctx.create_graph()?;
1364
1365 assert!(matches!(
1366 unsafe { stream.begin_capture_to_graph(&graph, &[], StreamCaptureMode::Relaxed) },
1367 Err(Error::GraphContextMismatch)
1368 ));
1369 assert_eq!(stream.capture_status()?, StreamCaptureStatus::None);
1370 Ok(())
1371 }
1372
1373 #[test]
1374 fn capture_to_graph_rejects_node_from_different_graph() -> Result<()> {
1375 let (_lock, ctx) = testing::bootstrap()?;
1376
1377 let stream = ctx.create_stream()?;
1378 let graph = ctx.create_graph()?;
1379 let mut other_graph = ctx.create_graph()?;
1380 let other_node = other_graph.add_empty_node(&[])?;
1381
1382 assert!(matches!(
1383 unsafe {
1384 stream.begin_capture_to_graph(&graph, &[other_node], StreamCaptureMode::Relaxed)
1385 },
1386 Err(Error::GraphNodeMismatch)
1387 ));
1388 assert_eq!(stream.capture_status()?, StreamCaptureStatus::None);
1389 Ok(())
1390 }
1391
1392 #[test]
1393 fn capture_to_graph_rejects_unassociated_raw_dependency_node() -> Result<()> {
1394 let (_lock, ctx) = testing::bootstrap()?;
1395
1396 let stream = ctx.create_stream()?;
1397 let graph = ctx.create_graph()?;
1398 let raw_node = unsafe { GraphNode::from_raw(0x1usize as _) };
1399
1400 assert!(matches!(
1401 unsafe {
1402 stream.begin_capture_to_graph(&graph, &[raw_node], StreamCaptureMode::Relaxed)
1403 },
1404 Err(Error::GraphNodeMismatch)
1405 ));
1406 assert_eq!(stream.capture_status()?, StreamCaptureStatus::None);
1407 Ok(())
1408 }
1409
1410 #[test]
1411 fn capture_dependency_update_rejects_node_from_different_context() -> Result<()> {
1412 let (_lock, ctx) = testing::bootstrap()?;
1413 let other_ctx = Context::create()?;
1414
1415 let stream = ctx.create_stream()?;
1416 let mut other_graph = other_ctx.create_graph()?;
1417 let other_node = other_graph.add_empty_node(&[])?;
1418
1419 let result = stream.capture(StreamCaptureMode::Relaxed, |_scope| {
1420 stream.update_capture_dependencies(&[other_node])
1421 });
1422
1423 assert!(matches!(result, Err(Error::GraphContextMismatch)));
1424 assert_eq!(stream.capture_status()?, StreamCaptureStatus::None);
1425 Ok(())
1426 }
1427
1428 #[test]
1429 fn capture_dependency_update_rejects_node_from_different_graph() -> Result<()> {
1430 let (_lock, ctx) = testing::bootstrap()?;
1431
1432 let stream = ctx.create_stream()?;
1433 let mut other_graph = ctx.create_graph()?;
1434 let other_node = other_graph.add_empty_node(&[])?;
1435
1436 stream.begin_capture(StreamCaptureMode::Relaxed)?;
1437 assert!(matches!(
1438 stream.update_capture_dependencies(&[other_node]),
1439 Err(Error::GraphNodeMismatch)
1440 ));
1441 drop(stream.end_capture());
1442 assert_eq!(stream.capture_status()?, StreamCaptureStatus::None);
1443 Ok(())
1444 }
1445
1446 #[test]
1447 fn capture_dependency_update_rejects_unassociated_raw_node() -> Result<()> {
1448 let (_lock, ctx) = testing::bootstrap()?;
1449
1450 let stream = ctx.create_stream()?;
1451 let raw_node = unsafe { GraphNode::from_raw(0x1usize as _) };
1452
1453 stream.begin_capture(StreamCaptureMode::Relaxed)?;
1454 assert!(matches!(
1455 stream.update_capture_dependencies(&[raw_node]),
1456 Err(Error::GraphNodeMismatch)
1457 ));
1458 drop(stream.end_capture());
1459 assert_eq!(stream.capture_status()?, StreamCaptureStatus::None);
1460 Ok(())
1461 }
1462
1463 #[test]
1464 fn captures_on_separate_streams_can_overlap() -> Result<()> {
1465 let (_lock, ctx) = testing::bootstrap()?;
1466 let stream_a = ctx.create_stream()?;
1467 let stream_b = ctx.create_stream()?;
1468
1469 thread::scope(|scope| {
1470 let a = scope.spawn(|| {
1471 let graph = stream_a.capture(StreamCaptureMode::Relaxed, |_scope| Ok(()))?;
1472 assert_eq!(graph.context(), Some(ctx.as_ref()));
1473 Result::<()>::Ok(())
1474 });
1475 let b = scope.spawn(|| {
1476 let graph = stream_b.capture(StreamCaptureMode::Relaxed, |_scope| Ok(()))?;
1477 assert_eq!(graph.context(), Some(ctx.as_ref()));
1478 Result::<()>::Ok(())
1479 });
1480
1481 a.join().expect("capture thread panicked")?;
1482 b.join().expect("capture thread panicked")?;
1483 Result::<()>::Ok(())
1484 })?;
1485
1486 Ok(())
1487 }
1488
1489 #[test]
1490 fn scoped_capture_error_leaves_stream_usable() -> Result<()> {
1491 let (_lock, ctx) = testing::bootstrap()?;
1492 let stream = ctx.create_stream()?;
1493
1494 let result = stream.capture(StreamCaptureMode::Relaxed, |_scope| {
1495 Err(Error::InvalidValue)
1496 });
1497
1498 assert!(matches!(result, Err(Error::InvalidValue)));
1499 stream.synchronize()?;
1500 Ok(())
1501 }
1502
1503 #[test]
1504 fn scoped_capture_panic_leaves_stream_usable() -> Result<()> {
1505 let (_lock, ctx) = testing::bootstrap()?;
1506 let stream = ctx.create_stream()?;
1507
1508 let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
1509 let _ = stream.capture(StreamCaptureMode::Relaxed, |_scope| -> Result<()> {
1510 panic!("capture body panic");
1511 });
1512 }));
1513
1514 assert!(result.is_err());
1515 stream.synchronize()?;
1516 Ok(())
1517 }
1518
1519 #[test]
1520 fn scoped_capture_records_memory_operations() -> Result<()> {
1521 let (_lock, ctx) = testing::bootstrap()?;
1522 let stream = ctx.create_stream()?;
1523
1524 let input = [1u8, 2, 3, 4];
1525 let source = DeviceMemory::from_slice(&input)?;
1526 let mut copied = DeviceMemory::<u8>::zeroes(input.len())?;
1527 let mut filled = DeviceMemory::<u8>::zeroes(input.len())?;
1528
1529 let graph = stream.capture(StreamCaptureMode::Relaxed, |scope| {
1530 let copy = unsafe { copied.copy_from_device_operation(&source)? };
1531 scope.record(copy)?;
1532
1533 let memset = unsafe { filled.set_value_operation(0xab) };
1534 scope.record(memset)
1535 })?;
1536
1537 let executable = graph.instantiate()?;
1538 executable.launch(&stream)?;
1539 stream.synchronize()?;
1540
1541 assert_eq!(copied.copy_to_host_vec()?, input);
1542 assert_eq!(filled.copy_to_host_vec()?, [0xab; 4]);
1543 Ok(())
1544 }
1545}