cot/utils/
graph.rs

1use thiserror::Error;
2
3#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Error)]
4#[error("cycle detected in the graph")]
5pub struct CycleDetected;
6
7#[doc(hidden)] // not part of public API, used in the Cot CLI
8pub fn apply_permutation<T>(items: &mut [T], order: &mut [usize]) {
9    for i in 0..order.len() {
10        let mut current = i;
11        let mut next = order[current];
12
13        while next != i {
14            // process the cycle
15            items.swap(current, next);
16            order[current] = current;
17
18            current = next;
19            next = order[current];
20        }
21
22        order[current] = current;
23    }
24}
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
27pub(crate) struct Graph {
28    vertex_edges: Vec<Vec<usize>>,
29}
30
31impl Graph {
32    #[must_use]
33    pub(crate) fn new(vertex_num: usize) -> Self {
34        Self {
35            vertex_edges: vec![Vec::new(); vertex_num],
36        }
37    }
38
39    pub(crate) fn add_edge(&mut self, from: usize, to: usize) {
40        self.vertex_edges[from].push(to);
41    }
42
43    #[must_use]
44    pub(crate) fn vertex_num(&self) -> usize {
45        self.vertex_edges.len()
46    }
47
48    pub(crate) fn toposort(&mut self) -> Result<Vec<usize>, CycleDetected> {
49        let mut visited = vec![VisitedStatus::NotVisited; self.vertex_num()];
50        let mut sorted_indices_stack = Vec::with_capacity(self.vertex_num());
51
52        for index in (0..self.vertex_num()).rev() {
53            self.toposort_visit(index, &mut visited, &mut sorted_indices_stack)?;
54        }
55
56        assert_eq!(sorted_indices_stack.len(), self.vertex_num());
57
58        sorted_indices_stack.reverse();
59        Ok(sorted_indices_stack)
60    }
61
62    #[expect(clippy::indexing_slicing)]
63    fn toposort_visit(
64        &self,
65        index: usize,
66        visited: &mut Vec<VisitedStatus>,
67        sorted_indices_stack: &mut Vec<usize>,
68    ) -> Result<(), CycleDetected> {
69        match visited[index] {
70            VisitedStatus::Visited => return Ok(()),
71            VisitedStatus::Visiting => {
72                return Err(CycleDetected);
73            }
74            VisitedStatus::NotVisited => {}
75        }
76
77        visited[index] = VisitedStatus::Visiting;
78
79        for &neighbor in &self.vertex_edges[index] {
80            self.toposort_visit(neighbor, visited, sorted_indices_stack)?;
81        }
82
83        visited[index] = VisitedStatus::Visited;
84        sorted_indices_stack.push(index);
85
86        Ok(())
87    }
88}
89
90#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
91enum VisitedStatus {
92    NotVisited,
93    Visiting,
94    Visited,
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100
101    #[test]
102    fn graph_toposort_stable() {
103        let mut graph = Graph::new(8);
104        let sorted_indices = graph.toposort().unwrap();
105        assert_eq!(sorted_indices, vec![0, 1, 2, 3, 4, 5, 6, 7]);
106    }
107
108    #[test]
109    fn graph_toposort() {
110        let mut graph = Graph::new(8);
111        graph.add_edge(5, 3);
112        graph.add_edge(1, 3);
113        graph.add_edge(1, 2);
114        graph.add_edge(4, 2);
115        graph.add_edge(4, 6);
116        graph.add_edge(3, 0);
117        graph.add_edge(3, 7);
118        graph.add_edge(3, 6);
119        graph.add_edge(2, 7);
120
121        let sorted_indices = graph.toposort().unwrap();
122
123        assert_eq!(sorted_indices, vec![1, 4, 2, 5, 3, 0, 6, 7]);
124    }
125
126    #[test]
127    fn graph_toposort_with_cycle() {
128        let mut graph = Graph::new(4);
129        graph.add_edge(0, 1);
130        graph.add_edge(1, 2);
131        graph.add_edge(2, 3);
132        graph.add_edge(3, 0);
133
134        assert!(matches!(graph.toposort(), Err(CycleDetected)));
135    }
136}