1#[derive(Debug, Clone)]
3pub enum ScheduleKind {
4 Geometric { rate: f64 },
5 Linear { steps: usize },
6 ReciprocalIter,
7}
8
9#[derive(Debug, Clone)]
17pub struct GumbelTemperatureSchedule {
18 pub tau_start: f64,
19 pub tau_min: f64,
20 pub decay: ScheduleKind,
21 pub iter_count: usize,
22}
23
24impl GumbelTemperatureSchedule {
25 #[must_use = "build error must be handled"]
26 pub fn new(tau_start: f64, tau_min: f64, decay: ScheduleKind) -> Result<Self, String> {
27 let sched = Self {
28 tau_start,
29 tau_min,
30 decay,
31 iter_count: 0,
32 };
33 sched.validate()?;
34 Ok(sched)
35 }
36
37 pub fn validate(&self) -> Result<(), String> {
38 if !(self.tau_start.is_finite() && self.tau_start > 0.0) {
39 return Err(format!(
40 "GumbelTemperatureSchedule: tau_start must be finite and positive; got {}",
41 self.tau_start
42 ));
43 }
44 if !(self.tau_min.is_finite() && self.tau_min > 0.0) {
45 return Err(format!(
46 "GumbelTemperatureSchedule: tau_min must be finite and positive; got {}",
47 self.tau_min
48 ));
49 }
50 if self.tau_min > self.tau_start {
51 return Err(format!(
52 "GumbelTemperatureSchedule: tau_min ({}) cannot exceed tau_start ({})",
53 self.tau_min, self.tau_start
54 ));
55 }
56 match self.decay {
57 ScheduleKind::Geometric { rate } => {
58 if !(rate.is_finite() && rate > 0.0 && rate < 1.0) {
59 return Err(format!(
60 "GumbelTemperatureSchedule::Geometric: rate must be in (0, 1); got {rate}"
61 ));
62 }
63 }
64 ScheduleKind::Linear { steps } => {
65 if steps == 0 {
66 return Err("GumbelTemperatureSchedule::Linear: steps must be positive".into());
67 }
68 }
69 ScheduleKind::ReciprocalIter => {}
70 }
71 Ok(())
72 }
73
74 pub fn current_tau(&self, iter: usize) -> f64 {
75 let raw = match self.decay {
76 ScheduleKind::Geometric { rate } => self.tau_start * rate.powf(iter as f64),
77 ScheduleKind::Linear { steps } => {
78 if iter >= steps {
79 self.tau_min
80 } else {
81 let frac = iter as f64 / steps as f64;
82 self.tau_start + frac * (self.tau_min - self.tau_start)
83 }
84 }
85 ScheduleKind::ReciprocalIter => self.tau_start / (1.0 + iter as f64),
86 };
87 raw.max(self.tau_min)
88 }
89
90 pub fn step(&mut self) -> f64 {
91 let tau = self.current_tau(self.iter_count);
92 self.iter_count += 1;
93 tau
94 }
95}
96
97#[derive(Debug, Clone, PartialEq)]
98pub enum SearchStrategy {
99 Fixed,
100 ExponentialSweep { values: Vec<f64> },
101}
102
103impl SearchStrategy {
104 #[must_use]
105 pub fn is_fixed(&self) -> bool {
106 matches!(self, Self::Fixed)
107 }
108
109 #[must_use]
110 pub fn sweep_values(&self) -> Option<&[f64]> {
111 match self {
112 Self::Fixed => None,
113 Self::ExponentialSweep { values } => Some(values),
114 }
115 }
116}