jujutsu_lib/
dag_walk.rs

1// Copyright 2020 The Jujutsu Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::hash::Hash;
17use std::iter::Iterator;
18
19pub struct BfsIter<'id_fn, 'neighbors_fn, T, ID, NI> {
20    id_fn: Box<dyn Fn(&T) -> ID + 'id_fn>,
21    neighbors_fn: Box<dyn FnMut(&T) -> NI + 'neighbors_fn>,
22    work: Vec<T>,
23    visited: HashSet<ID>,
24}
25
26impl<T, ID, NI> Iterator for BfsIter<'_, '_, T, ID, NI>
27where
28    ID: Hash + Eq,
29    NI: IntoIterator<Item = T>,
30{
31    type Item = T;
32
33    fn next(&mut self) -> Option<Self::Item> {
34        loop {
35            let c = self.work.pop()?;
36            let id = (self.id_fn)(&c);
37            if self.visited.contains(&id) {
38                continue;
39            }
40            for p in (self.neighbors_fn)(&c) {
41                self.work.push(p);
42            }
43            self.visited.insert(id);
44            return Some(c);
45        }
46    }
47}
48
49pub fn bfs<'id_fn, 'neighbors_fn, T, ID, II, NI>(
50    start: II,
51    id_fn: Box<dyn Fn(&T) -> ID + 'id_fn>,
52    neighbors_fn: Box<dyn FnMut(&T) -> NI + 'neighbors_fn>,
53) -> BfsIter<'id_fn, 'neighbors_fn, T, ID, NI>
54where
55    ID: Hash + Eq,
56    II: IntoIterator<Item = T>,
57    NI: IntoIterator<Item = T>,
58{
59    BfsIter {
60        id_fn,
61        neighbors_fn,
62        work: start.into_iter().collect(),
63        visited: Default::default(),
64    }
65}
66
67/// Returns neighbors before the node itself.
68pub fn topo_order_reverse<'a, T, ID, II, NI>(
69    start: II,
70    id_fn: Box<dyn Fn(&T) -> ID + 'a>,
71    mut neighbors_fn: Box<dyn FnMut(&T) -> NI + 'a>,
72) -> Vec<T>
73where
74    T: Hash + Eq + Clone,
75    ID: Hash + Eq + Clone,
76    II: IntoIterator<Item = T>,
77    NI: IntoIterator<Item = T>,
78{
79    let mut visiting = HashSet::new();
80    let mut emitted = HashSet::new();
81    let mut result = vec![];
82
83    let mut start_nodes: Vec<T> = start.into_iter().collect();
84    start_nodes.reverse();
85
86    for start_node in start_nodes {
87        let mut stack = vec![(start_node, false)];
88        while let Some((node, neighbors_visited)) = stack.pop() {
89            let id = id_fn(&node);
90            if emitted.contains(&id) {
91                continue;
92            }
93            if !neighbors_visited {
94                assert!(visiting.insert(id.clone()), "graph has cycle");
95                let neighbors = neighbors_fn(&node);
96                stack.push((node, true));
97                for neighbor in neighbors {
98                    stack.push((neighbor, false));
99                }
100            } else {
101                visiting.remove(&id);
102                emitted.insert(id);
103                result.push(node);
104            }
105        }
106    }
107    result.reverse();
108    result
109}
110
111pub fn leaves<T, ID, II, NI>(
112    start: II,
113    neighbors_fn: &mut impl FnMut(&T) -> NI,
114    id_fn: &impl Fn(&T) -> ID,
115) -> HashSet<T>
116where
117    T: Hash + Eq + Clone,
118    ID: Hash + Eq,
119    II: IntoIterator<Item = T>,
120    NI: IntoIterator<Item = T>,
121{
122    let mut visited = HashSet::new();
123    let mut work: Vec<T> = start.into_iter().collect();
124    let mut leaves: HashSet<T> = work.iter().cloned().collect();
125    let mut non_leaves = HashSet::new();
126    while !work.is_empty() {
127        // TODO: make this not waste so much memory on the sets
128        let mut new_work = vec![];
129        for c in work {
130            let id: ID = id_fn(&c);
131            if visited.contains(&id) {
132                continue;
133            }
134            for p in neighbors_fn(&c) {
135                non_leaves.insert(c.clone());
136                new_work.push(p);
137            }
138            visited.insert(id);
139            leaves.insert(c);
140        }
141        work = new_work;
142    }
143    leaves.difference(&non_leaves).cloned().collect()
144}
145
146/// Find nodes in the start set that are not reachable from other nodes in the
147/// start set.
148pub fn heads<T, ID, II, NI>(
149    start: II,
150    neighbors_fn: &impl Fn(&T) -> NI,
151    id_fn: &impl Fn(&T) -> ID,
152) -> HashSet<T>
153where
154    T: Hash + Eq + Clone,
155    ID: Hash + Eq,
156    II: IntoIterator<Item = T>,
157    NI: IntoIterator<Item = T>,
158{
159    let start: Vec<T> = start.into_iter().collect();
160    let mut reachable: HashSet<T> = start.iter().cloned().collect();
161    for _node in bfs(
162        start.into_iter(),
163        Box::new(id_fn),
164        Box::new(|node| {
165            let neighbors: Vec<T> = neighbors_fn(node).into_iter().collect();
166            for neighbor in &neighbors {
167                reachable.remove(neighbor);
168            }
169            neighbors
170        }),
171    ) {}
172    reachable
173}
174
175pub fn closest_common_node<T, ID, II1, II2, NI>(
176    set1: II1,
177    set2: II2,
178    neighbors_fn: &impl Fn(&T) -> NI,
179    id_fn: &impl Fn(&T) -> ID,
180) -> Option<T>
181where
182    T: Hash + Eq + Clone,
183    ID: Hash + Eq,
184    II1: IntoIterator<Item = T>,
185    II2: IntoIterator<Item = T>,
186    NI: IntoIterator<Item = T>,
187{
188    let mut visited1 = HashSet::new();
189    let mut visited2 = HashSet::new();
190
191    let mut work1: Vec<T> = set1.into_iter().collect();
192    let mut work2: Vec<T> = set2.into_iter().collect();
193    while !work1.is_empty() || !work2.is_empty() {
194        let mut new_work1 = vec![];
195        for node in work1 {
196            let id: ID = id_fn(&node);
197            if visited2.contains(&id) {
198                return Some(node);
199            }
200            if visited1.insert(id) {
201                for neighbor in neighbors_fn(&node) {
202                    new_work1.push(neighbor);
203                }
204            }
205        }
206        work1 = new_work1;
207
208        let mut new_work2 = vec![];
209        for node in work2 {
210            let id: ID = id_fn(&node);
211            if visited1.contains(&id) {
212                return Some(node);
213            }
214            if visited2.insert(id) {
215                for neighbor in neighbors_fn(&node) {
216                    new_work2.push(neighbor);
217                }
218            }
219        }
220        work2 = new_work2;
221    }
222    None
223}
224
225#[cfg(test)]
226mod tests {
227    use maplit::{hashmap, hashset};
228
229    use super::*;
230
231    #[test]
232    fn test_topo_order_reverse_linear() {
233        // This graph:
234        //  o C
235        //  o B
236        //  o A
237
238        let neighbors = hashmap! {
239            'A' => vec![],
240            'B' => vec!['A'],
241            'C' => vec!['B'],
242        };
243
244        let common = topo_order_reverse(
245            vec!['C'],
246            Box::new(|node| *node),
247            Box::new(move |node| neighbors[node].clone()),
248        );
249
250        assert_eq!(common, vec!['C', 'B', 'A']);
251    }
252
253    #[test]
254    fn test_topo_order_reverse_merge() {
255        // This graph:
256        //  o F
257        //  |\
258        //  o | E
259        //  | o D
260        //  | o C
261        //  | o B
262        //  |/
263        //  o A
264
265        let neighbors = hashmap! {
266            'A' => vec![],
267            'B' => vec!['A'],
268            'C' => vec!['B'],
269            'D' => vec!['C'],
270            'E' => vec!['A'],
271            'F' => vec!['E', 'D'],
272        };
273
274        let common = topo_order_reverse(
275            vec!['F'],
276            Box::new(|node| *node),
277            Box::new(move |node| neighbors[node].clone()),
278        );
279
280        assert_eq!(common, vec!['F', 'E', 'D', 'C', 'B', 'A']);
281    }
282
283    #[test]
284    fn test_topo_order_reverse_multiple_heads() {
285        // This graph:
286        //  o F
287        //  |\
288        //  o | E
289        //  | o D
290        //  | | o C
291        //  | | |
292        //  | | o B
293        //  | |/
294        //  |/
295        //  o A
296
297        let neighbors = hashmap! {
298            'A' => vec![],
299            'B' => vec!['A'],
300            'C' => vec!['B'],
301            'D' => vec!['A'],
302            'E' => vec!['A'],
303            'F' => vec!['E', 'D'],
304        };
305
306        let common = topo_order_reverse(
307            vec!['F', 'C'],
308            Box::new(|node| *node),
309            Box::new(move |node| neighbors[node].clone()),
310        );
311
312        assert_eq!(common, vec!['F', 'E', 'D', 'C', 'B', 'A']);
313    }
314
315    #[test]
316    fn test_closest_common_node_tricky() {
317        // Test this case where A is the shortest distance away, but we still want the
318        // result to be B because A is an ancestor of B. In other words, we want
319        // to minimize the longest distance.
320        //
321        //  E       H
322        //  |\     /|
323        //  | D   G |
324        //  | C   F |
325        //   \ \ / /
326        //    \ B /
327        //     \|/
328        //      A
329
330        let neighbors = hashmap! {
331            'A' => vec![],
332            'B' => vec!['A'],
333            'C' => vec!['B'],
334            'D' => vec!['C'],
335            'E' => vec!['A','D'],
336            'F' => vec!['B'],
337            'G' => vec!['F'],
338            'H' => vec!['A', 'G'],
339        };
340
341        let common = closest_common_node(
342            vec!['E'],
343            vec!['H'],
344            &|node| neighbors[node].clone(),
345            &|node| *node,
346        );
347
348        // TODO: fix the implementation to return B
349        assert_eq!(common, Some('A'));
350    }
351
352    #[test]
353    fn test_heads_mixed() {
354        // Test the uppercase letters are in the start set
355        //
356        //  D F
357        //  |/|
358        //  C e
359        //  |/
360        //  b
361        //  |
362        //  A
363
364        let neighbors = hashmap! {
365            'A' => vec![],
366            'b' => vec!['A'],
367            'C' => vec!['b'],
368            'D' => vec!['C'],
369            'e' => vec!['b'],
370            'F' => vec!['C', 'e'],
371        };
372
373        let actual = heads(
374            vec!['A', 'C', 'D', 'F'],
375            &|node| neighbors[node].clone(),
376            &|node| *node,
377        );
378        assert_eq!(actual, hashset!['D', 'F']);
379
380        // Check with a different order in the start set
381        let actual = heads(
382            vec!['F', 'D', 'C', 'A'],
383            &|node| neighbors[node].clone(),
384            &|node| *node,
385        );
386        assert_eq!(actual, hashset!['D', 'F']);
387    }
388}