peft_rs/
training.rs

1//! Training utilities for PEFT adapters.
2//!
3//! This module provides functionality for:
4//! - Learning rate schedules for adapter training
5//! - Training state management
6//! - Parameter counting helpers
7
8// Allow usize to f64 casts for learning rate calculations - this is standard in ML code
9#![allow(clippy::cast_precision_loss)]
10
11/// Learning rate schedule strategies.
12#[derive(Debug, Clone, Copy, PartialEq, Default)]
13pub enum LrSchedule {
14    /// Constant learning rate
15    #[default]
16    Constant,
17    /// Linear warmup from 0 to max LR
18    LinearWarmup {
19        /// Number of warmup steps
20        warmup_steps: usize,
21    },
22    /// Cosine annealing from max LR to min LR
23    CosineAnnealing {
24        /// Total number of steps
25        total_steps: usize,
26        /// Minimum learning rate
27        min_lr: f64,
28    },
29    /// Linear decay from max LR to min LR
30    LinearDecay {
31        /// Total number of steps
32        total_steps: usize,
33        /// Minimum learning rate
34        min_lr: f64,
35    },
36}
37
38impl LrSchedule {
39    /// Compute the learning rate multiplier for the given step.
40    ///
41    /// # Arguments
42    /// * `step` - Current training step (0-indexed)
43    /// * `base_lr` - Base learning rate
44    ///
45    /// # Returns
46    /// The learning rate for this step
47    #[must_use]
48    pub fn get_lr(&self, step: usize, base_lr: f64) -> f64 {
49        match self {
50            Self::Constant => base_lr,
51            Self::LinearWarmup { warmup_steps } => {
52                if *warmup_steps == 0 || step >= *warmup_steps {
53                    base_lr
54                } else {
55                    base_lr * (step as f64 / *warmup_steps as f64)
56                }
57            }
58            Self::CosineAnnealing {
59                total_steps,
60                min_lr,
61            } => {
62                if *total_steps == 0 || step >= *total_steps {
63                    *min_lr
64                } else {
65                    let progress = step as f64 / *total_steps as f64;
66                    let cosine_decay = f64::midpoint(1.0, (std::f64::consts::PI * progress).cos());
67                    min_lr + (base_lr - min_lr) * cosine_decay
68                }
69            }
70            Self::LinearDecay {
71                total_steps,
72                min_lr,
73            } => {
74                if *total_steps == 0 || step >= *total_steps {
75                    *min_lr
76                } else {
77                    let progress = step as f64 / *total_steps as f64;
78                    base_lr - (base_lr - min_lr) * progress
79                }
80            }
81        }
82    }
83}
84
85/// Configuration for adapter training.
86#[derive(Debug, Clone)]
87pub struct AdapterTrainingConfig {
88    /// Base learning rate
89    pub learning_rate: f64,
90    /// Learning rate schedule
91    pub lr_schedule: LrSchedule,
92    /// Weight decay (L2 regularization)
93    pub weight_decay: f64,
94    /// Gradient accumulation steps
95    pub gradient_accumulation_steps: usize,
96    /// Maximum gradient norm for clipping (None = no clipping)
97    pub max_grad_norm: Option<f64>,
98}
99
100impl Default for AdapterTrainingConfig {
101    fn default() -> Self {
102        Self {
103            learning_rate: 1e-4,
104            lr_schedule: LrSchedule::Constant,
105            weight_decay: 0.0,
106            gradient_accumulation_steps: 1,
107            max_grad_norm: Some(1.0),
108        }
109    }
110}
111
112/// Training state for adapter fine-tuning.
113#[derive(Debug, Clone)]
114pub struct AdapterTrainingState {
115    /// Current global step
116    pub global_step: usize,
117    /// Current epoch
118    pub epoch: usize,
119    /// Steps within current epoch
120    pub steps_in_epoch: usize,
121    /// Accumulated gradient steps (for gradient accumulation)
122    pub accumulated_steps: usize,
123    /// Best validation loss seen
124    pub best_val_loss: Option<f64>,
125    /// Training configuration
126    config: AdapterTrainingConfig,
127}
128
129impl AdapterTrainingState {
130    /// Create new training state with the given configuration.
131    #[must_use]
132    pub fn new(config: AdapterTrainingConfig) -> Self {
133        Self {
134            global_step: 0,
135            epoch: 0,
136            steps_in_epoch: 0,
137            accumulated_steps: 0,
138            best_val_loss: None,
139            config,
140        }
141    }
142
143    /// Get the current learning rate based on schedule.
144    #[must_use]
145    pub fn current_lr(&self) -> f64 {
146        self.config
147            .lr_schedule
148            .get_lr(self.global_step, self.config.learning_rate)
149    }
150
151    /// Check if gradient accumulation is complete.
152    #[must_use]
153    pub fn should_update(&self) -> bool {
154        self.accumulated_steps >= self.config.gradient_accumulation_steps
155    }
156
157    /// Step after processing a batch.
158    ///
159    /// Returns `true` if an optimizer step should be taken.
160    pub fn step(&mut self) -> bool {
161        self.accumulated_steps += 1;
162        self.steps_in_epoch += 1;
163
164        if self.should_update() {
165            self.global_step += 1;
166            self.accumulated_steps = 0;
167            true
168        } else {
169            false
170        }
171    }
172
173    /// Start a new epoch.
174    pub fn new_epoch(&mut self) {
175        self.epoch += 1;
176        self.steps_in_epoch = 0;
177    }
178
179    /// Update best validation loss.
180    ///
181    /// Returns `true` if this is the new best loss.
182    pub fn update_best_val_loss(&mut self, val_loss: f64) -> bool {
183        match self.best_val_loss {
184            Some(best) if val_loss >= best => false,
185            _ => {
186                self.best_val_loss = Some(val_loss);
187                true
188            }
189        }
190    }
191
192    /// Get gradient accumulation steps.
193    #[must_use]
194    pub fn gradient_accumulation_steps(&self) -> usize {
195        self.config.gradient_accumulation_steps
196    }
197
198    /// Get maximum gradient norm for clipping.
199    #[must_use]
200    pub fn max_grad_norm(&self) -> Option<f64> {
201        self.config.max_grad_norm
202    }
203
204    /// Get weight decay.
205    #[must_use]
206    pub fn weight_decay(&self) -> f64 {
207        self.config.weight_decay
208    }
209}
210
211/// Count trainable parameters in an adapter.
212///
213/// # Arguments
214/// * `adapter` - The adapter to count parameters for
215///
216/// # Returns
217/// Number of trainable parameters
218#[must_use]
219pub fn count_trainable_parameters<A: crate::traits::Adapter>(adapter: &A) -> usize {
220    adapter.num_parameters()
221}
222
223/// Format parameter count with appropriate units.
224///
225/// # Arguments
226/// * `count` - Number of parameters
227///
228/// # Returns
229/// Human-readable string (e.g., "12.3K", "1.5M", "2.1B")
230#[must_use]
231pub fn format_parameter_count(count: usize) -> String {
232    if count >= 1_000_000_000 {
233        format!("{:.2}B", count as f64 / 1_000_000_000.0)
234    } else if count >= 1_000_000 {
235        format!("{:.2}M", count as f64 / 1_000_000.0)
236    } else if count >= 1_000 {
237        format!("{:.2}K", count as f64 / 1_000.0)
238    } else {
239        count.to_string()
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn test_constant_lr() {
249        let schedule = LrSchedule::Constant;
250        assert!((schedule.get_lr(0, 0.001) - 0.001).abs() < 1e-10);
251        assert!((schedule.get_lr(100, 0.001) - 0.001).abs() < 1e-10);
252        assert!((schedule.get_lr(1000, 0.001) - 0.001).abs() < 1e-10);
253    }
254
255    #[test]
256    fn test_linear_warmup() {
257        let schedule = LrSchedule::LinearWarmup { warmup_steps: 100 };
258        assert!((schedule.get_lr(0, 0.001) - 0.0).abs() < 1e-10);
259        assert!((schedule.get_lr(50, 0.001) - 0.0005).abs() < 1e-10);
260        assert!((schedule.get_lr(100, 0.001) - 0.001).abs() < 1e-10);
261        assert!((schedule.get_lr(200, 0.001) - 0.001).abs() < 1e-10);
262    }
263
264    #[test]
265    #[allow(clippy::similar_names)]
266    fn test_cosine_annealing() {
267        let schedule = LrSchedule::CosineAnnealing {
268            total_steps: 100,
269            min_lr: 0.0001,
270        };
271
272        // At step 0, should be at max LR
273        let lr_0 = schedule.get_lr(0, 0.001);
274        assert!((lr_0 - 0.001).abs() < 1e-10);
275
276        // At halfway, should be at average of max and min
277        let lr_50 = schedule.get_lr(50, 0.001);
278        let expected_50 = 0.0001 + (0.001 - 0.0001) * 0.5;
279        assert!((lr_50 - expected_50).abs() < 1e-6);
280
281        // At end, should be at min LR
282        let lr_100 = schedule.get_lr(100, 0.001);
283        assert!((lr_100 - 0.0001).abs() < 1e-10);
284    }
285
286    #[test]
287    fn test_linear_decay() {
288        let schedule = LrSchedule::LinearDecay {
289            total_steps: 100,
290            min_lr: 0.0001,
291        };
292
293        assert!((schedule.get_lr(0, 0.001) - 0.001).abs() < 1e-10);
294        assert!((schedule.get_lr(50, 0.001) - 0.00055).abs() < 1e-10);
295        assert!((schedule.get_lr(100, 0.001) - 0.0001).abs() < 1e-10);
296    }
297
298    #[test]
299    fn test_training_state_step() {
300        let config = AdapterTrainingConfig {
301            gradient_accumulation_steps: 4,
302            ..Default::default()
303        };
304        let mut state = AdapterTrainingState::new(config);
305
306        assert!(!state.step()); // 1/4
307        assert!(!state.step()); // 2/4
308        assert!(!state.step()); // 3/4
309        assert!(state.step()); // 4/4 - should update
310        assert_eq!(state.global_step, 1);
311        assert_eq!(state.accumulated_steps, 0);
312
313        assert!(!state.step()); // 1/4
314        assert!(!state.step()); // 2/4
315        assert!(!state.step()); // 3/4
316        assert!(state.step()); // 4/4 - should update
317        assert_eq!(state.global_step, 2);
318    }
319
320    #[test]
321    fn test_training_state_epoch() {
322        let config = AdapterTrainingConfig::default();
323        let mut state = AdapterTrainingState::new(config);
324
325        state.step();
326        state.step();
327        assert_eq!(state.steps_in_epoch, 2);
328
329        state.new_epoch();
330        assert_eq!(state.epoch, 1);
331        assert_eq!(state.steps_in_epoch, 0);
332    }
333
334    #[test]
335    fn test_best_val_loss() {
336        let config = AdapterTrainingConfig::default();
337        let mut state = AdapterTrainingState::new(config);
338
339        assert!(state.update_best_val_loss(1.0));
340        assert_eq!(state.best_val_loss, Some(1.0));
341
342        assert!(state.update_best_val_loss(0.5));
343        assert_eq!(state.best_val_loss, Some(0.5));
344
345        assert!(!state.update_best_val_loss(0.8));
346        assert_eq!(state.best_val_loss, Some(0.5));
347    }
348
349    #[test]
350    fn test_format_parameter_count() {
351        assert_eq!(format_parameter_count(100), "100");
352        assert_eq!(format_parameter_count(1_234), "1.23K");
353        assert_eq!(format_parameter_count(12_345_678), "12.35M");
354        assert_eq!(format_parameter_count(1_234_567_890), "1.23B");
355    }
356
357    #[test]
358    fn test_current_lr_with_schedule() {
359        let config = AdapterTrainingConfig {
360            learning_rate: 0.001,
361            lr_schedule: LrSchedule::LinearWarmup { warmup_steps: 10 },
362            ..Default::default()
363        };
364        let mut state = AdapterTrainingState::new(config);
365
366        // At step 0, LR should be 0
367        assert!((state.current_lr() - 0.0).abs() < 1e-10);
368
369        // Advance 5 steps
370        for _ in 0..5 {
371            state.step();
372        }
373        assert!((state.current_lr() - 0.0005).abs() < 1e-10);
374
375        // Advance 5 more steps
376        for _ in 0..5 {
377            state.step();
378        }
379        assert!((state.current_lr() - 0.001).abs() < 1e-10);
380    }
381
382    #[test]
383    fn test_zero_warmup_steps() {
384        // Edge case: zero warmup steps should return base_lr immediately
385        let schedule = LrSchedule::LinearWarmup { warmup_steps: 0 };
386        assert!((schedule.get_lr(0, 0.001) - 0.001).abs() < 1e-10);
387        assert!((schedule.get_lr(100, 0.001) - 0.001).abs() < 1e-10);
388    }
389
390    #[test]
391    fn test_zero_total_steps_cosine() {
392        // Edge case: zero total steps should return min_lr immediately
393        let schedule = LrSchedule::CosineAnnealing {
394            total_steps: 0,
395            min_lr: 0.0001,
396        };
397        assert!((schedule.get_lr(0, 0.001) - 0.0001).abs() < 1e-10);
398    }
399
400    #[test]
401    fn test_zero_total_steps_linear_decay() {
402        // Edge case: zero total steps should return min_lr immediately
403        let schedule = LrSchedule::LinearDecay {
404            total_steps: 0,
405            min_lr: 0.0001,
406        };
407        assert!((schedule.get_lr(0, 0.001) - 0.0001).abs() < 1e-10);
408    }
409}