use crate::{Density, Domain, domain::SDomain};
use nalgebra::{Dim, OVector, RealField, SVector, U1, VectorView};
use rand::Rng;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
pub struct ConstantDensity<T>(SDomain<T>)
where
T: RealField;
impl<T> ConstantDensity<T>
where
T: RealField,
{
pub fn new(constant: T) -> Self {
Self(SDomain::Constant(constant))
}
pub fn constant(&self) -> T {
match &self.0 {
SDomain::Constant(value) => value.clone(),
_ => unreachable!(),
}
}
}
impl<T> Density<T, U1> for ConstantDensity<T>
where
T: RealField,
SDomain<T>: Domain<T, U1>,
for<'a> &'a SDomain<T>: Domain<T, U1>,
{
fn density<RStride: Dim, CStride: Dim>(
&self,
sample: &VectorView<T, U1, RStride, CStride>,
) -> Option<T> {
(&self).density(sample)
}
fn domain(&self) -> impl Domain<T, U1> + 'static {
self.0.clone()
}
fn center(&self) -> SVector<T, 1> {
(&self).center()
}
fn is_constant(&self) -> OVector<bool, U1> {
(&self).is_constant()
}
fn sample(&self, rng: &mut impl Rng, max_attempts: usize) -> Option<SVector<T, 1>> {
(&self).sample(rng, max_attempts)
}
}
impl<T> Density<T, U1> for &ConstantDensity<T>
where
T: RealField,
SDomain<T>: Domain<T, U1>,
for<'a> &'a SDomain<T>: Domain<T, U1>,
{
fn density<RStride: Dim, CStride: Dim>(
&self,
sample: &VectorView<T, U1, RStride, CStride>,
) -> Option<T> {
if !self.0.contains(sample) {
return None;
}
Some(T::one())
}
fn domain(&self) -> impl Domain<T, U1> + 'static {
self.0.clone()
}
fn center(&self) -> SVector<T, 1> {
match &self.0 {
SDomain::Constant(constant) => SVector::from([constant.clone()]),
_ => unreachable!(),
}
}
fn is_constant(&self) -> OVector<bool, U1> {
OVector::<bool, U1>::from_element(true)
}
fn sample(&self, _rng: &mut impl Rng, _max_attempts: usize) -> Option<SVector<T, 1>> {
match &self.0 {
SDomain::Constant(constant) => Some(SVector::from([constant.clone()])),
_ => unreachable!(),
}
}
}