runmat_runtime/builtins/common/
random.rs1use std::f64::consts::PI;
2use std::sync::{Mutex, OnceLock};
3
4pub(crate) const DEFAULT_RNG_SEED: u64 = 0x9e3779b97f4a7c15;
5pub(crate) const DEFAULT_USER_SEED: u64 = 0;
6const RNG_MULTIPLIER: u64 = 6364136223846793005;
7const RNG_INCREMENT: u64 = 1;
8const RNG_SHIFT: u32 = 11;
9const RNG_SCALE: f64 = 1.0 / ((1u64 << 53) as f64);
10const MIN_UNIFORM: f64 = f64::MIN_POSITIVE;
11
12#[derive(Clone, Copy, Debug, PartialEq, Eq)]
13pub(crate) enum RngAlgorithm {
14 RunMatLcg,
15}
16
17impl RngAlgorithm {
18 pub(crate) fn as_str(&self) -> &'static str {
19 match self {
20 RngAlgorithm::RunMatLcg => "twister",
21 }
22 }
23}
24
25#[derive(Clone, Copy, Debug)]
26pub(crate) struct RngSnapshot {
27 pub state: u64,
28 pub seed: Option<u64>,
29 pub algorithm: RngAlgorithm,
30}
31
32impl RngSnapshot {
33 pub(crate) fn new(state: u64, seed: Option<u64>, algorithm: RngAlgorithm) -> Self {
34 Self {
35 state,
36 seed,
37 algorithm,
38 }
39 }
40}
41
42#[derive(Clone, Copy)]
43struct GlobalRng {
44 state: u64,
45 seed: Option<u64>,
46 algorithm: RngAlgorithm,
47}
48
49impl GlobalRng {
50 fn new() -> Self {
51 Self {
52 state: DEFAULT_RNG_SEED,
53 seed: Some(DEFAULT_USER_SEED),
54 algorithm: RngAlgorithm::RunMatLcg,
55 }
56 }
57
58 fn snapshot(&self) -> RngSnapshot {
59 RngSnapshot {
60 state: self.state,
61 seed: self.seed,
62 algorithm: self.algorithm,
63 }
64 }
65}
66
67impl From<RngSnapshot> for GlobalRng {
68 fn from(snapshot: RngSnapshot) -> Self {
69 Self {
70 state: snapshot.state,
71 seed: snapshot.seed,
72 algorithm: snapshot.algorithm,
73 }
74 }
75}
76
77static RNG_STATE: OnceLock<Mutex<GlobalRng>> = OnceLock::new();
78
79fn rng_state() -> &'static Mutex<GlobalRng> {
80 RNG_STATE.get_or_init(|| Mutex::new(GlobalRng::new()))
81}
82
83fn mix_seed(seed: u64) -> u64 {
84 if seed == 0 {
85 return DEFAULT_RNG_SEED;
86 }
87 let mut z = seed.wrapping_add(0x9e3779b97f4a7c15);
88 z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
89 z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
90 let mixed = z ^ (z >> 31);
91 if mixed == 0 {
92 DEFAULT_RNG_SEED
93 } else {
94 mixed
95 }
96}
97
98pub(crate) fn snapshot() -> Result<RngSnapshot, String> {
99 rng_state()
100 .lock()
101 .map(|guard| guard.snapshot())
102 .map_err(|_| "rng: failed to acquire RNG lock".to_string())
103}
104
105pub(crate) fn apply_snapshot(snapshot: RngSnapshot) -> Result<RngSnapshot, String> {
106 let mut guard = rng_state()
107 .lock()
108 .map_err(|_| "rng: failed to acquire RNG lock".to_string())?;
109 let previous = guard.snapshot();
110 guard.state = snapshot.state;
111 guard.seed = snapshot.seed;
112 guard.algorithm = snapshot.algorithm;
113 Ok(previous)
114}
115
116pub(crate) fn set_seed(seed: u64) -> Result<RngSnapshot, String> {
117 let state = mix_seed(seed);
118 apply_snapshot(RngSnapshot::new(state, Some(seed), RngAlgorithm::RunMatLcg))
119}
120
121pub(crate) fn set_default() -> Result<RngSnapshot, String> {
122 apply_snapshot(default_snapshot())
123}
124
125pub(crate) fn default_snapshot() -> RngSnapshot {
126 RngSnapshot::new(
127 DEFAULT_RNG_SEED,
128 Some(DEFAULT_USER_SEED),
129 RngAlgorithm::RunMatLcg,
130 )
131}
132
133pub(crate) fn generate_uniform(len: usize, label: &str) -> Result<Vec<f64>, String> {
134 let mut guard = rng_state()
135 .lock()
136 .map_err(|_| format!("{label}: failed to acquire RNG lock"))?;
137 let mut out = Vec::with_capacity(len);
138 for _ in 0..len {
139 out.push(next_uniform_state(&mut guard.state));
140 }
141 Ok(out)
142}
143
144pub(crate) fn generate_uniform_single(len: usize, label: &str) -> Result<Vec<f64>, String> {
145 generate_uniform(len, label).map(|data| {
146 data.into_iter()
147 .map(|v| {
148 let value = v as f32;
149 value as f64
150 })
151 .collect()
152 })
153}
154
155pub(crate) fn skip_uniform(len: usize, label: &str) -> Result<(), String> {
156 if len == 0 {
157 return Ok(());
158 }
159 let mut guard = rng_state()
160 .lock()
161 .map_err(|_| format!("{label}: failed to acquire RNG lock"))?;
162 guard.state = advance_state(guard.state, len as u64);
163 Ok(())
164}
165
166fn advance_state(state: u64, mut delta: u64) -> u64 {
167 if delta == 0 {
168 return state;
169 }
170 let mut cur_mult = RNG_MULTIPLIER;
171 let mut cur_plus = RNG_INCREMENT;
172 let mut acc_mult = 1u64;
173 let mut acc_plus = 0u64;
174 while delta > 0 {
175 if (delta & 1) != 0 {
176 acc_mult = acc_mult.wrapping_mul(cur_mult);
177 acc_plus = acc_plus.wrapping_mul(cur_mult).wrapping_add(cur_plus);
178 }
179 cur_plus = cur_plus.wrapping_mul(cur_mult.wrapping_add(1));
180 cur_mult = cur_mult.wrapping_mul(cur_mult);
181 delta >>= 1;
182 }
183 acc_mult.wrapping_mul(state).wrapping_add(acc_plus)
184}
185
186pub(crate) fn generate_complex(len: usize, label: &str) -> Result<Vec<(f64, f64)>, String> {
187 let mut guard = rng_state()
188 .lock()
189 .map_err(|_| format!("{label}: failed to acquire RNG lock"))?;
190 let mut out = Vec::with_capacity(len);
191 for _ in 0..len {
192 let re = next_uniform_state(&mut guard.state);
193 let im = next_uniform_state(&mut guard.state);
194 out.push((re, im));
195 }
196 Ok(out)
197}
198
199pub(crate) fn next_uniform_state(state: &mut u64) -> f64 {
200 *state = state
201 .wrapping_mul(RNG_MULTIPLIER)
202 .wrapping_add(RNG_INCREMENT);
203 let bits = *state >> RNG_SHIFT;
204 (bits as f64) * RNG_SCALE
205}
206
207fn next_normal_pair(state: &mut u64) -> (f64, f64) {
208 let mut u1 = next_uniform_state(state);
209 if u1 <= 0.0 {
210 u1 = MIN_UNIFORM;
211 }
212 let u2 = next_uniform_state(state);
213 let radius = (-2.0 * u1.ln()).sqrt();
214 let angle = 2.0 * PI * u2;
215 (radius * angle.cos(), radius * angle.sin())
216}
217
218pub(crate) fn generate_normal(len: usize, label: &str) -> Result<Vec<f64>, String> {
219 let mut guard = rng_state()
220 .lock()
221 .map_err(|_| format!("{label}: failed to acquire RNG lock"))?;
222 let mut out = Vec::with_capacity(len);
223 while out.len() < len {
224 let (z0, z1) = next_normal_pair(&mut guard.state);
225 out.push(z0);
226 if out.len() < len {
227 out.push(z1);
228 }
229 }
230 Ok(out)
231}
232
233pub(crate) fn generate_normal_complex(len: usize, label: &str) -> Result<Vec<(f64, f64)>, String> {
234 let mut guard = rng_state()
235 .lock()
236 .map_err(|_| format!("{label}: failed to acquire RNG lock"))?;
237 let mut out = Vec::with_capacity(len);
238 for _ in 0..len {
239 let (re, im) = next_normal_pair(&mut guard.state);
240 out.push((re, im));
241 }
242 Ok(out)
243}
244
245#[cfg(test)]
246pub(crate) fn reset_rng() {
247 if let Some(mutex) = RNG_STATE.get() {
248 if let Ok(mut guard) = mutex.lock() {
249 *guard = GlobalRng::from(default_snapshot());
250 }
251 } else {
252 let _ = RNG_STATE.set(Mutex::new(GlobalRng::new()));
253 }
254}
255
256#[cfg(test)]
257pub(crate) fn expected_uniform_sequence(count: usize) -> Vec<f64> {
258 let mut seed = DEFAULT_RNG_SEED;
259 let mut seq = Vec::with_capacity(count);
260 for _ in 0..count {
261 seq.push(next_uniform_state(&mut seed));
262 }
263 seq
264}
265
266#[cfg(test)]
267pub(crate) fn expected_complex_sequence(count: usize) -> Vec<(f64, f64)> {
268 let mut seed = DEFAULT_RNG_SEED;
269 let mut seq = Vec::with_capacity(count);
270 for _ in 0..count {
271 let re = next_uniform_state(&mut seed);
272 let im = next_uniform_state(&mut seed);
273 seq.push((re, im));
274 }
275 seq
276}
277
278#[cfg(test)]
279pub(crate) fn expected_normal_sequence(count: usize) -> Vec<f64> {
280 let mut seed = DEFAULT_RNG_SEED;
281 let mut seq = Vec::with_capacity(count);
282 while seq.len() < count {
283 let (z0, z1) = next_normal_pair(&mut seed);
284 seq.push(z0);
285 if seq.len() < count {
286 seq.push(z1);
287 }
288 }
289 seq
290}
291
292#[cfg(test)]
293pub(crate) fn expected_complex_normal_sequence(count: usize) -> Vec<(f64, f64)> {
294 let mut seed = DEFAULT_RNG_SEED;
295 let mut seq = Vec::with_capacity(count);
296 for _ in 0..count {
297 seq.push(next_normal_pair(&mut seed));
298 }
299 seq
300}
301
302#[cfg(test)]
303static TEST_MUTEX: OnceLock<Mutex<()>> = OnceLock::new();
304
305#[cfg(test)]
306pub(crate) fn test_lock() -> &'static Mutex<()> {
307 TEST_MUTEX.get_or_init(|| Mutex::new(()))
308}