vecmat/vector/
distr.rs

1use crate::{
2    distr::{NonZero, Normal, Unit},
3    Vector,
4};
5use core::marker::PhantomData;
6use num_traits::Float;
7use rand_::{distributions::Distribution, Rng};
8
9/// Per-component vector distribution.
10pub struct VectorDistribution<D: Distribution<T>, T, const N: usize> {
11    pub inner: D,
12    phantom: PhantomData<Vector<T, N>>,
13}
14
15impl<D: Distribution<T>, T, const N: usize> VectorDistribution<D, T, N> {
16    pub fn new(inner: D) -> Self {
17        Self {
18            inner,
19            phantom: PhantomData,
20        }
21    }
22}
23
24impl<D: Distribution<T>, T, const N: usize> Distribution<Vector<T, N>>
25    for VectorDistribution<D, T, N>
26{
27    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vector<T, N> {
28        Vector::init(|| rng.sample(&self.inner))
29    }
30}
31
32impl<T, const N: usize> Distribution<Vector<T, N>> for Normal
33where
34    Normal: Distribution<T>,
35{
36    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vector<T, N> {
37        rng.sample(&VectorDistribution::new(self))
38    }
39}
40
41impl<T: Float, const N: usize> Distribution<Vector<T, N>> for NonZero
42where
43    Normal: Distribution<Vector<T, N>>,
44{
45    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vector<T, N> {
46        loop {
47            let x = rng.sample(&Normal);
48            if x.length() > T::epsilon() {
49                break x;
50            }
51        }
52    }
53}
54
55impl<T: Float, const N: usize> Distribution<Vector<T, N>> for Unit
56where
57    NonZero: Distribution<Vector<T, N>>,
58{
59    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vector<T, N> {
60        rng.sample(&NonZero).normalize()
61    }
62}