use crate::{DependentJoint, Distribution, IndependentJoint, RandomVariable};
use crate::{DistributionError, SampleableDistribution};
use rand::prelude::*;
use rand_distr::Dirichlet as RandDirichlet;
use rayon::{iter::IntoParallelIterator, prelude::*};
use std::{ops::BitAnd, ops::Mul};
#[derive(Clone, Debug)]
pub struct Dirichlet;
#[derive(thiserror::Error, Debug)]
pub enum DirichletError {
#[error("Dimension mismatch")]
DimensionMismatch,
#[error("Length of 'α' must be >= 2")]
AlphaLenMustBeGTE2,
#[error("'α' must be positibe")]
AlphaMustBePositive,
#[error("Unknown error")]
Unknown,
}
impl Distribution for Dirichlet {
type Value = Vec<f64>;
type Condition = DirichletParams;
fn p_kernel(&self, x: &Self::Value, theta: &Self::Condition) -> Result<f64, DistributionError> {
let alpha = theta.alpha();
if x.len() != alpha.len() {
return Err(DistributionError::InvalidParameters(
DirichletError::DimensionMismatch.into(),
));
}
Ok(x.into_par_iter()
.zip(alpha.into_par_iter())
.map(|(&xi, &alphai)| xi.powf(alphai - 1.0))
.product::<f64>())
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct DirichletParams {
alpha: Vec<f64>,
}
impl DirichletParams {
pub fn new(alpha: Vec<f64>) -> Result<Self, DistributionError> {
if alpha.len() < 2 {
return Err(DistributionError::InvalidParameters(
DirichletError::AlphaLenMustBeGTE2.into(),
));
}
for &alpha_i in alpha.iter() {
if alpha_i <= 0.0 {
return Err(DistributionError::InvalidParameters(
DirichletError::AlphaMustBePositive.into(),
));
}
}
Ok(Self { alpha })
}
pub fn alpha(&self) -> &[f64] {
&self.alpha
}
}
impl RandomVariable for DirichletParams {
type RestoreInfo = usize;
fn transform_vec(&self) -> (Vec<f64>, Self::RestoreInfo) {
let n = self.alpha.len();
(self.clone().alpha, n)
}
fn len(&self) -> usize {
self.alpha.len()
}
fn restore(v: &[f64], info: &Self::RestoreInfo) -> Result<Self, DistributionError> {
let n = *info;
if n < 2 {
return Err(DistributionError::InvalidRestoreVector);
}
let alpha = v.to_vec();
Self::new(alpha)
}
}
impl<Rhs, TRhs> Mul<Rhs> for Dirichlet
where
Rhs: Distribution<Value = TRhs, Condition = DirichletParams>,
TRhs: RandomVariable,
{
type Output = IndependentJoint<Self, Rhs, Vec<f64>, TRhs, DirichletParams>;
fn mul(self, rhs: Rhs) -> Self::Output {
IndependentJoint::new(self, rhs)
}
}
impl<Rhs, URhs> BitAnd<Rhs> for Dirichlet
where
Rhs: Distribution<Value = DirichletParams, Condition = URhs>,
URhs: RandomVariable,
{
type Output = DependentJoint<Self, Rhs, Vec<f64>, DirichletParams, URhs>;
fn bitand(self, rhs: Rhs) -> Self::Output {
DependentJoint::new(self, rhs)
}
}
impl SampleableDistribution for Dirichlet {
fn sample(
&self,
theta: &Self::Condition,
rng: &mut dyn RngCore,
) -> Result<Self::Value, DistributionError> {
let alpha = theta.alpha();
let dirichlet = match RandDirichlet::new(alpha) {
Ok(n) => n,
Err(e) => return Err(DistributionError::Others(e.into())),
};
Ok(rng.sample(dirichlet))
}
}
#[cfg(test)]
mod tests {
use crate::{Dirichlet, DirichletParams, Distribution, SampleableDistribution};
use rand::prelude::*;
#[test]
fn it_works() {
let dirichlet = Dirichlet;
let mut rng = StdRng::from_seed([1; 32]);
let alpha = vec![1.0, 1.0, 2.0, 3.0, 4.0, 5.0];
let x = dirichlet
.sample(&DirichletParams::new(alpha).unwrap(), &mut rng)
.unwrap();
println!("{:#?}", x);
}
}