use std::collections::HashMap;
use std::path::{Path, PathBuf};
use aprender::format::{self, Compression, ModelType, SaveOptions};
use aprender::online::drift::{DriftDetector, DriftStats, DriftStatus, ADWIN};
use aprender::primitives::Matrix;
use aprender::tree::RandomForestClassifier;
use serde::{Deserialize, Serialize};
pub mod autofixer;
pub mod automl_tuning;
pub mod citl_fixer;
pub mod classifier;
#[cfg(feature = "training")]
pub mod corpus_citl;
#[cfg(feature = "training")]
pub mod data_store;
pub mod depyler_training;
pub mod estimator;
pub mod features;
pub mod github_corpus;
#[cfg(feature = "api-fallback")]
pub mod hybrid;
pub mod moe_oracle;
#[cfg(feature = "training")]
pub mod acceleration_pipeline;
pub mod ast_embeddings; pub mod corpus_extract;
pub mod curriculum;
#[cfg(feature = "training")]
pub mod distillation;
#[cfg(feature = "training")]
pub mod error_patterns;
#[cfg(feature = "training")]
pub mod gnn_encoder;
pub mod graph_corpus;
pub mod hansei;
pub mod hybrid_retrieval;
pub mod ngram;
#[cfg(feature = "training")]
pub mod oip_export;
#[cfg(feature = "training")]
pub mod oracle_lineage; pub mod params_persistence;
pub mod patterns;
#[cfg(feature = "training")]
pub mod query_loop;
pub mod self_supervised;
pub mod synthetic;
#[cfg(feature = "training")]
pub mod tarantula;
#[cfg(feature = "training")]
pub mod tarantula_bridge;
#[cfg(feature = "training")]
pub mod tarantula_corpus;
pub mod tfidf;
pub mod training;
pub mod tuning;
pub mod unified_training;
pub mod utol; pub mod verificar_integration;
pub use autofixer::{AutoFixer, FixContext, FixResult, TransformRule};
pub use automl_tuning::{automl_full, automl_optimize, automl_quick, AutoMLConfig, AutoMLResult};
pub use citl_fixer::{CITLFixer, CITLFixerConfig, IterativeFixResult};
#[cfg(feature = "training")]
pub use corpus_citl::{CorpusCITL, IngestionStats};
pub use estimator::{message_to_features, samples_to_features, OracleEstimator};
pub use graph_corpus::{
analyze_graph_corpus, build_graph_corpus, convert_to_training_samples,
load_vectorized_failures, GraphCorpusStats, VectorizedFailure,
};
pub use params_persistence::{
default_params_path, load_params, params_exist, save_params, OptimizedParams,
};
pub use synthetic::{
generate_synthetic_corpus, generate_synthetic_corpus_sized, SyntheticConfig, SyntheticGenerator,
};
pub use tuning::{find_best_config, quick_tune, TuningConfig, TuningResult};
#[cfg(test)]
mod proptests;
pub use classifier::{ErrorCategory, ErrorClassifier};
pub use features::ErrorFeatures;
pub use hansei::{
CategorySummary, HanseiConfig, HanseiReport, IssueSeverity, TranspileHanseiAnalyzer,
TranspileIssue, TranspileOutcome, Trend,
};
#[cfg(feature = "api-fallback")]
pub use hybrid::{
HybridConfig, HybridTranspiler, PatternComplexity, Strategy, TrainingDataCollector,
TranslationPair, TranspileError, TranspileResult, TranspileStats,
};
pub use hybrid_retrieval::{reciprocal_rank_fusion, Bm25Scorer, HybridRetriever, RrfResult};
pub use ngram::{FixPattern, FixSuggestion, NgramFixPredictor};
#[cfg(feature = "training")]
pub use oracle_lineage::OracleLineage;
pub use patterns::{CodeTransform, FixTemplate, FixTemplateRegistry};
pub use tfidf::{CombinedFeatureExtractor, TfidfConfig, TfidfFeatureExtractor};
pub use training::{TrainingDataset, TrainingSample};
pub use depyler_training::{
classify_with_moe, load_real_corpus, train_moe_on_real_corpus, train_moe_oracle,
};
pub use moe_oracle::{ExpertDomain, MoeClassificationResult, MoeOracle, MoeOracleConfig};
#[cfg(feature = "training")]
pub use query_loop::{
apply_simple_diff, auto_fix_loop, AutoFixResult, ErrorContext, OracleMetrics, OracleQueryError,
OracleQueryLoop, OracleStats, OracleSuggestion, ParseRustErrorCodeError, QueryLoopConfig,
RustErrorCode,
};
pub use github_corpus::{
analyze_corpus, build_github_corpus, convert_oip_to_depyler, get_moe_samples_from_oip,
load_oip_training_data, CorpusStats, OipDefectCategory, OipTrainingDataset, OipTrainingExample,
};
pub use unified_training::{
build_default_unified_corpus, build_unified_corpus, build_unified_corpus_with_oip,
print_merge_stats, MergeStats, UnifiedTrainingConfig, UnifiedTrainingResult,
};
#[cfg(feature = "training")]
pub use tarantula::{
FixPriority, SuspiciousTranspilerDecision, TarantulaAnalyzer, TarantulaResult,
TranspilerDecision, TranspilerDecisionRecord,
};
#[cfg(feature = "training")]
pub use tarantula_corpus::{CorpusAnalysisReport, CorpusAnalyzer, TranspilationResult};
#[cfg(feature = "training")]
pub use tarantula_bridge::{
category_to_decision, decision_to_record, decisions_to_records, infer_decisions_from_error,
synthetic_decisions_from_errors,
};
#[cfg(feature = "training")]
pub use error_patterns::{
CorpusEntry, ErrorPattern, ErrorPatternConfig, ErrorPatternLibrary, ErrorPatternStats,
GoldenTraceEntry,
};
pub use curriculum::{
classify_error_difficulty, classify_from_category, CurriculumEntry, CurriculumScheduler,
CurriculumStats, DifficultyLevel,
};
#[cfg(feature = "training")]
pub use distillation::{
DistillationConfig, DistillationStats, ExtractedPattern, KnowledgeDistiller, LlmFixExample,
};
#[cfg(feature = "training")]
pub use gnn_encoder::{
infer_decision_from_match, map_error_category, DepylerGnnEncoder, GnnEncoderConfig,
GnnEncoderStats, SimilarPattern, StructuralPattern,
};
pub use ast_embeddings::{
AstEmbedder, AstEmbedding, AstEmbeddingConfig, CombinedEmbeddingExtractor, CombinedFeatures,
PathContext,
};
#[cfg(feature = "training")]
pub use oip_export::{
export_to_jsonl, BatchExporter, DepylerExport, ErrorCodeClass, ExportStats, SpanInfo,
SuggestionInfo,
};
#[cfg(feature = "training")]
pub use acceleration_pipeline::{
AccelerationPipeline, AnalysisResult, FixSource, PipelineConfig, PipelineStats,
};
#[derive(Debug, thiserror::Error)]
pub enum OracleError {
#[error("Model error: {0}")]
Model(String),
#[error("Feature extraction error: {0}")]
Feature(String),
#[error("Classification error: {0}")]
Classification(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}
pub type Result<T> = std::result::Result<T, OracleError>;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ClassificationResult {
pub category: ErrorCategory,
pub confidence: f32,
pub suggested_fix: Option<String>,
pub related_patterns: Vec<String>,
}
#[derive(Clone, Debug)]
pub struct OracleConfig {
pub n_estimators: usize,
pub max_depth: usize,
pub random_state: Option<u64>,
}
impl Default for OracleConfig {
fn default() -> Self {
Self {
n_estimators: 100,
max_depth: 10,
random_state: Some(42),
}
}
}
pub struct Oracle {
classifier: RandomForestClassifier,
#[allow(dead_code)]
config: OracleConfig,
categories: Vec<ErrorCategory>,
fix_templates: HashMap<ErrorCategory, Vec<String>>,
adwin_detector: ADWIN,
}
const DEFAULT_MODEL_NAME: &str = "depyler_oracle.apr";
fn find_project_root() -> Option<PathBuf> {
let mut root = std::env::current_dir().unwrap_or_default();
for _ in 0..5 {
if root.join("Cargo.toml").exists() {
return Some(root);
}
if !root.pop() {
return None;
}
}
None
}
fn collect_corpus_files(dir: &Path) -> Vec<PathBuf> {
let Ok(entries) = std::fs::read_dir(dir) else {
return Vec::new();
};
entries
.flatten()
.map(|e| e.path())
.filter(|p| p.extension().is_some_and(|e| e == "rs" || e == "json"))
.collect()
}
#[must_use]
pub fn get_training_corpus_paths() -> Vec<PathBuf> {
let Some(root) = find_project_root() else {
return Vec::new();
};
let corpus_dirs = [
root.join("crates/depyler-oracle/src"),
root.join("verificar/corpus"),
root.join("training_data"),
];
let mut paths: Vec<PathBuf> = corpus_dirs
.iter()
.filter(|d| d.exists())
.flat_map(|d| collect_corpus_files(d))
.collect();
paths.sort();
paths
}
impl Oracle {
#[must_use]
pub fn default_model_path() -> PathBuf {
let mut path = std::env::current_dir().unwrap_or_default();
for _ in 0..5 {
if path.join("Cargo.toml").exists() {
return path.join(DEFAULT_MODEL_NAME);
}
if !path.pop() {
break;
}
}
PathBuf::from(DEFAULT_MODEL_NAME)
}
#[cfg(feature = "training")]
pub fn load_or_train() -> Result<Self> {
let model_path = Self::default_model_path();
let lineage_path = OracleLineage::default_lineage_path();
let current_sha = OracleLineage::get_current_commit_sha();
let corpus_paths = get_training_corpus_paths();
let current_corpus_hash = OracleLineage::compute_corpus_hash(&corpus_paths);
let mut lineage = match OracleLineage::load(&lineage_path) {
Ok(l) => l,
Err(e) => {
eprintln!("Warning: Failed to load lineage: {e}. Starting fresh...");
OracleLineage::new()
}
};
let needs_retrain = lineage.needs_retraining(¤t_sha, ¤t_corpus_hash);
if needs_retrain && lineage.model_count() > 0 {
eprintln!("📊 Oracle: Codebase changes detected, triggering retraining...");
} else if needs_retrain {
eprintln!("📊 Oracle: No training history found, will train fresh...");
}
if !needs_retrain && model_path.exists() {
match Self::load(&model_path) {
Ok(oracle) => {
eprintln!("📊 Oracle: Loaded cached model (no changes detected)");
return Ok(oracle);
}
Err(e) => {
eprintln!("Warning: Failed to load cached model: {e}. Retraining...");
}
}
}
let mut dataset = verificar_integration::build_verificar_corpus();
let depyler_corpus = depyler_training::build_combined_corpus();
for sample in depyler_corpus.samples() {
dataset.add(sample.clone());
}
let synthetic_corpus = synthetic::generate_synthetic_corpus();
for sample in synthetic_corpus.samples() {
dataset.add(sample.clone());
}
let sample_count = dataset.samples().len();
let (features, labels_vec) = samples_to_features(dataset.samples());
let labels: Vec<usize> = labels_vec.as_slice().iter().map(|&x| x as usize).collect();
let mut oracle = Self::new();
oracle.train(&features, &labels)?;
if let Err(e) = oracle.save(&model_path) {
eprintln!(
"Warning: Failed to cache model to {}: {e}",
model_path.display()
);
}
let model_id = lineage.record_training(
current_sha,
current_corpus_hash,
sample_count,
0.85, );
if let Some((reason, delta)) = lineage.find_regression() {
eprintln!(
"⚠️ Oracle: Regression detected! Accuracy dropped by {:.2}% ({})",
delta.abs() * 100.0,
reason
);
}
if let Err(e) = lineage.save(&lineage_path) {
eprintln!("Warning: Failed to save lineage: {e}");
} else {
eprintln!(
"📊 Oracle: Training complete ({} samples), lineage recorded as {}",
sample_count, model_id
);
}
Ok(oracle)
}
#[must_use]
pub fn new() -> Self {
Self::with_config(OracleConfig::default())
}
#[must_use]
pub fn with_config(config: OracleConfig) -> Self {
let mut classifier =
RandomForestClassifier::new(config.n_estimators).with_max_depth(config.max_depth);
if let Some(seed) = config.random_state {
classifier = classifier.with_random_state(seed);
}
Self {
classifier,
config,
categories: vec![
ErrorCategory::TypeMismatch,
ErrorCategory::BorrowChecker,
ErrorCategory::MissingImport,
ErrorCategory::SyntaxError,
ErrorCategory::LifetimeError,
ErrorCategory::TraitBound,
ErrorCategory::Other,
],
fix_templates: Self::default_fix_templates(),
adwin_detector: ADWIN::with_delta(0.002),
}
}
fn default_fix_templates() -> HashMap<ErrorCategory, Vec<String>> {
let mut templates = HashMap::new();
templates.insert(
ErrorCategory::TypeMismatch,
vec![
"Convert type using `.into()` or `as`".to_string(),
"Check function signature for expected type".to_string(),
"Use type annotation to clarify".to_string(),
],
);
templates.insert(
ErrorCategory::BorrowChecker,
vec![
"Clone the value instead of borrowing".to_string(),
"Use a reference (&) instead of moving".to_string(),
"Introduce a scope to limit borrow lifetime".to_string(),
],
);
templates.insert(
ErrorCategory::MissingImport,
vec![
"Add `use` statement for the missing type".to_string(),
"Check crate dependencies in Cargo.toml".to_string(),
],
);
templates.insert(
ErrorCategory::SyntaxError,
vec![
"Check for missing semicolons or braces".to_string(),
"Verify function/struct syntax".to_string(),
],
);
templates.insert(
ErrorCategory::LifetimeError,
vec![
"Add explicit lifetime annotation".to_string(),
"Use 'static lifetime for owned data".to_string(),
"Consider using Rc/Arc for shared ownership".to_string(),
],
);
templates.insert(
ErrorCategory::TraitBound,
vec![
"Implement the required trait".to_string(),
"Add trait bound to generic parameter".to_string(),
"Use a wrapper type that implements the trait".to_string(),
],
);
templates.insert(
ErrorCategory::Other,
vec!["Review the full error message for specifics".to_string()],
);
templates
}
pub fn train(&mut self, features: &Matrix<f32>, labels: &[usize]) -> Result<()> {
self.classifier
.fit(features, labels)
.map_err(|e| OracleError::Model(e.to_string()))?;
Ok(())
}
pub fn classify_message(&self, message: &str) -> Result<ClassificationResult> {
let feature_matrix = message_to_features(message);
let predictions = self.classifier.predict(&feature_matrix);
self.build_classification_result(predictions)
}
#[deprecated(since = "3.22.0", note = "Use classify_message for better accuracy")]
pub fn classify(&self, features: &ErrorFeatures) -> Result<ClassificationResult> {
let error_features = features.to_vec();
let n_error_codes = estimator::feature_config::ERROR_CODES.len();
let n_keywords = estimator::feature_config::KEYWORDS.len();
let n_total = n_error_codes + n_keywords + ErrorFeatures::DIM;
let mut full_features = vec![0.0f32; n_total];
for (i, &val) in error_features.iter().enumerate() {
full_features[n_error_codes + n_keywords + i] = val;
}
let feature_matrix = aprender::primitives::Matrix::from_vec(1, n_total, full_features)
.expect("Feature matrix dimensions are correct");
let predictions = self.classifier.predict(&feature_matrix);
self.build_classification_result(predictions)
}
fn build_classification_result(&self, predictions: Vec<usize>) -> Result<ClassificationResult> {
if predictions.is_empty() {
return Err(OracleError::Classification(
"No prediction produced".to_string(),
));
}
let pred_idx = predictions[0];
let category = self
.categories
.get(pred_idx)
.copied()
.unwrap_or(ErrorCategory::Other);
let suggested_fix = self
.fix_templates
.get(&category)
.and_then(|fixes| fixes.first().cloned());
let related = self
.fix_templates
.get(&category)
.map(|fixes| fixes.iter().skip(1).cloned().collect())
.unwrap_or_default();
Ok(ClassificationResult {
category,
confidence: 0.85, suggested_fix,
related_patterns: related,
})
}
pub fn observe_prediction(&mut self, was_error: bool) -> DriftStatus {
self.adwin_detector.add_element(was_error);
self.adwin_detector.detected_change()
}
#[must_use]
pub fn drift_status(&self) -> DriftStatus {
self.adwin_detector.detected_change()
}
#[must_use]
pub fn needs_retraining(&self) -> bool {
matches!(self.drift_status(), DriftStatus::Drift)
}
pub fn reset_drift_detector(&mut self) {
self.adwin_detector.reset();
}
#[must_use]
pub fn drift_stats(&self) -> DriftStats {
self.adwin_detector.stats()
}
pub fn set_adwin_delta(&mut self, delta: f64) {
self.adwin_detector = ADWIN::with_delta(delta);
}
pub fn save(&self, path: &Path) -> Result<()> {
let options = SaveOptions::default()
.with_name("depyler-oracle")
.with_description("RandomForest error classification model for Depyler transpiler")
.with_compression(Compression::ZstdDefault);
format::save(&self.classifier, ModelType::RandomForest, path, options)
.map_err(|e| OracleError::Model(e.to_string()))?;
Ok(())
}
pub fn load(path: &Path) -> Result<Self> {
let classifier: RandomForestClassifier = format::load(path, ModelType::RandomForest)
.map_err(|e| OracleError::Model(e.to_string()))?;
let config = OracleConfig::default();
Ok(Self {
classifier,
config,
categories: vec![
ErrorCategory::TypeMismatch,
ErrorCategory::BorrowChecker,
ErrorCategory::MissingImport,
ErrorCategory::SyntaxError,
ErrorCategory::LifetimeError,
ErrorCategory::TraitBound,
ErrorCategory::Other,
],
fix_templates: Self::default_fix_templates(),
adwin_detector: ADWIN::with_delta(0.002),
})
}
#[cfg(feature = "training")]
#[must_use]
pub fn classify_enhanced(
&self,
error_code: &str,
error_message: &str,
python_source: &str,
rust_source: &str,
gnn_encoder: &mut DepylerGnnEncoder,
) -> EnhancedClassificationResult {
let base_result =
self.classify_message(error_message)
.unwrap_or_else(|_| ClassificationResult {
category: ErrorCategory::Other,
confidence: 0.5,
suggested_fix: None,
related_patterns: vec![],
});
let enhanced_features = features::EnhancedErrorFeatures::from_error_message(error_message);
let similar_patterns = gnn_encoder.find_similar(error_code, error_message, rust_source);
let combined_embedding =
gnn_encoder.encode_combined(error_code, error_message, python_source, rust_source);
let similarity_boost = if !similar_patterns.is_empty() {
similar_patterns[0].similarity * 0.1 } else {
0.0
};
let enhanced_confidence = (base_result.confidence + similarity_boost).min(1.0);
let pattern_fixes: Vec<String> = similar_patterns
.iter()
.filter_map(|sp| {
sp.pattern
.error_pattern
.as_ref()
.map(|ep| ep.fix_diff.clone())
})
.filter(|fix| !fix.is_empty())
.take(3)
.collect();
EnhancedClassificationResult {
category: base_result.category,
confidence: enhanced_confidence,
suggested_fix: base_result.suggested_fix,
related_patterns: base_result.related_patterns,
similar_patterns,
enhanced_features,
combined_embedding,
pattern_fixes,
hnsw_used: gnn_encoder.is_hnsw_active(),
}
}
}
#[cfg(feature = "training")]
#[derive(Debug, Clone)]
pub struct EnhancedClassificationResult {
pub category: ErrorCategory,
pub confidence: f32,
pub suggested_fix: Option<String>,
pub related_patterns: Vec<String>,
pub similar_patterns: Vec<SimilarPattern>,
pub enhanced_features: features::EnhancedErrorFeatures,
pub combined_embedding: Vec<f32>,
pub pattern_fixes: Vec<String>,
pub hnsw_used: bool,
}
impl Default for Oracle {
fn default() -> Self {
Self::new()
}
}
pub fn print_drift_status(stats: &DriftStats, status: &DriftStatus) {
let status_indicator = match status {
DriftStatus::Stable => "🟢 STABLE",
DriftStatus::Warning => "🟡 WARNING",
DriftStatus::Drift => "🔴 DRIFT DETECTED",
};
println!("╭─────────────────────────────────────────────────────╮");
println!("│ Drift Detection Status │");
println!("├─────────────────────────────────────────────────────┤");
println!("│ Status: {:^40} │", status_indicator);
println!(
"│ Samples: {:>8} │",
stats.n_samples
);
println!(
"│ Error Rate: {:>6.2}% │",
stats.error_rate * 100.0
);
println!(
"│ Min Error Rate: {:>6.2}% │",
stats.min_error_rate * 100.0
);
println!(
"│ Std Dev: {:>8.4} │",
stats.std_dev
);
println!("╰─────────────────────────────────────────────────────╯");
}
pub fn print_retrain_status(stats: &RetrainStats) {
let status_indicator = match &stats.drift_status {
DriftStatus::Stable => "🟢",
DriftStatus::Warning => "🟡",
DriftStatus::Drift => "🔴",
};
let accuracy_bar = create_accuracy_bar(stats.accuracy());
println!("╭─────────────────────────────────────────────────────╮");
println!("│ Retrain Trigger Status │");
println!("├─────────────────────────────────────────────────────┤");
println!(
"│ {} Drift Status: {:?} │",
status_indicator, stats.drift_status
);
println!(
"│ Predictions: {:>8} │",
stats.predictions_observed
);
println!(
"│ Correct: {:>8} │",
stats.correct_predictions
);
println!(
"│ Errors: {:>8} │",
stats.errors
);
println!(
"│ Consecutive: {:>8} │",
stats.consecutive_errors
);
println!(
"│ Drift Count: {:>8} │",
stats.drift_count
);
println!("├─────────────────────────────────────────────────────┤");
println!(
"│ Accuracy: {:>6.2}% {} │",
stats.accuracy() * 100.0,
accuracy_bar
);
println!(
"│ Error Rate: {:>6.2}% │",
stats.error_rate() * 100.0
);
println!("╰─────────────────────────────────────────────────────╯");
}
#[cfg(feature = "training")]
pub fn print_lineage_history(lineage: &OracleLineage) {
println!("╭─────────────────────────────────────────────────────╮");
println!("│ Model Lineage History │");
println!("├─────────────────────────────────────────────────────┤");
println!(
"│ Total Models: {:>6} │",
lineage.model_count()
);
if let Some(latest) = lineage.latest_model() {
let commit_sha = latest
.tags
.get("commit_sha")
.map(|s| &s[..8.min(s.len())])
.unwrap_or("unknown");
println!(
"│ Latest Model: {} │",
latest.model_id.chars().take(30).collect::<String>()
);
println!(
"│ Version: {} │",
latest.version
);
println!(
"│ Accuracy: {:>6.2}% │",
latest.accuracy * 100.0
);
println!("│ Commit: {} │", commit_sha);
}
if let Some((reason, delta)) = lineage.find_regression() {
let indicator = if delta < 0.0 { "🔴" } else { "🟢" };
println!("├─────────────────────────────────────────────────────┤");
println!(
"│ {} Regression: {:+.2}% │",
indicator,
delta * 100.0
);
println!(
"│ Reason: {:40} │",
reason.chars().take(40).collect::<String>()
);
}
let chain = lineage.get_lineage_chain();
if !chain.is_empty() {
println!("├─────────────────────────────────────────────────────┤");
println!(
"│ Lineage Chain ({} models): │",
chain.len()
);
for (i, model_id) in chain.iter().take(5).enumerate() {
let arrow = if i == 0 { "└" } else { "├" };
println!(
"│ {} {} │",
arrow,
model_id.chars().take(35).collect::<String>()
);
}
if chain.len() > 5 {
println!(
"│ ... and {} more │",
chain.len() - 5
);
}
}
println!("╰─────────────────────────────────────────────────────╯");
}
fn create_accuracy_bar(accuracy: f64) -> String {
let filled = (accuracy * 10.0).round() as usize;
let empty = 10 - filled;
format!("[{}{}]", "█".repeat(filled), "░".repeat(empty))
}
#[cfg(feature = "training")]
pub fn print_oracle_status(trigger: &RetrainTrigger, lineage: &OracleLineage) {
print_retrain_status(trigger.stats());
println!();
print_drift_status(&trigger.drift_stats(), &trigger.stats().drift_status);
println!();
print_lineage_history(lineage);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ObserveResult {
Stable,
Warning,
DriftDetected,
}
#[derive(Debug, Clone)]
pub struct RetrainConfig {
pub min_samples: usize,
pub max_consecutive_errors: usize,
pub warning_threshold: f64,
pub drift_threshold: f64,
}
impl Default for RetrainConfig {
fn default() -> Self {
Self {
min_samples: 50,
max_consecutive_errors: 10,
warning_threshold: 0.2,
drift_threshold: 0.3,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct RetrainStats {
pub predictions_observed: u64,
pub correct_predictions: u64,
pub errors: u64,
pub consecutive_errors: usize,
pub drift_status: DriftStatus,
pub drift_count: u64,
}
impl RetrainStats {
#[must_use]
pub fn error_rate(&self) -> f64 {
if self.predictions_observed == 0 {
0.0
} else {
self.errors as f64 / self.predictions_observed as f64
}
}
#[must_use]
pub fn accuracy(&self) -> f64 {
1.0 - self.error_rate()
}
}
pub struct RetrainTrigger {
oracle: Oracle,
config: RetrainConfig,
stats: RetrainStats,
}
impl RetrainTrigger {
pub fn new(oracle: Oracle, config: RetrainConfig) -> Self {
Self {
oracle,
config,
stats: RetrainStats::default(),
}
}
pub fn with_oracle(oracle: Oracle) -> Self {
Self::new(oracle, RetrainConfig::default())
}
pub fn observe(&mut self, was_error: bool) -> ObserveResult {
self.stats.predictions_observed += 1;
if was_error {
self.stats.errors += 1;
self.stats.consecutive_errors += 1;
} else {
self.stats.correct_predictions += 1;
self.stats.consecutive_errors = 0;
}
let drift_status = self.oracle.observe_prediction(was_error);
self.stats.drift_status = drift_status;
if matches!(drift_status, DriftStatus::Drift) {
self.stats.drift_count += 1;
return ObserveResult::DriftDetected;
}
if self.stats.consecutive_errors >= self.config.max_consecutive_errors {
self.stats.drift_count += 1;
return ObserveResult::DriftDetected;
}
if self.stats.predictions_observed >= self.config.min_samples as u64 {
let error_rate = self.stats.error_rate();
if error_rate >= self.config.drift_threshold {
self.stats.drift_count += 1;
return ObserveResult::DriftDetected;
}
if error_rate >= self.config.warning_threshold {
return ObserveResult::Warning;
}
}
ObserveResult::Stable
}
pub fn mark_retrained(&mut self) {
self.oracle.reset_drift_detector();
self.stats.consecutive_errors = 0;
self.stats.predictions_observed = 0;
self.stats.correct_predictions = 0;
self.stats.errors = 0;
}
#[must_use]
pub fn stats(&self) -> &RetrainStats {
&self.stats
}
pub fn oracle_mut(&mut self) -> &mut Oracle {
&mut self.oracle
}
#[must_use]
pub fn oracle(&self) -> &Oracle {
&self.oracle
}
#[must_use]
pub fn needs_retraining(&self) -> bool {
self.oracle.needs_retraining()
|| self.stats.consecutive_errors >= self.config.max_consecutive_errors
|| (self.stats.predictions_observed >= self.config.min_samples as u64
&& self.stats.error_rate() >= self.config.drift_threshold)
}
#[must_use]
pub fn drift_stats(&self) -> DriftStats {
self.oracle.drift_stats()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_oracle_creation() {
let oracle = Oracle::new();
assert_eq!(oracle.categories.len(), 7);
}
#[test]
#[cfg(feature = "training")]
fn test_needs_retrain_no_lineage_file() {
let temp_dir = tempfile::TempDir::new().expect("create temp dir");
let lineage_path = temp_dir.path().join(".depyler").join("oracle_lineage.json");
let lineage = OracleLineage::load(&lineage_path).expect("load should not error");
assert_eq!(
lineage.model_count(),
0,
"No lineage file should return empty lineage"
);
assert!(
lineage.needs_retraining("any_sha", "any_hash"),
"Empty lineage should need retraining"
);
}
#[test]
#[cfg(feature = "training")]
fn test_needs_retrain_commit_changed_lineage() {
let mut lineage = OracleLineage::new();
lineage.record_training(
"abc123def456".to_string(),
"corpus_hash_123".to_string(),
1000,
0.85,
);
assert!(
lineage.needs_retraining("different_sha_789", "corpus_hash_123"),
"Changed commit SHA should trigger retraining"
);
}
#[test]
#[cfg(feature = "training")]
fn test_needs_retrain_corpus_changed_lineage() {
let mut lineage = OracleLineage::new();
lineage.record_training(
"abc123def456".to_string(),
"corpus_hash_123".to_string(),
1000,
0.85,
);
assert!(
lineage.needs_retraining("abc123def456", "different_corpus_hash"),
"Changed corpus hash should trigger retraining"
);
}
#[test]
#[cfg(feature = "training")]
fn test_no_retrain_when_unchanged_lineage() {
let mut lineage = OracleLineage::new();
lineage.record_training(
"abc123def456".to_string(),
"corpus_hash_123".to_string(),
1000,
0.85,
);
assert!(
!lineage.needs_retraining("abc123def456", "corpus_hash_123"),
"Unchanged state should NOT need retraining"
);
}
#[test]
#[cfg(feature = "training")]
fn test_lineage_saves_after_training() {
let temp_dir = tempfile::TempDir::new().expect("create temp dir");
let lineage_path = temp_dir.path().join(".depyler").join("oracle_lineage.json");
let mut lineage = OracleLineage::new();
lineage.record_training(
"test_sha_12345".to_string(),
"test_hash_67890".to_string(),
500,
0.85,
);
lineage.save(&lineage_path).expect("save should work");
assert!(
lineage_path.exists(),
"Lineage file should exist after save"
);
let loaded = OracleLineage::load(&lineage_path).expect("load should work");
assert_eq!(loaded.model_count(), 1);
let latest = loaded.latest_model().expect("should have model");
assert_eq!(
latest.tags.get("commit_sha"),
Some(&"test_sha_12345".to_string())
);
assert_eq!(latest.config_hash, "test_hash_67890");
assert_eq!(latest.tags.get("sample_count"), Some(&"500".to_string()));
}
#[test]
#[cfg(feature = "training")]
fn test_get_corpus_paths_for_hashing() {
let paths = get_training_corpus_paths();
assert!(
paths.is_empty() || !paths.is_empty(),
"get_training_corpus_paths should return a Vec"
);
}
#[test]
fn test_fix_templates() {
let oracle = Oracle::new();
assert!(oracle
.fix_templates
.contains_key(&ErrorCategory::TypeMismatch));
assert!(oracle
.fix_templates
.contains_key(&ErrorCategory::BorrowChecker));
}
#[test]
fn test_adwin_drift_detection_stable() {
let mut oracle = Oracle::new();
for _ in 0..50 {
let status = oracle.observe_prediction(false); assert!(
matches!(status, DriftStatus::Stable),
"All correct predictions should be stable"
);
}
assert!(
!oracle.needs_retraining(),
"Should not need retraining with all correct"
);
}
#[test]
fn test_adwin_drift_detection_gradual_degradation() {
let mut oracle = Oracle::new();
oracle.set_adwin_delta(0.1);
for _ in 0..200 {
oracle.observe_prediction(false);
}
for _ in 0..200 {
oracle.observe_prediction(true);
}
let stats = oracle.drift_stats();
assert!(
oracle.needs_retraining() || stats.error_rate > 0.3,
"Should detect drift or have high error rate: {:?}, drift status: {:?}",
stats,
oracle.drift_status()
);
}
#[test]
fn test_adwin_drift_detector_reset() {
let mut oracle = Oracle::new();
for _ in 0..50 {
oracle.observe_prediction(true);
}
oracle.reset_drift_detector();
assert!(matches!(oracle.drift_status(), DriftStatus::Stable));
}
#[test]
fn test_adwin_drift_stats() {
let mut oracle = Oracle::new();
for _ in 0..10 {
oracle.observe_prediction(false);
}
for _ in 0..10 {
oracle.observe_prediction(true);
}
let stats = oracle.drift_stats();
assert_eq!(stats.n_samples, 20, "Should have 20 samples");
}
#[test]
#[ignore] #[cfg(feature = "training")]
fn test_load_or_train() {
if std::env::var("DEPYLER_FAST_TESTS").is_ok() {
let oracle = Oracle::new();
assert_eq!(oracle.categories.len(), 7);
return;
}
let oracle = Oracle::load_or_train().expect("load_or_train should succeed");
assert_eq!(oracle.categories.len(), 7);
let path = Oracle::default_model_path();
assert!(path.exists(), "Model file should be created at {:?}", path);
let oracle2 = Oracle::load_or_train().expect("second load_or_train should succeed");
assert_eq!(oracle2.categories.len(), 7);
let _ = std::fs::remove_file(path);
}
#[test]
fn test_default_model_path() {
let path = Oracle::default_model_path();
assert!(path.to_string_lossy().contains("depyler_oracle.apr"));
}
#[test]
fn test_retrain_trigger_creation() {
let oracle = Oracle::new();
let trigger = RetrainTrigger::with_oracle(oracle);
let stats = trigger.stats();
assert_eq!(stats.predictions_observed, 0);
assert_eq!(stats.errors, 0);
}
#[test]
fn test_retrain_trigger_observe_correct() {
let oracle = Oracle::new();
let mut trigger = RetrainTrigger::with_oracle(oracle);
for _ in 0..10 {
let result = trigger.observe(false); assert_eq!(result, ObserveResult::Stable);
}
let stats = trigger.stats();
assert_eq!(stats.predictions_observed, 10);
assert_eq!(stats.correct_predictions, 10);
assert_eq!(stats.errors, 0);
}
#[test]
fn test_retrain_trigger_consecutive_errors() {
let oracle = Oracle::new();
let config = RetrainConfig {
max_consecutive_errors: 5,
..Default::default()
};
let mut trigger = RetrainTrigger::new(oracle, config);
for _ in 0..4 {
let result = trigger.observe(true);
assert_eq!(result, ObserveResult::Stable);
}
let result = trigger.observe(true);
assert_eq!(result, ObserveResult::DriftDetected);
}
#[test]
fn test_retrain_trigger_error_rate_threshold() {
let oracle = Oracle::new();
let config = RetrainConfig {
min_samples: 10,
drift_threshold: 0.5,
warning_threshold: 0.3,
max_consecutive_errors: 100, };
let mut trigger = RetrainTrigger::new(oracle, config);
for _ in 0..7 {
trigger.observe(false);
}
for _ in 0..6 {
trigger.observe(true);
}
let result = trigger.observe(true);
assert!(
result == ObserveResult::DriftDetected || result == ObserveResult::Warning,
"Should detect drift or warning at 50% error rate"
);
}
#[test]
fn test_retrain_trigger_mark_retrained() {
let oracle = Oracle::new();
let mut trigger = RetrainTrigger::with_oracle(oracle);
for _ in 0..10 {
trigger.observe(true);
}
assert_eq!(trigger.stats().errors, 10);
trigger.mark_retrained();
let stats = trigger.stats();
assert_eq!(stats.predictions_observed, 0);
assert_eq!(stats.errors, 0);
assert_eq!(stats.consecutive_errors, 0);
}
#[test]
fn test_retrain_stats_error_rate() {
let mut stats = RetrainStats::default();
assert_eq!(stats.error_rate(), 0.0);
stats.predictions_observed = 100;
stats.errors = 25;
assert!((stats.error_rate() - 0.25).abs() < 0.001);
assert!((stats.accuracy() - 0.75).abs() < 0.001);
}
#[test]
fn test_oracle_config_default() {
let config = OracleConfig::default();
assert_eq!(config.n_estimators, 100);
assert_eq!(config.max_depth, 10);
assert_eq!(config.random_state, Some(42));
}
#[test]
fn test_oracle_config_custom() {
let config = OracleConfig {
n_estimators: 50,
max_depth: 5,
random_state: Some(123),
};
assert_eq!(config.n_estimators, 50);
assert_eq!(config.max_depth, 5);
assert_eq!(config.random_state, Some(123));
}
#[test]
fn test_oracle_with_config() {
let config = OracleConfig {
n_estimators: 20,
max_depth: 3,
random_state: None,
};
let oracle = Oracle::with_config(config);
assert_eq!(oracle.categories.len(), 7);
}
#[test]
fn test_oracle_error_display() {
let model_err = OracleError::Model("test error".to_string());
assert!(model_err.to_string().contains("Model error"));
let feature_err = OracleError::Feature("feature error".to_string());
assert!(feature_err.to_string().contains("Feature extraction error"));
let class_err = OracleError::Classification("class error".to_string());
assert!(class_err.to_string().contains("Classification error"));
let io_err = OracleError::Io(std::io::Error::new(
std::io::ErrorKind::NotFound,
"not found",
));
assert!(io_err.to_string().contains("IO error"));
}
#[test]
fn test_classification_result_creation() {
let result = ClassificationResult {
category: ErrorCategory::TypeMismatch,
confidence: 0.95,
suggested_fix: Some("Use .into()".to_string()),
related_patterns: vec!["pattern1".to_string(), "pattern2".to_string()],
};
assert_eq!(result.category, ErrorCategory::TypeMismatch);
assert_eq!(result.confidence, 0.95);
assert!(result.suggested_fix.is_some());
assert_eq!(result.related_patterns.len(), 2);
}
#[test]
fn test_classification_result_clone() {
let result = ClassificationResult {
category: ErrorCategory::BorrowChecker,
confidence: 0.80,
suggested_fix: None,
related_patterns: vec![],
};
let cloned = result.clone();
assert_eq!(cloned.category, result.category);
assert_eq!(cloned.confidence, result.confidence);
}
#[test]
fn test_observe_result_eq() {
assert_eq!(ObserveResult::Stable, ObserveResult::Stable);
assert_eq!(ObserveResult::Warning, ObserveResult::Warning);
assert_eq!(ObserveResult::DriftDetected, ObserveResult::DriftDetected);
assert_ne!(ObserveResult::Stable, ObserveResult::Warning);
}
#[test]
fn test_retrain_config_default() {
let config = RetrainConfig::default();
assert_eq!(config.min_samples, 50);
assert_eq!(config.max_consecutive_errors, 10);
assert!((config.warning_threshold - 0.2).abs() < 0.001);
assert!((config.drift_threshold - 0.3).abs() < 0.001);
}
#[test]
fn test_retrain_config_custom() {
let config = RetrainConfig {
min_samples: 100,
max_consecutive_errors: 5,
warning_threshold: 0.15,
drift_threshold: 0.25,
};
assert_eq!(config.min_samples, 100);
assert_eq!(config.max_consecutive_errors, 5);
}
#[test]
fn test_create_accuracy_bar() {
let bar = create_accuracy_bar(1.0);
assert_eq!(bar, "[██████████]");
let bar = create_accuracy_bar(0.5);
assert_eq!(bar, "[█████░░░░░]");
let bar = create_accuracy_bar(0.0);
assert_eq!(bar, "[░░░░░░░░░░]");
let bar = create_accuracy_bar(0.85);
assert!(bar.contains("█"));
}
#[test]
fn test_print_drift_status_does_not_panic() {
let mut oracle = Oracle::new();
for _ in 0..10 {
oracle.observe_prediction(false);
}
let stats = oracle.drift_stats();
print_drift_status(&stats, &DriftStatus::Stable);
print_drift_status(&stats, &DriftStatus::Warning);
print_drift_status(&stats, &DriftStatus::Drift);
}
#[test]
fn test_print_retrain_status_does_not_panic() {
let stats = RetrainStats {
predictions_observed: 100,
correct_predictions: 80,
errors: 20,
consecutive_errors: 2,
drift_status: DriftStatus::Stable,
drift_count: 0,
};
print_retrain_status(&stats);
}
#[test]
#[cfg(feature = "training")]
fn test_print_lineage_history_does_not_panic() {
let lineage = OracleLineage::new();
print_lineage_history(&lineage);
let mut lineage = OracleLineage::new();
lineage.record_training("sha123".to_string(), "hash456".to_string(), 1000, 0.9);
print_lineage_history(&lineage);
}
#[test]
#[cfg(feature = "training")]
fn test_print_oracle_status_does_not_panic() {
let oracle = Oracle::new();
let trigger = RetrainTrigger::with_oracle(oracle);
let lineage = OracleLineage::new();
print_oracle_status(&trigger, &lineage);
}
#[test]
fn test_retrain_trigger_oracle_access() {
let oracle = Oracle::new();
let mut trigger = RetrainTrigger::with_oracle(oracle);
assert_eq!(trigger.oracle().categories.len(), 7);
let oracle_mut = trigger.oracle_mut();
assert_eq!(oracle_mut.categories.len(), 7);
}
#[test]
fn test_retrain_trigger_drift_stats() {
let oracle = Oracle::new();
let trigger = RetrainTrigger::with_oracle(oracle);
let stats = trigger.drift_stats();
assert_eq!(stats.n_samples, 0);
}
#[test]
fn test_retrain_trigger_needs_retraining() {
let oracle = Oracle::new();
let trigger = RetrainTrigger::with_oracle(oracle);
assert!(!trigger.needs_retraining());
}
#[test]
fn test_oracle_default() {
let oracle = Oracle::default();
assert_eq!(oracle.categories.len(), 7);
}
#[test]
fn test_retrain_stats_default() {
let stats = RetrainStats::default();
assert_eq!(stats.predictions_observed, 0);
assert_eq!(stats.correct_predictions, 0);
assert_eq!(stats.errors, 0);
assert_eq!(stats.consecutive_errors, 0);
assert_eq!(stats.drift_count, 0);
assert!(matches!(stats.drift_status, DriftStatus::Stable));
}
#[test]
fn test_oracle_set_adwin_delta() {
let mut oracle = Oracle::new();
oracle.set_adwin_delta(0.001);
oracle.set_adwin_delta(0.01);
oracle.set_adwin_delta(0.1);
}
#[test]
#[cfg(feature = "training")]
fn test_phase4_enhanced_classification() {
let oracle = Oracle::new();
let mut gnn_encoder = DepylerGnnEncoder::new(GnnEncoderConfig {
similarity_threshold: 0.0, ..Default::default()
});
let pattern =
error_patterns::ErrorPattern::new("E0308", "mismatched types", "+let x: i32 = 42;");
gnn_encoder.index_pattern(&pattern, "let x: i32 = \"hello\";");
let result = oracle.classify_enhanced(
"E0308",
"mismatched types: expected i32, found String",
"def foo(): return \"hello\"",
"fn foo() -> i32 { \"hello\" }",
&mut gnn_encoder,
);
assert!(result.confidence >= 0.0 && result.confidence <= 1.0);
assert!(!result.combined_embedding.is_empty());
assert!(result.enhanced_features.base.message_length > 0.0);
}
#[test]
#[cfg(feature = "training")]
fn test_phase4_enhanced_classification_hnsw_used() {
let oracle = Oracle::new();
let mut gnn_encoder = DepylerGnnEncoder::with_defaults();
let pattern = error_patterns::ErrorPattern::new("E0308", "type mismatch", "+fix");
gnn_encoder.index_pattern(&pattern, "source");
let result = oracle.classify_enhanced(
"E0308",
"type mismatch",
"def foo(): pass",
"fn foo() {}",
&mut gnn_encoder,
);
assert!(
result.hnsw_used,
"HNSW should be used when patterns are indexed"
);
}
#[test]
#[cfg(feature = "training")]
fn test_phase4_enhanced_classification_without_hnsw() {
let oracle = Oracle::new();
let mut gnn_encoder = DepylerGnnEncoder::new(GnnEncoderConfig {
use_hnsw: false,
..Default::default()
});
let result = oracle.classify_enhanced(
"E0308",
"type mismatch",
"def foo(): pass",
"fn foo() {}",
&mut gnn_encoder,
);
assert!(!result.hnsw_used, "HNSW should not be used when disabled");
}
#[test]
#[cfg(feature = "training")]
fn test_phase4_enhanced_classification_combined_embedding_size() {
let oracle = Oracle::new();
let mut gnn_encoder = DepylerGnnEncoder::with_defaults();
let result = oracle.classify_enhanced(
"E0382",
"borrow of moved value",
"def foo(): x = []; return x",
"fn foo() { let x = vec![]; x }",
&mut gnn_encoder,
);
assert_eq!(
result.combined_embedding.len(),
gnn_encoder.combined_dim(),
"Combined embedding should have correct dimension"
);
}
#[test]
#[cfg(feature = "training")]
fn test_phase4_enhanced_features_extraction() {
let oracle = Oracle::new();
let mut gnn_encoder = DepylerGnnEncoder::with_defaults();
let result = oracle.classify_enhanced(
"E0277",
"the trait `Clone` is not implemented for `Foo`",
"class Foo: pass",
"struct Foo {}",
&mut gnn_encoder,
);
let keyword_sum: f32 = result.enhanced_features.keyword_counts.iter().sum();
assert!(keyword_sum > 0.0, "Should extract trait-related keywords");
}
#[test]
#[cfg(feature = "training")]
fn test_phase4_pattern_fixes_extraction() {
let oracle = Oracle::new();
let mut gnn_encoder = DepylerGnnEncoder::new(GnnEncoderConfig {
similarity_threshold: 0.0,
..Default::default()
});
let pattern = error_patterns::ErrorPattern::new(
"E0308",
"type error",
"-let x = \"hello\";\n+let x: i32 = 42;",
);
gnn_encoder.index_pattern(&pattern, "source");
let result = oracle.classify_enhanced(
"E0308",
"type error",
"def foo(): pass",
"fn foo() {}",
&mut gnn_encoder,
);
assert!(
!result.pattern_fixes.is_empty() || result.similar_patterns.is_empty(),
"Should extract fixes from matched patterns"
);
}
#[test]
#[cfg(feature = "training")]
fn test_phase4_enhanced_result_clone() {
let oracle = Oracle::new();
let mut gnn_encoder = DepylerGnnEncoder::with_defaults();
let result = oracle.classify_enhanced(
"E0599",
"method not found",
"foo.bar()",
"foo.bar()",
&mut gnn_encoder,
);
let cloned = result.clone();
assert_eq!(cloned.category, result.category);
assert_eq!(cloned.confidence, result.confidence);
assert_eq!(cloned.hnsw_used, result.hnsw_used);
}
}