Skip to main content

runmat_runtime/builtins/common/
random.rs

1use std::f64::consts::PI;
2use std::sync::{Mutex, OnceLock};
3
4use crate::{build_runtime_error, BuiltinResult, RuntimeError};
5
6pub(crate) const DEFAULT_RNG_SEED: u64 = 0x9e3779b97f4a7c15;
7pub(crate) const DEFAULT_USER_SEED: u64 = 0;
8const RNG_MULTIPLIER: u64 = 6364136223846793005;
9const RNG_INCREMENT: u64 = 1;
10const RNG_SHIFT: u32 = 11;
11const RNG_SCALE: f64 = 1.0 / ((1u64 << 53) as f64);
12const MIN_UNIFORM: f64 = f64::MIN_POSITIVE;
13
14fn random_error(label: &str, message: impl Into<String>) -> RuntimeError {
15    build_runtime_error(message).with_builtin(label).build()
16}
17
18#[derive(Clone, Copy, Debug, PartialEq, Eq)]
19pub(crate) enum RngAlgorithm {
20    RunMatLcg,
21}
22
23impl RngAlgorithm {
24    pub(crate) fn as_str(&self) -> &'static str {
25        match self {
26            RngAlgorithm::RunMatLcg => "twister",
27        }
28    }
29}
30
31#[derive(Clone, Copy, Debug)]
32pub(crate) struct RngSnapshot {
33    pub state: u64,
34    pub seed: Option<u64>,
35    pub algorithm: RngAlgorithm,
36}
37
38impl RngSnapshot {
39    pub(crate) fn new(state: u64, seed: Option<u64>, algorithm: RngAlgorithm) -> Self {
40        Self {
41            state,
42            seed,
43            algorithm,
44        }
45    }
46}
47
48#[derive(Clone, Copy)]
49struct GlobalRng {
50    state: u64,
51    seed: Option<u64>,
52    algorithm: RngAlgorithm,
53}
54
55impl GlobalRng {
56    fn new() -> Self {
57        Self {
58            state: DEFAULT_RNG_SEED,
59            seed: Some(DEFAULT_USER_SEED),
60            algorithm: RngAlgorithm::RunMatLcg,
61        }
62    }
63
64    fn snapshot(&self) -> RngSnapshot {
65        RngSnapshot {
66            state: self.state,
67            seed: self.seed,
68            algorithm: self.algorithm,
69        }
70    }
71}
72
73impl From<RngSnapshot> for GlobalRng {
74    fn from(snapshot: RngSnapshot) -> Self {
75        Self {
76            state: snapshot.state,
77            seed: snapshot.seed,
78            algorithm: snapshot.algorithm,
79        }
80    }
81}
82
83static RNG_STATE: OnceLock<Mutex<GlobalRng>> = OnceLock::new();
84
85fn rng_state() -> &'static Mutex<GlobalRng> {
86    RNG_STATE.get_or_init(|| Mutex::new(GlobalRng::new()))
87}
88
89fn mix_seed(seed: u64) -> u64 {
90    if seed == 0 {
91        return DEFAULT_RNG_SEED;
92    }
93    let mut z = seed.wrapping_add(0x9e3779b97f4a7c15);
94    z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
95    z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
96    let mixed = z ^ (z >> 31);
97    if mixed == 0 {
98        DEFAULT_RNG_SEED
99    } else {
100        mixed
101    }
102}
103
104pub(crate) fn snapshot() -> BuiltinResult<RngSnapshot> {
105    rng_state()
106        .lock()
107        .map(|guard| guard.snapshot())
108        .map_err(|_| random_error("rng", "rng: failed to acquire RNG lock"))
109}
110
111pub(crate) fn apply_snapshot(snapshot: RngSnapshot) -> BuiltinResult<RngSnapshot> {
112    let mut guard = rng_state()
113        .lock()
114        .map_err(|_| random_error("rng", "rng: failed to acquire RNG lock"))?;
115    let previous = guard.snapshot();
116    guard.state = snapshot.state;
117    guard.seed = snapshot.seed;
118    guard.algorithm = snapshot.algorithm;
119    Ok(previous)
120}
121
122pub(crate) fn set_seed(seed: u64) -> BuiltinResult<RngSnapshot> {
123    let state = mix_seed(seed);
124    apply_snapshot(RngSnapshot::new(state, Some(seed), RngAlgorithm::RunMatLcg))
125}
126
127pub(crate) fn set_default() -> BuiltinResult<RngSnapshot> {
128    apply_snapshot(default_snapshot())
129}
130
131pub(crate) fn default_snapshot() -> RngSnapshot {
132    RngSnapshot::new(
133        DEFAULT_RNG_SEED,
134        Some(DEFAULT_USER_SEED),
135        RngAlgorithm::RunMatLcg,
136    )
137}
138
139pub(crate) fn generate_uniform(len: usize, label: &str) -> BuiltinResult<Vec<f64>> {
140    let mut guard = rng_state()
141        .lock()
142        .map_err(|_| random_error(label, format!("{label}: failed to acquire RNG lock")))?;
143    let mut out = Vec::with_capacity(len);
144    for _ in 0..len {
145        out.push(next_uniform_state(&mut guard.state));
146    }
147    Ok(out)
148}
149
150pub(crate) fn generate_uniform_single(len: usize, label: &str) -> BuiltinResult<Vec<f64>> {
151    generate_uniform(len, label).map(|data| {
152        data.into_iter()
153            .map(|v| {
154                let value = v as f32;
155                value as f64
156            })
157            .collect()
158    })
159}
160
161pub(crate) fn skip_uniform(len: usize, label: &str) -> BuiltinResult<()> {
162    if len == 0 {
163        return Ok(());
164    }
165    let mut guard = rng_state()
166        .lock()
167        .map_err(|_| random_error(label, format!("{label}: failed to acquire RNG lock")))?;
168    guard.state = advance_state(guard.state, len as u64);
169    Ok(())
170}
171
172fn advance_state(state: u64, mut delta: u64) -> u64 {
173    if delta == 0 {
174        return state;
175    }
176    let mut cur_mult = RNG_MULTIPLIER;
177    let mut cur_plus = RNG_INCREMENT;
178    let mut acc_mult = 1u64;
179    let mut acc_plus = 0u64;
180    while delta > 0 {
181        if (delta & 1) != 0 {
182            acc_mult = acc_mult.wrapping_mul(cur_mult);
183            acc_plus = acc_plus.wrapping_mul(cur_mult).wrapping_add(cur_plus);
184        }
185        cur_plus = cur_plus.wrapping_mul(cur_mult.wrapping_add(1));
186        cur_mult = cur_mult.wrapping_mul(cur_mult);
187        delta >>= 1;
188    }
189    acc_mult.wrapping_mul(state).wrapping_add(acc_plus)
190}
191
192pub(crate) fn generate_complex(len: usize, label: &str) -> BuiltinResult<Vec<(f64, f64)>> {
193    let mut guard = rng_state()
194        .lock()
195        .map_err(|_| random_error(label, format!("{label}: failed to acquire RNG lock")))?;
196    let mut out = Vec::with_capacity(len);
197    for _ in 0..len {
198        let re = next_uniform_state(&mut guard.state);
199        let im = next_uniform_state(&mut guard.state);
200        out.push((re, im));
201    }
202    Ok(out)
203}
204
205pub(crate) fn next_uniform_state(state: &mut u64) -> f64 {
206    *state = state
207        .wrapping_mul(RNG_MULTIPLIER)
208        .wrapping_add(RNG_INCREMENT);
209    let bits = *state >> RNG_SHIFT;
210    (bits as f64) * RNG_SCALE
211}
212
213fn next_normal_pair(state: &mut u64) -> (f64, f64) {
214    let mut u1 = next_uniform_state(state);
215    if u1 <= 0.0 {
216        u1 = MIN_UNIFORM;
217    }
218    let u2 = next_uniform_state(state);
219    let radius = (-2.0 * u1.ln()).sqrt();
220    let angle = 2.0 * PI * u2;
221    (radius * angle.cos(), radius * angle.sin())
222}
223
224pub(crate) fn generate_exponential(mu: f64, len: usize, label: &str) -> BuiltinResult<Vec<f64>> {
225    let mut guard = rng_state()
226        .lock()
227        .map_err(|_| random_error(label, format!("{label}: failed to acquire RNG lock")))?;
228    let mut out = Vec::with_capacity(len);
229    for _ in 0..len {
230        let u = next_uniform_state(&mut guard.state).max(MIN_UNIFORM);
231        out.push(-mu * u.ln());
232    }
233    Ok(out)
234}
235
236pub(crate) fn generate_normal_scaled(
237    mu: f64,
238    sigma: f64,
239    len: usize,
240    label: &str,
241) -> BuiltinResult<Vec<f64>> {
242    let mut guard = rng_state()
243        .lock()
244        .map_err(|_| random_error(label, format!("{label}: failed to acquire RNG lock")))?;
245    let mut out = Vec::with_capacity(len);
246    while out.len() < len {
247        let (z0, z1) = next_normal_pair(&mut guard.state);
248        out.push(mu + sigma * z0);
249        if out.len() < len {
250            out.push(mu + sigma * z1);
251        }
252    }
253    Ok(out)
254}
255
256pub(crate) fn generate_uniform_scaled(
257    a: f64,
258    b: f64,
259    len: usize,
260    label: &str,
261) -> BuiltinResult<Vec<f64>> {
262    let mut guard = rng_state()
263        .lock()
264        .map_err(|_| random_error(label, format!("{label}: failed to acquire RNG lock")))?;
265    let mut out = Vec::with_capacity(len);
266    for _ in 0..len {
267        out.push(a + (b - a) * next_uniform_state(&mut guard.state));
268    }
269    Ok(out)
270}
271
272pub(crate) fn generate_normal(len: usize, label: &str) -> BuiltinResult<Vec<f64>> {
273    let mut guard = rng_state()
274        .lock()
275        .map_err(|_| random_error(label, format!("{label}: failed to acquire RNG lock")))?;
276    let mut out = Vec::with_capacity(len);
277    while out.len() < len {
278        let (z0, z1) = next_normal_pair(&mut guard.state);
279        out.push(z0);
280        if out.len() < len {
281            out.push(z1);
282        }
283    }
284    Ok(out)
285}
286
287pub(crate) fn generate_normal_complex(len: usize, label: &str) -> BuiltinResult<Vec<(f64, f64)>> {
288    let mut guard = rng_state()
289        .lock()
290        .map_err(|_| random_error(label, format!("{label}: failed to acquire RNG lock")))?;
291    let mut out = Vec::with_capacity(len);
292    for _ in 0..len {
293        let (re, im) = next_normal_pair(&mut guard.state);
294        out.push((re, im));
295    }
296    Ok(out)
297}
298
299#[cfg(test)]
300pub(crate) fn reset_rng() {
301    if let Some(mutex) = RNG_STATE.get() {
302        if let Ok(mut guard) = mutex.lock() {
303            *guard = GlobalRng::from(default_snapshot());
304        }
305    } else {
306        let _ = RNG_STATE.set(Mutex::new(GlobalRng::new()));
307    }
308}
309
310#[cfg(test)]
311pub(crate) fn expected_exponential_sequence(mu: f64, count: usize) -> Vec<f64> {
312    let mut seed = DEFAULT_RNG_SEED;
313    let mut seq = Vec::with_capacity(count);
314    for _ in 0..count {
315        let u = next_uniform_state(&mut seed).max(MIN_UNIFORM);
316        seq.push(-mu * u.ln());
317    }
318    seq
319}
320
321#[cfg(test)]
322pub(crate) fn expected_normal_scaled_sequence(mu: f64, sigma: f64, count: usize) -> Vec<f64> {
323    let mut seed = DEFAULT_RNG_SEED;
324    let mut seq = Vec::with_capacity(count);
325    while seq.len() < count {
326        let (z0, z1) = next_normal_pair(&mut seed);
327        seq.push(mu + sigma * z0);
328        if seq.len() < count {
329            seq.push(mu + sigma * z1);
330        }
331    }
332    seq
333}
334
335#[cfg(test)]
336pub(crate) fn expected_uniform_scaled_sequence(a: f64, b: f64, count: usize) -> Vec<f64> {
337    let mut seed = DEFAULT_RNG_SEED;
338    let mut seq = Vec::with_capacity(count);
339    for _ in 0..count {
340        seq.push(a + (b - a) * next_uniform_state(&mut seed));
341    }
342    seq
343}
344
345#[cfg(test)]
346pub(crate) fn expected_uniform_sequence(count: usize) -> Vec<f64> {
347    let mut seed = DEFAULT_RNG_SEED;
348    let mut seq = Vec::with_capacity(count);
349    for _ in 0..count {
350        seq.push(next_uniform_state(&mut seed));
351    }
352    seq
353}
354
355#[cfg(test)]
356pub(crate) fn expected_complex_sequence(count: usize) -> Vec<(f64, f64)> {
357    let mut seed = DEFAULT_RNG_SEED;
358    let mut seq = Vec::with_capacity(count);
359    for _ in 0..count {
360        let re = next_uniform_state(&mut seed);
361        let im = next_uniform_state(&mut seed);
362        seq.push((re, im));
363    }
364    seq
365}
366
367#[cfg(test)]
368pub(crate) fn expected_normal_sequence(count: usize) -> Vec<f64> {
369    let mut seed = DEFAULT_RNG_SEED;
370    let mut seq = Vec::with_capacity(count);
371    while seq.len() < count {
372        let (z0, z1) = next_normal_pair(&mut seed);
373        seq.push(z0);
374        if seq.len() < count {
375            seq.push(z1);
376        }
377    }
378    seq
379}
380
381#[cfg(test)]
382pub(crate) fn expected_complex_normal_sequence(count: usize) -> Vec<(f64, f64)> {
383    let mut seed = DEFAULT_RNG_SEED;
384    let mut seq = Vec::with_capacity(count);
385    for _ in 0..count {
386        seq.push(next_normal_pair(&mut seed));
387    }
388    seq
389}
390
391#[cfg(test)]
392static TEST_MUTEX: OnceLock<Mutex<()>> = OnceLock::new();
393
394#[cfg(test)]
395pub(crate) fn test_lock() -> &'static Mutex<()> {
396    TEST_MUTEX.get_or_init(|| Mutex::new(()))
397}