1use glob::Pattern;
30use serde::Deserialize;
31use std::collections::HashMap;
32use std::fs;
33use std::path::Path;
34use thiserror::Error;
35
36use crate::metrics::Volatility;
37
38#[derive(Error, Debug)]
40pub enum ConfigError {
41 #[error("Failed to read config file: {0}")]
42 IoError(#[from] std::io::Error),
43
44 #[error("Failed to parse config file: {0}")]
45 ParseError(#[from] toml::de::Error),
46
47 #[error("Invalid glob pattern: {0}")]
48 PatternError(String),
49}
50
51#[derive(Debug, Clone, Deserialize, Default)]
53pub struct VolatilityConfig {
54 #[serde(default)]
56 pub high: Vec<String>,
57
58 #[serde(default)]
60 pub medium: Vec<String>,
61
62 #[serde(default)]
64 pub low: Vec<String>,
65
66 #[serde(default)]
68 pub ignore: Vec<String>,
69}
70
71#[derive(Debug, Clone, Deserialize)]
73pub struct ThresholdsConfig {
74 #[serde(default = "default_max_dependencies")]
76 pub max_dependencies: usize,
77
78 #[serde(default = "default_max_dependents")]
80 pub max_dependents: usize,
81}
82
83fn default_max_dependencies() -> usize {
84 15
85}
86
87fn default_max_dependents() -> usize {
88 20
89}
90
91impl Default for ThresholdsConfig {
92 fn default() -> Self {
93 Self {
94 max_dependencies: default_max_dependencies(),
95 max_dependents: default_max_dependents(),
96 }
97 }
98}
99
100#[derive(Debug, Clone, Deserialize)]
102pub struct AposdConfig {
103 #[serde(default = "default_min_depth_ratio")]
105 pub min_depth_ratio: f64,
106
107 #[serde(default = "default_max_cognitive_load")]
109 pub max_cognitive_load: f64,
110
111 #[serde(default = "default_exclude_rust_idioms")]
113 pub exclude_rust_idioms: bool,
114
115 #[serde(default)]
117 pub exclude_prefixes: Vec<String>,
118
119 #[serde(default)]
121 pub exclude_methods: Vec<String>,
122}
123
124fn default_min_depth_ratio() -> f64 {
125 2.0
126}
127
128fn default_max_cognitive_load() -> f64 {
129 15.0
130}
131
132fn default_exclude_rust_idioms() -> bool {
133 true
134}
135
136impl Default for AposdConfig {
137 fn default() -> Self {
138 Self {
139 min_depth_ratio: default_min_depth_ratio(),
140 max_cognitive_load: default_max_cognitive_load(),
141 exclude_rust_idioms: default_exclude_rust_idioms(),
142 exclude_prefixes: Vec::new(),
143 exclude_methods: Vec::new(),
144 }
145 }
146}
147
148#[derive(Debug, Clone, Deserialize, Default)]
150pub struct CouplingConfig {
151 #[serde(default)]
153 pub volatility: VolatilityConfig,
154
155 #[serde(default)]
157 pub thresholds: ThresholdsConfig,
158
159 #[serde(default)]
161 pub aposd: AposdConfig,
162}
163
164#[derive(Debug)]
166pub struct CompiledConfig {
167 high_patterns: Vec<Pattern>,
169 medium_patterns: Vec<Pattern>,
171 low_patterns: Vec<Pattern>,
173 ignore_patterns: Vec<Pattern>,
175 pub thresholds: ThresholdsConfig,
177 pub aposd: AposdConfig,
179 cache: HashMap<String, Option<Volatility>>,
181}
182
183impl CompiledConfig {
184 pub fn from_config(config: CouplingConfig) -> Result<Self, ConfigError> {
186 let compile_patterns = |patterns: &[String]| -> Result<Vec<Pattern>, ConfigError> {
187 patterns
188 .iter()
189 .map(|p| {
190 Pattern::new(p).map_err(|e| ConfigError::PatternError(format!("{}: {}", p, e)))
191 })
192 .collect()
193 };
194
195 Ok(Self {
196 high_patterns: compile_patterns(&config.volatility.high)?,
197 medium_patterns: compile_patterns(&config.volatility.medium)?,
198 low_patterns: compile_patterns(&config.volatility.low)?,
199 ignore_patterns: compile_patterns(&config.volatility.ignore)?,
200 thresholds: config.thresholds,
201 aposd: config.aposd,
202 cache: HashMap::new(),
203 })
204 }
205
206 pub fn empty() -> Self {
208 Self {
209 high_patterns: Vec::new(),
210 medium_patterns: Vec::new(),
211 low_patterns: Vec::new(),
212 ignore_patterns: Vec::new(),
213 thresholds: ThresholdsConfig::default(),
214 aposd: AposdConfig::default(),
215 cache: HashMap::new(),
216 }
217 }
218
219 pub fn should_ignore(&self, path: &str) -> bool {
221 self.ignore_patterns.iter().any(|p| p.matches(path))
222 }
223
224 pub fn get_volatility_override(&mut self, path: &str) -> Option<Volatility> {
226 if let Some(cached) = self.cache.get(path) {
228 return *cached;
229 }
230
231 let result = if self.high_patterns.iter().any(|p| p.matches(path)) {
233 Some(Volatility::High)
234 } else if self.medium_patterns.iter().any(|p| p.matches(path)) {
235 Some(Volatility::Medium)
236 } else if self.low_patterns.iter().any(|p| p.matches(path)) {
237 Some(Volatility::Low)
238 } else {
239 None
240 };
241
242 self.cache.insert(path.to_string(), result);
244 result
245 }
246
247 pub fn get_volatility(&mut self, path: &str, git_volatility: Volatility) -> Volatility {
249 self.get_volatility_override(path).unwrap_or(git_volatility)
250 }
251
252 pub fn has_volatility_overrides(&self) -> bool {
254 !self.high_patterns.is_empty()
255 || !self.medium_patterns.is_empty()
256 || !self.low_patterns.is_empty()
257 }
258}
259
260pub fn load_config(project_path: &Path) -> Result<CouplingConfig, ConfigError> {
264 let config_path = find_config_file(project_path);
266
267 match config_path {
268 Some(path) => {
269 let content = fs::read_to_string(&path)?;
270 let config: CouplingConfig = toml::from_str(&content)?;
271 Ok(config)
272 }
273 None => Ok(CouplingConfig::default()),
274 }
275}
276
277fn find_config_file(start_path: &Path) -> Option<std::path::PathBuf> {
279 let config_names = [".coupling.toml", "coupling.toml"];
280
281 let mut current = if start_path.is_file() {
282 start_path.parent()?.to_path_buf()
283 } else {
284 start_path.to_path_buf()
285 };
286
287 loop {
288 for name in &config_names {
289 let config_path = current.join(name);
290 if config_path.exists() {
291 return Some(config_path);
292 }
293 }
294
295 if let Some(parent) = current.parent() {
297 current = parent.to_path_buf();
298 } else {
299 break;
300 }
301 }
302
303 None
304}
305
306pub fn load_compiled_config(project_path: &Path) -> Result<CompiledConfig, ConfigError> {
308 let config = load_config(project_path)?;
309 CompiledConfig::from_config(config)
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 #[test]
317 fn test_default_config() {
318 let config = CouplingConfig::default();
319 assert!(config.volatility.high.is_empty());
320 assert!(config.volatility.low.is_empty());
321 assert_eq!(config.thresholds.max_dependencies, 15);
322 assert_eq!(config.thresholds.max_dependents, 20);
323 }
324
325 #[test]
326 fn test_parse_config() {
327 let toml = r#"
328 [volatility]
329 high = ["src/api/*", "src/handlers/*"]
330 low = ["src/core/*"]
331 ignore = ["tests/*"]
332
333 [thresholds]
334 max_dependencies = 20
335 max_dependents = 30
336 "#;
337
338 let config: CouplingConfig = toml::from_str(toml).unwrap();
339 assert_eq!(config.volatility.high.len(), 2);
340 assert_eq!(config.volatility.low.len(), 1);
341 assert_eq!(config.volatility.ignore.len(), 1);
342 assert_eq!(config.thresholds.max_dependencies, 20);
343 assert_eq!(config.thresholds.max_dependents, 30);
344 }
345
346 #[test]
347 fn test_compiled_config() {
348 let toml = r#"
349 [volatility]
350 high = ["src/business/*"]
351 low = ["src/core/*"]
352 "#;
353
354 let config: CouplingConfig = toml::from_str(toml).unwrap();
355 let mut compiled = CompiledConfig::from_config(config).unwrap();
356
357 assert_eq!(
358 compiled.get_volatility_override("src/business/pricing.rs"),
359 Some(Volatility::High)
360 );
361 assert_eq!(
362 compiled.get_volatility_override("src/core/types.rs"),
363 Some(Volatility::Low)
364 );
365 assert_eq!(compiled.get_volatility_override("src/other/file.rs"), None);
366 }
367
368 #[test]
369 fn test_ignore_patterns() {
370 let toml = r#"
371 [volatility]
372 ignore = ["tests/*", "benches/*"]
373 "#;
374
375 let config: CouplingConfig = toml::from_str(toml).unwrap();
376 let compiled = CompiledConfig::from_config(config).unwrap();
377
378 assert!(compiled.should_ignore("tests/integration.rs"));
379 assert!(compiled.should_ignore("benches/perf.rs"));
380 assert!(!compiled.should_ignore("src/lib.rs"));
381 }
382
383 #[test]
384 fn test_get_volatility_with_fallback() {
385 let toml = r#"
386 [volatility]
387 high = ["src/api/*"]
388 "#;
389
390 let config: CouplingConfig = toml::from_str(toml).unwrap();
391 let mut compiled = CompiledConfig::from_config(config).unwrap();
392
393 assert_eq!(
395 compiled.get_volatility("src/api/handler.rs", Volatility::Low),
396 Volatility::High
397 );
398
399 assert_eq!(
401 compiled.get_volatility("src/other/file.rs", Volatility::Medium),
402 Volatility::Medium
403 );
404 }
405}