use crate::error::{StatsError, StatsResult};
use crate::sampling::SampleableDistribution;
use scirs2_core::ndarray::{Array1, ArrayBase, Data, Ix1};
use scirs2_core::random::prelude::*;
use scirs2_core::random::{Distribution, Gamma as RandGamma};
use std::fmt::Debug;
#[allow(dead_code)]
fn lgamma(x: f64) -> f64 {
if x <= 0.0 {
panic!("lgamma requires positive input");
}
if x.fract() == 0.0 && x <= 20.0 {
let n = x as usize;
if n == 1 || n == 2 {
return 0.0; }
let mut result = 0.0;
for i in 2..n {
result += (i as f64).ln();
}
return result;
}
if (x - 0.5).abs() < 1e-10 {
return (std::f64::consts::PI.sqrt()).ln();
}
if x > 1.0 {
return (x - 1.0).ln() + lgamma(x - 1.0);
}
if x < 1.0 {
return (std::f64::consts::PI / (std::f64::consts::PI * x).sin()).ln() - lgamma(1.0 - x);
}
let p = [
676.5203681218851,
-1259.1392167224028,
771.323_428_777_653_1,
-176.615_029_162_140_6,
12.507343278686905,
-0.13857109526572012,
9.984_369_578_019_572e-6,
1.5056327351493116e-7,
];
let x_adj = x - 1.0;
let t = x_adj + 7.5;
let mut sum = 0.0;
for (i, &coef) in p.iter().enumerate() {
sum += coef / (x_adj + (i + 1) as f64);
}
let pi = std::f64::consts::PI;
let sqrt_2pi = (2.0 * pi).sqrt();
sqrt_2pi.ln() + sum.ln() + (x_adj + 0.5) * t.ln() - t
}
#[derive(Debug, Clone)]
pub struct Dirichlet {
pub alpha: Array1<f64>,
pub dim: usize,
log_norm_const: f64,
}
impl Dirichlet {
pub fn new<D>(alpha: ArrayBase<D, Ix1>) -> StatsResult<Self>
where
D: Data<Elem = f64>,
{
let alpha_owned = alpha.to_owned();
let dim = alpha_owned.len();
for &a in alpha_owned.iter() {
if a <= 0.0 {
return Err(StatsError::DomainError(
"All concentration parameters must be positive".to_string(),
));
}
}
let alpha_sum = alpha_owned.sum();
let mut log_norm_const = 0.0;
for &a in alpha_owned.iter() {
log_norm_const += lgamma(a);
}
log_norm_const -= lgamma(alpha_sum);
Ok(Dirichlet {
alpha: alpha_owned,
dim,
log_norm_const,
})
}
pub fn pdf<D>(&self, x: &ArrayBase<D, Ix1>) -> f64
where
D: Data<Elem = f64>,
{
if x.len() != self.dim {
return 0.0; }
let sum: f64 = x.iter().sum();
if (sum - 1.0).abs() > 1e-10 {
return 0.0; }
for &val in x.iter() {
if val <= 0.0 || val >= 1.0 {
return 0.0; }
}
let log_pdf = self.logpdf(x);
log_pdf.exp()
}
pub fn logpdf<D>(&self, x: &ArrayBase<D, Ix1>) -> f64
where
D: Data<Elem = f64>,
{
if x.len() != self.dim {
return f64::NEG_INFINITY; }
let sum: f64 = x.iter().sum();
if (sum - 1.0).abs() > 1e-10 {
return f64::NEG_INFINITY; }
for &val in x.iter() {
if val <= 0.0 || val >= 1.0 {
return f64::NEG_INFINITY; }
}
let mut log_pdf = -self.log_norm_const;
for i in 0..self.dim {
log_pdf += (self.alpha[i] - 1.0) * x[i].ln();
}
log_pdf
}
pub fn rvs(&self, size: usize) -> StatsResult<Vec<Array1<f64>>> {
let mut rng = thread_rng();
let mut samples = Vec::with_capacity(size);
for _ in 0..size {
let mut sample = Array1::<f64>::zeros(self.dim);
let mut sum = 0.0;
for i in 0..self.dim {
let gamma_dist = RandGamma::new(self.alpha[i], 1.0).map_err(|_| {
StatsError::ComputationError("Failed to create gamma distribution".to_string())
})?;
let gamma_sample = gamma_dist.sample(&mut rng);
sample[i] = gamma_sample;
sum += gamma_sample;
}
sample.mapv_inplace(|x| x / sum);
samples.push(sample);
}
Ok(samples)
}
pub fn rvs_single(&self) -> StatsResult<Array1<f64>> {
let samples = self.rvs(1)?;
Ok(samples[0].clone())
}
}
#[allow(dead_code)]
pub fn dirichlet<D>(alpha: &ArrayBase<D, Ix1>) -> StatsResult<Dirichlet>
where
D: Data<Elem = f64>,
{
Dirichlet::new(alpha.to_owned())
}
impl SampleableDistribution<Array1<f64>> for Dirichlet {
fn rvs(&self, size: usize) -> StatsResult<Vec<Array1<f64>>> {
self.rvs(size)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_dirichlet_creation() {
let alpha = array![1.0, 1.0, 1.0];
let dirichlet = Dirichlet::new(alpha.clone()).expect("Operation failed");
assert_eq!(dirichlet.dim, 3);
assert_eq!(dirichlet.alpha, alpha);
let alpha2 = array![2.0, 3.0, 4.0];
let dirichlet2 = Dirichlet::new(alpha2.clone()).expect("Operation failed");
assert_eq!(dirichlet2.dim, 3);
assert_eq!(dirichlet2.alpha, alpha2);
}
#[test]
fn test_dirichlet_creation_errors() {
let alpha = array![1.0, 0.0, 1.0];
assert!(Dirichlet::new(alpha).is_err());
let alpha = array![1.0, -1.0, 1.0];
assert!(Dirichlet::new(alpha).is_err());
}
#[test]
fn test_dirichlet_pdf() {
let alpha = array![1.0, 1.0, 1.0];
let dirichlet = Dirichlet::new(alpha).expect("Operation failed");
let point1 = array![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
let point2 = array![0.2, 0.3, 0.5];
assert_relative_eq!(dirichlet.pdf(&point1), 2.0, epsilon = 1e-10);
assert_relative_eq!(dirichlet.pdf(&point2), 2.0, epsilon = 1e-10);
let alpha = array![5.0, 5.0, 5.0];
let concentrated = Dirichlet::new(alpha).expect("Operation failed");
let center = array![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
let edge = array![0.01, 0.01, 0.98];
assert!(concentrated.pdf(¢er) > concentrated.pdf(&edge));
}
#[test]
fn test_dirichlet_pdf_edge_cases() {
let alpha = array![1.0, 1.0, 1.0];
let dirichlet = Dirichlet::new(alpha).expect("Operation failed");
let invalid1 = array![0.3, 0.3, 0.3]; let invalid2 = array![0.5, 0.6, 0.2]; let invalid3 = array![0.0, 0.5, 0.5]; let invalid4 = array![1.0, 0.0, 0.0];
assert_eq!(dirichlet.pdf(&invalid1), 0.0);
assert_eq!(dirichlet.pdf(&invalid2), 0.0);
assert_eq!(dirichlet.pdf(&invalid3), 0.0);
assert_eq!(dirichlet.pdf(&invalid4), 0.0);
}
#[test]
fn test_dirichlet_logpdf() {
let alpha = array![1.0, 1.0, 1.0];
let dirichlet = Dirichlet::new(alpha).expect("Operation failed");
let point = array![0.3, 0.3, 0.4];
assert_relative_eq!(dirichlet.logpdf(&point), 0.693, epsilon = 1e-3);
assert_relative_eq!(
dirichlet.logpdf(&point).exp(),
dirichlet.pdf(&point),
epsilon = 1e-10
);
}
#[test]
fn test_dirichlet_rvs() {
let alpha = array![1.0, 2.0, 3.0];
let dirichlet = Dirichlet::new(alpha.clone()).expect("Operation failed");
let n_samples_ = 1000;
let samples = dirichlet.rvs(n_samples_).expect("Operation failed");
assert_eq!(samples.len(), n_samples_);
for sample in &samples {
let sum: f64 = sample.iter().sum();
assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
for &val in sample.iter() {
assert!(val >= 0.0 && val <= 1.0);
}
}
let mut sample_mean = [0.0; 3];
for sample in &samples {
for i in 0..3 {
sample_mean[i] += sample[i];
}
}
let alpha_sum = alpha.sum();
for i in 0..3 {
sample_mean[i] /= n_samples_ as f64;
let expected_mean = alpha[i] / alpha_sum;
assert_relative_eq!(sample_mean[i], expected_mean, epsilon = 0.05);
}
}
#[test]
fn test_dirichlet_rvs_single() {
let alpha = array![1.0, 2.0, 3.0];
let dirichlet = Dirichlet::new(alpha.clone()).expect("Operation failed");
let sample = dirichlet.rvs_single().expect("Operation failed");
assert_eq!(sample.len(), 3);
let sum: f64 = sample.iter().sum();
assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
for &val in sample.iter() {
assert!(val >= 0.0 && val <= 1.0);
}
}
}