Skip to main content

kizzasi_core/
training_core.rs

1//! Training core — SchedulerType, MixedPrecision, TrainingConfig, TrainableSSM
2//!
3//! This module contains the foundational training infrastructure:
4//!
5//! - [`SchedulerType`] — learning rate scheduler variants
6//! - [`MixedPrecision`] — FP16/BF16 mixed precision modes
7//! - [`TrainingConfig`] — full training hyperparameter configuration
8//! - [`TrainableSSM`] — differentiable SSM model with candle Var parameters
9
10use crate::config::KizzasiConfig;
11use crate::device::DeviceConfig;
12use crate::error::{CoreError, CoreResult};
13use candle_core::{DType, Device, Tensor, Var};
14use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
15use serde::{Deserialize, Serialize};
16
17/// Scheduler type enumeration
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub enum SchedulerType {
20    Constant,
21    Linear {
22        warmup_steps: usize,
23        final_lr: f64,
24    },
25    Cosine {
26        warmup_steps: usize,
27        min_lr: f64,
28    },
29    Step {
30        milestones: Vec<usize>,
31        decay_factor: f64,
32    },
33    Exponential {
34        decay_rate: f64,
35        decay_steps: usize,
36    },
37    OneCycle {
38        warmup_pct: f64,
39    },
40    Polynomial {
41        final_lr: f64,
42        power: f64,
43    },
44}
45
46/// Mixed precision training mode
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
48pub enum MixedPrecision {
49    /// Full precision (FP32)
50    None,
51    /// Half precision (FP16) - faster but less stable
52    FP16,
53    /// Brain float 16 (BF16) - better stability than FP16
54    BF16,
55}
56
57impl MixedPrecision {
58    /// Convert to candle DType
59    pub fn to_dtype(&self) -> DType {
60        match self {
61            MixedPrecision::None => DType::F32,
62            MixedPrecision::FP16 => DType::F16,
63            MixedPrecision::BF16 => DType::BF16,
64        }
65    }
66
67    /// Check if mixed precision is enabled
68    pub fn is_enabled(&self) -> bool {
69        !matches!(self, MixedPrecision::None)
70    }
71}
72
73/// Configuration for training
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct TrainingConfig {
76    /// Device configuration (CPU/CUDA/Metal)
77    pub device_config: DeviceConfig,
78    /// Learning rate (initial for schedulers)
79    pub learning_rate: f64,
80    /// Batch size
81    pub batch_size: usize,
82    /// Number of epochs
83    pub epochs: usize,
84    /// Weight decay (L2 regularization)
85    pub weight_decay: f64,
86    /// Gradient clipping threshold
87    pub grad_clip: Option<f32>,
88    /// Beta1 for Adam optimizer
89    pub beta1: f64,
90    /// Beta2 for Adam optimizer
91    pub beta2: f64,
92    /// Epsilon for Adam optimizer
93    pub eps: f64,
94    /// Learning rate scheduler type
95    pub scheduler: Option<SchedulerType>,
96    /// Enable metrics tracking
97    pub track_metrics: bool,
98    /// Log interval (batches)
99    pub log_interval: usize,
100    /// Validation split (0.0 to 1.0)
101    pub validation_split: f32,
102    /// Early stopping patience (epochs)
103    pub early_stopping_patience: Option<usize>,
104    /// Enable gradient checkpointing (saves memory by recomputing activations)
105    pub use_gradient_checkpointing: bool,
106    /// Checkpoint every N layers (None = checkpoint all layers)
107    pub checkpoint_segment_size: Option<usize>,
108    /// Mixed precision training mode
109    pub mixed_precision: MixedPrecision,
110    /// Loss scaling factor for mixed precision (to prevent underflow)
111    pub loss_scale: f32,
112}
113
114impl Default for TrainingConfig {
115    fn default() -> Self {
116        Self {
117            device_config: DeviceConfig::default(),
118            learning_rate: 1e-4,
119            batch_size: 32,
120            epochs: 10,
121            weight_decay: 1e-2,
122            grad_clip: Some(1.0),
123            beta1: 0.9,
124            beta2: 0.999,
125            eps: 1e-8,
126            scheduler: None,
127            track_metrics: true,
128            log_interval: 10,
129            validation_split: 0.2,
130            early_stopping_patience: Some(5),
131            use_gradient_checkpointing: false,
132            checkpoint_segment_size: Some(2), // Checkpoint every 2 layers by default
133            mixed_precision: MixedPrecision::None,
134            loss_scale: 1.0, // No scaling by default
135        }
136    }
137}
138
139impl TrainingConfig {
140    /// Set scheduler type
141    pub fn with_scheduler(mut self, scheduler: SchedulerType) -> Self {
142        self.scheduler = Some(scheduler);
143        self
144    }
145
146    /// Disable metrics tracking
147    pub fn without_metrics(mut self) -> Self {
148        self.track_metrics = false;
149        self
150    }
151
152    /// Set validation split
153    pub fn with_validation_split(mut self, split: f32) -> Self {
154        self.validation_split = split;
155        self
156    }
157
158    /// Set early stopping patience
159    pub fn with_early_stopping(mut self, patience: usize) -> Self {
160        self.early_stopping_patience = Some(patience);
161        self
162    }
163
164    /// Disable early stopping
165    pub fn without_early_stopping(mut self) -> Self {
166        self.early_stopping_patience = None;
167        self
168    }
169
170    /// Enable gradient checkpointing for memory-efficient training
171    pub fn with_gradient_checkpointing(mut self, segment_size: Option<usize>) -> Self {
172        self.use_gradient_checkpointing = true;
173        self.checkpoint_segment_size = segment_size;
174        self
175    }
176
177    /// Disable gradient checkpointing
178    pub fn without_gradient_checkpointing(mut self) -> Self {
179        self.use_gradient_checkpointing = false;
180        self
181    }
182
183    /// Enable mixed precision training (FP16)
184    pub fn with_fp16(mut self) -> Self {
185        self.mixed_precision = MixedPrecision::FP16;
186        self.loss_scale = 128.0; // Default loss scale for FP16
187        self
188    }
189
190    /// Enable mixed precision training (BF16)
191    pub fn with_bf16(mut self) -> Self {
192        self.mixed_precision = MixedPrecision::BF16;
193        self.loss_scale = 1.0; // BF16 is more stable, doesn't need scaling
194        self
195    }
196
197    /// Set mixed precision mode
198    pub fn with_mixed_precision(mut self, mode: MixedPrecision, loss_scale: f32) -> Self {
199        self.mixed_precision = mode;
200        self.loss_scale = loss_scale;
201        self
202    }
203
204    /// Disable mixed precision training
205    pub fn without_mixed_precision(mut self) -> Self {
206        self.mixed_precision = MixedPrecision::None;
207        self.loss_scale = 1.0;
208        self
209    }
210}
211
212/// Trainable Selective SSM using candle Tensors
213pub struct TrainableSSM {
214    pub(crate) config: KizzasiConfig,
215    pub(crate) training_config: TrainingConfig,
216    pub(crate) device: Device,
217    pub(crate) dtype: DType,
218    // Learnable parameters
219    pub(crate) embedding_weight: Var,
220    pub(crate) a_matrices: Vec<Var>,
221    pub(crate) b_matrices: Vec<Var>,
222    pub(crate) c_matrices: Vec<Var>,
223    pub(crate) d_vectors: Vec<Var>,
224    pub(crate) output_proj: Var,
225    // Layer normalization parameters
226    pub(crate) ln_gamma: Vec<Var>,
227    pub(crate) ln_beta: Vec<Var>,
228    // Variable map for optimizer
229    pub(crate) varmap: VarMap,
230}
231
232impl TrainableSSM {
233    /// Create a new trainable SSM model
234    pub fn new(config: KizzasiConfig, training_config: TrainingConfig) -> CoreResult<Self> {
235        // Create device from configuration
236        let device = training_config.device_config.create_device()?;
237
238        // Use mixed precision dtype from training config
239        let dtype = training_config.mixed_precision.to_dtype();
240
241        let hidden_dim = config.get_hidden_dim();
242        let state_dim = config.get_state_dim();
243        let num_layers = config.get_num_layers();
244        let input_dim = config.get_input_dim();
245        let output_dim = config.get_output_dim();
246
247        let varmap = VarMap::new();
248        let vb = VarBuilder::from_varmap(&varmap, dtype, &device);
249
250        // Initialize embedding layer
251        let embedding_weight_tensor = vb
252            .get_with_hints(
253                (input_dim, hidden_dim),
254                "embedding.weight",
255                candle_nn::init::DEFAULT_KAIMING_NORMAL,
256            )
257            .map_err(|e| CoreError::Generic(format!("Failed to create embedding: {}", e)))?;
258        let embedding_weight = Var::from_tensor(&embedding_weight_tensor)
259            .map_err(|e| CoreError::Generic(format!("Failed to create embedding var: {}", e)))?;
260
261        // Initialize SSM matrices for each layer
262        let mut a_matrices = Vec::with_capacity(num_layers);
263        let mut b_matrices = Vec::with_capacity(num_layers);
264        let mut c_matrices = Vec::with_capacity(num_layers);
265        let mut d_vectors = Vec::with_capacity(num_layers);
266        let mut ln_gamma = Vec::with_capacity(num_layers);
267        let mut ln_beta = Vec::with_capacity(num_layers);
268
269        for layer_idx in 0..num_layers {
270            // A matrix: state transition (initialized for stability)
271            let a_tensor = vb
272                .get_with_hints(
273                    (hidden_dim, state_dim),
274                    &format!("ssm.layer_{}.a", layer_idx),
275                    candle_nn::init::Init::Const(-0.5),
276                )
277                .map_err(|e| CoreError::Generic(format!("Failed to create A matrix: {}", e)))?;
278            let a = Var::from_tensor(&a_tensor)
279                .map_err(|e| CoreError::Generic(format!("Failed to create A var: {}", e)))?;
280            a_matrices.push(a);
281
282            // B matrix: input projection to state
283            let b_tensor = vb
284                .get_with_hints(
285                    (hidden_dim, state_dim),
286                    &format!("ssm.layer_{}.b", layer_idx),
287                    candle_nn::init::DEFAULT_KAIMING_NORMAL,
288                )
289                .map_err(|e| CoreError::Generic(format!("Failed to create B matrix: {}", e)))?;
290            let b = Var::from_tensor(&b_tensor)
291                .map_err(|e| CoreError::Generic(format!("Failed to create B var: {}", e)))?;
292            b_matrices.push(b);
293
294            // C matrix: state to output projection
295            let c_tensor = vb
296                .get_with_hints(
297                    (hidden_dim, state_dim),
298                    &format!("ssm.layer_{}.c", layer_idx),
299                    candle_nn::init::DEFAULT_KAIMING_NORMAL,
300                )
301                .map_err(|e| CoreError::Generic(format!("Failed to create C matrix: {}", e)))?;
302            let c = Var::from_tensor(&c_tensor)
303                .map_err(|e| CoreError::Generic(format!("Failed to create C var: {}", e)))?;
304            c_matrices.push(c);
305
306            // D vector: skip connection
307            let d_tensor = vb
308                .get_with_hints(
309                    hidden_dim,
310                    &format!("ssm.layer_{}.d", layer_idx),
311                    candle_nn::init::Init::Const(1.0),
312                )
313                .map_err(|e| CoreError::Generic(format!("Failed to create D vector: {}", e)))?;
314            let d = Var::from_tensor(&d_tensor)
315                .map_err(|e| CoreError::Generic(format!("Failed to create D var: {}", e)))?;
316            d_vectors.push(d);
317
318            // Layer normalization parameters
319            let gamma_tensor = vb
320                .get_with_hints(
321                    hidden_dim,
322                    &format!("ln.layer_{}.gamma", layer_idx),
323                    candle_nn::init::Init::Const(1.0),
324                )
325                .map_err(|e| CoreError::Generic(format!("Failed to create LN gamma: {}", e)))?;
326            let gamma = Var::from_tensor(&gamma_tensor)
327                .map_err(|e| CoreError::Generic(format!("Failed to create LN gamma var: {}", e)))?;
328            ln_gamma.push(gamma);
329
330            let beta_tensor = vb
331                .get_with_hints(
332                    hidden_dim,
333                    &format!("ln.layer_{}.beta", layer_idx),
334                    candle_nn::init::Init::Const(0.0),
335                )
336                .map_err(|e| CoreError::Generic(format!("Failed to create LN beta: {}", e)))?;
337            let beta = Var::from_tensor(&beta_tensor)
338                .map_err(|e| CoreError::Generic(format!("Failed to create LN beta var: {}", e)))?;
339            ln_beta.push(beta);
340        }
341
342        // Output projection
343        let output_proj_tensor = vb
344            .get_with_hints(
345                (hidden_dim, output_dim),
346                "output.proj",
347                candle_nn::init::DEFAULT_KAIMING_NORMAL,
348            )
349            .map_err(|e| {
350                CoreError::Generic(format!("Failed to create output projection: {}", e))
351            })?;
352        let output_proj = Var::from_tensor(&output_proj_tensor)
353            .map_err(|e| CoreError::Generic(format!("Failed to create output proj var: {}", e)))?;
354
355        Ok(Self {
356            config,
357            training_config,
358            device,
359            dtype,
360            embedding_weight,
361            a_matrices,
362            b_matrices,
363            c_matrices,
364            d_vectors,
365            output_proj,
366            ln_gamma,
367            ln_beta,
368            varmap,
369        })
370    }
371
372    /// Forward pass for training (tracks gradients)
373    ///
374    /// # Arguments
375    /// * `input` - Input tensor of shape [batch_size, seq_len, input_dim]
376    ///
377    /// # Returns
378    /// Output tensor of shape [batch_size, seq_len, output_dim]
379    pub fn forward(&self, input: &Tensor) -> CoreResult<Tensor> {
380        // Embed input: [batch, seq, input_dim] -> [batch, seq, hidden_dim]
381        // Reshape input to [batch * seq, input_dim] for matmul, then reshape back
382        let batch_size = input
383            .dim(0)
384            .map_err(|e| CoreError::Generic(format!("Failed to get batch dimension: {}", e)))?;
385        let seq_len = input
386            .dim(1)
387            .map_err(|e| CoreError::Generic(format!("Failed to get sequence dimension: {}", e)))?;
388        let input_dim = input
389            .dim(2)
390            .map_err(|e| CoreError::Generic(format!("Failed to get input dimension: {}", e)))?;
391
392        let x_flat = input
393            .reshape((batch_size * seq_len, input_dim))
394            .map_err(|e| CoreError::Generic(format!("Failed to reshape input: {}", e)))?;
395
396        let hidden_dim = self.config.get_hidden_dim();
397        let x_embedded = x_flat
398            .matmul(self.embedding_weight.as_tensor())
399            .map_err(|e| CoreError::Generic(format!("Embedding forward failed: {}", e)))?;
400
401        let x = x_embedded
402            .reshape((batch_size, seq_len, hidden_dim))
403            .map_err(|e| CoreError::Generic(format!("Failed to reshape embedded: {}", e)))?;
404
405        // Initialize hidden state
406        let state_dim = self.config.get_state_dim();
407
408        let mut h = Tensor::zeros(
409            (batch_size, hidden_dim, state_dim),
410            self.dtype,
411            &self.device,
412        )
413        .map_err(|e| CoreError::Generic(format!("Failed to create hidden state: {}", e)))?;
414
415        let mut x = x;
416
417        // Process through each layer
418        for layer_idx in 0..self.config.get_num_layers() {
419            x = self.layer_norm(&x, layer_idx)?;
420            x = self.ssm_layer(&x, &mut h, layer_idx)?;
421        }
422
423        // Project to output dimension: [batch, seq, hidden_dim] -> [batch, seq, output_dim]
424        // Reshape to [batch * seq, hidden_dim], matmul, then reshape back
425        let x_flat = x
426            .reshape((batch_size * seq_len, hidden_dim))
427            .map_err(|e| CoreError::Generic(format!("Failed to reshape for output: {}", e)))?;
428
429        let output_dim = self.config.get_output_dim();
430        let output_flat = x_flat
431            .matmul(self.output_proj.as_tensor())
432            .map_err(|e| CoreError::Generic(format!("Output projection failed: {}", e)))?;
433
434        let output = output_flat
435            .reshape((batch_size, seq_len, output_dim))
436            .map_err(|e| CoreError::Generic(format!("Failed to reshape output: {}", e)))?;
437
438        Ok(output)
439    }
440
441    /// Apply layer normalization
442    fn layer_norm(&self, x: &Tensor, layer_idx: usize) -> CoreResult<Tensor> {
443        const EPS: f64 = 1e-5;
444
445        // Compute mean and variance along the last dimension
446        let mean = x
447            .mean_keepdim(candle_core::D::Minus1)
448            .map_err(|e| CoreError::Generic(format!("Layer norm mean failed: {}", e)))?;
449        let x_centered = x.broadcast_sub(&mean).map_err(|e| {
450            CoreError::Generic(format!("Layer norm variance computation failed: {}", e))
451        })?;
452        let variance = x_centered
453            .sqr()
454            .map_err(|e| CoreError::Generic(format!("Layer norm variance sqr failed: {}", e)))?
455            .mean_keepdim(candle_core::D::Minus1)
456            .map_err(|e| CoreError::Generic(format!("Layer norm variance mean failed: {}", e)))?;
457
458        // Normalize: (x - mean) / sqrt(variance + eps)
459        let std = (variance.affine(1.0, EPS))
460            .map_err(|e| CoreError::Generic(format!("Layer norm variance add eps failed: {}", e)))?
461            .sqrt()
462            .map_err(|e| CoreError::Generic(format!("Layer norm sqrt failed: {}", e)))?;
463
464        let normalized = x_centered
465            .broadcast_div(&std)
466            .map_err(|e| CoreError::Generic(format!("Layer norm division failed: {}", e)))?;
467
468        // Apply affine transformation
469        let gamma = self.ln_gamma[layer_idx].as_tensor();
470        let beta = self.ln_beta[layer_idx].as_tensor();
471
472        normalized
473            .broadcast_mul(gamma)
474            .map_err(|e| CoreError::Generic(format!("Layer norm gamma mul failed: {}", e)))?
475            .broadcast_add(beta)
476            .map_err(|e| CoreError::Generic(format!("Layer norm beta add failed: {}", e)))
477    }
478
479    /// SSM layer computation
480    fn ssm_layer(&self, x: &Tensor, _h: &mut Tensor, layer_idx: usize) -> CoreResult<Tensor> {
481        let _a = self.a_matrices[layer_idx].as_tensor();
482        let _b = self.b_matrices[layer_idx].as_tensor();
483        let _c = self.c_matrices[layer_idx].as_tensor();
484        let d = self.d_vectors[layer_idx].as_tensor();
485
486        // Simplified SSM step (full implementation would include selective scan)
487        // For now, implementing a basic skip connection
488        // TODO: Implement proper selective scan mechanism with state evolution
489
490        // For training, we process the entire sequence in parallel (teacher forcing)
491        // Output: y = D * x (simplified - full version uses state)
492        let y = x
493            .broadcast_mul(d)
494            .map_err(|e| CoreError::Generic(format!("Skip connection failed: {}", e)))?;
495
496        Ok(y)
497    }
498
499    /// Create an optimizer for this model
500    pub fn create_optimizer(&self) -> CoreResult<AdamW> {
501        let params = ParamsAdamW {
502            lr: self.training_config.learning_rate,
503            beta1: self.training_config.beta1,
504            beta2: self.training_config.beta2,
505            eps: self.training_config.eps,
506            weight_decay: self.training_config.weight_decay,
507        };
508
509        AdamW::new(self.varmap.all_vars(), params)
510            .map_err(|e| CoreError::Generic(format!("Failed to create optimizer: {}", e)))
511    }
512
513    /// Get the variable map for loading/saving weights
514    pub fn varmap(&self) -> &VarMap {
515        &self.varmap
516    }
517
518    /// Get device
519    pub fn device(&self) -> &Device {
520        &self.device
521    }
522
523    /// Get dtype
524    pub fn dtype(&self) -> DType {
525        self.dtype
526    }
527
528    /// Save model weights to a safetensors file
529    ///
530    /// # Arguments
531    /// * `path` - Path to save the safetensors file
532    ///
533    /// # Example
534    /// ```rust,ignore
535    /// model.save_weights("model.safetensors")?;
536    /// ```
537    pub fn save_weights<P: AsRef<std::path::Path>>(&self, path: P) -> CoreResult<()> {
538        self.varmap
539            .save(path)
540            .map_err(|e| CoreError::Generic(format!("Failed to save weights: {}", e)))
541    }
542
543    /// Load model weights from a safetensors file
544    ///
545    /// # Arguments
546    /// * `path` - Path to the safetensors file
547    ///
548    /// # Example
549    /// ```rust,ignore
550    /// model.load_weights("model.safetensors")?;
551    /// ```
552    pub fn load_weights<P: AsRef<std::path::Path>>(&mut self, path: P) -> CoreResult<()> {
553        self.varmap
554            .load(path)
555            .map_err(|e| CoreError::Generic(format!("Failed to load weights: {}", e)))
556    }
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562    use candle_core::Tensor;
563
564    #[test]
565    fn test_trainable_ssm_creation() {
566        let config = KizzasiConfig::new()
567            .input_dim(3)
568            .output_dim(3)
569            .hidden_dim(64)
570            .state_dim(8)
571            .num_layers(2);
572
573        let training_config = TrainingConfig::default();
574
575        let model = TrainableSSM::new(config, training_config);
576        assert!(model.is_ok());
577    }
578
579    #[test]
580    fn test_forward_pass() {
581        let config = KizzasiConfig::new()
582            .input_dim(3)
583            .output_dim(3)
584            .hidden_dim(64)
585            .state_dim(8)
586            .num_layers(2);
587
588        let training_config = TrainingConfig::default();
589
590        let model = TrainableSSM::new(config, training_config).unwrap();
591        let device = model.device().clone();
592
593        // Create dummy input: [batch=2, seq=10, input_dim=3]
594        let input = Tensor::randn(0f32, 1.0, (2, 10, 3), &device).unwrap();
595
596        let output = model.forward(&input);
597        if let Err(e) = &output {
598            panic!("Forward pass failed: {:?}", e);
599        }
600
601        let output = output.unwrap();
602        assert_eq!(output.dims(), &[2, 10, 3]);
603    }
604
605    #[test]
606    fn test_training_config_default() {
607        let config = TrainingConfig::default();
608        assert_eq!(config.learning_rate, 1e-4);
609        assert_eq!(config.batch_size, 32);
610        assert_eq!(config.epochs, 10);
611        assert!(config.track_metrics);
612        assert_eq!(config.validation_split, 0.2);
613        assert_eq!(config.early_stopping_patience, Some(5));
614    }
615
616    #[test]
617    fn test_training_config_with_scheduler() {
618        let config = TrainingConfig::default().with_scheduler(SchedulerType::Cosine {
619            warmup_steps: 100,
620            min_lr: 1e-6,
621        });
622
623        assert!(config.scheduler.is_some());
624        if let Some(SchedulerType::Cosine {
625            warmup_steps,
626            min_lr,
627        }) = config.scheduler
628        {
629            assert_eq!(warmup_steps, 100);
630            assert_eq!(min_lr, 1e-6);
631        } else {
632            panic!("Expected Cosine scheduler");
633        }
634    }
635
636    #[test]
637    fn test_training_config_builder() {
638        let config = TrainingConfig::default()
639            .with_validation_split(0.15)
640            .with_early_stopping(10)
641            .without_metrics();
642
643        assert_eq!(config.validation_split, 0.15);
644        assert_eq!(config.early_stopping_patience, Some(10));
645        assert!(!config.track_metrics);
646    }
647
648    #[test]
649    fn test_scheduler_type_constant() {
650        let config = TrainingConfig::default().with_scheduler(SchedulerType::Constant);
651
652        assert!(config.scheduler.is_some());
653    }
654
655    #[test]
656    fn test_scheduler_type_step() {
657        let config = TrainingConfig::default().with_scheduler(SchedulerType::Step {
658            milestones: vec![100, 200, 300],
659            decay_factor: 0.1,
660        });
661
662        if let Some(SchedulerType::Step {
663            milestones,
664            decay_factor,
665        }) = config.scheduler
666        {
667            assert_eq!(milestones, vec![100, 200, 300]);
668            assert_eq!(decay_factor, 0.1);
669        } else {
670            panic!("Expected Step scheduler");
671        }
672    }
673
674    #[test]
675    fn test_scheduler_type_onecycle() {
676        let config =
677            TrainingConfig::default().with_scheduler(SchedulerType::OneCycle { warmup_pct: 0.3 });
678
679        if let Some(SchedulerType::OneCycle { warmup_pct }) = config.scheduler {
680            assert_eq!(warmup_pct, 0.3);
681        } else {
682            panic!("Expected OneCycle scheduler");
683        }
684    }
685}