1use crate::{Result, runtime_error};
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10use std::time::Instant;
11
12pub type NodeId = usize;
14
15#[derive(Debug, Clone)]
17pub enum NodeKind {
18 Kernel {
20 name: String,
21 grid: [u32; 3],
22 block: [u32; 3],
23 },
24 Memcpy {
26 size: usize,
27 kind: MemcpyDirection,
28 },
29 Memset {
31 size: usize,
32 value: u8,
33 },
34 HostCallback {
36 name: String,
37 },
38 Empty,
40}
41
42#[derive(Debug, Clone, Copy)]
44pub enum MemcpyDirection {
45 HostToDevice,
46 DeviceToHost,
47 DeviceToDevice,
48}
49
50#[derive(Debug, Clone)]
52pub struct GraphNode {
53 pub id: NodeId,
55 pub kind: NodeKind,
57 pub dependencies: Vec<NodeId>,
59 pub state: NodeState,
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum NodeState {
66 Pending,
67 Running,
68 Completed,
69 Failed,
70}
71
72pub struct CudaGraph {
74 name: String,
76 nodes: HashMap<NodeId, GraphNode>,
78 next_id: NodeId,
80 instantiated: bool,
82}
83
84impl CudaGraph {
85 pub fn new(name: &str) -> Self {
87 Self {
88 name: name.to_string(),
89 nodes: HashMap::new(),
90 next_id: 0,
91 instantiated: false,
92 }
93 }
94
95 pub fn name(&self) -> &str {
97 &self.name
98 }
99
100 pub fn node_count(&self) -> usize {
102 self.nodes.len()
103 }
104
105 pub fn add_kernel_node(
107 &mut self,
108 name: &str,
109 grid: [u32; 3],
110 block: [u32; 3],
111 dependencies: &[NodeId],
112 ) -> Result<NodeId> {
113 self.validate_dependencies(dependencies)?;
114 let id = self.allocate_id();
115 self.nodes.insert(id, GraphNode {
116 id,
117 kind: NodeKind::Kernel {
118 name: name.to_string(),
119 grid,
120 block,
121 },
122 dependencies: dependencies.to_vec(),
123 state: NodeState::Pending,
124 });
125 self.instantiated = false;
126 Ok(id)
127 }
128
129 pub fn add_memcpy_node(
131 &mut self,
132 size: usize,
133 kind: MemcpyDirection,
134 dependencies: &[NodeId],
135 ) -> Result<NodeId> {
136 self.validate_dependencies(dependencies)?;
137 let id = self.allocate_id();
138 self.nodes.insert(id, GraphNode {
139 id,
140 kind: NodeKind::Memcpy { size, kind },
141 dependencies: dependencies.to_vec(),
142 state: NodeState::Pending,
143 });
144 self.instantiated = false;
145 Ok(id)
146 }
147
148 pub fn add_memset_node(
150 &mut self,
151 size: usize,
152 value: u8,
153 dependencies: &[NodeId],
154 ) -> Result<NodeId> {
155 self.validate_dependencies(dependencies)?;
156 let id = self.allocate_id();
157 self.nodes.insert(id, GraphNode {
158 id,
159 kind: NodeKind::Memset { size, value },
160 dependencies: dependencies.to_vec(),
161 state: NodeState::Pending,
162 });
163 self.instantiated = false;
164 Ok(id)
165 }
166
167 pub fn add_host_node(
169 &mut self,
170 name: &str,
171 dependencies: &[NodeId],
172 ) -> Result<NodeId> {
173 self.validate_dependencies(dependencies)?;
174 let id = self.allocate_id();
175 self.nodes.insert(id, GraphNode {
176 id,
177 kind: NodeKind::HostCallback {
178 name: name.to_string(),
179 },
180 dependencies: dependencies.to_vec(),
181 state: NodeState::Pending,
182 });
183 self.instantiated = false;
184 Ok(id)
185 }
186
187 pub fn add_empty_node(&mut self, dependencies: &[NodeId]) -> Result<NodeId> {
189 self.validate_dependencies(dependencies)?;
190 let id = self.allocate_id();
191 self.nodes.insert(id, GraphNode {
192 id,
193 kind: NodeKind::Empty,
194 dependencies: dependencies.to_vec(),
195 state: NodeState::Pending,
196 });
197 self.instantiated = false;
198 Ok(id)
199 }
200
201 pub fn get_node(&self, id: NodeId) -> Option<&GraphNode> {
203 self.nodes.get(&id)
204 }
205
206 pub fn root_nodes(&self) -> Vec<NodeId> {
208 self.nodes
209 .values()
210 .filter(|n| n.dependencies.is_empty())
211 .map(|n| n.id)
212 .collect()
213 }
214
215 pub fn topological_order(&self) -> Result<Vec<NodeId>> {
217 let mut visited = HashMap::new();
218 let mut order = Vec::new();
219
220 let mut keys: Vec<NodeId> = self.nodes.keys().copied().collect();
222 keys.sort();
223
224 for id in keys {
225 if !visited.contains_key(&id) {
226 self.topo_visit(id, &mut visited, &mut order)?;
227 }
228 }
229
230 Ok(order)
233 }
234
235 pub fn validate(&self) -> Result<()> {
237 self.topological_order()?;
238 Ok(())
239 }
240
241 pub fn instantiate(&mut self) -> Result<GraphExec> {
243 self.validate()?;
244 self.instantiated = true;
245
246 let order = self.topological_order()?;
247 let nodes: Vec<GraphNode> = order
248 .iter()
249 .map(|id| self.nodes[id].clone())
250 .collect();
251
252 Ok(GraphExec {
253 graph_name: self.name.clone(),
254 nodes,
255 execution_count: 0,
256 total_execution_time_us: 0,
257 })
258 }
259
260 pub fn is_instantiated(&self) -> bool {
262 self.instantiated
263 }
264
265 fn allocate_id(&mut self) -> NodeId {
268 let id = self.next_id;
269 self.next_id += 1;
270 id
271 }
272
273 fn validate_dependencies(&self, deps: &[NodeId]) -> Result<()> {
274 for &dep in deps {
275 if !self.nodes.contains_key(&dep) {
276 return Err(runtime_error!(
277 "Dependency node {} does not exist in graph",
278 dep
279 ));
280 }
281 }
282 Ok(())
283 }
284
285 fn topo_visit(
286 &self,
287 id: NodeId,
288 visited: &mut HashMap<NodeId, bool>,
289 order: &mut Vec<NodeId>,
290 ) -> Result<()> {
291 if let Some(&in_progress) = visited.get(&id) {
292 if in_progress {
293 return Err(runtime_error!("Cycle detected in graph at node {}", id));
294 }
295 return Ok(());
296 }
297
298 visited.insert(id, true); if let Some(node) = self.nodes.get(&id) {
301 for &dep in &node.dependencies {
302 self.topo_visit(dep, visited, order)?;
303 }
304 }
305
306 visited.insert(id, false); order.push(id);
308 Ok(())
309 }
310}
311
312pub struct GraphExec {
314 graph_name: String,
316 nodes: Vec<GraphNode>,
318 execution_count: u64,
320 total_execution_time_us: u64,
322}
323
324impl GraphExec {
325 pub fn launch(&mut self) -> Result<GraphExecResult> {
331 let start = Instant::now();
332 let mut node_results = Vec::new();
333
334 for node in &self.nodes {
335 let node_start = Instant::now();
336
337 match &node.kind {
339 NodeKind::Kernel { name, grid, block } => {
340 let total_threads =
341 grid[0] * grid[1] * grid[2] * block[0] * block[1] * block[2];
342 node_results.push(NodeExecResult {
343 node_id: node.id,
344 name: name.clone(),
345 duration_us: node_start.elapsed().as_micros() as u64,
346 threads_launched: total_threads as u64,
347 });
348 }
349 NodeKind::Memcpy { size, .. } => {
350 node_results.push(NodeExecResult {
351 node_id: node.id,
352 name: format!("memcpy_{}_bytes", size),
353 duration_us: node_start.elapsed().as_micros() as u64,
354 threads_launched: 0,
355 });
356 }
357 NodeKind::Memset { size, .. } => {
358 node_results.push(NodeExecResult {
359 node_id: node.id,
360 name: format!("memset_{}_bytes", size),
361 duration_us: node_start.elapsed().as_micros() as u64,
362 threads_launched: 0,
363 });
364 }
365 NodeKind::HostCallback { name } => {
366 node_results.push(NodeExecResult {
367 node_id: node.id,
368 name: name.clone(),
369 duration_us: node_start.elapsed().as_micros() as u64,
370 threads_launched: 0,
371 });
372 }
373 NodeKind::Empty => {
374 node_results.push(NodeExecResult {
375 node_id: node.id,
376 name: "sync".to_string(),
377 duration_us: 0,
378 threads_launched: 0,
379 });
380 }
381 }
382 }
383
384 let total_us = start.elapsed().as_micros() as u64;
385 self.execution_count += 1;
386 self.total_execution_time_us += total_us;
387
388 Ok(GraphExecResult {
389 graph_name: self.graph_name.clone(),
390 node_results,
391 total_duration_us: total_us,
392 execution_number: self.execution_count,
393 })
394 }
395
396 pub fn execution_count(&self) -> u64 {
398 self.execution_count
399 }
400
401 pub fn avg_execution_time_us(&self) -> u64 {
403 if self.execution_count == 0 {
404 0
405 } else {
406 self.total_execution_time_us / self.execution_count
407 }
408 }
409
410 pub fn node_count(&self) -> usize {
412 self.nodes.len()
413 }
414}
415
416#[derive(Debug)]
418pub struct GraphExecResult {
419 pub graph_name: String,
420 pub node_results: Vec<NodeExecResult>,
421 pub total_duration_us: u64,
422 pub execution_number: u64,
423}
424
425#[derive(Debug)]
427pub struct NodeExecResult {
428 pub node_id: NodeId,
429 pub name: String,
430 pub duration_us: u64,
431 pub threads_launched: u64,
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
439 fn test_graph_creation() {
440 let graph = CudaGraph::new("test_graph");
441 assert_eq!(graph.name(), "test_graph");
442 assert_eq!(graph.node_count(), 0);
443 }
444
445 #[test]
446 fn test_add_kernel_node() {
447 let mut graph = CudaGraph::new("test");
448 let id = graph.add_kernel_node("my_kernel", [1, 1, 1], [256, 1, 1], &[]).unwrap();
449 assert_eq!(graph.node_count(), 1);
450 let node = graph.get_node(id).unwrap();
451 assert!(matches!(&node.kind, NodeKind::Kernel { name, .. } if name == "my_kernel"));
452 }
453
454 #[test]
455 fn test_add_memcpy_node() {
456 let mut graph = CudaGraph::new("test");
457 let id = graph
458 .add_memcpy_node(1024, MemcpyDirection::HostToDevice, &[])
459 .unwrap();
460 assert_eq!(graph.node_count(), 1);
461 let node = graph.get_node(id).unwrap();
462 assert!(matches!(&node.kind, NodeKind::Memcpy { size: 1024, .. }));
463 }
464
465 #[test]
466 fn test_graph_dependencies() {
467 let mut graph = CudaGraph::new("pipeline");
468 let upload = graph
469 .add_memcpy_node(1024, MemcpyDirection::HostToDevice, &[])
470 .unwrap();
471 let compute = graph
472 .add_kernel_node("process", [4, 1, 1], [256, 1, 1], &[upload])
473 .unwrap();
474 let download = graph
475 .add_memcpy_node(1024, MemcpyDirection::DeviceToHost, &[compute])
476 .unwrap();
477
478 assert_eq!(graph.node_count(), 3);
479 assert_eq!(graph.root_nodes(), vec![upload]);
480
481 let order = graph.topological_order().unwrap();
483 let upload_pos = order.iter().position(|&x| x == upload).unwrap();
484 let compute_pos = order.iter().position(|&x| x == compute).unwrap();
485 let download_pos = order.iter().position(|&x| x == download).unwrap();
486
487 assert!(upload_pos < compute_pos);
488 assert!(compute_pos < download_pos);
489 }
490
491 #[test]
492 fn test_invalid_dependency() {
493 let mut graph = CudaGraph::new("test");
494 let result = graph.add_kernel_node("k", [1, 1, 1], [1, 1, 1], &[999]);
495 assert!(result.is_err());
496 }
497
498 #[test]
499 fn test_graph_instantiate() {
500 let mut graph = CudaGraph::new("test");
501 graph.add_kernel_node("k1", [1, 1, 1], [256, 1, 1], &[]).unwrap();
502 graph.add_kernel_node("k2", [1, 1, 1], [256, 1, 1], &[]).unwrap();
503
504 let exec = graph.instantiate();
505 assert!(exec.is_ok());
506 assert!(graph.is_instantiated());
507 }
508
509 #[test]
510 fn test_graph_execute() {
511 let mut graph = CudaGraph::new("pipeline");
512 let n1 = graph.add_kernel_node("init", [1, 1, 1], [128, 1, 1], &[]).unwrap();
513 let n2 = graph.add_kernel_node("compute", [4, 1, 1], [256, 1, 1], &[n1]).unwrap();
514 graph.add_kernel_node("finalize", [1, 1, 1], [64, 1, 1], &[n2]).unwrap();
515
516 let mut exec = graph.instantiate().unwrap();
517 let result = exec.launch().unwrap();
518
519 assert_eq!(result.graph_name, "pipeline");
520 assert_eq!(result.node_results.len(), 3);
521 assert_eq!(result.execution_number, 1);
522 }
523
524 #[test]
525 fn test_graph_replay() {
526 let mut graph = CudaGraph::new("replay_test");
527 graph.add_kernel_node("k", [1, 1, 1], [32, 1, 1], &[]).unwrap();
528
529 let mut exec = graph.instantiate().unwrap();
530
531 for i in 1..=5 {
533 let result = exec.launch().unwrap();
534 assert_eq!(result.execution_number, i);
535 }
536 assert_eq!(exec.execution_count(), 5);
537 }
538
539 #[test]
540 fn test_graph_validate_dag() {
541 let mut graph = CudaGraph::new("valid");
542 let a = graph.add_kernel_node("a", [1, 1, 1], [1, 1, 1], &[]).unwrap();
543 let b = graph.add_kernel_node("b", [1, 1, 1], [1, 1, 1], &[a]).unwrap();
544 graph.add_kernel_node("c", [1, 1, 1], [1, 1, 1], &[a, b]).unwrap();
545
546 assert!(graph.validate().is_ok());
547 }
548
549 #[test]
550 fn test_empty_graph_instantiate() {
551 let mut graph = CudaGraph::new("empty");
552 let mut exec = graph.instantiate().unwrap();
553 let result = exec.launch().unwrap();
554 assert_eq!(result.node_results.len(), 0);
555 }
556
557 #[test]
558 fn test_memset_node() {
559 let mut graph = CudaGraph::new("memset_test");
560 let id = graph.add_memset_node(4096, 0, &[]).unwrap();
561 let node = graph.get_node(id).unwrap();
562 assert!(matches!(&node.kind, NodeKind::Memset { size: 4096, value: 0 }));
563 }
564
565 #[test]
566 fn test_host_callback_node() {
567 let mut graph = CudaGraph::new("callback_test");
568 let id = graph.add_host_node("my_callback", &[]).unwrap();
569 let node = graph.get_node(id).unwrap();
570 assert!(matches!(&node.kind, NodeKind::HostCallback { name } if name == "my_callback"));
571 }
572
573 #[test]
574 fn test_diamond_dependency_graph() {
575 let mut graph = CudaGraph::new("diamond");
576 let root = graph.add_kernel_node("root", [1, 1, 1], [1, 1, 1], &[]).unwrap();
577 let left = graph.add_kernel_node("left", [1, 1, 1], [1, 1, 1], &[root]).unwrap();
578 let right = graph.add_kernel_node("right", [1, 1, 1], [1, 1, 1], &[root]).unwrap();
579 let join = graph.add_kernel_node("join", [1, 1, 1], [1, 1, 1], &[left, right]).unwrap();
580
581 let order = graph.topological_order().unwrap();
582 let root_pos = order.iter().position(|&x| x == root).unwrap();
583 let left_pos = order.iter().position(|&x| x == left).unwrap();
584 let right_pos = order.iter().position(|&x| x == right).unwrap();
585 let join_pos = order.iter().position(|&x| x == join).unwrap();
586
587 assert!(root_pos < left_pos);
588 assert!(root_pos < right_pos);
589 assert!(left_pos < join_pos);
590 assert!(right_pos < join_pos);
591 }
592}