quantrs2_ml/pytorch_api/
schedulers.rs1pub trait LRScheduler {
5 fn get_lr(&self) -> f64;
7 fn step(&mut self);
9 fn set_epoch(&mut self, epoch: usize);
11}
12
13pub struct StepLR {
15 base_lr: f64,
16 step_size: usize,
17 gamma: f64,
18 current_epoch: usize,
19}
20
21impl StepLR {
22 pub fn new(base_lr: f64, step_size: usize, gamma: f64) -> Self {
24 Self {
25 base_lr,
26 step_size,
27 gamma,
28 current_epoch: 0,
29 }
30 }
31}
32
33impl LRScheduler for StepLR {
34 fn get_lr(&self) -> f64 {
35 self.base_lr
36 * self
37 .gamma
38 .powi((self.current_epoch / self.step_size) as i32)
39 }
40
41 fn step(&mut self) {
42 self.current_epoch += 1;
43 }
44
45 fn set_epoch(&mut self, epoch: usize) {
46 self.current_epoch = epoch;
47 }
48}
49
50pub struct ExponentialLR {
52 base_lr: f64,
53 gamma: f64,
54 current_epoch: usize,
55}
56
57impl ExponentialLR {
58 pub fn new(base_lr: f64, gamma: f64) -> Self {
60 Self {
61 base_lr,
62 gamma,
63 current_epoch: 0,
64 }
65 }
66}
67
68impl LRScheduler for ExponentialLR {
69 fn get_lr(&self) -> f64 {
70 self.base_lr * self.gamma.powi(self.current_epoch as i32)
71 }
72
73 fn step(&mut self) {
74 self.current_epoch += 1;
75 }
76
77 fn set_epoch(&mut self, epoch: usize) {
78 self.current_epoch = epoch;
79 }
80}
81
82pub struct CosineAnnealingLR {
84 base_lr: f64,
85 t_max: usize,
86 eta_min: f64,
87 current_epoch: usize,
88}
89
90impl CosineAnnealingLR {
91 pub fn new(base_lr: f64, t_max: usize) -> Self {
93 Self {
94 base_lr,
95 t_max,
96 eta_min: 0.0,
97 current_epoch: 0,
98 }
99 }
100
101 pub fn eta_min(mut self, eta_min: f64) -> Self {
103 self.eta_min = eta_min;
104 self
105 }
106}
107
108impl LRScheduler for CosineAnnealingLR {
109 fn get_lr(&self) -> f64 {
110 self.eta_min
111 + (self.base_lr - self.eta_min)
112 * (1.0
113 + (std::f64::consts::PI * self.current_epoch as f64 / self.t_max as f64).cos())
114 / 2.0
115 }
116
117 fn step(&mut self) {
118 self.current_epoch += 1;
119 }
120
121 fn set_epoch(&mut self, epoch: usize) {
122 self.current_epoch = epoch;
123 }
124}
125
126pub struct ReduceLROnPlateau {
128 base_lr: f64,
129 factor: f64,
130 patience: usize,
131 min_lr: f64,
132 best_score: f64,
133 num_bad_epochs: usize,
134 current_lr: f64,
135}
136
137impl ReduceLROnPlateau {
138 pub fn new(base_lr: f64) -> Self {
140 Self {
141 base_lr,
142 factor: 0.1,
143 patience: 10,
144 min_lr: 1e-8,
145 best_score: f64::INFINITY,
146 num_bad_epochs: 0,
147 current_lr: base_lr,
148 }
149 }
150
151 pub fn factor(mut self, factor: f64) -> Self {
153 self.factor = factor;
154 self
155 }
156
157 pub fn patience(mut self, patience: usize) -> Self {
159 self.patience = patience;
160 self
161 }
162
163 pub fn step_with_metric(&mut self, metric: f64) {
165 if metric < self.best_score {
166 self.best_score = metric;
167 self.num_bad_epochs = 0;
168 } else {
169 self.num_bad_epochs += 1;
170 if self.num_bad_epochs >= self.patience {
171 self.current_lr = (self.current_lr * self.factor).max(self.min_lr);
172 self.num_bad_epochs = 0;
173 }
174 }
175 }
176}
177
178impl LRScheduler for ReduceLROnPlateau {
179 fn get_lr(&self) -> f64 {
180 self.current_lr
181 }
182
183 fn step(&mut self) {
184 }
186
187 fn set_epoch(&mut self, _epoch: usize) {
188 }
190}