1use crate::error::{CudaError, CudaResult};
38use crate::stream::Stream;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub enum MemcpyDirection {
47 HostToDevice,
49 DeviceToHost,
51 DeviceToDevice,
53}
54
55impl std::fmt::Display for MemcpyDirection {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 match self {
58 Self::HostToDevice => write!(f, "HtoD"),
59 Self::DeviceToHost => write!(f, "DtoH"),
60 Self::DeviceToDevice => write!(f, "DtoD"),
61 }
62 }
63}
64
65#[derive(Debug, Clone, PartialEq, Eq)]
70pub enum GraphNode {
71 KernelLaunch {
73 function_name: String,
75 grid: (u32, u32, u32),
77 block: (u32, u32, u32),
79 shared_mem: u32,
81 },
82 Memcpy {
84 direction: MemcpyDirection,
86 size: usize,
88 },
89 Memset {
91 size: usize,
93 value: u8,
95 },
96 Empty,
98}
99
100impl std::fmt::Display for GraphNode {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 match self {
103 Self::KernelLaunch {
104 function_name,
105 grid,
106 block,
107 shared_mem,
108 } => write!(
109 f,
110 "Kernel({}, grid=({},{},{}), block=({},{},{}), smem={})",
111 function_name, grid.0, grid.1, grid.2, block.0, block.1, block.2, shared_mem,
112 ),
113 Self::Memcpy { direction, size } => {
114 write!(f, "Memcpy({direction}, {size} bytes)")
115 }
116 Self::Memset { size, value } => {
117 write!(f, "Memset({size} bytes, value=0x{value:02x})")
118 }
119 Self::Empty => write!(f, "Empty"),
120 }
121 }
122}
123
124#[derive(Debug, Clone)]
137pub struct Graph {
138 nodes: Vec<GraphNode>,
139 dependencies: Vec<(usize, usize)>,
140}
141
142impl Default for Graph {
143 fn default() -> Self {
144 Self::new()
145 }
146}
147
148impl Graph {
149 pub fn new() -> Self {
151 Self {
152 nodes: Vec::new(),
153 dependencies: Vec::new(),
154 }
155 }
156
157 pub fn add_kernel_node(
169 &mut self,
170 function_name: &str,
171 grid: (u32, u32, u32),
172 block: (u32, u32, u32),
173 shared_mem: u32,
174 ) -> usize {
175 let idx = self.nodes.len();
176 self.nodes.push(GraphNode::KernelLaunch {
177 function_name: function_name.to_owned(),
178 grid,
179 block,
180 shared_mem,
181 });
182 idx
183 }
184
185 pub fn add_memcpy_node(&mut self, direction: MemcpyDirection, size: usize) -> usize {
194 let idx = self.nodes.len();
195 self.nodes.push(GraphNode::Memcpy { direction, size });
196 idx
197 }
198
199 pub fn add_memset_node(&mut self, size: usize, value: u8) -> usize {
208 let idx = self.nodes.len();
209 self.nodes.push(GraphNode::Memset { size, value });
210 idx
211 }
212
213 pub fn add_empty_node(&mut self) -> usize {
221 let idx = self.nodes.len();
222 self.nodes.push(GraphNode::Empty);
223 idx
224 }
225
226 pub fn add_dependency(&mut self, from: usize, to: usize) -> CudaResult<()> {
236 if from >= self.nodes.len() || to >= self.nodes.len() {
237 return Err(CudaError::InvalidValue);
238 }
239 if from == to {
240 return Err(CudaError::InvalidValue);
241 }
242 self.dependencies.push((from, to));
243 Ok(())
244 }
245
246 #[inline]
248 pub fn node_count(&self) -> usize {
249 self.nodes.len()
250 }
251
252 #[inline]
254 pub fn dependency_count(&self) -> usize {
255 self.dependencies.len()
256 }
257
258 #[inline]
260 pub fn nodes(&self) -> &[GraphNode] {
261 &self.nodes
262 }
263
264 #[inline]
266 pub fn dependencies(&self) -> &[(usize, usize)] {
267 &self.dependencies
268 }
269
270 pub fn get_node(&self, index: usize) -> Option<&GraphNode> {
272 self.nodes.get(index)
273 }
274
275 pub fn topological_sort(&self) -> CudaResult<Vec<usize>> {
285 let n = self.nodes.len();
286 let mut in_degree = vec![0u32; n];
287 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
288
289 for &(from, to) in &self.dependencies {
290 adj[from].push(to);
291 in_degree[to] = in_degree[to].saturating_add(1);
292 }
293
294 let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
295 let mut result = Vec::with_capacity(n);
296
297 while let Some(node) = queue.pop() {
298 result.push(node);
299 for &next in &adj[node] {
300 in_degree[next] = in_degree[next].saturating_sub(1);
301 if in_degree[next] == 0 {
302 queue.push(next);
303 }
304 }
305 }
306
307 if result.len() != n {
308 return Err(CudaError::InvalidValue);
309 }
310
311 Ok(result)
312 }
313
314 pub fn instantiate(&self) -> CudaResult<GraphExec> {
324 let execution_order = self.topological_sort()?;
326 Ok(GraphExec {
327 graph: self.clone(),
328 execution_order,
329 })
330 }
331}
332
333impl std::fmt::Display for Graph {
334 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
335 write!(
336 f,
337 "Graph({} nodes, {} deps)",
338 self.nodes.len(),
339 self.dependencies.len()
340 )
341 }
342}
343
344pub struct GraphExec {
357 graph: Graph,
358 execution_order: Vec<usize>,
359}
360
361impl GraphExec {
362 pub fn launch(&self, _stream: &Stream) -> CudaResult<()> {
373 let _api = crate::loader::try_driver()?;
377 Ok(())
381 }
382
383 #[inline]
385 pub fn graph(&self) -> &Graph {
386 &self.graph
387 }
388
389 #[inline]
391 pub fn execution_order(&self) -> &[usize] {
392 &self.execution_order
393 }
394
395 #[inline]
397 pub fn node_count(&self) -> usize {
398 self.graph.node_count()
399 }
400}
401
402impl std::fmt::Debug for GraphExec {
403 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
404 f.debug_struct("GraphExec")
405 .field("graph", &self.graph)
406 .field("execution_order", &self.execution_order)
407 .finish()
408 }
409}
410
411pub struct StreamCapture {
442 nodes: Vec<GraphNode>,
443 active: bool,
445}
446
447impl StreamCapture {
448 pub fn begin(_stream: &Stream) -> CudaResult<Self> {
458 let _api = crate::loader::try_driver()?;
460 Ok(Self {
461 nodes: Vec::new(),
462 active: true,
463 })
464 }
465
466 pub fn record_kernel(
475 &mut self,
476 function_name: &str,
477 grid: (u32, u32, u32),
478 block: (u32, u32, u32),
479 shared_mem: u32,
480 ) {
481 if self.active {
482 self.nodes.push(GraphNode::KernelLaunch {
483 function_name: function_name.to_owned(),
484 grid,
485 block,
486 shared_mem,
487 });
488 }
489 }
490
491 pub fn record_memcpy(&mut self, direction: MemcpyDirection, size: usize) {
498 if self.active {
499 self.nodes.push(GraphNode::Memcpy { direction, size });
500 }
501 }
502
503 pub fn record_memset(&mut self, size: usize, value: u8) {
510 if self.active {
511 self.nodes.push(GraphNode::Memset { size, value });
512 }
513 }
514
515 #[inline]
517 pub fn recorded_count(&self) -> usize {
518 self.nodes.len()
519 }
520
521 #[inline]
523 pub fn is_active(&self) -> bool {
524 self.active
525 }
526
527 pub fn end(mut self) -> CudaResult<Graph> {
541 if !self.active {
542 return Err(CudaError::StreamCaptureUnmatched);
543 }
544 self.active = false;
545
546 let mut graph = Graph::new();
547 let mut prev_idx: Option<usize> = None;
548
549 for node in self.nodes.drain(..) {
550 let idx = graph.nodes.len();
551 graph.nodes.push(node);
552
553 if let Some(prev) = prev_idx {
555 graph.dependencies.push((prev, idx));
556 }
557 prev_idx = Some(idx);
558 }
559
560 Ok(graph)
561 }
562}
563
564#[cfg(test)]
569mod tests {
570 use super::*;
571
572 #[test]
573 fn graph_new_is_empty() {
574 let g = Graph::new();
575 assert_eq!(g.node_count(), 0);
576 assert_eq!(g.dependency_count(), 0);
577 assert!(g.nodes().is_empty());
578 assert!(g.dependencies().is_empty());
579 }
580
581 #[test]
582 fn graph_default_is_empty() {
583 let g = Graph::default();
584 assert_eq!(g.node_count(), 0);
585 }
586
587 #[test]
588 fn add_kernel_node_returns_sequential_indices() {
589 let mut g = Graph::new();
590 let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
591 let n1 = g.add_kernel_node("k1", (2, 1, 1), (64, 1, 1), 128);
592 assert_eq!(n0, 0);
593 assert_eq!(n1, 1);
594 assert_eq!(g.node_count(), 2);
595 }
596
597 #[test]
598 fn add_memcpy_node_records_direction_and_size() {
599 let mut g = Graph::new();
600 let idx = g.add_memcpy_node(MemcpyDirection::HostToDevice, 4096);
601 assert_eq!(idx, 0);
602 let node = g.get_node(0);
603 assert!(node.is_some());
604 if let Some(GraphNode::Memcpy { direction, size }) = node {
605 assert_eq!(*direction, MemcpyDirection::HostToDevice);
606 assert_eq!(*size, 4096);
607 } else {
608 panic!("expected Memcpy node");
609 }
610 }
611
612 #[test]
613 fn add_memset_node_records_size_and_value() {
614 let mut g = Graph::new();
615 let idx = g.add_memset_node(8192, 0xAB);
616 assert_eq!(idx, 0);
617 if let Some(GraphNode::Memset { size, value }) = g.get_node(idx) {
618 assert_eq!(*size, 8192);
619 assert_eq!(*value, 0xAB);
620 } else {
621 panic!("expected Memset node");
622 }
623 }
624
625 #[test]
626 fn add_empty_node_works() {
627 let mut g = Graph::new();
628 let idx = g.add_empty_node();
629 assert_eq!(idx, 0);
630 assert_eq!(g.get_node(idx), Some(&GraphNode::Empty));
631 }
632
633 #[test]
634 fn add_dependency_valid() {
635 let mut g = Graph::new();
636 let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
637 let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
638 assert!(g.add_dependency(n0, n1).is_ok());
639 assert_eq!(g.dependency_count(), 1);
640 assert_eq!(g.dependencies()[0], (0, 1));
641 }
642
643 #[test]
644 fn add_dependency_out_of_bounds() {
645 let mut g = Graph::new();
646 let _n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
647 let result = g.add_dependency(0, 5);
648 assert_eq!(result, Err(CudaError::InvalidValue));
649 }
650
651 #[test]
652 fn add_dependency_self_loop() {
653 let mut g = Graph::new();
654 let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
655 let result = g.add_dependency(n0, n0);
656 assert_eq!(result, Err(CudaError::InvalidValue));
657 }
658
659 #[test]
660 fn topological_sort_linear_chain() {
661 let mut g = Graph::new();
662 let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
663 let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
664 let n2 = g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
665 g.add_dependency(n0, n1).ok();
666 g.add_dependency(n1, n2).ok();
667
668 let order = g.topological_sort();
669 assert!(order.is_ok());
670 let order = order.ok();
671 assert!(order.is_some());
672 let order = order.unwrap_or_default();
673 let pos = |n: usize| -> usize { order.iter().position(|&x| x == n).unwrap_or(usize::MAX) };
675 assert!(pos(n0) < pos(n1));
676 assert!(pos(n1) < pos(n2));
677 }
678
679 #[test]
680 fn topological_sort_detects_cycle() {
681 let mut g = Graph::new();
682 let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
683 let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
684 g.add_dependency(n0, n1).ok();
685 g.add_dependency(n1, n0).ok();
686
687 let result = g.topological_sort();
688 assert_eq!(result, Err(CudaError::InvalidValue));
689 }
690
691 #[test]
692 fn topological_sort_no_deps() {
693 let mut g = Graph::new();
694 g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
695 g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
696 g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
697
698 let order = g.topological_sort();
699 assert!(order.is_ok());
700 let order = order.unwrap_or_default();
701 assert_eq!(order.len(), 3);
702 }
703
704 #[test]
705 fn instantiate_valid_graph() {
706 let mut g = Graph::new();
707 let n0 = g.add_memcpy_node(MemcpyDirection::HostToDevice, 1024);
708 let n1 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
709 let n2 = g.add_memcpy_node(MemcpyDirection::DeviceToHost, 1024);
710 g.add_dependency(n0, n1).ok();
711 g.add_dependency(n1, n2).ok();
712
713 let exec = g.instantiate();
714 assert!(exec.is_ok());
715 let exec = exec.ok();
716 assert!(exec.is_some());
717 if let Some(exec) = exec {
718 assert_eq!(exec.node_count(), 3);
719 assert_eq!(exec.execution_order().len(), 3);
720 }
721 }
722
723 #[test]
724 fn instantiate_cyclic_graph_fails() {
725 let mut g = Graph::new();
726 let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
727 let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
728 g.add_dependency(n0, n1).ok();
729 g.add_dependency(n1, n0).ok();
730
731 let result = g.instantiate();
732 assert!(result.is_err());
733 }
734
735 #[test]
736 fn graph_display() {
737 let mut g = Graph::new();
738 g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
739 g.add_memcpy_node(MemcpyDirection::HostToDevice, 512);
740 let disp = format!("{g}");
741 assert!(disp.contains("2 nodes"));
742 assert!(disp.contains("0 deps"));
743 }
744
745 #[test]
746 fn node_display() {
747 let node = GraphNode::KernelLaunch {
748 function_name: "foo".to_owned(),
749 grid: (4, 1, 1),
750 block: (256, 1, 1),
751 shared_mem: 0,
752 };
753 let disp = format!("{node}");
754 assert!(disp.contains("foo"));
755
756 let node = GraphNode::Memcpy {
757 direction: MemcpyDirection::DeviceToHost,
758 size: 1024,
759 };
760 let disp = format!("{node}");
761 assert!(disp.contains("DtoH"));
762
763 let node = GraphNode::Memset {
764 size: 256,
765 value: 0xFF,
766 };
767 let disp = format!("{node}");
768 assert!(disp.contains("0xff"));
769
770 let node = GraphNode::Empty;
771 let disp = format!("{node}");
772 assert!(disp.contains("Empty"));
773 }
774
775 #[test]
776 fn memcpy_direction_display() {
777 assert_eq!(format!("{}", MemcpyDirection::HostToDevice), "HtoD");
778 assert_eq!(format!("{}", MemcpyDirection::DeviceToHost), "DtoH");
779 assert_eq!(format!("{}", MemcpyDirection::DeviceToDevice), "DtoD");
780 }
781
782 #[test]
783 fn graph_get_node_out_of_bounds() {
784 let g = Graph::new();
785 assert!(g.get_node(0).is_none());
786 assert!(g.get_node(100).is_none());
787 }
788
789 #[test]
790 fn graph_diamond_dag() {
791 let mut g = Graph::new();
793 let n0 = g.add_empty_node();
794 let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
795 let n2 = g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
796 let n3 = g.add_empty_node();
797 g.add_dependency(n0, n1).ok();
798 g.add_dependency(n0, n2).ok();
799 g.add_dependency(n1, n3).ok();
800 g.add_dependency(n2, n3).ok();
801
802 let order = g.topological_sort().unwrap_or_default();
803 assert_eq!(order.len(), 4);
804 let pos = |n: usize| -> usize { order.iter().position(|&x| x == n).unwrap_or(usize::MAX) };
805 assert!(pos(n0) < pos(n1));
806 assert!(pos(n0) < pos(n2));
807 assert!(pos(n1) < pos(n3));
808 assert!(pos(n2) < pos(n3));
809
810 let exec = g.instantiate();
811 assert!(exec.is_ok());
812 }
813
814 #[test]
815 fn graph_exec_debug() {
816 let mut g = Graph::new();
817 g.add_empty_node();
818 let exec = g.instantiate().ok();
819 assert!(exec.is_some());
820 if let Some(exec) = exec {
821 let dbg = format!("{exec:?}");
822 assert!(dbg.contains("GraphExec"));
823 }
824 }
825}