1use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9
10use crate::graph::CodeGraph;
11use crate::types::CodeUnitType;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct ExtractedPattern {
18 pub name: String,
20 pub description: String,
22 pub instances: Vec<PatternInstance>,
24 pub structure: PatternStructure,
26 pub confidence: f64,
28 pub violations: Vec<PatternViolation>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct PatternInstance {
35 pub node_id: u64,
37 pub name: String,
39 pub file_path: String,
41 pub match_strength: f64,
43 pub deviations: Vec<String>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct PatternStructure {
50 pub template: String,
52 pub required: Vec<String>,
54 pub optional: Vec<String>,
56 pub anti_patterns: Vec<String>,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct PatternViolation {
63 pub node_id: u64,
65 pub name: String,
67 pub violation: String,
69 pub suggested_fix: String,
71 pub severity: ViolationSeverity,
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
77pub enum ViolationSeverity {
78 Info,
79 Warning,
80 Error,
81}
82
83pub struct PatternExtractor<'g> {
87 graph: &'g CodeGraph,
88}
89
90impl<'g> PatternExtractor<'g> {
91 pub fn new(graph: &'g CodeGraph) -> Self {
92 Self { graph }
93 }
94
95 pub fn extract_patterns(&self) -> Vec<ExtractedPattern> {
97 let mut patterns = Vec::new();
98
99 patterns.extend(self.detect_naming_patterns());
100 patterns.extend(self.detect_structural_patterns());
101
102 patterns.sort_by(|a, b| {
104 b.confidence
105 .partial_cmp(&a.confidence)
106 .unwrap_or(std::cmp::Ordering::Equal)
107 });
108 patterns
109 }
110
111 pub fn check_patterns(&self, unit_id: u64) -> Vec<PatternViolation> {
113 let patterns = self.extract_patterns();
114 let mut violations = Vec::new();
115
116 if let Some(unit) = self.graph.get_unit(unit_id) {
117 for pattern in &patterns {
118 let should_follow = pattern.instances.iter().any(|inst| {
120 let unit_path = unit.file_path.display().to_string();
122 let inst_path = &inst.file_path;
123 unit_path
124 .rsplit_once('/')
125 .map(|(d, _)| inst_path.starts_with(d))
126 .unwrap_or(false)
127 });
128
129 if should_follow && !pattern.instances.iter().any(|inst| inst.node_id == unit_id) {
130 violations.push(PatternViolation {
131 node_id: unit_id,
132 name: unit.name.clone(),
133 violation: format!("Does not follow '{}' pattern", pattern.name),
134 suggested_fix: format!("Apply pattern: {}", pattern.structure.template),
135 severity: ViolationSeverity::Warning,
136 });
137 }
138 }
139 }
140
141 violations
142 }
143
144 pub fn suggest_patterns(&self, file_path: &str) -> Vec<ExtractedPattern> {
146 let patterns = self.extract_patterns();
147 patterns
148 .into_iter()
149 .filter(|p| {
150 p.instances.iter().any(|inst| {
151 file_path
153 .rsplit_once('/')
154 .map(|(d, _)| inst.file_path.starts_with(d))
155 .unwrap_or(false)
156 })
157 })
158 .collect()
159 }
160
161 fn detect_naming_patterns(&self) -> Vec<ExtractedPattern> {
164 let mut prefix_groups: HashMap<String, Vec<(u64, String, String)>> = HashMap::new();
165 let mut suffix_groups: HashMap<String, Vec<(u64, String, String)>> = HashMap::new();
166
167 for unit in self.graph.units() {
168 if unit.unit_type != CodeUnitType::Function && unit.unit_type != CodeUnitType::Type {
169 continue;
170 }
171
172 let name = &unit.name;
173
174 if let Some(prefix) = name.split('_').next() {
176 if prefix.len() >= 3 {
177 prefix_groups
178 .entry(format!("{}_*", prefix))
179 .or_default()
180 .push((unit.id, name.clone(), unit.file_path.display().to_string()));
181 }
182 }
183
184 if let Some(suffix) = name.rsplit('_').next() {
186 if suffix.len() >= 4 {
187 suffix_groups
188 .entry(format!("*_{}", suffix))
189 .or_default()
190 .push((unit.id, name.clone(), unit.file_path.display().to_string()));
191 }
192 }
193 }
194
195 let mut patterns = Vec::new();
196
197 for (pattern_name, members) in prefix_groups.into_iter().chain(suffix_groups.into_iter()) {
199 if members.len() < 3 {
200 continue;
201 }
202
203 let instances: Vec<PatternInstance> = members
204 .iter()
205 .map(|(id, name, path)| PatternInstance {
206 node_id: *id,
207 name: name.clone(),
208 file_path: path.clone(),
209 match_strength: 1.0,
210 deviations: Vec::new(),
211 })
212 .collect();
213
214 let confidence = (members.len() as f64 * 0.15).min(0.95);
215
216 patterns.push(ExtractedPattern {
217 name: format!("Naming: {}", pattern_name),
218 description: format!(
219 "Functions/types following the '{}' naming pattern ({} instances)",
220 pattern_name,
221 members.len()
222 ),
223 instances,
224 structure: PatternStructure {
225 template: pattern_name.clone(),
226 required: vec![format!("Follow '{}' naming convention", pattern_name)],
227 optional: Vec::new(),
228 anti_patterns: Vec::new(),
229 },
230 confidence,
231 violations: Vec::new(),
232 });
233 }
234
235 patterns
236 }
237
238 fn detect_structural_patterns(&self) -> Vec<ExtractedPattern> {
239 let mut patterns = Vec::new();
240
241 let mut dir_groups: HashMap<String, Vec<(u64, String, CodeUnitType)>> = HashMap::new();
243 for unit in self.graph.units() {
244 let dir = unit
245 .file_path
246 .parent()
247 .map(|p| p.display().to_string())
248 .unwrap_or_default();
249 dir_groups
250 .entry(dir)
251 .or_default()
252 .push((unit.id, unit.name.clone(), unit.unit_type));
253 }
254
255 for (dir, members) in &dir_groups {
256 if members.len() < 3 || dir.is_empty() {
257 continue;
258 }
259
260 let type_counts: HashMap<CodeUnitType, usize> =
262 members.iter().fold(HashMap::new(), |mut acc, (_, _, t)| {
263 *acc.entry(*t).or_insert(0) += 1;
264 acc
265 });
266
267 if let Some((&dominant_type, &count)) = type_counts.iter().max_by_key(|(_, c)| *c) {
268 if count as f64 / members.len() as f64 > 0.7 {
269 let instances: Vec<PatternInstance> = members
270 .iter()
271 .filter(|(_, _, t)| *t == dominant_type)
272 .map(|(id, name, _)| PatternInstance {
273 node_id: *id,
274 name: name.clone(),
275 file_path: dir.clone(),
276 match_strength: 1.0,
277 deviations: Vec::new(),
278 })
279 .collect();
280
281 patterns.push(ExtractedPattern {
282 name: format!("Directory: {} is {}", dir, dominant_type.label()),
283 description: format!(
284 "Directory '{}' primarily contains {} ({}% of {})",
285 dir,
286 dominant_type.label(),
287 (count * 100) / members.len(),
288 members.len()
289 ),
290 instances,
291 structure: PatternStructure {
292 template: format!("Place {} in {}", dominant_type.label(), dir),
293 required: vec![format!(
294 "New {} should go in {}",
295 dominant_type.label(),
296 dir
297 )],
298 optional: Vec::new(),
299 anti_patterns: vec![format!(
300 "Don't place non-{} code in {}",
301 dominant_type.label(),
302 dir
303 )],
304 },
305 confidence: (count as f64 / members.len() as f64).min(0.9),
306 violations: Vec::new(),
307 });
308 }
309 }
310 }
311
312 patterns
313 }
314}
315
316#[cfg(test)]
319mod tests {
320 use super::*;
321 use crate::types::{CodeUnit, CodeUnitType, Language, Span};
322 use std::path::PathBuf;
323
324 fn test_graph() -> CodeGraph {
325 let mut graph = CodeGraph::with_default_dimension();
326 graph.add_unit(CodeUnit::new(
328 CodeUnitType::Function,
329 Language::Rust,
330 "get_user".to_string(),
331 "mod::get_user".to_string(),
332 PathBuf::from("src/api.rs"),
333 Span::new(1, 0, 10, 0),
334 ));
335 graph.add_unit(CodeUnit::new(
336 CodeUnitType::Function,
337 Language::Rust,
338 "get_order".to_string(),
339 "mod::get_order".to_string(),
340 PathBuf::from("src/api.rs"),
341 Span::new(11, 0, 20, 0),
342 ));
343 graph.add_unit(CodeUnit::new(
344 CodeUnitType::Function,
345 Language::Rust,
346 "get_product".to_string(),
347 "mod::get_product".to_string(),
348 PathBuf::from("src/api.rs"),
349 Span::new(21, 0, 30, 0),
350 ));
351 graph.add_unit(CodeUnit::new(
352 CodeUnitType::Function,
353 Language::Rust,
354 "create_user".to_string(),
355 "mod::create_user".to_string(),
356 PathBuf::from("src/api.rs"),
357 Span::new(31, 0, 40, 0),
358 ));
359 graph
360 }
361
362 #[test]
363 fn extract_naming_patterns() {
364 let graph = test_graph();
365 let extractor = PatternExtractor::new(&graph);
366 let patterns = extractor.extract_patterns();
367 assert!(patterns.iter().any(|p| p.name.contains("get_")));
369 }
370
371 #[test]
372 fn suggest_patterns_for_file() {
373 let graph = test_graph();
374 let extractor = PatternExtractor::new(&graph);
375 let suggestions = extractor.suggest_patterns("src/api.rs");
376 assert!(!suggestions.is_empty());
377 }
378}