burn_optim/lr_scheduler/
base.rs1pub(super) use alloc::string::String;
2use burn_core as burn;
3
4use burn::record::Record;
5use burn::tensor::backend::Backend;
6
7use crate::LearningRate;
8
9pub trait LrScheduler: Clone + Send + Sync {
11 type Record<B: Backend>: Record<B> + Clone + 'static;
13
14 fn step(&mut self) -> LearningRate;
17
18 fn to_record<B: Backend>(&self) -> Self::Record<B>;
20
21 fn load_record<B: Backend>(self, record: Self::Record<B>) -> Self;
23}
24
25#[cfg(test)]
26pub(super) mod test_utils {
27 use super::*;
28 use crate::TestBackend;
29
30 const LOOSE_EPSILON: LearningRate = 1e-10;
34
35 pub fn check_lr_sequence<I, S>(mut scheduler: S, expected_lrs: I)
36 where
37 I: IntoIterator<Item = LearningRate>,
38 S: LrScheduler,
39 {
40 expected_lrs
41 .into_iter()
42 .enumerate()
43 .for_each(|(i, expected)| {
44 let lr = scheduler.step();
45 assert!(
46 (lr - expected).abs() < LOOSE_EPSILON,
47 "Scheduled learning rate {lr} is not approximately equal to the expected value \
48 {expected} at step {i}",
49 );
50 });
51 }
52
53 pub fn check_save_load<S>(mut scheduler: S, save_at_step: usize)
56 where
57 S: Clone + LrScheduler,
58 {
59 let mut truth = scheduler.clone();
60 (0..save_at_step).for_each(|_| {
62 truth.step();
63 scheduler.step();
64 });
65 let rec = scheduler.to_record::<TestBackend>();
66 scheduler = scheduler.load_record::<TestBackend>(rec);
67
68 compare_steps(&mut scheduler, &mut truth, save_at_step);
70 }
71
72 pub fn compare_steps<S: LrScheduler>(a: &mut S, b: &mut S, num_steps: usize) {
75 (0..num_steps).for_each(|i| {
76 let lr_a = a.step();
77 let lr_b = b.step();
78 assert!(
79 (lr_a - lr_b).abs() < LOOSE_EPSILON,
80 "The two learning rates ({lr_a}, {lr_b}) at position {i} in the remaining \
81 sequences are not approximately equal",
82 );
83 });
84 }
85}