use crate::config::OracleConfig;
use crate::context::CDecisionContext;
use crate::error::OracleError;
use crate::metrics::OracleMetrics;
#[cfg(feature = "citl")]
use entrenar::citl::{DecisionPatternStore, FixSuggestion as EntrenarFixSuggestion};
#[cfg(feature = "citl")]
pub type FixSuggestion = EntrenarFixSuggestion;
#[derive(Debug, Clone)]
pub struct RustcError {
pub code: String,
pub message: String,
pub file: Option<String>,
pub line: Option<usize>,
}
impl RustcError {
pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
Self {
code: code.into(),
message: message.into(),
file: None,
line: None,
}
}
pub fn with_location(mut self, file: impl Into<String>, line: usize) -> Self {
self.file = Some(file.into());
self.line = Some(line);
self
}
}
pub struct DecyOracle {
config: OracleConfig,
#[cfg(feature = "citl")]
store: Option<DecisionPatternStore>,
metrics: OracleMetrics,
}
impl DecyOracle {
pub fn new(config: OracleConfig) -> Result<Self, OracleError> {
#[cfg(feature = "citl")]
let store = if config.patterns_path.exists() {
Some(
DecisionPatternStore::load_apr(&config.patterns_path)
.map_err(|e| OracleError::PatternStoreError(e.to_string()))?,
)
} else {
None
};
Ok(Self {
config,
#[cfg(feature = "citl")]
store,
metrics: OracleMetrics::default(),
})
}
pub fn has_patterns(&self) -> bool {
#[cfg(feature = "citl")]
{
self.store.is_some()
}
#[cfg(not(feature = "citl"))]
{
false
}
}
pub fn pattern_count(&self) -> usize {
#[cfg(feature = "citl")]
{
self.store.as_ref().map(|s| s.len()).unwrap_or(0)
}
#[cfg(not(feature = "citl"))]
{
0
}
}
#[cfg(feature = "citl")]
pub fn suggest_fix(
&mut self,
error: &RustcError,
context: &CDecisionContext,
) -> Option<FixSuggestion> {
let store = match self.store.as_ref() {
Some(s) => s,
None => {
self.metrics.record_miss(&error.code);
return None;
}
};
let context_strings = context.to_context_strings();
let suggestions =
match store.suggest_fix(&error.code, &context_strings, self.config.max_suggestions) {
Ok(s) => s,
Err(_) => {
self.metrics.record_miss(&error.code);
return None;
}
};
let best = match suggestions
.into_iter()
.find(|s| s.weighted_score() >= self.config.confidence_threshold)
{
Some(b) => b,
None => {
self.metrics.record_miss(&error.code);
return None;
}
};
self.metrics.record_hit(&error.code);
Some(best)
}
#[cfg(not(feature = "citl"))]
pub fn suggest_fix(&mut self, error: &RustcError, _context: &CDecisionContext) -> Option<()> {
self.metrics.record_miss(&error.code);
None
}
pub fn record_miss(&mut self, error: &RustcError) {
self.metrics.record_miss(&error.code);
}
pub fn record_fix_applied(&mut self, error: &RustcError) {
self.metrics.record_fix_applied(&error.code);
}
pub fn record_fix_verified(&mut self, error: &RustcError) {
self.metrics.record_fix_verified(&error.code);
}
pub fn metrics(&self) -> &OracleMetrics {
&self.metrics
}
pub fn config(&self) -> &OracleConfig {
&self.config
}
#[cfg(feature = "citl")]
pub fn import_patterns(&mut self, path: &std::path::Path) -> Result<usize, OracleError> {
self.import_patterns_with_config(path, crate::import::SmartImportConfig::default())
}
#[cfg(feature = "citl")]
pub fn import_patterns_with_config(
&mut self,
path: &std::path::Path,
config: crate::import::SmartImportConfig,
) -> Result<usize, OracleError> {
use crate::import::{smart_import_filter, ImportStats};
let other_store = DecisionPatternStore::load_apr(path)
.map_err(|e| OracleError::PatternStoreError(e.to_string()))?;
let transferable = ["E0382", "E0499", "E0506", "E0597", "E0515"];
let store = self.store.get_or_insert_with(|| {
DecisionPatternStore::new().expect("Failed to create pattern store")
});
let mut count = 0;
let mut stats = ImportStats::new();
for code in &transferable {
let patterns = other_store.patterns_for_error(code);
for pattern in patterns {
let strategy = crate::import::analyze_fix_strategy(&pattern.fix_diff);
let decision = smart_import_filter(&pattern.fix_diff, &pattern.metadata, &config);
stats.record(strategy, &decision);
if decision.allows_import() && store.index_fix(pattern.clone()).is_ok() {
count += 1;
}
}
}
if stats.total_evaluated > 0 {
tracing::info!(
"Import stats: {}/{} patterns accepted ({:.1}%)",
count,
stats.total_evaluated,
stats.overall_acceptance_rate() * 100.0
);
}
Ok(count)
}
#[cfg(feature = "citl")]
pub fn import_patterns_with_stats(
&mut self,
path: &std::path::Path,
config: crate::import::SmartImportConfig,
) -> Result<(usize, crate::import::ImportStats), OracleError> {
use crate::import::{smart_import_filter, ImportStats};
let other_store = DecisionPatternStore::load_apr(path)
.map_err(|e| OracleError::PatternStoreError(e.to_string()))?;
let transferable = ["E0382", "E0499", "E0506", "E0597", "E0515"];
let store = self.store.get_or_insert_with(|| {
DecisionPatternStore::new().expect("Failed to create pattern store")
});
let mut count = 0;
let mut stats = ImportStats::new();
for code in &transferable {
let patterns = other_store.patterns_for_error(code);
for pattern in patterns {
let strategy = crate::import::analyze_fix_strategy(&pattern.fix_diff);
let decision = smart_import_filter(&pattern.fix_diff, &pattern.metadata, &config);
stats.record(strategy, &decision);
if decision.allows_import() && store.index_fix(pattern.clone()).is_ok() {
count += 1;
}
}
}
Ok((count, stats))
}
#[cfg(feature = "citl")]
pub fn save(&self) -> Result<(), OracleError> {
if let Some(ref store) = self.store {
store
.save_apr(&self.config.patterns_path)
.map_err(|e| OracleError::SaveError {
path: self.config.patterns_path.display().to_string(),
source: std::io::Error::new(std::io::ErrorKind::Other, e.to_string()),
})?;
}
Ok(())
}
#[cfg(feature = "citl")]
pub fn bootstrap(&mut self) -> Result<usize, OracleError> {
use crate::bootstrap::seed_pattern_store;
let store = self.store.get_or_insert_with(|| {
DecisionPatternStore::new().expect("Failed to create pattern store")
});
seed_pattern_store(store)
}
pub fn needs_bootstrap(&self) -> bool {
self.pattern_count() < 10
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::CConstruct;
use crate::decisions::CDecisionCategory;
#[test]
fn test_oracle_creation_no_patterns() {
let config = OracleConfig {
patterns_path: std::path::PathBuf::from("/tmp/nonexistent_test_patterns.apr"),
..Default::default()
};
let oracle = DecyOracle::new(config).unwrap();
assert!(!oracle.has_patterns()); }
#[test]
fn test_oracle_pattern_count_empty() {
let config = OracleConfig {
patterns_path: std::path::PathBuf::from("/tmp/nonexistent_test_patterns.apr"),
..Default::default()
};
let oracle = DecyOracle::new(config).unwrap();
assert_eq!(oracle.pattern_count(), 0);
}
#[test]
fn test_oracle_config_access() {
let config = OracleConfig {
confidence_threshold: 0.9,
..Default::default()
};
let oracle = DecyOracle::new(config).unwrap();
assert!((oracle.config().confidence_threshold - 0.9).abs() < f32::EPSILON);
}
#[test]
fn test_rustc_error() {
let error = RustcError::new("E0382", "borrow of moved value").with_location("test.rs", 42);
assert_eq!(error.code, "E0382");
assert_eq!(error.line, Some(42));
}
#[test]
fn test_rustc_error_without_location() {
let error = RustcError::new("E0499", "cannot borrow as mutable more than once");
assert_eq!(error.code, "E0499");
assert_eq!(error.message, "cannot borrow as mutable more than once");
assert!(error.file.is_none());
assert!(error.line.is_none());
}
#[test]
fn test_rustc_error_chained_builder() {
let error = RustcError::new("E0506", "cannot assign").with_location("src/main.rs", 100);
assert_eq!(error.code, "E0506");
assert_eq!(error.file, Some("src/main.rs".into()));
assert_eq!(error.line, Some(100));
}
#[test]
fn test_metrics_recorded() {
let config = OracleConfig::default();
let mut oracle = DecyOracle::new(config).unwrap();
let error = RustcError::new("E0382", "test");
let context = CDecisionContext::new(
CConstruct::RawPointer {
is_const: false,
pointee: "int".into(),
},
CDecisionCategory::PointerOwnership,
);
let _ = oracle.suggest_fix(&error, &context);
assert_eq!(oracle.metrics().misses, 1);
}
#[test]
fn test_record_miss() {
let config = OracleConfig::default();
let mut oracle = DecyOracle::new(config).unwrap();
let error = RustcError::new("E0597", "borrowed value does not live long enough");
oracle.record_miss(&error);
assert_eq!(oracle.metrics().misses, 1);
assert_eq!(oracle.metrics().queries, 1);
}
#[test]
fn test_record_fix_applied() {
let config = OracleConfig::default();
let mut oracle = DecyOracle::new(config).unwrap();
let error = RustcError::new("E0382", "use of moved value");
oracle.record_fix_applied(&error);
assert_eq!(oracle.metrics().fixes_applied, 1);
}
#[test]
fn test_record_fix_verified() {
let config = OracleConfig::default();
let mut oracle = DecyOracle::new(config).unwrap();
let error = RustcError::new("E0515", "cannot return reference to local");
oracle.record_fix_verified(&error);
assert_eq!(oracle.metrics().fixes_verified, 1);
}
#[test]
fn test_multiple_error_codes_tracked() {
let config = OracleConfig::default();
let mut oracle = DecyOracle::new(config).unwrap();
oracle.record_miss(&RustcError::new("E0382", "test"));
oracle.record_miss(&RustcError::new("E0499", "test"));
oracle.record_miss(&RustcError::new("E0382", "test"));
let metrics = oracle.metrics();
assert_eq!(metrics.misses, 3);
assert_eq!(metrics.by_error_code.get("E0382").unwrap().queries, 2);
assert_eq!(metrics.by_error_code.get("E0499").unwrap().queries, 1);
}
#[test]
fn test_needs_bootstrap_when_empty() {
let config = OracleConfig {
patterns_path: std::path::PathBuf::from("/tmp/nonexistent.apr"),
..Default::default()
};
let oracle = DecyOracle::new(config).unwrap();
assert!(oracle.needs_bootstrap()); }
#[test]
fn test_needs_bootstrap_threshold() {
let config = OracleConfig {
patterns_path: std::path::PathBuf::from("/tmp/nonexistent.apr"),
..Default::default()
};
let oracle = DecyOracle::new(config).unwrap();
assert!(oracle.needs_bootstrap());
}
#[test]
fn test_rustc_error_new_with_empty_strings() {
let error = RustcError::new("", "");
assert_eq!(error.code, "");
assert_eq!(error.message, "");
}
#[test]
fn test_rustc_error_new_with_string_slices() {
let code: &str = "E0382";
let msg: &str = "use of moved value";
let error = RustcError::new(code, msg);
assert_eq!(error.code, "E0382");
assert_eq!(error.message, "use of moved value");
}
#[test]
fn test_rustc_error_new_with_string_type() {
let code = String::from("E0499");
let msg = String::from("cannot borrow");
let error = RustcError::new(code, msg);
assert_eq!(error.code, "E0499");
}
#[test]
fn test_rustc_error_with_location_zero_line() {
let error = RustcError::new("E0382", "test").with_location("test.rs", 0);
assert_eq!(error.line, Some(0));
}
#[test]
fn test_rustc_error_with_location_large_line() {
let error = RustcError::new("E0382", "test").with_location("test.rs", usize::MAX);
assert_eq!(error.line, Some(usize::MAX));
}
#[test]
fn test_rustc_error_with_location_empty_file() {
let error = RustcError::new("E0382", "test").with_location("", 10);
assert_eq!(error.file, Some("".into()));
}
#[test]
fn test_rustc_error_clone() {
let error = RustcError::new("E0382", "borrow of moved value").with_location("test.rs", 42);
let cloned = error.clone();
assert_eq!(cloned.code, error.code);
assert_eq!(cloned.message, error.message);
assert_eq!(cloned.file, error.file);
assert_eq!(cloned.line, error.line);
}
#[test]
fn test_rustc_error_debug() {
let error = RustcError::new("E0382", "test");
let debug_str = format!("{:?}", error);
assert!(debug_str.contains("RustcError"));
assert!(debug_str.contains("E0382"));
}
#[test]
fn test_has_patterns_false_when_no_file() {
let config = OracleConfig {
patterns_path: std::path::PathBuf::from("/does/not/exist.apr"),
..Default::default()
};
let oracle = DecyOracle::new(config).unwrap();
assert!(!oracle.has_patterns());
}
#[test]
fn test_pattern_count_zero_when_no_file() {
let config = OracleConfig {
patterns_path: std::path::PathBuf::from("/does/not/exist.apr"),
..Default::default()
};
let oracle = DecyOracle::new(config).unwrap();
assert_eq!(oracle.pattern_count(), 0);
}
#[test]
fn test_metrics_initial_state() {
let config = OracleConfig::default();
let oracle = DecyOracle::new(config).unwrap();
let metrics = oracle.metrics();
assert_eq!(metrics.queries, 0);
assert_eq!(metrics.hits, 0);
assert_eq!(metrics.misses, 0);
}
#[test]
fn test_record_miss_increments_queries() {
let config = OracleConfig::default();
let mut oracle = DecyOracle::new(config).unwrap();
let error = RustcError::new("E0382", "test");
oracle.record_miss(&error);
assert_eq!(oracle.metrics().queries, 1);
}
#[test]
fn test_record_fix_applied_multiple() {
let config = OracleConfig::default();
let mut oracle = DecyOracle::new(config).unwrap();
let error1 = RustcError::new("E0382", "test1");
let error2 = RustcError::new("E0499", "test2");
oracle.record_fix_applied(&error1);
oracle.record_fix_applied(&error2);
oracle.record_fix_applied(&error1);
assert_eq!(oracle.metrics().fixes_applied, 3);
}
#[test]
fn test_record_fix_verified_multiple() {
let config = OracleConfig::default();
let mut oracle = DecyOracle::new(config).unwrap();
let error = RustcError::new("E0382", "test");
oracle.record_fix_verified(&error);
oracle.record_fix_verified(&error);
assert_eq!(oracle.metrics().fixes_verified, 2);
}
#[test]
fn test_metrics_by_error_code_new_code() {
let config = OracleConfig::default();
let mut oracle = DecyOracle::new(config).unwrap();
let error = RustcError::new("E9999", "custom error");
oracle.record_miss(&error);
let metrics = oracle.metrics();
assert!(metrics.by_error_code.contains_key("E9999"));
}
#[test]
fn test_config_returns_original_config() {
let config = OracleConfig {
confidence_threshold: 0.95,
max_suggestions: 20,
auto_fix: true,
max_retries: 10,
..Default::default()
};
let oracle = DecyOracle::new(config).unwrap();
assert!((oracle.config().confidence_threshold - 0.95).abs() < f32::EPSILON);
assert_eq!(oracle.config().max_suggestions, 20);
assert!(oracle.config().auto_fix);
assert_eq!(oracle.config().max_retries, 10);
}
#[test]
fn test_suggest_fix_records_miss_when_no_patterns() {
let config = OracleConfig {
patterns_path: std::path::PathBuf::from("/nonexistent.apr"),
..Default::default()
};
let mut oracle = DecyOracle::new(config).unwrap();
let error = RustcError::new("E0382", "borrow of moved value");
let context = CDecisionContext::new(
CConstruct::RawPointer {
is_const: false,
pointee: "int".into(),
},
CDecisionCategory::PointerOwnership,
);
let result = oracle.suggest_fix(&error, &context);
assert!(result.is_none());
assert_eq!(oracle.metrics().misses, 1);
}
#[test]
fn test_suggest_fix_increments_queries() {
let config = OracleConfig::default();
let mut oracle = DecyOracle::new(config).unwrap();
let error = RustcError::new("E0499", "cannot borrow");
let context = CDecisionContext::new(
CConstruct::RawPointer {
is_const: true,
pointee: "char".into(),
},
CDecisionCategory::PointerOwnership,
);
oracle.suggest_fix(&error, &context);
assert!(oracle.metrics().queries >= 1);
}
#[test]
fn test_oracle_creation_with_custom_threshold() {
let config = OracleConfig {
confidence_threshold: 0.5,
..Default::default()
};
let oracle = DecyOracle::new(config).unwrap();
assert!((oracle.config().confidence_threshold - 0.5).abs() < f32::EPSILON);
}
#[test]
fn test_oracle_creation_with_max_suggestions() {
let config = OracleConfig {
max_suggestions: 100,
..Default::default()
};
let oracle = DecyOracle::new(config).unwrap();
assert_eq!(oracle.config().max_suggestions, 100);
}
#[test]
fn test_oracle_creation_with_auto_fix_enabled() {
let config = OracleConfig {
auto_fix: true,
..Default::default()
};
let oracle = DecyOracle::new(config).unwrap();
assert!(oracle.config().auto_fix);
}
}