kizzasi_core/
training.rs

1//! Training infrastructure for SSM models
2//!
3//! Provides trainable versions of SSM components with automatic differentiation,
4//! loss functions, and optimization utilities using candle-core.
5//!
6//! # Features
7//!
8//! - **TrainableSSM**: Differentiable SSM model with automatic gradient tracking
9//! - **Trainer**: Full training loop with scheduler, metrics, and validation
10//! - **Loss Functions**: MSE, MAE, Huber, Cross-Entropy
11//! - **LR Scheduling**: Integrated support for all scheduler types
12//! - **Metrics Tracking**: Automatic loss, LR, and gradient monitoring
13//! - **Early Stopping**: Validation-based early stopping with patience
14
15use crate::config::KizzasiConfig;
16use crate::dataloader::TimeSeriesDataLoader;
17use crate::device::DeviceConfig;
18use crate::error::{CoreError, CoreResult};
19use crate::metrics::{MetricsLogger, TrainingMetrics};
20use crate::scheduler::LRScheduler;
21use candle_core::{DType, Device, Tensor, Var};
22use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
23use serde::{Deserialize, Serialize};
24
25/// Scheduler type enumeration
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub enum SchedulerType {
28    Constant,
29    Linear {
30        warmup_steps: usize,
31        final_lr: f64,
32    },
33    Cosine {
34        warmup_steps: usize,
35        min_lr: f64,
36    },
37    Step {
38        milestones: Vec<usize>,
39        decay_factor: f64,
40    },
41    Exponential {
42        decay_rate: f64,
43        decay_steps: usize,
44    },
45    OneCycle {
46        warmup_pct: f64,
47    },
48    Polynomial {
49        final_lr: f64,
50        power: f64,
51    },
52}
53
54/// Mixed precision training mode
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
56pub enum MixedPrecision {
57    /// Full precision (FP32)
58    None,
59    /// Half precision (FP16) - faster but less stable
60    FP16,
61    /// Brain float 16 (BF16) - better stability than FP16
62    BF16,
63}
64
65impl MixedPrecision {
66    /// Convert to candle DType
67    pub fn to_dtype(&self) -> DType {
68        match self {
69            MixedPrecision::None => DType::F32,
70            MixedPrecision::FP16 => DType::F16,
71            MixedPrecision::BF16 => DType::BF16,
72        }
73    }
74
75    /// Check if mixed precision is enabled
76    pub fn is_enabled(&self) -> bool {
77        !matches!(self, MixedPrecision::None)
78    }
79}
80
81/// Configuration for training
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct TrainingConfig {
84    /// Device configuration (CPU/CUDA/Metal)
85    pub device_config: DeviceConfig,
86    /// Learning rate (initial for schedulers)
87    pub learning_rate: f64,
88    /// Batch size
89    pub batch_size: usize,
90    /// Number of epochs
91    pub epochs: usize,
92    /// Weight decay (L2 regularization)
93    pub weight_decay: f64,
94    /// Gradient clipping threshold
95    pub grad_clip: Option<f32>,
96    /// Beta1 for Adam optimizer
97    pub beta1: f64,
98    /// Beta2 for Adam optimizer
99    pub beta2: f64,
100    /// Epsilon for Adam optimizer
101    pub eps: f64,
102    /// Learning rate scheduler type
103    pub scheduler: Option<SchedulerType>,
104    /// Enable metrics tracking
105    pub track_metrics: bool,
106    /// Log interval (batches)
107    pub log_interval: usize,
108    /// Validation split (0.0 to 1.0)
109    pub validation_split: f32,
110    /// Early stopping patience (epochs)
111    pub early_stopping_patience: Option<usize>,
112    /// Enable gradient checkpointing (saves memory by recomputing activations)
113    pub use_gradient_checkpointing: bool,
114    /// Checkpoint every N layers (None = checkpoint all layers)
115    pub checkpoint_segment_size: Option<usize>,
116    /// Mixed precision training mode
117    pub mixed_precision: MixedPrecision,
118    /// Loss scaling factor for mixed precision (to prevent underflow)
119    pub loss_scale: f32,
120}
121
122impl Default for TrainingConfig {
123    fn default() -> Self {
124        Self {
125            device_config: DeviceConfig::default(),
126            learning_rate: 1e-4,
127            batch_size: 32,
128            epochs: 10,
129            weight_decay: 1e-2,
130            grad_clip: Some(1.0),
131            beta1: 0.9,
132            beta2: 0.999,
133            eps: 1e-8,
134            scheduler: None,
135            track_metrics: true,
136            log_interval: 10,
137            validation_split: 0.2,
138            early_stopping_patience: Some(5),
139            use_gradient_checkpointing: false,
140            checkpoint_segment_size: Some(2), // Checkpoint every 2 layers by default
141            mixed_precision: MixedPrecision::None,
142            loss_scale: 1.0, // No scaling by default
143        }
144    }
145}
146
147impl TrainingConfig {
148    /// Set scheduler type
149    pub fn with_scheduler(mut self, scheduler: SchedulerType) -> Self {
150        self.scheduler = Some(scheduler);
151        self
152    }
153
154    /// Disable metrics tracking
155    pub fn without_metrics(mut self) -> Self {
156        self.track_metrics = false;
157        self
158    }
159
160    /// Set validation split
161    pub fn with_validation_split(mut self, split: f32) -> Self {
162        self.validation_split = split;
163        self
164    }
165
166    /// Set early stopping patience
167    pub fn with_early_stopping(mut self, patience: usize) -> Self {
168        self.early_stopping_patience = Some(patience);
169        self
170    }
171
172    /// Disable early stopping
173    pub fn without_early_stopping(mut self) -> Self {
174        self.early_stopping_patience = None;
175        self
176    }
177
178    /// Enable gradient checkpointing for memory-efficient training
179    pub fn with_gradient_checkpointing(mut self, segment_size: Option<usize>) -> Self {
180        self.use_gradient_checkpointing = true;
181        self.checkpoint_segment_size = segment_size;
182        self
183    }
184
185    /// Disable gradient checkpointing
186    pub fn without_gradient_checkpointing(mut self) -> Self {
187        self.use_gradient_checkpointing = false;
188        self
189    }
190
191    /// Enable mixed precision training (FP16)
192    pub fn with_fp16(mut self) -> Self {
193        self.mixed_precision = MixedPrecision::FP16;
194        self.loss_scale = 128.0; // Default loss scale for FP16
195        self
196    }
197
198    /// Enable mixed precision training (BF16)
199    pub fn with_bf16(mut self) -> Self {
200        self.mixed_precision = MixedPrecision::BF16;
201        self.loss_scale = 1.0; // BF16 is more stable, doesn't need scaling
202        self
203    }
204
205    /// Set mixed precision mode
206    pub fn with_mixed_precision(mut self, mode: MixedPrecision, loss_scale: f32) -> Self {
207        self.mixed_precision = mode;
208        self.loss_scale = loss_scale;
209        self
210    }
211
212    /// Disable mixed precision training
213    pub fn without_mixed_precision(mut self) -> Self {
214        self.mixed_precision = MixedPrecision::None;
215        self.loss_scale = 1.0;
216        self
217    }
218}
219
220/// Trainable Selective SSM using candle Tensors
221pub struct TrainableSSM {
222    config: KizzasiConfig,
223    training_config: TrainingConfig,
224    device: Device,
225    dtype: DType,
226    // Learnable parameters
227    embedding_weight: Var,
228    a_matrices: Vec<Var>,
229    b_matrices: Vec<Var>,
230    c_matrices: Vec<Var>,
231    d_vectors: Vec<Var>,
232    output_proj: Var,
233    // Layer normalization parameters
234    ln_gamma: Vec<Var>,
235    ln_beta: Vec<Var>,
236    // Variable map for optimizer
237    varmap: VarMap,
238}
239
240impl TrainableSSM {
241    /// Create a new trainable SSM model
242    pub fn new(config: KizzasiConfig, training_config: TrainingConfig) -> CoreResult<Self> {
243        // Create device from configuration
244        let device = training_config.device_config.create_device()?;
245
246        // Use mixed precision dtype from training config
247        let dtype = training_config.mixed_precision.to_dtype();
248
249        let hidden_dim = config.get_hidden_dim();
250        let state_dim = config.get_state_dim();
251        let num_layers = config.get_num_layers();
252        let input_dim = config.get_input_dim();
253        let output_dim = config.get_output_dim();
254
255        let varmap = VarMap::new();
256        let vb = VarBuilder::from_varmap(&varmap, dtype, &device);
257
258        // Initialize embedding layer
259        let embedding_weight_tensor = vb
260            .get_with_hints(
261                (input_dim, hidden_dim),
262                "embedding.weight",
263                candle_nn::init::DEFAULT_KAIMING_NORMAL,
264            )
265            .map_err(|e| CoreError::Generic(format!("Failed to create embedding: {}", e)))?;
266        let embedding_weight = Var::from_tensor(&embedding_weight_tensor)
267            .map_err(|e| CoreError::Generic(format!("Failed to create embedding var: {}", e)))?;
268
269        // Initialize SSM matrices for each layer
270        let mut a_matrices = Vec::with_capacity(num_layers);
271        let mut b_matrices = Vec::with_capacity(num_layers);
272        let mut c_matrices = Vec::with_capacity(num_layers);
273        let mut d_vectors = Vec::with_capacity(num_layers);
274        let mut ln_gamma = Vec::with_capacity(num_layers);
275        let mut ln_beta = Vec::with_capacity(num_layers);
276
277        for layer_idx in 0..num_layers {
278            // A matrix: state transition (initialized for stability)
279            let a_tensor = vb
280                .get_with_hints(
281                    (hidden_dim, state_dim),
282                    &format!("ssm.layer_{}.a", layer_idx),
283                    candle_nn::init::Init::Const(-0.5),
284                )
285                .map_err(|e| CoreError::Generic(format!("Failed to create A matrix: {}", e)))?;
286            let a = Var::from_tensor(&a_tensor)
287                .map_err(|e| CoreError::Generic(format!("Failed to create A var: {}", e)))?;
288            a_matrices.push(a);
289
290            // B matrix: input projection to state
291            let b_tensor = vb
292                .get_with_hints(
293                    (hidden_dim, state_dim),
294                    &format!("ssm.layer_{}.b", layer_idx),
295                    candle_nn::init::DEFAULT_KAIMING_NORMAL,
296                )
297                .map_err(|e| CoreError::Generic(format!("Failed to create B matrix: {}", e)))?;
298            let b = Var::from_tensor(&b_tensor)
299                .map_err(|e| CoreError::Generic(format!("Failed to create B var: {}", e)))?;
300            b_matrices.push(b);
301
302            // C matrix: state to output projection
303            let c_tensor = vb
304                .get_with_hints(
305                    (hidden_dim, state_dim),
306                    &format!("ssm.layer_{}.c", layer_idx),
307                    candle_nn::init::DEFAULT_KAIMING_NORMAL,
308                )
309                .map_err(|e| CoreError::Generic(format!("Failed to create C matrix: {}", e)))?;
310            let c = Var::from_tensor(&c_tensor)
311                .map_err(|e| CoreError::Generic(format!("Failed to create C var: {}", e)))?;
312            c_matrices.push(c);
313
314            // D vector: skip connection
315            let d_tensor = vb
316                .get_with_hints(
317                    hidden_dim,
318                    &format!("ssm.layer_{}.d", layer_idx),
319                    candle_nn::init::Init::Const(1.0),
320                )
321                .map_err(|e| CoreError::Generic(format!("Failed to create D vector: {}", e)))?;
322            let d = Var::from_tensor(&d_tensor)
323                .map_err(|e| CoreError::Generic(format!("Failed to create D var: {}", e)))?;
324            d_vectors.push(d);
325
326            // Layer normalization parameters
327            let gamma_tensor = vb
328                .get_with_hints(
329                    hidden_dim,
330                    &format!("ln.layer_{}.gamma", layer_idx),
331                    candle_nn::init::Init::Const(1.0),
332                )
333                .map_err(|e| CoreError::Generic(format!("Failed to create LN gamma: {}", e)))?;
334            let gamma = Var::from_tensor(&gamma_tensor)
335                .map_err(|e| CoreError::Generic(format!("Failed to create LN gamma var: {}", e)))?;
336            ln_gamma.push(gamma);
337
338            let beta_tensor = vb
339                .get_with_hints(
340                    hidden_dim,
341                    &format!("ln.layer_{}.beta", layer_idx),
342                    candle_nn::init::Init::Const(0.0),
343                )
344                .map_err(|e| CoreError::Generic(format!("Failed to create LN beta: {}", e)))?;
345            let beta = Var::from_tensor(&beta_tensor)
346                .map_err(|e| CoreError::Generic(format!("Failed to create LN beta var: {}", e)))?;
347            ln_beta.push(beta);
348        }
349
350        // Output projection
351        let output_proj_tensor = vb
352            .get_with_hints(
353                (hidden_dim, output_dim),
354                "output.proj",
355                candle_nn::init::DEFAULT_KAIMING_NORMAL,
356            )
357            .map_err(|e| {
358                CoreError::Generic(format!("Failed to create output projection: {}", e))
359            })?;
360        let output_proj = Var::from_tensor(&output_proj_tensor)
361            .map_err(|e| CoreError::Generic(format!("Failed to create output proj var: {}", e)))?;
362
363        Ok(Self {
364            config,
365            training_config,
366            device,
367            dtype,
368            embedding_weight,
369            a_matrices,
370            b_matrices,
371            c_matrices,
372            d_vectors,
373            output_proj,
374            ln_gamma,
375            ln_beta,
376            varmap,
377        })
378    }
379
380    /// Forward pass for training (tracks gradients)
381    ///
382    /// # Arguments
383    /// * `input` - Input tensor of shape [batch_size, seq_len, input_dim]
384    /// * `targets` - Optional target tensor for loss computation
385    ///
386    /// # Returns
387    /// Output tensor of shape [batch_size, seq_len, output_dim]
388    pub fn forward(&self, input: &Tensor) -> CoreResult<Tensor> {
389        // Embed input: [batch, seq, input_dim] -> [batch, seq, hidden_dim]
390        // Reshape input to [batch * seq, input_dim] for matmul, then reshape back
391        let batch_size = input
392            .dim(0)
393            .map_err(|e| CoreError::Generic(format!("Failed to get batch dimension: {}", e)))?;
394        let seq_len = input
395            .dim(1)
396            .map_err(|e| CoreError::Generic(format!("Failed to get sequence dimension: {}", e)))?;
397        let input_dim = input
398            .dim(2)
399            .map_err(|e| CoreError::Generic(format!("Failed to get input dimension: {}", e)))?;
400
401        let x_flat = input
402            .reshape((batch_size * seq_len, input_dim))
403            .map_err(|e| CoreError::Generic(format!("Failed to reshape input: {}", e)))?;
404
405        let hidden_dim = self.config.get_hidden_dim();
406        let x_embedded = x_flat
407            .matmul(self.embedding_weight.as_tensor())
408            .map_err(|e| CoreError::Generic(format!("Embedding forward failed: {}", e)))?;
409
410        let x = x_embedded
411            .reshape((batch_size, seq_len, hidden_dim))
412            .map_err(|e| CoreError::Generic(format!("Failed to reshape embedded: {}", e)))?;
413
414        // Initialize hidden state
415        let state_dim = self.config.get_state_dim();
416
417        let mut h = Tensor::zeros(
418            (batch_size, hidden_dim, state_dim),
419            self.dtype,
420            &self.device,
421        )
422        .map_err(|e| CoreError::Generic(format!("Failed to create hidden state: {}", e)))?;
423
424        let mut x = x;
425
426        // Process through each layer
427        for layer_idx in 0..self.config.get_num_layers() {
428            x = self.layer_norm(&x, layer_idx)?;
429            x = self.ssm_layer(&x, &mut h, layer_idx)?;
430        }
431
432        // Project to output dimension: [batch, seq, hidden_dim] -> [batch, seq, output_dim]
433        // Reshape to [batch * seq, hidden_dim], matmul, then reshape back
434        let x_flat = x
435            .reshape((batch_size * seq_len, hidden_dim))
436            .map_err(|e| CoreError::Generic(format!("Failed to reshape for output: {}", e)))?;
437
438        let output_dim = self.config.get_output_dim();
439        let output_flat = x_flat
440            .matmul(self.output_proj.as_tensor())
441            .map_err(|e| CoreError::Generic(format!("Output projection failed: {}", e)))?;
442
443        let output = output_flat
444            .reshape((batch_size, seq_len, output_dim))
445            .map_err(|e| CoreError::Generic(format!("Failed to reshape output: {}", e)))?;
446
447        Ok(output)
448    }
449
450    /// Apply layer normalization
451    fn layer_norm(&self, x: &Tensor, layer_idx: usize) -> CoreResult<Tensor> {
452        const EPS: f64 = 1e-5;
453
454        // Compute mean and variance along the last dimension
455        let mean = x
456            .mean_keepdim(candle_core::D::Minus1)
457            .map_err(|e| CoreError::Generic(format!("Layer norm mean failed: {}", e)))?;
458        let x_centered = x.broadcast_sub(&mean).map_err(|e| {
459            CoreError::Generic(format!("Layer norm variance computation failed: {}", e))
460        })?;
461        let variance = x_centered
462            .sqr()
463            .map_err(|e| CoreError::Generic(format!("Layer norm variance sqr failed: {}", e)))?
464            .mean_keepdim(candle_core::D::Minus1)
465            .map_err(|e| CoreError::Generic(format!("Layer norm variance mean failed: {}", e)))?;
466
467        // Normalize: (x - mean) / sqrt(variance + eps)
468        let std = (variance.affine(1.0, EPS))
469            .map_err(|e| CoreError::Generic(format!("Layer norm variance add eps failed: {}", e)))?
470            .sqrt()
471            .map_err(|e| CoreError::Generic(format!("Layer norm sqrt failed: {}", e)))?;
472
473        let normalized = x_centered
474            .broadcast_div(&std)
475            .map_err(|e| CoreError::Generic(format!("Layer norm division failed: {}", e)))?;
476
477        // Apply affine transformation
478        let gamma = self.ln_gamma[layer_idx].as_tensor();
479        let beta = self.ln_beta[layer_idx].as_tensor();
480
481        normalized
482            .broadcast_mul(gamma)
483            .map_err(|e| CoreError::Generic(format!("Layer norm gamma mul failed: {}", e)))?
484            .broadcast_add(beta)
485            .map_err(|e| CoreError::Generic(format!("Layer norm beta add failed: {}", e)))
486    }
487
488    /// SSM layer computation
489    fn ssm_layer(&self, x: &Tensor, _h: &mut Tensor, layer_idx: usize) -> CoreResult<Tensor> {
490        let _a = self.a_matrices[layer_idx].as_tensor();
491        let _b = self.b_matrices[layer_idx].as_tensor();
492        let _c = self.c_matrices[layer_idx].as_tensor();
493        let d = self.d_vectors[layer_idx].as_tensor();
494
495        // Simplified SSM step (full implementation would include selective scan)
496        // For now, implementing a basic skip connection
497        // TODO: Implement proper selective scan mechanism with state evolution
498
499        // For training, we process the entire sequence in parallel (teacher forcing)
500        // Output: y = D * x (simplified - full version uses state)
501        let y = x
502            .broadcast_mul(d)
503            .map_err(|e| CoreError::Generic(format!("Skip connection failed: {}", e)))?;
504
505        Ok(y)
506    }
507
508    /// Create an optimizer for this model
509    pub fn create_optimizer(&self) -> CoreResult<AdamW> {
510        let params = ParamsAdamW {
511            lr: self.training_config.learning_rate,
512            beta1: self.training_config.beta1,
513            beta2: self.training_config.beta2,
514            eps: self.training_config.eps,
515            weight_decay: self.training_config.weight_decay,
516        };
517
518        AdamW::new(self.varmap.all_vars(), params)
519            .map_err(|e| CoreError::Generic(format!("Failed to create optimizer: {}", e)))
520    }
521
522    /// Get the variable map for loading/saving weights
523    pub fn varmap(&self) -> &VarMap {
524        &self.varmap
525    }
526
527    /// Get device
528    pub fn device(&self) -> &Device {
529        &self.device
530    }
531
532    /// Get dtype
533    pub fn dtype(&self) -> DType {
534        self.dtype
535    }
536
537    /// Save model weights to a safetensors file
538    ///
539    /// # Arguments
540    /// * `path` - Path to save the safetensors file
541    ///
542    /// # Example
543    /// ```rust,ignore
544    /// model.save_weights("model.safetensors")?;
545    /// ```
546    pub fn save_weights<P: AsRef<std::path::Path>>(&self, path: P) -> CoreResult<()> {
547        self.varmap
548            .save(path)
549            .map_err(|e| CoreError::Generic(format!("Failed to save weights: {}", e)))
550    }
551
552    /// Load model weights from a safetensors file
553    ///
554    /// # Arguments
555    /// * `path` - Path to the safetensors file
556    ///
557    /// # Example
558    /// ```rust,ignore
559    /// model.load_weights("model.safetensors")?;
560    /// ```
561    pub fn load_weights<P: AsRef<std::path::Path>>(&mut self, path: P) -> CoreResult<()> {
562        self.varmap
563            .load(path)
564            .map_err(|e| CoreError::Generic(format!("Failed to load weights: {}", e)))
565    }
566}
567
568/// Constraint-aware loss wrapper
569///
570/// Bridges kizzasi-logic constraints with candle tensor operations.
571/// Allows combining task loss with constraint violations for constrained optimization.
572///
573/// # Examples
574///
575/// ```rust,ignore
576/// use kizzasi_core::{ConstraintLoss, Loss};
577///
578/// let constraint_loss = ConstraintLoss::new(0.1);
579///
580/// // In training loop:
581/// let task_loss = Loss::mse(&predictions, &targets)?;
582/// let total_loss = constraint_loss.compute(&task_loss, &predictions, |pred| {
583///     // Compute constraint violation from prediction
584///     Ok(0.0)
585/// })?;
586/// ```
587pub struct ConstraintLoss {
588    /// Base weight for constraint violations
589    constraint_weight: f32,
590}
591
592impl ConstraintLoss {
593    /// Create a new constraint-aware loss
594    pub fn new(constraint_weight: f32) -> Self {
595        Self { constraint_weight }
596    }
597
598    /// Compute combined loss: task_loss + constraint_weight * constraint_penalty
599    ///
600    /// # Arguments
601    /// * `task_loss` - Base task loss (MSE, MAE, etc.)
602    /// * `prediction` - Model prediction tensor
603    /// * `constraint_fn` - Function that computes constraint violation from prediction
604    pub fn compute<F>(
605        &self,
606        task_loss: &Tensor,
607        prediction: &Tensor,
608        constraint_fn: F,
609    ) -> CoreResult<Tensor>
610    where
611        F: Fn(&Tensor) -> CoreResult<f32>,
612    {
613        // Compute constraint violation
614        let violation = constraint_fn(prediction)?;
615
616        // Add constraint penalty to task loss
617        // Create a scalar penalty value matching task_loss shape
618        let penalty_value = self.constraint_weight * violation;
619
620        // Use affine to add the penalty: task_loss + penalty = task_loss * 1.0 + penalty
621        task_loss
622            .affine(1.0, penalty_value as f64)
623            .map_err(|e| CoreError::Generic(format!("Failed to add constraint penalty: {}", e)))
624    }
625}
626
627/// Loss functions for training
628pub struct Loss;
629
630impl Loss {
631    /// Mean Squared Error loss
632    pub fn mse(predictions: &Tensor, targets: &Tensor) -> CoreResult<Tensor> {
633        predictions
634            .sub(targets)
635            .map_err(|e| CoreError::Generic(format!("MSE subtraction failed: {}", e)))?
636            .sqr()
637            .map_err(|e| CoreError::Generic(format!("MSE square failed: {}", e)))?
638            .mean_all()
639            .map_err(|e| CoreError::Generic(format!("MSE mean failed: {}", e)))
640    }
641
642    /// Mean Absolute Error loss
643    pub fn mae(predictions: &Tensor, targets: &Tensor) -> CoreResult<Tensor> {
644        predictions
645            .sub(targets)
646            .map_err(|e| CoreError::Generic(format!("MAE subtraction failed: {}", e)))?
647            .abs()
648            .map_err(|e| CoreError::Generic(format!("MAE abs failed: {}", e)))?
649            .mean_all()
650            .map_err(|e| CoreError::Generic(format!("MAE mean failed: {}", e)))
651    }
652
653    /// Huber loss (smooth L1 loss)
654    pub fn huber(predictions: &Tensor, targets: &Tensor, delta: f64) -> CoreResult<Tensor> {
655        let diff = predictions
656            .sub(targets)
657            .map_err(|e| CoreError::Generic(format!("Huber subtraction failed: {}", e)))?;
658        let abs_diff = diff
659            .abs()
660            .map_err(|e| CoreError::Generic(format!("Huber abs failed: {}", e)))?;
661
662        // If |diff| <= delta: 0.5 * diff^2
663        // If |diff| > delta: delta * (|diff| - 0.5 * delta)
664        let squared = diff
665            .sqr()
666            .map_err(|e| CoreError::Generic(format!("Huber square failed: {}", e)))?
667            .affine(0.5, 0.0)
668            .map_err(|e| CoreError::Generic(format!("Huber mul 0.5 failed: {}", e)))?;
669
670        let linear_offset = delta * delta * 0.5;
671        let linear = abs_diff
672            .affine(delta, -linear_offset)
673            .map_err(|e| CoreError::Generic(format!("Huber linear computation failed: {}", e)))?;
674
675        let mask = abs_diff
676            .le(delta)
677            .map_err(|e| CoreError::Generic(format!("Huber comparison failed: {}", e)))?
678            .to_dtype(predictions.dtype())
679            .map_err(|e| CoreError::Generic(format!("Huber mask conversion failed: {}", e)))?;
680
681        // Invert mask: 1 - mask
682        let inv_mask = mask
683            .affine(-1.0, 1.0)
684            .map_err(|e| CoreError::Generic(format!("Huber mask inversion failed: {}", e)))?;
685
686        let loss = squared
687            .mul(&mask)
688            .map_err(|e| CoreError::Generic(format!("Huber squared mul failed: {}", e)))?
689            .add(
690                &linear
691                    .mul(&inv_mask)
692                    .map_err(|e| CoreError::Generic(format!("Huber linear mul failed: {}", e)))?,
693            )
694            .map_err(|e| CoreError::Generic(format!("Huber final add failed: {}", e)))?;
695
696        loss.mean_all()
697            .map_err(|e| CoreError::Generic(format!("Huber mean failed: {}", e)))
698    }
699
700    /// Cross-entropy loss for classification
701    pub fn cross_entropy(logits: &Tensor, targets: &Tensor) -> CoreResult<Tensor> {
702        // Log softmax
703        let log_probs = candle_nn::ops::log_softmax(logits, candle_core::D::Minus1)
704            .map_err(|e| CoreError::Generic(format!("Log softmax failed: {}", e)))?;
705
706        // Negative log likelihood
707        let nll = log_probs
708            .mul(targets)
709            .map_err(|e| CoreError::Generic(format!("NLL multiplication failed: {}", e)))?
710            .sum_all()
711            .map_err(|e| CoreError::Generic(format!("NLL sum failed: {}", e)))?
712            .neg()
713            .map_err(|e| CoreError::Generic(format!("NLL negation failed: {}", e)))?;
714
715        // Average over batch
716        let batch_size = logits
717            .dim(0)
718            .map_err(|e| CoreError::Generic(format!("Failed to get batch size: {}", e)))?;
719        nll.affine(1.0 / batch_size as f64, 0.0)
720            .map_err(|e| CoreError::Generic(format!("Cross entropy division failed: {}", e)))
721    }
722}
723
724/// Training utilities with scheduler, metrics, and validation
725pub struct Trainer {
726    model: TrainableSSM,
727    optimizer: AdamW,
728    config: TrainingConfig,
729    scheduler: Option<Box<dyn LRScheduler>>,
730    metrics: TrainingMetrics,
731    logger: MetricsLogger,
732    current_step: usize,
733}
734
735impl Trainer {
736    /// Create a new trainer
737    pub fn new(model: TrainableSSM, config: TrainingConfig) -> CoreResult<Self> {
738        let optimizer = model.create_optimizer()?;
739
740        // Create scheduler based on config
741        let scheduler = Self::create_scheduler(&config);
742
743        let metrics = TrainingMetrics::new();
744
745        let logger = MetricsLogger::new()
746            .with_verbose(config.track_metrics)
747            .with_log_interval(config.log_interval);
748
749        Ok(Self {
750            model,
751            optimizer,
752            config,
753            scheduler,
754            metrics,
755            logger,
756            current_step: 0,
757        })
758    }
759
760    /// Create scheduler from config
761    fn create_scheduler(config: &TrainingConfig) -> Option<Box<dyn LRScheduler>> {
762        use crate::scheduler::*;
763
764        config.scheduler.as_ref().map(|sched_type| {
765            let total_steps = config.epochs * 100; // Rough estimate, can be updated later
766
767            match sched_type {
768                SchedulerType::Constant => {
769                    Box::new(ConstantScheduler::new(config.learning_rate)) as Box<dyn LRScheduler>
770                }
771                SchedulerType::Linear {
772                    warmup_steps,
773                    final_lr,
774                } => Box::new(LinearScheduler::new(
775                    config.learning_rate,
776                    *final_lr,
777                    total_steps,
778                    *warmup_steps,
779                )) as Box<dyn LRScheduler>,
780                SchedulerType::Cosine {
781                    warmup_steps,
782                    min_lr,
783                } => Box::new(
784                    CosineScheduler::new(config.learning_rate, total_steps, *warmup_steps)
785                        .with_min_lr(*min_lr),
786                ) as Box<dyn LRScheduler>,
787                SchedulerType::Step {
788                    milestones,
789                    decay_factor,
790                } => Box::new(StepScheduler::new(
791                    config.learning_rate,
792                    *decay_factor,
793                    milestones.clone(),
794                )) as Box<dyn LRScheduler>,
795                SchedulerType::Exponential {
796                    decay_rate,
797                    decay_steps,
798                } => Box::new(ExponentialScheduler::new(
799                    config.learning_rate,
800                    *decay_rate,
801                    *decay_steps,
802                )) as Box<dyn LRScheduler>,
803                SchedulerType::OneCycle { warmup_pct } => Box::new(
804                    OneCycleScheduler::new(config.learning_rate, total_steps)
805                        .with_warmup_pct(*warmup_pct),
806                ) as Box<dyn LRScheduler>,
807                SchedulerType::Polynomial { final_lr, power } => Box::new(PolynomialScheduler::new(
808                    config.learning_rate,
809                    *final_lr,
810                    total_steps,
811                    *power,
812                ))
813                    as Box<dyn LRScheduler>,
814            }
815        })
816    }
817
818    /// Get current learning rate
819    fn get_current_lr(&self) -> f64 {
820        self.scheduler
821            .as_ref()
822            .map(|s| s.get_lr(self.current_step))
823            .unwrap_or(self.config.learning_rate)
824    }
825
826    /// Train for one epoch
827    pub fn train_epoch<F>(
828        &mut self,
829        data_loader: &[(Tensor, Tensor)],
830        loss_fn: F,
831    ) -> CoreResult<f32>
832    where
833        F: Fn(&Tensor, &Tensor) -> CoreResult<Tensor>,
834    {
835        let mut total_loss = 0.0;
836        let num_batches = data_loader.len();
837        let epoch = self.current_step / num_batches.max(1);
838
839        for (batch_idx, (inputs, targets)) in data_loader.iter().enumerate() {
840            // Update learning rate from scheduler
841            let lr = self.get_current_lr();
842            if self.config.track_metrics {
843                self.metrics.record_learning_rate(lr);
844            }
845
846            // Forward pass
847            let predictions = self.model.forward(inputs)?;
848
849            // Compute loss
850            let loss = loss_fn(&predictions, targets)?;
851
852            // Backward pass
853            self.optimizer
854                .backward_step(&loss)
855                .map_err(|e| CoreError::Generic(format!("Backward step failed: {}", e)))?;
856
857            // Accumulate loss
858            let loss_val = loss
859                .to_vec0::<f32>()
860                .map_err(|e| CoreError::Generic(format!("Failed to extract loss value: {}", e)))?;
861            total_loss += loss_val;
862
863            // Track metrics
864            if self.config.track_metrics {
865                self.metrics.record_train_loss(epoch, loss_val);
866                self.logger.log_batch(epoch, batch_idx, loss_val);
867
868                // Compute and track gradient norm
869                let grad_norm = self.compute_grad_norm()?;
870                self.metrics.record_grad_norm(grad_norm);
871            }
872
873            // Gradient clipping if enabled
874            if let Some(max_norm) = self.config.grad_clip {
875                self.clip_gradients(max_norm)?;
876            }
877
878            self.current_step += 1;
879        }
880
881        Ok(total_loss / num_batches as f32)
882    }
883
884    /// Compute gradient norm
885    fn compute_grad_norm(&self) -> CoreResult<f32> {
886        // Placeholder: In candle, gradient norms would be computed from VarMap
887        // For now, return a dummy value
888        // TODO: Implement proper gradient norm computation when candle exposes gradient access
889        Ok(1.0)
890    }
891
892    /// Clip gradients by global norm
893    ///
894    /// Note: Gradient clipping is handled internally by candle's optimizer.
895    /// This is a placeholder for custom gradient clipping if needed.
896    fn clip_gradients(&self, _max_norm: f32) -> CoreResult<()> {
897        // Gradient clipping will be handled by the optimizer's built-in mechanism
898        // or via custom backward hooks in future implementations
899        Ok(())
900    }
901
902    /// Evaluate on validation data
903    pub fn evaluate<F>(&self, data_loader: &[(Tensor, Tensor)], loss_fn: F) -> CoreResult<f32>
904    where
905        F: Fn(&Tensor, &Tensor) -> CoreResult<Tensor>,
906    {
907        let mut total_loss = 0.0;
908        let num_batches = data_loader.len();
909
910        for (inputs, targets) in data_loader {
911            // Forward pass (no gradient tracking needed)
912            let predictions = self.model.forward(inputs)?;
913
914            // Compute loss
915            let loss = loss_fn(&predictions, targets)?;
916
917            // Accumulate loss
918            let loss_val = loss
919                .to_vec0::<f32>()
920                .map_err(|e| CoreError::Generic(format!("Failed to extract loss value: {}", e)))?;
921            total_loss += loss_val;
922        }
923
924        Ok(total_loss / num_batches as f32)
925    }
926
927    /// Full training loop with validation and early stopping
928    pub fn fit<F>(
929        &mut self,
930        mut train_loader: TimeSeriesDataLoader,
931        mut val_loader: Option<TimeSeriesDataLoader>,
932        loss_fn: F,
933    ) -> CoreResult<()>
934    where
935        F: Fn(&Tensor, &Tensor) -> CoreResult<Tensor> + Copy,
936    {
937        use std::time::Instant;
938
939        for epoch in 0..self.config.epochs {
940            let epoch_start = Instant::now();
941
942            // Shuffle training data
943            train_loader.shuffle();
944
945            // Prepare batches (simplified - actual implementation would iterate batches)
946            // For now, this is a placeholder for the integration
947            // TODO: Implement proper batch iteration with TimeSeriesDataLoader
948            let train_batches: Vec<(Tensor, Tensor)> = Vec::new();
949
950            // Train for one epoch
951            let train_loss = self.train_epoch(&train_batches, loss_fn)?;
952
953            // Validation
954            let val_loss = if let Some(ref mut _val_data) = val_loader {
955                let val_batches: Vec<(Tensor, Tensor)> = Vec::new();
956                let val_loss = self.evaluate(&val_batches, loss_fn)?;
957
958                if self.config.track_metrics {
959                    self.metrics.record_val_loss(epoch, val_loss);
960                }
961
962                Some(val_loss)
963            } else {
964                None
965            };
966
967            // Track epoch duration
968            let epoch_duration = epoch_start.elapsed().as_secs_f64();
969            if self.config.track_metrics {
970                self.metrics.record_epoch_duration(epoch, epoch_duration);
971            }
972
973            // Log epoch metrics
974            let current_lr = self.get_current_lr();
975            self.logger
976                .log_epoch(epoch, train_loss, val_loss, current_lr);
977
978            // Early stopping check
979            if let Some(patience) = self.config.early_stopping_patience {
980                if !self.metrics.is_improving(patience) {
981                    tracing::info!("Early stopping triggered at epoch {}", epoch);
982                    break;
983                }
984            }
985        }
986
987        // Log training summary
988        if self.config.track_metrics {
989            let summary = self.metrics.summary();
990            self.logger.log_summary(&summary);
991        }
992
993        Ok(())
994    }
995
996    /// Get reference to the model
997    pub fn model(&self) -> &TrainableSSM {
998        &self.model
999    }
1000
1001    /// Get mutable reference to the model
1002    pub fn model_mut(&mut self) -> &mut TrainableSSM {
1003        &mut self.model
1004    }
1005
1006    /// Get reference to training metrics
1007    pub fn metrics(&self) -> &TrainingMetrics {
1008        &self.metrics
1009    }
1010
1011    /// Get mutable reference to training metrics
1012    pub fn metrics_mut(&mut self) -> &mut TrainingMetrics {
1013        &mut self.metrics
1014    }
1015
1016    /// Get current training step
1017    pub fn current_step(&self) -> usize {
1018        self.current_step
1019    }
1020
1021    /// Save checkpoint to disk
1022    ///
1023    /// Saves model weights, optimizer state, training configuration, metrics, and metadata.
1024    ///
1025    /// # Arguments
1026    /// * `path` - Directory to save checkpoint files
1027    /// * `name` - Checkpoint name (without extension)
1028    ///
1029    /// # Example
1030    /// ```rust,ignore
1031    /// trainer.save_checkpoint("checkpoints", "epoch_10")?;
1032    /// // Creates: checkpoints/epoch_10.safetensors and checkpoints/epoch_10.json
1033    /// ```
1034    pub fn save_checkpoint<P: AsRef<std::path::Path>>(
1035        &self,
1036        path: P,
1037        name: &str,
1038    ) -> CoreResult<()> {
1039        use std::fs;
1040        use std::path::PathBuf;
1041
1042        let checkpoint_dir = path.as_ref();
1043        fs::create_dir_all(checkpoint_dir).map_err(|e| {
1044            CoreError::Generic(format!("Failed to create checkpoint directory: {}", e))
1045        })?;
1046
1047        // Save model weights to safetensors
1048        let weights_path: PathBuf = checkpoint_dir.join(format!("{}.safetensors", name));
1049        self.model
1050            .save_weights(&weights_path)
1051            .map_err(|e| CoreError::Generic(format!("Failed to save model weights: {}", e)))?;
1052
1053        // Create checkpoint metadata
1054        let metadata = CheckpointMetadata {
1055            version: env!("CARGO_PKG_VERSION").to_string(),
1056            timestamp: chrono::Utc::now().to_rfc3339(),
1057            current_step: self.current_step,
1058            current_epoch: self.metrics.summary().total_epochs,
1059            config: self.config.clone(),
1060            metrics: self.metrics.clone(),
1061        };
1062
1063        // Save metadata to JSON
1064        let metadata_path: PathBuf = checkpoint_dir.join(format!("{}.json", name));
1065        let metadata_json = serde_json::to_string_pretty(&metadata).map_err(|e| {
1066            CoreError::Generic(format!("Failed to serialize checkpoint metadata: {}", e))
1067        })?;
1068
1069        fs::write(&metadata_path, metadata_json).map_err(|e| {
1070            CoreError::Generic(format!("Failed to write checkpoint metadata: {}", e))
1071        })?;
1072
1073        tracing::info!(
1074            "Checkpoint saved: weights={}, metadata={}",
1075            weights_path.display(),
1076            metadata_path.display()
1077        );
1078
1079        Ok(())
1080    }
1081
1082    /// Load checkpoint and resume training
1083    ///
1084    /// Creates a new Trainer from a saved checkpoint, restoring model weights,
1085    /// configuration, and training state.
1086    ///
1087    /// # Arguments
1088    /// * `path` - Directory containing checkpoint files
1089    /// * `name` - Checkpoint name (without extension)
1090    /// * `model_config` - Model configuration (must match saved model)
1091    ///
1092    /// # Example
1093    /// ```rust,ignore
1094    /// let trainer = Trainer::load_checkpoint("checkpoints", "epoch_10", model_config)?;
1095    /// // Continue training from epoch 10
1096    /// ```
1097    pub fn load_checkpoint<P: AsRef<std::path::Path>>(
1098        path: P,
1099        name: &str,
1100        model_config: KizzasiConfig,
1101    ) -> CoreResult<Self> {
1102        use std::fs;
1103        use std::path::PathBuf;
1104
1105        let checkpoint_dir = path.as_ref();
1106
1107        // Load metadata from JSON
1108        let metadata_path: PathBuf = checkpoint_dir.join(format!("{}.json", name));
1109        let metadata_json = fs::read_to_string(&metadata_path).map_err(|e| {
1110            CoreError::Generic(format!("Failed to read checkpoint metadata: {}", e))
1111        })?;
1112
1113        let metadata: CheckpointMetadata = serde_json::from_str(&metadata_json).map_err(|e| {
1114            CoreError::Generic(format!("Failed to parse checkpoint metadata: {}", e))
1115        })?;
1116
1117        // Load model weights
1118        let weights_path: PathBuf = checkpoint_dir.join(format!("{}.safetensors", name));
1119        let mut model = TrainableSSM::new(model_config, metadata.config.clone())?;
1120        model
1121            .load_weights(&weights_path)
1122            .map_err(|e| CoreError::Generic(format!("Failed to load model weights: {}", e)))?;
1123
1124        // Create trainer with loaded state
1125        let optimizer = model.create_optimizer()?;
1126        let scheduler = Self::create_scheduler(&metadata.config);
1127
1128        let logger = MetricsLogger::new()
1129            .with_verbose(metadata.config.track_metrics)
1130            .with_log_interval(metadata.config.log_interval);
1131
1132        tracing::info!(
1133            "Checkpoint loaded: version={}, step={}, epoch={}",
1134            metadata.version,
1135            metadata.current_step,
1136            metadata.current_epoch
1137        );
1138
1139        Ok(Self {
1140            model,
1141            optimizer,
1142            config: metadata.config,
1143            scheduler,
1144            metrics: metadata.metrics,
1145            logger,
1146            current_step: metadata.current_step,
1147        })
1148    }
1149
1150    /// Save checkpoint with automatic naming (epoch-based)
1151    ///
1152    /// Convenience method that automatically names checkpoints based on current epoch.
1153    ///
1154    /// # Example
1155    /// ```rust,ignore
1156    /// trainer.save_checkpoint_auto("checkpoints")?;
1157    /// // Creates: checkpoints/checkpoint_epoch_5.safetensors, etc.
1158    /// ```
1159    pub fn save_checkpoint_auto<P: AsRef<std::path::Path>>(&self, path: P) -> CoreResult<()> {
1160        let current_epoch = self.metrics.summary().total_epochs;
1161        let name = format!("checkpoint_epoch_{}", current_epoch);
1162        self.save_checkpoint(path, &name)
1163    }
1164
1165    /// Save checkpoint if this is the best epoch (lowest validation loss)
1166    ///
1167    /// Automatically saves a "best" checkpoint when validation loss improves.
1168    ///
1169    /// # Example
1170    /// ```rust,ignore
1171    /// // After each validation epoch
1172    /// trainer.save_best_checkpoint("checkpoints")?;
1173    /// ```
1174    pub fn save_best_checkpoint<P: AsRef<std::path::Path>>(&self, path: P) -> CoreResult<()> {
1175        let summary = self.metrics.summary();
1176
1177        // Only save if this is the best epoch
1178        // Note: total_epochs is 1-indexed (count), best_epoch is 0-indexed (epoch number)
1179        if let (Some(best_epoch), Some(_best_loss)) = (summary.best_epoch, summary.best_val_loss) {
1180            // Current epoch is total_epochs - 1 (convert from count to 0-indexed)
1181            let current_epoch = summary.total_epochs.saturating_sub(1);
1182            if current_epoch == best_epoch {
1183                tracing::info!("New best validation loss! Saving best checkpoint");
1184                return self.save_checkpoint(path, "best");
1185            }
1186        }
1187
1188        Ok(())
1189    }
1190}
1191
1192/// Checkpoint metadata for training state persistence
1193#[derive(Debug, Clone, Serialize, Deserialize)]
1194pub struct CheckpointMetadata {
1195    /// Package version when checkpoint was created
1196    pub version: String,
1197    /// ISO 8601 timestamp
1198    pub timestamp: String,
1199    /// Current training step
1200    pub current_step: usize,
1201    /// Current epoch number
1202    pub current_epoch: usize,
1203    /// Training configuration
1204    pub config: TrainingConfig,
1205    /// Training metrics history
1206    pub metrics: TrainingMetrics,
1207}
1208
1209#[cfg(test)]
1210mod tests {
1211    use super::*;
1212
1213    #[test]
1214    fn test_trainable_ssm_creation() {
1215        let config = KizzasiConfig::new()
1216            .input_dim(3)
1217            .output_dim(3)
1218            .hidden_dim(64)
1219            .state_dim(8)
1220            .num_layers(2);
1221
1222        let training_config = TrainingConfig::default();
1223
1224        let model = TrainableSSM::new(config, training_config);
1225        assert!(model.is_ok());
1226    }
1227
1228    #[test]
1229    fn test_forward_pass() {
1230        let config = KizzasiConfig::new()
1231            .input_dim(3)
1232            .output_dim(3)
1233            .hidden_dim(64)
1234            .state_dim(8)
1235            .num_layers(2);
1236
1237        let training_config = TrainingConfig::default();
1238
1239        let model = TrainableSSM::new(config, training_config).unwrap();
1240        let device = model.device().clone();
1241
1242        // Create dummy input: [batch=2, seq=10, input_dim=3]
1243        let input = Tensor::randn(0f32, 1.0, (2, 10, 3), &device).unwrap();
1244
1245        let output = model.forward(&input);
1246        if let Err(e) = &output {
1247            panic!("Forward pass failed: {:?}", e);
1248        }
1249
1250        let output = output.unwrap();
1251        assert_eq!(output.dims(), &[2, 10, 3]);
1252    }
1253
1254    #[test]
1255    fn test_mse_loss() {
1256        let device = Device::Cpu;
1257        let predictions = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
1258        let targets = Tensor::new(&[1.5f32, 2.5, 3.5], &device).unwrap();
1259
1260        let loss = Loss::mse(&predictions, &targets).unwrap();
1261        let loss_val = loss.to_vec0::<f32>().unwrap();
1262
1263        // Expected: mean((0.5^2 + 0.5^2 + 0.5^2)) = 0.25
1264        assert!((loss_val - 0.25).abs() < 1e-5);
1265    }
1266
1267    #[test]
1268    fn test_training_config_default() {
1269        let config = TrainingConfig::default();
1270        assert_eq!(config.learning_rate, 1e-4);
1271        assert_eq!(config.batch_size, 32);
1272        assert_eq!(config.epochs, 10);
1273        assert!(config.track_metrics);
1274        assert_eq!(config.validation_split, 0.2);
1275        assert_eq!(config.early_stopping_patience, Some(5));
1276    }
1277
1278    #[test]
1279    fn test_training_config_with_scheduler() {
1280        let config = TrainingConfig::default().with_scheduler(SchedulerType::Cosine {
1281            warmup_steps: 100,
1282            min_lr: 1e-6,
1283        });
1284
1285        assert!(config.scheduler.is_some());
1286        if let Some(SchedulerType::Cosine {
1287            warmup_steps,
1288            min_lr,
1289        }) = config.scheduler
1290        {
1291            assert_eq!(warmup_steps, 100);
1292            assert_eq!(min_lr, 1e-6);
1293        } else {
1294            panic!("Expected Cosine scheduler");
1295        }
1296    }
1297
1298    #[test]
1299    fn test_training_config_builder() {
1300        let config = TrainingConfig::default()
1301            .with_validation_split(0.15)
1302            .with_early_stopping(10)
1303            .without_metrics();
1304
1305        assert_eq!(config.validation_split, 0.15);
1306        assert_eq!(config.early_stopping_patience, Some(10));
1307        assert!(!config.track_metrics);
1308    }
1309
1310    #[test]
1311    fn test_trainer_with_scheduler() {
1312        let model_config = KizzasiConfig::new()
1313            .input_dim(3)
1314            .output_dim(3)
1315            .hidden_dim(64)
1316            .state_dim(8)
1317            .num_layers(2);
1318
1319        let training_config = TrainingConfig::default().with_scheduler(SchedulerType::Linear {
1320            warmup_steps: 50,
1321            final_lr: 1e-6,
1322        });
1323
1324        let model = TrainableSSM::new(model_config, training_config.clone()).unwrap();
1325        let trainer = Trainer::new(model, training_config);
1326
1327        assert!(trainer.is_ok());
1328        let trainer = trainer.unwrap();
1329        assert!(trainer.scheduler.is_some());
1330    }
1331
1332    #[test]
1333    fn test_trainer_metrics_tracking() {
1334        let model_config = KizzasiConfig::new()
1335            .input_dim(3)
1336            .output_dim(3)
1337            .hidden_dim(64)
1338            .state_dim(8)
1339            .num_layers(2);
1340
1341        let training_config = TrainingConfig::default();
1342        let model = TrainableSSM::new(model_config, training_config.clone()).unwrap();
1343        let trainer = Trainer::new(model, training_config).unwrap();
1344
1345        // Check that metrics are initialized
1346        assert_eq!(trainer.metrics().current_step(), 0);
1347        assert_eq!(trainer.current_step(), 0);
1348    }
1349
1350    #[test]
1351    fn test_scheduler_type_constant() {
1352        let config = TrainingConfig::default().with_scheduler(SchedulerType::Constant);
1353
1354        assert!(config.scheduler.is_some());
1355    }
1356
1357    #[test]
1358    fn test_scheduler_type_step() {
1359        let config = TrainingConfig::default().with_scheduler(SchedulerType::Step {
1360            milestones: vec![100, 200, 300],
1361            decay_factor: 0.1,
1362        });
1363
1364        if let Some(SchedulerType::Step {
1365            milestones,
1366            decay_factor,
1367        }) = config.scheduler
1368        {
1369            assert_eq!(milestones, vec![100, 200, 300]);
1370            assert_eq!(decay_factor, 0.1);
1371        } else {
1372            panic!("Expected Step scheduler");
1373        }
1374    }
1375
1376    #[test]
1377    fn test_scheduler_type_onecycle() {
1378        let config =
1379            TrainingConfig::default().with_scheduler(SchedulerType::OneCycle { warmup_pct: 0.3 });
1380
1381        if let Some(SchedulerType::OneCycle { warmup_pct }) = config.scheduler {
1382            assert_eq!(warmup_pct, 0.3);
1383        } else {
1384            panic!("Expected OneCycle scheduler");
1385        }
1386    }
1387
1388    #[test]
1389    fn test_mae_loss() {
1390        let device = Device::Cpu;
1391        let predictions = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
1392        let targets = Tensor::new(&[1.5f32, 2.5, 3.5], &device).unwrap();
1393
1394        let loss = Loss::mae(&predictions, &targets).unwrap();
1395        let loss_val = loss.to_vec0::<f32>().unwrap();
1396
1397        // Expected: mean(|0.5| + |0.5| + |0.5|) = 0.5
1398        assert!((loss_val - 0.5).abs() < 1e-5);
1399    }
1400
1401    #[test]
1402    fn test_huber_loss() {
1403        let device = Device::Cpu;
1404        let predictions = Tensor::new(&[1.0f32, 2.0, 5.0], &device).unwrap();
1405        let targets = Tensor::new(&[1.1f32, 2.1, 3.0], &device).unwrap();
1406
1407        let loss = Loss::huber(&predictions, &targets, 1.0).unwrap();
1408        let loss_val = loss.to_vec0::<f32>().unwrap();
1409
1410        // Huber loss is smooth L1
1411        assert!(loss_val > 0.0);
1412        assert!(loss_val < 2.0); // Should be less than L1 loss for large errors
1413    }
1414
1415    #[test]
1416    fn test_constraint_loss_creation() {
1417        let constraint_loss = ConstraintLoss::new(0.5);
1418        assert_eq!(constraint_loss.constraint_weight, 0.5);
1419    }
1420
1421    #[test]
1422    fn test_constraint_loss_no_violation() {
1423        let device = Device::Cpu;
1424        let predictions = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
1425        let targets = Tensor::new(&[1.5f32, 2.5, 3.5], &device).unwrap();
1426
1427        let task_loss = Loss::mse(&predictions, &targets).unwrap();
1428        let task_loss_val = task_loss.to_vec0::<f32>().unwrap();
1429
1430        let constraint_loss = ConstraintLoss::new(0.5);
1431
1432        // No constraint violation
1433        let total_loss = constraint_loss
1434            .compute(&task_loss, &predictions, |_pred| Ok(0.0))
1435            .unwrap();
1436        let total_loss_val = total_loss.to_vec0::<f32>().unwrap();
1437
1438        // Should equal task loss when no violation
1439        assert!((total_loss_val - task_loss_val).abs() < 1e-5);
1440    }
1441
1442    #[test]
1443    fn test_constraint_loss_with_violation() {
1444        let device = Device::Cpu;
1445        let predictions = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
1446        let targets = Tensor::new(&[1.5f32, 2.5, 3.5], &device).unwrap();
1447
1448        let task_loss = Loss::mse(&predictions, &targets).unwrap();
1449        let task_loss_val = task_loss.to_vec0::<f32>().unwrap();
1450
1451        let constraint_loss = ConstraintLoss::new(0.5);
1452
1453        // Constraint violation of 1.0
1454        let total_loss = constraint_loss
1455            .compute(&task_loss, &predictions, |_pred| Ok(1.0))
1456            .unwrap();
1457        let total_loss_val = total_loss.to_vec0::<f32>().unwrap();
1458
1459        // Should be task_loss + 0.5 * 1.0 = task_loss + 0.5
1460        let expected = task_loss_val + 0.5;
1461        assert!((total_loss_val - expected).abs() < 1e-5);
1462    }
1463
1464    #[test]
1465    fn test_constraint_loss_scaling() {
1466        let device = Device::Cpu;
1467        let predictions = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
1468        let targets = Tensor::new(&[1.5f32, 2.5, 3.5], &device).unwrap();
1469
1470        let task_loss = Loss::mse(&predictions, &targets).unwrap();
1471        let task_loss_val = task_loss.to_vec0::<f32>().unwrap();
1472
1473        // Test different constraint weights
1474        let weights = [0.1, 0.5, 1.0, 2.0];
1475        let violation = 1.5;
1476
1477        for &weight in &weights {
1478            let constraint_loss = ConstraintLoss::new(weight);
1479            let total_loss = constraint_loss
1480                .compute(&task_loss, &predictions, |_pred| Ok(violation))
1481                .unwrap();
1482            let total_loss_val = total_loss.to_vec0::<f32>().unwrap();
1483
1484            let expected = task_loss_val + weight * violation;
1485            assert!(
1486                (total_loss_val - expected).abs() < 1e-4,
1487                "Weight {} failed: got {}, expected {}",
1488                weight,
1489                total_loss_val,
1490                expected
1491            );
1492        }
1493    }
1494
1495    #[test]
1496    fn test_checkpoint_save_load() {
1497        use std::env;
1498        use std::fs;
1499
1500        let temp_dir = env::temp_dir().join("kizzasi_checkpoint_test");
1501        fs::create_dir_all(&temp_dir).unwrap();
1502
1503        // Create a model
1504        let config = KizzasiConfig::new()
1505            .input_dim(3)
1506            .output_dim(3)
1507            .hidden_dim(64)
1508            .state_dim(8)
1509            .num_layers(2);
1510
1511        let training_config = TrainingConfig {
1512            epochs: 5,
1513            learning_rate: 1e-3,
1514            ..Default::default()
1515        };
1516
1517        let model = TrainableSSM::new(config.clone(), training_config.clone()).unwrap();
1518        let trainer = Trainer::new(model, training_config).unwrap();
1519
1520        // Save checkpoint
1521        trainer
1522            .save_checkpoint(&temp_dir, "test_checkpoint")
1523            .unwrap();
1524
1525        // Verify files exist
1526        assert!(temp_dir.join("test_checkpoint.safetensors").exists());
1527        assert!(temp_dir.join("test_checkpoint.json").exists());
1528
1529        // Load checkpoint
1530        let loaded_trainer =
1531            Trainer::load_checkpoint(&temp_dir, "test_checkpoint", config).unwrap();
1532
1533        // Verify loaded config matches
1534        assert_eq!(loaded_trainer.config.epochs, 5);
1535        assert_eq!(loaded_trainer.config.learning_rate, 1e-3);
1536        assert_eq!(loaded_trainer.current_step, 0);
1537
1538        // Clean up
1539        fs::remove_dir_all(&temp_dir).unwrap();
1540    }
1541
1542    #[test]
1543    fn test_checkpoint_auto_save() {
1544        use std::env;
1545        use std::fs;
1546
1547        let temp_dir = env::temp_dir().join("kizzasi_checkpoint_auto_test");
1548        fs::create_dir_all(&temp_dir).unwrap();
1549
1550        let config = KizzasiConfig::new()
1551            .input_dim(3)
1552            .output_dim(3)
1553            .hidden_dim(64)
1554            .state_dim(8)
1555            .num_layers(2);
1556
1557        let training_config = TrainingConfig::default();
1558        let model = TrainableSSM::new(config, training_config.clone()).unwrap();
1559        let mut trainer = Trainer::new(model, training_config).unwrap();
1560
1561        // Record some metrics to simulate training
1562        trainer.metrics.record_train_loss(0, 0.5);
1563
1564        // Save checkpoint with auto naming
1565        trainer.save_checkpoint_auto(&temp_dir).unwrap();
1566
1567        // Verify file exists with auto-generated name
1568        assert!(temp_dir.join("checkpoint_epoch_1.safetensors").exists());
1569        assert!(temp_dir.join("checkpoint_epoch_1.json").exists());
1570
1571        // Clean up
1572        fs::remove_dir_all(&temp_dir).unwrap();
1573    }
1574
1575    #[test]
1576    fn test_checkpoint_best_save() {
1577        use std::env;
1578        use std::fs;
1579
1580        let temp_dir = env::temp_dir().join("kizzasi_checkpoint_best_test");
1581        fs::create_dir_all(&temp_dir).unwrap();
1582
1583        let config = KizzasiConfig::new()
1584            .input_dim(3)
1585            .output_dim(3)
1586            .hidden_dim(64)
1587            .state_dim(8)
1588            .num_layers(2);
1589
1590        let training_config = TrainingConfig::default();
1591        let model = TrainableSSM::new(config, training_config.clone()).unwrap();
1592        let mut trainer = Trainer::new(model, training_config).unwrap();
1593
1594        // Simulate training epoch 0 (not best yet)
1595        trainer.metrics.record_train_loss(0, 1.2);
1596        trainer.metrics.record_val_loss(0, 1.0);
1597        trainer.save_best_checkpoint(&temp_dir).unwrap();
1598
1599        // Epoch 0 is the best so far, so checkpoint should be saved
1600        assert!(temp_dir.join("best.safetensors").exists());
1601        assert!(temp_dir.join("best.json").exists());
1602
1603        // Simulate training epoch 1 with worse loss (should not overwrite)
1604        trainer.metrics.record_train_loss(1, 0.9);
1605        trainer.metrics.record_val_loss(1, 1.2);
1606
1607        // Remove old best to test that it doesn't get overwritten
1608        fs::remove_file(temp_dir.join("best.safetensors")).unwrap();
1609        fs::remove_file(temp_dir.join("best.json")).unwrap();
1610
1611        trainer.save_best_checkpoint(&temp_dir).unwrap();
1612        // Should not save because epoch 1 is not the best
1613        assert!(!temp_dir.join("best.safetensors").exists());
1614
1615        // Clean up
1616        fs::remove_dir_all(&temp_dir).unwrap();
1617    }
1618
1619    #[test]
1620    fn test_checkpoint_metadata() {
1621        use std::env;
1622        use std::fs;
1623
1624        let temp_dir = env::temp_dir().join("kizzasi_checkpoint_metadata_test");
1625        fs::create_dir_all(&temp_dir).unwrap();
1626
1627        let config = KizzasiConfig::new()
1628            .input_dim(3)
1629            .output_dim(3)
1630            .hidden_dim(64)
1631            .state_dim(8)
1632            .num_layers(2);
1633
1634        let training_config = TrainingConfig::default();
1635        let model = TrainableSSM::new(config, training_config.clone()).unwrap();
1636        let mut trainer = Trainer::new(model, training_config).unwrap();
1637
1638        // Add some metrics
1639        trainer.metrics.record_train_loss(0, 0.5);
1640        trainer.metrics.record_val_loss(0, 0.45);
1641
1642        // Save checkpoint
1643        trainer.save_checkpoint(&temp_dir, "metadata_test").unwrap();
1644
1645        // Load and verify metadata
1646        let metadata_path = temp_dir.join("metadata_test.json");
1647        let metadata_json = fs::read_to_string(&metadata_path).unwrap();
1648        let metadata: CheckpointMetadata = serde_json::from_str(&metadata_json).unwrap();
1649
1650        assert_eq!(metadata.version, env!("CARGO_PKG_VERSION"));
1651        assert!(!metadata.timestamp.is_empty());
1652        assert_eq!(metadata.current_step, 0);
1653        assert!(metadata.metrics.val_loss(0).is_some());
1654        assert_eq!(metadata.metrics.val_loss(0).unwrap(), 0.45);
1655
1656        // Clean up
1657        fs::remove_dir_all(&temp_dir).unwrap();
1658    }
1659}