Skip to main content

decy_oracle/
oracle.rs

1//! Main oracle implementation
2
3use crate::config::OracleConfig;
4use crate::context::CDecisionContext;
5use crate::error::OracleError;
6use crate::metrics::OracleMetrics;
7
8#[cfg(feature = "citl")]
9use entrenar::citl::{DecisionPatternStore, FixSuggestion as EntrenarFixSuggestion};
10
11/// Fix suggestion from the oracle
12#[cfg(feature = "citl")]
13pub type FixSuggestion = EntrenarFixSuggestion;
14
15/// Rustc error information
16#[derive(Debug, Clone)]
17pub struct RustcError {
18    /// Error code (e.g., "E0382")
19    pub code: String,
20    /// Error message
21    pub message: String,
22    /// File path
23    pub file: Option<String>,
24    /// Line number
25    pub line: Option<usize>,
26}
27
28impl RustcError {
29    /// Create a new rustc error
30    pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
31        Self { code: code.into(), message: message.into(), file: None, line: None }
32    }
33
34    /// Add file location
35    pub fn with_location(mut self, file: impl Into<String>, line: usize) -> Self {
36        self.file = Some(file.into());
37        self.line = Some(line);
38        self
39    }
40}
41
42/// Decy CITL Oracle
43///
44/// Queries accumulated fix patterns to suggest corrections for rustc errors.
45pub struct DecyOracle {
46    config: OracleConfig,
47    #[cfg(feature = "citl")]
48    store: Option<DecisionPatternStore>,
49    metrics: OracleMetrics,
50}
51
52impl DecyOracle {
53    /// Create a new oracle from configuration
54    pub fn new(config: OracleConfig) -> Result<Self, OracleError> {
55        #[cfg(feature = "citl")]
56        let store = if config.patterns_path.exists() {
57            Some(
58                DecisionPatternStore::load_apr(&config.patterns_path)
59                    .map_err(|e| OracleError::PatternStoreError(e.to_string()))?,
60            )
61        } else {
62            None
63        };
64
65        Ok(Self {
66            config,
67            #[cfg(feature = "citl")]
68            store,
69            metrics: OracleMetrics::default(),
70        })
71    }
72
73    /// Check if the oracle has patterns loaded
74    pub fn has_patterns(&self) -> bool {
75        #[cfg(feature = "citl")]
76        {
77            self.store.is_some()
78        }
79        #[cfg(not(feature = "citl"))]
80        {
81            false
82        }
83    }
84
85    /// Get the number of patterns loaded
86    pub fn pattern_count(&self) -> usize {
87        #[cfg(feature = "citl")]
88        {
89            self.store.as_ref().map(|s| s.len()).unwrap_or(0)
90        }
91        #[cfg(not(feature = "citl"))]
92        {
93            0
94        }
95    }
96
97    /// Query for fix suggestion
98    #[cfg(feature = "citl")]
99    pub fn suggest_fix(
100        &mut self,
101        error: &RustcError,
102        context: &CDecisionContext,
103    ) -> Option<FixSuggestion> {
104        let store = match self.store.as_ref() {
105            Some(s) => s,
106            None => {
107                self.metrics.record_miss(&error.code);
108                return None;
109            }
110        };
111
112        let context_strings = context.to_context_strings();
113        let suggestions =
114            match store.suggest_fix(&error.code, &context_strings, self.config.max_suggestions) {
115                Ok(s) => s,
116                Err(_) => {
117                    self.metrics.record_miss(&error.code);
118                    return None;
119                }
120            };
121
122        let best = match suggestions
123            .into_iter()
124            .find(|s| s.weighted_score() >= self.config.confidence_threshold)
125        {
126            Some(b) => b,
127            None => {
128                self.metrics.record_miss(&error.code);
129                return None;
130            }
131        };
132
133        self.metrics.record_hit(&error.code);
134        Some(best)
135    }
136
137    /// Query for fix suggestion (stub when citl feature disabled)
138    #[cfg(not(feature = "citl"))]
139    pub fn suggest_fix(&mut self, error: &RustcError, _context: &CDecisionContext) -> Option<()> {
140        self.metrics.record_miss(&error.code);
141        None
142    }
143
144    /// Record a miss (no suggestion found)
145    pub fn record_miss(&mut self, error: &RustcError) {
146        self.metrics.record_miss(&error.code);
147    }
148
149    /// Record a successful fix application
150    pub fn record_fix_applied(&mut self, error: &RustcError) {
151        self.metrics.record_fix_applied(&error.code);
152    }
153
154    /// Record a verified fix (compiled successfully)
155    pub fn record_fix_verified(&mut self, error: &RustcError) {
156        self.metrics.record_fix_verified(&error.code);
157    }
158
159    /// Get current metrics
160    pub fn metrics(&self) -> &OracleMetrics {
161        &self.metrics
162    }
163
164    /// Get configuration
165    pub fn config(&self) -> &OracleConfig {
166        &self.config
167    }
168
169    /// Import patterns from another .apr file (cross-project transfer)
170    ///
171    /// Uses the smart import filter to verify fix strategies are applicable
172    /// to C→Rust context (not just Python→Rust patterns).
173    #[cfg(feature = "citl")]
174    pub fn import_patterns(&mut self, path: &std::path::Path) -> Result<usize, OracleError> {
175        self.import_patterns_with_config(path, crate::import::SmartImportConfig::default())
176    }
177
178    /// Import patterns with custom configuration
179    #[cfg(feature = "citl")]
180    pub fn import_patterns_with_config(
181        &mut self,
182        path: &std::path::Path,
183        config: crate::import::SmartImportConfig,
184    ) -> Result<usize, OracleError> {
185        use crate::import::{smart_import_filter, ImportStats};
186
187        let other_store = DecisionPatternStore::load_apr(path)
188            .map_err(|e| OracleError::PatternStoreError(e.to_string()))?;
189
190        // Transferable error codes (ownership/lifetime)
191        let transferable = ["E0382", "E0499", "E0506", "E0597", "E0515"];
192
193        let store = self.store.get_or_insert_with(|| {
194            DecisionPatternStore::new().expect("Failed to create pattern store")
195        });
196
197        let mut count = 0;
198        let mut stats = ImportStats::new();
199
200        for code in &transferable {
201            let patterns = other_store.patterns_for_error(code);
202            for pattern in patterns {
203                // Apply smart import filter
204                let strategy = crate::import::analyze_fix_strategy(&pattern.fix_diff);
205                let decision = smart_import_filter(&pattern.fix_diff, &pattern.metadata, &config);
206
207                stats.record(strategy, &decision);
208
209                if decision.allows_import() && store.index_fix(pattern.clone()).is_ok() {
210                    count += 1;
211                }
212            }
213        }
214
215        // Log import statistics
216        if stats.total_evaluated > 0 {
217            tracing::info!(
218                "Import stats: {}/{} patterns accepted ({:.1}%)",
219                count,
220                stats.total_evaluated,
221                stats.overall_acceptance_rate() * 100.0
222            );
223        }
224
225        Ok(count)
226    }
227
228    /// Import patterns with statistics tracking
229    #[cfg(feature = "citl")]
230    pub fn import_patterns_with_stats(
231        &mut self,
232        path: &std::path::Path,
233        config: crate::import::SmartImportConfig,
234    ) -> Result<(usize, crate::import::ImportStats), OracleError> {
235        use crate::import::{smart_import_filter, ImportStats};
236
237        let other_store = DecisionPatternStore::load_apr(path)
238            .map_err(|e| OracleError::PatternStoreError(e.to_string()))?;
239
240        let transferable = ["E0382", "E0499", "E0506", "E0597", "E0515"];
241
242        let store = self.store.get_or_insert_with(|| {
243            DecisionPatternStore::new().expect("Failed to create pattern store")
244        });
245
246        let mut count = 0;
247        let mut stats = ImportStats::new();
248
249        for code in &transferable {
250            let patterns = other_store.patterns_for_error(code);
251            for pattern in patterns {
252                let strategy = crate::import::analyze_fix_strategy(&pattern.fix_diff);
253                let decision = smart_import_filter(&pattern.fix_diff, &pattern.metadata, &config);
254
255                stats.record(strategy, &decision);
256
257                if decision.allows_import() && store.index_fix(pattern.clone()).is_ok() {
258                    count += 1;
259                }
260            }
261        }
262
263        Ok((count, stats))
264    }
265
266    /// Save patterns to .apr file
267    #[cfg(feature = "citl")]
268    pub fn save(&self) -> Result<(), OracleError> {
269        if let Some(ref store) = self.store {
270            store.save_apr(&self.config.patterns_path).map_err(|e| OracleError::SaveError {
271                path: self.config.patterns_path.display().to_string(),
272                source: std::io::Error::new(std::io::ErrorKind::Other, e.to_string()),
273            })?;
274        }
275        Ok(())
276    }
277
278    /// Bootstrap the oracle with seed patterns for cold start
279    ///
280    /// This loads predefined patterns for common C→Rust transpilation errors,
281    /// solving the cold start problem where the oracle has no patterns to learn from.
282    ///
283    /// # Toyota Way Principles
284    ///
285    /// - **Genchi Genbutsu**: Patterns derived from real C→Rust errors
286    /// - **Yokoten**: Cross-project pattern sharing
287    /// - **Jidoka**: Automated quality built-in
288    #[cfg(feature = "citl")]
289    pub fn bootstrap(&mut self) -> Result<usize, OracleError> {
290        use crate::bootstrap::seed_pattern_store;
291
292        let store = self.store.get_or_insert_with(|| {
293            DecisionPatternStore::new().expect("Failed to create pattern store")
294        });
295
296        seed_pattern_store(store)
297    }
298
299    /// Check if bootstrap patterns are needed
300    ///
301    /// Returns true if the oracle has no patterns or very few patterns,
302    /// indicating that bootstrapping would be beneficial.
303    pub fn needs_bootstrap(&self) -> bool {
304        self.pattern_count() < 10
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311    use crate::context::CConstruct;
312    use crate::decisions::CDecisionCategory;
313
314    #[test]
315    fn test_oracle_creation_no_patterns() {
316        // Use a path that doesn't exist to test no-patterns case
317        let config = OracleConfig {
318            patterns_path: std::path::PathBuf::from("/tmp/nonexistent_test_patterns.apr"),
319            ..Default::default()
320        };
321        let oracle = DecyOracle::new(config).unwrap();
322        assert!(!oracle.has_patterns()); // No patterns file exists
323    }
324
325    #[test]
326    fn test_oracle_pattern_count_empty() {
327        // Use a path that doesn't exist to test empty case
328        let config = OracleConfig {
329            patterns_path: std::path::PathBuf::from("/tmp/nonexistent_test_patterns.apr"),
330            ..Default::default()
331        };
332        let oracle = DecyOracle::new(config).unwrap();
333        assert_eq!(oracle.pattern_count(), 0);
334    }
335
336    #[test]
337    fn test_oracle_config_access() {
338        let config = OracleConfig { confidence_threshold: 0.9, ..Default::default() };
339        let oracle = DecyOracle::new(config).unwrap();
340        assert!((oracle.config().confidence_threshold - 0.9).abs() < f32::EPSILON);
341    }
342
343    #[test]
344    fn test_rustc_error() {
345        let error = RustcError::new("E0382", "borrow of moved value").with_location("test.rs", 42);
346        assert_eq!(error.code, "E0382");
347        assert_eq!(error.line, Some(42));
348    }
349
350    #[test]
351    fn test_rustc_error_without_location() {
352        let error = RustcError::new("E0499", "cannot borrow as mutable more than once");
353        assert_eq!(error.code, "E0499");
354        assert_eq!(error.message, "cannot borrow as mutable more than once");
355        assert!(error.file.is_none());
356        assert!(error.line.is_none());
357    }
358
359    #[test]
360    fn test_rustc_error_chained_builder() {
361        let error = RustcError::new("E0506", "cannot assign").with_location("src/main.rs", 100);
362        assert_eq!(error.code, "E0506");
363        assert_eq!(error.file, Some("src/main.rs".into()));
364        assert_eq!(error.line, Some(100));
365    }
366
367    #[test]
368    fn test_metrics_recorded() {
369        let config = OracleConfig::default();
370        let mut oracle = DecyOracle::new(config).unwrap();
371
372        let error = RustcError::new("E0382", "test");
373        let context = CDecisionContext::new(
374            CConstruct::RawPointer { is_const: false, pointee: "int".into() },
375            CDecisionCategory::PointerOwnership,
376        );
377
378        // No patterns, should be a miss
379        let _ = oracle.suggest_fix(&error, &context);
380        assert_eq!(oracle.metrics().misses, 1);
381    }
382
383    #[test]
384    fn test_record_miss() {
385        let config = OracleConfig::default();
386        let mut oracle = DecyOracle::new(config).unwrap();
387
388        let error = RustcError::new("E0597", "borrowed value does not live long enough");
389        oracle.record_miss(&error);
390        assert_eq!(oracle.metrics().misses, 1);
391        assert_eq!(oracle.metrics().queries, 1);
392    }
393
394    #[test]
395    fn test_record_fix_applied() {
396        let config = OracleConfig::default();
397        let mut oracle = DecyOracle::new(config).unwrap();
398
399        let error = RustcError::new("E0382", "use of moved value");
400        oracle.record_fix_applied(&error);
401        assert_eq!(oracle.metrics().fixes_applied, 1);
402    }
403
404    #[test]
405    fn test_record_fix_verified() {
406        let config = OracleConfig::default();
407        let mut oracle = DecyOracle::new(config).unwrap();
408
409        let error = RustcError::new("E0515", "cannot return reference to local");
410        oracle.record_fix_verified(&error);
411        assert_eq!(oracle.metrics().fixes_verified, 1);
412    }
413
414    #[test]
415    fn test_multiple_error_codes_tracked() {
416        let config = OracleConfig::default();
417        let mut oracle = DecyOracle::new(config).unwrap();
418
419        oracle.record_miss(&RustcError::new("E0382", "test"));
420        oracle.record_miss(&RustcError::new("E0499", "test"));
421        oracle.record_miss(&RustcError::new("E0382", "test"));
422
423        let metrics = oracle.metrics();
424        assert_eq!(metrics.misses, 3);
425        assert_eq!(metrics.by_error_code.get("E0382").unwrap().queries, 2);
426        assert_eq!(metrics.by_error_code.get("E0499").unwrap().queries, 1);
427    }
428
429    // ============================================================================
430    // NEEDS_BOOTSTRAP TESTS
431    // ============================================================================
432
433    #[test]
434    fn test_needs_bootstrap_when_empty() {
435        let config = OracleConfig {
436            patterns_path: std::path::PathBuf::from("/tmp/nonexistent.apr"),
437            ..Default::default()
438        };
439        let oracle = DecyOracle::new(config).unwrap();
440        assert!(oracle.needs_bootstrap()); // 0 patterns < 10
441    }
442
443    #[test]
444    fn test_needs_bootstrap_threshold() {
445        let config = OracleConfig {
446            patterns_path: std::path::PathBuf::from("/tmp/nonexistent.apr"),
447            ..Default::default()
448        };
449        let oracle = DecyOracle::new(config).unwrap();
450        // pattern_count() is 0, needs_bootstrap checks < 10
451        assert!(oracle.needs_bootstrap());
452    }
453
454    // ============================================================================
455    // RUSTC ERROR BUILDER TESTS
456    // ============================================================================
457
458    #[test]
459    fn test_rustc_error_new_with_empty_strings() {
460        let error = RustcError::new("", "");
461        assert_eq!(error.code, "");
462        assert_eq!(error.message, "");
463    }
464
465    #[test]
466    fn test_rustc_error_new_with_string_slices() {
467        let code: &str = "E0382";
468        let msg: &str = "use of moved value";
469        let error = RustcError::new(code, msg);
470        assert_eq!(error.code, "E0382");
471        assert_eq!(error.message, "use of moved value");
472    }
473
474    #[test]
475    fn test_rustc_error_new_with_string_type() {
476        let code = String::from("E0499");
477        let msg = String::from("cannot borrow");
478        let error = RustcError::new(code, msg);
479        assert_eq!(error.code, "E0499");
480    }
481
482    #[test]
483    fn test_rustc_error_with_location_zero_line() {
484        let error = RustcError::new("E0382", "test").with_location("test.rs", 0);
485        assert_eq!(error.line, Some(0));
486    }
487
488    #[test]
489    fn test_rustc_error_with_location_large_line() {
490        let error = RustcError::new("E0382", "test").with_location("test.rs", usize::MAX);
491        assert_eq!(error.line, Some(usize::MAX));
492    }
493
494    #[test]
495    fn test_rustc_error_with_location_empty_file() {
496        let error = RustcError::new("E0382", "test").with_location("", 10);
497        assert_eq!(error.file, Some("".into()));
498    }
499
500    #[test]
501    fn test_rustc_error_clone() {
502        let error = RustcError::new("E0382", "borrow of moved value").with_location("test.rs", 42);
503        let cloned = error.clone();
504        assert_eq!(cloned.code, error.code);
505        assert_eq!(cloned.message, error.message);
506        assert_eq!(cloned.file, error.file);
507        assert_eq!(cloned.line, error.line);
508    }
509
510    #[test]
511    fn test_rustc_error_debug() {
512        let error = RustcError::new("E0382", "test");
513        let debug_str = format!("{:?}", error);
514        assert!(debug_str.contains("RustcError"));
515        assert!(debug_str.contains("E0382"));
516    }
517
518    // ============================================================================
519    // ORACLE HAS_PATTERNS TESTS
520    // ============================================================================
521
522    #[test]
523    fn test_has_patterns_false_when_no_file() {
524        let config = OracleConfig {
525            patterns_path: std::path::PathBuf::from("/does/not/exist.apr"),
526            ..Default::default()
527        };
528        let oracle = DecyOracle::new(config).unwrap();
529        assert!(!oracle.has_patterns());
530    }
531
532    // ============================================================================
533    // ORACLE PATTERN_COUNT TESTS
534    // ============================================================================
535
536    #[test]
537    fn test_pattern_count_zero_when_no_file() {
538        let config = OracleConfig {
539            patterns_path: std::path::PathBuf::from("/does/not/exist.apr"),
540            ..Default::default()
541        };
542        let oracle = DecyOracle::new(config).unwrap();
543        assert_eq!(oracle.pattern_count(), 0);
544    }
545
546    // ============================================================================
547    // METRICS TRACKING TESTS
548    // ============================================================================
549
550    #[test]
551    fn test_metrics_initial_state() {
552        let config = OracleConfig::default();
553        let oracle = DecyOracle::new(config).unwrap();
554        let metrics = oracle.metrics();
555        assert_eq!(metrics.queries, 0);
556        assert_eq!(metrics.hits, 0);
557        assert_eq!(metrics.misses, 0);
558    }
559
560    #[test]
561    fn test_record_miss_increments_queries() {
562        let config = OracleConfig::default();
563        let mut oracle = DecyOracle::new(config).unwrap();
564
565        let error = RustcError::new("E0382", "test");
566        oracle.record_miss(&error);
567
568        assert_eq!(oracle.metrics().queries, 1);
569    }
570
571    #[test]
572    fn test_record_fix_applied_multiple() {
573        let config = OracleConfig::default();
574        let mut oracle = DecyOracle::new(config).unwrap();
575
576        let error1 = RustcError::new("E0382", "test1");
577        let error2 = RustcError::new("E0499", "test2");
578
579        oracle.record_fix_applied(&error1);
580        oracle.record_fix_applied(&error2);
581        oracle.record_fix_applied(&error1);
582
583        assert_eq!(oracle.metrics().fixes_applied, 3);
584    }
585
586    #[test]
587    fn test_record_fix_verified_multiple() {
588        let config = OracleConfig::default();
589        let mut oracle = DecyOracle::new(config).unwrap();
590
591        let error = RustcError::new("E0382", "test");
592
593        oracle.record_fix_verified(&error);
594        oracle.record_fix_verified(&error);
595
596        assert_eq!(oracle.metrics().fixes_verified, 2);
597    }
598
599    #[test]
600    fn test_metrics_by_error_code_new_code() {
601        let config = OracleConfig::default();
602        let mut oracle = DecyOracle::new(config).unwrap();
603
604        let error = RustcError::new("E9999", "custom error");
605        oracle.record_miss(&error);
606
607        let metrics = oracle.metrics();
608        assert!(metrics.by_error_code.contains_key("E9999"));
609    }
610
611    // ============================================================================
612    // CONFIG ACCESS TESTS
613    // ============================================================================
614
615    #[test]
616    fn test_config_returns_original_config() {
617        let config = OracleConfig {
618            confidence_threshold: 0.95,
619            max_suggestions: 20,
620            auto_fix: true,
621            max_retries: 10,
622            ..Default::default()
623        };
624        let oracle = DecyOracle::new(config).unwrap();
625
626        assert!((oracle.config().confidence_threshold - 0.95).abs() < f32::EPSILON);
627        assert_eq!(oracle.config().max_suggestions, 20);
628        assert!(oracle.config().auto_fix);
629        assert_eq!(oracle.config().max_retries, 10);
630    }
631
632    // ============================================================================
633    // SUGGEST_FIX TESTS (WITHOUT CITL FEATURE)
634    // ============================================================================
635
636    #[test]
637    fn test_suggest_fix_records_miss_when_no_patterns() {
638        let config = OracleConfig {
639            patterns_path: std::path::PathBuf::from("/nonexistent.apr"),
640            ..Default::default()
641        };
642        let mut oracle = DecyOracle::new(config).unwrap();
643
644        let error = RustcError::new("E0382", "borrow of moved value");
645        let context = CDecisionContext::new(
646            CConstruct::RawPointer { is_const: false, pointee: "int".into() },
647            CDecisionCategory::PointerOwnership,
648        );
649
650        let result = oracle.suggest_fix(&error, &context);
651        assert!(result.is_none());
652        assert_eq!(oracle.metrics().misses, 1);
653    }
654
655    #[test]
656    fn test_suggest_fix_increments_queries() {
657        let config = OracleConfig::default();
658        let mut oracle = DecyOracle::new(config).unwrap();
659
660        let error = RustcError::new("E0499", "cannot borrow");
661        let context = CDecisionContext::new(
662            CConstruct::RawPointer { is_const: true, pointee: "char".into() },
663            CDecisionCategory::PointerOwnership,
664        );
665
666        oracle.suggest_fix(&error, &context);
667        // Query count should be incremented via record_miss
668        assert!(oracle.metrics().queries >= 1);
669    }
670
671    // ============================================================================
672    // ORACLE CREATION WITH VARIOUS CONFIGS
673    // ============================================================================
674
675    #[test]
676    fn test_oracle_creation_with_custom_threshold() {
677        let config = OracleConfig { confidence_threshold: 0.5, ..Default::default() };
678        let oracle = DecyOracle::new(config).unwrap();
679        assert!((oracle.config().confidence_threshold - 0.5).abs() < f32::EPSILON);
680    }
681
682    #[test]
683    fn test_oracle_creation_with_max_suggestions() {
684        let config = OracleConfig { max_suggestions: 100, ..Default::default() };
685        let oracle = DecyOracle::new(config).unwrap();
686        assert_eq!(oracle.config().max_suggestions, 100);
687    }
688
689    #[test]
690    fn test_oracle_creation_with_auto_fix_enabled() {
691        let config = OracleConfig { auto_fix: true, ..Default::default() };
692        let oracle = DecyOracle::new(config).unwrap();
693        assert!(oracle.config().auto_fix);
694    }
695}