pub mod jidoka;
pub mod multi_turn;
pub mod prediction;
#[cfg(test)]
mod tests;
pub use jidoka::*;
pub use multi_turn::*;
pub use prediction::*;
use serde::{Deserialize, Serialize};
use crate::engine::rng::{RngState, SimRng};
use crate::engine::SimTime;
use crate::error::{SimError, SimResult};
use crate::replay::EventJournal;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingConfig {
pub learning_rate: f64,
pub batch_size: usize,
pub epochs: u64,
pub early_stopping: Option<usize>,
pub gradient_clip: Option<f64>,
pub weight_decay: f64,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
learning_rate: 0.001,
batch_size: 32,
epochs: 100,
early_stopping: Some(10),
gradient_clip: Some(1.0),
weight_decay: 0.0001,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingState {
pub epoch: u64,
pub loss: f64,
pub val_loss: f64,
pub metrics: TrainingMetrics,
pub rng_state: RngState,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingMetrics {
pub train_loss: f64,
pub val_loss: f64,
pub accuracy: Option<f64>,
pub gradient_norm: f64,
pub learning_rate: f64,
pub params_updated: usize,
}
impl Default for TrainingMetrics {
fn default() -> Self {
Self {
train_loss: 0.0,
val_loss: 0.0,
accuracy: None,
gradient_norm: 0.0,
learning_rate: 0.001,
params_updated: 0,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TrainingTrajectory {
pub states: Vec<TrainingState>,
}
impl TrainingTrajectory {
#[must_use]
pub fn new() -> Self {
Self { states: Vec::new() }
}
pub fn push(&mut self, state: TrainingState) {
self.states.push(state);
}
#[must_use]
pub fn final_state(&self) -> Option<&TrainingState> {
self.states.last()
}
#[must_use]
pub fn best_val_loss(&self) -> Option<f64> {
self.states
.iter()
.map(|s| s.val_loss)
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
}
#[must_use]
pub fn converged(&self, tolerance: f64) -> bool {
if self.states.len() < 10 {
return false;
}
let recent: Vec<f64> = self.states.iter().rev().take(10).map(|s| s.loss).collect();
let mean = recent.iter().sum::<f64>() / recent.len() as f64;
let variance = recent.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / recent.len() as f64;
variance.sqrt() < tolerance
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum TrainingAnomaly {
NonFiniteLoss,
GradientExplosion {
norm: f64,
threshold: f64,
},
GradientVanishing {
norm: f64,
threshold: f64,
},
LossSpike {
z_score: f64,
loss: f64,
},
LowConfidence {
confidence: f64,
threshold: f64,
},
}
impl std::fmt::Display for TrainingAnomaly {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NonFiniteLoss => write!(f, "Non-finite loss detected (NaN/Inf)"),
Self::GradientExplosion { norm, threshold } => {
write!(
f,
"Gradient explosion: norm={norm:.2e} > threshold={threshold:.2e}"
)
}
Self::GradientVanishing { norm, threshold } => {
write!(
f,
"Gradient vanishing: norm={norm:.2e} < threshold={threshold:.2e}"
)
}
Self::LossSpike { z_score, loss } => {
write!(f, "Loss spike: z-score={z_score:.2}, loss={loss:.4}")
}
Self::LowConfidence {
confidence,
threshold,
} => {
write!(
f,
"Low confidence: {confidence:.4} < threshold={threshold:.4}"
)
}
}
}
}
#[derive(Debug, Clone, Default)]
pub struct RollingStats {
count: u64,
mean: f64,
m2: f64,
window_size: usize,
recent: Vec<f64>,
}
impl RollingStats {
#[must_use]
pub fn new(window_size: usize) -> Self {
Self {
count: 0,
mean: 0.0,
m2: 0.0,
window_size,
recent: Vec::new(),
}
}
pub fn update(&mut self, value: f64) {
self.count += 1;
let delta = value - self.mean;
self.mean += delta / self.count as f64;
let delta2 = value - self.mean;
self.m2 += delta * delta2;
if self.window_size > 0 {
self.recent.push(value);
if self.recent.len() > self.window_size {
self.recent.remove(0);
}
}
}
#[must_use]
pub fn mean(&self) -> f64 {
self.mean
}
#[must_use]
pub fn variance(&self) -> f64 {
if self.count < 2 {
return 0.0;
}
self.m2 / (self.count - 1) as f64
}
#[must_use]
pub fn std_dev(&self) -> f64 {
self.variance().sqrt()
}
#[must_use]
pub fn z_score(&self, value: f64) -> f64 {
let std = self.std_dev();
if std < 1e-10 {
return 0.0;
}
(value - self.mean) / std
}
pub fn reset(&mut self) {
self.count = 0;
self.mean = 0.0;
self.m2 = 0.0;
self.recent.clear();
}
}
#[derive(Debug, Clone)]
pub struct AnomalyDetector {
loss_stats: RollingStats,
threshold_sigma: f64,
gradient_explosion_threshold: f64,
gradient_vanishing_threshold: f64,
warmup_count: u64,
anomaly_count: u64,
}
impl AnomalyDetector {
#[must_use]
pub fn new(threshold_sigma: f64) -> Self {
Self {
loss_stats: RollingStats::new(100),
threshold_sigma,
gradient_explosion_threshold: 1e6,
gradient_vanishing_threshold: 1e-10,
warmup_count: 10,
anomaly_count: 0,
}
}
#[must_use]
pub fn with_gradient_explosion_threshold(mut self, threshold: f64) -> Self {
self.gradient_explosion_threshold = threshold;
self
}
#[must_use]
pub fn with_gradient_vanishing_threshold(mut self, threshold: f64) -> Self {
self.gradient_vanishing_threshold = threshold;
self
}
#[must_use]
pub fn with_warmup(mut self, count: u64) -> Self {
self.warmup_count = count;
self
}
pub fn check(&mut self, loss: f64, gradient_norm: f64) -> Option<TrainingAnomaly> {
if !loss.is_finite() {
self.anomaly_count += 1;
return Some(TrainingAnomaly::NonFiniteLoss);
}
if gradient_norm > self.gradient_explosion_threshold {
self.anomaly_count += 1;
return Some(TrainingAnomaly::GradientExplosion {
norm: gradient_norm,
threshold: self.gradient_explosion_threshold,
});
}
if gradient_norm < self.gradient_vanishing_threshold && gradient_norm > 0.0 {
self.anomaly_count += 1;
return Some(TrainingAnomaly::GradientVanishing {
norm: gradient_norm,
threshold: self.gradient_vanishing_threshold,
});
}
self.loss_stats.update(loss);
if self.loss_stats.count > self.warmup_count {
let z_score = self.loss_stats.z_score(loss);
if z_score.abs() > self.threshold_sigma {
self.anomaly_count += 1;
return Some(TrainingAnomaly::LossSpike { z_score, loss });
}
}
None
}
#[must_use]
pub fn anomaly_count(&self) -> u64 {
self.anomaly_count
}
pub fn reset(&mut self) {
self.loss_stats.reset();
self.anomaly_count = 0;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TrainEvent {
Epoch(TrainingState),
Anomaly(String),
Checkpoint { epoch: u64 },
EarlyStopping { best_epoch: u64, best_val_loss: f64 },
}
pub struct TrainingSimulation {
config: TrainingConfig,
rng: SimRng,
journal: EventJournal,
anomaly_detector: AnomalyDetector,
current_epoch: u64,
trajectory: TrainingTrajectory,
best_val_loss: f64,
epochs_without_improvement: usize,
}
impl TrainingSimulation {
#[must_use]
pub fn new(seed: u64) -> Self {
Self {
config: TrainingConfig::default(),
rng: SimRng::new(seed),
journal: EventJournal::new(true), anomaly_detector: AnomalyDetector::new(3.0), current_epoch: 0,
trajectory: TrainingTrajectory::new(),
best_val_loss: f64::INFINITY,
epochs_without_improvement: 0,
}
}
#[must_use]
pub fn with_config(seed: u64, config: TrainingConfig) -> Self {
Self {
config,
rng: SimRng::new(seed),
journal: EventJournal::new(true), anomaly_detector: AnomalyDetector::new(3.0),
current_epoch: 0,
trajectory: TrainingTrajectory::new(),
best_val_loss: f64::INFINITY,
epochs_without_improvement: 0,
}
}
pub fn set_anomaly_detector(&mut self, detector: AnomalyDetector) {
self.anomaly_detector = detector;
}
#[must_use]
pub fn config(&self) -> &TrainingConfig {
&self.config
}
#[must_use]
pub fn trajectory(&self) -> &TrainingTrajectory {
&self.trajectory
}
pub fn step(&mut self, loss: f64, gradient_norm: f64) -> SimResult<Option<TrainingState>> {
if let Some(anomaly) = self.anomaly_detector.check(loss, gradient_norm) {
let event = TrainEvent::Anomaly(anomaly.to_string());
let rng_state = self.rng.save_state();
let _ = self.journal.append(
SimTime::from_secs(self.current_epoch as f64),
self.current_epoch,
&event,
Some(&rng_state),
);
return Err(SimError::jidoka(format!(
"Training anomaly at epoch {}: {anomaly}",
self.current_epoch
)));
}
let val_loss = loss * (1.0 + 0.1 * (self.rng.gen_f64() - 0.5));
let rng_state = self.rng.save_state();
let state = TrainingState {
epoch: self.current_epoch,
loss,
val_loss,
metrics: TrainingMetrics {
train_loss: loss,
val_loss,
accuracy: None,
gradient_norm,
learning_rate: self.config.learning_rate,
params_updated: 1000, },
rng_state: rng_state.clone(),
};
if val_loss < self.best_val_loss {
self.best_val_loss = val_loss;
self.epochs_without_improvement = 0;
} else {
self.epochs_without_improvement += 1;
}
let event = TrainEvent::Epoch(state.clone());
let _ = self.journal.append(
SimTime::from_secs(self.current_epoch as f64),
self.current_epoch,
&event,
Some(&rng_state),
);
self.trajectory.push(state.clone());
self.current_epoch += 1;
if let Some(patience) = self.config.early_stopping {
if self.epochs_without_improvement >= patience {
let event = TrainEvent::EarlyStopping {
best_epoch: self.current_epoch - patience as u64,
best_val_loss: self.best_val_loss,
};
let rng_state = self.rng.save_state();
let _ = self.journal.append(
SimTime::from_secs(self.current_epoch as f64),
self.current_epoch,
&event,
Some(&rng_state),
);
return Ok(None); }
}
Ok(Some(state))
}
pub fn simulate<F>(&mut self, epochs: u64, mut loss_fn: F) -> SimResult<&TrainingTrajectory>
where
F: FnMut(u64, &mut SimRng) -> (f64, f64),
{
contract_pre_iterator!();
for epoch in 0..epochs {
let (loss, grad_norm) = loss_fn(epoch, &mut self.rng);
if self.step(loss, grad_norm)?.is_none() {
break; }
}
Ok(&self.trajectory)
}
pub fn replay_from(&mut self, checkpoint: &TrainingState) -> SimResult<()> {
self.rng
.restore_state(&checkpoint.rng_state)
.map_err(|e| SimError::config(format!("Failed to restore RNG state: {e}")))?;
self.current_epoch = checkpoint.epoch;
Ok(())
}
#[must_use]
pub fn journal(&self) -> &EventJournal {
&self.journal
}
pub fn reset(&mut self, seed: u64) {
self.rng = SimRng::new(seed);
self.journal = EventJournal::new(true);
self.anomaly_detector.reset();
self.current_epoch = 0;
self.trajectory = TrainingTrajectory::new();
self.best_val_loss = f64::INFINITY;
self.epochs_without_improvement = 0;
}
}