prodef 0.1.0

A simple Rust crate for handling probability distributions.
Documentation
use crate::{Density, Domain, domain::SDomain};
use nalgebra::{Dim, OVector, RealField, SVector, U1, VectorView};
use rand::Rng;
use rand_distr::{Distribution, StandardNormal, uniform::SampleUniform};
use serde::{Deserialize, Serialize};

/// A univariate normal probability density function.
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
pub struct NormalDensity<T>(T, T, SDomain<T>)
where
    T: RealField;

impl<T> NormalDensity<T>
where
    T: PartialOrd + RealField,
{
    /// Create a new [`NormalDensity`].
    pub fn new(mean: T, std_dev: T, opt_a: Option<T>, opt_b: Option<T>) -> Option<Self> {
        let domain = SDomain::new(opt_a, opt_b)?;

        if matches!(domain, SDomain::Constant(_)) {
            return None;
        }

        Some(Self(mean, std_dev, domain))
    }

    /// Evaluates the cumulative distribution function at `x`.
    pub fn cdf(&self, x: T) -> T {
        let z = (x - self.0.clone()) / (self.1.clone() * T::from_f64(2.0).unwrap().sqrt());

        T::from_f64(0.5).unwrap() * (T::one() + Self::erf(z))
    }

    /// Evaluates the error function at `x`.
    pub fn erf(z: T) -> T {
        T::from_f64(2.0).unwrap() / T::pi().sqrt()
            * (z.clone() - z.clone().powi(3) / T::from_f64(3.0).unwrap()
                + z.clone().clone().powi(5) / T::from_f64(10.0).unwrap()
                - z.clone().powi(7) / T::from_f64(42.0).unwrap()
                + z.clone().powi(9) / T::from_f64(216.0).unwrap()
                - z.clone().powi(11) / T::from_f64(1320.0).unwrap())
    }
}

impl<T> Density<T, U1> for NormalDensity<T>
where
    T: RealField + SampleUniform,
    SDomain<T>: Domain<T, U1>,
    for<'a> &'a SDomain<T>: Domain<T, U1>,
    StandardNormal: Distribution<T>,
{
    fn density<RStride: Dim, CStride: Dim>(
        &self,
        sample: &VectorView<T, U1, RStride, CStride>,
    ) -> Option<T> {
        (&self).density(sample)
    }

    fn domain(&self) -> impl Domain<T, U1> + 'static {
        self.2.clone()
    }

    fn center(&self) -> SVector<T, 1> {
        (&self).center()
    }

    fn is_constant(&self) -> OVector<bool, U1> {
        (&self).is_constant()
    }

    fn sample(&self, rng: &mut impl Rng, max_attempts: usize) -> Option<SVector<T, 1>> {
        (&self).sample(rng, max_attempts)
    }
}

impl<T> Density<T, U1> for &NormalDensity<T>
where
    T: RealField + SampleUniform,
    SDomain<T>: Domain<T, U1>,
    for<'a> &'a SDomain<T>: Domain<T, U1>,
    StandardNormal: Distribution<T>,
{
    fn density<RStride: Dim, CStride: Dim>(
        &self,
        sample: &VectorView<T, U1, RStride, CStride>,
    ) -> Option<T> {
        if !self.2.contains(sample) {
            return None;
        }

        Some(
            T::one() / (self.1.clone() * T::from_f64(2.0 * std::f64::consts::PI).unwrap().sqrt())
                * (-((sample[0].clone() - self.0.clone()) / self.1.clone()).powi(2)
                    / T::from_usize(2).unwrap())
                .exp(),
        )
    }

    fn domain(&self) -> impl Domain<T, U1> + 'static {
        self.2.clone()
    }

    fn center(&self) -> SVector<T, 1> {
        SVector::from([self.0.clone()])
    }

    fn is_constant(&self) -> OVector<bool, U1> {
        OVector::<bool, U1>::from_element(false)
    }

    fn sample(&self, rng: &mut impl Rng, max_attempts: usize) -> Option<SVector<T, 1>> {
        let normal = StandardNormal;

        let sample = {
            let mut attempts = 0;
            let mut candidate = self.1.clone() * rng.sample(normal) + self.0.clone();

            // Continsouly draw candidates until a sample is drawn within the domain.
            while !self
                .2
                .contains::<U1, U1>(&SVector::from([candidate.clone()]).as_view())
            {
                candidate = rng.sample(normal);

                attempts += 1;

                if attempts > max_attempts {
                    return None;
                }
            }

            candidate
        };

        Some(SVector::from([sample]))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use approx::ulps_eq;

    #[test]
    fn test_normal() {
        let normal = NormalDensity::new(0.1, 0.2, None, None).unwrap();

        assert!(ulps_eq!(normal.cdf(-0.1), 0.15865588083956078));
        assert!(ulps_eq!(normal.cdf(0.1), 0.5));
        assert!(ulps_eq!(NormalDensity::erf(0.71), 0.6846642286867719));
    }
}