1use ade_common::INVALID_KEY_SEQUENCE;
2use ade_traits::{EdgeTrait, GraphViewTrait, NodeTrait};
3use fixedbitset::FixedBitSet;
4
5pub fn topological_sort<N, E, K, F>(
6 graph: &impl GraphViewTrait<N, E>,
7 key_fn: Option<F>,
8) -> Result<Vec<u32>, String>
9where
10 N: NodeTrait,
11 E: EdgeTrait,
12 K: Ord,
13 F: Fn(&N) -> K,
14{
15 fn dfs<N, E, K, F>(
16 node_key: u32,
17 graph: &impl GraphViewTrait<N, E>,
18 visiting: &mut FixedBitSet,
19 visited: &mut FixedBitSet,
20 result: &mut Vec<u32>,
21 key_fn: &Option<F>,
22 ) -> Result<(), String>
23 where
24 N: NodeTrait,
25 E: EdgeTrait,
26 K: Ord,
27 F: Fn(&N) -> K,
28 {
29 let idx = node_key as usize;
30
31 if visiting[idx] {
32 return Err("Graph contains a cycle".into());
33 }
34 if visited[idx] {
35 return Ok(());
36 }
37
38 visiting.set(idx, true);
39
40 match key_fn {
41 Some(f) => {
42 let mut successors = graph.get_successors(node_key).collect::<Vec<_>>();
43 successors.sort_by_key(|n| std::cmp::Reverse(f(n)));
44 for successor in successors {
45 dfs(successor.key(), graph, visiting, visited, result, key_fn)?;
46 }
47 }
48 None => {
49 for successor_key in graph.get_successors_keys(node_key) {
50 dfs(successor_key, graph, visiting, visited, result, key_fn)?;
51 }
52 }
53 }
54
55 visiting.set(idx, false);
56 visited.set(idx, true);
57 result.push(node_key);
58 Ok(())
59 }
60
61 if !graph.has_sequential_keys() {
63 panic!("{}", INVALID_KEY_SEQUENCE);
64 }
65
66 let node_count = graph.get_node_keys().count();
67 let mut visiting = FixedBitSet::with_capacity(node_count);
68 let mut visited = FixedBitSet::with_capacity(node_count);
69 let mut result = Vec::new();
70
71 match &key_fn {
72 Some(f) => {
73 let mut nodes: Vec<_> = graph.get_nodes().collect();
74 nodes.sort_by_key(|n| std::cmp::Reverse(f(n)));
75 for node in nodes {
76 let idx = node.key() as usize;
77 if !visited[idx] {
78 dfs(
79 node.key(),
80 graph,
81 &mut visiting,
82 &mut visited,
83 &mut result,
84 &key_fn,
85 )?;
86 }
87 }
88 }
89 None => {
90 for node_key in graph.get_node_keys() {
91 let idx = node_key as usize;
92 if !visited[idx] {
93 dfs(
94 node_key,
95 graph,
96 &mut visiting,
97 &mut visited,
98 &mut result,
99 &key_fn,
100 )?;
101 }
102 }
103 }
104 }
105
106 result.reverse();
107 Ok(result)
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113 use ade_graph::implementations::Edge;
114 use ade_graph::implementations::Graph;
115 use ade_graph::implementations::Node;
116 use ade_graph::utils::build::build_graph;
117 use ade_graph_generators::generate_random_graph_data;
118 #[test]
121 fn test_topological_sort() {
122 let n1 = Node::new(0);
123 let n2 = Node::new(1);
124 let n3 = Node::new(2);
125
126 let e1 = Edge::new(0, 1);
127 let e2 = Edge::new(1, 2);
128
129 let graph = Graph::<Node, Edge>::new(vec![n1, n2, n3], vec![e1, e2]);
130 let sorted = topological_sort::<Node, Edge, u32, fn(&Node) -> u32>(&graph, None).unwrap();
131
132 assert_eq!(sorted, vec![0, 1, 2]);
133 }
134
135 #[test]
136 fn test_topological_sort_non_sequential_keys() {
137 use ade_common::assert_panics_with;
138
139 let graph = build_graph(vec![1, 3, 5], vec![(1, 3), (3, 5), (5, 1)]);
140 assert_panics_with!(
141 topological_sort::<Node, Edge, u32, fn(&Node) -> u32>(&graph, None),
142 ade_common::INVALID_KEY_SEQUENCE
143 );
144 }
145
146 #[test]
147 fn test_topological_sort_cycle() {
148 let n1 = Node::new(0);
149 let n2 = Node::new(1);
150
151 let e1 = Edge::new(0, 1);
152 let e2 = Edge::new(1, 0); let graph = Graph::<Node, Edge>::new(vec![n1, n2], vec![e1, e2]);
155
156 let result = topological_sort::<Node, Edge, u32, fn(&Node) -> u32>(&graph, None);
157 assert!(result.is_err());
158 assert_eq!(result.unwrap_err(), "Graph contains a cycle");
159 }
160
161 #[test]
162 fn test_topological_sort_with_compare_by_key_1() {
163 let graph1 = build_graph(vec![0, 1, 2], vec![(0, 1), (0, 2)]);
164
165 assert_eq!(
166 topological_sort::<Node, Edge, u32, _>(&graph1, Some(|n: &Node| (n.key() as u32)),)
167 .unwrap(),
168 vec![0, 1, 2]
169 );
170
171 assert_eq!(
172 topological_sort::<Node, Edge, i32, _>(&graph1, Some(|n: &Node| -(n.key() as i32)),)
173 .unwrap(),
174 vec![0, 2, 1]
175 );
176
177 let graph2 = build_graph(vec![0, 1, 2], vec![(0, 2), (1, 2)]);
178
179 assert_eq!(
180 topological_sort::<Node, Edge, u32, _>(&graph2, Some(|n: &Node| (n.key() as u32)),)
181 .unwrap(),
182 vec![0, 1, 2]
183 );
184
185 assert_eq!(
186 topological_sort::<Node, Edge, i32, _>(&graph2, Some(|n: &Node| -(n.key() as i32)),)
187 .unwrap(),
188 vec![1, 0, 2]
189 );
190
191 let graph3 = build_graph(vec![0, 1, 2, 3, 4], vec![(0, 1), (0, 4), (2, 4), (2, 3)]);
192
193 assert_eq!(
194 topological_sort::<Node, Edge, u32, _>(&graph3, Some(|n: &Node| (n.key() as u32)),)
195 .unwrap(),
196 vec![0, 1, 2, 3, 4]
197 );
198
199 assert_eq!(
200 topological_sort::<Node, Edge, i32, _>(&graph3, Some(|n: &Node| -(n.key() as i32)),)
201 .unwrap(),
202 vec![2, 3, 0, 4, 1]
203 );
204 }
205
206 #[test]
207 fn test_topological_sort_random_graph() {
208 let (nodes, edges) = generate_random_graph_data(20, 20, 3);
209 let graph = build_graph(nodes, edges);
210 let sorting = topological_sort::<Node, Edge, u32, fn(&Node) -> u32>(&graph, None);
211 assert!(sorting.is_ok());
212 }
213
214 }