Skip to main content

xlsynth_mcmc/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use rand::RngCore;
4use std::collections::HashMap;
5use std::fmt;
6use std::hash::Hash;
7use std::sync::Mutex;
8use std::sync::atomic::Ordering;
9
10pub mod multichain;
11
12/// Compatibility wrapper for atomically managed `u128` metrics.
13///
14/// This toolchain does not provide stable `AtomicU128`, so we expose a
15/// compare-exchange style API backed by a mutex.
16pub struct AtomicMetricU128 {
17    inner: Mutex<u128>,
18}
19
20impl AtomicMetricU128 {
21    pub fn new(value: u128) -> Self {
22        Self {
23            inner: Mutex::new(value),
24        }
25    }
26
27    pub fn load(&self, order: Ordering) -> u128 {
28        let _ = order;
29        *self.inner.lock().unwrap()
30    }
31
32    pub fn compare_exchange(
33        &self,
34        current: u128,
35        new: u128,
36        success: Ordering,
37        failure: Ordering,
38    ) -> Result<u128, u128> {
39        let _ = (success, failure);
40        let mut guard = self.inner.lock().unwrap();
41        if *guard == current {
42            *guard = new;
43            Ok(current)
44        } else {
45            Err(*guard)
46        }
47    }
48}
49
50/// Shared best-so-far candidate across threads.
51pub struct Best<T> {
52    pub cost: AtomicMetricU128,
53    pub value: Mutex<T>,
54}
55
56impl<T: Clone> Best<T> {
57    pub fn new(initial_cost: u128, value: T) -> Self {
58        Self {
59            cost: AtomicMetricU128::new(initial_cost),
60            value: Mutex::new(value),
61        }
62    }
63
64    /// Attempts to update the best-so-far candidate.
65    ///
66    /// Returns `true` if this call updated the global best, `false` otherwise.
67    pub fn try_update(&self, new_cost: u128, new_value: T) -> bool {
68        let mut current = self.cost.load(Ordering::SeqCst);
69        while new_cost < current {
70            match self
71                .cost
72                .compare_exchange(current, new_cost, Ordering::SeqCst, Ordering::SeqCst)
73            {
74                Ok(_) => {
75                    let mut v = self.value.lock().unwrap();
76                    *v = new_value;
77                    return true;
78                }
79                Err(v) => current = v,
80            }
81        }
82        false
83    }
84
85    pub fn get(&self) -> T {
86        self.value.lock().unwrap().clone()
87    }
88}
89
90/// Holds MCMC iteration statistics.
91#[derive(Debug)]
92pub struct McmcStats<K> {
93    pub accepted_overall: usize,
94    pub rejected_apply_fail: usize,
95    pub rejected_candidate_fail: usize,
96    pub rejected_oracle: usize,
97    pub rejected_metro: usize,
98    pub oracle_verified: usize,
99    pub total_oracle_time_micros: u128,
100    pub accepted_edits_by_kind: HashMap<K, usize>,
101    pub rejected_sim_fail: usize,
102    pub total_sim_time_micros: u128,
103}
104
105impl<K> Default for McmcStats<K> {
106    fn default() -> Self {
107        McmcStats {
108            accepted_overall: 0,
109            rejected_apply_fail: 0,
110            rejected_candidate_fail: 0,
111            rejected_oracle: 0,
112            rejected_metro: 0,
113            oracle_verified: 0,
114            total_oracle_time_micros: 0,
115            accepted_edits_by_kind: HashMap::new(),
116            rejected_sim_fail: 0,
117            total_sim_time_micros: 0,
118        }
119    }
120}
121
122impl<K> McmcStats<K>
123where
124    K: Eq + Hash,
125{
126    /// Merges `other` into `self` (useful for multi-segment / multi-chain
127    /// runs).
128    pub fn merge_from(&mut self, other: McmcStats<K>) {
129        self.accepted_overall += other.accepted_overall;
130        self.rejected_apply_fail += other.rejected_apply_fail;
131        self.rejected_candidate_fail += other.rejected_candidate_fail;
132        self.rejected_oracle += other.rejected_oracle;
133        self.rejected_metro += other.rejected_metro;
134        self.oracle_verified += other.oracle_verified;
135        self.total_oracle_time_micros += other.total_oracle_time_micros;
136        self.rejected_sim_fail += other.rejected_sim_fail;
137        self.total_sim_time_micros += other.total_sim_time_micros;
138        for (k, v) in other.accepted_edits_by_kind.into_iter() {
139            *self.accepted_edits_by_kind.entry(k).or_insert(0) += v;
140        }
141    }
142}
143
144/// Details of what occurred during a single MCMC iteration attempt.
145pub enum IterationOutcomeDetails<K> {
146    CandidateFailure,
147    ApplyFailure,
148    SimFailure,
149    OracleFailure,
150    MetropolisReject,
151    Accepted { kind: K },
152}
153
154/// Output of a single MCMC iteration.
155pub struct McmcIterationOutput<S, C, K> {
156    pub output_state: S,
157    pub output_cost: C,
158    pub best_updated: bool,
159    pub outcome: IterationOutcomeDetails<K>,
160    pub transform_always_equivalent: bool,
161    pub transform: Option<K>,
162    /// Time spent in oracle or simulation, 0 if not run.
163    pub oracle_time_micros: u128,
164}
165
166/// Engine-level options for MCMC runs.
167#[derive(Clone, Debug)]
168pub struct McmcOptions {
169    pub sat_reset_interval: u64,
170    pub initial_temperature: f64,
171    /// Global starting iteration index (for multi-segment runs).
172    pub start_iteration: u64,
173    /// Total planned iterations for the *entire* run (across segments).
174    /// If `None`, temperature remains constant (no cooling).
175    pub total_iters: Option<u64>,
176}
177
178impl<K> McmcStats<K>
179where
180    K: Eq + Hash + Ord + Clone + fmt::Debug,
181{
182    /// Update statistics based on the outcome of a single iteration.
183    ///
184    /// `iteration_index` is the human-readable global iteration number used
185    /// only for panic messages in paranoid mode.
186    pub fn update_for_iteration<S, C>(
187        &mut self,
188        iteration: &McmcIterationOutput<S, C, K>,
189        paranoid: bool,
190        iteration_index: u64,
191    ) {
192        self.total_oracle_time_micros += iteration.oracle_time_micros;
193
194        match &iteration.outcome {
195            IterationOutcomeDetails::Accepted { kind } => {
196                self.accepted_overall += 1;
197                *self.accepted_edits_by_kind.entry(kind.clone()).or_insert(0) += 1;
198                if iteration.oracle_time_micros > 0 {
199                    self.oracle_verified += 1;
200                }
201            }
202            IterationOutcomeDetails::CandidateFailure => {
203                self.rejected_candidate_fail += 1;
204            }
205            IterationOutcomeDetails::ApplyFailure => {
206                self.rejected_apply_fail += 1;
207            }
208            IterationOutcomeDetails::SimFailure => {
209                self.rejected_sim_fail += 1;
210                self.total_sim_time_micros += iteration.oracle_time_micros;
211            }
212            IterationOutcomeDetails::OracleFailure => {
213                self.rejected_oracle += 1;
214                if paranoid && iteration.transform_always_equivalent {
215                    panic!(
216                        "[mcmc] equivalence failure for always-equivalent transform at iteration {}; transform: {:?} should always be equivalent",
217                        iteration_index, iteration.transform
218                    );
219                }
220            }
221            IterationOutcomeDetails::MetropolisReject => {
222                self.rejected_metro += 1;
223                if iteration.oracle_time_micros > 0 {
224                    self.oracle_verified += 1;
225                }
226            }
227        }
228    }
229}
230
231/// Minimum allowed relative temperature (as a ratio of the initial
232/// temperature) to avoid underflow and numeric issues during cooling.
233pub const MIN_TEMPERATURE_RATIO: f64 = 0.00001;
234
235/// Decide whether to accept a candidate move under the Metropolis rule.
236///
237/// `current_metric` and `new_metric` are scalar objective values (lower is
238/// better).  When `new_metric < current_metric`, the move is always accepted.
239/// Otherwise it is accepted with probability `exp((current - new) / temp)`.
240pub fn metropolis_accept<R: RngCore + ?Sized>(
241    current_metric: f64,
242    new_metric: f64,
243    temp: f64,
244    rng: &mut R,
245) -> bool {
246    if new_metric < current_metric {
247        return true;
248    }
249
250    let accept_prob = ((current_metric - new_metric) / temp).exp();
251    let raw = rng.next_u64();
252
253    // Generate a uniform floating-point value in [0, 1) with correct IEEE-754
254    // semantics.
255    //
256    // Using `(raw as f64) / (u64::MAX as f64)` can produce `1.0` when `raw ==
257    // u64::MAX`, which introduces a tiny rejection bias in Metropolis-Hastings.
258    //
259    // We instead take the top 53 bits (the precision of f64’s mantissa) and scale
260    // by 2^-53, yielding values in [0, 1).
261    let u01 = (raw >> 11) as f64 * 2.0_f64.powi(-53);
262    u01 < accept_prob
263}
264
265#[cfg(test)]
266mod tests {
267    use rand::RngCore;
268
269    struct FixedU64Rng {
270        v: u64,
271    }
272
273    impl RngCore for FixedU64Rng {
274        fn next_u32(&mut self) -> u32 {
275            (self.v >> 32) as u32
276        }
277
278        fn next_u64(&mut self) -> u64 {
279            self.v
280        }
281
282        fn fill_bytes(&mut self, dest: &mut [u8]) {
283            for (i, b) in dest.iter_mut().enumerate() {
284                *b = (self.v >> ((i % 8) * 8)) as u8;
285            }
286        }
287
288        fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
289            self.fill_bytes(dest);
290            Ok(())
291        }
292    }
293
294    #[test]
295    fn u01_conversion_never_reaches_one() {
296        let mut rng = FixedU64Rng { v: u64::MAX };
297        let raw = rng.next_u64();
298        let u01 = (raw >> 11) as f64 * 2.0_f64.powi(-53);
299        assert!(u01 < 1.0, "u01 must be in [0,1), got {u01}");
300    }
301
302    #[test]
303    fn u01_conversion_zero_is_zero() {
304        let mut rng = FixedU64Rng { v: 0 };
305        let raw = rng.next_u64();
306        let u01 = (raw >> 11) as f64 * 2.0_f64.powi(-53);
307        assert_eq!(u01, 0.0);
308    }
309}