1use rand::RngCore;
4use std::collections::HashMap;
5use std::fmt;
6use std::hash::Hash;
7use std::sync::Mutex;
8use std::sync::atomic::Ordering;
9
10pub mod multichain;
11
12pub struct AtomicMetricU128 {
17 inner: Mutex<u128>,
18}
19
20impl AtomicMetricU128 {
21 pub fn new(value: u128) -> Self {
22 Self {
23 inner: Mutex::new(value),
24 }
25 }
26
27 pub fn load(&self, order: Ordering) -> u128 {
28 let _ = order;
29 *self.inner.lock().unwrap()
30 }
31
32 pub fn compare_exchange(
33 &self,
34 current: u128,
35 new: u128,
36 success: Ordering,
37 failure: Ordering,
38 ) -> Result<u128, u128> {
39 let _ = (success, failure);
40 let mut guard = self.inner.lock().unwrap();
41 if *guard == current {
42 *guard = new;
43 Ok(current)
44 } else {
45 Err(*guard)
46 }
47 }
48}
49
50pub struct Best<T> {
52 pub cost: AtomicMetricU128,
53 pub value: Mutex<T>,
54}
55
56impl<T: Clone> Best<T> {
57 pub fn new(initial_cost: u128, value: T) -> Self {
58 Self {
59 cost: AtomicMetricU128::new(initial_cost),
60 value: Mutex::new(value),
61 }
62 }
63
64 pub fn try_update(&self, new_cost: u128, new_value: T) -> bool {
68 let mut current = self.cost.load(Ordering::SeqCst);
69 while new_cost < current {
70 match self
71 .cost
72 .compare_exchange(current, new_cost, Ordering::SeqCst, Ordering::SeqCst)
73 {
74 Ok(_) => {
75 let mut v = self.value.lock().unwrap();
76 *v = new_value;
77 return true;
78 }
79 Err(v) => current = v,
80 }
81 }
82 false
83 }
84
85 pub fn get(&self) -> T {
86 self.value.lock().unwrap().clone()
87 }
88}
89
90#[derive(Debug)]
92pub struct McmcStats<K> {
93 pub accepted_overall: usize,
94 pub rejected_apply_fail: usize,
95 pub rejected_candidate_fail: usize,
96 pub rejected_oracle: usize,
97 pub rejected_metro: usize,
98 pub oracle_verified: usize,
99 pub total_oracle_time_micros: u128,
100 pub accepted_edits_by_kind: HashMap<K, usize>,
101 pub rejected_sim_fail: usize,
102 pub total_sim_time_micros: u128,
103}
104
105impl<K> Default for McmcStats<K> {
106 fn default() -> Self {
107 McmcStats {
108 accepted_overall: 0,
109 rejected_apply_fail: 0,
110 rejected_candidate_fail: 0,
111 rejected_oracle: 0,
112 rejected_metro: 0,
113 oracle_verified: 0,
114 total_oracle_time_micros: 0,
115 accepted_edits_by_kind: HashMap::new(),
116 rejected_sim_fail: 0,
117 total_sim_time_micros: 0,
118 }
119 }
120}
121
122impl<K> McmcStats<K>
123where
124 K: Eq + Hash,
125{
126 pub fn merge_from(&mut self, other: McmcStats<K>) {
129 self.accepted_overall += other.accepted_overall;
130 self.rejected_apply_fail += other.rejected_apply_fail;
131 self.rejected_candidate_fail += other.rejected_candidate_fail;
132 self.rejected_oracle += other.rejected_oracle;
133 self.rejected_metro += other.rejected_metro;
134 self.oracle_verified += other.oracle_verified;
135 self.total_oracle_time_micros += other.total_oracle_time_micros;
136 self.rejected_sim_fail += other.rejected_sim_fail;
137 self.total_sim_time_micros += other.total_sim_time_micros;
138 for (k, v) in other.accepted_edits_by_kind.into_iter() {
139 *self.accepted_edits_by_kind.entry(k).or_insert(0) += v;
140 }
141 }
142}
143
144pub enum IterationOutcomeDetails<K> {
146 CandidateFailure,
147 ApplyFailure,
148 SimFailure,
149 OracleFailure,
150 MetropolisReject,
151 Accepted { kind: K },
152}
153
154pub struct McmcIterationOutput<S, C, K> {
156 pub output_state: S,
157 pub output_cost: C,
158 pub best_updated: bool,
159 pub outcome: IterationOutcomeDetails<K>,
160 pub transform_always_equivalent: bool,
161 pub transform: Option<K>,
162 pub oracle_time_micros: u128,
164}
165
166#[derive(Clone, Debug)]
168pub struct McmcOptions {
169 pub sat_reset_interval: u64,
170 pub initial_temperature: f64,
171 pub start_iteration: u64,
173 pub total_iters: Option<u64>,
176}
177
178impl<K> McmcStats<K>
179where
180 K: Eq + Hash + Ord + Clone + fmt::Debug,
181{
182 pub fn update_for_iteration<S, C>(
187 &mut self,
188 iteration: &McmcIterationOutput<S, C, K>,
189 paranoid: bool,
190 iteration_index: u64,
191 ) {
192 self.total_oracle_time_micros += iteration.oracle_time_micros;
193
194 match &iteration.outcome {
195 IterationOutcomeDetails::Accepted { kind } => {
196 self.accepted_overall += 1;
197 *self.accepted_edits_by_kind.entry(kind.clone()).or_insert(0) += 1;
198 if iteration.oracle_time_micros > 0 {
199 self.oracle_verified += 1;
200 }
201 }
202 IterationOutcomeDetails::CandidateFailure => {
203 self.rejected_candidate_fail += 1;
204 }
205 IterationOutcomeDetails::ApplyFailure => {
206 self.rejected_apply_fail += 1;
207 }
208 IterationOutcomeDetails::SimFailure => {
209 self.rejected_sim_fail += 1;
210 self.total_sim_time_micros += iteration.oracle_time_micros;
211 }
212 IterationOutcomeDetails::OracleFailure => {
213 self.rejected_oracle += 1;
214 if paranoid && iteration.transform_always_equivalent {
215 panic!(
216 "[mcmc] equivalence failure for always-equivalent transform at iteration {}; transform: {:?} should always be equivalent",
217 iteration_index, iteration.transform
218 );
219 }
220 }
221 IterationOutcomeDetails::MetropolisReject => {
222 self.rejected_metro += 1;
223 if iteration.oracle_time_micros > 0 {
224 self.oracle_verified += 1;
225 }
226 }
227 }
228 }
229}
230
231pub const MIN_TEMPERATURE_RATIO: f64 = 0.00001;
234
235pub fn metropolis_accept<R: RngCore + ?Sized>(
241 current_metric: f64,
242 new_metric: f64,
243 temp: f64,
244 rng: &mut R,
245) -> bool {
246 if new_metric < current_metric {
247 return true;
248 }
249
250 let accept_prob = ((current_metric - new_metric) / temp).exp();
251 let raw = rng.next_u64();
252
253 let u01 = (raw >> 11) as f64 * 2.0_f64.powi(-53);
262 u01 < accept_prob
263}
264
265#[cfg(test)]
266mod tests {
267 use rand::RngCore;
268
269 struct FixedU64Rng {
270 v: u64,
271 }
272
273 impl RngCore for FixedU64Rng {
274 fn next_u32(&mut self) -> u32 {
275 (self.v >> 32) as u32
276 }
277
278 fn next_u64(&mut self) -> u64 {
279 self.v
280 }
281
282 fn fill_bytes(&mut self, dest: &mut [u8]) {
283 for (i, b) in dest.iter_mut().enumerate() {
284 *b = (self.v >> ((i % 8) * 8)) as u8;
285 }
286 }
287
288 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
289 self.fill_bytes(dest);
290 Ok(())
291 }
292 }
293
294 #[test]
295 fn u01_conversion_never_reaches_one() {
296 let mut rng = FixedU64Rng { v: u64::MAX };
297 let raw = rng.next_u64();
298 let u01 = (raw >> 11) as f64 * 2.0_f64.powi(-53);
299 assert!(u01 < 1.0, "u01 must be in [0,1), got {u01}");
300 }
301
302 #[test]
303 fn u01_conversion_zero_is_zero() {
304 let mut rng = FixedU64Rng { v: 0 };
305 let raw = rng.next_u64();
306 let u01 = (raw >> 11) as f64 * 2.0_f64.powi(-53);
307 assert_eq!(u01, 0.0);
308 }
309}