burn_optim/lr_scheduler/
exponential.rs1use burn_core as burn;
2
3use super::{LrScheduler, String};
4use crate::LearningRate;
5use burn::config::Config;
6use burn::tensor::backend::Backend;
7
8#[derive(Config, Debug)]
14pub struct ExponentialLrSchedulerConfig {
15 initial_lr: LearningRate,
17 gamma: f64,
19}
20
21impl ExponentialLrSchedulerConfig {
22 pub fn init(&self) -> Result<ExponentialLrScheduler, String> {
31 if self.initial_lr <= 0. || self.initial_lr > 1. {
32 return Err("Initial learning rate must be greater than 0 and at most 1".into());
33 }
34 if self.gamma <= 0. || self.gamma > 1. {
35 return Err("Gamma must be greater than 0 and at most 1".into());
36 }
37
38 Ok(ExponentialLrScheduler {
39 previous_lr: self.initial_lr / self.gamma,
42 gamma: self.gamma,
43 })
44 }
45}
46
47#[derive(Clone, Copy, Debug)]
51pub struct ExponentialLrScheduler {
52 previous_lr: LearningRate,
54 gamma: f64,
56}
57
58impl LrScheduler for ExponentialLrScheduler {
59 type Record<B: Backend> = LearningRate;
60
61 fn step(&mut self) -> LearningRate {
62 self.previous_lr *= self.gamma;
63 self.previous_lr
64 }
65
66 fn to_record<B: Backend>(&self) -> Self::Record<B> {
67 self.previous_lr
68 }
69
70 fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
71 self.previous_lr = record;
72 self
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::super::test_utils;
79 use super::*;
80
81 #[test]
82 fn config_initial_lr_too_low() {
83 let r = ExponentialLrSchedulerConfig::new(0., 0.5).init();
84 assert!(r.is_err(), "Should return an error");
85 assert_eq!(
86 r.unwrap_err(),
87 "Initial learning rate must be greater than 0 and at most 1",
88 "Error messages should match",
89 );
90 }
91
92 #[test]
93 fn config_initial_lr_too_high() {
94 let r = ExponentialLrSchedulerConfig::new(1.5, 0.5).init();
95 assert!(r.is_err(), "Should return an error");
96 assert_eq!(
97 r.unwrap_err(),
98 "Initial learning rate must be greater than 0 and at most 1",
99 "Error messages should match",
100 );
101 }
102
103 #[test]
104 fn config_gamma_too_low() {
105 let r = ExponentialLrSchedulerConfig::new(0.5, 0.0).init();
106 assert!(r.is_err(), "Should return an error");
107 assert_eq!(
108 r.unwrap_err(),
109 "Gamma must be greater than 0 and at most 1",
110 "Error messages should match",
111 );
112 }
113
114 #[test]
115 fn config_gamma_too_high() {
116 let r = ExponentialLrSchedulerConfig::new(0.5, 1.5).init();
117 assert!(r.is_err(), "Should return an error");
118 assert_eq!(
119 r.unwrap_err(),
120 "Gamma must be greater than 0 and at most 1",
121 "Error messages should match",
122 );
123 }
124
125 #[test]
126 fn test_lr_change() {
127 let scheduler = ExponentialLrSchedulerConfig::new(0.8, 0.1).init().unwrap();
128 let expected_lrs = [0.8, 0.08, 0.008, 0.0008, 0.00008];
129 test_utils::check_lr_sequence(scheduler, expected_lrs);
130 }
131
132 #[test]
133 fn test_save_and_load() {
134 let scheduler = ExponentialLrSchedulerConfig::new(0.083, 0.3)
135 .init()
136 .unwrap();
137 test_utils::check_save_load(scheduler, 7);
138 }
139}