1use std::collections::{BTreeMap, BTreeSet, VecDeque};
24use thiserror::Error;
25
26#[derive(Debug, Clone, PartialEq, Error)]
30pub enum GraphError {
31 #[error("Node not found: {0}")]
33 NodeNotFound(u32),
34 #[error("Duplicate node ID: {0}")]
36 DuplicateNode(u32),
37 #[error("Edge from node {from} to node {to} would create a cycle")]
39 CyclicEdge { from: u32, to: u32 },
40 #[error("Compute graph contains a cycle; cannot determine execution order")]
42 CycleDetected,
43 #[error("Node {node_id} is missing required resource binding '{resource}'")]
45 MissingBinding { node_id: u32, resource: String },
46 #[error("Resource '{resource}' cannot be bound to a {node_kind:?} node")]
48 IncompatibleBinding {
49 resource: String,
50 node_kind: NodeKind,
51 },
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
58pub enum NodeKind {
59 Kernel {
61 entry_point: String,
63 dispatch: [u32; 3],
65 },
66 Copy {
68 src_buffer: u32,
70 dst_buffer: u32,
72 byte_count: usize,
74 },
75 Barrier {
77 src_stage: PipelineStageFlags,
79 dst_stage: PipelineStageFlags,
81 },
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
88pub struct PipelineStageFlags(pub u32);
89
90impl PipelineStageFlags {
91 pub const NONE: Self = Self(0);
93 pub const COMPUTE_SHADER: Self = Self(1 << 0);
95 pub const TRANSFER: Self = Self(1 << 1);
97 pub const HOST: Self = Self(1 << 2);
99 pub const ALL: Self = Self(0xFFFF_FFFF);
101
102 #[must_use]
104 pub fn contains(self, other: Self) -> bool {
105 (self.0 & other.0) == other.0
106 }
107
108 #[must_use]
110 pub fn union(self, other: Self) -> Self {
111 Self(self.0 | other.0)
112 }
113}
114
115#[derive(Debug, Clone, PartialEq, Eq)]
119pub struct ResourceBinding {
120 pub name: String,
122 pub resource_id: u32,
124 pub access: ResourceAccess,
126}
127
128#[derive(Debug, Clone, Copy, PartialEq, Eq)]
130pub enum ResourceAccess {
131 ReadOnly,
133 WriteOnly,
135 ReadWrite,
137}
138
139#[derive(Debug, Clone)]
143pub struct GraphNode {
144 pub id: u32,
146 pub label: String,
148 pub kind: NodeKind,
150 pub bindings: Vec<ResourceBinding>,
152}
153
154impl GraphNode {
155 #[must_use]
157 pub fn new(id: u32, label: impl Into<String>, kind: NodeKind) -> Self {
158 Self {
159 id,
160 label: label.into(),
161 kind,
162 bindings: Vec::new(),
163 }
164 }
165}
166
167#[derive(Debug, Clone)]
171pub struct ExecutionPlan {
172 pub order: Vec<u32>,
174 pub total_dispatch_groups: u64,
176 pub barrier_count: usize,
178 pub copy_count: usize,
180}
181
182pub struct ComputeGraph {
186 nodes: BTreeMap<u32, GraphNode>,
188 adj: BTreeMap<u32, BTreeSet<u32>>,
190 radj: BTreeMap<u32, BTreeSet<u32>>,
192}
193
194impl ComputeGraph {
195 #[must_use]
197 pub fn new() -> Self {
198 Self {
199 nodes: BTreeMap::new(),
200 adj: BTreeMap::new(),
201 radj: BTreeMap::new(),
202 }
203 }
204
205 pub fn add_node(&mut self, node: GraphNode) -> Result<(), GraphError> {
212 if self.nodes.contains_key(&node.id) {
213 return Err(GraphError::DuplicateNode(node.id));
214 }
215 let id = node.id;
216 self.nodes.insert(id, node);
217 self.adj.entry(id).or_default();
218 self.radj.entry(id).or_default();
219 Ok(())
220 }
221
222 pub fn bind_resource(
230 &mut self,
231 node_id: u32,
232 binding: ResourceBinding,
233 ) -> Result<(), GraphError> {
234 let node = self
235 .nodes
236 .get_mut(&node_id)
237 .ok_or(GraphError::NodeNotFound(node_id))?;
238 if matches!(node.kind, NodeKind::Barrier { .. }) {
240 return Err(GraphError::IncompatibleBinding {
241 resource: binding.name,
242 node_kind: NodeKind::Barrier {
243 src_stage: PipelineStageFlags::NONE,
244 dst_stage: PipelineStageFlags::NONE,
245 },
246 });
247 }
248 node.bindings.push(binding);
249 Ok(())
250 }
251
252 pub fn add_edge(&mut self, from: u32, to: u32) -> Result<(), GraphError> {
259 if !self.nodes.contains_key(&from) {
260 return Err(GraphError::NodeNotFound(from));
261 }
262 if !self.nodes.contains_key(&to) {
263 return Err(GraphError::NodeNotFound(to));
264 }
265 if self.is_reachable(to, from) {
267 return Err(GraphError::CyclicEdge { from, to });
268 }
269 self.adj.entry(from).or_default().insert(to);
270 self.radj.entry(to).or_default().insert(from);
271 Ok(())
272 }
273
274 pub fn execution_order(&self) -> Result<ExecutionPlan, GraphError> {
282 let mut in_degree: BTreeMap<u32, usize> = self
283 .nodes
284 .keys()
285 .map(|&id| (id, self.radj[&id].len()))
286 .collect();
287
288 let mut ready: BTreeSet<u32> = in_degree
289 .iter()
290 .filter_map(|(&id, °)| if deg == 0 { Some(id) } else { None })
291 .collect();
292
293 let mut order = Vec::with_capacity(self.nodes.len());
294
295 while let Some(&next) = ready.iter().next() {
296 ready.remove(&next);
297 order.push(next);
298 for &successor in self
299 .adj
300 .get(&next)
301 .map_or(&BTreeSet::new() as &BTreeSet<u32>, |s| s)
302 {
303 let deg = in_degree.entry(successor).or_insert(0);
304 *deg = deg.saturating_sub(1);
305 if *deg == 0 {
306 ready.insert(successor);
307 }
308 }
309 }
310
311 if order.len() != self.nodes.len() {
312 return Err(GraphError::CycleDetected);
313 }
314
315 let mut total_dispatch_groups: u64 = 0;
317 let mut barrier_count = 0usize;
318 let mut copy_count = 0usize;
319
320 for &id in &order {
321 if let Some(node) = self.nodes.get(&id) {
322 match &node.kind {
323 NodeKind::Kernel { dispatch, .. } => {
324 total_dispatch_groups +=
325 dispatch.iter().map(|&d| u64::from(d)).product::<u64>();
326 }
327 NodeKind::Copy { .. } => copy_count += 1,
328 NodeKind::Barrier { .. } => barrier_count += 1,
329 }
330 }
331 }
332
333 Ok(ExecutionPlan {
334 order,
335 total_dispatch_groups,
336 barrier_count,
337 copy_count,
338 })
339 }
340
341 pub fn validate(&self) -> Result<(), GraphError> {
350 for node in self.nodes.values() {
351 match &node.kind {
352 NodeKind::Kernel { .. } | NodeKind::Copy { .. } => {
353 if node.bindings.is_empty() {
354 return Err(GraphError::MissingBinding {
355 node_id: node.id,
356 resource: "<any>".to_string(),
357 });
358 }
359 }
360 NodeKind::Barrier { .. } => {} }
362 }
363 Ok(())
364 }
365
366 #[must_use]
368 pub fn node_count(&self) -> usize {
369 self.nodes.len()
370 }
371
372 #[must_use]
374 pub fn edge_count(&self) -> usize {
375 self.adj.values().map(|s| s.len()).sum()
376 }
377
378 #[must_use]
380 pub fn node(&self, id: u32) -> Option<&GraphNode> {
381 self.nodes.get(&id)
382 }
383
384 pub fn predecessors(&self, node_id: u32) -> Result<Vec<u32>, GraphError> {
391 if !self.nodes.contains_key(&node_id) {
392 return Err(GraphError::NodeNotFound(node_id));
393 }
394 Ok(self
395 .radj
396 .get(&node_id)
397 .map_or(vec![], |s| s.iter().copied().collect()))
398 }
399
400 pub fn successors(&self, node_id: u32) -> Result<Vec<u32>, GraphError> {
407 if !self.nodes.contains_key(&node_id) {
408 return Err(GraphError::NodeNotFound(node_id));
409 }
410 Ok(self
411 .adj
412 .get(&node_id)
413 .map_or(vec![], |s| s.iter().copied().collect()))
414 }
415
416 fn is_reachable(&self, start: u32, target: u32) -> bool {
420 if start == target {
421 return true;
422 }
423 let mut visited = BTreeSet::new();
424 let mut queue = VecDeque::new();
425 queue.push_back(start);
426 while let Some(cur) = queue.pop_front() {
427 if visited.contains(&cur) {
428 continue;
429 }
430 visited.insert(cur);
431 if let Some(succs) = self.adj.get(&cur) {
432 for &s in succs {
433 if s == target {
434 return true;
435 }
436 queue.push_back(s);
437 }
438 }
439 }
440 false
441 }
442}
443
444impl Default for ComputeGraph {
445 fn default() -> Self {
446 Self::new()
447 }
448}
449
450#[cfg(test)]
453mod tests {
454 use super::*;
455
456 fn kernel_node(id: u32, dispatch: [u32; 3]) -> GraphNode {
457 GraphNode::new(
458 id,
459 format!("kernel_{id}"),
460 NodeKind::Kernel {
461 entry_point: format!("main_{id}"),
462 dispatch,
463 },
464 )
465 }
466
467 fn copy_node(id: u32, src: u32, dst: u32, bytes: usize) -> GraphNode {
468 GraphNode::new(
469 id,
470 format!("copy_{id}"),
471 NodeKind::Copy {
472 src_buffer: src,
473 dst_buffer: dst,
474 byte_count: bytes,
475 },
476 )
477 }
478
479 fn barrier_node(id: u32) -> GraphNode {
480 GraphNode::new(
481 id,
482 format!("barrier_{id}"),
483 NodeKind::Barrier {
484 src_stage: PipelineStageFlags::COMPUTE_SHADER,
485 dst_stage: PipelineStageFlags::TRANSFER,
486 },
487 )
488 }
489
490 fn simple_binding(name: &str, resource_id: u32) -> ResourceBinding {
491 ResourceBinding {
492 name: name.to_string(),
493 resource_id,
494 access: ResourceAccess::ReadWrite,
495 }
496 }
497
498 #[test]
501 fn test_pipeline_stage_contains() {
502 let combined = PipelineStageFlags::COMPUTE_SHADER.union(PipelineStageFlags::TRANSFER);
503 assert!(combined.contains(PipelineStageFlags::COMPUTE_SHADER));
504 assert!(combined.contains(PipelineStageFlags::TRANSFER));
505 assert!(!combined.contains(PipelineStageFlags::HOST));
506 }
507
508 #[test]
509 fn test_pipeline_stage_all_contains_any() {
510 assert!(PipelineStageFlags::ALL.contains(PipelineStageFlags::COMPUTE_SHADER));
511 assert!(PipelineStageFlags::ALL.contains(PipelineStageFlags::HOST));
512 }
513
514 #[test]
517 fn test_add_node_and_count() {
518 let mut g = ComputeGraph::new();
519 g.add_node(kernel_node(1, [4, 1, 1])).unwrap();
520 g.add_node(barrier_node(2)).unwrap();
521 assert_eq!(g.node_count(), 2);
522 }
523
524 #[test]
525 fn test_add_duplicate_node_error() {
526 let mut g = ComputeGraph::new();
527 g.add_node(kernel_node(1, [1, 1, 1])).unwrap();
528 let err = g.add_node(kernel_node(1, [2, 2, 2]));
529 assert!(matches!(err, Err(GraphError::DuplicateNode(1))));
530 }
531
532 #[test]
533 fn test_add_edge_increments_count() {
534 let mut g = ComputeGraph::new();
535 g.add_node(kernel_node(1, [1, 1, 1])).unwrap();
536 g.add_node(kernel_node(2, [1, 1, 1])).unwrap();
537 g.add_edge(1, 2).unwrap();
538 assert_eq!(g.edge_count(), 1);
539 }
540
541 #[test]
542 fn test_add_edge_unknown_node_error() {
543 let mut g = ComputeGraph::new();
544 g.add_node(kernel_node(1, [1, 1, 1])).unwrap();
545 assert!(matches!(
546 g.add_edge(1, 99),
547 Err(GraphError::NodeNotFound(99))
548 ));
549 }
550
551 #[test]
552 fn test_add_cyclic_edge_error() {
553 let mut g = ComputeGraph::new();
554 g.add_node(kernel_node(1, [1, 1, 1])).unwrap();
555 g.add_node(kernel_node(2, [1, 1, 1])).unwrap();
556 g.add_edge(1, 2).unwrap();
557 let err = g.add_edge(2, 1);
558 assert!(matches!(
559 err,
560 Err(GraphError::CyclicEdge { from: 2, to: 1 })
561 ));
562 }
563
564 #[test]
567 fn test_execution_order_single_node() {
568 let mut g = ComputeGraph::new();
569 g.add_node(kernel_node(5, [8, 1, 1])).unwrap();
570 let plan = g.execution_order().unwrap();
571 assert_eq!(plan.order, vec![5]);
572 assert_eq!(plan.total_dispatch_groups, 8);
573 }
574
575 #[test]
576 fn test_execution_order_linear_chain() {
577 let mut g = ComputeGraph::new();
578 for id in [1, 2, 3] {
579 g.add_node(kernel_node(id, [2, 1, 1])).unwrap();
580 }
581 g.add_edge(1, 2).unwrap();
582 g.add_edge(2, 3).unwrap();
583 let plan = g.execution_order().unwrap();
584 assert_eq!(plan.order, vec![1, 2, 3]);
585 assert_eq!(plan.total_dispatch_groups, 6);
586 }
587
588 #[test]
589 fn test_execution_order_with_barrier_and_copy() {
590 let mut g = ComputeGraph::new();
592 g.add_node(kernel_node(1, [4, 4, 1])).unwrap();
593 g.add_node(barrier_node(2)).unwrap();
594 g.add_node(copy_node(3, 0, 1, 1024)).unwrap();
595 g.add_edge(1, 2).unwrap();
596 g.add_edge(2, 3).unwrap();
597 let plan = g.execution_order().unwrap();
598 assert_eq!(plan.order, vec![1, 2, 3]);
599 assert_eq!(plan.barrier_count, 1);
600 assert_eq!(plan.copy_count, 1);
601 assert_eq!(plan.total_dispatch_groups, 16);
602 }
603
604 #[test]
605 fn test_execution_order_independent_nodes_sorted_by_id() {
606 let mut g = ComputeGraph::new();
607 for id in [5, 3, 1] {
608 g.add_node(kernel_node(id, [1, 1, 1])).unwrap();
609 }
610 let plan = g.execution_order().unwrap();
611 assert_eq!(plan.order, vec![1, 3, 5]);
612 }
613
614 #[test]
617 fn test_bind_resource_to_kernel() {
618 let mut g = ComputeGraph::new();
619 g.add_node(kernel_node(1, [1, 1, 1])).unwrap();
620 g.bind_resource(1, simple_binding("input", 10)).unwrap();
621 let node = g.node(1).unwrap();
622 assert_eq!(node.bindings.len(), 1);
623 }
624
625 #[test]
626 fn test_bind_resource_to_barrier_fails() {
627 let mut g = ComputeGraph::new();
628 g.add_node(barrier_node(1)).unwrap();
629 let err = g.bind_resource(1, simple_binding("buf", 0));
630 assert!(matches!(err, Err(GraphError::IncompatibleBinding { .. })));
631 }
632
633 #[test]
634 fn test_bind_resource_unknown_node_fails() {
635 let mut g = ComputeGraph::new();
636 let err = g.bind_resource(99, simple_binding("buf", 0));
637 assert!(matches!(err, Err(GraphError::NodeNotFound(99))));
638 }
639
640 #[test]
643 fn test_validate_passes_when_all_bound() {
644 let mut g = ComputeGraph::new();
645 let mut n = kernel_node(1, [1, 1, 1]);
646 n.bindings.push(simple_binding("buf", 0));
647 g.add_node(n).unwrap();
648 g.add_node(barrier_node(2)).unwrap();
649 g.add_edge(1, 2).unwrap();
650 assert!(g.validate().is_ok());
651 }
652
653 #[test]
654 fn test_validate_fails_when_kernel_has_no_bindings() {
655 let mut g = ComputeGraph::new();
656 g.add_node(kernel_node(1, [1, 1, 1])).unwrap();
657 assert!(matches!(
658 g.validate(),
659 Err(GraphError::MissingBinding { node_id: 1, .. })
660 ));
661 }
662
663 #[test]
666 fn test_predecessors_and_successors() {
667 let mut g = ComputeGraph::new();
668 for id in [1, 2, 3] {
669 g.add_node(kernel_node(id, [1, 1, 1])).unwrap();
670 }
671 g.add_edge(1, 3).unwrap();
672 g.add_edge(2, 3).unwrap();
673 let mut preds = g.predecessors(3).unwrap();
674 preds.sort_unstable();
675 assert_eq!(preds, vec![1, 2]);
676 let succs_1 = g.successors(1).unwrap();
677 assert_eq!(succs_1, vec![3]);
678 }
679
680 #[test]
681 fn test_predecessors_unknown_node_error() {
682 let g = ComputeGraph::new();
683 assert!(matches!(
684 g.predecessors(42),
685 Err(GraphError::NodeNotFound(42))
686 ));
687 }
688}