Skip to main content

lisette_semantics/module_graph/
kahn.rs

1use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
2
3use super::ModuleId;
4
5pub fn topological_sort(
6    edges: &HashMap<ModuleId, HashSet<ModuleId>>,
7) -> (Vec<ModuleId>, Vec<Vec<ModuleId>>) {
8    let mut in_degree: HashMap<ModuleId, usize> = HashMap::default();
9    let mut order = Vec::new();
10
11    for (module, imports) in edges {
12        in_degree.entry(module.clone()).or_insert(0);
13        for import in imports {
14            *in_degree.entry(import.clone()).or_insert(0) += 1;
15        }
16    }
17
18    let mut queue: Vec<_> = in_degree
19        .iter()
20        .filter(|&(_, deg)| *deg == 0)
21        .map(|(id, _)| id.clone())
22        .collect();
23
24    queue.sort();
25
26    while let Some(module) = queue.pop() {
27        order.push(module.clone());
28
29        if let Some(imports) = edges.get(&module) {
30            for import in imports {
31                if let Some(degree) = in_degree.get_mut(import) {
32                    *degree -= 1;
33                    if *degree == 0 {
34                        queue.push(import.clone());
35                        queue.sort();
36                    }
37                }
38            }
39        }
40    }
41
42    let cycles = if order.len() < edges.len() {
43        find_cycles(edges, &order)
44    } else {
45        vec![]
46    };
47
48    order.reverse();
49
50    (order, cycles)
51}
52
53fn find_cycles(
54    edges: &HashMap<ModuleId, HashSet<ModuleId>>,
55    processed: &[ModuleId],
56) -> Vec<Vec<ModuleId>> {
57    let processed_set: HashSet<_> = processed.iter().collect();
58    let unprocessed: Vec<_> = edges
59        .keys()
60        .filter(|k| !processed_set.contains(k))
61        .collect();
62
63    let mut cycles = Vec::new();
64    let mut visited = HashSet::default();
65
66    for start in unprocessed {
67        if visited.contains(start) {
68            continue;
69        }
70
71        let mut stack = vec![(start, vec![start.clone()])];
72        let mut on_stack: HashSet<ModuleId> = HashSet::default();
73
74        while let Some((node, path)) = stack.pop() {
75            if on_stack.contains(node) {
76                continue;
77            }
78            on_stack.insert(node.clone());
79            visited.insert(node.clone());
80
81            if let Some(imports) = edges.get(node) {
82                for import in imports {
83                    if let Some(position) = path.iter().position(|p| p == import) {
84                        let mut cycle_path: Vec<_> = path[position..].to_vec();
85                        cycle_path.push(import.clone());
86                        cycles.push(cycle_path);
87                    } else if !visited.contains(import) {
88                        let mut new_path = path.clone();
89                        new_path.push(import.clone());
90                        stack.push((import, new_path));
91                    }
92                }
93            }
94        }
95    }
96
97    cycles
98}