1#[derive(Debug, Clone)]
3pub enum ScheduleKind {
4 Geometric { rate: f64 },
5 Linear { steps: usize },
6 ReciprocalIter,
7}
8
9#[derive(Debug, Clone)]
17pub struct GumbelTemperatureSchedule {
18 pub tau_start: f64,
19 pub tau_min: f64,
20 pub decay: ScheduleKind,
21 pub iter_count: usize,
22}
23
24impl GumbelTemperatureSchedule {
25 #[must_use = "build error must be handled"]
26 pub fn new(tau_start: f64, tau_min: f64, decay: ScheduleKind) -> Result<Self, String> {
27 let sched = Self {
28 tau_start,
29 tau_min,
30 decay,
31 iter_count: 0,
32 };
33 sched.validate()?;
34 Ok(sched)
35 }
36
37 pub fn validate(&self) -> Result<(), String> {
38 if !(self.tau_start.is_finite() && self.tau_start > 0.0) {
39 return Err(format!(
40 "GumbelTemperatureSchedule: tau_start must be finite and positive; got {}",
41 self.tau_start
42 ));
43 }
44 if !(self.tau_min.is_finite() && self.tau_min > 0.0) {
45 return Err(format!(
46 "GumbelTemperatureSchedule: tau_min must be finite and positive; got {}",
47 self.tau_min
48 ));
49 }
50 if self.tau_min > self.tau_start {
51 return Err(format!(
52 "GumbelTemperatureSchedule: tau_min ({}) cannot exceed tau_start ({})",
53 self.tau_min, self.tau_start
54 ));
55 }
56 match self.decay {
57 ScheduleKind::Geometric { rate } => {
58 if !(rate.is_finite() && rate > 0.0 && rate < 1.0) {
59 return Err(format!(
60 "GumbelTemperatureSchedule::Geometric: rate must be in (0, 1); got {rate}"
61 ));
62 }
63 }
64 ScheduleKind::Linear { steps } => {
65 if steps == 0 {
66 return Err("GumbelTemperatureSchedule::Linear: steps must be positive".into());
67 }
68 }
69 ScheduleKind::ReciprocalIter => {}
70 }
71 Ok(())
72 }
73
74 pub fn current_tau(&self, iter: usize) -> f64 {
75 let raw = match self.decay {
76 ScheduleKind::Geometric { rate } => self.tau_start * rate.powf(iter as f64),
77 ScheduleKind::Linear { steps } => {
78 if iter >= steps {
79 self.tau_min
80 } else {
81 let frac = iter as f64 / steps as f64;
82 self.tau_start + frac * (self.tau_min - self.tau_start)
83 }
84 }
85 ScheduleKind::ReciprocalIter => self.tau_start / (1.0 + iter as f64),
86 };
87 raw.max(self.tau_min)
88 }
89
90 pub fn step(&mut self) -> f64 {
91 let tau = self.current_tau(self.iter_count);
92 self.iter_count += 1;
93 tau
94 }
95}
96
97#[derive(Debug, Clone, PartialEq)]
98pub enum SearchStrategy {
99 Fixed,
100 ExponentialSweep { values: Vec<f64> },
101}
102
103impl SearchStrategy {
104 #[must_use]
105 pub fn is_fixed(&self) -> bool {
106 matches!(self, Self::Fixed)
107 }
108
109 #[must_use]
110 pub fn sweep_values(&self) -> Option<&[f64]> {
111 match self {
112 Self::Fixed => None,
113 Self::ExponentialSweep { values } => Some(values),
114 }
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 fn geometric(rate: f64) -> GumbelTemperatureSchedule {
123 GumbelTemperatureSchedule::new(1.0, 0.01, ScheduleKind::Geometric { rate }).unwrap()
124 }
125
126 #[test]
129 fn new_ok_for_valid_geometric() {
130 assert!(GumbelTemperatureSchedule::new(
131 1.0,
132 0.1,
133 ScheduleKind::Geometric { rate: 0.9 }
134 )
135 .is_ok());
136 }
137
138 #[test]
139 fn new_err_for_non_positive_tau_start() {
140 assert!(GumbelTemperatureSchedule::new(0.0, 0.1, ScheduleKind::ReciprocalIter).is_err());
141 assert!(GumbelTemperatureSchedule::new(f64::NAN, 0.1, ScheduleKind::ReciprocalIter)
142 .is_err());
143 }
144
145 #[test]
146 fn new_err_for_tau_min_exceeds_tau_start() {
147 assert!(GumbelTemperatureSchedule::new(
148 0.5,
149 1.0,
150 ScheduleKind::Geometric { rate: 0.9 }
151 )
152 .is_err());
153 }
154
155 #[test]
156 fn new_err_for_geometric_rate_out_of_range() {
157 assert!(GumbelTemperatureSchedule::new(
158 1.0,
159 0.1,
160 ScheduleKind::Geometric { rate: 1.0 }
161 )
162 .is_err());
163 assert!(GumbelTemperatureSchedule::new(
164 1.0,
165 0.1,
166 ScheduleKind::Geometric { rate: 0.0 }
167 )
168 .is_err());
169 }
170
171 #[test]
172 fn new_err_for_linear_zero_steps() {
173 assert!(
174 GumbelTemperatureSchedule::new(1.0, 0.1, ScheduleKind::Linear { steps: 0 }).is_err()
175 );
176 }
177
178 #[test]
181 fn geometric_iter_zero_returns_tau_start() {
182 let s = geometric(0.5);
183 assert!((s.current_tau(0) - 1.0).abs() < 1e-14);
184 }
185
186 #[test]
187 fn geometric_decays_by_rate_each_step() {
188 let s = geometric(0.5);
189 assert!((s.current_tau(2) - 0.25).abs() < 1e-12);
191 }
192
193 #[test]
194 fn geometric_clamps_at_tau_min() {
195 let s = GumbelTemperatureSchedule::new(
196 1.0,
197 0.5,
198 ScheduleKind::Geometric { rate: 0.1 },
199 )
200 .unwrap();
201 assert!((s.current_tau(5) - 0.5).abs() < 1e-14);
203 }
204
205 #[test]
208 fn linear_iter_zero_returns_tau_start() {
209 let s = GumbelTemperatureSchedule::new(2.0, 0.5, ScheduleKind::Linear { steps: 10 }).unwrap();
210 assert!((s.current_tau(0) - 2.0).abs() < 1e-14);
211 }
212
213 #[test]
214 fn linear_at_steps_returns_tau_min() {
215 let s = GumbelTemperatureSchedule::new(2.0, 0.5, ScheduleKind::Linear { steps: 10 }).unwrap();
216 assert!((s.current_tau(10) - 0.5).abs() < 1e-14);
217 }
218
219 #[test]
222 fn reciprocal_iter_zero_returns_tau_start() {
223 let s = GumbelTemperatureSchedule::new(4.0, 0.1, ScheduleKind::ReciprocalIter).unwrap();
224 assert!((s.current_tau(0) - 4.0).abs() < 1e-14);
225 }
226
227 #[test]
228 fn reciprocal_iter_one_halves_tau_start() {
229 let s = GumbelTemperatureSchedule::new(4.0, 0.1, ScheduleKind::ReciprocalIter).unwrap();
230 assert!((s.current_tau(1) - 2.0).abs() < 1e-14);
231 }
232
233 #[test]
236 fn step_increments_iter_count() {
237 let mut s = geometric(0.5);
238 assert_eq!(s.iter_count, 0);
239 s.step();
240 assert_eq!(s.iter_count, 1);
241 s.step();
242 assert_eq!(s.iter_count, 2);
243 }
244
245 #[test]
248 fn fixed_is_fixed_and_has_no_sweep_values() {
249 let s = SearchStrategy::Fixed;
250 assert!(s.is_fixed());
251 assert!(s.sweep_values().is_none());
252 }
253
254 #[test]
255 fn exponential_sweep_is_not_fixed_and_returns_values() {
256 let s = SearchStrategy::ExponentialSweep { values: vec![1.0, 2.0, 3.0] };
257 assert!(!s.is_fixed());
258 assert_eq!(s.sweep_values().unwrap(), &[1.0, 2.0, 3.0]);
259 }
260}