use crate::{Density, Domain, domain::SDomain};
use nalgebra::{Dim, OVector, RealField, SVector, U1, VectorView};
use rand::Rng;
use rand_distr::{Uniform, uniform::SampleUniform};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
pub struct UniformDensity<T>(SDomain<T>)
where
T: RealField;
impl<T> UniformDensity<T>
where
T: RealField,
{
pub fn new(a: T, b: T) -> Option<Self> {
if a > b {
None
} else {
Some(Self(SDomain::Bounded(a, b)))
}
}
}
impl<T> Density<T, U1> for UniformDensity<T>
where
T: RealField + SampleUniform,
SDomain<T>: Domain<T, U1>,
for<'a> &'a SDomain<T>: Domain<T, U1>,
{
fn density<RStride: Dim, CStride: Dim>(
&self,
sample: &VectorView<T, U1, RStride, CStride>,
) -> Option<T> {
(&self).density(sample)
}
fn domain(&self) -> impl Domain<T, U1> + 'static {
self.0.clone()
}
fn center(&self) -> SVector<T, 1> {
(&self).center()
}
fn is_constant(&self) -> OVector<bool, U1> {
(&self).is_constant()
}
fn sample(&self, rng: &mut impl Rng, max_attempts: usize) -> Option<SVector<T, 1>> {
(&self).sample(rng, max_attempts)
}
}
impl<T> Density<T, U1> for &UniformDensity<T>
where
T: RealField + SampleUniform,
SDomain<T>: Domain<T, U1>,
for<'a> &'a SDomain<T>: Domain<T, U1>,
{
fn density<RStride: Dim, CStride: Dim>(
&self,
sample: &VectorView<T, U1, RStride, CStride>,
) -> Option<T> {
if !self.0.contains(sample) {
return None;
}
Some(
T::one()
/ (self.0.maximum_values().unwrap()[0].clone()
- self.0.minimum_values().unwrap()[0].clone()),
)
}
fn domain(&self) -> impl Domain<T, U1> + 'static {
self.0.clone()
}
fn center(&self) -> SVector<T, 1> {
match &self.0 {
SDomain::Bounded(a, b) => {
let mean = (a.clone() + b.clone()) / T::from_usize(2).unwrap();
SVector::from([mean])
}
_ => unreachable!(),
}
}
fn is_constant(&self) -> OVector<bool, U1> {
OVector::<bool, U1>::from_element(false)
}
fn sample(&self, rng: &mut impl Rng, _max_attempts: usize) -> Option<SVector<T, 1>> {
let uniform = Uniform::new_inclusive(
self.0.minimum_values().unwrap()[0].clone(),
self.0.maximum_values().unwrap()[0].clone(),
)
.unwrap();
Some(SVector::from([rng.sample(uniform)]))
}
}