use crate::{
distr::{NonZero, Normal, Unit},
Vector,
};
use core::marker::PhantomData;
use num_traits::Float;
use rand_::{distributions::Distribution, Rng};
pub struct VectorDistribution<D: Distribution<T>, T, const N: usize> {
pub inner: D,
phantom: PhantomData<Vector<T, N>>,
}
impl<D: Distribution<T>, T, const N: usize> VectorDistribution<D, T, N> {
pub fn new(inner: D) -> Self {
Self {
inner,
phantom: PhantomData,
}
}
}
impl<D: Distribution<T>, T, const N: usize> Distribution<Vector<T, N>>
for VectorDistribution<D, T, N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vector<T, N> {
Vector::init(|| rng.sample(&self.inner))
}
}
impl<T, const N: usize> Distribution<Vector<T, N>> for Normal
where
Normal: Distribution<T>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vector<T, N> {
rng.sample(&VectorDistribution::new(self))
}
}
impl<T: Float, const N: usize> Distribution<Vector<T, N>> for NonZero
where
Normal: Distribution<Vector<T, N>>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vector<T, N> {
loop {
let x = rng.sample(&Normal);
if x.length() > T::epsilon() {
break x;
}
}
}
}
impl<T: Float, const N: usize> Distribution<Vector<T, N>> for Unit
where
NonZero: Distribution<Vector<T, N>>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vector<T, N> {
rng.sample(&NonZero).normalize()
}
}