use crate::array::Array;
use crate::new_modules::probabilistic::{
validate_non_negative, validate_positive, ProbabilisticError, Result,
};
use scirs2_core::ndarray::Array1;
use scirs2_core::random::{Distribution, Rng, RngExt};
use std::f64::consts::PI;
#[derive(Debug, Clone)]
pub struct BetaDistribution {
alpha: f64,
beta: f64,
log_beta: f64,
}
impl BetaDistribution {
pub fn new(alpha: f64, beta: f64) -> Result<Self> {
validate_positive(alpha, "alpha")?;
validate_positive(beta, "beta")?;
let log_beta = gamma_ln(alpha) + gamma_ln(beta) - gamma_ln(alpha + beta);
Ok(Self {
alpha,
beta,
log_beta,
})
}
pub fn alpha(&self) -> f64 {
self.alpha
}
pub fn beta(&self) -> f64 {
self.beta
}
pub fn pdf(&self, x: f64) -> f64 {
if !(0.0..=1.0).contains(&x) {
return 0.0;
}
if x == 0.0 {
return if self.alpha < 1.0 {
f64::INFINITY
} else if self.alpha == 1.0 {
self.beta
} else {
0.0
};
}
if x == 1.0 {
return if self.beta < 1.0 {
f64::INFINITY
} else if self.beta == 1.0 {
self.alpha
} else {
0.0
};
}
(self.log_pdf(x)).exp()
}
pub fn log_pdf(&self, x: f64) -> f64 {
if !(0.0..=1.0).contains(&x) {
return f64::NEG_INFINITY;
}
if x == 0.0 || x == 1.0 {
return f64::NEG_INFINITY; }
(self.alpha - 1.0) * x.ln() + (self.beta - 1.0) * (1.0 - x).ln() - self.log_beta
}
pub fn sample<R: Rng>(&self, rng: &mut R) -> Result<f64> {
let gamma1 = GammaDistribution::new(self.alpha, 1.0)?;
let gamma2 = GammaDistribution::new(self.beta, 1.0)?;
let x = gamma1.sample(rng)?;
let y = gamma2.sample(rng)?;
Ok(x / (x + y))
}
pub fn mean(&self) -> f64 {
self.alpha / (self.alpha + self.beta)
}
pub fn variance(&self) -> f64 {
let sum = self.alpha + self.beta;
(self.alpha * self.beta) / (sum * sum * (sum + 1.0))
}
}
#[derive(Debug, Clone)]
pub struct GammaDistribution {
alpha: f64,
beta: f64,
log_norm: f64,
}
impl GammaDistribution {
pub fn new(alpha: f64, beta: f64) -> Result<Self> {
validate_positive(alpha, "alpha")?;
validate_positive(beta, "beta")?;
let log_norm = alpha * beta.ln() - gamma_ln(alpha);
Ok(Self {
alpha,
beta,
log_norm,
})
}
pub fn alpha(&self) -> f64 {
self.alpha
}
pub fn beta(&self) -> f64 {
self.beta
}
pub fn pdf(&self, x: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
self.log_pdf(x).exp()
}
pub fn log_pdf(&self, x: f64) -> f64 {
if x <= 0.0 {
return f64::NEG_INFINITY;
}
self.log_norm + (self.alpha - 1.0) * x.ln() - self.beta * x
}
pub fn sample<R: Rng>(&self, rng: &mut R) -> Result<f64> {
let (alpha_adj, correction) = if self.alpha < 1.0 {
(self.alpha + 1.0, rng.random::<f64>().powf(1.0 / self.alpha))
} else {
(self.alpha, 1.0)
};
let d = alpha_adj - 1.0 / 3.0;
let c = 1.0 / (9.0 * d).sqrt();
loop {
let z = sample_standard_normal(rng);
let v = (1.0 + c * z).powi(3);
if v <= 0.0 {
continue;
}
let u: f64 = rng.random();
let z2 = z * z;
if u < 1.0 - 0.0331 * z2 * z2 {
return Ok(d * v * correction / self.beta);
}
if u.ln() < 0.5 * z2 + d * (1.0 - v + v.ln()) {
return Ok(d * v * correction / self.beta);
}
}
}
pub fn mean(&self) -> f64 {
self.alpha / self.beta
}
pub fn variance(&self) -> f64 {
self.alpha / (self.beta * self.beta)
}
}
#[derive(Debug, Clone)]
pub struct DirichletDistribution {
alpha: Vec<f64>,
alpha_sum: f64,
}
impl DirichletDistribution {
pub fn new(alpha: Vec<f64>) -> Result<Self> {
if alpha.is_empty() {
return Err(ProbabilisticError::InvalidParameter {
parameter: "alpha".to_string(),
message: "alpha vector cannot be empty".to_string(),
});
}
for (i, &a) in alpha.iter().enumerate() {
validate_positive(a, &format!("alpha[{}]", i))?;
}
let alpha_sum = alpha.iter().sum();
Ok(Self { alpha, alpha_sum })
}
pub fn alpha(&self) -> &Vec<f64> {
&self.alpha
}
pub fn log_pdf(&self, x: &[f64]) -> Result<f64> {
if x.len() != self.alpha.len() {
return Err(ProbabilisticError::DimensionMismatch {
expected: vec![self.alpha.len()],
actual: vec![x.len()],
operation: "Dirichlet log_pdf".to_string(),
});
}
let sum: f64 = x.iter().sum();
if (sum - 1.0).abs() > 1e-10 {
return Err(ProbabilisticError::InvalidParameter {
parameter: "x".to_string(),
message: format!("x must sum to 1, got sum = {}", sum),
});
}
for (i, &xi) in x.iter().enumerate() {
if xi < 0.0 {
return Err(ProbabilisticError::InvalidParameter {
parameter: format!("x[{}]", i),
message: "elements must be non-negative".to_string(),
});
}
}
let mut log_prob = gamma_ln(self.alpha_sum);
for i in 0..self.alpha.len() {
log_prob -= gamma_ln(self.alpha[i]);
if x[i] > 0.0 {
log_prob += (self.alpha[i] - 1.0) * x[i].ln();
} else if self.alpha[i] > 1.0 {
return Ok(f64::NEG_INFINITY);
}
}
Ok(log_prob)
}
pub fn sample<R: Rng>(&self, rng: &mut R) -> Result<Vec<f64>> {
let mut samples = Vec::with_capacity(self.alpha.len());
let mut sum = 0.0;
for &alpha_i in &self.alpha {
let gamma = GammaDistribution::new(alpha_i, 1.0)?;
let x = gamma.sample(rng)?;
samples.push(x);
sum += x;
}
for sample in &mut samples {
*sample /= sum;
}
Ok(samples)
}
pub fn mean(&self) -> Vec<f64> {
self.alpha.iter().map(|&a| a / self.alpha_sum).collect()
}
}
#[derive(Debug, Clone)]
pub struct StudentTDistribution {
nu: f64,
log_norm: f64,
}
impl StudentTDistribution {
pub fn new(nu: f64) -> Result<Self> {
validate_positive(nu, "nu")?;
let log_norm = gamma_ln((nu + 1.0) / 2.0) - 0.5 * (nu * PI).ln() - gamma_ln(nu / 2.0);
Ok(Self { nu, log_norm })
}
pub fn pdf(&self, x: f64) -> f64 {
self.log_pdf(x).exp()
}
pub fn log_pdf(&self, x: f64) -> f64 {
self.log_norm - ((self.nu + 1.0) / 2.0) * (1.0 + x * x / self.nu).ln()
}
pub fn sample<R: Rng>(&self, rng: &mut R) -> Result<f64> {
let z = sample_standard_normal(rng);
let chi_squared = GammaDistribution::new(self.nu / 2.0, 0.5)?;
let v = chi_squared.sample(rng)?;
Ok(z / (v / self.nu).sqrt())
}
pub fn mean(&self) -> Option<f64> {
if self.nu > 1.0 {
Some(0.0)
} else {
None
}
}
pub fn variance(&self) -> Option<f64> {
if self.nu > 2.0 {
Some(self.nu / (self.nu - 2.0))
} else {
None
}
}
}
#[derive(Debug, Clone)]
pub struct LogNormalDistribution {
mu: f64,
sigma: f64,
}
impl LogNormalDistribution {
pub fn new(mu: f64, sigma: f64) -> Result<Self> {
validate_positive(sigma, "sigma")?;
Ok(Self { mu, sigma })
}
pub fn pdf(&self, x: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
self.log_pdf(x).exp()
}
pub fn log_pdf(&self, x: f64) -> f64 {
if x <= 0.0 {
return f64::NEG_INFINITY;
}
let log_x = x.ln();
-log_x
- 0.5 * (2.0 * PI).ln()
- self.sigma.ln()
- 0.5 * ((log_x - self.mu) / self.sigma).powi(2)
}
pub fn sample<R: Rng>(&self, rng: &mut R) -> Result<f64> {
let z = sample_standard_normal(rng);
Ok((self.mu + self.sigma * z).exp())
}
pub fn mean(&self) -> f64 {
(self.mu + 0.5 * self.sigma * self.sigma).exp()
}
pub fn variance(&self) -> f64 {
let exp_sigma2 = (self.sigma * self.sigma).exp();
(2.0 * self.mu).exp() * exp_sigma2 * (exp_sigma2 - 1.0)
}
}
#[derive(Debug, Clone)]
pub struct VonMisesDistribution {
mu: f64,
kappa: f64,
}
impl VonMisesDistribution {
pub fn new(mu: f64, kappa: f64) -> Result<Self> {
validate_non_negative(kappa, "kappa")?;
let mu_normalized = mu.rem_euclid(2.0 * PI);
Ok(Self {
mu: mu_normalized,
kappa,
})
}
pub fn pdf(&self, x: f64) -> f64 {
let x_normalized = x.rem_euclid(2.0 * PI);
let i0_kappa = bessel_i0(self.kappa);
(self.kappa * (x_normalized - self.mu).cos()).exp() / (2.0 * PI * i0_kappa)
}
pub fn log_pdf(&self, x: f64) -> f64 {
let x_normalized = x.rem_euclid(2.0 * PI);
let log_i0_kappa = bessel_i0(self.kappa).ln();
self.kappa * (x_normalized - self.mu).cos() - (2.0 * PI).ln() - log_i0_kappa
}
pub fn sample<R: Rng>(&self, rng: &mut R) -> Result<f64> {
if self.kappa < 1e-6 {
return Ok(rng.random::<f64>() * 2.0 * PI);
}
let tau = 1.0 + (1.0 + 4.0 * self.kappa * self.kappa).sqrt();
let rho = (tau - (2.0 * tau).sqrt()) / (2.0 * self.kappa);
let r = (1.0 + rho * rho) / (2.0 * rho);
loop {
let u1: f64 = rng.random();
let u2: f64 = rng.random();
let u3: f64 = rng.random();
if u2 < 1e-10 {
continue;
}
let z = (u1 - 0.5) / u2;
let f = (1.0 + r * z) / (r + z);
if !f.is_finite() || !(-1.0..=1.0).contains(&f) {
continue;
}
let c = self.kappa * (r - f);
if c * (2.0 - c) - u3 > 0.0 || (c > 0.0 && c.ln() - c + 1.0 - u3 >= 0.0) {
let theta = self.mu + f.acos() * if rng.random::<bool>() { 1.0 } else { -1.0 };
return Ok(theta.rem_euclid(2.0 * PI));
}
}
}
pub fn mean_direction(&self) -> f64 {
self.mu
}
pub fn circular_variance(&self) -> f64 {
1.0 - bessel_i1(self.kappa) / bessel_i0(self.kappa)
}
}
fn gamma_ln(x: f64) -> f64 {
if x < 0.5 {
let pi = std::f64::consts::PI;
return (pi / ((pi * x).sin())).ln() - gamma_ln(1.0 - x);
}
if x < 12.0 {
let mut result = 0.0;
let mut x_curr = x;
while x_curr < 12.0 {
result -= x_curr.ln();
x_curr += 1.0;
}
result + gamma_ln_stirling(x_curr)
} else {
gamma_ln_stirling(x)
}
}
fn gamma_ln_stirling(x: f64) -> f64 {
let log_sqrt_2pi = 0.5 * (2.0 * std::f64::consts::PI).ln();
(x - 0.5) * x.ln() - x + log_sqrt_2pi + (1.0 / (12.0 * x)) - (1.0 / (360.0 * x.powi(3)))
}
fn sample_standard_normal<R: Rng>(rng: &mut R) -> f64 {
let u1: f64 = rng.random();
let u2: f64 = rng.random();
let r = (-2.0 * u1.ln()).sqrt();
let theta = 2.0 * PI * u2;
r * theta.cos()
}
fn bessel_i0(x: f64) -> f64 {
let ax = x.abs();
if ax < 3.75 {
let t = (x / 3.75).powi(2);
1.0 + t
* (3.5156229
+ t * (3.0899424
+ t * (1.2067492 + t * (0.2659732 + t * (0.0360768 + t * 0.0045813)))))
} else {
let t = 3.75 / ax;
(ax.exp() / ax.sqrt())
* (0.39894228
+ t * (0.01328592
+ t * (0.00225319
+ t * (-0.00157565
+ t * (0.00916281
+ t * (-0.02057706
+ t * (0.02635537 + t * (-0.01647633 + t * 0.00392377))))))))
}
}
fn bessel_i1(x: f64) -> f64 {
let ax = x.abs();
if ax < 3.75 {
let t = (x / 3.75).powi(2);
let result = ax
* (0.5
+ t * (0.87890594
+ t * (0.51498869
+ t * (0.15084934
+ t * (0.02658733 + t * (0.00301532 + t * 0.00032411))))));
if x < 0.0 {
-result
} else {
result
}
} else {
let t = 3.75 / ax;
let result = (ax.exp() / ax.sqrt())
* (0.39894228
+ t * (-0.03988024
+ t * (-0.00362018
+ t * (0.00163801
+ t * (-0.01031555
+ t * (0.02282967
+ t * (-0.02895312 + t * (0.01787654 - t * 0.00420059))))))));
if x < 0.0 {
-result
} else {
result
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::random::thread_rng;
#[test]
fn test_beta_distribution_creation() {
let beta = BetaDistribution::new(2.0, 5.0);
assert!(beta.is_ok());
let beta = BetaDistribution::new(-1.0, 5.0);
assert!(beta.is_err());
let beta = BetaDistribution::new(2.0, 0.0);
assert!(beta.is_err());
}
#[test]
fn test_beta_distribution_pdf() {
let beta =
BetaDistribution::new(2.0, 5.0).expect("test: valid Beta distribution parameters");
assert_eq!(beta.pdf(-0.1), 0.0);
assert_eq!(beta.pdf(1.1), 0.0);
assert!(beta.pdf(0.5) > 0.0);
assert!(beta.pdf(0.3) > 0.0);
}
#[test]
fn test_beta_distribution_moments() {
let beta =
BetaDistribution::new(2.0, 5.0).expect("test: valid Beta distribution parameters");
assert_relative_eq!(beta.mean(), 2.0 / 7.0, epsilon = 1e-10);
let expected_var = (2.0 * 5.0) / (49.0 * 8.0);
assert_relative_eq!(beta.variance(), expected_var, epsilon = 1e-10);
}
#[test]
fn test_beta_distribution_sampling() {
let beta =
BetaDistribution::new(2.0, 5.0).expect("test: valid Beta distribution parameters");
let mut rng = thread_rng();
for _ in 0..100 {
let sample = beta.sample(&mut rng).expect("test: valid Beta sample");
assert!((0.0..=1.0).contains(&sample));
}
}
#[test]
fn test_gamma_distribution_creation() {
let gamma = GammaDistribution::new(2.0, 1.0);
assert!(gamma.is_ok());
let gamma = GammaDistribution::new(0.0, 1.0);
assert!(gamma.is_err());
}
#[test]
fn test_gamma_distribution_pdf() {
let gamma =
GammaDistribution::new(2.0, 1.0).expect("test: valid Gamma distribution parameters");
assert_eq!(gamma.pdf(0.0), 0.0);
assert_eq!(gamma.pdf(-1.0), 0.0);
assert!(gamma.pdf(1.0) > 0.0);
assert!(gamma.pdf(2.0) > 0.0);
}
#[test]
fn test_gamma_distribution_moments() {
let gamma =
GammaDistribution::new(2.0, 1.0).expect("test: valid Gamma distribution parameters");
assert_relative_eq!(gamma.mean(), 2.0, epsilon = 1e-10);
assert_relative_eq!(gamma.variance(), 2.0, epsilon = 1e-10);
}
#[test]
fn test_gamma_distribution_sampling() {
let gamma =
GammaDistribution::new(2.0, 1.0).expect("test: valid Gamma distribution parameters");
let mut rng = thread_rng();
for _ in 0..100 {
let sample = gamma.sample(&mut rng).expect("test: valid Gamma sample");
assert!(sample > 0.0);
}
}
#[test]
fn test_dirichlet_distribution_creation() {
let dir = DirichletDistribution::new(vec![1.0, 2.0, 3.0]);
assert!(dir.is_ok());
let dir = DirichletDistribution::new(vec![1.0, -1.0, 3.0]);
assert!(dir.is_err());
let dir = DirichletDistribution::new(vec![]);
assert!(dir.is_err());
}
#[test]
fn test_dirichlet_distribution_mean() {
let dir = DirichletDistribution::new(vec![1.0, 2.0, 3.0])
.expect("test: valid Dirichlet distribution parameters");
let mean = dir.mean();
assert_eq!(mean.len(), 3);
assert_relative_eq!(mean[0], 1.0 / 6.0, epsilon = 1e-10);
assert_relative_eq!(mean[1], 2.0 / 6.0, epsilon = 1e-10);
assert_relative_eq!(mean[2], 3.0 / 6.0, epsilon = 1e-10);
}
#[test]
fn test_dirichlet_distribution_sampling() {
let dir = DirichletDistribution::new(vec![1.0, 2.0, 3.0])
.expect("test: valid Dirichlet distribution parameters");
let mut rng = thread_rng();
for _ in 0..100 {
let sample = dir.sample(&mut rng).expect("test: valid Dirichlet sample");
let sum: f64 = sample.iter().sum();
assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
for &x in &sample {
assert!(x >= 0.0);
}
}
}
#[test]
fn test_student_t_distribution_creation() {
let t = StudentTDistribution::new(5.0);
assert!(t.is_ok());
let t = StudentTDistribution::new(0.0);
assert!(t.is_err());
}
#[test]
fn test_student_t_distribution_moments() {
let t =
StudentTDistribution::new(5.0).expect("test: valid Student-T distribution parameters");
assert_eq!(t.mean(), Some(0.0));
assert_relative_eq!(
t.variance().expect("test: variance exists for df>2"),
5.0 / 3.0,
epsilon = 1e-10
);
let cauchy =
StudentTDistribution::new(1.0).expect("test: valid Student-T distribution parameters");
assert_eq!(cauchy.mean(), None);
assert_eq!(cauchy.variance(), None);
}
#[test]
fn test_student_t_distribution_sampling() {
let t =
StudentTDistribution::new(5.0).expect("test: valid Student-T distribution parameters");
let mut rng = thread_rng();
for _ in 0..100 {
let _sample = t.sample(&mut rng).expect("test: valid Student-T sample");
}
}
#[test]
fn test_lognormal_distribution_creation() {
let ln = LogNormalDistribution::new(0.0, 1.0);
assert!(ln.is_ok());
let ln = LogNormalDistribution::new(0.0, -1.0);
assert!(ln.is_err());
}
#[test]
fn test_lognormal_distribution_pdf() {
let ln = LogNormalDistribution::new(0.0, 1.0)
.expect("test: valid LogNormal distribution parameters");
assert_eq!(ln.pdf(0.0), 0.0);
assert_eq!(ln.pdf(-1.0), 0.0);
assert!(ln.pdf(1.0) > 0.0);
}
#[test]
fn test_lognormal_distribution_sampling() {
let ln = LogNormalDistribution::new(0.0, 1.0)
.expect("test: valid LogNormal distribution parameters");
let mut rng = thread_rng();
for _ in 0..100 {
let sample = ln.sample(&mut rng).expect("test: valid LogNormal sample");
assert!(sample > 0.0);
}
}
#[test]
fn test_von_mises_distribution_creation() {
let vm = VonMisesDistribution::new(0.0, 1.0);
assert!(vm.is_ok());
let vm = VonMisesDistribution::new(0.0, -1.0);
assert!(vm.is_err());
}
#[test]
fn test_von_mises_distribution_pdf() {
let vm = VonMisesDistribution::new(0.0, 1.0)
.expect("test: valid VonMises distribution parameters");
assert!(vm.pdf(0.0) > 0.0);
assert!(vm.pdf(PI) > 0.0);
}
#[test]
fn test_von_mises_distribution_sampling() {
let vm = VonMisesDistribution::new(0.0, 1.0)
.expect("test: valid VonMises distribution parameters");
let mut rng = thread_rng();
for _ in 0..100 {
let sample = vm.sample(&mut rng).expect("test: valid VonMises sample");
assert!(sample.is_finite());
let normalized = sample.rem_euclid(2.0 * PI);
assert!((0.0..2.0 * PI).contains(&normalized));
}
}
#[test]
fn test_helper_functions() {
let g1 = gamma_ln(1.0);
assert_relative_eq!(g1, 0.0, epsilon = 1e-6);
let g2 = gamma_ln(2.0);
assert_relative_eq!(g2, 0.0, epsilon = 1e-6);
let i0_0 = bessel_i0(0.0);
assert_relative_eq!(i0_0, 1.0, epsilon = 1e-6);
let i1_0 = bessel_i1(0.0);
assert_relative_eq!(i1_0, 0.0, epsilon = 1e-6); }
#[test]
fn test_sample_standard_normal() {
let mut rng = thread_rng();
for _ in 0..100 {
let _sample = sample_standard_normal(&mut rng);
}
}
}