Skip to main content

singe_cuda/
stream.rs

1#[allow(unused_imports)]
2use crate::error::Status;
3
4use std::{iter, marker::PhantomData, mem::ManuallyDrop, panic, ptr, sync::Arc};
5
6use num_enum::{IntoPrimitive, TryFromPrimitive};
7use singe_core::impl_enum_conversion;
8use singe_cuda_sys::runtime;
9
10use crate::{
11    context::Context,
12    device::Device,
13    error::{Error, Result},
14    event::Event,
15    graph::{
16        ExecutableGraph, Graph, GraphDependency, GraphEdgeData, GraphInstantiateFlags, GraphNode,
17    },
18    try_ffi,
19};
20
21bitflags::bitflags! {
22    /// Flags for CUDA stream creation ([`Context::create_stream_with_flags`]).
23    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24    pub struct StreamFlags: u32 {
25        const DEFAULT = runtime::cudaStreamDefault;
26        const NON_BLOCKING = runtime::cudaStreamNonBlocking;
27    }
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
31#[repr(u32)]
32#[non_exhaustive]
33pub enum StreamCaptureStatus {
34    None = runtime::cudaStreamCaptureStatus::CU_STREAM_CAPTURE_STATUS_NONE as _,
35    Active = runtime::cudaStreamCaptureStatus::CU_STREAM_CAPTURE_STATUS_ACTIVE as _,
36    Invalidated = runtime::cudaStreamCaptureStatus::CU_STREAM_CAPTURE_STATUS_INVALIDATED as _,
37}
38
39impl_enum_conversion!(u32, runtime::cudaStreamCaptureStatus, StreamCaptureStatus);
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
42#[repr(u32)]
43#[non_exhaustive]
44pub enum StreamCaptureMode {
45    Global = runtime::cudaStreamCaptureMode::CU_STREAM_CAPTURE_MODE_GLOBAL as _,
46    ThreadLocal = runtime::cudaStreamCaptureMode::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL as _,
47    Relaxed = runtime::cudaStreamCaptureMode::CU_STREAM_CAPTURE_MODE_RELAXED as _,
48}
49
50impl_enum_conversion!(u32, runtime::cudaStreamCaptureMode, StreamCaptureMode);
51
52/// Flags for [`Stream::update_capture_dependencies_with_dependencies`]
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
54#[repr(u32)]
55#[non_exhaustive]
56pub enum StreamCaptureDependencyUpdate {
57    /// Add new nodes to the dependency set.
58    Add = runtime::cudaStreamUpdateCaptureDependenciesFlags::cudaStreamAddCaptureDependencies as _,
59    /// Replace the dependency set with the new nodes.
60    Set = runtime::cudaStreamUpdateCaptureDependenciesFlags::cudaStreamSetCaptureDependencies as _,
61}
62
63impl_enum_conversion!(
64    u32,
65    runtime::cudaStreamUpdateCaptureDependenciesFlags,
66    StreamCaptureDependencyUpdate,
67);
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
70pub struct StreamCaptureInfo {
71    pub status: StreamCaptureStatus,
72    pub id: u64,
73}
74
75// Type alias for the trait object Box itself (inner box).
76type RustStreamCallbackDyn = Box<dyn FnOnce(Result<()>) + Send + 'static>;
77
78// Type alias for the pointer type stored in the outer box.
79type BoxedCallbackPtr = *mut RustStreamCallbackDyn;
80
81type RustHostFunctionDyn = Box<dyn FnOnce() + Send + 'static>;
82type BoxedHostFunctionPtr = *mut RustHostFunctionDyn;
83
84#[derive(Debug, Clone)]
85pub struct Stream {
86    inner: Arc<StreamInner>,
87}
88
89#[derive(Debug)]
90struct StreamInner {
91    handle: runtime::cudaStream_t,
92    ctx: Arc<Context>,
93    // TODO: Store device ID? Could be useful for multi-GPU.
94    // device_id: DeviceId,
95}
96
97impl PartialEq for Stream {
98    fn eq(&self, other: &Self) -> bool {
99        self.as_raw() == other.as_raw() && Arc::ptr_eq(&self.inner.ctx, &other.inner.ctx)
100    }
101}
102
103impl Eq for Stream {}
104
105#[derive(Debug)]
106pub struct StreamScope<'scope, 'env> {
107    stream: &'scope Stream,
108    _env: PhantomData<&'env mut &'env ()>,
109}
110
111#[derive(Debug)]
112pub struct StreamCaptureScope<'scope> {
113    stream: &'scope Stream,
114    _not_send: PhantomData<*const ()>,
115}
116
117/// Operation that may be recorded into a CUDA graph capture scope.
118///
119/// # Safety
120///
121/// Implementors must only enqueue CUDA work that is valid during stream
122/// capture. Every pointer, handle, and side effect captured into the resulting
123/// graph must have its replay safety contract represented by the operation's
124/// type and constructor.
125pub unsafe trait GraphRecordable {
126    type Output;
127
128    fn record(self, scope: &StreamCaptureScope<'_>) -> Result<Self::Output>;
129}
130
131struct ActiveStreamCapture<'stream> {
132    stream: &'stream Stream,
133    finished: bool,
134}
135
136impl ActiveStreamCapture<'_> {
137    fn finish(mut self) -> Result<Graph> {
138        self.finished = true;
139        self.stream.end_capture()
140    }
141
142    fn discard(mut self) {
143        self.finished = true;
144        drop(self.stream.end_capture());
145    }
146}
147
148impl Drop for ActiveStreamCapture<'_> {
149    fn drop(&mut self) {
150        if !self.finished {
151            drop(self.stream.end_capture());
152        }
153    }
154}
155
156#[derive(Debug, Clone)]
157pub struct BorrowedStream {
158    handle: runtime::cudaStream_t,
159    ctx: Arc<Context>,
160}
161
162#[derive(Debug, Clone)]
163#[non_exhaustive]
164pub enum StreamBinding {
165    Default(Arc<Context>),
166    Borrowed(BorrowedStream),
167}
168
169impl Stream {
170    /// Wraps an existing CUDA stream handle and takes ownership of it.
171    ///
172    /// Dropping the returned stream may block while synchronizing the stream
173    /// before destruction. Use [`Stream::shutdown`] to surface synchronization
174    /// or destruction errors explicitly.
175    ///
176    /// # Safety
177    ///
178    /// `handle` must be a valid CUDA stream owned by `ctx`, and ownership of
179    /// the handle is transferred to the returned [`Stream`]. The handle must
180    /// not be destroyed elsewhere after calling this function.
181    pub unsafe fn from_raw(handle: runtime::cudaStream_t, ctx: Arc<Context>) -> Result<Self> {
182        if handle.is_null() {
183            return Err(Error::NullHandle);
184        }
185
186        Ok(Self {
187            inner: Arc::new(StreamInner { handle, ctx }),
188        })
189    }
190
191    pub fn to_borrowed(&self) -> BorrowedStream {
192        unsafe { BorrowedStream::from_raw(self.as_raw(), Arc::clone(&self.inner.ctx)) }
193    }
194
195    /// Runs `f` with a stream scope and synchronizes this stream before returning.
196    ///
197    /// Use this for scoped asynchronous operations that borrow host or device
198    /// memory until stream completion. For CUDA graph capture, use
199    /// [`Stream::capture`] or [`Stream::capture_executable`].
200    pub fn sync_scope<'env, F, R>(&self, f: F) -> Result<R>
201    where
202        F: for<'scope> FnOnce(&'scope StreamScope<'scope, 'env>) -> Result<R>,
203    {
204        let scope = StreamScope {
205            stream: self,
206            _env: PhantomData,
207        };
208        let result = f(&scope);
209        let sync_result = self.synchronize();
210
211        match (result, sync_result) {
212            (Ok(value), Ok(())) => Ok(value),
213            (Ok(_), Err(err)) => Err(err),
214            (Err(err), Ok(())) | (Err(err), Err(_)) => Err(err),
215        }
216    }
217
218    /// Blocks until stream has completed all operations.
219    /// If [`ContextFlags::SCHEDULE_BLOCKING_SYNC`](crate::context::ContextFlags::SCHEDULE_BLOCKING_SYNC) was set for this device, the host thread will block until the stream is finished with all of its tasks.
220    ///
221    /// Uses standard `default stream` semantics.
222    ///
223    /// # Errors
224    ///
225    /// Returns an error if stream synchronization fails or if a previous asynchronous launch
226    /// reported an error. CUDA may also return initialization-related errors such as
227    /// [`crate::error::Status::NotInitialized`], [`crate::error::Status::CallRequiresNewerDriver`], or
228    /// [`crate::error::Status::NoDevice`] if this call initializes internal runtime state. Callbacks must not
229    /// call CUDA functions; see [`Stream::add_callback`].
230    pub fn synchronize(&self) -> Result<()> {
231        self.inner.ctx.bind()?;
232        unsafe { try_ffi!(runtime::cudaStreamSynchronize(self.as_raw())) }
233    }
234
235    /// Synchronizes this stream, destroys it, and returns any CUDA error.
236    ///
237    /// This is the explicit version of the cleanup normally performed by
238    /// [`Drop`]. It may block while waiting for stream work and callbacks to
239    /// complete. If synchronization fails, destruction is still attempted and
240    /// the synchronization error is returned. If synchronization succeeds but
241    /// destruction fails, the destruction error is returned.
242    pub fn shutdown(self) -> Result<()> {
243        let inner = Arc::try_unwrap(self.inner).map_err(|_| Error::InvalidValue)?;
244        let inner = ManuallyDrop::new(inner);
245        Self::destroy_handle(inner.ctx.as_ref(), inner.handle)
246    }
247
248    /// Returns `true` if all operations in stream have completed, or `false` if not.
249    ///
250    /// For the purposes of Unified Memory, a return value of `true` is equivalent to having called [`Stream::synchronize`].
251    ///
252    /// Uses standard `default stream` semantics.
253    ///
254    /// # Errors
255    ///
256    /// Returns an error if querying the stream fails or if a previous asynchronous launch reported
257    /// an error. CUDA may also return initialization-related errors such as
258    /// [`crate::error::Status::NotInitialized`], [`crate::error::Status::CallRequiresNewerDriver`], or
259    /// [`crate::error::Status::NoDevice`] if this call initializes internal runtime state. Callbacks must not
260    /// call CUDA functions; see [`Stream::add_callback`].
261    pub fn query(&self) -> Result<bool> {
262        let error = unsafe { runtime::cudaStreamQuery(self.as_raw()) };
263        match error {
264            runtime::cudaError_t::CUDA_SUCCESS => Ok(true),
265            runtime::cudaError_t::CUDA_ERROR_NOT_READY => Ok(false),
266            _ => Err(error.into()),
267        }
268    }
269
270    /// Makes all future work submitted to stream wait for all work captured in event.
271    /// See [`sys::cudaEventRecord`](singe_cuda_sys::runtime::cudaEventRecord) for details on what is captured by an event.
272    /// Synchronization is performed efficiently on the device when applicable.
273    /// `event` may be from a different device than `stream`.
274    ///
275    /// Uses standard `default stream` semantics.
276    ///
277    /// # Errors
278    ///
279    /// Returns an error if the stream cannot wait on the event or if a previous asynchronous launch
280    /// reported an error. CUDA may also return initialization-related errors such as
281    /// [`crate::error::Status::NotInitialized`], [`crate::error::Status::CallRequiresNewerDriver`], or
282    /// [`crate::error::Status::NoDevice`] if this call initializes internal runtime state. Callbacks must not
283    /// call CUDA functions; see [`Stream::add_callback`].
284    pub fn wait_event(&self, event: &Event) -> Result<()> {
285        self.wait_event_with_flags(event, 0)
286    }
287
288    /// Makes all future work submitted to stream wait for all work captured in event.
289    /// See [`sys::cudaEventRecord`](singe_cuda_sys::runtime::cudaEventRecord) for details on what is captured by an event.
290    /// `flags` controls how strictly the wait is enforced.
291    /// Synchronization is performed efficiently on the device when applicable.
292    /// `event` may be from a different device than `stream`.
293    ///
294    /// Uses standard `default stream` semantics.
295    ///
296    /// # Errors
297    ///
298    /// Returns an error if the stream cannot wait on the event or if a previous asynchronous launch
299    /// reported an error. CUDA may also return initialization-related errors such as
300    /// [`crate::error::Status::NotInitialized`], [`crate::error::Status::CallRequiresNewerDriver`], or
301    /// [`crate::error::Status::NoDevice`] if this call initializes internal runtime state. Callbacks must not
302    /// call CUDA functions; see [`Stream::add_callback`].
303    pub fn wait_event_with_flags(&self, event: &Event, flags: u32) -> Result<()> {
304        self.inner.ctx.bind()?;
305        unsafe {
306            try_ffi!(runtime::cudaStreamWaitEvent(
307                self.as_raw(),
308                event.as_raw(),
309                flags,
310            ))
311        }
312    }
313
314    /// Begin graph capture on stream.
315    /// When a stream is in capture mode, operations pushed into the stream are captured
316    /// into a graph instead of executed. [`Stream::end_capture`] returns the graph.
317    /// Capture may not be initiated on the legacy default stream.
318    /// Capture must be ended on the same stream in which it was initiated, and it may only be initiated if the stream is not already in capture mode.
319    /// The capture mode may be queried via [`Stream::capture_status`].
320    /// A unique id representing the capture sequence may be queried via [`Stream::capture_info`].
321    ///
322    /// If mode is not [`StreamCaptureMode::Relaxed`], [`Stream::end_capture`] must be called on this stream from the same thread.
323    ///
324    /// Kernels captured using this API must not use texture and surface references.
325    /// Reading or writing through any texture or surface reference is undefined behavior.
326    /// This restriction does not apply to texture and surface objects.
327    ///
328    /// # Errors
329    ///
330    /// Returns an error if the context cannot be bound, capture cannot begin on
331    /// this stream, the capture mode is invalid for the current thread state,
332    /// or a previous asynchronous launch reports an error.
333    pub fn begin_capture(&self, mode: StreamCaptureMode) -> Result<()> {
334        self.inner.ctx.bind()?;
335        unsafe {
336            try_ffi!(runtime::cudaStreamBeginCapture(self.as_raw(), mode.into()))?;
337        }
338        Ok(())
339    }
340
341    /// Begins stream capture into an existing graph.
342    ///
343    /// # Safety
344    ///
345    /// This low-level API captures into `graph`'s existing CUDA handle. Calling
346    /// [`Stream::end_capture`] after this may return that same raw handle; the
347    /// caller must not wrap it as a second owned [`Graph`]. Prefer
348    /// [`Stream::capture`] unless manually managing capture into an existing
349    /// graph is required.
350    pub unsafe fn begin_capture_to_graph(
351        &self,
352        graph: &Graph,
353        dependencies: &[GraphNode],
354        mode: StreamCaptureMode,
355    ) -> Result<()> {
356        unsafe { self.begin_capture_to_graph_with_data(graph, dependencies, &[], mode) }
357    }
358
359    /// Begins stream capture into an existing graph with annotated dependency edges.
360    ///
361    /// # Safety
362    ///
363    /// This has the same ownership restrictions as
364    /// [`Stream::begin_capture_to_graph`].
365    pub unsafe fn begin_capture_to_graph_with_data(
366        &self,
367        graph: &Graph,
368        dependencies: &[GraphNode],
369        edge_data: &[GraphEdgeData],
370        mode: StreamCaptureMode,
371    ) -> Result<()> {
372        if !edge_data.is_empty() && edge_data.len() != dependencies.len() {
373            return Err(Error::GraphDependencyMismatch);
374        }
375
376        let dependencies: Vec<_> = dependencies
377            .iter()
378            .zip(
379                edge_data
380                    .iter()
381                    .copied()
382                    .chain(iter::repeat(GraphEdgeData::default())),
383            )
384            .map(|(node, data)| GraphDependency {
385                node: node.clone(),
386                data,
387            })
388            .collect();
389
390        unsafe { self.begin_capture_to_graph_with_dependencies(graph, &dependencies, mode) }
391    }
392
393    /// Begin graph capture on stream.
394    /// When a stream is in capture mode, operations pushed into the stream are captured
395    /// into a graph instead of executed. [`Stream::end_capture`] returns the graph.
396    ///
397    /// Capture may not be initiated on the legacy default stream.
398    /// Capture must be ended on the same stream in which it was initiated, and it may only be initiated if the stream is not already in capture mode.
399    /// The capture mode may be queried via [`Stream::capture_status`].
400    /// A unique id representing the capture sequence may be queried via [`Stream::capture_info`].
401    ///
402    /// If mode is not [`StreamCaptureMode::Relaxed`], [`Stream::end_capture`] must be called on this stream from the same thread.
403    ///
404    /// Kernels captured using this API must not use texture and surface references.
405    /// Reading or writing through any texture or surface reference is undefined behavior.
406    /// This restriction does not apply to texture and surface objects.
407    ///
408    /// # Errors
409    ///
410    /// Returns an error if the context cannot be bound, capture cannot begin on
411    /// this stream, the graph dependencies are invalid, the capture mode is
412    /// invalid for the current thread state, or a previous asynchronous launch
413    /// reports an error.
414    /// # Safety
415    ///
416    /// This captures into `graph`'s existing CUDA handle. Calling
417    /// [`Stream::end_capture`] after this may return that same raw handle; the
418    /// caller must not wrap it as a second owned [`Graph`].
419    pub unsafe fn begin_capture_to_graph_with_dependencies(
420        &self,
421        graph: &Graph,
422        dependencies: &[GraphDependency],
423        mode: StreamCaptureMode,
424    ) -> Result<()> {
425        self.check_graph_context(graph)?;
426        self.check_capture_dependency_contexts(dependencies)?;
427        self.check_capture_graph_dependencies(graph, dependencies)?;
428        self.inner.ctx.bind()?;
429
430        let dependencies_raw: Vec<_> = dependencies
431            .iter()
432            .map(|dependency| dependency.node.as_raw())
433            .collect();
434        let edge_data_raw: Vec<_> = dependencies
435            .iter()
436            .map(|dependency| dependency.data.into())
437            .collect();
438        unsafe {
439            try_ffi!(runtime::cudaStreamBeginCaptureToGraph(
440                self.as_raw(),
441                graph.as_raw(),
442                dependencies_raw.as_ptr(),
443                if edge_data_raw.is_empty() {
444                    ptr::null()
445                } else {
446                    edge_data_raw.as_ptr()
447                },
448                dependencies_raw.len() as _,
449                mode.into(),
450            ))?;
451        }
452        Ok(())
453    }
454
455    /// Ends capture on this stream, returning the captured graph.
456    /// Capture must have been initiated on stream via a call to [`Stream::begin_capture`].
457    /// If capture was invalidated due to a violation of the rules of stream capture, an error is returned.
458    ///
459    /// If the mode argument to [`Stream::begin_capture`] was not [`StreamCaptureMode::Relaxed`], this call must be from the same thread as [`Stream::begin_capture`].
460    ///
461    /// # Errors
462    ///
463    /// Returns an error if the context cannot be bound, capture is not active
464    /// on this stream, the capture has been invalidated, or a previous
465    /// asynchronous launch reports an error.
466    pub fn end_capture(&self) -> Result<Graph> {
467        self.inner.ctx.bind()?;
468        let mut handle = ptr::null_mut();
469        unsafe {
470            try_ffi!(runtime::cudaStreamEndCapture(
471                self.as_raw(),
472                &raw mut handle
473            ))?;
474            Graph::from_raw_in_context(handle, Arc::clone(&self.inner.ctx))
475        }
476    }
477
478    /// Captures stream work recorded by `f` into a CUDA graph.
479    ///
480    /// This is the scoped form of [`Stream::begin_capture`] and
481    /// [`Stream::end_capture`]. The capture is always ended before this method
482    /// returns or resumes a panic. If `f` returns an error, this method attempts
483    /// to end capture to restore stream usability, destroys any graph returned
484    /// by CUDA, and returns the closure error.
485    ///
486    /// The scope is intentionally `!Send`, so it cannot be moved to another
487    /// thread while capture is active. Future graph-safe recording helpers can
488    /// be added to [`StreamCaptureScope`] without changing this API shape.
489    ///
490    /// # Errors
491    ///
492    /// Returns an error if capture cannot begin, if `f` returns an error, or if
493    /// capture cannot be ended successfully.
494    pub fn capture<F>(&self, mode: StreamCaptureMode, f: F) -> Result<Graph>
495    where
496        F: FnOnce(&StreamCaptureScope<'_>) -> Result<()>,
497    {
498        self.begin_capture(mode)?;
499        let capture = ActiveStreamCapture {
500            stream: self,
501            finished: false,
502        };
503
504        let scope = StreamCaptureScope {
505            stream: self,
506            _not_send: PhantomData,
507        };
508
509        let capture_result = panic::catch_unwind(panic::AssertUnwindSafe(|| f(&scope)));
510        match capture_result {
511            Ok(Ok(())) => capture.finish(),
512            Ok(Err(err)) => {
513                capture.discard();
514                Err(err)
515            }
516            Err(payload) => {
517                drop(capture);
518                panic::resume_unwind(payload);
519            }
520        }
521    }
522
523    pub fn capture_executable<F>(&self, mode: StreamCaptureMode, f: F) -> Result<ExecutableGraph>
524    where
525        F: FnOnce(&StreamCaptureScope<'_>) -> Result<()>,
526    {
527        self.capture_executable_with_flags(mode, GraphInstantiateFlags::empty(), f)
528    }
529
530    pub fn capture_executable_with_flags<F>(
531        &self,
532        mode: StreamCaptureMode,
533        flags: GraphInstantiateFlags,
534        f: F,
535    ) -> Result<ExecutableGraph>
536    where
537        F: FnOnce(&StreamCaptureScope<'_>) -> Result<()>,
538    {
539        let graph = self.capture(mode, f)?;
540        graph.instantiate_with_flags(flags)
541    }
542
543    /// Returns the capture status of this stream.
544    /// After a successful call, the status is one of the following:
545    ///
546    /// * [`StreamCaptureStatus::None`]: The stream is not capturing.
547    /// * [`StreamCaptureStatus::Active`]: The stream is capturing.
548    /// * [`StreamCaptureStatus::Invalidated`]: The stream was capturing but an error has invalidated the capture sequence.
549    ///   The capture sequence must be terminated with
550    ///   [`Stream::end_capture`] on the stream where it was initiated to continue using the stream.
551    ///
552    /// If this is called on the legacy default stream while a blocking stream on the same device is capturing, it returns [`crate::error::Status::StreamCaptureImplicit`].
553    /// The blocking stream capture is not invalidated.
554    ///
555    /// When a blocking stream is capturing, the legacy stream is in an unusable state until the blocking stream capture is terminated.
556    /// The legacy stream is not supported for stream capture, but attempted use would have an implicit dependency on the capturing stream(s).
557    ///
558    /// # Errors
559    ///
560    /// Returns an error if the context cannot be bound, CUDA cannot query the
561    /// capture status, or a previous asynchronous launch reports an error.
562    pub fn capture_status(&self) -> Result<StreamCaptureStatus> {
563        self.inner.ctx.bind()?;
564        let mut status = runtime::cudaStreamCaptureStatus::CU_STREAM_CAPTURE_STATUS_NONE;
565        unsafe {
566            try_ffi!(runtime::cudaStreamIsCapturing(
567                self.as_raw(),
568                &raw mut status
569            ))?;
570        }
571        Ok(status.into())
572    }
573
574    /// Query stream state related to stream capture.
575    ///
576    /// If called on the legacy default stream while a stream not created with [`StreamFlags::NON_BLOCKING`] is capturing, returns [`crate::error::Status::StreamCaptureImplicit`].
577    ///
578    /// Valid data (other than capture status) is returned only if both of the following are true:
579    ///
580    /// * the call succeeds
581    /// * the returned capture status is [`StreamCaptureStatus::Active`]
582    ///
583    /// If there is non-zero edge data for one or more current stream dependencies and the query cannot return that data, the call returns [`crate::error::Status::LossyQuery`].
584    ///
585    /// # Errors
586    ///
587    /// Returns an error if the context cannot be bound, CUDA cannot query the
588    /// capture info, the query would lose non-zero edge data, or a previous
589    /// asynchronous launch reports an error.
590    pub fn capture_info(&self) -> Result<StreamCaptureInfo> {
591        self.inner.ctx.bind()?;
592        let mut status = runtime::cudaStreamCaptureStatus::CU_STREAM_CAPTURE_STATUS_NONE;
593        let mut id = 0;
594        unsafe {
595            try_ffi!(runtime::cudaStreamGetCaptureInfo(
596                self.as_raw(),
597                &raw mut status,
598                &raw mut id,
599                ptr::null_mut(),
600                ptr::null_mut(),
601                ptr::null_mut(),
602                ptr::null_mut(),
603            ))?;
604        }
605        Ok(StreamCaptureInfo {
606            status: status.into(),
607            id,
608        })
609    }
610
611    pub fn update_capture_dependencies(&self, dependencies: &[GraphNode]) -> Result<()> {
612        self.update_capture_dependencies_with_mode(
613            dependencies,
614            &[],
615            StreamCaptureDependencyUpdate::Add,
616        )
617    }
618
619    pub fn update_capture_dependencies_with_data(
620        &self,
621        dependencies: &[GraphNode],
622        edge_data: &[GraphEdgeData],
623    ) -> Result<()> {
624        self.update_capture_dependencies_with_mode(
625            dependencies,
626            edge_data,
627            StreamCaptureDependencyUpdate::Add,
628        )
629    }
630
631    pub fn update_capture_dependencies_with_mode(
632        &self,
633        dependencies: &[GraphNode],
634        edge_data: &[GraphEdgeData],
635        mode: StreamCaptureDependencyUpdate,
636    ) -> Result<()> {
637        if !edge_data.is_empty() && edge_data.len() != dependencies.len() {
638            return Err(Error::GraphDependencyMismatch);
639        }
640
641        let dependencies: Vec<_> = dependencies
642            .iter()
643            .zip(
644                edge_data
645                    .iter()
646                    .copied()
647                    .chain(iter::repeat(GraphEdgeData::default())),
648            )
649            .map(|(node, data)| GraphDependency {
650                node: node.clone(),
651                data,
652            })
653            .collect();
654
655        self.update_capture_dependencies_with_dependencies(&dependencies, mode)
656    }
657
658    /// Modifies the dependency set of a capturing stream.
659    /// The dependency set is the set of nodes that the next captured node in the stream will depend on.
660    ///
661    /// Valid flags are [`StreamCaptureDependencyUpdate::Add`] and [`StreamCaptureDependencyUpdate::Set`].
662    /// These control whether the supplied set is added to the existing set or replaces it.
663    /// A flags value of 0 defaults to [`StreamCaptureDependencyUpdate::Add`].
664    ///
665    /// Nodes that are removed from the dependency set by this call do not result in [`crate::error::Status::StreamCaptureUnjoined`] if they are unreachable from the stream at [`Stream::end_capture`].
666    ///
667    /// Returns [`crate::error::Status::IllegalState`] if the stream is not capturing.
668    ///
669    /// # Errors
670    ///
671    /// Returns an error if the context cannot be bound, the stream is not
672    /// capturing, the supplied dependencies are invalid, or a previous
673    /// asynchronous launch reports an error.
674    pub fn update_capture_dependencies_with_dependencies(
675        &self,
676        dependencies: &[GraphDependency],
677        mode: StreamCaptureDependencyUpdate,
678    ) -> Result<()> {
679        self.check_capture_dependency_contexts(dependencies)?;
680        self.check_active_capture_graph_dependencies(dependencies)?;
681        self.inner.ctx.bind()?;
682
683        let mut dependencies_raw: Vec<_> = dependencies
684            .iter()
685            .map(|dependency| dependency.node.as_raw())
686            .collect();
687        let edge_data_raw: Vec<_> = dependencies
688            .iter()
689            .map(|dependency| dependency.data.into())
690            .collect();
691        unsafe {
692            try_ffi!(runtime::cudaStreamUpdateCaptureDependencies(
693                self.as_raw(),
694                dependencies_raw.as_mut_ptr(),
695                if edge_data_raw.is_empty() {
696                    ptr::null()
697                } else {
698                    edge_data_raw.as_ptr()
699                },
700                dependencies_raw.len() as _,
701                mode.into(),
702            ))?;
703        }
704        Ok(())
705    }
706
707    /// This callback API is slated for eventual deprecation and removal.
708    /// If you do not require the callback to execute after a device error, consider using [`sys::cudaLaunchHostFunc`](singe_cuda_sys::runtime::cudaLaunchHostFunc).
709    /// Additionally, this callback mechanism is not supported with [`Stream::begin_capture`] and [`Stream::end_capture`], unlike [`sys::cudaLaunchHostFunc`](singe_cuda_sys::runtime::cudaLaunchHostFunc).
710    ///
711    /// Adds a callback to be called on the host after all currently enqueued items in the stream have completed.
712    /// For each [`Stream::add_callback`] call, a callback is executed exactly once.
713    /// The callback blocks later work in the stream until it is finished.
714    ///
715    /// The callback may be passed a successful status or an error code.
716    /// In the event of a device error, all subsequently executed callbacks receive an appropriate [`Status`].
717    ///
718    /// Callbacks must not call CUDA functions.
719    /// Attempting to do so may result in [`crate::error::Status::NotPermitted`].
720    /// Callbacks must not perform any synchronization that may depend on outstanding device work or other callbacks that are not mandated to run earlier.
721    /// Callbacks without a mandated order (in independent streams) execute in undefined order and may be serialized.
722    ///
723    /// For the purposes of Unified Memory, callback execution makes a number of guarantees:
724    ///
725    /// * The callback stream is considered idle for the duration of the callback.
726    ///   Thus, for example, a callback may always use memory
727    ///   attached to the callback stream.
728    /// * The start of execution of a callback has the same effect as synchronizing an event recorded in the same stream immediately
729    ///   before the callback.
730    ///   It thus synchronizes streams which have been "joined" before the callback.
731    /// * Adding device work to any stream does not have the effect of making the stream active until all preceding callbacks have executed.
732    ///   Thus, for example, a callback might use global attached memory even if work has been added to another stream, if it has been
733    ///   properly ordered with an event.
734    /// * Completion of a callback does not cause a stream to become active except as described above.
735    ///   The callback stream will remain
736    ///   idle if no device work follows the callback, and will remain idle across consecutive callbacks without device work in between.
737    ///   Thus, for example, stream synchronization can be done by signaling from a callback at the end of the stream.
738    ///
739    /// # Errors
740    ///
741    /// Returns an error if the context cannot be bound, CUDA rejects the
742    /// callback registration, a previous asynchronous launch reports an error,
743    /// or CUDA reports runtime initialization diagnostics such as
744    /// [`crate::error::Status::NotInitialized`], [`crate::error::Status::CallRequiresNewerDriver`],
745    /// or [`crate::error::Status::NoDevice`].
746    pub fn add_callback<F>(&self, callback: F) -> Result<()>
747    where
748        F: FnOnce(Result<()>) + Send + 'static,
749    {
750        self.inner.ctx.bind()?;
751
752        let boxed_dyn_callback: RustStreamCallbackDyn = Box::new(callback);
753        let boxed_wrapper: Box<RustStreamCallbackDyn> = Box::new(boxed_dyn_callback);
754        let user_data_ptr: BoxedCallbackPtr = Box::into_raw(boxed_wrapper);
755        let final_user_data = user_data_ptr.cast();
756
757        let flags = 0u32;
758
759        unsafe {
760            let status = runtime::cudaStreamAddCallback(
761                self.as_raw(),
762                Some(stream_callback_trampoline),
763                final_user_data, // Pass the thin pointer
764                flags,
765            );
766
767            // If adding the callback fails, manually reconstruct and drop the *outer* Box to prevent leaking both boxes.
768            if status != runtime::cudaError_t::CUDA_SUCCESS {
769                // Reconstruct the outer box (Box<Box<dyn Trait>>)
770                let _leaked_box = Box::from_raw(user_data_ptr);
771                // Drop the reconstructed outer box.
772                try_ffi!(status)?;
773            }
774        }
775
776        Ok(())
777    }
778
779    /// Enqueues a host function to run after all currently enqueued work in this stream completes.
780    ///
781    /// Unlike [`Stream::add_callback`], CUDA does not call this function if the CUDA context is already in an error state.
782    /// This API is supported during stream capture by CUDA, but the host function still must not call CUDA
783    /// APIs or perform synchronization that depends on outstanding device work.
784    ///
785    /// # Errors
786    ///
787    /// Returns an error if the context cannot be bound, CUDA rejects the host
788    /// function registration, a previous asynchronous launch reports an error,
789    /// or CUDA reports runtime initialization diagnostics.
790    pub fn launch_host_func<F>(&self, function: F) -> Result<()>
791    where
792        F: FnOnce() + Send + 'static,
793    {
794        self.inner.ctx.bind()?;
795
796        let boxed_dyn_function: RustHostFunctionDyn = Box::new(function);
797        let boxed_wrapper: Box<RustHostFunctionDyn> = Box::new(boxed_dyn_function);
798        let user_data_ptr: BoxedHostFunctionPtr = Box::into_raw(boxed_wrapper);
799        let final_user_data = user_data_ptr.cast();
800
801        unsafe {
802            let status = runtime::cudaLaunchHostFunc(
803                self.as_raw(),
804                Some(stream_host_function_trampoline),
805                final_user_data,
806            );
807
808            if status != runtime::cudaError_t::CUDA_SUCCESS {
809                let _leaked_box = Box::from_raw(user_data_ptr);
810                try_ffi!(status)?;
811            }
812        }
813
814        Ok(())
815    }
816
817    /// Query the flags of a stream.
818    /// Returns the stream flags.
819    /// See [`Context::create_stream_with_flags`] for a list of valid flags.
820    ///
821    /// Uses standard `default stream` semantics.
822    ///
823    /// # Errors
824    ///
825    /// Returns an error if CUDA cannot query the flags or if a previous asynchronous launch
826    /// reported an error. CUDA may also return initialization-related errors such as
827    /// [`crate::error::Status::NotInitialized`], [`crate::error::Status::CallRequiresNewerDriver`], or
828    /// [`crate::error::Status::NoDevice`] if this call initializes internal runtime state. Callbacks must not
829    /// call CUDA functions; see [`Stream::add_callback`].
830    pub fn flags(&self) -> Result<StreamFlags> {
831        self.inner.ctx.bind()?;
832        let mut flags_raw = 0u32;
833        unsafe {
834            try_ffi!(runtime::cudaStreamGetFlags(
835                self.as_raw(),
836                &raw mut flags_raw
837            ))?;
838        }
839        Ok(StreamFlags::from_bits_retain(flags_raw))
840    }
841
842    /// Query the priority of a stream.
843    /// Returns the stream priority.
844    /// If the stream was created with a priority outside the meaningful numerical range returned by [`Device::stream_priority_range`], this returns the clamped priority.
845    /// See [`Context::create_stream_with_priority`] for details about priority clamping.
846    ///
847    /// # Errors
848    ///
849    /// Returns an error if the context cannot be bound, CUDA cannot query the
850    /// priority, a previous asynchronous launch reports an error, or CUDA
851    /// reports runtime initialization diagnostics.
852    pub fn priority(&self) -> Result<i32> {
853        self.inner.ctx.bind()?;
854        let mut priority = 0i32;
855        unsafe {
856            try_ffi!(runtime::cudaStreamGetPriority(
857                self.as_raw(),
858                &raw mut priority
859            ))?;
860        }
861        Ok(priority)
862    }
863
864    /// Returns a stream identifier that remains unique for the life of the program.
865    ///
866    /// The stream handle may refer to any of the following:
867    ///
868    /// * a stream created via any of the CUDA runtime APIs such as [`sys::cudaStreamCreate`](singe_cuda_sys::runtime::cudaStreamCreate), [`Context::create_stream_with_flags`] and [`Context::create_stream_with_priority`], or their driver API equivalents such as [`sys::cuStreamCreate`](singe_cuda_sys::driver::cuStreamCreate) or [`sys::cuStreamCreateWithPriority`](singe_cuda_sys::driver::cuStreamCreateWithPriority).
869    ///   Passing an invalid handle results in undefined behavior.
870    /// * the special legacy default stream and per-thread default stream.
871    ///   The driver API equivalents of these are also accepted.
872    ///
873    /// # Errors
874    ///
875    /// Returns an error if the context cannot be bound, CUDA cannot query the
876    /// stream identifier, a previous asynchronous launch reports an error, or
877    /// CUDA reports runtime initialization diagnostics.
878    pub fn id(&self) -> Result<u64> {
879        self.inner.ctx.bind()?;
880        let mut id = 0u64;
881        unsafe {
882            try_ffi!(runtime::cudaStreamGetId(self.as_raw(), &raw mut id))?;
883        }
884        Ok(id)
885    }
886
887    /// Returns the device of the stream.
888    ///
889    /// # Errors
890    ///
891    /// Returns an error if the context cannot be bound, CUDA cannot query the
892    /// stream device, a previous asynchronous launch reports an error, or CUDA
893    /// reports runtime initialization diagnostics.
894    pub fn device(&self) -> Result<Device> {
895        self.inner.ctx.bind()?;
896        let mut device = 0i32;
897        unsafe {
898            try_ffi!(runtime::cudaStreamGetDevice(self.as_raw(), &raw mut device))?;
899        }
900        Ok(Device::new(device))
901    }
902
903    pub fn context(&self) -> &Context {
904        &self.inner.ctx
905    }
906
907    pub fn as_raw(&self) -> runtime::cudaStream_t {
908        self.inner.handle
909    }
910
911    /// Consumes the stream and returns the raw CUDA stream handle without
912    /// destroying it.
913    ///
914    /// The caller becomes responsible for eventually destroying the returned
915    /// handle with CUDA.
916    pub fn into_raw(self) -> runtime::cudaStream_t {
917        let inner = Arc::try_unwrap(self.inner)
918            .expect("cannot transfer raw stream handle while cloned stream handles exist");
919        let inner = ManuallyDrop::new(inner);
920        inner.handle
921    }
922
923    fn destroy_handle(ctx: &Context, handle: runtime::cudaStream_t) -> Result<()> {
924        let bind_result = ctx.bind();
925        let sync_result =
926            bind_result.and_then(|()| unsafe { try_ffi!(runtime::cudaStreamSynchronize(handle)) });
927        let destroy_result = unsafe { try_ffi!(runtime::cudaStreamDestroy(handle)) };
928
929        match (sync_result, destroy_result) {
930            (Ok(()), Ok(())) => Ok(()),
931            (Err(err), _) | (Ok(()), Err(err)) => Err(err),
932        }
933    }
934
935    // pub fn is_null(&self) -> bool {
936    //     self.inner.handle.is_null()
937    // }
938
939    // --- Methods related to Memory Management ---
940    // Add methods like malloc_async, free_async, attach_mem_async if needed
941    // These would likely take wrappers around device memory pointers.
942
943    fn check_graph_context(&self, graph: &Graph) -> Result<()> {
944        if matches!(graph.context(), Some(ctx) if ctx != self.inner.ctx.as_ref()) {
945            return Err(Error::GraphContextMismatch);
946        }
947        Ok(())
948    }
949
950    fn check_capture_dependency_contexts(&self, dependencies: &[GraphDependency]) -> Result<()> {
951        for dependency in dependencies {
952            if matches!(dependency.node.context(), Some(ctx) if ctx != self.inner.ctx.as_ref()) {
953                return Err(Error::GraphContextMismatch);
954            }
955        }
956        Ok(())
957    }
958
959    fn check_capture_graph_dependencies(
960        &self,
961        graph: &Graph,
962        dependencies: &[GraphDependency],
963    ) -> Result<()> {
964        for dependency in dependencies {
965            graph.check_node(&dependency.node)?;
966        }
967        Ok(())
968    }
969
970    fn check_active_capture_graph_dependencies(
971        &self,
972        dependencies: &[GraphDependency],
973    ) -> Result<()> {
974        if dependencies.is_empty() {
975            return Ok(());
976        }
977
978        self.inner.ctx.bind()?;
979        let mut status = runtime::cudaStreamCaptureStatus::CU_STREAM_CAPTURE_STATUS_NONE;
980        let mut graph = ptr::null_mut();
981        unsafe {
982            try_ffi!(runtime::cudaStreamGetCaptureInfo(
983                self.as_raw(),
984                &raw mut status,
985                ptr::null_mut(),
986                &raw mut graph,
987                ptr::null_mut(),
988                ptr::null_mut(),
989                ptr::null_mut(),
990            ))?;
991        }
992        if StreamCaptureStatus::from(status) != StreamCaptureStatus::Active {
993            return Ok(());
994        }
995        if graph.is_null() {
996            return Err(Error::NullHandle);
997        }
998
999        for dependency in dependencies {
1000            if !matches!(dependency.node.graph_raw(), Some(node_graph) if node_graph == graph) {
1001                return Err(Error::GraphNodeMismatch);
1002            }
1003        }
1004        Ok(())
1005    }
1006
1007    pub(crate) fn ensure_not_capturing_for_future(&self) -> Result<()> {
1008        match self.capture_status()? {
1009            StreamCaptureStatus::None => Ok(()),
1010            StreamCaptureStatus::Active => Err(Status::StreamCaptureUnsupported.into()),
1011            StreamCaptureStatus::Invalidated => Err(Status::StreamCaptureInvalidated.into()),
1012        }
1013    }
1014}
1015
1016impl<'scope, 'env> StreamScope<'scope, 'env> {
1017    pub const fn stream(&self) -> &'scope Stream {
1018        self.stream
1019    }
1020
1021    pub fn synchronize(&self) -> Result<()> {
1022        self.stream.synchronize()
1023    }
1024}
1025
1026impl<'scope> StreamCaptureScope<'scope> {
1027    pub const fn stream(&self) -> &'scope Stream {
1028        self.stream
1029    }
1030
1031    /// Records a graph-safe operation into this active stream capture.
1032    ///
1033    /// Only operations implementing [`GraphRecordable`] can be submitted
1034    /// through this method. Allocation/free and other capture-unsafe CUDA calls
1035    /// should stay outside this trait unless their replay ownership and address
1036    /// stability are explicitly modeled.
1037    pub fn record<O>(&self, operation: O) -> Result<O::Output>
1038    where
1039        O: GraphRecordable,
1040    {
1041        operation.record(self)
1042    }
1043}
1044
1045impl BorrowedStream {
1046    /// Wraps an existing CUDA stream handle without taking ownership.
1047    ///
1048    /// # Safety
1049    ///
1050    /// `handle` must be a valid CUDA stream associated with `ctx`, and it must
1051    /// remain valid for every operation using the returned borrowed stream.
1052    pub const unsafe fn from_raw(handle: runtime::cudaStream_t, ctx: Arc<Context>) -> Self {
1053        Self { handle, ctx }
1054    }
1055
1056    pub fn synchronize(&self) -> Result<()> {
1057        self.ctx.bind()?;
1058        unsafe { try_ffi!(runtime::cudaStreamSynchronize(self.as_raw())) }
1059    }
1060
1061    pub fn context(&self) -> &Context {
1062        &self.ctx
1063    }
1064
1065    pub const fn as_raw(&self) -> runtime::cudaStream_t {
1066        self.handle
1067    }
1068}
1069
1070impl StreamBinding {
1071    pub fn context(&self) -> &Context {
1072        match self {
1073            Self::Default(ctx) => ctx.as_ref(),
1074            Self::Borrowed(stream) => stream.context(),
1075        }
1076    }
1077
1078    pub fn is_default(&self) -> bool {
1079        matches!(self, Self::Default(..))
1080    }
1081
1082    pub fn as_raw(&self) -> runtime::cudaStream_t {
1083        match self {
1084            Self::Default(_) => ptr::null_mut(),
1085            Self::Borrowed(stream) => stream.as_raw(),
1086        }
1087    }
1088}
1089
1090// CUDA streams are ordering handles.
1091// Operations take &self and are serialized by CUDA stream semantics rather than mutable Rust state.
1092unsafe impl Send for StreamInner {}
1093unsafe impl Sync for StreamInner {}
1094unsafe impl Send for Stream {}
1095unsafe impl Sync for Stream {}
1096
1097impl Drop for StreamInner {
1098    fn drop(&mut self) {
1099        if let Err(err) = Stream::destroy_handle(self.ctx.as_ref(), self.handle) {
1100            #[cfg(debug_assertions)]
1101            eprintln!("failed to synchronize or destroy CUDA stream: {err}");
1102        }
1103    }
1104}
1105
1106// Trampoline function to bridge C FFI callback to Rust closure
1107extern "C" fn stream_callback_trampoline(
1108    _stream: runtime::cudaStream_t,
1109    status: runtime::cudaError_t,
1110    user_data: *mut std::ffi::c_void,
1111) {
1112    if user_data.is_null() {
1113        return;
1114    }
1115
1116    let user_data_ptr = user_data as BoxedCallbackPtr;
1117
1118    let boxed_callback: Box<RustStreamCallbackDyn> = unsafe { Box::from_raw(user_data_ptr) };
1119    let callback: RustStreamCallbackDyn = *boxed_callback;
1120
1121    let result = if status == runtime::cudaError_t::CUDA_SUCCESS {
1122        Ok(())
1123    } else {
1124        Err(status.into())
1125    };
1126
1127    callback(result);
1128}
1129
1130extern "C" fn stream_host_function_trampoline(user_data: *mut std::ffi::c_void) {
1131    if user_data.is_null() {
1132        return;
1133    }
1134
1135    let user_data_ptr = user_data as BoxedHostFunctionPtr;
1136    let boxed_function: Box<RustHostFunctionDyn> = unsafe { Box::from_raw(user_data_ptr) };
1137    let function: RustHostFunctionDyn = *boxed_function;
1138    function();
1139}
1140
1141impl Context {
1142    pub fn create_stream(self: &Arc<Self>) -> Result<Stream> {
1143        self.create_stream_with_flags(StreamFlags::DEFAULT)
1144    }
1145
1146    /// Creates a new asynchronous stream on the context that is current to the calling host thread.
1147    /// If no context is current to the calling host thread, then the primary context for a device is selected, made current to the calling thread, and initialized before creating a stream on it.
1148    /// The flags argument determines the behaviors of the stream.
1149    /// Valid values are provided by [`StreamFlags`]:
1150    ///
1151    /// * [`StreamFlags::DEFAULT`]: default stream creation behavior.
1152    /// * [`StreamFlags::NON_BLOCKING`]: allows the created stream to run concurrently with the legacy default stream without implicit synchronization.
1153    ///
1154    /// # Errors
1155    ///
1156    /// Returns an error if CUDA cannot create the stream, if it does not return a valid stream
1157    /// handle, or if a previous asynchronous launch reported an error. CUDA may also return
1158    /// initialization-related errors such as [`crate::error::Status::NotInitialized`],
1159    /// [`crate::error::Status::CallRequiresNewerDriver`], or [`crate::error::Status::NoDevice`] if this call initializes
1160    /// internal runtime state. Callbacks must not call CUDA functions; see
1161    /// [`Stream::add_callback`].
1162    pub fn create_stream_with_flags(self: &Arc<Self>, flags: StreamFlags) -> Result<Stream> {
1163        self.bind()?;
1164        let mut handle = ptr::null_mut();
1165        unsafe {
1166            try_ffi!(runtime::cudaStreamCreateWithFlags(
1167                &raw mut handle,
1168                flags.bits(),
1169            ))?;
1170        }
1171        if handle.is_null() {
1172            return Err(Error::NullHandle);
1173        }
1174        // let mut device_id = 0;
1175        // unsafe { check(cudaStreamGetDevice(stream, &mut device_id))?; }
1176        unsafe { Stream::from_raw(handle, Arc::clone(self)) }
1177    }
1178
1179    /// Creates a stream with the specified priority.
1180    /// The stream is created on this context.
1181    /// This affects the scheduling priority of work in the stream.
1182    /// Priorities provide a hint to preferentially run work with higher priority when possible, but do not preempt already-running work or provide any other functional guarantee on execution order.
1183    ///
1184    /// `priority` follows a convention where lower numbers represent higher priorities.
1185    /// `0` represents default priority.
1186    /// The range of meaningful numerical priorities can be queried using [`Device::stream_priority_range`].
1187    /// If the specified priority is outside the numerical range returned by [`Device::stream_priority_range`], it will automatically be clamped to the lowest or the highest number in the range.
1188    ///
1189    /// * Stream priorities are supported only on GPUs with compute capability 3.5 or higher.
1190    /// * In the current implementation, only compute kernels launched in priority streams are affected by the stream's priority.
1191    ///   Stream
1192    ///   priorities have no effect on host-to-device and device-to-host memory operations.
1193    ///
1194    /// # Errors
1195    ///
1196    /// Returns an error if the context cannot be bound, CUDA cannot create the
1197    /// stream, CUDA returns a null stream handle, a previous asynchronous launch
1198    /// reports an error, or CUDA reports runtime initialization diagnostics.
1199    pub fn create_stream_with_priority(
1200        self: &Arc<Self>,
1201        flags: StreamFlags,
1202        priority: i32,
1203    ) -> Result<Stream> {
1204        self.bind()?;
1205        let mut handle = ptr::null_mut();
1206        unsafe {
1207            try_ffi!(runtime::cudaStreamCreateWithPriority(
1208                &raw mut handle,
1209                flags.bits(),
1210                priority,
1211            ))?;
1212        }
1213        if handle.is_null() {
1214            return Err(Error::NullHandle);
1215        }
1216        unsafe { Stream::from_raw(handle, Arc::clone(self)) }
1217    }
1218}
1219
1220/// Sets the calling thread's stream capture interaction mode, returning the previous mode for the thread.
1221/// To facilitate deterministic behavior across function or module boundaries, use this in a push-pop fashion.
1222///
1223/// During stream capture (see [`Stream::begin_capture`]), some actions, such as a call to [`DeviceMemory::alloc`](crate::memory::DeviceMemory::alloc), may be unsafe.
1224/// In the case of [`DeviceMemory::alloc`](crate::memory::DeviceMemory::alloc), the operation is not enqueued asynchronously to a stream, and is not observed by stream capture.
1225/// If the sequence of operations captured via [`Stream::begin_capture`] depended on the allocation being replayed whenever the graph is launched, the captured graph would be invalid.
1226///
1227/// Therefore, stream capture places restrictions on CUDA calls that can be made within or concurrently to a [`Stream::begin_capture`]-[`Stream::end_capture`] sequence.
1228/// Control this behavior with this function and flags to [`Stream::begin_capture`].
1229///
1230/// A thread's mode is one of the following:
1231///
1232/// * [`StreamCaptureMode::Global`]: default mode.
1233///   If the local thread has an ongoing capture sequence that was not initiated with [`StreamCaptureMode::Relaxed`] at [`Stream::begin_capture`], or if any other thread has a concurrent capture sequence initiated with [`StreamCaptureMode::Global`], this thread is prohibited from potentially unsafe CUDA calls.
1234/// * [`StreamCaptureMode::ThreadLocal`]: If the local thread has an ongoing capture sequence not initiated with [`StreamCaptureMode::Relaxed`], it is prohibited from potentially unsafe CUDA calls.
1235///   Concurrent capture sequences in other threads are ignored.
1236/// * [`StreamCaptureMode::Relaxed`]: The local thread is not prohibited from potentially unsafe CUDA calls.
1237///   The thread is still prohibited from CUDA calls
1238///   which necessarily conflict with stream capture, for example, attempting [`Event::query`] on an event that was last recorded inside a capture sequence.
1239///
1240/// # Errors
1241///
1242/// Returns an error if CUDA rejects the capture-mode exchange or if a previous asynchronous launch
1243/// reported an error.
1244pub fn exchange_capture_mode(mode: StreamCaptureMode) -> Result<StreamCaptureMode> {
1245    let mut mode_raw: runtime::cudaStreamCaptureMode = mode.into();
1246    unsafe {
1247        try_ffi!(runtime::cudaThreadExchangeStreamCaptureMode(
1248            &raw mut mode_raw
1249        ))?;
1250    }
1251    Ok(mode_raw.into())
1252}
1253
1254#[cfg(all(test, feature = "testing"))]
1255mod tests {
1256    use std::{
1257        sync::{
1258            Arc,
1259            atomic::{AtomicBool, Ordering},
1260        },
1261        thread,
1262    };
1263
1264    use super::*;
1265    use crate::{event::EventRecordFlags, memory::DeviceMemory, testing};
1266
1267    #[test]
1268    fn it_works() -> Result<()> {
1269        let (_lock, ctx) = testing::bootstrap()?;
1270        let stream1 = ctx.create_stream()?;
1271        let _stream2 = ctx.create_stream_with_flags(StreamFlags::NON_BLOCKING)?;
1272
1273        let stream1_called = Arc::new(AtomicBool::new(false));
1274        stream1.add_callback(Box::new({
1275            let stream1_called = Arc::clone(&stream1_called);
1276            move |_status| {
1277                stream1_called.store(true, Ordering::SeqCst);
1278            }
1279        }))?;
1280
1281        let is_done = stream1.query()?;
1282        assert!(!is_done);
1283
1284        stream1.synchronize()?;
1285
1286        let is_done_after = stream1.query()?;
1287        assert!(is_done_after);
1288
1289        assert!(stream1_called.load(Ordering::SeqCst));
1290
1291        Ok(())
1292    }
1293
1294    #[test]
1295    fn event_query_uses_event_context() -> Result<()> {
1296        let (_lock, ctx) = testing::bootstrap()?;
1297        let stream = ctx.create_stream()?;
1298        let event = ctx.create_event()?;
1299
1300        event.record(&stream, EventRecordFlags::DEFAULT)?;
1301        stream.synchronize()?;
1302
1303        assert!(event.query()?);
1304        Ok(())
1305    }
1306
1307    #[test]
1308    fn shutdown_synchronizes_and_destroys_stream() -> Result<()> {
1309        let (_lock, ctx) = testing::bootstrap()?;
1310        let stream = ctx.create_stream()?;
1311
1312        let called = Arc::new(AtomicBool::new(false));
1313        stream.add_callback(Box::new({
1314            let called = Arc::clone(&called);
1315            move |_status| {
1316                called.store(true, Ordering::SeqCst);
1317            }
1318        }))?;
1319
1320        stream.shutdown()?;
1321        assert!(called.load(Ordering::SeqCst));
1322        Ok(())
1323    }
1324
1325    #[test]
1326    fn launch_host_func_runs_after_stream_work() -> Result<()> {
1327        let (_lock, ctx) = testing::bootstrap()?;
1328        let stream = ctx.create_stream()?;
1329
1330        let called = Arc::new(AtomicBool::new(false));
1331        stream.launch_host_func({
1332            let called = Arc::clone(&called);
1333            move || {
1334                called.store(true, Ordering::SeqCst);
1335            }
1336        })?;
1337        stream.synchronize()?;
1338
1339        assert!(called.load(Ordering::SeqCst));
1340        Ok(())
1341    }
1342
1343    #[test]
1344    fn scoped_capture_returns_context_associated_graph() -> Result<()> {
1345        let (_lock, ctx) = testing::bootstrap()?;
1346        let stream = ctx.create_stream()?;
1347
1348        let graph = stream.capture(StreamCaptureMode::Relaxed, |scope| {
1349            assert_eq!(scope.stream().context(), ctx.as_ref());
1350            Ok(())
1351        })?;
1352
1353        assert_eq!(graph.context(), Some(ctx.as_ref()));
1354        Ok(())
1355    }
1356
1357    #[test]
1358    fn capture_to_graph_rejects_graph_from_different_context() -> Result<()> {
1359        let (_lock, ctx) = testing::bootstrap()?;
1360        let other_ctx = Context::create()?;
1361
1362        let stream = ctx.create_stream()?;
1363        let graph = other_ctx.create_graph()?;
1364
1365        assert!(matches!(
1366            unsafe { stream.begin_capture_to_graph(&graph, &[], StreamCaptureMode::Relaxed) },
1367            Err(Error::GraphContextMismatch)
1368        ));
1369        assert_eq!(stream.capture_status()?, StreamCaptureStatus::None);
1370        Ok(())
1371    }
1372
1373    #[test]
1374    fn capture_to_graph_rejects_node_from_different_graph() -> Result<()> {
1375        let (_lock, ctx) = testing::bootstrap()?;
1376
1377        let stream = ctx.create_stream()?;
1378        let graph = ctx.create_graph()?;
1379        let mut other_graph = ctx.create_graph()?;
1380        let other_node = other_graph.add_empty_node(&[])?;
1381
1382        assert!(matches!(
1383            unsafe {
1384                stream.begin_capture_to_graph(&graph, &[other_node], StreamCaptureMode::Relaxed)
1385            },
1386            Err(Error::GraphNodeMismatch)
1387        ));
1388        assert_eq!(stream.capture_status()?, StreamCaptureStatus::None);
1389        Ok(())
1390    }
1391
1392    #[test]
1393    fn capture_to_graph_rejects_unassociated_raw_dependency_node() -> Result<()> {
1394        let (_lock, ctx) = testing::bootstrap()?;
1395
1396        let stream = ctx.create_stream()?;
1397        let graph = ctx.create_graph()?;
1398        let raw_node = unsafe { GraphNode::from_raw(0x1usize as _) };
1399
1400        assert!(matches!(
1401            unsafe {
1402                stream.begin_capture_to_graph(&graph, &[raw_node], StreamCaptureMode::Relaxed)
1403            },
1404            Err(Error::GraphNodeMismatch)
1405        ));
1406        assert_eq!(stream.capture_status()?, StreamCaptureStatus::None);
1407        Ok(())
1408    }
1409
1410    #[test]
1411    fn capture_dependency_update_rejects_node_from_different_context() -> Result<()> {
1412        let (_lock, ctx) = testing::bootstrap()?;
1413        let other_ctx = Context::create()?;
1414
1415        let stream = ctx.create_stream()?;
1416        let mut other_graph = other_ctx.create_graph()?;
1417        let other_node = other_graph.add_empty_node(&[])?;
1418
1419        let result = stream.capture(StreamCaptureMode::Relaxed, |_scope| {
1420            stream.update_capture_dependencies(&[other_node])
1421        });
1422
1423        assert!(matches!(result, Err(Error::GraphContextMismatch)));
1424        assert_eq!(stream.capture_status()?, StreamCaptureStatus::None);
1425        Ok(())
1426    }
1427
1428    #[test]
1429    fn capture_dependency_update_rejects_node_from_different_graph() -> Result<()> {
1430        let (_lock, ctx) = testing::bootstrap()?;
1431
1432        let stream = ctx.create_stream()?;
1433        let mut other_graph = ctx.create_graph()?;
1434        let other_node = other_graph.add_empty_node(&[])?;
1435
1436        stream.begin_capture(StreamCaptureMode::Relaxed)?;
1437        assert!(matches!(
1438            stream.update_capture_dependencies(&[other_node]),
1439            Err(Error::GraphNodeMismatch)
1440        ));
1441        drop(stream.end_capture());
1442        assert_eq!(stream.capture_status()?, StreamCaptureStatus::None);
1443        Ok(())
1444    }
1445
1446    #[test]
1447    fn capture_dependency_update_rejects_unassociated_raw_node() -> Result<()> {
1448        let (_lock, ctx) = testing::bootstrap()?;
1449
1450        let stream = ctx.create_stream()?;
1451        let raw_node = unsafe { GraphNode::from_raw(0x1usize as _) };
1452
1453        stream.begin_capture(StreamCaptureMode::Relaxed)?;
1454        assert!(matches!(
1455            stream.update_capture_dependencies(&[raw_node]),
1456            Err(Error::GraphNodeMismatch)
1457        ));
1458        drop(stream.end_capture());
1459        assert_eq!(stream.capture_status()?, StreamCaptureStatus::None);
1460        Ok(())
1461    }
1462
1463    #[test]
1464    fn captures_on_separate_streams_can_overlap() -> Result<()> {
1465        let (_lock, ctx) = testing::bootstrap()?;
1466        let stream_a = ctx.create_stream()?;
1467        let stream_b = ctx.create_stream()?;
1468
1469        thread::scope(|scope| {
1470            let a = scope.spawn(|| {
1471                let graph = stream_a.capture(StreamCaptureMode::Relaxed, |_scope| Ok(()))?;
1472                assert_eq!(graph.context(), Some(ctx.as_ref()));
1473                Result::<()>::Ok(())
1474            });
1475            let b = scope.spawn(|| {
1476                let graph = stream_b.capture(StreamCaptureMode::Relaxed, |_scope| Ok(()))?;
1477                assert_eq!(graph.context(), Some(ctx.as_ref()));
1478                Result::<()>::Ok(())
1479            });
1480
1481            a.join().expect("capture thread panicked")?;
1482            b.join().expect("capture thread panicked")?;
1483            Result::<()>::Ok(())
1484        })?;
1485
1486        Ok(())
1487    }
1488
1489    #[test]
1490    fn scoped_capture_error_leaves_stream_usable() -> Result<()> {
1491        let (_lock, ctx) = testing::bootstrap()?;
1492        let stream = ctx.create_stream()?;
1493
1494        let result = stream.capture(StreamCaptureMode::Relaxed, |_scope| {
1495            Err(Error::InvalidValue)
1496        });
1497
1498        assert!(matches!(result, Err(Error::InvalidValue)));
1499        stream.synchronize()?;
1500        Ok(())
1501    }
1502
1503    #[test]
1504    fn scoped_capture_panic_leaves_stream_usable() -> Result<()> {
1505        let (_lock, ctx) = testing::bootstrap()?;
1506        let stream = ctx.create_stream()?;
1507
1508        let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
1509            let _ = stream.capture(StreamCaptureMode::Relaxed, |_scope| -> Result<()> {
1510                panic!("capture body panic");
1511            });
1512        }));
1513
1514        assert!(result.is_err());
1515        stream.synchronize()?;
1516        Ok(())
1517    }
1518
1519    #[test]
1520    fn scoped_capture_records_memory_operations() -> Result<()> {
1521        let (_lock, ctx) = testing::bootstrap()?;
1522        let stream = ctx.create_stream()?;
1523
1524        let input = [1u8, 2, 3, 4];
1525        let source = DeviceMemory::from_slice(&input)?;
1526        let mut copied = DeviceMemory::<u8>::zeroes(input.len())?;
1527        let mut filled = DeviceMemory::<u8>::zeroes(input.len())?;
1528
1529        let graph = stream.capture(StreamCaptureMode::Relaxed, |scope| {
1530            let copy = unsafe { copied.copy_from_device_operation(&source)? };
1531            scope.record(copy)?;
1532
1533            let memset = unsafe { filled.set_value_operation(0xab) };
1534            scope.record(memset)
1535        })?;
1536
1537        let executable = graph.instantiate()?;
1538        executable.launch(&stream)?;
1539        stream.synchronize()?;
1540
1541        assert_eq!(copied.copy_to_host_vec()?, input);
1542        assert_eq!(filled.copy_to_host_vec()?, [0xab; 4]);
1543        Ok(())
1544    }
1545}