Skip to main content

burn_optim/lr_scheduler/
base.rs

1pub(super) use alloc::string::String;
2use burn_core as burn;
3
4use burn::record::Record;
5use burn::tensor::backend::Backend;
6
7use crate::LearningRate;
8
9/// Learning rate scheduler defines how the learning rate will evolve during training.
10pub trait LrScheduler: Clone + Send + Sync {
11    /// Scheduler associative type to be used when saving and loading the state.
12    type Record<B: Backend>: Record<B> + Clone + 'static;
13
14    /// Perform the scheduler step, potentially updating its state, and returning the effective
15    /// learning rate.
16    fn step(&mut self) -> LearningRate;
17
18    /// Get the current state of the scheduler as a [record](Record).
19    fn to_record<B: Backend>(&self) -> Self::Record<B>;
20
21    /// Load the state of the scheduler as a [record](Record).
22    fn load_record<B: Backend>(self, record: Self::Record<B>) -> Self;
23}
24
25#[cfg(test)]
26pub(super) mod test_utils {
27    use super::*;
28    use crate::TestBackend;
29
30    // A small tolerance for learning rate comparisons. Depending on how learning rates are
31    // computed, floating-point arithmetic error might exceed f64::EPSILON, so a larger value is
32    // used here.
33    const LOOSE_EPSILON: LearningRate = 1e-10;
34
35    pub fn check_lr_sequence<I, S>(mut scheduler: S, expected_lrs: I)
36    where
37        I: IntoIterator<Item = LearningRate>,
38        S: LrScheduler,
39    {
40        expected_lrs
41            .into_iter()
42            .enumerate()
43            .for_each(|(i, expected)| {
44                let lr = scheduler.step();
45                assert!(
46                    (lr - expected).abs() < LOOSE_EPSILON,
47                    "Scheduled learning rate {lr} is not approximately equal to the expected value \
48                     {expected} at step {i}",
49                );
50            });
51    }
52
53    // save_at_step is the number of steps to run the scheduler before saving and loading back its
54    // state.
55    pub fn check_save_load<S>(mut scheduler: S, save_at_step: usize)
56    where
57        S: Clone + LrScheduler,
58    {
59        let mut truth = scheduler.clone();
60        // Consume some steps before saving and loading back
61        (0..save_at_step).for_each(|_| {
62            truth.step();
63            scheduler.step();
64        });
65        let rec = scheduler.to_record::<TestBackend>();
66        scheduler = scheduler.load_record::<TestBackend>(rec);
67
68        // Validate that the scheduler resumes from where it left off.
69        compare_steps(&mut scheduler, &mut truth, save_at_step);
70    }
71
72    // Check if two schedulers produce the same learning rate sequences over the specified number of
73    // steps.
74    pub fn compare_steps<S: LrScheduler>(a: &mut S, b: &mut S, num_steps: usize) {
75        (0..num_steps).for_each(|i| {
76            let lr_a = a.step();
77            let lr_b = b.step();
78            assert!(
79                (lr_a - lr_b).abs() < LOOSE_EPSILON,
80                "The two learning rates ({lr_a}, {lr_b}) at position {i} in the remaining \
81                 sequences are not approximately equal",
82            );
83        });
84    }
85}