use core::{array, fmt::Debug, ops::Range};
use super::{Color, Point, Point2, Point3, Vec2, Vec3, Vector};
pub type DefaultRng = Xorshift64;
pub trait Distrib: Clone {
type Sample;
fn sample(&self, rng: &mut DefaultRng) -> Self::Sample;
fn samples(
&self,
rng: &mut DefaultRng,
) -> impl Iterator<Item = Self::Sample> {
Iter(self.clone(), rng)
}
}
#[derive(Copy, Clone, Debug)]
#[repr(transparent)]
pub struct Xorshift64(pub u64);
#[derive(Clone, Debug)]
pub struct Uniform<T>(pub Range<T>);
#[derive(Copy, Clone, Debug)]
pub struct UnitCircle;
#[derive(Copy, Clone, Debug, Default)]
pub struct UnitSphere;
#[derive(Copy, Clone, Debug, Default)]
pub struct VectorsOnUnitDisk;
#[derive(Copy, Clone, Debug, Default)]
pub struct VectorsInUnitBall;
#[derive(Copy, Clone, Debug, Default)]
pub struct PointsOnUnitDisk;
#[derive(Copy, Clone, Debug, Default)]
pub struct PointsInUnitBall;
#[derive(Copy, Clone, Debug)]
pub struct Bernoulli(pub f32);
#[derive(Copy, Clone, Debug)]
struct Iter<D, R>(D, R);
impl Xorshift64 {
pub const DEFAULT_SEED: u64 = 378682147834061;
pub fn from_seed(seed: u64) -> Self {
assert_ne!(seed, 0, "xorshift seed cannot be zero");
Self(seed)
}
#[cfg(feature = "std")]
pub fn from_time() -> Self {
let t = std::time::SystemTime::UNIX_EPOCH
.elapsed()
.unwrap();
Self(t.as_micros() as u64)
}
pub fn next_bits(&mut self) -> u64 {
let Self(x) = self;
*x ^= *x << 13;
*x ^= *x >> 7;
*x ^= *x << 17;
*x
}
}
impl<D: Distrib> Iterator for Iter<D, &'_ mut DefaultRng> {
type Item = D::Sample;
fn next(&mut self) -> Option<Self::Item> {
Some(self.0.sample(self.1))
}
}
impl Default for Xorshift64 {
fn default() -> Self {
Self::from_seed(Self::DEFAULT_SEED)
}
}
impl Distrib for Uniform<i32> {
type Sample = i32;
fn sample(&self, rng: &mut DefaultRng) -> i32 {
let bits = rng.next_bits() as i32;
bits.rem_euclid(self.0.end - self.0.start) + self.0.start
}
}
impl Distrib for Uniform<u32> {
type Sample = u32;
fn sample(&self, rng: &mut DefaultRng) -> u32 {
let bits = rng.next_bits() as u32;
bits.rem_euclid(self.0.end - self.0.start) + self.0.start
}
}
impl Distrib for Uniform<usize> {
type Sample = usize;
fn sample(&self, rng: &mut DefaultRng) -> usize {
let bits = rng.next_bits() as usize;
bits.rem_euclid(self.0.end - self.0.start) + self.0.start
}
}
impl Distrib for Uniform<f32> {
type Sample = f32;
fn sample(&self, rng: &mut DefaultRng) -> f32 {
let Range { start, end } = self.0;
let (exp, mantissa) = (127 << 23, rng.next_bits() >> 41);
let unit = f32::from_bits(exp | mantissa as u32) - 1.0;
unit * (end - start) + start
}
}
impl<T, const N: usize> Distrib for Uniform<[T; N]>
where
T: Copy,
Uniform<T>: Distrib<Sample = T>,
{
type Sample = [T; N];
fn sample(&self, rng: &mut DefaultRng) -> [T; N] {
let Range { start, end } = self.0;
array::from_fn(|i| Uniform(start[i]..end[i]).sample(rng))
}
}
impl<Sc, Sp, const DIM: usize> Distrib for Uniform<Vector<[Sc; DIM], Sp>>
where
Sc: Copy,
Uniform<[Sc; DIM]>: Distrib<Sample = [Sc; DIM]>,
{
type Sample = Vector<[Sc; DIM], Sp>;
fn sample(&self, rng: &mut DefaultRng) -> Self::Sample {
Uniform(self.0.start.0..self.0.end.0)
.sample(rng)
.into()
}
}
impl<Sc, Sp, const DIM: usize> Distrib for Uniform<Point<[Sc; DIM], Sp>>
where
Sc: Copy,
Uniform<[Sc; DIM]>: Distrib<Sample = [Sc; DIM]>,
{
type Sample = Point<[Sc; DIM], Sp>;
fn sample(&self, rng: &mut DefaultRng) -> Self::Sample {
Uniform(self.0.start.0..self.0.end.0)
.sample(rng)
.into()
}
}
impl<Sc, Sp, const DIM: usize> Distrib for Uniform<Color<[Sc; DIM], Sp>>
where
Sc: Copy,
Sp: Clone, Uniform<[Sc; DIM]>: Distrib<Sample = [Sc; DIM]>,
{
type Sample = Point<[Sc; DIM], Sp>;
fn sample(&self, rng: &mut DefaultRng) -> Self::Sample {
Uniform(self.0.start.0..self.0.end.0)
.sample(rng)
.into()
}
}
#[cfg(feature = "fp")]
impl Distrib for UnitCircle {
type Sample = Vec2;
fn sample(&self, rng: &mut DefaultRng) -> Vec2 {
let d = Uniform([-1.0; 2]..[1.0; 2]);
Vec2::from(d.sample(rng)).normalize()
}
}
impl Distrib for VectorsOnUnitDisk {
type Sample = Vec2;
fn sample(&self, rng: &mut DefaultRng) -> Vec2 {
let d = Uniform([-1.0f32; 2]..[1.0; 2]);
loop {
let v = Vec2::from(d.sample(rng));
if v.len_sqr() <= 1.0 {
return v;
}
}
}
}
#[cfg(feature = "fp")]
impl Distrib for UnitSphere {
type Sample = Vec3;
fn sample(&self, rng: &mut DefaultRng) -> Vec3 {
let d = Uniform([-1.0; 3]..[1.0; 3]);
Vec3::from(d.sample(rng)).normalize()
}
}
impl Distrib for VectorsInUnitBall {
type Sample = Vec3;
fn sample(&self, rng: &mut DefaultRng) -> Vec3 {
let d = Uniform([-1.0; 3]..[1.0; 3]);
loop {
let v = Vec3::from(d.sample(rng));
if v.len_sqr() <= 1.0 {
return v;
}
}
}
}
impl Distrib for PointsOnUnitDisk {
type Sample = Point2;
fn sample(&self, rng: &mut DefaultRng) -> Point2 {
VectorsOnUnitDisk.sample(rng).to_pt()
}
}
impl Distrib for PointsInUnitBall {
type Sample = Point3;
fn sample(&self, rng: &mut DefaultRng) -> Point3 {
VectorsInUnitBall.sample(rng).to_pt()
}
}
impl Distrib for Bernoulli {
type Sample = bool;
fn sample(&self, rng: &mut DefaultRng) -> bool {
Uniform(0.0f32..1.0).sample(rng) < self.0
}
}
impl<D: Distrib, E: Distrib> Distrib for (D, E) {
type Sample = (D::Sample, E::Sample);
fn sample(&self, rng: &mut DefaultRng) -> Self::Sample {
(self.0.sample(rng), self.1.sample(rng))
}
}
#[cfg(test)]
#[allow(clippy::manual_range_contains)]
mod tests {
use crate::math::vec3;
use super::*;
const COUNT: usize = 1000;
fn rng() -> DefaultRng {
Default::default()
}
#[test]
fn uniform_i32() {
let dist = Uniform(-123i32..456);
for r in dist.samples(&mut rng()).take(COUNT) {
assert!(-123 <= r && r < 456);
}
}
#[test]
fn uniform_f32() {
let dist = Uniform(-1.23..4.56);
for r in dist.samples(&mut rng()).take(COUNT) {
assert!(-1.23 <= r && r < 4.56);
}
}
#[test]
fn uniform_i32_array() {
let dist = Uniform([0, -10]..[10, 15]);
let sum = dist
.samples(&mut rng())
.take(COUNT)
.inspect(|&[x, y]| {
assert!(0 <= x && x < 10);
assert!(-10 <= y && x < 15);
})
.fold([0, 0], |[ax, ay], [x, y]| [ax + x, ay + y]);
assert_eq!(sum, [4531, 1652]);
}
#[test]
fn uniform_vec3() {
let dist =
Uniform(vec3::<f32, ()>(-2.0, 0.0, -1.0)..vec3(1.0, 2.0, 3.0));
let mean = dist
.samples(&mut rng())
.take(COUNT)
.inspect(|v| {
assert!(-2.0 <= v.x() && v.x() < 1.0);
assert!(0.0 <= v.y() && v.y() < 2.0);
assert!(-1.0 <= v.z() && v.z() < 3.0);
})
.sum::<Vec3>()
/ COUNT as f32;
assert_eq!(mean, vec3(-0.46046025, 1.0209353, 0.9742225));
}
#[test]
fn bernoulli() {
let rng = &mut rng();
let bools = Bernoulli(0.1).samples(rng).take(COUNT);
let approx_100 = bools.filter(|&b| b).count();
assert_eq!(approx_100, 82);
}
#[cfg(feature = "fp")]
#[test]
fn unit_circle() {
use crate::assert_approx_eq;
for v in UnitCircle.samples(&mut rng()).take(COUNT) {
assert_approx_eq!(v.len_sqr(), 1.0, "non-unit vector: {v:?}");
}
}
#[test]
fn vectors_on_unit_disk() {
for v in VectorsOnUnitDisk.samples(&mut rng()).take(COUNT) {
assert!(v.len_sqr() <= 1.0, "vector of len > 1.0: {v:?}");
}
}
#[cfg(feature = "fp")]
#[test]
fn unit_sphere() {
use crate::assert_approx_eq;
for v in UnitSphere.samples(&mut rng()).take(COUNT) {
assert_approx_eq!(v.len_sqr(), 1.0, "non-unit vector: {v:?}");
}
}
#[test]
fn vectors_in_unit_ball() {
for v in VectorsInUnitBall.samples(&mut rng()).take(COUNT) {
assert!(v.len_sqr() <= 1.0, "vector of len > 1.0: {v:?}");
}
}
#[test]
fn zipped_pair() {
let rng = &mut rng();
let dist = (Bernoulli(0.8), Uniform(0..4));
assert_eq!(dist.sample(rng), (true, 1));
assert_eq!(dist.sample(rng), (false, 3));
assert_eq!(dist.sample(rng), (true, 2));
}
}