Skip to main content

runmat_runtime/builtins/common/
random.rs

1use std::f64::consts::PI;
2use std::sync::{Mutex, OnceLock};
3
4use crate::{build_runtime_error, BuiltinResult, RuntimeError};
5
6pub(crate) const DEFAULT_RNG_SEED: u64 = 0x9e3779b97f4a7c15;
7pub(crate) const DEFAULT_USER_SEED: u64 = 0;
8const RNG_MULTIPLIER: u64 = 6364136223846793005;
9const RNG_INCREMENT: u64 = 1;
10const RNG_SHIFT: u32 = 11;
11const RNG_SCALE: f64 = 1.0 / ((1u64 << 53) as f64);
12const MIN_UNIFORM: f64 = f64::MIN_POSITIVE;
13
14fn random_error(label: &str, message: impl Into<String>) -> RuntimeError {
15    build_runtime_error(message).with_builtin(label).build()
16}
17
18#[derive(Clone, Copy, Debug, PartialEq, Eq)]
19pub(crate) enum RngAlgorithm {
20    RunMatLcg,
21}
22
23impl RngAlgorithm {
24    pub(crate) fn as_str(&self) -> &'static str {
25        match self {
26            RngAlgorithm::RunMatLcg => "twister",
27        }
28    }
29}
30
31#[derive(Clone, Copy, Debug)]
32pub(crate) struct RngSnapshot {
33    pub state: u64,
34    pub seed: Option<u64>,
35    pub algorithm: RngAlgorithm,
36}
37
38impl RngSnapshot {
39    pub(crate) fn new(state: u64, seed: Option<u64>, algorithm: RngAlgorithm) -> Self {
40        Self {
41            state,
42            seed,
43            algorithm,
44        }
45    }
46}
47
48#[derive(Clone, Copy)]
49struct GlobalRng {
50    state: u64,
51    seed: Option<u64>,
52    algorithm: RngAlgorithm,
53}
54
55impl GlobalRng {
56    fn new() -> Self {
57        Self {
58            state: DEFAULT_RNG_SEED,
59            seed: Some(DEFAULT_USER_SEED),
60            algorithm: RngAlgorithm::RunMatLcg,
61        }
62    }
63
64    fn snapshot(&self) -> RngSnapshot {
65        RngSnapshot {
66            state: self.state,
67            seed: self.seed,
68            algorithm: self.algorithm,
69        }
70    }
71}
72
73impl From<RngSnapshot> for GlobalRng {
74    fn from(snapshot: RngSnapshot) -> Self {
75        Self {
76            state: snapshot.state,
77            seed: snapshot.seed,
78            algorithm: snapshot.algorithm,
79        }
80    }
81}
82
83static RNG_STATE: OnceLock<Mutex<GlobalRng>> = OnceLock::new();
84
85fn rng_state() -> &'static Mutex<GlobalRng> {
86    RNG_STATE.get_or_init(|| Mutex::new(GlobalRng::new()))
87}
88
89fn mix_seed(seed: u64) -> u64 {
90    if seed == 0 {
91        return DEFAULT_RNG_SEED;
92    }
93    let mut z = seed.wrapping_add(0x9e3779b97f4a7c15);
94    z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
95    z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
96    let mixed = z ^ (z >> 31);
97    if mixed == 0 {
98        DEFAULT_RNG_SEED
99    } else {
100        mixed
101    }
102}
103
104pub(crate) fn snapshot() -> BuiltinResult<RngSnapshot> {
105    rng_state()
106        .lock()
107        .map(|guard| guard.snapshot())
108        .map_err(|_| random_error("rng", "rng: failed to acquire RNG lock"))
109}
110
111pub(crate) fn apply_snapshot(snapshot: RngSnapshot) -> BuiltinResult<RngSnapshot> {
112    let mut guard = rng_state()
113        .lock()
114        .map_err(|_| random_error("rng", "rng: failed to acquire RNG lock"))?;
115    let previous = guard.snapshot();
116    guard.state = snapshot.state;
117    guard.seed = snapshot.seed;
118    guard.algorithm = snapshot.algorithm;
119    Ok(previous)
120}
121
122pub(crate) fn set_seed(seed: u64) -> BuiltinResult<RngSnapshot> {
123    let state = mix_seed(seed);
124    apply_snapshot(RngSnapshot::new(state, Some(seed), RngAlgorithm::RunMatLcg))
125}
126
127pub(crate) fn set_default() -> BuiltinResult<RngSnapshot> {
128    apply_snapshot(default_snapshot())
129}
130
131pub(crate) fn default_snapshot() -> RngSnapshot {
132    RngSnapshot::new(
133        DEFAULT_RNG_SEED,
134        Some(DEFAULT_USER_SEED),
135        RngAlgorithm::RunMatLcg,
136    )
137}
138
139pub(crate) fn generate_uniform(len: usize, label: &str) -> BuiltinResult<Vec<f64>> {
140    let mut guard = rng_state()
141        .lock()
142        .map_err(|_| random_error(label, format!("{label}: failed to acquire RNG lock")))?;
143    let mut out = Vec::with_capacity(len);
144    for _ in 0..len {
145        out.push(next_uniform_state(&mut guard.state));
146    }
147    Ok(out)
148}
149
150pub(crate) fn generate_uniform_single(len: usize, label: &str) -> BuiltinResult<Vec<f64>> {
151    generate_uniform(len, label).map(|data| {
152        data.into_iter()
153            .map(|v| {
154                let value = v as f32;
155                value as f64
156            })
157            .collect()
158    })
159}
160
161pub(crate) fn skip_uniform(len: usize, label: &str) -> BuiltinResult<()> {
162    if len == 0 {
163        return Ok(());
164    }
165    let mut guard = rng_state()
166        .lock()
167        .map_err(|_| random_error(label, format!("{label}: failed to acquire RNG lock")))?;
168    guard.state = advance_state(guard.state, len as u64);
169    Ok(())
170}
171
172fn advance_state(state: u64, mut delta: u64) -> u64 {
173    if delta == 0 {
174        return state;
175    }
176    let mut cur_mult = RNG_MULTIPLIER;
177    let mut cur_plus = RNG_INCREMENT;
178    let mut acc_mult = 1u64;
179    let mut acc_plus = 0u64;
180    while delta > 0 {
181        if (delta & 1) != 0 {
182            acc_mult = acc_mult.wrapping_mul(cur_mult);
183            acc_plus = acc_plus.wrapping_mul(cur_mult).wrapping_add(cur_plus);
184        }
185        cur_plus = cur_plus.wrapping_mul(cur_mult.wrapping_add(1));
186        cur_mult = cur_mult.wrapping_mul(cur_mult);
187        delta >>= 1;
188    }
189    acc_mult.wrapping_mul(state).wrapping_add(acc_plus)
190}
191
192pub(crate) fn generate_complex(len: usize, label: &str) -> BuiltinResult<Vec<(f64, f64)>> {
193    let mut guard = rng_state()
194        .lock()
195        .map_err(|_| random_error(label, format!("{label}: failed to acquire RNG lock")))?;
196    let mut out = Vec::with_capacity(len);
197    for _ in 0..len {
198        let re = next_uniform_state(&mut guard.state);
199        let im = next_uniform_state(&mut guard.state);
200        out.push((re, im));
201    }
202    Ok(out)
203}
204
205pub(crate) fn next_uniform_state(state: &mut u64) -> f64 {
206    *state = state
207        .wrapping_mul(RNG_MULTIPLIER)
208        .wrapping_add(RNG_INCREMENT);
209    let bits = *state >> RNG_SHIFT;
210    (bits as f64) * RNG_SCALE
211}
212
213fn next_normal_pair(state: &mut u64) -> (f64, f64) {
214    let mut u1 = next_uniform_state(state);
215    if u1 <= 0.0 {
216        u1 = MIN_UNIFORM;
217    }
218    let u2 = next_uniform_state(state);
219    let radius = (-2.0 * u1.ln()).sqrt();
220    let angle = 2.0 * PI * u2;
221    (radius * angle.cos(), radius * angle.sin())
222}
223
224pub(crate) fn generate_normal(len: usize, label: &str) -> BuiltinResult<Vec<f64>> {
225    let mut guard = rng_state()
226        .lock()
227        .map_err(|_| random_error(label, format!("{label}: failed to acquire RNG lock")))?;
228    let mut out = Vec::with_capacity(len);
229    while out.len() < len {
230        let (z0, z1) = next_normal_pair(&mut guard.state);
231        out.push(z0);
232        if out.len() < len {
233            out.push(z1);
234        }
235    }
236    Ok(out)
237}
238
239pub(crate) fn generate_normal_complex(len: usize, label: &str) -> BuiltinResult<Vec<(f64, f64)>> {
240    let mut guard = rng_state()
241        .lock()
242        .map_err(|_| random_error(label, format!("{label}: failed to acquire RNG lock")))?;
243    let mut out = Vec::with_capacity(len);
244    for _ in 0..len {
245        let (re, im) = next_normal_pair(&mut guard.state);
246        out.push((re, im));
247    }
248    Ok(out)
249}
250
251#[cfg(test)]
252pub(crate) fn reset_rng() {
253    if let Some(mutex) = RNG_STATE.get() {
254        if let Ok(mut guard) = mutex.lock() {
255            *guard = GlobalRng::from(default_snapshot());
256        }
257    } else {
258        let _ = RNG_STATE.set(Mutex::new(GlobalRng::new()));
259    }
260}
261
262#[cfg(test)]
263pub(crate) fn expected_uniform_sequence(count: usize) -> Vec<f64> {
264    let mut seed = DEFAULT_RNG_SEED;
265    let mut seq = Vec::with_capacity(count);
266    for _ in 0..count {
267        seq.push(next_uniform_state(&mut seed));
268    }
269    seq
270}
271
272#[cfg(test)]
273pub(crate) fn expected_complex_sequence(count: usize) -> Vec<(f64, f64)> {
274    let mut seed = DEFAULT_RNG_SEED;
275    let mut seq = Vec::with_capacity(count);
276    for _ in 0..count {
277        let re = next_uniform_state(&mut seed);
278        let im = next_uniform_state(&mut seed);
279        seq.push((re, im));
280    }
281    seq
282}
283
284#[cfg(test)]
285pub(crate) fn expected_normal_sequence(count: usize) -> Vec<f64> {
286    let mut seed = DEFAULT_RNG_SEED;
287    let mut seq = Vec::with_capacity(count);
288    while seq.len() < count {
289        let (z0, z1) = next_normal_pair(&mut seed);
290        seq.push(z0);
291        if seq.len() < count {
292            seq.push(z1);
293        }
294    }
295    seq
296}
297
298#[cfg(test)]
299pub(crate) fn expected_complex_normal_sequence(count: usize) -> Vec<(f64, f64)> {
300    let mut seed = DEFAULT_RNG_SEED;
301    let mut seq = Vec::with_capacity(count);
302    for _ in 0..count {
303        seq.push(next_normal_pair(&mut seed));
304    }
305    seq
306}
307
308#[cfg(test)]
309static TEST_MUTEX: OnceLock<Mutex<()>> = OnceLock::new();
310
311#[cfg(test)]
312pub(crate) fn test_lock() -> &'static Mutex<()> {
313    TEST_MUTEX.get_or_init(|| Mutex::new(()))
314}