1use crate::config::KizzasiConfig;
11use crate::dataloader::TimeSeriesDataLoader;
12use crate::error::{CoreError, CoreResult};
13use crate::metrics::{MetricsLogger, TrainingMetrics};
14use crate::scheduler::LRScheduler;
15use crate::training_core::{SchedulerType, TrainableSSM, TrainingConfig};
16use candle_core::Tensor;
17use candle_nn::{AdamW, Optimizer};
18use serde::{Deserialize, Serialize};
19
20pub struct ConstraintLoss {
40 pub(crate) constraint_weight: f32,
42}
43
44impl ConstraintLoss {
45 pub fn new(constraint_weight: f32) -> Self {
47 Self { constraint_weight }
48 }
49
50 pub fn compute<F>(
57 &self,
58 task_loss: &Tensor,
59 prediction: &Tensor,
60 constraint_fn: F,
61 ) -> CoreResult<Tensor>
62 where
63 F: Fn(&Tensor) -> CoreResult<f32>,
64 {
65 let violation = constraint_fn(prediction)?;
67
68 let penalty_value = self.constraint_weight * violation;
71
72 task_loss
74 .affine(1.0, penalty_value as f64)
75 .map_err(|e| CoreError::Generic(format!("Failed to add constraint penalty: {}", e)))
76 }
77}
78
79pub struct Loss;
81
82impl Loss {
83 pub fn mse(predictions: &Tensor, targets: &Tensor) -> CoreResult<Tensor> {
85 predictions
86 .sub(targets)
87 .map_err(|e| CoreError::Generic(format!("MSE subtraction failed: {}", e)))?
88 .sqr()
89 .map_err(|e| CoreError::Generic(format!("MSE square failed: {}", e)))?
90 .mean_all()
91 .map_err(|e| CoreError::Generic(format!("MSE mean failed: {}", e)))
92 }
93
94 pub fn mae(predictions: &Tensor, targets: &Tensor) -> CoreResult<Tensor> {
96 predictions
97 .sub(targets)
98 .map_err(|e| CoreError::Generic(format!("MAE subtraction failed: {}", e)))?
99 .abs()
100 .map_err(|e| CoreError::Generic(format!("MAE abs failed: {}", e)))?
101 .mean_all()
102 .map_err(|e| CoreError::Generic(format!("MAE mean failed: {}", e)))
103 }
104
105 pub fn huber(predictions: &Tensor, targets: &Tensor, delta: f64) -> CoreResult<Tensor> {
107 let diff = predictions
108 .sub(targets)
109 .map_err(|e| CoreError::Generic(format!("Huber subtraction failed: {}", e)))?;
110 let abs_diff = diff
111 .abs()
112 .map_err(|e| CoreError::Generic(format!("Huber abs failed: {}", e)))?;
113
114 let squared = diff
117 .sqr()
118 .map_err(|e| CoreError::Generic(format!("Huber square failed: {}", e)))?
119 .affine(0.5, 0.0)
120 .map_err(|e| CoreError::Generic(format!("Huber mul 0.5 failed: {}", e)))?;
121
122 let linear_offset = delta * delta * 0.5;
123 let linear = abs_diff
124 .affine(delta, -linear_offset)
125 .map_err(|e| CoreError::Generic(format!("Huber linear computation failed: {}", e)))?;
126
127 let mask = abs_diff
128 .le(delta)
129 .map_err(|e| CoreError::Generic(format!("Huber comparison failed: {}", e)))?
130 .to_dtype(predictions.dtype())
131 .map_err(|e| CoreError::Generic(format!("Huber mask conversion failed: {}", e)))?;
132
133 let inv_mask = mask
135 .affine(-1.0, 1.0)
136 .map_err(|e| CoreError::Generic(format!("Huber mask inversion failed: {}", e)))?;
137
138 let loss = squared
139 .mul(&mask)
140 .map_err(|e| CoreError::Generic(format!("Huber squared mul failed: {}", e)))?
141 .add(
142 &linear
143 .mul(&inv_mask)
144 .map_err(|e| CoreError::Generic(format!("Huber linear mul failed: {}", e)))?,
145 )
146 .map_err(|e| CoreError::Generic(format!("Huber final add failed: {}", e)))?;
147
148 loss.mean_all()
149 .map_err(|e| CoreError::Generic(format!("Huber mean failed: {}", e)))
150 }
151
152 pub fn cross_entropy(logits: &Tensor, targets: &Tensor) -> CoreResult<Tensor> {
154 let log_probs = candle_nn::ops::log_softmax(logits, candle_core::D::Minus1)
156 .map_err(|e| CoreError::Generic(format!("Log softmax failed: {}", e)))?;
157
158 let nll = log_probs
160 .mul(targets)
161 .map_err(|e| CoreError::Generic(format!("NLL multiplication failed: {}", e)))?
162 .sum_all()
163 .map_err(|e| CoreError::Generic(format!("NLL sum failed: {}", e)))?
164 .neg()
165 .map_err(|e| CoreError::Generic(format!("NLL negation failed: {}", e)))?;
166
167 let batch_size = logits
169 .dim(0)
170 .map_err(|e| CoreError::Generic(format!("Failed to get batch size: {}", e)))?;
171 nll.affine(1.0 / batch_size as f64, 0.0)
172 .map_err(|e| CoreError::Generic(format!("Cross entropy division failed: {}", e)))
173 }
174}
175
176pub struct Trainer {
178 pub(crate) model: TrainableSSM,
179 pub(crate) optimizer: AdamW,
180 pub(crate) config: TrainingConfig,
181 pub(crate) scheduler: Option<Box<dyn LRScheduler>>,
182 pub(crate) metrics: TrainingMetrics,
183 pub(crate) logger: MetricsLogger,
184 pub(crate) current_step: usize,
185}
186
187impl Trainer {
188 pub fn new(model: TrainableSSM, config: TrainingConfig) -> CoreResult<Self> {
190 let optimizer = model.create_optimizer()?;
191
192 let scheduler = Self::create_scheduler(&config);
194
195 let metrics = TrainingMetrics::new();
196
197 let logger = MetricsLogger::new()
198 .with_verbose(config.track_metrics)
199 .with_log_interval(config.log_interval);
200
201 Ok(Self {
202 model,
203 optimizer,
204 config,
205 scheduler,
206 metrics,
207 logger,
208 current_step: 0,
209 })
210 }
211
212 fn create_scheduler(config: &TrainingConfig) -> Option<Box<dyn LRScheduler>> {
214 use crate::scheduler::*;
215
216 config.scheduler.as_ref().map(|sched_type| {
217 let total_steps = config.epochs * 100; match sched_type {
220 SchedulerType::Constant => {
221 Box::new(ConstantScheduler::new(config.learning_rate)) as Box<dyn LRScheduler>
222 }
223 SchedulerType::Linear {
224 warmup_steps,
225 final_lr,
226 } => Box::new(LinearScheduler::new(
227 config.learning_rate,
228 *final_lr,
229 total_steps,
230 *warmup_steps,
231 )) as Box<dyn LRScheduler>,
232 SchedulerType::Cosine {
233 warmup_steps,
234 min_lr,
235 } => Box::new(
236 CosineScheduler::new(config.learning_rate, total_steps, *warmup_steps)
237 .with_min_lr(*min_lr),
238 ) as Box<dyn LRScheduler>,
239 SchedulerType::Step {
240 milestones,
241 decay_factor,
242 } => Box::new(StepScheduler::new(
243 config.learning_rate,
244 *decay_factor,
245 milestones.clone(),
246 )) as Box<dyn LRScheduler>,
247 SchedulerType::Exponential {
248 decay_rate,
249 decay_steps,
250 } => Box::new(ExponentialScheduler::new(
251 config.learning_rate,
252 *decay_rate,
253 *decay_steps,
254 )) as Box<dyn LRScheduler>,
255 SchedulerType::OneCycle { warmup_pct } => Box::new(
256 OneCycleScheduler::new(config.learning_rate, total_steps)
257 .with_warmup_pct(*warmup_pct),
258 ) as Box<dyn LRScheduler>,
259 SchedulerType::Polynomial { final_lr, power } => Box::new(PolynomialScheduler::new(
260 config.learning_rate,
261 *final_lr,
262 total_steps,
263 *power,
264 ))
265 as Box<dyn LRScheduler>,
266 }
267 })
268 }
269
270 fn get_current_lr(&self) -> f64 {
272 self.scheduler
273 .as_ref()
274 .map(|s| s.get_lr(self.current_step))
275 .unwrap_or(self.config.learning_rate)
276 }
277
278 pub fn train_epoch<F>(
280 &mut self,
281 data_loader: &[(Tensor, Tensor)],
282 loss_fn: F,
283 ) -> CoreResult<f32>
284 where
285 F: Fn(&Tensor, &Tensor) -> CoreResult<Tensor>,
286 {
287 let mut total_loss = 0.0;
288 let num_batches = data_loader.len();
289 let epoch = self.current_step / num_batches.max(1);
290
291 for (batch_idx, (inputs, targets)) in data_loader.iter().enumerate() {
292 let lr = self.get_current_lr();
294 if self.config.track_metrics {
295 self.metrics.record_learning_rate(lr);
296 }
297
298 let predictions = self.model.forward(inputs)?;
300
301 let loss = loss_fn(&predictions, targets)?;
303
304 self.optimizer
306 .backward_step(&loss)
307 .map_err(|e| CoreError::Generic(format!("Backward step failed: {}", e)))?;
308
309 let loss_val = loss
311 .to_vec0::<f32>()
312 .map_err(|e| CoreError::Generic(format!("Failed to extract loss value: {}", e)))?;
313 total_loss += loss_val;
314
315 if self.config.track_metrics {
317 self.metrics.record_train_loss(epoch, loss_val);
318 self.logger.log_batch(epoch, batch_idx, loss_val);
319
320 let grad_norm = self.compute_grad_norm()?;
322 self.metrics.record_grad_norm(grad_norm);
323 }
324
325 if let Some(max_norm) = self.config.grad_clip {
327 self.clip_gradients(max_norm)?;
328 }
329
330 self.current_step += 1;
331 }
332
333 Ok(total_loss / num_batches as f32)
334 }
335
336 fn compute_grad_norm(&self) -> CoreResult<f32> {
338 Ok(1.0)
342 }
343
344 fn clip_gradients(&self, _max_norm: f32) -> CoreResult<()> {
349 Ok(())
352 }
353
354 pub fn evaluate<F>(&self, data_loader: &[(Tensor, Tensor)], loss_fn: F) -> CoreResult<f32>
356 where
357 F: Fn(&Tensor, &Tensor) -> CoreResult<Tensor>,
358 {
359 let mut total_loss = 0.0;
360 let num_batches = data_loader.len();
361
362 for (inputs, targets) in data_loader {
363 let predictions = self.model.forward(inputs)?;
365
366 let loss = loss_fn(&predictions, targets)?;
368
369 let loss_val = loss
371 .to_vec0::<f32>()
372 .map_err(|e| CoreError::Generic(format!("Failed to extract loss value: {}", e)))?;
373 total_loss += loss_val;
374 }
375
376 Ok(total_loss / num_batches as f32)
377 }
378
379 pub fn fit<F>(
381 &mut self,
382 mut train_loader: TimeSeriesDataLoader,
383 mut val_loader: Option<TimeSeriesDataLoader>,
384 loss_fn: F,
385 ) -> CoreResult<()>
386 where
387 F: Fn(&Tensor, &Tensor) -> CoreResult<Tensor> + Copy,
388 {
389 use std::time::Instant;
390
391 for epoch in 0..self.config.epochs {
392 let epoch_start = Instant::now();
393
394 train_loader.shuffle();
396
397 let train_batches: Vec<(Tensor, Tensor)> = Vec::new();
401
402 let train_loss = self.train_epoch(&train_batches, loss_fn)?;
404
405 let val_loss = if let Some(ref mut _val_data) = val_loader {
407 let val_batches: Vec<(Tensor, Tensor)> = Vec::new();
408 let val_loss = self.evaluate(&val_batches, loss_fn)?;
409
410 if self.config.track_metrics {
411 self.metrics.record_val_loss(epoch, val_loss);
412 }
413
414 Some(val_loss)
415 } else {
416 None
417 };
418
419 let epoch_duration = epoch_start.elapsed().as_secs_f64();
421 if self.config.track_metrics {
422 self.metrics.record_epoch_duration(epoch, epoch_duration);
423 }
424
425 let current_lr = self.get_current_lr();
427 self.logger
428 .log_epoch(epoch, train_loss, val_loss, current_lr);
429
430 if let Some(patience) = self.config.early_stopping_patience {
432 if !self.metrics.is_improving(patience) {
433 tracing::info!("Early stopping triggered at epoch {}", epoch);
434 break;
435 }
436 }
437 }
438
439 if self.config.track_metrics {
441 let summary = self.metrics.summary();
442 self.logger.log_summary(&summary);
443 }
444
445 Ok(())
446 }
447
448 pub fn model(&self) -> &TrainableSSM {
450 &self.model
451 }
452
453 pub fn model_mut(&mut self) -> &mut TrainableSSM {
455 &mut self.model
456 }
457
458 pub fn metrics(&self) -> &TrainingMetrics {
460 &self.metrics
461 }
462
463 pub fn metrics_mut(&mut self) -> &mut TrainingMetrics {
465 &mut self.metrics
466 }
467
468 pub fn current_step(&self) -> usize {
470 self.current_step
471 }
472
473 pub fn save_checkpoint<P: AsRef<std::path::Path>>(
487 &self,
488 path: P,
489 name: &str,
490 ) -> CoreResult<()> {
491 use std::fs;
492 use std::path::PathBuf;
493
494 let checkpoint_dir = path.as_ref();
495 fs::create_dir_all(checkpoint_dir).map_err(|e| {
496 CoreError::Generic(format!("Failed to create checkpoint directory: {}", e))
497 })?;
498
499 let weights_path: PathBuf = checkpoint_dir.join(format!("{}.safetensors", name));
501 self.model
502 .save_weights(&weights_path)
503 .map_err(|e| CoreError::Generic(format!("Failed to save model weights: {}", e)))?;
504
505 let metadata = CheckpointMetadata {
507 version: env!("CARGO_PKG_VERSION").to_string(),
508 timestamp: chrono::Utc::now().to_rfc3339(),
509 current_step: self.current_step,
510 current_epoch: self.metrics.summary().total_epochs,
511 config: self.config.clone(),
512 metrics: self.metrics.clone(),
513 };
514
515 let metadata_path: PathBuf = checkpoint_dir.join(format!("{}.json", name));
517 let metadata_json = serde_json::to_string_pretty(&metadata).map_err(|e| {
518 CoreError::Generic(format!("Failed to serialize checkpoint metadata: {}", e))
519 })?;
520
521 fs::write(&metadata_path, metadata_json).map_err(|e| {
522 CoreError::Generic(format!("Failed to write checkpoint metadata: {}", e))
523 })?;
524
525 tracing::info!(
526 "Checkpoint saved: weights={}, metadata={}",
527 weights_path.display(),
528 metadata_path.display()
529 );
530
531 Ok(())
532 }
533
534 pub fn load_checkpoint<P: AsRef<std::path::Path>>(
550 path: P,
551 name: &str,
552 model_config: KizzasiConfig,
553 ) -> CoreResult<Self> {
554 use std::fs;
555 use std::path::PathBuf;
556
557 let checkpoint_dir = path.as_ref();
558
559 let metadata_path: PathBuf = checkpoint_dir.join(format!("{}.json", name));
561 let metadata_json = fs::read_to_string(&metadata_path).map_err(|e| {
562 CoreError::Generic(format!("Failed to read checkpoint metadata: {}", e))
563 })?;
564
565 let metadata: CheckpointMetadata = serde_json::from_str(&metadata_json).map_err(|e| {
566 CoreError::Generic(format!("Failed to parse checkpoint metadata: {}", e))
567 })?;
568
569 let weights_path: PathBuf = checkpoint_dir.join(format!("{}.safetensors", name));
571 let mut model = TrainableSSM::new(model_config, metadata.config.clone())?;
572 model
573 .load_weights(&weights_path)
574 .map_err(|e| CoreError::Generic(format!("Failed to load model weights: {}", e)))?;
575
576 let optimizer = model.create_optimizer()?;
578 let scheduler = Self::create_scheduler(&metadata.config);
579
580 let logger = MetricsLogger::new()
581 .with_verbose(metadata.config.track_metrics)
582 .with_log_interval(metadata.config.log_interval);
583
584 tracing::info!(
585 "Checkpoint loaded: version={}, step={}, epoch={}",
586 metadata.version,
587 metadata.current_step,
588 metadata.current_epoch
589 );
590
591 Ok(Self {
592 model,
593 optimizer,
594 config: metadata.config,
595 scheduler,
596 metrics: metadata.metrics,
597 logger,
598 current_step: metadata.current_step,
599 })
600 }
601
602 pub fn save_checkpoint_auto<P: AsRef<std::path::Path>>(&self, path: P) -> CoreResult<()> {
612 let current_epoch = self.metrics.summary().total_epochs;
613 let name = format!("checkpoint_epoch_{}", current_epoch);
614 self.save_checkpoint(path, &name)
615 }
616
617 pub fn save_best_checkpoint<P: AsRef<std::path::Path>>(&self, path: P) -> CoreResult<()> {
627 let summary = self.metrics.summary();
628
629 if let (Some(best_epoch), Some(_best_loss)) = (summary.best_epoch, summary.best_val_loss) {
632 let current_epoch = summary.total_epochs.saturating_sub(1);
634 if current_epoch == best_epoch {
635 tracing::info!("New best validation loss! Saving best checkpoint");
636 return self.save_checkpoint(path, "best");
637 }
638 }
639
640 Ok(())
641 }
642}
643
644#[derive(Debug, Clone, Serialize, Deserialize)]
646pub struct CheckpointMetadata {
647 pub version: String,
649 pub timestamp: String,
651 pub current_step: usize,
653 pub current_epoch: usize,
655 pub config: TrainingConfig,
657 pub metrics: TrainingMetrics,
659}
660
661#[cfg(test)]
662mod tests {
663 use super::*;
664 use crate::training_core::TrainingConfig;
665 use candle_core::{Device, Tensor};
666
667 #[test]
668 fn test_mse_loss() {
669 let device = Device::Cpu;
670 let predictions = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
671 let targets = Tensor::new(&[1.5f32, 2.5, 3.5], &device).unwrap();
672
673 let loss = Loss::mse(&predictions, &targets).unwrap();
674 let loss_val = loss.to_vec0::<f32>().unwrap();
675
676 assert!((loss_val - 0.25).abs() < 1e-5);
678 }
679
680 #[test]
681 fn test_trainer_with_scheduler() {
682 use crate::config::KizzasiConfig;
683 use crate::training_core::{SchedulerType, TrainableSSM};
684
685 let model_config = KizzasiConfig::new()
686 .input_dim(3)
687 .output_dim(3)
688 .hidden_dim(64)
689 .state_dim(8)
690 .num_layers(2);
691
692 let training_config = TrainingConfig::default().with_scheduler(SchedulerType::Linear {
693 warmup_steps: 50,
694 final_lr: 1e-6,
695 });
696
697 let model = TrainableSSM::new(model_config, training_config.clone()).unwrap();
698 let trainer = Trainer::new(model, training_config);
699
700 assert!(trainer.is_ok());
701 let trainer = trainer.unwrap();
702 assert!(trainer.scheduler.is_some());
703 }
704
705 #[test]
706 fn test_trainer_metrics_tracking() {
707 use crate::config::KizzasiConfig;
708 use crate::training_core::TrainableSSM;
709
710 let model_config = KizzasiConfig::new()
711 .input_dim(3)
712 .output_dim(3)
713 .hidden_dim(64)
714 .state_dim(8)
715 .num_layers(2);
716
717 let training_config = TrainingConfig::default();
718 let model = TrainableSSM::new(model_config, training_config.clone()).unwrap();
719 let trainer = Trainer::new(model, training_config).unwrap();
720
721 assert_eq!(trainer.metrics().current_step(), 0);
723 assert_eq!(trainer.current_step(), 0);
724 }
725
726 #[test]
727 fn test_mae_loss() {
728 let device = Device::Cpu;
729 let predictions = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
730 let targets = Tensor::new(&[1.5f32, 2.5, 3.5], &device).unwrap();
731
732 let loss = Loss::mae(&predictions, &targets).unwrap();
733 let loss_val = loss.to_vec0::<f32>().unwrap();
734
735 assert!((loss_val - 0.5).abs() < 1e-5);
737 }
738
739 #[test]
740 fn test_huber_loss() {
741 let device = Device::Cpu;
742 let predictions = Tensor::new(&[1.0f32, 2.0, 5.0], &device).unwrap();
743 let targets = Tensor::new(&[1.1f32, 2.1, 3.0], &device).unwrap();
744
745 let loss = Loss::huber(&predictions, &targets, 1.0).unwrap();
746 let loss_val = loss.to_vec0::<f32>().unwrap();
747
748 assert!(loss_val > 0.0);
750 assert!(loss_val < 2.0); }
752
753 #[test]
754 fn test_constraint_loss_creation() {
755 let constraint_loss = ConstraintLoss::new(0.5);
756 assert_eq!(constraint_loss.constraint_weight, 0.5);
757 }
758
759 #[test]
760 fn test_constraint_loss_no_violation() {
761 let device = Device::Cpu;
762 let predictions = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
763 let targets = Tensor::new(&[1.5f32, 2.5, 3.5], &device).unwrap();
764
765 let task_loss = Loss::mse(&predictions, &targets).unwrap();
766 let task_loss_val = task_loss.to_vec0::<f32>().unwrap();
767
768 let constraint_loss = ConstraintLoss::new(0.5);
769
770 let total_loss = constraint_loss
772 .compute(&task_loss, &predictions, |_pred| Ok(0.0))
773 .unwrap();
774 let total_loss_val = total_loss.to_vec0::<f32>().unwrap();
775
776 assert!((total_loss_val - task_loss_val).abs() < 1e-5);
778 }
779
780 #[test]
781 fn test_constraint_loss_with_violation() {
782 let device = Device::Cpu;
783 let predictions = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
784 let targets = Tensor::new(&[1.5f32, 2.5, 3.5], &device).unwrap();
785
786 let task_loss = Loss::mse(&predictions, &targets).unwrap();
787 let task_loss_val = task_loss.to_vec0::<f32>().unwrap();
788
789 let constraint_loss = ConstraintLoss::new(0.5);
790
791 let total_loss = constraint_loss
793 .compute(&task_loss, &predictions, |_pred| Ok(1.0))
794 .unwrap();
795 let total_loss_val = total_loss.to_vec0::<f32>().unwrap();
796
797 let expected = task_loss_val + 0.5;
799 assert!((total_loss_val - expected).abs() < 1e-5);
800 }
801
802 #[test]
803 fn test_constraint_loss_scaling() {
804 let device = Device::Cpu;
805 let predictions = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
806 let targets = Tensor::new(&[1.5f32, 2.5, 3.5], &device).unwrap();
807
808 let task_loss = Loss::mse(&predictions, &targets).unwrap();
809 let task_loss_val = task_loss.to_vec0::<f32>().unwrap();
810
811 let weights = [0.1, 0.5, 1.0, 2.0];
813 let violation = 1.5;
814
815 for &weight in &weights {
816 let constraint_loss = ConstraintLoss::new(weight);
817 let total_loss = constraint_loss
818 .compute(&task_loss, &predictions, |_pred| Ok(violation))
819 .unwrap();
820 let total_loss_val = total_loss.to_vec0::<f32>().unwrap();
821
822 let expected = task_loss_val + weight * violation;
823 assert!(
824 (total_loss_val - expected).abs() < 1e-4,
825 "Weight {} failed: got {}, expected {}",
826 weight,
827 total_loss_val,
828 expected
829 );
830 }
831 }
832
833 #[test]
834 fn test_checkpoint_save_load() {
835 use crate::config::KizzasiConfig;
836 use crate::training_core::TrainableSSM;
837 use std::env;
838 use std::fs;
839
840 let temp_dir = env::temp_dir().join("kizzasi_checkpoint_test");
841 fs::create_dir_all(&temp_dir).unwrap();
842
843 let config = KizzasiConfig::new()
845 .input_dim(3)
846 .output_dim(3)
847 .hidden_dim(64)
848 .state_dim(8)
849 .num_layers(2);
850
851 let training_config = TrainingConfig {
852 epochs: 5,
853 learning_rate: 1e-3,
854 ..Default::default()
855 };
856
857 let model = TrainableSSM::new(config.clone(), training_config.clone()).unwrap();
858 let trainer = Trainer::new(model, training_config).unwrap();
859
860 trainer
862 .save_checkpoint(&temp_dir, "test_checkpoint")
863 .unwrap();
864
865 assert!(temp_dir.join("test_checkpoint.safetensors").exists());
867 assert!(temp_dir.join("test_checkpoint.json").exists());
868
869 let loaded_trainer =
871 Trainer::load_checkpoint(&temp_dir, "test_checkpoint", config).unwrap();
872
873 assert_eq!(loaded_trainer.config.epochs, 5);
875 assert_eq!(loaded_trainer.config.learning_rate, 1e-3);
876 assert_eq!(loaded_trainer.current_step, 0);
877
878 fs::remove_dir_all(&temp_dir).unwrap();
880 }
881
882 #[test]
883 fn test_checkpoint_auto_save() {
884 use crate::config::KizzasiConfig;
885 use crate::training_core::TrainableSSM;
886 use std::env;
887 use std::fs;
888
889 let temp_dir = env::temp_dir().join("kizzasi_checkpoint_auto_test");
890 fs::create_dir_all(&temp_dir).unwrap();
891
892 let config = KizzasiConfig::new()
893 .input_dim(3)
894 .output_dim(3)
895 .hidden_dim(64)
896 .state_dim(8)
897 .num_layers(2);
898
899 let training_config = TrainingConfig::default();
900 let model = TrainableSSM::new(config, training_config.clone()).unwrap();
901 let mut trainer = Trainer::new(model, training_config).unwrap();
902
903 trainer.metrics.record_train_loss(0, 0.5);
905
906 trainer.save_checkpoint_auto(&temp_dir).unwrap();
908
909 assert!(temp_dir.join("checkpoint_epoch_1.safetensors").exists());
911 assert!(temp_dir.join("checkpoint_epoch_1.json").exists());
912
913 fs::remove_dir_all(&temp_dir).unwrap();
915 }
916
917 #[test]
918 fn test_checkpoint_best_save() {
919 use crate::config::KizzasiConfig;
920 use crate::training_core::TrainableSSM;
921 use std::env;
922 use std::fs;
923
924 let temp_dir = env::temp_dir().join("kizzasi_checkpoint_best_test");
925 fs::create_dir_all(&temp_dir).unwrap();
926
927 let config = KizzasiConfig::new()
928 .input_dim(3)
929 .output_dim(3)
930 .hidden_dim(64)
931 .state_dim(8)
932 .num_layers(2);
933
934 let training_config = TrainingConfig::default();
935 let model = TrainableSSM::new(config, training_config.clone()).unwrap();
936 let mut trainer = Trainer::new(model, training_config).unwrap();
937
938 trainer.metrics.record_train_loss(0, 1.2);
940 trainer.metrics.record_val_loss(0, 1.0);
941 trainer.save_best_checkpoint(&temp_dir).unwrap();
942
943 assert!(temp_dir.join("best.safetensors").exists());
945 assert!(temp_dir.join("best.json").exists());
946
947 trainer.metrics.record_train_loss(1, 0.9);
949 trainer.metrics.record_val_loss(1, 1.2);
950
951 fs::remove_file(temp_dir.join("best.safetensors")).unwrap();
953 fs::remove_file(temp_dir.join("best.json")).unwrap();
954
955 trainer.save_best_checkpoint(&temp_dir).unwrap();
956 assert!(!temp_dir.join("best.safetensors").exists());
958
959 fs::remove_dir_all(&temp_dir).unwrap();
961 }
962
963 #[test]
964 fn test_checkpoint_metadata() {
965 use crate::config::KizzasiConfig;
966 use crate::training_core::TrainableSSM;
967 use std::env;
968 use std::fs;
969
970 let temp_dir = env::temp_dir().join("kizzasi_checkpoint_metadata_test");
971 fs::create_dir_all(&temp_dir).unwrap();
972
973 let config = KizzasiConfig::new()
974 .input_dim(3)
975 .output_dim(3)
976 .hidden_dim(64)
977 .state_dim(8)
978 .num_layers(2);
979
980 let training_config = TrainingConfig::default();
981 let model = TrainableSSM::new(config, training_config.clone()).unwrap();
982 let mut trainer = Trainer::new(model, training_config).unwrap();
983
984 trainer.metrics.record_train_loss(0, 0.5);
986 trainer.metrics.record_val_loss(0, 0.45);
987
988 trainer.save_checkpoint(&temp_dir, "metadata_test").unwrap();
990
991 let metadata_path = temp_dir.join("metadata_test.json");
993 let metadata_json = fs::read_to_string(&metadata_path).unwrap();
994 let metadata: CheckpointMetadata = serde_json::from_str(&metadata_json).unwrap();
995
996 assert_eq!(metadata.version, env!("CARGO_PKG_VERSION"));
997 assert!(!metadata.timestamp.is_empty());
998 assert_eq!(metadata.current_step, 0);
999 assert!(metadata.metrics.val_loss(0).is_some());
1000 assert_eq!(metadata.metrics.val_loss(0).unwrap(), 0.45);
1001
1002 fs::remove_dir_all(&temp_dir).unwrap();
1004 }
1005}