use super::special;
use crate::stats::distribution::Distribution;
use crate::stats::error::{StatsError, StatsResult};
#[derive(Debug, Clone)]
pub struct Dirichlet {
alpha: Vec<f64>,
k: usize,
alpha_sum: f64,
log_beta: f64,
}
impl Dirichlet {
pub fn new(alpha: Vec<f64>) -> StatsResult<Self> {
if alpha.len() < 2 {
return Err(StatsError::InvalidParameter {
name: "alpha".to_string(),
value: alpha.len() as f64,
reason: "Dirichlet requires at least 2 categories".to_string(),
});
}
for (i, &a) in alpha.iter().enumerate() {
if a <= 0.0 || !a.is_finite() {
return Err(StatsError::InvalidParameter {
name: format!("alpha[{}]", i),
value: a,
reason: "concentration parameter must be positive and finite".to_string(),
});
}
}
let k = alpha.len();
let alpha_sum: f64 = alpha.iter().sum();
let log_beta: f64 =
alpha.iter().map(|&a| special::lgamma(a)).sum::<f64>() - special::lgamma(alpha_sum);
Ok(Self {
alpha,
k,
alpha_sum,
log_beta,
})
}
pub fn alpha(&self) -> &[f64] {
&self.alpha
}
pub fn k(&self) -> usize {
self.k
}
pub fn alpha_sum(&self) -> f64 {
self.alpha_sum
}
pub fn log_pdf(&self, x: &[f64]) -> f64 {
assert_eq!(x.len(), self.k, "x must have length k");
let sum: f64 = x.iter().sum();
if (sum - 1.0).abs() > 1e-6 {
return f64::NEG_INFINITY;
}
for &xi in x {
if xi <= 0.0 {
return f64::NEG_INFINITY;
}
}
let mut log_p = -self.log_beta;
for (xi, ai) in x.iter().zip(self.alpha.iter()) {
log_p += (ai - 1.0) * xi.ln();
}
log_p
}
pub fn pdf(&self, x: &[f64]) -> f64 {
self.log_pdf(x).exp()
}
pub fn mean_vec(&self) -> Vec<f64> {
self.alpha.iter().map(|&a| a / self.alpha_sum).collect()
}
pub fn var_vec(&self) -> Vec<f64> {
let a0 = self.alpha_sum;
let denom = a0 * a0 * (a0 + 1.0);
self.alpha.iter().map(|&a| a * (a0 - a) / denom).collect()
}
pub fn cov_matrix(&self) -> Vec<Vec<f64>> {
let a0 = self.alpha_sum;
let denom = a0 * a0 * (a0 + 1.0);
let mut cov = vec![vec![0.0; self.k]; self.k];
for (i, row) in cov.iter_mut().enumerate().take(self.k) {
for (j, cell) in row.iter_mut().enumerate().take(self.k) {
if i == j {
*cell = self.alpha[i] * (a0 - self.alpha[i]) / denom;
} else {
*cell = -self.alpha[i] * self.alpha[j] / denom;
}
}
}
cov
}
pub fn mode_vec(&self) -> Option<Vec<f64>> {
if self.alpha.iter().any(|&a| a <= 1.0) {
return None;
}
let denom = self.alpha_sum - self.k as f64;
Some(self.alpha.iter().map(|&a| (a - 1.0) / denom).collect())
}
}
impl Distribution for Dirichlet {
fn mean(&self) -> f64 {
self.alpha[0] / self.alpha_sum
}
fn var(&self) -> f64 {
let a0 = self.alpha_sum;
self.alpha[0] * (a0 - self.alpha[0]) / (a0 * a0 * (a0 + 1.0))
}
fn entropy(&self) -> f64 {
let a0 = self.alpha_sum;
let mut h = self.log_beta + (a0 - self.k as f64) * special::digamma(a0);
for &ai in &self.alpha {
h -= (ai - 1.0) * special::digamma(ai);
}
h
}
fn median(&self) -> f64 {
self.mean()
}
fn mode(&self) -> f64 {
if self.alpha[0] > 1.0 {
(self.alpha[0] - 1.0) / (self.alpha_sum - self.k as f64)
} else {
0.0
}
}
fn skewness(&self) -> f64 {
let a0 = self.alpha_sum;
let ai = self.alpha[0];
let b = a0 - ai;
2.0 * (b - ai) * (a0 + 1.0).sqrt() / ((a0 + 2.0) * (ai * b).sqrt())
}
fn kurtosis(&self) -> f64 {
let a0 = self.alpha_sum;
let ai = self.alpha[0];
let b = a0 - ai;
let num = 6.0
* (ai.powi(3) - ai.powi(2) * (2.0 * b - 1.0) + b.powi(2) * (b + 1.0)
- 2.0 * ai * b * (b + 2.0));
let den = ai * b * (a0 + 2.0) * (a0 + 3.0);
num / den
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dirichlet_creation() {
let d = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap();
assert_eq!(d.k(), 3);
assert!((d.alpha_sum() - 6.0).abs() < 1e-10);
assert!(Dirichlet::new(vec![1.0]).is_err());
assert!(Dirichlet::new(vec![1.0, -1.0]).is_err());
assert!(Dirichlet::new(vec![0.0, 1.0]).is_err());
}
#[test]
fn test_dirichlet_mean() {
let d = Dirichlet::new(vec![2.0, 3.0, 5.0]).unwrap();
let mean = d.mean_vec();
assert!((mean[0] - 0.2).abs() < 1e-10);
assert!((mean[1] - 0.3).abs() < 1e-10);
assert!((mean[2] - 0.5).abs() < 1e-10);
}
#[test]
fn test_dirichlet_pdf_at_mean() {
let d = Dirichlet::new(vec![2.0, 2.0, 2.0]).unwrap();
let mean = d.mean_vec();
let pdf_val = d.pdf(&mean);
assert!(pdf_val > 0.0);
assert!(pdf_val.is_finite());
}
#[test]
fn test_dirichlet_pdf_outside_simplex() {
let d = Dirichlet::new(vec![2.0, 3.0]).unwrap();
assert_eq!(d.pdf(&[0.3, 0.3]), 0.0);
assert_eq!(d.pdf(&[-0.1, 1.1]), 0.0);
}
#[test]
fn test_dirichlet_uniform() {
let d = Dirichlet::new(vec![1.0, 1.0, 1.0]).unwrap();
let p1 = d.pdf(&[0.2, 0.3, 0.5]);
let p2 = d.pdf(&[0.1, 0.1, 0.8]);
assert!((p1 - p2).abs() < 1e-10);
}
#[test]
fn test_dirichlet_covariance() {
let d = Dirichlet::new(vec![1.0, 1.0]).unwrap();
let cov = d.cov_matrix();
assert!((cov[0][0] - 1.0 / 12.0).abs() < 1e-10);
assert!((cov[0][1] - (-1.0 / 12.0)).abs() < 1e-10);
}
#[test]
fn test_dirichlet_mode() {
let d = Dirichlet::new(vec![3.0, 5.0, 2.0]).unwrap();
let mode = d.mode_vec().unwrap();
assert!((mode[0] - 2.0 / 7.0).abs() < 1e-10);
assert!((mode[1] - 4.0 / 7.0).abs() < 1e-10);
assert!((mode[2] - 1.0 / 7.0).abs() < 1e-10);
let d2 = Dirichlet::new(vec![0.5, 2.0]).unwrap();
assert!(d2.mode_vec().is_none());
}
#[test]
fn test_dirichlet_entropy() {
let d = Dirichlet::new(vec![1.0, 1.0]).unwrap();
assert!(d.entropy().abs() < 1e-10);
}
#[test]
fn test_dirichlet_symmetric_skewness() {
let d = Dirichlet::new(vec![5.0, 5.0]).unwrap();
assert!(d.skewness().abs() < 1e-10);
}
}