Skip to main content

kizzasi_core/
training_loop.rs

1//! Training loop — ConstraintLoss, Loss, Trainer, CheckpointMetadata
2//!
3//! This module contains the high-level training utilities:
4//!
5//! - [`ConstraintLoss`] — bridges kizzasi-logic constraints with candle tensor ops
6//! - [`Loss`] — MSE, MAE, Huber, and cross-entropy loss functions
7//! - [`Trainer`] — full training loop with scheduler, metrics, validation, and checkpointing
8//! - [`CheckpointMetadata`] — serialisable checkpoint state for training resumption
9
10use crate::config::KizzasiConfig;
11use crate::dataloader::TimeSeriesDataLoader;
12use crate::error::{CoreError, CoreResult};
13use crate::metrics::{MetricsLogger, TrainingMetrics};
14use crate::scheduler::LRScheduler;
15use crate::training_core::{SchedulerType, TrainableSSM, TrainingConfig};
16use candle_core::Tensor;
17use candle_nn::{AdamW, Optimizer};
18use serde::{Deserialize, Serialize};
19
20/// Constraint-aware loss wrapper
21///
22/// Bridges kizzasi-logic constraints with candle tensor operations.
23/// Allows combining task loss with constraint violations for constrained optimization.
24///
25/// # Examples
26///
27/// ```rust,ignore
28/// use kizzasi_core::{ConstraintLoss, Loss};
29///
30/// let constraint_loss = ConstraintLoss::new(0.1);
31///
32/// // In training loop:
33/// let task_loss = Loss::mse(&predictions, &targets)?;
34/// let total_loss = constraint_loss.compute(&task_loss, &predictions, |pred| {
35///     // Compute constraint violation from prediction
36///     Ok(0.0)
37/// })?;
38/// ```
39pub struct ConstraintLoss {
40    /// Base weight for constraint violations
41    pub(crate) constraint_weight: f32,
42}
43
44impl ConstraintLoss {
45    /// Create a new constraint-aware loss
46    pub fn new(constraint_weight: f32) -> Self {
47        Self { constraint_weight }
48    }
49
50    /// Compute combined loss: task_loss + constraint_weight * constraint_penalty
51    ///
52    /// # Arguments
53    /// * `task_loss` - Base task loss (MSE, MAE, etc.)
54    /// * `prediction` - Model prediction tensor
55    /// * `constraint_fn` - Function that computes constraint violation from prediction
56    pub fn compute<F>(
57        &self,
58        task_loss: &Tensor,
59        prediction: &Tensor,
60        constraint_fn: F,
61    ) -> CoreResult<Tensor>
62    where
63        F: Fn(&Tensor) -> CoreResult<f32>,
64    {
65        // Compute constraint violation
66        let violation = constraint_fn(prediction)?;
67
68        // Add constraint penalty to task loss
69        // Create a scalar penalty value matching task_loss shape
70        let penalty_value = self.constraint_weight * violation;
71
72        // Use affine to add the penalty: task_loss + penalty = task_loss * 1.0 + penalty
73        task_loss
74            .affine(1.0, penalty_value as f64)
75            .map_err(|e| CoreError::Generic(format!("Failed to add constraint penalty: {}", e)))
76    }
77}
78
79/// Loss functions for training
80pub struct Loss;
81
82impl Loss {
83    /// Mean Squared Error loss
84    pub fn mse(predictions: &Tensor, targets: &Tensor) -> CoreResult<Tensor> {
85        predictions
86            .sub(targets)
87            .map_err(|e| CoreError::Generic(format!("MSE subtraction failed: {}", e)))?
88            .sqr()
89            .map_err(|e| CoreError::Generic(format!("MSE square failed: {}", e)))?
90            .mean_all()
91            .map_err(|e| CoreError::Generic(format!("MSE mean failed: {}", e)))
92    }
93
94    /// Mean Absolute Error loss
95    pub fn mae(predictions: &Tensor, targets: &Tensor) -> CoreResult<Tensor> {
96        predictions
97            .sub(targets)
98            .map_err(|e| CoreError::Generic(format!("MAE subtraction failed: {}", e)))?
99            .abs()
100            .map_err(|e| CoreError::Generic(format!("MAE abs failed: {}", e)))?
101            .mean_all()
102            .map_err(|e| CoreError::Generic(format!("MAE mean failed: {}", e)))
103    }
104
105    /// Huber loss (smooth L1 loss)
106    pub fn huber(predictions: &Tensor, targets: &Tensor, delta: f64) -> CoreResult<Tensor> {
107        let diff = predictions
108            .sub(targets)
109            .map_err(|e| CoreError::Generic(format!("Huber subtraction failed: {}", e)))?;
110        let abs_diff = diff
111            .abs()
112            .map_err(|e| CoreError::Generic(format!("Huber abs failed: {}", e)))?;
113
114        // If |diff| <= delta: 0.5 * diff^2
115        // If |diff| > delta: delta * (|diff| - 0.5 * delta)
116        let squared = diff
117            .sqr()
118            .map_err(|e| CoreError::Generic(format!("Huber square failed: {}", e)))?
119            .affine(0.5, 0.0)
120            .map_err(|e| CoreError::Generic(format!("Huber mul 0.5 failed: {}", e)))?;
121
122        let linear_offset = delta * delta * 0.5;
123        let linear = abs_diff
124            .affine(delta, -linear_offset)
125            .map_err(|e| CoreError::Generic(format!("Huber linear computation failed: {}", e)))?;
126
127        let mask = abs_diff
128            .le(delta)
129            .map_err(|e| CoreError::Generic(format!("Huber comparison failed: {}", e)))?
130            .to_dtype(predictions.dtype())
131            .map_err(|e| CoreError::Generic(format!("Huber mask conversion failed: {}", e)))?;
132
133        // Invert mask: 1 - mask
134        let inv_mask = mask
135            .affine(-1.0, 1.0)
136            .map_err(|e| CoreError::Generic(format!("Huber mask inversion failed: {}", e)))?;
137
138        let loss = squared
139            .mul(&mask)
140            .map_err(|e| CoreError::Generic(format!("Huber squared mul failed: {}", e)))?
141            .add(
142                &linear
143                    .mul(&inv_mask)
144                    .map_err(|e| CoreError::Generic(format!("Huber linear mul failed: {}", e)))?,
145            )
146            .map_err(|e| CoreError::Generic(format!("Huber final add failed: {}", e)))?;
147
148        loss.mean_all()
149            .map_err(|e| CoreError::Generic(format!("Huber mean failed: {}", e)))
150    }
151
152    /// Cross-entropy loss for classification
153    pub fn cross_entropy(logits: &Tensor, targets: &Tensor) -> CoreResult<Tensor> {
154        // Log softmax
155        let log_probs = candle_nn::ops::log_softmax(logits, candle_core::D::Minus1)
156            .map_err(|e| CoreError::Generic(format!("Log softmax failed: {}", e)))?;
157
158        // Negative log likelihood
159        let nll = log_probs
160            .mul(targets)
161            .map_err(|e| CoreError::Generic(format!("NLL multiplication failed: {}", e)))?
162            .sum_all()
163            .map_err(|e| CoreError::Generic(format!("NLL sum failed: {}", e)))?
164            .neg()
165            .map_err(|e| CoreError::Generic(format!("NLL negation failed: {}", e)))?;
166
167        // Average over batch
168        let batch_size = logits
169            .dim(0)
170            .map_err(|e| CoreError::Generic(format!("Failed to get batch size: {}", e)))?;
171        nll.affine(1.0 / batch_size as f64, 0.0)
172            .map_err(|e| CoreError::Generic(format!("Cross entropy division failed: {}", e)))
173    }
174}
175
176/// Training utilities with scheduler, metrics, and validation
177pub struct Trainer {
178    pub(crate) model: TrainableSSM,
179    pub(crate) optimizer: AdamW,
180    pub(crate) config: TrainingConfig,
181    pub(crate) scheduler: Option<Box<dyn LRScheduler>>,
182    pub(crate) metrics: TrainingMetrics,
183    pub(crate) logger: MetricsLogger,
184    pub(crate) current_step: usize,
185}
186
187impl Trainer {
188    /// Create a new trainer
189    pub fn new(model: TrainableSSM, config: TrainingConfig) -> CoreResult<Self> {
190        let optimizer = model.create_optimizer()?;
191
192        // Create scheduler based on config
193        let scheduler = Self::create_scheduler(&config);
194
195        let metrics = TrainingMetrics::new();
196
197        let logger = MetricsLogger::new()
198            .with_verbose(config.track_metrics)
199            .with_log_interval(config.log_interval);
200
201        Ok(Self {
202            model,
203            optimizer,
204            config,
205            scheduler,
206            metrics,
207            logger,
208            current_step: 0,
209        })
210    }
211
212    /// Create scheduler from config
213    fn create_scheduler(config: &TrainingConfig) -> Option<Box<dyn LRScheduler>> {
214        use crate::scheduler::*;
215
216        config.scheduler.as_ref().map(|sched_type| {
217            let total_steps = config.epochs * 100; // Rough estimate, can be updated later
218
219            match sched_type {
220                SchedulerType::Constant => {
221                    Box::new(ConstantScheduler::new(config.learning_rate)) as Box<dyn LRScheduler>
222                }
223                SchedulerType::Linear {
224                    warmup_steps,
225                    final_lr,
226                } => Box::new(LinearScheduler::new(
227                    config.learning_rate,
228                    *final_lr,
229                    total_steps,
230                    *warmup_steps,
231                )) as Box<dyn LRScheduler>,
232                SchedulerType::Cosine {
233                    warmup_steps,
234                    min_lr,
235                } => Box::new(
236                    CosineScheduler::new(config.learning_rate, total_steps, *warmup_steps)
237                        .with_min_lr(*min_lr),
238                ) as Box<dyn LRScheduler>,
239                SchedulerType::Step {
240                    milestones,
241                    decay_factor,
242                } => Box::new(StepScheduler::new(
243                    config.learning_rate,
244                    *decay_factor,
245                    milestones.clone(),
246                )) as Box<dyn LRScheduler>,
247                SchedulerType::Exponential {
248                    decay_rate,
249                    decay_steps,
250                } => Box::new(ExponentialScheduler::new(
251                    config.learning_rate,
252                    *decay_rate,
253                    *decay_steps,
254                )) as Box<dyn LRScheduler>,
255                SchedulerType::OneCycle { warmup_pct } => Box::new(
256                    OneCycleScheduler::new(config.learning_rate, total_steps)
257                        .with_warmup_pct(*warmup_pct),
258                ) as Box<dyn LRScheduler>,
259                SchedulerType::Polynomial { final_lr, power } => Box::new(PolynomialScheduler::new(
260                    config.learning_rate,
261                    *final_lr,
262                    total_steps,
263                    *power,
264                ))
265                    as Box<dyn LRScheduler>,
266            }
267        })
268    }
269
270    /// Get current learning rate
271    fn get_current_lr(&self) -> f64 {
272        self.scheduler
273            .as_ref()
274            .map(|s| s.get_lr(self.current_step))
275            .unwrap_or(self.config.learning_rate)
276    }
277
278    /// Train for one epoch
279    pub fn train_epoch<F>(
280        &mut self,
281        data_loader: &[(Tensor, Tensor)],
282        loss_fn: F,
283    ) -> CoreResult<f32>
284    where
285        F: Fn(&Tensor, &Tensor) -> CoreResult<Tensor>,
286    {
287        let mut total_loss = 0.0;
288        let num_batches = data_loader.len();
289        let epoch = self.current_step / num_batches.max(1);
290
291        for (batch_idx, (inputs, targets)) in data_loader.iter().enumerate() {
292            // Update learning rate from scheduler
293            let lr = self.get_current_lr();
294            if self.config.track_metrics {
295                self.metrics.record_learning_rate(lr);
296            }
297
298            // Forward pass
299            let predictions = self.model.forward(inputs)?;
300
301            // Compute loss
302            let loss = loss_fn(&predictions, targets)?;
303
304            // Backward pass
305            self.optimizer
306                .backward_step(&loss)
307                .map_err(|e| CoreError::Generic(format!("Backward step failed: {}", e)))?;
308
309            // Accumulate loss
310            let loss_val = loss
311                .to_vec0::<f32>()
312                .map_err(|e| CoreError::Generic(format!("Failed to extract loss value: {}", e)))?;
313            total_loss += loss_val;
314
315            // Track metrics
316            if self.config.track_metrics {
317                self.metrics.record_train_loss(epoch, loss_val);
318                self.logger.log_batch(epoch, batch_idx, loss_val);
319
320                // Compute and track gradient norm
321                let grad_norm = self.compute_grad_norm()?;
322                self.metrics.record_grad_norm(grad_norm);
323            }
324
325            // Gradient clipping if enabled
326            if let Some(max_norm) = self.config.grad_clip {
327                self.clip_gradients(max_norm)?;
328            }
329
330            self.current_step += 1;
331        }
332
333        Ok(total_loss / num_batches as f32)
334    }
335
336    /// Compute gradient norm
337    fn compute_grad_norm(&self) -> CoreResult<f32> {
338        // Placeholder: In candle, gradient norms would be computed from VarMap
339        // For now, return a dummy value
340        // TODO: Implement proper gradient norm computation when candle exposes gradient access
341        Ok(1.0)
342    }
343
344    /// Clip gradients by global norm
345    ///
346    /// Note: Gradient clipping is handled internally by candle's optimizer.
347    /// This is a placeholder for custom gradient clipping if needed.
348    fn clip_gradients(&self, _max_norm: f32) -> CoreResult<()> {
349        // Gradient clipping will be handled by the optimizer's built-in mechanism
350        // or via custom backward hooks in future implementations
351        Ok(())
352    }
353
354    /// Evaluate on validation data
355    pub fn evaluate<F>(&self, data_loader: &[(Tensor, Tensor)], loss_fn: F) -> CoreResult<f32>
356    where
357        F: Fn(&Tensor, &Tensor) -> CoreResult<Tensor>,
358    {
359        let mut total_loss = 0.0;
360        let num_batches = data_loader.len();
361
362        for (inputs, targets) in data_loader {
363            // Forward pass (no gradient tracking needed)
364            let predictions = self.model.forward(inputs)?;
365
366            // Compute loss
367            let loss = loss_fn(&predictions, targets)?;
368
369            // Accumulate loss
370            let loss_val = loss
371                .to_vec0::<f32>()
372                .map_err(|e| CoreError::Generic(format!("Failed to extract loss value: {}", e)))?;
373            total_loss += loss_val;
374        }
375
376        Ok(total_loss / num_batches as f32)
377    }
378
379    /// Full training loop with validation and early stopping
380    pub fn fit<F>(
381        &mut self,
382        mut train_loader: TimeSeriesDataLoader,
383        mut val_loader: Option<TimeSeriesDataLoader>,
384        loss_fn: F,
385    ) -> CoreResult<()>
386    where
387        F: Fn(&Tensor, &Tensor) -> CoreResult<Tensor> + Copy,
388    {
389        use std::time::Instant;
390
391        for epoch in 0..self.config.epochs {
392            let epoch_start = Instant::now();
393
394            // Shuffle training data
395            train_loader.shuffle();
396
397            // Prepare batches (simplified - actual implementation would iterate batches)
398            // For now, this is a placeholder for the integration
399            // TODO: Implement proper batch iteration with TimeSeriesDataLoader
400            let train_batches: Vec<(Tensor, Tensor)> = Vec::new();
401
402            // Train for one epoch
403            let train_loss = self.train_epoch(&train_batches, loss_fn)?;
404
405            // Validation
406            let val_loss = if let Some(ref mut _val_data) = val_loader {
407                let val_batches: Vec<(Tensor, Tensor)> = Vec::new();
408                let val_loss = self.evaluate(&val_batches, loss_fn)?;
409
410                if self.config.track_metrics {
411                    self.metrics.record_val_loss(epoch, val_loss);
412                }
413
414                Some(val_loss)
415            } else {
416                None
417            };
418
419            // Track epoch duration
420            let epoch_duration = epoch_start.elapsed().as_secs_f64();
421            if self.config.track_metrics {
422                self.metrics.record_epoch_duration(epoch, epoch_duration);
423            }
424
425            // Log epoch metrics
426            let current_lr = self.get_current_lr();
427            self.logger
428                .log_epoch(epoch, train_loss, val_loss, current_lr);
429
430            // Early stopping check
431            if let Some(patience) = self.config.early_stopping_patience {
432                if !self.metrics.is_improving(patience) {
433                    tracing::info!("Early stopping triggered at epoch {}", epoch);
434                    break;
435                }
436            }
437        }
438
439        // Log training summary
440        if self.config.track_metrics {
441            let summary = self.metrics.summary();
442            self.logger.log_summary(&summary);
443        }
444
445        Ok(())
446    }
447
448    /// Get reference to the model
449    pub fn model(&self) -> &TrainableSSM {
450        &self.model
451    }
452
453    /// Get mutable reference to the model
454    pub fn model_mut(&mut self) -> &mut TrainableSSM {
455        &mut self.model
456    }
457
458    /// Get reference to training metrics
459    pub fn metrics(&self) -> &TrainingMetrics {
460        &self.metrics
461    }
462
463    /// Get mutable reference to training metrics
464    pub fn metrics_mut(&mut self) -> &mut TrainingMetrics {
465        &mut self.metrics
466    }
467
468    /// Get current training step
469    pub fn current_step(&self) -> usize {
470        self.current_step
471    }
472
473    /// Save checkpoint to disk
474    ///
475    /// Saves model weights, optimizer state, training configuration, metrics, and metadata.
476    ///
477    /// # Arguments
478    /// * `path` - Directory to save checkpoint files
479    /// * `name` - Checkpoint name (without extension)
480    ///
481    /// # Example
482    /// ```rust,ignore
483    /// trainer.save_checkpoint("checkpoints", "epoch_10")?;
484    /// // Creates: checkpoints/epoch_10.safetensors and checkpoints/epoch_10.json
485    /// ```
486    pub fn save_checkpoint<P: AsRef<std::path::Path>>(
487        &self,
488        path: P,
489        name: &str,
490    ) -> CoreResult<()> {
491        use std::fs;
492        use std::path::PathBuf;
493
494        let checkpoint_dir = path.as_ref();
495        fs::create_dir_all(checkpoint_dir).map_err(|e| {
496            CoreError::Generic(format!("Failed to create checkpoint directory: {}", e))
497        })?;
498
499        // Save model weights to safetensors
500        let weights_path: PathBuf = checkpoint_dir.join(format!("{}.safetensors", name));
501        self.model
502            .save_weights(&weights_path)
503            .map_err(|e| CoreError::Generic(format!("Failed to save model weights: {}", e)))?;
504
505        // Create checkpoint metadata
506        let metadata = CheckpointMetadata {
507            version: env!("CARGO_PKG_VERSION").to_string(),
508            timestamp: chrono::Utc::now().to_rfc3339(),
509            current_step: self.current_step,
510            current_epoch: self.metrics.summary().total_epochs,
511            config: self.config.clone(),
512            metrics: self.metrics.clone(),
513        };
514
515        // Save metadata to JSON
516        let metadata_path: PathBuf = checkpoint_dir.join(format!("{}.json", name));
517        let metadata_json = serde_json::to_string_pretty(&metadata).map_err(|e| {
518            CoreError::Generic(format!("Failed to serialize checkpoint metadata: {}", e))
519        })?;
520
521        fs::write(&metadata_path, metadata_json).map_err(|e| {
522            CoreError::Generic(format!("Failed to write checkpoint metadata: {}", e))
523        })?;
524
525        tracing::info!(
526            "Checkpoint saved: weights={}, metadata={}",
527            weights_path.display(),
528            metadata_path.display()
529        );
530
531        Ok(())
532    }
533
534    /// Load checkpoint and resume training
535    ///
536    /// Creates a new Trainer from a saved checkpoint, restoring model weights,
537    /// configuration, and training state.
538    ///
539    /// # Arguments
540    /// * `path` - Directory containing checkpoint files
541    /// * `name` - Checkpoint name (without extension)
542    /// * `model_config` - Model configuration (must match saved model)
543    ///
544    /// # Example
545    /// ```rust,ignore
546    /// let trainer = Trainer::load_checkpoint("checkpoints", "epoch_10", model_config)?;
547    /// // Continue training from epoch 10
548    /// ```
549    pub fn load_checkpoint<P: AsRef<std::path::Path>>(
550        path: P,
551        name: &str,
552        model_config: KizzasiConfig,
553    ) -> CoreResult<Self> {
554        use std::fs;
555        use std::path::PathBuf;
556
557        let checkpoint_dir = path.as_ref();
558
559        // Load metadata from JSON
560        let metadata_path: PathBuf = checkpoint_dir.join(format!("{}.json", name));
561        let metadata_json = fs::read_to_string(&metadata_path).map_err(|e| {
562            CoreError::Generic(format!("Failed to read checkpoint metadata: {}", e))
563        })?;
564
565        let metadata: CheckpointMetadata = serde_json::from_str(&metadata_json).map_err(|e| {
566            CoreError::Generic(format!("Failed to parse checkpoint metadata: {}", e))
567        })?;
568
569        // Load model weights
570        let weights_path: PathBuf = checkpoint_dir.join(format!("{}.safetensors", name));
571        let mut model = TrainableSSM::new(model_config, metadata.config.clone())?;
572        model
573            .load_weights(&weights_path)
574            .map_err(|e| CoreError::Generic(format!("Failed to load model weights: {}", e)))?;
575
576        // Create trainer with loaded state
577        let optimizer = model.create_optimizer()?;
578        let scheduler = Self::create_scheduler(&metadata.config);
579
580        let logger = MetricsLogger::new()
581            .with_verbose(metadata.config.track_metrics)
582            .with_log_interval(metadata.config.log_interval);
583
584        tracing::info!(
585            "Checkpoint loaded: version={}, step={}, epoch={}",
586            metadata.version,
587            metadata.current_step,
588            metadata.current_epoch
589        );
590
591        Ok(Self {
592            model,
593            optimizer,
594            config: metadata.config,
595            scheduler,
596            metrics: metadata.metrics,
597            logger,
598            current_step: metadata.current_step,
599        })
600    }
601
602    /// Save checkpoint with automatic naming (epoch-based)
603    ///
604    /// Convenience method that automatically names checkpoints based on current epoch.
605    ///
606    /// # Example
607    /// ```rust,ignore
608    /// trainer.save_checkpoint_auto("checkpoints")?;
609    /// // Creates: checkpoints/checkpoint_epoch_5.safetensors, etc.
610    /// ```
611    pub fn save_checkpoint_auto<P: AsRef<std::path::Path>>(&self, path: P) -> CoreResult<()> {
612        let current_epoch = self.metrics.summary().total_epochs;
613        let name = format!("checkpoint_epoch_{}", current_epoch);
614        self.save_checkpoint(path, &name)
615    }
616
617    /// Save checkpoint if this is the best epoch (lowest validation loss)
618    ///
619    /// Automatically saves a "best" checkpoint when validation loss improves.
620    ///
621    /// # Example
622    /// ```rust,ignore
623    /// // After each validation epoch
624    /// trainer.save_best_checkpoint("checkpoints")?;
625    /// ```
626    pub fn save_best_checkpoint<P: AsRef<std::path::Path>>(&self, path: P) -> CoreResult<()> {
627        let summary = self.metrics.summary();
628
629        // Only save if this is the best epoch
630        // Note: total_epochs is 1-indexed (count), best_epoch is 0-indexed (epoch number)
631        if let (Some(best_epoch), Some(_best_loss)) = (summary.best_epoch, summary.best_val_loss) {
632            // Current epoch is total_epochs - 1 (convert from count to 0-indexed)
633            let current_epoch = summary.total_epochs.saturating_sub(1);
634            if current_epoch == best_epoch {
635                tracing::info!("New best validation loss! Saving best checkpoint");
636                return self.save_checkpoint(path, "best");
637            }
638        }
639
640        Ok(())
641    }
642}
643
644/// Checkpoint metadata for training state persistence
645#[derive(Debug, Clone, Serialize, Deserialize)]
646pub struct CheckpointMetadata {
647    /// Package version when checkpoint was created
648    pub version: String,
649    /// ISO 8601 timestamp
650    pub timestamp: String,
651    /// Current training step
652    pub current_step: usize,
653    /// Current epoch number
654    pub current_epoch: usize,
655    /// Training configuration
656    pub config: TrainingConfig,
657    /// Training metrics history
658    pub metrics: TrainingMetrics,
659}
660
661#[cfg(test)]
662mod tests {
663    use super::*;
664    use crate::training_core::TrainingConfig;
665    use candle_core::{Device, Tensor};
666
667    #[test]
668    fn test_mse_loss() {
669        let device = Device::Cpu;
670        let predictions = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
671        let targets = Tensor::new(&[1.5f32, 2.5, 3.5], &device).unwrap();
672
673        let loss = Loss::mse(&predictions, &targets).unwrap();
674        let loss_val = loss.to_vec0::<f32>().unwrap();
675
676        // Expected: mean((0.5^2 + 0.5^2 + 0.5^2)) = 0.25
677        assert!((loss_val - 0.25).abs() < 1e-5);
678    }
679
680    #[test]
681    fn test_trainer_with_scheduler() {
682        use crate::config::KizzasiConfig;
683        use crate::training_core::{SchedulerType, TrainableSSM};
684
685        let model_config = KizzasiConfig::new()
686            .input_dim(3)
687            .output_dim(3)
688            .hidden_dim(64)
689            .state_dim(8)
690            .num_layers(2);
691
692        let training_config = TrainingConfig::default().with_scheduler(SchedulerType::Linear {
693            warmup_steps: 50,
694            final_lr: 1e-6,
695        });
696
697        let model = TrainableSSM::new(model_config, training_config.clone()).unwrap();
698        let trainer = Trainer::new(model, training_config);
699
700        assert!(trainer.is_ok());
701        let trainer = trainer.unwrap();
702        assert!(trainer.scheduler.is_some());
703    }
704
705    #[test]
706    fn test_trainer_metrics_tracking() {
707        use crate::config::KizzasiConfig;
708        use crate::training_core::TrainableSSM;
709
710        let model_config = KizzasiConfig::new()
711            .input_dim(3)
712            .output_dim(3)
713            .hidden_dim(64)
714            .state_dim(8)
715            .num_layers(2);
716
717        let training_config = TrainingConfig::default();
718        let model = TrainableSSM::new(model_config, training_config.clone()).unwrap();
719        let trainer = Trainer::new(model, training_config).unwrap();
720
721        // Check that metrics are initialized
722        assert_eq!(trainer.metrics().current_step(), 0);
723        assert_eq!(trainer.current_step(), 0);
724    }
725
726    #[test]
727    fn test_mae_loss() {
728        let device = Device::Cpu;
729        let predictions = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
730        let targets = Tensor::new(&[1.5f32, 2.5, 3.5], &device).unwrap();
731
732        let loss = Loss::mae(&predictions, &targets).unwrap();
733        let loss_val = loss.to_vec0::<f32>().unwrap();
734
735        // Expected: mean(|0.5| + |0.5| + |0.5|) = 0.5
736        assert!((loss_val - 0.5).abs() < 1e-5);
737    }
738
739    #[test]
740    fn test_huber_loss() {
741        let device = Device::Cpu;
742        let predictions = Tensor::new(&[1.0f32, 2.0, 5.0], &device).unwrap();
743        let targets = Tensor::new(&[1.1f32, 2.1, 3.0], &device).unwrap();
744
745        let loss = Loss::huber(&predictions, &targets, 1.0).unwrap();
746        let loss_val = loss.to_vec0::<f32>().unwrap();
747
748        // Huber loss is smooth L1
749        assert!(loss_val > 0.0);
750        assert!(loss_val < 2.0); // Should be less than L1 loss for large errors
751    }
752
753    #[test]
754    fn test_constraint_loss_creation() {
755        let constraint_loss = ConstraintLoss::new(0.5);
756        assert_eq!(constraint_loss.constraint_weight, 0.5);
757    }
758
759    #[test]
760    fn test_constraint_loss_no_violation() {
761        let device = Device::Cpu;
762        let predictions = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
763        let targets = Tensor::new(&[1.5f32, 2.5, 3.5], &device).unwrap();
764
765        let task_loss = Loss::mse(&predictions, &targets).unwrap();
766        let task_loss_val = task_loss.to_vec0::<f32>().unwrap();
767
768        let constraint_loss = ConstraintLoss::new(0.5);
769
770        // No constraint violation
771        let total_loss = constraint_loss
772            .compute(&task_loss, &predictions, |_pred| Ok(0.0))
773            .unwrap();
774        let total_loss_val = total_loss.to_vec0::<f32>().unwrap();
775
776        // Should equal task loss when no violation
777        assert!((total_loss_val - task_loss_val).abs() < 1e-5);
778    }
779
780    #[test]
781    fn test_constraint_loss_with_violation() {
782        let device = Device::Cpu;
783        let predictions = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
784        let targets = Tensor::new(&[1.5f32, 2.5, 3.5], &device).unwrap();
785
786        let task_loss = Loss::mse(&predictions, &targets).unwrap();
787        let task_loss_val = task_loss.to_vec0::<f32>().unwrap();
788
789        let constraint_loss = ConstraintLoss::new(0.5);
790
791        // Constraint violation of 1.0
792        let total_loss = constraint_loss
793            .compute(&task_loss, &predictions, |_pred| Ok(1.0))
794            .unwrap();
795        let total_loss_val = total_loss.to_vec0::<f32>().unwrap();
796
797        // Should be task_loss + 0.5 * 1.0 = task_loss + 0.5
798        let expected = task_loss_val + 0.5;
799        assert!((total_loss_val - expected).abs() < 1e-5);
800    }
801
802    #[test]
803    fn test_constraint_loss_scaling() {
804        let device = Device::Cpu;
805        let predictions = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
806        let targets = Tensor::new(&[1.5f32, 2.5, 3.5], &device).unwrap();
807
808        let task_loss = Loss::mse(&predictions, &targets).unwrap();
809        let task_loss_val = task_loss.to_vec0::<f32>().unwrap();
810
811        // Test different constraint weights
812        let weights = [0.1, 0.5, 1.0, 2.0];
813        let violation = 1.5;
814
815        for &weight in &weights {
816            let constraint_loss = ConstraintLoss::new(weight);
817            let total_loss = constraint_loss
818                .compute(&task_loss, &predictions, |_pred| Ok(violation))
819                .unwrap();
820            let total_loss_val = total_loss.to_vec0::<f32>().unwrap();
821
822            let expected = task_loss_val + weight * violation;
823            assert!(
824                (total_loss_val - expected).abs() < 1e-4,
825                "Weight {} failed: got {}, expected {}",
826                weight,
827                total_loss_val,
828                expected
829            );
830        }
831    }
832
833    #[test]
834    fn test_checkpoint_save_load() {
835        use crate::config::KizzasiConfig;
836        use crate::training_core::TrainableSSM;
837        use std::env;
838        use std::fs;
839
840        let temp_dir = env::temp_dir().join("kizzasi_checkpoint_test");
841        fs::create_dir_all(&temp_dir).unwrap();
842
843        // Create a model
844        let config = KizzasiConfig::new()
845            .input_dim(3)
846            .output_dim(3)
847            .hidden_dim(64)
848            .state_dim(8)
849            .num_layers(2);
850
851        let training_config = TrainingConfig {
852            epochs: 5,
853            learning_rate: 1e-3,
854            ..Default::default()
855        };
856
857        let model = TrainableSSM::new(config.clone(), training_config.clone()).unwrap();
858        let trainer = Trainer::new(model, training_config).unwrap();
859
860        // Save checkpoint
861        trainer
862            .save_checkpoint(&temp_dir, "test_checkpoint")
863            .unwrap();
864
865        // Verify files exist
866        assert!(temp_dir.join("test_checkpoint.safetensors").exists());
867        assert!(temp_dir.join("test_checkpoint.json").exists());
868
869        // Load checkpoint
870        let loaded_trainer =
871            Trainer::load_checkpoint(&temp_dir, "test_checkpoint", config).unwrap();
872
873        // Verify loaded config matches
874        assert_eq!(loaded_trainer.config.epochs, 5);
875        assert_eq!(loaded_trainer.config.learning_rate, 1e-3);
876        assert_eq!(loaded_trainer.current_step, 0);
877
878        // Clean up
879        fs::remove_dir_all(&temp_dir).unwrap();
880    }
881
882    #[test]
883    fn test_checkpoint_auto_save() {
884        use crate::config::KizzasiConfig;
885        use crate::training_core::TrainableSSM;
886        use std::env;
887        use std::fs;
888
889        let temp_dir = env::temp_dir().join("kizzasi_checkpoint_auto_test");
890        fs::create_dir_all(&temp_dir).unwrap();
891
892        let config = KizzasiConfig::new()
893            .input_dim(3)
894            .output_dim(3)
895            .hidden_dim(64)
896            .state_dim(8)
897            .num_layers(2);
898
899        let training_config = TrainingConfig::default();
900        let model = TrainableSSM::new(config, training_config.clone()).unwrap();
901        let mut trainer = Trainer::new(model, training_config).unwrap();
902
903        // Record some metrics to simulate training
904        trainer.metrics.record_train_loss(0, 0.5);
905
906        // Save checkpoint with auto naming
907        trainer.save_checkpoint_auto(&temp_dir).unwrap();
908
909        // Verify file exists with auto-generated name
910        assert!(temp_dir.join("checkpoint_epoch_1.safetensors").exists());
911        assert!(temp_dir.join("checkpoint_epoch_1.json").exists());
912
913        // Clean up
914        fs::remove_dir_all(&temp_dir).unwrap();
915    }
916
917    #[test]
918    fn test_checkpoint_best_save() {
919        use crate::config::KizzasiConfig;
920        use crate::training_core::TrainableSSM;
921        use std::env;
922        use std::fs;
923
924        let temp_dir = env::temp_dir().join("kizzasi_checkpoint_best_test");
925        fs::create_dir_all(&temp_dir).unwrap();
926
927        let config = KizzasiConfig::new()
928            .input_dim(3)
929            .output_dim(3)
930            .hidden_dim(64)
931            .state_dim(8)
932            .num_layers(2);
933
934        let training_config = TrainingConfig::default();
935        let model = TrainableSSM::new(config, training_config.clone()).unwrap();
936        let mut trainer = Trainer::new(model, training_config).unwrap();
937
938        // Simulate training epoch 0 (not best yet)
939        trainer.metrics.record_train_loss(0, 1.2);
940        trainer.metrics.record_val_loss(0, 1.0);
941        trainer.save_best_checkpoint(&temp_dir).unwrap();
942
943        // Epoch 0 is the best so far, so checkpoint should be saved
944        assert!(temp_dir.join("best.safetensors").exists());
945        assert!(temp_dir.join("best.json").exists());
946
947        // Simulate training epoch 1 with worse loss (should not overwrite)
948        trainer.metrics.record_train_loss(1, 0.9);
949        trainer.metrics.record_val_loss(1, 1.2);
950
951        // Remove old best to test that it doesn't get overwritten
952        fs::remove_file(temp_dir.join("best.safetensors")).unwrap();
953        fs::remove_file(temp_dir.join("best.json")).unwrap();
954
955        trainer.save_best_checkpoint(&temp_dir).unwrap();
956        // Should not save because epoch 1 is not the best
957        assert!(!temp_dir.join("best.safetensors").exists());
958
959        // Clean up
960        fs::remove_dir_all(&temp_dir).unwrap();
961    }
962
963    #[test]
964    fn test_checkpoint_metadata() {
965        use crate::config::KizzasiConfig;
966        use crate::training_core::TrainableSSM;
967        use std::env;
968        use std::fs;
969
970        let temp_dir = env::temp_dir().join("kizzasi_checkpoint_metadata_test");
971        fs::create_dir_all(&temp_dir).unwrap();
972
973        let config = KizzasiConfig::new()
974            .input_dim(3)
975            .output_dim(3)
976            .hidden_dim(64)
977            .state_dim(8)
978            .num_layers(2);
979
980        let training_config = TrainingConfig::default();
981        let model = TrainableSSM::new(config, training_config.clone()).unwrap();
982        let mut trainer = Trainer::new(model, training_config).unwrap();
983
984        // Add some metrics
985        trainer.metrics.record_train_loss(0, 0.5);
986        trainer.metrics.record_val_loss(0, 0.45);
987
988        // Save checkpoint
989        trainer.save_checkpoint(&temp_dir, "metadata_test").unwrap();
990
991        // Load and verify metadata
992        let metadata_path = temp_dir.join("metadata_test.json");
993        let metadata_json = fs::read_to_string(&metadata_path).unwrap();
994        let metadata: CheckpointMetadata = serde_json::from_str(&metadata_json).unwrap();
995
996        assert_eq!(metadata.version, env!("CARGO_PKG_VERSION"));
997        assert!(!metadata.timestamp.is_empty());
998        assert_eq!(metadata.current_step, 0);
999        assert!(metadata.metrics.val_loss(0).is_some());
1000        assert_eq!(metadata.metrics.val_loss(0).unwrap(), 0.45);
1001
1002        // Clean up
1003        fs::remove_dir_all(&temp_dir).unwrap();
1004    }
1005}