1use std::collections::HashMap;
9use std::path::{Path, PathBuf};
10
11use aprender::format::{self, Compression, ModelType, SaveOptions};
12use aprender::online::drift::{DriftDetector, DriftStats, DriftStatus, ADWIN};
13use aprender::primitives::Matrix;
14use aprender::tree::RandomForestClassifier;
15use serde::{Deserialize, Serialize};
16
17pub mod autofixer;
18pub mod automl_tuning;
19pub mod citl_fixer;
20pub mod classifier;
21#[cfg(feature = "training")]
22pub mod corpus_citl;
23#[cfg(feature = "training")]
24pub mod data_store;
25pub mod depyler_training;
26pub mod estimator;
27pub mod features;
28pub mod github_corpus;
29#[cfg(feature = "api-fallback")]
30pub mod hybrid;
31pub mod moe_oracle;
32#[cfg(feature = "training")]
33pub mod acceleration_pipeline;
34pub mod ast_embeddings; pub mod corpus_extract;
36pub mod curriculum;
37#[cfg(feature = "training")]
38pub mod distillation;
39#[cfg(feature = "training")]
40pub mod error_patterns;
41#[cfg(feature = "training")]
42pub mod gnn_encoder;
43pub mod graph_corpus;
44pub mod hansei;
45pub mod hybrid_retrieval;
46pub mod ngram;
47#[cfg(feature = "training")]
48pub mod oip_export;
49#[cfg(feature = "training")]
50pub mod oracle_lineage; pub mod params_persistence;
52pub mod patterns;
53#[cfg(feature = "training")]
54pub mod query_loop;
55pub mod self_supervised;
56pub mod synthetic;
57#[cfg(feature = "training")]
58pub mod tarantula;
59#[cfg(feature = "training")]
60pub mod tarantula_bridge;
61#[cfg(feature = "training")]
62pub mod tarantula_corpus;
63pub mod tfidf;
64pub mod training;
65pub mod tuning;
66pub mod unified_training;
67pub mod utol; pub mod verificar_integration; pub use autofixer::{AutoFixer, FixContext, FixResult, TransformRule};
71pub use automl_tuning::{automl_full, automl_optimize, automl_quick, AutoMLConfig, AutoMLResult};
72pub use citl_fixer::{CITLFixer, CITLFixerConfig, IterativeFixResult};
73#[cfg(feature = "training")]
74pub use corpus_citl::{CorpusCITL, IngestionStats};
75pub use estimator::{message_to_features, samples_to_features, OracleEstimator};
76pub use graph_corpus::{
77 analyze_graph_corpus, build_graph_corpus, convert_to_training_samples,
78 load_vectorized_failures, GraphCorpusStats, VectorizedFailure,
79};
80pub use params_persistence::{
81 default_params_path, load_params, params_exist, save_params, OptimizedParams,
82};
83pub use synthetic::{
84 generate_synthetic_corpus, generate_synthetic_corpus_sized, SyntheticConfig, SyntheticGenerator,
85};
86pub use tuning::{find_best_config, quick_tune, TuningConfig, TuningResult};
87
88#[cfg(test)]
89mod proptests;
90
91pub use classifier::{ErrorCategory, ErrorClassifier};
92pub use features::ErrorFeatures;
93pub use hansei::{
94 CategorySummary, HanseiConfig, HanseiReport, IssueSeverity, TranspileHanseiAnalyzer,
95 TranspileIssue, TranspileOutcome, Trend,
96};
97#[cfg(feature = "api-fallback")]
98pub use hybrid::{
99 HybridConfig, HybridTranspiler, PatternComplexity, Strategy, TrainingDataCollector,
100 TranslationPair, TranspileError, TranspileResult, TranspileStats,
101};
102pub use hybrid_retrieval::{reciprocal_rank_fusion, Bm25Scorer, HybridRetriever, RrfResult};
103pub use ngram::{FixPattern, FixSuggestion, NgramFixPredictor};
104#[cfg(feature = "training")]
105pub use oracle_lineage::OracleLineage;
106pub use patterns::{CodeTransform, FixTemplate, FixTemplateRegistry};
107pub use tfidf::{CombinedFeatureExtractor, TfidfConfig, TfidfFeatureExtractor};
108pub use training::{TrainingDataset, TrainingSample}; pub use depyler_training::{
112 classify_with_moe, load_real_corpus, train_moe_on_real_corpus, train_moe_oracle,
113};
114pub use moe_oracle::{ExpertDomain, MoeClassificationResult, MoeOracle, MoeOracleConfig};
115
116#[cfg(feature = "training")]
118pub use query_loop::{
119 apply_simple_diff, auto_fix_loop, AutoFixResult, ErrorContext, OracleMetrics, OracleQueryError,
120 OracleQueryLoop, OracleStats, OracleSuggestion, ParseRustErrorCodeError, QueryLoopConfig,
121 RustErrorCode,
122};
123
124pub use github_corpus::{
126 analyze_corpus, build_github_corpus, convert_oip_to_depyler, get_moe_samples_from_oip,
127 load_oip_training_data, CorpusStats, OipDefectCategory, OipTrainingDataset, OipTrainingExample,
128};
129
130pub use unified_training::{
132 build_default_unified_corpus, build_unified_corpus, build_unified_corpus_with_oip,
133 print_merge_stats, MergeStats, UnifiedTrainingConfig, UnifiedTrainingResult,
134};
135
136#[cfg(feature = "training")]
138pub use tarantula::{
139 FixPriority, SuspiciousTranspilerDecision, TarantulaAnalyzer, TarantulaResult,
140 TranspilerDecision, TranspilerDecisionRecord,
141};
142
143#[cfg(feature = "training")]
145pub use tarantula_corpus::{CorpusAnalysisReport, CorpusAnalyzer, TranspilationResult};
146
147#[cfg(feature = "training")]
149pub use tarantula_bridge::{
150 category_to_decision, decision_to_record, decisions_to_records, infer_decisions_from_error,
151 synthetic_decisions_from_errors,
152};
153
154#[cfg(feature = "training")]
156pub use error_patterns::{
157 CorpusEntry, ErrorPattern, ErrorPatternConfig, ErrorPatternLibrary, ErrorPatternStats,
158 GoldenTraceEntry,
159};
160
161pub use curriculum::{
163 classify_error_difficulty, classify_from_category, CurriculumEntry, CurriculumScheduler,
164 CurriculumStats, DifficultyLevel,
165};
166
167#[cfg(feature = "training")]
169pub use distillation::{
170 DistillationConfig, DistillationStats, ExtractedPattern, KnowledgeDistiller, LlmFixExample,
171};
172
173#[cfg(feature = "training")]
175pub use gnn_encoder::{
176 infer_decision_from_match, map_error_category, DepylerGnnEncoder, GnnEncoderConfig,
177 GnnEncoderStats, SimilarPattern, StructuralPattern,
178};
179
180pub use ast_embeddings::{
182 AstEmbedder, AstEmbedding, AstEmbeddingConfig, CombinedEmbeddingExtractor, CombinedFeatures,
183 PathContext,
184};
185
186#[cfg(feature = "training")]
188pub use oip_export::{
189 export_to_jsonl, BatchExporter, DepylerExport, ErrorCodeClass, ExportStats, SpanInfo,
190 SuggestionInfo,
191};
192
193#[cfg(feature = "training")]
195pub use acceleration_pipeline::{
196 AccelerationPipeline, AnalysisResult, FixSource, PipelineConfig, PipelineStats,
197};
198
199#[derive(Debug, thiserror::Error)]
201pub enum OracleError {
202 #[error("Model error: {0}")]
204 Model(String),
205 #[error("Feature extraction error: {0}")]
207 Feature(String),
208 #[error("Classification error: {0}")]
210 Classification(String),
211 #[error("IO error: {0}")]
213 Io(#[from] std::io::Error),
214}
215
216pub type Result<T> = std::result::Result<T, OracleError>;
218
219#[derive(Clone, Debug, Serialize, Deserialize)]
221pub struct ClassificationResult {
222 pub category: ErrorCategory,
224 pub confidence: f32,
226 pub suggested_fix: Option<String>,
228 pub related_patterns: Vec<String>,
230}
231
232#[derive(Clone, Debug)]
246pub struct OracleConfig {
247 pub n_estimators: usize,
249 pub max_depth: usize,
251 pub random_state: Option<u64>,
253}
254
255impl Default for OracleConfig {
256 fn default() -> Self {
257 Self {
258 n_estimators: 100,
261 max_depth: 10,
262 random_state: Some(42),
263 }
264 }
265}
266
267pub struct Oracle {
268 classifier: RandomForestClassifier,
270 #[allow(dead_code)]
272 config: OracleConfig,
273 categories: Vec<ErrorCategory>,
275 fix_templates: HashMap<ErrorCategory, Vec<String>>,
277 adwin_detector: ADWIN,
280}
281
282const DEFAULT_MODEL_NAME: &str = "depyler_oracle.apr";
284
285fn find_project_root() -> Option<PathBuf> {
287 let mut root = std::env::current_dir().unwrap_or_default();
288 for _ in 0..5 {
289 if root.join("Cargo.toml").exists() {
290 return Some(root);
291 }
292 if !root.pop() {
293 return None;
294 }
295 }
296 None
297}
298
299fn collect_corpus_files(dir: &Path) -> Vec<PathBuf> {
301 let Ok(entries) = std::fs::read_dir(dir) else {
302 return Vec::new();
303 };
304 entries
305 .flatten()
306 .map(|e| e.path())
307 .filter(|p| p.extension().is_some_and(|e| e == "rs" || e == "json"))
308 .collect()
309}
310
311#[must_use]
315pub fn get_training_corpus_paths() -> Vec<PathBuf> {
316 let Some(root) = find_project_root() else {
317 return Vec::new();
318 };
319
320 let corpus_dirs = [
321 root.join("crates/depyler-oracle/src"),
322 root.join("verificar/corpus"),
323 root.join("training_data"),
324 ];
325
326 let mut paths: Vec<PathBuf> = corpus_dirs
327 .iter()
328 .filter(|d| d.exists())
329 .flat_map(|d| collect_corpus_files(d))
330 .collect();
331
332 paths.sort();
333 paths
334}
335
336impl Oracle {
337 #[must_use]
339 pub fn default_model_path() -> PathBuf {
340 let mut path = std::env::current_dir().unwrap_or_default();
342 for _ in 0..5 {
343 if path.join("Cargo.toml").exists() {
344 return path.join(DEFAULT_MODEL_NAME);
345 }
346 if !path.pop() {
347 break;
348 }
349 }
350 PathBuf::from(DEFAULT_MODEL_NAME)
352 }
353
354 #[cfg(feature = "training")]
367 pub fn load_or_train() -> Result<Self> {
368 let model_path = Self::default_model_path();
369 let lineage_path = OracleLineage::default_lineage_path();
370
371 let current_sha = OracleLineage::get_current_commit_sha();
373 let corpus_paths = get_training_corpus_paths();
374 let current_corpus_hash = OracleLineage::compute_corpus_hash(&corpus_paths);
375
376 let mut lineage = match OracleLineage::load(&lineage_path) {
378 Ok(l) => l,
379 Err(e) => {
380 eprintln!("Warning: Failed to load lineage: {e}. Starting fresh...");
381 OracleLineage::new()
382 }
383 };
384
385 let needs_retrain = lineage.needs_retraining(¤t_sha, ¤t_corpus_hash);
387 if needs_retrain && lineage.model_count() > 0 {
388 eprintln!("📊 Oracle: Codebase changes detected, triggering retraining...");
389 } else if needs_retrain {
390 eprintln!("📊 Oracle: No training history found, will train fresh...");
391 }
392
393 if !needs_retrain && model_path.exists() {
395 match Self::load(&model_path) {
396 Ok(oracle) => {
397 eprintln!("📊 Oracle: Loaded cached model (no changes detected)");
398 return Ok(oracle);
399 }
400 Err(e) => {
401 eprintln!("Warning: Failed to load cached model: {e}. Retraining...");
402 }
403 }
404 }
405
406 let mut dataset = verificar_integration::build_verificar_corpus();
408 let depyler_corpus = depyler_training::build_combined_corpus();
409 for sample in depyler_corpus.samples() {
410 dataset.add(sample.clone());
411 }
412
413 let synthetic_corpus = synthetic::generate_synthetic_corpus();
415 for sample in synthetic_corpus.samples() {
416 dataset.add(sample.clone());
417 }
418
419 let sample_count = dataset.samples().len();
420 let (features, labels_vec) = samples_to_features(dataset.samples());
421 let labels: Vec<usize> = labels_vec.as_slice().iter().map(|&x| x as usize).collect();
422
423 let mut oracle = Self::new();
424 oracle.train(&features, &labels)?;
425
426 if let Err(e) = oracle.save(&model_path) {
428 eprintln!(
429 "Warning: Failed to cache model to {}: {e}",
430 model_path.display()
431 );
432 }
433
434 let model_id = lineage.record_training(
437 current_sha,
438 current_corpus_hash,
439 sample_count,
440 0.85, );
442
443 if let Some((reason, delta)) = lineage.find_regression() {
445 eprintln!(
446 "⚠️ Oracle: Regression detected! Accuracy dropped by {:.2}% ({})",
447 delta.abs() * 100.0,
448 reason
449 );
450 }
451
452 if let Err(e) = lineage.save(&lineage_path) {
454 eprintln!("Warning: Failed to save lineage: {e}");
455 } else {
456 eprintln!(
457 "📊 Oracle: Training complete ({} samples), lineage recorded as {}",
458 sample_count, model_id
459 );
460 }
461
462 Ok(oracle)
463 }
464
465 #[must_use]
467 pub fn new() -> Self {
468 Self::with_config(OracleConfig::default())
469 }
470
471 #[must_use]
473 pub fn with_config(config: OracleConfig) -> Self {
474 let mut classifier =
475 RandomForestClassifier::new(config.n_estimators).with_max_depth(config.max_depth);
476 if let Some(seed) = config.random_state {
477 classifier = classifier.with_random_state(seed);
478 }
479
480 Self {
481 classifier,
482 config,
483 categories: vec![
484 ErrorCategory::TypeMismatch,
485 ErrorCategory::BorrowChecker,
486 ErrorCategory::MissingImport,
487 ErrorCategory::SyntaxError,
488 ErrorCategory::LifetimeError,
489 ErrorCategory::TraitBound,
490 ErrorCategory::Other,
491 ],
492 fix_templates: Self::default_fix_templates(),
493 adwin_detector: ADWIN::with_delta(0.002),
496 }
497 }
498
499 fn default_fix_templates() -> HashMap<ErrorCategory, Vec<String>> {
501 let mut templates = HashMap::new();
502
503 templates.insert(
504 ErrorCategory::TypeMismatch,
505 vec![
506 "Convert type using `.into()` or `as`".to_string(),
507 "Check function signature for expected type".to_string(),
508 "Use type annotation to clarify".to_string(),
509 ],
510 );
511
512 templates.insert(
513 ErrorCategory::BorrowChecker,
514 vec![
515 "Clone the value instead of borrowing".to_string(),
516 "Use a reference (&) instead of moving".to_string(),
517 "Introduce a scope to limit borrow lifetime".to_string(),
518 ],
519 );
520
521 templates.insert(
522 ErrorCategory::MissingImport,
523 vec![
524 "Add `use` statement for the missing type".to_string(),
525 "Check crate dependencies in Cargo.toml".to_string(),
526 ],
527 );
528
529 templates.insert(
530 ErrorCategory::SyntaxError,
531 vec![
532 "Check for missing semicolons or braces".to_string(),
533 "Verify function/struct syntax".to_string(),
534 ],
535 );
536
537 templates.insert(
538 ErrorCategory::LifetimeError,
539 vec![
540 "Add explicit lifetime annotation".to_string(),
541 "Use 'static lifetime for owned data".to_string(),
542 "Consider using Rc/Arc for shared ownership".to_string(),
543 ],
544 );
545
546 templates.insert(
547 ErrorCategory::TraitBound,
548 vec![
549 "Implement the required trait".to_string(),
550 "Add trait bound to generic parameter".to_string(),
551 "Use a wrapper type that implements the trait".to_string(),
552 ],
553 );
554
555 templates.insert(
556 ErrorCategory::Other,
557 vec!["Review the full error message for specifics".to_string()],
558 );
559
560 templates
561 }
562
563 pub fn train(&mut self, features: &Matrix<f32>, labels: &[usize]) -> Result<()> {
569 self.classifier
570 .fit(features, labels)
571 .map_err(|e| OracleError::Model(e.to_string()))?;
572
573 Ok(())
574 }
575
576 pub fn classify_message(&self, message: &str) -> Result<ClassificationResult> {
580 let feature_matrix = message_to_features(message);
581 let predictions = self.classifier.predict(&feature_matrix);
582
583 self.build_classification_result(predictions)
584 }
585
586 #[deprecated(since = "3.22.0", note = "Use classify_message for better accuracy")]
591 pub fn classify(&self, features: &ErrorFeatures) -> Result<ClassificationResult> {
592 let error_features = features.to_vec();
594 let n_error_codes = estimator::feature_config::ERROR_CODES.len();
595 let n_keywords = estimator::feature_config::KEYWORDS.len();
596 let n_total = n_error_codes + n_keywords + ErrorFeatures::DIM;
597
598 let mut full_features = vec![0.0f32; n_total];
599 for (i, &val) in error_features.iter().enumerate() {
601 full_features[n_error_codes + n_keywords + i] = val;
602 }
603
604 let feature_matrix = aprender::primitives::Matrix::from_vec(1, n_total, full_features)
605 .expect("Feature matrix dimensions are correct");
606 let predictions = self.classifier.predict(&feature_matrix);
607
608 self.build_classification_result(predictions)
609 }
610
611 fn build_classification_result(&self, predictions: Vec<usize>) -> Result<ClassificationResult> {
612 if predictions.is_empty() {
613 return Err(OracleError::Classification(
614 "No prediction produced".to_string(),
615 ));
616 }
617
618 let pred_idx = predictions[0];
619 let category = self
620 .categories
621 .get(pred_idx)
622 .copied()
623 .unwrap_or(ErrorCategory::Other);
624
625 let suggested_fix = self
626 .fix_templates
627 .get(&category)
628 .and_then(|fixes| fixes.first().cloned());
629
630 let related = self
631 .fix_templates
632 .get(&category)
633 .map(|fixes| fixes.iter().skip(1).cloned().collect())
634 .unwrap_or_default();
635
636 Ok(ClassificationResult {
637 category,
638 confidence: 0.85, suggested_fix,
640 related_patterns: related,
641 })
642 }
643
644 pub fn observe_prediction(&mut self, was_error: bool) -> DriftStatus {
657 self.adwin_detector.add_element(was_error);
658 self.adwin_detector.detected_change()
659 }
660
661 #[must_use]
663 pub fn drift_status(&self) -> DriftStatus {
664 self.adwin_detector.detected_change()
665 }
666
667 #[must_use]
669 pub fn needs_retraining(&self) -> bool {
670 matches!(self.drift_status(), DriftStatus::Drift)
671 }
672
673 pub fn reset_drift_detector(&mut self) {
675 self.adwin_detector.reset();
676 }
677
678 #[must_use]
680 pub fn drift_stats(&self) -> DriftStats {
681 self.adwin_detector.stats()
682 }
683
684 pub fn set_adwin_delta(&mut self, delta: f64) {
690 self.adwin_detector = ADWIN::with_delta(delta);
691 }
692
693 pub fn save(&self, path: &Path) -> Result<()> {
699 let options = SaveOptions::default()
700 .with_name("depyler-oracle")
701 .with_description("RandomForest error classification model for Depyler transpiler")
702 .with_compression(Compression::ZstdDefault);
703
704 format::save(&self.classifier, ModelType::RandomForest, path, options)
705 .map_err(|e| OracleError::Model(e.to_string()))?;
706
707 Ok(())
708 }
709
710 pub fn load(path: &Path) -> Result<Self> {
716 let classifier: RandomForestClassifier = format::load(path, ModelType::RandomForest)
717 .map_err(|e| OracleError::Model(e.to_string()))?;
718
719 let config = OracleConfig::default();
720 Ok(Self {
721 classifier,
722 config,
723 categories: vec![
724 ErrorCategory::TypeMismatch,
725 ErrorCategory::BorrowChecker,
726 ErrorCategory::MissingImport,
727 ErrorCategory::SyntaxError,
728 ErrorCategory::LifetimeError,
729 ErrorCategory::TraitBound,
730 ErrorCategory::Other,
731 ],
732 fix_templates: Self::default_fix_templates(),
733 adwin_detector: ADWIN::with_delta(0.002),
735 })
736 }
737
738 #[cfg(feature = "training")]
760 #[must_use]
761 pub fn classify_enhanced(
762 &self,
763 error_code: &str,
764 error_message: &str,
765 python_source: &str,
766 rust_source: &str,
767 gnn_encoder: &mut DepylerGnnEncoder,
768 ) -> EnhancedClassificationResult {
769 let base_result =
771 self.classify_message(error_message)
772 .unwrap_or_else(|_| ClassificationResult {
773 category: ErrorCategory::Other,
774 confidence: 0.5,
775 suggested_fix: None,
776 related_patterns: vec![],
777 });
778
779 let enhanced_features = features::EnhancedErrorFeatures::from_error_message(error_message);
781
782 let similar_patterns = gnn_encoder.find_similar(error_code, error_message, rust_source);
784
785 let combined_embedding =
787 gnn_encoder.encode_combined(error_code, error_message, python_source, rust_source);
788
789 let similarity_boost = if !similar_patterns.is_empty() {
792 similar_patterns[0].similarity * 0.1 } else {
794 0.0
795 };
796 let enhanced_confidence = (base_result.confidence + similarity_boost).min(1.0);
797
798 let pattern_fixes: Vec<String> = similar_patterns
800 .iter()
801 .filter_map(|sp| {
802 sp.pattern
803 .error_pattern
804 .as_ref()
805 .map(|ep| ep.fix_diff.clone())
806 })
807 .filter(|fix| !fix.is_empty())
808 .take(3)
809 .collect();
810
811 EnhancedClassificationResult {
812 category: base_result.category,
813 confidence: enhanced_confidence,
814 suggested_fix: base_result.suggested_fix,
815 related_patterns: base_result.related_patterns,
816 similar_patterns,
817 enhanced_features,
818 combined_embedding,
819 pattern_fixes,
820 hnsw_used: gnn_encoder.is_hnsw_active(),
821 }
822 }
823}
824
825#[cfg(feature = "training")]
827#[derive(Debug, Clone)]
828pub struct EnhancedClassificationResult {
829 pub category: ErrorCategory,
831 pub confidence: f32,
833 pub suggested_fix: Option<String>,
835 pub related_patterns: Vec<String>,
837 pub similar_patterns: Vec<SimilarPattern>,
839 pub enhanced_features: features::EnhancedErrorFeatures,
841 pub combined_embedding: Vec<f32>,
843 pub pattern_fixes: Vec<String>,
845 pub hnsw_used: bool,
847}
848
849impl Default for Oracle {
850 fn default() -> Self {
851 Self::new()
852 }
853}
854
855pub fn print_drift_status(stats: &DriftStats, status: &DriftStatus) {
861 let status_indicator = match status {
862 DriftStatus::Stable => "🟢 STABLE",
863 DriftStatus::Warning => "🟡 WARNING",
864 DriftStatus::Drift => "🔴 DRIFT DETECTED",
865 };
866
867 println!("╭─────────────────────────────────────────────────────╮");
868 println!("│ Drift Detection Status │");
869 println!("├─────────────────────────────────────────────────────┤");
870 println!("│ Status: {:^40} │", status_indicator);
871 println!(
872 "│ Samples: {:>8} │",
873 stats.n_samples
874 );
875 println!(
876 "│ Error Rate: {:>6.2}% │",
877 stats.error_rate * 100.0
878 );
879 println!(
880 "│ Min Error Rate: {:>6.2}% │",
881 stats.min_error_rate * 100.0
882 );
883 println!(
884 "│ Std Dev: {:>8.4} │",
885 stats.std_dev
886 );
887 println!("╰─────────────────────────────────────────────────────╯");
888}
889
890pub fn print_retrain_status(stats: &RetrainStats) {
892 let status_indicator = match &stats.drift_status {
893 DriftStatus::Stable => "🟢",
894 DriftStatus::Warning => "🟡",
895 DriftStatus::Drift => "🔴",
896 };
897
898 let accuracy_bar = create_accuracy_bar(stats.accuracy());
899
900 println!("╭─────────────────────────────────────────────────────╮");
901 println!("│ Retrain Trigger Status │");
902 println!("├─────────────────────────────────────────────────────┤");
903 println!(
904 "│ {} Drift Status: {:?} │",
905 status_indicator, stats.drift_status
906 );
907 println!(
908 "│ Predictions: {:>8} │",
909 stats.predictions_observed
910 );
911 println!(
912 "│ Correct: {:>8} │",
913 stats.correct_predictions
914 );
915 println!(
916 "│ Errors: {:>8} │",
917 stats.errors
918 );
919 println!(
920 "│ Consecutive: {:>8} │",
921 stats.consecutive_errors
922 );
923 println!(
924 "│ Drift Count: {:>8} │",
925 stats.drift_count
926 );
927 println!("├─────────────────────────────────────────────────────┤");
928 println!(
929 "│ Accuracy: {:>6.2}% {} │",
930 stats.accuracy() * 100.0,
931 accuracy_bar
932 );
933 println!(
934 "│ Error Rate: {:>6.2}% │",
935 stats.error_rate() * 100.0
936 );
937 println!("╰─────────────────────────────────────────────────────╯");
938}
939
940#[cfg(feature = "training")]
942pub fn print_lineage_history(lineage: &OracleLineage) {
943 println!("╭─────────────────────────────────────────────────────╮");
944 println!("│ Model Lineage History │");
945 println!("├─────────────────────────────────────────────────────┤");
946 println!(
947 "│ Total Models: {:>6} │",
948 lineage.model_count()
949 );
950
951 if let Some(latest) = lineage.latest_model() {
952 let commit_sha = latest
953 .tags
954 .get("commit_sha")
955 .map(|s| &s[..8.min(s.len())])
956 .unwrap_or("unknown");
957 println!(
958 "│ Latest Model: {} │",
959 latest.model_id.chars().take(30).collect::<String>()
960 );
961 println!(
962 "│ Version: {} │",
963 latest.version
964 );
965 println!(
966 "│ Accuracy: {:>6.2}% │",
967 latest.accuracy * 100.0
968 );
969 println!("│ Commit: {} │", commit_sha);
970 }
971
972 if let Some((reason, delta)) = lineage.find_regression() {
974 let indicator = if delta < 0.0 { "🔴" } else { "🟢" };
975 println!("├─────────────────────────────────────────────────────┤");
976 println!(
977 "│ {} Regression: {:+.2}% │",
978 indicator,
979 delta * 100.0
980 );
981 println!(
982 "│ Reason: {:40} │",
983 reason.chars().take(40).collect::<String>()
984 );
985 }
986
987 let chain = lineage.get_lineage_chain();
989 if !chain.is_empty() {
990 println!("├─────────────────────────────────────────────────────┤");
991 println!(
992 "│ Lineage Chain ({} models): │",
993 chain.len()
994 );
995 for (i, model_id) in chain.iter().take(5).enumerate() {
996 let arrow = if i == 0 { "└" } else { "├" };
997 println!(
998 "│ {} {} │",
999 arrow,
1000 model_id.chars().take(35).collect::<String>()
1001 );
1002 }
1003 if chain.len() > 5 {
1004 println!(
1005 "│ ... and {} more │",
1006 chain.len() - 5
1007 );
1008 }
1009 }
1010
1011 println!("╰─────────────────────────────────────────────────────╯");
1012}
1013
1014fn create_accuracy_bar(accuracy: f64) -> String {
1016 let filled = (accuracy * 10.0).round() as usize;
1017 let empty = 10 - filled;
1018 format!("[{}{}]", "█".repeat(filled), "░".repeat(empty))
1019}
1020
1021#[cfg(feature = "training")]
1023pub fn print_oracle_status(trigger: &RetrainTrigger, lineage: &OracleLineage) {
1024 print_retrain_status(trigger.stats());
1025 println!();
1026 print_drift_status(&trigger.drift_stats(), &trigger.stats().drift_status);
1027 println!();
1028 print_lineage_history(lineage);
1029}
1030
1031#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1037pub enum ObserveResult {
1038 Stable,
1040 Warning,
1042 DriftDetected,
1044}
1045
1046#[derive(Debug, Clone)]
1048pub struct RetrainConfig {
1049 pub min_samples: usize,
1051 pub max_consecutive_errors: usize,
1053 pub warning_threshold: f64,
1055 pub drift_threshold: f64,
1057}
1058
1059impl Default for RetrainConfig {
1060 fn default() -> Self {
1061 Self {
1062 min_samples: 50,
1063 max_consecutive_errors: 10,
1064 warning_threshold: 0.2,
1065 drift_threshold: 0.3,
1066 }
1067 }
1068}
1069
1070#[derive(Debug, Clone, Default)]
1072pub struct RetrainStats {
1073 pub predictions_observed: u64,
1075 pub correct_predictions: u64,
1077 pub errors: u64,
1079 pub consecutive_errors: usize,
1081 pub drift_status: DriftStatus,
1083 pub drift_count: u64,
1085}
1086
1087impl RetrainStats {
1088 #[must_use]
1090 pub fn error_rate(&self) -> f64 {
1091 if self.predictions_observed == 0 {
1092 0.0
1093 } else {
1094 self.errors as f64 / self.predictions_observed as f64
1095 }
1096 }
1097
1098 #[must_use]
1100 pub fn accuracy(&self) -> f64 {
1101 1.0 - self.error_rate()
1102 }
1103}
1104
1105pub struct RetrainTrigger {
1123 oracle: Oracle,
1125 config: RetrainConfig,
1127 stats: RetrainStats,
1129}
1130
1131impl RetrainTrigger {
1132 pub fn new(oracle: Oracle, config: RetrainConfig) -> Self {
1134 Self {
1135 oracle,
1136 config,
1137 stats: RetrainStats::default(),
1138 }
1139 }
1140
1141 pub fn with_oracle(oracle: Oracle) -> Self {
1143 Self::new(oracle, RetrainConfig::default())
1144 }
1145
1146 pub fn observe(&mut self, was_error: bool) -> ObserveResult {
1154 self.stats.predictions_observed += 1;
1155
1156 if was_error {
1157 self.stats.errors += 1;
1158 self.stats.consecutive_errors += 1;
1159 } else {
1160 self.stats.correct_predictions += 1;
1161 self.stats.consecutive_errors = 0;
1162 }
1163
1164 let drift_status = self.oracle.observe_prediction(was_error);
1166 self.stats.drift_status = drift_status;
1167
1168 if matches!(drift_status, DriftStatus::Drift) {
1170 self.stats.drift_count += 1;
1171 return ObserveResult::DriftDetected;
1172 }
1173
1174 if self.stats.consecutive_errors >= self.config.max_consecutive_errors {
1176 self.stats.drift_count += 1;
1177 return ObserveResult::DriftDetected;
1178 }
1179
1180 if self.stats.predictions_observed >= self.config.min_samples as u64 {
1182 let error_rate = self.stats.error_rate();
1183 if error_rate >= self.config.drift_threshold {
1184 self.stats.drift_count += 1;
1185 return ObserveResult::DriftDetected;
1186 }
1187 if error_rate >= self.config.warning_threshold {
1188 return ObserveResult::Warning;
1189 }
1190 }
1191
1192 ObserveResult::Stable
1193 }
1194
1195 pub fn mark_retrained(&mut self) {
1197 self.oracle.reset_drift_detector();
1198 self.stats.consecutive_errors = 0;
1199 self.stats.predictions_observed = 0;
1200 self.stats.correct_predictions = 0;
1201 self.stats.errors = 0;
1202 }
1203
1204 #[must_use]
1206 pub fn stats(&self) -> &RetrainStats {
1207 &self.stats
1208 }
1209
1210 pub fn oracle_mut(&mut self) -> &mut Oracle {
1212 &mut self.oracle
1213 }
1214
1215 #[must_use]
1217 pub fn oracle(&self) -> &Oracle {
1218 &self.oracle
1219 }
1220
1221 #[must_use]
1223 pub fn needs_retraining(&self) -> bool {
1224 self.oracle.needs_retraining()
1225 || self.stats.consecutive_errors >= self.config.max_consecutive_errors
1226 || (self.stats.predictions_observed >= self.config.min_samples as u64
1227 && self.stats.error_rate() >= self.config.drift_threshold)
1228 }
1229
1230 #[must_use]
1232 pub fn drift_stats(&self) -> DriftStats {
1233 self.oracle.drift_stats()
1234 }
1235}
1236
1237#[cfg(test)]
1238mod tests {
1239 use super::*;
1240
1241 #[test]
1242 fn test_oracle_creation() {
1243 let oracle = Oracle::new();
1244 assert_eq!(oracle.categories.len(), 7);
1245 }
1246
1247 #[test]
1252 #[cfg(feature = "training")]
1253 fn test_needs_retrain_no_lineage_file() {
1254 let temp_dir = tempfile::TempDir::new().expect("create temp dir");
1256 let lineage_path = temp_dir.path().join(".depyler").join("oracle_lineage.json");
1257
1258 let lineage = OracleLineage::load(&lineage_path).expect("load should not error");
1259 assert_eq!(
1260 lineage.model_count(),
1261 0,
1262 "No lineage file should return empty lineage"
1263 );
1264
1265 assert!(
1267 lineage.needs_retraining("any_sha", "any_hash"),
1268 "Empty lineage should need retraining"
1269 );
1270 }
1271
1272 #[test]
1273 #[cfg(feature = "training")]
1274 fn test_needs_retrain_commit_changed_lineage() {
1275 let mut lineage = OracleLineage::new();
1277 lineage.record_training(
1278 "abc123def456".to_string(),
1279 "corpus_hash_123".to_string(),
1280 1000,
1281 0.85,
1282 );
1283
1284 assert!(
1285 lineage.needs_retraining("different_sha_789", "corpus_hash_123"),
1286 "Changed commit SHA should trigger retraining"
1287 );
1288 }
1289
1290 #[test]
1291 #[cfg(feature = "training")]
1292 fn test_needs_retrain_corpus_changed_lineage() {
1293 let mut lineage = OracleLineage::new();
1295 lineage.record_training(
1296 "abc123def456".to_string(),
1297 "corpus_hash_123".to_string(),
1298 1000,
1299 0.85,
1300 );
1301
1302 assert!(
1303 lineage.needs_retraining("abc123def456", "different_corpus_hash"),
1304 "Changed corpus hash should trigger retraining"
1305 );
1306 }
1307
1308 #[test]
1309 #[cfg(feature = "training")]
1310 fn test_no_retrain_when_unchanged_lineage() {
1311 let mut lineage = OracleLineage::new();
1313 lineage.record_training(
1314 "abc123def456".to_string(),
1315 "corpus_hash_123".to_string(),
1316 1000,
1317 0.85,
1318 );
1319
1320 assert!(
1321 !lineage.needs_retraining("abc123def456", "corpus_hash_123"),
1322 "Unchanged state should NOT need retraining"
1323 );
1324 }
1325
1326 #[test]
1327 #[cfg(feature = "training")]
1328 fn test_lineage_saves_after_training() {
1329 let temp_dir = tempfile::TempDir::new().expect("create temp dir");
1330 let lineage_path = temp_dir.path().join(".depyler").join("oracle_lineage.json");
1331
1332 let mut lineage = OracleLineage::new();
1334 lineage.record_training(
1335 "test_sha_12345".to_string(),
1336 "test_hash_67890".to_string(),
1337 500,
1338 0.85,
1339 );
1340 lineage.save(&lineage_path).expect("save should work");
1341
1342 assert!(
1344 lineage_path.exists(),
1345 "Lineage file should exist after save"
1346 );
1347
1348 let loaded = OracleLineage::load(&lineage_path).expect("load should work");
1350 assert_eq!(loaded.model_count(), 1);
1351
1352 let latest = loaded.latest_model().expect("should have model");
1354 assert_eq!(
1355 latest.tags.get("commit_sha"),
1356 Some(&"test_sha_12345".to_string())
1357 );
1358 assert_eq!(latest.config_hash, "test_hash_67890");
1359 assert_eq!(latest.tags.get("sample_count"), Some(&"500".to_string()));
1360 }
1361
1362 #[test]
1363 #[cfg(feature = "training")]
1364 fn test_get_corpus_paths_for_hashing() {
1365 let paths = get_training_corpus_paths();
1368 assert!(
1370 paths.is_empty() || !paths.is_empty(),
1371 "get_training_corpus_paths should return a Vec"
1372 );
1373 }
1374
1375 #[test]
1376 fn test_fix_templates() {
1377 let oracle = Oracle::new();
1378 assert!(oracle
1379 .fix_templates
1380 .contains_key(&ErrorCategory::TypeMismatch));
1381 assert!(oracle
1382 .fix_templates
1383 .contains_key(&ErrorCategory::BorrowChecker));
1384 }
1385
1386 #[test]
1391 fn test_adwin_drift_detection_stable() {
1392 let mut oracle = Oracle::new();
1393
1394 for _ in 0..50 {
1396 let status = oracle.observe_prediction(false); assert!(
1398 matches!(status, DriftStatus::Stable),
1399 "All correct predictions should be stable"
1400 );
1401 }
1402
1403 assert!(
1404 !oracle.needs_retraining(),
1405 "Should not need retraining with all correct"
1406 );
1407 }
1408
1409 #[test]
1410 fn test_adwin_drift_detection_gradual_degradation() {
1411 let mut oracle = Oracle::new();
1413 oracle.set_adwin_delta(0.1);
1415
1416 for _ in 0..200 {
1418 oracle.observe_prediction(false);
1419 }
1420
1421 for _ in 0..200 {
1423 oracle.observe_prediction(true);
1424 }
1425
1426 let stats = oracle.drift_stats();
1428 assert!(
1431 oracle.needs_retraining() || stats.error_rate > 0.3,
1432 "Should detect drift or have high error rate: {:?}, drift status: {:?}",
1433 stats,
1434 oracle.drift_status()
1435 );
1436 }
1437
1438 #[test]
1439 fn test_adwin_drift_detector_reset() {
1440 let mut oracle = Oracle::new();
1441
1442 for _ in 0..50 {
1444 oracle.observe_prediction(true);
1445 }
1446
1447 oracle.reset_drift_detector();
1449
1450 assert!(matches!(oracle.drift_status(), DriftStatus::Stable));
1452 }
1453
1454 #[test]
1455 fn test_adwin_drift_stats() {
1456 let mut oracle = Oracle::new();
1457
1458 for _ in 0..10 {
1460 oracle.observe_prediction(false);
1461 }
1462 for _ in 0..10 {
1463 oracle.observe_prediction(true);
1464 }
1465
1466 let stats = oracle.drift_stats();
1467 assert_eq!(stats.n_samples, 20, "Should have 20 samples");
1468 }
1469
1470 #[test]
1471 #[ignore] #[cfg(feature = "training")]
1473 fn test_load_or_train() {
1474 if std::env::var("DEPYLER_FAST_TESTS").is_ok() {
1476 let oracle = Oracle::new();
1478 assert_eq!(oracle.categories.len(), 7);
1479 return;
1480 }
1481
1482 let oracle = Oracle::load_or_train().expect("load_or_train should succeed");
1484 assert_eq!(oracle.categories.len(), 7);
1485
1486 let path = Oracle::default_model_path();
1488 assert!(path.exists(), "Model file should be created at {:?}", path);
1489
1490 let oracle2 = Oracle::load_or_train().expect("second load_or_train should succeed");
1492 assert_eq!(oracle2.categories.len(), 7);
1493
1494 let _ = std::fs::remove_file(path);
1496 }
1497
1498 #[test]
1499 fn test_default_model_path() {
1500 let path = Oracle::default_model_path();
1501 assert!(path.to_string_lossy().contains("depyler_oracle.apr"));
1502 }
1503
1504 #[test]
1509 fn test_retrain_trigger_creation() {
1510 let oracle = Oracle::new();
1511 let trigger = RetrainTrigger::with_oracle(oracle);
1512 let stats = trigger.stats();
1513 assert_eq!(stats.predictions_observed, 0);
1514 assert_eq!(stats.errors, 0);
1515 }
1516
1517 #[test]
1518 fn test_retrain_trigger_observe_correct() {
1519 let oracle = Oracle::new();
1520 let mut trigger = RetrainTrigger::with_oracle(oracle);
1521
1522 for _ in 0..10 {
1523 let result = trigger.observe(false); assert_eq!(result, ObserveResult::Stable);
1525 }
1526
1527 let stats = trigger.stats();
1528 assert_eq!(stats.predictions_observed, 10);
1529 assert_eq!(stats.correct_predictions, 10);
1530 assert_eq!(stats.errors, 0);
1531 }
1532
1533 #[test]
1534 fn test_retrain_trigger_consecutive_errors() {
1535 let oracle = Oracle::new();
1536 let config = RetrainConfig {
1537 max_consecutive_errors: 5,
1538 ..Default::default()
1539 };
1540 let mut trigger = RetrainTrigger::new(oracle, config);
1541
1542 for _ in 0..4 {
1544 let result = trigger.observe(true);
1545 assert_eq!(result, ObserveResult::Stable);
1546 }
1547
1548 let result = trigger.observe(true);
1550 assert_eq!(result, ObserveResult::DriftDetected);
1551 }
1552
1553 #[test]
1554 fn test_retrain_trigger_error_rate_threshold() {
1555 let oracle = Oracle::new();
1556 let config = RetrainConfig {
1557 min_samples: 10,
1558 drift_threshold: 0.5,
1559 warning_threshold: 0.3,
1560 max_consecutive_errors: 100, };
1562 let mut trigger = RetrainTrigger::new(oracle, config);
1563
1564 for _ in 0..7 {
1566 trigger.observe(false);
1567 }
1568
1569 for _ in 0..6 {
1571 trigger.observe(true);
1572 }
1573
1574 let result = trigger.observe(true);
1576 assert!(
1577 result == ObserveResult::DriftDetected || result == ObserveResult::Warning,
1578 "Should detect drift or warning at 50% error rate"
1579 );
1580 }
1581
1582 #[test]
1583 fn test_retrain_trigger_mark_retrained() {
1584 let oracle = Oracle::new();
1585 let mut trigger = RetrainTrigger::with_oracle(oracle);
1586
1587 for _ in 0..10 {
1589 trigger.observe(true);
1590 }
1591
1592 assert_eq!(trigger.stats().errors, 10);
1593
1594 trigger.mark_retrained();
1596
1597 let stats = trigger.stats();
1599 assert_eq!(stats.predictions_observed, 0);
1600 assert_eq!(stats.errors, 0);
1601 assert_eq!(stats.consecutive_errors, 0);
1602 }
1603
1604 #[test]
1605 fn test_retrain_stats_error_rate() {
1606 let mut stats = RetrainStats::default();
1607 assert_eq!(stats.error_rate(), 0.0);
1608
1609 stats.predictions_observed = 100;
1610 stats.errors = 25;
1611 assert!((stats.error_rate() - 0.25).abs() < 0.001);
1612 assert!((stats.accuracy() - 0.75).abs() < 0.001);
1613 }
1614
1615 #[test]
1620 fn test_oracle_config_default() {
1621 let config = OracleConfig::default();
1622 assert_eq!(config.n_estimators, 100);
1623 assert_eq!(config.max_depth, 10);
1624 assert_eq!(config.random_state, Some(42));
1625 }
1626
1627 #[test]
1628 fn test_oracle_config_custom() {
1629 let config = OracleConfig {
1630 n_estimators: 50,
1631 max_depth: 5,
1632 random_state: Some(123),
1633 };
1634 assert_eq!(config.n_estimators, 50);
1635 assert_eq!(config.max_depth, 5);
1636 assert_eq!(config.random_state, Some(123));
1637 }
1638
1639 #[test]
1640 fn test_oracle_with_config() {
1641 let config = OracleConfig {
1642 n_estimators: 20,
1643 max_depth: 3,
1644 random_state: None,
1645 };
1646 let oracle = Oracle::with_config(config);
1647 assert_eq!(oracle.categories.len(), 7);
1648 }
1649
1650 #[test]
1651 fn test_oracle_error_display() {
1652 let model_err = OracleError::Model("test error".to_string());
1653 assert!(model_err.to_string().contains("Model error"));
1654
1655 let feature_err = OracleError::Feature("feature error".to_string());
1656 assert!(feature_err.to_string().contains("Feature extraction error"));
1657
1658 let class_err = OracleError::Classification("class error".to_string());
1659 assert!(class_err.to_string().contains("Classification error"));
1660
1661 let io_err = OracleError::Io(std::io::Error::new(
1662 std::io::ErrorKind::NotFound,
1663 "not found",
1664 ));
1665 assert!(io_err.to_string().contains("IO error"));
1666 }
1667
1668 #[test]
1669 fn test_classification_result_creation() {
1670 let result = ClassificationResult {
1671 category: ErrorCategory::TypeMismatch,
1672 confidence: 0.95,
1673 suggested_fix: Some("Use .into()".to_string()),
1674 related_patterns: vec!["pattern1".to_string(), "pattern2".to_string()],
1675 };
1676 assert_eq!(result.category, ErrorCategory::TypeMismatch);
1677 assert_eq!(result.confidence, 0.95);
1678 assert!(result.suggested_fix.is_some());
1679 assert_eq!(result.related_patterns.len(), 2);
1680 }
1681
1682 #[test]
1683 fn test_classification_result_clone() {
1684 let result = ClassificationResult {
1685 category: ErrorCategory::BorrowChecker,
1686 confidence: 0.80,
1687 suggested_fix: None,
1688 related_patterns: vec![],
1689 };
1690 let cloned = result.clone();
1691 assert_eq!(cloned.category, result.category);
1692 assert_eq!(cloned.confidence, result.confidence);
1693 }
1694
1695 #[test]
1696 fn test_observe_result_eq() {
1697 assert_eq!(ObserveResult::Stable, ObserveResult::Stable);
1698 assert_eq!(ObserveResult::Warning, ObserveResult::Warning);
1699 assert_eq!(ObserveResult::DriftDetected, ObserveResult::DriftDetected);
1700 assert_ne!(ObserveResult::Stable, ObserveResult::Warning);
1701 }
1702
1703 #[test]
1704 fn test_retrain_config_default() {
1705 let config = RetrainConfig::default();
1706 assert_eq!(config.min_samples, 50);
1707 assert_eq!(config.max_consecutive_errors, 10);
1708 assert!((config.warning_threshold - 0.2).abs() < 0.001);
1709 assert!((config.drift_threshold - 0.3).abs() < 0.001);
1710 }
1711
1712 #[test]
1713 fn test_retrain_config_custom() {
1714 let config = RetrainConfig {
1715 min_samples: 100,
1716 max_consecutive_errors: 5,
1717 warning_threshold: 0.15,
1718 drift_threshold: 0.25,
1719 };
1720 assert_eq!(config.min_samples, 100);
1721 assert_eq!(config.max_consecutive_errors, 5);
1722 }
1723
1724 #[test]
1725 fn test_create_accuracy_bar() {
1726 let bar = create_accuracy_bar(1.0);
1728 assert_eq!(bar, "[██████████]");
1729
1730 let bar = create_accuracy_bar(0.5);
1732 assert_eq!(bar, "[█████░░░░░]");
1733
1734 let bar = create_accuracy_bar(0.0);
1736 assert_eq!(bar, "[░░░░░░░░░░]");
1737
1738 let bar = create_accuracy_bar(0.85);
1740 assert!(bar.contains("█"));
1741 }
1742
1743 #[test]
1744 fn test_print_drift_status_does_not_panic() {
1745 let mut oracle = Oracle::new();
1747 for _ in 0..10 {
1748 oracle.observe_prediction(false);
1749 }
1750 let stats = oracle.drift_stats();
1751 print_drift_status(&stats, &DriftStatus::Stable);
1753 print_drift_status(&stats, &DriftStatus::Warning);
1754 print_drift_status(&stats, &DriftStatus::Drift);
1755 }
1756
1757 #[test]
1758 fn test_print_retrain_status_does_not_panic() {
1759 let stats = RetrainStats {
1760 predictions_observed: 100,
1761 correct_predictions: 80,
1762 errors: 20,
1763 consecutive_errors: 2,
1764 drift_status: DriftStatus::Stable,
1765 drift_count: 0,
1766 };
1767 print_retrain_status(&stats);
1769 }
1770
1771 #[test]
1772 #[cfg(feature = "training")]
1773 fn test_print_lineage_history_does_not_panic() {
1774 let lineage = OracleLineage::new();
1775 print_lineage_history(&lineage);
1777
1778 let mut lineage = OracleLineage::new();
1780 lineage.record_training("sha123".to_string(), "hash456".to_string(), 1000, 0.9);
1781 print_lineage_history(&lineage);
1782 }
1783
1784 #[test]
1785 #[cfg(feature = "training")]
1786 fn test_print_oracle_status_does_not_panic() {
1787 let oracle = Oracle::new();
1788 let trigger = RetrainTrigger::with_oracle(oracle);
1789 let lineage = OracleLineage::new();
1790 print_oracle_status(&trigger, &lineage);
1792 }
1793
1794 #[test]
1795 fn test_retrain_trigger_oracle_access() {
1796 let oracle = Oracle::new();
1797 let mut trigger = RetrainTrigger::with_oracle(oracle);
1798
1799 assert_eq!(trigger.oracle().categories.len(), 7);
1801
1802 let oracle_mut = trigger.oracle_mut();
1804 assert_eq!(oracle_mut.categories.len(), 7);
1805 }
1806
1807 #[test]
1808 fn test_retrain_trigger_drift_stats() {
1809 let oracle = Oracle::new();
1810 let trigger = RetrainTrigger::with_oracle(oracle);
1811 let stats = trigger.drift_stats();
1812 assert_eq!(stats.n_samples, 0);
1813 }
1814
1815 #[test]
1816 fn test_retrain_trigger_needs_retraining() {
1817 let oracle = Oracle::new();
1818 let trigger = RetrainTrigger::with_oracle(oracle);
1819 assert!(!trigger.needs_retraining());
1821 }
1822
1823 #[test]
1824 fn test_oracle_default() {
1825 let oracle = Oracle::default();
1826 assert_eq!(oracle.categories.len(), 7);
1827 }
1828
1829 #[test]
1830 fn test_retrain_stats_default() {
1831 let stats = RetrainStats::default();
1832 assert_eq!(stats.predictions_observed, 0);
1833 assert_eq!(stats.correct_predictions, 0);
1834 assert_eq!(stats.errors, 0);
1835 assert_eq!(stats.consecutive_errors, 0);
1836 assert_eq!(stats.drift_count, 0);
1837 assert!(matches!(stats.drift_status, DriftStatus::Stable));
1838 }
1839
1840 #[test]
1841 fn test_oracle_set_adwin_delta() {
1842 let mut oracle = Oracle::new();
1843 oracle.set_adwin_delta(0.001);
1845 oracle.set_adwin_delta(0.01);
1846 oracle.set_adwin_delta(0.1);
1847 }
1848
1849 #[test]
1854 #[cfg(feature = "training")]
1855 fn test_phase4_enhanced_classification() {
1856 let oracle = Oracle::new();
1857 let mut gnn_encoder = DepylerGnnEncoder::new(GnnEncoderConfig {
1858 similarity_threshold: 0.0, ..Default::default()
1860 });
1861
1862 let pattern =
1864 error_patterns::ErrorPattern::new("E0308", "mismatched types", "+let x: i32 = 42;");
1865 gnn_encoder.index_pattern(&pattern, "let x: i32 = \"hello\";");
1866
1867 let result = oracle.classify_enhanced(
1868 "E0308",
1869 "mismatched types: expected i32, found String",
1870 "def foo(): return \"hello\"",
1871 "fn foo() -> i32 { \"hello\" }",
1872 &mut gnn_encoder,
1873 );
1874
1875 assert!(result.confidence >= 0.0 && result.confidence <= 1.0);
1877 assert!(!result.combined_embedding.is_empty());
1878 assert!(result.enhanced_features.base.message_length > 0.0);
1879 }
1880
1881 #[test]
1882 #[cfg(feature = "training")]
1883 fn test_phase4_enhanced_classification_hnsw_used() {
1884 let oracle = Oracle::new();
1885 let mut gnn_encoder = DepylerGnnEncoder::with_defaults();
1886
1887 let pattern = error_patterns::ErrorPattern::new("E0308", "type mismatch", "+fix");
1889 gnn_encoder.index_pattern(&pattern, "source");
1890
1891 let result = oracle.classify_enhanced(
1892 "E0308",
1893 "type mismatch",
1894 "def foo(): pass",
1895 "fn foo() {}",
1896 &mut gnn_encoder,
1897 );
1898
1899 assert!(
1900 result.hnsw_used,
1901 "HNSW should be used when patterns are indexed"
1902 );
1903 }
1904
1905 #[test]
1906 #[cfg(feature = "training")]
1907 fn test_phase4_enhanced_classification_without_hnsw() {
1908 let oracle = Oracle::new();
1909 let mut gnn_encoder = DepylerGnnEncoder::new(GnnEncoderConfig {
1910 use_hnsw: false,
1911 ..Default::default()
1912 });
1913
1914 let result = oracle.classify_enhanced(
1915 "E0308",
1916 "type mismatch",
1917 "def foo(): pass",
1918 "fn foo() {}",
1919 &mut gnn_encoder,
1920 );
1921
1922 assert!(!result.hnsw_used, "HNSW should not be used when disabled");
1923 }
1924
1925 #[test]
1926 #[cfg(feature = "training")]
1927 fn test_phase4_enhanced_classification_combined_embedding_size() {
1928 let oracle = Oracle::new();
1929 let mut gnn_encoder = DepylerGnnEncoder::with_defaults();
1930
1931 let result = oracle.classify_enhanced(
1932 "E0382",
1933 "borrow of moved value",
1934 "def foo(): x = []; return x",
1935 "fn foo() { let x = vec![]; x }",
1936 &mut gnn_encoder,
1937 );
1938
1939 assert_eq!(
1941 result.combined_embedding.len(),
1942 gnn_encoder.combined_dim(),
1943 "Combined embedding should have correct dimension"
1944 );
1945 }
1946
1947 #[test]
1948 #[cfg(feature = "training")]
1949 fn test_phase4_enhanced_features_extraction() {
1950 let oracle = Oracle::new();
1951 let mut gnn_encoder = DepylerGnnEncoder::with_defaults();
1952
1953 let result = oracle.classify_enhanced(
1954 "E0277",
1955 "the trait `Clone` is not implemented for `Foo`",
1956 "class Foo: pass",
1957 "struct Foo {}",
1958 &mut gnn_encoder,
1959 );
1960
1961 let keyword_sum: f32 = result.enhanced_features.keyword_counts.iter().sum();
1963 assert!(keyword_sum > 0.0, "Should extract trait-related keywords");
1964 }
1965
1966 #[test]
1967 #[cfg(feature = "training")]
1968 fn test_phase4_pattern_fixes_extraction() {
1969 let oracle = Oracle::new();
1970 let mut gnn_encoder = DepylerGnnEncoder::new(GnnEncoderConfig {
1971 similarity_threshold: 0.0,
1972 ..Default::default()
1973 });
1974
1975 let pattern = error_patterns::ErrorPattern::new(
1977 "E0308",
1978 "type error",
1979 "-let x = \"hello\";\n+let x: i32 = 42;",
1980 );
1981 gnn_encoder.index_pattern(&pattern, "source");
1982
1983 let result = oracle.classify_enhanced(
1984 "E0308",
1985 "type error",
1986 "def foo(): pass",
1987 "fn foo() {}",
1988 &mut gnn_encoder,
1989 );
1990
1991 assert!(
1993 !result.pattern_fixes.is_empty() || result.similar_patterns.is_empty(),
1994 "Should extract fixes from matched patterns"
1995 );
1996 }
1997
1998 #[test]
1999 #[cfg(feature = "training")]
2000 fn test_phase4_enhanced_result_clone() {
2001 let oracle = Oracle::new();
2002 let mut gnn_encoder = DepylerGnnEncoder::with_defaults();
2003
2004 let result = oracle.classify_enhanced(
2005 "E0599",
2006 "method not found",
2007 "foo.bar()",
2008 "foo.bar()",
2009 &mut gnn_encoder,
2010 );
2011
2012 let cloned = result.clone();
2013 assert_eq!(cloned.category, result.category);
2014 assert_eq!(cloned.confidence, result.confidence);
2015 assert_eq!(cloned.hnsw_used, result.hnsw_used);
2016 }
2017}