runmat_runtime/builtins/common/
random.rs

1use std::f64::consts::PI;
2use std::sync::{Mutex, OnceLock};
3
4pub(crate) const DEFAULT_RNG_SEED: u64 = 0x9e3779b97f4a7c15;
5pub(crate) const DEFAULT_USER_SEED: u64 = 0;
6const RNG_MULTIPLIER: u64 = 6364136223846793005;
7const RNG_INCREMENT: u64 = 1;
8const RNG_SHIFT: u32 = 11;
9const RNG_SCALE: f64 = 1.0 / ((1u64 << 53) as f64);
10const MIN_UNIFORM: f64 = f64::MIN_POSITIVE;
11
12#[derive(Clone, Copy, Debug, PartialEq, Eq)]
13pub(crate) enum RngAlgorithm {
14    RunMatLcg,
15}
16
17impl RngAlgorithm {
18    pub(crate) fn as_str(&self) -> &'static str {
19        match self {
20            RngAlgorithm::RunMatLcg => "twister",
21        }
22    }
23}
24
25#[derive(Clone, Copy, Debug)]
26pub(crate) struct RngSnapshot {
27    pub state: u64,
28    pub seed: Option<u64>,
29    pub algorithm: RngAlgorithm,
30}
31
32impl RngSnapshot {
33    pub(crate) fn new(state: u64, seed: Option<u64>, algorithm: RngAlgorithm) -> Self {
34        Self {
35            state,
36            seed,
37            algorithm,
38        }
39    }
40}
41
42#[derive(Clone, Copy)]
43struct GlobalRng {
44    state: u64,
45    seed: Option<u64>,
46    algorithm: RngAlgorithm,
47}
48
49impl GlobalRng {
50    fn new() -> Self {
51        Self {
52            state: DEFAULT_RNG_SEED,
53            seed: Some(DEFAULT_USER_SEED),
54            algorithm: RngAlgorithm::RunMatLcg,
55        }
56    }
57
58    fn snapshot(&self) -> RngSnapshot {
59        RngSnapshot {
60            state: self.state,
61            seed: self.seed,
62            algorithm: self.algorithm,
63        }
64    }
65}
66
67impl From<RngSnapshot> for GlobalRng {
68    fn from(snapshot: RngSnapshot) -> Self {
69        Self {
70            state: snapshot.state,
71            seed: snapshot.seed,
72            algorithm: snapshot.algorithm,
73        }
74    }
75}
76
77static RNG_STATE: OnceLock<Mutex<GlobalRng>> = OnceLock::new();
78
79fn rng_state() -> &'static Mutex<GlobalRng> {
80    RNG_STATE.get_or_init(|| Mutex::new(GlobalRng::new()))
81}
82
83fn mix_seed(seed: u64) -> u64 {
84    if seed == 0 {
85        return DEFAULT_RNG_SEED;
86    }
87    let mut z = seed.wrapping_add(0x9e3779b97f4a7c15);
88    z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
89    z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
90    let mixed = z ^ (z >> 31);
91    if mixed == 0 {
92        DEFAULT_RNG_SEED
93    } else {
94        mixed
95    }
96}
97
98pub(crate) fn snapshot() -> Result<RngSnapshot, String> {
99    rng_state()
100        .lock()
101        .map(|guard| guard.snapshot())
102        .map_err(|_| "rng: failed to acquire RNG lock".to_string())
103}
104
105pub(crate) fn apply_snapshot(snapshot: RngSnapshot) -> Result<RngSnapshot, String> {
106    let mut guard = rng_state()
107        .lock()
108        .map_err(|_| "rng: failed to acquire RNG lock".to_string())?;
109    let previous = guard.snapshot();
110    guard.state = snapshot.state;
111    guard.seed = snapshot.seed;
112    guard.algorithm = snapshot.algorithm;
113    Ok(previous)
114}
115
116pub(crate) fn set_seed(seed: u64) -> Result<RngSnapshot, String> {
117    let state = mix_seed(seed);
118    apply_snapshot(RngSnapshot::new(state, Some(seed), RngAlgorithm::RunMatLcg))
119}
120
121pub(crate) fn set_default() -> Result<RngSnapshot, String> {
122    apply_snapshot(default_snapshot())
123}
124
125pub(crate) fn default_snapshot() -> RngSnapshot {
126    RngSnapshot::new(
127        DEFAULT_RNG_SEED,
128        Some(DEFAULT_USER_SEED),
129        RngAlgorithm::RunMatLcg,
130    )
131}
132
133pub(crate) fn generate_uniform(len: usize, label: &str) -> Result<Vec<f64>, String> {
134    let mut guard = rng_state()
135        .lock()
136        .map_err(|_| format!("{label}: failed to acquire RNG lock"))?;
137    let mut out = Vec::with_capacity(len);
138    for _ in 0..len {
139        out.push(next_uniform_state(&mut guard.state));
140    }
141    Ok(out)
142}
143
144pub(crate) fn generate_uniform_single(len: usize, label: &str) -> Result<Vec<f64>, String> {
145    generate_uniform(len, label).map(|data| {
146        data.into_iter()
147            .map(|v| {
148                let value = v as f32;
149                value as f64
150            })
151            .collect()
152    })
153}
154
155pub(crate) fn skip_uniform(len: usize, label: &str) -> Result<(), String> {
156    if len == 0 {
157        return Ok(());
158    }
159    let mut guard = rng_state()
160        .lock()
161        .map_err(|_| format!("{label}: failed to acquire RNG lock"))?;
162    guard.state = advance_state(guard.state, len as u64);
163    Ok(())
164}
165
166fn advance_state(state: u64, mut delta: u64) -> u64 {
167    if delta == 0 {
168        return state;
169    }
170    let mut cur_mult = RNG_MULTIPLIER;
171    let mut cur_plus = RNG_INCREMENT;
172    let mut acc_mult = 1u64;
173    let mut acc_plus = 0u64;
174    while delta > 0 {
175        if (delta & 1) != 0 {
176            acc_mult = acc_mult.wrapping_mul(cur_mult);
177            acc_plus = acc_plus.wrapping_mul(cur_mult).wrapping_add(cur_plus);
178        }
179        cur_plus = cur_plus.wrapping_mul(cur_mult.wrapping_add(1));
180        cur_mult = cur_mult.wrapping_mul(cur_mult);
181        delta >>= 1;
182    }
183    acc_mult.wrapping_mul(state).wrapping_add(acc_plus)
184}
185
186pub(crate) fn generate_complex(len: usize, label: &str) -> Result<Vec<(f64, f64)>, String> {
187    let mut guard = rng_state()
188        .lock()
189        .map_err(|_| format!("{label}: failed to acquire RNG lock"))?;
190    let mut out = Vec::with_capacity(len);
191    for _ in 0..len {
192        let re = next_uniform_state(&mut guard.state);
193        let im = next_uniform_state(&mut guard.state);
194        out.push((re, im));
195    }
196    Ok(out)
197}
198
199pub(crate) fn next_uniform_state(state: &mut u64) -> f64 {
200    *state = state
201        .wrapping_mul(RNG_MULTIPLIER)
202        .wrapping_add(RNG_INCREMENT);
203    let bits = *state >> RNG_SHIFT;
204    (bits as f64) * RNG_SCALE
205}
206
207fn next_normal_pair(state: &mut u64) -> (f64, f64) {
208    let mut u1 = next_uniform_state(state);
209    if u1 <= 0.0 {
210        u1 = MIN_UNIFORM;
211    }
212    let u2 = next_uniform_state(state);
213    let radius = (-2.0 * u1.ln()).sqrt();
214    let angle = 2.0 * PI * u2;
215    (radius * angle.cos(), radius * angle.sin())
216}
217
218pub(crate) fn generate_normal(len: usize, label: &str) -> Result<Vec<f64>, String> {
219    let mut guard = rng_state()
220        .lock()
221        .map_err(|_| format!("{label}: failed to acquire RNG lock"))?;
222    let mut out = Vec::with_capacity(len);
223    while out.len() < len {
224        let (z0, z1) = next_normal_pair(&mut guard.state);
225        out.push(z0);
226        if out.len() < len {
227            out.push(z1);
228        }
229    }
230    Ok(out)
231}
232
233pub(crate) fn generate_normal_complex(len: usize, label: &str) -> Result<Vec<(f64, f64)>, String> {
234    let mut guard = rng_state()
235        .lock()
236        .map_err(|_| format!("{label}: failed to acquire RNG lock"))?;
237    let mut out = Vec::with_capacity(len);
238    for _ in 0..len {
239        let (re, im) = next_normal_pair(&mut guard.state);
240        out.push((re, im));
241    }
242    Ok(out)
243}
244
245#[cfg(test)]
246pub(crate) fn reset_rng() {
247    if let Some(mutex) = RNG_STATE.get() {
248        if let Ok(mut guard) = mutex.lock() {
249            *guard = GlobalRng::from(default_snapshot());
250        }
251    } else {
252        let _ = RNG_STATE.set(Mutex::new(GlobalRng::new()));
253    }
254}
255
256#[cfg(test)]
257pub(crate) fn expected_uniform_sequence(count: usize) -> Vec<f64> {
258    let mut seed = DEFAULT_RNG_SEED;
259    let mut seq = Vec::with_capacity(count);
260    for _ in 0..count {
261        seq.push(next_uniform_state(&mut seed));
262    }
263    seq
264}
265
266#[cfg(test)]
267pub(crate) fn expected_complex_sequence(count: usize) -> Vec<(f64, f64)> {
268    let mut seed = DEFAULT_RNG_SEED;
269    let mut seq = Vec::with_capacity(count);
270    for _ in 0..count {
271        let re = next_uniform_state(&mut seed);
272        let im = next_uniform_state(&mut seed);
273        seq.push((re, im));
274    }
275    seq
276}
277
278#[cfg(test)]
279pub(crate) fn expected_normal_sequence(count: usize) -> Vec<f64> {
280    let mut seed = DEFAULT_RNG_SEED;
281    let mut seq = Vec::with_capacity(count);
282    while seq.len() < count {
283        let (z0, z1) = next_normal_pair(&mut seed);
284        seq.push(z0);
285        if seq.len() < count {
286            seq.push(z1);
287        }
288    }
289    seq
290}
291
292#[cfg(test)]
293pub(crate) fn expected_complex_normal_sequence(count: usize) -> Vec<(f64, f64)> {
294    let mut seed = DEFAULT_RNG_SEED;
295    let mut seq = Vec::with_capacity(count);
296    for _ in 0..count {
297        seq.push(next_normal_pair(&mut seed));
298    }
299    seq
300}
301
302#[cfg(test)]
303static TEST_MUTEX: OnceLock<Mutex<()>> = OnceLock::new();
304
305#[cfg(test)]
306pub(crate) fn test_lock() -> &'static Mutex<()> {
307    TEST_MUTEX.get_or_init(|| Mutex::new(()))
308}