Skip to main content

baracuda_runtime/
graph.rs

1//! CUDA Graphs (Runtime API).
2//!
3//! Two construction paths match the Driver-side API:
4//!
5//! 1. **Stream capture** ([`Stream::begin_capture`], [`Stream::end_capture`])
6//!    — run your work on a stream; CUDA records it into a graph instead
7//!    of executing. This is the usual path for applications.
8//! 2. **Explicit construction** ([`Graph::new`]) — empty graph, add
9//!    nodes by hand via the runtime's `cudaGraphAdd*Node` functions
10//!    (baracuda-runtime doesn't expose typed builders for those yet —
11//!    use stream capture, or drop down to the raw PFNs).
12//!
13//! Either way, [`Graph::instantiate`] compiles to [`GraphExec`], and
14//! [`GraphExec::launch`] runs it on any stream.
15
16use std::sync::Arc;
17
18use baracuda_cuda_sys::runtime::{
19    cudaGraphExec_t, cudaGraphNode_t, cudaGraph_t, runtime, types::cudaStreamCaptureStatus,
20};
21
22use crate::error::{check, Result};
23use crate::stream::Stream;
24
25/// Stream-capture mode (matches `cudaStreamCaptureMode`).
26#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
27pub enum CaptureMode {
28    Global,
29    #[default]
30    ThreadLocal,
31    Relaxed,
32}
33
34impl CaptureMode {
35    #[inline]
36    fn raw(self) -> i32 {
37        match self {
38            CaptureMode::Global => 0,
39            CaptureMode::ThreadLocal => 1,
40            CaptureMode::Relaxed => 2,
41        }
42    }
43}
44
45impl Stream {
46    /// Begin recording operations submitted to this stream into a graph.
47    pub fn begin_capture(&self, mode: CaptureMode) -> Result<()> {
48        let r = runtime()?;
49        let cu = r.cuda_stream_begin_capture()?;
50        check(unsafe { cu(self.as_raw(), mode.raw()) })
51    }
52
53    /// Stop capture and return the graph of everything recorded.
54    pub fn end_capture(&self) -> Result<Graph> {
55        let r = runtime()?;
56        let cu = r.cuda_stream_end_capture()?;
57        let mut graph: cudaGraph_t = core::ptr::null_mut();
58        check(unsafe { cu(self.as_raw(), &mut graph) })?;
59        Ok(Graph {
60            inner: Arc::new(GraphInner { handle: graph }),
61        })
62    }
63
64    /// Convenience wrapper: run `f`, capturing its submissions to this
65    /// stream, and return the resulting graph.
66    pub fn capture<F>(&self, mode: CaptureMode, f: F) -> Result<Graph>
67    where
68        F: FnOnce(&Stream) -> Result<()>,
69    {
70        self.begin_capture(mode)?;
71        let inner_result = f(self);
72        let end_result = self.end_capture();
73        match (inner_result, end_result) {
74            (Ok(()), Ok(graph)) => Ok(graph),
75            (Err(e), _) => Err(e),
76            (Ok(()), Err(e)) => Err(e),
77        }
78    }
79
80    /// `true` if this stream is currently recording into a graph.
81    pub fn is_capturing(&self) -> Result<bool> {
82        let r = runtime()?;
83        let cu = r.cuda_stream_is_capturing()?;
84        let mut status: core::ffi::c_int = 0;
85        check(unsafe { cu(self.as_raw(), &mut status) })?;
86        Ok(status == cudaStreamCaptureStatus::ACTIVE)
87    }
88}
89
90/// A CUDA graph — DAG of operations, replayable via [`Graph::instantiate`].
91#[derive(Clone)]
92pub struct Graph {
93    inner: Arc<GraphInner>,
94}
95
96struct GraphInner {
97    handle: cudaGraph_t,
98}
99
100unsafe impl Send for GraphInner {}
101unsafe impl Sync for GraphInner {}
102
103impl core::fmt::Debug for GraphInner {
104    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
105        f.debug_struct("Graph")
106            .field("handle", &self.handle)
107            .finish_non_exhaustive()
108    }
109}
110
111impl core::fmt::Debug for Graph {
112    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
113        self.inner.fmt(f)
114    }
115}
116
117impl Graph {
118    /// Create an empty graph.
119    pub fn new() -> Result<Self> {
120        let r = runtime()?;
121        let cu = r.cuda_graph_create()?;
122        let mut graph: cudaGraph_t = core::ptr::null_mut();
123        check(unsafe { cu(&mut graph, 0) })?;
124        Ok(Self {
125            inner: Arc::new(GraphInner { handle: graph }),
126        })
127    }
128
129    /// Compile this graph into an executable form.
130    pub fn instantiate(&self) -> Result<GraphExec> {
131        let r = runtime()?;
132        let cu = r.cuda_graph_instantiate()?;
133        let mut exec: cudaGraphExec_t = core::ptr::null_mut();
134        check(unsafe { cu(&mut exec, self.inner.handle, 0) })?;
135        Ok(GraphExec {
136            inner: Arc::new(GraphExecInner { handle: exec }),
137        })
138    }
139
140    /// Approximate node count (for debugging).
141    pub fn node_count(&self) -> Result<usize> {
142        let r = runtime()?;
143        let cu = r.cuda_graph_get_nodes()?;
144        let mut count: usize = 0;
145        check(unsafe { cu(self.inner.handle, core::ptr::null_mut(), &mut count) })?;
146        Ok(count)
147    }
148
149    #[inline]
150    pub fn as_raw(&self) -> cudaGraph_t {
151        self.inner.handle
152    }
153
154    /// Add an empty "barrier" node with the given dependencies.
155    pub fn add_empty_node(&self, dependencies: &[GraphNode]) -> Result<GraphNode> {
156        let r = runtime()?;
157        let cu = r.cuda_graph_add_empty_node()?;
158        let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
159        let (dp, dl) = deps_raw(&deps);
160        let mut node: cudaGraphNode_t = core::ptr::null_mut();
161        check(unsafe { cu(&mut node, self.inner.handle, dp, dl) })?;
162        Ok(GraphNode { raw: node })
163    }
164
165    /// Add a kernel-launch node.
166    ///
167    /// # Safety
168    ///
169    /// Same discipline as [`crate::LaunchBuilder::launch`]: argument
170    /// count/order/types must match the kernel's C signature.
171    pub unsafe fn add_kernel_node(
172        &self,
173        dependencies: &[GraphNode],
174        kernel: &crate::Kernel,
175        grid: crate::Dim3,
176        block: crate::Dim3,
177        shared_mem_bytes: u32,
178        args: &mut [*mut core::ffi::c_void],
179    ) -> Result<GraphNode> {
180        use baracuda_cuda_sys::runtime::types::{cudaKernelNodeParams, dim3};
181        let r = runtime()?;
182        let cu = r.cuda_graph_add_kernel_node()?;
183        let params = cudaKernelNodeParams {
184            func: kernel.as_launch_ptr() as *mut core::ffi::c_void,
185            grid_dim: dim3::new(grid.x, grid.y, grid.z),
186            block_dim: dim3::new(block.x, block.y, block.z),
187            shared_mem_bytes,
188            kernel_params: if args.is_empty() {
189                core::ptr::null_mut()
190            } else {
191                args.as_mut_ptr()
192            },
193            extra: core::ptr::null_mut(),
194        };
195        let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
196        let (dp, dl) = deps_raw(&deps);
197        let mut node: cudaGraphNode_t = core::ptr::null_mut();
198        check(cu(&mut node, self.inner.handle, dp, dl, &params))?;
199        Ok(GraphNode { raw: node })
200    }
201
202    /// Add a memset node filling `count` 4-byte words at `dst` with `value`.
203    pub fn add_memset_u32_node(
204        &self,
205        dependencies: &[GraphNode],
206        dst: *mut core::ffi::c_void,
207        value: u32,
208        count: usize,
209    ) -> Result<GraphNode> {
210        use baracuda_cuda_sys::runtime::types::cudaMemsetParams;
211        let r = runtime()?;
212        let cu = r.cuda_graph_add_memset_node()?;
213        let params = cudaMemsetParams {
214            dst,
215            pitch: 0,
216            value,
217            element_size: 4,
218            width: count,
219            height: 1,
220        };
221        let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
222        let (dp, dl) = deps_raw(&deps);
223        let mut node: cudaGraphNode_t = core::ptr::null_mut();
224        check(unsafe { cu(&mut node, self.inner.handle, dp, dl, &params) })?;
225        Ok(GraphNode { raw: node })
226    }
227
228    /// Add a host-function node. `fn_` runs on a driver-owned thread
229    /// when this node executes.
230    ///
231    /// # Safety
232    ///
233    /// `fn_` must remain callable with `user_data` as long as any
234    /// `GraphExec` derived from this graph is alive.
235    pub unsafe fn add_host_node(
236        &self,
237        dependencies: &[GraphNode],
238        fn_: unsafe extern "C" fn(*mut core::ffi::c_void),
239        user_data: *mut core::ffi::c_void,
240    ) -> Result<GraphNode> {
241        use baracuda_cuda_sys::runtime::types::cudaHostNodeParams;
242        let r = runtime()?;
243        let cu = r.cuda_graph_add_host_node()?;
244        let params = cudaHostNodeParams {
245            fn_: Some(fn_),
246            user_data,
247        };
248        let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
249        let (dp, dl) = deps_raw(&deps);
250        let mut node: cudaGraphNode_t = core::ptr::null_mut();
251        check(cu(&mut node, self.inner.handle, dp, dl, &params))?;
252        Ok(GraphNode { raw: node })
253    }
254
255    /// Add a child-graph node.
256    pub fn add_child_graph_node(
257        &self,
258        dependencies: &[GraphNode],
259        child: &Graph,
260    ) -> Result<GraphNode> {
261        let r = runtime()?;
262        let cu = r.cuda_graph_add_child_graph_node()?;
263        let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
264        let (dp, dl) = deps_raw(&deps);
265        let mut node: cudaGraphNode_t = core::ptr::null_mut();
266        check(unsafe { cu(&mut node, self.inner.handle, dp, dl, child.as_raw()) })?;
267        Ok(GraphNode { raw: node })
268    }
269
270    /// Add an event-record node.
271    pub fn add_event_record_node(
272        &self,
273        dependencies: &[GraphNode],
274        event: &crate::Event,
275    ) -> Result<GraphNode> {
276        let r = runtime()?;
277        let cu = r.cuda_graph_add_event_record_node()?;
278        let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
279        let (dp, dl) = deps_raw(&deps);
280        let mut node: cudaGraphNode_t = core::ptr::null_mut();
281        check(unsafe { cu(&mut node, self.inner.handle, dp, dl, event.as_raw()) })?;
282        Ok(GraphNode { raw: node })
283    }
284
285    /// Add an event-wait node.
286    pub fn add_event_wait_node(
287        &self,
288        dependencies: &[GraphNode],
289        event: &crate::Event,
290    ) -> Result<GraphNode> {
291        let r = runtime()?;
292        let cu = r.cuda_graph_add_event_wait_node()?;
293        let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
294        let (dp, dl) = deps_raw(&deps);
295        let mut node: cudaGraphNode_t = core::ptr::null_mut();
296        check(unsafe { cu(&mut node, self.inner.handle, dp, dl, event.as_raw()) })?;
297        Ok(GraphNode { raw: node })
298    }
299
300    /// Add a stream-ordered mem-alloc node. Returns the node plus the
301    /// device pointer the node will allocate at launch time.
302    pub fn add_mem_alloc_node(
303        &self,
304        dependencies: &[GraphNode],
305        device: &crate::Device,
306        bytesize: usize,
307    ) -> Result<(GraphNode, *mut core::ffi::c_void)> {
308        use baracuda_cuda_sys::runtime::types::{
309            cudaMemAllocNodeParams, cudaMemAllocationHandleType, cudaMemAllocationType,
310            cudaMemLocation, cudaMemLocationType, cudaMemPoolProps,
311        };
312        let r = runtime()?;
313        let cu = r.cuda_graph_add_mem_alloc_node()?;
314        let mut params = cudaMemAllocNodeParams {
315            pool_props: cudaMemPoolProps {
316                alloc_type: cudaMemAllocationType::PINNED,
317                handle_types: cudaMemAllocationHandleType::NONE,
318                location: cudaMemLocation {
319                    type_: cudaMemLocationType::DEVICE,
320                    id: device.ordinal(),
321                },
322                ..Default::default()
323            },
324            access_descs: core::ptr::null(),
325            access_desc_count: 0,
326            bytesize,
327            dptr: core::ptr::null_mut(),
328        };
329        let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
330        let (dp, dl) = deps_raw(&deps);
331        let mut node: cudaGraphNode_t = core::ptr::null_mut();
332        check(unsafe { cu(&mut node, self.inner.handle, dp, dl, &mut params) })?;
333        Ok((GraphNode { raw: node }, params.dptr))
334    }
335
336    /// Add a stream-ordered mem-free node for `dptr`.
337    ///
338    /// # Safety
339    ///
340    /// `dptr` must be a pointer returned by a prior mem-alloc node in
341    /// this graph.
342    pub unsafe fn add_mem_free_node(
343        &self,
344        dependencies: &[GraphNode],
345        dptr: *mut core::ffi::c_void,
346    ) -> Result<GraphNode> {
347        let r = runtime()?;
348        let cu = r.cuda_graph_add_mem_free_node()?;
349        let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
350        let (dp, dl) = deps_raw(&deps);
351        let mut node: cudaGraphNode_t = core::ptr::null_mut();
352        check(cu(&mut node, self.inner.handle, dp, dl, dptr))?;
353        Ok(GraphNode { raw: node })
354    }
355
356    /// Create a conditional-node handle (CUDA 12.3+). The returned u64
357    /// is an opaque driver handle used by `cuGraphAddNode`-style
358    /// conditional-node construction, which the runtime exposes via
359    /// `cudaGraphAddNode`. `default_launch_value` is the handle's
360    /// starting value (commonly 0 = "don't execute"), `flags` = 0 for
361    /// the default. Returns [`crate::Error::FeatureNotSupported`] on
362    /// older CUDA.
363    pub fn conditional_handle_create(&self, default_launch_value: u32, flags: u32) -> Result<u64> {
364        use baracuda_types::{supports, Feature};
365        let installed = crate::init::driver_version()?;
366        if !supports(installed, Feature::GraphConditionalNodes) {
367            return Err(crate::error::Error::FeatureNotSupported {
368                api: "cudaGraphConditionalHandleCreate",
369                since: Feature::GraphConditionalNodes.required_version(),
370            });
371        }
372        let r = runtime()?;
373        let cu = r.cuda_graph_conditional_handle_create()?;
374        let mut handle: u64 = 0;
375        check(unsafe { cu(&mut handle, self.inner.handle, default_launch_value, flags) })?;
376        Ok(handle)
377    }
378
379    /// Low-level `cudaGraphAddNode` — add a node from a tagged
380    /// `cudaGraphNodeParams` struct. baracuda-runtime exposes typed
381    /// builders for the common node types (kernel, memset, host, etc.);
382    /// this escape hatch exists for the node types the typed API does
383    /// not cover (notably conditional nodes on CUDA 12.3+).
384    ///
385    /// # Safety
386    ///
387    /// `node_params` must point at a correctly-tagged
388    /// `cudaGraphNodeParams` whose union payload matches the `type`
389    /// field.
390    pub unsafe fn add_node_raw(
391        &self,
392        dependencies: &[GraphNode],
393        node_params: *mut core::ffi::c_void,
394    ) -> Result<GraphNode> {
395        let r = runtime()?;
396        let cu = r.cuda_graph_add_node()?;
397        let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
398        let (dp, dl) = deps_raw(&deps);
399        let mut node: cudaGraphNode_t = core::ptr::null_mut();
400        check(cu(&mut node, self.inner.handle, dp, dl, node_params))?;
401        Ok(GraphNode { raw: node })
402    }
403
404    /// Add dependency edges `from[i] -> to[i]`.
405    pub fn add_dependencies(&self, from: &[GraphNode], to: &[GraphNode]) -> Result<()> {
406        assert_eq!(from.len(), to.len());
407        if from.is_empty() {
408            return Ok(());
409        }
410        let r = runtime()?;
411        let cu = r.cuda_graph_add_dependencies()?;
412        let f: Vec<_> = from.iter().map(|n| n.raw).collect();
413        let t: Vec<_> = to.iter().map(|n| n.raw).collect();
414        check(unsafe { cu(self.inner.handle, f.as_ptr(), t.as_ptr(), f.len()) })
415    }
416}
417
418fn deps_raw(deps: &[cudaGraphNode_t]) -> (*const cudaGraphNode_t, usize) {
419    if deps.is_empty() {
420        (core::ptr::null(), 0)
421    } else {
422        (deps.as_ptr(), deps.len())
423    }
424}
425
426/// A node inside a [`Graph`]. Lightweight `Copy` handle that borrows the
427/// parent graph's storage.
428#[derive(Copy, Clone, Debug)]
429pub struct GraphNode {
430    raw: cudaGraphNode_t,
431}
432
433impl GraphNode {
434    #[inline]
435    pub fn as_raw(&self) -> cudaGraphNode_t {
436        self.raw
437    }
438
439    /// Return the `cudaGraphNodeType` integer. Compare against CUDA's
440    /// enum values in the runtime API docs (0=Kernel, 1=Memcpy, 2=Memset,
441    /// 3=Host, 4=Graph, 5=Empty, 6=WaitEvent, 7=EventRecord, 10=MemAlloc,
442    /// 11=MemFree).
443    pub fn node_type(&self) -> Result<i32> {
444        let r = runtime()?;
445        let cu = r.cuda_graph_node_get_type()?;
446        let mut t: core::ffi::c_int = 0;
447        check(unsafe { cu(self.raw, &mut t) })?;
448        Ok(t)
449    }
450
451    /// For `MemFree` nodes: return the device pointer this node will free.
452    pub fn mem_free_ptr(&self) -> Result<*mut core::ffi::c_void> {
453        let r = runtime()?;
454        let cu = r.cuda_graph_mem_free_node_get_params()?;
455        let mut p: *mut core::ffi::c_void = core::ptr::null_mut();
456        check(unsafe { cu(self.raw, &mut p) })?;
457        Ok(p)
458    }
459}
460
461impl Drop for GraphInner {
462    fn drop(&mut self) {
463        if let Ok(r) = runtime() {
464            if let Ok(cu) = r.cuda_graph_destroy() {
465                let _ = unsafe { cu(self.handle) };
466            }
467        }
468    }
469}
470
471/// An instantiated (executable) CUDA graph.
472#[derive(Clone)]
473pub struct GraphExec {
474    inner: Arc<GraphExecInner>,
475}
476
477struct GraphExecInner {
478    handle: cudaGraphExec_t,
479}
480
481unsafe impl Send for GraphExecInner {}
482unsafe impl Sync for GraphExecInner {}
483
484impl core::fmt::Debug for GraphExecInner {
485    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
486        f.debug_struct("GraphExec")
487            .field("handle", &self.handle)
488            .finish_non_exhaustive()
489    }
490}
491
492impl core::fmt::Debug for GraphExec {
493    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
494        self.inner.fmt(f)
495    }
496}
497
498impl GraphExec {
499    /// Launch this graph on `stream`.
500    pub fn launch(&self, stream: &Stream) -> Result<()> {
501        let r = runtime()?;
502        let cu = r.cuda_graph_launch()?;
503        check(unsafe { cu(self.inner.handle, stream.as_raw()) })
504    }
505
506    /// Attempt to update this executable graph in place from a
507    /// topology-identical template. On refusal the exec is left
508    /// unchanged; inspect [`UpdateResult`] for the specific reason.
509    pub fn update(&self, new_template: &Graph) -> Result<UpdateResult> {
510        let r = runtime()?;
511        let cu = r.cuda_graph_exec_update()?;
512        let mut error_node: cudaGraphNode_t = core::ptr::null_mut();
513        let mut result: core::ffi::c_int = 0;
514        // cudaGraphExecUpdate may return non-SUCCESS even when the
515        // `result` field has its own more informative value (per CUDA
516        // docs). We propagate the returned rc only when `result` is
517        // `SUCCESS` — otherwise we return the parsed result.
518        let rc = unsafe {
519            cu(
520                self.inner.handle,
521                new_template.as_raw(),
522                &mut error_node,
523                &mut result,
524            )
525        };
526        if rc != baracuda_cuda_sys::runtime::cudaError_t::Success
527            && result == baracuda_cuda_sys::runtime::types::cudaGraphExecUpdateResult::SUCCESS
528        {
529            return Err(crate::error::Error::Status { status: rc });
530        }
531        Ok(UpdateResult {
532            result,
533            error_node: if error_node.is_null() {
534                None
535            } else {
536                Some(GraphNode { raw: error_node })
537            },
538        })
539    }
540
541    #[inline]
542    pub fn as_raw(&self) -> cudaGraphExec_t {
543        self.inner.handle
544    }
545}
546
547/// Outcome of [`GraphExec::update`]. `result` is a
548/// `cudaGraphExecUpdateResult` code — `SUCCESS` (0) means the executable
549/// graph was patched in place.
550#[derive(Clone, Debug)]
551pub struct UpdateResult {
552    pub result: core::ffi::c_int,
553    pub error_node: Option<GraphNode>,
554}
555
556impl UpdateResult {
557    pub fn is_success(&self) -> bool {
558        self.result == baracuda_cuda_sys::runtime::types::cudaGraphExecUpdateResult::SUCCESS
559    }
560}
561
562impl Drop for GraphExecInner {
563    fn drop(&mut self) {
564        if let Ok(r) = runtime() {
565            if let Ok(cu) = r.cuda_graph_exec_destroy() {
566                let _ = unsafe { cu(self.handle) };
567            }
568        }
569    }
570}