burn_core/lr_scheduler/
linear.rs1use super::{LrScheduler, String};
2use crate as burn;
3use crate::{LearningRate, config::Config};
4use burn_tensor::backend::Backend;
5
6#[derive(Config)]
13pub struct LinearLrSchedulerConfig {
14 initial_lr: LearningRate,
16 final_lr: LearningRate,
18 num_iters: usize,
20}
21
22impl LinearLrSchedulerConfig {
23 pub fn init(&self) -> Result<LinearLrScheduler, String> {
33 if self.initial_lr <= 0. || self.initial_lr > 1. {
34 return Err("Initial learning rate must be greater than 0 and at most 1".into());
35 }
36 if self.final_lr < 0. || self.final_lr > 1. {
37 return Err("Final learning rate must be at least 0 and at most 1".into());
38 }
39 if self.num_iters == 0 {
40 return Err("Number of iterations must be at least 1".into());
41 }
42
43 Ok(LinearLrScheduler {
44 final_lr: self.final_lr,
45 step_size: (self.final_lr - self.initial_lr) / self.num_iters as f64,
46 remaining_iters: self.num_iters + 1,
47 })
48 }
49}
50
51#[derive(Clone, Copy, Debug)]
55pub struct LinearLrScheduler {
56 final_lr: LearningRate,
58 step_size: f64,
60 remaining_iters: usize,
62}
63
64impl LrScheduler for LinearLrScheduler {
65 type Record<B: Backend> = usize;
66
67 fn step(&mut self) -> LearningRate {
68 self.remaining_iters -= (self.remaining_iters != 0) as usize;
69 self.final_lr - self.step_size * self.remaining_iters as f64
70 }
71
72 fn to_record<B: Backend>(&self) -> Self::Record<B> {
73 self.remaining_iters
74 }
75
76 fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
77 self.remaining_iters = record;
78 self
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::super::test_utils;
85 use super::*;
86
87 #[test]
88 fn config_initial_lr_too_low() {
89 let r = LinearLrSchedulerConfig::new(0., 0.5, 100).init();
90 assert!(r.is_err(), "Should return an error");
91 assert_eq!(
92 r.unwrap_err(),
93 "Initial learning rate must be greater than 0 and at most 1",
94 "Error messages should match",
95 );
96 }
97
98 #[test]
99 fn config_initial_lr_too_high() {
100 let r = LinearLrSchedulerConfig::new(1.5, 0.5, 100).init();
101 assert!(r.is_err(), "Should return an error");
102 assert_eq!(
103 r.unwrap_err(),
104 "Initial learning rate must be greater than 0 and at most 1",
105 "Error messages should match",
106 );
107 }
108
109 #[test]
110 fn config_final_lr_too_low() {
111 let r = LinearLrSchedulerConfig::new(0.5, -0.5, 100).init();
112 assert!(r.is_err(), "Should return an error");
113 assert_eq!(
114 r.unwrap_err(),
115 "Final learning rate must be at least 0 and at most 1",
116 "Error messages should match",
117 );
118 }
119
120 #[test]
121 fn config_final_lr_too_high() {
122 let r = LinearLrSchedulerConfig::new(0.5, 1.5, 100).init();
123 assert!(r.is_err(), "Should return an error");
124 assert_eq!(
125 r.unwrap_err(),
126 "Final learning rate must be at least 0 and at most 1",
127 "Error messages should match",
128 );
129 }
130
131 #[test]
132 fn config_num_iters_too_low() {
133 let r = LinearLrSchedulerConfig::new(0.9, 0.1, 0).init();
134 assert!(r.is_err(), "Should return an error");
135 assert_eq!(
136 r.unwrap_err(),
137 "Number of iterations must be at least 1",
138 "Error messages should match",
139 );
140 }
141
142 #[test]
143 fn test_lr_decreasing() {
144 let scheduler = LinearLrSchedulerConfig::new(0.9, 0.5, 4).init().unwrap();
145 let expected_lrs = [0.9, 0.8, 0.7, 0.6, 0.5, 0.5];
146 test_utils::check_lr_sequence(scheduler, expected_lrs);
147 }
148
149 #[test]
150 fn test_lr_increasing() {
151 let scheduler = LinearLrSchedulerConfig::new(0.01, 0.04, 3).init().unwrap();
152 let expected_lrs = [0.01, 0.02, 0.03, 0.04, 0.04];
153 test_utils::check_lr_sequence(scheduler, expected_lrs);
154 }
155
156 #[test]
157 fn test_lr_unchanging() {
158 let scheduler = LinearLrSchedulerConfig::new(0.3, 0.3, 2).init().unwrap();
159 let expected_lrs = [0.3, 0.3, 0.3, 0.3];
160 test_utils::check_lr_sequence(scheduler, expected_lrs);
161 }
162
163 #[test]
164 fn test_save_and_load() {
165 const NUM_ITERS: usize = 6;
166 let scheduler = LinearLrSchedulerConfig::new(1.0, 0.01, NUM_ITERS)
167 .init()
168 .unwrap();
169 test_utils::check_save_load(scheduler, NUM_ITERS / 3 * 2);
170 }
171}