prodef 0.1.0

A simple Rust crate for handling probability distributions.
Documentation
use crate::{Density, Domain, domain::SDomain};
use nalgebra::{Dim, OVector, RealField, SVector, U1, VectorView};
use rand::Rng;
use rand_distr::{Uniform, uniform::SampleUniform};
use serde::{Deserialize, Serialize};

/// A cosine probability density function.
///
/// The domain must be within the interval [-π/2, π/2].
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
pub struct CosineDensity<T>(SDomain<T>)
where
    T: RealField;

impl<T> CosineDensity<T>
where
    T: RealField,
{
    /// Create a new [`CosineDensity`].
    ///
    /// Returns [`None`] for invalid domain.
    pub fn new(minimum: T, maximum: T) -> Option<Self> {
        match (minimum > -T::frac_pi_2(), maximum < T::frac_pi_2()) {
            (true, true) => Some(Self(SDomain::Bounded(minimum, maximum))),
            _ => None,
        }
    }
}

impl<T> Density<T, U1> for CosineDensity<T>
where
    T: RealField + SampleUniform,
    SDomain<T>: Domain<T, U1>,
    for<'a> &'a SDomain<T>: Domain<T, U1>,
{
    fn density<RStride: Dim, CStride: Dim>(
        &self,
        sample: &VectorView<T, U1, RStride, CStride>,
    ) -> Option<T> {
        (&self).density(sample)
    }

    fn domain(&self) -> impl Domain<T, U1> + 'static {
        self.0.clone()
    }

    fn center(&self) -> SVector<T, 1> {
        (&self).center()
    }

    fn is_constant(&self) -> OVector<bool, U1> {
        (&self).is_constant()
    }

    fn sample(&self, rng: &mut impl Rng, max_attempts: usize) -> Option<SVector<T, 1>> {
        (&self).sample(rng, max_attempts)
    }
}

impl<T> Density<T, U1> for &CosineDensity<T>
where
    T: RealField + SampleUniform,
    SDomain<T>: Domain<T, U1>,
    for<'a> &'a SDomain<T>: Domain<T, U1>,
{
    fn density<RStride: Dim, CStride: Dim>(
        &self,
        sample: &VectorView<T, U1, RStride, CStride>,
    ) -> Option<T> {
        if !self.0.contains(sample) {
            return None;
        }

        Some(sample[0].clone().cos())
    }

    fn domain(&self) -> impl Domain<T, U1> + 'static {
        self.0.clone()
    }

    fn center(&self) -> SVector<T, 1> {
        match &self.0 {
            SDomain::Bounded(min, max) => {
                let mean = (max.clone().sin() - min.clone().sin()) / (max.clone() - min.clone());
                SVector::from([mean])
            }
            _ => unreachable!(),
        }
    }

    fn is_constant(&self) -> OVector<bool, U1> {
        OVector::<bool, U1>::from_element(false)
    }

    fn sample(&self, rng: &mut impl Rng, _max_attempts: usize) -> Option<SVector<T, 1>> {
        // The range is limited to the interval [-π/2, π/2].
        // This invariant is guaranteed by the constructor.
        match &self.0 {
            SDomain::Bounded(min, max) => {
                let uniform = Uniform::new_inclusive(min.clone().sin(), max.clone().sin()).unwrap();

                Some(SVector::from([rng.sample(uniform).asin()]))
            }
            _ => unreachable!(),
        }
    }
}