1use crate::config::KizzasiConfig;
11use crate::device::DeviceConfig;
12use crate::error::{CoreError, CoreResult};
13use candle_core::{DType, Device, Tensor, Var};
14use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
15use serde::{Deserialize, Serialize};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub enum SchedulerType {
20 Constant,
21 Linear {
22 warmup_steps: usize,
23 final_lr: f64,
24 },
25 Cosine {
26 warmup_steps: usize,
27 min_lr: f64,
28 },
29 Step {
30 milestones: Vec<usize>,
31 decay_factor: f64,
32 },
33 Exponential {
34 decay_rate: f64,
35 decay_steps: usize,
36 },
37 OneCycle {
38 warmup_pct: f64,
39 },
40 Polynomial {
41 final_lr: f64,
42 power: f64,
43 },
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
48pub enum MixedPrecision {
49 None,
51 FP16,
53 BF16,
55}
56
57impl MixedPrecision {
58 pub fn to_dtype(&self) -> DType {
60 match self {
61 MixedPrecision::None => DType::F32,
62 MixedPrecision::FP16 => DType::F16,
63 MixedPrecision::BF16 => DType::BF16,
64 }
65 }
66
67 pub fn is_enabled(&self) -> bool {
69 !matches!(self, MixedPrecision::None)
70 }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct TrainingConfig {
76 pub device_config: DeviceConfig,
78 pub learning_rate: f64,
80 pub batch_size: usize,
82 pub epochs: usize,
84 pub weight_decay: f64,
86 pub grad_clip: Option<f32>,
88 pub beta1: f64,
90 pub beta2: f64,
92 pub eps: f64,
94 pub scheduler: Option<SchedulerType>,
96 pub track_metrics: bool,
98 pub log_interval: usize,
100 pub validation_split: f32,
102 pub early_stopping_patience: Option<usize>,
104 pub use_gradient_checkpointing: bool,
106 pub checkpoint_segment_size: Option<usize>,
108 pub mixed_precision: MixedPrecision,
110 pub loss_scale: f32,
112}
113
114impl Default for TrainingConfig {
115 fn default() -> Self {
116 Self {
117 device_config: DeviceConfig::default(),
118 learning_rate: 1e-4,
119 batch_size: 32,
120 epochs: 10,
121 weight_decay: 1e-2,
122 grad_clip: Some(1.0),
123 beta1: 0.9,
124 beta2: 0.999,
125 eps: 1e-8,
126 scheduler: None,
127 track_metrics: true,
128 log_interval: 10,
129 validation_split: 0.2,
130 early_stopping_patience: Some(5),
131 use_gradient_checkpointing: false,
132 checkpoint_segment_size: Some(2), mixed_precision: MixedPrecision::None,
134 loss_scale: 1.0, }
136 }
137}
138
139impl TrainingConfig {
140 pub fn with_scheduler(mut self, scheduler: SchedulerType) -> Self {
142 self.scheduler = Some(scheduler);
143 self
144 }
145
146 pub fn without_metrics(mut self) -> Self {
148 self.track_metrics = false;
149 self
150 }
151
152 pub fn with_validation_split(mut self, split: f32) -> Self {
154 self.validation_split = split;
155 self
156 }
157
158 pub fn with_early_stopping(mut self, patience: usize) -> Self {
160 self.early_stopping_patience = Some(patience);
161 self
162 }
163
164 pub fn without_early_stopping(mut self) -> Self {
166 self.early_stopping_patience = None;
167 self
168 }
169
170 pub fn with_gradient_checkpointing(mut self, segment_size: Option<usize>) -> Self {
172 self.use_gradient_checkpointing = true;
173 self.checkpoint_segment_size = segment_size;
174 self
175 }
176
177 pub fn without_gradient_checkpointing(mut self) -> Self {
179 self.use_gradient_checkpointing = false;
180 self
181 }
182
183 pub fn with_fp16(mut self) -> Self {
185 self.mixed_precision = MixedPrecision::FP16;
186 self.loss_scale = 128.0; self
188 }
189
190 pub fn with_bf16(mut self) -> Self {
192 self.mixed_precision = MixedPrecision::BF16;
193 self.loss_scale = 1.0; self
195 }
196
197 pub fn with_mixed_precision(mut self, mode: MixedPrecision, loss_scale: f32) -> Self {
199 self.mixed_precision = mode;
200 self.loss_scale = loss_scale;
201 self
202 }
203
204 pub fn without_mixed_precision(mut self) -> Self {
206 self.mixed_precision = MixedPrecision::None;
207 self.loss_scale = 1.0;
208 self
209 }
210}
211
212pub struct TrainableSSM {
214 pub(crate) config: KizzasiConfig,
215 pub(crate) training_config: TrainingConfig,
216 pub(crate) device: Device,
217 pub(crate) dtype: DType,
218 pub(crate) embedding_weight: Var,
220 pub(crate) a_matrices: Vec<Var>,
221 pub(crate) b_matrices: Vec<Var>,
222 pub(crate) c_matrices: Vec<Var>,
223 pub(crate) d_vectors: Vec<Var>,
224 pub(crate) output_proj: Var,
225 pub(crate) ln_gamma: Vec<Var>,
227 pub(crate) ln_beta: Vec<Var>,
228 pub(crate) varmap: VarMap,
230}
231
232impl TrainableSSM {
233 pub fn new(config: KizzasiConfig, training_config: TrainingConfig) -> CoreResult<Self> {
235 let device = training_config.device_config.create_device()?;
237
238 let dtype = training_config.mixed_precision.to_dtype();
240
241 let hidden_dim = config.get_hidden_dim();
242 let state_dim = config.get_state_dim();
243 let num_layers = config.get_num_layers();
244 let input_dim = config.get_input_dim();
245 let output_dim = config.get_output_dim();
246
247 let varmap = VarMap::new();
248 let vb = VarBuilder::from_varmap(&varmap, dtype, &device);
249
250 let embedding_weight_tensor = vb
252 .get_with_hints(
253 (input_dim, hidden_dim),
254 "embedding.weight",
255 candle_nn::init::DEFAULT_KAIMING_NORMAL,
256 )
257 .map_err(|e| CoreError::Generic(format!("Failed to create embedding: {}", e)))?;
258 let embedding_weight = Var::from_tensor(&embedding_weight_tensor)
259 .map_err(|e| CoreError::Generic(format!("Failed to create embedding var: {}", e)))?;
260
261 let mut a_matrices = Vec::with_capacity(num_layers);
263 let mut b_matrices = Vec::with_capacity(num_layers);
264 let mut c_matrices = Vec::with_capacity(num_layers);
265 let mut d_vectors = Vec::with_capacity(num_layers);
266 let mut ln_gamma = Vec::with_capacity(num_layers);
267 let mut ln_beta = Vec::with_capacity(num_layers);
268
269 for layer_idx in 0..num_layers {
270 let a_tensor = vb
272 .get_with_hints(
273 (hidden_dim, state_dim),
274 &format!("ssm.layer_{}.a", layer_idx),
275 candle_nn::init::Init::Const(-0.5),
276 )
277 .map_err(|e| CoreError::Generic(format!("Failed to create A matrix: {}", e)))?;
278 let a = Var::from_tensor(&a_tensor)
279 .map_err(|e| CoreError::Generic(format!("Failed to create A var: {}", e)))?;
280 a_matrices.push(a);
281
282 let b_tensor = vb
284 .get_with_hints(
285 (hidden_dim, state_dim),
286 &format!("ssm.layer_{}.b", layer_idx),
287 candle_nn::init::DEFAULT_KAIMING_NORMAL,
288 )
289 .map_err(|e| CoreError::Generic(format!("Failed to create B matrix: {}", e)))?;
290 let b = Var::from_tensor(&b_tensor)
291 .map_err(|e| CoreError::Generic(format!("Failed to create B var: {}", e)))?;
292 b_matrices.push(b);
293
294 let c_tensor = vb
296 .get_with_hints(
297 (hidden_dim, state_dim),
298 &format!("ssm.layer_{}.c", layer_idx),
299 candle_nn::init::DEFAULT_KAIMING_NORMAL,
300 )
301 .map_err(|e| CoreError::Generic(format!("Failed to create C matrix: {}", e)))?;
302 let c = Var::from_tensor(&c_tensor)
303 .map_err(|e| CoreError::Generic(format!("Failed to create C var: {}", e)))?;
304 c_matrices.push(c);
305
306 let d_tensor = vb
308 .get_with_hints(
309 hidden_dim,
310 &format!("ssm.layer_{}.d", layer_idx),
311 candle_nn::init::Init::Const(1.0),
312 )
313 .map_err(|e| CoreError::Generic(format!("Failed to create D vector: {}", e)))?;
314 let d = Var::from_tensor(&d_tensor)
315 .map_err(|e| CoreError::Generic(format!("Failed to create D var: {}", e)))?;
316 d_vectors.push(d);
317
318 let gamma_tensor = vb
320 .get_with_hints(
321 hidden_dim,
322 &format!("ln.layer_{}.gamma", layer_idx),
323 candle_nn::init::Init::Const(1.0),
324 )
325 .map_err(|e| CoreError::Generic(format!("Failed to create LN gamma: {}", e)))?;
326 let gamma = Var::from_tensor(&gamma_tensor)
327 .map_err(|e| CoreError::Generic(format!("Failed to create LN gamma var: {}", e)))?;
328 ln_gamma.push(gamma);
329
330 let beta_tensor = vb
331 .get_with_hints(
332 hidden_dim,
333 &format!("ln.layer_{}.beta", layer_idx),
334 candle_nn::init::Init::Const(0.0),
335 )
336 .map_err(|e| CoreError::Generic(format!("Failed to create LN beta: {}", e)))?;
337 let beta = Var::from_tensor(&beta_tensor)
338 .map_err(|e| CoreError::Generic(format!("Failed to create LN beta var: {}", e)))?;
339 ln_beta.push(beta);
340 }
341
342 let output_proj_tensor = vb
344 .get_with_hints(
345 (hidden_dim, output_dim),
346 "output.proj",
347 candle_nn::init::DEFAULT_KAIMING_NORMAL,
348 )
349 .map_err(|e| {
350 CoreError::Generic(format!("Failed to create output projection: {}", e))
351 })?;
352 let output_proj = Var::from_tensor(&output_proj_tensor)
353 .map_err(|e| CoreError::Generic(format!("Failed to create output proj var: {}", e)))?;
354
355 Ok(Self {
356 config,
357 training_config,
358 device,
359 dtype,
360 embedding_weight,
361 a_matrices,
362 b_matrices,
363 c_matrices,
364 d_vectors,
365 output_proj,
366 ln_gamma,
367 ln_beta,
368 varmap,
369 })
370 }
371
372 pub fn forward(&self, input: &Tensor) -> CoreResult<Tensor> {
380 let batch_size = input
383 .dim(0)
384 .map_err(|e| CoreError::Generic(format!("Failed to get batch dimension: {}", e)))?;
385 let seq_len = input
386 .dim(1)
387 .map_err(|e| CoreError::Generic(format!("Failed to get sequence dimension: {}", e)))?;
388 let input_dim = input
389 .dim(2)
390 .map_err(|e| CoreError::Generic(format!("Failed to get input dimension: {}", e)))?;
391
392 let x_flat = input
393 .reshape((batch_size * seq_len, input_dim))
394 .map_err(|e| CoreError::Generic(format!("Failed to reshape input: {}", e)))?;
395
396 let hidden_dim = self.config.get_hidden_dim();
397 let x_embedded = x_flat
398 .matmul(self.embedding_weight.as_tensor())
399 .map_err(|e| CoreError::Generic(format!("Embedding forward failed: {}", e)))?;
400
401 let x = x_embedded
402 .reshape((batch_size, seq_len, hidden_dim))
403 .map_err(|e| CoreError::Generic(format!("Failed to reshape embedded: {}", e)))?;
404
405 let state_dim = self.config.get_state_dim();
407
408 let mut h = Tensor::zeros(
409 (batch_size, hidden_dim, state_dim),
410 self.dtype,
411 &self.device,
412 )
413 .map_err(|e| CoreError::Generic(format!("Failed to create hidden state: {}", e)))?;
414
415 let mut x = x;
416
417 for layer_idx in 0..self.config.get_num_layers() {
419 x = self.layer_norm(&x, layer_idx)?;
420 x = self.ssm_layer(&x, &mut h, layer_idx)?;
421 }
422
423 let x_flat = x
426 .reshape((batch_size * seq_len, hidden_dim))
427 .map_err(|e| CoreError::Generic(format!("Failed to reshape for output: {}", e)))?;
428
429 let output_dim = self.config.get_output_dim();
430 let output_flat = x_flat
431 .matmul(self.output_proj.as_tensor())
432 .map_err(|e| CoreError::Generic(format!("Output projection failed: {}", e)))?;
433
434 let output = output_flat
435 .reshape((batch_size, seq_len, output_dim))
436 .map_err(|e| CoreError::Generic(format!("Failed to reshape output: {}", e)))?;
437
438 Ok(output)
439 }
440
441 fn layer_norm(&self, x: &Tensor, layer_idx: usize) -> CoreResult<Tensor> {
443 const EPS: f64 = 1e-5;
444
445 let mean = x
447 .mean_keepdim(candle_core::D::Minus1)
448 .map_err(|e| CoreError::Generic(format!("Layer norm mean failed: {}", e)))?;
449 let x_centered = x.broadcast_sub(&mean).map_err(|e| {
450 CoreError::Generic(format!("Layer norm variance computation failed: {}", e))
451 })?;
452 let variance = x_centered
453 .sqr()
454 .map_err(|e| CoreError::Generic(format!("Layer norm variance sqr failed: {}", e)))?
455 .mean_keepdim(candle_core::D::Minus1)
456 .map_err(|e| CoreError::Generic(format!("Layer norm variance mean failed: {}", e)))?;
457
458 let std = (variance.affine(1.0, EPS))
460 .map_err(|e| CoreError::Generic(format!("Layer norm variance add eps failed: {}", e)))?
461 .sqrt()
462 .map_err(|e| CoreError::Generic(format!("Layer norm sqrt failed: {}", e)))?;
463
464 let normalized = x_centered
465 .broadcast_div(&std)
466 .map_err(|e| CoreError::Generic(format!("Layer norm division failed: {}", e)))?;
467
468 let gamma = self.ln_gamma[layer_idx].as_tensor();
470 let beta = self.ln_beta[layer_idx].as_tensor();
471
472 normalized
473 .broadcast_mul(gamma)
474 .map_err(|e| CoreError::Generic(format!("Layer norm gamma mul failed: {}", e)))?
475 .broadcast_add(beta)
476 .map_err(|e| CoreError::Generic(format!("Layer norm beta add failed: {}", e)))
477 }
478
479 fn ssm_layer(&self, x: &Tensor, _h: &mut Tensor, layer_idx: usize) -> CoreResult<Tensor> {
481 let _a = self.a_matrices[layer_idx].as_tensor();
482 let _b = self.b_matrices[layer_idx].as_tensor();
483 let _c = self.c_matrices[layer_idx].as_tensor();
484 let d = self.d_vectors[layer_idx].as_tensor();
485
486 let y = x
493 .broadcast_mul(d)
494 .map_err(|e| CoreError::Generic(format!("Skip connection failed: {}", e)))?;
495
496 Ok(y)
497 }
498
499 pub fn create_optimizer(&self) -> CoreResult<AdamW> {
501 let params = ParamsAdamW {
502 lr: self.training_config.learning_rate,
503 beta1: self.training_config.beta1,
504 beta2: self.training_config.beta2,
505 eps: self.training_config.eps,
506 weight_decay: self.training_config.weight_decay,
507 };
508
509 AdamW::new(self.varmap.all_vars(), params)
510 .map_err(|e| CoreError::Generic(format!("Failed to create optimizer: {}", e)))
511 }
512
513 pub fn varmap(&self) -> &VarMap {
515 &self.varmap
516 }
517
518 pub fn device(&self) -> &Device {
520 &self.device
521 }
522
523 pub fn dtype(&self) -> DType {
525 self.dtype
526 }
527
528 pub fn save_weights<P: AsRef<std::path::Path>>(&self, path: P) -> CoreResult<()> {
538 self.varmap
539 .save(path)
540 .map_err(|e| CoreError::Generic(format!("Failed to save weights: {}", e)))
541 }
542
543 pub fn load_weights<P: AsRef<std::path::Path>>(&mut self, path: P) -> CoreResult<()> {
553 self.varmap
554 .load(path)
555 .map_err(|e| CoreError::Generic(format!("Failed to load weights: {}", e)))
556 }
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562 use candle_core::Tensor;
563
564 #[test]
565 fn test_trainable_ssm_creation() {
566 let config = KizzasiConfig::new()
567 .input_dim(3)
568 .output_dim(3)
569 .hidden_dim(64)
570 .state_dim(8)
571 .num_layers(2);
572
573 let training_config = TrainingConfig::default();
574
575 let model = TrainableSSM::new(config, training_config);
576 assert!(model.is_ok());
577 }
578
579 #[test]
580 fn test_forward_pass() {
581 let config = KizzasiConfig::new()
582 .input_dim(3)
583 .output_dim(3)
584 .hidden_dim(64)
585 .state_dim(8)
586 .num_layers(2);
587
588 let training_config = TrainingConfig::default();
589
590 let model = TrainableSSM::new(config, training_config).unwrap();
591 let device = model.device().clone();
592
593 let input = Tensor::randn(0f32, 1.0, (2, 10, 3), &device).unwrap();
595
596 let output = model.forward(&input);
597 if let Err(e) = &output {
598 panic!("Forward pass failed: {:?}", e);
599 }
600
601 let output = output.unwrap();
602 assert_eq!(output.dims(), &[2, 10, 3]);
603 }
604
605 #[test]
606 fn test_training_config_default() {
607 let config = TrainingConfig::default();
608 assert_eq!(config.learning_rate, 1e-4);
609 assert_eq!(config.batch_size, 32);
610 assert_eq!(config.epochs, 10);
611 assert!(config.track_metrics);
612 assert_eq!(config.validation_split, 0.2);
613 assert_eq!(config.early_stopping_patience, Some(5));
614 }
615
616 #[test]
617 fn test_training_config_with_scheduler() {
618 let config = TrainingConfig::default().with_scheduler(SchedulerType::Cosine {
619 warmup_steps: 100,
620 min_lr: 1e-6,
621 });
622
623 assert!(config.scheduler.is_some());
624 if let Some(SchedulerType::Cosine {
625 warmup_steps,
626 min_lr,
627 }) = config.scheduler
628 {
629 assert_eq!(warmup_steps, 100);
630 assert_eq!(min_lr, 1e-6);
631 } else {
632 panic!("Expected Cosine scheduler");
633 }
634 }
635
636 #[test]
637 fn test_training_config_builder() {
638 let config = TrainingConfig::default()
639 .with_validation_split(0.15)
640 .with_early_stopping(10)
641 .without_metrics();
642
643 assert_eq!(config.validation_split, 0.15);
644 assert_eq!(config.early_stopping_patience, Some(10));
645 assert!(!config.track_metrics);
646 }
647
648 #[test]
649 fn test_scheduler_type_constant() {
650 let config = TrainingConfig::default().with_scheduler(SchedulerType::Constant);
651
652 assert!(config.scheduler.is_some());
653 }
654
655 #[test]
656 fn test_scheduler_type_step() {
657 let config = TrainingConfig::default().with_scheduler(SchedulerType::Step {
658 milestones: vec![100, 200, 300],
659 decay_factor: 0.1,
660 });
661
662 if let Some(SchedulerType::Step {
663 milestones,
664 decay_factor,
665 }) = config.scheduler
666 {
667 assert_eq!(milestones, vec![100, 200, 300]);
668 assert_eq!(decay_factor, 0.1);
669 } else {
670 panic!("Expected Step scheduler");
671 }
672 }
673
674 #[test]
675 fn test_scheduler_type_onecycle() {
676 let config =
677 TrainingConfig::default().with_scheduler(SchedulerType::OneCycle { warmup_pct: 0.3 });
678
679 if let Some(SchedulerType::OneCycle { warmup_pct }) = config.scheduler {
680 assert_eq!(warmup_pct, 0.3);
681 } else {
682 panic!("Expected OneCycle scheduler");
683 }
684 }
685}