quantrs2_ml/keras_api/
schedules.rs1pub trait LearningRateSchedule: Send + Sync {
5 fn get_lr(&self, step: usize) -> f64;
7}
8
9pub struct ExponentialDecay {
11 initial_lr: f64,
13 decay_steps: usize,
15 decay_rate: f64,
17 staircase: bool,
19}
20
21impl ExponentialDecay {
22 pub fn new(initial_lr: f64, decay_steps: usize, decay_rate: f64) -> Self {
24 Self {
25 initial_lr,
26 decay_steps,
27 decay_rate,
28 staircase: false,
29 }
30 }
31
32 pub fn staircase(mut self, staircase: bool) -> Self {
34 self.staircase = staircase;
35 self
36 }
37}
38
39impl LearningRateSchedule for ExponentialDecay {
40 fn get_lr(&self, step: usize) -> f64 {
41 let progress = if self.staircase {
42 (step / self.decay_steps) as f64
43 } else {
44 step as f64 / self.decay_steps as f64
45 };
46 self.initial_lr * self.decay_rate.powf(progress)
47 }
48}
49
50pub struct PiecewiseConstantDecay {
52 boundaries: Vec<usize>,
54 values: Vec<f64>,
56}
57
58impl PiecewiseConstantDecay {
59 pub fn new(boundaries: Vec<usize>, values: Vec<f64>) -> Self {
61 Self { boundaries, values }
62 }
63}
64
65impl LearningRateSchedule for PiecewiseConstantDecay {
66 fn get_lr(&self, step: usize) -> f64 {
67 for (i, &boundary) in self.boundaries.iter().enumerate() {
68 if step < boundary {
69 return self.values[i];
70 }
71 }
72 *self.values.last().unwrap_or(&0.001)
73 }
74}
75
76pub struct PolynomialDecay {
78 initial_lr: f64,
80 decay_steps: usize,
82 end_lr: f64,
84 power: f64,
86}
87
88impl PolynomialDecay {
89 pub fn new(initial_lr: f64, decay_steps: usize, end_lr: f64, power: f64) -> Self {
91 Self {
92 initial_lr,
93 decay_steps,
94 end_lr,
95 power,
96 }
97 }
98}
99
100impl LearningRateSchedule for PolynomialDecay {
101 fn get_lr(&self, step: usize) -> f64 {
102 let step = step.min(self.decay_steps);
103 let decay = (1.0 - step as f64 / self.decay_steps as f64).powf(self.power);
104 (self.initial_lr - self.end_lr) * decay + self.end_lr
105 }
106}
107
108pub struct CosineDecay {
110 initial_lr: f64,
112 decay_steps: usize,
114 alpha: f64,
116}
117
118impl CosineDecay {
119 pub fn new(initial_lr: f64, decay_steps: usize) -> Self {
121 Self {
122 initial_lr,
123 decay_steps,
124 alpha: 0.0,
125 }
126 }
127
128 pub fn alpha(mut self, alpha: f64) -> Self {
130 self.alpha = alpha;
131 self
132 }
133}
134
135impl LearningRateSchedule for CosineDecay {
136 fn get_lr(&self, step: usize) -> f64 {
137 let step = step.min(self.decay_steps);
138 let progress = step as f64 / self.decay_steps as f64;
139 let cosine_decay = 0.5 * (1.0 + (std::f64::consts::PI * progress).cos());
140 self.initial_lr * (self.alpha + (1.0 - self.alpha) * cosine_decay)
141 }
142}