ade_topological_sort/
lib.rs

1use ade_common::INVALID_KEY_SEQUENCE;
2use ade_traits::{EdgeTrait, GraphViewTrait, NodeTrait};
3use fixedbitset::FixedBitSet;
4
5pub fn topological_sort<N, E, K, F>(
6    graph: &impl GraphViewTrait<N, E>,
7    key_fn: Option<F>,
8) -> Result<Vec<u32>, String>
9where
10    N: NodeTrait,
11    E: EdgeTrait,
12    K: Ord,
13    F: Fn(&N) -> K,
14{
15    fn dfs<N, E, K, F>(
16        node_key: u32,
17        graph: &impl GraphViewTrait<N, E>,
18        visiting: &mut FixedBitSet,
19        visited: &mut FixedBitSet,
20        result: &mut Vec<u32>,
21        key_fn: &Option<F>,
22    ) -> Result<(), String>
23    where
24        N: NodeTrait,
25        E: EdgeTrait,
26        K: Ord,
27        F: Fn(&N) -> K,
28    {
29        let idx = node_key as usize;
30
31        if visiting[idx] {
32            return Err("Graph contains a cycle".into());
33        }
34        if visited[idx] {
35            return Ok(());
36        }
37
38        visiting.set(idx, true);
39
40        match key_fn {
41            Some(f) => {
42                let mut successors = graph.get_successors(node_key).collect::<Vec<_>>();
43                successors.sort_by_key(|n| std::cmp::Reverse(f(n)));
44                for successor in successors {
45                    dfs(successor.key(), graph, visiting, visited, result, key_fn)?;
46                }
47            }
48            None => {
49                for successor_key in graph.get_successors_keys(node_key) {
50                    dfs(successor_key, graph, visiting, visited, result, key_fn)?;
51                }
52            }
53        }
54
55        visiting.set(idx, false);
56        visited.set(idx, true);
57        result.push(node_key);
58        Ok(())
59    }
60
61    // Panic if the graph does not have sequential keys
62    if !graph.has_sequential_keys() {
63        panic!("{}", INVALID_KEY_SEQUENCE);
64    }
65
66    let node_count = graph.get_node_keys().count();
67    let mut visiting = FixedBitSet::with_capacity(node_count);
68    let mut visited = FixedBitSet::with_capacity(node_count);
69    let mut result = Vec::new();
70
71    match &key_fn {
72        Some(f) => {
73            let mut nodes: Vec<_> = graph.get_nodes().collect();
74            nodes.sort_by_key(|n| std::cmp::Reverse(f(n)));
75            for node in nodes {
76                let idx = node.key() as usize;
77                if !visited[idx] {
78                    dfs(
79                        node.key(),
80                        graph,
81                        &mut visiting,
82                        &mut visited,
83                        &mut result,
84                        &key_fn,
85                    )?;
86                }
87            }
88        }
89        None => {
90            for node_key in graph.get_node_keys() {
91                let idx = node_key as usize;
92                if !visited[idx] {
93                    dfs(
94                        node_key,
95                        graph,
96                        &mut visiting,
97                        &mut visited,
98                        &mut result,
99                        &key_fn,
100                    )?;
101                }
102            }
103        }
104    }
105
106    result.reverse();
107    Ok(result)
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113    use ade_graph::implementations::Edge;
114    use ade_graph::implementations::Graph;
115    use ade_graph::implementations::Node;
116    use ade_graph::utils::build::build_graph;
117    use ade_graph_generators::generate_random_graph_data;
118    //use hierarchy::node::{HierarchyNode, NodeType};
119
120    #[test]
121    fn test_topological_sort() {
122        let n1 = Node::new(0);
123        let n2 = Node::new(1);
124        let n3 = Node::new(2);
125
126        let e1 = Edge::new(0, 1);
127        let e2 = Edge::new(1, 2);
128
129        let graph = Graph::<Node, Edge>::new(vec![n1, n2, n3], vec![e1, e2]);
130        let sorted = topological_sort::<Node, Edge, u32, fn(&Node) -> u32>(&graph, None).unwrap();
131
132        assert_eq!(sorted, vec![0, 1, 2]);
133    }
134
135    #[test]
136    fn test_topological_sort_non_sequential_keys() {
137        use ade_common::assert_panics_with;
138
139        let graph = build_graph(vec![1, 3, 5], vec![(1, 3), (3, 5), (5, 1)]);
140        assert_panics_with!(
141            topological_sort::<Node, Edge, u32, fn(&Node) -> u32>(&graph, None),
142            ade_common::INVALID_KEY_SEQUENCE
143        );
144    }
145
146    #[test]
147    fn test_topological_sort_cycle() {
148        let n1 = Node::new(0);
149        let n2 = Node::new(1);
150
151        let e1 = Edge::new(0, 1);
152        let e2 = Edge::new(1, 0); // creates a cycle
153
154        let graph = Graph::<Node, Edge>::new(vec![n1, n2], vec![e1, e2]);
155
156        let result = topological_sort::<Node, Edge, u32, fn(&Node) -> u32>(&graph, None);
157        assert!(result.is_err());
158        assert_eq!(result.unwrap_err(), "Graph contains a cycle");
159    }
160
161    #[test]
162    fn test_topological_sort_with_compare_by_key_1() {
163        let graph1 = build_graph(vec![0, 1, 2], vec![(0, 1), (0, 2)]);
164
165        assert_eq!(
166            topological_sort::<Node, Edge, u32, _>(&graph1, Some(|n: &Node| (n.key() as u32)),)
167                .unwrap(),
168            vec![0, 1, 2]
169        );
170
171        assert_eq!(
172            topological_sort::<Node, Edge, i32, _>(&graph1, Some(|n: &Node| -(n.key() as i32)),)
173                .unwrap(),
174            vec![0, 2, 1]
175        );
176
177        let graph2 = build_graph(vec![0, 1, 2], vec![(0, 2), (1, 2)]);
178
179        assert_eq!(
180            topological_sort::<Node, Edge, u32, _>(&graph2, Some(|n: &Node| (n.key() as u32)),)
181                .unwrap(),
182            vec![0, 1, 2]
183        );
184
185        assert_eq!(
186            topological_sort::<Node, Edge, i32, _>(&graph2, Some(|n: &Node| -(n.key() as i32)),)
187                .unwrap(),
188            vec![1, 0, 2]
189        );
190
191        let graph3 = build_graph(vec![0, 1, 2, 3, 4], vec![(0, 1), (0, 4), (2, 4), (2, 3)]);
192
193        assert_eq!(
194            topological_sort::<Node, Edge, u32, _>(&graph3, Some(|n: &Node| (n.key() as u32)),)
195                .unwrap(),
196            vec![0, 1, 2, 3, 4]
197        );
198
199        assert_eq!(
200            topological_sort::<Node, Edge, i32, _>(&graph3, Some(|n: &Node| -(n.key() as i32)),)
201                .unwrap(),
202            vec![2, 3, 0, 4, 1]
203        );
204    }
205
206    #[test]
207    fn test_topological_sort_random_graph() {
208        let (nodes, edges) = generate_random_graph_data(20, 20, 3);
209        let graph = build_graph(nodes, edges);
210        let sorting = topological_sort::<Node, Edge, u32, fn(&Node) -> u32>(&graph, None);
211        assert!(sorting.is_ok());
212    }
213
214    // #[test]
215    // fn test_topological_sort_with_compare_by_layer() {
216    //     let mut n1 = HierarchyNode::new(1, NodeType::Base);
217    //     let mut n2 = HierarchyNode::new(2, NodeType::Base);
218    //     let mut n3 = HierarchyNode::new(3, NodeType::Base);
219
220    //     n1.set_layer(3);
221    //     n2.set_layer(2);
222    //     n3.set_layer(1);
223
224    //     let e1 = Edge::new(1, 2);
225    //     let e2 = Edge::new(1, 3);
226
227    //     let graph: Graph<HierarchyNode, Edge> =
228    //         Graph::<HierarchyNode, Edge>::new(vec![n1, n2, n3], vec![e1, e2]);
229
230    //     let sorted = topological_sort(
231    //         &graph,
232    //         Some(|a: &HierarchyNode, b: &HierarchyNode| a.layer().cmp(&b.layer())),
233    //     )
234    //     .unwrap();
235
236    //     assert_eq!(sorted, vec![1, 3, 2]);
237    // }
238}