use crate::{Density, Domain, domain::SDomain};
use nalgebra::{Dim, OVector, RealField, SVector, U1, VectorView};
use rand::Rng;
use rand_distr::{Uniform, uniform::SampleUniform};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
pub struct CosineDensity<T>(SDomain<T>)
where
T: RealField;
impl<T> CosineDensity<T>
where
T: RealField,
{
pub fn new(minimum: T, maximum: T) -> Option<Self> {
match (minimum > -T::frac_pi_2(), maximum < T::frac_pi_2()) {
(true, true) => Some(Self(SDomain::Bounded(minimum, maximum))),
_ => None,
}
}
}
impl<T> Density<T, U1> for CosineDensity<T>
where
T: RealField + SampleUniform,
SDomain<T>: Domain<T, U1>,
for<'a> &'a SDomain<T>: Domain<T, U1>,
{
fn density<RStride: Dim, CStride: Dim>(
&self,
sample: &VectorView<T, U1, RStride, CStride>,
) -> Option<T> {
(&self).density(sample)
}
fn domain(&self) -> impl Domain<T, U1> + 'static {
self.0.clone()
}
fn center(&self) -> SVector<T, 1> {
(&self).center()
}
fn is_constant(&self) -> OVector<bool, U1> {
(&self).is_constant()
}
fn sample(&self, rng: &mut impl Rng, max_attempts: usize) -> Option<SVector<T, 1>> {
(&self).sample(rng, max_attempts)
}
}
impl<T> Density<T, U1> for &CosineDensity<T>
where
T: RealField + SampleUniform,
SDomain<T>: Domain<T, U1>,
for<'a> &'a SDomain<T>: Domain<T, U1>,
{
fn density<RStride: Dim, CStride: Dim>(
&self,
sample: &VectorView<T, U1, RStride, CStride>,
) -> Option<T> {
if !self.0.contains(sample) {
return None;
}
Some(sample[0].clone().cos())
}
fn domain(&self) -> impl Domain<T, U1> + 'static {
self.0.clone()
}
fn center(&self) -> SVector<T, 1> {
match &self.0 {
SDomain::Bounded(min, max) => {
let mean = (max.clone().sin() - min.clone().sin()) / (max.clone() - min.clone());
SVector::from([mean])
}
_ => unreachable!(),
}
}
fn is_constant(&self) -> OVector<bool, U1> {
OVector::<bool, U1>::from_element(false)
}
fn sample(&self, rng: &mut impl Rng, _max_attempts: usize) -> Option<SVector<T, 1>> {
match &self.0 {
SDomain::Bounded(min, max) => {
let uniform = Uniform::new_inclusive(min.clone().sin(), max.clone().sin()).unwrap();
Some(SVector::from([rng.sample(uniform).asin()]))
}
_ => unreachable!(),
}
}
}