jujube_lib/
dag_walk.rs

1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::iter::Iterator;
17
18use crate::commit::Commit;
19use crate::store::CommitId;
20use std::hash::Hash;
21
22pub struct AncestorsIter {
23    bfs_iter: BfsIter<'static, 'static, Commit, CommitId, Vec<Commit>>,
24}
25
26impl Iterator for AncestorsIter {
27    type Item = Commit;
28
29    fn next(&mut self) -> Option<Self::Item> {
30        self.bfs_iter.next()
31    }
32}
33
34pub fn walk_ancestors<II>(start: II) -> AncestorsIter
35where
36    II: IntoIterator<Item = Commit>,
37{
38    let bfs_iter = bfs(
39        start,
40        Box::new(|commit| commit.id().clone()),
41        Box::new(|commit| commit.parents()),
42    );
43    AncestorsIter { bfs_iter }
44}
45
46pub struct BfsIter<'id_fn, 'neighbors_fn, T, ID, NI> {
47    id_fn: Box<dyn Fn(&T) -> ID + 'id_fn>,
48    neighbors_fn: Box<dyn FnMut(&T) -> NI + 'neighbors_fn>,
49    work: Vec<T>,
50    visited: HashSet<ID>,
51}
52
53impl<T, ID, NI> Iterator for BfsIter<'_, '_, T, ID, NI>
54where
55    ID: Hash + Eq,
56    NI: IntoIterator<Item = T>,
57{
58    type Item = T;
59
60    fn next(&mut self) -> Option<Self::Item> {
61        while !self.work.is_empty() {
62            let c = self.work.pop().unwrap();
63            let id = (self.id_fn)(&c);
64            if self.visited.contains(&id) {
65                continue;
66            }
67            for p in (self.neighbors_fn)(&c) {
68                self.work.push(p);
69            }
70            self.visited.insert(id);
71            return Some(c);
72        }
73        None
74    }
75}
76
77pub fn bfs<'id_fn, 'neighbors_fn, T, ID, II, NI>(
78    start: II,
79    id_fn: Box<dyn Fn(&T) -> ID + 'id_fn>,
80    neighbors_fn: Box<dyn FnMut(&T) -> NI + 'neighbors_fn>,
81) -> BfsIter<'id_fn, 'neighbors_fn, T, ID, NI>
82where
83    ID: Hash + Eq,
84    II: IntoIterator<Item = T>,
85    NI: IntoIterator<Item = T>,
86{
87    BfsIter {
88        id_fn,
89        neighbors_fn,
90        work: start.into_iter().collect(),
91        visited: Default::default(),
92    }
93}
94
95pub struct TopoIter<'id_fn, 'neighbors_fn, T, ID, NI> {
96    id_fn: Box<dyn Fn(&T) -> ID + 'id_fn>,
97    neighbors_fn: Box<dyn FnMut(&T) -> NI + 'neighbors_fn>,
98    work: Vec<T>,
99    visited: HashSet<ID>,
100}
101
102impl<T, ID, NI> Iterator for TopoIter<'_, '_, T, ID, NI>
103where
104    ID: Hash + Eq,
105    NI: IntoIterator<Item = T>,
106{
107    type Item = T;
108
109    fn next(&mut self) -> Option<Self::Item> {
110        while !self.work.is_empty() {
111            let c = self.work.pop().unwrap();
112            let id = (self.id_fn)(&c);
113            if self.visited.contains(&id) {
114                continue;
115            }
116            for p in (self.neighbors_fn)(&c) {
117                self.work.push(p);
118            }
119            self.visited.insert(id);
120            return Some(c);
121        }
122        None
123    }
124}
125
126/// Returns neighbors before the node itself.
127pub fn topo_order_reverse<T, ID, II, NI>(
128    start: II,
129    id_fn: Box<dyn Fn(&T) -> ID>,
130    mut neighbors_fn: Box<dyn FnMut(&T) -> NI>,
131) -> Vec<T>
132where
133    T: Hash + Eq + Clone,
134    ID: Hash + Eq + Clone,
135    II: IntoIterator<Item = T>,
136    NI: IntoIterator<Item = T>,
137{
138    let mut visiting = HashSet::new();
139    let mut emitted = HashSet::new();
140    let mut result = vec![];
141
142    let mut start_nodes: Vec<_> = start.into_iter().collect();
143    start_nodes.reverse();
144
145    for start_node in start_nodes {
146        let mut stack = vec![(start_node, false)];
147        while !stack.is_empty() {
148            let (node, neighbors_visited) = stack.pop().unwrap();
149            let id = id_fn(&node);
150            if emitted.contains(&id) {
151                continue;
152            }
153            if !neighbors_visited {
154                assert!(visiting.insert(id.clone()), "graph has cycle");
155                let neighbors = neighbors_fn(&node);
156                stack.push((node, true));
157                for neighbor in neighbors {
158                    stack.push((neighbor, false));
159                }
160            } else {
161                visiting.remove(&id);
162                emitted.insert(id);
163                result.push(node);
164            }
165        }
166    }
167    result.reverse();
168    result
169}
170
171pub fn leaves<T, ID, II, NI>(
172    start: II,
173    neighbors_fn: &mut impl FnMut(&T) -> NI,
174    id_fn: &impl Fn(&T) -> ID,
175) -> HashSet<T>
176where
177    T: Hash + Eq + Clone,
178    ID: Hash + Eq,
179    II: IntoIterator<Item = T>,
180    NI: IntoIterator<Item = T>,
181{
182    let mut visited = HashSet::new();
183    let mut work: Vec<T> = start.into_iter().collect();
184    let mut leaves: HashSet<T> = work.iter().cloned().collect();
185    let mut non_leaves = HashSet::new();
186    while !work.is_empty() {
187        // TODO: make this not waste so much memory on the sets
188        let mut new_work = vec![];
189        for c in work {
190            let id: ID = id_fn(&c);
191            if visited.contains(&id) {
192                continue;
193            }
194            for p in neighbors_fn(&c) {
195                non_leaves.insert(c.clone());
196                new_work.push(p);
197            }
198            visited.insert(id);
199            leaves.insert(c);
200        }
201        work = new_work;
202    }
203    leaves.difference(&non_leaves).cloned().collect()
204}
205
206/// Find nodes in the start set that are not reachable from other nodes in the
207/// start set.
208pub fn unreachable<T, ID, II, NI>(
209    start: II,
210    neighbors_fn: &impl Fn(&T) -> NI,
211    id_fn: &impl Fn(&T) -> ID,
212) -> HashSet<T>
213where
214    T: Hash + Eq + Clone,
215    ID: Hash + Eq,
216    II: IntoIterator<Item = T>,
217    NI: IntoIterator<Item = T>,
218{
219    let start: Vec<T> = start.into_iter().collect();
220    let mut reachable: HashSet<T> = start.iter().cloned().collect();
221    for _node in bfs(
222        start.into_iter(),
223        Box::new(id_fn),
224        Box::new(|node| {
225            let neighbors: Vec<T> = neighbors_fn(node).into_iter().collect();
226            for neighbor in &neighbors {
227                reachable.remove(&neighbor);
228            }
229            neighbors
230        }),
231    ) {}
232    reachable
233}
234
235pub fn common_ancestor<'a, I1, I2>(set1: I1, set2: I2) -> Commit
236where
237    I1: IntoIterator<Item = &'a Commit>,
238    I2: IntoIterator<Item = &'a Commit>,
239{
240    let set1: Vec<Commit> = set1.into_iter().cloned().collect();
241    let set2: Vec<Commit> = set2.into_iter().cloned().collect();
242    closest_common_node(set1, set2, &|commit| commit.parents(), &|commit| {
243        commit.id().clone()
244    })
245    .unwrap()
246}
247
248pub fn closest_common_node<T, ID, II1, II2, NI>(
249    set1: II1,
250    set2: II2,
251    neighbors_fn: &impl Fn(&T) -> NI,
252    id_fn: &impl Fn(&T) -> ID,
253) -> Option<T>
254where
255    T: Hash + Eq + Clone,
256    ID: Hash + Eq,
257    II1: IntoIterator<Item = T>,
258    II2: IntoIterator<Item = T>,
259    NI: IntoIterator<Item = T>,
260{
261    let mut visited1 = HashSet::new();
262    let mut visited2 = HashSet::new();
263
264    let mut work1: Vec<T> = set1.into_iter().collect();
265    let mut work2: Vec<T> = set2.into_iter().collect();
266    while !work1.is_empty() || !work2.is_empty() {
267        let mut new_work1 = vec![];
268        for node in work1 {
269            let id: ID = id_fn(&node);
270            if visited2.contains(&id) {
271                return Some(node);
272            }
273            if visited1.insert(id) {
274                for neighbor in neighbors_fn(&node) {
275                    new_work1.push(neighbor);
276                }
277            }
278        }
279        work1 = new_work1;
280
281        let mut new_work2 = vec![];
282        for node in work2 {
283            let id: ID = id_fn(&node);
284            if visited1.contains(&id) {
285                return Some(node);
286            }
287            if visited2.insert(id) {
288                for neighbor in neighbors_fn(&node) {
289                    new_work2.push(neighbor);
290                }
291            }
292        }
293        work2 = new_work2;
294    }
295    None
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn topo_order_reverse_linear() {
304        // This graph:
305        //  o C
306        //  o B
307        //  o A
308
309        let neighbors = hashmap! {
310            'A' => vec![],
311            'B' => vec!['A'],
312            'C' => vec!['B'],
313        };
314
315        let common = topo_order_reverse(
316            vec!['C'],
317            Box::new(|node| *node),
318            Box::new(move |node| neighbors[node].clone()),
319        );
320
321        assert_eq!(common, vec!['C', 'B', 'A']);
322    }
323
324    #[test]
325    fn topo_order_reverse_merge() {
326        // This graph:
327        //  o F
328        //  |\
329        //  o | E
330        //  | o D
331        //  | o C
332        //  | o B
333        //  |/
334        //  o A
335
336        let neighbors = hashmap! {
337            'A' => vec![],
338            'B' => vec!['A'],
339            'C' => vec!['B'],
340            'D' => vec!['C'],
341            'E' => vec!['A'],
342            'F' => vec!['E', 'D'],
343        };
344
345        let common = topo_order_reverse(
346            vec!['F'],
347            Box::new(|node| *node),
348            Box::new(move |node| neighbors[node].clone()),
349        );
350
351        assert_eq!(common, vec!['F', 'E', 'D', 'C', 'B', 'A']);
352    }
353
354    #[test]
355    fn topo_order_reverse_multiple_heads() {
356        // This graph:
357        //  o F
358        //  |\
359        //  o | E
360        //  | o D
361        //  | | o C
362        //  | | |
363        //  | | o B
364        //  | |/
365        //  |/
366        //  o A
367
368        let neighbors = hashmap! {
369            'A' => vec![],
370            'B' => vec!['A'],
371            'C' => vec!['B'],
372            'D' => vec!['A'],
373            'E' => vec!['A'],
374            'F' => vec!['E', 'D'],
375        };
376
377        let common = topo_order_reverse(
378            vec!['F', 'C'],
379            Box::new(|node| *node),
380            Box::new(move |node| neighbors[node].clone()),
381        );
382
383        assert_eq!(common, vec!['F', 'E', 'D', 'C', 'B', 'A']);
384    }
385
386    #[test]
387    fn closest_common_node_tricky() {
388        // Test this case where A is the shortest distance away, but we still want the
389        // result to be B because A is an ancestor of B. In other words, we want
390        // to minimize the longest distance.
391        //
392        //  E       H
393        //  |\     /|
394        //  | D   G |
395        //  | C   F |
396        //   \ \ / /
397        //    \ B /
398        //     \|/
399        //      A
400
401        let neighbors = hashmap! {
402            'A' => vec![],
403            'B' => vec!['A'],
404            'C' => vec!['B'],
405            'D' => vec!['C'],
406            'E' => vec!['A','D'],
407            'F' => vec!['B'],
408            'G' => vec!['F'],
409            'H' => vec!['A', 'G'],
410        };
411
412        let common = closest_common_node(
413            vec!['E'],
414            vec!['H'],
415            &|node| neighbors[node].clone(),
416            &|node| *node,
417        );
418
419        // TODO: fix the implementation to return B
420        assert_eq!(common, Some('A'));
421    }
422
423    #[test]
424    fn unreachable_mixed() {
425        // Test the uppercase letters are in the start set
426        //
427        //  D F
428        //  |/|
429        //  C e
430        //  |/
431        //  b
432        //  |
433        //  A
434
435        let neighbors = hashmap! {
436            'A' => vec![],
437            'b' => vec!['A'],
438            'C' => vec!['b'],
439            'D' => vec!['C'],
440            'e' => vec!['b'],
441            'F' => vec!['C', 'e'],
442        };
443        let expected: HashSet<char> = vec!['D', 'F'].into_iter().collect();
444
445        let actual = unreachable(
446            vec!['A', 'C', 'D', 'F'],
447            &|node| neighbors[node].clone(),
448            &|node| *node,
449        );
450        assert_eq!(actual, expected);
451
452        // Check with a different order in the start set
453        let actual = unreachable(
454            vec!['F', 'D', 'C', 'A'],
455            &|node| neighbors[node].clone(),
456            &|node| *node,
457        );
458        assert_eq!(actual, expected);
459    }
460}