Skip to main content

cuda_rust_wasm/runtime/
cuda_graph.rs

1//! CUDA Graphs: graph-based kernel execution
2//!
3//! Provides a dependency graph for kernel launches, enabling the runtime
4//! to optimise scheduling by executing independent nodes in parallel and
5//! replaying captured workloads without re-recording overhead.
6
7use crate::{Result, runtime_error};
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10use std::time::Instant;
11
12/// Unique node identifier within a graph
13pub type NodeId = usize;
14
15/// Kind of work represented by a graph node
16#[derive(Debug, Clone)]
17pub enum NodeKind {
18    /// GPU kernel launch
19    Kernel {
20        name: String,
21        grid: [u32; 3],
22        block: [u32; 3],
23    },
24    /// Host-to-device or device-to-host memory copy
25    Memcpy {
26        size: usize,
27        kind: MemcpyDirection,
28    },
29    /// Memory set (fill with a value)
30    Memset {
31        size: usize,
32        value: u8,
33    },
34    /// Host callback
35    HostCallback {
36        name: String,
37    },
38    /// Empty / synchronization-only node
39    Empty,
40}
41
42/// Memory copy direction for graph edges
43#[derive(Debug, Clone, Copy)]
44pub enum MemcpyDirection {
45    HostToDevice,
46    DeviceToHost,
47    DeviceToDevice,
48}
49
50/// A node in a CUDA graph
51#[derive(Debug, Clone)]
52pub struct GraphNode {
53    /// Unique ID
54    pub id: NodeId,
55    /// Kind of work
56    pub kind: NodeKind,
57    /// IDs of nodes that this node depends on
58    pub dependencies: Vec<NodeId>,
59    /// Execution state
60    pub state: NodeState,
61}
62
63/// Execution state of a graph node
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum NodeState {
66    Pending,
67    Running,
68    Completed,
69    Failed,
70}
71
72/// CUDA Graph: a directed acyclic graph of GPU operations
73pub struct CudaGraph {
74    /// Graph name
75    name: String,
76    /// All nodes in the graph, indexed by NodeId
77    nodes: HashMap<NodeId, GraphNode>,
78    /// Next available node ID
79    next_id: NodeId,
80    /// Whether the graph has been instantiated (compiled for execution)
81    instantiated: bool,
82}
83
84impl CudaGraph {
85    /// Create a new empty graph
86    pub fn new(name: &str) -> Self {
87        Self {
88            name: name.to_string(),
89            nodes: HashMap::new(),
90            next_id: 0,
91            instantiated: false,
92        }
93    }
94
95    /// Get graph name
96    pub fn name(&self) -> &str {
97        &self.name
98    }
99
100    /// Get node count
101    pub fn node_count(&self) -> usize {
102        self.nodes.len()
103    }
104
105    /// Add a kernel node
106    pub fn add_kernel_node(
107        &mut self,
108        name: &str,
109        grid: [u32; 3],
110        block: [u32; 3],
111        dependencies: &[NodeId],
112    ) -> Result<NodeId> {
113        self.validate_dependencies(dependencies)?;
114        let id = self.allocate_id();
115        self.nodes.insert(id, GraphNode {
116            id,
117            kind: NodeKind::Kernel {
118                name: name.to_string(),
119                grid,
120                block,
121            },
122            dependencies: dependencies.to_vec(),
123            state: NodeState::Pending,
124        });
125        self.instantiated = false;
126        Ok(id)
127    }
128
129    /// Add a memcpy node
130    pub fn add_memcpy_node(
131        &mut self,
132        size: usize,
133        kind: MemcpyDirection,
134        dependencies: &[NodeId],
135    ) -> Result<NodeId> {
136        self.validate_dependencies(dependencies)?;
137        let id = self.allocate_id();
138        self.nodes.insert(id, GraphNode {
139            id,
140            kind: NodeKind::Memcpy { size, kind },
141            dependencies: dependencies.to_vec(),
142            state: NodeState::Pending,
143        });
144        self.instantiated = false;
145        Ok(id)
146    }
147
148    /// Add a memset node
149    pub fn add_memset_node(
150        &mut self,
151        size: usize,
152        value: u8,
153        dependencies: &[NodeId],
154    ) -> Result<NodeId> {
155        self.validate_dependencies(dependencies)?;
156        let id = self.allocate_id();
157        self.nodes.insert(id, GraphNode {
158            id,
159            kind: NodeKind::Memset { size, value },
160            dependencies: dependencies.to_vec(),
161            state: NodeState::Pending,
162        });
163        self.instantiated = false;
164        Ok(id)
165    }
166
167    /// Add a host callback node
168    pub fn add_host_node(
169        &mut self,
170        name: &str,
171        dependencies: &[NodeId],
172    ) -> Result<NodeId> {
173        self.validate_dependencies(dependencies)?;
174        let id = self.allocate_id();
175        self.nodes.insert(id, GraphNode {
176            id,
177            kind: NodeKind::HostCallback {
178                name: name.to_string(),
179            },
180            dependencies: dependencies.to_vec(),
181            state: NodeState::Pending,
182        });
183        self.instantiated = false;
184        Ok(id)
185    }
186
187    /// Add an empty synchronization node
188    pub fn add_empty_node(&mut self, dependencies: &[NodeId]) -> Result<NodeId> {
189        self.validate_dependencies(dependencies)?;
190        let id = self.allocate_id();
191        self.nodes.insert(id, GraphNode {
192            id,
193            kind: NodeKind::Empty,
194            dependencies: dependencies.to_vec(),
195            state: NodeState::Pending,
196        });
197        self.instantiated = false;
198        Ok(id)
199    }
200
201    /// Get a node by ID
202    pub fn get_node(&self, id: NodeId) -> Option<&GraphNode> {
203        self.nodes.get(&id)
204    }
205
206    /// Get all root nodes (no dependencies)
207    pub fn root_nodes(&self) -> Vec<NodeId> {
208        self.nodes
209            .values()
210            .filter(|n| n.dependencies.is_empty())
211            .map(|n| n.id)
212            .collect()
213    }
214
215    /// Get topological ordering of nodes for execution
216    pub fn topological_order(&self) -> Result<Vec<NodeId>> {
217        let mut visited = HashMap::new();
218        let mut order = Vec::new();
219
220        // Sort keys for deterministic iteration order
221        let mut keys: Vec<NodeId> = self.nodes.keys().copied().collect();
222        keys.sort();
223
224        for id in keys {
225            if !visited.contains_key(&id) {
226                self.topo_visit(id, &mut visited, &mut order)?;
227            }
228        }
229
230        // DFS visiting predecessors (deps) produces order where dependencies
231        // come before dependents -- already a valid topological order.
232        Ok(order)
233    }
234
235    /// Check if the graph is a valid DAG (no cycles)
236    pub fn validate(&self) -> Result<()> {
237        self.topological_order()?;
238        Ok(())
239    }
240
241    /// Instantiate the graph (compile for execution)
242    pub fn instantiate(&mut self) -> Result<GraphExec> {
243        self.validate()?;
244        self.instantiated = true;
245
246        let order = self.topological_order()?;
247        let nodes: Vec<GraphNode> = order
248            .iter()
249            .map(|id| self.nodes[id].clone())
250            .collect();
251
252        Ok(GraphExec {
253            graph_name: self.name.clone(),
254            nodes,
255            execution_count: 0,
256            total_execution_time_us: 0,
257        })
258    }
259
260    /// Whether the graph has been instantiated
261    pub fn is_instantiated(&self) -> bool {
262        self.instantiated
263    }
264
265    // --- Private helpers ---
266
267    fn allocate_id(&mut self) -> NodeId {
268        let id = self.next_id;
269        self.next_id += 1;
270        id
271    }
272
273    fn validate_dependencies(&self, deps: &[NodeId]) -> Result<()> {
274        for &dep in deps {
275            if !self.nodes.contains_key(&dep) {
276                return Err(runtime_error!(
277                    "Dependency node {} does not exist in graph",
278                    dep
279                ));
280            }
281        }
282        Ok(())
283    }
284
285    fn topo_visit(
286        &self,
287        id: NodeId,
288        visited: &mut HashMap<NodeId, bool>,
289        order: &mut Vec<NodeId>,
290    ) -> Result<()> {
291        if let Some(&in_progress) = visited.get(&id) {
292            if in_progress {
293                return Err(runtime_error!("Cycle detected in graph at node {}", id));
294            }
295            return Ok(());
296        }
297
298        visited.insert(id, true); // Mark as in-progress
299
300        if let Some(node) = self.nodes.get(&id) {
301            for &dep in &node.dependencies {
302                self.topo_visit(dep, visited, order)?;
303            }
304        }
305
306        visited.insert(id, false); // Mark as completed
307        order.push(id);
308        Ok(())
309    }
310}
311
312/// Executable (instantiated) graph
313pub struct GraphExec {
314    /// Graph name
315    graph_name: String,
316    /// Nodes in topological order
317    nodes: Vec<GraphNode>,
318    /// Number of times this graph has been executed
319    execution_count: u64,
320    /// Total execution time in microseconds
321    total_execution_time_us: u64,
322}
323
324impl GraphExec {
325    /// Execute the graph
326    ///
327    /// In the CPU emulation backend, nodes are executed sequentially in
328    /// topological order. With a real GPU backend, independent nodes could
329    /// be dispatched in parallel.
330    pub fn launch(&mut self) -> Result<GraphExecResult> {
331        let start = Instant::now();
332        let mut node_results = Vec::new();
333
334        for node in &self.nodes {
335            let node_start = Instant::now();
336
337            // CPU emulation: just record that we "executed" each node
338            match &node.kind {
339                NodeKind::Kernel { name, grid, block } => {
340                    let total_threads =
341                        grid[0] * grid[1] * grid[2] * block[0] * block[1] * block[2];
342                    node_results.push(NodeExecResult {
343                        node_id: node.id,
344                        name: name.clone(),
345                        duration_us: node_start.elapsed().as_micros() as u64,
346                        threads_launched: total_threads as u64,
347                    });
348                }
349                NodeKind::Memcpy { size, .. } => {
350                    node_results.push(NodeExecResult {
351                        node_id: node.id,
352                        name: format!("memcpy_{}_bytes", size),
353                        duration_us: node_start.elapsed().as_micros() as u64,
354                        threads_launched: 0,
355                    });
356                }
357                NodeKind::Memset { size, .. } => {
358                    node_results.push(NodeExecResult {
359                        node_id: node.id,
360                        name: format!("memset_{}_bytes", size),
361                        duration_us: node_start.elapsed().as_micros() as u64,
362                        threads_launched: 0,
363                    });
364                }
365                NodeKind::HostCallback { name } => {
366                    node_results.push(NodeExecResult {
367                        node_id: node.id,
368                        name: name.clone(),
369                        duration_us: node_start.elapsed().as_micros() as u64,
370                        threads_launched: 0,
371                    });
372                }
373                NodeKind::Empty => {
374                    node_results.push(NodeExecResult {
375                        node_id: node.id,
376                        name: "sync".to_string(),
377                        duration_us: 0,
378                        threads_launched: 0,
379                    });
380                }
381            }
382        }
383
384        let total_us = start.elapsed().as_micros() as u64;
385        self.execution_count += 1;
386        self.total_execution_time_us += total_us;
387
388        Ok(GraphExecResult {
389            graph_name: self.graph_name.clone(),
390            node_results,
391            total_duration_us: total_us,
392            execution_number: self.execution_count,
393        })
394    }
395
396    /// Get execution count
397    pub fn execution_count(&self) -> u64 {
398        self.execution_count
399    }
400
401    /// Get average execution time in microseconds
402    pub fn avg_execution_time_us(&self) -> u64 {
403        if self.execution_count == 0 {
404            0
405        } else {
406            self.total_execution_time_us / self.execution_count
407        }
408    }
409
410    /// Get number of nodes
411    pub fn node_count(&self) -> usize {
412        self.nodes.len()
413    }
414}
415
416/// Result of executing a graph
417#[derive(Debug)]
418pub struct GraphExecResult {
419    pub graph_name: String,
420    pub node_results: Vec<NodeExecResult>,
421    pub total_duration_us: u64,
422    pub execution_number: u64,
423}
424
425/// Result of executing a single node
426#[derive(Debug)]
427pub struct NodeExecResult {
428    pub node_id: NodeId,
429    pub name: String,
430    pub duration_us: u64,
431    pub threads_launched: u64,
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn test_graph_creation() {
440        let graph = CudaGraph::new("test_graph");
441        assert_eq!(graph.name(), "test_graph");
442        assert_eq!(graph.node_count(), 0);
443    }
444
445    #[test]
446    fn test_add_kernel_node() {
447        let mut graph = CudaGraph::new("test");
448        let id = graph.add_kernel_node("my_kernel", [1, 1, 1], [256, 1, 1], &[]).unwrap();
449        assert_eq!(graph.node_count(), 1);
450        let node = graph.get_node(id).unwrap();
451        assert!(matches!(&node.kind, NodeKind::Kernel { name, .. } if name == "my_kernel"));
452    }
453
454    #[test]
455    fn test_add_memcpy_node() {
456        let mut graph = CudaGraph::new("test");
457        let id = graph
458            .add_memcpy_node(1024, MemcpyDirection::HostToDevice, &[])
459            .unwrap();
460        assert_eq!(graph.node_count(), 1);
461        let node = graph.get_node(id).unwrap();
462        assert!(matches!(&node.kind, NodeKind::Memcpy { size: 1024, .. }));
463    }
464
465    #[test]
466    fn test_graph_dependencies() {
467        let mut graph = CudaGraph::new("pipeline");
468        let upload = graph
469            .add_memcpy_node(1024, MemcpyDirection::HostToDevice, &[])
470            .unwrap();
471        let compute = graph
472            .add_kernel_node("process", [4, 1, 1], [256, 1, 1], &[upload])
473            .unwrap();
474        let download = graph
475            .add_memcpy_node(1024, MemcpyDirection::DeviceToHost, &[compute])
476            .unwrap();
477
478        assert_eq!(graph.node_count(), 3);
479        assert_eq!(graph.root_nodes(), vec![upload]);
480
481        // Verify topological order
482        let order = graph.topological_order().unwrap();
483        let upload_pos = order.iter().position(|&x| x == upload).unwrap();
484        let compute_pos = order.iter().position(|&x| x == compute).unwrap();
485        let download_pos = order.iter().position(|&x| x == download).unwrap();
486
487        assert!(upload_pos < compute_pos);
488        assert!(compute_pos < download_pos);
489    }
490
491    #[test]
492    fn test_invalid_dependency() {
493        let mut graph = CudaGraph::new("test");
494        let result = graph.add_kernel_node("k", [1, 1, 1], [1, 1, 1], &[999]);
495        assert!(result.is_err());
496    }
497
498    #[test]
499    fn test_graph_instantiate() {
500        let mut graph = CudaGraph::new("test");
501        graph.add_kernel_node("k1", [1, 1, 1], [256, 1, 1], &[]).unwrap();
502        graph.add_kernel_node("k2", [1, 1, 1], [256, 1, 1], &[]).unwrap();
503
504        let exec = graph.instantiate();
505        assert!(exec.is_ok());
506        assert!(graph.is_instantiated());
507    }
508
509    #[test]
510    fn test_graph_execute() {
511        let mut graph = CudaGraph::new("pipeline");
512        let n1 = graph.add_kernel_node("init", [1, 1, 1], [128, 1, 1], &[]).unwrap();
513        let n2 = graph.add_kernel_node("compute", [4, 1, 1], [256, 1, 1], &[n1]).unwrap();
514        graph.add_kernel_node("finalize", [1, 1, 1], [64, 1, 1], &[n2]).unwrap();
515
516        let mut exec = graph.instantiate().unwrap();
517        let result = exec.launch().unwrap();
518
519        assert_eq!(result.graph_name, "pipeline");
520        assert_eq!(result.node_results.len(), 3);
521        assert_eq!(result.execution_number, 1);
522    }
523
524    #[test]
525    fn test_graph_replay() {
526        let mut graph = CudaGraph::new("replay_test");
527        graph.add_kernel_node("k", [1, 1, 1], [32, 1, 1], &[]).unwrap();
528
529        let mut exec = graph.instantiate().unwrap();
530
531        // Execute multiple times (replay)
532        for i in 1..=5 {
533            let result = exec.launch().unwrap();
534            assert_eq!(result.execution_number, i);
535        }
536        assert_eq!(exec.execution_count(), 5);
537    }
538
539    #[test]
540    fn test_graph_validate_dag() {
541        let mut graph = CudaGraph::new("valid");
542        let a = graph.add_kernel_node("a", [1, 1, 1], [1, 1, 1], &[]).unwrap();
543        let b = graph.add_kernel_node("b", [1, 1, 1], [1, 1, 1], &[a]).unwrap();
544        graph.add_kernel_node("c", [1, 1, 1], [1, 1, 1], &[a, b]).unwrap();
545
546        assert!(graph.validate().is_ok());
547    }
548
549    #[test]
550    fn test_empty_graph_instantiate() {
551        let mut graph = CudaGraph::new("empty");
552        let mut exec = graph.instantiate().unwrap();
553        let result = exec.launch().unwrap();
554        assert_eq!(result.node_results.len(), 0);
555    }
556
557    #[test]
558    fn test_memset_node() {
559        let mut graph = CudaGraph::new("memset_test");
560        let id = graph.add_memset_node(4096, 0, &[]).unwrap();
561        let node = graph.get_node(id).unwrap();
562        assert!(matches!(&node.kind, NodeKind::Memset { size: 4096, value: 0 }));
563    }
564
565    #[test]
566    fn test_host_callback_node() {
567        let mut graph = CudaGraph::new("callback_test");
568        let id = graph.add_host_node("my_callback", &[]).unwrap();
569        let node = graph.get_node(id).unwrap();
570        assert!(matches!(&node.kind, NodeKind::HostCallback { name } if name == "my_callback"));
571    }
572
573    #[test]
574    fn test_diamond_dependency_graph() {
575        let mut graph = CudaGraph::new("diamond");
576        let root = graph.add_kernel_node("root", [1, 1, 1], [1, 1, 1], &[]).unwrap();
577        let left = graph.add_kernel_node("left", [1, 1, 1], [1, 1, 1], &[root]).unwrap();
578        let right = graph.add_kernel_node("right", [1, 1, 1], [1, 1, 1], &[root]).unwrap();
579        let join = graph.add_kernel_node("join", [1, 1, 1], [1, 1, 1], &[left, right]).unwrap();
580
581        let order = graph.topological_order().unwrap();
582        let root_pos = order.iter().position(|&x| x == root).unwrap();
583        let left_pos = order.iter().position(|&x| x == left).unwrap();
584        let right_pos = order.iter().position(|&x| x == right).unwrap();
585        let join_pos = order.iter().position(|&x| x == join).unwrap();
586
587        assert!(root_pos < left_pos);
588        assert!(root_pos < right_pos);
589        assert!(left_pos < join_pos);
590        assert!(right_pos < join_pos);
591    }
592}