runmat_runtime/builtins/common/
random.rs1use std::f64::consts::PI;
2use std::sync::{Mutex, OnceLock};
3
4use crate::{build_runtime_error, BuiltinResult, RuntimeError};
5
6pub(crate) const DEFAULT_RNG_SEED: u64 = 0x9e3779b97f4a7c15;
7pub(crate) const DEFAULT_USER_SEED: u64 = 0;
8const RNG_MULTIPLIER: u64 = 6364136223846793005;
9const RNG_INCREMENT: u64 = 1;
10const RNG_SHIFT: u32 = 11;
11const RNG_SCALE: f64 = 1.0 / ((1u64 << 53) as f64);
12const MIN_UNIFORM: f64 = f64::MIN_POSITIVE;
13
14fn random_error(label: &str, message: impl Into<String>) -> RuntimeError {
15 build_runtime_error(message).with_builtin(label).build()
16}
17
18#[derive(Clone, Copy, Debug, PartialEq, Eq)]
19pub(crate) enum RngAlgorithm {
20 RunMatLcg,
21}
22
23impl RngAlgorithm {
24 pub(crate) fn as_str(&self) -> &'static str {
25 match self {
26 RngAlgorithm::RunMatLcg => "twister",
27 }
28 }
29}
30
31#[derive(Clone, Copy, Debug)]
32pub(crate) struct RngSnapshot {
33 pub state: u64,
34 pub seed: Option<u64>,
35 pub algorithm: RngAlgorithm,
36}
37
38impl RngSnapshot {
39 pub(crate) fn new(state: u64, seed: Option<u64>, algorithm: RngAlgorithm) -> Self {
40 Self {
41 state,
42 seed,
43 algorithm,
44 }
45 }
46}
47
48#[derive(Clone, Copy)]
49struct GlobalRng {
50 state: u64,
51 seed: Option<u64>,
52 algorithm: RngAlgorithm,
53}
54
55impl GlobalRng {
56 fn new() -> Self {
57 Self {
58 state: DEFAULT_RNG_SEED,
59 seed: Some(DEFAULT_USER_SEED),
60 algorithm: RngAlgorithm::RunMatLcg,
61 }
62 }
63
64 fn snapshot(&self) -> RngSnapshot {
65 RngSnapshot {
66 state: self.state,
67 seed: self.seed,
68 algorithm: self.algorithm,
69 }
70 }
71}
72
73impl From<RngSnapshot> for GlobalRng {
74 fn from(snapshot: RngSnapshot) -> Self {
75 Self {
76 state: snapshot.state,
77 seed: snapshot.seed,
78 algorithm: snapshot.algorithm,
79 }
80 }
81}
82
83static RNG_STATE: OnceLock<Mutex<GlobalRng>> = OnceLock::new();
84
85fn rng_state() -> &'static Mutex<GlobalRng> {
86 RNG_STATE.get_or_init(|| Mutex::new(GlobalRng::new()))
87}
88
89fn mix_seed(seed: u64) -> u64 {
90 if seed == 0 {
91 return DEFAULT_RNG_SEED;
92 }
93 let mut z = seed.wrapping_add(0x9e3779b97f4a7c15);
94 z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
95 z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
96 let mixed = z ^ (z >> 31);
97 if mixed == 0 {
98 DEFAULT_RNG_SEED
99 } else {
100 mixed
101 }
102}
103
104pub(crate) fn snapshot() -> BuiltinResult<RngSnapshot> {
105 rng_state()
106 .lock()
107 .map(|guard| guard.snapshot())
108 .map_err(|_| random_error("rng", "rng: failed to acquire RNG lock"))
109}
110
111pub(crate) fn apply_snapshot(snapshot: RngSnapshot) -> BuiltinResult<RngSnapshot> {
112 let mut guard = rng_state()
113 .lock()
114 .map_err(|_| random_error("rng", "rng: failed to acquire RNG lock"))?;
115 let previous = guard.snapshot();
116 guard.state = snapshot.state;
117 guard.seed = snapshot.seed;
118 guard.algorithm = snapshot.algorithm;
119 Ok(previous)
120}
121
122pub(crate) fn set_seed(seed: u64) -> BuiltinResult<RngSnapshot> {
123 let state = mix_seed(seed);
124 apply_snapshot(RngSnapshot::new(state, Some(seed), RngAlgorithm::RunMatLcg))
125}
126
127pub(crate) fn set_default() -> BuiltinResult<RngSnapshot> {
128 apply_snapshot(default_snapshot())
129}
130
131pub(crate) fn default_snapshot() -> RngSnapshot {
132 RngSnapshot::new(
133 DEFAULT_RNG_SEED,
134 Some(DEFAULT_USER_SEED),
135 RngAlgorithm::RunMatLcg,
136 )
137}
138
139pub(crate) fn generate_uniform(len: usize, label: &str) -> BuiltinResult<Vec<f64>> {
140 let mut guard = rng_state()
141 .lock()
142 .map_err(|_| random_error(label, format!("{label}: failed to acquire RNG lock")))?;
143 let mut out = Vec::with_capacity(len);
144 for _ in 0..len {
145 out.push(next_uniform_state(&mut guard.state));
146 }
147 Ok(out)
148}
149
150pub(crate) fn generate_uniform_single(len: usize, label: &str) -> BuiltinResult<Vec<f64>> {
151 generate_uniform(len, label).map(|data| {
152 data.into_iter()
153 .map(|v| {
154 let value = v as f32;
155 value as f64
156 })
157 .collect()
158 })
159}
160
161pub(crate) fn skip_uniform(len: usize, label: &str) -> BuiltinResult<()> {
162 if len == 0 {
163 return Ok(());
164 }
165 let mut guard = rng_state()
166 .lock()
167 .map_err(|_| random_error(label, format!("{label}: failed to acquire RNG lock")))?;
168 guard.state = advance_state(guard.state, len as u64);
169 Ok(())
170}
171
172fn advance_state(state: u64, mut delta: u64) -> u64 {
173 if delta == 0 {
174 return state;
175 }
176 let mut cur_mult = RNG_MULTIPLIER;
177 let mut cur_plus = RNG_INCREMENT;
178 let mut acc_mult = 1u64;
179 let mut acc_plus = 0u64;
180 while delta > 0 {
181 if (delta & 1) != 0 {
182 acc_mult = acc_mult.wrapping_mul(cur_mult);
183 acc_plus = acc_plus.wrapping_mul(cur_mult).wrapping_add(cur_plus);
184 }
185 cur_plus = cur_plus.wrapping_mul(cur_mult.wrapping_add(1));
186 cur_mult = cur_mult.wrapping_mul(cur_mult);
187 delta >>= 1;
188 }
189 acc_mult.wrapping_mul(state).wrapping_add(acc_plus)
190}
191
192pub(crate) fn generate_complex(len: usize, label: &str) -> BuiltinResult<Vec<(f64, f64)>> {
193 let mut guard = rng_state()
194 .lock()
195 .map_err(|_| random_error(label, format!("{label}: failed to acquire RNG lock")))?;
196 let mut out = Vec::with_capacity(len);
197 for _ in 0..len {
198 let re = next_uniform_state(&mut guard.state);
199 let im = next_uniform_state(&mut guard.state);
200 out.push((re, im));
201 }
202 Ok(out)
203}
204
205pub(crate) fn next_uniform_state(state: &mut u64) -> f64 {
206 *state = state
207 .wrapping_mul(RNG_MULTIPLIER)
208 .wrapping_add(RNG_INCREMENT);
209 let bits = *state >> RNG_SHIFT;
210 (bits as f64) * RNG_SCALE
211}
212
213fn next_normal_pair(state: &mut u64) -> (f64, f64) {
214 let mut u1 = next_uniform_state(state);
215 if u1 <= 0.0 {
216 u1 = MIN_UNIFORM;
217 }
218 let u2 = next_uniform_state(state);
219 let radius = (-2.0 * u1.ln()).sqrt();
220 let angle = 2.0 * PI * u2;
221 (radius * angle.cos(), radius * angle.sin())
222}
223
224pub(crate) fn generate_normal(len: usize, label: &str) -> BuiltinResult<Vec<f64>> {
225 let mut guard = rng_state()
226 .lock()
227 .map_err(|_| random_error(label, format!("{label}: failed to acquire RNG lock")))?;
228 let mut out = Vec::with_capacity(len);
229 while out.len() < len {
230 let (z0, z1) = next_normal_pair(&mut guard.state);
231 out.push(z0);
232 if out.len() < len {
233 out.push(z1);
234 }
235 }
236 Ok(out)
237}
238
239pub(crate) fn generate_normal_complex(len: usize, label: &str) -> BuiltinResult<Vec<(f64, f64)>> {
240 let mut guard = rng_state()
241 .lock()
242 .map_err(|_| random_error(label, format!("{label}: failed to acquire RNG lock")))?;
243 let mut out = Vec::with_capacity(len);
244 for _ in 0..len {
245 let (re, im) = next_normal_pair(&mut guard.state);
246 out.push((re, im));
247 }
248 Ok(out)
249}
250
251#[cfg(test)]
252pub(crate) fn reset_rng() {
253 if let Some(mutex) = RNG_STATE.get() {
254 if let Ok(mut guard) = mutex.lock() {
255 *guard = GlobalRng::from(default_snapshot());
256 }
257 } else {
258 let _ = RNG_STATE.set(Mutex::new(GlobalRng::new()));
259 }
260}
261
262#[cfg(test)]
263pub(crate) fn expected_uniform_sequence(count: usize) -> Vec<f64> {
264 let mut seed = DEFAULT_RNG_SEED;
265 let mut seq = Vec::with_capacity(count);
266 for _ in 0..count {
267 seq.push(next_uniform_state(&mut seed));
268 }
269 seq
270}
271
272#[cfg(test)]
273pub(crate) fn expected_complex_sequence(count: usize) -> Vec<(f64, f64)> {
274 let mut seed = DEFAULT_RNG_SEED;
275 let mut seq = Vec::with_capacity(count);
276 for _ in 0..count {
277 let re = next_uniform_state(&mut seed);
278 let im = next_uniform_state(&mut seed);
279 seq.push((re, im));
280 }
281 seq
282}
283
284#[cfg(test)]
285pub(crate) fn expected_normal_sequence(count: usize) -> Vec<f64> {
286 let mut seed = DEFAULT_RNG_SEED;
287 let mut seq = Vec::with_capacity(count);
288 while seq.len() < count {
289 let (z0, z1) = next_normal_pair(&mut seed);
290 seq.push(z0);
291 if seq.len() < count {
292 seq.push(z1);
293 }
294 }
295 seq
296}
297
298#[cfg(test)]
299pub(crate) fn expected_complex_normal_sequence(count: usize) -> Vec<(f64, f64)> {
300 let mut seed = DEFAULT_RNG_SEED;
301 let mut seq = Vec::with_capacity(count);
302 for _ in 0..count {
303 seq.push(next_normal_pair(&mut seed));
304 }
305 seq
306}
307
308#[cfg(test)]
309static TEST_MUTEX: OnceLock<Mutex<()>> = OnceLock::new();
310
311#[cfg(test)]
312pub(crate) fn test_lock() -> &'static Mutex<()> {
313 TEST_MUTEX.get_or_init(|| Mutex::new(()))
314}