use crate::{AutoFixer, OracleError};
use aprender::citl::{
CompilationMode, CompilationResult, CompileOptions, CompilerDiagnostic, CompilerInterface,
ErrorEmbedding, ErrorEncoder, FixTemplate, MetricsTracker, PatternLibrary, PatternMatch,
RustCompiler,
};
use std::collections::HashMap;
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct IterativeFixResult {
pub success: bool,
pub fixed_source: String,
pub iterations: usize,
pub fixes_applied: Vec<String>,
pub remaining_errors: usize,
pub fix_duration_ms: u64,
}
impl IterativeFixResult {
#[must_use]
pub fn success(fixed_source: String, iterations: usize, fixes: Vec<String>) -> Self {
Self {
success: true,
fixed_source,
iterations,
fixes_applied: fixes,
remaining_errors: 0,
fix_duration_ms: 0,
}
}
#[must_use]
pub fn failure(source: String, iterations: usize, remaining: usize) -> Self {
Self {
success: false,
fixed_source: source,
iterations,
fixes_applied: Vec::new(),
remaining_errors: remaining,
fix_duration_ms: 0,
}
}
#[must_use]
pub fn with_duration(mut self, ms: u64) -> Self {
self.fix_duration_ms = ms;
self
}
}
pub struct CITLFixer {
compiler: RustCompiler,
encoder: ErrorEncoder,
pattern_library: PatternLibrary,
metrics: MetricsTracker,
max_iterations: usize,
confidence_threshold: f32,
autofixer: Option<AutoFixer>,
}
impl CITLFixer {
pub fn new() -> Result<Self, OracleError> {
Self::with_config(CITLFixerConfig::default())
}
pub fn with_config(config: CITLFixerConfig) -> Result<Self, OracleError> {
let compiler = RustCompiler::new().mode(config.compilation_mode);
let encoder = ErrorEncoder::new();
let pattern_library = if let Some(ref path) = config.pattern_library_path {
PatternLibrary::load(path).unwrap_or_else(|_| PatternLibrary::new())
} else {
PatternLibrary::new()
};
let autofixer = if config.use_autofixer_fallback {
AutoFixer::new().ok()
} else {
None
};
Ok(Self {
compiler,
encoder,
pattern_library,
metrics: MetricsTracker::new(),
max_iterations: config.max_iterations,
confidence_threshold: config.confidence_threshold,
autofixer,
})
}
pub fn fix_all(&mut self, source: &str) -> IterativeFixResult {
let start = Instant::now();
let mut current = source.to_string();
let mut iterations = 0;
let mut applied_fixes = Vec::new();
if self.compiles(¤t) {
return IterativeFixResult::success(current, 0, applied_fixes)
.with_duration(start.elapsed().as_millis() as u64);
}
while iterations < self.max_iterations {
iterations += 1;
let compile_result = self.compile(¤t);
let errors = match compile_result {
Ok(CompilationResult::Success { .. }) => {
self.metrics.record_convergence(iterations, true);
return IterativeFixResult::success(current, iterations, applied_fixes)
.with_duration(start.elapsed().as_millis() as u64);
}
Ok(CompilationResult::Failure { errors, .. }) => errors,
Err(_) => {
break;
}
};
if errors.is_empty() {
self.metrics.record_convergence(iterations, true);
return IterativeFixResult::success(current, iterations, applied_fixes)
.with_duration(start.elapsed().as_millis() as u64);
}
let error = &errors[0];
let error_code = &error.code.code;
let embedding = self.encoder.encode(error, ¤t);
let matches = self.pattern_library.search(&embedding, 5);
let mut fixed = false;
for (idx, m) in matches.iter().enumerate() {
if m.similarity >= self.confidence_threshold {
if let Some(new_source) = self.apply_pattern_fix(¤t, error, m) {
self.metrics.record_fix_attempt(true, error_code);
self.metrics.record_pattern_use(idx, true);
applied_fixes.push(m.pattern.fix_template.description.clone());
current = new_source;
fixed = true;
break;
}
}
}
if !fixed {
if let Some(ref autofixer) = self.autofixer {
let error_str = format!("error[{}]: {}", error.code.code, error.message);
let fix_result = autofixer.fix(¤t, &error_str);
if fix_result.fixed {
self.metrics.record_fix_attempt(true, error_code);
applied_fixes.push(fix_result.description);
current = fix_result.source;
fixed = true;
}
}
}
if !fixed {
self.metrics.record_fix_attempt(false, error_code);
self.metrics.record_convergence(iterations, false);
return IterativeFixResult::failure(current, iterations, errors.len())
.with_duration(start.elapsed().as_millis() as u64);
}
}
let errors = self.count_errors(¤t);
self.metrics.record_convergence(iterations, false);
IterativeFixResult {
success: false,
fixed_source: current,
iterations,
fixes_applied: applied_fixes,
remaining_errors: errors,
fix_duration_ms: start.elapsed().as_millis() as u64,
}
}
#[must_use]
pub fn compiles(&self, source: &str) -> bool {
matches!(
self.compiler.compile(source, &CompileOptions::default()),
Ok(CompilationResult::Success { .. })
)
}
fn compile(&self, source: &str) -> Result<CompilationResult, aprender::citl::CITLError> {
self.compiler.compile(source, &CompileOptions::default())
}
fn count_errors(&self, source: &str) -> usize {
match self.compile(source) {
Ok(CompilationResult::Failure { errors, .. }) => errors.len(),
_ => 0,
}
}
fn apply_pattern_fix(
&self,
source: &str,
error: &CompilerDiagnostic,
pattern_match: &PatternMatch,
) -> Option<String> {
for suggestion in &error.suggestions {
let replacement = &suggestion.replacement;
let span = &replacement.span;
if span.byte_start < source.len() && span.byte_end <= source.len() {
let mut result = source.to_string();
result.replace_range(span.byte_start..span.byte_end, &replacement.replacement);
return Some(result);
}
}
let template = &pattern_match.pattern.fix_template;
let mut bindings = HashMap::new();
if let Some(ref expected) = error.expected {
bindings.insert("expected_type".to_string(), expected.full.clone());
}
if let Some(ref found) = error.found {
bindings.insert("found_type".to_string(), found.full.clone());
}
if !bindings.is_empty() || !template.pattern.is_empty() {
let fix_code = template.apply(&bindings);
if !fix_code.is_empty() && fix_code != template.pattern {
let span = &error.span;
if span.byte_start < source.len() && span.byte_end <= source.len() {
let mut result = source.to_string();
result.replace_range(span.byte_start..span.byte_end, &fix_code);
return Some(result);
}
}
}
None
}
pub fn record_success(&mut self, error_embedding: ErrorEmbedding, fix_template: FixTemplate) {
self.pattern_library
.add_pattern(error_embedding, fix_template);
}
pub fn save_patterns(&self, path: &str) -> Result<(), OracleError> {
self.pattern_library
.save(path)
.map_err(|e| OracleError::Model(e.to_string()))
}
#[must_use]
pub fn metrics_summary(&self) -> aprender::citl::MetricsSummary {
self.metrics.summary()
}
#[must_use]
pub fn metrics(&self) -> &MetricsTracker {
&self.metrics
}
}
impl Default for CITLFixer {
fn default() -> Self {
Self::new().expect("CITLFixer initialization failed")
}
}
#[derive(Debug, Clone)]
pub struct CITLFixerConfig {
pub max_iterations: usize,
pub confidence_threshold: f32,
pub pattern_library_path: Option<String>,
pub use_autofixer_fallback: bool,
pub compilation_mode: CompilationMode,
}
impl Default for CITLFixerConfig {
fn default() -> Self {
Self {
max_iterations: 10,
confidence_threshold: 0.7,
pattern_library_path: None,
use_autofixer_fallback: true,
compilation_mode: CompilationMode::Standalone,
}
}
}
impl CITLFixerConfig {
#[must_use]
pub fn quick() -> Self {
Self {
max_iterations: 3,
confidence_threshold: 0.8,
pattern_library_path: None,
use_autofixer_fallback: true,
compilation_mode: CompilationMode::Standalone,
}
}
#[must_use]
pub fn thorough() -> Self {
Self {
max_iterations: 20,
confidence_threshold: 0.5,
pattern_library_path: None,
use_autofixer_fallback: true,
compilation_mode: CompilationMode::Standalone,
}
}
#[must_use]
pub fn with_pattern_library(mut self, path: &str) -> Self {
self.pattern_library_path = Some(path.to_string());
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_iterative_fix_result_success() {
let result =
IterativeFixResult::success("fn main() {}".to_string(), 3, vec!["fix1".to_string()]);
assert!(result.success);
assert_eq!(result.iterations, 3);
assert_eq!(result.remaining_errors, 0);
assert_eq!(result.fixes_applied.len(), 1);
}
#[test]
fn test_iterative_fix_result_failure() {
let result = IterativeFixResult::failure("broken".to_string(), 10, 5);
assert!(!result.success);
assert_eq!(result.iterations, 10);
assert_eq!(result.remaining_errors, 5);
}
#[test]
fn test_iterative_fix_result_with_duration() {
let result = IterativeFixResult::success("code".to_string(), 1, vec![]).with_duration(150);
assert_eq!(result.fix_duration_ms, 150);
}
#[test]
fn test_config_default() {
let config = CITLFixerConfig::default();
assert_eq!(config.max_iterations, 10);
assert!((config.confidence_threshold - 0.7).abs() < 0.001);
assert!(config.use_autofixer_fallback);
}
#[test]
fn test_config_quick() {
let config = CITLFixerConfig::quick();
assert_eq!(config.max_iterations, 3);
assert!((config.confidence_threshold - 0.8).abs() < 0.001);
}
#[test]
fn test_config_thorough() {
let config = CITLFixerConfig::thorough();
assert_eq!(config.max_iterations, 20);
assert!((config.confidence_threshold - 0.5).abs() < 0.001);
}
#[test]
fn test_config_with_pattern_library() {
let config = CITLFixerConfig::default().with_pattern_library("patterns.citl");
assert_eq!(
config.pattern_library_path,
Some("patterns.citl".to_string())
);
}
#[test]
fn test_fix_all_already_compiles() {
if std::env::var("DEPYLER_FAST_TESTS").is_ok() {
return;
}
let config = CITLFixerConfig {
use_autofixer_fallback: false,
..CITLFixerConfig::quick()
};
let mut fixer = CITLFixer::with_config(config).unwrap();
let valid_code = r#"fn main() { println!("Hello"); }"#;
let result = fixer.fix_all(valid_code);
assert!(result.success);
assert_eq!(result.iterations, 0);
}
#[test]
fn test_fix_all_respects_max_iterations() {
if std::env::var("DEPYLER_FAST_TESTS").is_ok() {
return;
}
let config = CITLFixerConfig {
max_iterations: 2,
use_autofixer_fallback: false,
..CITLFixerConfig::default()
};
let mut fixer = CITLFixer::with_config(config).unwrap();
let broken_code = "fn main() { undefined_function(); }";
let result = fixer.fix_all(broken_code);
assert!(!result.success);
assert!(result.iterations <= 2);
}
#[test]
fn test_compiles_valid_code() {
if std::env::var("DEPYLER_FAST_TESTS").is_ok() {
return;
}
let config = CITLFixerConfig {
use_autofixer_fallback: false,
..CITLFixerConfig::quick()
};
let fixer = CITLFixer::with_config(config).unwrap();
assert!(fixer.compiles("fn main() {}"));
assert!(!fixer.compiles("fn main() { undefined() }"));
}
#[test]
fn test_metrics_tracking() {
if std::env::var("DEPYLER_FAST_TESTS").is_ok() {
return;
}
let config = CITLFixerConfig {
max_iterations: 1,
use_autofixer_fallback: false,
..CITLFixerConfig::default()
};
let mut fixer = CITLFixer::with_config(config).unwrap();
let _ = fixer.fix_all("fn main() { x }");
let summary = fixer.metrics_summary();
let _ = summary.session_duration;
}
#[test]
#[ignore] fn test_fix_all_never_increases_errors() {
if std::env::var("DEPYLER_FAST_TESTS").is_ok() {
return;
}
let config = CITLFixerConfig {
max_iterations: 5,
use_autofixer_fallback: true,
..CITLFixerConfig::default()
};
let mut fixer = CITLFixer::with_config(config).unwrap();
let source = "fn main() { let x: i32 = \"string\"; }";
let initial_errors = fixer.count_errors(source);
let result = fixer.fix_all(source);
let final_errors = fixer.count_errors(&result.fixed_source);
assert!(
final_errors <= initial_errors || result.success,
"Errors increased: {} -> {}",
initial_errors,
final_errors
);
}
#[test]
fn test_fix_preserves_structure() {
if std::env::var("DEPYLER_FAST_TESTS").is_ok() {
return;
}
let config = CITLFixerConfig {
use_autofixer_fallback: false,
..CITLFixerConfig::quick()
};
let mut fixer = CITLFixer::with_config(config).unwrap();
let valid = "fn main() { let x = 42; println!(\"{}\", x); }";
let result = fixer.fix_all(valid);
assert!(result.success);
assert!(result.fixed_source.contains("fn main()"));
assert!(result.fixed_source.contains("let x"));
}
}