Skip to main content

agentic_codebase/semantic/
pattern_extract.rs

1//! Pattern Extraction — Invention 12.
2//!
3//! Extract implicit patterns and make them explicit, enforceable.
4//! Codebase has patterns, but they're implicit. New code doesn't follow them.
5
6use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9
10use crate::graph::CodeGraph;
11use crate::types::CodeUnitType;
12
13// ── Types ────────────────────────────────────────────────────────────────────
14
15/// An extracted pattern.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct ExtractedPattern {
18    /// Pattern name.
19    pub name: String,
20    /// Description.
21    pub description: String,
22    /// Where it's used.
23    pub instances: Vec<PatternInstance>,
24    /// The pattern structure.
25    pub structure: PatternStructure,
26    /// Confidence it's intentional.
27    pub confidence: f64,
28    /// Violations (code that should follow but doesn't).
29    pub violations: Vec<PatternViolation>,
30}
31
32/// An instance of a pattern.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct PatternInstance {
35    /// Node ID.
36    pub node_id: u64,
37    /// Name.
38    pub name: String,
39    /// File.
40    pub file_path: String,
41    /// How well it matches.
42    pub match_strength: f64,
43    /// Any deviations.
44    pub deviations: Vec<String>,
45}
46
47/// Structure of a pattern.
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct PatternStructure {
50    /// Template description.
51    pub template: String,
52    /// Required elements.
53    pub required: Vec<String>,
54    /// Optional elements.
55    pub optional: Vec<String>,
56    /// Anti-patterns (what NOT to do).
57    pub anti_patterns: Vec<String>,
58}
59
60/// A pattern violation.
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct PatternViolation {
63    /// Node ID.
64    pub node_id: u64,
65    /// Node name.
66    pub name: String,
67    /// What's wrong.
68    pub violation: String,
69    /// How to fix.
70    pub suggested_fix: String,
71    /// Severity.
72    pub severity: ViolationSeverity,
73}
74
75/// Severity of a violation.
76#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
77pub enum ViolationSeverity {
78    Info,
79    Warning,
80    Error,
81}
82
83// ── PatternExtractor ─────────────────────────────────────────────────────────
84
85/// Extracts and validates patterns in the codebase.
86pub struct PatternExtractor<'g> {
87    graph: &'g CodeGraph,
88}
89
90impl<'g> PatternExtractor<'g> {
91    pub fn new(graph: &'g CodeGraph) -> Self {
92        Self { graph }
93    }
94
95    /// Extract all detected patterns from the codebase.
96    pub fn extract_patterns(&self) -> Vec<ExtractedPattern> {
97        let mut patterns = Vec::new();
98
99        patterns.extend(self.detect_naming_patterns());
100        patterns.extend(self.detect_structural_patterns());
101
102        // Sort by confidence
103        patterns.sort_by(|a, b| {
104            b.confidence
105                .partial_cmp(&a.confidence)
106                .unwrap_or(std::cmp::Ordering::Equal)
107        });
108        patterns
109    }
110
111    /// Check code against detected patterns.
112    pub fn check_patterns(&self, unit_id: u64) -> Vec<PatternViolation> {
113        let patterns = self.extract_patterns();
114        let mut violations = Vec::new();
115
116        if let Some(unit) = self.graph.get_unit(unit_id) {
117            for pattern in &patterns {
118                // Check if this unit should follow the pattern
119                let should_follow = pattern.instances.iter().any(|inst| {
120                    // Same file directory or same module prefix
121                    let unit_path = unit.file_path.display().to_string();
122                    let inst_path = &inst.file_path;
123                    unit_path
124                        .rsplit_once('/')
125                        .map(|(d, _)| inst_path.starts_with(d))
126                        .unwrap_or(false)
127                });
128
129                if should_follow && !pattern.instances.iter().any(|inst| inst.node_id == unit_id) {
130                    violations.push(PatternViolation {
131                        node_id: unit_id,
132                        name: unit.name.clone(),
133                        violation: format!("Does not follow '{}' pattern", pattern.name),
134                        suggested_fix: format!("Apply pattern: {}", pattern.structure.template),
135                        severity: ViolationSeverity::Warning,
136                    });
137                }
138            }
139        }
140
141        violations
142    }
143
144    /// Suggest patterns for new code based on location.
145    pub fn suggest_patterns(&self, file_path: &str) -> Vec<ExtractedPattern> {
146        let patterns = self.extract_patterns();
147        patterns
148            .into_iter()
149            .filter(|p| {
150                p.instances.iter().any(|inst| {
151                    // Same directory
152                    file_path
153                        .rsplit_once('/')
154                        .map(|(d, _)| inst.file_path.starts_with(d))
155                        .unwrap_or(false)
156                })
157            })
158            .collect()
159    }
160
161    // ── Internal ─────────────────────────────────────────────────────────
162
163    fn detect_naming_patterns(&self) -> Vec<ExtractedPattern> {
164        let mut prefix_groups: HashMap<String, Vec<(u64, String, String)>> = HashMap::new();
165        let mut suffix_groups: HashMap<String, Vec<(u64, String, String)>> = HashMap::new();
166
167        for unit in self.graph.units() {
168            if unit.unit_type != CodeUnitType::Function && unit.unit_type != CodeUnitType::Type {
169                continue;
170            }
171
172            let name = &unit.name;
173
174            // Detect prefix patterns (e.g., get_*, create_*, handle_*)
175            if let Some(prefix) = name.split('_').next() {
176                if prefix.len() >= 3 {
177                    prefix_groups
178                        .entry(format!("{}_*", prefix))
179                        .or_default()
180                        .push((unit.id, name.clone(), unit.file_path.display().to_string()));
181                }
182            }
183
184            // Detect suffix patterns (e.g., *_handler, *_service, *_controller)
185            if let Some(suffix) = name.rsplit('_').next() {
186                if suffix.len() >= 4 {
187                    suffix_groups
188                        .entry(format!("*_{}", suffix))
189                        .or_default()
190                        .push((unit.id, name.clone(), unit.file_path.display().to_string()));
191                }
192            }
193        }
194
195        let mut patterns = Vec::new();
196
197        // Only report groups with 3+ members as patterns
198        for (pattern_name, members) in prefix_groups.into_iter().chain(suffix_groups.into_iter()) {
199            if members.len() < 3 {
200                continue;
201            }
202
203            let instances: Vec<PatternInstance> = members
204                .iter()
205                .map(|(id, name, path)| PatternInstance {
206                    node_id: *id,
207                    name: name.clone(),
208                    file_path: path.clone(),
209                    match_strength: 1.0,
210                    deviations: Vec::new(),
211                })
212                .collect();
213
214            let confidence = (members.len() as f64 * 0.15).min(0.95);
215
216            patterns.push(ExtractedPattern {
217                name: format!("Naming: {}", pattern_name),
218                description: format!(
219                    "Functions/types following the '{}' naming pattern ({} instances)",
220                    pattern_name,
221                    members.len()
222                ),
223                instances,
224                structure: PatternStructure {
225                    template: pattern_name.clone(),
226                    required: vec![format!("Follow '{}' naming convention", pattern_name)],
227                    optional: Vec::new(),
228                    anti_patterns: Vec::new(),
229                },
230                confidence,
231                violations: Vec::new(),
232            });
233        }
234
235        patterns
236    }
237
238    fn detect_structural_patterns(&self) -> Vec<ExtractedPattern> {
239        let mut patterns = Vec::new();
240
241        // Detect module organization patterns
242        let mut dir_groups: HashMap<String, Vec<(u64, String, CodeUnitType)>> = HashMap::new();
243        for unit in self.graph.units() {
244            let dir = unit
245                .file_path
246                .parent()
247                .map(|p| p.display().to_string())
248                .unwrap_or_default();
249            dir_groups
250                .entry(dir)
251                .or_default()
252                .push((unit.id, unit.name.clone(), unit.unit_type));
253        }
254
255        for (dir, members) in &dir_groups {
256            if members.len() < 3 || dir.is_empty() {
257                continue;
258            }
259
260            // Check if all members are the same type (e.g., all functions, all types)
261            let type_counts: HashMap<CodeUnitType, usize> =
262                members.iter().fold(HashMap::new(), |mut acc, (_, _, t)| {
263                    *acc.entry(*t).or_insert(0) += 1;
264                    acc
265                });
266
267            if let Some((&dominant_type, &count)) = type_counts.iter().max_by_key(|(_, c)| *c) {
268                if count as f64 / members.len() as f64 > 0.7 {
269                    let instances: Vec<PatternInstance> = members
270                        .iter()
271                        .filter(|(_, _, t)| *t == dominant_type)
272                        .map(|(id, name, _)| PatternInstance {
273                            node_id: *id,
274                            name: name.clone(),
275                            file_path: dir.clone(),
276                            match_strength: 1.0,
277                            deviations: Vec::new(),
278                        })
279                        .collect();
280
281                    patterns.push(ExtractedPattern {
282                        name: format!("Directory: {} is {}", dir, dominant_type.label()),
283                        description: format!(
284                            "Directory '{}' primarily contains {} ({}% of {})",
285                            dir,
286                            dominant_type.label(),
287                            (count * 100) / members.len(),
288                            members.len()
289                        ),
290                        instances,
291                        structure: PatternStructure {
292                            template: format!("Place {} in {}", dominant_type.label(), dir),
293                            required: vec![format!(
294                                "New {} should go in {}",
295                                dominant_type.label(),
296                                dir
297                            )],
298                            optional: Vec::new(),
299                            anti_patterns: vec![format!(
300                                "Don't place non-{} code in {}",
301                                dominant_type.label(),
302                                dir
303                            )],
304                        },
305                        confidence: (count as f64 / members.len() as f64).min(0.9),
306                        violations: Vec::new(),
307                    });
308                }
309            }
310        }
311
312        patterns
313    }
314}
315
316// ── Tests ────────────────────────────────────────────────────────────────────
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321    use crate::types::{CodeUnit, CodeUnitType, Language, Span};
322    use std::path::PathBuf;
323
324    fn test_graph() -> CodeGraph {
325        let mut graph = CodeGraph::with_default_dimension();
326        // Naming pattern: get_*
327        graph.add_unit(CodeUnit::new(
328            CodeUnitType::Function,
329            Language::Rust,
330            "get_user".to_string(),
331            "mod::get_user".to_string(),
332            PathBuf::from("src/api.rs"),
333            Span::new(1, 0, 10, 0),
334        ));
335        graph.add_unit(CodeUnit::new(
336            CodeUnitType::Function,
337            Language::Rust,
338            "get_order".to_string(),
339            "mod::get_order".to_string(),
340            PathBuf::from("src/api.rs"),
341            Span::new(11, 0, 20, 0),
342        ));
343        graph.add_unit(CodeUnit::new(
344            CodeUnitType::Function,
345            Language::Rust,
346            "get_product".to_string(),
347            "mod::get_product".to_string(),
348            PathBuf::from("src/api.rs"),
349            Span::new(21, 0, 30, 0),
350        ));
351        graph.add_unit(CodeUnit::new(
352            CodeUnitType::Function,
353            Language::Rust,
354            "create_user".to_string(),
355            "mod::create_user".to_string(),
356            PathBuf::from("src/api.rs"),
357            Span::new(31, 0, 40, 0),
358        ));
359        graph
360    }
361
362    #[test]
363    fn extract_naming_patterns() {
364        let graph = test_graph();
365        let extractor = PatternExtractor::new(&graph);
366        let patterns = extractor.extract_patterns();
367        // Should find get_* pattern (3 instances)
368        assert!(patterns.iter().any(|p| p.name.contains("get_")));
369    }
370
371    #[test]
372    fn suggest_patterns_for_file() {
373        let graph = test_graph();
374        let extractor = PatternExtractor::new(&graph);
375        let suggestions = extractor.suggest_patterns("src/api.rs");
376        assert!(!suggestions.is_empty());
377    }
378}