1use std::collections::{BTreeMap, HashMap, HashSet};
2
3use super::Module;
4
5#[derive(Debug, Clone)]
7pub struct LayerInfo {
8 pub name: String,
9 pub level: usize,
10 pub file_count: usize,
11 pub fan_in: usize,
12 pub fan_out: usize,
13 pub instability: f64,
14}
15
16#[derive(Debug, Clone)]
18pub struct LayerViolation {
19 pub from_module: String,
20 pub to_module: String,
21 pub from_level: usize,
22 pub to_level: usize,
23}
24
25pub fn infer_layers(
28 modules: &[Module],
29 file_imports: &[(String, String)],
30) -> (Vec<LayerInfo>, Vec<LayerViolation>) {
31 let file_to_mod: HashMap<&str, &str> = modules
32 .iter()
33 .flat_map(|m| m.files.iter().map(|f| (f.as_str(), m.name.as_str())))
34 .collect();
35
36 let (fan_in, fan_out) = compute_module_fans(&file_to_mod, file_imports);
37 let mut layers = build_layers(modules, &fan_in, &fan_out);
38
39 layers.sort_by(|a, b| a.instability.partial_cmp(&b.instability).unwrap());
40 for (i, l) in layers.iter_mut().enumerate() {
41 l.level = i;
42 }
43
44 let violations = detect_violations(&layers, &file_to_mod, file_imports);
45 (layers, violations)
46}
47
48type FanMap<'a> = HashMap<&'a str, HashSet<&'a str>>;
49
50fn compute_module_fans<'a>(
51 file_to_mod: &HashMap<&'a str, &'a str>,
52 file_imports: &[(String, String)],
53) -> (FanMap<'a>, FanMap<'a>) {
54 let mut fan_in: HashMap<&str, HashSet<&str>> = HashMap::new();
55 let mut fan_out: HashMap<&str, HashSet<&str>> = HashMap::new();
56 for (from, to) in file_imports {
57 let fm = file_to_mod.get(from.as_str()).copied().unwrap_or("");
58 let tm = file_to_mod.get(to.as_str()).copied().unwrap_or("");
59 if !fm.is_empty() && !tm.is_empty() && fm != tm {
60 fan_out.entry(fm).or_default().insert(tm);
61 fan_in.entry(tm).or_default().insert(fm);
62 }
63 }
64 (fan_in, fan_out)
65}
66
67fn build_layers(
68 modules: &[Module],
69 fan_in: &HashMap<&str, HashSet<&str>>,
70 fan_out: &HashMap<&str, HashSet<&str>>,
71) -> Vec<LayerInfo> {
72 modules
73 .iter()
74 .map(|m| {
75 let fi = fan_in.get(m.name.as_str()).map(|s| s.len()).unwrap_or(0);
76 let fo = fan_out.get(m.name.as_str()).map(|s| s.len()).unwrap_or(0);
77 let total = fi + fo;
78 LayerInfo {
79 name: m.name.clone(),
80 level: 0,
81 file_count: m.files.len(),
82 fan_in: fi,
83 fan_out: fo,
84 instability: if total > 0 {
85 fo as f64 / total as f64
86 } else {
87 0.5
88 },
89 }
90 })
91 .collect()
92}
93
94fn detect_violations(
95 layers: &[LayerInfo],
96 file_to_mod: &HashMap<&str, &str>,
97 file_imports: &[(String, String)],
98) -> Vec<LayerViolation> {
99 let level_map: BTreeMap<&str, usize> =
100 layers.iter().map(|l| (l.name.as_str(), l.level)).collect();
101
102 let mut seen: BTreeMap<(&str, &str), (usize, usize)> = BTreeMap::new();
103 for (from, to) in file_imports {
104 let fm = file_to_mod.get(from.as_str()).copied().unwrap_or("");
105 let tm = file_to_mod.get(to.as_str()).copied().unwrap_or("");
106 if fm == tm || fm.is_empty() || tm.is_empty() {
107 continue;
108 }
109 if let (Some(&fl), Some(&tl)) = (level_map.get(fm), level_map.get(tm))
110 && fl < tl
111 {
112 seen.entry((fm, tm)).or_insert((fl, tl));
113 }
114 }
115
116 seen.into_iter()
117 .map(|((from, to), (fl, tl))| LayerViolation {
118 from_module: from.to_string(),
119 to_module: to.to_string(),
120 from_level: fl,
121 to_level: tl,
122 })
123 .collect()
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 #[test]
131 fn stable_module_importing_volatile_is_violation() {
132 let modules = vec![
133 Module {
134 name: "core".into(),
135 files: vec!["core/a.rs".into()],
136 },
137 Module {
138 name: "ui".into(),
139 files: vec!["ui/b.rs".into()],
140 },
141 ];
142 let imports = vec![("core/a.rs".into(), "ui/b.rs".into())];
148 let (layers, violations) = infer_layers(&modules, &imports);
149 assert_eq!(layers.len(), 2);
150 assert!(violations.is_empty());
152 }
153}