opensrdk-probability 0.8.9

Standard probability library for OpenSRDK toolchain.
Documentation
use opensrdk_linear_algebra::{pp::trf::PPTRF, Matrix, SymmetricPackedMatrix};
use std::fmt::Debug;

use crate::DistributionError;

pub trait RandomVariable: Clone + Debug + Send + Sync {
    type RestoreInfo: Eq;
    fn transform_vec(&self) -> (Vec<f64>, Self::RestoreInfo);
    fn len(&self) -> usize;
    fn restore(v: &[f64], info: &Self::RestoreInfo) -> Result<Self, DistributionError>;
}

impl RandomVariable for () {
    type RestoreInfo = ();

    fn transform_vec(&self) -> (Vec<f64>, Self::RestoreInfo) {
        (vec![], ())
    }

    fn len(&self) -> usize {
        0
    }

    fn restore(v: &[f64], _: &Self::RestoreInfo) -> Result<Self, DistributionError> {
        if v.len() != 0 {
            return Err(DistributionError::InvalidRestoreVector);
        }
        Ok(())
    }
}

impl RandomVariable for f64 {
    type RestoreInfo = ();

    fn transform_vec(&self) -> (Vec<f64>, Self::RestoreInfo) {
        (vec![*self], ())
    }

    fn len(&self) -> usize {
        1
    }

    fn restore(v: &[f64], _: &Self::RestoreInfo) -> Result<Self, DistributionError> {
        if v.len() != 1 {
            return Err(DistributionError::InvalidRestoreVector);
        }
        Ok(v[0])
    }
}

impl RandomVariable for u64 {
    type RestoreInfo = u64;

    fn transform_vec(&self) -> (Vec<f64>, Self::RestoreInfo) {
        (vec![], *self)
    }

    fn len(&self) -> usize {
        1
    }

    fn restore(v: &[f64], info: &Self::RestoreInfo) -> Result<Self, DistributionError> {
        if v.len() != 0 {
            return Err(DistributionError::InvalidRestoreVector);
        }
        Ok(*info)
    }
}

impl RandomVariable for usize {
    type RestoreInfo = usize;

    fn transform_vec(&self) -> (Vec<f64>, Self::RestoreInfo) {
        (vec![], *self)
    }

    fn len(&self) -> usize {
        1
    }

    fn restore(v: &[f64], info: &Self::RestoreInfo) -> Result<Self, DistributionError> {
        if v.len() != 0 {
            return Err(DistributionError::InvalidRestoreVector);
        }
        Ok(*info)
    }
}

impl RandomVariable for bool {
    type RestoreInfo = bool;

    fn transform_vec(&self) -> (Vec<f64>, Self::RestoreInfo) {
        (vec![], *self)
    }

    fn len(&self) -> usize {
        1
    }

    fn restore(v: &[f64], info: &Self::RestoreInfo) -> Result<Self, DistributionError> {
        if v.len() != 0 {
            return Err(DistributionError::InvalidRestoreVector);
        }
        Ok(*info)
    }
}

impl RandomVariable for Matrix {
    type RestoreInfo = usize;

    fn transform_vec(&self) -> (Vec<f64>, Self::RestoreInfo) {
        let rows = self.rows();
        (self.clone().vec(), rows)
    }

    fn len(&self) -> usize {
        self.cols() * self.rows()
    }

    fn restore(v: &[f64], info: &Self::RestoreInfo) -> Result<Self, DistributionError> {
        if v.len() != info * info {
            return Err(DistributionError::InvalidRestoreVector);
        }
        Ok(Matrix::from(*info, v.to_vec()).unwrap())
    }
}

impl RandomVariable for PPTRF {
    type RestoreInfo = usize;

    fn transform_vec(&self) -> (Vec<f64>, Self::RestoreInfo) {
        let n = self.0.dim();
        (self.0.elems().to_vec(), n)
    }

    fn len(&self) -> usize {
        self.0.elems().len()
    }

    fn restore(v: &[f64], info: &Self::RestoreInfo) -> Result<Self, DistributionError> {
        if v.len() != info + (info + 1) / 2 {
            return Err(DistributionError::InvalidRestoreVector);
        }
        Ok(PPTRF(
            SymmetricPackedMatrix::from(*info, v.to_vec()).unwrap(),
        ))
    }
}

impl<T, U> RandomVariable for (T, U)
where
    T: RandomVariable,
    U: RandomVariable,
{
    type RestoreInfo = (usize, T::RestoreInfo, U::RestoreInfo);

    fn transform_vec(&self) -> (Vec<f64>, Self::RestoreInfo) {
        let t = self.0.transform_vec();
        let u = self.1.transform_vec();
        let len = t.0.len();

        ([t.0, u.0].concat(), (len, t.1, u.1))
    }

    fn len(&self) -> usize {
        self.0.len() + self.1.len()
    }

    fn restore(v: &[f64], info: &Self::RestoreInfo) -> Result<Self, DistributionError> {
        let (len, t_1, u_1) = info;
        let t_0 = &v[0..*len];
        let u_0 = &v[*len..];

        Ok((T::restore(t_0, t_1)?, U::restore(u_0, u_1)?))
    }
}

impl<T> RandomVariable for Vec<T>
where
    T: RandomVariable,
{
    type RestoreInfo = (Vec<usize>, Vec<T::RestoreInfo>);

    fn transform_vec(&self) -> (Vec<f64>, Self::RestoreInfo) {
        let len = self.len();
        let mut t_0_vec = vec![];
        let mut len_vec = vec![];
        let mut t_1_vec = vec![];
        for i in 0..len {
            t_0_vec = [t_0_vec, self[i].transform_vec().0].concat();
            len_vec.push(self[i].transform_vec().0.len());
            t_1_vec.push(self[i].transform_vec().1);
        }
        (t_0_vec, (len_vec, t_1_vec))
    }

    fn len(&self) -> usize {
        self.iter().map(|self_i| self_i.len()).sum::<usize>()
    }

    fn restore(v: &[f64], info: &Self::RestoreInfo) -> Result<Self, DistributionError> {
        let len_vec = &info.0;
        let t_1_vec = &info.1;
        if len_vec.len() != t_1_vec.len() {
            return Err(DistributionError::InvalidRestoreVector);
        }
        let len = len_vec.len();
        let mut t_vec = vec![];
        let mut n = 0;

        for i in 0..len {
            let len_i = len_vec[i];
            let t_1_i = &t_1_vec[i];
            t_vec.push(T::restore(&v[n..n + len_i], t_1_i)?);
            n += len_vec[i];
        }
        Ok(t_vec)
    }
}