#![cfg(feature = "alloc")]
use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal, multi::MultiDistribution};
use core::fmt;
use num_traits::{Float, NumCast};
use rand::Rng;
#[cfg(feature = "serde")]
use serde_with::serde_as;
use alloc::{boxed::Box, vec, vec::Vec};
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", serde_as)]
struct DirichletFromGamma<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
samplers: Vec<Gamma<F>>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum DirichletFromGammaError {
GammmaNewFailed,
}
impl<F> DirichletFromGamma<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
#[inline]
fn new(alpha: &[F]) -> Result<DirichletFromGamma<F>, DirichletFromGammaError> {
let mut gamma_dists = Vec::new();
for a in alpha {
let dist =
Gamma::new(*a, F::one()).map_err(|_| DirichletFromGammaError::GammmaNewFailed)?;
gamma_dists.push(dist);
}
Ok(DirichletFromGamma {
samplers: gamma_dists,
})
}
}
impl<F> MultiDistribution<F> for DirichletFromGamma<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
#[inline]
fn sample_len(&self) -> usize {
self.samplers.len()
}
fn sample_to_slice<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) {
assert_eq!(output.len(), self.sample_len());
let mut sum = F::zero();
for (s, g) in output.iter_mut().zip(self.samplers.iter()) {
*s = g.sample(rng);
sum = sum + *s;
}
let invacc = F::one() / sum;
for s in output.iter_mut() {
*s = *s * invacc;
}
}
}
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
struct DirichletFromBeta<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
samplers: Box<[Beta<F>]>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum DirichletFromBetaError {
BetaNewFailed,
}
impl<F> DirichletFromBeta<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
#[inline]
fn new(alpha: &[F]) -> Result<DirichletFromBeta<F>, DirichletFromBetaError> {
let n = alpha.len();
let mut alpha_rev_csum = vec![alpha[n - 1]; n - 1];
for k in 0..(n - 2) {
alpha_rev_csum[n - 3 - k] = alpha_rev_csum[n - 2 - k] + alpha[n - 2 - k];
}
let mut beta_dists = Vec::new();
for (&a, &b) in alpha[..(n - 1)].iter().zip(alpha_rev_csum.iter()) {
let dist = Beta::new(a, b).map_err(|_| DirichletFromBetaError::BetaNewFailed)?;
beta_dists.push(dist);
}
Ok(DirichletFromBeta {
samplers: beta_dists.into_boxed_slice(),
})
}
}
impl<F> MultiDistribution<F> for DirichletFromBeta<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
#[inline]
fn sample_len(&self) -> usize {
self.samplers.len() + 1
}
fn sample_to_slice<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) {
assert_eq!(output.len(), self.sample_len());
let mut acc = F::one();
for (s, beta) in output.iter_mut().zip(self.samplers.iter()) {
let beta_sample = beta.sample(rng);
*s = acc * beta_sample;
acc = acc * (F::one() - beta_sample);
}
output[output.len() - 1] = acc;
}
}
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", serde_as)]
enum DirichletRepr<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
FromGamma(DirichletFromGamma<F>),
FromBeta(DirichletFromBeta<F>),
}
#[cfg_attr(feature = "serde", serde_as)]
#[derive(Clone, Debug, PartialEq)]
pub struct Dirichlet<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
repr: DirichletRepr<F>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
AlphaTooShort,
AlphaTooSmall,
AlphaSubnormal,
AlphaInfinite,
FailedToCreateGamma,
FailedToCreateBeta,
SizeTooSmall,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Error::AlphaTooShort | Error::SizeTooSmall => {
"less than 2 dimensions in Dirichlet distribution"
}
Error::AlphaTooSmall => "alpha is not positive in Dirichlet distribution",
Error::AlphaSubnormal => "alpha contains a subnormal value in Dirichlet distribution",
Error::AlphaInfinite => "alpha contains an infinite value in Dirichlet distribution",
Error::FailedToCreateGamma => {
"failed to create required Gamma distribution for Dirichlet distribution"
}
Error::FailedToCreateBeta => {
"failed to create required Beta distribution for Dirichlet distribution"
}
})
}
}
#[cfg(feature = "std")]
impl std::error::Error for Error {}
impl<F> Dirichlet<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
#[inline]
pub fn new(alpha: &[F]) -> Result<Dirichlet<F>, Error> {
if alpha.len() < 2 {
return Err(Error::AlphaTooShort);
}
for &ai in alpha.iter() {
if !(ai > F::zero()) {
return Err(Error::AlphaTooSmall);
}
if ai.is_infinite() {
return Err(Error::AlphaInfinite);
}
if !ai.is_normal() {
return Err(Error::AlphaSubnormal);
}
}
if alpha.iter().all(|&x| x <= NumCast::from(0.1).unwrap()) {
let dist = DirichletFromBeta::new(alpha).map_err(|_| Error::FailedToCreateBeta)?;
Ok(Dirichlet {
repr: DirichletRepr::FromBeta(dist),
})
} else {
let dist = DirichletFromGamma::new(alpha).map_err(|_| Error::FailedToCreateGamma)?;
Ok(Dirichlet {
repr: DirichletRepr::FromGamma(dist),
})
}
}
}
impl<F> MultiDistribution<F> for Dirichlet<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
#[inline]
fn sample_len(&self) -> usize {
match &self.repr {
DirichletRepr::FromGamma(dirichlet) => dirichlet.sample_len(),
DirichletRepr::FromBeta(dirichlet) => dirichlet.sample_len(),
}
}
fn sample_to_slice<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) {
match &self.repr {
DirichletRepr::FromGamma(dirichlet) => dirichlet.sample_to_slice(rng, output),
DirichletRepr::FromBeta(dirichlet) => dirichlet.sample_to_slice(rng, output),
}
}
}
impl<F> Distribution<Vec<F>> for Dirichlet<F>
where
F: Float + Default,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
distribution_impl!(F);
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_dirichlet() {
let d = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap();
let mut rng = crate::test::rng(221);
let samples = d.sample(&mut rng);
assert!(samples.into_iter().all(|x: f64| x > 0.0));
}
#[test]
#[should_panic]
fn test_dirichlet_invalid_length() {
Dirichlet::new(&[0.5]).unwrap();
}
#[test]
#[should_panic]
fn test_dirichlet_alpha_zero() {
Dirichlet::new(&[0.1, 0.0, 0.3]).unwrap();
}
#[test]
#[should_panic]
fn test_dirichlet_alpha_negative() {
Dirichlet::new(&[0.1, -1.5, 0.3]).unwrap();
}
#[test]
#[should_panic]
fn test_dirichlet_alpha_nan() {
Dirichlet::new(&[0.5, f64::NAN, 0.25]).unwrap();
}
#[test]
#[should_panic]
fn test_dirichlet_alpha_subnormal() {
Dirichlet::new(&[0.5, 1.5e-321, 0.25]).unwrap();
}
#[test]
#[should_panic]
fn test_dirichlet_alpha_inf() {
Dirichlet::new(&[0.5, f64::INFINITY, 0.25]).unwrap();
}
#[test]
fn dirichlet_distributions_can_be_compared() {
assert_eq!(Dirichlet::new(&[1.0, 2.0]), Dirichlet::new(&[1.0, 2.0]));
}
fn check_dirichlet_means<const N: usize>(alpha: [f64; N], n: i32, rtol: f64, seed: u64) {
let d = Dirichlet::new(&alpha).unwrap();
let mut rng = crate::test::rng(seed);
let mut sums = [0.0; N];
for _ in 0..n {
let samples = d.sample(&mut rng);
for i in 0..N {
sums[i] += samples[i];
}
}
let sample_mean = sums.map(|x| x / n as f64);
let alpha_sum: f64 = alpha.iter().sum();
let expected_mean = alpha.map(|x| x / alpha_sum);
for i in 0..N {
average::assert_almost_eq!(sample_mean[i], expected_mean[i], rtol);
}
}
#[test]
fn test_dirichlet_means() {
let n = 20000;
let rtol = 2e-2;
let seed = 1317624576693539401;
check_dirichlet_means([0.5, 0.25], n, rtol, seed);
check_dirichlet_means([123.0, 75.0], n, rtol, seed);
check_dirichlet_means([2.0, 2.5, 5.0, 7.0], n, rtol, seed);
check_dirichlet_means([0.1, 8.0, 1.0, 2.0, 2.0, 0.85, 0.05, 12.5], n, rtol, seed);
}
#[test]
fn test_dirichlet_means_very_small_alpha() {
let alpha = [0.001; 3];
let n = 10000;
let rtol = 1e-2;
let seed = 1317624576693539401;
check_dirichlet_means(alpha, n, rtol, seed);
}
#[test]
fn test_dirichlet_means_small_alpha() {
let alpha = [0.05, 0.025, 0.075, 0.05];
let n = 150000;
let rtol = 1e-3;
let seed = 1317624576693539401;
check_dirichlet_means(alpha, n, rtol, seed);
}
}