1use crate::config::KizzasiConfig;
16use crate::dataloader::TimeSeriesDataLoader;
17use crate::device::DeviceConfig;
18use crate::error::{CoreError, CoreResult};
19use crate::metrics::{MetricsLogger, TrainingMetrics};
20use crate::scheduler::LRScheduler;
21use candle_core::{DType, Device, Tensor, Var};
22use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
23use serde::{Deserialize, Serialize};
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub enum SchedulerType {
28 Constant,
29 Linear {
30 warmup_steps: usize,
31 final_lr: f64,
32 },
33 Cosine {
34 warmup_steps: usize,
35 min_lr: f64,
36 },
37 Step {
38 milestones: Vec<usize>,
39 decay_factor: f64,
40 },
41 Exponential {
42 decay_rate: f64,
43 decay_steps: usize,
44 },
45 OneCycle {
46 warmup_pct: f64,
47 },
48 Polynomial {
49 final_lr: f64,
50 power: f64,
51 },
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
56pub enum MixedPrecision {
57 None,
59 FP16,
61 BF16,
63}
64
65impl MixedPrecision {
66 pub fn to_dtype(&self) -> DType {
68 match self {
69 MixedPrecision::None => DType::F32,
70 MixedPrecision::FP16 => DType::F16,
71 MixedPrecision::BF16 => DType::BF16,
72 }
73 }
74
75 pub fn is_enabled(&self) -> bool {
77 !matches!(self, MixedPrecision::None)
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct TrainingConfig {
84 pub device_config: DeviceConfig,
86 pub learning_rate: f64,
88 pub batch_size: usize,
90 pub epochs: usize,
92 pub weight_decay: f64,
94 pub grad_clip: Option<f32>,
96 pub beta1: f64,
98 pub beta2: f64,
100 pub eps: f64,
102 pub scheduler: Option<SchedulerType>,
104 pub track_metrics: bool,
106 pub log_interval: usize,
108 pub validation_split: f32,
110 pub early_stopping_patience: Option<usize>,
112 pub use_gradient_checkpointing: bool,
114 pub checkpoint_segment_size: Option<usize>,
116 pub mixed_precision: MixedPrecision,
118 pub loss_scale: f32,
120}
121
122impl Default for TrainingConfig {
123 fn default() -> Self {
124 Self {
125 device_config: DeviceConfig::default(),
126 learning_rate: 1e-4,
127 batch_size: 32,
128 epochs: 10,
129 weight_decay: 1e-2,
130 grad_clip: Some(1.0),
131 beta1: 0.9,
132 beta2: 0.999,
133 eps: 1e-8,
134 scheduler: None,
135 track_metrics: true,
136 log_interval: 10,
137 validation_split: 0.2,
138 early_stopping_patience: Some(5),
139 use_gradient_checkpointing: false,
140 checkpoint_segment_size: Some(2), mixed_precision: MixedPrecision::None,
142 loss_scale: 1.0, }
144 }
145}
146
147impl TrainingConfig {
148 pub fn with_scheduler(mut self, scheduler: SchedulerType) -> Self {
150 self.scheduler = Some(scheduler);
151 self
152 }
153
154 pub fn without_metrics(mut self) -> Self {
156 self.track_metrics = false;
157 self
158 }
159
160 pub fn with_validation_split(mut self, split: f32) -> Self {
162 self.validation_split = split;
163 self
164 }
165
166 pub fn with_early_stopping(mut self, patience: usize) -> Self {
168 self.early_stopping_patience = Some(patience);
169 self
170 }
171
172 pub fn without_early_stopping(mut self) -> Self {
174 self.early_stopping_patience = None;
175 self
176 }
177
178 pub fn with_gradient_checkpointing(mut self, segment_size: Option<usize>) -> Self {
180 self.use_gradient_checkpointing = true;
181 self.checkpoint_segment_size = segment_size;
182 self
183 }
184
185 pub fn without_gradient_checkpointing(mut self) -> Self {
187 self.use_gradient_checkpointing = false;
188 self
189 }
190
191 pub fn with_fp16(mut self) -> Self {
193 self.mixed_precision = MixedPrecision::FP16;
194 self.loss_scale = 128.0; self
196 }
197
198 pub fn with_bf16(mut self) -> Self {
200 self.mixed_precision = MixedPrecision::BF16;
201 self.loss_scale = 1.0; self
203 }
204
205 pub fn with_mixed_precision(mut self, mode: MixedPrecision, loss_scale: f32) -> Self {
207 self.mixed_precision = mode;
208 self.loss_scale = loss_scale;
209 self
210 }
211
212 pub fn without_mixed_precision(mut self) -> Self {
214 self.mixed_precision = MixedPrecision::None;
215 self.loss_scale = 1.0;
216 self
217 }
218}
219
220pub struct TrainableSSM {
222 config: KizzasiConfig,
223 training_config: TrainingConfig,
224 device: Device,
225 dtype: DType,
226 embedding_weight: Var,
228 a_matrices: Vec<Var>,
229 b_matrices: Vec<Var>,
230 c_matrices: Vec<Var>,
231 d_vectors: Vec<Var>,
232 output_proj: Var,
233 ln_gamma: Vec<Var>,
235 ln_beta: Vec<Var>,
236 varmap: VarMap,
238}
239
240impl TrainableSSM {
241 pub fn new(config: KizzasiConfig, training_config: TrainingConfig) -> CoreResult<Self> {
243 let device = training_config.device_config.create_device()?;
245
246 let dtype = training_config.mixed_precision.to_dtype();
248
249 let hidden_dim = config.get_hidden_dim();
250 let state_dim = config.get_state_dim();
251 let num_layers = config.get_num_layers();
252 let input_dim = config.get_input_dim();
253 let output_dim = config.get_output_dim();
254
255 let varmap = VarMap::new();
256 let vb = VarBuilder::from_varmap(&varmap, dtype, &device);
257
258 let embedding_weight_tensor = vb
260 .get_with_hints(
261 (input_dim, hidden_dim),
262 "embedding.weight",
263 candle_nn::init::DEFAULT_KAIMING_NORMAL,
264 )
265 .map_err(|e| CoreError::Generic(format!("Failed to create embedding: {}", e)))?;
266 let embedding_weight = Var::from_tensor(&embedding_weight_tensor)
267 .map_err(|e| CoreError::Generic(format!("Failed to create embedding var: {}", e)))?;
268
269 let mut a_matrices = Vec::with_capacity(num_layers);
271 let mut b_matrices = Vec::with_capacity(num_layers);
272 let mut c_matrices = Vec::with_capacity(num_layers);
273 let mut d_vectors = Vec::with_capacity(num_layers);
274 let mut ln_gamma = Vec::with_capacity(num_layers);
275 let mut ln_beta = Vec::with_capacity(num_layers);
276
277 for layer_idx in 0..num_layers {
278 let a_tensor = vb
280 .get_with_hints(
281 (hidden_dim, state_dim),
282 &format!("ssm.layer_{}.a", layer_idx),
283 candle_nn::init::Init::Const(-0.5),
284 )
285 .map_err(|e| CoreError::Generic(format!("Failed to create A matrix: {}", e)))?;
286 let a = Var::from_tensor(&a_tensor)
287 .map_err(|e| CoreError::Generic(format!("Failed to create A var: {}", e)))?;
288 a_matrices.push(a);
289
290 let b_tensor = vb
292 .get_with_hints(
293 (hidden_dim, state_dim),
294 &format!("ssm.layer_{}.b", layer_idx),
295 candle_nn::init::DEFAULT_KAIMING_NORMAL,
296 )
297 .map_err(|e| CoreError::Generic(format!("Failed to create B matrix: {}", e)))?;
298 let b = Var::from_tensor(&b_tensor)
299 .map_err(|e| CoreError::Generic(format!("Failed to create B var: {}", e)))?;
300 b_matrices.push(b);
301
302 let c_tensor = vb
304 .get_with_hints(
305 (hidden_dim, state_dim),
306 &format!("ssm.layer_{}.c", layer_idx),
307 candle_nn::init::DEFAULT_KAIMING_NORMAL,
308 )
309 .map_err(|e| CoreError::Generic(format!("Failed to create C matrix: {}", e)))?;
310 let c = Var::from_tensor(&c_tensor)
311 .map_err(|e| CoreError::Generic(format!("Failed to create C var: {}", e)))?;
312 c_matrices.push(c);
313
314 let d_tensor = vb
316 .get_with_hints(
317 hidden_dim,
318 &format!("ssm.layer_{}.d", layer_idx),
319 candle_nn::init::Init::Const(1.0),
320 )
321 .map_err(|e| CoreError::Generic(format!("Failed to create D vector: {}", e)))?;
322 let d = Var::from_tensor(&d_tensor)
323 .map_err(|e| CoreError::Generic(format!("Failed to create D var: {}", e)))?;
324 d_vectors.push(d);
325
326 let gamma_tensor = vb
328 .get_with_hints(
329 hidden_dim,
330 &format!("ln.layer_{}.gamma", layer_idx),
331 candle_nn::init::Init::Const(1.0),
332 )
333 .map_err(|e| CoreError::Generic(format!("Failed to create LN gamma: {}", e)))?;
334 let gamma = Var::from_tensor(&gamma_tensor)
335 .map_err(|e| CoreError::Generic(format!("Failed to create LN gamma var: {}", e)))?;
336 ln_gamma.push(gamma);
337
338 let beta_tensor = vb
339 .get_with_hints(
340 hidden_dim,
341 &format!("ln.layer_{}.beta", layer_idx),
342 candle_nn::init::Init::Const(0.0),
343 )
344 .map_err(|e| CoreError::Generic(format!("Failed to create LN beta: {}", e)))?;
345 let beta = Var::from_tensor(&beta_tensor)
346 .map_err(|e| CoreError::Generic(format!("Failed to create LN beta var: {}", e)))?;
347 ln_beta.push(beta);
348 }
349
350 let output_proj_tensor = vb
352 .get_with_hints(
353 (hidden_dim, output_dim),
354 "output.proj",
355 candle_nn::init::DEFAULT_KAIMING_NORMAL,
356 )
357 .map_err(|e| {
358 CoreError::Generic(format!("Failed to create output projection: {}", e))
359 })?;
360 let output_proj = Var::from_tensor(&output_proj_tensor)
361 .map_err(|e| CoreError::Generic(format!("Failed to create output proj var: {}", e)))?;
362
363 Ok(Self {
364 config,
365 training_config,
366 device,
367 dtype,
368 embedding_weight,
369 a_matrices,
370 b_matrices,
371 c_matrices,
372 d_vectors,
373 output_proj,
374 ln_gamma,
375 ln_beta,
376 varmap,
377 })
378 }
379
380 pub fn forward(&self, input: &Tensor) -> CoreResult<Tensor> {
389 let batch_size = input
392 .dim(0)
393 .map_err(|e| CoreError::Generic(format!("Failed to get batch dimension: {}", e)))?;
394 let seq_len = input
395 .dim(1)
396 .map_err(|e| CoreError::Generic(format!("Failed to get sequence dimension: {}", e)))?;
397 let input_dim = input
398 .dim(2)
399 .map_err(|e| CoreError::Generic(format!("Failed to get input dimension: {}", e)))?;
400
401 let x_flat = input
402 .reshape((batch_size * seq_len, input_dim))
403 .map_err(|e| CoreError::Generic(format!("Failed to reshape input: {}", e)))?;
404
405 let hidden_dim = self.config.get_hidden_dim();
406 let x_embedded = x_flat
407 .matmul(self.embedding_weight.as_tensor())
408 .map_err(|e| CoreError::Generic(format!("Embedding forward failed: {}", e)))?;
409
410 let x = x_embedded
411 .reshape((batch_size, seq_len, hidden_dim))
412 .map_err(|e| CoreError::Generic(format!("Failed to reshape embedded: {}", e)))?;
413
414 let state_dim = self.config.get_state_dim();
416
417 let mut h = Tensor::zeros(
418 (batch_size, hidden_dim, state_dim),
419 self.dtype,
420 &self.device,
421 )
422 .map_err(|e| CoreError::Generic(format!("Failed to create hidden state: {}", e)))?;
423
424 let mut x = x;
425
426 for layer_idx in 0..self.config.get_num_layers() {
428 x = self.layer_norm(&x, layer_idx)?;
429 x = self.ssm_layer(&x, &mut h, layer_idx)?;
430 }
431
432 let x_flat = x
435 .reshape((batch_size * seq_len, hidden_dim))
436 .map_err(|e| CoreError::Generic(format!("Failed to reshape for output: {}", e)))?;
437
438 let output_dim = self.config.get_output_dim();
439 let output_flat = x_flat
440 .matmul(self.output_proj.as_tensor())
441 .map_err(|e| CoreError::Generic(format!("Output projection failed: {}", e)))?;
442
443 let output = output_flat
444 .reshape((batch_size, seq_len, output_dim))
445 .map_err(|e| CoreError::Generic(format!("Failed to reshape output: {}", e)))?;
446
447 Ok(output)
448 }
449
450 fn layer_norm(&self, x: &Tensor, layer_idx: usize) -> CoreResult<Tensor> {
452 const EPS: f64 = 1e-5;
453
454 let mean = x
456 .mean_keepdim(candle_core::D::Minus1)
457 .map_err(|e| CoreError::Generic(format!("Layer norm mean failed: {}", e)))?;
458 let x_centered = x.broadcast_sub(&mean).map_err(|e| {
459 CoreError::Generic(format!("Layer norm variance computation failed: {}", e))
460 })?;
461 let variance = x_centered
462 .sqr()
463 .map_err(|e| CoreError::Generic(format!("Layer norm variance sqr failed: {}", e)))?
464 .mean_keepdim(candle_core::D::Minus1)
465 .map_err(|e| CoreError::Generic(format!("Layer norm variance mean failed: {}", e)))?;
466
467 let std = (variance.affine(1.0, EPS))
469 .map_err(|e| CoreError::Generic(format!("Layer norm variance add eps failed: {}", e)))?
470 .sqrt()
471 .map_err(|e| CoreError::Generic(format!("Layer norm sqrt failed: {}", e)))?;
472
473 let normalized = x_centered
474 .broadcast_div(&std)
475 .map_err(|e| CoreError::Generic(format!("Layer norm division failed: {}", e)))?;
476
477 let gamma = self.ln_gamma[layer_idx].as_tensor();
479 let beta = self.ln_beta[layer_idx].as_tensor();
480
481 normalized
482 .broadcast_mul(gamma)
483 .map_err(|e| CoreError::Generic(format!("Layer norm gamma mul failed: {}", e)))?
484 .broadcast_add(beta)
485 .map_err(|e| CoreError::Generic(format!("Layer norm beta add failed: {}", e)))
486 }
487
488 fn ssm_layer(&self, x: &Tensor, _h: &mut Tensor, layer_idx: usize) -> CoreResult<Tensor> {
490 let _a = self.a_matrices[layer_idx].as_tensor();
491 let _b = self.b_matrices[layer_idx].as_tensor();
492 let _c = self.c_matrices[layer_idx].as_tensor();
493 let d = self.d_vectors[layer_idx].as_tensor();
494
495 let y = x
502 .broadcast_mul(d)
503 .map_err(|e| CoreError::Generic(format!("Skip connection failed: {}", e)))?;
504
505 Ok(y)
506 }
507
508 pub fn create_optimizer(&self) -> CoreResult<AdamW> {
510 let params = ParamsAdamW {
511 lr: self.training_config.learning_rate,
512 beta1: self.training_config.beta1,
513 beta2: self.training_config.beta2,
514 eps: self.training_config.eps,
515 weight_decay: self.training_config.weight_decay,
516 };
517
518 AdamW::new(self.varmap.all_vars(), params)
519 .map_err(|e| CoreError::Generic(format!("Failed to create optimizer: {}", e)))
520 }
521
522 pub fn varmap(&self) -> &VarMap {
524 &self.varmap
525 }
526
527 pub fn device(&self) -> &Device {
529 &self.device
530 }
531
532 pub fn dtype(&self) -> DType {
534 self.dtype
535 }
536
537 pub fn save_weights<P: AsRef<std::path::Path>>(&self, path: P) -> CoreResult<()> {
547 self.varmap
548 .save(path)
549 .map_err(|e| CoreError::Generic(format!("Failed to save weights: {}", e)))
550 }
551
552 pub fn load_weights<P: AsRef<std::path::Path>>(&mut self, path: P) -> CoreResult<()> {
562 self.varmap
563 .load(path)
564 .map_err(|e| CoreError::Generic(format!("Failed to load weights: {}", e)))
565 }
566}
567
568pub struct ConstraintLoss {
588 constraint_weight: f32,
590}
591
592impl ConstraintLoss {
593 pub fn new(constraint_weight: f32) -> Self {
595 Self { constraint_weight }
596 }
597
598 pub fn compute<F>(
605 &self,
606 task_loss: &Tensor,
607 prediction: &Tensor,
608 constraint_fn: F,
609 ) -> CoreResult<Tensor>
610 where
611 F: Fn(&Tensor) -> CoreResult<f32>,
612 {
613 let violation = constraint_fn(prediction)?;
615
616 let penalty_value = self.constraint_weight * violation;
619
620 task_loss
622 .affine(1.0, penalty_value as f64)
623 .map_err(|e| CoreError::Generic(format!("Failed to add constraint penalty: {}", e)))
624 }
625}
626
627pub struct Loss;
629
630impl Loss {
631 pub fn mse(predictions: &Tensor, targets: &Tensor) -> CoreResult<Tensor> {
633 predictions
634 .sub(targets)
635 .map_err(|e| CoreError::Generic(format!("MSE subtraction failed: {}", e)))?
636 .sqr()
637 .map_err(|e| CoreError::Generic(format!("MSE square failed: {}", e)))?
638 .mean_all()
639 .map_err(|e| CoreError::Generic(format!("MSE mean failed: {}", e)))
640 }
641
642 pub fn mae(predictions: &Tensor, targets: &Tensor) -> CoreResult<Tensor> {
644 predictions
645 .sub(targets)
646 .map_err(|e| CoreError::Generic(format!("MAE subtraction failed: {}", e)))?
647 .abs()
648 .map_err(|e| CoreError::Generic(format!("MAE abs failed: {}", e)))?
649 .mean_all()
650 .map_err(|e| CoreError::Generic(format!("MAE mean failed: {}", e)))
651 }
652
653 pub fn huber(predictions: &Tensor, targets: &Tensor, delta: f64) -> CoreResult<Tensor> {
655 let diff = predictions
656 .sub(targets)
657 .map_err(|e| CoreError::Generic(format!("Huber subtraction failed: {}", e)))?;
658 let abs_diff = diff
659 .abs()
660 .map_err(|e| CoreError::Generic(format!("Huber abs failed: {}", e)))?;
661
662 let squared = diff
665 .sqr()
666 .map_err(|e| CoreError::Generic(format!("Huber square failed: {}", e)))?
667 .affine(0.5, 0.0)
668 .map_err(|e| CoreError::Generic(format!("Huber mul 0.5 failed: {}", e)))?;
669
670 let linear_offset = delta * delta * 0.5;
671 let linear = abs_diff
672 .affine(delta, -linear_offset)
673 .map_err(|e| CoreError::Generic(format!("Huber linear computation failed: {}", e)))?;
674
675 let mask = abs_diff
676 .le(delta)
677 .map_err(|e| CoreError::Generic(format!("Huber comparison failed: {}", e)))?
678 .to_dtype(predictions.dtype())
679 .map_err(|e| CoreError::Generic(format!("Huber mask conversion failed: {}", e)))?;
680
681 let inv_mask = mask
683 .affine(-1.0, 1.0)
684 .map_err(|e| CoreError::Generic(format!("Huber mask inversion failed: {}", e)))?;
685
686 let loss = squared
687 .mul(&mask)
688 .map_err(|e| CoreError::Generic(format!("Huber squared mul failed: {}", e)))?
689 .add(
690 &linear
691 .mul(&inv_mask)
692 .map_err(|e| CoreError::Generic(format!("Huber linear mul failed: {}", e)))?,
693 )
694 .map_err(|e| CoreError::Generic(format!("Huber final add failed: {}", e)))?;
695
696 loss.mean_all()
697 .map_err(|e| CoreError::Generic(format!("Huber mean failed: {}", e)))
698 }
699
700 pub fn cross_entropy(logits: &Tensor, targets: &Tensor) -> CoreResult<Tensor> {
702 let log_probs = candle_nn::ops::log_softmax(logits, candle_core::D::Minus1)
704 .map_err(|e| CoreError::Generic(format!("Log softmax failed: {}", e)))?;
705
706 let nll = log_probs
708 .mul(targets)
709 .map_err(|e| CoreError::Generic(format!("NLL multiplication failed: {}", e)))?
710 .sum_all()
711 .map_err(|e| CoreError::Generic(format!("NLL sum failed: {}", e)))?
712 .neg()
713 .map_err(|e| CoreError::Generic(format!("NLL negation failed: {}", e)))?;
714
715 let batch_size = logits
717 .dim(0)
718 .map_err(|e| CoreError::Generic(format!("Failed to get batch size: {}", e)))?;
719 nll.affine(1.0 / batch_size as f64, 0.0)
720 .map_err(|e| CoreError::Generic(format!("Cross entropy division failed: {}", e)))
721 }
722}
723
724pub struct Trainer {
726 model: TrainableSSM,
727 optimizer: AdamW,
728 config: TrainingConfig,
729 scheduler: Option<Box<dyn LRScheduler>>,
730 metrics: TrainingMetrics,
731 logger: MetricsLogger,
732 current_step: usize,
733}
734
735impl Trainer {
736 pub fn new(model: TrainableSSM, config: TrainingConfig) -> CoreResult<Self> {
738 let optimizer = model.create_optimizer()?;
739
740 let scheduler = Self::create_scheduler(&config);
742
743 let metrics = TrainingMetrics::new();
744
745 let logger = MetricsLogger::new()
746 .with_verbose(config.track_metrics)
747 .with_log_interval(config.log_interval);
748
749 Ok(Self {
750 model,
751 optimizer,
752 config,
753 scheduler,
754 metrics,
755 logger,
756 current_step: 0,
757 })
758 }
759
760 fn create_scheduler(config: &TrainingConfig) -> Option<Box<dyn LRScheduler>> {
762 use crate::scheduler::*;
763
764 config.scheduler.as_ref().map(|sched_type| {
765 let total_steps = config.epochs * 100; match sched_type {
768 SchedulerType::Constant => {
769 Box::new(ConstantScheduler::new(config.learning_rate)) as Box<dyn LRScheduler>
770 }
771 SchedulerType::Linear {
772 warmup_steps,
773 final_lr,
774 } => Box::new(LinearScheduler::new(
775 config.learning_rate,
776 *final_lr,
777 total_steps,
778 *warmup_steps,
779 )) as Box<dyn LRScheduler>,
780 SchedulerType::Cosine {
781 warmup_steps,
782 min_lr,
783 } => Box::new(
784 CosineScheduler::new(config.learning_rate, total_steps, *warmup_steps)
785 .with_min_lr(*min_lr),
786 ) as Box<dyn LRScheduler>,
787 SchedulerType::Step {
788 milestones,
789 decay_factor,
790 } => Box::new(StepScheduler::new(
791 config.learning_rate,
792 *decay_factor,
793 milestones.clone(),
794 )) as Box<dyn LRScheduler>,
795 SchedulerType::Exponential {
796 decay_rate,
797 decay_steps,
798 } => Box::new(ExponentialScheduler::new(
799 config.learning_rate,
800 *decay_rate,
801 *decay_steps,
802 )) as Box<dyn LRScheduler>,
803 SchedulerType::OneCycle { warmup_pct } => Box::new(
804 OneCycleScheduler::new(config.learning_rate, total_steps)
805 .with_warmup_pct(*warmup_pct),
806 ) as Box<dyn LRScheduler>,
807 SchedulerType::Polynomial { final_lr, power } => Box::new(PolynomialScheduler::new(
808 config.learning_rate,
809 *final_lr,
810 total_steps,
811 *power,
812 ))
813 as Box<dyn LRScheduler>,
814 }
815 })
816 }
817
818 fn get_current_lr(&self) -> f64 {
820 self.scheduler
821 .as_ref()
822 .map(|s| s.get_lr(self.current_step))
823 .unwrap_or(self.config.learning_rate)
824 }
825
826 pub fn train_epoch<F>(
828 &mut self,
829 data_loader: &[(Tensor, Tensor)],
830 loss_fn: F,
831 ) -> CoreResult<f32>
832 where
833 F: Fn(&Tensor, &Tensor) -> CoreResult<Tensor>,
834 {
835 let mut total_loss = 0.0;
836 let num_batches = data_loader.len();
837 let epoch = self.current_step / num_batches.max(1);
838
839 for (batch_idx, (inputs, targets)) in data_loader.iter().enumerate() {
840 let lr = self.get_current_lr();
842 if self.config.track_metrics {
843 self.metrics.record_learning_rate(lr);
844 }
845
846 let predictions = self.model.forward(inputs)?;
848
849 let loss = loss_fn(&predictions, targets)?;
851
852 self.optimizer
854 .backward_step(&loss)
855 .map_err(|e| CoreError::Generic(format!("Backward step failed: {}", e)))?;
856
857 let loss_val = loss
859 .to_vec0::<f32>()
860 .map_err(|e| CoreError::Generic(format!("Failed to extract loss value: {}", e)))?;
861 total_loss += loss_val;
862
863 if self.config.track_metrics {
865 self.metrics.record_train_loss(epoch, loss_val);
866 self.logger.log_batch(epoch, batch_idx, loss_val);
867
868 let grad_norm = self.compute_grad_norm()?;
870 self.metrics.record_grad_norm(grad_norm);
871 }
872
873 if let Some(max_norm) = self.config.grad_clip {
875 self.clip_gradients(max_norm)?;
876 }
877
878 self.current_step += 1;
879 }
880
881 Ok(total_loss / num_batches as f32)
882 }
883
884 fn compute_grad_norm(&self) -> CoreResult<f32> {
886 Ok(1.0)
890 }
891
892 fn clip_gradients(&self, _max_norm: f32) -> CoreResult<()> {
897 Ok(())
900 }
901
902 pub fn evaluate<F>(&self, data_loader: &[(Tensor, Tensor)], loss_fn: F) -> CoreResult<f32>
904 where
905 F: Fn(&Tensor, &Tensor) -> CoreResult<Tensor>,
906 {
907 let mut total_loss = 0.0;
908 let num_batches = data_loader.len();
909
910 for (inputs, targets) in data_loader {
911 let predictions = self.model.forward(inputs)?;
913
914 let loss = loss_fn(&predictions, targets)?;
916
917 let loss_val = loss
919 .to_vec0::<f32>()
920 .map_err(|e| CoreError::Generic(format!("Failed to extract loss value: {}", e)))?;
921 total_loss += loss_val;
922 }
923
924 Ok(total_loss / num_batches as f32)
925 }
926
927 pub fn fit<F>(
929 &mut self,
930 mut train_loader: TimeSeriesDataLoader,
931 mut val_loader: Option<TimeSeriesDataLoader>,
932 loss_fn: F,
933 ) -> CoreResult<()>
934 where
935 F: Fn(&Tensor, &Tensor) -> CoreResult<Tensor> + Copy,
936 {
937 use std::time::Instant;
938
939 for epoch in 0..self.config.epochs {
940 let epoch_start = Instant::now();
941
942 train_loader.shuffle();
944
945 let train_batches: Vec<(Tensor, Tensor)> = Vec::new();
949
950 let train_loss = self.train_epoch(&train_batches, loss_fn)?;
952
953 let val_loss = if let Some(ref mut _val_data) = val_loader {
955 let val_batches: Vec<(Tensor, Tensor)> = Vec::new();
956 let val_loss = self.evaluate(&val_batches, loss_fn)?;
957
958 if self.config.track_metrics {
959 self.metrics.record_val_loss(epoch, val_loss);
960 }
961
962 Some(val_loss)
963 } else {
964 None
965 };
966
967 let epoch_duration = epoch_start.elapsed().as_secs_f64();
969 if self.config.track_metrics {
970 self.metrics.record_epoch_duration(epoch, epoch_duration);
971 }
972
973 let current_lr = self.get_current_lr();
975 self.logger
976 .log_epoch(epoch, train_loss, val_loss, current_lr);
977
978 if let Some(patience) = self.config.early_stopping_patience {
980 if !self.metrics.is_improving(patience) {
981 tracing::info!("Early stopping triggered at epoch {}", epoch);
982 break;
983 }
984 }
985 }
986
987 if self.config.track_metrics {
989 let summary = self.metrics.summary();
990 self.logger.log_summary(&summary);
991 }
992
993 Ok(())
994 }
995
996 pub fn model(&self) -> &TrainableSSM {
998 &self.model
999 }
1000
1001 pub fn model_mut(&mut self) -> &mut TrainableSSM {
1003 &mut self.model
1004 }
1005
1006 pub fn metrics(&self) -> &TrainingMetrics {
1008 &self.metrics
1009 }
1010
1011 pub fn metrics_mut(&mut self) -> &mut TrainingMetrics {
1013 &mut self.metrics
1014 }
1015
1016 pub fn current_step(&self) -> usize {
1018 self.current_step
1019 }
1020
1021 pub fn save_checkpoint<P: AsRef<std::path::Path>>(
1035 &self,
1036 path: P,
1037 name: &str,
1038 ) -> CoreResult<()> {
1039 use std::fs;
1040 use std::path::PathBuf;
1041
1042 let checkpoint_dir = path.as_ref();
1043 fs::create_dir_all(checkpoint_dir).map_err(|e| {
1044 CoreError::Generic(format!("Failed to create checkpoint directory: {}", e))
1045 })?;
1046
1047 let weights_path: PathBuf = checkpoint_dir.join(format!("{}.safetensors", name));
1049 self.model
1050 .save_weights(&weights_path)
1051 .map_err(|e| CoreError::Generic(format!("Failed to save model weights: {}", e)))?;
1052
1053 let metadata = CheckpointMetadata {
1055 version: env!("CARGO_PKG_VERSION").to_string(),
1056 timestamp: chrono::Utc::now().to_rfc3339(),
1057 current_step: self.current_step,
1058 current_epoch: self.metrics.summary().total_epochs,
1059 config: self.config.clone(),
1060 metrics: self.metrics.clone(),
1061 };
1062
1063 let metadata_path: PathBuf = checkpoint_dir.join(format!("{}.json", name));
1065 let metadata_json = serde_json::to_string_pretty(&metadata).map_err(|e| {
1066 CoreError::Generic(format!("Failed to serialize checkpoint metadata: {}", e))
1067 })?;
1068
1069 fs::write(&metadata_path, metadata_json).map_err(|e| {
1070 CoreError::Generic(format!("Failed to write checkpoint metadata: {}", e))
1071 })?;
1072
1073 tracing::info!(
1074 "Checkpoint saved: weights={}, metadata={}",
1075 weights_path.display(),
1076 metadata_path.display()
1077 );
1078
1079 Ok(())
1080 }
1081
1082 pub fn load_checkpoint<P: AsRef<std::path::Path>>(
1098 path: P,
1099 name: &str,
1100 model_config: KizzasiConfig,
1101 ) -> CoreResult<Self> {
1102 use std::fs;
1103 use std::path::PathBuf;
1104
1105 let checkpoint_dir = path.as_ref();
1106
1107 let metadata_path: PathBuf = checkpoint_dir.join(format!("{}.json", name));
1109 let metadata_json = fs::read_to_string(&metadata_path).map_err(|e| {
1110 CoreError::Generic(format!("Failed to read checkpoint metadata: {}", e))
1111 })?;
1112
1113 let metadata: CheckpointMetadata = serde_json::from_str(&metadata_json).map_err(|e| {
1114 CoreError::Generic(format!("Failed to parse checkpoint metadata: {}", e))
1115 })?;
1116
1117 let weights_path: PathBuf = checkpoint_dir.join(format!("{}.safetensors", name));
1119 let mut model = TrainableSSM::new(model_config, metadata.config.clone())?;
1120 model
1121 .load_weights(&weights_path)
1122 .map_err(|e| CoreError::Generic(format!("Failed to load model weights: {}", e)))?;
1123
1124 let optimizer = model.create_optimizer()?;
1126 let scheduler = Self::create_scheduler(&metadata.config);
1127
1128 let logger = MetricsLogger::new()
1129 .with_verbose(metadata.config.track_metrics)
1130 .with_log_interval(metadata.config.log_interval);
1131
1132 tracing::info!(
1133 "Checkpoint loaded: version={}, step={}, epoch={}",
1134 metadata.version,
1135 metadata.current_step,
1136 metadata.current_epoch
1137 );
1138
1139 Ok(Self {
1140 model,
1141 optimizer,
1142 config: metadata.config,
1143 scheduler,
1144 metrics: metadata.metrics,
1145 logger,
1146 current_step: metadata.current_step,
1147 })
1148 }
1149
1150 pub fn save_checkpoint_auto<P: AsRef<std::path::Path>>(&self, path: P) -> CoreResult<()> {
1160 let current_epoch = self.metrics.summary().total_epochs;
1161 let name = format!("checkpoint_epoch_{}", current_epoch);
1162 self.save_checkpoint(path, &name)
1163 }
1164
1165 pub fn save_best_checkpoint<P: AsRef<std::path::Path>>(&self, path: P) -> CoreResult<()> {
1175 let summary = self.metrics.summary();
1176
1177 if let (Some(best_epoch), Some(_best_loss)) = (summary.best_epoch, summary.best_val_loss) {
1180 let current_epoch = summary.total_epochs.saturating_sub(1);
1182 if current_epoch == best_epoch {
1183 tracing::info!("New best validation loss! Saving best checkpoint");
1184 return self.save_checkpoint(path, "best");
1185 }
1186 }
1187
1188 Ok(())
1189 }
1190}
1191
1192#[derive(Debug, Clone, Serialize, Deserialize)]
1194pub struct CheckpointMetadata {
1195 pub version: String,
1197 pub timestamp: String,
1199 pub current_step: usize,
1201 pub current_epoch: usize,
1203 pub config: TrainingConfig,
1205 pub metrics: TrainingMetrics,
1207}
1208
1209#[cfg(test)]
1210mod tests {
1211 use super::*;
1212
1213 #[test]
1214 fn test_trainable_ssm_creation() {
1215 let config = KizzasiConfig::new()
1216 .input_dim(3)
1217 .output_dim(3)
1218 .hidden_dim(64)
1219 .state_dim(8)
1220 .num_layers(2);
1221
1222 let training_config = TrainingConfig::default();
1223
1224 let model = TrainableSSM::new(config, training_config);
1225 assert!(model.is_ok());
1226 }
1227
1228 #[test]
1229 fn test_forward_pass() {
1230 let config = KizzasiConfig::new()
1231 .input_dim(3)
1232 .output_dim(3)
1233 .hidden_dim(64)
1234 .state_dim(8)
1235 .num_layers(2);
1236
1237 let training_config = TrainingConfig::default();
1238
1239 let model = TrainableSSM::new(config, training_config).unwrap();
1240 let device = model.device().clone();
1241
1242 let input = Tensor::randn(0f32, 1.0, (2, 10, 3), &device).unwrap();
1244
1245 let output = model.forward(&input);
1246 if let Err(e) = &output {
1247 panic!("Forward pass failed: {:?}", e);
1248 }
1249
1250 let output = output.unwrap();
1251 assert_eq!(output.dims(), &[2, 10, 3]);
1252 }
1253
1254 #[test]
1255 fn test_mse_loss() {
1256 let device = Device::Cpu;
1257 let predictions = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
1258 let targets = Tensor::new(&[1.5f32, 2.5, 3.5], &device).unwrap();
1259
1260 let loss = Loss::mse(&predictions, &targets).unwrap();
1261 let loss_val = loss.to_vec0::<f32>().unwrap();
1262
1263 assert!((loss_val - 0.25).abs() < 1e-5);
1265 }
1266
1267 #[test]
1268 fn test_training_config_default() {
1269 let config = TrainingConfig::default();
1270 assert_eq!(config.learning_rate, 1e-4);
1271 assert_eq!(config.batch_size, 32);
1272 assert_eq!(config.epochs, 10);
1273 assert!(config.track_metrics);
1274 assert_eq!(config.validation_split, 0.2);
1275 assert_eq!(config.early_stopping_patience, Some(5));
1276 }
1277
1278 #[test]
1279 fn test_training_config_with_scheduler() {
1280 let config = TrainingConfig::default().with_scheduler(SchedulerType::Cosine {
1281 warmup_steps: 100,
1282 min_lr: 1e-6,
1283 });
1284
1285 assert!(config.scheduler.is_some());
1286 if let Some(SchedulerType::Cosine {
1287 warmup_steps,
1288 min_lr,
1289 }) = config.scheduler
1290 {
1291 assert_eq!(warmup_steps, 100);
1292 assert_eq!(min_lr, 1e-6);
1293 } else {
1294 panic!("Expected Cosine scheduler");
1295 }
1296 }
1297
1298 #[test]
1299 fn test_training_config_builder() {
1300 let config = TrainingConfig::default()
1301 .with_validation_split(0.15)
1302 .with_early_stopping(10)
1303 .without_metrics();
1304
1305 assert_eq!(config.validation_split, 0.15);
1306 assert_eq!(config.early_stopping_patience, Some(10));
1307 assert!(!config.track_metrics);
1308 }
1309
1310 #[test]
1311 fn test_trainer_with_scheduler() {
1312 let model_config = KizzasiConfig::new()
1313 .input_dim(3)
1314 .output_dim(3)
1315 .hidden_dim(64)
1316 .state_dim(8)
1317 .num_layers(2);
1318
1319 let training_config = TrainingConfig::default().with_scheduler(SchedulerType::Linear {
1320 warmup_steps: 50,
1321 final_lr: 1e-6,
1322 });
1323
1324 let model = TrainableSSM::new(model_config, training_config.clone()).unwrap();
1325 let trainer = Trainer::new(model, training_config);
1326
1327 assert!(trainer.is_ok());
1328 let trainer = trainer.unwrap();
1329 assert!(trainer.scheduler.is_some());
1330 }
1331
1332 #[test]
1333 fn test_trainer_metrics_tracking() {
1334 let model_config = KizzasiConfig::new()
1335 .input_dim(3)
1336 .output_dim(3)
1337 .hidden_dim(64)
1338 .state_dim(8)
1339 .num_layers(2);
1340
1341 let training_config = TrainingConfig::default();
1342 let model = TrainableSSM::new(model_config, training_config.clone()).unwrap();
1343 let trainer = Trainer::new(model, training_config).unwrap();
1344
1345 assert_eq!(trainer.metrics().current_step(), 0);
1347 assert_eq!(trainer.current_step(), 0);
1348 }
1349
1350 #[test]
1351 fn test_scheduler_type_constant() {
1352 let config = TrainingConfig::default().with_scheduler(SchedulerType::Constant);
1353
1354 assert!(config.scheduler.is_some());
1355 }
1356
1357 #[test]
1358 fn test_scheduler_type_step() {
1359 let config = TrainingConfig::default().with_scheduler(SchedulerType::Step {
1360 milestones: vec![100, 200, 300],
1361 decay_factor: 0.1,
1362 });
1363
1364 if let Some(SchedulerType::Step {
1365 milestones,
1366 decay_factor,
1367 }) = config.scheduler
1368 {
1369 assert_eq!(milestones, vec![100, 200, 300]);
1370 assert_eq!(decay_factor, 0.1);
1371 } else {
1372 panic!("Expected Step scheduler");
1373 }
1374 }
1375
1376 #[test]
1377 fn test_scheduler_type_onecycle() {
1378 let config =
1379 TrainingConfig::default().with_scheduler(SchedulerType::OneCycle { warmup_pct: 0.3 });
1380
1381 if let Some(SchedulerType::OneCycle { warmup_pct }) = config.scheduler {
1382 assert_eq!(warmup_pct, 0.3);
1383 } else {
1384 panic!("Expected OneCycle scheduler");
1385 }
1386 }
1387
1388 #[test]
1389 fn test_mae_loss() {
1390 let device = Device::Cpu;
1391 let predictions = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
1392 let targets = Tensor::new(&[1.5f32, 2.5, 3.5], &device).unwrap();
1393
1394 let loss = Loss::mae(&predictions, &targets).unwrap();
1395 let loss_val = loss.to_vec0::<f32>().unwrap();
1396
1397 assert!((loss_val - 0.5).abs() < 1e-5);
1399 }
1400
1401 #[test]
1402 fn test_huber_loss() {
1403 let device = Device::Cpu;
1404 let predictions = Tensor::new(&[1.0f32, 2.0, 5.0], &device).unwrap();
1405 let targets = Tensor::new(&[1.1f32, 2.1, 3.0], &device).unwrap();
1406
1407 let loss = Loss::huber(&predictions, &targets, 1.0).unwrap();
1408 let loss_val = loss.to_vec0::<f32>().unwrap();
1409
1410 assert!(loss_val > 0.0);
1412 assert!(loss_val < 2.0); }
1414
1415 #[test]
1416 fn test_constraint_loss_creation() {
1417 let constraint_loss = ConstraintLoss::new(0.5);
1418 assert_eq!(constraint_loss.constraint_weight, 0.5);
1419 }
1420
1421 #[test]
1422 fn test_constraint_loss_no_violation() {
1423 let device = Device::Cpu;
1424 let predictions = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
1425 let targets = Tensor::new(&[1.5f32, 2.5, 3.5], &device).unwrap();
1426
1427 let task_loss = Loss::mse(&predictions, &targets).unwrap();
1428 let task_loss_val = task_loss.to_vec0::<f32>().unwrap();
1429
1430 let constraint_loss = ConstraintLoss::new(0.5);
1431
1432 let total_loss = constraint_loss
1434 .compute(&task_loss, &predictions, |_pred| Ok(0.0))
1435 .unwrap();
1436 let total_loss_val = total_loss.to_vec0::<f32>().unwrap();
1437
1438 assert!((total_loss_val - task_loss_val).abs() < 1e-5);
1440 }
1441
1442 #[test]
1443 fn test_constraint_loss_with_violation() {
1444 let device = Device::Cpu;
1445 let predictions = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
1446 let targets = Tensor::new(&[1.5f32, 2.5, 3.5], &device).unwrap();
1447
1448 let task_loss = Loss::mse(&predictions, &targets).unwrap();
1449 let task_loss_val = task_loss.to_vec0::<f32>().unwrap();
1450
1451 let constraint_loss = ConstraintLoss::new(0.5);
1452
1453 let total_loss = constraint_loss
1455 .compute(&task_loss, &predictions, |_pred| Ok(1.0))
1456 .unwrap();
1457 let total_loss_val = total_loss.to_vec0::<f32>().unwrap();
1458
1459 let expected = task_loss_val + 0.5;
1461 assert!((total_loss_val - expected).abs() < 1e-5);
1462 }
1463
1464 #[test]
1465 fn test_constraint_loss_scaling() {
1466 let device = Device::Cpu;
1467 let predictions = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
1468 let targets = Tensor::new(&[1.5f32, 2.5, 3.5], &device).unwrap();
1469
1470 let task_loss = Loss::mse(&predictions, &targets).unwrap();
1471 let task_loss_val = task_loss.to_vec0::<f32>().unwrap();
1472
1473 let weights = [0.1, 0.5, 1.0, 2.0];
1475 let violation = 1.5;
1476
1477 for &weight in &weights {
1478 let constraint_loss = ConstraintLoss::new(weight);
1479 let total_loss = constraint_loss
1480 .compute(&task_loss, &predictions, |_pred| Ok(violation))
1481 .unwrap();
1482 let total_loss_val = total_loss.to_vec0::<f32>().unwrap();
1483
1484 let expected = task_loss_val + weight * violation;
1485 assert!(
1486 (total_loss_val - expected).abs() < 1e-4,
1487 "Weight {} failed: got {}, expected {}",
1488 weight,
1489 total_loss_val,
1490 expected
1491 );
1492 }
1493 }
1494
1495 #[test]
1496 fn test_checkpoint_save_load() {
1497 use std::env;
1498 use std::fs;
1499
1500 let temp_dir = env::temp_dir().join("kizzasi_checkpoint_test");
1501 fs::create_dir_all(&temp_dir).unwrap();
1502
1503 let config = KizzasiConfig::new()
1505 .input_dim(3)
1506 .output_dim(3)
1507 .hidden_dim(64)
1508 .state_dim(8)
1509 .num_layers(2);
1510
1511 let training_config = TrainingConfig {
1512 epochs: 5,
1513 learning_rate: 1e-3,
1514 ..Default::default()
1515 };
1516
1517 let model = TrainableSSM::new(config.clone(), training_config.clone()).unwrap();
1518 let trainer = Trainer::new(model, training_config).unwrap();
1519
1520 trainer
1522 .save_checkpoint(&temp_dir, "test_checkpoint")
1523 .unwrap();
1524
1525 assert!(temp_dir.join("test_checkpoint.safetensors").exists());
1527 assert!(temp_dir.join("test_checkpoint.json").exists());
1528
1529 let loaded_trainer =
1531 Trainer::load_checkpoint(&temp_dir, "test_checkpoint", config).unwrap();
1532
1533 assert_eq!(loaded_trainer.config.epochs, 5);
1535 assert_eq!(loaded_trainer.config.learning_rate, 1e-3);
1536 assert_eq!(loaded_trainer.current_step, 0);
1537
1538 fs::remove_dir_all(&temp_dir).unwrap();
1540 }
1541
1542 #[test]
1543 fn test_checkpoint_auto_save() {
1544 use std::env;
1545 use std::fs;
1546
1547 let temp_dir = env::temp_dir().join("kizzasi_checkpoint_auto_test");
1548 fs::create_dir_all(&temp_dir).unwrap();
1549
1550 let config = KizzasiConfig::new()
1551 .input_dim(3)
1552 .output_dim(3)
1553 .hidden_dim(64)
1554 .state_dim(8)
1555 .num_layers(2);
1556
1557 let training_config = TrainingConfig::default();
1558 let model = TrainableSSM::new(config, training_config.clone()).unwrap();
1559 let mut trainer = Trainer::new(model, training_config).unwrap();
1560
1561 trainer.metrics.record_train_loss(0, 0.5);
1563
1564 trainer.save_checkpoint_auto(&temp_dir).unwrap();
1566
1567 assert!(temp_dir.join("checkpoint_epoch_1.safetensors").exists());
1569 assert!(temp_dir.join("checkpoint_epoch_1.json").exists());
1570
1571 fs::remove_dir_all(&temp_dir).unwrap();
1573 }
1574
1575 #[test]
1576 fn test_checkpoint_best_save() {
1577 use std::env;
1578 use std::fs;
1579
1580 let temp_dir = env::temp_dir().join("kizzasi_checkpoint_best_test");
1581 fs::create_dir_all(&temp_dir).unwrap();
1582
1583 let config = KizzasiConfig::new()
1584 .input_dim(3)
1585 .output_dim(3)
1586 .hidden_dim(64)
1587 .state_dim(8)
1588 .num_layers(2);
1589
1590 let training_config = TrainingConfig::default();
1591 let model = TrainableSSM::new(config, training_config.clone()).unwrap();
1592 let mut trainer = Trainer::new(model, training_config).unwrap();
1593
1594 trainer.metrics.record_train_loss(0, 1.2);
1596 trainer.metrics.record_val_loss(0, 1.0);
1597 trainer.save_best_checkpoint(&temp_dir).unwrap();
1598
1599 assert!(temp_dir.join("best.safetensors").exists());
1601 assert!(temp_dir.join("best.json").exists());
1602
1603 trainer.metrics.record_train_loss(1, 0.9);
1605 trainer.metrics.record_val_loss(1, 1.2);
1606
1607 fs::remove_file(temp_dir.join("best.safetensors")).unwrap();
1609 fs::remove_file(temp_dir.join("best.json")).unwrap();
1610
1611 trainer.save_best_checkpoint(&temp_dir).unwrap();
1612 assert!(!temp_dir.join("best.safetensors").exists());
1614
1615 fs::remove_dir_all(&temp_dir).unwrap();
1617 }
1618
1619 #[test]
1620 fn test_checkpoint_metadata() {
1621 use std::env;
1622 use std::fs;
1623
1624 let temp_dir = env::temp_dir().join("kizzasi_checkpoint_metadata_test");
1625 fs::create_dir_all(&temp_dir).unwrap();
1626
1627 let config = KizzasiConfig::new()
1628 .input_dim(3)
1629 .output_dim(3)
1630 .hidden_dim(64)
1631 .state_dim(8)
1632 .num_layers(2);
1633
1634 let training_config = TrainingConfig::default();
1635 let model = TrainableSSM::new(config, training_config.clone()).unwrap();
1636 let mut trainer = Trainer::new(model, training_config).unwrap();
1637
1638 trainer.metrics.record_train_loss(0, 0.5);
1640 trainer.metrics.record_val_loss(0, 0.45);
1641
1642 trainer.save_checkpoint(&temp_dir, "metadata_test").unwrap();
1644
1645 let metadata_path = temp_dir.join("metadata_test.json");
1647 let metadata_json = fs::read_to_string(&metadata_path).unwrap();
1648 let metadata: CheckpointMetadata = serde_json::from_str(&metadata_json).unwrap();
1649
1650 assert_eq!(metadata.version, env!("CARGO_PKG_VERSION"));
1651 assert!(!metadata.timestamp.is_empty());
1652 assert_eq!(metadata.current_step, 0);
1653 assert!(metadata.metrics.val_loss(0).is_some());
1654 assert_eq!(metadata.metrics.val_loss(0).unwrap(), 0.45);
1655
1656 fs::remove_dir_all(&temp_dir).unwrap();
1658 }
1659}