1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
//! Learning rate schedule configuration
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::error::{Result, TuneError};
/// Learning rate schedule
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
pub enum LRSchedule {
/// Constant learning rate
Constant,
/// Linear warmup then constant
LinearWarmup {
/// Number of warmup steps
warmup_steps: usize,
},
/// Step decay (reduce by factor every step_size epochs)
StepDecay {
/// Epochs between reductions
step_size: usize,
/// Decay factor (e.g., 0.1 = reduce to 10%)
gamma: f32,
},
/// Exponential decay
ExponentialDecay {
/// Decay rate per epoch
gamma: f32,
},
/// Cosine annealing
CosineAnnealing {
/// Minimum learning rate
min_lr: f32,
/// Period in epochs
t_max: usize,
},
/// Cosine annealing with warmup
CosineAnnealingWarmup {
/// Number of warmup steps
warmup_steps: usize,
/// Minimum learning rate
min_lr: f32,
/// Period in epochs (after warmup)
t_max: usize,
},
/// One cycle policy
OneCycle {
/// Maximum learning rate
max_lr: f32,
/// Percentage of cycle spent increasing LR
pct_start: f32,
/// Total steps in cycle
total_steps: usize,
},
}
impl Default for LRSchedule {
fn default() -> Self {
Self::CosineAnnealingWarmup {
warmup_steps: 100,
min_lr: 1e-6,
t_max: 100,
}
}
}
impl LRSchedule {
/// Validate the schedule parameters to catch division-by-zero and other
/// invalid configurations before training begins.
pub fn validate(&self) -> Result<()> {
match self {
LRSchedule::Constant => {}
LRSchedule::LinearWarmup { warmup_steps } => {
if *warmup_steps == 0 {
return Err(TuneError::InvalidConfig(
"LinearWarmup: warmup_steps must be > 0".into(),
));
}
}
LRSchedule::StepDecay { step_size, gamma } => {
if *step_size == 0 {
return Err(TuneError::InvalidConfig(
"StepDecay: step_size must be > 0".into(),
));
}
if !gamma.is_finite() || *gamma <= 0.0 {
return Err(TuneError::InvalidConfig(
"StepDecay: gamma must be finite and > 0".into(),
));
}
}
LRSchedule::ExponentialDecay { gamma } => {
if !gamma.is_finite() || *gamma <= 0.0 {
return Err(TuneError::InvalidConfig(
"ExponentialDecay: gamma must be finite and > 0".into(),
));
}
}
LRSchedule::CosineAnnealing { min_lr, t_max } => {
if *t_max == 0 {
return Err(TuneError::InvalidConfig(
"CosineAnnealing: t_max must be > 0".into(),
));
}
if !min_lr.is_finite() || *min_lr < 0.0 {
return Err(TuneError::InvalidConfig(
"CosineAnnealing: min_lr must be finite and >= 0".into(),
));
}
}
LRSchedule::CosineAnnealingWarmup {
warmup_steps,
min_lr,
t_max,
} => {
if *t_max == 0 {
return Err(TuneError::InvalidConfig(
"CosineAnnealingWarmup: t_max must be > 0".into(),
));
}
if *warmup_steps == 0 {
return Err(TuneError::InvalidConfig(
"CosineAnnealingWarmup: warmup_steps must be > 0".into(),
));
}
if !min_lr.is_finite() || *min_lr < 0.0 {
return Err(TuneError::InvalidConfig(
"CosineAnnealingWarmup: min_lr must be finite and >= 0".into(),
));
}
}
LRSchedule::OneCycle {
max_lr,
pct_start,
total_steps,
} => {
if *total_steps == 0 {
return Err(TuneError::InvalidConfig(
"OneCycle: total_steps must be > 0".into(),
));
}
if !max_lr.is_finite() || *max_lr <= 0.0 {
return Err(TuneError::InvalidConfig(
"OneCycle: max_lr must be finite and > 0".into(),
));
}
if !pct_start.is_finite() || *pct_start <= 0.0 || *pct_start >= 1.0 {
return Err(TuneError::InvalidConfig(
"OneCycle: pct_start must be finite and in (0, 1)".into(),
));
}
}
}
Ok(())
}
/// Calculate learning rate at given step
pub fn get_lr(&self, base_lr: f32, step: usize, epoch: usize) -> f32 {
match self {
LRSchedule::Constant => base_lr,
LRSchedule::LinearWarmup { warmup_steps } => {
if *warmup_steps == 0 || step >= *warmup_steps {
base_lr
} else {
base_lr * (step as f32 / *warmup_steps as f32)
}
}
LRSchedule::StepDecay { step_size, gamma } => {
if *step_size == 0 {
base_lr
} else {
let num_decays = epoch / step_size;
base_lr * gamma.powi(num_decays as i32)
}
}
LRSchedule::ExponentialDecay { gamma } => base_lr * gamma.powi(epoch as i32),
LRSchedule::CosineAnnealing { min_lr, t_max } => {
if *t_max == 0 {
return base_lr;
}
let progress = (epoch % t_max) as f32 / *t_max as f32;
*min_lr
+ (base_lr - *min_lr) * (1.0 + (progress * std::f32::consts::PI).cos()) / 2.0
}
LRSchedule::CosineAnnealingWarmup {
warmup_steps,
min_lr,
t_max,
} => {
if *warmup_steps > 0 && step < *warmup_steps {
base_lr * (step as f32 / *warmup_steps as f32)
} else if *t_max == 0 {
base_lr
} else {
let progress = (epoch % t_max) as f32 / *t_max as f32;
*min_lr
+ (base_lr - *min_lr) * (1.0 + (progress * std::f32::consts::PI).cos())
/ 2.0
}
}
LRSchedule::OneCycle {
max_lr,
pct_start,
total_steps,
} => {
let pct = step as f32 / *total_steps as f32;
if pct < *pct_start {
// Increasing phase
let phase_pct = pct / *pct_start;
base_lr + (max_lr - base_lr) * phase_pct
} else {
// Decreasing phase
let phase_pct = (pct - *pct_start) / (1.0 - *pct_start);
*max_lr * (1.0 - phase_pct) + base_lr * 0.01 * phase_pct
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lr_schedule() {
// Constant
let constant = LRSchedule::Constant;
assert_eq!(constant.get_lr(0.01, 100, 10), 0.01);
// Linear warmup
let warmup = LRSchedule::LinearWarmup { warmup_steps: 100 };
assert!(warmup.get_lr(0.01, 0, 0) < 0.001); // Start low
assert!(warmup.get_lr(0.01, 50, 0) < 0.01); // Mid warmup
assert_eq!(warmup.get_lr(0.01, 100, 1), 0.01); // After warmup
// Step decay
let step = LRSchedule::StepDecay {
step_size: 10,
gamma: 0.1,
};
assert_eq!(step.get_lr(0.01, 0, 0), 0.01);
assert!((step.get_lr(0.01, 0, 10) - 0.001).abs() < 1e-6);
}
}