Skip to main content

baracuda_runtime/
graph.rs

1//! CUDA Graphs (Runtime API).
2//!
3//! Two construction paths match the Driver-side API:
4//!
5//! 1. **Stream capture** ([`Stream::begin_capture`], [`Stream::end_capture`])
6//!    — run your work on a stream; CUDA records it into a graph instead
7//!    of executing. This is the usual path for applications.
8//! 2. **Explicit construction** ([`Graph::new`]) — empty graph, add
9//!    nodes by hand via the runtime's `cudaGraphAdd*Node` functions
10//!    (baracuda-runtime doesn't expose typed builders for those yet —
11//!    use stream capture, or drop down to the raw PFNs).
12//!
13//! Either way, [`Graph::instantiate`] compiles to [`GraphExec`], and
14//! [`GraphExec::launch`] runs it on any stream.
15
16use std::sync::Arc;
17
18use baracuda_cuda_sys::runtime::{
19    cudaGraphExec_t, cudaGraphNode_t, cudaGraph_t, runtime, types::cudaStreamCaptureStatus,
20};
21
22use crate::error::{check, Result};
23use crate::stream::Stream;
24
25/// Stream-capture mode (matches `cudaStreamCaptureMode`).
26#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
27pub enum CaptureMode {
28    Global,
29    #[default]
30    ThreadLocal,
31    Relaxed,
32}
33
34impl CaptureMode {
35    #[inline]
36    fn raw(self) -> i32 {
37        match self {
38            CaptureMode::Global => 0,
39            CaptureMode::ThreadLocal => 1,
40            CaptureMode::Relaxed => 2,
41        }
42    }
43}
44
45impl Stream {
46    /// Begin recording operations submitted to this stream into a graph.
47    pub fn begin_capture(&self, mode: CaptureMode) -> Result<()> {
48        let r = runtime()?;
49        let cu = r.cuda_stream_begin_capture()?;
50        check(unsafe { cu(self.as_raw(), mode.raw()) })
51    }
52
53    /// Stop capture and return the graph of everything recorded.
54    pub fn end_capture(&self) -> Result<Graph> {
55        let r = runtime()?;
56        let cu = r.cuda_stream_end_capture()?;
57        let mut graph: cudaGraph_t = core::ptr::null_mut();
58        check(unsafe { cu(self.as_raw(), &mut graph) })?;
59        Ok(Graph {
60            inner: Arc::new(GraphInner { handle: graph }),
61        })
62    }
63
64    /// Convenience wrapper: run `f`, capturing its submissions to this
65    /// stream, and return the resulting graph.
66    ///
67    /// **Panic safety**: if `f` panics, the stream's capture state is
68    /// cleaned up before the panic propagates. This matters for
69    /// `CaptureMode::ThreadLocal` — without the cleanup the calling
70    /// thread is left in capture mode, and ALL subsequent CUDA
71    /// allocations on that thread fail with `cudaErrorStreamCaptureImplicit`
72    /// (906). cargo's test harness reuses threads from a pool, so a
73    /// leaked capture from a panicking test poisons every following
74    /// test that lands on the same thread.
75    pub fn capture<F>(&self, mode: CaptureMode, f: F) -> Result<Graph>
76    where
77        F: FnOnce(&Stream) -> Result<()>,
78    {
79        /// RAII guard: ensures `end_capture` runs on the unwind path
80        /// even if `f` panics. We deliberately swallow any error from
81        /// the cleanup `end_capture` — by the time Drop fires we have
82        /// no return channel, and the alternative (double panic) is
83        /// worse than a silently-dropped status code.
84        struct CaptureGuard<'a> {
85            stream: &'a Stream,
86            armed: bool,
87        }
88        impl Drop for CaptureGuard<'_> {
89            fn drop(&mut self) {
90                if self.armed {
91                    let _ = self.stream.end_capture();
92                }
93            }
94        }
95
96        self.begin_capture(mode)?;
97        let mut guard = CaptureGuard { stream: self, armed: true };
98        let inner_result = f(self);
99        // Disarm the guard before our explicit end_capture so we don't
100        // double-end on the normal path.
101        guard.armed = false;
102        let end_result = self.end_capture();
103        match (inner_result, end_result) {
104            (Ok(()), Ok(graph)) => Ok(graph),
105            (Err(e), _) => Err(e),
106            (Ok(()), Err(e)) => Err(e),
107        }
108    }
109
110    /// `true` if this stream is currently recording into a graph.
111    pub fn is_capturing(&self) -> Result<bool> {
112        let r = runtime()?;
113        let cu = r.cuda_stream_is_capturing()?;
114        let mut status: core::ffi::c_int = 0;
115        check(unsafe { cu(self.as_raw(), &mut status) })?;
116        Ok(status == cudaStreamCaptureStatus::ACTIVE)
117    }
118}
119
120/// A CUDA graph — DAG of operations, replayable via [`Graph::instantiate`].
121#[derive(Clone)]
122pub struct Graph {
123    inner: Arc<GraphInner>,
124}
125
126struct GraphInner {
127    handle: cudaGraph_t,
128}
129
130unsafe impl Send for GraphInner {}
131unsafe impl Sync for GraphInner {}
132
133impl core::fmt::Debug for GraphInner {
134    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
135        f.debug_struct("Graph")
136            .field("handle", &self.handle)
137            .finish_non_exhaustive()
138    }
139}
140
141impl core::fmt::Debug for Graph {
142    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
143        self.inner.fmt(f)
144    }
145}
146
147impl Graph {
148    /// Create an empty graph.
149    pub fn new() -> Result<Self> {
150        let r = runtime()?;
151        let cu = r.cuda_graph_create()?;
152        let mut graph: cudaGraph_t = core::ptr::null_mut();
153        check(unsafe { cu(&mut graph, 0) })?;
154        Ok(Self {
155            inner: Arc::new(GraphInner { handle: graph }),
156        })
157    }
158
159    /// Compile this graph into an executable form.
160    pub fn instantiate(&self) -> Result<GraphExec> {
161        let r = runtime()?;
162        let cu = r.cuda_graph_instantiate()?;
163        let mut exec: cudaGraphExec_t = core::ptr::null_mut();
164        check(unsafe { cu(&mut exec, self.inner.handle, 0) })?;
165        Ok(GraphExec {
166            inner: Arc::new(GraphExecInner { handle: exec }),
167        })
168    }
169
170    /// Approximate node count (for debugging).
171    pub fn node_count(&self) -> Result<usize> {
172        let r = runtime()?;
173        let cu = r.cuda_graph_get_nodes()?;
174        let mut count: usize = 0;
175        check(unsafe { cu(self.inner.handle, core::ptr::null_mut(), &mut count) })?;
176        Ok(count)
177    }
178
179    #[inline]
180    pub fn as_raw(&self) -> cudaGraph_t {
181        self.inner.handle
182    }
183
184    /// Add an empty "barrier" node with the given dependencies.
185    pub fn add_empty_node(&self, dependencies: &[GraphNode]) -> Result<GraphNode> {
186        let r = runtime()?;
187        let cu = r.cuda_graph_add_empty_node()?;
188        let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
189        let (dp, dl) = deps_raw(&deps);
190        let mut node: cudaGraphNode_t = core::ptr::null_mut();
191        check(unsafe { cu(&mut node, self.inner.handle, dp, dl) })?;
192        Ok(GraphNode { raw: node })
193    }
194
195    /// Add a kernel-launch node.
196    ///
197    /// # Safety
198    ///
199    /// Same discipline as [`crate::LaunchBuilder::launch`]: argument
200    /// count/order/types must match the kernel's C signature.
201    pub unsafe fn add_kernel_node(
202        &self,
203        dependencies: &[GraphNode],
204        kernel: &crate::Kernel,
205        grid: crate::Dim3,
206        block: crate::Dim3,
207        shared_mem_bytes: u32,
208        args: &mut [*mut core::ffi::c_void],
209    ) -> Result<GraphNode> { unsafe {
210        use baracuda_cuda_sys::runtime::types::{cudaKernelNodeParams, dim3};
211        let r = runtime()?;
212        let cu = r.cuda_graph_add_kernel_node()?;
213        let params = cudaKernelNodeParams {
214            func: kernel.as_launch_ptr() as *mut core::ffi::c_void,
215            grid_dim: dim3::new(grid.x, grid.y, grid.z),
216            block_dim: dim3::new(block.x, block.y, block.z),
217            shared_mem_bytes,
218            kernel_params: if args.is_empty() {
219                core::ptr::null_mut()
220            } else {
221                args.as_mut_ptr()
222            },
223            extra: core::ptr::null_mut(),
224        };
225        let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
226        let (dp, dl) = deps_raw(&deps);
227        let mut node: cudaGraphNode_t = core::ptr::null_mut();
228        check(cu(&mut node, self.inner.handle, dp, dl, &params))?;
229        Ok(GraphNode { raw: node })
230    }}
231
232    /// Add a memset node filling `count` 4-byte words at `dst` with `value`.
233    pub fn add_memset_u32_node(
234        &self,
235        dependencies: &[GraphNode],
236        dst: *mut core::ffi::c_void,
237        value: u32,
238        count: usize,
239    ) -> Result<GraphNode> {
240        use baracuda_cuda_sys::runtime::types::cudaMemsetParams;
241        let r = runtime()?;
242        let cu = r.cuda_graph_add_memset_node()?;
243        let params = cudaMemsetParams {
244            dst,
245            pitch: 0,
246            value,
247            element_size: 4,
248            width: count,
249            height: 1,
250        };
251        let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
252        let (dp, dl) = deps_raw(&deps);
253        let mut node: cudaGraphNode_t = core::ptr::null_mut();
254        check(unsafe { cu(&mut node, self.inner.handle, dp, dl, &params) })?;
255        Ok(GraphNode { raw: node })
256    }
257
258    /// Add a host-function node. `fn_` runs on a driver-owned thread
259    /// when this node executes.
260    ///
261    /// # Safety
262    ///
263    /// `fn_` must remain callable with `user_data` as long as any
264    /// `GraphExec` derived from this graph is alive.
265    pub unsafe fn add_host_node(
266        &self,
267        dependencies: &[GraphNode],
268        fn_: unsafe extern "C" fn(*mut core::ffi::c_void),
269        user_data: *mut core::ffi::c_void,
270    ) -> Result<GraphNode> { unsafe {
271        use baracuda_cuda_sys::runtime::types::cudaHostNodeParams;
272        let r = runtime()?;
273        let cu = r.cuda_graph_add_host_node()?;
274        let params = cudaHostNodeParams {
275            fn_: Some(fn_),
276            user_data,
277        };
278        let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
279        let (dp, dl) = deps_raw(&deps);
280        let mut node: cudaGraphNode_t = core::ptr::null_mut();
281        check(cu(&mut node, self.inner.handle, dp, dl, &params))?;
282        Ok(GraphNode { raw: node })
283    }}
284
285    /// Add a child-graph node.
286    pub fn add_child_graph_node(
287        &self,
288        dependencies: &[GraphNode],
289        child: &Graph,
290    ) -> Result<GraphNode> {
291        let r = runtime()?;
292        let cu = r.cuda_graph_add_child_graph_node()?;
293        let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
294        let (dp, dl) = deps_raw(&deps);
295        let mut node: cudaGraphNode_t = core::ptr::null_mut();
296        check(unsafe { cu(&mut node, self.inner.handle, dp, dl, child.as_raw()) })?;
297        Ok(GraphNode { raw: node })
298    }
299
300    /// Add an event-record node.
301    pub fn add_event_record_node(
302        &self,
303        dependencies: &[GraphNode],
304        event: &crate::Event,
305    ) -> Result<GraphNode> {
306        let r = runtime()?;
307        let cu = r.cuda_graph_add_event_record_node()?;
308        let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
309        let (dp, dl) = deps_raw(&deps);
310        let mut node: cudaGraphNode_t = core::ptr::null_mut();
311        check(unsafe { cu(&mut node, self.inner.handle, dp, dl, event.as_raw()) })?;
312        Ok(GraphNode { raw: node })
313    }
314
315    /// Add an event-wait node.
316    pub fn add_event_wait_node(
317        &self,
318        dependencies: &[GraphNode],
319        event: &crate::Event,
320    ) -> Result<GraphNode> {
321        let r = runtime()?;
322        let cu = r.cuda_graph_add_event_wait_node()?;
323        let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
324        let (dp, dl) = deps_raw(&deps);
325        let mut node: cudaGraphNode_t = core::ptr::null_mut();
326        check(unsafe { cu(&mut node, self.inner.handle, dp, dl, event.as_raw()) })?;
327        Ok(GraphNode { raw: node })
328    }
329
330    /// Add a stream-ordered mem-alloc node. Returns the node plus the
331    /// device pointer the node will allocate at launch time.
332    pub fn add_mem_alloc_node(
333        &self,
334        dependencies: &[GraphNode],
335        device: &crate::Device,
336        bytesize: usize,
337    ) -> Result<(GraphNode, *mut core::ffi::c_void)> {
338        use baracuda_cuda_sys::runtime::types::{
339            cudaMemAllocNodeParams, cudaMemAllocationHandleType, cudaMemAllocationType,
340            cudaMemLocation, cudaMemLocationType, cudaMemPoolProps,
341        };
342        let r = runtime()?;
343        let cu = r.cuda_graph_add_mem_alloc_node()?;
344        let mut params = cudaMemAllocNodeParams {
345            pool_props: cudaMemPoolProps {
346                alloc_type: cudaMemAllocationType::PINNED,
347                handle_types: cudaMemAllocationHandleType::NONE,
348                location: cudaMemLocation {
349                    type_: cudaMemLocationType::DEVICE,
350                    id: device.ordinal(),
351                },
352                ..Default::default()
353            },
354            access_descs: core::ptr::null(),
355            access_desc_count: 0,
356            bytesize,
357            dptr: core::ptr::null_mut(),
358        };
359        let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
360        let (dp, dl) = deps_raw(&deps);
361        let mut node: cudaGraphNode_t = core::ptr::null_mut();
362        check(unsafe { cu(&mut node, self.inner.handle, dp, dl, &mut params) })?;
363        Ok((GraphNode { raw: node }, params.dptr))
364    }
365
366    /// Add a stream-ordered mem-free node for `dptr`.
367    ///
368    /// # Safety
369    ///
370    /// `dptr` must be a pointer returned by a prior mem-alloc node in
371    /// this graph.
372    pub unsafe fn add_mem_free_node(
373        &self,
374        dependencies: &[GraphNode],
375        dptr: *mut core::ffi::c_void,
376    ) -> Result<GraphNode> { unsafe {
377        let r = runtime()?;
378        let cu = r.cuda_graph_add_mem_free_node()?;
379        let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
380        let (dp, dl) = deps_raw(&deps);
381        let mut node: cudaGraphNode_t = core::ptr::null_mut();
382        check(cu(&mut node, self.inner.handle, dp, dl, dptr))?;
383        Ok(GraphNode { raw: node })
384    }}
385
386    /// Create a conditional-node handle (CUDA 12.3+). The returned u64
387    /// is an opaque driver handle used by `cuGraphAddNode`-style
388    /// conditional-node construction, which the runtime exposes via
389    /// `cudaGraphAddNode`. `default_launch_value` is the handle's
390    /// starting value (commonly 0 = "don't execute"), `flags` = 0 for
391    /// the default. Returns [`crate::Error::FeatureNotSupported`] on
392    /// older CUDA.
393    pub fn conditional_handle_create(&self, default_launch_value: u32, flags: u32) -> Result<u64> {
394        use baracuda_types::{supports, Feature};
395        let installed = crate::init::driver_version()?;
396        if !supports(installed, Feature::GraphConditionalNodes) {
397            return Err(crate::error::Error::FeatureNotSupported {
398                api: "cudaGraphConditionalHandleCreate",
399                since: Feature::GraphConditionalNodes.required_version(),
400            });
401        }
402        let r = runtime()?;
403        let cu = r.cuda_graph_conditional_handle_create()?;
404        let mut handle: u64 = 0;
405        check(unsafe { cu(&mut handle, self.inner.handle, default_launch_value, flags) })?;
406        Ok(handle)
407    }
408
409    /// Low-level `cudaGraphAddNode` — add a node from a tagged
410    /// `cudaGraphNodeParams` struct. baracuda-runtime exposes typed
411    /// builders for the common node types (kernel, memset, host, etc.);
412    /// this escape hatch exists for the node types the typed API does
413    /// not cover (notably conditional nodes on CUDA 12.3+).
414    ///
415    /// # Safety
416    ///
417    /// `node_params` must point at a correctly-tagged
418    /// `cudaGraphNodeParams` whose union payload matches the `type`
419    /// field.
420    pub unsafe fn add_node_raw(
421        &self,
422        dependencies: &[GraphNode],
423        node_params: *mut core::ffi::c_void,
424    ) -> Result<GraphNode> { unsafe {
425        let r = runtime()?;
426        let cu = r.cuda_graph_add_node()?;
427        let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
428        let (dp, dl) = deps_raw(&deps);
429        let mut node: cudaGraphNode_t = core::ptr::null_mut();
430        check(cu(&mut node, self.inner.handle, dp, dl, node_params))?;
431        Ok(GraphNode { raw: node })
432    }}
433
434    /// Add dependency edges `from[i] -> to[i]`.
435    pub fn add_dependencies(&self, from: &[GraphNode], to: &[GraphNode]) -> Result<()> {
436        assert_eq!(from.len(), to.len());
437        if from.is_empty() {
438            return Ok(());
439        }
440        let r = runtime()?;
441        let cu = r.cuda_graph_add_dependencies()?;
442        let f: Vec<_> = from.iter().map(|n| n.raw).collect();
443        let t: Vec<_> = to.iter().map(|n| n.raw).collect();
444        check(unsafe { cu(self.inner.handle, f.as_ptr(), t.as_ptr(), f.len()) })
445    }
446}
447
448fn deps_raw(deps: &[cudaGraphNode_t]) -> (*const cudaGraphNode_t, usize) {
449    if deps.is_empty() {
450        (core::ptr::null(), 0)
451    } else {
452        (deps.as_ptr(), deps.len())
453    }
454}
455
456/// A node inside a [`Graph`]. Lightweight `Copy` handle that borrows the
457/// parent graph's storage.
458#[derive(Copy, Clone, Debug)]
459pub struct GraphNode {
460    raw: cudaGraphNode_t,
461}
462
463impl GraphNode {
464    #[inline]
465    pub fn as_raw(&self) -> cudaGraphNode_t {
466        self.raw
467    }
468
469    /// Return the `cudaGraphNodeType` integer. Compare against CUDA's
470    /// enum values in the runtime API docs (0=Kernel, 1=Memcpy, 2=Memset,
471    /// 3=Host, 4=Graph, 5=Empty, 6=WaitEvent, 7=EventRecord, 10=MemAlloc,
472    /// 11=MemFree).
473    pub fn node_type(&self) -> Result<i32> {
474        let r = runtime()?;
475        let cu = r.cuda_graph_node_get_type()?;
476        let mut t: core::ffi::c_int = 0;
477        check(unsafe { cu(self.raw, &mut t) })?;
478        Ok(t)
479    }
480
481    /// For `MemFree` nodes: return the device pointer this node will free.
482    pub fn mem_free_ptr(&self) -> Result<*mut core::ffi::c_void> {
483        let r = runtime()?;
484        let cu = r.cuda_graph_mem_free_node_get_params()?;
485        let mut p: *mut core::ffi::c_void = core::ptr::null_mut();
486        check(unsafe { cu(self.raw, &mut p) })?;
487        Ok(p)
488    }
489}
490
491impl Drop for GraphInner {
492    fn drop(&mut self) {
493        if let Ok(r) = runtime() {
494            if let Ok(cu) = r.cuda_graph_destroy() {
495                let _ = unsafe { cu(self.handle) };
496            }
497        }
498    }
499}
500
501/// An instantiated (executable) CUDA graph.
502#[derive(Clone)]
503pub struct GraphExec {
504    inner: Arc<GraphExecInner>,
505}
506
507struct GraphExecInner {
508    handle: cudaGraphExec_t,
509}
510
511unsafe impl Send for GraphExecInner {}
512unsafe impl Sync for GraphExecInner {}
513
514impl core::fmt::Debug for GraphExecInner {
515    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
516        f.debug_struct("GraphExec")
517            .field("handle", &self.handle)
518            .finish_non_exhaustive()
519    }
520}
521
522impl core::fmt::Debug for GraphExec {
523    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
524        self.inner.fmt(f)
525    }
526}
527
528impl GraphExec {
529    /// Launch this graph on `stream`.
530    pub fn launch(&self, stream: &Stream) -> Result<()> {
531        let r = runtime()?;
532        let cu = r.cuda_graph_launch()?;
533        check(unsafe { cu(self.inner.handle, stream.as_raw()) })
534    }
535
536    /// Attempt to update this executable graph in place from a
537    /// topology-identical template. On refusal the exec is left
538    /// unchanged; inspect [`UpdateResult`] for the specific reason.
539    pub fn update(&self, new_template: &Graph) -> Result<UpdateResult> {
540        let r = runtime()?;
541        let cu = r.cuda_graph_exec_update()?;
542        let mut error_node: cudaGraphNode_t = core::ptr::null_mut();
543        let mut result: core::ffi::c_int = 0;
544        // cudaGraphExecUpdate may return non-SUCCESS even when the
545        // `result` field has its own more informative value (per CUDA
546        // docs). We propagate the returned rc only when `result` is
547        // `SUCCESS` — otherwise we return the parsed result.
548        let rc = unsafe {
549            cu(
550                self.inner.handle,
551                new_template.as_raw(),
552                &mut error_node,
553                &mut result,
554            )
555        };
556        if rc != baracuda_cuda_sys::runtime::cudaError_t::Success
557            && result == baracuda_cuda_sys::runtime::types::cudaGraphExecUpdateResult::SUCCESS
558        {
559            return Err(crate::error::Error::Status { status: rc });
560        }
561        Ok(UpdateResult {
562            result,
563            error_node: if error_node.is_null() {
564                None
565            } else {
566                Some(GraphNode { raw: error_node })
567            },
568        })
569    }
570
571    #[inline]
572    pub fn as_raw(&self) -> cudaGraphExec_t {
573        self.inner.handle
574    }
575}
576
577/// Outcome of [`GraphExec::update`]. `result` is a
578/// `cudaGraphExecUpdateResult` code — `SUCCESS` (0) means the executable
579/// graph was patched in place.
580#[derive(Clone, Debug)]
581pub struct UpdateResult {
582    pub result: core::ffi::c_int,
583    pub error_node: Option<GraphNode>,
584}
585
586impl UpdateResult {
587    pub fn is_success(&self) -> bool {
588        self.result == baracuda_cuda_sys::runtime::types::cudaGraphExecUpdateResult::SUCCESS
589    }
590}
591
592impl Drop for GraphExecInner {
593    fn drop(&mut self) {
594        if let Ok(r) = runtime() {
595            if let Ok(cu) = r.cuda_graph_exec_destroy() {
596                let _ = unsafe { cu(self.handle) };
597            }
598        }
599    }
600}