1pub use fast_topo::CycleError;
12pub use fast_topo::FastTopoSorter;
13
14use std::collections::{HashMap, HashSet, VecDeque};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
18pub struct TopoNodeId(
19 pub usize,
21);
22
23impl std::fmt::Display for TopoNodeId {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 write!(f, "Node({})", self.0)
26 }
27}
28
29#[derive(Debug, Clone, PartialEq, Eq)]
31pub enum TopoError {
32 CycleDetected(
34 Vec<TopoNodeId>,
36 ),
37 NodeNotFound(
39 TopoNodeId,
41 ),
42 EmptyGraph,
44}
45
46impl std::fmt::Display for TopoError {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 match self {
49 Self::CycleDetected(nodes) => {
50 write!(f, "Cycle detected involving {} nodes", nodes.len())
51 }
52 Self::NodeNotFound(id) => write!(f, "Node {id} not found"),
53 Self::EmptyGraph => write!(f, "Graph is empty"),
54 }
55 }
56}
57
58pub struct TopoGraph {
60 adjacency: HashMap<TopoNodeId, HashSet<TopoNodeId>>,
62 reverse: HashMap<TopoNodeId, HashSet<TopoNodeId>>,
64}
65
66impl TopoGraph {
67 pub fn new() -> Self {
69 Self {
70 adjacency: HashMap::new(),
71 reverse: HashMap::new(),
72 }
73 }
74
75 pub fn add_node(&mut self, id: TopoNodeId) {
77 self.adjacency.entry(id).or_default();
78 self.reverse.entry(id).or_default();
79 }
80
81 pub fn add_edge(&mut self, from: TopoNodeId, to: TopoNodeId) {
83 self.add_node(from);
84 self.add_node(to);
85 self.adjacency.entry(from).or_default().insert(to);
86 self.reverse.entry(to).or_default().insert(from);
87 }
88
89 pub fn node_count(&self) -> usize {
91 self.adjacency.len()
92 }
93
94 pub fn edge_count(&self) -> usize {
96 self.adjacency.values().map(|s| s.len()).sum()
97 }
98
99 pub fn in_degree(&self, id: TopoNodeId) -> usize {
101 self.reverse.get(&id).map_or(0, |s| s.len())
102 }
103
104 pub fn out_degree(&self, id: TopoNodeId) -> usize {
106 self.adjacency.get(&id).map_or(0, |s| s.len())
107 }
108
109 pub fn sources(&self) -> Vec<TopoNodeId> {
111 let mut sources: Vec<TopoNodeId> = self
112 .adjacency
113 .keys()
114 .filter(|id| self.in_degree(**id) == 0)
115 .copied()
116 .collect();
117 sources.sort();
118 sources
119 }
120
121 pub fn sinks(&self) -> Vec<TopoNodeId> {
123 let mut sinks: Vec<TopoNodeId> = self
124 .adjacency
125 .keys()
126 .filter(|id| self.out_degree(**id) == 0)
127 .copied()
128 .collect();
129 sinks.sort();
130 sinks
131 }
132
133 pub fn sort_kahn(&self) -> Result<Vec<TopoNodeId>, TopoError> {
137 if self.adjacency.is_empty() {
138 return Err(TopoError::EmptyGraph);
139 }
140
141 let mut in_degrees: HashMap<TopoNodeId, usize> = HashMap::new();
142 for &node in self.adjacency.keys() {
143 in_degrees.insert(node, self.in_degree(node));
144 }
145
146 let mut queue: VecDeque<TopoNodeId> = in_degrees
147 .iter()
148 .filter(|(_, °)| deg == 0)
149 .map(|(&id, _)| id)
150 .collect();
151
152 let mut sorted_start: Vec<TopoNodeId> = queue.drain(..).collect();
154 sorted_start.sort();
155 queue.extend(sorted_start);
156
157 let mut result = Vec::with_capacity(self.adjacency.len());
158
159 while let Some(node) = queue.pop_front() {
160 result.push(node);
161 if let Some(successors) = self.adjacency.get(&node) {
162 let mut sorted_succ: Vec<TopoNodeId> = successors.iter().copied().collect();
163 sorted_succ.sort();
164 for succ in sorted_succ {
165 if let Some(deg) = in_degrees.get_mut(&succ) {
166 *deg -= 1;
167 if *deg == 0 {
168 queue.push_back(succ);
169 }
170 }
171 }
172 }
173 }
174
175 if result.len() != self.adjacency.len() {
176 let remaining: Vec<TopoNodeId> = self
177 .adjacency
178 .keys()
179 .filter(|id| !result.contains(id))
180 .copied()
181 .collect();
182 return Err(TopoError::CycleDetected(remaining));
183 }
184
185 Ok(result)
186 }
187
188 pub fn sort_dfs(&self) -> Result<Vec<TopoNodeId>, TopoError> {
192 if self.adjacency.is_empty() {
193 return Err(TopoError::EmptyGraph);
194 }
195
196 let mut visited: HashSet<TopoNodeId> = HashSet::new();
197 let mut in_stack: HashSet<TopoNodeId> = HashSet::new();
198 let mut result: Vec<TopoNodeId> = Vec::new();
199
200 let mut nodes: Vec<TopoNodeId> = self.adjacency.keys().copied().collect();
201 nodes.sort();
202
203 for node in &nodes {
204 if !visited.contains(node)
205 && !Self::dfs_visit(
206 *node,
207 &self.adjacency,
208 &mut visited,
209 &mut in_stack,
210 &mut result,
211 )
212 {
213 let cycle_nodes: Vec<TopoNodeId> = in_stack.into_iter().collect();
214 return Err(TopoError::CycleDetected(cycle_nodes));
215 }
216 }
217
218 result.reverse();
219 Ok(result)
220 }
221
222 fn dfs_visit(
224 node: TopoNodeId,
225 adjacency: &HashMap<TopoNodeId, HashSet<TopoNodeId>>,
226 visited: &mut HashSet<TopoNodeId>,
227 in_stack: &mut HashSet<TopoNodeId>,
228 result: &mut Vec<TopoNodeId>,
229 ) -> bool {
230 visited.insert(node);
231 in_stack.insert(node);
232
233 if let Some(successors) = adjacency.get(&node) {
234 let mut sorted_succ: Vec<TopoNodeId> = successors.iter().copied().collect();
235 sorted_succ.sort();
236 for succ in sorted_succ {
237 if in_stack.contains(&succ) {
238 return false;
239 }
240 if !visited.contains(&succ)
241 && !Self::dfs_visit(succ, adjacency, visited, in_stack, result)
242 {
243 return false;
244 }
245 }
246 }
247
248 in_stack.remove(&node);
249 result.push(node);
250 true
251 }
252
253 pub fn is_dag(&self) -> bool {
255 self.sort_kahn().is_ok()
256 }
257
258 pub fn longest_path(&self) -> Result<usize, TopoError> {
260 let order = self.sort_kahn()?;
261 let mut dist: HashMap<TopoNodeId, usize> = HashMap::new();
262 for &node in &order {
263 dist.insert(node, 0);
264 }
265
266 for &node in &order {
267 let node_dist = dist[&node];
268 if let Some(successors) = self.adjacency.get(&node) {
269 for &succ in successors {
270 let entry = dist.entry(succ).or_insert(0);
271 if node_dist + 1 > *entry {
272 *entry = node_dist + 1;
273 }
274 }
275 }
276 }
277
278 Ok(dist.values().copied().max().unwrap_or(0))
279 }
280
281 pub fn node_depths(&self) -> Result<HashMap<TopoNodeId, usize>, TopoError> {
283 let order = self.sort_kahn()?;
284 let mut depths: HashMap<TopoNodeId, usize> = HashMap::new();
285 for &node in &order {
286 depths.insert(node, 0);
287 }
288
289 for &node in &order {
290 let node_depth = depths[&node];
291 if let Some(successors) = self.adjacency.get(&node) {
292 for &succ in successors {
293 let entry = depths.entry(succ).or_insert(0);
294 if node_depth + 1 > *entry {
295 *entry = node_depth + 1;
296 }
297 }
298 }
299 }
300
301 Ok(depths)
302 }
303
304 pub fn can_reach(&self, a: TopoNodeId, b: TopoNodeId) -> bool {
306 let mut visited: HashSet<TopoNodeId> = HashSet::new();
307 let mut queue = VecDeque::new();
308 queue.push_back(a);
309
310 while let Some(current) = queue.pop_front() {
311 if current == b {
312 return true;
313 }
314 if visited.insert(current) {
315 if let Some(successors) = self.adjacency.get(¤t) {
316 for &succ in successors {
317 queue.push_back(succ);
318 }
319 }
320 }
321 }
322
323 false
324 }
325}
326
327impl Default for TopoGraph {
328 fn default() -> Self {
329 Self::new()
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 fn n(id: usize) -> TopoNodeId {
338 TopoNodeId(id)
339 }
340
341 #[test]
342 fn test_empty_graph() {
343 let graph = TopoGraph::new();
344 assert_eq!(graph.node_count(), 0);
345 assert!(matches!(graph.sort_kahn(), Err(TopoError::EmptyGraph)));
346 }
347
348 #[test]
349 fn test_single_node() {
350 let mut graph = TopoGraph::new();
351 graph.add_node(n(0));
352 let order = graph.sort_kahn().expect("sort_kahn should succeed");
353 assert_eq!(order, vec![n(0)]);
354 }
355
356 #[test]
357 fn test_linear_chain() {
358 let mut graph = TopoGraph::new();
359 graph.add_edge(n(0), n(1));
360 graph.add_edge(n(1), n(2));
361 graph.add_edge(n(2), n(3));
362 let order = graph.sort_kahn().expect("sort_kahn should succeed");
363 assert_eq!(order, vec![n(0), n(1), n(2), n(3)]);
364 }
365
366 #[test]
367 fn test_diamond_graph() {
368 let mut graph = TopoGraph::new();
369 graph.add_edge(n(0), n(1));
370 graph.add_edge(n(0), n(2));
371 graph.add_edge(n(1), n(3));
372 graph.add_edge(n(2), n(3));
373 let order = graph.sort_kahn().expect("sort_kahn should succeed");
374 assert_eq!(order[0], n(0));
375 assert_eq!(order[3], n(3));
376 }
377
378 #[test]
379 fn test_cycle_detection_kahn() {
380 let mut graph = TopoGraph::new();
381 graph.add_edge(n(0), n(1));
382 graph.add_edge(n(1), n(2));
383 graph.add_edge(n(2), n(0));
384 let result = graph.sort_kahn();
385 assert!(matches!(result, Err(TopoError::CycleDetected(_))));
386 }
387
388 #[test]
389 fn test_cycle_detection_dfs() {
390 let mut graph = TopoGraph::new();
391 graph.add_edge(n(0), n(1));
392 graph.add_edge(n(1), n(2));
393 graph.add_edge(n(2), n(0));
394 let result = graph.sort_dfs();
395 assert!(matches!(result, Err(TopoError::CycleDetected(_))));
396 }
397
398 #[test]
399 fn test_dfs_sort_matches_kahn() {
400 let mut graph = TopoGraph::new();
401 graph.add_edge(n(0), n(1));
402 graph.add_edge(n(0), n(2));
403 graph.add_edge(n(1), n(3));
404 graph.add_edge(n(2), n(3));
405 let kahn = graph.sort_kahn().expect("sort_kahn should succeed");
406 let dfs = graph.sort_dfs().expect("sort_dfs should succeed");
407 assert_eq!(kahn[0], n(0));
409 assert_eq!(dfs[0], n(0));
410 assert_eq!(*kahn.last().expect("last should succeed"), n(3));
411 assert_eq!(*dfs.last().expect("last should succeed"), n(3));
412 }
413
414 #[test]
415 fn test_sources_and_sinks() {
416 let mut graph = TopoGraph::new();
417 graph.add_edge(n(0), n(2));
418 graph.add_edge(n(1), n(2));
419 graph.add_edge(n(2), n(3));
420 graph.add_edge(n(2), n(4));
421 assert_eq!(graph.sources(), vec![n(0), n(1)]);
422 assert_eq!(graph.sinks(), vec![n(3), n(4)]);
423 }
424
425 #[test]
426 fn test_in_out_degree() {
427 let mut graph = TopoGraph::new();
428 graph.add_edge(n(0), n(1));
429 graph.add_edge(n(0), n(2));
430 graph.add_edge(n(1), n(2));
431 assert_eq!(graph.out_degree(n(0)), 2);
432 assert_eq!(graph.in_degree(n(2)), 2);
433 assert_eq!(graph.in_degree(n(0)), 0);
434 }
435
436 #[test]
437 fn test_is_dag() {
438 let mut graph = TopoGraph::new();
439 graph.add_edge(n(0), n(1));
440 graph.add_edge(n(1), n(2));
441 assert!(graph.is_dag());
442
443 graph.add_edge(n(2), n(0));
444 assert!(!graph.is_dag());
445 }
446
447 #[test]
448 fn test_longest_path() {
449 let mut graph = TopoGraph::new();
450 graph.add_edge(n(0), n(1));
451 graph.add_edge(n(1), n(2));
452 graph.add_edge(n(0), n(2));
453 assert_eq!(
454 graph.longest_path().expect("longest_path should succeed"),
455 2
456 );
457 }
458
459 #[test]
460 fn test_node_depths() {
461 let mut graph = TopoGraph::new();
462 graph.add_edge(n(0), n(1));
463 graph.add_edge(n(0), n(2));
464 graph.add_edge(n(1), n(3));
465 graph.add_edge(n(2), n(3));
466 let depths = graph.node_depths().expect("node_depths should succeed");
467 assert_eq!(depths[&n(0)], 0);
468 assert_eq!(depths[&n(3)], 2);
469 }
470
471 #[test]
472 fn test_can_reach() {
473 let mut graph = TopoGraph::new();
474 graph.add_edge(n(0), n(1));
475 graph.add_edge(n(1), n(2));
476 assert!(graph.can_reach(n(0), n(2)));
477 assert!(!graph.can_reach(n(2), n(0)));
478 }
479
480 #[test]
481 fn test_topo_error_display() {
482 let err = TopoError::EmptyGraph;
483 assert_eq!(format!("{err}"), "Graph is empty");
484 let err2 = TopoError::NodeNotFound(n(5));
485 assert!(format!("{err2}").contains("5"));
486 }
487
488 #[test]
489 fn test_edge_count() {
490 let mut graph = TopoGraph::new();
491 graph.add_edge(n(0), n(1));
492 graph.add_edge(n(1), n(2));
493 graph.add_edge(n(0), n(2));
494 assert_eq!(graph.edge_count(), 3);
495 }
496
497 #[test]
498 fn test_node_id_display() {
499 let id = TopoNodeId(42);
500 assert_eq!(format!("{id}"), "Node(42)");
501 }
502}
503
504pub mod fast_topo {
510 use std::collections::VecDeque;
511
512 #[derive(Debug, Clone, PartialEq, Eq)]
514 pub struct CycleError {
515 pub remaining: Vec<usize>,
518 }
519
520 impl std::fmt::Display for CycleError {
521 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
522 write!(
523 f,
524 "cycle detected; {} nodes could not be ordered",
525 self.remaining.len()
526 )
527 }
528 }
529
530 impl std::error::Error for CycleError {}
531
532 pub struct FastTopoSorter {
551 n: usize,
553 adjacency: Vec<Vec<usize>>,
555 in_degree: Vec<u32>,
557 }
558
559 impl FastTopoSorter {
560 #[must_use]
562 pub fn new(n: usize) -> Self {
563 Self {
564 n,
565 adjacency: vec![Vec::new(); n],
566 in_degree: vec![0u32; n],
567 }
568 }
569
570 pub fn add_edge(&mut self, from: usize, to: usize) {
576 assert!(from < self.n, "node id {from} out of range (n={})", self.n);
577 assert!(to < self.n, "node id {to} out of range (n={})", self.n);
578 self.adjacency[from].push(to);
579 self.in_degree[to] = self.in_degree[to].saturating_add(1);
580 }
581
582 pub fn sort(&self) -> Result<Vec<usize>, CycleError> {
591 let mut deg = self.in_degree.clone();
594
595 let mut queue: VecDeque<usize> = (0..self.n).filter(|&i| deg[i] == 0).collect();
597
598 let mut result = Vec::with_capacity(self.n);
599
600 while let Some(node) = queue.pop_front() {
601 result.push(node);
602 for &succ in &self.adjacency[node] {
603 deg[succ] = deg[succ].saturating_sub(1);
606 if deg[succ] == 0 {
607 queue.push_back(succ);
608 }
609 }
610 }
611
612 if result.len() != self.n {
613 let remaining: Vec<usize> = (0..self.n).filter(|&i| deg[i] > 0).collect();
614 return Err(CycleError { remaining });
615 }
616
617 Ok(result)
618 }
619
620 #[must_use]
622 pub fn node_count(&self) -> usize {
623 self.n
624 }
625
626 #[must_use]
628 pub fn edge_count(&self) -> usize {
629 self.adjacency.iter().map(|v| v.len()).sum()
630 }
631 }
632
633 #[cfg(test)]
634 mod tests {
635 use super::*;
636
637 #[test]
638 fn test_fast_topo_simple() {
639 let mut s = FastTopoSorter::new(5);
643 s.add_edge(0, 1);
644 s.add_edge(0, 2);
645 s.add_edge(1, 3);
646 s.add_edge(2, 3);
647 s.add_edge(3, 4);
648
649 let order = s.sort().expect("DAG must succeed");
650 assert_eq!(order.len(), 5, "all 5 nodes must appear");
651 assert_eq!(order[0], 0, "node 0 has no predecessors");
652 assert_eq!(
653 *order.last().expect("non-empty"),
654 4,
655 "node 4 is the only sink"
656 );
657
658 let mut pos = vec![0usize; 5];
660 for (rank, &node) in order.iter().enumerate() {
661 pos[node] = rank;
662 }
663 assert!(pos[0] < pos[1]);
664 assert!(pos[0] < pos[2]);
665 assert!(pos[1] < pos[3]);
666 assert!(pos[2] < pos[3]);
667 assert!(pos[3] < pos[4]);
668 }
669
670 #[test]
671 fn test_fast_topo_cycle_detected() {
672 let mut s = FastTopoSorter::new(3);
674 s.add_edge(0, 1);
675 s.add_edge(1, 2);
676 s.add_edge(2, 0);
677
678 let result = s.sort();
679 assert!(result.is_err(), "cycle must produce an error");
680 let err = result.expect_err("expected CycleError");
681 assert_eq!(
682 err.remaining.len(),
683 3,
684 "all three nodes are stuck in the cycle"
685 );
686 }
687
688 #[test]
689 fn test_fast_topo_large() {
690 let n = 10_000usize;
692 let mut s = FastTopoSorter::new(n);
693 for i in 0..n - 1 {
694 s.add_edge(i, i + 1);
695 }
696
697 let start = std::time::Instant::now();
698 let order = s.sort().expect("linear chain must sort cleanly");
699 let elapsed = start.elapsed();
700
701 assert_eq!(order.len(), n, "all {n} nodes must appear");
702 assert!(
703 elapsed.as_millis() < 50,
704 "sort must complete in < 50 ms, took {} ms",
705 elapsed.as_millis()
706 );
707
708 for (rank, &node) in order.iter().enumerate() {
710 assert_eq!(node, rank, "linear chain must be in strict ascending order");
711 }
712 }
713 }
714}