termgraph/
acyclic.rs

1use std::{
2    collections::{HashMap, HashSet},
3    fmt::Debug,
4    hash::Hash,
5};
6
7#[derive(Debug)]
8pub struct AcyclicDirectedGraph<'g, ID, T> {
9    pub(crate) nodes: HashMap<&'g ID, &'g T>,
10    edges: HashMap<&'g ID, HashSet<&'g ID>>,
11}
12
13impl<'g, ID, T> AcyclicDirectedGraph<'g, ID, T>
14where
15    ID: Hash + Eq,
16{
17    pub fn new(nodes: HashMap<&'g ID, &'g T>, edges: HashMap<&'g ID, HashSet<&'g ID>>) -> Self {
18        Self { nodes, edges }
19    }
20
21    /// Performs a transitive reduction on the current acyclic graph. This means that all of the
22    /// Edges `a -> c` are removed if the Edges `a -> b` and `b -> c` exist.
23    pub fn transitive_reduction(&self) -> MinimalAcyclicDirectedGraph<'g, ID, T> {
24        let reachable = {
25            let mut reachable: HashMap<&ID, HashSet<&ID>> = HashMap::new();
26
27            for id in self.nodes.keys() {
28                if reachable.contains_key(id) {
29                    continue;
30                }
31
32                let mut stack: Vec<&ID> = vec![*id];
33                while let Some(id) = stack.pop() {
34                    if reachable.contains_key(id) {
35                        continue;
36                    }
37
38                    let succs = match self.edges.get(id) {
39                        Some(s) => s,
40                        None => {
41                            reachable.insert(id, HashSet::new());
42                            continue;
43                        }
44                    };
45                    if succs.is_empty() {
46                        reachable.insert(id, HashSet::new());
47                        continue;
48                    }
49
50                    if succs.iter().all(|id| reachable.contains_key(id)) {
51                        let others: HashSet<&ID> = succs
52                            .iter()
53                            .flat_map(|id| {
54                                reachable
55                                    .get(id)
56                                    .expect("We previously check that it contains the Key")
57                                    .iter()
58                                    .copied()
59                            })
60                            .chain(succs.iter().copied())
61                            .collect();
62
63                        reachable.insert(id, others);
64
65                        continue;
66                    }
67
68                    stack.push(id);
69                    stack.extend(succs.iter());
70                }
71            }
72
73            reachable
74        };
75
76        let mut remove_edges = HashMap::new();
77
78        let empty_succs = HashSet::new();
79        for node in self.nodes.keys() {
80            let edges = self.edges.get(node).unwrap_or(&empty_succs);
81
82            let succ_reachs: HashSet<_> = edges
83                .iter()
84                .flat_map(|id| {
85                    reachable
86                        .get(id)
87                        .expect("There is an Entry in the reachable Map for every Node")
88                })
89                .collect();
90
91            let unique_edges: HashSet<&ID> = edges
92                .iter()
93                .filter(|id| !succ_reachs.contains(id))
94                .copied()
95                .collect();
96
97            let remove: HashSet<&ID> = edges.difference(&unique_edges).copied().collect();
98
99            remove_edges.insert(*node, remove);
100        }
101
102        let n_edges: HashMap<&ID, HashSet<&ID>> = self
103            .edges
104            .iter()
105            .map(|(from, to)| {
106                let filter_targets = remove_edges.get(from).expect("");
107
108                (
109                    *from,
110                    to.iter()
111                        .filter(|t_id| !filter_targets.contains(*t_id))
112                        .copied()
113                        .collect(),
114                )
115            })
116            .collect();
117
118        MinimalAcyclicDirectedGraph {
119            inner: AcyclicDirectedGraph {
120                nodes: self.nodes.clone(),
121                edges: n_edges,
122            },
123        }
124    }
125
126    pub fn successors(&self, node: &ID) -> Option<&HashSet<&'g ID>> {
127        self.edges.get(node)
128    }
129}
130
131impl<'g, ID, T> PartialEq for AcyclicDirectedGraph<'g, ID, T>
132where
133    ID: PartialEq + Hash + Eq,
134    T: PartialEq,
135{
136    fn eq(&self, other: &Self) -> bool {
137        if self.nodes != other.nodes {
138            return false;
139        }
140        if self.edges != other.edges {
141            return false;
142        }
143
144        true
145    }
146}
147
148/// This is an acyclic directed Graph that is transitively reduced so there should be no edges in
149/// the form `a -> c` if the edges `a -> b` and `b -> c` exist.
150///
151/// This form makes the level generation easier as we can basically attempt to assign all the
152/// successors of a node to the level below the node.
153#[derive(Debug)]
154pub struct MinimalAcyclicDirectedGraph<'g, ID, T> {
155    pub(crate) inner: AcyclicDirectedGraph<'g, ID, T>,
156}
157
158impl<'g, ID, T> PartialEq for MinimalAcyclicDirectedGraph<'g, ID, T>
159where
160    ID: PartialEq + Hash + Eq,
161    T: PartialEq,
162{
163    fn eq(&self, other: &Self) -> bool {
164        self.inner == other.inner
165    }
166}
167
168impl<'g, ID, T> MinimalAcyclicDirectedGraph<'g, ID, T>
169where
170    ID: Hash + Eq,
171{
172    /// Generates a Mapping for each Vertex to Vertices that are leading to it
173    pub fn incoming_mapping(&self) -> HashMap<&'g ID, HashSet<&'g ID>> {
174        let mut result: HashMap<&ID, HashSet<&ID>> = HashMap::with_capacity(self.inner.nodes.len());
175        for node in self.inner.nodes.keys() {
176            result.insert(*node, HashSet::new());
177        }
178
179        for (from, to) in self.inner.edges.iter() {
180            for target in to {
181                let entry = result.entry(target);
182                let value = entry.or_insert_with(HashSet::new);
183                value.insert(*from);
184            }
185        }
186
187        result
188    }
189
190    pub fn outgoing(&self, node: &ID) -> Option<impl Iterator<Item = &'g ID> + '_> {
191        let targets = self.inner.edges.get(node)?;
192        Some(targets.iter().copied())
193    }
194
195    pub fn topological_sort(&self) -> Vec<&'g ID>
196    where
197        ID: Hash + Eq,
198    {
199        let incoming = self.incoming_mapping();
200
201        let mut ordering: Vec<&ID> = Vec::new();
202
203        let mut nodes: Vec<_> = self.inner.nodes.keys().copied().collect();
204
205        while !nodes.is_empty() {
206            let mut potential: Vec<(usize, &ID)> = nodes
207                .iter()
208                .enumerate()
209                .filter(|(_, id)| match incoming.get(*id) {
210                    Some(in_edges) => in_edges.iter().all(|id| ordering.contains(id)),
211                    None => true,
212                })
213                .map(|(i, id)| (i, *id))
214                .collect();
215
216            // TODO
217            // The Second part of the Ordering Condition is not really used/implemented
218            // and may even be outright wrong
219
220            if potential.len() == 1 {
221                let (index, entry) = potential
222                    .pop()
223                    .expect("We previously checked that there is at least one item in it");
224                ordering.push(entry);
225                nodes.remove(index);
226                continue;
227            }
228
229            potential.sort_by(|(_, a), (_, b)| {
230                let a_incoming = match incoming.get(a) {
231                    Some(i) => i,
232                    None => return std::cmp::Ordering::Less,
233                };
234                let a_first_index = ordering
235                    .iter()
236                    .enumerate()
237                    .find(|(_, id)| a_incoming.contains(*id))
238                    .map(|(i, _)| i);
239
240                let b_incoming = match incoming.get(b) {
241                    Some(i) => i,
242                    None => return std::cmp::Ordering::Greater,
243                };
244                let b_first_index = ordering
245                    .iter()
246                    .enumerate()
247                    .find(|(_, id)| b_incoming.contains(*id))
248                    .map(|(i, _)| i);
249
250                a_first_index.cmp(&b_first_index)
251            });
252
253            let (_, entry) = potential.remove(0);
254            let index = nodes
255                .iter()
256                .enumerate()
257                .find(|(_, id)| **id == entry)
258                .map(|(i, _)| i)
259                .expect("We know that the there is at least one potential entry, so we can assume that we find that entry");
260            ordering.push(entry);
261            nodes.remove(index);
262        }
263
264        ordering
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    #[test]
273    fn reduce_with_changes() {
274        let nodes: HashMap<&i32, &&str> = [(&0, &"first"), (&1, &"second"), (&2, &"third")]
275            .into_iter()
276            .collect();
277        let graph = AcyclicDirectedGraph::new(
278            nodes.clone(),
279            [
280                (&0, [&1, &2].into_iter().collect()),
281                (&1, [&2].into_iter().collect()),
282                (&2, [].into_iter().collect()),
283            ]
284            .into_iter()
285            .collect(),
286        );
287
288        let result = graph.transitive_reduction();
289
290        let expected = MinimalAcyclicDirectedGraph {
291            inner: AcyclicDirectedGraph::new(
292                nodes,
293                [
294                    (&0, [&1].into_iter().collect()),
295                    (&1, [&2].into_iter().collect()),
296                    (&2, [].into_iter().collect()),
297                ]
298                .into_iter()
299                .collect(),
300            ),
301        };
302
303        assert_eq!(expected, result);
304    }
305
306    #[test]
307    fn incoming_mapping_linear() {
308        let graph = MinimalAcyclicDirectedGraph {
309            inner: AcyclicDirectedGraph::new(
310                [
311                    (&0, &"test"),
312                    (&1, &"test"),
313                    (&2, &"test"),
314                    (&3, &"test"),
315                    (&4, &"test"),
316                ]
317                .into_iter()
318                .collect(),
319                [
320                    (&0, [&1].into_iter().collect()),
321                    (&1, [&2].into_iter().collect()),
322                    (&2, [&3].into_iter().collect()),
323                    (&3, [&4].into_iter().collect()),
324                ]
325                .into_iter()
326                .collect(),
327            ),
328        };
329
330        let mapping = graph.incoming_mapping();
331        dbg!(&mapping);
332
333        let expected: HashMap<_, HashSet<_>> = [
334            (&0, [].into_iter().collect()),
335            (&1, [&0].into_iter().collect()),
336            (&2, [&1].into_iter().collect()),
337            (&3, [&2].into_iter().collect()),
338            (&4, [&3].into_iter().collect()),
339        ]
340        .into_iter()
341        .collect();
342
343        assert_eq!(expected, mapping);
344    }
345
346    #[test]
347    fn incoming_mapping_branched() {
348        let graph = MinimalAcyclicDirectedGraph {
349            inner: AcyclicDirectedGraph::new(
350                [
351                    (&0, &"test"),
352                    (&1, &"test"),
353                    (&2, &"test"),
354                    (&3, &"test"),
355                    (&4, &"test"),
356                ]
357                .into_iter()
358                .collect(),
359                [
360                    (&0, [&1, &2].into_iter().collect()),
361                    (&1, [&3].into_iter().collect()),
362                    (&2, [&4].into_iter().collect()),
363                ]
364                .into_iter()
365                .collect(),
366            ),
367        };
368
369        let mapping = graph.incoming_mapping();
370
371        let expected: HashMap<_, HashSet<_>> = [
372            (&0, [].into_iter().collect()),
373            (&1, [&0].into_iter().collect()),
374            (&2, [&0].into_iter().collect()),
375            (&3, [&1].into_iter().collect()),
376            (&4, [&2].into_iter().collect()),
377        ]
378        .into_iter()
379        .collect();
380
381        assert_eq!(expected, mapping);
382    }
383
384    #[test]
385    fn topsort_linear() {
386        let graphs = MinimalAcyclicDirectedGraph {
387            inner: AcyclicDirectedGraph::new(
388                [
389                    (&0, &"test"),
390                    (&1, &"test"),
391                    (&2, &"test"),
392                    (&3, &"test"),
393                    (&4, &"test"),
394                ]
395                .into_iter()
396                .collect(),
397                [
398                    (&0, [&1].into_iter().collect()),
399                    (&1, [&2].into_iter().collect()),
400                    (&2, [&3].into_iter().collect()),
401                    (&3, [&4].into_iter().collect()),
402                ]
403                .into_iter()
404                .collect(),
405            ),
406        };
407
408        let sort = graphs.topological_sort();
409        dbg!(&sort);
410
411        let expected = vec![&0, &1, &2, &3, &4];
412
413        assert_eq!(expected, sort);
414    }
415
416    #[test]
417    fn topsort_branched() {
418        let graphs = MinimalAcyclicDirectedGraph {
419            inner: AcyclicDirectedGraph::new(
420                [
421                    (&0, &"test"),
422                    (&1, &"test"),
423                    (&2, &"test"),
424                    (&3, &"test"),
425                    (&4, &"test"),
426                ]
427                .into_iter()
428                .collect(),
429                [
430                    (&0, [&1, &2].into_iter().collect()),
431                    (&1, [&3].into_iter().collect()),
432                    (&2, [&4].into_iter().collect()),
433                ]
434                .into_iter()
435                .collect(),
436            ),
437        };
438
439        let sort = graphs.topological_sort();
440        dbg!(&sort);
441
442        let expected1 = vec![&0, &1, &2, &3, &4];
443        let expected2 = vec![&0, &2, &1, &4, &3];
444
445        assert!(sort == expected1 || sort == expected2);
446    }
447}