use super::{CodeStatistics, DetectedPattern};
use anyhow::Result;
use regex::Regex;
use std::collections::HashMap;
use std::path::Path;
use tokio::fs;
use tracing::debug;
pub struct PatternExtractor {
patterns: Vec<PatternDefinition>,
}
#[derive(Clone)]
struct PatternDefinition {
name: String,
pattern_type: String,
regex: Regex,
languages: Vec<String>,
confidence_boost: f64,
}
impl PatternExtractor {
pub fn new() -> Self {
let patterns = vec![
PatternDefinition {
name: "snake_case".to_string(),
pattern_type: "naming_convention".to_string(),
regex: Regex::new(r"\b[a-z]+(_[a-z]+)+\b").unwrap(),
languages: vec!["Rust".to_string(), "Python".to_string(), "Ruby".to_string()],
confidence_boost: 0.8,
},
PatternDefinition {
name: "camelCase".to_string(),
pattern_type: "naming_convention".to_string(),
regex: Regex::new(r"\b[a-z]+([A-Z][a-z]+)+\b").unwrap(),
languages: vec!["JavaScript".to_string(), "TypeScript".to_string(), "Java".to_string()],
confidence_boost: 0.8,
},
PatternDefinition {
name: "PascalCase".to_string(),
pattern_type: "naming_convention".to_string(),
regex: Regex::new(r"\b[A-Z][a-z]+([A-Z][a-z]+)+\b").unwrap(),
languages: vec!["Java".to_string(), "C#".to_string(), "TypeScript".to_string()],
confidence_boost: 0.8,
},
PatternDefinition {
name: "error_handling_result".to_string(),
pattern_type: "error_handling".to_string(),
regex: Regex::new(r"Result<|\.unwrap\(\)|\.expect\(").unwrap(),
languages: vec!["Rust".to_string()],
confidence_boost: 0.9,
},
PatternDefinition {
name: "async_await".to_string(),
pattern_type: "async_pattern".to_string(),
regex: Regex::new(r"\basync\b|\bawait\b").unwrap(),
languages: vec!["Rust".to_string(), "JavaScript".to_string(), "TypeScript".to_string()],
confidence_boost: 0.85,
},
PatternDefinition {
name: "functional_style".to_string(),
pattern_type: "code_style".to_string(),
regex: Regex::new(r"\.map\(|\.filter\(|\.reduce\(|\.forEach\(").unwrap(),
languages: vec!["JavaScript".to_string(), "TypeScript".to_string(), "Rust".to_string()],
confidence_boost: 0.75,
},
PatternDefinition {
name: "type_definitions".to_string(),
pattern_type: "type_safety".to_string(),
regex: Regex::new(r"\binterface\b|\btype\b|\bstruct\b|\benum\b").unwrap(),
languages: vec!["TypeScript".to_string(), "Rust".to_string()],
confidence_boost: 0.8,
},
PatternDefinition {
name: "documentation_comments".to_string(),
pattern_type: "documentation".to_string(),
regex: Regex::new(r"///|//!|# |<!--").unwrap(),
languages: vec!["Rust".to_string(), "Python".to_string(), "HTML".to_string()],
confidence_boost: 0.7,
},
PatternDefinition {
name: "unit_tests".to_string(),
pattern_type: "testing".to_string(),
regex: Regex::new(r"#\[test\]|#\[cfg\(test\)\]|\btest\b|\bdescribe\b|\bit\b").unwrap(),
languages: vec!["Rust".to_string(), "JavaScript".to_string(), "Python".to_string()],
confidence_boost: 0.85,
},
PatternDefinition {
name: "dependency_injection".to_string(),
pattern_type: "architecture".to_string(),
regex: Regex::new(r"inject|provider|container|module").unwrap(),
languages: vec!["TypeScript".to_string(), "Java".to_string()],
confidence_boost: 0.7,
},
];
Self { patterns }
}
pub async fn extract_patterns(
&self,
path: &Path,
stats: &CodeStatistics,
) -> Result<Vec<DetectedPattern>> {
let mut detected = Vec::new();
let mut pattern_counts: HashMap<String, (usize, Vec<String>)> = HashMap::new();
for (ext, count) in &stats.file_types {
let language = self.extension_to_language(ext);
for pattern in &self.patterns {
if pattern.languages.contains(&language) || pattern.languages.is_empty() {
let key = format!("{}:{}", pattern.name, pattern.pattern_type);
pattern_counts.entry(key)
.or_insert_with(|| (0, Vec::new()))
.0 += count;
}
}
}
let sample_files = self.get_sample_files(path, 20).await?;
for file_path in sample_files {
if let Ok(content) = fs::read_to_string(&file_path).await {
let relative_path = file_path.strip_prefix(path)
.unwrap_or(&file_path)
.to_string_lossy()
.to_string();
for pattern in &self.patterns {
let matches: Vec<_> = pattern.regex.find_iter(&content).collect();
if !matches.is_empty() {
let key = format!("{}:{}", pattern.name, pattern.pattern_type);
if let Some((count, files)) = pattern_counts.get_mut(&key) {
*count += matches.len();
if !files.contains(&relative_path) {
files.push(relative_path.clone());
}
}
}
}
}
}
for (key, (count, files)) in pattern_counts {
if count > 0 {
let parts: Vec<_> = key.split(':').collect();
let name = parts[0].to_string();
let pattern_type = parts[1].to_string();
let base_confidence = if count > 100 {
0.95
} else if count > 50 {
0.85
} else if count > 20 {
0.75
} else {
0.6
};
let pattern_def = self.patterns.iter()
.find(|p| p.name == name)
.map(|p| p.confidence_boost)
.unwrap_or(0.5);
let confidence = (base_confidence + pattern_def) / 2.0;
detected.push(DetectedPattern {
pattern_type,
confidence,
evidence: format!("Found {} occurrences in {} files", count, files.len()),
files: files.into_iter().take(5).collect(),
});
}
}
detected.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
debug!("Extracted {} patterns", detected.len());
Ok(detected)
}
async fn get_sample_files(&self, path: &Path, limit: usize) -> Result<Vec<std::path::PathBuf>> {
let mut files = Vec::new();
for entry in walkdir::WalkDir::new(path)
.max_depth(5)
.into_iter()
.filter_map(|e| e.ok())
{
if files.len() >= limit {
break;
}
let path = entry.path();
if path.is_file() && !self.should_skip(path) {
if let Some(ext) = path.extension() {
let ext = ext.to_string_lossy().to_lowercase();
if self.is_code_file(&ext) {
files.push(path.to_path_buf());
}
}
}
}
Ok(files)
}
fn should_skip(&self, path: &std::path::Path) -> bool {
let skip_dirs = [
"node_modules", "target", "build", "dist", ".git",
"__pycache__", ".next", "out", "vendor"
];
path.components().any(|c| {
if let Some(s) = c.as_os_str().to_str() {
skip_dirs.contains(&s)
} else {
false
}
})
}
fn is_code_file(&self, ext: &str) -> bool {
let code_extensions = [
"rs", "js", "ts", "jsx", "tsx", "py", "java", "go", "rb",
"php", "c", "cpp", "h", "hpp", "cs", "swift", "scala", "kt"
];
code_extensions.contains(&ext)
}
fn extension_to_language(&self, ext: &str) -> String {
match ext {
"rs" => "Rust",
"js" => "JavaScript",
"ts" => "TypeScript",
"jsx" => "JavaScript",
"tsx" => "TypeScript",
"py" => "Python",
"java" => "Java",
"go" => "Go",
"rb" => "Ruby",
"php" => "PHP",
"c" | "cpp" | "h" | "hpp" => "C/C++",
"cs" => "C#",
"swift" => "Swift",
"scala" => "Scala",
"kt" => "Kotlin",
_ => "Other",
}.to_string()
}
}