1use crate::{
2 distr::{NonZero, Normal, Unit},
3 Vector,
4};
5use core::marker::PhantomData;
6use num_traits::Float;
7use rand_::{distributions::Distribution, Rng};
8
9pub struct VectorDistribution<D: Distribution<T>, T, const N: usize> {
11 pub inner: D,
12 phantom: PhantomData<Vector<T, N>>,
13}
14
15impl<D: Distribution<T>, T, const N: usize> VectorDistribution<D, T, N> {
16 pub fn new(inner: D) -> Self {
17 Self {
18 inner,
19 phantom: PhantomData,
20 }
21 }
22}
23
24impl<D: Distribution<T>, T, const N: usize> Distribution<Vector<T, N>>
25 for VectorDistribution<D, T, N>
26{
27 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vector<T, N> {
28 Vector::init(|| rng.sample(&self.inner))
29 }
30}
31
32impl<T, const N: usize> Distribution<Vector<T, N>> for Normal
33where
34 Normal: Distribution<T>,
35{
36 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vector<T, N> {
37 rng.sample(&VectorDistribution::new(self))
38 }
39}
40
41impl<T: Float, const N: usize> Distribution<Vector<T, N>> for NonZero
42where
43 Normal: Distribution<Vector<T, N>>,
44{
45 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vector<T, N> {
46 loop {
47 let x = rng.sample(&Normal);
48 if x.length() > T::epsilon() {
49 break x;
50 }
51 }
52 }
53}
54
55impl<T: Float, const N: usize> Distribution<Vector<T, N>> for Unit
56where
57 NonZero: Distribution<Vector<T, N>>,
58{
59 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vector<T, N> {
60 rng.sample(&NonZero).normalize()
61 }
62}