Skip to main content

roma_lib/utils/
random.rs

1use std::sync::atomic::{AtomicU64, Ordering};
2use std::time::{SystemTime, UNIX_EPOCH};
3
4static SEED_COUNTER: AtomicU64 = AtomicU64::new(0xA0761D6478BD642F);
5
6/// A small non-cryptographic 64-bit mixer based on SplitMix-like transforms.
7/// Kept private so the public API of `Random` is unchanged while we reuse
8/// the same mixing routine for seeding and output generation.
9#[inline]
10fn mix64(mut z: u64) -> u64 {
11    z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
12    z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
13    z ^ (z >> 31)
14}
15
16/// Generates a time-based seed and folds in cheap per-process variability.
17///
18/// A monotonic atomic counter is mixed in so back-to-back calls still produce
19/// distinct seeds even when the system clock resolution is coarse.
20pub fn seed_from_time() -> u64 {
21    // Gather several cheap, system-dependent sources of variability and
22    // fold them into a single 64-bit value before mixing.
23    let now = SystemTime::now()
24        .duration_since(UNIX_EPOCH)
25        .unwrap()
26        .as_nanos(); // u128
27
28    let now_lo = now as u64;
29    let now_hi = (now >> 64) as u64;
30    let pid = std::process::id() as u64;
31    let counter = SEED_COUNTER.fetch_add(0x9E3779B97F4A7C15, Ordering::Relaxed);
32    // address of a local value gives additional low-cost entropy between
33    // rapidly repeated calls (stack pointer / ASLR differences)
34    let stack_addr = (&now as *const _ as usize) as u64;
35
36    let mut seed = now_lo ^ now_hi;
37    seed = seed.wrapping_add(pid.wrapping_mul(0x9E3779B97F4A7C15));
38    seed ^= counter;
39    seed ^= stack_addr.wrapping_mul(0xBF58476D1CE4E5B9);
40    mix64(seed)
41}
42
43/// Small deterministic pseudo-random number generator used across Roma.
44///
45/// The generator is intentionally dependency-free and exposes its internal
46/// state so algorithms can checkpoint and resume runs exactly.
47#[derive(Debug, Clone)]
48pub struct Random {
49    state: u64,
50}
51
52impl Random {
53    /// Creates a generator initialized with `seed`.
54    pub fn new(seed: u64) -> Self {
55        Self { state: seed }
56    }
57
58    /// Derives a reproducible stream seed from a base seed and stream id.
59    #[inline]
60    pub fn derive_seed(base_seed: u64, stream: u64) -> u64 {
61        // Use the same mixer as the generator to ensure small changes in
62        // (base_seed, stream) produce dramatically different derived seeds.
63        mix64(base_seed ^ stream.wrapping_mul(0x9E3779B97F4A7C15))
64    }
65
66    /// Returns the configured seed or a fresh time-based seed when absent.
67    #[inline]
68    pub fn resolve_seed(random_seed: Option<u64>) -> u64 {
69        random_seed.unwrap_or_else(seed_from_time)
70    }
71
72    /// Returns the current internal generator state.
73    #[inline]
74    pub fn state(&self) -> u64 {
75        self.state
76    }
77
78    /// Replaces the internal generator state.
79    #[inline]
80    pub fn set_state(&mut self, state: u64) {
81        self.state = state;
82    }
83
84    /// Advances the generator and returns a 64-bit pseudo-random value.
85    #[inline]
86    pub fn next_u64(&mut self) -> u64 {
87        // Increment state (as in SplitMix) then mix.
88        self.state = self.state.wrapping_add(0x9E3779B97F4A7C15);
89        mix64(self.state)
90    }
91
92    /// Advances the generator and returns the low 32 bits.
93    #[inline]
94    pub fn next_u32(&mut self) -> u32 {
95        self.next_u64() as u32
96    }
97
98    /// Returns a floating-point value in the half-open interval `[0, 1)`.
99    #[inline]
100    pub fn next_f64(&mut self) -> f64 {
101        // Take the top 53 bits and scale into [0,1).
102        let x = (self.next_u64() >> 11) as u64;
103        (x as f64) * (1.0 / 9007199254740992.0) // 1 / 2^53
104    }
105
106    /// Returns an integer in the half-open interval `[0, max)`.
107    ///
108    /// When `max` is zero the function returns `0`.
109    #[inline]
110    pub fn range(&mut self, max: u64) -> u64 {
111        if max == 0 {
112            return 0;
113        }
114        if max == 1 {
115            return 0;
116        }
117
118        if max.is_power_of_two() {
119            return self.next_u64() & (max - 1);
120        }
121
122        // Lemire-style unbiased reduction: use the high half of a 128-bit
123        // product and only retry when the low half falls in the small biased
124        // zone near zero.
125        let threshold = max.wrapping_neg() % max;
126        loop {
127            let random = self.next_u64();
128            let product = (random as u128) * (max as u128);
129            let low = product as u64;
130            if low >= threshold {
131                return (product >> 64) as u64;
132            }
133        }
134    }
135
136    /// Returns an integer in the half-open interval `[min, max)`.
137    ///
138    /// In debug builds, invalid or empty intervals trigger a debug assertion.
139    /// In release builds, `min` is returned as a defensive fallback.
140    #[inline]
141    pub fn range_between(&mut self, min: u64, max: u64) -> u64 {
142        debug_assert!(max > min, "Random::range_between requires max > min");
143        if max <= min {
144            return min;
145        }
146        min + self.range(max - min)
147    }
148
149    /// Returns `true` with probability `p`.
150    #[inline]
151    pub fn chance(&mut self, p: f64) -> bool {
152        self.next_f64() < p
153    }
154
155    /// Returns `true` with probability `0.5`.
156    #[inline]
157    pub fn coin_flip(&mut self) -> bool {
158        self.chance(0.5)
159    }
160}
161
162#[cfg(test)]
163mod test {
164
165    use crate::utils::random::{seed_from_time, Random};
166
167    #[test]
168    fn range_between_test() {
169        let min: u64 = 100;
170        let max: u64 = 200;
171        let mut rng: Random = Random::new(seed_from_time());
172
173        let x: u64 = rng.range_between(min, max);
174        assert!(x >= min && x < max);
175    }
176
177    #[test]
178    fn coin_flip_test() {
179        let mut rng_seed_generator = Random::new(seed_from_time());
180        let seed: u64 = rng_seed_generator.next_u64();
181
182        let mut rng: Random = Random::new(seed);
183        let prob_chance: f64 = 0.0;
184
185        let x: bool = rng.chance(prob_chance);
186        assert!(!x);
187
188        let prob_chance = 1.0;
189        let x: bool = rng.chance(prob_chance);
190        assert!(x);
191    }
192
193    #[test]
194    fn random_determinism_test() {
195        let mut rng_seed_generator = Random::new(seed_from_time());
196        let seed: u64 = rng_seed_generator.next_u64();
197
198        let mut rng_1: Random = Random::new(seed);
199        let mut rng_2: Random = Random::new(seed);
200
201        assert_eq!(
202            rng_1.coin_flip(),
203            rng_2.coin_flip(),
204            "Structure Random with the same seed should give the same result"
205        );
206        assert_eq!(
207            rng_1.next_f64(),
208            rng_2.next_f64(),
209            "Structure Random with the same seed should give the same result"
210        );
211        assert_eq!(
212            rng_1.next_u32(),
213            rng_2.next_u32(),
214            "Structure Random with the same seed should give the same result"
215        );
216        assert_eq!(
217            rng_1.next_u64(),
218            rng_2.next_u64(),
219            "Structure Random with the same seed should give the same result"
220        );
221    }
222
223    #[test]
224    fn zero_seed_test() {
225        let seed: u64 = 0;
226        let mut rng: Random = Random::new(seed);
227
228        let max: u64 = 200;
229        let x: u64 = rng.range(max);
230        assert!(x < max);
231    }
232
233    #[test]
234    fn highest_seed_test() {
235        let seed = u64::MAX;
236        let mut rng: Random = Random::new(seed);
237
238        let min: u64 = 100;
239        let max: u64 = 200;
240        let x: u64 = rng.range_between(min, max);
241        assert!(x >= min && x < max);
242    }
243
244    #[test]
245    fn seed_from_time_back_to_back_calls_produce_distinct_seeds() {
246        let first = seed_from_time();
247        let second = seed_from_time();
248
249        assert_ne!(first, second);
250    }
251
252    #[test]
253    fn range_handles_power_of_two_upper_bound() {
254        let mut rng = Random::new(42);
255        let upper_bound = 1024;
256
257        for _ in 0..1024 {
258            assert!(rng.range(upper_bound) < upper_bound);
259        }
260    }
261
262    #[test]
263    fn range_handles_large_upper_bound() {
264        let mut rng = Random::new(42);
265        let upper_bound = u64::MAX;
266
267        for _ in 0..1024 {
268            assert!(rng.range(upper_bound) < upper_bound);
269        }
270    }
271}