use std::collections::HashMap;
use aprender::text::tokenize::WhitespaceTokenizer;
use aprender::text::vectorize::TfidfVectorizer;
use aprender::text::Tokenizer;
use serde::{Deserialize, Serialize};
use crate::classifier::ErrorCategory;
use crate::OracleError;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FixPattern {
pub error_pattern: String,
pub fix_template: String,
pub category: ErrorCategory,
pub frequency: usize,
pub success_rate: f32,
}
impl FixPattern {
#[must_use]
pub fn new(error_pattern: &str, fix_template: &str, category: ErrorCategory) -> Self {
Self {
error_pattern: error_pattern.to_string(),
fix_template: fix_template.to_string(),
category,
frequency: 1,
success_rate: 0.0,
}
}
pub fn increment(&mut self) {
self.frequency += 1;
}
pub fn update_success(&mut self, success: bool) {
let alpha = 0.1; let success_val = if success { 1.0 } else { 0.0 };
self.success_rate = alpha * success_val + (1.0 - alpha) * self.success_rate;
}
}
#[derive(Clone, Debug)]
pub struct FixSuggestion {
pub fix: String,
pub confidence: f32,
pub category: ErrorCategory,
pub matched_pattern: String,
}
pub struct NgramFixPredictor {
patterns: HashMap<ErrorCategory, Vec<FixPattern>>,
vectorizer: TfidfVectorizer,
is_fitted: bool,
min_similarity: f32,
ngram_range: (usize, usize),
}
impl NgramFixPredictor {
#[must_use]
pub fn new() -> Self {
Self {
patterns: HashMap::new(),
vectorizer: TfidfVectorizer::new()
.with_tokenizer(Box::new(WhitespaceTokenizer::new()))
.with_ngram_range(1, 2) .with_sublinear_tf(true),
is_fitted: false,
min_similarity: 0.05, ngram_range: (1, 2), }
}
#[must_use]
pub fn with_min_similarity(mut self, threshold: f32) -> Self {
self.min_similarity = threshold.clamp(0.0, 1.0);
self
}
#[must_use]
pub fn with_ngram_range(mut self, min_n: usize, max_n: usize) -> Self {
self.ngram_range = (min_n.max(1), max_n.max(1));
self.vectorizer = TfidfVectorizer::new()
.with_tokenizer(Box::new(WhitespaceTokenizer::new()))
.with_ngram_range(min_n.max(1), max_n.max(1))
.with_sublinear_tf(true);
self.is_fitted = false;
self
}
pub fn learn_pattern(
&mut self,
error_message: &str,
fix_template: &str,
category: ErrorCategory,
) {
let normalized = normalize_error(error_message);
let patterns = self.patterns.entry(category).or_default();
if let Some(existing) = patterns.iter_mut().find(|p| p.error_pattern == normalized) {
existing.increment();
} else {
patterns.push(FixPattern::new(&normalized, fix_template, category));
}
self.is_fitted = false;
}
pub fn learn_batch(&mut self, training_data: &[(String, String, ErrorCategory)]) {
for (error, fix, category) in training_data {
self.learn_pattern(error, fix, *category);
}
}
pub fn fit(&mut self) -> Result<(), OracleError> {
let all_patterns: Vec<String> = self
.patterns
.values()
.flat_map(|ps| ps.iter().map(|p| p.error_pattern.clone()))
.collect();
if all_patterns.is_empty() {
return Err(OracleError::Model(
"No patterns to fit. Call learn_pattern() first.".to_string(),
));
}
self.vectorizer
.fit(&all_patterns)
.map_err(|e| OracleError::Model(e.to_string()))?;
self.is_fitted = true;
Ok(())
}
#[must_use]
pub fn predict_fixes(&self, error_message: &str, top_k: usize) -> Vec<FixSuggestion> {
if !self.is_fitted || self.patterns.is_empty() {
return Vec::new();
}
let normalized = normalize_error(error_message);
let mut suggestions: Vec<FixSuggestion> = Vec::new();
for (category, patterns) in &self.patterns {
for pattern in patterns {
let similarity = self.compute_similarity(&normalized, &pattern.error_pattern);
if similarity >= self.min_similarity {
let confidence = similarity
* (1.0 + (pattern.frequency as f32).ln())
* (0.5 + pattern.success_rate);
suggestions.push(FixSuggestion {
fix: pattern.fix_template.clone(),
confidence: confidence.min(1.0),
category: *category,
matched_pattern: pattern.error_pattern.clone(),
});
}
}
}
suggestions.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
});
suggestions.truncate(top_k);
suggestions
}
#[must_use]
pub fn predict_for_category(
&self,
error_message: &str,
category: ErrorCategory,
top_k: usize,
) -> Vec<FixSuggestion> {
let all = self.predict_fixes(error_message, top_k * 2);
all.into_iter()
.filter(|s| s.category == category)
.take(top_k)
.collect()
}
fn compute_similarity(&self, a: &str, b: &str) -> f32 {
let tokenizer = WhitespaceTokenizer::new();
let tokens_a = tokenizer.tokenize(a).unwrap_or_default();
let tokens_b = tokenizer.tokenize(b).unwrap_or_default();
if tokens_a.is_empty() || tokens_b.is_empty() {
return 0.0;
}
let ngrams_a = generate_ngrams(&tokens_a, self.ngram_range.0, self.ngram_range.1);
let ngrams_b = generate_ngrams(&tokens_b, self.ngram_range.0, self.ngram_range.1);
let intersection = ngrams_a.iter().filter(|ng| ngrams_b.contains(ng)).count();
let union = ngrams_a.len() + ngrams_b.len() - intersection;
if union == 0 {
0.0
} else {
intersection as f32 / union as f32
}
}
pub fn record_feedback(&mut self, error_pattern: &str, success: bool) {
let normalized = normalize_error(error_pattern);
for patterns in self.patterns.values_mut() {
if let Some(pattern) = patterns.iter_mut().find(|p| p.error_pattern == normalized) {
pattern.update_success(success);
return;
}
}
}
#[must_use]
pub fn patterns_for_category(&self, category: ErrorCategory) -> &[FixPattern] {
self.patterns.get(&category).map_or(&[], |v| v.as_slice())
}
#[must_use]
pub fn pattern_count(&self) -> usize {
self.patterns.values().map(|v| v.len()).sum()
}
#[must_use]
pub fn is_fitted(&self) -> bool {
self.is_fitted
}
pub fn save(&self, path: &std::path::Path) -> Result<(), crate::OracleError> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let data = serde_json::to_string_pretty(&self.patterns)
.map_err(|e| crate::OracleError::Model(format!("serialize: {e}")))?;
std::fs::write(path, data)?;
Ok(())
}
pub fn load(&mut self, path: &std::path::Path) -> Result<(), crate::OracleError> {
if !path.exists() {
return Ok(()); }
let data = std::fs::read_to_string(path)?;
let loaded: HashMap<ErrorCategory, Vec<FixPattern>> = serde_json::from_str(&data)
.map_err(|e| crate::OracleError::Model(format!("deserialize: {e}")))?;
for (category, patterns) in loaded {
let existing = self.patterns.entry(category).or_default();
for pattern in patterns {
if !existing
.iter()
.any(|p| p.error_pattern == pattern.error_pattern)
{
existing.push(pattern);
}
}
}
Ok(())
}
#[must_use]
pub fn default_user_model_path() -> std::path::PathBuf {
dirs::home_dir()
.unwrap_or_else(|| std::path::PathBuf::from("."))
.join(".depyler")
.join("oracle_user.bin")
}
}
impl Default for NgramFixPredictor {
fn default() -> Self {
Self::new()
}
}
fn normalize_error(message: &str) -> String {
let error_code = extract_error_code(message);
let code_prefix = error_code
.map(|c| format!("{} {} ", c, c)) .unwrap_or_default();
let normalized = message
.to_lowercase()
.replace(|c: char| c.is_ascii_digit(), "N")
.replace("error:", "")
.replace("-->", "")
.replace(" ", " ")
.trim()
.to_string();
format!("{}{}", code_prefix, normalized)
}
fn extract_error_code(message: &str) -> Option<String> {
if let Some(start) = message.find("error[E") {
if let Some(end) = message[start..].find(']') {
let code = &message[start + 6..start + end];
if code.len() == 4 && code.chars().all(|c| c.is_ascii_digit()) {
return Some(format!("e{}", code.to_lowercase()));
}
}
}
None
}
fn generate_ngrams(tokens: &[String], min_n: usize, max_n: usize) -> Vec<String> {
let mut ngrams = Vec::new();
for n in min_n..=max_n {
if tokens.len() >= n {
for window in tokens.windows(n) {
ngrams.push(window.join("_"));
}
}
}
ngrams
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fix_pattern_creation() {
let pattern = FixPattern::new(
"expected i32, found str",
"Use .parse()",
ErrorCategory::TypeMismatch,
);
assert_eq!(pattern.error_pattern, "expected i32, found str");
assert_eq!(pattern.fix_template, "Use .parse()");
assert_eq!(pattern.category, ErrorCategory::TypeMismatch);
assert_eq!(pattern.frequency, 1);
assert!((pattern.success_rate - 0.0).abs() < 1e-6);
}
#[test]
fn test_fix_pattern_increment() {
let mut pattern = FixPattern::new("test", "fix", ErrorCategory::Other);
assert_eq!(pattern.frequency, 1);
pattern.increment();
assert_eq!(pattern.frequency, 2);
pattern.increment();
assert_eq!(pattern.frequency, 3);
}
#[test]
fn test_fix_pattern_success_update() {
let mut pattern = FixPattern::new("test", "fix", ErrorCategory::Other);
assert!((pattern.success_rate - 0.0).abs() < 1e-6);
pattern.update_success(true);
assert!(pattern.success_rate > 0.0);
for _ in 0..10 {
pattern.update_success(true);
}
assert!(pattern.success_rate > 0.5);
}
#[test]
fn test_predictor_creation() {
let predictor = NgramFixPredictor::new();
assert!(!predictor.is_fitted());
assert_eq!(predictor.pattern_count(), 0);
}
#[test]
fn test_learn_single_pattern() {
let mut predictor = NgramFixPredictor::new();
predictor.learn_pattern(
"expected i32, found &str",
"Convert using .to_string() or .parse()",
ErrorCategory::TypeMismatch,
);
assert_eq!(predictor.pattern_count(), 1);
assert!(!predictor.is_fitted()); }
#[test]
fn test_learn_duplicate_pattern() {
let mut predictor = NgramFixPredictor::new();
predictor.learn_pattern("error msg", "fix", ErrorCategory::Other);
predictor.learn_pattern("error msg", "fix", ErrorCategory::Other);
assert_eq!(predictor.pattern_count(), 1);
let patterns = predictor.patterns_for_category(ErrorCategory::Other);
assert_eq!(patterns[0].frequency, 2);
}
#[test]
fn test_learn_batch() {
let mut predictor = NgramFixPredictor::new();
let training = vec![
(
"expected i32".to_string(),
"use .parse()".to_string(),
ErrorCategory::TypeMismatch,
),
(
"cannot borrow".to_string(),
"use .clone()".to_string(),
ErrorCategory::BorrowChecker,
),
(
"not found".to_string(),
"add use statement".to_string(),
ErrorCategory::MissingImport,
),
];
predictor.learn_batch(&training);
assert_eq!(predictor.pattern_count(), 3);
}
#[test]
fn test_fit_empty_patterns() {
let mut predictor = NgramFixPredictor::new();
let result = predictor.fit();
assert!(result.is_err());
}
#[test]
fn test_fit_with_patterns() {
let mut predictor = NgramFixPredictor::new();
predictor.learn_pattern("expected i32", "convert type", ErrorCategory::TypeMismatch);
predictor.learn_pattern("cannot borrow", "clone value", ErrorCategory::BorrowChecker);
let result = predictor.fit();
assert!(result.is_ok());
assert!(predictor.is_fitted());
}
#[test]
fn test_predict_without_fit() {
let mut predictor = NgramFixPredictor::new();
predictor.learn_pattern("test", "fix", ErrorCategory::Other);
let suggestions = predictor.predict_fixes("test error", 3);
assert!(suggestions.is_empty());
}
#[test]
fn test_predict_basic() {
let mut predictor = NgramFixPredictor::new();
predictor.learn_pattern(
"expected i32, found str",
"Use type conversion",
ErrorCategory::TypeMismatch,
);
predictor.fit().expect("fit should succeed");
let suggestions = predictor.predict_fixes("expected i32, found string", 3);
assert!(!suggestions.is_empty());
assert!(suggestions[0].confidence > 0.0);
}
#[test]
fn test_predict_ranking() {
let mut predictor = NgramFixPredictor::new();
predictor.learn_pattern(
"expected i32, found str",
"Use .parse()",
ErrorCategory::TypeMismatch,
);
predictor.learn_pattern(
"expected u32, found string",
"Use .parse::<u32>()",
ErrorCategory::TypeMismatch,
);
predictor.learn_pattern(
"cannot borrow",
"Use .clone()",
ErrorCategory::BorrowChecker,
);
predictor.fit().expect("fit should succeed");
let suggestions = predictor.predict_fixes("expected u64, found str", 5);
if suggestions.len() >= 2 {
assert!(suggestions[0].confidence >= suggestions[1].confidence);
}
}
#[test]
fn test_predict_for_category() {
let mut predictor = NgramFixPredictor::new();
predictor.learn_pattern("type error", "fix type", ErrorCategory::TypeMismatch);
predictor.learn_pattern("borrow error", "fix borrow", ErrorCategory::BorrowChecker);
predictor.fit().expect("fit should succeed");
let suggestions =
predictor.predict_for_category("type error", ErrorCategory::TypeMismatch, 3);
for s in &suggestions {
assert_eq!(s.category, ErrorCategory::TypeMismatch);
}
}
#[test]
fn test_record_feedback() {
let mut predictor = NgramFixPredictor::new();
predictor.learn_pattern("test error", "test fix", ErrorCategory::Other);
let patterns = predictor.patterns_for_category(ErrorCategory::Other);
let initial_rate = patterns[0].success_rate;
predictor.record_feedback("test error", true);
let patterns = predictor.patterns_for_category(ErrorCategory::Other);
assert!(patterns[0].success_rate > initial_rate);
}
#[test]
fn test_normalize_error() {
let normalized = normalize_error("error: expected i32, found &str at line 42");
assert!(!normalized.contains("E"));
assert!(!normalized.contains("error:"));
assert!(normalized.contains('n') || normalized.contains('N'));
}
#[test]
fn test_generate_ngrams() {
let tokens: Vec<String> = vec!["hello", "world", "rust"]
.into_iter()
.map(String::from)
.collect();
let ngrams = generate_ngrams(&tokens, 1, 2);
assert!(ngrams.contains(&"hello".to_string()));
assert!(ngrams.contains(&"hello_world".to_string()));
assert!(ngrams.contains(&"world_rust".to_string()));
}
#[test]
fn test_similarity_identical() {
let predictor = NgramFixPredictor::new();
let sim = predictor.compute_similarity("expected i32 found str", "expected i32 found str");
assert!((sim - 1.0).abs() < 0.01);
}
#[test]
fn test_similarity_different() {
let predictor = NgramFixPredictor::new();
let sim = predictor.compute_similarity(
"expected i32 found str",
"completely different error message",
);
assert!(sim < 0.5);
}
#[test]
fn test_min_similarity_threshold() {
let mut predictor = NgramFixPredictor::new().with_min_similarity(0.8);
predictor.learn_pattern("exact error", "exact fix", ErrorCategory::Other);
predictor.fit().expect("fit should succeed");
let suggestions = predictor.predict_fixes("completely different", 3);
assert!(suggestions.is_empty());
}
#[test]
fn test_ngram_range_config() {
let predictor = NgramFixPredictor::new().with_ngram_range(2, 4);
assert_eq!(predictor.ngram_range, (2, 4));
assert!(!predictor.is_fitted());
}
#[test]
fn test_patterns_for_nonexistent_category() {
let predictor = NgramFixPredictor::new();
let patterns = predictor.patterns_for_category(ErrorCategory::TypeMismatch);
assert!(patterns.is_empty());
}
#[test]
fn test_fix_suggestion_structure() {
let mut predictor = NgramFixPredictor::new();
predictor.learn_pattern("test error", "test fix", ErrorCategory::TypeMismatch);
predictor.fit().expect("fit should succeed");
let suggestions = predictor.predict_fixes("test error", 1);
if let Some(s) = suggestions.first() {
assert!(!s.fix.is_empty());
assert!(s.confidence > 0.0);
assert!(s.confidence <= 1.0);
assert!(!s.matched_pattern.is_empty());
}
}
#[test]
fn test_frequency_affects_confidence() {
let mut predictor = NgramFixPredictor::new();
for _ in 0..5 {
predictor.learn_pattern(
"frequent error",
"frequent fix",
ErrorCategory::TypeMismatch,
);
}
predictor.learn_pattern("rare error", "rare fix", ErrorCategory::TypeMismatch);
predictor.fit().expect("fit should succeed");
let frequent = predictor.predict_fixes("frequent error", 1);
let rare = predictor.predict_fixes("rare error", 1);
if !frequent.is_empty() && !rare.is_empty() {
assert!(frequent[0].confidence >= rare[0].confidence);
}
}
#[test]
fn test_success_rate_affects_confidence() {
let mut predictor = NgramFixPredictor::new();
predictor.learn_pattern("good pattern", "good fix", ErrorCategory::Other);
predictor.learn_pattern("bad pattern", "bad fix", ErrorCategory::Other);
for _ in 0..10 {
predictor.record_feedback("good pattern", true);
predictor.record_feedback("bad pattern", false);
}
predictor.fit().expect("fit should succeed");
let good = predictor.predict_fixes("good pattern", 1);
let bad = predictor.predict_fixes("bad pattern", 1);
if !good.is_empty() && !bad.is_empty() {
assert!(good[0].confidence >= bad[0].confidence);
}
}
#[test]
fn test_multiple_categories() {
let mut predictor = NgramFixPredictor::new();
predictor.learn_pattern("type error", "type fix", ErrorCategory::TypeMismatch);
predictor.learn_pattern("borrow error", "borrow fix", ErrorCategory::BorrowChecker);
predictor.learn_pattern("import error", "import fix", ErrorCategory::MissingImport);
predictor.learn_pattern(
"lifetime error",
"lifetime fix",
ErrorCategory::LifetimeError,
);
assert_eq!(predictor.pattern_count(), 4);
assert_eq!(
predictor
.patterns_for_category(ErrorCategory::TypeMismatch)
.len(),
1
);
assert_eq!(
predictor
.patterns_for_category(ErrorCategory::BorrowChecker)
.len(),
1
);
}
}