1#![allow(clippy::cast_precision_loss)]
10
11#[derive(Debug, Clone, Copy, PartialEq, Default)]
13pub enum LrSchedule {
14 #[default]
16 Constant,
17 LinearWarmup {
19 warmup_steps: usize,
21 },
22 CosineAnnealing {
24 total_steps: usize,
26 min_lr: f64,
28 },
29 LinearDecay {
31 total_steps: usize,
33 min_lr: f64,
35 },
36}
37
38impl LrSchedule {
39 #[must_use]
48 pub fn get_lr(&self, step: usize, base_lr: f64) -> f64 {
49 match self {
50 Self::Constant => base_lr,
51 Self::LinearWarmup { warmup_steps } => {
52 if *warmup_steps == 0 || step >= *warmup_steps {
53 base_lr
54 } else {
55 base_lr * (step as f64 / *warmup_steps as f64)
56 }
57 }
58 Self::CosineAnnealing {
59 total_steps,
60 min_lr,
61 } => {
62 if *total_steps == 0 || step >= *total_steps {
63 *min_lr
64 } else {
65 let progress = step as f64 / *total_steps as f64;
66 let cosine_decay = f64::midpoint(1.0, (std::f64::consts::PI * progress).cos());
67 min_lr + (base_lr - min_lr) * cosine_decay
68 }
69 }
70 Self::LinearDecay {
71 total_steps,
72 min_lr,
73 } => {
74 if *total_steps == 0 || step >= *total_steps {
75 *min_lr
76 } else {
77 let progress = step as f64 / *total_steps as f64;
78 base_lr - (base_lr - min_lr) * progress
79 }
80 }
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
87pub struct AdapterTrainingConfig {
88 pub learning_rate: f64,
90 pub lr_schedule: LrSchedule,
92 pub weight_decay: f64,
94 pub gradient_accumulation_steps: usize,
96 pub max_grad_norm: Option<f64>,
98}
99
100impl Default for AdapterTrainingConfig {
101 fn default() -> Self {
102 Self {
103 learning_rate: 1e-4,
104 lr_schedule: LrSchedule::Constant,
105 weight_decay: 0.0,
106 gradient_accumulation_steps: 1,
107 max_grad_norm: Some(1.0),
108 }
109 }
110}
111
112#[derive(Debug, Clone)]
114pub struct AdapterTrainingState {
115 pub global_step: usize,
117 pub epoch: usize,
119 pub steps_in_epoch: usize,
121 pub accumulated_steps: usize,
123 pub best_val_loss: Option<f64>,
125 config: AdapterTrainingConfig,
127}
128
129impl AdapterTrainingState {
130 #[must_use]
132 pub fn new(config: AdapterTrainingConfig) -> Self {
133 Self {
134 global_step: 0,
135 epoch: 0,
136 steps_in_epoch: 0,
137 accumulated_steps: 0,
138 best_val_loss: None,
139 config,
140 }
141 }
142
143 #[must_use]
145 pub fn current_lr(&self) -> f64 {
146 self.config
147 .lr_schedule
148 .get_lr(self.global_step, self.config.learning_rate)
149 }
150
151 #[must_use]
153 pub fn should_update(&self) -> bool {
154 self.accumulated_steps >= self.config.gradient_accumulation_steps
155 }
156
157 pub fn step(&mut self) -> bool {
161 self.accumulated_steps += 1;
162 self.steps_in_epoch += 1;
163
164 if self.should_update() {
165 self.global_step += 1;
166 self.accumulated_steps = 0;
167 true
168 } else {
169 false
170 }
171 }
172
173 pub fn new_epoch(&mut self) {
175 self.epoch += 1;
176 self.steps_in_epoch = 0;
177 }
178
179 pub fn update_best_val_loss(&mut self, val_loss: f64) -> bool {
183 match self.best_val_loss {
184 Some(best) if val_loss >= best => false,
185 _ => {
186 self.best_val_loss = Some(val_loss);
187 true
188 }
189 }
190 }
191
192 #[must_use]
194 pub fn gradient_accumulation_steps(&self) -> usize {
195 self.config.gradient_accumulation_steps
196 }
197
198 #[must_use]
200 pub fn max_grad_norm(&self) -> Option<f64> {
201 self.config.max_grad_norm
202 }
203
204 #[must_use]
206 pub fn weight_decay(&self) -> f64 {
207 self.config.weight_decay
208 }
209}
210
211#[must_use]
219pub fn count_trainable_parameters<A: crate::traits::Adapter>(adapter: &A) -> usize {
220 adapter.num_parameters()
221}
222
223#[must_use]
231pub fn format_parameter_count(count: usize) -> String {
232 if count >= 1_000_000_000 {
233 format!("{:.2}B", count as f64 / 1_000_000_000.0)
234 } else if count >= 1_000_000 {
235 format!("{:.2}M", count as f64 / 1_000_000.0)
236 } else if count >= 1_000 {
237 format!("{:.2}K", count as f64 / 1_000.0)
238 } else {
239 count.to_string()
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn test_constant_lr() {
249 let schedule = LrSchedule::Constant;
250 assert!((schedule.get_lr(0, 0.001) - 0.001).abs() < 1e-10);
251 assert!((schedule.get_lr(100, 0.001) - 0.001).abs() < 1e-10);
252 assert!((schedule.get_lr(1000, 0.001) - 0.001).abs() < 1e-10);
253 }
254
255 #[test]
256 fn test_linear_warmup() {
257 let schedule = LrSchedule::LinearWarmup { warmup_steps: 100 };
258 assert!((schedule.get_lr(0, 0.001) - 0.0).abs() < 1e-10);
259 assert!((schedule.get_lr(50, 0.001) - 0.0005).abs() < 1e-10);
260 assert!((schedule.get_lr(100, 0.001) - 0.001).abs() < 1e-10);
261 assert!((schedule.get_lr(200, 0.001) - 0.001).abs() < 1e-10);
262 }
263
264 #[test]
265 #[allow(clippy::similar_names)]
266 fn test_cosine_annealing() {
267 let schedule = LrSchedule::CosineAnnealing {
268 total_steps: 100,
269 min_lr: 0.0001,
270 };
271
272 let lr_0 = schedule.get_lr(0, 0.001);
274 assert!((lr_0 - 0.001).abs() < 1e-10);
275
276 let lr_50 = schedule.get_lr(50, 0.001);
278 let expected_50 = 0.0001 + (0.001 - 0.0001) * 0.5;
279 assert!((lr_50 - expected_50).abs() < 1e-6);
280
281 let lr_100 = schedule.get_lr(100, 0.001);
283 assert!((lr_100 - 0.0001).abs() < 1e-10);
284 }
285
286 #[test]
287 fn test_linear_decay() {
288 let schedule = LrSchedule::LinearDecay {
289 total_steps: 100,
290 min_lr: 0.0001,
291 };
292
293 assert!((schedule.get_lr(0, 0.001) - 0.001).abs() < 1e-10);
294 assert!((schedule.get_lr(50, 0.001) - 0.00055).abs() < 1e-10);
295 assert!((schedule.get_lr(100, 0.001) - 0.0001).abs() < 1e-10);
296 }
297
298 #[test]
299 fn test_training_state_step() {
300 let config = AdapterTrainingConfig {
301 gradient_accumulation_steps: 4,
302 ..Default::default()
303 };
304 let mut state = AdapterTrainingState::new(config);
305
306 assert!(!state.step()); assert!(!state.step()); assert!(!state.step()); assert!(state.step()); assert_eq!(state.global_step, 1);
311 assert_eq!(state.accumulated_steps, 0);
312
313 assert!(!state.step()); assert!(!state.step()); assert!(!state.step()); assert!(state.step()); assert_eq!(state.global_step, 2);
318 }
319
320 #[test]
321 fn test_training_state_epoch() {
322 let config = AdapterTrainingConfig::default();
323 let mut state = AdapterTrainingState::new(config);
324
325 state.step();
326 state.step();
327 assert_eq!(state.steps_in_epoch, 2);
328
329 state.new_epoch();
330 assert_eq!(state.epoch, 1);
331 assert_eq!(state.steps_in_epoch, 0);
332 }
333
334 #[test]
335 fn test_best_val_loss() {
336 let config = AdapterTrainingConfig::default();
337 let mut state = AdapterTrainingState::new(config);
338
339 assert!(state.update_best_val_loss(1.0));
340 assert_eq!(state.best_val_loss, Some(1.0));
341
342 assert!(state.update_best_val_loss(0.5));
343 assert_eq!(state.best_val_loss, Some(0.5));
344
345 assert!(!state.update_best_val_loss(0.8));
346 assert_eq!(state.best_val_loss, Some(0.5));
347 }
348
349 #[test]
350 fn test_format_parameter_count() {
351 assert_eq!(format_parameter_count(100), "100");
352 assert_eq!(format_parameter_count(1_234), "1.23K");
353 assert_eq!(format_parameter_count(12_345_678), "12.35M");
354 assert_eq!(format_parameter_count(1_234_567_890), "1.23B");
355 }
356
357 #[test]
358 fn test_current_lr_with_schedule() {
359 let config = AdapterTrainingConfig {
360 learning_rate: 0.001,
361 lr_schedule: LrSchedule::LinearWarmup { warmup_steps: 10 },
362 ..Default::default()
363 };
364 let mut state = AdapterTrainingState::new(config);
365
366 assert!((state.current_lr() - 0.0).abs() < 1e-10);
368
369 for _ in 0..5 {
371 state.step();
372 }
373 assert!((state.current_lr() - 0.0005).abs() < 1e-10);
374
375 for _ in 0..5 {
377 state.step();
378 }
379 assert!((state.current_lr() - 0.001).abs() < 1e-10);
380 }
381
382 #[test]
383 fn test_zero_warmup_steps() {
384 let schedule = LrSchedule::LinearWarmup { warmup_steps: 0 };
386 assert!((schedule.get_lr(0, 0.001) - 0.001).abs() < 1e-10);
387 assert!((schedule.get_lr(100, 0.001) - 0.001).abs() < 1e-10);
388 }
389
390 #[test]
391 fn test_zero_total_steps_cosine() {
392 let schedule = LrSchedule::CosineAnnealing {
394 total_steps: 0,
395 min_lr: 0.0001,
396 };
397 assert!((schedule.get_lr(0, 0.001) - 0.0001).abs() < 1e-10);
398 }
399
400 #[test]
401 fn test_zero_total_steps_linear_decay() {
402 let schedule = LrSchedule::LinearDecay {
404 total_steps: 0,
405 min_lr: 0.0001,
406 };
407 assert!((schedule.get_lr(0, 0.001) - 0.0001).abs() < 1e-10);
408 }
409}