retrofire_core/math/rand.rs
1//! Pseudo-random number generation and distributions.
2
3use core::{array, fmt::Debug, ops::Range};
4
5use super::vec::{Vec2, Vec3, Vector};
6
7//
8// Traits and types
9//
10
11pub type DefaultRng = Xorshift64;
12
13/// Trait for generating values sampled from a probability distribution.
14pub trait Distrib<R = DefaultRng>: Clone {
15 /// The type of the elements of the sample space of `Self`, also called
16 /// "outcomes".
17 type Sample;
18
19 /// Returns a pseudo-random value sampled from `self`.
20 ///
21 /// # Examples
22 /// ```
23 /// use retrofire_core::math::rand::*;
24 /// // Simulate rolling a six-sided die
25 /// let mut rng = DefaultRng::default();
26 /// let d6 = Uniform(1..7).sample(&mut rng);
27 /// assert_eq!(d6, 3);
28 /// ```
29 fn sample(&self, rng: &mut R) -> Self::Sample;
30
31 /// Returns an iterator that yields samples from `self`.
32 ///
33 /// # Examples
34 /// ```
35 /// use retrofire_core::math::rand::*;
36 /// // Simulate rolling a six-sided die
37 /// let rng = DefaultRng::default();
38 /// let mut iter = Uniform(1..7).samples(rng);
39 /// assert_eq!(iter.next(), Some(3));
40 /// assert_eq!(iter.next(), Some(2));
41 /// assert_eq!(iter.next(), Some(4));
42 /// ```
43 fn samples(&self, rng: R) -> Iter<Self, R> {
44 Iter(self.clone(), rng)
45 }
46}
47
48/// A pseudo-random number generator (PRNG) that uses a [Xorshift algorithm][^1]
49/// to generate 64 bits of randomness at a time, represented by a `u64`.
50///
51/// Xorshift is a type of linear-feedback shift register that uses only three
52/// right-shifts and three xor operations per generated number, making it very
53/// efficient. Xorshift64 has a period of 2<sup>64</sup>-1: it yields every
54/// number in the interval [1, 2<sup>64</sup>) exactly once before repeating.
55///
56/// [^1]: Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software,
57/// 8(14), 1–6. <https://doi.org/10.18637/jss.v008.i14>
58#[derive(Copy, Clone, Debug)]
59#[repr(transparent)]
60pub struct Xorshift64(pub u64);
61
62/// A uniform distribution of values in a range.
63#[derive(Clone, Debug)]
64pub struct Uniform<T>(pub Range<T>);
65
66/// A uniform distribution of 2-vectors on the (perimeter of) the unit circle.
67#[derive(Copy, Clone, Debug)]
68pub struct UnitCircle;
69
70/// A uniform distribution of 2-vectors inside the (closed) unit disk.
71#[derive(Copy, Clone, Debug, Default)]
72pub struct UnitDisk;
73
74/// A uniform distribution of 3-vectors on the (surface of) the unit sphere.
75#[derive(Copy, Clone, Debug, Default)]
76pub struct UnitSphere;
77
78/// A uniform distribution of 3-vectors inside the (closed) unit ball.
79#[derive(Copy, Clone, Debug, Default)]
80pub struct UnitBall;
81
82/// A Bernoulli distribution.
83///
84/// Generates boolean values such that:
85/// * P(true) = p
86/// * P(false) = 1 - p.
87///
88/// given a parameter p ∈ [0.0, 1.0].
89#[derive(Copy, Clone, Debug)]
90pub struct Bernoulli(pub f32);
91
92/// Iterator returned by the [`Distrib::samples()`] method.
93#[derive(Copy, Clone, Debug)]
94pub struct Iter<D, R>(D, R);
95
96//
97// Inherent impls
98//
99
100impl Xorshift64 {
101 /// A random 64-bit prime, used to initialize the generator returned by
102 /// [`Xorshift64::default()`].
103 pub const DEFAULT_SEED: u64 = 378682147834061;
104
105 /// Returns a new `Xorshift64` seeded by the given number.
106 ///
107 /// Two `Xorshift64` instances generate the same sequence of pseudo-random
108 /// numbers if and only if they were created with the same seed.
109 /// (Technically, every `Xorshift64` instance yields values from the same
110 /// sequence; the seed determines the starting point in the sequence).
111 ///
112 /// # Examples
113 /// ```
114 /// # use retrofire_core::math::rand::Xorshift64;
115 /// let mut g = Xorshift64::from_seed(123);
116 /// assert_eq!(g.next_bits(), 133101616827);
117 /// assert_eq!(g.next_bits(), 12690785413091508870);
118 /// assert_eq!(g.next_bits(), 7516749944291143043);
119 /// ```
120 ///
121 /// # Panics
122 ///
123 /// If `seed` equals 0.
124 pub fn from_seed(seed: u64) -> Self {
125 assert_ne!(seed, 0, "xorshift seed cannot be zero");
126 Self(seed)
127 }
128
129 /// Returns a new `Xorshift64` seeded by the current system time.
130 ///
131 /// Note that depending on the precision of the system clock, two or more
132 /// calls to this function in quick succession *may* return instances seeded
133 /// by the same number.
134 ///
135 /// # Examples
136 /// ```
137 /// # use std::thread;
138 /// # use retrofire_core::math::rand::Xorshift64;
139 /// let mut g = Xorshift64::from_time();
140 /// thread::sleep_ms(1); // Just to be sure
141 /// let mut h = Xorshift64::from_time();
142 /// assert_ne!(g.next_bits(), h.next_bits());
143 /// ```
144 #[cfg(feature = "std")]
145 pub fn from_time() -> Self {
146 let t = std::time::SystemTime::UNIX_EPOCH
147 .elapsed()
148 .unwrap();
149 Self(t.as_micros() as u64)
150 }
151
152 /// Returns 64 bits of pseudo-randomness.
153 ///
154 /// Successive calls to this function (with the same `self`) will yield
155 /// every value in the interval [1, 2<sup>64</sup>) exactly once before
156 /// starting to repeat the sequence.
157 pub fn next_bits(&mut self) -> u64 {
158 let Self(x) = self;
159 *x ^= *x << 13;
160 *x ^= *x >> 7;
161 *x ^= *x << 17;
162 *x
163 }
164}
165
166//
167// Foreign trait impls
168//
169
170/// An infinite iterator of pseudorandom values sampled from a distribution.
171///
172/// This type is returned by [`Distrib::samples`].
173impl<D: Distrib> Iterator for Iter<D, DefaultRng> {
174 type Item = D::Sample;
175
176 /// Returns the next pseudorandom sample from this iterator.
177 ///
178 /// This method never returns `None`.
179 fn next(&mut self) -> Option<Self::Item> {
180 Some(self.0.sample(&mut self.1))
181 }
182}
183
184impl Default for Xorshift64 {
185 /// Returns a `Xorshift64` seeded with [`DEFAULT_SEED`](Self::DEFAULT_SEED).
186 ///
187 /// # Examples
188 /// ```
189 /// use retrofire_core::math::rand::Xorshift64;
190 /// let mut g = Xorshift64::default();
191 /// assert_eq!(g.next_bits(), 11039719294064252060);
192 /// ```
193 fn default() -> Self {
194 // Random 64-bit prime
195 Self::from_seed(Self::DEFAULT_SEED)
196 }
197}
198
199//
200// Local trait impls
201//
202
203/// Uniformly distributed integers.
204impl Distrib for Uniform<i32> {
205 type Sample = i32;
206
207 /// Returns a uniformly distributed `i32` in the range.
208 ///
209 /// # Examples
210 /// ```
211 /// use retrofire_core::math::rand::*;
212 /// let rng = DefaultRng::default();
213 ///
214 /// // Simulate rolling a six-sided die
215 /// let mut iter = Uniform(1..7).samples(rng);
216 /// assert_eq!(iter.next(), Some(3));
217 /// assert_eq!(iter.next(), Some(2));
218 /// assert_eq!(iter.next(), Some(4));
219 /// ```
220 fn sample(&self, rng: &mut DefaultRng) -> i32 {
221 let bits = rng.next_bits() as i32;
222 // TODO rem introduces slight bias
223 bits.rem_euclid(self.0.end - self.0.start) + self.0.start
224 }
225}
226
227/// Uniformly distributed floats.
228impl Distrib for Uniform<f32> {
229 type Sample = f32;
230
231 /// Returns a uniformly distributed `f32` in the range.
232 ///
233 /// # Examples
234 /// ```
235 /// use retrofire_core::math::rand::*;
236 /// let rng = DefaultRng::default();
237 ///
238 /// // Floats in the interval [-1, 1)
239 /// let mut iter = Uniform(-1.0..1.0).samples(rng);
240 /// assert_eq!(iter.next(), Some(0.19692874));
241 /// assert_eq!(iter.next(), Some(-0.7686298));
242 /// assert_eq!(iter.next(), Some(0.91969657));
243 /// ```
244 fn sample(&self, rng: &mut DefaultRng) -> f32 {
245 let Range { start, end } = self.0;
246 // Bit repr of a random f32 in range 1.0..2.0
247 // Leaves a lot of precision unused near zero, but it's okay.
248 let (exp, mantissa) = (127 << 23, rng.next_bits() >> 41);
249 let unit = f32::from_bits(exp | mantissa as u32) - 1.0;
250 unit * (end - start) + start
251 }
252}
253
254impl<T, const N: usize> Distrib for Uniform<[T; N]>
255where
256 T: Copy,
257 Uniform<T>: Distrib<Sample = T>,
258{
259 type Sample = [T; N];
260
261 /// Returns the coordinates of a uniformly distributed point within
262 /// the N-dimensional rectangular volume bounded by the range `self.0`.
263 ///
264 /// # Examples
265 /// ```
266 /// use retrofire_core::math::rand::*;
267 /// let rng = DefaultRng::default();
268 ///
269 /// // Pairs of integers [X, Y] such that 0 <= X < 4 and -2 <= Y <= 3
270 /// let mut iter = Uniform([0, -2]..[4, 3]).samples(rng);
271 /// assert_eq!(iter.next(), Some([0, -1]));
272 /// assert_eq!(iter.next(), Some([1, 0]));
273 /// assert_eq!(iter.next(), Some([3, 1]));
274 /// ```
275 fn sample(&self, rng: &mut DefaultRng) -> [T; N] {
276 let Range { start, end } = self.0;
277 array::from_fn(|i| Uniform(start[i]..end[i]).sample(rng))
278 }
279}
280
281/// Uniformly distributed vectors within a rectangular volume.
282impl<Sc, Sp, const DIM: usize> Distrib for Uniform<Vector<[Sc; DIM], Sp>>
283where
284 Sc: Copy,
285 Uniform<[Sc; DIM]>: Distrib<Sample = [Sc; DIM]>,
286{
287 type Sample = Vector<[Sc; DIM], Sp>;
288
289 /// Returns a uniformly distributed vector within the rectangular volume
290 /// bounded by the range `self.0`.
291 fn sample(&self, rng: &mut DefaultRng) -> Self::Sample {
292 Uniform(self.0.start.0..self.0.end.0)
293 .sample(rng)
294 .into()
295 }
296}
297
298#[cfg(feature = "fp")]
299impl Distrib for UnitCircle {
300 type Sample = Vec2;
301
302 /// Returns a 2-vector uniformly distributed on the unit circle.
303 fn sample(&self, rng: &mut DefaultRng) -> Vec2 {
304 let d = Uniform([-1.0; 2]..[1.0; 2]);
305 Vec2::from(d.sample(rng)).normalize()
306 }
307}
308
309impl Distrib for UnitDisk {
310 type Sample = Vec2;
311
312 /// Returns a 2-vector uniformly distributed within the unit disk.
313 fn sample(&self, rng: &mut DefaultRng) -> Vec2 {
314 let d = Uniform([-1.0f32; 2]..[1.0; 2]);
315 loop {
316 let v = Vec2::from(d.sample(rng));
317 if v.len_sqr() <= 1.0 {
318 return v;
319 }
320 }
321 }
322}
323
324#[cfg(feature = "fp")]
325impl Distrib for UnitSphere {
326 type Sample = Vec3;
327
328 /// Returns a vector uniformly distributed on the unit sphere.
329 fn sample(&self, rng: &mut DefaultRng) -> Vec3 {
330 let d = Uniform([-1.0f32; 3]..[1.0; 3]);
331 Vec3::from(d.sample(rng)).normalize()
332 }
333}
334
335impl Distrib for UnitBall {
336 type Sample = Vec3;
337
338 /// Returns a vector uniformly distributed within the unit ball.
339 fn sample(&self, rng: &mut DefaultRng) -> Vec3 {
340 let d = Uniform([-1.0; 3]..[1.0; 3]);
341 loop {
342 let v = Vec3::from(d.sample(rng));
343 if v.len_sqr() <= 1.0 {
344 return v;
345 }
346 }
347 }
348}
349
350impl Distrib for Bernoulli {
351 type Sample = bool;
352
353 /// Returns boolean values sampled from a Bernoulli distribution.
354 fn sample(&self, rng: &mut DefaultRng) -> bool {
355 Uniform(0.0f32..1.0).sample(rng) < self.0
356 }
357}
358
359impl<D: Distrib, E: Distrib> Distrib for (D, E) {
360 type Sample = (D::Sample, E::Sample);
361
362 /// Returns a pair of samples, sampled from two separate distributions.
363 fn sample(&self, rng: &mut DefaultRng) -> Self::Sample {
364 (self.0.sample(rng), self.1.sample(rng))
365 }
366}
367
368#[cfg(test)]
369#[allow(clippy::manual_range_contains)]
370mod tests {
371 use crate::assert_approx_eq;
372 use crate::math::vec::vec3;
373
374 use super::*;
375
376 const COUNT: usize = 1000;
377
378 fn rng() -> DefaultRng {
379 Default::default()
380 }
381
382 #[test]
383 fn uniform_i32() {
384 let dist = Uniform(-123..456);
385 for r in dist.samples(rng()).take(COUNT) {
386 assert!(-123 <= r && r < 456);
387 }
388 }
389
390 #[test]
391 fn uniform_f32() {
392 let dist = Uniform(-1.23..4.56);
393 for r in dist.samples(rng()).take(COUNT) {
394 assert!(-1.23 <= r && r < 4.56);
395 }
396 }
397
398 #[test]
399 fn uniform_i32_array() {
400 let dist = Uniform([0, -10]..[10, 15]);
401
402 let sum = dist
403 .samples(rng())
404 .take(COUNT)
405 .inspect(|&[x, y]| {
406 assert!(0 <= x && x < 10);
407 assert!(-10 <= y && x < 15);
408 })
409 .fold([0, 0], |[ax, ay], [x, y]| [ax + x, ay + y]);
410
411 assert_eq!(sum, [4531, 1652]);
412 }
413
414 #[test]
415 fn uniform_vec3() {
416 let dist =
417 Uniform(vec3::<f32, ()>(-2.0, 0.0, -1.0)..vec3(1.0, 2.0, 3.0));
418
419 let mean = dist
420 .samples(rng())
421 .take(COUNT)
422 .inspect(|v| {
423 assert!(-2.0 <= v.x() && v.x() < 1.0);
424 assert!(0.0 <= v.y() && v.y() < 2.0);
425 assert!(-1.0 <= v.z() && v.z() < 3.0);
426 })
427 .sum::<Vec3>()
428 / COUNT as f32;
429
430 assert_eq!(mean, vec3(-0.46046025, 1.0209353, 0.9742225));
431 }
432
433 #[test]
434 fn bernoulli() {
435 let approx_100 = Bernoulli(0.1)
436 .samples(rng())
437 .take(COUNT)
438 .filter(|&b| b)
439 .count();
440 assert_eq!(approx_100, 82);
441 }
442
443 #[cfg(feature = "fp")]
444 #[test]
445 fn unit_circle() {
446 for v in UnitCircle.samples(rng()).take(COUNT) {
447 assert_approx_eq!(v.len_sqr(), 1.0, "non-unit vector: {v:?}");
448 }
449 }
450
451 #[test]
452 fn unit_disk() {
453 for v in UnitDisk.samples(rng()).take(COUNT) {
454 assert!(v.len_sqr() <= 1.0, "vector of len > 1.0: {v:?}");
455 }
456 }
457
458 #[cfg(feature = "fp")]
459 #[test]
460 fn unit_sphere() {
461 for v in UnitSphere.samples(rng()).take(COUNT) {
462 assert_approx_eq!(v.len_sqr(), 1.0, "non-unit vector: {v:?}");
463 }
464 }
465
466 #[test]
467 fn unit_ball() {
468 for v in UnitBall.samples(rng()).take(COUNT) {
469 assert!(v.len_sqr() <= 1.0, "vector of len > 1.0: {v:?}");
470 }
471 }
472
473 #[test]
474 fn zipped_pair() {
475 let mut rng = rng();
476 let dist = (Bernoulli(0.8), Uniform(0..4));
477 assert_eq!(dist.sample(&mut rng), (true, 1));
478 assert_eq!(dist.sample(&mut rng), (false, 3));
479 assert_eq!(dist.sample(&mut rng), (true, 2));
480 }
481}