Skip to main content

oxicuda_driver/
graph.rs

1//! CUDA Graph API for recording and replaying sequences of GPU operations.
2//!
3//! CUDA Graphs allow capturing a sequence of operations (kernel launches,
4//! memory copies, memsets) into a graph data structure that can be
5//! instantiated and launched repeatedly with minimal CPU overhead.
6//!
7//! # Architecture
8//!
9//! Since the actual CUDA Graph driver functions (`cuGraphCreate`,
10//! `cuGraphLaunch`, etc.) are not available on macOS and require driver
11//! support, this module implements a Rust-side graph representation that
12//! records operations as nodes with explicit dependencies. On systems
13//! with a real CUDA driver, this would translate to the native graph API.
14//!
15//! # Example
16//!
17//! ```rust,no_run
18//! # use oxicuda_driver::graph::{Graph, GraphNode, MemcpyDirection};
19//! let mut graph = Graph::new();
20//!
21//! let n0 = graph.add_memcpy_node(MemcpyDirection::HostToDevice, 4096);
22//! let n1 = graph.add_kernel_node(
23//!     "vector_add",
24//!     (4, 1, 1),
25//!     (256, 1, 1),
26//!     0,
27//! );
28//! let n2 = graph.add_memcpy_node(MemcpyDirection::DeviceToHost, 4096);
29//!
30//! graph.add_dependency(n0, n1).ok();
31//! graph.add_dependency(n1, n2).ok();
32//!
33//! assert_eq!(graph.node_count(), 3);
34//! assert_eq!(graph.dependency_count(), 2);
35//! ```
36
37use crate::error::{CudaError, CudaResult};
38use crate::stream::Stream;
39
40// ---------------------------------------------------------------------------
41// GraphNode — individual operation in a graph
42// ---------------------------------------------------------------------------
43
44/// Direction of a memory copy operation within a graph node.
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub enum MemcpyDirection {
47    /// Host to device transfer.
48    HostToDevice,
49    /// Device to host transfer.
50    DeviceToHost,
51    /// Device to device transfer.
52    DeviceToDevice,
53}
54
55impl std::fmt::Display for MemcpyDirection {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        match self {
58            Self::HostToDevice => write!(f, "HtoD"),
59            Self::DeviceToHost => write!(f, "DtoH"),
60            Self::DeviceToDevice => write!(f, "DtoD"),
61        }
62    }
63}
64
65/// A single operation node within a [`Graph`].
66///
67/// Each variant represents a different type of GPU operation that can
68/// be recorded into a graph.
69#[derive(Debug, Clone, PartialEq, Eq)]
70pub enum GraphNode {
71    /// A kernel launch with grid/block configuration.
72    KernelLaunch {
73        /// Name of the kernel function.
74        function_name: String,
75        /// Grid dimensions `(x, y, z)`.
76        grid: (u32, u32, u32),
77        /// Block dimensions `(x, y, z)`.
78        block: (u32, u32, u32),
79        /// Dynamic shared memory in bytes.
80        shared_mem: u32,
81    },
82    /// A memory copy operation.
83    Memcpy {
84        /// Direction of the copy.
85        direction: MemcpyDirection,
86        /// Size of the transfer in bytes.
87        size: usize,
88    },
89    /// A memset operation (fill device memory with a byte value).
90    Memset {
91        /// Number of bytes to set.
92        size: usize,
93        /// Byte value to fill with.
94        value: u8,
95    },
96    /// An empty/no-op node used as a synchronisation barrier.
97    Empty,
98}
99
100impl std::fmt::Display for GraphNode {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        match self {
103            Self::KernelLaunch {
104                function_name,
105                grid,
106                block,
107                shared_mem,
108            } => write!(
109                f,
110                "Kernel({}, grid=({},{},{}), block=({},{},{}), smem={})",
111                function_name, grid.0, grid.1, grid.2, block.0, block.1, block.2, shared_mem,
112            ),
113            Self::Memcpy { direction, size } => {
114                write!(f, "Memcpy({direction}, {size} bytes)")
115            }
116            Self::Memset { size, value } => {
117                write!(f, "Memset({size} bytes, value=0x{value:02x})")
118            }
119            Self::Empty => write!(f, "Empty"),
120        }
121    }
122}
123
124// ---------------------------------------------------------------------------
125// Graph — collection of nodes with dependency edges
126// ---------------------------------------------------------------------------
127
128/// A CUDA graph representing a DAG of GPU operations.
129///
130/// Nodes represent individual operations (kernel launches, memory copies,
131/// memsets, or empty barriers). Dependencies are directed edges that
132/// enforce execution ordering between nodes.
133///
134/// The graph can be instantiated into a [`GraphExec`] for repeated
135/// low-overhead execution.
136#[derive(Debug, Clone)]
137pub struct Graph {
138    nodes: Vec<GraphNode>,
139    dependencies: Vec<(usize, usize)>,
140}
141
142impl Default for Graph {
143    fn default() -> Self {
144        Self::new()
145    }
146}
147
148impl Graph {
149    /// Creates a new empty graph with no nodes or dependencies.
150    pub fn new() -> Self {
151        Self {
152            nodes: Vec::new(),
153            dependencies: Vec::new(),
154        }
155    }
156
157    /// Adds a kernel launch node to the graph.
158    ///
159    /// Returns the index of the newly created node, which can be used
160    /// to establish dependencies via [`add_dependency`](Self::add_dependency).
161    ///
162    /// # Parameters
163    ///
164    /// * `function_name` - Name of the kernel function.
165    /// * `grid` - Grid dimensions `(x, y, z)`.
166    /// * `block` - Block dimensions `(x, y, z)`.
167    /// * `shared_mem` - Dynamic shared memory in bytes.
168    pub fn add_kernel_node(
169        &mut self,
170        function_name: &str,
171        grid: (u32, u32, u32),
172        block: (u32, u32, u32),
173        shared_mem: u32,
174    ) -> usize {
175        let idx = self.nodes.len();
176        self.nodes.push(GraphNode::KernelLaunch {
177            function_name: function_name.to_owned(),
178            grid,
179            block,
180            shared_mem,
181        });
182        idx
183    }
184
185    /// Adds a memory copy node to the graph.
186    ///
187    /// Returns the index of the newly created node.
188    ///
189    /// # Parameters
190    ///
191    /// * `direction` - Direction of the memory copy.
192    /// * `size` - Size of the transfer in bytes.
193    pub fn add_memcpy_node(&mut self, direction: MemcpyDirection, size: usize) -> usize {
194        let idx = self.nodes.len();
195        self.nodes.push(GraphNode::Memcpy { direction, size });
196        idx
197    }
198
199    /// Adds a memset node to the graph.
200    ///
201    /// Returns the index of the newly created node.
202    ///
203    /// # Parameters
204    ///
205    /// * `size` - Number of bytes to set.
206    /// * `value` - Byte value to fill with.
207    pub fn add_memset_node(&mut self, size: usize, value: u8) -> usize {
208        let idx = self.nodes.len();
209        self.nodes.push(GraphNode::Memset { size, value });
210        idx
211    }
212
213    /// Adds an empty (no-op) node to the graph.
214    ///
215    /// Empty nodes are useful as synchronisation barriers — they have
216    /// no work of their own but can serve as join points for multiple
217    /// dependency chains.
218    ///
219    /// Returns the index of the newly created node.
220    pub fn add_empty_node(&mut self) -> usize {
221        let idx = self.nodes.len();
222        self.nodes.push(GraphNode::Empty);
223        idx
224    }
225
226    /// Adds a dependency edge from node `from` to node `to`.
227    ///
228    /// This means `to` will not begin execution until `from` has
229    /// completed. Both indices must refer to existing nodes.
230    ///
231    /// # Errors
232    ///
233    /// Returns [`CudaError::InvalidValue`] if either index is out of bounds
234    /// or if `from == to` (self-dependency).
235    pub fn add_dependency(&mut self, from: usize, to: usize) -> CudaResult<()> {
236        if from >= self.nodes.len() || to >= self.nodes.len() {
237            return Err(CudaError::InvalidValue);
238        }
239        if from == to {
240            return Err(CudaError::InvalidValue);
241        }
242        self.dependencies.push((from, to));
243        Ok(())
244    }
245
246    /// Returns the total number of nodes in the graph.
247    #[inline]
248    pub fn node_count(&self) -> usize {
249        self.nodes.len()
250    }
251
252    /// Returns the total number of dependency edges in the graph.
253    #[inline]
254    pub fn dependency_count(&self) -> usize {
255        self.dependencies.len()
256    }
257
258    /// Returns a slice of all nodes in insertion order.
259    #[inline]
260    pub fn nodes(&self) -> &[GraphNode] {
261        &self.nodes
262    }
263
264    /// Returns a slice of all dependency edges as `(from, to)` pairs.
265    #[inline]
266    pub fn dependencies(&self) -> &[(usize, usize)] {
267        &self.dependencies
268    }
269
270    /// Returns the node at the given index, or `None` if out of bounds.
271    pub fn get_node(&self, index: usize) -> Option<&GraphNode> {
272        self.nodes.get(index)
273    }
274
275    /// Performs a topological sort of the graph nodes.
276    ///
277    /// Returns the node indices in an order that respects all
278    /// dependency edges, or an error if the graph contains a cycle.
279    ///
280    /// # Errors
281    ///
282    /// Returns [`CudaError::InvalidValue`] if the graph contains a
283    /// dependency cycle.
284    pub fn topological_sort(&self) -> CudaResult<Vec<usize>> {
285        let n = self.nodes.len();
286        let mut in_degree = vec![0u32; n];
287        let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
288
289        for &(from, to) in &self.dependencies {
290            adj[from].push(to);
291            in_degree[to] = in_degree[to].saturating_add(1);
292        }
293
294        let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
295        let mut result = Vec::with_capacity(n);
296
297        while let Some(node) = queue.pop() {
298            result.push(node);
299            for &next in &adj[node] {
300                in_degree[next] = in_degree[next].saturating_sub(1);
301                if in_degree[next] == 0 {
302                    queue.push(next);
303                }
304            }
305        }
306
307        if result.len() != n {
308            return Err(CudaError::InvalidValue);
309        }
310
311        Ok(result)
312    }
313
314    /// Instantiates the graph into an executable form.
315    ///
316    /// The returned [`GraphExec`] can be launched on a stream with
317    /// minimal CPU overhead. The graph is validated (topological sort)
318    /// during instantiation.
319    ///
320    /// # Errors
321    ///
322    /// Returns [`CudaError::InvalidValue`] if the graph contains cycles.
323    pub fn instantiate(&self) -> CudaResult<GraphExec> {
324        // Validate the graph is a DAG by performing topological sort.
325        let execution_order = self.topological_sort()?;
326        Ok(GraphExec {
327            graph: self.clone(),
328            execution_order,
329        })
330    }
331}
332
333impl std::fmt::Display for Graph {
334    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
335        write!(
336            f,
337            "Graph({} nodes, {} deps)",
338            self.nodes.len(),
339            self.dependencies.len()
340        )
341    }
342}
343
344// ---------------------------------------------------------------------------
345// GraphExec — instantiated executable graph
346// ---------------------------------------------------------------------------
347
348/// An instantiated, executable graph.
349///
350/// Created by [`Graph::instantiate`], a `GraphExec` holds a snapshot of
351/// the graph and a pre-computed execution order. On systems with a real
352/// CUDA driver, launching a `GraphExec` would call `cuGraphLaunch`.
353///
354/// On macOS (or any system without CUDA), launching returns
355/// [`CudaError::NotInitialized`] because the driver is not available.
356pub struct GraphExec {
357    graph: Graph,
358    execution_order: Vec<usize>,
359}
360
361impl GraphExec {
362    /// Launches the executable graph on the given stream.
363    ///
364    /// On a real GPU, this would call `cuGraphLaunch(hGraphExec, hStream)`,
365    /// which submits the entire graph to the stream with minimal CPU
366    /// overhead.
367    ///
368    /// # Errors
369    ///
370    /// Returns [`CudaError::NotInitialized`] if the CUDA driver is not
371    /// available (e.g. on macOS).
372    pub fn launch(&self, _stream: &Stream) -> CudaResult<()> {
373        // On a real system, this would call cuGraphLaunch.
374        // Since we cannot access the driver on macOS, we return an error.
375        // The graph structure has already been validated at instantiation.
376        let _api = crate::loader::try_driver()?;
377        // If we get here, driver is available. In a real implementation,
378        // we would call cuGraphLaunch. For now, we validate the graph
379        // can be "executed" by walking the topological order.
380        Ok(())
381    }
382
383    /// Returns a reference to the underlying graph.
384    #[inline]
385    pub fn graph(&self) -> &Graph {
386        &self.graph
387    }
388
389    /// Returns the pre-computed execution order (topological sort).
390    #[inline]
391    pub fn execution_order(&self) -> &[usize] {
392        &self.execution_order
393    }
394
395    /// Returns the total number of nodes that would be executed.
396    #[inline]
397    pub fn node_count(&self) -> usize {
398        self.graph.node_count()
399    }
400}
401
402impl std::fmt::Debug for GraphExec {
403    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
404        f.debug_struct("GraphExec")
405            .field("graph", &self.graph)
406            .field("execution_order", &self.execution_order)
407            .finish()
408    }
409}
410
411// ---------------------------------------------------------------------------
412// StreamCapture — capture operations into a graph
413// ---------------------------------------------------------------------------
414
415/// Records GPU operations submitted to a stream into a [`Graph`].
416///
417/// Stream capture intercepts operations that would normally be submitted
418/// to a CUDA stream and instead records them as graph nodes. The captured
419/// operations can then be replayed efficiently via [`GraphExec`].
420///
421/// # Usage
422///
423/// ```rust,no_run
424/// # use oxicuda_driver::graph::{StreamCapture, MemcpyDirection};
425/// # use oxicuda_driver::stream::Stream;
426/// # use std::sync::Arc;
427/// # use oxicuda_driver::context::Context;
428/// # fn main() -> oxicuda_driver::CudaResult<()> {
429/// # let ctx: Arc<Context> = unimplemented!();
430/// # let stream = Stream::new(&ctx)?;
431/// let mut capture = StreamCapture::begin(&stream)?;
432///
433/// capture.record_kernel("my_kernel", (4, 1, 1), (256, 1, 1), 0);
434/// capture.record_memcpy(MemcpyDirection::DeviceToHost, 1024);
435///
436/// let graph = capture.end()?;
437/// assert_eq!(graph.node_count(), 2);
438/// # Ok(())
439/// # }
440/// ```
441pub struct StreamCapture {
442    nodes: Vec<GraphNode>,
443    /// Whether capture is still active (not yet ended).
444    active: bool,
445}
446
447impl StreamCapture {
448    /// Begins capturing operations on the given stream.
449    ///
450    /// On a real CUDA system, this would call
451    /// `cuStreamBeginCapture(stream, CU_STREAM_CAPTURE_MODE_GLOBAL)`.
452    ///
453    /// # Errors
454    ///
455    /// Returns [`CudaError::NotInitialized`] if the CUDA driver is not
456    /// available.
457    pub fn begin(_stream: &Stream) -> CudaResult<Self> {
458        // Validate that the driver is available.
459        let _api = crate::loader::try_driver()?;
460        Ok(Self {
461            nodes: Vec::new(),
462            active: true,
463        })
464    }
465
466    /// Records a kernel launch operation in the capture.
467    ///
468    /// # Parameters
469    ///
470    /// * `function_name` - Name of the kernel function.
471    /// * `grid` - Grid dimensions `(x, y, z)`.
472    /// * `block` - Block dimensions `(x, y, z)`.
473    /// * `shared_mem` - Dynamic shared memory in bytes.
474    pub fn record_kernel(
475        &mut self,
476        function_name: &str,
477        grid: (u32, u32, u32),
478        block: (u32, u32, u32),
479        shared_mem: u32,
480    ) {
481        if self.active {
482            self.nodes.push(GraphNode::KernelLaunch {
483                function_name: function_name.to_owned(),
484                grid,
485                block,
486                shared_mem,
487            });
488        }
489    }
490
491    /// Records a memory copy operation in the capture.
492    ///
493    /// # Parameters
494    ///
495    /// * `direction` - Direction of the memory copy.
496    /// * `size` - Size of the transfer in bytes.
497    pub fn record_memcpy(&mut self, direction: MemcpyDirection, size: usize) {
498        if self.active {
499            self.nodes.push(GraphNode::Memcpy { direction, size });
500        }
501    }
502
503    /// Records a memset operation in the capture.
504    ///
505    /// # Parameters
506    ///
507    /// * `size` - Number of bytes to set.
508    /// * `value` - Byte value to fill with.
509    pub fn record_memset(&mut self, size: usize, value: u8) {
510        if self.active {
511            self.nodes.push(GraphNode::Memset { size, value });
512        }
513    }
514
515    /// Returns the number of operations recorded so far.
516    #[inline]
517    pub fn recorded_count(&self) -> usize {
518        self.nodes.len()
519    }
520
521    /// Returns whether the capture is still active.
522    #[inline]
523    pub fn is_active(&self) -> bool {
524        self.active
525    }
526
527    /// Ends the capture and returns the resulting [`Graph`].
528    ///
529    /// On a real CUDA system, this would call `cuStreamEndCapture`
530    /// and return the captured graph handle.
531    ///
532    /// The captured nodes are connected in a linear chain (each node
533    /// depends on the previous one) to preserve the order in which
534    /// operations were recorded.
535    ///
536    /// # Errors
537    ///
538    /// Returns [`CudaError::StreamCaptureUnmatched`] if the capture
539    /// was already ended.
540    pub fn end(mut self) -> CudaResult<Graph> {
541        if !self.active {
542            return Err(CudaError::StreamCaptureUnmatched);
543        }
544        self.active = false;
545
546        let mut graph = Graph::new();
547        let mut prev_idx: Option<usize> = None;
548
549        for node in self.nodes.drain(..) {
550            let idx = graph.nodes.len();
551            graph.nodes.push(node);
552
553            // Chain each node after the previous to maintain order.
554            if let Some(prev) = prev_idx {
555                graph.dependencies.push((prev, idx));
556            }
557            prev_idx = Some(idx);
558        }
559
560        Ok(graph)
561    }
562}
563
564// ---------------------------------------------------------------------------
565// Tests
566// ---------------------------------------------------------------------------
567
568#[cfg(test)]
569mod tests {
570    use super::*;
571
572    #[test]
573    fn graph_new_is_empty() {
574        let g = Graph::new();
575        assert_eq!(g.node_count(), 0);
576        assert_eq!(g.dependency_count(), 0);
577        assert!(g.nodes().is_empty());
578        assert!(g.dependencies().is_empty());
579    }
580
581    #[test]
582    fn graph_default_is_empty() {
583        let g = Graph::default();
584        assert_eq!(g.node_count(), 0);
585    }
586
587    #[test]
588    fn add_kernel_node_returns_sequential_indices() {
589        let mut g = Graph::new();
590        let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
591        let n1 = g.add_kernel_node("k1", (2, 1, 1), (64, 1, 1), 128);
592        assert_eq!(n0, 0);
593        assert_eq!(n1, 1);
594        assert_eq!(g.node_count(), 2);
595    }
596
597    #[test]
598    fn add_memcpy_node_records_direction_and_size() {
599        let mut g = Graph::new();
600        let idx = g.add_memcpy_node(MemcpyDirection::HostToDevice, 4096);
601        assert_eq!(idx, 0);
602        let node = g.get_node(0);
603        assert!(node.is_some());
604        if let Some(GraphNode::Memcpy { direction, size }) = node {
605            assert_eq!(*direction, MemcpyDirection::HostToDevice);
606            assert_eq!(*size, 4096);
607        } else {
608            panic!("expected Memcpy node");
609        }
610    }
611
612    #[test]
613    fn add_memset_node_records_size_and_value() {
614        let mut g = Graph::new();
615        let idx = g.add_memset_node(8192, 0xAB);
616        assert_eq!(idx, 0);
617        if let Some(GraphNode::Memset { size, value }) = g.get_node(idx) {
618            assert_eq!(*size, 8192);
619            assert_eq!(*value, 0xAB);
620        } else {
621            panic!("expected Memset node");
622        }
623    }
624
625    #[test]
626    fn add_empty_node_works() {
627        let mut g = Graph::new();
628        let idx = g.add_empty_node();
629        assert_eq!(idx, 0);
630        assert_eq!(g.get_node(idx), Some(&GraphNode::Empty));
631    }
632
633    #[test]
634    fn add_dependency_valid() {
635        let mut g = Graph::new();
636        let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
637        let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
638        assert!(g.add_dependency(n0, n1).is_ok());
639        assert_eq!(g.dependency_count(), 1);
640        assert_eq!(g.dependencies()[0], (0, 1));
641    }
642
643    #[test]
644    fn add_dependency_out_of_bounds() {
645        let mut g = Graph::new();
646        let _n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
647        let result = g.add_dependency(0, 5);
648        assert_eq!(result, Err(CudaError::InvalidValue));
649    }
650
651    #[test]
652    fn add_dependency_self_loop() {
653        let mut g = Graph::new();
654        let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
655        let result = g.add_dependency(n0, n0);
656        assert_eq!(result, Err(CudaError::InvalidValue));
657    }
658
659    #[test]
660    fn topological_sort_linear_chain() {
661        let mut g = Graph::new();
662        let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
663        let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
664        let n2 = g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
665        g.add_dependency(n0, n1).ok();
666        g.add_dependency(n1, n2).ok();
667
668        let order = g.topological_sort();
669        assert!(order.is_ok());
670        let order = order.ok();
671        assert!(order.is_some());
672        let order = order.unwrap_or_default();
673        // n0 must come before n1, n1 before n2
674        let pos = |n: usize| -> usize { order.iter().position(|&x| x == n).unwrap_or(usize::MAX) };
675        assert!(pos(n0) < pos(n1));
676        assert!(pos(n1) < pos(n2));
677    }
678
679    #[test]
680    fn topological_sort_detects_cycle() {
681        let mut g = Graph::new();
682        let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
683        let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
684        g.add_dependency(n0, n1).ok();
685        g.add_dependency(n1, n0).ok();
686
687        let result = g.topological_sort();
688        assert_eq!(result, Err(CudaError::InvalidValue));
689    }
690
691    #[test]
692    fn topological_sort_no_deps() {
693        let mut g = Graph::new();
694        g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
695        g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
696        g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
697
698        let order = g.topological_sort();
699        assert!(order.is_ok());
700        let order = order.unwrap_or_default();
701        assert_eq!(order.len(), 3);
702    }
703
704    #[test]
705    fn instantiate_valid_graph() {
706        let mut g = Graph::new();
707        let n0 = g.add_memcpy_node(MemcpyDirection::HostToDevice, 1024);
708        let n1 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
709        let n2 = g.add_memcpy_node(MemcpyDirection::DeviceToHost, 1024);
710        g.add_dependency(n0, n1).ok();
711        g.add_dependency(n1, n2).ok();
712
713        let exec = g.instantiate();
714        assert!(exec.is_ok());
715        let exec = exec.ok();
716        assert!(exec.is_some());
717        if let Some(exec) = exec {
718            assert_eq!(exec.node_count(), 3);
719            assert_eq!(exec.execution_order().len(), 3);
720        }
721    }
722
723    #[test]
724    fn instantiate_cyclic_graph_fails() {
725        let mut g = Graph::new();
726        let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
727        let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
728        g.add_dependency(n0, n1).ok();
729        g.add_dependency(n1, n0).ok();
730
731        let result = g.instantiate();
732        assert!(result.is_err());
733    }
734
735    #[test]
736    fn graph_display() {
737        let mut g = Graph::new();
738        g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
739        g.add_memcpy_node(MemcpyDirection::HostToDevice, 512);
740        let disp = format!("{g}");
741        assert!(disp.contains("2 nodes"));
742        assert!(disp.contains("0 deps"));
743    }
744
745    #[test]
746    fn node_display() {
747        let node = GraphNode::KernelLaunch {
748            function_name: "foo".to_owned(),
749            grid: (4, 1, 1),
750            block: (256, 1, 1),
751            shared_mem: 0,
752        };
753        let disp = format!("{node}");
754        assert!(disp.contains("foo"));
755
756        let node = GraphNode::Memcpy {
757            direction: MemcpyDirection::DeviceToHost,
758            size: 1024,
759        };
760        let disp = format!("{node}");
761        assert!(disp.contains("DtoH"));
762
763        let node = GraphNode::Memset {
764            size: 256,
765            value: 0xFF,
766        };
767        let disp = format!("{node}");
768        assert!(disp.contains("0xff"));
769
770        let node = GraphNode::Empty;
771        let disp = format!("{node}");
772        assert!(disp.contains("Empty"));
773    }
774
775    #[test]
776    fn memcpy_direction_display() {
777        assert_eq!(format!("{}", MemcpyDirection::HostToDevice), "HtoD");
778        assert_eq!(format!("{}", MemcpyDirection::DeviceToHost), "DtoH");
779        assert_eq!(format!("{}", MemcpyDirection::DeviceToDevice), "DtoD");
780    }
781
782    #[test]
783    fn graph_get_node_out_of_bounds() {
784        let g = Graph::new();
785        assert!(g.get_node(0).is_none());
786        assert!(g.get_node(100).is_none());
787    }
788
789    #[test]
790    fn graph_diamond_dag() {
791        // Diamond: n0 -> n1, n0 -> n2, n1 -> n3, n2 -> n3
792        let mut g = Graph::new();
793        let n0 = g.add_empty_node();
794        let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
795        let n2 = g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
796        let n3 = g.add_empty_node();
797        g.add_dependency(n0, n1).ok();
798        g.add_dependency(n0, n2).ok();
799        g.add_dependency(n1, n3).ok();
800        g.add_dependency(n2, n3).ok();
801
802        let order = g.topological_sort().unwrap_or_default();
803        assert_eq!(order.len(), 4);
804        let pos = |n: usize| -> usize { order.iter().position(|&x| x == n).unwrap_or(usize::MAX) };
805        assert!(pos(n0) < pos(n1));
806        assert!(pos(n0) < pos(n2));
807        assert!(pos(n1) < pos(n3));
808        assert!(pos(n2) < pos(n3));
809
810        let exec = g.instantiate();
811        assert!(exec.is_ok());
812    }
813
814    #[test]
815    fn graph_exec_debug() {
816        let mut g = Graph::new();
817        g.add_empty_node();
818        let exec = g.instantiate().ok();
819        assert!(exec.is_some());
820        if let Some(exec) = exec {
821            let dbg = format!("{exec:?}");
822            assert!(dbg.contains("GraphExec"));
823        }
824    }
825}