burn_core/lr_scheduler/
base.rs1pub(super) use alloc::string::String;
2
3use burn_tensor::backend::Backend;
4
5use crate::{LearningRate, record::Record};
6
7pub trait LrScheduler: Send + Sync {
9 type Record<B: Backend>: Record<B>;
11
12 fn step(&mut self) -> LearningRate;
15
16 fn to_record<B: Backend>(&self) -> Self::Record<B>;
18
19 fn load_record<B: Backend>(self, record: Self::Record<B>) -> Self;
21}
22
23#[cfg(test)]
24pub(super) mod test_utils {
25 use super::*;
26 use crate::TestBackend;
27
28 const LOOSE_EPSILON: LearningRate = 1e-10;
32
33 pub fn check_lr_sequence<I, S>(mut scheduler: S, expected_lrs: I)
34 where
35 I: IntoIterator<Item = LearningRate>,
36 S: LrScheduler,
37 {
38 expected_lrs
39 .into_iter()
40 .enumerate()
41 .for_each(|(i, expected)| {
42 let lr = scheduler.step();
43 assert!(
44 (lr - expected).abs() < LOOSE_EPSILON,
45 "Scheduled learning rate {lr} is not approximately equal to the expected value \
46 {expected} at step {i}",
47 );
48 });
49 }
50
51 pub fn check_save_load<S>(mut scheduler: S, save_at_step: usize)
54 where
55 S: Clone + LrScheduler,
56 {
57 let mut truth = scheduler.clone();
58 (0..save_at_step).for_each(|_| {
60 truth.step();
61 scheduler.step();
62 });
63 let rec = scheduler.to_record::<TestBackend>();
64 scheduler = scheduler.load_record::<TestBackend>(rec);
65
66 compare_steps(&mut scheduler, &mut truth, save_at_step);
68 }
69
70 pub fn compare_steps<S: LrScheduler>(a: &mut S, b: &mut S, num_steps: usize) {
73 (0..num_steps).for_each(|i| {
74 let lr_a = a.step();
75 let lr_b = b.step();
76 assert!(
77 (lr_a - lr_b).abs() < LOOSE_EPSILON,
78 "The two learning rates ({lr_a}, {lr_b}) at position {i} in the remaining \
79 sequences are not approximately equal",
80 );
81 });
82 }
83}