1#[derive(Debug, Clone)]
3pub enum ScheduleKind {
4 Geometric { rate: f64 },
5 Linear { steps: usize },
6 ReciprocalIter,
7}
8
9#[derive(Debug, Clone)]
17pub struct GumbelTemperatureSchedule {
18 pub tau_start: f64,
19 pub tau_min: f64,
20 pub decay: ScheduleKind,
21 pub iter_count: usize,
22}
23
24impl GumbelTemperatureSchedule {
25 #[must_use = "build error must be handled"]
26 pub fn new(tau_start: f64, tau_min: f64, decay: ScheduleKind) -> Result<Self, String> {
27 let sched = Self {
28 tau_start,
29 tau_min,
30 decay,
31 iter_count: 0,
32 };
33 sched.validate()?;
34 Ok(sched)
35 }
36
37 pub fn validate(&self) -> Result<(), String> {
38 if !(self.tau_start.is_finite() && self.tau_start > 0.0) {
39 return Err(format!(
40 "GumbelTemperatureSchedule: tau_start must be finite and positive; got {}",
41 self.tau_start
42 ));
43 }
44 if !(self.tau_min.is_finite() && self.tau_min > 0.0) {
45 return Err(format!(
46 "GumbelTemperatureSchedule: tau_min must be finite and positive; got {}",
47 self.tau_min
48 ));
49 }
50 if self.tau_min > self.tau_start {
51 return Err(format!(
52 "GumbelTemperatureSchedule: tau_min ({}) cannot exceed tau_start ({})",
53 self.tau_min, self.tau_start
54 ));
55 }
56 match self.decay {
57 ScheduleKind::Geometric { rate } => {
58 if !(rate.is_finite() && rate > 0.0 && rate < 1.0) {
59 return Err(format!(
60 "GumbelTemperatureSchedule::Geometric: rate must be in (0, 1); got {rate}"
61 ));
62 }
63 }
64 ScheduleKind::Linear { steps } => {
65 if steps == 0 {
66 return Err("GumbelTemperatureSchedule::Linear: steps must be positive".into());
67 }
68 }
69 ScheduleKind::ReciprocalIter => {}
70 }
71 Ok(())
72 }
73
74 pub fn current_tau(&self, iter: usize) -> f64 {
75 let raw = match self.decay {
76 ScheduleKind::Geometric { rate } => self.tau_start * rate.powf(iter as f64),
77 ScheduleKind::Linear { steps } => {
78 if iter >= steps {
79 self.tau_min
80 } else {
81 let frac = iter as f64 / steps as f64;
82 self.tau_start + frac * (self.tau_min - self.tau_start)
83 }
84 }
85 ScheduleKind::ReciprocalIter => self.tau_start / (1.0 + iter as f64),
86 };
87 raw.max(self.tau_min)
88 }
89
90 pub fn step(&mut self) -> f64 {
91 let tau = self.current_tau(self.iter_count);
92 self.iter_count += 1;
93 tau
94 }
95}
96
97#[derive(Debug, Clone, PartialEq)]
98pub enum SearchStrategy {
99 Fixed,
100 ExponentialSweep { values: Vec<f64> },
101}
102
103impl SearchStrategy {
104 #[must_use]
105 pub fn is_fixed(&self) -> bool {
106 matches!(self, Self::Fixed)
107 }
108
109 #[must_use]
110 pub fn sweep_values(&self) -> Option<&[f64]> {
111 match self {
112 Self::Fixed => None,
113 Self::ExponentialSweep { values } => Some(values),
114 }
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 fn geometric(rate: f64) -> GumbelTemperatureSchedule {
123 GumbelTemperatureSchedule::new(1.0, 0.01, ScheduleKind::Geometric { rate }).unwrap()
124 }
125
126 #[test]
129 fn new_ok_for_valid_geometric() {
130 assert!(
131 GumbelTemperatureSchedule::new(1.0, 0.1, ScheduleKind::Geometric { rate: 0.9 }).is_ok()
132 );
133 }
134
135 #[test]
136 fn new_err_for_non_positive_tau_start() {
137 assert!(GumbelTemperatureSchedule::new(0.0, 0.1, ScheduleKind::ReciprocalIter).is_err());
138 assert!(
139 GumbelTemperatureSchedule::new(f64::NAN, 0.1, ScheduleKind::ReciprocalIter).is_err()
140 );
141 }
142
143 #[test]
144 fn new_err_for_tau_min_exceeds_tau_start() {
145 assert!(
146 GumbelTemperatureSchedule::new(0.5, 1.0, ScheduleKind::Geometric { rate: 0.9 })
147 .is_err()
148 );
149 }
150
151 #[test]
152 fn new_err_for_geometric_rate_out_of_range() {
153 assert!(
154 GumbelTemperatureSchedule::new(1.0, 0.1, ScheduleKind::Geometric { rate: 1.0 })
155 .is_err()
156 );
157 assert!(
158 GumbelTemperatureSchedule::new(1.0, 0.1, ScheduleKind::Geometric { rate: 0.0 })
159 .is_err()
160 );
161 }
162
163 #[test]
164 fn new_err_for_linear_zero_steps() {
165 assert!(
166 GumbelTemperatureSchedule::new(1.0, 0.1, ScheduleKind::Linear { steps: 0 }).is_err()
167 );
168 }
169
170 #[test]
173 fn geometric_iter_zero_returns_tau_start() {
174 let s = geometric(0.5);
175 assert!((s.current_tau(0) - 1.0).abs() < 1e-14);
176 }
177
178 #[test]
179 fn geometric_decays_by_rate_each_step() {
180 let s = geometric(0.5);
181 assert!((s.current_tau(2) - 0.25).abs() < 1e-12);
183 }
184
185 #[test]
186 fn geometric_clamps_at_tau_min() {
187 let s = GumbelTemperatureSchedule::new(1.0, 0.5, ScheduleKind::Geometric { rate: 0.1 })
188 .unwrap();
189 assert!((s.current_tau(5) - 0.5).abs() < 1e-14);
191 }
192
193 #[test]
196 fn linear_iter_zero_returns_tau_start() {
197 let s =
198 GumbelTemperatureSchedule::new(2.0, 0.5, ScheduleKind::Linear { steps: 10 }).unwrap();
199 assert!((s.current_tau(0) - 2.0).abs() < 1e-14);
200 }
201
202 #[test]
203 fn linear_at_steps_returns_tau_min() {
204 let s =
205 GumbelTemperatureSchedule::new(2.0, 0.5, ScheduleKind::Linear { steps: 10 }).unwrap();
206 assert!((s.current_tau(10) - 0.5).abs() < 1e-14);
207 }
208
209 #[test]
212 fn reciprocal_iter_zero_returns_tau_start() {
213 let s = GumbelTemperatureSchedule::new(4.0, 0.1, ScheduleKind::ReciprocalIter).unwrap();
214 assert!((s.current_tau(0) - 4.0).abs() < 1e-14);
215 }
216
217 #[test]
218 fn reciprocal_iter_one_halves_tau_start() {
219 let s = GumbelTemperatureSchedule::new(4.0, 0.1, ScheduleKind::ReciprocalIter).unwrap();
220 assert!((s.current_tau(1) - 2.0).abs() < 1e-14);
221 }
222
223 #[test]
226 fn step_increments_iter_count() {
227 let mut s = geometric(0.5);
228 assert_eq!(s.iter_count, 0);
229 s.step();
230 assert_eq!(s.iter_count, 1);
231 s.step();
232 assert_eq!(s.iter_count, 2);
233 }
234
235 #[test]
238 fn fixed_is_fixed_and_has_no_sweep_values() {
239 let s = SearchStrategy::Fixed;
240 assert!(s.is_fixed());
241 assert!(s.sweep_values().is_none());
242 }
243
244 #[test]
245 fn exponential_sweep_is_not_fixed_and_returns_values() {
246 let s = SearchStrategy::ExponentialSweep {
247 values: vec![1.0, 2.0, 3.0],
248 };
249 assert!(!s.is_fixed());
250 assert_eq!(s.sweep_values().unwrap(), &[1.0, 2.0, 3.0]);
251 }
252}