burn_optim/lr_scheduler/
linear.rs1use burn_core as burn;
2
3use super::{LrScheduler, String};
4use crate::LearningRate;
5use burn::config::Config;
6use burn::tensor::backend::Backend;
7
8#[derive(Config, Debug)]
15pub struct LinearLrSchedulerConfig {
16 initial_lr: LearningRate,
18 final_lr: LearningRate,
20 num_iters: usize,
22}
23
24impl LinearLrSchedulerConfig {
25 pub fn init(&self) -> Result<LinearLrScheduler, String> {
35 if self.initial_lr <= 0. || self.initial_lr > 1. {
36 return Err("Initial learning rate must be greater than 0 and at most 1".into());
37 }
38 if self.final_lr < 0. || self.final_lr > 1. {
39 return Err("Final learning rate must be at least 0 and at most 1".into());
40 }
41 if self.num_iters == 0 {
42 return Err("Number of iterations must be at least 1".into());
43 }
44
45 Ok(LinearLrScheduler {
46 final_lr: self.final_lr,
47 step_size: (self.final_lr - self.initial_lr) / self.num_iters as f64,
48 remaining_iters: self.num_iters + 1,
49 })
50 }
51}
52
53#[derive(Clone, Copy, Debug)]
57pub struct LinearLrScheduler {
58 final_lr: LearningRate,
60 step_size: f64,
62 remaining_iters: usize,
64}
65
66impl LrScheduler for LinearLrScheduler {
67 type Record<B: Backend> = usize;
68
69 fn step(&mut self) -> LearningRate {
70 self.remaining_iters -= (self.remaining_iters != 0) as usize;
71 self.final_lr - self.step_size * self.remaining_iters as f64
72 }
73
74 fn to_record<B: Backend>(&self) -> Self::Record<B> {
75 self.remaining_iters
76 }
77
78 fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
79 self.remaining_iters = record;
80 self
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::super::test_utils;
87 use super::*;
88
89 #[test]
90 fn config_initial_lr_too_low() {
91 let r = LinearLrSchedulerConfig::new(0., 0.5, 100).init();
92 assert!(r.is_err(), "Should return an error");
93 assert_eq!(
94 r.unwrap_err(),
95 "Initial learning rate must be greater than 0 and at most 1",
96 "Error messages should match",
97 );
98 }
99
100 #[test]
101 fn config_initial_lr_too_high() {
102 let r = LinearLrSchedulerConfig::new(1.5, 0.5, 100).init();
103 assert!(r.is_err(), "Should return an error");
104 assert_eq!(
105 r.unwrap_err(),
106 "Initial learning rate must be greater than 0 and at most 1",
107 "Error messages should match",
108 );
109 }
110
111 #[test]
112 fn config_final_lr_too_low() {
113 let r = LinearLrSchedulerConfig::new(0.5, -0.5, 100).init();
114 assert!(r.is_err(), "Should return an error");
115 assert_eq!(
116 r.unwrap_err(),
117 "Final learning rate must be at least 0 and at most 1",
118 "Error messages should match",
119 );
120 }
121
122 #[test]
123 fn config_final_lr_too_high() {
124 let r = LinearLrSchedulerConfig::new(0.5, 1.5, 100).init();
125 assert!(r.is_err(), "Should return an error");
126 assert_eq!(
127 r.unwrap_err(),
128 "Final learning rate must be at least 0 and at most 1",
129 "Error messages should match",
130 );
131 }
132
133 #[test]
134 fn config_num_iters_too_low() {
135 let r = LinearLrSchedulerConfig::new(0.9, 0.1, 0).init();
136 assert!(r.is_err(), "Should return an error");
137 assert_eq!(
138 r.unwrap_err(),
139 "Number of iterations must be at least 1",
140 "Error messages should match",
141 );
142 }
143
144 #[test]
145 fn test_lr_decreasing() {
146 let scheduler = LinearLrSchedulerConfig::new(0.9, 0.5, 4).init().unwrap();
147 let expected_lrs = [0.9, 0.8, 0.7, 0.6, 0.5, 0.5];
148 test_utils::check_lr_sequence(scheduler, expected_lrs);
149 }
150
151 #[test]
152 fn test_lr_increasing() {
153 let scheduler = LinearLrSchedulerConfig::new(0.01, 0.04, 3).init().unwrap();
154 let expected_lrs = [0.01, 0.02, 0.03, 0.04, 0.04];
155 test_utils::check_lr_sequence(scheduler, expected_lrs);
156 }
157
158 #[test]
159 fn test_lr_unchanging() {
160 let scheduler = LinearLrSchedulerConfig::new(0.3, 0.3, 2).init().unwrap();
161 let expected_lrs = [0.3, 0.3, 0.3, 0.3];
162 test_utils::check_lr_sequence(scheduler, expected_lrs);
163 }
164
165 #[test]
166 fn test_save_and_load() {
167 const NUM_ITERS: usize = 6;
168 let scheduler = LinearLrSchedulerConfig::new(1.0, 0.01, NUM_ITERS)
169 .init()
170 .unwrap();
171 test_utils::check_save_load(scheduler, NUM_ITERS / 3 * 2);
172 }
173}