1use std::collections::{BTreeMap, BTreeSet, VecDeque};
24use thiserror::Error;
25
26#[derive(Debug, Clone, PartialEq, Error)]
30pub enum GraphError {
31 #[error("Node not found: {0}")]
33 NodeNotFound(u32),
34 #[error("Duplicate node ID: {0}")]
36 DuplicateNode(u32),
37 #[error("Edge from node {from} to node {to} would create a cycle")]
39 CyclicEdge { from: u32, to: u32 },
40 #[error("Compute graph contains a cycle; cannot determine execution order")]
42 CycleDetected,
43 #[error("Node {node_id} is missing required resource binding '{resource}'")]
45 MissingBinding { node_id: u32, resource: String },
46 #[error("Resource '{resource}' cannot be bound to a {node_kind:?} node")]
48 IncompatibleBinding {
49 resource: String,
50 node_kind: NodeKind,
51 },
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
58pub enum NodeKind {
59 Kernel {
61 entry_point: String,
63 dispatch: [u32; 3],
65 },
66 Copy {
68 src_buffer: u32,
70 dst_buffer: u32,
72 byte_count: usize,
74 },
75 Barrier {
77 src_stage: PipelineStageFlags,
79 dst_stage: PipelineStageFlags,
81 },
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
88pub struct PipelineStageFlags(pub u32);
89
90impl PipelineStageFlags {
91 pub const NONE: Self = Self(0);
93 pub const COMPUTE_SHADER: Self = Self(1 << 0);
95 pub const TRANSFER: Self = Self(1 << 1);
97 pub const HOST: Self = Self(1 << 2);
99 pub const ALL: Self = Self(0xFFFF_FFFF);
101
102 #[must_use]
104 pub fn contains(self, other: Self) -> bool {
105 (self.0 & other.0) == other.0
106 }
107
108 #[must_use]
110 pub fn union(self, other: Self) -> Self {
111 Self(self.0 | other.0)
112 }
113}
114
115#[derive(Debug, Clone, PartialEq, Eq)]
119pub struct ResourceBinding {
120 pub name: String,
122 pub resource_id: u32,
124 pub access: ResourceAccess,
126}
127
128#[derive(Debug, Clone, Copy, PartialEq, Eq)]
130pub enum ResourceAccess {
131 ReadOnly,
133 WriteOnly,
135 ReadWrite,
137}
138
139#[derive(Debug, Clone)]
143pub struct GraphNode {
144 pub id: u32,
146 pub label: String,
148 pub kind: NodeKind,
150 pub bindings: Vec<ResourceBinding>,
152}
153
154impl GraphNode {
155 #[must_use]
157 pub fn new(id: u32, label: impl Into<String>, kind: NodeKind) -> Self {
158 Self {
159 id,
160 label: label.into(),
161 kind,
162 bindings: Vec::new(),
163 }
164 }
165}
166
167#[derive(Debug, Clone)]
171pub struct ExecutionPlan {
172 pub order: Vec<u32>,
174 pub total_dispatch_groups: u64,
176 pub barrier_count: usize,
178 pub copy_count: usize,
180}
181
182pub struct ComputeGraph {
186 nodes: BTreeMap<u32, GraphNode>,
188 adj: BTreeMap<u32, BTreeSet<u32>>,
190 radj: BTreeMap<u32, BTreeSet<u32>>,
192}
193
194impl ComputeGraph {
195 #[must_use]
197 pub fn new() -> Self {
198 Self {
199 nodes: BTreeMap::new(),
200 adj: BTreeMap::new(),
201 radj: BTreeMap::new(),
202 }
203 }
204
205 pub fn add_node(&mut self, node: GraphNode) -> Result<(), GraphError> {
212 if self.nodes.contains_key(&node.id) {
213 return Err(GraphError::DuplicateNode(node.id));
214 }
215 let id = node.id;
216 self.nodes.insert(id, node);
217 self.adj.entry(id).or_default();
218 self.radj.entry(id).or_default();
219 Ok(())
220 }
221
222 pub fn bind_resource(
230 &mut self,
231 node_id: u32,
232 binding: ResourceBinding,
233 ) -> Result<(), GraphError> {
234 let node = self
235 .nodes
236 .get_mut(&node_id)
237 .ok_or(GraphError::NodeNotFound(node_id))?;
238 if matches!(node.kind, NodeKind::Barrier { .. }) {
240 return Err(GraphError::IncompatibleBinding {
241 resource: binding.name,
242 node_kind: NodeKind::Barrier {
243 src_stage: PipelineStageFlags::NONE,
244 dst_stage: PipelineStageFlags::NONE,
245 },
246 });
247 }
248 node.bindings.push(binding);
249 Ok(())
250 }
251
252 pub fn add_edge(&mut self, from: u32, to: u32) -> Result<(), GraphError> {
259 if !self.nodes.contains_key(&from) {
260 return Err(GraphError::NodeNotFound(from));
261 }
262 if !self.nodes.contains_key(&to) {
263 return Err(GraphError::NodeNotFound(to));
264 }
265 if self.is_reachable(to, from) {
267 return Err(GraphError::CyclicEdge { from, to });
268 }
269 self.adj.entry(from).or_default().insert(to);
270 self.radj.entry(to).or_default().insert(from);
271 Ok(())
272 }
273
274 pub fn execution_order(&self) -> Result<ExecutionPlan, GraphError> {
282 let mut in_degree: BTreeMap<u32, usize> = self
283 .nodes
284 .keys()
285 .map(|&id| (id, self.radj[&id].len()))
286 .collect();
287
288 let mut ready: BTreeSet<u32> = in_degree
289 .iter()
290 .filter_map(|(&id, °)| if deg == 0 { Some(id) } else { None })
291 .collect();
292
293 let mut order = Vec::with_capacity(self.nodes.len());
294
295 while let Some(&next) = ready.iter().next() {
296 ready.remove(&next);
297 order.push(next);
298 for &successor in self
299 .adj
300 .get(&next)
301 .map_or(&BTreeSet::new() as &BTreeSet<u32>, |s| s)
302 {
303 let deg = in_degree.entry(successor).or_insert(0);
304 *deg = deg.saturating_sub(1);
305 if *deg == 0 {
306 ready.insert(successor);
307 }
308 }
309 }
310
311 if order.len() != self.nodes.len() {
312 return Err(GraphError::CycleDetected);
313 }
314
315 let mut total_dispatch_groups: u64 = 0;
317 let mut barrier_count = 0usize;
318 let mut copy_count = 0usize;
319
320 for &id in &order {
321 if let Some(node) = self.nodes.get(&id) {
322 match &node.kind {
323 NodeKind::Kernel { dispatch, .. } => {
324 total_dispatch_groups +=
325 dispatch.iter().map(|&d| u64::from(d)).product::<u64>();
326 }
327 NodeKind::Copy { .. } => copy_count += 1,
328 NodeKind::Barrier { .. } => barrier_count += 1,
329 }
330 }
331 }
332
333 Ok(ExecutionPlan {
334 order,
335 total_dispatch_groups,
336 barrier_count,
337 copy_count,
338 })
339 }
340
341 pub fn validate(&self) -> Result<(), GraphError> {
350 for node in self.nodes.values() {
351 match &node.kind {
352 NodeKind::Kernel { .. } | NodeKind::Copy { .. } => {
353 if node.bindings.is_empty() {
354 return Err(GraphError::MissingBinding {
355 node_id: node.id,
356 resource: "<any>".to_string(),
357 });
358 }
359 }
360 NodeKind::Barrier { .. } => {} }
362 }
363 Ok(())
364 }
365
366 #[must_use]
368 pub fn node_count(&self) -> usize {
369 self.nodes.len()
370 }
371
372 #[must_use]
374 pub fn edge_count(&self) -> usize {
375 self.adj.values().map(|s| s.len()).sum()
376 }
377
378 #[must_use]
380 pub fn node(&self, id: u32) -> Option<&GraphNode> {
381 self.nodes.get(&id)
382 }
383
384 pub fn predecessors(&self, node_id: u32) -> Result<Vec<u32>, GraphError> {
391 if !self.nodes.contains_key(&node_id) {
392 return Err(GraphError::NodeNotFound(node_id));
393 }
394 Ok(self
395 .radj
396 .get(&node_id)
397 .map_or(vec![], |s| s.iter().copied().collect()))
398 }
399
400 pub fn successors(&self, node_id: u32) -> Result<Vec<u32>, GraphError> {
407 if !self.nodes.contains_key(&node_id) {
408 return Err(GraphError::NodeNotFound(node_id));
409 }
410 Ok(self
411 .adj
412 .get(&node_id)
413 .map_or(vec![], |s| s.iter().copied().collect()))
414 }
415
416 fn is_reachable(&self, start: u32, target: u32) -> bool {
420 if start == target {
421 return true;
422 }
423 let mut visited = BTreeSet::new();
424 let mut queue = VecDeque::new();
425 queue.push_back(start);
426 while let Some(cur) = queue.pop_front() {
427 if visited.contains(&cur) {
428 continue;
429 }
430 visited.insert(cur);
431 if let Some(succs) = self.adj.get(&cur) {
432 for &s in succs {
433 if s == target {
434 return true;
435 }
436 queue.push_back(s);
437 }
438 }
439 }
440 false
441 }
442}
443
444impl Default for ComputeGraph {
445 fn default() -> Self {
446 Self::new()
447 }
448}
449
450#[cfg(test)]
453mod tests {
454 use super::*;
455
456 fn kernel_node(id: u32, dispatch: [u32; 3]) -> GraphNode {
457 GraphNode::new(
458 id,
459 format!("kernel_{id}"),
460 NodeKind::Kernel {
461 entry_point: format!("main_{id}"),
462 dispatch,
463 },
464 )
465 }
466
467 fn copy_node(id: u32, src: u32, dst: u32, bytes: usize) -> GraphNode {
468 GraphNode::new(
469 id,
470 format!("copy_{id}"),
471 NodeKind::Copy {
472 src_buffer: src,
473 dst_buffer: dst,
474 byte_count: bytes,
475 },
476 )
477 }
478
479 fn barrier_node(id: u32) -> GraphNode {
480 GraphNode::new(
481 id,
482 format!("barrier_{id}"),
483 NodeKind::Barrier {
484 src_stage: PipelineStageFlags::COMPUTE_SHADER,
485 dst_stage: PipelineStageFlags::TRANSFER,
486 },
487 )
488 }
489
490 fn simple_binding(name: &str, resource_id: u32) -> ResourceBinding {
491 ResourceBinding {
492 name: name.to_string(),
493 resource_id,
494 access: ResourceAccess::ReadWrite,
495 }
496 }
497
498 #[test]
501 fn test_pipeline_stage_contains() {
502 let combined = PipelineStageFlags::COMPUTE_SHADER.union(PipelineStageFlags::TRANSFER);
503 assert!(combined.contains(PipelineStageFlags::COMPUTE_SHADER));
504 assert!(combined.contains(PipelineStageFlags::TRANSFER));
505 assert!(!combined.contains(PipelineStageFlags::HOST));
506 }
507
508 #[test]
509 fn test_pipeline_stage_all_contains_any() {
510 assert!(PipelineStageFlags::ALL.contains(PipelineStageFlags::COMPUTE_SHADER));
511 assert!(PipelineStageFlags::ALL.contains(PipelineStageFlags::HOST));
512 }
513
514 #[test]
517 fn test_add_node_and_count() -> Result<(), GraphError> {
518 let mut g = ComputeGraph::new();
519 g.add_node(kernel_node(1, [4, 1, 1]))?;
520 g.add_node(barrier_node(2))?;
521 assert_eq!(g.node_count(), 2);
522 Ok(())
523 }
524
525 #[test]
526 fn test_add_duplicate_node_error() -> Result<(), GraphError> {
527 let mut g = ComputeGraph::new();
528 g.add_node(kernel_node(1, [1, 1, 1]))?;
529 let err = g.add_node(kernel_node(1, [2, 2, 2]));
530 assert!(matches!(err, Err(GraphError::DuplicateNode(1))));
531 Ok(())
532 }
533
534 #[test]
535 fn test_add_edge_increments_count() -> Result<(), GraphError> {
536 let mut g = ComputeGraph::new();
537 g.add_node(kernel_node(1, [1, 1, 1]))?;
538 g.add_node(kernel_node(2, [1, 1, 1]))?;
539 g.add_edge(1, 2)?;
540 assert_eq!(g.edge_count(), 1);
541 Ok(())
542 }
543
544 #[test]
545 fn test_add_edge_unknown_node_error() -> Result<(), GraphError> {
546 let mut g = ComputeGraph::new();
547 g.add_node(kernel_node(1, [1, 1, 1]))?;
548 assert!(matches!(
549 g.add_edge(1, 99),
550 Err(GraphError::NodeNotFound(99))
551 ));
552 Ok(())
553 }
554
555 #[test]
556 fn test_add_cyclic_edge_error() -> Result<(), GraphError> {
557 let mut g = ComputeGraph::new();
558 g.add_node(kernel_node(1, [1, 1, 1]))?;
559 g.add_node(kernel_node(2, [1, 1, 1]))?;
560 g.add_edge(1, 2)?;
561 let err = g.add_edge(2, 1);
562 assert!(matches!(
563 err,
564 Err(GraphError::CyclicEdge { from: 2, to: 1 })
565 ));
566 Ok(())
567 }
568
569 #[test]
572 fn test_execution_order_single_node() -> Result<(), GraphError> {
573 let mut g = ComputeGraph::new();
574 g.add_node(kernel_node(5, [8, 1, 1]))?;
575 let plan = g.execution_order()?;
576 assert_eq!(plan.order, vec![5]);
577 assert_eq!(plan.total_dispatch_groups, 8);
578 Ok(())
579 }
580
581 #[test]
582 fn test_execution_order_linear_chain() -> Result<(), GraphError> {
583 let mut g = ComputeGraph::new();
584 for id in [1, 2, 3] {
585 g.add_node(kernel_node(id, [2, 1, 1]))?;
586 }
587 g.add_edge(1, 2)?;
588 g.add_edge(2, 3)?;
589 let plan = g.execution_order()?;
590 assert_eq!(plan.order, vec![1, 2, 3]);
591 assert_eq!(plan.total_dispatch_groups, 6);
592 Ok(())
593 }
594
595 #[test]
596 fn test_execution_order_with_barrier_and_copy() -> Result<(), GraphError> {
597 let mut g = ComputeGraph::new();
599 g.add_node(kernel_node(1, [4, 4, 1]))?;
600 g.add_node(barrier_node(2))?;
601 g.add_node(copy_node(3, 0, 1, 1024))?;
602 g.add_edge(1, 2)?;
603 g.add_edge(2, 3)?;
604 let plan = g.execution_order()?;
605 assert_eq!(plan.order, vec![1, 2, 3]);
606 assert_eq!(plan.barrier_count, 1);
607 assert_eq!(plan.copy_count, 1);
608 assert_eq!(plan.total_dispatch_groups, 16);
609 Ok(())
610 }
611
612 #[test]
613 fn test_execution_order_independent_nodes_sorted_by_id() -> Result<(), GraphError> {
614 let mut g = ComputeGraph::new();
615 for id in [5, 3, 1] {
616 g.add_node(kernel_node(id, [1, 1, 1]))?;
617 }
618 let plan = g.execution_order()?;
619 assert_eq!(plan.order, vec![1, 3, 5]);
620 Ok(())
621 }
622
623 #[test]
626 fn test_bind_resource_to_kernel() -> Result<(), GraphError> {
627 let mut g = ComputeGraph::new();
628 g.add_node(kernel_node(1, [1, 1, 1]))?;
629 g.bind_resource(1, simple_binding("input", 10))?;
630 let node = g.node(1).ok_or(GraphError::NodeNotFound(1))?;
631 assert_eq!(node.bindings.len(), 1);
632 Ok(())
633 }
634
635 #[test]
636 fn test_bind_resource_to_barrier_fails() -> Result<(), GraphError> {
637 let mut g = ComputeGraph::new();
638 g.add_node(barrier_node(1))?;
639 let err = g.bind_resource(1, simple_binding("buf", 0));
640 assert!(matches!(err, Err(GraphError::IncompatibleBinding { .. })));
641 Ok(())
642 }
643
644 #[test]
645 fn test_bind_resource_unknown_node_fails() -> Result<(), GraphError> {
646 let mut g = ComputeGraph::new();
647 let err = g.bind_resource(99, simple_binding("buf", 0));
648 assert!(matches!(err, Err(GraphError::NodeNotFound(99))));
649 Ok(())
650 }
651
652 #[test]
655 fn test_validate_passes_when_all_bound() -> Result<(), GraphError> {
656 let mut g = ComputeGraph::new();
657 let mut n = kernel_node(1, [1, 1, 1]);
658 n.bindings.push(simple_binding("buf", 0));
659 g.add_node(n)?;
660 g.add_node(barrier_node(2))?;
661 g.add_edge(1, 2)?;
662 assert!(g.validate().is_ok());
663 Ok(())
664 }
665
666 #[test]
667 fn test_validate_fails_when_kernel_has_no_bindings() -> Result<(), GraphError> {
668 let mut g = ComputeGraph::new();
669 g.add_node(kernel_node(1, [1, 1, 1]))?;
670 assert!(matches!(
671 g.validate(),
672 Err(GraphError::MissingBinding { node_id: 1, .. })
673 ));
674 Ok(())
675 }
676
677 #[test]
680 fn test_predecessors_and_successors() -> Result<(), GraphError> {
681 let mut g = ComputeGraph::new();
682 for id in [1, 2, 3] {
683 g.add_node(kernel_node(id, [1, 1, 1]))?;
684 }
685 g.add_edge(1, 3)?;
686 g.add_edge(2, 3)?;
687 let mut preds = g.predecessors(3)?;
688 preds.sort_unstable();
689 assert_eq!(preds, vec![1, 2]);
690 let succs_1 = g.successors(1)?;
691 assert_eq!(succs_1, vec![3]);
692 Ok(())
693 }
694
695 #[test]
696 fn test_predecessors_unknown_node_error() -> Result<(), GraphError> {
697 let g = ComputeGraph::new();
698 assert!(matches!(
699 g.predecessors(42),
700 Err(GraphError::NodeNotFound(42))
701 ));
702 Ok(())
703 }
704}