retrofire_core/math/rand.rs
1//! Pseudo-random number generation and distributions.
2
3use core::{array, fmt::Debug, ops::Range};
4
5use super::{Color, Point, Point2, Point3, Vec2, Vec3, Vector};
6
7//
8// Traits and types
9//
10
11pub type DefaultRng = Xorshift64;
12
13/// Trait for generating values sampled from a probability distribution.
14pub trait Distrib: Clone {
15 /// The type of the elements of the sample space of `Self`, also called
16 /// "outcomes".
17 type Sample;
18
19 /// Returns a pseudo-random value sampled from `self`.
20 ///
21 /// # Examples
22 /// ```
23 /// use retrofire_core::math::rand::*;
24 ///
25 /// // Simulate rolling a six-sided die
26 /// let rng = &mut DefaultRng::default();
27 /// let d6 = Uniform(1..7).sample(rng);
28 /// assert_eq!(d6, 3);
29 /// ```
30 fn sample(&self, rng: &mut DefaultRng) -> Self::Sample;
31
32 /// Returns an iterator that yields samples from `self` indefinitely.
33 ///
34 /// # Examples
35 /// ```
36 /// use retrofire_core::math::rand::*;
37 ///
38 /// // Simulate rolling a six-sided die three times
39 /// let rng = &mut DefaultRng::default();
40 /// let mut iter = Uniform(1u32..7).samples(rng);
41 ///
42 /// assert_eq!(iter.next(), Some(1));
43 /// assert_eq!(iter.next(), Some(2));
44 /// assert_eq!(iter.next(), Some(4));
45 /// ```
46 fn samples(
47 &self,
48 rng: &mut DefaultRng,
49 ) -> impl Iterator<Item = Self::Sample> {
50 Iter(self.clone(), rng)
51 }
52}
53
54/// A pseudo-random number generator (PRNG) that uses a [Xorshift algorithm][^1]
55/// to generate 64 bits of randomness at a time, represented by a `u64`.
56///
57/// Xorshift is a type of linear-feedback shift register that uses only three
58/// right-shifts and three xor operations per generated number, making it very
59/// efficient. Xorshift64 has a period of 2<sup>64</sup>-1: it yields every
60/// number in the interval [1, 2<sup>64</sup>) exactly once before repeating.
61///
62/// [^1]: Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software,
63/// 8(14), 1–6. <https://doi.org/10.18637/jss.v008.i14>
64#[derive(Copy, Clone, Debug)]
65#[repr(transparent)]
66pub struct Xorshift64(pub u64);
67
68/// A uniform distribution of values in a range.
69#[derive(Clone, Debug)]
70pub struct Uniform<T>(pub Range<T>);
71
72/// A uniform distribution of unit 2-vectors.
73#[derive(Copy, Clone, Debug)]
74pub struct UnitCircle;
75
76/// A uniform distribution of unit 3-vectors.
77#[derive(Copy, Clone, Debug, Default)]
78pub struct UnitSphere;
79
80/// A uniform distribution of 2-vectors inside the (closed) unit disk.
81#[derive(Copy, Clone, Debug, Default)]
82pub struct VectorsOnUnitDisk;
83
84/// A uniform distribution of 3-vectors inside the (closed) unit ball.
85#[derive(Copy, Clone, Debug, Default)]
86pub struct VectorsInUnitBall;
87
88/// A uniform distribution of 2-points inside the (closed) unit disk.
89#[derive(Copy, Clone, Debug, Default)]
90pub struct PointsOnUnitDisk;
91
92/// A uniform distribution of 3-points inside the (closed) unit ball.
93#[derive(Copy, Clone, Debug, Default)]
94pub struct PointsInUnitBall;
95
96/// A Bernoulli distribution.
97///
98/// Generates boolean values such that:
99/// * P(true) = p
100/// * P(false) = 1 - p.
101///
102/// given a parameter p ∈ [0.0, 1.0].
103#[derive(Copy, Clone, Debug)]
104pub struct Bernoulli(pub f32);
105
106/// Iterator returned by the [`Distrib::samples()`] method.
107#[derive(Copy, Clone, Debug)]
108struct Iter<D, R>(D, R);
109
110//
111// Inherent impls
112//
113
114impl Xorshift64 {
115 /// A random 64-bit prime, used to initialize the generator returned by
116 /// [`Xorshift64::default()`].
117 pub const DEFAULT_SEED: u64 = 378682147834061;
118
119 /// Returns a new `Xorshift64` seeded by the given number.
120 ///
121 /// Two `Xorshift64` instances generate the same sequence of pseudo-random
122 /// numbers if and only if they were created with the same seed.
123 /// (Technically, every `Xorshift64` instance yields values from the same
124 /// sequence; the seed determines the starting point in the sequence).
125 ///
126 /// # Examples
127 /// ```
128 /// use retrofire_core::math::rand::Xorshift64;
129 ///
130 /// let mut g = Xorshift64::from_seed(123);
131 /// assert_eq!(g.next_bits(), 133101616827);
132 /// assert_eq!(g.next_bits(), 12690785413091508870);
133 /// assert_eq!(g.next_bits(), 7516749944291143043);
134 /// ```
135 ///
136 /// # Panics
137 ///
138 /// If `seed` equals 0.
139 pub fn from_seed(seed: u64) -> Self {
140 assert_ne!(seed, 0, "xorshift seed cannot be zero");
141 Self(seed)
142 }
143
144 /// Returns a new `Xorshift64` seeded by the current system time.
145 ///
146 /// Note that depending on the precision of the system clock, two or more
147 /// calls to this function in quick succession *may* return instances seeded
148 /// by the same number.
149 ///
150 /// # Examples
151 /// ```
152 /// use std::thread;
153 /// use retrofire_core::math::rand::Xorshift64;
154 ///
155 /// let mut g = Xorshift64::from_time();
156 /// thread::sleep_ms(1); // Just to be sure
157 /// let mut h = Xorshift64::from_time();
158 /// assert_ne!(g.next_bits(), h.next_bits());
159 /// ```
160 #[cfg(feature = "std")]
161 pub fn from_time() -> Self {
162 let t = std::time::SystemTime::UNIX_EPOCH
163 .elapsed()
164 .unwrap();
165 Self(t.as_micros() as u64)
166 }
167
168 /// Returns 64 bits of pseudo-randomness.
169 ///
170 /// Successive calls to this function (with the same `self`) will yield
171 /// every value in the interval [1, 2<sup>64</sup>) exactly once before
172 /// starting to repeat the sequence.
173 pub fn next_bits(&mut self) -> u64 {
174 let Self(x) = self;
175 *x ^= *x << 13;
176 *x ^= *x >> 7;
177 *x ^= *x << 17;
178 *x
179 }
180}
181
182//
183// Foreign trait impls
184//
185
186/// An infinite iterator of pseudorandom values sampled from a distribution.
187///
188/// This type is returned by [`Distrib::samples`].
189impl<D: Distrib> Iterator for Iter<D, &'_ mut DefaultRng> {
190 type Item = D::Sample;
191
192 /// Returns the next pseudorandom sample from this iterator.
193 ///
194 /// This method never returns `None`.
195 fn next(&mut self) -> Option<Self::Item> {
196 Some(self.0.sample(self.1))
197 }
198}
199
200impl Default for Xorshift64 {
201 /// Returns a `Xorshift64` seeded with [`DEFAULT_SEED`](Self::DEFAULT_SEED).
202 ///
203 /// # Examples
204 /// ```
205 /// use retrofire_core::math::rand::Xorshift64;
206 ///
207 /// let mut g = Xorshift64::default();
208 /// assert_eq!(g.next_bits(), 11039719294064252060);
209 /// ```
210 fn default() -> Self {
211 // Random 64-bit prime
212 Self::from_seed(Self::DEFAULT_SEED)
213 }
214}
215
216//
217// Local trait impls
218//
219
220/// Uniformly distributed signed integers.
221impl Distrib for Uniform<i32> {
222 type Sample = i32;
223
224 /// Returns a uniformly distributed `i32` in the range.
225 ///
226 /// # Examples
227 /// ```
228 /// use retrofire_core::math::rand::*;
229 /// let rng = &mut DefaultRng::default();
230 ///
231 ///
232 /// let mut iter = Uniform(-5i32..6).samples(rng);
233 /// assert_eq!(iter.next(), Some(0));
234 /// assert_eq!(iter.next(), Some(4));
235 /// assert_eq!(iter.next(), Some(5));
236 /// ```
237 fn sample(&self, rng: &mut DefaultRng) -> i32 {
238 let bits = rng.next_bits() as i32;
239 // TODO rem introduces slight bias
240 bits.rem_euclid(self.0.end - self.0.start) + self.0.start
241 }
242}
243/// Uniformly distributed unsigned integers.
244impl Distrib for Uniform<u32> {
245 type Sample = u32;
246
247 /// Returns a uniformly distributed `u32` in the range.
248 ///
249 /// # Examples
250 /// ```
251 /// use retrofire_core::math::rand::*;
252 /// let rng = &mut DefaultRng::from_seed(1234);
253 ///
254 /// // Simulate rolling a six-sided die
255 /// let mut rolls: Vec<_> = Uniform(1u32..7)
256 /// .samples(rng)
257 /// .take(6)
258 /// .collect();
259 /// assert_eq!(rolls, [2, 4, 6, 6, 3, 1]);
260 /// ```
261 fn sample(&self, rng: &mut DefaultRng) -> u32 {
262 let bits = rng.next_bits() as u32;
263 // TODO rem introduces slight bias
264 bits.rem_euclid(self.0.end - self.0.start) + self.0.start
265 }
266}
267
268/// Uniformly distributed indices.
269impl Distrib for Uniform<usize> {
270 type Sample = usize;
271
272 /// Returns a uniformly distributed `usize` in the range.
273 ///
274 /// # Examples
275 /// ```
276 /// use retrofire_core::math::rand::*;
277 /// let rng = &mut DefaultRng::default();
278 ///
279 /// // Randomly sample elements from a list (with replacement)
280 /// let beverages = ["water", "tea", "coffee", "Coke", "Red Bull"];
281 /// let mut x: Vec<_> = Uniform(0..beverages.len())
282 /// .samples(rng)
283 /// .take(3)
284 /// .map(|i| beverages[i])
285 /// .collect();
286 ///
287 /// assert_eq!(x, ["water", "tea", "Red Bull"]);
288 /// ```
289 fn sample(&self, rng: &mut DefaultRng) -> usize {
290 let bits = rng.next_bits() as usize;
291 // TODO rem introduces slight bias
292 bits.rem_euclid(self.0.end - self.0.start) + self.0.start
293 }
294}
295
296/// Uniformly distributed floats.
297impl Distrib for Uniform<f32> {
298 type Sample = f32;
299
300 /// Returns a uniformly distributed `f32` in the range.
301 ///
302 /// # Examples
303 /// ```
304 /// use retrofire_core::math::rand::*;
305 /// let rng = &mut DefaultRng::default();
306 ///
307 /// // Floats in the interval [-1, 1)
308 /// let mut iter = Uniform(-1.0..1.0).samples(rng);
309 /// assert_eq!(iter.next(), Some(0.19692874));
310 /// assert_eq!(iter.next(), Some(-0.7686298));
311 /// assert_eq!(iter.next(), Some(0.91969657));
312 /// ```
313 fn sample(&self, rng: &mut DefaultRng) -> f32 {
314 let Range { start, end } = self.0;
315 // Bit repr of a random f32 in range 1.0..2.0
316 // Leaves a lot of precision unused near zero, but it's okay.
317 let (exp, mantissa) = (127 << 23, rng.next_bits() >> 41);
318 let unit = f32::from_bits(exp | mantissa as u32) - 1.0;
319 unit * (end - start) + start
320 }
321}
322
323impl<T, const N: usize> Distrib for Uniform<[T; N]>
324where
325 T: Copy,
326 Uniform<T>: Distrib<Sample = T>,
327{
328 type Sample = [T; N];
329
330 /// Returns the coordinates of a point sampled from a uniform distribution
331 /// within the N-dimensional rectangular volume bounded by `self.0`.
332 ///
333 /// # Examples
334 /// ```
335 /// use retrofire_core::math::rand::*;
336 /// let rng = &mut DefaultRng::default();
337 ///
338 /// // Pairs of integers [X, Y] such that 0 <= X < 4 and -2 <= Y <= 3
339 /// let mut int_pairs = Uniform([0, -2]..[4, 3]).samples(rng);
340 ///
341 /// assert_eq!(int_pairs.next(), Some([0, -1]));
342 /// assert_eq!(int_pairs.next(), Some([1, 0]));
343 /// assert_eq!(int_pairs.next(), Some([3, 1]));
344 /// ```
345 fn sample(&self, rng: &mut DefaultRng) -> [T; N] {
346 let Range { start, end } = self.0;
347 array::from_fn(|i| Uniform(start[i]..end[i]).sample(rng))
348 }
349}
350
351/// Uniformly distributed vectors within a rectangular volume.
352impl<Sc, Sp, const DIM: usize> Distrib for Uniform<Vector<[Sc; DIM], Sp>>
353where
354 Sc: Copy,
355 Uniform<[Sc; DIM]>: Distrib<Sample = [Sc; DIM]>,
356{
357 type Sample = Vector<[Sc; DIM], Sp>;
358
359 /// Returns a vector uniformly sampled from the rectangular volume
360 /// bounded by `self.0`.
361 fn sample(&self, rng: &mut DefaultRng) -> Self::Sample {
362 Uniform(self.0.start.0..self.0.end.0)
363 .sample(rng)
364 .into()
365 }
366}
367
368/// Uniformly distributed points within a rectangular volume.
369impl<Sc, Sp, const DIM: usize> Distrib for Uniform<Point<[Sc; DIM], Sp>>
370where
371 Sc: Copy,
372 Uniform<[Sc; DIM]>: Distrib<Sample = [Sc; DIM]>,
373{
374 type Sample = Point<[Sc; DIM], Sp>;
375
376 /// Returns a point uniformly sampled from the rectangular volume
377 /// bounded by `self.0`.
378 fn sample(&self, rng: &mut DefaultRng) -> Self::Sample {
379 Uniform(self.0.start.0..self.0.end.0)
380 .sample(rng)
381 .into()
382 }
383}
384impl<Sc, Sp, const DIM: usize> Distrib for Uniform<Color<[Sc; DIM], Sp>>
385where
386 Sc: Copy,
387 Sp: Clone, // TODO Color needs manual Clone etc impls like Vector
388 Uniform<[Sc; DIM]>: Distrib<Sample = [Sc; DIM]>,
389{
390 type Sample = Point<[Sc; DIM], Sp>;
391
392 /// Returns a point uniformly sampled from the rectangular volume
393 /// bounded by `self.0`.
394 fn sample(&self, rng: &mut DefaultRng) -> Self::Sample {
395 Uniform(self.0.start.0..self.0.end.0)
396 .sample(rng)
397 .into()
398 }
399}
400
401#[cfg(feature = "fp")]
402impl Distrib for UnitCircle {
403 type Sample = Vec2;
404
405 /// Returns a unit 2-vector uniformly sampled from the unit circle.
406 ///
407 /// # Example
408 /// ```
409 /// use retrofire_core::math::{ApproxEq, rand::*};
410 /// let rng = &mut DefaultRng::default();
411 ///
412 /// let vec = UnitCircle.sample(rng);
413 /// assert!(vec.len_sqr().approx_eq(&1.0));
414 /// ```
415 fn sample(&self, rng: &mut DefaultRng) -> Vec2 {
416 let d = Uniform([-1.0; 2]..[1.0; 2]);
417 // Normalization preserves uniformity
418 Vec2::from(d.sample(rng)).normalize()
419 }
420}
421
422impl Distrib for VectorsOnUnitDisk {
423 type Sample = Vec2;
424
425 /// Returns a 2-vector uniformly sampled from the unit disk.
426 ///
427 /// # Example
428 /// ```
429 /// use retrofire_core::math::rand::*;
430 /// let rng = &mut DefaultRng::default();
431 ///
432 /// let vec = VectorsOnUnitDisk.sample(rng);
433 /// assert!(vec.len_sqr() <= 1.0);
434 /// ```
435 fn sample(&self, rng: &mut DefaultRng) -> Vec2 {
436 let d = Uniform([-1.0f32; 2]..[1.0; 2]);
437 loop {
438 // Rejection sampling
439 let v = Vec2::from(d.sample(rng));
440 if v.len_sqr() <= 1.0 {
441 return v;
442 }
443 }
444 }
445}
446
447#[cfg(feature = "fp")]
448impl Distrib for UnitSphere {
449 type Sample = Vec3;
450
451 /// Returns a unit 3-vector uniformly sampled from the unit sphere.
452 ///
453 /// # Example
454 /// ```
455 /// use retrofire_core::assert_approx_eq;
456 /// use retrofire_core::math::rand::*;
457 /// let rng = &mut DefaultRng::default();
458 ///
459 /// let vec = UnitSphere.sample(rng);
460 /// assert_approx_eq!(vec.len_sqr(), 1.0);
461 /// ```
462 fn sample(&self, rng: &mut DefaultRng) -> Vec3 {
463 let d = Uniform([-1.0; 3]..[1.0; 3]);
464 Vec3::from(d.sample(rng)).normalize()
465 }
466}
467
468impl Distrib for VectorsInUnitBall {
469 type Sample = Vec3;
470
471 /// Returns a 3-vector uniformly sampled from the unit ball.
472 ///
473 /// # Example
474 /// ```
475 /// use retrofire_core::math::rand::*;
476 /// let rng = &mut DefaultRng::default();
477 ///
478 /// let vec = VectorsInUnitBall.sample(rng);
479 /// assert!(vec.len_sqr() <= 1.0);
480 /// ```
481 fn sample(&self, rng: &mut DefaultRng) -> Vec3 {
482 let d = Uniform([-1.0; 3]..[1.0; 3]);
483 loop {
484 // Rejection sampling
485 let v = Vec3::from(d.sample(rng));
486 if v.len_sqr() <= 1.0 {
487 return v;
488 }
489 }
490 }
491}
492
493impl Distrib for PointsOnUnitDisk {
494 type Sample = Point2;
495
496 /// Returns a 2-point uniformly sampled from the unit disk.
497 ///
498 /// See [`VectorsOnUnitDisk::sample`].
499 fn sample(&self, rng: &mut DefaultRng) -> Point2 {
500 VectorsOnUnitDisk.sample(rng).to_pt()
501 }
502}
503
504impl Distrib for PointsInUnitBall {
505 type Sample = Point3;
506
507 /// Returns a 3-point uniformly sampled from the unit ball.
508 ///
509 /// See [`VectorsInUnitBall::sample`].
510 fn sample(&self, rng: &mut DefaultRng) -> Point3 {
511 VectorsInUnitBall.sample(rng).to_pt()
512 }
513}
514
515impl Distrib for Bernoulli {
516 type Sample = bool;
517
518 /// Returns booleans sampled from a Bernoulli distribution.
519 ///
520 /// The result is `true` with probability `self.0` and false
521 /// with probability 1 - `self.0`.
522 ///
523 /// # Example
524 /// ```
525 /// use core::array;
526 /// use retrofire_core::math::rand::*;
527 /// let rng = &mut DefaultRng::default();
528 ///
529 /// let bern = Bernoulli(0.6); // P(true) = 0.6
530 /// let bools = array::from_fn(|_| bern.sample(rng));
531 /// assert_eq!(bools, [true, true, false, true, false, true]);
532 /// ```
533 fn sample(&self, rng: &mut DefaultRng) -> bool {
534 Uniform(0.0f32..1.0).sample(rng) < self.0
535 }
536}
537
538impl<D: Distrib, E: Distrib> Distrib for (D, E) {
539 type Sample = (D::Sample, E::Sample);
540
541 /// Returns a pair of samples, sampled from two separate distributions.
542 fn sample(&self, rng: &mut DefaultRng) -> Self::Sample {
543 (self.0.sample(rng), self.1.sample(rng))
544 }
545}
546
547#[cfg(test)]
548#[allow(clippy::manual_range_contains)]
549mod tests {
550 use crate::math::vec3;
551
552 use super::*;
553
554 const COUNT: usize = 1000;
555
556 fn rng() -> DefaultRng {
557 Default::default()
558 }
559
560 #[test]
561 fn uniform_i32() {
562 let dist = Uniform(-123i32..456);
563 for r in dist.samples(&mut rng()).take(COUNT) {
564 assert!(-123 <= r && r < 456);
565 }
566 }
567
568 #[test]
569 fn uniform_f32() {
570 let dist = Uniform(-1.23..4.56);
571 for r in dist.samples(&mut rng()).take(COUNT) {
572 assert!(-1.23 <= r && r < 4.56);
573 }
574 }
575
576 #[test]
577 fn uniform_i32_array() {
578 let dist = Uniform([0, -10]..[10, 15]);
579
580 let sum = dist
581 .samples(&mut rng())
582 .take(COUNT)
583 .inspect(|&[x, y]| {
584 assert!(0 <= x && x < 10);
585 assert!(-10 <= y && x < 15);
586 })
587 .fold([0, 0], |[ax, ay], [x, y]| [ax + x, ay + y]);
588
589 assert_eq!(sum, [4531, 1652]);
590 }
591
592 #[test]
593 fn uniform_vec3() {
594 let dist =
595 Uniform(vec3::<f32, ()>(-2.0, 0.0, -1.0)..vec3(1.0, 2.0, 3.0));
596
597 let mean = dist
598 .samples(&mut rng())
599 .take(COUNT)
600 .inspect(|v| {
601 assert!(-2.0 <= v.x() && v.x() < 1.0);
602 assert!(0.0 <= v.y() && v.y() < 2.0);
603 assert!(-1.0 <= v.z() && v.z() < 3.0);
604 })
605 .sum::<Vec3>()
606 / COUNT as f32;
607
608 assert_eq!(mean, vec3(-0.46046025, 1.0209353, 0.9742225));
609 }
610
611 #[test]
612 fn bernoulli() {
613 let rng = &mut rng();
614 let bools = Bernoulli(0.1).samples(rng).take(COUNT);
615 let approx_100 = bools.filter(|&b| b).count();
616 assert_eq!(approx_100, 82);
617 }
618
619 #[cfg(feature = "fp")]
620 #[test]
621 fn unit_circle() {
622 use crate::assert_approx_eq;
623 for v in UnitCircle.samples(&mut rng()).take(COUNT) {
624 assert_approx_eq!(v.len_sqr(), 1.0, "non-unit vector: {v:?}");
625 }
626 }
627
628 #[test]
629 fn vectors_on_unit_disk() {
630 for v in VectorsOnUnitDisk.samples(&mut rng()).take(COUNT) {
631 assert!(v.len_sqr() <= 1.0, "vector of len > 1.0: {v:?}");
632 }
633 }
634
635 #[cfg(feature = "fp")]
636 #[test]
637 fn unit_sphere() {
638 use crate::assert_approx_eq;
639 for v in UnitSphere.samples(&mut rng()).take(COUNT) {
640 assert_approx_eq!(v.len_sqr(), 1.0, "non-unit vector: {v:?}");
641 }
642 }
643
644 #[test]
645 fn vectors_in_unit_ball() {
646 for v in VectorsInUnitBall.samples(&mut rng()).take(COUNT) {
647 assert!(v.len_sqr() <= 1.0, "vector of len > 1.0: {v:?}");
648 }
649 }
650
651 #[test]
652 fn zipped_pair() {
653 let rng = &mut rng();
654 let dist = (Bernoulli(0.8), Uniform(0..4));
655 assert_eq!(dist.sample(rng), (true, 1));
656 assert_eq!(dist.sample(rng), (false, 3));
657 assert_eq!(dist.sample(rng), (true, 2));
658 }
659}