Skip to main content

gam_problem/
schedule.rs

1/// Decay law for deterministic Gumbel/concrete assignment temperature.
2#[derive(Debug, Clone)]
3pub enum ScheduleKind {
4    Geometric { rate: f64 },
5    Linear { steps: usize },
6    ReciprocalIter,
7}
8
9/// Outer-state temperature annealing for SAE assignment relaxations.
10///
11/// Annealing drives the continuous concrete/softmax assignment toward the
12/// discrete argmax or IBP active-set solution while PIRLS solves smooth
13/// positive-temperature subproblems. In the zero-floor limit, softmax becomes
14/// argmax and the IBP-MAP sigmoid active set becomes exact; a positive
15/// `tau_min` optimizes the corresponding near-discrete MAP problem.
16#[derive(Debug, Clone)]
17pub struct GumbelTemperatureSchedule {
18    pub tau_start: f64,
19    pub tau_min: f64,
20    pub decay: ScheduleKind,
21    pub iter_count: usize,
22}
23
24impl GumbelTemperatureSchedule {
25    #[must_use = "build error must be handled"]
26    pub fn new(tau_start: f64, tau_min: f64, decay: ScheduleKind) -> Result<Self, String> {
27        let sched = Self {
28            tau_start,
29            tau_min,
30            decay,
31            iter_count: 0,
32        };
33        sched.validate()?;
34        Ok(sched)
35    }
36
37    pub fn validate(&self) -> Result<(), String> {
38        if !(self.tau_start.is_finite() && self.tau_start > 0.0) {
39            return Err(format!(
40                "GumbelTemperatureSchedule: tau_start must be finite and positive; got {}",
41                self.tau_start
42            ));
43        }
44        if !(self.tau_min.is_finite() && self.tau_min > 0.0) {
45            return Err(format!(
46                "GumbelTemperatureSchedule: tau_min must be finite and positive; got {}",
47                self.tau_min
48            ));
49        }
50        if self.tau_min > self.tau_start {
51            return Err(format!(
52                "GumbelTemperatureSchedule: tau_min ({}) cannot exceed tau_start ({})",
53                self.tau_min, self.tau_start
54            ));
55        }
56        match self.decay {
57            ScheduleKind::Geometric { rate } => {
58                if !(rate.is_finite() && rate > 0.0 && rate < 1.0) {
59                    return Err(format!(
60                        "GumbelTemperatureSchedule::Geometric: rate must be in (0, 1); got {rate}"
61                    ));
62                }
63            }
64            ScheduleKind::Linear { steps } => {
65                if steps == 0 {
66                    return Err("GumbelTemperatureSchedule::Linear: steps must be positive".into());
67                }
68            }
69            ScheduleKind::ReciprocalIter => {}
70        }
71        Ok(())
72    }
73
74    pub fn current_tau(&self, iter: usize) -> f64 {
75        let raw = match self.decay {
76            ScheduleKind::Geometric { rate } => self.tau_start * rate.powf(iter as f64),
77            ScheduleKind::Linear { steps } => {
78                if iter >= steps {
79                    self.tau_min
80                } else {
81                    let frac = iter as f64 / steps as f64;
82                    self.tau_start + frac * (self.tau_min - self.tau_start)
83                }
84            }
85            ScheduleKind::ReciprocalIter => self.tau_start / (1.0 + iter as f64),
86        };
87        raw.max(self.tau_min)
88    }
89
90    pub fn step(&mut self) -> f64 {
91        let tau = self.current_tau(self.iter_count);
92        self.iter_count += 1;
93        tau
94    }
95}
96
97#[derive(Debug, Clone, PartialEq)]
98pub enum SearchStrategy {
99    Fixed,
100    ExponentialSweep { values: Vec<f64> },
101}
102
103impl SearchStrategy {
104    #[must_use]
105    pub fn is_fixed(&self) -> bool {
106        matches!(self, Self::Fixed)
107    }
108
109    #[must_use]
110    pub fn sweep_values(&self) -> Option<&[f64]> {
111        match self {
112            Self::Fixed => None,
113            Self::ExponentialSweep { values } => Some(values),
114        }
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    fn geometric(rate: f64) -> GumbelTemperatureSchedule {
123        GumbelTemperatureSchedule::new(1.0, 0.01, ScheduleKind::Geometric { rate }).unwrap()
124    }
125
126    // ── GumbelTemperatureSchedule validation ──────────────────────────────────
127
128    #[test]
129    fn new_ok_for_valid_geometric() {
130        assert!(GumbelTemperatureSchedule::new(
131            1.0,
132            0.1,
133            ScheduleKind::Geometric { rate: 0.9 }
134        )
135        .is_ok());
136    }
137
138    #[test]
139    fn new_err_for_non_positive_tau_start() {
140        assert!(GumbelTemperatureSchedule::new(0.0, 0.1, ScheduleKind::ReciprocalIter).is_err());
141        assert!(GumbelTemperatureSchedule::new(f64::NAN, 0.1, ScheduleKind::ReciprocalIter)
142            .is_err());
143    }
144
145    #[test]
146    fn new_err_for_tau_min_exceeds_tau_start() {
147        assert!(GumbelTemperatureSchedule::new(
148            0.5,
149            1.0,
150            ScheduleKind::Geometric { rate: 0.9 }
151        )
152        .is_err());
153    }
154
155    #[test]
156    fn new_err_for_geometric_rate_out_of_range() {
157        assert!(GumbelTemperatureSchedule::new(
158            1.0,
159            0.1,
160            ScheduleKind::Geometric { rate: 1.0 }
161        )
162        .is_err());
163        assert!(GumbelTemperatureSchedule::new(
164            1.0,
165            0.1,
166            ScheduleKind::Geometric { rate: 0.0 }
167        )
168        .is_err());
169    }
170
171    #[test]
172    fn new_err_for_linear_zero_steps() {
173        assert!(
174            GumbelTemperatureSchedule::new(1.0, 0.1, ScheduleKind::Linear { steps: 0 }).is_err()
175        );
176    }
177
178    // ── current_tau: Geometric ────────────────────────────────────────────────
179
180    #[test]
181    fn geometric_iter_zero_returns_tau_start() {
182        let s = geometric(0.5);
183        assert!((s.current_tau(0) - 1.0).abs() < 1e-14);
184    }
185
186    #[test]
187    fn geometric_decays_by_rate_each_step() {
188        let s = geometric(0.5);
189        // iter 2: 1.0 * 0.5^2 = 0.25
190        assert!((s.current_tau(2) - 0.25).abs() < 1e-12);
191    }
192
193    #[test]
194    fn geometric_clamps_at_tau_min() {
195        let s = GumbelTemperatureSchedule::new(
196            1.0,
197            0.5,
198            ScheduleKind::Geometric { rate: 0.1 },
199        )
200        .unwrap();
201        // 1.0 * 0.1^5 = 1e-5 < tau_min=0.5 → clamped
202        assert!((s.current_tau(5) - 0.5).abs() < 1e-14);
203    }
204
205    // ── current_tau: Linear ───────────────────────────────────────────────────
206
207    #[test]
208    fn linear_iter_zero_returns_tau_start() {
209        let s = GumbelTemperatureSchedule::new(2.0, 0.5, ScheduleKind::Linear { steps: 10 }).unwrap();
210        assert!((s.current_tau(0) - 2.0).abs() < 1e-14);
211    }
212
213    #[test]
214    fn linear_at_steps_returns_tau_min() {
215        let s = GumbelTemperatureSchedule::new(2.0, 0.5, ScheduleKind::Linear { steps: 10 }).unwrap();
216        assert!((s.current_tau(10) - 0.5).abs() < 1e-14);
217    }
218
219    // ── current_tau: ReciprocalIter ───────────────────────────────────────────
220
221    #[test]
222    fn reciprocal_iter_zero_returns_tau_start() {
223        let s = GumbelTemperatureSchedule::new(4.0, 0.1, ScheduleKind::ReciprocalIter).unwrap();
224        assert!((s.current_tau(0) - 4.0).abs() < 1e-14);
225    }
226
227    #[test]
228    fn reciprocal_iter_one_halves_tau_start() {
229        let s = GumbelTemperatureSchedule::new(4.0, 0.1, ScheduleKind::ReciprocalIter).unwrap();
230        assert!((s.current_tau(1) - 2.0).abs() < 1e-14);
231    }
232
233    // ── step() increments iter_count ──────────────────────────────────────────
234
235    #[test]
236    fn step_increments_iter_count() {
237        let mut s = geometric(0.5);
238        assert_eq!(s.iter_count, 0);
239        s.step();
240        assert_eq!(s.iter_count, 1);
241        s.step();
242        assert_eq!(s.iter_count, 2);
243    }
244
245    // ── SearchStrategy ────────────────────────────────────────────────────────
246
247    #[test]
248    fn fixed_is_fixed_and_has_no_sweep_values() {
249        let s = SearchStrategy::Fixed;
250        assert!(s.is_fixed());
251        assert!(s.sweep_values().is_none());
252    }
253
254    #[test]
255    fn exponential_sweep_is_not_fixed_and_returns_values() {
256        let s = SearchStrategy::ExponentialSweep { values: vec![1.0, 2.0, 3.0] };
257        assert!(!s.is_fixed());
258        assert_eq!(s.sweep_values().unwrap(), &[1.0, 2.0, 3.0]);
259    }
260}