use crate::{Density, RejectionSampler, SamplingMode, domain::Domain, macros::tval};
use nalgebra::{Dim, OVector, RealField, SVector, Scalar, U1, VectorView};
use rand::RngExt;
use rand_distr::{Distribution, StandardNormal};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
pub struct NormalDensity<T>(T, T, Domain<T, U1>)
where
T: Scalar;
impl<T> NormalDensity<T>
where
T: RealField,
{
pub fn cdf(&self, x: T) -> T {
let z = (x - self.0.clone()) / (self.1.clone() * tval!(2, usize).sqrt());
tval!(0.5, f64) * (T::one() + Self::erf(z))
}
pub fn erf(z: T) -> T {
tval!(2, usize) / T::pi().sqrt()
* (z.clone() - z.clone().powi(3) / tval!(3, usize)
+ z.clone().powi(5) / tval!(10, usize)
- z.clone().powi(7) / tval!(42, usize)
+ z.clone().powi(9) / tval!(216, usize)
- z.powi(11) / tval!(1320, usize))
}
pub fn new(mean: T, std_dev: T, opt_a: Option<T>, opt_b: Option<T>) -> Option<Self> {
if std_dev <= T::zero() {
return None;
}
if opt_a.as_ref().unwrap_or(&T::neg(T::one())) >= opt_b.as_ref().unwrap_or(&T::one()) {
return None;
}
let domain = Domain::new_mdomain(OVector::from_element_generic(U1, U1, (opt_a, opt_b)));
Some(Self(mean, std_dev, domain))
}
pub fn maximum(&self) -> Option<T> {
match &self.2.inner().unwrap() {
(_, Some(max)) => Some(max.clone()),
_ => None,
}
}
pub fn minimum(&self) -> Option<T> {
match &self.2.inner().unwrap() {
(Some(min), _) => Some(min.clone()),
_ => None,
}
}
}
impl<T> Density<T, U1> for NormalDensity<T>
where
T: RealField,
StandardNormal: Distribution<T>,
{
fn density<RStride: Dim, CStride: Dim>(
&self,
sample: &VectorView<T, U1, RStride, CStride>,
) -> Option<T> {
if !self.2.contains(sample) {
return None;
}
Some(
T::one() / (self.1.clone() * tval!(2.0 * std::f64::consts::PI, f64).sqrt())
* (-((sample[0].clone() - self.0.clone()) / self.1.clone()).powi(2)
/ tval!(2, usize))
.exp(),
)
}
fn domain(&self) -> Domain<T, U1> {
self.2.clone()
}
fn mean(&self) -> SVector<T, 1> {
let mu = self.0.clone();
let a = self.minimum();
let b = self.maximum();
if let (Some(min), Some(max)) = (&a, &b) {
if min <= &mu && &mu <= max {
return SVector::from([mu]);
}
if &mu < min {
return SVector::from([min.clone()]);
} else {
return SVector::from([max.clone()]);
}
}
if let Some(min) = &a
&& &mu < min
{
return SVector::from([min.clone()]);
}
if let Some(max) = &b
&& mu > *max
{
return SVector::from([max.clone()]);
}
SVector::from([mu])
}
fn sample(&self, rng: &mut impl RngExt, mode: &SamplingMode) -> Option<SVector<T, 1>> {
self.rejection_sample(rng, mode)
}
fn sample_iter(&self, rng: &mut impl RngExt) -> impl Iterator<Item = Option<SVector<T, 1>>> {
let normal = StandardNormal;
rng.sample_iter(normal).map(move |z| {
let candidate = self.1.clone() * z + self.0.clone();
if self
.2
.contains::<U1, U1>(&SVector::from([candidate.clone()]).as_view())
{
Some(OVector::from([candidate]))
} else {
None
}
})
}
fn variance(&self) -> SVector<T, 1> {
SVector::from([self.1.clone().powi(2)])
}
}
impl<T> RejectionSampler<T, U1> for &NormalDensity<T>
where
T: RealField,
StandardNormal: Distribution<T>,
{
fn generate_candidate(&self, rng: &mut impl RngExt) -> SVector<T, 1> {
let z = rng.sample(StandardNormal);
OVector::from([self.1.clone() * z + self.0.clone()])
}
}
impl<T: RealField> TryFrom<crate::univariate::UnivariateDensity<T>> for NormalDensity<T> {
type Error = ();
fn try_from(value: crate::univariate::UnivariateDensity<T>) -> Result<Self, Self::Error> {
match value {
crate::univariate::UnivariateDensity::Normal(pdf) => Ok(pdf),
_ => Err(()),
}
}
}