Skip to main content

baracuda_driver/
graph.rs

1//! CUDA Graphs — record a sequence of operations once, replay cheaply.
2//!
3//! Two construction paths:
4//!
5//! 1. **Stream capture** ([`Stream::begin_capture`], [`Stream::end_capture`])
6//!    — run your work normally on a stream; the driver records it into a
7//!    graph instead of executing it. This is the recommended entry point
8//!    for most users.
9//! 2. **Explicit construction** ([`Graph::new`]) — build a graph node by
10//!    node. baracuda v0.1 does not yet expose typed node builders; use
11//!    capture for now.
12//!
13//! Either way, instantiate the [`Graph`] to a [`GraphExec`] and launch it
14//! on any stream as many times as you like.
15
16use std::sync::Arc;
17
18use baracuda_cuda_sys::types::{
19    CUgraphConditionalHandle, CUgraphExecUpdateResultInfo, CUgraphNodeParams, CUgraphNodeType,
20    CUmemAllocationHandleType, CUmemAllocationType, CUmemLocation, CUmemLocationType,
21    CUmemPoolProps, CUDA_CONDITIONAL_NODE_PARAMS, CUDA_HOST_NODE_PARAMS, CUDA_KERNEL_NODE_PARAMS,
22    CUDA_MEMCPY3D, CUDA_MEMSET_NODE_PARAMS, CUDA_MEM_ALLOC_NODE_PARAMS,
23};
24use baracuda_cuda_sys::{driver, CUdeviceptr, CUgraph, CUgraphExec, CUgraphNode};
25
26use crate::context::Context;
27use crate::error::{check, Result};
28use crate::event::Event;
29use crate::launch::Dim3;
30use crate::module::Function;
31use crate::stream::Stream;
32
33/// Stream-capture mode, matching `CUstreamCaptureMode`.
34#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
35pub enum CaptureMode {
36    /// Captures operations on *any* stream in the process while this
37    /// thread's stream is capturing. Discouraged — mostly present for
38    /// compatibility with legacy code.
39    Global,
40    /// Captures only operations on streams whose capture was started from
41    /// the current thread. **Recommended default.**
42    #[default]
43    ThreadLocal,
44    /// Permissive mode — allows unsynchronized cross-stream activity
45    /// without failing capture.
46    Relaxed,
47}
48
49impl CaptureMode {
50    #[inline]
51    fn raw(self) -> u32 {
52        match self {
53            CaptureMode::Global => 0,
54            CaptureMode::ThreadLocal => 1,
55            CaptureMode::Relaxed => 2,
56        }
57    }
58}
59
60impl Stream {
61    /// Begin recording operations submitted to this stream into a CUDA graph.
62    ///
63    /// Call [`Stream::end_capture`] to retrieve the resulting [`Graph`].
64    /// Most operations (kernel launches, memcpys, event records) enqueued
65    /// between these two calls are captured rather than executed.
66    pub fn begin_capture(&self, mode: CaptureMode) -> Result<()> {
67        let d = driver()?;
68        let cu = d.cu_stream_begin_capture()?;
69        check(unsafe { cu(self.as_raw(), mode.raw()) })
70    }
71
72    /// Stop capture and return the graph of everything that was recorded.
73    pub fn end_capture(&self) -> Result<Graph> {
74        let d = driver()?;
75        let cu = d.cu_stream_end_capture()?;
76        let mut graph: CUgraph = core::ptr::null_mut();
77        check(unsafe { cu(self.as_raw(), &mut graph) })?;
78        Ok(Graph {
79            inner: Arc::new(GraphInner {
80                handle: graph,
81                context: self.context().clone(),
82                owned: true,
83            }),
84        })
85    }
86
87    /// Convenience wrapper: run `f`, capturing everything it submits to
88    /// this stream, and return the resulting graph.
89    ///
90    /// `f` should enqueue its work on `self`. If it errors out mid-capture,
91    /// we still end the capture to avoid leaking the captured state.
92    pub fn capture<F>(&self, mode: CaptureMode, f: F) -> Result<Graph>
93    where
94        F: FnOnce(&Stream) -> Result<()>,
95    {
96        self.begin_capture(mode)?;
97        let inner_result = f(self);
98        // End capture regardless of f's success so we don't leak state.
99        let end_result = self.end_capture();
100        match (inner_result, end_result) {
101            (Ok(()), Ok(graph)) => Ok(graph),
102            (Err(e), _) => Err(e),
103            (Ok(()), Err(e)) => Err(e),
104        }
105    }
106
107    /// `true` if this stream is currently in capture mode.
108    pub fn is_capturing(&self) -> Result<bool> {
109        let d = driver()?;
110        let cu = d.cu_stream_is_capturing()?;
111        let mut status: core::ffi::c_uint = 0;
112        check(unsafe { cu(self.as_raw(), &mut status) })?;
113        // CUstreamCaptureStatus::NONE = 0, ACTIVE = 1, INVALIDATED = 2.
114        Ok(status == 1)
115    }
116}
117
118/// A CUDA graph — a DAG of CUDA operations.
119#[derive(Clone)]
120pub struct Graph {
121    inner: Arc<GraphInner>,
122}
123
124struct GraphInner {
125    handle: CUgraph,
126    context: Context,
127    /// When `false`, this `Graph` wraps a graph owned by something else
128    /// (e.g. the body of a conditional node). Drop is a no-op.
129    owned: bool,
130}
131
132unsafe impl Send for GraphInner {}
133unsafe impl Sync for GraphInner {}
134
135impl core::fmt::Debug for GraphInner {
136    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
137        f.debug_struct("Graph")
138            .field("handle", &self.handle)
139            .finish_non_exhaustive()
140    }
141}
142
143impl core::fmt::Debug for Graph {
144    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
145        self.inner.fmt(f)
146    }
147}
148
149impl Graph {
150    /// Create an empty graph in the given context. Use this as a starting
151    /// point for explicit node construction — note that baracuda v0.1
152    /// does not yet expose typed node builders, so in practice you'll
153    /// almost always build graphs via stream capture instead.
154    pub fn new(context: &Context) -> Result<Self> {
155        context.set_current()?;
156        let d = driver()?;
157        let cu = d.cu_graph_create()?;
158        let mut graph: CUgraph = core::ptr::null_mut();
159        check(unsafe { cu(&mut graph, 0) })?;
160        Ok(Self {
161            inner: Arc::new(GraphInner {
162                handle: graph,
163                context: context.clone(),
164                owned: true,
165            }),
166        })
167    }
168
169    /// Compile this graph into an executable form that can be launched.
170    pub fn instantiate(&self) -> Result<GraphExec> {
171        self.instantiate_with_flags(0)
172    }
173
174    /// As [`Self::instantiate`] but passes `flags` to
175    /// `cuGraphInstantiateWithFlags` (see [`instantiate_flags`]).
176    pub fn instantiate_with_flags(&self, flags: u64) -> Result<GraphExec> {
177        let d = driver()?;
178        let cu = d.cu_graph_instantiate_with_flags()?;
179        let mut exec: CUgraphExec = core::ptr::null_mut();
180        check(unsafe { cu(&mut exec, self.inner.handle, flags) })?;
181        Ok(GraphExec {
182            inner: Arc::new(GraphExecInner {
183                handle: exec,
184                context: self.inner.context.clone(),
185            }),
186        })
187    }
188
189    /// Approximate number of nodes in the graph (useful for debugging).
190    pub fn node_count(&self) -> Result<usize> {
191        let d = driver()?;
192        let cu = d.cu_graph_get_nodes()?;
193        let mut count: usize = 0;
194        check(unsafe { cu(self.inner.handle, core::ptr::null_mut(), &mut count) })?;
195        Ok(count)
196    }
197
198    /// Raw `CUgraph`. Use with care.
199    #[inline]
200    pub fn as_raw(&self) -> CUgraph {
201        self.inner.handle
202    }
203
204    /// Add an empty "join / barrier" node with the given dependencies.
205    /// Returns the new node so it can be used as a dependency of later nodes.
206    pub fn add_empty_node(&self, dependencies: &[GraphNode]) -> Result<GraphNode> {
207        let d = driver()?;
208        let cu = d.cu_graph_add_empty_node()?;
209        let mut node: CUgraphNode = core::ptr::null_mut();
210        let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
211        let (deps_ptr, deps_len) = deps_raw(&deps);
212        check(unsafe { cu(&mut node, self.inner.handle, deps_ptr, deps_len) })?;
213        Ok(GraphNode { raw: node })
214    }
215
216    /// Add a kernel-launch node. `args` is the same `*mut c_void` array you'd
217    /// pass to `cuLaunchKernel` — build it from [`baracuda_types::KernelArg`]
218    /// pointers, same as the [`crate::LaunchBuilder`] does.
219    ///
220    /// # Safety
221    ///
222    /// Same responsibilities as [`crate::LaunchBuilder::launch`]: argument
223    /// count, order, and types must match the kernel's C signature, and
224    /// pointer-typed arguments must remain live as long as any executable
225    /// derived from this graph is running.
226    pub unsafe fn add_kernel_node(
227        &self,
228        dependencies: &[GraphNode],
229        function: &Function,
230        grid: impl Into<Dim3>,
231        block: impl Into<Dim3>,
232        shared_mem_bytes: u32,
233        args: &mut [*mut core::ffi::c_void],
234    ) -> Result<GraphNode> { unsafe {
235        let d = driver()?;
236        let cu = d.cu_graph_add_kernel_node()?;
237        let grid = grid.into();
238        let block = block.into();
239        let params = CUDA_KERNEL_NODE_PARAMS {
240            func: function.as_raw(),
241            grid_dim_x: grid.x,
242            grid_dim_y: grid.y,
243            grid_dim_z: grid.z,
244            block_dim_x: block.x,
245            block_dim_y: block.y,
246            block_dim_z: block.z,
247            shared_mem_bytes,
248            kernel_params: if args.is_empty() {
249                core::ptr::null_mut()
250            } else {
251                args.as_mut_ptr()
252            },
253            extra: core::ptr::null_mut(),
254            kern: core::ptr::null_mut(),
255            ctx: core::ptr::null_mut(),
256        };
257        let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
258        let (deps_ptr, deps_len) = deps_raw(&deps);
259        let mut node: CUgraphNode = core::ptr::null_mut();
260        check(cu(
261            &mut node,
262            self.inner.handle,
263            deps_ptr,
264            deps_len,
265            &params,
266        ))?;
267        Ok(GraphNode { raw: node })
268    }}
269
270    /// Add a 1-D memset node that fills `count` elements starting at `dst`
271    /// with the 4-byte pattern `value`. Operates in the graph's parent
272    /// context.
273    pub fn add_memset_u32_node(
274        &self,
275        dependencies: &[GraphNode],
276        dst: CUdeviceptr,
277        value: u32,
278        count: usize,
279    ) -> Result<GraphNode> {
280        let d = driver()?;
281        let cu = d.cu_graph_add_memset_node()?;
282        let params = CUDA_MEMSET_NODE_PARAMS {
283            dst,
284            pitch: 0,
285            value,
286            element_size: 4,
287            width: count,
288            height: 1,
289        };
290        let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
291        let (deps_ptr, deps_len) = deps_raw(&deps);
292        let mut node: CUgraphNode = core::ptr::null_mut();
293        check(unsafe {
294            cu(
295                &mut node,
296                self.inner.handle,
297                deps_ptr,
298                deps_len,
299                &params,
300                self.inner.context.as_raw(),
301            )
302        })?;
303        Ok(GraphNode { raw: node })
304    }
305
306    /// Deep-copy this graph (including its topology). The clone is
307    /// independent — destroying one does not affect the other.
308    pub fn clone_graph(&self) -> Result<Self> {
309        let d = driver()?;
310        let cu = d.cu_graph_clone()?;
311        let mut out: CUgraph = core::ptr::null_mut();
312        check(unsafe { cu(&mut out, self.inner.handle) })?;
313        Ok(Self {
314            inner: Arc::new(GraphInner {
315                handle: out,
316                context: self.inner.context.clone(),
317                owned: true,
318            }),
319        })
320    }
321
322    /// Add a memcpy node. `params` is a fully-populated [`CUDA_MEMCPY3D`].
323    pub fn add_memcpy_node(
324        &self,
325        dependencies: &[GraphNode],
326        params: &CUDA_MEMCPY3D,
327    ) -> Result<GraphNode> {
328        let d = driver()?;
329        let cu = d.cu_graph_add_memcpy_node()?;
330        let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
331        let (deps_ptr, deps_len) = deps_raw(&deps);
332        let mut node: CUgraphNode = core::ptr::null_mut();
333        check(unsafe {
334            cu(
335                &mut node,
336                self.inner.handle,
337                deps_ptr,
338                deps_len,
339                params,
340                self.inner.context.as_raw(),
341            )
342        })?;
343        Ok(GraphNode { raw: node })
344    }
345
346    /// Add a host-function node. Hand the closure's trampoline address
347    /// and a user-data pointer that remains valid for the lifetime of any
348    /// executable graph derived from this graph.
349    ///
350    /// # Safety
351    ///
352    /// `fn_` will be invoked on a CUDA-internal host thread with
353    /// `user_data` as its argument. The pointer must remain valid as long
354    /// as any `GraphExec` containing this node is alive.
355    pub unsafe fn add_host_node(
356        &self,
357        dependencies: &[GraphNode],
358        fn_: unsafe extern "C" fn(*mut core::ffi::c_void),
359        user_data: *mut core::ffi::c_void,
360    ) -> Result<GraphNode> { unsafe {
361        let d = driver()?;
362        let cu = d.cu_graph_add_host_node()?;
363        let params = CUDA_HOST_NODE_PARAMS {
364            fn_: Some(fn_),
365            user_data,
366        };
367        let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
368        let (deps_ptr, deps_len) = deps_raw(&deps);
369        let mut node: CUgraphNode = core::ptr::null_mut();
370        check(cu(
371            &mut node,
372            self.inner.handle,
373            deps_ptr,
374            deps_len,
375            &params,
376        ))?;
377        Ok(GraphNode { raw: node })
378    }}
379
380    /// Add a child-graph node — executes `child` in its entirety when
381    /// reached.
382    pub fn add_child_graph_node(
383        &self,
384        dependencies: &[GraphNode],
385        child: &Graph,
386    ) -> Result<GraphNode> {
387        let d = driver()?;
388        let cu = d.cu_graph_add_child_graph_node()?;
389        let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
390        let (deps_ptr, deps_len) = deps_raw(&deps);
391        let mut node: CUgraphNode = core::ptr::null_mut();
392        check(unsafe {
393            cu(
394                &mut node,
395                self.inner.handle,
396                deps_ptr,
397                deps_len,
398                child.as_raw(),
399            )
400        })?;
401        Ok(GraphNode { raw: node })
402    }
403
404    /// Add an event-record node — records `event` when executed.
405    pub fn add_event_record_node(
406        &self,
407        dependencies: &[GraphNode],
408        event: &Event,
409    ) -> Result<GraphNode> {
410        let d = driver()?;
411        let cu = d.cu_graph_add_event_record_node()?;
412        let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
413        let (deps_ptr, deps_len) = deps_raw(&deps);
414        let mut node: CUgraphNode = core::ptr::null_mut();
415        check(unsafe {
416            cu(
417                &mut node,
418                self.inner.handle,
419                deps_ptr,
420                deps_len,
421                event.as_raw(),
422            )
423        })?;
424        Ok(GraphNode { raw: node })
425    }
426
427    /// Add an event-wait node — blocks downstream nodes until `event` has
428    /// been recorded.
429    pub fn add_event_wait_node(
430        &self,
431        dependencies: &[GraphNode],
432        event: &Event,
433    ) -> Result<GraphNode> {
434        let d = driver()?;
435        let cu = d.cu_graph_add_event_wait_node()?;
436        let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
437        let (deps_ptr, deps_len) = deps_raw(&deps);
438        let mut node: CUgraphNode = core::ptr::null_mut();
439        check(unsafe {
440            cu(
441                &mut node,
442                self.inner.handle,
443                deps_ptr,
444                deps_len,
445                event.as_raw(),
446            )
447        })?;
448        Ok(GraphNode { raw: node })
449    }
450
451    /// Add a stream-ordered memory allocation node. When the graph runs,
452    /// the node allocates `bytesize` bytes on `device` (from the device's
453    /// default pool). The resulting device pointer is returned in the
454    /// output tuple alongside the new node.
455    pub fn add_mem_alloc_node(
456        &self,
457        dependencies: &[GraphNode],
458        device: &crate::Device,
459        bytesize: usize,
460    ) -> Result<(GraphNode, CUdeviceptr)> {
461        let d = driver()?;
462        let cu = d.cu_graph_add_mem_alloc_node()?;
463        let mut params = CUDA_MEM_ALLOC_NODE_PARAMS {
464            pool_props: CUmemPoolProps {
465                alloc_type: CUmemAllocationType::PINNED,
466                handle_types: CUmemAllocationHandleType::NONE,
467                location: CUmemLocation {
468                    type_: CUmemLocationType::DEVICE,
469                    id: device.as_raw().0,
470                },
471                ..Default::default()
472            },
473            access_descs: core::ptr::null(),
474            access_desc_count: 0,
475            bytesize,
476            dptr: CUdeviceptr(0),
477        };
478        let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
479        let (deps_ptr, deps_len) = deps_raw(&deps);
480        let mut node: CUgraphNode = core::ptr::null_mut();
481        check(unsafe {
482            cu(
483                &mut node,
484                self.inner.handle,
485                deps_ptr,
486                deps_len,
487                &mut params,
488            )
489        })?;
490        Ok((GraphNode { raw: node }, params.dptr))
491    }
492
493    /// Add a stream-ordered memory-free node for `dptr` (which is
494    /// typically the `dptr` returned by a prior
495    /// [`Self::add_mem_alloc_node`] on the same graph).
496    pub fn add_mem_free_node(
497        &self,
498        dependencies: &[GraphNode],
499        dptr: CUdeviceptr,
500    ) -> Result<GraphNode> {
501        let d = driver()?;
502        let cu = d.cu_graph_add_mem_free_node()?;
503        let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
504        let (deps_ptr, deps_len) = deps_raw(&deps);
505        let mut node: CUgraphNode = core::ptr::null_mut();
506        check(unsafe { cu(&mut node, self.inner.handle, deps_ptr, deps_len, dptr) })?;
507        Ok(GraphNode { raw: node })
508    }
509
510    /// Add a batch-memop node — a single node that performs a sequence of
511    /// 32/64-bit wait/write value operations on device memory atomically
512    /// wrt the graph's execution order.
513    ///
514    /// `ops` may include any mix of [`baracuda_cuda_sys::types::CUstreamBatchMemOpParams`]
515    /// entries built with that type's `wait_value_*` / `write_value_*` helpers.
516    pub fn add_batch_mem_op_node(
517        &self,
518        dependencies: &[GraphNode],
519        ops: &mut [baracuda_cuda_sys::types::CUstreamBatchMemOpParams],
520    ) -> Result<GraphNode> {
521        let d = driver()?;
522        let cu = d.cu_graph_add_batch_mem_op_node()?;
523        let params = baracuda_cuda_sys::types::CUDA_BATCH_MEM_OP_NODE_PARAMS {
524            ctx: self.inner.context.as_raw(),
525            count: ops.len() as core::ffi::c_uint,
526            param_array: ops.as_mut_ptr(),
527            flags: 0,
528        };
529        let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
530        let (deps_ptr, deps_len) = deps_raw(&deps);
531        let mut node: CUgraphNode = core::ptr::null_mut();
532        check(unsafe { cu(&mut node, self.inner.handle, deps_ptr, deps_len, &params) })?;
533        Ok(GraphNode { raw: node })
534    }
535
536    /// Add dependency edges from each node in `from` to its counterpart in
537    /// `to`. Both slices must have the same length.
538    pub fn add_dependencies(&self, from: &[GraphNode], to: &[GraphNode]) -> Result<()> {
539        assert_eq!(from.len(), to.len(), "add_dependencies: length mismatch");
540        if from.is_empty() {
541            return Ok(());
542        }
543        let d = driver()?;
544        let cu = d.cu_graph_add_dependencies()?;
545        let f: Vec<CUgraphNode> = from.iter().map(|n| n.raw).collect();
546        let t: Vec<CUgraphNode> = to.iter().map(|n| n.raw).collect();
547        check(unsafe { cu(self.inner.handle, f.as_ptr(), t.as_ptr(), f.len()) })
548    }
549
550    /// Remove previously-added dependency edges.
551    pub fn remove_dependencies(&self, from: &[GraphNode], to: &[GraphNode]) -> Result<()> {
552        assert_eq!(from.len(), to.len(), "remove_dependencies: length mismatch");
553        if from.is_empty() {
554            return Ok(());
555        }
556        let d = driver()?;
557        let cu = d.cu_graph_remove_dependencies()?;
558        let f: Vec<CUgraphNode> = from.iter().map(|n| n.raw).collect();
559        let t: Vec<CUgraphNode> = to.iter().map(|n| n.raw).collect();
560        check(unsafe { cu(self.inner.handle, f.as_ptr(), t.as_ptr(), f.len()) })
561    }
562
563    /// Dump a Graphviz-compatible representation of this graph to `path`.
564    /// Pass `flags = 0` for the default verbose output.
565    pub fn debug_dot_print(&self, path: &str, flags: u32) -> Result<()> {
566        let d = driver()?;
567        let cu = d.cu_graph_debug_dot_print()?;
568        let c_path = std::ffi::CString::new(path).map_err(|_| {
569            crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
570                library: "cuda-driver",
571                symbol: "cuGraphDebugDotPrint(path contained a NUL byte)",
572            })
573        })?;
574        check(unsafe { cu(self.inner.handle, c_path.as_ptr(), flags) })
575    }
576
577    /// Create a conditional handle tied to this parent graph. Pass the
578    /// handle's value from inside a kernel (via
579    /// `cudaGraphSetConditional(handle, val)`) to drive whether or how
580    /// many times the conditional body executes.
581    pub fn conditional_handle(
582        &self,
583        default_launch_value: u32,
584        flags: u32,
585    ) -> Result<CUgraphConditionalHandle> {
586        let d = driver()?;
587        let cu = d.cu_graph_conditional_handle_create()?;
588        let mut h: CUgraphConditionalHandle = 0;
589        check(unsafe {
590            cu(
591                &mut h,
592                self.inner.handle,
593                self.inner.context.as_raw(),
594                default_launch_value,
595                flags,
596            )
597        })?;
598        Ok(h)
599    }
600
601    /// Add a conditional node (IF / WHILE / SWITCH). Returns `(node, body)`
602    /// — populate the `body` graph with the code to execute conditionally.
603    ///
604    /// `type_` is one of
605    /// [`baracuda_cuda_sys::types::CUgraphConditionalNodeType`].
606    /// `size` is the count of body graphs (1 for IF/WHILE; up to N for SWITCH).
607    pub fn add_conditional_node(
608        &self,
609        dependencies: &[GraphNode],
610        handle: CUgraphConditionalHandle,
611        type_: i32,
612        size: u32,
613    ) -> Result<(GraphNode, Graph)> {
614        let d = driver()?;
615        let cu = d.cu_graph_add_node()?;
616        let mut body: CUgraph = core::ptr::null_mut();
617        let cond = CUDA_CONDITIONAL_NODE_PARAMS {
618            handle,
619            type_,
620            size,
621            body_graph_out: &mut body,
622            ctx: self.inner.context.as_raw(),
623        };
624        let mut params = CUgraphNodeParams {
625            type_: CUgraphNodeType::CONDITIONAL,
626            ..Default::default()
627        };
628        // Write the CUDA_CONDITIONAL_NODE_PARAMS at the start of the payload.
629        // SAFETY: payload is [u64; 29] = 232 bytes, 8-aligned; conditional
630        // params fit in 32 bytes.
631        unsafe {
632            let dst = params.payload.as_mut_ptr() as *mut CUDA_CONDITIONAL_NODE_PARAMS;
633            dst.write(cond);
634        }
635        let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
636        let (deps_ptr, deps_len) = deps_raw(&deps);
637        let mut node: CUgraphNode = core::ptr::null_mut();
638        check(unsafe {
639            cu(
640                &mut node,
641                self.inner.handle,
642                deps_ptr,
643                core::ptr::null(),
644                deps_len,
645                &mut params,
646            )
647        })?;
648        // The body CUgraph is owned by the conditional node; wrap it
649        // non-owning so our Drop doesn't double-free.
650        let body_graph = Graph {
651            inner: Arc::new(GraphInner {
652                handle: body,
653                context: self.inner.context.clone(),
654                owned: false,
655            }),
656        };
657        Ok((GraphNode { raw: node }, body_graph))
658    }
659
660    /// Return `(from, to)` vectors describing every edge in the graph.
661    pub fn edges(&self) -> Result<(Vec<GraphNode>, Vec<GraphNode>)> {
662        let d = driver()?;
663        let cu = d.cu_graph_get_edges()?;
664        // First call: ask for edge count.
665        let mut count: usize = 0;
666        check(unsafe {
667            cu(
668                self.inner.handle,
669                core::ptr::null_mut(),
670                core::ptr::null_mut(),
671                &mut count,
672            )
673        })?;
674        let mut from = vec![core::ptr::null_mut(); count];
675        let mut to = vec![core::ptr::null_mut(); count];
676        if count > 0 {
677            check(unsafe {
678                cu(
679                    self.inner.handle,
680                    from.as_mut_ptr(),
681                    to.as_mut_ptr(),
682                    &mut count,
683                )
684            })?;
685        }
686        Ok((
687            from.into_iter().map(|raw| GraphNode { raw }).collect(),
688            to.into_iter().map(|raw| GraphNode { raw }).collect(),
689        ))
690    }
691}
692
693fn deps_raw(deps: &[CUgraphNode]) -> (*const CUgraphNode, usize) {
694    if deps.is_empty() {
695        (core::ptr::null(), 0)
696    } else {
697        (deps.as_ptr(), deps.len())
698    }
699}
700
701/// A node inside a [`Graph`]. Lightweight handle that borrows the parent
702/// graph's storage — destroying the graph invalidates all of its nodes.
703///
704/// `GraphNode` is `Copy` so you can use a single node as a dependency of
705/// many successors without cloning.
706#[derive(Copy, Clone, Debug)]
707pub struct GraphNode {
708    raw: CUgraphNode,
709}
710
711impl GraphNode {
712    /// Raw `CUgraphNode`. Use with care.
713    #[inline]
714    pub fn as_raw(&self) -> CUgraphNode {
715        self.raw
716    }
717
718    /// Return the `CUgraphNodeType` code for this node. Compare against
719    /// constants in [`baracuda_cuda_sys::types::CUgraphNodeType`].
720    pub fn node_type(&self) -> Result<core::ffi::c_int> {
721        let d = driver()?;
722        let cu = d.cu_graph_node_get_type()?;
723        let mut t: core::ffi::c_int = 0;
724        check(unsafe { cu(self.raw, &mut t) })?;
725        Ok(t)
726    }
727
728    /// Return this node's upstream dependencies.
729    pub fn dependencies(&self) -> Result<Vec<GraphNode>> {
730        let d = driver()?;
731        let cu = d.cu_graph_node_get_dependencies()?;
732        let mut count: usize = 0;
733        check(unsafe { cu(self.raw, core::ptr::null_mut(), &mut count) })?;
734        let mut out = vec![core::ptr::null_mut(); count];
735        if count > 0 {
736            check(unsafe { cu(self.raw, out.as_mut_ptr(), &mut count) })?;
737        }
738        Ok(out.into_iter().map(|raw| GraphNode { raw }).collect())
739    }
740
741    /// Return nodes that depend on this node.
742    pub fn dependent_nodes(&self) -> Result<Vec<GraphNode>> {
743        let d = driver()?;
744        let cu = d.cu_graph_node_get_dependent_nodes()?;
745        let mut count: usize = 0;
746        check(unsafe { cu(self.raw, core::ptr::null_mut(), &mut count) })?;
747        let mut out = vec![core::ptr::null_mut(); count];
748        if count > 0 {
749            check(unsafe { cu(self.raw, out.as_mut_ptr(), &mut count) })?;
750        }
751        Ok(out.into_iter().map(|raw| GraphNode { raw }).collect())
752    }
753
754    /// Fetch current kernel-node params (kernel-node nodes only).
755    pub fn kernel_params(&self) -> Result<CUDA_KERNEL_NODE_PARAMS> {
756        let d = driver()?;
757        let cu = d.cu_graph_kernel_node_get_params()?;
758        let mut p = CUDA_KERNEL_NODE_PARAMS::default();
759        check(unsafe { cu(self.raw, &mut p) })?;
760        Ok(p)
761    }
762
763    /// Overwrite this kernel-node's params on the template graph (not the
764    /// instantiated exec — use [`GraphExec::set_kernel_node_params`] for
765    /// live edit).
766    ///
767    /// # Safety
768    ///
769    /// The caller ensures the new params describe a valid kernel launch
770    /// — same kind of invariants as [`crate::LaunchBuilder::launch`].
771    pub unsafe fn set_kernel_params(&self, params: &CUDA_KERNEL_NODE_PARAMS) -> Result<()> { unsafe {
772        let d = driver()?;
773        let cu = d.cu_graph_kernel_node_set_params()?;
774        check(cu(self.raw, params))
775    }}
776
777    /// Generic params edit on the template graph — works for any node
778    /// kind (kernel, memcpy, memset, child-graph, host, …) by reading
779    /// the `type_` field of `CUgraphNodeParams` and dispatching to the
780    /// matching internal setter. CUDA 12.3+.
781    ///
782    /// # Safety
783    ///
784    /// The `type_` discriminant in `params` must match the node's
785    /// actual kind, and the union payload must be initialized for that
786    /// type.
787    pub unsafe fn set_params(&self, params: &mut CUgraphNodeParams) -> Result<()> { unsafe {
788        let d = driver()?;
789        let cu = d.cu_graph_node_set_params()?;
790        check(cu(self.raw, params))
791    }}
792
793    /// Fetch current memset-node params (memset-node nodes only).
794    pub fn memset_params(&self) -> Result<CUDA_MEMSET_NODE_PARAMS> {
795        let d = driver()?;
796        let cu = d.cu_graph_memset_node_get_params()?;
797        let mut p = CUDA_MEMSET_NODE_PARAMS::default();
798        check(unsafe { cu(self.raw, &mut p) })?;
799        Ok(p)
800    }
801
802    /// Overwrite this memset-node's params on the template graph.
803    pub fn set_memset_params(&self, params: &CUDA_MEMSET_NODE_PARAMS) -> Result<()> {
804        let d = driver()?;
805        let cu = d.cu_graph_memset_node_set_params()?;
806        check(unsafe { cu(self.raw, params) })
807    }
808
809    /// Fetch the device pointer this `MemFree` node will free. Only valid
810    /// on nodes of type `CUgraphNodeType::MEM_FREE`.
811    pub fn mem_free_ptr(&self) -> Result<CUdeviceptr> {
812        let d = driver()?;
813        let cu = d.cu_graph_mem_free_node_get_params()?;
814        let mut p = CUdeviceptr(0);
815        check(unsafe { cu(self.raw, &mut p) })?;
816        Ok(p)
817    }
818
819    /// Fetch the full mem-alloc-node params (pool props + bytesize + the
820    /// output `dptr` CUDA will write into at execute time).
821    pub fn mem_alloc_params(&self) -> Result<CUDA_MEM_ALLOC_NODE_PARAMS> {
822        let d = driver()?;
823        let cu = d.cu_graph_mem_alloc_node_get_params()?;
824        let mut p = CUDA_MEM_ALLOC_NODE_PARAMS::default();
825        check(unsafe { cu(self.raw, &mut p) })?;
826        Ok(p)
827    }
828
829    /// Fetch current memcpy-node params.
830    pub fn memcpy_params(&self) -> Result<CUDA_MEMCPY3D> {
831        let d = driver()?;
832        let cu = d.cu_graph_memcpy_node_get_params()?;
833        let mut p = CUDA_MEMCPY3D::default();
834        check(unsafe { cu(self.raw, &mut p) })?;
835        Ok(p)
836    }
837
838    /// Overwrite this memcpy-node's params on the template graph.
839    pub fn set_memcpy_params(&self, params: &CUDA_MEMCPY3D) -> Result<()> {
840        let d = driver()?;
841        let cu = d.cu_graph_memcpy_node_set_params()?;
842        check(unsafe { cu(self.raw, params) })
843    }
844
845    /// Explicitly destroy this node inside its parent graph. Usually you
846    /// just drop the [`Graph`] to clean up everything at once; this is only
847    /// useful for surgically editing a graph mid-construction.
848    ///
849    /// # Safety
850    ///
851    /// The caller must not use this `GraphNode` (or any dependency-list
852    /// reference to it) after calling this function.
853    pub unsafe fn destroy(self) -> Result<()> { unsafe {
854        let d = driver()?;
855        let cu = d.cu_graph_destroy_node()?;
856        check(cu(self.raw))
857    }}
858}
859
860/// Re-export of `cuGraphInstantiateWithFlags` flag constants.
861pub mod instantiate_flags {
862    pub use baracuda_cuda_sys::types::CUgraphInstantiate_flags::*;
863}
864
865impl Drop for GraphInner {
866    fn drop(&mut self) {
867        if !self.owned || self.handle.is_null() {
868            return;
869        }
870        if let Ok(d) = driver() {
871            if let Ok(cu) = d.cu_graph_destroy() {
872                let _ = unsafe { cu(self.handle) };
873            }
874        }
875    }
876}
877
878/// An instantiated (executable) CUDA graph.
879#[derive(Clone)]
880pub struct GraphExec {
881    inner: Arc<GraphExecInner>,
882}
883
884struct GraphExecInner {
885    handle: CUgraphExec,
886    #[allow(dead_code)]
887    context: Context,
888}
889
890unsafe impl Send for GraphExecInner {}
891unsafe impl Sync for GraphExecInner {}
892
893impl core::fmt::Debug for GraphExecInner {
894    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
895        f.debug_struct("GraphExec")
896            .field("handle", &self.handle)
897            .finish_non_exhaustive()
898    }
899}
900
901impl core::fmt::Debug for GraphExec {
902    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
903        self.inner.fmt(f)
904    }
905}
906
907impl GraphExec {
908    /// Launch this graph on `stream`. Can be called repeatedly — that's the
909    /// whole point of CUDA graphs.
910    pub fn launch(&self, stream: &Stream) -> Result<()> {
911        let d = driver()?;
912        let cu = d.cu_graph_launch()?;
913        check(unsafe { cu(self.inner.handle, stream.as_raw()) })
914    }
915
916    /// Raw `CUgraphExec`. Use with care.
917    #[inline]
918    pub fn as_raw(&self) -> CUgraphExec {
919        self.inner.handle
920    }
921
922    /// Try to update this executable graph in place based on a new
923    /// template graph. Returns `Ok(info)` in all cases — inspect
924    /// [`UpdateResult`] to distinguish success from the various
925    /// reasons CUDA refused the update. On refusal, the executable graph
926    /// is left unchanged (not corrupted), so the caller can fall back to
927    /// re-instantiating.
928    ///
929    /// Only topology-invariant changes are allowed (same number + type of
930    /// nodes, same dependency edges). Changes to kernel arguments, grid
931    /// dims, memset values, memcpy params, etc. are supported.
932    pub fn update(&self, new_template: &Graph) -> Result<UpdateResult> {
933        let d = driver()?;
934        let cu = d.cu_graph_exec_update()?;
935        let mut info = CUgraphExecUpdateResultInfo::default();
936        // cuGraphExecUpdate_v2 returns SUCCESS even when the update
937        // failed for topology/type reasons — the info.result field
938        // carries the real outcome.
939        let rc = unsafe { cu(self.inner.handle, new_template.as_raw(), &mut info) };
940        if rc != baracuda_cuda_sys::CUresult::SUCCESS
941            && info.result == baracuda_cuda_sys::types::CUgraphExecUpdateResult::SUCCESS
942        {
943            return Err(crate::error::Error::Status { status: rc });
944        }
945        Ok(UpdateResult {
946            result: info.result,
947            error_node: if info.error_node.is_null() {
948                None
949            } else {
950                Some(GraphNode {
951                    raw: info.error_node,
952                })
953            },
954            error_from_node: if info.error_from_node.is_null() {
955                None
956            } else {
957                Some(GraphNode {
958                    raw: info.error_from_node,
959                })
960            },
961        })
962    }
963
964    /// Live-edit a kernel-node's parameters on the instantiated graph.
965    /// Avoids re-instantiation when only arg values / grid dims / shmem
966    /// change. The topology must match (no new nodes, no new edges).
967    ///
968    /// # Safety
969    ///
970    /// Same launch invariants as [`crate::LaunchBuilder::launch`].
971    pub unsafe fn set_kernel_node_params(
972        &self,
973        node: GraphNode,
974        params: &CUDA_KERNEL_NODE_PARAMS,
975    ) -> Result<()> { unsafe {
976        let d = driver()?;
977        let cu = d.cu_graph_exec_kernel_node_set_params()?;
978        check(cu(self.inner.handle, node.raw, params))
979    }}
980
981    /// Live-edit a memcpy-node's parameters on the instantiated graph.
982    pub fn set_memcpy_node_params(&self, node: GraphNode, params: &CUDA_MEMCPY3D) -> Result<()> {
983        let d = driver()?;
984        let cu = d.cu_graph_exec_memcpy_node_set_params()?;
985        check(unsafe {
986            cu(
987                self.inner.handle,
988                node.raw,
989                params,
990                self.inner.context.as_raw(),
991            )
992        })
993    }
994
995    /// Live-edit a memset-node's parameters on the instantiated graph.
996    pub fn set_memset_node_params(
997        &self,
998        node: GraphNode,
999        params: &CUDA_MEMSET_NODE_PARAMS,
1000    ) -> Result<()> {
1001        let d = driver()?;
1002        let cu = d.cu_graph_exec_memset_node_set_params()?;
1003        check(unsafe {
1004            cu(
1005                self.inner.handle,
1006                node.raw,
1007                params,
1008                self.inner.context.as_raw(),
1009            )
1010        })
1011    }
1012
1013    /// Live-edit a host-node's callback on the instantiated graph.
1014    ///
1015    /// # Safety
1016    ///
1017    /// `fn_` must remain callable with `user_data` for the lifetime of
1018    /// this `GraphExec`.
1019    pub unsafe fn set_host_node_params(
1020        &self,
1021        node: GraphNode,
1022        fn_: unsafe extern "C" fn(*mut core::ffi::c_void),
1023        user_data: *mut core::ffi::c_void,
1024    ) -> Result<()> { unsafe {
1025        let d = driver()?;
1026        let cu = d.cu_graph_exec_host_node_set_params()?;
1027        let params = CUDA_HOST_NODE_PARAMS {
1028            fn_: Some(fn_),
1029            user_data,
1030        };
1031        check(cu(self.inner.handle, node.raw, &params))
1032    }}
1033}
1034
1035impl Drop for GraphExecInner {
1036    fn drop(&mut self) {
1037        if let Ok(d) = driver() {
1038            if let Ok(cu) = d.cu_graph_exec_destroy() {
1039                let _ = unsafe { cu(self.handle) };
1040            }
1041        }
1042    }
1043}
1044
1045/// Outcome of [`GraphExec::update`]. Inspect [`result`](Self::result)
1046/// against the constants in
1047/// [`baracuda_cuda_sys::types::CUgraphExecUpdateResult`] —
1048/// `SUCCESS` (0) means the executable graph was patched in place.
1049#[derive(Clone, Debug)]
1050pub struct UpdateResult {
1051    pub result: core::ffi::c_int,
1052    /// Node in the *new* template that triggered the failure, if any.
1053    pub error_node: Option<GraphNode>,
1054    /// Corresponding node in the *old* (already-instantiated) template.
1055    pub error_from_node: Option<GraphNode>,
1056}
1057
1058impl UpdateResult {
1059    /// `true` iff CUDA accepted the update.
1060    pub fn is_success(&self) -> bool {
1061        self.result == baracuda_cuda_sys::types::CUgraphExecUpdateResult::SUCCESS
1062    }
1063}
1064
1065// ---- Graph-memory per-device attribute queries --------------------------
1066
1067/// Trim the per-device graph-mem reserve back to the minimum footprint.
1068pub fn device_graph_mem_trim(device: &crate::Device) -> Result<()> {
1069    let d = driver()?;
1070    let cu = d.cu_device_graph_mem_trim()?;
1071    check(unsafe { cu(device.as_raw()) })
1072}
1073
1074/// Query a `CUgraphMem_attribute` (used / reserved memory, high-water
1075/// marks) — see [`baracuda_cuda_sys::types::CUgraphMem_attribute`].
1076pub fn device_graph_mem_attribute(device: &crate::Device, attr: i32) -> Result<u64> {
1077    let d = driver()?;
1078    let cu = d.cu_device_get_graph_mem_attribute()?;
1079    let mut v: u64 = 0;
1080    check(unsafe {
1081        cu(
1082            device.as_raw(),
1083            attr,
1084            &mut v as *mut u64 as *mut core::ffi::c_void,
1085        )
1086    })?;
1087    Ok(v)
1088}
1089
1090/// Reset a `CUgraphMem_attribute` (typically the high-water marks).
1091pub fn device_set_graph_mem_attribute(device: &crate::Device, attr: i32, value: u64) -> Result<()> {
1092    let d = driver()?;
1093    let cu = d.cu_device_set_graph_mem_attribute()?;
1094    let mut v = value;
1095    check(unsafe {
1096        cu(
1097            device.as_raw(),
1098            attr,
1099            &mut v as *mut u64 as *mut core::ffi::c_void,
1100        )
1101    })
1102}