1#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
39#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
40pub enum RngBackend {
41 #[default]
43 Philox,
44 Ort,
46 Zero,
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
53pub struct RngOptions {
54 pub seed: u64,
56 pub backend: RngBackend,
57}
58
59impl Default for RngOptions {
60 fn default() -> Self {
61 Self {
62 seed: 42,
63 backend: RngBackend::Philox,
64 }
65 }
66}
67
68impl RngOptions {
69 pub const fn new(seed: u64, backend: RngBackend) -> Self {
70 Self { seed, backend }
71 }
72
73 pub fn philox(seed: u64) -> Self {
74 Self {
75 seed,
76 backend: RngBackend::Philox,
77 }
78 }
79
80 pub fn ort(seed: u64) -> Self {
81 Self {
82 seed,
83 backend: RngBackend::Ort,
84 }
85 }
86
87 pub fn zero() -> Self {
88 Self {
89 seed: 0,
90 backend: RngBackend::Zero,
91 }
92 }
93}
94
95pub fn combine_seed(global: u64, key: u64) -> u64 {
97 global.wrapping_add(key.wrapping_mul(0x9E37_79B9_7F4A_7C15))
98}
99
100pub fn ort_engine_seed(global: u64, key: u64, op_seed: Option<f32>) -> u32 {
102 if let Some(s) = op_seed {
103 s as u32
104 } else {
105 global.wrapping_add(key) as u32
106 }
107}
108
109pub fn fill_normal_like(
111 out: &mut [f32],
112 mean: f32,
113 scale: f32,
114 opts: RngOptions,
115 key: u64,
116 op_seed: Option<f32>,
117) {
118 match opts.backend {
119 RngBackend::Zero => out.fill(0.0),
120 RngBackend::Philox => {
121 let mut rng = Philox4x32::new(combine_seed(opts.seed, key));
122 for v in out.iter_mut() {
123 *v = mean + scale * rng.normal();
124 }
125 }
126 RngBackend::Ort => {
127 let mut eng = MinstdRand0::new(ort_engine_seed(opts.seed, key, op_seed));
128 let mut dist = StdNormalDist::new(mean, scale);
129 for v in out.iter_mut() {
130 *v = dist.sample(&mut eng);
131 }
132 }
133 }
134}
135
136pub fn fill_uniform_like(
138 out: &mut [f32],
139 low: f32,
140 high: f32,
141 opts: RngOptions,
142 key: u64,
143 op_seed: Option<f32>,
144) {
145 match opts.backend {
146 RngBackend::Zero => out.fill(0.0),
147 RngBackend::Philox => {
148 let mut rng = Philox4x32::new(combine_seed(opts.seed, key));
149 for v in out.iter_mut() {
150 *v = rng.uniform(low, high);
151 }
152 }
153 RngBackend::Ort => {
154 let mut eng = MinstdRand0::new(ort_engine_seed(opts.seed, key, op_seed));
155 for v in out.iter_mut() {
156 *v = low + (high - low) * eng.unit_f32();
157 }
158 }
159 }
160}
161
162#[derive(Debug, Clone, Copy)]
165pub struct Philox4x32 {
166 seed: [u32; 2],
167 counter: [u32; 4],
168 buffer: [u32; 4],
170 cursor: u8,
171}
172
173impl Philox4x32 {
174 pub const fn new(seed: u64) -> Self {
175 let lo = (seed & 0xFFFF_FFFF) as u32;
176 let hi = (seed >> 32) as u32;
177 Self {
178 seed: [lo, hi],
179 counter: [0, 0, 0, 0],
180 buffer: [0; 4],
181 cursor: 4, }
183 }
184
185 fn round(state: &mut [u32; 4], key: [u32; 2]) {
186 const M0: u64 = 0xD256_1A75;
187 const M1: u64 = 0xCD9E_8D57;
188 let p0 = (state[0] as u64) * M0;
189 let p1 = (state[2] as u64) * M1;
190 let hi0 = (p0 >> 32) as u32;
191 let lo0 = p0 as u32;
192 let hi1 = (p1 >> 32) as u32;
193 let lo1 = p1 as u32;
194 state[0] = hi1 ^ state[1] ^ key[0];
195 state[1] = lo1;
196 state[2] = hi0 ^ state[3] ^ key[1];
197 state[3] = lo0;
198 }
199
200 fn fill_buffer(&mut self) {
201 let mut state = self.counter;
202 let mut key = self.seed;
203 for _ in 0..10 {
204 Self::round(&mut state, key);
205 key[0] = key[0].wrapping_add(0x9E37_79B9);
207 key[1] = key[1].wrapping_add(0xBB67_AE85);
208 }
209 self.buffer = state;
210 self.cursor = 0;
211
212 let (c0, of0) = self.counter[0].overflowing_add(1);
214 self.counter[0] = c0;
215 if of0 {
216 let (c1, of1) = self.counter[1].overflowing_add(1);
217 self.counter[1] = c1;
218 if of1 {
219 let (c2, of2) = self.counter[2].overflowing_add(1);
220 self.counter[2] = c2;
221 if of2 {
222 self.counter[3] = self.counter[3].wrapping_add(1);
223 }
224 }
225 }
226 }
227
228 pub fn next_u32(&mut self) -> u32 {
229 if self.cursor >= 4 {
230 self.fill_buffer();
231 }
232 let v = self.buffer[self.cursor as usize];
233 self.cursor += 1;
234 v
235 }
236
237 pub fn next_f32(&mut self) -> f32 {
240 let bits = self.next_u32() >> 8;
241 bits as f32 / (1u32 << 24) as f32
242 }
243
244 pub fn uniform(&mut self, lo: f32, hi: f32) -> f32 {
246 lo + self.next_f32() * (hi - lo)
247 }
248
249 pub fn normal(&mut self) -> f32 {
253 let u1 = self.next_f32().max(f32::MIN_POSITIVE);
254 let u2 = self.next_f32();
255 let r = (-2.0 * u1.ln()).sqrt();
256 let theta = 2.0 * std::f32::consts::PI * u2;
257 r * theta.cos()
258 }
259
260 pub fn fill_uniform(&mut self, out: &mut [f32]) {
263 for v in out {
264 *v = self.next_f32();
265 }
266 }
267
268 pub fn fill_normal(&mut self, out: &mut [f32]) {
270 for v in out {
271 *v = self.normal();
272 }
273 }
274}
275
276#[derive(Debug, Clone, Copy)]
278struct MinstdRand0 {
279 state: u32,
280}
281
282impl MinstdRand0 {
283 const A: u32 = 48_271;
284 const M: u32 = 2_147_483_647;
285
286 fn new(seed: u32) -> Self {
287 Self {
288 state: seed % Self::M,
289 }
290 }
291
292 fn next_u32(&mut self) -> u32 {
293 self.state = ((self.state as u64 * Self::A as u64) % Self::M as u64) as u32;
294 self.state
295 }
296
297 fn unit_f32(&mut self) -> f32 {
299 self.next_u32() as f32 / (Self::M - 1) as f32
300 }
301}
302
303#[derive(Debug, Clone, Copy)]
305struct StdNormalDist {
306 mean: f32,
307 scale: f32,
308 spare: f32,
309 has_spare: bool,
310}
311
312impl StdNormalDist {
313 fn new(mean: f32, scale: f32) -> Self {
314 Self {
315 mean,
316 scale,
317 spare: 0.0,
318 has_spare: false,
319 }
320 }
321
322 fn sample(&mut self, eng: &mut MinstdRand0) -> f32 {
323 if self.has_spare {
324 self.has_spare = false;
325 return self.spare;
326 }
327 loop {
328 let u1 = 2.0 * eng.unit_f32() - 1.0;
329 let u2 = 2.0 * eng.unit_f32() - 1.0;
330 let s = u1 * u1 + u2 * u2;
331 if s >= 1.0 || s == 0.0 {
332 continue;
333 }
334 let factor = (-2.0 * s.ln() / s).sqrt();
335 self.spare = u2 * factor * self.scale + self.mean;
336 self.has_spare = true;
337 return u1 * factor * self.scale + self.mean;
338 }
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[test]
347 fn same_seed_same_sequence() {
348 let mut a = Philox4x32::new(0x1234_5678);
349 let mut b = Philox4x32::new(0x1234_5678);
350 for _ in 0..256 {
351 assert_eq!(a.next_u32(), b.next_u32());
352 }
353 }
354
355 #[test]
356 fn different_seed_different_sequence() {
357 let mut a = Philox4x32::new(1);
358 let mut b = Philox4x32::new(2);
359 let mut diffs = 0usize;
360 for _ in 0..16 {
361 if a.next_u32() != b.next_u32() {
362 diffs += 1;
363 }
364 }
365 assert!(
366 diffs >= 14,
367 "two distinct seeds should disagree on >=14/16 samples"
368 );
369 }
370
371 #[test]
372 fn next_f32_in_unit_interval() {
373 let mut r = Philox4x32::new(42);
374 for _ in 0..1000 {
375 let v = r.next_f32();
376 assert!((0.0..1.0).contains(&v), "{v} not in [0, 1)");
377 }
378 }
379
380 #[test]
381 fn fill_uniform_is_deterministic() {
382 let mut r1 = Philox4x32::new(7);
383 let mut r2 = Philox4x32::new(7);
384 let mut a = vec![0f32; 64];
385 let mut b = vec![0f32; 64];
386 r1.fill_uniform(&mut a);
387 r2.fill_uniform(&mut b);
388 assert_eq!(a, b);
389 }
390
391 #[test]
392 fn normal_mean_is_near_zero() {
393 let mut r = Philox4x32::new(123);
394 let n = 10_000;
395 let mut sum = 0f32;
396 for _ in 0..n {
397 sum += r.normal();
398 }
399 let mean = sum / n as f32;
400 assert!(mean.abs() < 0.1, "mean {mean} too far from 0");
401 }
402
403 #[test]
404 fn zero_backend_fills_zeros() {
405 let mut out = vec![1.0; 8];
406 fill_normal_like(&mut out, 0.0, 1.0, RngOptions::zero(), 0xABC, None);
407 assert!(out.iter().all(|&v| v == 0.0));
408 }
409
410 #[test]
411 fn philox_backend_is_deterministic() {
412 let opts = RngOptions::philox(99);
413 let mut a = vec![0f32; 32];
414 let mut b = vec![0f32; 32];
415 fill_normal_like(&mut a, 0.0, 0.5, opts, 123, None);
416 fill_normal_like(&mut b, 0.0, 0.5, opts, 123, None);
417 assert_eq!(a, b);
418 }
419
420 #[test]
421 fn ort_backend_is_deterministic() {
422 let opts = RngOptions::ort(7);
423 let mut a = vec![0f32; 64];
424 let mut b = vec![0f32; 64];
425 fill_normal_like(&mut a, 0.1, 2.0, opts, 555, None);
426 fill_normal_like(&mut b, 0.1, 2.0, opts, 555, None);
427 assert_eq!(a, b);
428 }
429
430 #[test]
431 fn backends_disagree() {
432 let mut philox = vec![0f32; 16];
433 let mut ort = vec![0f32; 16];
434 fill_normal_like(&mut philox, 0.0, 1.0, RngOptions::philox(42), 1, None);
435 fill_normal_like(&mut ort, 0.0, 1.0, RngOptions::ort(42), 1, None);
436 assert_ne!(philox, ort);
437 }
438}