use super::{
andon::{render_andon_tui, render_compact},
curriculum::{CurriculumConfig, CurriculumScheduler},
DriftStatus, ErrorCategory, RuchyOracle, Sample, SampleSource,
};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct TrainingLoopConfig {
pub max_iterations: usize,
pub target_accuracy: f64,
pub auto_retrain: bool,
pub retrain_threshold: usize,
pub display_mode: DisplayMode,
pub curriculum: CurriculumConfig,
pub curriculum_enabled: bool,
}
impl Default for TrainingLoopConfig {
fn default() -> Self {
Self {
max_iterations: 50,
target_accuracy: 0.80,
auto_retrain: true,
retrain_threshold: 100,
display_mode: DisplayMode::Compact,
curriculum: CurriculumConfig::default(),
curriculum_enabled: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DisplayMode {
Verbose,
Compact,
Silent,
}
#[derive(Debug, Clone)]
pub enum TrainingEvent {
IterationComplete {
iteration: usize,
accuracy: f64,
drift_status: DriftStatus,
},
Converged { iteration: usize, accuracy: f64 },
RetrainingTriggered {
reason: RetrainReason,
samples: usize,
},
RetrainingComplete {
accuracy_before: f64,
accuracy_after: f64,
},
CurriculumAdvanced {
from: super::DifficultyLevel,
to: super::DifficultyLevel,
},
MaxIterationsReached { accuracy: f64 },
Error { message: String },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RetrainReason {
Drift,
Threshold,
Manual,
}
pub struct TrainingLoop {
oracle: RuchyOracle,
config: TrainingLoopConfig,
curriculum: CurriculumScheduler,
iteration: usize,
accuracy_history: Vec<f64>,
last_trained: Instant,
running: bool,
samples_since_retrain: usize,
holdout_set: Vec<Sample>,
holdout_index: usize,
}
impl TrainingLoop {
#[must_use]
pub fn new(oracle: RuchyOracle) -> Self {
Self::with_config(oracle, TrainingLoopConfig::default())
}
#[must_use]
pub fn with_config(oracle: RuchyOracle, config: TrainingLoopConfig) -> Self {
let curriculum = CurriculumScheduler::with_config(config.curriculum.clone());
let holdout_set = Self::build_holdout_set();
Self {
oracle,
config,
curriculum,
iteration: 0,
accuracy_history: Vec::new(),
last_trained: Instant::now(),
running: false,
samples_since_retrain: 0,
holdout_set,
holdout_index: 0,
}
}
fn build_holdout_set() -> Vec<Sample> {
vec![
Sample::new(
"expected `u32`, found `i64`",
Some("E0308".into()),
ErrorCategory::TypeMismatch,
)
.with_source(SampleSource::Synthetic),
Sample::new(
"type mismatch: expected String but found integer",
None,
ErrorCategory::TypeMismatch,
)
.with_source(SampleSource::Synthetic),
Sample::new(
"cannot borrow `x` as mutable because it is also borrowed as immutable",
Some("E0502".into()),
ErrorCategory::BorrowChecker,
)
.with_source(SampleSource::Synthetic),
Sample::new(
"cannot borrow value after move to another function",
None,
ErrorCategory::BorrowChecker,
)
.with_source(SampleSource::Synthetic),
Sample::new(
"borrowed value does not live long enough",
Some("E0597".into()),
ErrorCategory::LifetimeError,
)
.with_source(SampleSource::Synthetic),
Sample::new(
"lifetime of reference outlives the data it points to",
None,
ErrorCategory::LifetimeError,
)
.with_source(SampleSource::Synthetic),
Sample::new(
"the trait bound `MyType: Clone` is not satisfied",
Some("E0277".into()),
ErrorCategory::TraitBound,
)
.with_source(SampleSource::Synthetic),
Sample::new(
"trait Clone is not implemented for this type",
None,
ErrorCategory::TraitBound,
)
.with_source(SampleSource::Synthetic),
Sample::new(
"cannot find value `HashMap` in this scope",
Some("E0425".into()),
ErrorCategory::MissingImport,
)
.with_source(SampleSource::Synthetic),
Sample::new(
"unresolved import: module not found in crate",
None,
ErrorCategory::MissingImport,
)
.with_source(SampleSource::Synthetic),
Sample::new(
"cannot borrow `vec` as mutable, as it is not declared as mutable",
Some("E0596".into()),
ErrorCategory::MutabilityError,
)
.with_source(SampleSource::Synthetic),
Sample::new(
"requires mut binding but variable is immutable",
None,
ErrorCategory::MutabilityError,
)
.with_source(SampleSource::Synthetic),
Sample::new(
"expected `;`, found `let`",
None,
ErrorCategory::SyntaxError,
)
.with_source(SampleSource::Synthetic),
Sample::new(
"this function takes 2 arguments but 3 arguments were supplied",
Some("E0061".into()),
ErrorCategory::SyntaxError,
)
.with_source(SampleSource::Synthetic),
Sample::new(
"recursion limit reached while expanding the macro",
None,
ErrorCategory::Other,
)
.with_source(SampleSource::Synthetic),
Sample::new(
"internal compiler error: unexpected panic during compilation",
None,
ErrorCategory::Other,
)
.with_source(SampleSource::Synthetic),
]
}
#[must_use]
pub fn config(&self) -> &TrainingLoopConfig {
&self.config
}
#[must_use]
pub fn iteration(&self) -> usize {
self.iteration
}
#[must_use]
pub fn is_running(&self) -> bool {
self.running
}
pub fn oracle_mut(&mut self) -> &mut RuchyOracle {
&mut self.oracle
}
#[must_use]
pub fn oracle(&self) -> &RuchyOracle {
&self.oracle
}
#[must_use]
pub fn accuracy_history(&self) -> &[f64] {
&self.accuracy_history
}
pub fn step(&mut self) -> TrainingEvent {
self.running = true;
self.iteration += 1;
let accuracy = self.evaluate_holdout_batch();
self.accuracy_history.push(accuracy);
if self.accuracy_history.len() > 50 {
self.accuracy_history.remove(0);
}
let drift_status = self.oracle.drift_status();
if self.config.auto_retrain {
if drift_status == DriftStatus::Drift {
return self.trigger_retrain(RetrainReason::Drift);
}
if self.samples_since_retrain >= self.config.retrain_threshold {
return self.trigger_retrain(RetrainReason::Threshold);
}
}
if accuracy >= self.config.target_accuracy {
self.running = false;
return TrainingEvent::Converged {
iteration: self.iteration,
accuracy,
};
}
if self.iteration >= self.config.max_iterations {
self.running = false;
return TrainingEvent::MaxIterationsReached { accuracy };
}
if self.config.curriculum_enabled {
let old_level = self.curriculum.current_level();
self.curriculum.report_accuracy(accuracy);
let new_level = self.curriculum.current_level();
if old_level != new_level {
return TrainingEvent::CurriculumAdvanced {
from: old_level,
to: new_level,
};
}
}
TrainingEvent::IterationComplete {
iteration: self.iteration,
accuracy,
drift_status,
}
}
fn evaluate_holdout_batch(&mut self) -> f64 {
if self.holdout_set.is_empty() {
return 0.0;
}
const BATCH_SIZE: usize = 4;
let mut correct = 0;
let mut total = 0;
for _ in 0..BATCH_SIZE {
let sample = &self.holdout_set[self.holdout_index];
let error = sample.to_compilation_error();
let classification = self.oracle.classify(&error);
if classification.category == sample.category {
correct += 1;
}
self.oracle
.record_result(classification.category, sample.category);
self.holdout_index = (self.holdout_index + 1) % self.holdout_set.len();
total += 1;
}
if total == 0 {
0.0
} else {
f64::from(correct) / f64::from(total)
}
}
fn trigger_retrain(&mut self, _reason: RetrainReason) -> TrainingEvent {
let _samples = self.samples_since_retrain;
let accuracy_before = self.oracle.metadata().training_accuracy;
match self.oracle.retrain() {
Ok(()) => {
self.samples_since_retrain = 0;
self.last_trained = Instant::now();
self.oracle.reset_drift_detector();
let accuracy_after = self.oracle.metadata().training_accuracy;
TrainingEvent::RetrainingComplete {
accuracy_before,
accuracy_after,
}
}
Err(e) => TrainingEvent::Error {
message: format!("Retraining failed: {e}"),
},
}
}
pub fn record_result(&mut self, predicted: ErrorCategory, actual: ErrorCategory) {
self.oracle.record_result(predicted, actual);
self.samples_since_retrain += 1;
if self.config.curriculum_enabled {
self.curriculum.record_prediction(predicted == actual);
}
}
pub fn add_live_samples(&mut self, samples: Vec<Sample>) {
self.holdout_set.extend(samples);
}
#[must_use]
pub fn holdout_size(&self) -> usize {
self.holdout_set.len()
}
#[must_use]
pub fn render(&self) -> String {
match self.config.display_mode {
DisplayMode::Verbose => self.render_verbose(),
DisplayMode::Compact => self.render_compact(),
DisplayMode::Silent => String::new(),
}
}
fn render_verbose(&self) -> String {
let accuracy = self.current_accuracy();
let delta = self.compute_accuracy_delta();
let last_trained = self.format_last_trained();
let model_size = 500;
render_andon_tui(
self.iteration,
self.config.max_iterations,
accuracy,
self.config.target_accuracy,
delta,
&last_trained,
model_size,
&self.accuracy_history,
&self.oracle.drift_status(),
)
}
fn render_compact(&self) -> String {
let accuracy = self.current_accuracy();
let last_trained = self.format_last_trained_ago();
let model_size = 500;
render_compact(
self.iteration,
self.config.max_iterations,
accuracy,
model_size,
&last_trained,
&self.oracle.drift_status(),
)
}
fn current_accuracy(&self) -> f64 {
self.accuracy_history.last().copied().unwrap_or(0.0)
}
fn compute_accuracy_delta(&self) -> f64 {
if self.accuracy_history.len() < 2 {
return 0.0;
}
let len = self.accuracy_history.len();
self.accuracy_history[len - 1] - self.accuracy_history[len - 2]
}
fn format_last_trained(&self) -> String {
let elapsed = self.last_trained.elapsed();
format!(
"{} ({})",
chrono::Utc::now().format("%Y-%m-%d %H:%M:%S UTC"),
self.format_duration(elapsed)
)
}
fn format_last_trained_ago(&self) -> String {
self.format_duration(self.last_trained.elapsed())
}
fn format_duration(&self, duration: Duration) -> String {
let secs = duration.as_secs();
if secs < 60 {
format!("{secs}s ago")
} else if secs < 3600 {
format!("{}m ago", secs / 60)
} else if secs < 86400 {
format!("{}h ago", secs / 3600)
} else {
format!("{}d ago", secs / 86400)
}
}
pub fn run(&mut self) -> TrainingEvent {
loop {
let event = self.step();
match &event {
TrainingEvent::Converged { .. }
| TrainingEvent::MaxIterationsReached { .. }
| TrainingEvent::Error { .. } => {
self.running = false;
return event;
}
_ => {
if self.config.display_mode != DisplayMode::Silent {
println!("{}", self.render());
}
}
}
}
}
pub fn stop(&mut self) {
self.running = false;
}
pub fn reset(&mut self) {
self.iteration = 0;
self.accuracy_history.clear();
self.samples_since_retrain = 0;
self.running = false;
self.curriculum.reset();
self.holdout_index = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_training_loop_new() {
let oracle = RuchyOracle::new();
let loop_runner = TrainingLoop::new(oracle);
assert!(!loop_runner.is_running());
assert_eq!(loop_runner.iteration(), 0);
}
#[test]
fn test_training_loop_with_config() {
let config = TrainingLoopConfig {
max_iterations: 100,
target_accuracy: 0.90,
..Default::default()
};
let oracle = RuchyOracle::new();
let loop_runner = TrainingLoop::with_config(oracle, config);
assert_eq!(loop_runner.config().max_iterations, 100);
assert!((loop_runner.config().target_accuracy - 0.90).abs() < f64::EPSILON);
}
#[test]
fn test_training_loop_step() {
let mut oracle = RuchyOracle::new();
oracle.train_from_examples().expect("bootstrap");
let mut loop_runner = TrainingLoop::new(oracle);
let event = loop_runner.step();
assert!(matches!(
event,
TrainingEvent::IterationComplete { .. } | TrainingEvent::Converged { .. }
));
assert_eq!(loop_runner.iteration(), 1);
}
#[test]
fn test_training_loop_convergence() {
let mut oracle = RuchyOracle::new();
oracle.train_from_examples().expect("bootstrap");
let config = TrainingLoopConfig {
max_iterations: 10,
target_accuracy: 0.50, auto_retrain: false,
..Default::default()
};
let mut loop_runner = TrainingLoop::with_config(oracle, config);
let mut converged = false;
for _ in 0..10 {
match loop_runner.step() {
TrainingEvent::Converged { accuracy, .. } => {
converged = true;
assert!(accuracy >= 0.50);
break;
}
TrainingEvent::MaxIterationsReached { .. } => break,
_ => {}
}
}
assert!(converged || loop_runner.iteration() >= 10);
}
#[test]
fn test_training_loop_record_result() {
let mut oracle = RuchyOracle::new();
oracle.train_from_examples().expect("bootstrap");
let mut loop_runner = TrainingLoop::new(oracle);
loop_runner.record_result(ErrorCategory::TypeMismatch, ErrorCategory::TypeMismatch);
}
#[test]
fn test_training_loop_render_compact() {
let mut oracle = RuchyOracle::new();
oracle.train_from_examples().expect("bootstrap");
let config = TrainingLoopConfig {
display_mode: DisplayMode::Compact,
..Default::default()
};
let loop_runner = TrainingLoop::with_config(oracle, config);
let output = loop_runner.render();
assert!(output.contains("Oracle"));
}
#[test]
fn test_training_loop_render_silent() {
let oracle = RuchyOracle::new();
let config = TrainingLoopConfig {
display_mode: DisplayMode::Silent,
..Default::default()
};
let loop_runner = TrainingLoop::with_config(oracle, config);
let output = loop_runner.render();
assert!(output.is_empty());
}
#[test]
fn test_training_loop_reset() {
let mut oracle = RuchyOracle::new();
oracle.train_from_examples().expect("bootstrap");
let mut loop_runner = TrainingLoop::new(oracle);
loop_runner.step();
loop_runner.step();
assert!(loop_runner.iteration() >= 2);
loop_runner.reset();
assert_eq!(loop_runner.iteration(), 0);
assert!(loop_runner.accuracy_history().is_empty());
}
#[test]
fn test_training_loop_config_default() {
let config = TrainingLoopConfig::default();
assert_eq!(config.max_iterations, 50);
assert!((config.target_accuracy - 0.80).abs() < f64::EPSILON);
assert!(config.auto_retrain);
}
#[test]
fn test_display_mode_variants() {
assert_eq!(DisplayMode::Verbose, DisplayMode::Verbose);
assert_ne!(DisplayMode::Verbose, DisplayMode::Compact);
assert_ne!(DisplayMode::Compact, DisplayMode::Silent);
}
#[test]
fn test_retrain_reason_variants() {
assert_eq!(RetrainReason::Drift, RetrainReason::Drift);
assert_ne!(RetrainReason::Drift, RetrainReason::Threshold);
}
#[test]
fn test_holdout_set_initialized() {
let oracle = RuchyOracle::new();
let loop_runner = TrainingLoop::new(oracle);
assert!(!loop_runner.holdout_set.is_empty());
assert_eq!(loop_runner.holdout_set.len(), 16);
}
#[test]
fn test_holdout_evaluation_produces_real_accuracy() {
let mut oracle = RuchyOracle::new();
oracle.train_from_examples().expect("bootstrap");
let mut loop_runner = TrainingLoop::new(oracle);
let event = loop_runner.step();
match event {
TrainingEvent::IterationComplete { accuracy, .. }
| TrainingEvent::Converged { accuracy, .. } => {
assert!(accuracy > 0.0, "Accuracy should be > 0% after training");
}
TrainingEvent::MaxIterationsReached { accuracy } => {
assert!(accuracy > 0.0, "Accuracy should be > 0% after training");
}
_ => {}
}
assert!(!loop_runner.accuracy_history().is_empty());
for &acc in loop_runner.accuracy_history() {
assert!((0.0..=1.0).contains(&acc), "Accuracy must be in [0, 1]");
}
}
#[test]
fn test_current_accuracy_reflects_history() {
let mut oracle = RuchyOracle::new();
oracle.train_from_examples().expect("bootstrap");
let mut loop_runner = TrainingLoop::new(oracle);
assert_eq!(loop_runner.current_accuracy(), 0.0);
loop_runner.step();
let current = loop_runner.current_accuracy();
let last_history = loop_runner
.accuracy_history()
.last()
.copied()
.unwrap_or(0.0);
assert!((current - last_history).abs() < f64::EPSILON);
}
#[test]
fn test_holdout_index_wraps_around() {
let mut oracle = RuchyOracle::new();
oracle.train_from_examples().expect("bootstrap");
let mut loop_runner = TrainingLoop::new(oracle);
let holdout_size = loop_runner.holdout_set.len();
let steps_needed = (holdout_size / 4) + 2;
for _ in 0..steps_needed {
loop_runner.step();
}
assert!(loop_runner.holdout_index < holdout_size);
}
#[test]
fn test_render_uses_evaluated_accuracy() {
let mut oracle = RuchyOracle::new();
oracle.train_from_examples().expect("bootstrap");
let config = TrainingLoopConfig {
display_mode: DisplayMode::Compact,
..Default::default()
};
let mut loop_runner = TrainingLoop::with_config(oracle, config);
loop_runner.step();
let output = loop_runner.render();
assert!(output.contains("Oracle"));
}
#[test]
fn test_add_live_samples() {
let oracle = RuchyOracle::new();
let mut loop_runner = TrainingLoop::new(oracle);
let initial_size = loop_runner.holdout_size();
assert_eq!(initial_size, 16);
let live_samples = vec![
Sample::new(
"custom error from examples",
Some("E0999".into()),
ErrorCategory::Other,
)
.with_source(SampleSource::Examples),
Sample::new("another custom error", None, ErrorCategory::SyntaxError)
.with_source(SampleSource::Examples),
];
loop_runner.add_live_samples(live_samples);
assert_eq!(loop_runner.holdout_size(), initial_size + 2);
}
#[test]
fn test_live_samples_evaluated_in_step() {
let mut oracle = RuchyOracle::new();
oracle.train_from_examples().expect("bootstrap");
let mut loop_runner = TrainingLoop::new(oracle);
let live_samples = vec![Sample::new(
"expected type `String`, found `i32`",
Some("E0308".into()),
ErrorCategory::TypeMismatch,
)
.with_source(SampleSource::Examples)];
loop_runner.add_live_samples(live_samples);
for _ in 0..5 {
loop_runner.step();
}
assert!(!loop_runner.accuracy_history().is_empty());
}
}