1pub mod jidoka;
21pub mod multi_turn;
22pub mod prediction;
23
24#[cfg(test)]
25mod tests;
26
27pub use jidoka::*;
28pub use multi_turn::*;
29pub use prediction::*;
30
31use serde::{Deserialize, Serialize};
32
33use crate::engine::rng::{RngState, SimRng};
34use crate::engine::SimTime;
35use crate::error::{SimError, SimResult};
36use crate::replay::EventJournal;
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct TrainingConfig {
45 pub learning_rate: f64,
47 pub batch_size: usize,
49 pub epochs: u64,
51 pub early_stopping: Option<usize>,
53 pub gradient_clip: Option<f64>,
55 pub weight_decay: f64,
57}
58
59impl Default for TrainingConfig {
60 fn default() -> Self {
61 Self {
62 learning_rate: 0.001,
63 batch_size: 32,
64 epochs: 100,
65 early_stopping: Some(10),
66 gradient_clip: Some(1.0),
67 weight_decay: 0.0001,
68 }
69 }
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct TrainingState {
75 pub epoch: u64,
77 pub loss: f64,
79 pub val_loss: f64,
81 pub metrics: TrainingMetrics,
83 pub rng_state: RngState,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct TrainingMetrics {
90 pub train_loss: f64,
92 pub val_loss: f64,
94 pub accuracy: Option<f64>,
96 pub gradient_norm: f64,
98 pub learning_rate: f64,
100 pub params_updated: usize,
102}
103
104impl Default for TrainingMetrics {
105 fn default() -> Self {
106 Self {
107 train_loss: 0.0,
108 val_loss: 0.0,
109 accuracy: None,
110 gradient_norm: 0.0,
111 learning_rate: 0.001,
112 params_updated: 0,
113 }
114 }
115}
116
117#[derive(Debug, Clone, Default, Serialize, Deserialize)]
119pub struct TrainingTrajectory {
120 pub states: Vec<TrainingState>,
122}
123
124impl TrainingTrajectory {
125 #[must_use]
127 pub fn new() -> Self {
128 Self { states: Vec::new() }
129 }
130
131 pub fn push(&mut self, state: TrainingState) {
133 self.states.push(state);
134 }
135
136 #[must_use]
138 pub fn final_state(&self) -> Option<&TrainingState> {
139 self.states.last()
140 }
141
142 #[must_use]
144 pub fn best_val_loss(&self) -> Option<f64> {
145 self.states
146 .iter()
147 .map(|s| s.val_loss)
148 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
149 }
150
151 #[must_use]
153 pub fn converged(&self, tolerance: f64) -> bool {
154 if self.states.len() < 10 {
155 return false;
156 }
157 let recent: Vec<f64> = self.states.iter().rev().take(10).map(|s| s.loss).collect();
158 let mean = recent.iter().sum::<f64>() / recent.len() as f64;
159 let variance = recent.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / recent.len() as f64;
160 variance.sqrt() < tolerance
161 }
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
170pub enum TrainingAnomaly {
171 NonFiniteLoss,
173 GradientExplosion {
175 norm: f64,
177 threshold: f64,
179 },
180 GradientVanishing {
182 norm: f64,
184 threshold: f64,
186 },
187 LossSpike {
189 z_score: f64,
191 loss: f64,
193 },
194 LowConfidence {
196 confidence: f64,
198 threshold: f64,
200 },
201}
202
203impl std::fmt::Display for TrainingAnomaly {
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 match self {
206 Self::NonFiniteLoss => write!(f, "Non-finite loss detected (NaN/Inf)"),
207 Self::GradientExplosion { norm, threshold } => {
208 write!(
209 f,
210 "Gradient explosion: norm={norm:.2e} > threshold={threshold:.2e}"
211 )
212 }
213 Self::GradientVanishing { norm, threshold } => {
214 write!(
215 f,
216 "Gradient vanishing: norm={norm:.2e} < threshold={threshold:.2e}"
217 )
218 }
219 Self::LossSpike { z_score, loss } => {
220 write!(f, "Loss spike: z-score={z_score:.2}, loss={loss:.4}")
221 }
222 Self::LowConfidence {
223 confidence,
224 threshold,
225 } => {
226 write!(
227 f,
228 "Low confidence: {confidence:.4} < threshold={threshold:.4}"
229 )
230 }
231 }
232 }
233}
234
235#[derive(Debug, Clone, Default)]
237pub struct RollingStats {
238 count: u64,
240 mean: f64,
242 m2: f64,
244 window_size: usize,
246 recent: Vec<f64>,
248}
249
250impl RollingStats {
251 #[must_use]
253 pub fn new(window_size: usize) -> Self {
254 Self {
255 count: 0,
256 mean: 0.0,
257 m2: 0.0,
258 window_size,
259 recent: Vec::new(),
260 }
261 }
262
263 pub fn update(&mut self, value: f64) {
265 self.count += 1;
266 let delta = value - self.mean;
267 self.mean += delta / self.count as f64;
268 let delta2 = value - self.mean;
269 self.m2 += delta * delta2;
270
271 if self.window_size > 0 {
272 self.recent.push(value);
273 if self.recent.len() > self.window_size {
274 self.recent.remove(0);
275 }
276 }
277 }
278
279 #[must_use]
281 pub fn mean(&self) -> f64 {
282 self.mean
283 }
284
285 #[must_use]
287 pub fn variance(&self) -> f64 {
288 if self.count < 2 {
289 return 0.0;
290 }
291 self.m2 / (self.count - 1) as f64
292 }
293
294 #[must_use]
296 pub fn std_dev(&self) -> f64 {
297 self.variance().sqrt()
298 }
299
300 #[must_use]
302 pub fn z_score(&self, value: f64) -> f64 {
303 let std = self.std_dev();
304 if std < 1e-10 {
305 return 0.0;
306 }
307 (value - self.mean) / std
308 }
309
310 pub fn reset(&mut self) {
312 self.count = 0;
313 self.mean = 0.0;
314 self.m2 = 0.0;
315 self.recent.clear();
316 }
317}
318
319#[derive(Debug, Clone)]
321pub struct AnomalyDetector {
322 loss_stats: RollingStats,
324 threshold_sigma: f64,
326 gradient_explosion_threshold: f64,
328 gradient_vanishing_threshold: f64,
330 warmup_count: u64,
332 anomaly_count: u64,
334}
335
336impl AnomalyDetector {
337 #[must_use]
339 pub fn new(threshold_sigma: f64) -> Self {
340 Self {
341 loss_stats: RollingStats::new(100),
342 threshold_sigma,
343 gradient_explosion_threshold: 1e6,
344 gradient_vanishing_threshold: 1e-10,
345 warmup_count: 10,
346 anomaly_count: 0,
347 }
348 }
349
350 #[must_use]
352 pub fn with_gradient_explosion_threshold(mut self, threshold: f64) -> Self {
353 self.gradient_explosion_threshold = threshold;
354 self
355 }
356
357 #[must_use]
359 pub fn with_gradient_vanishing_threshold(mut self, threshold: f64) -> Self {
360 self.gradient_vanishing_threshold = threshold;
361 self
362 }
363
364 #[must_use]
366 pub fn with_warmup(mut self, count: u64) -> Self {
367 self.warmup_count = count;
368 self
369 }
370
371 pub fn check(&mut self, loss: f64, gradient_norm: f64) -> Option<TrainingAnomaly> {
373 if !loss.is_finite() {
375 self.anomaly_count += 1;
376 return Some(TrainingAnomaly::NonFiniteLoss);
377 }
378
379 if gradient_norm > self.gradient_explosion_threshold {
381 self.anomaly_count += 1;
382 return Some(TrainingAnomaly::GradientExplosion {
383 norm: gradient_norm,
384 threshold: self.gradient_explosion_threshold,
385 });
386 }
387
388 if gradient_norm < self.gradient_vanishing_threshold && gradient_norm > 0.0 {
390 self.anomaly_count += 1;
391 return Some(TrainingAnomaly::GradientVanishing {
392 norm: gradient_norm,
393 threshold: self.gradient_vanishing_threshold,
394 });
395 }
396
397 self.loss_stats.update(loss);
399 if self.loss_stats.count > self.warmup_count {
400 let z_score = self.loss_stats.z_score(loss);
401 if z_score.abs() > self.threshold_sigma {
402 self.anomaly_count += 1;
403 return Some(TrainingAnomaly::LossSpike { z_score, loss });
404 }
405 }
406
407 None
408 }
409
410 #[must_use]
412 pub fn anomaly_count(&self) -> u64 {
413 self.anomaly_count
414 }
415
416 pub fn reset(&mut self) {
418 self.loss_stats.reset();
419 self.anomaly_count = 0;
420 }
421}
422
423#[derive(Debug, Clone, Serialize, Deserialize)]
429pub enum TrainEvent {
430 Epoch(TrainingState),
432 Anomaly(String),
434 Checkpoint { epoch: u64 },
436 EarlyStopping { best_epoch: u64, best_val_loss: f64 },
438}
439
440pub struct TrainingSimulation {
447 config: TrainingConfig,
449 rng: SimRng,
451 journal: EventJournal,
453 anomaly_detector: AnomalyDetector,
455 current_epoch: u64,
457 trajectory: TrainingTrajectory,
459 best_val_loss: f64,
461 epochs_without_improvement: usize,
463}
464
465impl TrainingSimulation {
466 #[must_use]
468 pub fn new(seed: u64) -> Self {
469 Self {
470 config: TrainingConfig::default(),
471 rng: SimRng::new(seed),
472 journal: EventJournal::new(true), anomaly_detector: AnomalyDetector::new(3.0), current_epoch: 0,
475 trajectory: TrainingTrajectory::new(),
476 best_val_loss: f64::INFINITY,
477 epochs_without_improvement: 0,
478 }
479 }
480
481 #[must_use]
483 pub fn with_config(seed: u64, config: TrainingConfig) -> Self {
484 Self {
485 config,
486 rng: SimRng::new(seed),
487 journal: EventJournal::new(true), anomaly_detector: AnomalyDetector::new(3.0),
489 current_epoch: 0,
490 trajectory: TrainingTrajectory::new(),
491 best_val_loss: f64::INFINITY,
492 epochs_without_improvement: 0,
493 }
494 }
495
496 pub fn set_anomaly_detector(&mut self, detector: AnomalyDetector) {
498 self.anomaly_detector = detector;
499 }
500
501 #[must_use]
503 pub fn config(&self) -> &TrainingConfig {
504 &self.config
505 }
506
507 #[must_use]
509 pub fn trajectory(&self) -> &TrainingTrajectory {
510 &self.trajectory
511 }
512
513 pub fn step(&mut self, loss: f64, gradient_norm: f64) -> SimResult<Option<TrainingState>> {
523 if let Some(anomaly) = self.anomaly_detector.check(loss, gradient_norm) {
525 let event = TrainEvent::Anomaly(anomaly.to_string());
526 let rng_state = self.rng.save_state();
527 let _ = self.journal.append(
528 SimTime::from_secs(self.current_epoch as f64),
529 self.current_epoch,
530 &event,
531 Some(&rng_state),
532 );
533 return Err(SimError::jidoka(format!(
534 "Training anomaly at epoch {}: {anomaly}",
535 self.current_epoch
536 )));
537 }
538
539 let val_loss = loss * (1.0 + 0.1 * (self.rng.gen_f64() - 0.5));
541
542 let rng_state = self.rng.save_state();
544 let state = TrainingState {
545 epoch: self.current_epoch,
546 loss,
547 val_loss,
548 metrics: TrainingMetrics {
549 train_loss: loss,
550 val_loss,
551 accuracy: None,
552 gradient_norm,
553 learning_rate: self.config.learning_rate,
554 params_updated: 1000, },
556 rng_state: rng_state.clone(),
557 };
558
559 if val_loss < self.best_val_loss {
561 self.best_val_loss = val_loss;
562 self.epochs_without_improvement = 0;
563 } else {
564 self.epochs_without_improvement += 1;
565 }
566
567 let event = TrainEvent::Epoch(state.clone());
569 let _ = self.journal.append(
570 SimTime::from_secs(self.current_epoch as f64),
571 self.current_epoch,
572 &event,
573 Some(&rng_state),
574 );
575 self.trajectory.push(state.clone());
576
577 self.current_epoch += 1;
578
579 if let Some(patience) = self.config.early_stopping {
581 if self.epochs_without_improvement >= patience {
582 let event = TrainEvent::EarlyStopping {
583 best_epoch: self.current_epoch - patience as u64,
584 best_val_loss: self.best_val_loss,
585 };
586 let rng_state = self.rng.save_state();
587 let _ = self.journal.append(
588 SimTime::from_secs(self.current_epoch as f64),
589 self.current_epoch,
590 &event,
591 Some(&rng_state),
592 );
593 return Ok(None); }
595 }
596
597 Ok(Some(state))
598 }
599
600 pub fn simulate<F>(&mut self, epochs: u64, mut loss_fn: F) -> SimResult<&TrainingTrajectory>
608 where
609 F: FnMut(u64, &mut SimRng) -> (f64, f64),
610 {
611 contract_pre_iterator!();
612 for epoch in 0..epochs {
613 let (loss, grad_norm) = loss_fn(epoch, &mut self.rng);
614 if self.step(loss, grad_norm)?.is_none() {
615 break; }
617 }
618 Ok(&self.trajectory)
619 }
620
621 pub fn replay_from(&mut self, checkpoint: &TrainingState) -> SimResult<()> {
627 self.rng
628 .restore_state(&checkpoint.rng_state)
629 .map_err(|e| SimError::config(format!("Failed to restore RNG state: {e}")))?;
630 self.current_epoch = checkpoint.epoch;
631 Ok(())
632 }
633
634 #[must_use]
636 pub fn journal(&self) -> &EventJournal {
637 &self.journal
638 }
639
640 pub fn reset(&mut self, seed: u64) {
642 self.rng = SimRng::new(seed);
643 self.journal = EventJournal::new(true);
644 self.anomaly_detector.reset();
645 self.current_epoch = 0;
646 self.trajectory = TrainingTrajectory::new();
647 self.best_val_loss = f64::INFINITY;
648 self.epochs_without_improvement = 0;
649 }
650}