quantrs2_ml/keras_api/
schedules.rs

1//! Learning rate schedules for Keras-like API
2
3/// Learning rate schedule trait
4pub trait LearningRateSchedule: Send + Sync {
5    /// Get learning rate for given step
6    fn get_lr(&self, step: usize) -> f64;
7}
8
9/// Exponential decay schedule
10pub struct ExponentialDecay {
11    /// Initial learning rate
12    initial_lr: f64,
13    /// Decay steps
14    decay_steps: usize,
15    /// Decay rate
16    decay_rate: f64,
17    /// Staircase (step-wise decay)
18    staircase: bool,
19}
20
21impl ExponentialDecay {
22    /// Create new exponential decay schedule
23    pub fn new(initial_lr: f64, decay_steps: usize, decay_rate: f64) -> Self {
24        Self {
25            initial_lr,
26            decay_steps,
27            decay_rate,
28            staircase: false,
29        }
30    }
31
32    /// Set staircase mode
33    pub fn staircase(mut self, staircase: bool) -> Self {
34        self.staircase = staircase;
35        self
36    }
37}
38
39impl LearningRateSchedule for ExponentialDecay {
40    fn get_lr(&self, step: usize) -> f64 {
41        let progress = if self.staircase {
42            (step / self.decay_steps) as f64
43        } else {
44            step as f64 / self.decay_steps as f64
45        };
46        self.initial_lr * self.decay_rate.powf(progress)
47    }
48}
49
50/// Piecewise constant schedule
51pub struct PiecewiseConstantDecay {
52    /// Boundaries
53    boundaries: Vec<usize>,
54    /// Values
55    values: Vec<f64>,
56}
57
58impl PiecewiseConstantDecay {
59    /// Create new piecewise constant decay schedule
60    pub fn new(boundaries: Vec<usize>, values: Vec<f64>) -> Self {
61        Self { boundaries, values }
62    }
63}
64
65impl LearningRateSchedule for PiecewiseConstantDecay {
66    fn get_lr(&self, step: usize) -> f64 {
67        for (i, &boundary) in self.boundaries.iter().enumerate() {
68            if step < boundary {
69                return self.values[i];
70            }
71        }
72        *self.values.last().unwrap_or(&0.001)
73    }
74}
75
76/// Polynomial decay schedule
77pub struct PolynomialDecay {
78    /// Initial learning rate
79    initial_lr: f64,
80    /// Decay steps
81    decay_steps: usize,
82    /// End learning rate
83    end_lr: f64,
84    /// Power
85    power: f64,
86}
87
88impl PolynomialDecay {
89    /// Create new polynomial decay schedule
90    pub fn new(initial_lr: f64, decay_steps: usize, end_lr: f64, power: f64) -> Self {
91        Self {
92            initial_lr,
93            decay_steps,
94            end_lr,
95            power,
96        }
97    }
98}
99
100impl LearningRateSchedule for PolynomialDecay {
101    fn get_lr(&self, step: usize) -> f64 {
102        let step = step.min(self.decay_steps);
103        let decay = (1.0 - step as f64 / self.decay_steps as f64).powf(self.power);
104        (self.initial_lr - self.end_lr) * decay + self.end_lr
105    }
106}
107
108/// Cosine decay schedule
109pub struct CosineDecay {
110    /// Initial learning rate
111    initial_lr: f64,
112    /// Decay steps
113    decay_steps: usize,
114    /// Alpha (minimum LR factor)
115    alpha: f64,
116}
117
118impl CosineDecay {
119    /// Create new cosine decay schedule
120    pub fn new(initial_lr: f64, decay_steps: usize) -> Self {
121        Self {
122            initial_lr,
123            decay_steps,
124            alpha: 0.0,
125        }
126    }
127
128    /// Set alpha (minimum LR factor)
129    pub fn alpha(mut self, alpha: f64) -> Self {
130        self.alpha = alpha;
131        self
132    }
133}
134
135impl LearningRateSchedule for CosineDecay {
136    fn get_lr(&self, step: usize) -> f64 {
137        let step = step.min(self.decay_steps);
138        let progress = step as f64 / self.decay_steps as f64;
139        let cosine_decay = 0.5 * (1.0 + (std::f64::consts::PI * progress).cos());
140        self.initial_lr * (self.alpha + (1.0 - self.alpha) * cosine_decay)
141    }
142}