Skip to main content

dk_runner/steps/semantic/
quality.rs

1use std::collections::{HashMap, HashSet};
2
3use regex::Regex;
4
5use dk_core::types::{SymbolId, Visibility, SymbolKind};
6
7use crate::findings::{Finding, Severity};
8
9use super::checks::{CheckContext, SemanticCheck};
10
11// ─── complexity-limit ────────────────────────────────────────────────────
12
13/// Counts nesting depth of branching constructs (if, else, match, for,
14/// while, loop) in changed files and flags functions that exceed a
15/// configurable threshold.
16pub struct ComplexityLimit {
17    threshold: usize,
18    branch_re: Regex,
19}
20
21impl ComplexityLimit {
22    pub fn new() -> Self {
23        Self::with_threshold(10)
24    }
25
26    pub fn with_threshold(threshold: usize) -> Self {
27        Self {
28            threshold,
29            // Match branching keywords at word boundaries. We count occurrences
30            // per file as a rough proxy for McCabe-like complexity.
31            branch_re: Regex::new(
32                r"\b(if|else|match|for|while|loop)\b"
33            )
34            .expect("invalid regex"),
35        }
36    }
37}
38
39impl SemanticCheck for ComplexityLimit {
40    fn name(&self) -> &str {
41        "complexity-limit"
42    }
43
44    fn run(&self, ctx: &CheckContext) -> Vec<Finding> {
45        let mut findings = Vec::new();
46        let fn_re = Regex::new(r"\b(pub\s+)?(async\s+)?fn\s+(\w+)").expect("invalid fn regex");
47
48        for file in &ctx.changed_files {
49            let content = match &file.content {
50                Some(c) => c,
51                None => continue,
52            };
53
54            // Track per-function complexity by detecting `fn` declarations
55            // and measuring branching depth within each function's scope.
56            let mut current_fn: Option<(String, usize)> = None; // (name, start_line)
57            let mut fn_depth: usize = 0;       // brace depth within current function
58            let mut branch_depth: usize = 0;   // branching nesting within current function
59            let mut max_branch_depth: usize = 0;
60            let mut max_branch_line: usize = 0;
61            let mut in_function = false;
62
63            for (line_idx, line) in content.lines().enumerate() {
64                let trimmed = line.trim();
65
66                // Detect function start
67                if let Some(caps) = fn_re.captures(trimmed) {
68                    if !in_function {
69                        let fn_name = caps.get(3).map(|m| m.as_str().to_string()).unwrap_or_default();
70                        current_fn = Some((fn_name, line_idx + 1));
71                        fn_depth = 0;
72                        branch_depth = 0;
73                        max_branch_depth = 0;
74                        max_branch_line = line_idx + 1;
75                        in_function = true;
76                    }
77                }
78
79                if in_function {
80                    // Count opening braces
81                    fn_depth += trimmed.matches('{').count();
82
83                    // Track branching keywords for complexity
84                    if self.branch_re.is_match(trimmed) {
85                        branch_depth += 1;
86                        if branch_depth > max_branch_depth {
87                            max_branch_depth = branch_depth;
88                            max_branch_line = line_idx + 1;
89                        }
90                    }
91
92                    // Count closing braces
93                    let close_count = trimmed.matches('}').count();
94                    if close_count > 0 {
95                        // Reduce branch depth for closing braces (heuristic)
96                        if branch_depth > 0 && trimmed.starts_with('}') {
97                            branch_depth = branch_depth.saturating_sub(1);
98                        }
99                    }
100                    fn_depth = fn_depth.saturating_sub(close_count);
101
102                    // Function ended when brace depth returns to 0
103                    if fn_depth == 0 && current_fn.is_some() {
104                        if max_branch_depth > self.threshold {
105                            let (fn_name, fn_start) = current_fn.as_ref().unwrap();
106                            findings.push(Finding {
107                                severity: Severity::Warning,
108                                check_name: self.name().to_string(),
109                                message: format!(
110                                    "function '{}' (line {}) has branching complexity {} exceeding threshold {} (deepest near line {})",
111                                    fn_name, fn_start, max_branch_depth, self.threshold, max_branch_line
112                                ),
113                                file_path: Some(file.path.clone()),
114                                line: Some(max_branch_line as u32),
115                                symbol: Some(fn_name.clone()),
116                            });
117                        }
118                        current_fn = None;
119                        in_function = false;
120                    }
121                }
122            }
123        }
124
125        findings
126    }
127}
128
129// ─── no-dependency-cycles ────────────────────────────────────────────────
130
131/// Performs DFS cycle detection on the call graph and flags any cycles found.
132pub struct NoDependencyCycles;
133
134impl NoDependencyCycles {
135    pub fn new() -> Self {
136        Self
137    }
138}
139
140impl SemanticCheck for NoDependencyCycles {
141    fn name(&self) -> &str {
142        "no-dependency-cycles"
143    }
144
145    fn run(&self, ctx: &CheckContext) -> Vec<Finding> {
146        let mut findings = Vec::new();
147
148        // Build adjacency list from the call graph.
149        let mut adjacency: HashMap<SymbolId, Vec<SymbolId>> = HashMap::new();
150        for edge in &ctx.before_call_graph {
151            adjacency.entry(edge.caller).or_default().push(edge.callee);
152        }
153        // Also include after call graph if available.
154        for edge in &ctx.after_call_graph {
155            adjacency.entry(edge.caller).or_default().push(edge.callee);
156        }
157
158        if adjacency.is_empty() {
159            return findings;
160        }
161
162        // DFS cycle detection.
163        let nodes: Vec<SymbolId> = adjacency.keys().copied().collect();
164        let mut visited: HashSet<SymbolId> = HashSet::new();
165        let mut on_stack: HashSet<SymbolId> = HashSet::new();
166        let mut cycles: Vec<Vec<SymbolId>> = Vec::new();
167
168        for &node in &nodes {
169            if !visited.contains(&node) {
170                let mut path = Vec::new();
171                dfs_detect_cycle(
172                    node,
173                    &adjacency,
174                    &mut visited,
175                    &mut on_stack,
176                    &mut path,
177                    &mut cycles,
178                );
179            }
180        }
181
182        for cycle in &cycles {
183            let cycle_ids: Vec<String> = cycle.iter().map(|id| id.to_string()).collect();
184            findings.push(Finding {
185                severity: Severity::Error,
186                check_name: self.name().to_string(),
187                message: format!(
188                    "dependency cycle detected involving {} symbol(s): {}",
189                    cycle.len(),
190                    cycle_ids.join(" -> ")
191                ),
192                file_path: None,
193                line: None,
194                symbol: None,
195            });
196        }
197
198        findings
199    }
200}
201
202fn dfs_detect_cycle(
203    node: SymbolId,
204    adj: &HashMap<SymbolId, Vec<SymbolId>>,
205    visited: &mut HashSet<SymbolId>,
206    on_stack: &mut HashSet<SymbolId>,
207    path: &mut Vec<SymbolId>,
208    cycles: &mut Vec<Vec<SymbolId>>,
209) {
210    visited.insert(node);
211    on_stack.insert(node);
212    path.push(node);
213
214    if let Some(neighbors) = adj.get(&node) {
215        for &next in neighbors {
216            if !visited.contains(&next) {
217                dfs_detect_cycle(next, adj, visited, on_stack, path, cycles);
218            } else if on_stack.contains(&next) {
219                // Found a cycle: extract the cycle from the path.
220                if let Some(pos) = path.iter().position(|&n| n == next) {
221                    let cycle: Vec<SymbolId> = path[pos..].to_vec();
222                    cycles.push(cycle);
223                }
224            }
225        }
226    }
227
228    path.pop();
229    on_stack.remove(&node);
230}
231
232// ─── dead-code-detection ─────────────────────────────────────────────────
233
234/// Detects private functions with zero incoming calls in the call graph.
235pub struct DeadCodeDetection;
236
237impl DeadCodeDetection {
238    pub fn new() -> Self {
239        Self
240    }
241}
242
243impl SemanticCheck for DeadCodeDetection {
244    fn name(&self) -> &str {
245        "dead-code-detection"
246    }
247
248    fn run(&self, ctx: &CheckContext) -> Vec<Finding> {
249        let mut findings = Vec::new();
250
251        // Collect all callee symbol IDs from the call graph.
252        let mut called_symbols: HashSet<SymbolId> = HashSet::new();
253        for edge in &ctx.before_call_graph {
254            called_symbols.insert(edge.callee);
255        }
256        for edge in &ctx.after_call_graph {
257            called_symbols.insert(edge.callee);
258        }
259
260        // Check after_symbols (current state) for private functions with zero callers.
261        for sym in &ctx.after_symbols {
262            if sym.kind != SymbolKind::Function {
263                continue;
264            }
265            if sym.visibility != Visibility::Private {
266                continue;
267            }
268            // Skip "main" functions and test helpers.
269            if sym.name == "main" || sym.name.starts_with("test") {
270                continue;
271            }
272
273            if !called_symbols.contains(&sym.id) {
274                findings.push(Finding {
275                    severity: Severity::Info,
276                    check_name: self.name().to_string(),
277                    message: format!(
278                        "private function '{}' has no callers and may be dead code",
279                        sym.qualified_name
280                    ),
281                    file_path: Some(sym.file_path.to_string_lossy().to_string()),
282                    line: None,
283                    symbol: Some(sym.qualified_name.clone()),
284                });
285            }
286        }
287
288        findings
289    }
290}
291
292/// Returns all 3 quality checks.
293pub fn quality_checks() -> Vec<Box<dyn SemanticCheck>> {
294    vec![
295        Box::new(ComplexityLimit::new()),
296        Box::new(NoDependencyCycles::new()),
297        Box::new(DeadCodeDetection::new()),
298    ]
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use super::super::checks::{ChangedFile, CheckContext};
305    use dk_core::types::{CallEdge, CallKind, Span, Symbol};
306    use uuid::Uuid;
307
308    fn empty_context() -> CheckContext {
309        CheckContext {
310            before_symbols: vec![],
311            after_symbols: vec![],
312            before_call_graph: vec![],
313            after_call_graph: vec![],
314            before_deps: vec![],
315            after_deps: vec![],
316            changed_files: vec![],
317        }
318    }
319
320    fn make_fn(name: &str, vis: Visibility) -> Symbol {
321        Symbol {
322            id: Uuid::new_v4(),
323            name: name.split("::").last().unwrap_or(name).into(),
324            qualified_name: name.into(),
325            kind: SymbolKind::Function,
326            visibility: vis,
327            file_path: "src/lib.rs".into(),
328            span: Span { start_byte: 0, end_byte: 100 },
329            signature: None,
330            doc_comment: None,
331            parent: None,
332            last_modified_by: None,
333            last_modified_intent: None,
334        }
335    }
336
337    #[test]
338    fn test_complexity_under_threshold() {
339        let mut ctx = empty_context();
340        ctx.changed_files.push(ChangedFile {
341            path: "src/main.rs".into(),
342            content: Some("fn simple() {\n    if true {\n        return;\n    }\n}".into()),
343        });
344
345        let check = ComplexityLimit::with_threshold(10);
346        assert!(check.run(&ctx).is_empty());
347    }
348
349    #[test]
350    fn test_complexity_over_threshold() {
351        let mut ctx = empty_context();
352        // Create deeply nested code inside a function that exceeds threshold of 2.
353        let code = "\
354fn complex() {
355    if a {
356        if b {
357            if c {
358                x
359            }
360        }
361    }
362}";
363        ctx.changed_files.push(ChangedFile {
364            path: "src/main.rs".into(),
365            content: Some(code.into()),
366        });
367
368        let check = ComplexityLimit::with_threshold(2);
369        let findings = check.run(&ctx);
370        assert_eq!(findings.len(), 1);
371        assert_eq!(findings[0].severity, Severity::Warning);
372        assert!(findings[0].symbol.is_some(), "should report the function name");
373    }
374
375    #[test]
376    fn test_no_dependency_cycles_clean() {
377        let repo_id = Uuid::new_v4();
378        let a = Uuid::new_v4();
379        let b = Uuid::new_v4();
380
381        let mut ctx = empty_context();
382        ctx.before_call_graph.push(CallEdge {
383            id: Uuid::new_v4(),
384            repo_id,
385            caller: a,
386            callee: b,
387            kind: CallKind::DirectCall,
388        });
389
390        let check = NoDependencyCycles::new();
391        assert!(check.run(&ctx).is_empty());
392    }
393
394    #[test]
395    fn test_no_dependency_cycles_detects_cycle() {
396        let repo_id = Uuid::new_v4();
397        let a = Uuid::new_v4();
398        let b = Uuid::new_v4();
399
400        let mut ctx = empty_context();
401        ctx.before_call_graph.push(CallEdge {
402            id: Uuid::new_v4(),
403            repo_id,
404            caller: a,
405            callee: b,
406            kind: CallKind::DirectCall,
407        });
408        ctx.before_call_graph.push(CallEdge {
409            id: Uuid::new_v4(),
410            repo_id,
411            caller: b,
412            callee: a,
413            kind: CallKind::DirectCall,
414        });
415
416        let check = NoDependencyCycles::new();
417        let findings = check.run(&ctx);
418        assert!(!findings.is_empty());
419        assert_eq!(findings[0].severity, Severity::Error);
420    }
421
422    #[test]
423    fn test_dead_code_detects_uncalled_private() {
424        let mut ctx = empty_context();
425        let sym = make_fn("crate::helper", Visibility::Private);
426        // No call edges point to this symbol.
427        ctx.after_symbols.push(sym);
428
429        let check = DeadCodeDetection::new();
430        let findings = check.run(&ctx);
431        assert_eq!(findings.len(), 1);
432        assert_eq!(findings[0].severity, Severity::Info);
433    }
434
435    #[test]
436    fn test_dead_code_ignores_public() {
437        let mut ctx = empty_context();
438        ctx.after_symbols.push(make_fn("crate::api_handler", Visibility::Public));
439
440        let check = DeadCodeDetection::new();
441        assert!(check.run(&ctx).is_empty());
442    }
443
444    #[test]
445    fn test_dead_code_ignores_called_private() {
446        let repo_id = Uuid::new_v4();
447        let mut ctx = empty_context();
448        let sym = make_fn("crate::helper", Visibility::Private);
449        let sym_id = sym.id;
450        ctx.after_symbols.push(sym);
451
452        // Add a call edge pointing to this symbol.
453        ctx.before_call_graph.push(CallEdge {
454            id: Uuid::new_v4(),
455            repo_id,
456            caller: Uuid::new_v4(),
457            callee: sym_id,
458            kind: CallKind::DirectCall,
459        });
460
461        let check = DeadCodeDetection::new();
462        assert!(check.run(&ctx).is_empty());
463    }
464}