retrofire_core/math/
rand.rs

1//! Pseudo-random number generation and distributions.
2
3use core::{array, fmt::Debug, ops::Range};
4
5use super::{Color, Point, Point2, Point3, Vec2, Vec3, Vector};
6
7//
8// Traits and types
9//
10
11pub type DefaultRng = Xorshift64;
12
13/// Trait for generating values sampled from a probability distribution.
14pub trait Distrib: Clone {
15    /// The type of the elements of the sample space of `Self`, also called
16    /// "outcomes".
17    type Sample;
18
19    /// Returns a pseudo-random value sampled from `self`.
20    ///
21    /// # Examples
22    /// ```
23    /// use retrofire_core::math::rand::*;
24    ///
25    /// // Simulate rolling a six-sided die
26    /// let rng = &mut DefaultRng::default();
27    /// let d6 = Uniform(1..7).sample(rng);
28    /// assert_eq!(d6, 3);
29    /// ```
30    fn sample(&self, rng: &mut DefaultRng) -> Self::Sample;
31
32    /// Returns an iterator that yields samples from `self` indefinitely.
33    ///
34    /// # Examples
35    /// ```
36    /// use retrofire_core::math::rand::*;
37    ///
38    /// // Simulate rolling a six-sided die three times
39    /// let rng = &mut DefaultRng::default();
40    /// let mut iter = Uniform(1u32..7).samples(rng);
41    ///
42    /// assert_eq!(iter.next(), Some(1));
43    /// assert_eq!(iter.next(), Some(2));
44    /// assert_eq!(iter.next(), Some(4));
45    /// ```
46    fn samples(
47        &self,
48        rng: &mut DefaultRng,
49    ) -> impl Iterator<Item = Self::Sample> {
50        Iter(self.clone(), rng)
51    }
52}
53
54/// A pseudo-random number generator (PRNG) that uses a [Xorshift algorithm][^1]
55/// to generate 64 bits of randomness at a time, represented by a `u64`.
56///
57/// Xorshift is a type of linear-feedback shift register that uses only three
58/// right-shifts and three xor operations per generated number, making it very
59/// efficient. Xorshift64 has a period of 2<sup>64</sup>-1: it yields every
60/// number in the interval [1, 2<sup>64</sup>) exactly once before repeating.
61///
62/// [^1]: Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software,
63///     8(14), 1–6. <https://doi.org/10.18637/jss.v008.i14>
64#[derive(Copy, Clone, Debug)]
65#[repr(transparent)]
66pub struct Xorshift64(pub u64);
67
68/// A uniform distribution of values in a range.
69#[derive(Clone, Debug)]
70pub struct Uniform<T>(pub Range<T>);
71
72/// A uniform distribution of unit 2-vectors.
73#[derive(Copy, Clone, Debug)]
74pub struct UnitCircle;
75
76/// A uniform distribution of unit 3-vectors.
77#[derive(Copy, Clone, Debug, Default)]
78pub struct UnitSphere;
79
80/// A uniform distribution of 2-vectors inside the (closed) unit disk.
81#[derive(Copy, Clone, Debug, Default)]
82pub struct VectorsOnUnitDisk;
83
84/// A uniform distribution of 3-vectors inside the (closed) unit ball.
85#[derive(Copy, Clone, Debug, Default)]
86pub struct VectorsInUnitBall;
87
88/// A uniform distribution of 2-points inside the (closed) unit disk.
89#[derive(Copy, Clone, Debug, Default)]
90pub struct PointsOnUnitDisk;
91
92/// A uniform distribution of 3-points inside the (closed) unit ball.
93#[derive(Copy, Clone, Debug, Default)]
94pub struct PointsInUnitBall;
95
96/// A Bernoulli distribution.
97///
98/// Generates boolean values such that:
99/// * P(true) = p
100/// * P(false) = 1 - p.
101///
102/// given a parameter p ∈ [0.0, 1.0].
103#[derive(Copy, Clone, Debug)]
104pub struct Bernoulli(pub f32);
105
106/// Iterator returned by the [`Distrib::samples()`] method.
107#[derive(Copy, Clone, Debug)]
108struct Iter<D, R>(D, R);
109
110//
111// Inherent impls
112//
113
114impl Xorshift64 {
115    /// A random 64-bit prime, used to initialize the generator returned by
116    /// [`Xorshift64::default()`].
117    pub const DEFAULT_SEED: u64 = 378682147834061;
118
119    /// Returns a new `Xorshift64` seeded by the given number.
120    ///
121    /// Two `Xorshift64` instances generate the same sequence of pseudo-random
122    /// numbers if and only if they were created with the same seed.
123    /// (Technically, every `Xorshift64` instance yields values from the same
124    /// sequence; the seed determines the starting point in the sequence).
125    ///
126    /// # Examples
127    /// ```
128    /// use retrofire_core::math::rand::Xorshift64;
129    ///
130    /// let mut g = Xorshift64::from_seed(123);
131    /// assert_eq!(g.next_bits(), 133101616827);
132    /// assert_eq!(g.next_bits(), 12690785413091508870);
133    /// assert_eq!(g.next_bits(), 7516749944291143043);
134    /// ```
135    ///
136    /// # Panics
137    ///
138    /// If `seed` equals 0.
139    pub fn from_seed(seed: u64) -> Self {
140        assert_ne!(seed, 0, "xorshift seed cannot be zero");
141        Self(seed)
142    }
143
144    /// Returns a new `Xorshift64` seeded by the current system time.
145    ///
146    /// Note that depending on the precision of the system clock, two or more
147    /// calls to this function in quick succession *may* return instances seeded
148    /// by the same number.
149    ///
150    /// #  Examples
151    /// ```
152    /// use std::thread;
153    /// use retrofire_core::math::rand::Xorshift64;
154    ///
155    /// let mut g = Xorshift64::from_time();
156    /// thread::sleep_ms(1); // Just to be sure
157    /// let mut h = Xorshift64::from_time();
158    /// assert_ne!(g.next_bits(), h.next_bits());
159    /// ```
160    #[cfg(feature = "std")]
161    pub fn from_time() -> Self {
162        let t = std::time::SystemTime::UNIX_EPOCH
163            .elapsed()
164            .unwrap();
165        Self(t.as_micros() as u64)
166    }
167
168    /// Returns 64 bits of pseudo-randomness.
169    ///
170    /// Successive calls to this function (with the same `self`) will yield
171    /// every value in the interval [1, 2<sup>64</sup>) exactly once before
172    /// starting to repeat the sequence.
173    pub fn next_bits(&mut self) -> u64 {
174        let Self(x) = self;
175        *x ^= *x << 13;
176        *x ^= *x >> 7;
177        *x ^= *x << 17;
178        *x
179    }
180}
181
182//
183// Foreign trait impls
184//
185
186/// An infinite iterator of pseudorandom values sampled from a distribution.
187///
188/// This type is returned by [`Distrib::samples`].
189impl<D: Distrib> Iterator for Iter<D, &'_ mut DefaultRng> {
190    type Item = D::Sample;
191
192    /// Returns the next pseudorandom sample from this iterator.
193    ///
194    /// This method never returns `None`.
195    fn next(&mut self) -> Option<Self::Item> {
196        Some(self.0.sample(self.1))
197    }
198}
199
200impl Default for Xorshift64 {
201    /// Returns a `Xorshift64` seeded with [`DEFAULT_SEED`](Self::DEFAULT_SEED).
202    ///
203    /// # Examples
204    /// ```
205    /// use retrofire_core::math::rand::Xorshift64;
206    ///
207    /// let mut g = Xorshift64::default();
208    /// assert_eq!(g.next_bits(), 11039719294064252060);
209    /// ```
210    fn default() -> Self {
211        // Random 64-bit prime
212        Self::from_seed(Self::DEFAULT_SEED)
213    }
214}
215
216//
217// Local trait impls
218//
219
220/// Uniformly distributed signed integers.
221impl Distrib for Uniform<i32> {
222    type Sample = i32;
223
224    /// Returns a uniformly distributed `i32` in the range.
225    ///
226    /// # Examples
227    /// ```
228    /// use retrofire_core::math::rand::*;
229    /// let rng = &mut DefaultRng::default();
230    ///
231    ///
232    /// let mut iter = Uniform(-5i32..6).samples(rng);
233    /// assert_eq!(iter.next(), Some(0));
234    /// assert_eq!(iter.next(), Some(4));
235    /// assert_eq!(iter.next(), Some(5));
236    /// ```
237    fn sample(&self, rng: &mut DefaultRng) -> i32 {
238        let bits = rng.next_bits() as i32;
239        // TODO rem introduces slight bias
240        bits.rem_euclid(self.0.end - self.0.start) + self.0.start
241    }
242}
243/// Uniformly distributed unsigned integers.
244impl Distrib for Uniform<u32> {
245    type Sample = u32;
246
247    /// Returns a uniformly distributed `u32` in the range.
248    ///
249    /// # Examples
250    /// ```
251    /// use retrofire_core::math::rand::*;
252    /// let rng = &mut DefaultRng::from_seed(1234);
253    ///
254    /// // Simulate rolling a six-sided die
255    /// let mut rolls: Vec<_>  = Uniform(1u32..7)
256    ///     .samples(rng)
257    ///     .take(6)
258    ///     .collect();
259    /// assert_eq!(rolls, [2, 4, 6, 6, 3, 1]);
260    /// ```
261    fn sample(&self, rng: &mut DefaultRng) -> u32 {
262        let bits = rng.next_bits() as u32;
263        // TODO rem introduces slight bias
264        bits.rem_euclid(self.0.end - self.0.start) + self.0.start
265    }
266}
267
268/// Uniformly distributed indices.
269impl Distrib for Uniform<usize> {
270    type Sample = usize;
271
272    /// Returns a uniformly distributed `usize` in the range.
273    ///
274    /// # Examples
275    /// ```
276    /// use retrofire_core::math::rand::*;
277    /// let rng = &mut DefaultRng::default();
278    ///
279    /// // Randomly sample elements from a list (with replacement)
280    /// let beverages = ["water", "tea", "coffee", "Coke", "Red Bull"];
281    /// let mut x: Vec<_> = Uniform(0..beverages.len())
282    ///     .samples(rng)
283    ///     .take(3)
284    ///     .map(|i| beverages[i])
285    ///     .collect();
286    ///
287    /// assert_eq!(x, ["water", "tea", "Red Bull"]);
288    /// ```
289    fn sample(&self, rng: &mut DefaultRng) -> usize {
290        let bits = rng.next_bits() as usize;
291        // TODO rem introduces slight bias
292        bits.rem_euclid(self.0.end - self.0.start) + self.0.start
293    }
294}
295
296/// Uniformly distributed floats.
297impl Distrib for Uniform<f32> {
298    type Sample = f32;
299
300    /// Returns a uniformly distributed `f32` in the range.
301    ///
302    /// # Examples
303    /// ```
304    /// use retrofire_core::math::rand::*;
305    /// let rng = &mut DefaultRng::default();
306    ///
307    /// // Floats in the interval [-1, 1)
308    /// let mut iter = Uniform(-1.0..1.0).samples(rng);
309    /// assert_eq!(iter.next(), Some(0.19692874));
310    /// assert_eq!(iter.next(), Some(-0.7686298));
311    /// assert_eq!(iter.next(), Some(0.91969657));
312    /// ```
313    fn sample(&self, rng: &mut DefaultRng) -> f32 {
314        let Range { start, end } = self.0;
315        // Bit repr of a random f32 in range 1.0..2.0
316        // Leaves a lot of precision unused near zero, but it's okay.
317        let (exp, mantissa) = (127 << 23, rng.next_bits() >> 41);
318        let unit = f32::from_bits(exp | mantissa as u32) - 1.0;
319        unit * (end - start) + start
320    }
321}
322
323impl<T, const N: usize> Distrib for Uniform<[T; N]>
324where
325    T: Copy,
326    Uniform<T>: Distrib<Sample = T>,
327{
328    type Sample = [T; N];
329
330    /// Returns the coordinates of a point sampled from a uniform distribution
331    /// within the N-dimensional rectangular volume bounded by `self.0`.
332    ///
333    /// # Examples
334    /// ```
335    /// use retrofire_core::math::rand::*;
336    /// let rng = &mut DefaultRng::default();
337    ///
338    /// // Pairs of integers [X, Y] such that 0 <= X < 4 and -2 <= Y <= 3
339    /// let mut int_pairs = Uniform([0, -2]..[4, 3]).samples(rng);
340    ///
341    /// assert_eq!(int_pairs.next(), Some([0, -1]));
342    /// assert_eq!(int_pairs.next(), Some([1, 0]));
343    /// assert_eq!(int_pairs.next(), Some([3, 1]));
344    /// ```
345    fn sample(&self, rng: &mut DefaultRng) -> [T; N] {
346        let Range { start, end } = self.0;
347        array::from_fn(|i| Uniform(start[i]..end[i]).sample(rng))
348    }
349}
350
351/// Uniformly distributed vectors within a rectangular volume.
352impl<Sc, Sp, const DIM: usize> Distrib for Uniform<Vector<[Sc; DIM], Sp>>
353where
354    Sc: Copy,
355    Uniform<[Sc; DIM]>: Distrib<Sample = [Sc; DIM]>,
356{
357    type Sample = Vector<[Sc; DIM], Sp>;
358
359    /// Returns a vector uniformly sampled from the rectangular volume
360    /// bounded by `self.0`.
361    fn sample(&self, rng: &mut DefaultRng) -> Self::Sample {
362        Uniform(self.0.start.0..self.0.end.0)
363            .sample(rng)
364            .into()
365    }
366}
367
368/// Uniformly distributed points within a rectangular volume.
369impl<Sc, Sp, const DIM: usize> Distrib for Uniform<Point<[Sc; DIM], Sp>>
370where
371    Sc: Copy,
372    Uniform<[Sc; DIM]>: Distrib<Sample = [Sc; DIM]>,
373{
374    type Sample = Point<[Sc; DIM], Sp>;
375
376    /// Returns a point uniformly sampled from the rectangular volume
377    /// bounded by `self.0`.
378    fn sample(&self, rng: &mut DefaultRng) -> Self::Sample {
379        Uniform(self.0.start.0..self.0.end.0)
380            .sample(rng)
381            .into()
382    }
383}
384impl<Sc, Sp, const DIM: usize> Distrib for Uniform<Color<[Sc; DIM], Sp>>
385where
386    Sc: Copy,
387    Sp: Clone, // TODO Color needs manual Clone etc impls like Vector
388    Uniform<[Sc; DIM]>: Distrib<Sample = [Sc; DIM]>,
389{
390    type Sample = Point<[Sc; DIM], Sp>;
391
392    /// Returns a point uniformly sampled from the rectangular volume
393    /// bounded by `self.0`.
394    fn sample(&self, rng: &mut DefaultRng) -> Self::Sample {
395        Uniform(self.0.start.0..self.0.end.0)
396            .sample(rng)
397            .into()
398    }
399}
400
401#[cfg(feature = "fp")]
402impl Distrib for UnitCircle {
403    type Sample = Vec2;
404
405    /// Returns a unit 2-vector uniformly sampled from the unit circle.
406    ///
407    /// # Example
408    /// ```
409    /// use retrofire_core::math::{ApproxEq, rand::*};
410    /// let rng = &mut DefaultRng::default();
411    ///
412    /// let vec = UnitCircle.sample(rng);
413    /// assert!(vec.len_sqr().approx_eq(&1.0));
414    /// ```
415    fn sample(&self, rng: &mut DefaultRng) -> Vec2 {
416        let d = Uniform([-1.0; 2]..[1.0; 2]);
417        // Normalization preserves uniformity
418        Vec2::from(d.sample(rng)).normalize()
419    }
420}
421
422impl Distrib for VectorsOnUnitDisk {
423    type Sample = Vec2;
424
425    /// Returns a 2-vector uniformly sampled from the unit disk.
426    ///
427    /// # Example
428    /// ```
429    /// use retrofire_core::math::rand::*;
430    /// let rng = &mut DefaultRng::default();
431    ///
432    /// let vec = VectorsOnUnitDisk.sample(rng);
433    /// assert!(vec.len_sqr() <= 1.0);
434    /// ```
435    fn sample(&self, rng: &mut DefaultRng) -> Vec2 {
436        let d = Uniform([-1.0f32; 2]..[1.0; 2]);
437        loop {
438            // Rejection sampling
439            let v = Vec2::from(d.sample(rng));
440            if v.len_sqr() <= 1.0 {
441                return v;
442            }
443        }
444    }
445}
446
447#[cfg(feature = "fp")]
448impl Distrib for UnitSphere {
449    type Sample = Vec3;
450
451    /// Returns a unit 3-vector uniformly sampled from the unit sphere.
452    ///
453    /// # Example
454    /// ```
455    /// use retrofire_core::assert_approx_eq;
456    /// use retrofire_core::math::rand::*;
457    /// let rng = &mut DefaultRng::default();
458    ///
459    /// let vec = UnitSphere.sample(rng);
460    /// assert_approx_eq!(vec.len_sqr(), 1.0);
461    /// ```
462    fn sample(&self, rng: &mut DefaultRng) -> Vec3 {
463        let d = Uniform([-1.0; 3]..[1.0; 3]);
464        Vec3::from(d.sample(rng)).normalize()
465    }
466}
467
468impl Distrib for VectorsInUnitBall {
469    type Sample = Vec3;
470
471    /// Returns a 3-vector uniformly sampled from the unit ball.
472    ///
473    /// # Example
474    /// ```
475    /// use retrofire_core::math::rand::*;
476    /// let rng = &mut DefaultRng::default();
477    ///
478    /// let vec = VectorsInUnitBall.sample(rng);
479    /// assert!(vec.len_sqr() <= 1.0);
480    /// ```
481    fn sample(&self, rng: &mut DefaultRng) -> Vec3 {
482        let d = Uniform([-1.0; 3]..[1.0; 3]);
483        loop {
484            // Rejection sampling
485            let v = Vec3::from(d.sample(rng));
486            if v.len_sqr() <= 1.0 {
487                return v;
488            }
489        }
490    }
491}
492
493impl Distrib for PointsOnUnitDisk {
494    type Sample = Point2;
495
496    /// Returns a 2-point uniformly sampled from the unit disk.
497    ///
498    /// See [`VectorsOnUnitDisk::sample`].
499    fn sample(&self, rng: &mut DefaultRng) -> Point2 {
500        VectorsOnUnitDisk.sample(rng).to_pt()
501    }
502}
503
504impl Distrib for PointsInUnitBall {
505    type Sample = Point3;
506
507    /// Returns a 3-point uniformly sampled from the unit ball.
508    ///
509    /// See [`VectorsInUnitBall::sample`].
510    fn sample(&self, rng: &mut DefaultRng) -> Point3 {
511        VectorsInUnitBall.sample(rng).to_pt()
512    }
513}
514
515impl Distrib for Bernoulli {
516    type Sample = bool;
517
518    /// Returns booleans sampled from a Bernoulli distribution.
519    ///
520    /// The result is `true` with probability `self.0` and false
521    /// with probability 1 - `self.0`.
522    ///
523    /// # Example
524    /// ```
525    /// use core::array;
526    /// use retrofire_core::math::rand::*;
527    /// let rng = &mut DefaultRng::default();
528    ///
529    /// let bern = Bernoulli(0.6); // P(true) = 0.6
530    /// let bools = array::from_fn(|_| bern.sample(rng));
531    /// assert_eq!(bools, [true, true, false, true, false, true]);
532    /// ```
533    fn sample(&self, rng: &mut DefaultRng) -> bool {
534        Uniform(0.0f32..1.0).sample(rng) < self.0
535    }
536}
537
538impl<D: Distrib, E: Distrib> Distrib for (D, E) {
539    type Sample = (D::Sample, E::Sample);
540
541    /// Returns a pair of samples, sampled from two separate distributions.
542    fn sample(&self, rng: &mut DefaultRng) -> Self::Sample {
543        (self.0.sample(rng), self.1.sample(rng))
544    }
545}
546
547#[cfg(test)]
548#[allow(clippy::manual_range_contains)]
549mod tests {
550    use crate::math::vec3;
551
552    use super::*;
553
554    const COUNT: usize = 1000;
555
556    fn rng() -> DefaultRng {
557        Default::default()
558    }
559
560    #[test]
561    fn uniform_i32() {
562        let dist = Uniform(-123i32..456);
563        for r in dist.samples(&mut rng()).take(COUNT) {
564            assert!(-123 <= r && r < 456);
565        }
566    }
567
568    #[test]
569    fn uniform_f32() {
570        let dist = Uniform(-1.23..4.56);
571        for r in dist.samples(&mut rng()).take(COUNT) {
572            assert!(-1.23 <= r && r < 4.56);
573        }
574    }
575
576    #[test]
577    fn uniform_i32_array() {
578        let dist = Uniform([0, -10]..[10, 15]);
579
580        let sum = dist
581            .samples(&mut rng())
582            .take(COUNT)
583            .inspect(|&[x, y]| {
584                assert!(0 <= x && x < 10);
585                assert!(-10 <= y && x < 15);
586            })
587            .fold([0, 0], |[ax, ay], [x, y]| [ax + x, ay + y]);
588
589        assert_eq!(sum, [4531, 1652]);
590    }
591
592    #[test]
593    fn uniform_vec3() {
594        let dist =
595            Uniform(vec3::<f32, ()>(-2.0, 0.0, -1.0)..vec3(1.0, 2.0, 3.0));
596
597        let mean = dist
598            .samples(&mut rng())
599            .take(COUNT)
600            .inspect(|v| {
601                assert!(-2.0 <= v.x() && v.x() < 1.0);
602                assert!(0.0 <= v.y() && v.y() < 2.0);
603                assert!(-1.0 <= v.z() && v.z() < 3.0);
604            })
605            .sum::<Vec3>()
606            / COUNT as f32;
607
608        assert_eq!(mean, vec3(-0.46046025, 1.0209353, 0.9742225));
609    }
610
611    #[test]
612    fn bernoulli() {
613        let rng = &mut rng();
614        let bools = Bernoulli(0.1).samples(rng).take(COUNT);
615        let approx_100 = bools.filter(|&b| b).count();
616        assert_eq!(approx_100, 82);
617    }
618
619    #[cfg(feature = "fp")]
620    #[test]
621    fn unit_circle() {
622        use crate::assert_approx_eq;
623        for v in UnitCircle.samples(&mut rng()).take(COUNT) {
624            assert_approx_eq!(v.len_sqr(), 1.0, "non-unit vector: {v:?}");
625        }
626    }
627
628    #[test]
629    fn vectors_on_unit_disk() {
630        for v in VectorsOnUnitDisk.samples(&mut rng()).take(COUNT) {
631            assert!(v.len_sqr() <= 1.0, "vector of len > 1.0: {v:?}");
632        }
633    }
634
635    #[cfg(feature = "fp")]
636    #[test]
637    fn unit_sphere() {
638        use crate::assert_approx_eq;
639        for v in UnitSphere.samples(&mut rng()).take(COUNT) {
640            assert_approx_eq!(v.len_sqr(), 1.0, "non-unit vector: {v:?}");
641        }
642    }
643
644    #[test]
645    fn vectors_in_unit_ball() {
646        for v in VectorsInUnitBall.samples(&mut rng()).take(COUNT) {
647            assert!(v.len_sqr() <= 1.0, "vector of len > 1.0: {v:?}");
648        }
649    }
650
651    #[test]
652    fn zipped_pair() {
653        let rng = &mut rng();
654        let dist = (Bernoulli(0.8), Uniform(0..4));
655        assert_eq!(dist.sample(rng), (true, 1));
656        assert_eq!(dist.sample(rng), (false, 3));
657        assert_eq!(dist.sample(rng), (true, 2));
658    }
659}