1use std::collections::HashSet;
16use std::hash::Hash;
17use std::iter::Iterator;
18
19pub struct BfsIter<'id_fn, 'neighbors_fn, T, ID, NI> {
20 id_fn: Box<dyn Fn(&T) -> ID + 'id_fn>,
21 neighbors_fn: Box<dyn FnMut(&T) -> NI + 'neighbors_fn>,
22 work: Vec<T>,
23 visited: HashSet<ID>,
24}
25
26impl<T, ID, NI> Iterator for BfsIter<'_, '_, T, ID, NI>
27where
28 ID: Hash + Eq,
29 NI: IntoIterator<Item = T>,
30{
31 type Item = T;
32
33 fn next(&mut self) -> Option<Self::Item> {
34 loop {
35 let c = self.work.pop()?;
36 let id = (self.id_fn)(&c);
37 if self.visited.contains(&id) {
38 continue;
39 }
40 for p in (self.neighbors_fn)(&c) {
41 self.work.push(p);
42 }
43 self.visited.insert(id);
44 return Some(c);
45 }
46 }
47}
48
49pub fn bfs<'id_fn, 'neighbors_fn, T, ID, II, NI>(
50 start: II,
51 id_fn: Box<dyn Fn(&T) -> ID + 'id_fn>,
52 neighbors_fn: Box<dyn FnMut(&T) -> NI + 'neighbors_fn>,
53) -> BfsIter<'id_fn, 'neighbors_fn, T, ID, NI>
54where
55 ID: Hash + Eq,
56 II: IntoIterator<Item = T>,
57 NI: IntoIterator<Item = T>,
58{
59 BfsIter {
60 id_fn,
61 neighbors_fn,
62 work: start.into_iter().collect(),
63 visited: Default::default(),
64 }
65}
66
67pub fn topo_order_reverse<'a, T, ID, II, NI>(
69 start: II,
70 id_fn: Box<dyn Fn(&T) -> ID + 'a>,
71 mut neighbors_fn: Box<dyn FnMut(&T) -> NI + 'a>,
72) -> Vec<T>
73where
74 T: Hash + Eq + Clone,
75 ID: Hash + Eq + Clone,
76 II: IntoIterator<Item = T>,
77 NI: IntoIterator<Item = T>,
78{
79 let mut visiting = HashSet::new();
80 let mut emitted = HashSet::new();
81 let mut result = vec![];
82
83 let mut start_nodes: Vec<T> = start.into_iter().collect();
84 start_nodes.reverse();
85
86 for start_node in start_nodes {
87 let mut stack = vec![(start_node, false)];
88 while let Some((node, neighbors_visited)) = stack.pop() {
89 let id = id_fn(&node);
90 if emitted.contains(&id) {
91 continue;
92 }
93 if !neighbors_visited {
94 assert!(visiting.insert(id.clone()), "graph has cycle");
95 let neighbors = neighbors_fn(&node);
96 stack.push((node, true));
97 for neighbor in neighbors {
98 stack.push((neighbor, false));
99 }
100 } else {
101 visiting.remove(&id);
102 emitted.insert(id);
103 result.push(node);
104 }
105 }
106 }
107 result.reverse();
108 result
109}
110
111pub fn leaves<T, ID, II, NI>(
112 start: II,
113 neighbors_fn: &mut impl FnMut(&T) -> NI,
114 id_fn: &impl Fn(&T) -> ID,
115) -> HashSet<T>
116where
117 T: Hash + Eq + Clone,
118 ID: Hash + Eq,
119 II: IntoIterator<Item = T>,
120 NI: IntoIterator<Item = T>,
121{
122 let mut visited = HashSet::new();
123 let mut work: Vec<T> = start.into_iter().collect();
124 let mut leaves: HashSet<T> = work.iter().cloned().collect();
125 let mut non_leaves = HashSet::new();
126 while !work.is_empty() {
127 let mut new_work = vec![];
129 for c in work {
130 let id: ID = id_fn(&c);
131 if visited.contains(&id) {
132 continue;
133 }
134 for p in neighbors_fn(&c) {
135 non_leaves.insert(c.clone());
136 new_work.push(p);
137 }
138 visited.insert(id);
139 leaves.insert(c);
140 }
141 work = new_work;
142 }
143 leaves.difference(&non_leaves).cloned().collect()
144}
145
146pub fn heads<T, ID, II, NI>(
149 start: II,
150 neighbors_fn: &impl Fn(&T) -> NI,
151 id_fn: &impl Fn(&T) -> ID,
152) -> HashSet<T>
153where
154 T: Hash + Eq + Clone,
155 ID: Hash + Eq,
156 II: IntoIterator<Item = T>,
157 NI: IntoIterator<Item = T>,
158{
159 let start: Vec<T> = start.into_iter().collect();
160 let mut reachable: HashSet<T> = start.iter().cloned().collect();
161 for _node in bfs(
162 start.into_iter(),
163 Box::new(id_fn),
164 Box::new(|node| {
165 let neighbors: Vec<T> = neighbors_fn(node).into_iter().collect();
166 for neighbor in &neighbors {
167 reachable.remove(neighbor);
168 }
169 neighbors
170 }),
171 ) {}
172 reachable
173}
174
175pub fn closest_common_node<T, ID, II1, II2, NI>(
176 set1: II1,
177 set2: II2,
178 neighbors_fn: &impl Fn(&T) -> NI,
179 id_fn: &impl Fn(&T) -> ID,
180) -> Option<T>
181where
182 T: Hash + Eq + Clone,
183 ID: Hash + Eq,
184 II1: IntoIterator<Item = T>,
185 II2: IntoIterator<Item = T>,
186 NI: IntoIterator<Item = T>,
187{
188 let mut visited1 = HashSet::new();
189 let mut visited2 = HashSet::new();
190
191 let mut work1: Vec<T> = set1.into_iter().collect();
192 let mut work2: Vec<T> = set2.into_iter().collect();
193 while !work1.is_empty() || !work2.is_empty() {
194 let mut new_work1 = vec![];
195 for node in work1 {
196 let id: ID = id_fn(&node);
197 if visited2.contains(&id) {
198 return Some(node);
199 }
200 if visited1.insert(id) {
201 for neighbor in neighbors_fn(&node) {
202 new_work1.push(neighbor);
203 }
204 }
205 }
206 work1 = new_work1;
207
208 let mut new_work2 = vec![];
209 for node in work2 {
210 let id: ID = id_fn(&node);
211 if visited1.contains(&id) {
212 return Some(node);
213 }
214 if visited2.insert(id) {
215 for neighbor in neighbors_fn(&node) {
216 new_work2.push(neighbor);
217 }
218 }
219 }
220 work2 = new_work2;
221 }
222 None
223}
224
225#[cfg(test)]
226mod tests {
227 use maplit::{hashmap, hashset};
228
229 use super::*;
230
231 #[test]
232 fn test_topo_order_reverse_linear() {
233 let neighbors = hashmap! {
239 'A' => vec![],
240 'B' => vec!['A'],
241 'C' => vec!['B'],
242 };
243
244 let common = topo_order_reverse(
245 vec!['C'],
246 Box::new(|node| *node),
247 Box::new(move |node| neighbors[node].clone()),
248 );
249
250 assert_eq!(common, vec!['C', 'B', 'A']);
251 }
252
253 #[test]
254 fn test_topo_order_reverse_merge() {
255 let neighbors = hashmap! {
266 'A' => vec![],
267 'B' => vec!['A'],
268 'C' => vec!['B'],
269 'D' => vec!['C'],
270 'E' => vec!['A'],
271 'F' => vec!['E', 'D'],
272 };
273
274 let common = topo_order_reverse(
275 vec!['F'],
276 Box::new(|node| *node),
277 Box::new(move |node| neighbors[node].clone()),
278 );
279
280 assert_eq!(common, vec!['F', 'E', 'D', 'C', 'B', 'A']);
281 }
282
283 #[test]
284 fn test_topo_order_reverse_multiple_heads() {
285 let neighbors = hashmap! {
298 'A' => vec![],
299 'B' => vec!['A'],
300 'C' => vec!['B'],
301 'D' => vec!['A'],
302 'E' => vec!['A'],
303 'F' => vec!['E', 'D'],
304 };
305
306 let common = topo_order_reverse(
307 vec!['F', 'C'],
308 Box::new(|node| *node),
309 Box::new(move |node| neighbors[node].clone()),
310 );
311
312 assert_eq!(common, vec!['F', 'E', 'D', 'C', 'B', 'A']);
313 }
314
315 #[test]
316 fn test_closest_common_node_tricky() {
317 let neighbors = hashmap! {
331 'A' => vec![],
332 'B' => vec!['A'],
333 'C' => vec!['B'],
334 'D' => vec!['C'],
335 'E' => vec!['A','D'],
336 'F' => vec!['B'],
337 'G' => vec!['F'],
338 'H' => vec!['A', 'G'],
339 };
340
341 let common = closest_common_node(
342 vec!['E'],
343 vec!['H'],
344 &|node| neighbors[node].clone(),
345 &|node| *node,
346 );
347
348 assert_eq!(common, Some('A'));
350 }
351
352 #[test]
353 fn test_heads_mixed() {
354 let neighbors = hashmap! {
365 'A' => vec![],
366 'b' => vec!['A'],
367 'C' => vec!['b'],
368 'D' => vec!['C'],
369 'e' => vec!['b'],
370 'F' => vec!['C', 'e'],
371 };
372
373 let actual = heads(
374 vec!['A', 'C', 'D', 'F'],
375 &|node| neighbors[node].clone(),
376 &|node| *node,
377 );
378 assert_eq!(actual, hashset!['D', 'F']);
379
380 let actual = heads(
382 vec!['F', 'D', 'C', 'A'],
383 &|node| neighbors[node].clone(),
384 &|node| *node,
385 );
386 assert_eq!(actual, hashset!['D', 'F']);
387 }
388}