retrofire_core/math/
rand.rs

1//! Pseudo-random number generation and distributions.
2
3use core::{array, fmt::Debug, ops::Range};
4
5use super::vec::{Vec2, Vec3, Vector};
6
7//
8// Traits and types
9//
10
11pub type DefaultRng = Xorshift64;
12
13/// Trait for generating values sampled from a probability distribution.
14pub trait Distrib<R = DefaultRng>: Clone {
15    /// The type of the elements of the sample space of `Self`, also called
16    /// "outcomes".
17    type Sample;
18
19    /// Returns a pseudo-random value sampled from `self`.
20    ///
21    /// # Examples
22    /// ```
23    /// use retrofire_core::math::rand::*;
24    /// // Simulate rolling a six-sided die
25    /// let mut rng = DefaultRng::default();
26    /// let d6 = Uniform(1..7).sample(&mut rng);
27    /// assert_eq!(d6, 3);
28    /// ```
29    fn sample(&self, rng: &mut R) -> Self::Sample;
30
31    /// Returns an iterator that yields samples from `self`.
32    ///
33    /// # Examples
34    /// ```
35    /// use retrofire_core::math::rand::*;
36    /// // Simulate rolling a six-sided die
37    /// let rng = DefaultRng::default();
38    /// let mut iter = Uniform(1..7).samples(rng);
39    /// assert_eq!(iter.next(), Some(3));
40    /// assert_eq!(iter.next(), Some(2));
41    /// assert_eq!(iter.next(), Some(4));
42    /// ```
43    fn samples(&self, rng: R) -> Iter<Self, R> {
44        Iter(self.clone(), rng)
45    }
46}
47
48/// A pseudo-random number generator (PRNG) that uses a [Xorshift algorithm][^1]
49/// to generate 64 bits of randomness at a time, represented by a `u64`.
50///
51/// Xorshift is a type of linear-feedback shift register that uses only three
52/// right-shifts and three xor operations per generated number, making it very
53/// efficient. Xorshift64 has a period of 2<sup>64</sup>-1: it yields every
54/// number in the interval [1, 2<sup>64</sup>) exactly once before repeating.
55///
56/// [^1]: Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software,
57///     8(14), 1–6. <https://doi.org/10.18637/jss.v008.i14>
58#[derive(Copy, Clone, Debug)]
59#[repr(transparent)]
60pub struct Xorshift64(pub u64);
61
62/// A uniform distribution of values in a range.
63#[derive(Clone, Debug)]
64pub struct Uniform<T>(pub Range<T>);
65
66/// A uniform distribution of 2-vectors on the (perimeter of) the unit circle.
67#[derive(Copy, Clone, Debug)]
68pub struct UnitCircle;
69
70/// A uniform distribution of 2-vectors inside the (closed) unit disk.
71#[derive(Copy, Clone, Debug, Default)]
72pub struct UnitDisk;
73
74/// A uniform distribution of 3-vectors on the (surface of) the unit sphere.
75#[derive(Copy, Clone, Debug, Default)]
76pub struct UnitSphere;
77
78/// A uniform distribution of 3-vectors inside the (closed) unit ball.
79#[derive(Copy, Clone, Debug, Default)]
80pub struct UnitBall;
81
82/// A Bernoulli distribution.
83///
84/// Generates boolean values such that:
85/// * P(true) = p
86/// * P(false) = 1 - p.
87///
88/// given a parameter p ∈ [0.0, 1.0].
89#[derive(Copy, Clone, Debug)]
90pub struct Bernoulli(pub f32);
91
92/// Iterator returned by the [`Distrib::samples()`] method.
93#[derive(Copy, Clone, Debug)]
94pub struct Iter<D, R>(D, R);
95
96//
97// Inherent impls
98//
99
100impl Xorshift64 {
101    /// A random 64-bit prime, used to initialize the generator returned by
102    /// [`Xorshift64::default()`].
103    pub const DEFAULT_SEED: u64 = 378682147834061;
104
105    /// Returns a new `Xorshift64` seeded by the given number.
106    ///
107    /// Two `Xorshift64` instances generate the same sequence of pseudo-random
108    /// numbers if and only if they were created with the same seed.
109    /// (Technically, every `Xorshift64` instance yields values from the same
110    /// sequence; the seed determines the starting point in the sequence).
111    ///
112    /// # Examples
113    /// ```
114    /// # use retrofire_core::math::rand::Xorshift64;
115    /// let mut g = Xorshift64::from_seed(123);
116    /// assert_eq!(g.next_bits(), 133101616827);
117    /// assert_eq!(g.next_bits(), 12690785413091508870);
118    /// assert_eq!(g.next_bits(), 7516749944291143043);
119    /// ```
120    ///
121    /// # Panics
122    ///
123    /// If `seed` equals 0.
124    pub fn from_seed(seed: u64) -> Self {
125        assert_ne!(seed, 0, "xorshift seed cannot be zero");
126        Self(seed)
127    }
128
129    /// Returns a new `Xorshift64` seeded by the current system time.
130    ///
131    /// Note that depending on the precision of the system clock, two or more
132    /// calls to this function in quick succession *may* return instances seeded
133    /// by the same number.
134    ///
135    /// #  Examples
136    /// ```
137    /// # use std::thread;
138    /// # use retrofire_core::math::rand::Xorshift64;
139    /// let mut g = Xorshift64::from_time();
140    /// thread::sleep_ms(1); // Just to be sure
141    /// let mut h = Xorshift64::from_time();
142    /// assert_ne!(g.next_bits(), h.next_bits());
143    /// ```
144    #[cfg(feature = "std")]
145    pub fn from_time() -> Self {
146        let t = std::time::SystemTime::UNIX_EPOCH
147            .elapsed()
148            .unwrap();
149        Self(t.as_micros() as u64)
150    }
151
152    /// Returns 64 bits of pseudo-randomness.
153    ///
154    /// Successive calls to this function (with the same `self`) will yield
155    /// every value in the interval [1, 2<sup>64</sup>) exactly once before
156    /// starting to repeat the sequence.
157    pub fn next_bits(&mut self) -> u64 {
158        let Self(x) = self;
159        *x ^= *x << 13;
160        *x ^= *x >> 7;
161        *x ^= *x << 17;
162        *x
163    }
164}
165
166//
167// Foreign trait impls
168//
169
170/// An infinite iterator of pseudorandom values sampled from a distribution.
171///
172/// This type is returned by [`Distrib::samples`].
173impl<D: Distrib> Iterator for Iter<D, DefaultRng> {
174    type Item = D::Sample;
175
176    /// Returns the next pseudorandom sample from this iterator.
177    ///
178    /// This method never returns `None`.
179    fn next(&mut self) -> Option<Self::Item> {
180        Some(self.0.sample(&mut self.1))
181    }
182}
183
184impl Default for Xorshift64 {
185    /// Returns a `Xorshift64` seeded with [`DEFAULT_SEED`](Self::DEFAULT_SEED).
186    ///
187    /// # Examples
188    /// ```
189    /// use retrofire_core::math::rand::Xorshift64;
190    /// let mut g = Xorshift64::default();
191    /// assert_eq!(g.next_bits(), 11039719294064252060);
192    /// ```
193    fn default() -> Self {
194        // Random 64-bit prime
195        Self::from_seed(Self::DEFAULT_SEED)
196    }
197}
198
199//
200// Local trait impls
201//
202
203/// Uniformly distributed integers.
204impl Distrib for Uniform<i32> {
205    type Sample = i32;
206
207    /// Returns a uniformly distributed `i32` in the range.
208    ///
209    /// # Examples
210    /// ```
211    /// use retrofire_core::math::rand::*;
212    /// let rng = DefaultRng::default();
213    ///
214    /// // Simulate rolling a six-sided die
215    /// let mut iter = Uniform(1..7).samples(rng);
216    /// assert_eq!(iter.next(), Some(3));
217    /// assert_eq!(iter.next(), Some(2));
218    /// assert_eq!(iter.next(), Some(4));
219    /// ```
220    fn sample(&self, rng: &mut DefaultRng) -> i32 {
221        let bits = rng.next_bits() as i32;
222        // TODO rem introduces slight bias
223        bits.rem_euclid(self.0.end - self.0.start) + self.0.start
224    }
225}
226
227/// Uniformly distributed floats.
228impl Distrib for Uniform<f32> {
229    type Sample = f32;
230
231    /// Returns a uniformly distributed `f32` in the range.
232    ///
233    /// # Examples
234    /// ```
235    /// use retrofire_core::math::rand::*;
236    /// let rng = DefaultRng::default();
237    ///
238    /// // Floats in the interval [-1, 1)
239    /// let mut iter = Uniform(-1.0..1.0).samples(rng);
240    /// assert_eq!(iter.next(), Some(0.19692874));
241    /// assert_eq!(iter.next(), Some(-0.7686298));
242    /// assert_eq!(iter.next(), Some(0.91969657));
243    /// ```
244    fn sample(&self, rng: &mut DefaultRng) -> f32 {
245        let Range { start, end } = self.0;
246        // Bit repr of a random f32 in range 1.0..2.0
247        // Leaves a lot of precision unused near zero, but it's okay.
248        let (exp, mantissa) = (127 << 23, rng.next_bits() >> 41);
249        let unit = f32::from_bits(exp | mantissa as u32) - 1.0;
250        unit * (end - start) + start
251    }
252}
253
254impl<T, const N: usize> Distrib for Uniform<[T; N]>
255where
256    T: Copy,
257    Uniform<T>: Distrib<Sample = T>,
258{
259    type Sample = [T; N];
260
261    /// Returns the coordinates of a uniformly distributed point within
262    /// the N-dimensional rectangular volume bounded by the range `self.0`.
263    ///
264    /// # Examples
265    /// ```
266    /// use retrofire_core::math::rand::*;
267    /// let rng = DefaultRng::default();
268    ///
269    /// // Pairs of integers [X, Y] such that 0 <= X < 4 and -2 <= Y <= 3
270    /// let mut iter = Uniform([0, -2]..[4, 3]).samples(rng);
271    /// assert_eq!(iter.next(), Some([0, -1]));
272    /// assert_eq!(iter.next(), Some([1, 0]));
273    /// assert_eq!(iter.next(), Some([3, 1]));
274    /// ```
275    fn sample(&self, rng: &mut DefaultRng) -> [T; N] {
276        let Range { start, end } = self.0;
277        array::from_fn(|i| Uniform(start[i]..end[i]).sample(rng))
278    }
279}
280
281/// Uniformly distributed vectors within a rectangular volume.
282impl<Sc, Sp, const DIM: usize> Distrib for Uniform<Vector<[Sc; DIM], Sp>>
283where
284    Sc: Copy,
285    Uniform<[Sc; DIM]>: Distrib<Sample = [Sc; DIM]>,
286{
287    type Sample = Vector<[Sc; DIM], Sp>;
288
289    /// Returns a uniformly distributed vector within the rectangular volume
290    /// bounded by the range `self.0`.
291    fn sample(&self, rng: &mut DefaultRng) -> Self::Sample {
292        Uniform(self.0.start.0..self.0.end.0)
293            .sample(rng)
294            .into()
295    }
296}
297
298#[cfg(feature = "fp")]
299impl Distrib for UnitCircle {
300    type Sample = Vec2;
301
302    /// Returns a 2-vector uniformly distributed on the unit circle.
303    fn sample(&self, rng: &mut DefaultRng) -> Vec2 {
304        let d = Uniform([-1.0; 2]..[1.0; 2]);
305        Vec2::from(d.sample(rng)).normalize()
306    }
307}
308
309impl Distrib for UnitDisk {
310    type Sample = Vec2;
311
312    /// Returns a 2-vector uniformly distributed within the unit disk.
313    fn sample(&self, rng: &mut DefaultRng) -> Vec2 {
314        let d = Uniform([-1.0f32; 2]..[1.0; 2]);
315        loop {
316            let v = Vec2::from(d.sample(rng));
317            if v.len_sqr() <= 1.0 {
318                return v;
319            }
320        }
321    }
322}
323
324#[cfg(feature = "fp")]
325impl Distrib for UnitSphere {
326    type Sample = Vec3;
327
328    /// Returns a vector uniformly distributed on the unit sphere.
329    fn sample(&self, rng: &mut DefaultRng) -> Vec3 {
330        let d = Uniform([-1.0f32; 3]..[1.0; 3]);
331        Vec3::from(d.sample(rng)).normalize()
332    }
333}
334
335impl Distrib for UnitBall {
336    type Sample = Vec3;
337
338    /// Returns a vector uniformly distributed within the unit ball.
339    fn sample(&self, rng: &mut DefaultRng) -> Vec3 {
340        let d = Uniform([-1.0; 3]..[1.0; 3]);
341        loop {
342            let v = Vec3::from(d.sample(rng));
343            if v.len_sqr() <= 1.0 {
344                return v;
345            }
346        }
347    }
348}
349
350impl Distrib for Bernoulli {
351    type Sample = bool;
352
353    /// Returns boolean values sampled from a Bernoulli distribution.
354    fn sample(&self, rng: &mut DefaultRng) -> bool {
355        Uniform(0.0f32..1.0).sample(rng) < self.0
356    }
357}
358
359impl<D: Distrib, E: Distrib> Distrib for (D, E) {
360    type Sample = (D::Sample, E::Sample);
361
362    /// Returns a pair of samples, sampled from two separate distributions.
363    fn sample(&self, rng: &mut DefaultRng) -> Self::Sample {
364        (self.0.sample(rng), self.1.sample(rng))
365    }
366}
367
368#[cfg(test)]
369#[allow(clippy::manual_range_contains)]
370mod tests {
371    use crate::assert_approx_eq;
372    use crate::math::vec::vec3;
373
374    use super::*;
375
376    const COUNT: usize = 1000;
377
378    fn rng() -> DefaultRng {
379        Default::default()
380    }
381
382    #[test]
383    fn uniform_i32() {
384        let dist = Uniform(-123..456);
385        for r in dist.samples(rng()).take(COUNT) {
386            assert!(-123 <= r && r < 456);
387        }
388    }
389
390    #[test]
391    fn uniform_f32() {
392        let dist = Uniform(-1.23..4.56);
393        for r in dist.samples(rng()).take(COUNT) {
394            assert!(-1.23 <= r && r < 4.56);
395        }
396    }
397
398    #[test]
399    fn uniform_i32_array() {
400        let dist = Uniform([0, -10]..[10, 15]);
401
402        let sum = dist
403            .samples(rng())
404            .take(COUNT)
405            .inspect(|&[x, y]| {
406                assert!(0 <= x && x < 10);
407                assert!(-10 <= y && x < 15);
408            })
409            .fold([0, 0], |[ax, ay], [x, y]| [ax + x, ay + y]);
410
411        assert_eq!(sum, [4531, 1652]);
412    }
413
414    #[test]
415    fn uniform_vec3() {
416        let dist =
417            Uniform(vec3::<f32, ()>(-2.0, 0.0, -1.0)..vec3(1.0, 2.0, 3.0));
418
419        let mean = dist
420            .samples(rng())
421            .take(COUNT)
422            .inspect(|v| {
423                assert!(-2.0 <= v.x() && v.x() < 1.0);
424                assert!(0.0 <= v.y() && v.y() < 2.0);
425                assert!(-1.0 <= v.z() && v.z() < 3.0);
426            })
427            .sum::<Vec3>()
428            / COUNT as f32;
429
430        assert_eq!(mean, vec3(-0.46046025, 1.0209353, 0.9742225));
431    }
432
433    #[test]
434    fn bernoulli() {
435        let approx_100 = Bernoulli(0.1)
436            .samples(rng())
437            .take(COUNT)
438            .filter(|&b| b)
439            .count();
440        assert_eq!(approx_100, 82);
441    }
442
443    #[cfg(feature = "fp")]
444    #[test]
445    fn unit_circle() {
446        for v in UnitCircle.samples(rng()).take(COUNT) {
447            assert_approx_eq!(v.len_sqr(), 1.0, "non-unit vector: {v:?}");
448        }
449    }
450
451    #[test]
452    fn unit_disk() {
453        for v in UnitDisk.samples(rng()).take(COUNT) {
454            assert!(v.len_sqr() <= 1.0, "vector of len > 1.0: {v:?}");
455        }
456    }
457
458    #[cfg(feature = "fp")]
459    #[test]
460    fn unit_sphere() {
461        for v in UnitSphere.samples(rng()).take(COUNT) {
462            assert_approx_eq!(v.len_sqr(), 1.0, "non-unit vector: {v:?}");
463        }
464    }
465
466    #[test]
467    fn unit_ball() {
468        for v in UnitBall.samples(rng()).take(COUNT) {
469            assert!(v.len_sqr() <= 1.0, "vector of len > 1.0: {v:?}");
470        }
471    }
472
473    #[test]
474    fn zipped_pair() {
475        let mut rng = rng();
476        let dist = (Bernoulli(0.8), Uniform(0..4));
477        assert_eq!(dist.sample(&mut rng), (true, 1));
478        assert_eq!(dist.sample(&mut rng), (false, 3));
479        assert_eq!(dist.sample(&mut rng), (true, 2));
480    }
481}