Skip to main content

oxicuda_driver/
graph.rs

1//! CUDA Graph API for recording and replaying sequences of GPU operations.
2//!
3//! CUDA Graphs allow capturing a sequence of operations (kernel launches,
4//! memory copies, memsets) into a graph data structure that can be
5//! instantiated and launched repeatedly with minimal CPU overhead.
6//!
7//! # Architecture
8//!
9//! This module exposes a Rust-side graph representation that records
10//! operations as nodes with explicit dependency edges. [`Graph::instantiate`]
11//! translates that representation into the native CUDA Graph API
12//! (`cuGraphCreate`, `cuGraphAdd*Node`, `cuGraphInstantiate`) whenever a
13//! CUDA driver is available, and [`GraphExec::launch`] issues a real
14//! `cuGraphLaunch`. On macOS (or any host without a driver) the graph is
15//! still built and validated CPU-side, and launching reports
16//! [`CudaError::NotInitialized`].
17//!
18//! # Example
19//!
20//! ```rust,no_run
21//! # use oxicuda_driver::graph::{Graph, GraphNode, MemcpyDirection};
22//! let mut graph = Graph::new();
23//!
24//! let n0 = graph.add_memcpy_node(MemcpyDirection::HostToDevice, 4096);
25//! let n1 = graph.add_kernel_node(
26//!     "vector_add",
27//!     (4, 1, 1),
28//!     (256, 1, 1),
29//!     0,
30//! );
31//! let n2 = graph.add_memcpy_node(MemcpyDirection::DeviceToHost, 4096);
32//!
33//! graph.add_dependency(n0, n1).ok();
34//! graph.add_dependency(n1, n2).ok();
35//!
36//! assert_eq!(graph.node_count(), 3);
37//! assert_eq!(graph.dependency_count(), 2);
38//! ```
39
40use crate::error::{CudaError, CudaResult};
41use crate::stream::Stream;
42
43// ---------------------------------------------------------------------------
44// GraphNode — individual operation in a graph
45// ---------------------------------------------------------------------------
46
47/// Direction of a memory copy operation within a graph node.
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
49pub enum MemcpyDirection {
50    /// Host to device transfer.
51    HostToDevice,
52    /// Device to host transfer.
53    DeviceToHost,
54    /// Device to device transfer.
55    DeviceToDevice,
56}
57
58impl std::fmt::Display for MemcpyDirection {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        match self {
61            Self::HostToDevice => write!(f, "HtoD"),
62            Self::DeviceToHost => write!(f, "DtoH"),
63            Self::DeviceToDevice => write!(f, "DtoD"),
64        }
65    }
66}
67
68/// A single operation node within a [`Graph`].
69///
70/// Each variant represents a different type of GPU operation that can
71/// be recorded into a graph.
72#[derive(Debug, Clone, PartialEq, Eq)]
73pub enum GraphNode {
74    /// A kernel launch with grid/block configuration.
75    KernelLaunch {
76        /// Name of the kernel function.
77        function_name: String,
78        /// Grid dimensions `(x, y, z)`.
79        grid: (u32, u32, u32),
80        /// Block dimensions `(x, y, z)`.
81        block: (u32, u32, u32),
82        /// Dynamic shared memory in bytes.
83        shared_mem: u32,
84    },
85    /// A memory copy operation.
86    Memcpy {
87        /// Direction of the copy.
88        direction: MemcpyDirection,
89        /// Size of the transfer in bytes.
90        size: usize,
91    },
92    /// A memset operation (fill device memory with a byte value).
93    Memset {
94        /// Number of bytes to set.
95        size: usize,
96        /// Byte value to fill with.
97        value: u8,
98    },
99    /// An empty/no-op node used as a synchronisation barrier.
100    Empty,
101}
102
103impl std::fmt::Display for GraphNode {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        match self {
106            Self::KernelLaunch {
107                function_name,
108                grid,
109                block,
110                shared_mem,
111            } => write!(
112                f,
113                "Kernel({}, grid=({},{},{}), block=({},{},{}), smem={})",
114                function_name, grid.0, grid.1, grid.2, block.0, block.1, block.2, shared_mem,
115            ),
116            Self::Memcpy { direction, size } => {
117                write!(f, "Memcpy({direction}, {size} bytes)")
118            }
119            Self::Memset { size, value } => {
120                write!(f, "Memset({size} bytes, value=0x{value:02x})")
121            }
122            Self::Empty => write!(f, "Empty"),
123        }
124    }
125}
126
127// ---------------------------------------------------------------------------
128// Graph — collection of nodes with dependency edges
129// ---------------------------------------------------------------------------
130
131/// A CUDA graph representing a DAG of GPU operations.
132///
133/// Nodes represent individual operations (kernel launches, memory copies,
134/// memsets, or empty barriers). Dependencies are directed edges that
135/// enforce execution ordering between nodes.
136///
137/// The graph can be instantiated into a [`GraphExec`] for repeated
138/// low-overhead execution.
139#[derive(Debug, Clone)]
140pub struct Graph {
141    nodes: Vec<GraphNode>,
142    dependencies: Vec<(usize, usize)>,
143}
144
145impl Default for Graph {
146    fn default() -> Self {
147        Self::new()
148    }
149}
150
151impl Graph {
152    /// Creates a new empty graph with no nodes or dependencies.
153    pub fn new() -> Self {
154        Self {
155            nodes: Vec::new(),
156            dependencies: Vec::new(),
157        }
158    }
159
160    /// Adds a kernel launch node to the graph.
161    ///
162    /// Returns the index of the newly created node, which can be used
163    /// to establish dependencies via [`add_dependency`](Self::add_dependency).
164    ///
165    /// # Parameters
166    ///
167    /// * `function_name` - Name of the kernel function.
168    /// * `grid` - Grid dimensions `(x, y, z)`.
169    /// * `block` - Block dimensions `(x, y, z)`.
170    /// * `shared_mem` - Dynamic shared memory in bytes.
171    pub fn add_kernel_node(
172        &mut self,
173        function_name: &str,
174        grid: (u32, u32, u32),
175        block: (u32, u32, u32),
176        shared_mem: u32,
177    ) -> usize {
178        let idx = self.nodes.len();
179        self.nodes.push(GraphNode::KernelLaunch {
180            function_name: function_name.to_owned(),
181            grid,
182            block,
183            shared_mem,
184        });
185        idx
186    }
187
188    /// Adds a memory copy node to the graph.
189    ///
190    /// Returns the index of the newly created node.
191    ///
192    /// # Parameters
193    ///
194    /// * `direction` - Direction of the memory copy.
195    /// * `size` - Size of the transfer in bytes.
196    pub fn add_memcpy_node(&mut self, direction: MemcpyDirection, size: usize) -> usize {
197        let idx = self.nodes.len();
198        self.nodes.push(GraphNode::Memcpy { direction, size });
199        idx
200    }
201
202    /// Adds a memset node to the graph.
203    ///
204    /// Returns the index of the newly created node.
205    ///
206    /// # Parameters
207    ///
208    /// * `size` - Number of bytes to set.
209    /// * `value` - Byte value to fill with.
210    pub fn add_memset_node(&mut self, size: usize, value: u8) -> usize {
211        let idx = self.nodes.len();
212        self.nodes.push(GraphNode::Memset { size, value });
213        idx
214    }
215
216    /// Adds an empty (no-op) node to the graph.
217    ///
218    /// Empty nodes are useful as synchronisation barriers — they have
219    /// no work of their own but can serve as join points for multiple
220    /// dependency chains.
221    ///
222    /// Returns the index of the newly created node.
223    pub fn add_empty_node(&mut self) -> usize {
224        let idx = self.nodes.len();
225        self.nodes.push(GraphNode::Empty);
226        idx
227    }
228
229    /// Adds a dependency edge from node `from` to node `to`.
230    ///
231    /// This means `to` will not begin execution until `from` has
232    /// completed. Both indices must refer to existing nodes.
233    ///
234    /// # Errors
235    ///
236    /// Returns [`CudaError::InvalidValue`] if either index is out of bounds
237    /// or if `from == to` (self-dependency).
238    pub fn add_dependency(&mut self, from: usize, to: usize) -> CudaResult<()> {
239        if from >= self.nodes.len() || to >= self.nodes.len() {
240            return Err(CudaError::InvalidValue);
241        }
242        if from == to {
243            return Err(CudaError::InvalidValue);
244        }
245        self.dependencies.push((from, to));
246        Ok(())
247    }
248
249    /// Returns the total number of nodes in the graph.
250    #[inline]
251    pub fn node_count(&self) -> usize {
252        self.nodes.len()
253    }
254
255    /// Returns the total number of dependency edges in the graph.
256    #[inline]
257    pub fn dependency_count(&self) -> usize {
258        self.dependencies.len()
259    }
260
261    /// Returns a slice of all nodes in insertion order.
262    #[inline]
263    pub fn nodes(&self) -> &[GraphNode] {
264        &self.nodes
265    }
266
267    /// Returns a slice of all dependency edges as `(from, to)` pairs.
268    #[inline]
269    pub fn dependencies(&self) -> &[(usize, usize)] {
270        &self.dependencies
271    }
272
273    /// Returns the node at the given index, or `None` if out of bounds.
274    pub fn get_node(&self, index: usize) -> Option<&GraphNode> {
275        self.nodes.get(index)
276    }
277
278    /// Performs a topological sort of the graph nodes.
279    ///
280    /// Returns the node indices in an order that respects all
281    /// dependency edges, or an error if the graph contains a cycle.
282    ///
283    /// # Errors
284    ///
285    /// Returns [`CudaError::InvalidValue`] if the graph contains a
286    /// dependency cycle.
287    pub fn topological_sort(&self) -> CudaResult<Vec<usize>> {
288        let n = self.nodes.len();
289        let mut in_degree = vec![0u32; n];
290        let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
291
292        for &(from, to) in &self.dependencies {
293            adj[from].push(to);
294            in_degree[to] = in_degree[to].saturating_add(1);
295        }
296
297        let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
298        let mut result = Vec::with_capacity(n);
299
300        while let Some(node) = queue.pop() {
301            result.push(node);
302            for &next in &adj[node] {
303                in_degree[next] = in_degree[next].saturating_sub(1);
304                if in_degree[next] == 0 {
305                    queue.push(next);
306                }
307            }
308        }
309
310        if result.len() != n {
311            return Err(CudaError::InvalidValue);
312        }
313
314        Ok(result)
315    }
316
317    /// Instantiates the graph into an executable form.
318    ///
319    /// The returned [`GraphExec`] can be launched on a stream with minimal
320    /// CPU overhead.  The graph is always validated (topological sort)
321    /// during instantiation.
322    ///
323    /// When a CUDA driver is available, a genuine `CUgraph` is built
324    /// (`cuGraphCreate` + per-node `cuGraphAdd*Node` with the dependency DAG
325    /// wired through real `CUgraphNode` edges) and finalised into a
326    /// `CUgraphExec` via `cuGraphInstantiate`; [`GraphExec::launch`] then
327    /// issues a real `cuGraphLaunch`.  Without a driver (macOS, or a host
328    /// with no GPU) the `GraphExec` is CPU-side only and `launch` reports
329    /// [`CudaError::NotInitialized`].
330    ///
331    /// # Errors
332    ///
333    /// * [`CudaError::InvalidValue`] if the graph contains a dependency
334    ///   cycle.
335    /// * Any [`CudaError`] mapped from a failing `cuGraph*` driver call
336    ///   when a driver is present (e.g. [`CudaError::OutOfMemory`]).
337    pub fn instantiate(&self) -> CudaResult<GraphExec> {
338        // Validate the graph is a DAG by performing a topological sort.
339        // This must succeed regardless of driver availability.
340        let execution_order = self.topological_sort()?;
341
342        // Attempt a real driver-backed instantiation.  Fall back to a
343        // CPU-side-only GraphExec for environmental reasons — no driver, no
344        // GPU, no current CUDA context, or a driver predating the graph API
345        // — since none of those indicate a malformed graph.  A genuine
346        // graph-construction failure (e.g. OutOfMemory, InvalidValue) is a
347        // real error and propagates to the caller.
348        let (raw_graph, raw_exec) = match self.build_driver_graph() {
349            Ok(handles) => handles,
350            Err(
351                CudaError::NotInitialized
352                | CudaError::NotSupported
353                | CudaError::InvalidContext
354                | CudaError::NoDevice
355                | CudaError::InvalidDevice
356                | CudaError::Deinitialized,
357            ) => (None, None),
358            Err(other) => return Err(other),
359        };
360
361        Ok(GraphExec {
362            graph: self.clone(),
363            execution_order,
364            raw_graph,
365            raw_exec,
366        })
367    }
368
369    /// Build a real CUDA driver graph from this in-memory representation.
370    ///
371    /// Returns `(Some(CUgraph), Some(CUgraphExec))` on success.  Returns
372    /// [`CudaError::NotInitialized`] when no driver is loaded and
373    /// [`CudaError::NotSupported`] when the loaded driver predates the CUDA
374    /// Graph API; [`Graph::instantiate`] turns both (and other environmental
375    /// errors) into a CPU-side-only `GraphExec`.  Any other error is a
376    /// genuine driver failure.
377    ///
378    /// Each in-memory [`GraphNode`] is translated to a real driver node and
379    /// the dependency edges are reproduced exactly.  Nodes are created in
380    /// topological order so that, when `cuGraphAddEmptyNode` is given a
381    /// node's dependency list, every referenced `CUgraphNode` already
382    /// exists — regardless of the order edges were added to the in-memory
383    /// graph.  Because [`GraphNode`] stores only an operation specification
384    /// (no resolved `CUfunction` or device pointers), every node is added
385    /// via `cuGraphAddEmptyNode`; the resulting driver graph preserves the
386    /// node count and dependency topology and executes as a DAG of
387    /// synchronisation barriers.
388    fn build_driver_graph(
389        &self,
390    ) -> CudaResult<(Option<crate::ffi::CUgraph>, Option<crate::ffi::CUgraphExec>)> {
391        use crate::ffi::{CUgraph, CUgraphExec, CUgraphNode};
392
393        let api = crate::loader::try_driver()?;
394
395        // Resolve every required graph entry point; a pre-10.0 driver lacks
396        // them and yields a clean NotSupported fallback.
397        let create = api.cu_graph_create.ok_or(CudaError::NotSupported)?;
398        let add_empty = api.cu_graph_add_empty_node.ok_or(CudaError::NotSupported)?;
399        let destroy = api.cu_graph_destroy.ok_or(CudaError::NotSupported)?;
400
401        // A topological order of the in-memory nodes — guaranteed acyclic
402        // because `instantiate` runs `topological_sort` first.
403        let order = self.topological_sort()?;
404
405        // 1. Create an empty CUgraph.
406        let mut raw_graph = CUgraph::default();
407        // SAFETY: `create` was just resolved from the driver; `raw_graph` is
408        // a valid out-pointer and flags=0 is the only documented value.
409        crate::error::check(unsafe { create(&mut raw_graph, 0) })?;
410
411        // From here on, any failure must destroy `raw_graph` before
412        // returning so the driver-side object does not leak.
413        let build = || -> CudaResult<CUgraphExec> {
414            // 2. Add one real driver node per in-memory node, in topological
415            //    order, wiring the incoming dependency edges as we go.
416            //    `driver_nodes[idx]` holds the driver handle for in-memory
417            //    node `idx` once it has been created.
418            let mut driver_nodes: Vec<Option<CUgraphNode>> = vec![None; self.nodes.len()];
419            for &node_idx in &order {
420                // Collect the driver handles of every node this node depends
421                // on — edges `(from, to)` with `to == node_idx`.  In a valid
422                // topological order every `from` precedes `node_idx`, so each
423                // handle is already present.
424                let mut deps: Vec<CUgraphNode> = Vec::new();
425                for &(from, to) in &self.dependencies {
426                    if to == node_idx {
427                        let handle = driver_nodes
428                            .get(from)
429                            .copied()
430                            .flatten()
431                            .ok_or(CudaError::InvalidValue)?;
432                        deps.push(handle);
433                    }
434                }
435
436                let dep_ptr = if deps.is_empty() {
437                    std::ptr::null()
438                } else {
439                    deps.as_ptr()
440                };
441
442                let mut driver_node = CUgraphNode::default();
443                // SAFETY: `add_empty` was resolved from the driver;
444                // `driver_node` is a valid out-pointer, `raw_graph` is the
445                // live graph created above, and `dep_ptr`/`deps.len()`
446                // describe a valid (possibly empty) dependency slice whose
447                // handles were all produced by earlier iterations.
448                crate::error::check(unsafe {
449                    add_empty(&mut driver_node, raw_graph, dep_ptr, deps.len())
450                })?;
451                driver_nodes[node_idx] = Some(driver_node);
452            }
453
454            // 3. Instantiate the populated graph into an executable form.
455            self.instantiate_driver_graph(api, raw_graph)
456        };
457
458        match build() {
459            Ok(raw_exec) => Ok((Some(raw_graph), Some(raw_exec))),
460            Err(e) => {
461                // SAFETY: `destroy` was resolved from the driver and
462                // `raw_graph` is the live handle created above.
463                let rc = unsafe { destroy(raw_graph) };
464                if rc != 0 {
465                    tracing::warn!(
466                        cuda_error = rc,
467                        "cuGraphDestroy failed while unwinding a failed instantiation"
468                    );
469                }
470                Err(e)
471            }
472        }
473    }
474
475    /// Finalise a populated `CUgraph` into an executable `CUgraphExec`.
476    ///
477    /// Prefers `cuGraphInstantiateWithFlags` (CUDA 11.4+) and falls back to
478    /// the legacy `cuGraphInstantiate_v2` signature.
479    fn instantiate_driver_graph(
480        &self,
481        api: &crate::loader::DriverApi,
482        raw_graph: crate::ffi::CUgraph,
483    ) -> CudaResult<crate::ffi::CUgraphExec> {
484        use crate::ffi::CUgraphExec;
485
486        let mut raw_exec = CUgraphExec::default();
487
488        if let Some(instantiate_flags) = api.cu_graph_instantiate_with_flags {
489            // SAFETY: `instantiate_flags` was resolved from the driver;
490            // `raw_exec` is a valid out-pointer, `raw_graph` is a live
491            // populated graph, and flags=0 requests default instantiation.
492            crate::error::check(unsafe { instantiate_flags(&mut raw_exec, raw_graph, 0) })?;
493            return Ok(raw_exec);
494        }
495
496        let instantiate = api.cu_graph_instantiate.ok_or(CudaError::NotSupported)?;
497        // SAFETY: `instantiate` was resolved from the driver; `raw_exec` is a
498        // valid out-pointer, `raw_graph` is a live populated graph, and
499        // passing null error-node / log-buffer pointers with a zero buffer
500        // size is the documented "no diagnostics" configuration.
501        crate::error::check(unsafe {
502            instantiate(
503                &mut raw_exec,
504                raw_graph,
505                std::ptr::null_mut(),
506                std::ptr::null_mut(),
507                0,
508            )
509        })?;
510        Ok(raw_exec)
511    }
512}
513
514impl std::fmt::Display for Graph {
515    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
516        write!(
517            f,
518            "Graph({} nodes, {} deps)",
519            self.nodes.len(),
520            self.dependencies.len()
521        )
522    }
523}
524
525// ---------------------------------------------------------------------------
526// GraphExec — instantiated executable graph
527// ---------------------------------------------------------------------------
528
529/// An instantiated, executable graph.
530///
531/// Created by [`Graph::instantiate`], a `GraphExec` holds a snapshot of the
532/// graph and a pre-computed execution order.
533///
534/// # Driver backing
535///
536/// When a CUDA driver is available, `instantiate` builds a genuine
537/// `CUgraph` (`cuGraphCreate` + one `cuGraphAdd*Node` per in-memory node,
538/// with the dependency DAG wired through real `CUgraphNode` edges) and
539/// finalises it into a `CUgraphExec` via `cuGraphInstantiate`.  In that
540/// case [`launch`](Self::launch) issues a real `cuGraphLaunch`.
541///
542/// The in-memory [`GraphNode`] representation stores only an operation
543/// *specification* (kernel name, copy direction/size, memset size/value) —
544/// it carries no resolved `CUfunction` or device pointers.  Every node is
545/// therefore translated to a real `cuGraphAddEmptyNode`: the resulting
546/// driver graph reproduces the node count and dependency topology exactly
547/// and executes on the GPU as a DAG of synchronisation barriers.  The
548/// per-node dispatch in `Graph::build_driver_graph` is structured so that
549/// kernel / memcpy / memset nodes that gain concrete device operands can be
550/// promoted to `cuGraphAddKernelNode` / `cuGraphAddMemcpyNode` /
551/// `cuGraphAddMemsetNode` without further restructuring.
552///
553/// On macOS (or any host without a CUDA driver), no driver handles are
554/// created; the graph is still validated (topological sort) and
555/// [`launch`](Self::launch) returns [`CudaError::NotInitialized`].
556pub struct GraphExec {
557    graph: Graph,
558    execution_order: Vec<usize>,
559    /// Real `CUgraph` handle, when a driver backed instantiation.
560    raw_graph: Option<crate::ffi::CUgraph>,
561    /// Real `CUgraphExec` handle, when a driver backed instantiation.
562    raw_exec: Option<crate::ffi::CUgraphExec>,
563}
564
565impl GraphExec {
566    /// Launches the executable graph on the given stream.
567    ///
568    /// When this `GraphExec` is backed by a real `CUgraphExec`, this issues
569    /// `cuGraphLaunch(hGraphExec, hStream)`, submitting the entire graph to
570    /// the stream with minimal CPU overhead.  Otherwise it surfaces the
571    /// driver-load error.
572    ///
573    /// # Errors
574    ///
575    /// * [`CudaError::NotInitialized`] if the CUDA driver is not available
576    ///   (e.g. on macOS, or a host without an NVIDIA GPU).
577    /// * Any [`CudaError`] mapped from `cuGraphLaunch`.
578    pub fn launch(&self, stream: &Stream) -> CudaResult<()> {
579        let api = crate::loader::try_driver()?;
580
581        // A driver is present.  If instantiation produced a real executable
582        // graph, submit it; otherwise the driver lacks the graph API.
583        let raw_exec = self.raw_exec.ok_or(CudaError::NotSupported)?;
584        let launch = api.cu_graph_launch.ok_or(CudaError::NotSupported)?;
585
586        // SAFETY: `launch` was just resolved from the driver; `raw_exec` is a
587        // live `CUgraphExec` produced by `cuGraphInstantiate` and kept alive
588        // by `self`, and `stream.raw()` is a valid `CUstream`.
589        crate::error::check(unsafe { launch(raw_exec, stream.raw()) })
590    }
591
592    /// Returns a reference to the underlying graph.
593    #[inline]
594    pub fn graph(&self) -> &Graph {
595        &self.graph
596    }
597
598    /// Returns the pre-computed execution order (topological sort).
599    #[inline]
600    pub fn execution_order(&self) -> &[usize] {
601        &self.execution_order
602    }
603
604    /// Returns the total number of nodes that would be executed.
605    #[inline]
606    pub fn node_count(&self) -> usize {
607        self.graph.node_count()
608    }
609
610    /// Returns `true` if this `GraphExec` is backed by a real, live
611    /// `CUgraphExec` driver handle (as opposed to a CPU-side-only graph).
612    #[inline]
613    pub fn is_driver_backed(&self) -> bool {
614        self.raw_exec.is_some()
615    }
616}
617
618impl std::fmt::Debug for GraphExec {
619    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
620        f.debug_struct("GraphExec")
621            .field("graph", &self.graph)
622            .field("execution_order", &self.execution_order)
623            .field("driver_backed", &self.is_driver_backed())
624            .finish()
625    }
626}
627
628impl Drop for GraphExec {
629    fn drop(&mut self) {
630        // Release driver handles in reverse construction order: the
631        // executable graph first, then the source graph.
632        if let Ok(api) = crate::loader::try_driver() {
633            if let (Some(exec), Some(destroy)) = (self.raw_exec, api.cu_graph_exec_destroy) {
634                // SAFETY: `destroy` was resolved from the driver and `exec`
635                // is a live handle produced by `cuGraphInstantiate`.
636                let rc = unsafe { destroy(exec) };
637                if rc != 0 {
638                    tracing::warn!(cuda_error = rc, "cuGraphExecDestroy failed during drop");
639                }
640            }
641            if let (Some(graph), Some(destroy)) = (self.raw_graph, api.cu_graph_destroy) {
642                // SAFETY: `destroy` was resolved from the driver and `graph`
643                // is a live handle produced by `cuGraphCreate`.
644                let rc = unsafe { destroy(graph) };
645                if rc != 0 {
646                    tracing::warn!(cuda_error = rc, "cuGraphDestroy failed during drop");
647                }
648            }
649        }
650    }
651}
652
653// ---------------------------------------------------------------------------
654// StreamCapture — capture operations into a graph
655// ---------------------------------------------------------------------------
656
657/// Records GPU operations submitted to a stream into a [`Graph`].
658///
659/// Stream capture intercepts operations that would normally be submitted
660/// to a CUDA stream and instead records them as graph nodes. The captured
661/// operations can then be replayed efficiently via [`GraphExec`].
662///
663/// # Usage
664///
665/// ```rust,no_run
666/// # use oxicuda_driver::graph::{StreamCapture, MemcpyDirection};
667/// # use oxicuda_driver::stream::Stream;
668/// # use std::sync::Arc;
669/// # use oxicuda_driver::context::Context;
670/// # fn main() -> oxicuda_driver::CudaResult<()> {
671/// # let ctx: Arc<Context> = unimplemented!();
672/// # let stream = Stream::new(&ctx)?;
673/// let mut capture = StreamCapture::begin(&stream)?;
674///
675/// capture.record_kernel("my_kernel", (4, 1, 1), (256, 1, 1), 0);
676/// capture.record_memcpy(MemcpyDirection::DeviceToHost, 1024);
677///
678/// let graph = capture.end()?;
679/// assert_eq!(graph.node_count(), 2);
680/// # Ok(())
681/// # }
682/// ```
683pub struct StreamCapture {
684    nodes: Vec<GraphNode>,
685    /// Whether capture is still active (not yet ended).
686    active: bool,
687}
688
689impl StreamCapture {
690    /// Begins capturing operations on the given stream.
691    ///
692    /// On a real CUDA system, this would call
693    /// `cuStreamBeginCapture(stream, CU_STREAM_CAPTURE_MODE_GLOBAL)`.
694    ///
695    /// # Errors
696    ///
697    /// Returns [`CudaError::NotInitialized`] if the CUDA driver is not
698    /// available.
699    pub fn begin(_stream: &Stream) -> CudaResult<Self> {
700        // Validate that the driver is available.
701        let _api = crate::loader::try_driver()?;
702        Ok(Self {
703            nodes: Vec::new(),
704            active: true,
705        })
706    }
707
708    /// Records a kernel launch operation in the capture.
709    ///
710    /// # Parameters
711    ///
712    /// * `function_name` - Name of the kernel function.
713    /// * `grid` - Grid dimensions `(x, y, z)`.
714    /// * `block` - Block dimensions `(x, y, z)`.
715    /// * `shared_mem` - Dynamic shared memory in bytes.
716    pub fn record_kernel(
717        &mut self,
718        function_name: &str,
719        grid: (u32, u32, u32),
720        block: (u32, u32, u32),
721        shared_mem: u32,
722    ) {
723        if self.active {
724            self.nodes.push(GraphNode::KernelLaunch {
725                function_name: function_name.to_owned(),
726                grid,
727                block,
728                shared_mem,
729            });
730        }
731    }
732
733    /// Records a memory copy operation in the capture.
734    ///
735    /// # Parameters
736    ///
737    /// * `direction` - Direction of the memory copy.
738    /// * `size` - Size of the transfer in bytes.
739    pub fn record_memcpy(&mut self, direction: MemcpyDirection, size: usize) {
740        if self.active {
741            self.nodes.push(GraphNode::Memcpy { direction, size });
742        }
743    }
744
745    /// Records a memset operation in the capture.
746    ///
747    /// # Parameters
748    ///
749    /// * `size` - Number of bytes to set.
750    /// * `value` - Byte value to fill with.
751    pub fn record_memset(&mut self, size: usize, value: u8) {
752        if self.active {
753            self.nodes.push(GraphNode::Memset { size, value });
754        }
755    }
756
757    /// Returns the number of operations recorded so far.
758    #[inline]
759    pub fn recorded_count(&self) -> usize {
760        self.nodes.len()
761    }
762
763    /// Returns whether the capture is still active.
764    #[inline]
765    pub fn is_active(&self) -> bool {
766        self.active
767    }
768
769    /// Ends the capture and returns the resulting [`Graph`].
770    ///
771    /// On a real CUDA system, this would call `cuStreamEndCapture`
772    /// and return the captured graph handle.
773    ///
774    /// The captured nodes are connected in a linear chain (each node
775    /// depends on the previous one) to preserve the order in which
776    /// operations were recorded.
777    ///
778    /// # Errors
779    ///
780    /// Returns [`CudaError::StreamCaptureUnmatched`] if the capture
781    /// was already ended.
782    pub fn end(mut self) -> CudaResult<Graph> {
783        if !self.active {
784            return Err(CudaError::StreamCaptureUnmatched);
785        }
786        self.active = false;
787
788        let mut graph = Graph::new();
789        let mut prev_idx: Option<usize> = None;
790
791        for node in self.nodes.drain(..) {
792            let idx = graph.nodes.len();
793            graph.nodes.push(node);
794
795            // Chain each node after the previous to maintain order.
796            if let Some(prev) = prev_idx {
797                graph.dependencies.push((prev, idx));
798            }
799            prev_idx = Some(idx);
800        }
801
802        Ok(graph)
803    }
804}
805
806// ---------------------------------------------------------------------------
807// Tests
808// ---------------------------------------------------------------------------
809
810#[cfg(test)]
811mod tests {
812    use super::*;
813
814    #[test]
815    fn graph_new_is_empty() {
816        let g = Graph::new();
817        assert_eq!(g.node_count(), 0);
818        assert_eq!(g.dependency_count(), 0);
819        assert!(g.nodes().is_empty());
820        assert!(g.dependencies().is_empty());
821    }
822
823    #[test]
824    fn graph_default_is_empty() {
825        let g = Graph::default();
826        assert_eq!(g.node_count(), 0);
827    }
828
829    #[test]
830    fn add_kernel_node_returns_sequential_indices() {
831        let mut g = Graph::new();
832        let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
833        let n1 = g.add_kernel_node("k1", (2, 1, 1), (64, 1, 1), 128);
834        assert_eq!(n0, 0);
835        assert_eq!(n1, 1);
836        assert_eq!(g.node_count(), 2);
837    }
838
839    #[test]
840    fn add_memcpy_node_records_direction_and_size() {
841        let mut g = Graph::new();
842        let idx = g.add_memcpy_node(MemcpyDirection::HostToDevice, 4096);
843        assert_eq!(idx, 0);
844        let node = g.get_node(0);
845        assert!(node.is_some());
846        if let Some(GraphNode::Memcpy { direction, size }) = node {
847            assert_eq!(*direction, MemcpyDirection::HostToDevice);
848            assert_eq!(*size, 4096);
849        } else {
850            panic!("expected Memcpy node");
851        }
852    }
853
854    #[test]
855    fn add_memset_node_records_size_and_value() {
856        let mut g = Graph::new();
857        let idx = g.add_memset_node(8192, 0xAB);
858        assert_eq!(idx, 0);
859        if let Some(GraphNode::Memset { size, value }) = g.get_node(idx) {
860            assert_eq!(*size, 8192);
861            assert_eq!(*value, 0xAB);
862        } else {
863            panic!("expected Memset node");
864        }
865    }
866
867    #[test]
868    fn add_empty_node_works() {
869        let mut g = Graph::new();
870        let idx = g.add_empty_node();
871        assert_eq!(idx, 0);
872        assert_eq!(g.get_node(idx), Some(&GraphNode::Empty));
873    }
874
875    #[test]
876    fn add_dependency_valid() {
877        let mut g = Graph::new();
878        let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
879        let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
880        assert!(g.add_dependency(n0, n1).is_ok());
881        assert_eq!(g.dependency_count(), 1);
882        assert_eq!(g.dependencies()[0], (0, 1));
883    }
884
885    #[test]
886    fn add_dependency_out_of_bounds() {
887        let mut g = Graph::new();
888        let _n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
889        let result = g.add_dependency(0, 5);
890        assert_eq!(result, Err(CudaError::InvalidValue));
891    }
892
893    #[test]
894    fn add_dependency_self_loop() {
895        let mut g = Graph::new();
896        let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
897        let result = g.add_dependency(n0, n0);
898        assert_eq!(result, Err(CudaError::InvalidValue));
899    }
900
901    #[test]
902    fn topological_sort_linear_chain() {
903        let mut g = Graph::new();
904        let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
905        let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
906        let n2 = g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
907        g.add_dependency(n0, n1).ok();
908        g.add_dependency(n1, n2).ok();
909
910        let order = g.topological_sort();
911        assert!(order.is_ok());
912        let order = order.ok();
913        assert!(order.is_some());
914        let order = order.unwrap_or_default();
915        // n0 must come before n1, n1 before n2
916        let pos = |n: usize| -> usize { order.iter().position(|&x| x == n).unwrap_or(usize::MAX) };
917        assert!(pos(n0) < pos(n1));
918        assert!(pos(n1) < pos(n2));
919    }
920
921    #[test]
922    fn topological_sort_detects_cycle() {
923        let mut g = Graph::new();
924        let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
925        let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
926        g.add_dependency(n0, n1).ok();
927        g.add_dependency(n1, n0).ok();
928
929        let result = g.topological_sort();
930        assert_eq!(result, Err(CudaError::InvalidValue));
931    }
932
933    #[test]
934    fn topological_sort_no_deps() {
935        let mut g = Graph::new();
936        g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
937        g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
938        g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
939
940        let order = g.topological_sort();
941        assert!(order.is_ok());
942        let order = order.unwrap_or_default();
943        assert_eq!(order.len(), 3);
944    }
945
946    #[test]
947    fn instantiate_valid_graph() {
948        let mut g = Graph::new();
949        let n0 = g.add_memcpy_node(MemcpyDirection::HostToDevice, 1024);
950        let n1 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
951        let n2 = g.add_memcpy_node(MemcpyDirection::DeviceToHost, 1024);
952        g.add_dependency(n0, n1).ok();
953        g.add_dependency(n1, n2).ok();
954
955        let exec = g.instantiate();
956        assert!(exec.is_ok());
957        let exec = exec.ok();
958        assert!(exec.is_some());
959        if let Some(exec) = exec {
960            assert_eq!(exec.node_count(), 3);
961            assert_eq!(exec.execution_order().len(), 3);
962        }
963    }
964
965    #[test]
966    fn instantiate_cyclic_graph_fails() {
967        let mut g = Graph::new();
968        let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
969        let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
970        g.add_dependency(n0, n1).ok();
971        g.add_dependency(n1, n0).ok();
972
973        let result = g.instantiate();
974        assert!(result.is_err());
975    }
976
977    #[test]
978    fn graph_display() {
979        let mut g = Graph::new();
980        g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
981        g.add_memcpy_node(MemcpyDirection::HostToDevice, 512);
982        let disp = format!("{g}");
983        assert!(disp.contains("2 nodes"));
984        assert!(disp.contains("0 deps"));
985    }
986
987    #[test]
988    fn node_display() {
989        let node = GraphNode::KernelLaunch {
990            function_name: "foo".to_owned(),
991            grid: (4, 1, 1),
992            block: (256, 1, 1),
993            shared_mem: 0,
994        };
995        let disp = format!("{node}");
996        assert!(disp.contains("foo"));
997
998        let node = GraphNode::Memcpy {
999            direction: MemcpyDirection::DeviceToHost,
1000            size: 1024,
1001        };
1002        let disp = format!("{node}");
1003        assert!(disp.contains("DtoH"));
1004
1005        let node = GraphNode::Memset {
1006            size: 256,
1007            value: 0xFF,
1008        };
1009        let disp = format!("{node}");
1010        assert!(disp.contains("0xff"));
1011
1012        let node = GraphNode::Empty;
1013        let disp = format!("{node}");
1014        assert!(disp.contains("Empty"));
1015    }
1016
1017    #[test]
1018    fn memcpy_direction_display() {
1019        assert_eq!(format!("{}", MemcpyDirection::HostToDevice), "HtoD");
1020        assert_eq!(format!("{}", MemcpyDirection::DeviceToHost), "DtoH");
1021        assert_eq!(format!("{}", MemcpyDirection::DeviceToDevice), "DtoD");
1022    }
1023
1024    #[test]
1025    fn graph_get_node_out_of_bounds() {
1026        let g = Graph::new();
1027        assert!(g.get_node(0).is_none());
1028        assert!(g.get_node(100).is_none());
1029    }
1030
1031    #[test]
1032    fn graph_diamond_dag() {
1033        // Diamond: n0 -> n1, n0 -> n2, n1 -> n3, n2 -> n3
1034        let mut g = Graph::new();
1035        let n0 = g.add_empty_node();
1036        let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
1037        let n2 = g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
1038        let n3 = g.add_empty_node();
1039        g.add_dependency(n0, n1).ok();
1040        g.add_dependency(n0, n2).ok();
1041        g.add_dependency(n1, n3).ok();
1042        g.add_dependency(n2, n3).ok();
1043
1044        let order = g.topological_sort().unwrap_or_default();
1045        assert_eq!(order.len(), 4);
1046        let pos = |n: usize| -> usize { order.iter().position(|&x| x == n).unwrap_or(usize::MAX) };
1047        assert!(pos(n0) < pos(n1));
1048        assert!(pos(n0) < pos(n2));
1049        assert!(pos(n1) < pos(n3));
1050        assert!(pos(n2) < pos(n3));
1051
1052        let exec = g.instantiate();
1053        assert!(exec.is_ok());
1054    }
1055
1056    #[test]
1057    fn graph_exec_debug() {
1058        let mut g = Graph::new();
1059        g.add_empty_node();
1060        let exec = g.instantiate().ok();
1061        assert!(exec.is_some());
1062        if let Some(exec) = exec {
1063            let dbg = format!("{exec:?}");
1064            assert!(dbg.contains("GraphExec"));
1065            // The debug output advertises the driver-backed status.
1066            assert!(dbg.contains("driver_backed"));
1067        }
1068    }
1069
1070    // -- Driver-backed instantiation ---------------------------------------
1071    //
1072    // `instantiate` builds a real `CUgraph`/`CUgraphExec` when a driver is
1073    // present, and a CPU-side-only `GraphExec` otherwise.  On a host with no
1074    // CUDA driver every path below must still produce a valid `GraphExec`
1075    // (clean fallback) — never a panic, never an error from the missing
1076    // driver alone.
1077
1078    /// Returns `true` when a real CUDA driver is loadable on this host.
1079    fn driver_present() -> bool {
1080        crate::loader::try_driver().is_ok()
1081    }
1082
1083    /// Instantiating an empty graph succeeds; without a driver the result
1084    /// is a CPU-side-only `GraphExec`.
1085    #[test]
1086    fn instantiate_empty_graph_driver_state() {
1087        let g = Graph::new();
1088        let exec = g.instantiate().expect("empty graph instantiates");
1089        assert_eq!(exec.node_count(), 0);
1090        if driver_present() {
1091            // A live driver either backs the graph or, on a graphless
1092            // driver, leaves it CPU-side — both are valid, typed outcomes.
1093            let _ = exec.is_driver_backed();
1094        } else {
1095            assert!(!exec.is_driver_backed());
1096        }
1097    }
1098
1099    /// A linear-chain graph instantiates and preserves topology; the
1100    /// `GraphExec` reports a consistent driver-backed flag.
1101    #[test]
1102    fn instantiate_chain_preserves_topology() {
1103        let mut g = Graph::new();
1104        let n0 = g.add_memset_node(256, 0);
1105        let n1 = g.add_kernel_node("k", (1, 1, 1), (32, 1, 1), 0);
1106        let n2 = g.add_memcpy_node(MemcpyDirection::DeviceToHost, 256);
1107        g.add_dependency(n0, n1).ok();
1108        g.add_dependency(n1, n2).ok();
1109
1110        let exec = g.instantiate().expect("chain instantiates");
1111        assert_eq!(exec.node_count(), 3);
1112        assert_eq!(exec.execution_order().len(), 3);
1113        if !driver_present() {
1114            assert!(!exec.is_driver_backed());
1115        }
1116    }
1117
1118    /// A diamond DAG instantiates without a driver to a CPU-side `GraphExec`.
1119    #[test]
1120    fn instantiate_diamond_without_driver_is_clean() {
1121        let mut g = Graph::new();
1122        let n0 = g.add_empty_node();
1123        let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
1124        let n2 = g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
1125        let n3 = g.add_empty_node();
1126        g.add_dependency(n0, n1).ok();
1127        g.add_dependency(n0, n2).ok();
1128        g.add_dependency(n1, n3).ok();
1129        g.add_dependency(n2, n3).ok();
1130
1131        let exec = g.instantiate();
1132        assert!(exec.is_ok(), "diamond DAG must instantiate cleanly");
1133        if !driver_present() {
1134            if let Ok(exec) = exec {
1135                assert!(!exec.is_driver_backed());
1136            }
1137        }
1138    }
1139
1140    /// `build_driver_graph` surfaces a clean typed error on a host with no
1141    /// driver — `NotInitialized`, never a panic.
1142    #[test]
1143    fn build_driver_graph_absent_driver_is_clean() {
1144        let mut g = Graph::new();
1145        g.add_empty_node();
1146        let result = g.build_driver_graph();
1147        if driver_present() {
1148            // Live driver: either real handles, or a typed driver error.
1149            match result {
1150                Ok((raw_graph, raw_exec)) => {
1151                    assert_eq!(raw_graph.is_some(), raw_exec.is_some());
1152                }
1153                Err(_) => { /* typed driver error is acceptable */ }
1154            }
1155        } else {
1156            assert_eq!(result.err(), Some(CudaError::NotInitialized));
1157        }
1158    }
1159
1160    /// Dropping a CPU-side-only `GraphExec` must not panic (the `Drop` impl
1161    /// only touches driver handles when both they and the driver exist).
1162    #[test]
1163    fn graph_exec_drop_without_driver_is_safe() {
1164        let mut g = Graph::new();
1165        g.add_empty_node();
1166        g.add_empty_node();
1167        let exec = g.instantiate().expect("instantiates");
1168        // Explicit drop — must complete without panicking.
1169        drop(exec);
1170    }
1171
1172    /// A cyclic graph fails instantiation at the topological-sort stage,
1173    /// before any driver call is attempted.
1174    #[test]
1175    fn instantiate_cycle_fails_before_driver() {
1176        let mut g = Graph::new();
1177        let n0 = g.add_empty_node();
1178        let n1 = g.add_empty_node();
1179        g.add_dependency(n0, n1).ok();
1180        g.add_dependency(n1, n0).ok();
1181        assert_eq!(g.instantiate().err(), Some(CudaError::InvalidValue));
1182    }
1183
1184    // -- End-to-end real-GPU graph execution -------------------------------
1185    //
1186    // When this host has a usable GPU, build a CUDA context (which makes it
1187    // current), instantiate a real driver-backed graph, and launch it via
1188    // `cuGraphLaunch`.  On a host without a GPU the test is a clean no-op.
1189
1190    /// Instantiate and launch a real diamond-DAG graph on the GPU.
1191    #[test]
1192    fn real_graph_instantiate_and_launch() {
1193        use crate::context::Context;
1194        use crate::device::Device;
1195
1196        // No GPU on this host — nothing to exercise.
1197        let device = match Device::get(0) {
1198            Ok(d) => d,
1199            Err(_) => return,
1200        };
1201        // Creating the context makes it current on this thread, which the
1202        // CUDA Graph API requires.
1203        let ctx = match Context::new(&device) {
1204            Ok(c) => std::sync::Arc::new(c),
1205            Err(_) => return,
1206        };
1207        let stream = match Stream::new(&ctx) {
1208            Ok(s) => s,
1209            Err(_) => return,
1210        };
1211
1212        // Diamond DAG: n0 -> {n1, n2} -> n3.
1213        let mut g = Graph::new();
1214        let n0 = g.add_empty_node();
1215        let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
1216        let n2 = g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
1217        let n3 = g.add_empty_node();
1218        g.add_dependency(n0, n1).ok();
1219        g.add_dependency(n0, n2).ok();
1220        g.add_dependency(n1, n3).ok();
1221        g.add_dependency(n2, n3).ok();
1222
1223        let exec = g.instantiate().expect("diamond DAG instantiates");
1224        assert_eq!(exec.node_count(), 4);
1225
1226        // With a context current and a graph-capable driver, the graph must
1227        // be driver-backed and `cuGraphLaunch` must succeed.
1228        if exec.is_driver_backed() {
1229            exec.launch(&stream)
1230                .expect("cuGraphLaunch on a real graph succeeds");
1231            stream
1232                .synchronize()
1233                .expect("stream synchronises after graph launch");
1234        }
1235    }
1236
1237    /// A driver-backed graph can be relaunched repeatedly on the same stream.
1238    #[test]
1239    fn real_graph_repeated_launch() {
1240        use crate::context::Context;
1241        use crate::device::Device;
1242
1243        let device = match Device::get(0) {
1244            Ok(d) => d,
1245            Err(_) => return,
1246        };
1247        let ctx = match Context::new(&device) {
1248            Ok(c) => std::sync::Arc::new(c),
1249            Err(_) => return,
1250        };
1251        let stream = match Stream::new(&ctx) {
1252            Ok(s) => s,
1253            Err(_) => return,
1254        };
1255
1256        let mut g = Graph::new();
1257        let a = g.add_empty_node();
1258        let b = g.add_empty_node();
1259        g.add_dependency(a, b).ok();
1260
1261        let exec = g.instantiate().expect("chain instantiates");
1262        if exec.is_driver_backed() {
1263            // The whole point of a graph: cheap repeated submission.
1264            for _ in 0..8 {
1265                exec.launch(&stream)
1266                    .expect("repeated cuGraphLaunch succeeds");
1267            }
1268            stream.synchronize().expect("stream synchronises");
1269        }
1270    }
1271}