opensrdk-probability 0.8.9

Standard probability library for OpenSRDK toolchain.
Documentation
use crate::{
    DependentJoint, Distribution, IndependentJoint, RandomVariable, SampleableDistribution,
};
use crate::{DistributionError, Event};
use rand::prelude::*;
use std::{
    collections::HashMap,
    fmt::Debug,
    ops::{BitAnd, Mul},
};

pub mod params;

pub use params::*;

#[derive(Clone, Debug)]
pub struct SwitchedDistribution<'a, D, T, U>
where
    D: Distribution<Value = T, Condition = U>,
    T: RandomVariable,
    U: Clone + Debug + Send + Sync,
{
    distribution: &'a D,
    map: &'a HashMap<u32, U>,
}

#[derive(thiserror::Error, Debug)]
pub enum SwitchedError {
    #[error("Key not found")]
    KeyNotFound,
    #[error("Unknown error")]
    Unknown,
}

impl<'a, D, T, U> SwitchedDistribution<'a, D, T, U>
where
    D: Distribution<Value = T, Condition = U>,
    T: RandomVariable,
    U: Event,
{
    pub fn new(distribution: &'a D, map: &'a HashMap<u32, U>) -> Self {
        Self { distribution, map }
    }

    pub fn distribution(&self) -> &D {
        &self.distribution
    }
}

impl<'a, D, T, U> Distribution for SwitchedDistribution<'a, D, T, U>
where
    D: Distribution<Value = T, Condition = U>,
    T: RandomVariable,
    U: Event,
{
    type Value = T;
    type Condition = SwitchedParams<U>;

    fn p_kernel(&self, x: &Self::Value, theta: &Self::Condition) -> Result<f64, DistributionError> {
        let s = theta;

        match s {
            SwitchedParams::Key(k) => match self.map.get(k) {
                Some(theta) => self.distribution.p_kernel(x, theta),
                None => Err(DistributionError::InvalidParameters(
                    SwitchedError::KeyNotFound.into(),
                )),
            },
            SwitchedParams::Direct(theta) => self.distribution.p_kernel(x, theta),
        }
    }
}

pub trait SwitchableDistribution<U>: Distribution + Sized
where
    U: Event,
{
    fn switch<'a>(
        &'a self,
        map: &'a HashMap<u32, U>,
    ) -> SwitchedDistribution<'a, Self, Self::Value, Self::Condition>;
}

impl<D, T, U> SwitchableDistribution<U> for D
where
    D: Distribution<Value = T, Condition = U>,
    T: RandomVariable,
    U: Event,
{
    fn switch<'a>(
        &'a self,
        map: &'a HashMap<u32, U>,
    ) -> SwitchedDistribution<'a, Self, Self::Value, U> {
        SwitchedDistribution::<Self, Self::Value, U>::new(self, map)
    }
}

impl<'a, D, T, U, Rhs, TRhs> Mul<Rhs> for SwitchedDistribution<'a, D, T, U>
where
    D: Distribution<Value = T, Condition = U>,
    T: RandomVariable,
    U: Event,
    Rhs: Distribution<Value = TRhs, Condition = SwitchedParams<U>>,
    TRhs: RandomVariable,
{
    type Output = IndependentJoint<Self, Rhs, T, TRhs, SwitchedParams<U>>;

    fn mul(self, rhs: Rhs) -> Self::Output {
        IndependentJoint::new(self, rhs)
    }
}

impl<'a, D, T, U, Rhs, URhs> BitAnd<Rhs> for SwitchedDistribution<'a, D, T, U>
where
    D: Distribution<Value = T, Condition = U>,
    T: RandomVariable,
    U: Event,
    Rhs: Distribution<Value = SwitchedParams<U>, Condition = URhs>,
    URhs: Event,
{
    type Output = DependentJoint<Self, Rhs, T, SwitchedParams<U>, URhs>;

    fn bitand(self, rhs: Rhs) -> Self::Output {
        DependentJoint::new(self, rhs)
    }
}

impl<'a, D, T, U> SampleableDistribution for SwitchedDistribution<'a, D, T, U>
where
    D: SampleableDistribution<Value = T, Condition = U>,
    T: RandomVariable,
    U: Event,
{
    fn sample(
        &self,
        theta: &Self::Condition,
        rng: &mut dyn RngCore,
    ) -> Result<Self::Value, DistributionError> {
        let s = theta;

        match s {
            SwitchedParams::Key(k) => match self.map.get(k) {
                Some(theta) => self.distribution.sample(theta, rng),
                None => Err(DistributionError::InvalidParameters(
                    SwitchedError::KeyNotFound.into(),
                )),
            },
            SwitchedParams::Direct(theta) => self.distribution.sample(theta, rng),
        }
    }
}

#[cfg(test)]
mod tests {
    use crate::*;
    use std::collections::HashMap;

    #[test]
    fn it_works() {
        let mut theta = HashMap::new();
        theta.insert(1u32, NormalParams::new(1.0, 2.0).unwrap());
        theta.insert(2u32, NormalParams::new(2.0, 2.0).unwrap());
        theta.insert(3u32, NormalParams::new(3.0, 2.0).unwrap());
        theta.insert(4u32, NormalParams::new(4.0, 2.0).unwrap());
        let distr = Normal.switch(&theta);
        let switched_fk = distr.p_kernel(&0f64, &SwitchedParams::Key(1u32)).unwrap();
        let fk = Normal
            .p_kernel(&0f64, &NormalParams::new(1.0, 2.0).unwrap())
            .unwrap();

        assert_eq!(switched_fk, fk);
    }
}