burn_core/lr_scheduler/
base.rs

1pub(super) use alloc::string::String;
2
3use burn_tensor::backend::Backend;
4
5use crate::{LearningRate, record::Record};
6
7/// Learning rate scheduler defines how the learning rate will evolve during training.
8pub trait LrScheduler: Send + Sync {
9    /// Scheduler associative type to be used when saving and loading the state.
10    type Record<B: Backend>: Record<B>;
11
12    /// Perform the scheduler step, potentially updating its state, and returning the effective
13    /// learning rate.
14    fn step(&mut self) -> LearningRate;
15
16    /// Get the current state of the scheduler as a [record](Record).
17    fn to_record<B: Backend>(&self) -> Self::Record<B>;
18
19    /// Load the state of the scheduler as a [record](Record).
20    fn load_record<B: Backend>(self, record: Self::Record<B>) -> Self;
21}
22
23#[cfg(test)]
24pub(super) mod test_utils {
25    use super::*;
26    use crate::TestBackend;
27
28    // A small tolerance for learning rate comparisons. Depending on how learning rates are
29    // computed, floating-point arithmetic error might exceed f64::EPSILON, so a larger value is
30    // used here.
31    const LOOSE_EPSILON: LearningRate = 1e-10;
32
33    pub fn check_lr_sequence<I, S>(mut scheduler: S, expected_lrs: I)
34    where
35        I: IntoIterator<Item = LearningRate>,
36        S: LrScheduler,
37    {
38        expected_lrs
39            .into_iter()
40            .enumerate()
41            .for_each(|(i, expected)| {
42                let lr = scheduler.step();
43                assert!(
44                    (lr - expected).abs() < LOOSE_EPSILON,
45                    "Scheduled learning rate {lr} is not approximately equal to the expected value \
46                     {expected} at step {i}",
47                );
48            });
49    }
50
51    // save_at_step is the number of steps to run the scheduler before saving and loading back its
52    // state.
53    pub fn check_save_load<S>(mut scheduler: S, save_at_step: usize)
54    where
55        S: Clone + LrScheduler,
56    {
57        let mut truth = scheduler.clone();
58        // Consume some steps before saving and loading back
59        (0..save_at_step).for_each(|_| {
60            truth.step();
61            scheduler.step();
62        });
63        let rec = scheduler.to_record::<TestBackend>();
64        scheduler = scheduler.load_record::<TestBackend>(rec);
65
66        // Validate that the scheduler resumes from where it left off.
67        compare_steps(&mut scheduler, &mut truth, save_at_step);
68    }
69
70    // Check if two schedulers produce the same learning rate sequences over the specified number of
71    // steps.
72    pub fn compare_steps<S: LrScheduler>(a: &mut S, b: &mut S, num_steps: usize) {
73        (0..num_steps).for_each(|i| {
74            let lr_a = a.step();
75            let lr_b = b.step();
76            assert!(
77                (lr_a - lr_b).abs() < LOOSE_EPSILON,
78                "The two learning rates ({lr_a}, {lr_b}) at position {i} in the remaining \
79                 sequences are not approximately equal",
80            );
81        });
82    }
83}