Skip to main content

moonpool_sim/sim/
rng.rs

1//! Thread-local random number generation for simulation.
2//!
3//! This module provides deterministic randomness through thread-local storage,
4//! enabling clean API design without passing RNG through the simulation state.
5//! Each thread maintains its own RNG state, ensuring deterministic behavior
6//! within each simulation run while supporting parallel test execution.
7
8use rand::SeedableRng;
9use rand::{
10    RngExt,
11    distr::{Distribution, StandardUniform, uniform::SampleUniform},
12};
13use rand_chacha::ChaCha8Rng;
14use std::cell::{Cell, RefCell};
15use std::collections::VecDeque;
16
17thread_local! {
18    /// Thread-local random number generator for simulation.
19    ///
20    /// Uses ChaCha8Rng for deterministic, reproducible randomness.
21    /// Each thread maintains independent state for parallel test execution.
22    static SIM_RNG: RefCell<ChaCha8Rng> = RefCell::new(ChaCha8Rng::seed_from_u64(0));
23
24    /// Thread-local storage for the current simulation seed.
25    ///
26    /// This stores the last seed set via [`set_sim_seed`] to enable
27    /// error reporting with seed information.
28    static CURRENT_SEED: RefCell<u64> = const { RefCell::new(0) };
29
30    /// Thread-local counter tracking RNG calls since last reset.
31    ///
32    /// Used by the exploration framework to record fork points and
33    /// enable deterministic replay via breakpoints.
34    static RNG_CALL_COUNT: Cell<u64> = const { Cell::new(0) };
35
36    /// Thread-local queue of RNG breakpoints, sorted by target call count.
37    ///
38    /// Each entry is `(target_count, new_seed)`. When the call count exceeds
39    /// `target_count`, the RNG reseeds with `new_seed` and the count resets to 1.
40    static RNG_BREAKPOINTS: RefCell<VecDeque<(u64, u64)>> = const { RefCell::new(VecDeque::new()) };
41}
42
43/// Increment the RNG call counter and check for breakpoints.
44///
45/// Called before every RNG sample. If the current call count exceeds
46/// a breakpoint's target, reseeds the RNG and resets the counter.
47fn pre_sample() {
48    RNG_CALL_COUNT.with(|c| c.set(c.get() + 1));
49    check_rng_breakpoint();
50}
51
52/// Check and trigger any pending RNG breakpoints.
53///
54/// Pops breakpoints whose target count has been exceeded (using `>`),
55/// reseeding the RNG for each. The count resets to 1 because the
56/// current call is the first call of the new seed segment.
57fn check_rng_breakpoint() {
58    RNG_BREAKPOINTS.with(|bp| {
59        let mut breakpoints = bp.borrow_mut();
60        while let Some(&(target_count, new_seed)) = breakpoints.front() {
61            let count = RNG_CALL_COUNT.with(|c| c.get());
62            if count > target_count {
63                breakpoints.pop_front();
64                SIM_RNG.with(|rng| {
65                    *rng.borrow_mut() = ChaCha8Rng::seed_from_u64(new_seed);
66                });
67                CURRENT_SEED.with(|s| {
68                    *s.borrow_mut() = new_seed;
69                });
70                RNG_CALL_COUNT.with(|c| c.set(1));
71            } else {
72                break;
73            }
74        }
75    });
76}
77
78/// Generate a random value using the thread-local simulation RNG.
79///
80/// This function provides deterministic randomness based on the seed set
81/// via [`set_sim_seed`]. The same seed will always produce the same sequence
82/// of random values within a single thread.
83///
84/// # Type Parameters
85///
86/// * `T` - The type to generate. Must implement the Standard distribution.
87///
88/// Generate a random value using the thread-local simulation RNG.
89pub fn sim_random<T>() -> T
90where
91    StandardUniform: Distribution<T>,
92{
93    pre_sample();
94    SIM_RNG.with(|rng| rng.borrow_mut().sample(StandardUniform))
95}
96
97/// Generate a random value within a specified range using the thread-local simulation RNG.
98///
99/// This function provides deterministic randomness for values within a range.
100/// The same seed will always produce the same sequence of values.
101///
102/// # Type Parameters
103///
104/// * `T` - The type to generate. Must implement SampleUniform.
105///
106/// # Parameters
107///
108/// * `range` - The range to sample from (exclusive upper bound).
109///
110/// Generate a random value within a specified range.
111pub fn sim_random_range<T>(range: std::ops::Range<T>) -> T
112where
113    T: SampleUniform + PartialOrd,
114{
115    pre_sample();
116    SIM_RNG.with(|rng| rng.borrow_mut().random_range(range))
117}
118
119/// Generate a random value within the given range, returning the start value if the range is empty.
120///
121/// This is a safe version of [`sim_random_range`] that handles empty ranges gracefully
122/// by returning the start value when start == end.
123///
124/// # Parameters
125///
126/// * `range` - The range to sample from (start..end)
127///
128/// # Returns
129///
130/// A random value within the range, or the start value if the range is empty.
131///
132/// Generate a random value in range or return start value if range is empty.
133pub fn sim_random_range_or_default<T>(range: std::ops::Range<T>) -> T
134where
135    T: SampleUniform + PartialOrd + Clone,
136{
137    if range.start >= range.end {
138        range.start
139    } else {
140        sim_random_range(range)
141    }
142}
143
144/// Set the seed for the thread-local simulation RNG.
145///
146/// This function initializes the thread-local RNG with a specific seed,
147/// ensuring deterministic behavior. The same seed will always produce
148/// the same sequence of random values.
149///
150/// # Parameters
151///
152/// * `seed` - The seed value to use for deterministic randomness.
153///
154/// Set the seed for the thread-local simulation RNG.
155pub fn set_sim_seed(seed: u64) {
156    SIM_RNG.with(|rng| {
157        *rng.borrow_mut() = ChaCha8Rng::seed_from_u64(seed);
158    });
159    CURRENT_SEED.with(|current| {
160        *current.borrow_mut() = seed;
161    });
162}
163
164/// Generate a random f64 in the range [0.0, 1.0) using the simulation RNG.
165///
166/// This is a convenience function matching FDB's `deterministicRandom()->random01()`.
167///
168/// # Returns
169///
170/// A random f64 value in [0.0, 1.0).
171pub fn sim_random_f64() -> f64 {
172    pre_sample();
173    SIM_RNG.with(|rng| rng.borrow_mut().sample(StandardUniform))
174}
175
176/// Get the current simulation seed.
177///
178/// Returns the seed that was last set via [`set_sim_seed`].
179/// This is useful for error reporting to help reproduce failing test cases.
180///
181/// # Returns
182///
183/// The current simulation seed, or 0 if no seed has been set.
184///
185/// Get the current simulation seed.
186pub fn get_current_sim_seed() -> u64 {
187    CURRENT_SEED.with(|current| *current.borrow())
188}
189
190/// Reset the thread-local simulation RNG to a fresh state.
191///
192/// This function clears any existing RNG state and initializes with entropy.
193/// It should be called before setting a new seed to ensure clean state
194/// between consecutive simulation runs on the same thread.
195///
196/// Reset the thread-local simulation RNG to a fresh state.
197pub fn reset_sim_rng() {
198    SIM_RNG.with(|rng| {
199        *rng.borrow_mut() = ChaCha8Rng::seed_from_u64(0);
200    });
201    CURRENT_SEED.with(|current| {
202        *current.borrow_mut() = 0;
203    });
204    RNG_CALL_COUNT.with(|c| c.set(0));
205    RNG_BREAKPOINTS.with(|bp| bp.borrow_mut().clear());
206}
207
208/// Get the current RNG call count.
209///
210/// Returns the number of RNG calls made since the last seed set or reset.
211/// Used by the exploration framework to record fork points.
212pub fn get_rng_call_count() -> u64 {
213    RNG_CALL_COUNT.with(|c| c.get())
214}
215
216/// Reset the RNG call count to zero.
217///
218/// Used when reseeding to start a new counting segment.
219pub fn reset_rng_call_count() {
220    RNG_CALL_COUNT.with(|c| c.set(0));
221}
222
223/// Set RNG breakpoints for deterministic replay.
224///
225/// Each breakpoint is a `(target_count, new_seed)` pair. When the RNG call
226/// count exceeds `target_count`, the RNG is reseeded with `new_seed` and
227/// the count resets to 1.
228///
229/// Breakpoints must be sorted by `target_count` in ascending order.
230///
231/// # Parameters
232///
233/// * `breakpoints` - Sorted list of (target_count, new_seed) pairs.
234pub fn set_rng_breakpoints(breakpoints: Vec<(u64, u64)>) {
235    RNG_BREAKPOINTS.with(|bp| {
236        *bp.borrow_mut() = VecDeque::from(breakpoints);
237    });
238}
239
240/// Clear all RNG breakpoints.
241pub fn clear_rng_breakpoints() {
242    RNG_BREAKPOINTS.with(|bp| bp.borrow_mut().clear());
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn test_deterministic_randomness() {
251        // Set seed and generate some values
252        set_sim_seed(42);
253        let value1: f64 = sim_random();
254        let value2: u32 = sim_random();
255        let value3: bool = sim_random();
256
257        // Reset to same seed and verify same sequence
258        set_sim_seed(42);
259        assert_eq!(value1, sim_random::<f64>());
260        assert_eq!(value2, sim_random::<u32>());
261        assert_eq!(value3, sim_random::<bool>());
262    }
263
264    #[test]
265    fn test_different_seeds_produce_different_values() {
266        // Generate values with first seed
267        set_sim_seed(1);
268        let value1_seed1: f64 = sim_random();
269        let value2_seed1: f64 = sim_random();
270
271        // Generate values with different seed
272        set_sim_seed(2);
273        let value1_seed2: f64 = sim_random();
274        let value2_seed2: f64 = sim_random();
275
276        // Values should be different
277        assert_ne!(value1_seed1, value1_seed2);
278        assert_ne!(value2_seed1, value2_seed2);
279    }
280
281    #[test]
282    fn test_sim_random_range() {
283        set_sim_seed(42);
284
285        // Test integer range
286        for _ in 0..100 {
287            let value = sim_random_range(10..20);
288            assert!(value >= 10);
289            assert!(value < 20);
290        }
291
292        // Test f64 range
293        for _ in 0..100 {
294            let value = sim_random_range(0.0..1.0);
295            assert!(value >= 0.0);
296            assert!(value < 1.0);
297        }
298    }
299
300    #[test]
301    fn test_range_determinism() {
302        set_sim_seed(123);
303        let value1 = sim_random_range(100..1000);
304        let value2 = sim_random_range(0.0..10.0);
305
306        set_sim_seed(123);
307        assert_eq!(value1, sim_random_range(100..1000));
308        assert_eq!(value2, sim_random_range(0.0..10.0));
309    }
310
311    #[test]
312    fn test_reset_clears_state() {
313        // Set seed and advance RNG
314        set_sim_seed(42);
315        let _advance1: f64 = sim_random();
316        let _advance2: f64 = sim_random();
317        let after_advance: f64 = sim_random();
318
319        // Reset and set same seed - should get first value, not third
320        reset_sim_rng();
321        set_sim_seed(42);
322        let first_value: f64 = sim_random();
323
324        // Should be different because reset cleared the advanced state
325        assert_ne!(after_advance, first_value);
326    }
327
328    #[test]
329    fn test_sequence_persistence_within_thread() {
330        set_sim_seed(42);
331        let value1: f64 = sim_random();
332        let value2: f64 = sim_random();
333        let value3: f64 = sim_random();
334
335        // Values should form a deterministic sequence
336        set_sim_seed(42);
337        assert_eq!(value1, sim_random::<f64>());
338        assert_eq!(value2, sim_random::<f64>());
339        assert_eq!(value3, sim_random::<f64>());
340    }
341
342    #[test]
343    fn test_multiple_resets_and_seeds() {
344        // Test multiple reset/seed cycles
345        for seed in [1, 42, 12345] {
346            reset_sim_rng();
347            set_sim_seed(seed);
348            let first: f64 = sim_random();
349
350            reset_sim_rng();
351            set_sim_seed(seed);
352            assert_eq!(first, sim_random::<f64>());
353        }
354    }
355
356    #[test]
357    fn test_get_current_sim_seed() {
358        // Test getting current seed after setting
359        set_sim_seed(12345);
360        assert_eq!(get_current_sim_seed(), 12345);
361
362        set_sim_seed(98765);
363        assert_eq!(get_current_sim_seed(), 98765);
364
365        // Test that reset clears the seed
366        reset_sim_rng();
367        assert_eq!(get_current_sim_seed(), 0);
368    }
369
370    #[test]
371    fn test_call_counting() {
372        reset_sim_rng();
373        set_sim_seed(42);
374        assert_eq!(get_rng_call_count(), 0);
375
376        let _: f64 = sim_random();
377        assert_eq!(get_rng_call_count(), 1);
378
379        let _: u32 = sim_random();
380        assert_eq!(get_rng_call_count(), 2);
381
382        let _ = sim_random_range(0..100);
383        assert_eq!(get_rng_call_count(), 3);
384
385        let _ = sim_random_f64();
386        assert_eq!(get_rng_call_count(), 4);
387
388        // sim_random_range_or_default with valid range delegates to sim_random_range
389        let _ = sim_random_range_or_default(0..100);
390        assert_eq!(get_rng_call_count(), 5);
391
392        // sim_random_range_or_default with empty range does NOT consume RNG
393        let _ = sim_random_range_or_default(100..100);
394        assert_eq!(get_rng_call_count(), 5);
395    }
396
397    #[test]
398    fn test_breakpoint_reseed() {
399        reset_sim_rng();
400        set_sim_seed(100);
401
402        // Record first 5 values with seed 100
403        let mut old_values = Vec::new();
404        for _ in 0..5 {
405            old_values.push(sim_random::<f64>());
406        }
407
408        // Record first value with seed 200 from scratch
409        reset_sim_rng();
410        set_sim_seed(200);
411        let new_seed_first: f64 = sim_random();
412
413        // Replay: seed 100, breakpoint at count=5 to reseed to 200
414        reset_sim_rng();
415        set_sim_seed(100);
416        set_rng_breakpoints(vec![(5, 200)]);
417
418        // First 5 calls should match old seed
419        for (i, expected) in old_values.iter().enumerate() {
420            let actual: f64 = sim_random();
421            assert_eq!(*expected, actual, "Mismatch at call {}", i + 1);
422        }
423
424        // Call 6 triggers breakpoint (count 6 > 5), reseeds to 200
425        let after_breakpoint: f64 = sim_random();
426        assert_eq!(after_breakpoint, new_seed_first);
427        assert_eq!(get_rng_call_count(), 1);
428        assert_eq!(get_current_sim_seed(), 200);
429    }
430
431    #[test]
432    fn test_chained_breakpoints() {
433        reset_sim_rng();
434        set_sim_seed(10);
435        set_rng_breakpoints(vec![(3, 20), (2, 30)]);
436
437        // 3 calls with seed 10
438        let _: f64 = sim_random(); // count=1
439        let _: f64 = sim_random(); // count=2
440        let _: f64 = sim_random(); // count=3
441        assert_eq!(get_current_sim_seed(), 10);
442
443        // Call 4: count becomes 4 > 3, breakpoint fires: reseed to 20, count=1
444        let _: f64 = sim_random();
445        assert_eq!(get_current_sim_seed(), 20);
446        assert_eq!(get_rng_call_count(), 1);
447
448        // 1 more call with seed 20
449        let _: f64 = sim_random(); // count=2
450
451        // Call 3 of seed 20: count becomes 3 > 2, breakpoint fires: reseed to 30, count=1
452        let _: f64 = sim_random();
453        assert_eq!(get_current_sim_seed(), 30);
454        assert_eq!(get_rng_call_count(), 1);
455    }
456
457    #[test]
458    fn test_replay_determinism() {
459        // Run 1: record a "recipe" — seed 42, fork at call 3 to seed 99
460        reset_sim_rng();
461        set_sim_seed(42);
462        let _: f64 = sim_random();
463        let _: f64 = sim_random();
464        let _: f64 = sim_random();
465        let fork_count = get_rng_call_count();
466        set_sim_seed(99);
467        reset_rng_call_count();
468        let post_fork_1: f64 = sim_random();
469        let post_fork_2: f64 = sim_random();
470
471        // Run 2: replay using breakpoints
472        reset_sim_rng();
473        set_sim_seed(42);
474        set_rng_breakpoints(vec![(fork_count, 99)]);
475        let _: f64 = sim_random();
476        let _: f64 = sim_random();
477        let _: f64 = sim_random();
478        // Breakpoint triggers on next call (count 4 > 3)
479        let replay_1: f64 = sim_random();
480        let replay_2: f64 = sim_random();
481
482        assert_eq!(post_fork_1, replay_1);
483        assert_eq!(post_fork_2, replay_2);
484    }
485
486    #[test]
487    fn test_reset_clears_everything_including_breakpoints() {
488        set_sim_seed(42);
489        let _: f64 = sim_random();
490        let _: f64 = sim_random();
491        set_rng_breakpoints(vec![(10, 99)]);
492
493        assert_eq!(get_rng_call_count(), 2);
494
495        reset_sim_rng();
496
497        assert_eq!(get_rng_call_count(), 0);
498        assert_eq!(get_current_sim_seed(), 0);
499
500        // Verify breakpoints were cleared
501        set_sim_seed(42);
502        let _: f64 = sim_random();
503        assert_eq!(get_rng_call_count(), 1);
504        assert_eq!(get_current_sim_seed(), 42); // no breakpoint triggered
505    }
506}