#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use crate::impl_display;
use crate::misc::vec_to_string;
use crate::traits::*;
use once_cell::sync::OnceCell;
use rand::Rng;
use rand_distr::Gamma as RGamma;
use special::Gamma as _;
use std::fmt;
mod categorical_prior;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct SymmetricDirichlet {
alpha: f64,
k: usize,
#[cfg_attr(feature = "serde1", serde(skip))]
ln_gamma_alpha: OnceCell<f64>,
}
impl PartialEq for SymmetricDirichlet {
fn eq(&self, other: &Self) -> bool {
self.alpha == other.alpha && self.k == other.k
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub enum SymmetricDirichletError {
KIsZero,
AlphaTooLow { alpha: f64 },
AlphaNotFinite { alpha: f64 },
}
impl SymmetricDirichlet {
#[inline]
pub fn new(alpha: f64, k: usize) -> Result<Self, SymmetricDirichletError> {
if k == 0 {
Err(SymmetricDirichletError::KIsZero)
} else if alpha <= 0.0 {
Err(SymmetricDirichletError::AlphaTooLow { alpha })
} else if !alpha.is_finite() {
Err(SymmetricDirichletError::AlphaNotFinite { alpha })
} else {
Ok(Self {
alpha,
k,
ln_gamma_alpha: OnceCell::new(),
})
}
}
#[inline]
pub fn new_unchecked(alpha: f64, k: usize) -> Self {
Self {
alpha,
k,
ln_gamma_alpha: OnceCell::new(),
}
}
#[inline]
pub fn jeffreys(k: usize) -> Result<Self, SymmetricDirichletError> {
if k == 0 {
Err(SymmetricDirichletError::KIsZero)
} else {
Ok(Self {
alpha: 0.5,
k,
ln_gamma_alpha: OnceCell::new(),
})
}
}
#[inline]
pub fn alpha(&self) -> f64 {
self.alpha
}
#[inline]
pub fn set_alpha(
&mut self,
alpha: f64,
) -> Result<(), SymmetricDirichletError> {
if alpha <= 0.0 {
Err(SymmetricDirichletError::AlphaTooLow { alpha })
} else if !alpha.is_finite() {
Err(SymmetricDirichletError::AlphaNotFinite { alpha })
} else {
self.set_alpha_unchecked(alpha);
self.ln_gamma_alpha = OnceCell::new();
Ok(())
}
}
#[inline]
pub fn set_alpha_unchecked(&mut self, alpha: f64) {
self.alpha = alpha;
self.ln_gamma_alpha = OnceCell::new();
}
#[inline]
pub fn k(&self) -> usize {
self.k
}
#[inline]
fn ln_gamma_alpha(&self) -> f64 {
*self.ln_gamma_alpha.get_or_init(|| self.alpha.ln_gamma().0)
}
}
impl From<&SymmetricDirichlet> for String {
fn from(symdir: &SymmetricDirichlet) -> String {
format!("SymmetricDirichlet({}; α: {})", symdir.k, symdir.alpha)
}
}
impl_display!(SymmetricDirichlet);
impl Rv<Vec<f64>> for SymmetricDirichlet {
fn draw<R: Rng>(&self, rng: &mut R) -> Vec<f64> {
let g = RGamma::new(self.alpha, 1.0).unwrap();
let mut xs: Vec<f64> = (0..self.k).map(|_| rng.sample(g)).collect();
let z: f64 = xs.iter().sum();
xs.iter_mut().for_each(|x| *x /= z);
xs
}
fn ln_f(&self, x: &Vec<f64>) -> f64 {
let kf = self.k as f64;
let sum_ln_gamma = self.ln_gamma_alpha() * kf;
let ln_gamma_sum = (self.alpha * kf).ln_gamma().0;
let am1 = self.alpha - 1.0;
let term = x.iter().fold(0.0, |acc, &xi| am1.mul_add(xi.ln(), acc));
term - (sum_ln_gamma - ln_gamma_sum)
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub enum DirichletError {
KIsZero,
AlphasEmpty,
AlphaTooLow { ix: usize, alpha: f64 },
AlphaNotFinite { ix: usize, alpha: f64 },
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Dirichlet {
pub(crate) alphas: Vec<f64>,
}
impl From<SymmetricDirichlet> for Dirichlet {
fn from(symdir: SymmetricDirichlet) -> Self {
Dirichlet::new_unchecked(vec![symdir.alpha; symdir.k])
}
}
impl From<&SymmetricDirichlet> for Dirichlet {
fn from(symdir: &SymmetricDirichlet) -> Self {
Dirichlet::new_unchecked(vec![symdir.alpha; symdir.k])
}
}
impl Dirichlet {
pub fn new(alphas: Vec<f64>) -> Result<Self, DirichletError> {
if alphas.is_empty() {
return Err(DirichletError::AlphasEmpty);
}
alphas.iter().enumerate().try_for_each(|(ix, &alpha)| {
if alpha <= 0.0 {
Err(DirichletError::AlphaTooLow { ix, alpha })
} else if !alpha.is_finite() {
Err(DirichletError::AlphaNotFinite { ix, alpha })
} else {
Ok(())
}
})?;
Ok(Dirichlet { alphas })
}
#[inline]
pub fn new_unchecked(alphas: Vec<f64>) -> Self {
Dirichlet { alphas }
}
#[inline]
pub fn symmetric(alpha: f64, k: usize) -> Result<Self, DirichletError> {
if k == 0 {
Err(DirichletError::KIsZero)
} else if alpha <= 0.0 {
Err(DirichletError::AlphaTooLow { ix: 0, alpha })
} else if !alpha.is_finite() {
Err(DirichletError::AlphaNotFinite { ix: 0, alpha })
} else {
Ok(Dirichlet {
alphas: vec![alpha; k],
})
}
}
#[inline]
pub fn jeffreys(k: usize) -> Result<Self, DirichletError> {
if k == 0 {
Err(DirichletError::KIsZero)
} else {
Ok(Dirichlet::new_unchecked(vec![0.5; k]))
}
}
#[inline]
pub fn k(&self) -> usize {
self.alphas.len()
}
#[inline]
pub fn alphas(&self) -> &Vec<f64> {
&self.alphas
}
}
impl From<&Dirichlet> for String {
fn from(dir: &Dirichlet) -> String {
format!("Dir(α: {})", vec_to_string(&dir.alphas, 5))
}
}
impl_display!(Dirichlet);
impl ContinuousDistr<Vec<f64>> for SymmetricDirichlet {}
impl Support<Vec<f64>> for SymmetricDirichlet {
fn supports(&self, x: &Vec<f64>) -> bool {
if x.len() != self.k {
false
} else {
let sum = x.iter().fold(0.0, |acc, &xi| acc + xi);
x.iter().all(|&xi| xi > 0.0) && (1.0 - sum).abs() < 1E-12
}
}
}
impl Rv<Vec<f64>> for Dirichlet {
fn draw<R: Rng>(&self, rng: &mut R) -> Vec<f64> {
let gammas: Vec<RGamma<f64>> = self
.alphas
.iter()
.map(|&alpha| RGamma::new(alpha, 1.0).unwrap())
.collect();
let mut xs: Vec<f64> = gammas.iter().map(|g| rng.sample(g)).collect();
let z: f64 = xs.iter().sum();
xs.iter_mut().for_each(|x| *x /= z);
xs
}
fn ln_f(&self, x: &Vec<f64>) -> f64 {
let sum_ln_gamma: f64 = self
.alphas
.iter()
.fold(0.0, |acc, &alpha| acc + alpha.ln_gamma().0);
let ln_gamma_sum: f64 = self.alphas.iter().sum::<f64>().ln_gamma().0;
let term = x
.iter()
.zip(self.alphas.iter())
.fold(0.0, |acc, (&xi, &alpha)| {
(alpha - 1.0).mul_add(xi.ln(), acc)
});
term - (sum_ln_gamma - ln_gamma_sum)
}
}
impl ContinuousDistr<Vec<f64>> for Dirichlet {}
impl Support<Vec<f64>> for Dirichlet {
fn supports(&self, x: &Vec<f64>) -> bool {
if x.len() != self.alphas.len() {
false
} else {
let sum = x.iter().fold(0.0, |acc, &xi| acc + xi);
x.iter().all(|&xi| xi > 0.0) && (1.0 - sum).abs() < 1E-12
}
}
}
impl std::error::Error for SymmetricDirichletError {}
impl std::error::Error for DirichletError {}
impl fmt::Display for SymmetricDirichletError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::AlphaTooLow { alpha } => {
write!(f, "alpha ({}) must be greater than zero", alpha)
}
Self::AlphaNotFinite { alpha } => {
write!(f, "alpha ({}) was non-finite", alpha)
}
Self::KIsZero => write!(f, "k must be greater than zero"),
}
}
}
impl fmt::Display for DirichletError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::KIsZero => write!(f, "k must be greater than zero"),
Self::AlphasEmpty => write!(f, "alphas vector was empty"),
Self::AlphaTooLow { ix, alpha } => {
write!(f, "Invalid alpha at index {}: {} <= 0.0", ix, alpha)
}
Self::AlphaNotFinite { ix, alpha } => {
write!(f, "Non-finite alpha at index {}: {}", ix, alpha)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{test_basic_impls, verify_cache_resets};
const TOL: f64 = 1E-12;
mod dir {
use super::*;
test_basic_impls!(Dirichlet::jeffreys(4).unwrap(), vec![0.25_f64; 4]);
#[test]
fn properly_sized_points_on_simplex_should_be_in_support() {
let dir = Dirichlet::symmetric(1.0, 4).unwrap();
assert!(dir.supports(&vec![0.25, 0.25, 0.25, 0.25]));
assert!(dir.supports(&vec![0.1, 0.2, 0.3, 0.4]));
}
#[test]
fn improperly_sized_points_should_not_be_in_support() {
let dir = Dirichlet::symmetric(1.0, 3).unwrap();
assert!(!dir.supports(&vec![0.25, 0.25, 0.25, 0.25]));
assert!(!dir.supports(&vec![0.1, 0.2, 0.7, 0.4]));
}
#[test]
fn properly_sized_points_off_simplex_should_not_be_in_support() {
let dir = Dirichlet::symmetric(1.0, 4).unwrap();
assert!(!dir.supports(&vec![0.25, 0.25, 0.26, 0.25]));
assert!(!dir.supports(&vec![0.1, 0.3, 0.3, 0.4]));
}
#[test]
fn draws_should_be_in_support() {
let mut rng = rand::thread_rng();
let dir = Dirichlet::jeffreys(10).unwrap();
for _ in 0..100 {
let x = dir.draw(&mut rng);
assert!(dir.supports(&x));
}
}
#[test]
fn sample_should_return_the_proper_number_of_draws() {
let mut rng = rand::thread_rng();
let dir = Dirichlet::jeffreys(3).unwrap();
let xs: Vec<Vec<f64>> = dir.sample(88, &mut rng);
assert_eq!(xs.len(), 88);
}
#[test]
fn log_pdf_symemtric() {
let dir = Dirichlet::symmetric(1.0, 3).unwrap();
assert::close(
dir.ln_pdf(&vec![0.2, 0.3, 0.5]),
std::f64::consts::LN_2,
TOL,
);
}
#[test]
fn log_pdf_jeffreys() {
let dir = Dirichlet::jeffreys(3).unwrap();
assert::close(
dir.ln_pdf(&vec![0.2, 0.3, 0.5]),
-0.084_598_117_749_354_22,
TOL,
);
}
#[test]
fn log_pdf() {
let dir = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap();
assert::close(
dir.ln_pdf(&vec![0.2, 0.3, 0.5]),
1.504_077_396_776_273_7,
TOL,
);
}
}
mod symdir {
use std::f64::consts::PI;
use super::*;
test_basic_impls!(
SymmetricDirichlet::jeffreys(4).unwrap(),
vec![0.25_f64; 4]
);
#[test]
fn sample_should_return_the_proper_number_of_draws() {
let mut rng = rand::thread_rng();
let symdir = SymmetricDirichlet::jeffreys(3).unwrap();
let xs: Vec<Vec<f64>> = symdir.sample(88, &mut rng);
assert_eq!(xs.len(), 88);
}
#[test]
fn log_pdf_jeffreys() {
let symdir = SymmetricDirichlet::jeffreys(3).unwrap();
assert::close(
symdir.ln_pdf(&vec![0.2, 0.3, 0.5]),
-0.084_598_117_749_354_22,
TOL,
);
}
#[test]
fn properly_sized_points_off_simplex_should_not_be_in_support() {
let symdir = SymmetricDirichlet::new(1.0, 4).unwrap();
assert!(!symdir.supports(&vec![0.25, 0.25, 0.26, 0.25]));
assert!(!symdir.supports(&vec![0.1, 0.3, 0.3, 0.4]));
}
#[test]
fn draws_should_be_in_support() {
let mut rng = rand::thread_rng();
let symdir = SymmetricDirichlet::jeffreys(10).unwrap();
for _ in 0..100 {
let x: Vec<f64> = symdir.draw(&mut rng);
assert!(symdir.supports(&x));
}
}
verify_cache_resets!(
[unchecked],
ln_f_is_same_after_reset_unchecked_alpha_identically,
set_alpha_unchecked,
SymmetricDirichlet::new(1.2, 2).unwrap(),
vec![0.1_f64, 0.9_f64],
1.2,
PI
);
verify_cache_resets!(
[checked],
ln_f_is_same_after_reset_checked_alpha_identically,
set_alpha,
SymmetricDirichlet::new(1.2, 2).unwrap(),
vec![0.1_f64, 0.9_f64],
1.2,
PI
);
}
}