quantrs2_ml/pytorch_api/
schedulers.rs

1//! Learning rate schedulers for PyTorch-like API
2
3/// Learning rate scheduler trait
4pub trait LRScheduler {
5    /// Get current learning rate
6    fn get_lr(&self) -> f64;
7    /// Step the scheduler
8    fn step(&mut self);
9    /// Set the epoch
10    fn set_epoch(&mut self, epoch: usize);
11}
12
13/// Step learning rate scheduler
14pub struct StepLR {
15    base_lr: f64,
16    step_size: usize,
17    gamma: f64,
18    current_epoch: usize,
19}
20
21impl StepLR {
22    /// Create new step LR scheduler
23    pub fn new(base_lr: f64, step_size: usize, gamma: f64) -> Self {
24        Self {
25            base_lr,
26            step_size,
27            gamma,
28            current_epoch: 0,
29        }
30    }
31}
32
33impl LRScheduler for StepLR {
34    fn get_lr(&self) -> f64 {
35        self.base_lr
36            * self
37                .gamma
38                .powi((self.current_epoch / self.step_size) as i32)
39    }
40
41    fn step(&mut self) {
42        self.current_epoch += 1;
43    }
44
45    fn set_epoch(&mut self, epoch: usize) {
46        self.current_epoch = epoch;
47    }
48}
49
50/// Exponential learning rate scheduler
51pub struct ExponentialLR {
52    base_lr: f64,
53    gamma: f64,
54    current_epoch: usize,
55}
56
57impl ExponentialLR {
58    /// Create new exponential LR scheduler
59    pub fn new(base_lr: f64, gamma: f64) -> Self {
60        Self {
61            base_lr,
62            gamma,
63            current_epoch: 0,
64        }
65    }
66}
67
68impl LRScheduler for ExponentialLR {
69    fn get_lr(&self) -> f64 {
70        self.base_lr * self.gamma.powi(self.current_epoch as i32)
71    }
72
73    fn step(&mut self) {
74        self.current_epoch += 1;
75    }
76
77    fn set_epoch(&mut self, epoch: usize) {
78        self.current_epoch = epoch;
79    }
80}
81
82/// Cosine annealing learning rate scheduler
83pub struct CosineAnnealingLR {
84    base_lr: f64,
85    t_max: usize,
86    eta_min: f64,
87    current_epoch: usize,
88}
89
90impl CosineAnnealingLR {
91    /// Create new cosine annealing LR scheduler
92    pub fn new(base_lr: f64, t_max: usize) -> Self {
93        Self {
94            base_lr,
95            t_max,
96            eta_min: 0.0,
97            current_epoch: 0,
98        }
99    }
100
101    /// Set minimum learning rate
102    pub fn eta_min(mut self, eta_min: f64) -> Self {
103        self.eta_min = eta_min;
104        self
105    }
106}
107
108impl LRScheduler for CosineAnnealingLR {
109    fn get_lr(&self) -> f64 {
110        self.eta_min
111            + (self.base_lr - self.eta_min)
112                * (1.0
113                    + (std::f64::consts::PI * self.current_epoch as f64 / self.t_max as f64).cos())
114                / 2.0
115    }
116
117    fn step(&mut self) {
118        self.current_epoch += 1;
119    }
120
121    fn set_epoch(&mut self, epoch: usize) {
122        self.current_epoch = epoch;
123    }
124}
125
126/// ReduceLROnPlateau scheduler
127pub struct ReduceLROnPlateau {
128    base_lr: f64,
129    factor: f64,
130    patience: usize,
131    min_lr: f64,
132    best_score: f64,
133    num_bad_epochs: usize,
134    current_lr: f64,
135}
136
137impl ReduceLROnPlateau {
138    /// Create new ReduceLROnPlateau scheduler
139    pub fn new(base_lr: f64) -> Self {
140        Self {
141            base_lr,
142            factor: 0.1,
143            patience: 10,
144            min_lr: 1e-8,
145            best_score: f64::INFINITY,
146            num_bad_epochs: 0,
147            current_lr: base_lr,
148        }
149    }
150
151    /// Set reduction factor
152    pub fn factor(mut self, factor: f64) -> Self {
153        self.factor = factor;
154        self
155    }
156
157    /// Set patience
158    pub fn patience(mut self, patience: usize) -> Self {
159        self.patience = patience;
160        self
161    }
162
163    /// Step based on validation loss
164    pub fn step_with_metric(&mut self, metric: f64) {
165        if metric < self.best_score {
166            self.best_score = metric;
167            self.num_bad_epochs = 0;
168        } else {
169            self.num_bad_epochs += 1;
170            if self.num_bad_epochs >= self.patience {
171                self.current_lr = (self.current_lr * self.factor).max(self.min_lr);
172                self.num_bad_epochs = 0;
173            }
174        }
175    }
176}
177
178impl LRScheduler for ReduceLROnPlateau {
179    fn get_lr(&self) -> f64 {
180        self.current_lr
181    }
182
183    fn step(&mut self) {
184        // No-op for ReduceLROnPlateau - use step_with_metric instead
185    }
186
187    fn set_epoch(&mut self, _epoch: usize) {
188        // No-op for ReduceLROnPlateau
189    }
190}