oxicuda_driver/graph.rs
1//! CUDA Graph API for recording and replaying sequences of GPU operations.
2//!
3//! CUDA Graphs allow capturing a sequence of operations (kernel launches,
4//! memory copies, memsets) into a graph data structure that can be
5//! instantiated and launched repeatedly with minimal CPU overhead.
6//!
7//! # Architecture
8//!
9//! This module exposes a Rust-side graph representation that records
10//! operations as nodes with explicit dependency edges. [`Graph::instantiate`]
11//! translates that representation into the native CUDA Graph API
12//! (`cuGraphCreate`, `cuGraphAdd*Node`, `cuGraphInstantiate`) whenever a
13//! CUDA driver is available, and [`GraphExec::launch`] issues a real
14//! `cuGraphLaunch`. On macOS (or any host without a driver) the graph is
15//! still built and validated CPU-side, and launching reports
16//! [`CudaError::NotInitialized`].
17//!
18//! # Example
19//!
20//! ```rust,no_run
21//! # use oxicuda_driver::graph::{Graph, GraphNode, MemcpyDirection};
22//! let mut graph = Graph::new();
23//!
24//! let n0 = graph.add_memcpy_node(MemcpyDirection::HostToDevice, 4096);
25//! let n1 = graph.add_kernel_node(
26//! "vector_add",
27//! (4, 1, 1),
28//! (256, 1, 1),
29//! 0,
30//! );
31//! let n2 = graph.add_memcpy_node(MemcpyDirection::DeviceToHost, 4096);
32//!
33//! graph.add_dependency(n0, n1).ok();
34//! graph.add_dependency(n1, n2).ok();
35//!
36//! assert_eq!(graph.node_count(), 3);
37//! assert_eq!(graph.dependency_count(), 2);
38//! ```
39
40use crate::error::{CudaError, CudaResult};
41use crate::stream::Stream;
42
43// ---------------------------------------------------------------------------
44// GraphNode — individual operation in a graph
45// ---------------------------------------------------------------------------
46
47/// Direction of a memory copy operation within a graph node.
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
49pub enum MemcpyDirection {
50 /// Host to device transfer.
51 HostToDevice,
52 /// Device to host transfer.
53 DeviceToHost,
54 /// Device to device transfer.
55 DeviceToDevice,
56}
57
58impl std::fmt::Display for MemcpyDirection {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 match self {
61 Self::HostToDevice => write!(f, "HtoD"),
62 Self::DeviceToHost => write!(f, "DtoH"),
63 Self::DeviceToDevice => write!(f, "DtoD"),
64 }
65 }
66}
67
68/// A single operation node within a [`Graph`].
69///
70/// Each variant represents a different type of GPU operation that can
71/// be recorded into a graph.
72#[derive(Debug, Clone, PartialEq, Eq)]
73pub enum GraphNode {
74 /// A kernel launch with grid/block configuration.
75 KernelLaunch {
76 /// Name of the kernel function.
77 function_name: String,
78 /// Grid dimensions `(x, y, z)`.
79 grid: (u32, u32, u32),
80 /// Block dimensions `(x, y, z)`.
81 block: (u32, u32, u32),
82 /// Dynamic shared memory in bytes.
83 shared_mem: u32,
84 },
85 /// A memory copy operation.
86 Memcpy {
87 /// Direction of the copy.
88 direction: MemcpyDirection,
89 /// Size of the transfer in bytes.
90 size: usize,
91 },
92 /// A memset operation (fill device memory with a byte value).
93 Memset {
94 /// Number of bytes to set.
95 size: usize,
96 /// Byte value to fill with.
97 value: u8,
98 },
99 /// An empty/no-op node used as a synchronisation barrier.
100 Empty,
101}
102
103impl std::fmt::Display for GraphNode {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 match self {
106 Self::KernelLaunch {
107 function_name,
108 grid,
109 block,
110 shared_mem,
111 } => write!(
112 f,
113 "Kernel({}, grid=({},{},{}), block=({},{},{}), smem={})",
114 function_name, grid.0, grid.1, grid.2, block.0, block.1, block.2, shared_mem,
115 ),
116 Self::Memcpy { direction, size } => {
117 write!(f, "Memcpy({direction}, {size} bytes)")
118 }
119 Self::Memset { size, value } => {
120 write!(f, "Memset({size} bytes, value=0x{value:02x})")
121 }
122 Self::Empty => write!(f, "Empty"),
123 }
124 }
125}
126
127// ---------------------------------------------------------------------------
128// Graph — collection of nodes with dependency edges
129// ---------------------------------------------------------------------------
130
131/// A CUDA graph representing a DAG of GPU operations.
132///
133/// Nodes represent individual operations (kernel launches, memory copies,
134/// memsets, or empty barriers). Dependencies are directed edges that
135/// enforce execution ordering between nodes.
136///
137/// The graph can be instantiated into a [`GraphExec`] for repeated
138/// low-overhead execution.
139#[derive(Debug, Clone)]
140pub struct Graph {
141 nodes: Vec<GraphNode>,
142 dependencies: Vec<(usize, usize)>,
143}
144
145impl Default for Graph {
146 fn default() -> Self {
147 Self::new()
148 }
149}
150
151impl Graph {
152 /// Creates a new empty graph with no nodes or dependencies.
153 pub fn new() -> Self {
154 Self {
155 nodes: Vec::new(),
156 dependencies: Vec::new(),
157 }
158 }
159
160 /// Adds a kernel launch node to the graph.
161 ///
162 /// Returns the index of the newly created node, which can be used
163 /// to establish dependencies via [`add_dependency`](Self::add_dependency).
164 ///
165 /// # Parameters
166 ///
167 /// * `function_name` - Name of the kernel function.
168 /// * `grid` - Grid dimensions `(x, y, z)`.
169 /// * `block` - Block dimensions `(x, y, z)`.
170 /// * `shared_mem` - Dynamic shared memory in bytes.
171 pub fn add_kernel_node(
172 &mut self,
173 function_name: &str,
174 grid: (u32, u32, u32),
175 block: (u32, u32, u32),
176 shared_mem: u32,
177 ) -> usize {
178 let idx = self.nodes.len();
179 self.nodes.push(GraphNode::KernelLaunch {
180 function_name: function_name.to_owned(),
181 grid,
182 block,
183 shared_mem,
184 });
185 idx
186 }
187
188 /// Adds a memory copy node to the graph.
189 ///
190 /// Returns the index of the newly created node.
191 ///
192 /// # Parameters
193 ///
194 /// * `direction` - Direction of the memory copy.
195 /// * `size` - Size of the transfer in bytes.
196 pub fn add_memcpy_node(&mut self, direction: MemcpyDirection, size: usize) -> usize {
197 let idx = self.nodes.len();
198 self.nodes.push(GraphNode::Memcpy { direction, size });
199 idx
200 }
201
202 /// Adds a memset node to the graph.
203 ///
204 /// Returns the index of the newly created node.
205 ///
206 /// # Parameters
207 ///
208 /// * `size` - Number of bytes to set.
209 /// * `value` - Byte value to fill with.
210 pub fn add_memset_node(&mut self, size: usize, value: u8) -> usize {
211 let idx = self.nodes.len();
212 self.nodes.push(GraphNode::Memset { size, value });
213 idx
214 }
215
216 /// Adds an empty (no-op) node to the graph.
217 ///
218 /// Empty nodes are useful as synchronisation barriers — they have
219 /// no work of their own but can serve as join points for multiple
220 /// dependency chains.
221 ///
222 /// Returns the index of the newly created node.
223 pub fn add_empty_node(&mut self) -> usize {
224 let idx = self.nodes.len();
225 self.nodes.push(GraphNode::Empty);
226 idx
227 }
228
229 /// Adds a dependency edge from node `from` to node `to`.
230 ///
231 /// This means `to` will not begin execution until `from` has
232 /// completed. Both indices must refer to existing nodes.
233 ///
234 /// # Errors
235 ///
236 /// Returns [`CudaError::InvalidValue`] if either index is out of bounds
237 /// or if `from == to` (self-dependency).
238 pub fn add_dependency(&mut self, from: usize, to: usize) -> CudaResult<()> {
239 if from >= self.nodes.len() || to >= self.nodes.len() {
240 return Err(CudaError::InvalidValue);
241 }
242 if from == to {
243 return Err(CudaError::InvalidValue);
244 }
245 self.dependencies.push((from, to));
246 Ok(())
247 }
248
249 /// Returns the total number of nodes in the graph.
250 #[inline]
251 pub fn node_count(&self) -> usize {
252 self.nodes.len()
253 }
254
255 /// Returns the total number of dependency edges in the graph.
256 #[inline]
257 pub fn dependency_count(&self) -> usize {
258 self.dependencies.len()
259 }
260
261 /// Returns a slice of all nodes in insertion order.
262 #[inline]
263 pub fn nodes(&self) -> &[GraphNode] {
264 &self.nodes
265 }
266
267 /// Returns a slice of all dependency edges as `(from, to)` pairs.
268 #[inline]
269 pub fn dependencies(&self) -> &[(usize, usize)] {
270 &self.dependencies
271 }
272
273 /// Returns the node at the given index, or `None` if out of bounds.
274 pub fn get_node(&self, index: usize) -> Option<&GraphNode> {
275 self.nodes.get(index)
276 }
277
278 /// Performs a topological sort of the graph nodes.
279 ///
280 /// Returns the node indices in an order that respects all
281 /// dependency edges, or an error if the graph contains a cycle.
282 ///
283 /// # Errors
284 ///
285 /// Returns [`CudaError::InvalidValue`] if the graph contains a
286 /// dependency cycle.
287 pub fn topological_sort(&self) -> CudaResult<Vec<usize>> {
288 let n = self.nodes.len();
289 let mut in_degree = vec![0u32; n];
290 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
291
292 for &(from, to) in &self.dependencies {
293 adj[from].push(to);
294 in_degree[to] = in_degree[to].saturating_add(1);
295 }
296
297 let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
298 let mut result = Vec::with_capacity(n);
299
300 while let Some(node) = queue.pop() {
301 result.push(node);
302 for &next in &adj[node] {
303 in_degree[next] = in_degree[next].saturating_sub(1);
304 if in_degree[next] == 0 {
305 queue.push(next);
306 }
307 }
308 }
309
310 if result.len() != n {
311 return Err(CudaError::InvalidValue);
312 }
313
314 Ok(result)
315 }
316
317 /// Instantiates the graph into an executable form.
318 ///
319 /// The returned [`GraphExec`] can be launched on a stream with minimal
320 /// CPU overhead. The graph is always validated (topological sort)
321 /// during instantiation.
322 ///
323 /// When a CUDA driver is available, a genuine `CUgraph` is built
324 /// (`cuGraphCreate` + per-node `cuGraphAdd*Node` with the dependency DAG
325 /// wired through real `CUgraphNode` edges) and finalised into a
326 /// `CUgraphExec` via `cuGraphInstantiate`; [`GraphExec::launch`] then
327 /// issues a real `cuGraphLaunch`. Without a driver (macOS, or a host
328 /// with no GPU) the `GraphExec` is CPU-side only and `launch` reports
329 /// [`CudaError::NotInitialized`].
330 ///
331 /// # Errors
332 ///
333 /// * [`CudaError::InvalidValue`] if the graph contains a dependency
334 /// cycle.
335 /// * Any [`CudaError`] mapped from a failing `cuGraph*` driver call
336 /// when a driver is present (e.g. [`CudaError::OutOfMemory`]).
337 pub fn instantiate(&self) -> CudaResult<GraphExec> {
338 // Validate the graph is a DAG by performing a topological sort.
339 // This must succeed regardless of driver availability.
340 let execution_order = self.topological_sort()?;
341
342 // Attempt a real driver-backed instantiation. Fall back to a
343 // CPU-side-only GraphExec for environmental reasons — no driver, no
344 // GPU, no current CUDA context, or a driver predating the graph API
345 // — since none of those indicate a malformed graph. A genuine
346 // graph-construction failure (e.g. OutOfMemory, InvalidValue) is a
347 // real error and propagates to the caller.
348 let (raw_graph, raw_exec) = match self.build_driver_graph() {
349 Ok(handles) => handles,
350 Err(
351 CudaError::NotInitialized
352 | CudaError::NotSupported
353 | CudaError::InvalidContext
354 | CudaError::NoDevice
355 | CudaError::InvalidDevice
356 | CudaError::Deinitialized,
357 ) => (None, None),
358 Err(other) => return Err(other),
359 };
360
361 Ok(GraphExec {
362 graph: self.clone(),
363 execution_order,
364 raw_graph,
365 raw_exec,
366 })
367 }
368
369 /// Build a real CUDA driver graph from this in-memory representation.
370 ///
371 /// Returns `(Some(CUgraph), Some(CUgraphExec))` on success. Returns
372 /// [`CudaError::NotInitialized`] when no driver is loaded and
373 /// [`CudaError::NotSupported`] when the loaded driver predates the CUDA
374 /// Graph API; [`Graph::instantiate`] turns both (and other environmental
375 /// errors) into a CPU-side-only `GraphExec`. Any other error is a
376 /// genuine driver failure.
377 ///
378 /// Each in-memory [`GraphNode`] is translated to a real driver node and
379 /// the dependency edges are reproduced exactly. Nodes are created in
380 /// topological order so that, when `cuGraphAddEmptyNode` is given a
381 /// node's dependency list, every referenced `CUgraphNode` already
382 /// exists — regardless of the order edges were added to the in-memory
383 /// graph. Because [`GraphNode`] stores only an operation specification
384 /// (no resolved `CUfunction` or device pointers), every node is added
385 /// via `cuGraphAddEmptyNode`; the resulting driver graph preserves the
386 /// node count and dependency topology and executes as a DAG of
387 /// synchronisation barriers.
388 fn build_driver_graph(
389 &self,
390 ) -> CudaResult<(Option<crate::ffi::CUgraph>, Option<crate::ffi::CUgraphExec>)> {
391 use crate::ffi::{CUgraph, CUgraphExec, CUgraphNode};
392
393 let api = crate::loader::try_driver()?;
394
395 // Resolve every required graph entry point; a pre-10.0 driver lacks
396 // them and yields a clean NotSupported fallback.
397 let create = api.cu_graph_create.ok_or(CudaError::NotSupported)?;
398 let add_empty = api.cu_graph_add_empty_node.ok_or(CudaError::NotSupported)?;
399 let destroy = api.cu_graph_destroy.ok_or(CudaError::NotSupported)?;
400
401 // A topological order of the in-memory nodes — guaranteed acyclic
402 // because `instantiate` runs `topological_sort` first.
403 let order = self.topological_sort()?;
404
405 // 1. Create an empty CUgraph.
406 let mut raw_graph = CUgraph::default();
407 // SAFETY: `create` was just resolved from the driver; `raw_graph` is
408 // a valid out-pointer and flags=0 is the only documented value.
409 crate::error::check(unsafe { create(&mut raw_graph, 0) })?;
410
411 // From here on, any failure must destroy `raw_graph` before
412 // returning so the driver-side object does not leak.
413 let build = || -> CudaResult<CUgraphExec> {
414 // 2. Add one real driver node per in-memory node, in topological
415 // order, wiring the incoming dependency edges as we go.
416 // `driver_nodes[idx]` holds the driver handle for in-memory
417 // node `idx` once it has been created.
418 let mut driver_nodes: Vec<Option<CUgraphNode>> = vec![None; self.nodes.len()];
419 for &node_idx in &order {
420 // Collect the driver handles of every node this node depends
421 // on — edges `(from, to)` with `to == node_idx`. In a valid
422 // topological order every `from` precedes `node_idx`, so each
423 // handle is already present.
424 let mut deps: Vec<CUgraphNode> = Vec::new();
425 for &(from, to) in &self.dependencies {
426 if to == node_idx {
427 let handle = driver_nodes
428 .get(from)
429 .copied()
430 .flatten()
431 .ok_or(CudaError::InvalidValue)?;
432 deps.push(handle);
433 }
434 }
435
436 let dep_ptr = if deps.is_empty() {
437 std::ptr::null()
438 } else {
439 deps.as_ptr()
440 };
441
442 let mut driver_node = CUgraphNode::default();
443 // SAFETY: `add_empty` was resolved from the driver;
444 // `driver_node` is a valid out-pointer, `raw_graph` is the
445 // live graph created above, and `dep_ptr`/`deps.len()`
446 // describe a valid (possibly empty) dependency slice whose
447 // handles were all produced by earlier iterations.
448 crate::error::check(unsafe {
449 add_empty(&mut driver_node, raw_graph, dep_ptr, deps.len())
450 })?;
451 driver_nodes[node_idx] = Some(driver_node);
452 }
453
454 // 3. Instantiate the populated graph into an executable form.
455 self.instantiate_driver_graph(api, raw_graph)
456 };
457
458 match build() {
459 Ok(raw_exec) => Ok((Some(raw_graph), Some(raw_exec))),
460 Err(e) => {
461 // SAFETY: `destroy` was resolved from the driver and
462 // `raw_graph` is the live handle created above.
463 let rc = unsafe { destroy(raw_graph) };
464 if rc != 0 {
465 tracing::warn!(
466 cuda_error = rc,
467 "cuGraphDestroy failed while unwinding a failed instantiation"
468 );
469 }
470 Err(e)
471 }
472 }
473 }
474
475 /// Finalise a populated `CUgraph` into an executable `CUgraphExec`.
476 ///
477 /// Prefers `cuGraphInstantiateWithFlags` (CUDA 11.4+) and falls back to
478 /// the legacy `cuGraphInstantiate_v2` signature.
479 fn instantiate_driver_graph(
480 &self,
481 api: &crate::loader::DriverApi,
482 raw_graph: crate::ffi::CUgraph,
483 ) -> CudaResult<crate::ffi::CUgraphExec> {
484 use crate::ffi::CUgraphExec;
485
486 let mut raw_exec = CUgraphExec::default();
487
488 if let Some(instantiate_flags) = api.cu_graph_instantiate_with_flags {
489 // SAFETY: `instantiate_flags` was resolved from the driver;
490 // `raw_exec` is a valid out-pointer, `raw_graph` is a live
491 // populated graph, and flags=0 requests default instantiation.
492 crate::error::check(unsafe { instantiate_flags(&mut raw_exec, raw_graph, 0) })?;
493 return Ok(raw_exec);
494 }
495
496 let instantiate = api.cu_graph_instantiate.ok_or(CudaError::NotSupported)?;
497 // SAFETY: `instantiate` was resolved from the driver; `raw_exec` is a
498 // valid out-pointer, `raw_graph` is a live populated graph, and
499 // passing null error-node / log-buffer pointers with a zero buffer
500 // size is the documented "no diagnostics" configuration.
501 crate::error::check(unsafe {
502 instantiate(
503 &mut raw_exec,
504 raw_graph,
505 std::ptr::null_mut(),
506 std::ptr::null_mut(),
507 0,
508 )
509 })?;
510 Ok(raw_exec)
511 }
512}
513
514impl std::fmt::Display for Graph {
515 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
516 write!(
517 f,
518 "Graph({} nodes, {} deps)",
519 self.nodes.len(),
520 self.dependencies.len()
521 )
522 }
523}
524
525// ---------------------------------------------------------------------------
526// GraphExec — instantiated executable graph
527// ---------------------------------------------------------------------------
528
529/// An instantiated, executable graph.
530///
531/// Created by [`Graph::instantiate`], a `GraphExec` holds a snapshot of the
532/// graph and a pre-computed execution order.
533///
534/// # Driver backing
535///
536/// When a CUDA driver is available, `instantiate` builds a genuine
537/// `CUgraph` (`cuGraphCreate` + one `cuGraphAdd*Node` per in-memory node,
538/// with the dependency DAG wired through real `CUgraphNode` edges) and
539/// finalises it into a `CUgraphExec` via `cuGraphInstantiate`. In that
540/// case [`launch`](Self::launch) issues a real `cuGraphLaunch`.
541///
542/// The in-memory [`GraphNode`] representation stores only an operation
543/// *specification* (kernel name, copy direction/size, memset size/value) —
544/// it carries no resolved `CUfunction` or device pointers. Every node is
545/// therefore translated to a real `cuGraphAddEmptyNode`: the resulting
546/// driver graph reproduces the node count and dependency topology exactly
547/// and executes on the GPU as a DAG of synchronisation barriers. The
548/// per-node dispatch in `Graph::build_driver_graph` is structured so that
549/// kernel / memcpy / memset nodes that gain concrete device operands can be
550/// promoted to `cuGraphAddKernelNode` / `cuGraphAddMemcpyNode` /
551/// `cuGraphAddMemsetNode` without further restructuring.
552///
553/// On macOS (or any host without a CUDA driver), no driver handles are
554/// created; the graph is still validated (topological sort) and
555/// [`launch`](Self::launch) returns [`CudaError::NotInitialized`].
556pub struct GraphExec {
557 graph: Graph,
558 execution_order: Vec<usize>,
559 /// Real `CUgraph` handle, when a driver backed instantiation.
560 raw_graph: Option<crate::ffi::CUgraph>,
561 /// Real `CUgraphExec` handle, when a driver backed instantiation.
562 raw_exec: Option<crate::ffi::CUgraphExec>,
563}
564
565impl GraphExec {
566 /// Launches the executable graph on the given stream.
567 ///
568 /// When this `GraphExec` is backed by a real `CUgraphExec`, this issues
569 /// `cuGraphLaunch(hGraphExec, hStream)`, submitting the entire graph to
570 /// the stream with minimal CPU overhead. Otherwise it surfaces the
571 /// driver-load error.
572 ///
573 /// # Errors
574 ///
575 /// * [`CudaError::NotInitialized`] if the CUDA driver is not available
576 /// (e.g. on macOS, or a host without an NVIDIA GPU).
577 /// * Any [`CudaError`] mapped from `cuGraphLaunch`.
578 pub fn launch(&self, stream: &Stream) -> CudaResult<()> {
579 let api = crate::loader::try_driver()?;
580
581 // A driver is present. If instantiation produced a real executable
582 // graph, submit it; otherwise the driver lacks the graph API.
583 let raw_exec = self.raw_exec.ok_or(CudaError::NotSupported)?;
584 let launch = api.cu_graph_launch.ok_or(CudaError::NotSupported)?;
585
586 // SAFETY: `launch` was just resolved from the driver; `raw_exec` is a
587 // live `CUgraphExec` produced by `cuGraphInstantiate` and kept alive
588 // by `self`, and `stream.raw()` is a valid `CUstream`.
589 crate::error::check(unsafe { launch(raw_exec, stream.raw()) })
590 }
591
592 /// Returns a reference to the underlying graph.
593 #[inline]
594 pub fn graph(&self) -> &Graph {
595 &self.graph
596 }
597
598 /// Returns the pre-computed execution order (topological sort).
599 #[inline]
600 pub fn execution_order(&self) -> &[usize] {
601 &self.execution_order
602 }
603
604 /// Returns the total number of nodes that would be executed.
605 #[inline]
606 pub fn node_count(&self) -> usize {
607 self.graph.node_count()
608 }
609
610 /// Returns `true` if this `GraphExec` is backed by a real, live
611 /// `CUgraphExec` driver handle (as opposed to a CPU-side-only graph).
612 #[inline]
613 pub fn is_driver_backed(&self) -> bool {
614 self.raw_exec.is_some()
615 }
616}
617
618impl std::fmt::Debug for GraphExec {
619 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
620 f.debug_struct("GraphExec")
621 .field("graph", &self.graph)
622 .field("execution_order", &self.execution_order)
623 .field("driver_backed", &self.is_driver_backed())
624 .finish()
625 }
626}
627
628impl Drop for GraphExec {
629 fn drop(&mut self) {
630 // Release driver handles in reverse construction order: the
631 // executable graph first, then the source graph.
632 if let Ok(api) = crate::loader::try_driver() {
633 if let (Some(exec), Some(destroy)) = (self.raw_exec, api.cu_graph_exec_destroy) {
634 // SAFETY: `destroy` was resolved from the driver and `exec`
635 // is a live handle produced by `cuGraphInstantiate`.
636 let rc = unsafe { destroy(exec) };
637 if rc != 0 {
638 tracing::warn!(cuda_error = rc, "cuGraphExecDestroy failed during drop");
639 }
640 }
641 if let (Some(graph), Some(destroy)) = (self.raw_graph, api.cu_graph_destroy) {
642 // SAFETY: `destroy` was resolved from the driver and `graph`
643 // is a live handle produced by `cuGraphCreate`.
644 let rc = unsafe { destroy(graph) };
645 if rc != 0 {
646 tracing::warn!(cuda_error = rc, "cuGraphDestroy failed during drop");
647 }
648 }
649 }
650 }
651}
652
653// ---------------------------------------------------------------------------
654// StreamCapture — capture operations into a graph
655// ---------------------------------------------------------------------------
656
657/// Records GPU operations submitted to a stream into a [`Graph`].
658///
659/// Stream capture intercepts operations that would normally be submitted
660/// to a CUDA stream and instead records them as graph nodes. The captured
661/// operations can then be replayed efficiently via [`GraphExec`].
662///
663/// # Usage
664///
665/// ```rust,no_run
666/// # use oxicuda_driver::graph::{StreamCapture, MemcpyDirection};
667/// # use oxicuda_driver::stream::Stream;
668/// # use std::sync::Arc;
669/// # use oxicuda_driver::context::Context;
670/// # fn main() -> oxicuda_driver::CudaResult<()> {
671/// # let ctx: Arc<Context> = unimplemented!();
672/// # let stream = Stream::new(&ctx)?;
673/// let mut capture = StreamCapture::begin(&stream)?;
674///
675/// capture.record_kernel("my_kernel", (4, 1, 1), (256, 1, 1), 0);
676/// capture.record_memcpy(MemcpyDirection::DeviceToHost, 1024);
677///
678/// let graph = capture.end()?;
679/// assert_eq!(graph.node_count(), 2);
680/// # Ok(())
681/// # }
682/// ```
683pub struct StreamCapture {
684 nodes: Vec<GraphNode>,
685 /// Whether capture is still active (not yet ended).
686 active: bool,
687}
688
689impl StreamCapture {
690 /// Begins capturing operations on the given stream.
691 ///
692 /// On a real CUDA system, this would call
693 /// `cuStreamBeginCapture(stream, CU_STREAM_CAPTURE_MODE_GLOBAL)`.
694 ///
695 /// # Errors
696 ///
697 /// Returns [`CudaError::NotInitialized`] if the CUDA driver is not
698 /// available.
699 pub fn begin(_stream: &Stream) -> CudaResult<Self> {
700 // Validate that the driver is available.
701 let _api = crate::loader::try_driver()?;
702 Ok(Self {
703 nodes: Vec::new(),
704 active: true,
705 })
706 }
707
708 /// Records a kernel launch operation in the capture.
709 ///
710 /// # Parameters
711 ///
712 /// * `function_name` - Name of the kernel function.
713 /// * `grid` - Grid dimensions `(x, y, z)`.
714 /// * `block` - Block dimensions `(x, y, z)`.
715 /// * `shared_mem` - Dynamic shared memory in bytes.
716 pub fn record_kernel(
717 &mut self,
718 function_name: &str,
719 grid: (u32, u32, u32),
720 block: (u32, u32, u32),
721 shared_mem: u32,
722 ) {
723 if self.active {
724 self.nodes.push(GraphNode::KernelLaunch {
725 function_name: function_name.to_owned(),
726 grid,
727 block,
728 shared_mem,
729 });
730 }
731 }
732
733 /// Records a memory copy operation in the capture.
734 ///
735 /// # Parameters
736 ///
737 /// * `direction` - Direction of the memory copy.
738 /// * `size` - Size of the transfer in bytes.
739 pub fn record_memcpy(&mut self, direction: MemcpyDirection, size: usize) {
740 if self.active {
741 self.nodes.push(GraphNode::Memcpy { direction, size });
742 }
743 }
744
745 /// Records a memset operation in the capture.
746 ///
747 /// # Parameters
748 ///
749 /// * `size` - Number of bytes to set.
750 /// * `value` - Byte value to fill with.
751 pub fn record_memset(&mut self, size: usize, value: u8) {
752 if self.active {
753 self.nodes.push(GraphNode::Memset { size, value });
754 }
755 }
756
757 /// Returns the number of operations recorded so far.
758 #[inline]
759 pub fn recorded_count(&self) -> usize {
760 self.nodes.len()
761 }
762
763 /// Returns whether the capture is still active.
764 #[inline]
765 pub fn is_active(&self) -> bool {
766 self.active
767 }
768
769 /// Ends the capture and returns the resulting [`Graph`].
770 ///
771 /// On a real CUDA system, this would call `cuStreamEndCapture`
772 /// and return the captured graph handle.
773 ///
774 /// The captured nodes are connected in a linear chain (each node
775 /// depends on the previous one) to preserve the order in which
776 /// operations were recorded.
777 ///
778 /// # Errors
779 ///
780 /// Returns [`CudaError::StreamCaptureUnmatched`] if the capture
781 /// was already ended.
782 pub fn end(mut self) -> CudaResult<Graph> {
783 if !self.active {
784 return Err(CudaError::StreamCaptureUnmatched);
785 }
786 self.active = false;
787
788 let mut graph = Graph::new();
789 let mut prev_idx: Option<usize> = None;
790
791 for node in self.nodes.drain(..) {
792 let idx = graph.nodes.len();
793 graph.nodes.push(node);
794
795 // Chain each node after the previous to maintain order.
796 if let Some(prev) = prev_idx {
797 graph.dependencies.push((prev, idx));
798 }
799 prev_idx = Some(idx);
800 }
801
802 Ok(graph)
803 }
804}
805
806// ---------------------------------------------------------------------------
807// Tests
808// ---------------------------------------------------------------------------
809
810#[cfg(test)]
811mod tests {
812 use super::*;
813
814 #[test]
815 fn graph_new_is_empty() {
816 let g = Graph::new();
817 assert_eq!(g.node_count(), 0);
818 assert_eq!(g.dependency_count(), 0);
819 assert!(g.nodes().is_empty());
820 assert!(g.dependencies().is_empty());
821 }
822
823 #[test]
824 fn graph_default_is_empty() {
825 let g = Graph::default();
826 assert_eq!(g.node_count(), 0);
827 }
828
829 #[test]
830 fn add_kernel_node_returns_sequential_indices() {
831 let mut g = Graph::new();
832 let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
833 let n1 = g.add_kernel_node("k1", (2, 1, 1), (64, 1, 1), 128);
834 assert_eq!(n0, 0);
835 assert_eq!(n1, 1);
836 assert_eq!(g.node_count(), 2);
837 }
838
839 #[test]
840 fn add_memcpy_node_records_direction_and_size() {
841 let mut g = Graph::new();
842 let idx = g.add_memcpy_node(MemcpyDirection::HostToDevice, 4096);
843 assert_eq!(idx, 0);
844 let node = g.get_node(0);
845 assert!(node.is_some());
846 if let Some(GraphNode::Memcpy { direction, size }) = node {
847 assert_eq!(*direction, MemcpyDirection::HostToDevice);
848 assert_eq!(*size, 4096);
849 } else {
850 panic!("expected Memcpy node");
851 }
852 }
853
854 #[test]
855 fn add_memset_node_records_size_and_value() {
856 let mut g = Graph::new();
857 let idx = g.add_memset_node(8192, 0xAB);
858 assert_eq!(idx, 0);
859 if let Some(GraphNode::Memset { size, value }) = g.get_node(idx) {
860 assert_eq!(*size, 8192);
861 assert_eq!(*value, 0xAB);
862 } else {
863 panic!("expected Memset node");
864 }
865 }
866
867 #[test]
868 fn add_empty_node_works() {
869 let mut g = Graph::new();
870 let idx = g.add_empty_node();
871 assert_eq!(idx, 0);
872 assert_eq!(g.get_node(idx), Some(&GraphNode::Empty));
873 }
874
875 #[test]
876 fn add_dependency_valid() {
877 let mut g = Graph::new();
878 let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
879 let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
880 assert!(g.add_dependency(n0, n1).is_ok());
881 assert_eq!(g.dependency_count(), 1);
882 assert_eq!(g.dependencies()[0], (0, 1));
883 }
884
885 #[test]
886 fn add_dependency_out_of_bounds() {
887 let mut g = Graph::new();
888 let _n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
889 let result = g.add_dependency(0, 5);
890 assert_eq!(result, Err(CudaError::InvalidValue));
891 }
892
893 #[test]
894 fn add_dependency_self_loop() {
895 let mut g = Graph::new();
896 let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
897 let result = g.add_dependency(n0, n0);
898 assert_eq!(result, Err(CudaError::InvalidValue));
899 }
900
901 #[test]
902 fn topological_sort_linear_chain() {
903 let mut g = Graph::new();
904 let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
905 let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
906 let n2 = g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
907 g.add_dependency(n0, n1).ok();
908 g.add_dependency(n1, n2).ok();
909
910 let order = g.topological_sort();
911 assert!(order.is_ok());
912 let order = order.ok();
913 assert!(order.is_some());
914 let order = order.unwrap_or_default();
915 // n0 must come before n1, n1 before n2
916 let pos = |n: usize| -> usize { order.iter().position(|&x| x == n).unwrap_or(usize::MAX) };
917 assert!(pos(n0) < pos(n1));
918 assert!(pos(n1) < pos(n2));
919 }
920
921 #[test]
922 fn topological_sort_detects_cycle() {
923 let mut g = Graph::new();
924 let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
925 let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
926 g.add_dependency(n0, n1).ok();
927 g.add_dependency(n1, n0).ok();
928
929 let result = g.topological_sort();
930 assert_eq!(result, Err(CudaError::InvalidValue));
931 }
932
933 #[test]
934 fn topological_sort_no_deps() {
935 let mut g = Graph::new();
936 g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
937 g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
938 g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
939
940 let order = g.topological_sort();
941 assert!(order.is_ok());
942 let order = order.unwrap_or_default();
943 assert_eq!(order.len(), 3);
944 }
945
946 #[test]
947 fn instantiate_valid_graph() {
948 let mut g = Graph::new();
949 let n0 = g.add_memcpy_node(MemcpyDirection::HostToDevice, 1024);
950 let n1 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
951 let n2 = g.add_memcpy_node(MemcpyDirection::DeviceToHost, 1024);
952 g.add_dependency(n0, n1).ok();
953 g.add_dependency(n1, n2).ok();
954
955 let exec = g.instantiate();
956 assert!(exec.is_ok());
957 let exec = exec.ok();
958 assert!(exec.is_some());
959 if let Some(exec) = exec {
960 assert_eq!(exec.node_count(), 3);
961 assert_eq!(exec.execution_order().len(), 3);
962 }
963 }
964
965 #[test]
966 fn instantiate_cyclic_graph_fails() {
967 let mut g = Graph::new();
968 let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
969 let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
970 g.add_dependency(n0, n1).ok();
971 g.add_dependency(n1, n0).ok();
972
973 let result = g.instantiate();
974 assert!(result.is_err());
975 }
976
977 #[test]
978 fn graph_display() {
979 let mut g = Graph::new();
980 g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
981 g.add_memcpy_node(MemcpyDirection::HostToDevice, 512);
982 let disp = format!("{g}");
983 assert!(disp.contains("2 nodes"));
984 assert!(disp.contains("0 deps"));
985 }
986
987 #[test]
988 fn node_display() {
989 let node = GraphNode::KernelLaunch {
990 function_name: "foo".to_owned(),
991 grid: (4, 1, 1),
992 block: (256, 1, 1),
993 shared_mem: 0,
994 };
995 let disp = format!("{node}");
996 assert!(disp.contains("foo"));
997
998 let node = GraphNode::Memcpy {
999 direction: MemcpyDirection::DeviceToHost,
1000 size: 1024,
1001 };
1002 let disp = format!("{node}");
1003 assert!(disp.contains("DtoH"));
1004
1005 let node = GraphNode::Memset {
1006 size: 256,
1007 value: 0xFF,
1008 };
1009 let disp = format!("{node}");
1010 assert!(disp.contains("0xff"));
1011
1012 let node = GraphNode::Empty;
1013 let disp = format!("{node}");
1014 assert!(disp.contains("Empty"));
1015 }
1016
1017 #[test]
1018 fn memcpy_direction_display() {
1019 assert_eq!(format!("{}", MemcpyDirection::HostToDevice), "HtoD");
1020 assert_eq!(format!("{}", MemcpyDirection::DeviceToHost), "DtoH");
1021 assert_eq!(format!("{}", MemcpyDirection::DeviceToDevice), "DtoD");
1022 }
1023
1024 #[test]
1025 fn graph_get_node_out_of_bounds() {
1026 let g = Graph::new();
1027 assert!(g.get_node(0).is_none());
1028 assert!(g.get_node(100).is_none());
1029 }
1030
1031 #[test]
1032 fn graph_diamond_dag() {
1033 // Diamond: n0 -> n1, n0 -> n2, n1 -> n3, n2 -> n3
1034 let mut g = Graph::new();
1035 let n0 = g.add_empty_node();
1036 let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
1037 let n2 = g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
1038 let n3 = g.add_empty_node();
1039 g.add_dependency(n0, n1).ok();
1040 g.add_dependency(n0, n2).ok();
1041 g.add_dependency(n1, n3).ok();
1042 g.add_dependency(n2, n3).ok();
1043
1044 let order = g.topological_sort().unwrap_or_default();
1045 assert_eq!(order.len(), 4);
1046 let pos = |n: usize| -> usize { order.iter().position(|&x| x == n).unwrap_or(usize::MAX) };
1047 assert!(pos(n0) < pos(n1));
1048 assert!(pos(n0) < pos(n2));
1049 assert!(pos(n1) < pos(n3));
1050 assert!(pos(n2) < pos(n3));
1051
1052 let exec = g.instantiate();
1053 assert!(exec.is_ok());
1054 }
1055
1056 #[test]
1057 fn graph_exec_debug() {
1058 let mut g = Graph::new();
1059 g.add_empty_node();
1060 let exec = g.instantiate().ok();
1061 assert!(exec.is_some());
1062 if let Some(exec) = exec {
1063 let dbg = format!("{exec:?}");
1064 assert!(dbg.contains("GraphExec"));
1065 // The debug output advertises the driver-backed status.
1066 assert!(dbg.contains("driver_backed"));
1067 }
1068 }
1069
1070 // -- Driver-backed instantiation ---------------------------------------
1071 //
1072 // `instantiate` builds a real `CUgraph`/`CUgraphExec` when a driver is
1073 // present, and a CPU-side-only `GraphExec` otherwise. On a host with no
1074 // CUDA driver every path below must still produce a valid `GraphExec`
1075 // (clean fallback) — never a panic, never an error from the missing
1076 // driver alone.
1077
1078 /// Returns `true` when a real CUDA driver is loadable on this host.
1079 fn driver_present() -> bool {
1080 crate::loader::try_driver().is_ok()
1081 }
1082
1083 /// Instantiating an empty graph succeeds; without a driver the result
1084 /// is a CPU-side-only `GraphExec`.
1085 #[test]
1086 fn instantiate_empty_graph_driver_state() {
1087 let g = Graph::new();
1088 let exec = g.instantiate().expect("empty graph instantiates");
1089 assert_eq!(exec.node_count(), 0);
1090 if driver_present() {
1091 // A live driver either backs the graph or, on a graphless
1092 // driver, leaves it CPU-side — both are valid, typed outcomes.
1093 let _ = exec.is_driver_backed();
1094 } else {
1095 assert!(!exec.is_driver_backed());
1096 }
1097 }
1098
1099 /// A linear-chain graph instantiates and preserves topology; the
1100 /// `GraphExec` reports a consistent driver-backed flag.
1101 #[test]
1102 fn instantiate_chain_preserves_topology() {
1103 let mut g = Graph::new();
1104 let n0 = g.add_memset_node(256, 0);
1105 let n1 = g.add_kernel_node("k", (1, 1, 1), (32, 1, 1), 0);
1106 let n2 = g.add_memcpy_node(MemcpyDirection::DeviceToHost, 256);
1107 g.add_dependency(n0, n1).ok();
1108 g.add_dependency(n1, n2).ok();
1109
1110 let exec = g.instantiate().expect("chain instantiates");
1111 assert_eq!(exec.node_count(), 3);
1112 assert_eq!(exec.execution_order().len(), 3);
1113 if !driver_present() {
1114 assert!(!exec.is_driver_backed());
1115 }
1116 }
1117
1118 /// A diamond DAG instantiates without a driver to a CPU-side `GraphExec`.
1119 #[test]
1120 fn instantiate_diamond_without_driver_is_clean() {
1121 let mut g = Graph::new();
1122 let n0 = g.add_empty_node();
1123 let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
1124 let n2 = g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
1125 let n3 = g.add_empty_node();
1126 g.add_dependency(n0, n1).ok();
1127 g.add_dependency(n0, n2).ok();
1128 g.add_dependency(n1, n3).ok();
1129 g.add_dependency(n2, n3).ok();
1130
1131 let exec = g.instantiate();
1132 assert!(exec.is_ok(), "diamond DAG must instantiate cleanly");
1133 if !driver_present() {
1134 if let Ok(exec) = exec {
1135 assert!(!exec.is_driver_backed());
1136 }
1137 }
1138 }
1139
1140 /// `build_driver_graph` surfaces a clean typed error on a host with no
1141 /// driver — `NotInitialized`, never a panic.
1142 #[test]
1143 fn build_driver_graph_absent_driver_is_clean() {
1144 let mut g = Graph::new();
1145 g.add_empty_node();
1146 let result = g.build_driver_graph();
1147 if driver_present() {
1148 // Live driver: either real handles, or a typed driver error.
1149 match result {
1150 Ok((raw_graph, raw_exec)) => {
1151 assert_eq!(raw_graph.is_some(), raw_exec.is_some());
1152 }
1153 Err(_) => { /* typed driver error is acceptable */ }
1154 }
1155 } else {
1156 assert_eq!(result.err(), Some(CudaError::NotInitialized));
1157 }
1158 }
1159
1160 /// Dropping a CPU-side-only `GraphExec` must not panic (the `Drop` impl
1161 /// only touches driver handles when both they and the driver exist).
1162 #[test]
1163 fn graph_exec_drop_without_driver_is_safe() {
1164 let mut g = Graph::new();
1165 g.add_empty_node();
1166 g.add_empty_node();
1167 let exec = g.instantiate().expect("instantiates");
1168 // Explicit drop — must complete without panicking.
1169 drop(exec);
1170 }
1171
1172 /// A cyclic graph fails instantiation at the topological-sort stage,
1173 /// before any driver call is attempted.
1174 #[test]
1175 fn instantiate_cycle_fails_before_driver() {
1176 let mut g = Graph::new();
1177 let n0 = g.add_empty_node();
1178 let n1 = g.add_empty_node();
1179 g.add_dependency(n0, n1).ok();
1180 g.add_dependency(n1, n0).ok();
1181 assert_eq!(g.instantiate().err(), Some(CudaError::InvalidValue));
1182 }
1183
1184 // -- End-to-end real-GPU graph execution -------------------------------
1185 //
1186 // When this host has a usable GPU, build a CUDA context (which makes it
1187 // current), instantiate a real driver-backed graph, and launch it via
1188 // `cuGraphLaunch`. On a host without a GPU the test is a clean no-op.
1189
1190 /// Instantiate and launch a real diamond-DAG graph on the GPU.
1191 #[test]
1192 fn real_graph_instantiate_and_launch() {
1193 use crate::context::Context;
1194 use crate::device::Device;
1195
1196 // No GPU on this host — nothing to exercise.
1197 let device = match Device::get(0) {
1198 Ok(d) => d,
1199 Err(_) => return,
1200 };
1201 // Creating the context makes it current on this thread, which the
1202 // CUDA Graph API requires.
1203 let ctx = match Context::new(&device) {
1204 Ok(c) => std::sync::Arc::new(c),
1205 Err(_) => return,
1206 };
1207 let stream = match Stream::new(&ctx) {
1208 Ok(s) => s,
1209 Err(_) => return,
1210 };
1211
1212 // Diamond DAG: n0 -> {n1, n2} -> n3.
1213 let mut g = Graph::new();
1214 let n0 = g.add_empty_node();
1215 let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
1216 let n2 = g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
1217 let n3 = g.add_empty_node();
1218 g.add_dependency(n0, n1).ok();
1219 g.add_dependency(n0, n2).ok();
1220 g.add_dependency(n1, n3).ok();
1221 g.add_dependency(n2, n3).ok();
1222
1223 let exec = g.instantiate().expect("diamond DAG instantiates");
1224 assert_eq!(exec.node_count(), 4);
1225
1226 // With a context current and a graph-capable driver, the graph must
1227 // be driver-backed and `cuGraphLaunch` must succeed.
1228 if exec.is_driver_backed() {
1229 exec.launch(&stream)
1230 .expect("cuGraphLaunch on a real graph succeeds");
1231 stream
1232 .synchronize()
1233 .expect("stream synchronises after graph launch");
1234 }
1235 }
1236
1237 /// A driver-backed graph can be relaunched repeatedly on the same stream.
1238 #[test]
1239 fn real_graph_repeated_launch() {
1240 use crate::context::Context;
1241 use crate::device::Device;
1242
1243 let device = match Device::get(0) {
1244 Ok(d) => d,
1245 Err(_) => return,
1246 };
1247 let ctx = match Context::new(&device) {
1248 Ok(c) => std::sync::Arc::new(c),
1249 Err(_) => return,
1250 };
1251 let stream = match Stream::new(&ctx) {
1252 Ok(s) => s,
1253 Err(_) => return,
1254 };
1255
1256 let mut g = Graph::new();
1257 let a = g.add_empty_node();
1258 let b = g.add_empty_node();
1259 g.add_dependency(a, b).ok();
1260
1261 let exec = g.instantiate().expect("chain instantiates");
1262 if exec.is_driver_backed() {
1263 // The whole point of a graph: cheap repeated submission.
1264 for _ in 0..8 {
1265 exec.launch(&stream)
1266 .expect("repeated cuGraphLaunch succeeds");
1267 }
1268 stream.synchronize().expect("stream synchronises");
1269 }
1270 }
1271}