1use std::{
2 collections::{HashMap, HashSet},
3 fmt::Debug,
4 hash::Hash,
5};
6
7#[derive(Debug)]
8pub struct AcyclicDirectedGraph<'g, ID, T> {
9 pub(crate) nodes: HashMap<&'g ID, &'g T>,
10 edges: HashMap<&'g ID, HashSet<&'g ID>>,
11}
12
13impl<'g, ID, T> AcyclicDirectedGraph<'g, ID, T>
14where
15 ID: Hash + Eq,
16{
17 pub fn new(nodes: HashMap<&'g ID, &'g T>, edges: HashMap<&'g ID, HashSet<&'g ID>>) -> Self {
18 Self { nodes, edges }
19 }
20
21 pub fn transitive_reduction(&self) -> MinimalAcyclicDirectedGraph<'g, ID, T> {
24 let reachable = {
25 let mut reachable: HashMap<&ID, HashSet<&ID>> = HashMap::new();
26
27 for id in self.nodes.keys() {
28 if reachable.contains_key(id) {
29 continue;
30 }
31
32 let mut stack: Vec<&ID> = vec![*id];
33 while let Some(id) = stack.pop() {
34 if reachable.contains_key(id) {
35 continue;
36 }
37
38 let succs = match self.edges.get(id) {
39 Some(s) => s,
40 None => {
41 reachable.insert(id, HashSet::new());
42 continue;
43 }
44 };
45 if succs.is_empty() {
46 reachable.insert(id, HashSet::new());
47 continue;
48 }
49
50 if succs.iter().all(|id| reachable.contains_key(id)) {
51 let others: HashSet<&ID> = succs
52 .iter()
53 .flat_map(|id| {
54 reachable
55 .get(id)
56 .expect("We previously check that it contains the Key")
57 .iter()
58 .copied()
59 })
60 .chain(succs.iter().copied())
61 .collect();
62
63 reachable.insert(id, others);
64
65 continue;
66 }
67
68 stack.push(id);
69 stack.extend(succs.iter());
70 }
71 }
72
73 reachable
74 };
75
76 let mut remove_edges = HashMap::new();
77
78 let empty_succs = HashSet::new();
79 for node in self.nodes.keys() {
80 let edges = self.edges.get(node).unwrap_or(&empty_succs);
81
82 let succ_reachs: HashSet<_> = edges
83 .iter()
84 .flat_map(|id| {
85 reachable
86 .get(id)
87 .expect("There is an Entry in the reachable Map for every Node")
88 })
89 .collect();
90
91 let unique_edges: HashSet<&ID> = edges
92 .iter()
93 .filter(|id| !succ_reachs.contains(id))
94 .copied()
95 .collect();
96
97 let remove: HashSet<&ID> = edges.difference(&unique_edges).copied().collect();
98
99 remove_edges.insert(*node, remove);
100 }
101
102 let n_edges: HashMap<&ID, HashSet<&ID>> = self
103 .edges
104 .iter()
105 .map(|(from, to)| {
106 let filter_targets = remove_edges.get(from).expect("");
107
108 (
109 *from,
110 to.iter()
111 .filter(|t_id| !filter_targets.contains(*t_id))
112 .copied()
113 .collect(),
114 )
115 })
116 .collect();
117
118 MinimalAcyclicDirectedGraph {
119 inner: AcyclicDirectedGraph {
120 nodes: self.nodes.clone(),
121 edges: n_edges,
122 },
123 }
124 }
125
126 pub fn successors(&self, node: &ID) -> Option<&HashSet<&'g ID>> {
127 self.edges.get(node)
128 }
129}
130
131impl<'g, ID, T> PartialEq for AcyclicDirectedGraph<'g, ID, T>
132where
133 ID: PartialEq + Hash + Eq,
134 T: PartialEq,
135{
136 fn eq(&self, other: &Self) -> bool {
137 if self.nodes != other.nodes {
138 return false;
139 }
140 if self.edges != other.edges {
141 return false;
142 }
143
144 true
145 }
146}
147
148#[derive(Debug)]
154pub struct MinimalAcyclicDirectedGraph<'g, ID, T> {
155 pub(crate) inner: AcyclicDirectedGraph<'g, ID, T>,
156}
157
158impl<'g, ID, T> PartialEq for MinimalAcyclicDirectedGraph<'g, ID, T>
159where
160 ID: PartialEq + Hash + Eq,
161 T: PartialEq,
162{
163 fn eq(&self, other: &Self) -> bool {
164 self.inner == other.inner
165 }
166}
167
168impl<'g, ID, T> MinimalAcyclicDirectedGraph<'g, ID, T>
169where
170 ID: Hash + Eq,
171{
172 pub fn incoming_mapping(&self) -> HashMap<&'g ID, HashSet<&'g ID>> {
174 let mut result: HashMap<&ID, HashSet<&ID>> = HashMap::with_capacity(self.inner.nodes.len());
175 for node in self.inner.nodes.keys() {
176 result.insert(*node, HashSet::new());
177 }
178
179 for (from, to) in self.inner.edges.iter() {
180 for target in to {
181 let entry = result.entry(target);
182 let value = entry.or_insert_with(HashSet::new);
183 value.insert(*from);
184 }
185 }
186
187 result
188 }
189
190 pub fn outgoing(&self, node: &ID) -> Option<impl Iterator<Item = &'g ID> + '_> {
191 let targets = self.inner.edges.get(node)?;
192 Some(targets.iter().copied())
193 }
194
195 pub fn topological_sort(&self) -> Vec<&'g ID>
196 where
197 ID: Hash + Eq,
198 {
199 let incoming = self.incoming_mapping();
200
201 let mut ordering: Vec<&ID> = Vec::new();
202
203 let mut nodes: Vec<_> = self.inner.nodes.keys().copied().collect();
204
205 while !nodes.is_empty() {
206 let mut potential: Vec<(usize, &ID)> = nodes
207 .iter()
208 .enumerate()
209 .filter(|(_, id)| match incoming.get(*id) {
210 Some(in_edges) => in_edges.iter().all(|id| ordering.contains(id)),
211 None => true,
212 })
213 .map(|(i, id)| (i, *id))
214 .collect();
215
216 if potential.len() == 1 {
221 let (index, entry) = potential
222 .pop()
223 .expect("We previously checked that there is at least one item in it");
224 ordering.push(entry);
225 nodes.remove(index);
226 continue;
227 }
228
229 potential.sort_by(|(_, a), (_, b)| {
230 let a_incoming = match incoming.get(a) {
231 Some(i) => i,
232 None => return std::cmp::Ordering::Less,
233 };
234 let a_first_index = ordering
235 .iter()
236 .enumerate()
237 .find(|(_, id)| a_incoming.contains(*id))
238 .map(|(i, _)| i);
239
240 let b_incoming = match incoming.get(b) {
241 Some(i) => i,
242 None => return std::cmp::Ordering::Greater,
243 };
244 let b_first_index = ordering
245 .iter()
246 .enumerate()
247 .find(|(_, id)| b_incoming.contains(*id))
248 .map(|(i, _)| i);
249
250 a_first_index.cmp(&b_first_index)
251 });
252
253 let (_, entry) = potential.remove(0);
254 let index = nodes
255 .iter()
256 .enumerate()
257 .find(|(_, id)| **id == entry)
258 .map(|(i, _)| i)
259 .expect("We know that the there is at least one potential entry, so we can assume that we find that entry");
260 ordering.push(entry);
261 nodes.remove(index);
262 }
263
264 ordering
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 #[test]
273 fn reduce_with_changes() {
274 let nodes: HashMap<&i32, &&str> = [(&0, &"first"), (&1, &"second"), (&2, &"third")]
275 .into_iter()
276 .collect();
277 let graph = AcyclicDirectedGraph::new(
278 nodes.clone(),
279 [
280 (&0, [&1, &2].into_iter().collect()),
281 (&1, [&2].into_iter().collect()),
282 (&2, [].into_iter().collect()),
283 ]
284 .into_iter()
285 .collect(),
286 );
287
288 let result = graph.transitive_reduction();
289
290 let expected = MinimalAcyclicDirectedGraph {
291 inner: AcyclicDirectedGraph::new(
292 nodes,
293 [
294 (&0, [&1].into_iter().collect()),
295 (&1, [&2].into_iter().collect()),
296 (&2, [].into_iter().collect()),
297 ]
298 .into_iter()
299 .collect(),
300 ),
301 };
302
303 assert_eq!(expected, result);
304 }
305
306 #[test]
307 fn incoming_mapping_linear() {
308 let graph = MinimalAcyclicDirectedGraph {
309 inner: AcyclicDirectedGraph::new(
310 [
311 (&0, &"test"),
312 (&1, &"test"),
313 (&2, &"test"),
314 (&3, &"test"),
315 (&4, &"test"),
316 ]
317 .into_iter()
318 .collect(),
319 [
320 (&0, [&1].into_iter().collect()),
321 (&1, [&2].into_iter().collect()),
322 (&2, [&3].into_iter().collect()),
323 (&3, [&4].into_iter().collect()),
324 ]
325 .into_iter()
326 .collect(),
327 ),
328 };
329
330 let mapping = graph.incoming_mapping();
331 dbg!(&mapping);
332
333 let expected: HashMap<_, HashSet<_>> = [
334 (&0, [].into_iter().collect()),
335 (&1, [&0].into_iter().collect()),
336 (&2, [&1].into_iter().collect()),
337 (&3, [&2].into_iter().collect()),
338 (&4, [&3].into_iter().collect()),
339 ]
340 .into_iter()
341 .collect();
342
343 assert_eq!(expected, mapping);
344 }
345
346 #[test]
347 fn incoming_mapping_branched() {
348 let graph = MinimalAcyclicDirectedGraph {
349 inner: AcyclicDirectedGraph::new(
350 [
351 (&0, &"test"),
352 (&1, &"test"),
353 (&2, &"test"),
354 (&3, &"test"),
355 (&4, &"test"),
356 ]
357 .into_iter()
358 .collect(),
359 [
360 (&0, [&1, &2].into_iter().collect()),
361 (&1, [&3].into_iter().collect()),
362 (&2, [&4].into_iter().collect()),
363 ]
364 .into_iter()
365 .collect(),
366 ),
367 };
368
369 let mapping = graph.incoming_mapping();
370
371 let expected: HashMap<_, HashSet<_>> = [
372 (&0, [].into_iter().collect()),
373 (&1, [&0].into_iter().collect()),
374 (&2, [&0].into_iter().collect()),
375 (&3, [&1].into_iter().collect()),
376 (&4, [&2].into_iter().collect()),
377 ]
378 .into_iter()
379 .collect();
380
381 assert_eq!(expected, mapping);
382 }
383
384 #[test]
385 fn topsort_linear() {
386 let graphs = MinimalAcyclicDirectedGraph {
387 inner: AcyclicDirectedGraph::new(
388 [
389 (&0, &"test"),
390 (&1, &"test"),
391 (&2, &"test"),
392 (&3, &"test"),
393 (&4, &"test"),
394 ]
395 .into_iter()
396 .collect(),
397 [
398 (&0, [&1].into_iter().collect()),
399 (&1, [&2].into_iter().collect()),
400 (&2, [&3].into_iter().collect()),
401 (&3, [&4].into_iter().collect()),
402 ]
403 .into_iter()
404 .collect(),
405 ),
406 };
407
408 let sort = graphs.topological_sort();
409 dbg!(&sort);
410
411 let expected = vec![&0, &1, &2, &3, &4];
412
413 assert_eq!(expected, sort);
414 }
415
416 #[test]
417 fn topsort_branched() {
418 let graphs = MinimalAcyclicDirectedGraph {
419 inner: AcyclicDirectedGraph::new(
420 [
421 (&0, &"test"),
422 (&1, &"test"),
423 (&2, &"test"),
424 (&3, &"test"),
425 (&4, &"test"),
426 ]
427 .into_iter()
428 .collect(),
429 [
430 (&0, [&1, &2].into_iter().collect()),
431 (&1, [&3].into_iter().collect()),
432 (&2, [&4].into_iter().collect()),
433 ]
434 .into_iter()
435 .collect(),
436 ),
437 };
438
439 let sort = graphs.topological_sort();
440 dbg!(&sort);
441
442 let expected1 = vec![&0, &1, &2, &3, &4];
443 let expected2 = vec![&0, &2, &1, &4, &3];
444
445 assert!(sort == expected1 || sort == expected2);
446 }
447}