use super::patterns::{FixSuggestion, PatternStore};
use super::{ErrorCategory, ErrorFeatures, OracleConfig};
use aprender::online::drift::{DriftDetector, DriftDetectorFactory};
use aprender::prelude::Matrix;
use aprender::tree::RandomForestClassifier;
#[derive(Debug, Clone)]
pub struct Classification {
pub category: ErrorCategory,
pub confidence: f64,
pub suggestions: Vec<FixSuggestion>,
pub should_auto_fix: bool,
}
impl Classification {
#[must_use]
pub fn new(category: ErrorCategory, confidence: f64) -> Self {
Self {
category,
confidence,
suggestions: Vec::new(),
should_auto_fix: false,
}
}
pub fn with_suggestions(mut self, suggestions: Vec<FixSuggestion>) -> Self {
self.suggestions = suggestions;
self
}
pub fn with_auto_fix(mut self, threshold: f64) -> Self {
self.should_auto_fix = self.confidence >= threshold && !self.suggestions.is_empty();
self
}
}
#[derive(Debug, Clone)]
pub struct CompilationError {
pub code: Option<String>,
pub message: String,
pub file_path: Option<String>,
pub line: Option<u32>,
pub column: Option<u32>,
}
impl CompilationError {
#[must_use]
pub fn new(message: impl Into<String>) -> Self {
Self {
code: None,
message: message.into(),
file_path: None,
line: None,
column: None,
}
}
pub fn with_code(mut self, code: impl Into<String>) -> Self {
self.code = Some(code.into());
self
}
pub fn with_file(mut self, path: impl Into<String>) -> Self {
self.file_path = Some(path.into());
self
}
pub fn with_location(mut self, line: u32, column: u32) -> Self {
self.line = Some(line);
self.column = Some(column);
self
}
}
#[derive(Debug, Clone, Default)]
pub struct OracleMetadata {
pub sample_count: usize,
pub training_accuracy: f64,
pub version: String,
pub trained_at: Option<String>,
}
pub struct RuchyOracle {
config: OracleConfig,
metadata: OracleMetadata,
pattern_store: PatternStore,
drift_detector: Box<dyn DriftDetector>,
is_trained: bool,
classifier: Option<RandomForestClassifier>,
training_features: Vec<Vec<f32>>,
training_labels: Vec<usize>,
}
impl RuchyOracle {
#[must_use]
pub fn new() -> Self {
Self::with_config(OracleConfig::default())
}
#[must_use]
pub fn with_config(config: OracleConfig) -> Self {
Self {
config,
metadata: OracleMetadata::default(),
pattern_store: PatternStore::new(),
drift_detector: DriftDetectorFactory::recommended(),
is_trained: false,
classifier: None,
training_features: Vec::new(),
training_labels: Vec::new(),
}
}
pub fn load(path: &std::path::Path) -> Result<Self, OracleError> {
use crate::oracle::SerializedModel;
if !path.exists() {
return Err(OracleError::ModelNotFound(path.to_path_buf()));
}
let model = SerializedModel::load(path)?;
let mut oracle = Self::new();
oracle.load_from_serialized(&model)?;
Ok(oracle)
}
pub fn load_from_serialized(
&mut self,
model: &crate::oracle::SerializedModel,
) -> Result<(), OracleError> {
if model.training_features.is_empty() || model.training_labels.is_empty() {
return Err(OracleError::EmptyTrainingData);
}
self.train(&model.training_features, &model.training_labels)?;
self.metadata.training_accuracy = model.metadata.accuracy;
Ok(())
}
pub fn load_or_train() -> Result<Self, OracleError> {
use crate::oracle::ModelPaths;
let paths = ModelPaths::default();
if let Some(model_path) = paths.find_existing() {
Self::load(&model_path)
} else {
let mut oracle = Self::new();
oracle.train_from_examples()?;
Ok(oracle)
}
}
pub fn train(&mut self, features: &[Vec<f32>], labels: &[usize]) -> Result<(), OracleError> {
if features.is_empty() {
return Err(OracleError::EmptyTrainingData);
}
if features.len() != labels.len() {
return Err(OracleError::MismatchedData {
features: features.len(),
labels: labels.len(),
});
}
self.training_features = features.to_vec();
self.training_labels = labels.to_vec();
let n_samples = features.len();
let n_features = features.first().map_or(0, Vec::len);
let flat_features: Vec<f32> = features.iter().flatten().copied().collect();
let x = Matrix::from_vec(n_samples, n_features, flat_features)
.expect("feature dimensions should match");
let mut rf = RandomForestClassifier::new(10) .with_max_depth(5)
.with_random_state(42);
if let Err(e) = rf.fit(&x, labels) {
eprintln!("RandomForest training failed, using k-NN: {e}");
} else {
self.classifier = Some(rf);
}
self.metadata.sample_count = features.len();
self.metadata.version = "2.0.0-rf".to_string();
self.is_trained = true;
Ok(())
}
pub fn train_from_examples(&mut self) -> Result<(), OracleError> {
let samples = self.generate_bootstrap_samples();
let features: Vec<Vec<f32>> = samples
.iter()
.map(|(msg, code, _)| ErrorFeatures::extract(msg, code.as_deref()).to_vec())
.collect();
let labels: Vec<usize> = samples.iter().map(|(_, _, cat)| cat.to_index()).collect();
self.train(&features, &labels)
}
fn generate_bootstrap_samples(&self) -> Vec<(String, Option<String>, ErrorCategory)> {
vec![
(
"mismatched types: expected `i32`, found `String`".into(),
Some("E0308".into()),
ErrorCategory::TypeMismatch,
),
(
"expected `&str`, found `String`".into(),
Some("E0308".into()),
ErrorCategory::TypeMismatch,
),
(
"type mismatch: expected Vec<i32>".into(),
Some("E0271".into()),
ErrorCategory::TypeMismatch,
),
(
"borrow of moved value: `x`".into(),
Some("E0382".into()),
ErrorCategory::BorrowChecker,
),
(
"cannot borrow `x` as mutable".into(),
Some("E0502".into()),
ErrorCategory::BorrowChecker,
),
(
"value moved here".into(),
Some("E0382".into()),
ErrorCategory::BorrowChecker,
),
(
"borrowed value does not live long enough".into(),
Some("E0597".into()),
ErrorCategory::LifetimeError,
),
(
"lifetime `'a` required".into(),
Some("E0621".into()),
ErrorCategory::LifetimeError,
),
(
"the trait `Debug` is not implemented".into(),
Some("E0277".into()),
ErrorCategory::TraitBound,
),
(
"no method named `foo` found".into(),
Some("E0599".into()),
ErrorCategory::TraitBound,
),
(
"cannot find type `HashMap` in this scope".into(),
Some("E0433".into()),
ErrorCategory::MissingImport,
),
(
"cannot find value `x` in this scope".into(),
Some("E0425".into()),
ErrorCategory::MissingImport,
),
(
"failed to resolve: use of undeclared type".into(),
Some("E0433".into()),
ErrorCategory::MissingImport,
),
(
"use of undeclared type or module".into(),
Some("E0412".into()),
ErrorCategory::MissingImport,
),
(
"cannot borrow `x` as mutable, as it is not declared as mutable".into(),
Some("E0596".into()),
ErrorCategory::MutabilityError,
),
(
"cannot assign to `x`, as it is not declared as mutable".into(),
Some("E0594".into()),
ErrorCategory::MutabilityError,
),
(
"expected `;`, found `}`".into(),
Some("E0658".into()),
ErrorCategory::SyntaxError,
),
(
"this function takes 2 arguments but 1 was supplied".into(),
Some("E0061".into()),
ErrorCategory::SyntaxError,
),
(
"Module 'scanner' not resolved".into(),
None,
ErrorCategory::MissingImport,
),
(
"Failed to resolve module declaration".into(),
None,
ErrorCategory::MissingImport,
),
(
"Module 'utils' not found".into(),
None,
ErrorCategory::MissingImport,
),
(
"Failed to find module".into(),
None,
ErrorCategory::MissingImport,
),
(
"no method named `resolve` found for struct".into(),
Some("E0599".into()),
ErrorCategory::TraitBound,
),
(
"no method named `len` found".into(),
Some("E0599".into()),
ErrorCategory::TraitBound,
),
(
"method not found in".into(),
Some("E0599".into()),
ErrorCategory::TraitBound,
),
(
"item in documentation is missing backticks".into(),
None,
ErrorCategory::SyntaxError,
),
(
"called `map(<f>).unwrap_or(<a>)` on an Option value".into(),
None,
ErrorCategory::SyntaxError,
),
("redundant closure".into(), None, ErrorCategory::SyntaxError),
(
"this function has too many arguments".into(),
None,
ErrorCategory::SyntaxError,
),
(
"this argument is passed by value, but not consumed".into(),
None,
ErrorCategory::SyntaxError,
),
]
}
#[must_use]
pub fn classify(&self, error: &CompilationError) -> Classification {
if let Some(ref code) = error.code {
let category = ErrorCategory::from_error_code(code);
if category != ErrorCategory::Other {
let suggestions = self.pattern_store.query(
category,
&error.message,
self.config.similarity_threshold,
);
return Classification::new(category, 0.95)
.with_suggestions(suggestions)
.with_auto_fix(self.config.confidence_threshold);
}
}
let features = ErrorFeatures::extract(&error.message, error.code.as_deref());
let (category, confidence) = if self.is_trained {
self.predict_with_model(&features)
} else {
self.predict_with_rules(&features)
};
let suggestions =
self.pattern_store
.query(category, &error.message, self.config.similarity_threshold);
Classification::new(category, confidence)
.with_suggestions(suggestions)
.with_auto_fix(self.config.confidence_threshold)
}
fn predict_with_model(&self, features: &ErrorFeatures) -> (ErrorCategory, f64) {
if let Some(ref rf) = self.classifier {
let query = features.to_vec();
let n_features = query.len();
let x = match Matrix::from_vec(1, n_features, query) {
Ok(m) => m,
Err(_) => return self.predict_with_knn(features),
};
let predictions = rf.predict(&x);
if let Some(&label) = predictions.first() {
let category = ErrorCategory::from_index(label).unwrap_or(ErrorCategory::Other);
let proba = rf.predict_proba(&x);
let confidence: f64 = if proba.shape().0 > 0 && proba.shape().1 > 0 {
let row = proba.row(0);
let max_idx = row.argmax();
f64::from(row[max_idx])
} else {
0.8
};
return (category, confidence);
}
}
self.predict_with_knn(features)
}
fn predict_with_knn(&self, features: &ErrorFeatures) -> (ErrorCategory, f64) {
if self.training_features.is_empty() {
return self.predict_with_rules(features);
}
let query = features.to_vec();
let mut best_dist = f64::MAX;
let mut best_label = 0usize;
for (i, train_features) in self.training_features.iter().enumerate() {
let dist = euclidean_distance(&query, train_features);
if dist < best_dist {
best_dist = dist;
best_label = self.training_labels[i];
}
}
let category = ErrorCategory::from_index(best_label).unwrap_or(ErrorCategory::Other);
let confidence = (1.0 / (1.0 + best_dist)).min(1.0);
(category, confidence)
}
fn predict_with_rules(&self, features: &ErrorFeatures) -> (ErrorCategory, f64) {
let category = features.predict_category_rules();
let confidence = if category == ErrorCategory::Other {
0.3
} else {
0.7
};
(category, confidence)
}
#[must_use]
pub fn is_trained(&self) -> bool {
self.is_trained
}
#[must_use]
pub fn metadata(&self) -> &OracleMetadata {
&self.metadata
}
#[must_use]
pub fn config(&self) -> &OracleConfig {
&self.config
}
pub fn set_config(&mut self, config: OracleConfig) {
self.config = config;
}
#[must_use]
pub fn get_training_data(&self) -> (Vec<Vec<f32>>, Vec<usize>) {
(self.training_features.clone(), self.training_labels.clone())
}
pub fn record_result(&mut self, predicted: ErrorCategory, actual: ErrorCategory) {
let error = predicted != actual;
self.drift_detector.add_element(error);
}
#[must_use]
pub fn drift_status(&self) -> aprender::online::drift::DriftStatus {
self.drift_detector.detected_change()
}
pub fn reset_drift_detector(&mut self) {
self.drift_detector.reset();
}
pub fn record_error(&mut self, message: &str, category: ErrorCategory) {
let features = ErrorFeatures::extract(message, None);
self.training_features.push(features.to_vec());
self.training_labels.push(category.to_index());
}
#[must_use]
pub fn should_retrain(&self) -> bool {
const RETRAIN_THRESHOLD: usize = 100;
self.training_labels.len() > 30 + RETRAIN_THRESHOLD }
pub fn retrain(&mut self) -> Result<(), OracleError> {
if self.training_features.is_empty() {
return Err(OracleError::EmptyTrainingData);
}
self.train(
&self.training_features.clone(),
&self.training_labels.clone(),
)
}
#[must_use]
pub fn drift_detected(&self) -> bool {
matches!(
self.drift_detector.detected_change(),
aprender::online::drift::DriftStatus::Drift
)
}
#[must_use]
pub fn parse_rustc_errors(stderr: &str) -> Vec<super::Sample> {
use super::{Sample, SampleSource};
use regex::Regex;
let error_re = Regex::new(r"error\[E(\d{4})\]:\s*(.+?)(?:\n|$)").unwrap();
let mut samples = Vec::new();
for cap in error_re.captures_iter(stderr) {
let code = format!("E{}", &cap[1]);
let message = format!("error[{}]: {}", code, &cap[2]);
let category = ErrorCategory::from_error_code(&code);
samples.push(
Sample::new(message, Some(code), category).with_source(SampleSource::Production),
);
}
samples
}
#[must_use]
pub fn generate_synthetic_samples(count: usize) -> Vec<super::Sample> {
use super::{Sample, SampleSource};
let templates = [
(
"error[E0308]: mismatched types",
"E0308",
ErrorCategory::TypeMismatch,
),
(
"error[E0271]: type mismatch resolving",
"E0271",
ErrorCategory::TypeMismatch,
),
(
"error[E0382]: borrow of moved value",
"E0382",
ErrorCategory::BorrowChecker,
),
(
"error[E0502]: cannot borrow as mutable",
"E0502",
ErrorCategory::BorrowChecker,
),
(
"error[E0597]: borrowed value does not live long enough",
"E0597",
ErrorCategory::LifetimeError,
),
(
"error[E0621]: explicit lifetime required",
"E0621",
ErrorCategory::LifetimeError,
),
(
"error[E0277]: the trait bound is not satisfied",
"E0277",
ErrorCategory::TraitBound,
),
(
"error[E0599]: no method named",
"E0599",
ErrorCategory::TraitBound,
),
(
"error[E0433]: failed to resolve",
"E0433",
ErrorCategory::MissingImport,
),
(
"error[E0412]: cannot find type",
"E0412",
ErrorCategory::MissingImport,
),
(
"error[E0596]: cannot borrow as mutable",
"E0596",
ErrorCategory::MutabilityError,
),
(
"error[E0594]: cannot assign to immutable",
"E0594",
ErrorCategory::MutabilityError,
),
(
"error[E0658]: syntax error",
"E0658",
ErrorCategory::SyntaxError,
),
(
"error[E0061]: wrong number of arguments",
"E0061",
ErrorCategory::SyntaxError,
),
("error: unknown error", "", ErrorCategory::Other),
("error: internal compiler error", "", ErrorCategory::Other),
];
let per_template = count / templates.len();
let mut samples = Vec::with_capacity(count);
for (i, (msg, code, cat)) in templates.iter().cycle().take(count).enumerate() {
let varied_msg = format!("{} (variant {})", msg, i % (per_template.max(1)));
let code_opt = if code.is_empty() {
None
} else {
Some((*code).to_string())
};
samples
.push(Sample::new(varied_msg, code_opt, *cat).with_source(SampleSource::Synthetic));
}
samples
}
}
impl Default for RuchyOracle {
fn default() -> Self {
Self::new()
}
}
fn euclidean_distance(a: &[f32], b: &[f32]) -> f64 {
if a.len() != b.len() {
return f64::MAX;
}
a.iter()
.zip(b.iter())
.map(|(x, y)| (f64::from(*x) - f64::from(*y)).powi(2))
.sum::<f64>()
.sqrt()
}
#[derive(Debug, Clone)]
pub enum OracleError {
ModelNotFound(std::path::PathBuf),
EmptyTrainingData,
MismatchedData { features: usize, labels: usize },
IoError(String),
TrainingFailed(String),
}
impl std::fmt::Display for OracleError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
OracleError::ModelNotFound(path) => {
write!(f, "Oracle model not found: {}", path.display())
}
OracleError::EmptyTrainingData => {
write!(f, "Cannot train on empty data")
}
OracleError::MismatchedData { features, labels } => {
write!(f, "Mismatched data: {features} features, {labels} labels")
}
OracleError::IoError(msg) => write!(f, "IO error: {msg}"),
OracleError::TrainingFailed(msg) => write!(f, "Training failed: {msg}"),
}
}
}
impl std::error::Error for OracleError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_classification_new() {
let classification = Classification::new(ErrorCategory::TypeMismatch, 0.95);
assert_eq!(classification.category, ErrorCategory::TypeMismatch);
assert!((classification.confidence - 0.95).abs() < f64::EPSILON);
assert!(classification.suggestions.is_empty());
assert!(!classification.should_auto_fix);
}
#[test]
fn test_classification_with_auto_fix_above_threshold() {
let suggestions = vec![FixSuggestion::new("add .to_string()")];
let classification = Classification::new(ErrorCategory::TypeMismatch, 0.95)
.with_suggestions(suggestions)
.with_auto_fix(0.85);
assert!(classification.should_auto_fix);
}
#[test]
fn test_classification_with_auto_fix_below_threshold() {
let suggestions = vec![FixSuggestion::new("add .to_string()")];
let classification = Classification::new(ErrorCategory::TypeMismatch, 0.80)
.with_suggestions(suggestions)
.with_auto_fix(0.85);
assert!(!classification.should_auto_fix);
}
#[test]
fn test_classification_no_auto_fix_without_suggestions() {
let classification =
Classification::new(ErrorCategory::TypeMismatch, 0.95).with_auto_fix(0.85);
assert!(!classification.should_auto_fix); }
#[test]
fn test_compilation_error_new() {
let error = CompilationError::new("mismatched types");
assert_eq!(error.message, "mismatched types");
assert!(error.code.is_none());
}
#[test]
fn test_compilation_error_with_code() {
let error = CompilationError::new("mismatched types").with_code("E0308");
assert_eq!(error.code, Some("E0308".to_string()));
}
#[test]
fn test_compilation_error_with_location() {
let error = CompilationError::new("error")
.with_file("main.rs")
.with_location(10, 5);
assert_eq!(error.file_path, Some("main.rs".to_string()));
assert_eq!(error.line, Some(10));
assert_eq!(error.column, Some(5));
}
#[test]
fn test_oracle_new() {
let oracle = RuchyOracle::new();
assert!(!oracle.is_trained());
assert_eq!(oracle.metadata().sample_count, 0);
}
#[test]
fn test_oracle_with_config() {
let config = OracleConfig {
confidence_threshold: 0.90,
..Default::default()
};
let oracle = RuchyOracle::with_config(config);
assert!((oracle.config().confidence_threshold - 0.90).abs() < f64::EPSILON);
}
#[test]
fn test_oracle_train_empty_data() {
let mut oracle = RuchyOracle::new();
let result = oracle.train(&[], &[]);
assert!(matches!(result, Err(OracleError::EmptyTrainingData)));
}
#[test]
fn test_oracle_train_mismatched_data() {
let mut oracle = RuchyOracle::new();
let features = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let labels = vec![0];
let result = oracle.train(&features, &labels);
assert!(matches!(result, Err(OracleError::MismatchedData { .. })));
}
#[test]
fn test_oracle_train_success() {
let mut oracle = RuchyOracle::new();
let features = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let labels = vec![0, 1];
let result = oracle.train(&features, &labels);
assert!(result.is_ok());
assert!(oracle.is_trained());
assert_eq!(oracle.metadata().sample_count, 2);
}
#[test]
fn test_oracle_classify_type_mismatch() {
let mut oracle = RuchyOracle::new();
oracle.train_from_examples().expect("bootstrap training");
let error = CompilationError::new("mismatched types: expected `i32`, found `String`")
.with_code("E0308");
let classification = oracle.classify(&error);
assert_eq!(classification.category, ErrorCategory::TypeMismatch);
assert!(classification.confidence > 0.0);
}
#[test]
fn test_oracle_classify_borrow_checker() {
let mut oracle = RuchyOracle::new();
oracle.train_from_examples().expect("bootstrap training");
let error = CompilationError::new("borrow of moved value").with_code("E0382");
let classification = oracle.classify(&error);
assert_eq!(classification.category, ErrorCategory::BorrowChecker);
}
#[test]
fn test_oracle_classify_untrained_fallback() {
let oracle = RuchyOracle::new();
let error = CompilationError::new("mismatched types").with_code("E0308");
let classification = oracle.classify(&error);
assert_eq!(classification.category, ErrorCategory::TypeMismatch);
assert!((classification.confidence - 0.95).abs() < 0.01);
}
#[test]
fn test_euclidean_distance_same() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0, 3.0];
assert!((euclidean_distance(&a, &b) - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_euclidean_distance_different() {
let a = vec![0.0, 0.0];
let b = vec![3.0, 4.0];
assert!((euclidean_distance(&a, &b) - 5.0).abs() < f64::EPSILON);
}
#[test]
fn test_euclidean_distance_different_lengths() {
let a = vec![1.0, 2.0];
let b = vec![1.0, 2.0, 3.0];
assert_eq!(euclidean_distance(&a, &b), f64::MAX);
}
#[test]
fn test_oracle_error_display() {
let err = OracleError::EmptyTrainingData;
assert_eq!(format!("{err}"), "Cannot train on empty data");
let err = OracleError::MismatchedData {
features: 10,
labels: 5,
};
assert!(format!("{err}").contains("10 features"));
}
#[test]
fn test_oracle_default() {
let oracle = RuchyOracle::default();
assert!(!oracle.is_trained());
}
#[test]
fn test_oracle_record_result() {
let mut oracle = RuchyOracle::new();
oracle.record_result(ErrorCategory::TypeMismatch, ErrorCategory::TypeMismatch);
}
}