Skip to main content

simular/domains/ml/
mod.rs

1//! Machine Learning Simulation Engine.
2//!
3//! Provides deterministic, reproducible simulation of ML workflows using
4//! Popperian falsification methodology. Implements TPS principles:
5//! - Jidoka: Stop-on-anomaly detection
6//! - Heijunka: Load-balanced batch processing
7//! - Kaizen: Continuous improvement via feedback
8//!
9//! # Example
10//!
11//! ```rust
12//! use simular::domains::ml::{TrainingSimulation, TrainingConfig, AnomalyDetector};
13//! use simular::engine::rng::SimRng;
14//!
15//! let mut sim = TrainingSimulation::new(42);
16//! let config = TrainingConfig::default();
17//! // Training simulation would run here
18//! ```
19
20pub mod jidoka;
21pub mod multi_turn;
22pub mod prediction;
23
24#[cfg(test)]
25mod tests;
26
27pub use jidoka::*;
28pub use multi_turn::*;
29pub use prediction::*;
30
31use serde::{Deserialize, Serialize};
32
33use crate::engine::rng::{RngState, SimRng};
34use crate::engine::SimTime;
35use crate::error::{SimError, SimResult};
36use crate::replay::EventJournal;
37
38// ============================================================================
39// Training Simulation Types
40// ============================================================================
41
42/// Training hyperparameters configuration.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct TrainingConfig {
45    /// Learning rate.
46    pub learning_rate: f64,
47    /// Batch size for training.
48    pub batch_size: usize,
49    /// Number of epochs.
50    pub epochs: u64,
51    /// Early stopping patience (None = disabled).
52    pub early_stopping: Option<usize>,
53    /// Gradient clipping max norm (None = disabled).
54    pub gradient_clip: Option<f64>,
55    /// Weight decay (L2 regularization).
56    pub weight_decay: f64,
57}
58
59impl Default for TrainingConfig {
60    fn default() -> Self {
61        Self {
62            learning_rate: 0.001,
63            batch_size: 32,
64            epochs: 100,
65            early_stopping: Some(10),
66            gradient_clip: Some(1.0),
67            weight_decay: 0.0001,
68        }
69    }
70}
71
72/// Training state captured at each epoch.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct TrainingState {
75    /// Current epoch.
76    pub epoch: u64,
77    /// Training loss.
78    pub loss: f64,
79    /// Validation loss.
80    pub val_loss: f64,
81    /// Training metrics.
82    pub metrics: TrainingMetrics,
83    /// RNG state for perfect reproducibility.
84    pub rng_state: RngState,
85}
86
87/// Training metrics collected during simulation.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct TrainingMetrics {
90    /// Training loss.
91    pub train_loss: f64,
92    /// Validation loss.
93    pub val_loss: f64,
94    /// Accuracy (if classification).
95    pub accuracy: Option<f64>,
96    /// Gradient L2 norm.
97    pub gradient_norm: f64,
98    /// Current learning rate (after scheduling).
99    pub learning_rate: f64,
100    /// Number of parameters updated.
101    pub params_updated: usize,
102}
103
104impl Default for TrainingMetrics {
105    fn default() -> Self {
106        Self {
107            train_loss: 0.0,
108            val_loss: 0.0,
109            accuracy: None,
110            gradient_norm: 0.0,
111            learning_rate: 0.001,
112            params_updated: 0,
113        }
114    }
115}
116
117/// Training trajectory - sequence of training states.
118#[derive(Debug, Clone, Default, Serialize, Deserialize)]
119pub struct TrainingTrajectory {
120    /// Sequence of training states.
121    pub states: Vec<TrainingState>,
122}
123
124impl TrainingTrajectory {
125    /// Create new empty trajectory.
126    #[must_use]
127    pub fn new() -> Self {
128        Self { states: Vec::new() }
129    }
130
131    /// Add a state to the trajectory.
132    pub fn push(&mut self, state: TrainingState) {
133        self.states.push(state);
134    }
135
136    /// Get the final training state.
137    #[must_use]
138    pub fn final_state(&self) -> Option<&TrainingState> {
139        self.states.last()
140    }
141
142    /// Get best validation loss achieved.
143    #[must_use]
144    pub fn best_val_loss(&self) -> Option<f64> {
145        self.states
146            .iter()
147            .map(|s| s.val_loss)
148            .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
149    }
150
151    /// Check if training converged (loss stabilized).
152    #[must_use]
153    pub fn converged(&self, tolerance: f64) -> bool {
154        if self.states.len() < 10 {
155            return false;
156        }
157        let recent: Vec<f64> = self.states.iter().rev().take(10).map(|s| s.loss).collect();
158        let mean = recent.iter().sum::<f64>() / recent.len() as f64;
159        let variance = recent.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / recent.len() as f64;
160        variance.sqrt() < tolerance
161    }
162}
163
164// ============================================================================
165// Anomaly Detection (Jidoka)
166// ============================================================================
167
168/// Training anomaly types for Jidoka detection.
169#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
170pub enum TrainingAnomaly {
171    /// Loss became NaN or Infinity.
172    NonFiniteLoss,
173    /// Gradient norm exceeded threshold.
174    GradientExplosion {
175        /// Observed gradient norm.
176        norm: f64,
177        /// Threshold that was exceeded.
178        threshold: f64,
179    },
180    /// Gradient norm fell below threshold.
181    GradientVanishing {
182        /// Observed gradient norm.
183        norm: f64,
184        /// Threshold that was violated.
185        threshold: f64,
186    },
187    /// Loss spike detected (statistical outlier).
188    LossSpike {
189        /// Z-score of the spike.
190        z_score: f64,
191        /// Actual loss value.
192        loss: f64,
193    },
194    /// Prediction confidence below threshold.
195    LowConfidence {
196        /// Observed confidence.
197        confidence: f64,
198        /// Required threshold.
199        threshold: f64,
200    },
201}
202
203impl std::fmt::Display for TrainingAnomaly {
204    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205        match self {
206            Self::NonFiniteLoss => write!(f, "Non-finite loss detected (NaN/Inf)"),
207            Self::GradientExplosion { norm, threshold } => {
208                write!(
209                    f,
210                    "Gradient explosion: norm={norm:.2e} > threshold={threshold:.2e}"
211                )
212            }
213            Self::GradientVanishing { norm, threshold } => {
214                write!(
215                    f,
216                    "Gradient vanishing: norm={norm:.2e} < threshold={threshold:.2e}"
217                )
218            }
219            Self::LossSpike { z_score, loss } => {
220                write!(f, "Loss spike: z-score={z_score:.2}, loss={loss:.4}")
221            }
222            Self::LowConfidence {
223                confidence,
224                threshold,
225            } => {
226                write!(
227                    f,
228                    "Low confidence: {confidence:.4} < threshold={threshold:.4}"
229                )
230            }
231        }
232    }
233}
234
235/// Rolling statistics for anomaly detection.
236#[derive(Debug, Clone, Default)]
237pub struct RollingStats {
238    /// Number of observations.
239    count: u64,
240    /// Running mean.
241    mean: f64,
242    /// Running M2 for variance calculation.
243    m2: f64,
244    /// Window size (0 = unlimited).
245    window_size: usize,
246    /// Recent values for windowed stats.
247    recent: Vec<f64>,
248}
249
250impl RollingStats {
251    /// Create new rolling stats with optional window.
252    #[must_use]
253    pub fn new(window_size: usize) -> Self {
254        Self {
255            count: 0,
256            mean: 0.0,
257            m2: 0.0,
258            window_size,
259            recent: Vec::new(),
260        }
261    }
262
263    /// Update with new observation (Welford's algorithm).
264    pub fn update(&mut self, value: f64) {
265        self.count += 1;
266        let delta = value - self.mean;
267        self.mean += delta / self.count as f64;
268        let delta2 = value - self.mean;
269        self.m2 += delta * delta2;
270
271        if self.window_size > 0 {
272            self.recent.push(value);
273            if self.recent.len() > self.window_size {
274                self.recent.remove(0);
275            }
276        }
277    }
278
279    /// Get current mean.
280    #[must_use]
281    pub fn mean(&self) -> f64 {
282        self.mean
283    }
284
285    /// Get current variance.
286    #[must_use]
287    pub fn variance(&self) -> f64 {
288        if self.count < 2 {
289            return 0.0;
290        }
291        self.m2 / (self.count - 1) as f64
292    }
293
294    /// Get current standard deviation.
295    #[must_use]
296    pub fn std_dev(&self) -> f64 {
297        self.variance().sqrt()
298    }
299
300    /// Compute z-score for a value.
301    #[must_use]
302    pub fn z_score(&self, value: f64) -> f64 {
303        let std = self.std_dev();
304        if std < 1e-10 {
305            return 0.0;
306        }
307        (value - self.mean) / std
308    }
309
310    /// Reset statistics.
311    pub fn reset(&mut self) {
312        self.count = 0;
313        self.mean = 0.0;
314        self.m2 = 0.0;
315        self.recent.clear();
316    }
317}
318
319/// Anomaly detector for Jidoka-style training quality gates.
320#[derive(Debug, Clone)]
321pub struct AnomalyDetector {
322    /// Rolling statistics for loss values.
323    loss_stats: RollingStats,
324    /// Threshold in standard deviations for loss spikes.
325    threshold_sigma: f64,
326    /// Gradient explosion threshold.
327    gradient_explosion_threshold: f64,
328    /// Gradient vanishing threshold.
329    gradient_vanishing_threshold: f64,
330    /// Minimum observations before anomaly detection.
331    warmup_count: u64,
332    /// Number of anomalies detected.
333    anomaly_count: u64,
334}
335
336impl AnomalyDetector {
337    /// Create new anomaly detector with sigma threshold.
338    #[must_use]
339    pub fn new(threshold_sigma: f64) -> Self {
340        Self {
341            loss_stats: RollingStats::new(100),
342            threshold_sigma,
343            gradient_explosion_threshold: 1e6,
344            gradient_vanishing_threshold: 1e-10,
345            warmup_count: 10,
346            anomaly_count: 0,
347        }
348    }
349
350    /// Set gradient explosion threshold.
351    #[must_use]
352    pub fn with_gradient_explosion_threshold(mut self, threshold: f64) -> Self {
353        self.gradient_explosion_threshold = threshold;
354        self
355    }
356
357    /// Set gradient vanishing threshold.
358    #[must_use]
359    pub fn with_gradient_vanishing_threshold(mut self, threshold: f64) -> Self {
360        self.gradient_vanishing_threshold = threshold;
361        self
362    }
363
364    /// Set warmup count before anomaly detection activates.
365    #[must_use]
366    pub fn with_warmup(mut self, count: u64) -> Self {
367        self.warmup_count = count;
368        self
369    }
370
371    /// Check for training anomalies given loss and gradient norm.
372    pub fn check(&mut self, loss: f64, gradient_norm: f64) -> Option<TrainingAnomaly> {
373        // NaN/Inf detection (Poka-Yoke) - always active
374        if !loss.is_finite() {
375            self.anomaly_count += 1;
376            return Some(TrainingAnomaly::NonFiniteLoss);
377        }
378
379        // Gradient explosion detection
380        if gradient_norm > self.gradient_explosion_threshold {
381            self.anomaly_count += 1;
382            return Some(TrainingAnomaly::GradientExplosion {
383                norm: gradient_norm,
384                threshold: self.gradient_explosion_threshold,
385            });
386        }
387
388        // Gradient vanishing detection
389        if gradient_norm < self.gradient_vanishing_threshold && gradient_norm > 0.0 {
390            self.anomaly_count += 1;
391            return Some(TrainingAnomaly::GradientVanishing {
392                norm: gradient_norm,
393                threshold: self.gradient_vanishing_threshold,
394            });
395        }
396
397        // Loss spike detection (statistical) - only after warmup
398        self.loss_stats.update(loss);
399        if self.loss_stats.count > self.warmup_count {
400            let z_score = self.loss_stats.z_score(loss);
401            if z_score.abs() > self.threshold_sigma {
402                self.anomaly_count += 1;
403                return Some(TrainingAnomaly::LossSpike { z_score, loss });
404            }
405        }
406
407        None
408    }
409
410    /// Get number of anomalies detected.
411    #[must_use]
412    pub fn anomaly_count(&self) -> u64 {
413        self.anomaly_count
414    }
415
416    /// Reset detector state.
417    pub fn reset(&mut self) {
418        self.loss_stats.reset();
419        self.anomaly_count = 0;
420    }
421}
422
423// ============================================================================
424// Training Simulation
425// ============================================================================
426
427/// Simulated training event for journaling.
428#[derive(Debug, Clone, Serialize, Deserialize)]
429pub enum TrainEvent {
430    /// Epoch completed.
431    Epoch(TrainingState),
432    /// Anomaly detected.
433    Anomaly(String),
434    /// Checkpoint created.
435    Checkpoint { epoch: u64 },
436    /// Early stopping triggered.
437    EarlyStopping { best_epoch: u64, best_val_loss: f64 },
438}
439
440/// Simulated training scenario for reproducible ML experiments.
441///
442/// Implements Toyota Way principles:
443/// - Jidoka: Stop-on-anomaly via `AnomalyDetector`
444/// - Heijunka: Load-balanced batch iteration
445/// - Kaizen: Continuous improvement tracking
446pub struct TrainingSimulation {
447    /// Training hyperparameters.
448    config: TrainingConfig,
449    /// Deterministic RNG for reproducibility.
450    rng: SimRng,
451    /// Training event journal for replay.
452    journal: EventJournal,
453    /// Anomaly detector (Jidoka).
454    anomaly_detector: AnomalyDetector,
455    /// Current epoch.
456    current_epoch: u64,
457    /// Training trajectory.
458    trajectory: TrainingTrajectory,
459    /// Best validation loss for early stopping.
460    best_val_loss: f64,
461    /// Epochs without improvement counter.
462    epochs_without_improvement: usize,
463}
464
465impl TrainingSimulation {
466    /// Create new training simulation with deterministic seed.
467    #[must_use]
468    pub fn new(seed: u64) -> Self {
469        Self {
470            config: TrainingConfig::default(),
471            rng: SimRng::new(seed),
472            journal: EventJournal::new(true), // Record RNG state
473            anomaly_detector: AnomalyDetector::new(3.0), // 3σ threshold
474            current_epoch: 0,
475            trajectory: TrainingTrajectory::new(),
476            best_val_loss: f64::INFINITY,
477            epochs_without_improvement: 0,
478        }
479    }
480
481    /// Create with custom configuration.
482    #[must_use]
483    pub fn with_config(seed: u64, config: TrainingConfig) -> Self {
484        Self {
485            config,
486            rng: SimRng::new(seed),
487            journal: EventJournal::new(true), // Record RNG state
488            anomaly_detector: AnomalyDetector::new(3.0),
489            current_epoch: 0,
490            trajectory: TrainingTrajectory::new(),
491            best_val_loss: f64::INFINITY,
492            epochs_without_improvement: 0,
493        }
494    }
495
496    /// Set anomaly detector.
497    pub fn set_anomaly_detector(&mut self, detector: AnomalyDetector) {
498        self.anomaly_detector = detector;
499    }
500
501    /// Get current training configuration.
502    #[must_use]
503    pub fn config(&self) -> &TrainingConfig {
504        &self.config
505    }
506
507    /// Get current trajectory.
508    #[must_use]
509    pub fn trajectory(&self) -> &TrainingTrajectory {
510        &self.trajectory
511    }
512
513    /// Simulate a single training step with given loss and gradient norm.
514    ///
515    /// This is a simplified simulation - real training would compute actual
516    /// forward/backward passes. This enables testing training dynamics
517    /// without actual model computation.
518    ///
519    /// # Errors
520    ///
521    /// Returns error if a training anomaly is detected (Jidoka).
522    pub fn step(&mut self, loss: f64, gradient_norm: f64) -> SimResult<Option<TrainingState>> {
523        // Jidoka: Check for anomalies
524        if let Some(anomaly) = self.anomaly_detector.check(loss, gradient_norm) {
525            let event = TrainEvent::Anomaly(anomaly.to_string());
526            let rng_state = self.rng.save_state();
527            let _ = self.journal.append(
528                SimTime::from_secs(self.current_epoch as f64),
529                self.current_epoch,
530                &event,
531                Some(&rng_state),
532            );
533            return Err(SimError::jidoka(format!(
534                "Training anomaly at epoch {}: {anomaly}",
535                self.current_epoch
536            )));
537        }
538
539        // Simulate validation loss (simplified: add noise to training loss)
540        let val_loss = loss * (1.0 + 0.1 * (self.rng.gen_f64() - 0.5));
541
542        // Create training state
543        let rng_state = self.rng.save_state();
544        let state = TrainingState {
545            epoch: self.current_epoch,
546            loss,
547            val_loss,
548            metrics: TrainingMetrics {
549                train_loss: loss,
550                val_loss,
551                accuracy: None,
552                gradient_norm,
553                learning_rate: self.config.learning_rate,
554                params_updated: 1000, // Simulated
555            },
556            rng_state: rng_state.clone(),
557        };
558
559        // Track best validation loss for early stopping
560        if val_loss < self.best_val_loss {
561            self.best_val_loss = val_loss;
562            self.epochs_without_improvement = 0;
563        } else {
564            self.epochs_without_improvement += 1;
565        }
566
567        // Record in journal and trajectory
568        let event = TrainEvent::Epoch(state.clone());
569        let _ = self.journal.append(
570            SimTime::from_secs(self.current_epoch as f64),
571            self.current_epoch,
572            &event,
573            Some(&rng_state),
574        );
575        self.trajectory.push(state.clone());
576
577        self.current_epoch += 1;
578
579        // Check early stopping
580        if let Some(patience) = self.config.early_stopping {
581            if self.epochs_without_improvement >= patience {
582                let event = TrainEvent::EarlyStopping {
583                    best_epoch: self.current_epoch - patience as u64,
584                    best_val_loss: self.best_val_loss,
585                };
586                let rng_state = self.rng.save_state();
587                let _ = self.journal.append(
588                    SimTime::from_secs(self.current_epoch as f64),
589                    self.current_epoch,
590                    &event,
591                    Some(&rng_state),
592                );
593                return Ok(None); // Signal early stopping
594            }
595        }
596
597        Ok(Some(state))
598    }
599
600    /// Simulate training for specified epochs using a loss function.
601    ///
602    /// The `loss_fn` takes (epoch, rng) and returns (loss, `gradient_norm`).
603    ///
604    /// # Errors
605    ///
606    /// Returns error if a training anomaly is detected.
607    pub fn simulate<F>(&mut self, epochs: u64, mut loss_fn: F) -> SimResult<&TrainingTrajectory>
608    where
609        F: FnMut(u64, &mut SimRng) -> (f64, f64),
610    {
611        contract_pre_iterator!();
612        for epoch in 0..epochs {
613            let (loss, grad_norm) = loss_fn(epoch, &mut self.rng);
614            if self.step(loss, grad_norm)?.is_none() {
615                break; // Early stopping
616            }
617        }
618        Ok(&self.trajectory)
619    }
620
621    /// Replay training from a checkpoint state.
622    ///
623    /// # Errors
624    ///
625    /// Returns error if RNG state restoration fails.
626    pub fn replay_from(&mut self, checkpoint: &TrainingState) -> SimResult<()> {
627        self.rng
628            .restore_state(&checkpoint.rng_state)
629            .map_err(|e| SimError::config(format!("Failed to restore RNG state: {e}")))?;
630        self.current_epoch = checkpoint.epoch;
631        Ok(())
632    }
633
634    /// Get the event journal.
635    #[must_use]
636    pub fn journal(&self) -> &EventJournal {
637        &self.journal
638    }
639
640    /// Reset simulation state.
641    pub fn reset(&mut self, seed: u64) {
642        self.rng = SimRng::new(seed);
643        self.journal = EventJournal::new(true);
644        self.anomaly_detector.reset();
645        self.current_epoch = 0;
646        self.trajectory = TrainingTrajectory::new();
647        self.best_val_loss = f64::INFINITY;
648        self.epochs_without_improvement = 0;
649    }
650}