use crate::error::PramanaError;
use crate::math;
use crate::math::erf;
use crate::rng::Rng;
use serde::{Deserialize, Serialize};
use std::f64::consts::{PI, SQRT_2};
pub trait Distribution {
fn pdf(&self, x: f64) -> f64;
fn cdf(&self, x: f64) -> f64;
fn mean(&self) -> f64;
fn variance(&self) -> f64;
fn sample(&self, rng: &mut impl Rng) -> f64;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct Normal {
pub mean: f64,
pub std_dev: f64,
}
impl Normal {
pub fn new(mean: f64, std_dev: f64) -> Result<Self, PramanaError> {
if std_dev <= 0.0 {
return Err(PramanaError::InvalidParameter(
"std_dev must be positive".into(),
));
}
Ok(Self { mean, std_dev })
}
}
impl Distribution for Normal {
#[inline]
fn pdf(&self, x: f64) -> f64 {
let z = (x - self.mean) / self.std_dev;
(1.0 / (self.std_dev * (2.0 * PI).sqrt())) * (-0.5 * z * z).exp()
}
#[inline]
fn cdf(&self, x: f64) -> f64 {
let z = (x - self.mean) / (self.std_dev * SQRT_2);
0.5 * (1.0 + erf(z))
}
#[inline]
fn mean(&self) -> f64 {
self.mean
}
#[inline]
fn variance(&self) -> f64 {
self.std_dev * self.std_dev
}
fn sample(&self, rng: &mut impl Rng) -> f64 {
let u1 = rng.next_f64().max(f64::MIN_POSITIVE); let u2 = rng.next_f64();
let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
self.mean + self.std_dev * z
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct Uniform {
pub min: f64,
pub max: f64,
}
impl Uniform {
pub fn new(min: f64, max: f64) -> Result<Self, PramanaError> {
if min >= max {
return Err(PramanaError::InvalidParameter(
"min must be less than max".into(),
));
}
Ok(Self { min, max })
}
}
impl Distribution for Uniform {
#[inline]
fn pdf(&self, x: f64) -> f64 {
if x >= self.min && x <= self.max {
1.0 / (self.max - self.min)
} else {
0.0
}
}
#[inline]
fn cdf(&self, x: f64) -> f64 {
if x < self.min {
0.0
} else if x > self.max {
1.0
} else {
(x - self.min) / (self.max - self.min)
}
}
#[inline]
fn mean(&self) -> f64 {
(self.min + self.max) / 2.0
}
#[inline]
fn variance(&self) -> f64 {
let range = self.max - self.min;
range * range / 12.0
}
#[inline]
fn sample(&self, rng: &mut impl Rng) -> f64 {
self.min + rng.next_f64() * (self.max - self.min)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct Exponential {
pub lambda: f64,
}
impl Exponential {
pub fn new(lambda: f64) -> Result<Self, PramanaError> {
if lambda <= 0.0 {
return Err(PramanaError::InvalidParameter(
"lambda must be positive".into(),
));
}
Ok(Self { lambda })
}
}
impl Distribution for Exponential {
#[inline]
fn pdf(&self, x: f64) -> f64 {
if x < 0.0 {
0.0
} else {
self.lambda * (-self.lambda * x).exp()
}
}
#[inline]
fn cdf(&self, x: f64) -> f64 {
if x < 0.0 {
0.0
} else {
1.0 - (-self.lambda * x).exp()
}
}
#[inline]
fn mean(&self) -> f64 {
1.0 / self.lambda
}
#[inline]
fn variance(&self) -> f64 {
1.0 / (self.lambda * self.lambda)
}
#[inline]
fn sample(&self, rng: &mut impl Rng) -> f64 {
let u = rng.next_f64().max(f64::MIN_POSITIVE);
-(1.0 - u).ln() / self.lambda
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct Poisson {
pub lambda: f64,
}
impl Poisson {
pub fn new(lambda: f64) -> Result<Self, PramanaError> {
if lambda <= 0.0 {
return Err(PramanaError::InvalidParameter(
"lambda must be positive".into(),
));
}
Ok(Self { lambda })
}
}
impl Distribution for Poisson {
#[inline]
fn pdf(&self, x: f64) -> f64 {
if x < 0.0 || x.fract() != 0.0 {
return 0.0;
}
let k = x as u64;
let log_pmf = (k as f64) * self.lambda.ln() - self.lambda - ln_factorial(k);
log_pmf.exp()
}
fn cdf(&self, x: f64) -> f64 {
if x < 0.0 {
return 0.0;
}
let n = x.floor() as u64;
let mut sum = 0.0;
for k in 0..=n {
let log_pmf = (k as f64) * self.lambda.ln() - self.lambda - ln_factorial(k);
sum += log_pmf.exp();
}
sum.min(1.0)
}
#[inline]
fn mean(&self) -> f64 {
self.lambda
}
#[inline]
fn variance(&self) -> f64 {
self.lambda
}
fn sample(&self, rng: &mut impl Rng) -> f64 {
if self.lambda > 30.0 {
let u1 = rng.next_f64().max(f64::MIN_POSITIVE);
let u2 = rng.next_f64();
let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
let sample = self.lambda + self.lambda.sqrt() * z;
return sample.round().max(0.0);
}
let l = (-self.lambda).exp();
let mut k: u64 = 0;
let mut p = 1.0;
loop {
k += 1;
p *= rng.next_f64();
if p <= l {
break;
}
}
(k - 1) as f64
}
}
#[inline]
fn ln_factorial(n: u64) -> f64 {
if n <= 1 {
return 0.0;
}
let mut sum = 0.0;
for i in 2..=n {
sum += (i as f64).ln();
}
sum
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct Binomial {
pub n: u64,
pub p: f64,
}
impl Binomial {
pub fn new(n: u64, p: f64) -> Result<Self, PramanaError> {
if n == 0 {
return Err(PramanaError::InvalidParameter("n must be positive".into()));
}
if !(0.0..=1.0).contains(&p) {
return Err(PramanaError::InvalidParameter("p must be in [0, 1]".into()));
}
Ok(Self { n, p })
}
}
impl Distribution for Binomial {
#[inline]
fn pdf(&self, x: f64) -> f64 {
if x < 0.0 || x.fract() != 0.0 || x > self.n as f64 {
return 0.0;
}
let k = x as u64;
let log_pmf = ln_binomial_coeff(self.n, k)
+ (k as f64) * self.p.ln()
+ ((self.n - k) as f64) * (1.0 - self.p).ln();
log_pmf.exp()
}
fn cdf(&self, x: f64) -> f64 {
if x < 0.0 {
return 0.0;
}
if x >= self.n as f64 {
return 1.0;
}
let upper = x.floor() as u64;
let mut sum = 0.0;
for k in 0..=upper {
let log_pmf = ln_binomial_coeff(self.n, k)
+ (k as f64) * self.p.ln()
+ ((self.n - k) as f64) * (1.0 - self.p).ln();
sum += log_pmf.exp();
}
sum.min(1.0)
}
#[inline]
fn mean(&self) -> f64 {
self.n as f64 * self.p
}
#[inline]
fn variance(&self) -> f64 {
self.n as f64 * self.p * (1.0 - self.p)
}
fn sample(&self, rng: &mut impl Rng) -> f64 {
let mut successes: u64 = 0;
for _ in 0..self.n {
if rng.next_f64() < self.p {
successes += 1;
}
}
successes as f64
}
}
#[inline]
fn ln_binomial_coeff(n: u64, k: u64) -> f64 {
if k > n {
return f64::NEG_INFINITY;
}
ln_factorial(n) - ln_factorial(k) - ln_factorial(n - k)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct Bernoulli {
pub p: f64,
}
impl Bernoulli {
pub fn new(p: f64) -> Result<Self, PramanaError> {
if !(0.0..=1.0).contains(&p) {
return Err(PramanaError::InvalidParameter("p must be in [0, 1]".into()));
}
Ok(Self { p })
}
}
impl Distribution for Bernoulli {
#[inline]
fn pdf(&self, x: f64) -> f64 {
if (x - 0.0).abs() < f64::EPSILON {
1.0 - self.p
} else if (x - 1.0).abs() < f64::EPSILON {
self.p
} else {
0.0
}
}
#[inline]
fn cdf(&self, x: f64) -> f64 {
if x < 0.0 {
0.0
} else if x < 1.0 {
1.0 - self.p
} else {
1.0
}
}
#[inline]
fn mean(&self) -> f64 {
self.p
}
#[inline]
fn variance(&self) -> f64 {
self.p * (1.0 - self.p)
}
#[inline]
fn sample(&self, rng: &mut impl Rng) -> f64 {
if rng.next_f64() < self.p { 1.0 } else { 0.0 }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct Gamma {
pub alpha: f64,
pub beta: f64,
}
impl Gamma {
pub fn new(alpha: f64, beta: f64) -> Result<Self, PramanaError> {
if alpha <= 0.0 {
return Err(PramanaError::InvalidParameter(
"alpha must be positive".into(),
));
}
if beta <= 0.0 {
return Err(PramanaError::InvalidParameter(
"beta must be positive".into(),
));
}
Ok(Self { alpha, beta })
}
}
impl Distribution for Gamma {
#[inline]
fn pdf(&self, x: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
let ln_pdf = (self.alpha - 1.0) * x.ln()
- x / self.beta
- self.alpha * self.beta.ln()
- math::ln_gamma(self.alpha);
ln_pdf.exp()
}
#[inline]
fn cdf(&self, x: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
math::regularized_lower_incomplete_gamma(self.alpha, x / self.beta)
}
#[inline]
fn mean(&self) -> f64 {
self.alpha * self.beta
}
#[inline]
fn variance(&self) -> f64 {
self.alpha * self.beta * self.beta
}
fn sample(&self, rng: &mut impl Rng) -> f64 {
sample_gamma_standard(self.alpha, rng) * self.beta
}
}
fn sample_gamma_standard(alpha: f64, rng: &mut impl Rng) -> f64 {
let (effective_alpha, boost) = if alpha < 1.0 {
let u = rng.next_f64().max(f64::MIN_POSITIVE);
(alpha + 1.0, u.powf(1.0 / alpha))
} else {
(alpha, 1.0)
};
let d = effective_alpha - 1.0 / 3.0;
let c = 1.0 / (9.0 * d).sqrt();
let sample = loop {
let (x, v) = loop {
let u1 = rng.next_f64().max(f64::MIN_POSITIVE);
let u2 = rng.next_f64();
let x = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
let v = 1.0 + c * x;
if v > 0.0 {
break (x, v * v * v);
}
};
let u = rng.next_f64().max(f64::MIN_POSITIVE);
if u < 1.0 - 0.0331 * x * x * x * x {
break d * v;
}
if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
break d * v;
}
};
sample * boost
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct Beta {
pub alpha: f64,
pub beta: f64,
}
impl Beta {
pub fn new(alpha: f64, beta: f64) -> Result<Self, PramanaError> {
if alpha <= 0.0 {
return Err(PramanaError::InvalidParameter(
"alpha must be positive".into(),
));
}
if beta <= 0.0 {
return Err(PramanaError::InvalidParameter(
"beta must be positive".into(),
));
}
Ok(Self { alpha, beta })
}
}
impl Distribution for Beta {
#[inline]
fn pdf(&self, x: f64) -> f64 {
if x <= 0.0 || x >= 1.0 {
return 0.0;
}
let ln_pdf = (self.alpha - 1.0) * x.ln() + (self.beta - 1.0) * (1.0 - x).ln()
- math::ln_beta(self.alpha, self.beta);
ln_pdf.exp()
}
#[inline]
fn cdf(&self, x: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
if x >= 1.0 {
return 1.0;
}
math::regularized_incomplete_beta(x, self.alpha, self.beta)
}
#[inline]
fn mean(&self) -> f64 {
self.alpha / (self.alpha + self.beta)
}
#[inline]
fn variance(&self) -> f64 {
let ab = self.alpha + self.beta;
self.alpha * self.beta / (ab * ab * (ab + 1.0))
}
fn sample(&self, rng: &mut impl Rng) -> f64 {
let x = sample_gamma_standard(self.alpha, rng);
let y = sample_gamma_standard(self.beta, rng);
x / (x + y)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct ChiSquared {
pub df: f64,
}
impl ChiSquared {
pub fn new(df: f64) -> Result<Self, PramanaError> {
if df <= 0.0 {
return Err(PramanaError::InvalidParameter("df must be positive".into()));
}
Ok(Self { df })
}
}
impl Distribution for ChiSquared {
#[inline]
fn pdf(&self, x: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
let half_df = self.df / 2.0;
let ln_pdf =
(half_df - 1.0) * x.ln() - x / 2.0 - half_df * 2.0_f64.ln() - math::ln_gamma(half_df);
ln_pdf.exp()
}
#[inline]
fn cdf(&self, x: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
math::regularized_lower_incomplete_gamma(self.df / 2.0, x / 2.0)
}
#[inline]
fn mean(&self) -> f64 {
self.df
}
#[inline]
fn variance(&self) -> f64 {
2.0 * self.df
}
fn sample(&self, rng: &mut impl Rng) -> f64 {
sample_gamma_standard(self.df / 2.0, rng) * 2.0
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct StudentT {
pub df: f64,
}
impl StudentT {
pub fn new(df: f64) -> Result<Self, PramanaError> {
if df <= 0.0 {
return Err(PramanaError::InvalidParameter("df must be positive".into()));
}
Ok(Self { df })
}
}
impl Distribution for StudentT {
#[inline]
fn pdf(&self, x: f64) -> f64 {
let half_dfp1 = (self.df + 1.0) / 2.0;
let ln_pdf = math::ln_gamma(half_dfp1)
- 0.5 * (self.df * PI).ln()
- math::ln_gamma(self.df / 2.0)
- half_dfp1 * (1.0 + x * x / self.df).ln();
ln_pdf.exp()
}
#[inline]
fn cdf(&self, x: f64) -> f64 {
let ibeta =
math::regularized_incomplete_beta(self.df / (self.df + x * x), self.df / 2.0, 0.5);
if x >= 0.0 {
1.0 - 0.5 * ibeta
} else {
0.5 * ibeta
}
}
#[inline]
fn mean(&self) -> f64 {
if self.df > 1.0 { 0.0 } else { f64::NAN }
}
#[inline]
fn variance(&self) -> f64 {
if self.df > 2.0 {
self.df / (self.df - 2.0)
} else if self.df > 1.0 {
f64::INFINITY
} else {
f64::NAN
}
}
fn sample(&self, rng: &mut impl Rng) -> f64 {
let u1 = rng.next_f64().max(f64::MIN_POSITIVE);
let u2 = rng.next_f64();
let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
let v = sample_gamma_standard(self.df / 2.0, rng) * 2.0;
z / (v / self.df).sqrt()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct FDistribution {
pub d1: f64,
pub d2: f64,
}
impl FDistribution {
pub fn new(d1: f64, d2: f64) -> Result<Self, PramanaError> {
if d1 <= 0.0 {
return Err(PramanaError::InvalidParameter("d1 must be positive".into()));
}
if d2 <= 0.0 {
return Err(PramanaError::InvalidParameter("d2 must be positive".into()));
}
Ok(Self { d1, d2 })
}
}
impl Distribution for FDistribution {
#[inline]
fn pdf(&self, x: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
let half_d1 = self.d1 / 2.0;
let half_d2 = self.d2 / 2.0;
let ln_pdf = half_d1 * (self.d1 / self.d2).ln() + (half_d1 - 1.0) * x.ln()
- (half_d1 + half_d2) * (1.0 + self.d1 * x / self.d2).ln()
- math::ln_beta(half_d1, half_d2);
ln_pdf.exp()
}
#[inline]
fn cdf(&self, x: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
let u = self.d1 * x / (self.d1 * x + self.d2);
math::regularized_incomplete_beta(u, self.d1 / 2.0, self.d2 / 2.0)
}
#[inline]
fn mean(&self) -> f64 {
if self.d2 > 2.0 {
self.d2 / (self.d2 - 2.0)
} else {
f64::NAN
}
}
#[inline]
fn variance(&self) -> f64 {
if self.d2 > 4.0 {
2.0 * self.d2 * self.d2 * (self.d1 + self.d2 - 2.0)
/ (self.d1 * (self.d2 - 2.0) * (self.d2 - 2.0) * (self.d2 - 4.0))
} else {
f64::NAN
}
}
fn sample(&self, rng: &mut impl Rng) -> f64 {
let x1 = sample_gamma_standard(self.d1 / 2.0, rng) * 2.0;
let x2 = sample_gamma_standard(self.d2 / 2.0, rng) * 2.0;
(x1 / self.d1) / (x2 / self.d2)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct Cauchy {
pub x0: f64,
pub gamma: f64,
}
impl Cauchy {
pub fn new(x0: f64, gamma: f64) -> Result<Self, PramanaError> {
if gamma <= 0.0 {
return Err(PramanaError::InvalidParameter(
"gamma must be positive".into(),
));
}
Ok(Self { x0, gamma })
}
}
impl Distribution for Cauchy {
#[inline]
fn pdf(&self, x: f64) -> f64 {
let z = (x - self.x0) / self.gamma;
1.0 / (PI * self.gamma * (1.0 + z * z))
}
#[inline]
fn cdf(&self, x: f64) -> f64 {
(1.0 / PI) * ((x - self.x0) / self.gamma).atan() + 0.5
}
#[inline]
fn mean(&self) -> f64 {
f64::NAN
}
#[inline]
fn variance(&self) -> f64 {
f64::NAN
}
#[inline]
fn sample(&self, rng: &mut impl Rng) -> f64 {
let u = rng.next_f64();
self.x0 + self.gamma * (PI * (u - 0.5)).tan()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct Weibull {
pub k: f64,
pub lambda: f64,
}
impl Weibull {
pub fn new(k: f64, lambda: f64) -> Result<Self, PramanaError> {
if k <= 0.0 {
return Err(PramanaError::InvalidParameter("k must be positive".into()));
}
if lambda <= 0.0 {
return Err(PramanaError::InvalidParameter(
"lambda must be positive".into(),
));
}
Ok(Self { k, lambda })
}
}
impl Distribution for Weibull {
#[inline]
fn pdf(&self, x: f64) -> f64 {
if x < 0.0 {
return 0.0;
}
if x == 0.0 {
return if self.k == 1.0 {
1.0 / self.lambda
} else if self.k > 1.0 {
0.0
} else {
f64::INFINITY
};
}
let z = x / self.lambda;
(self.k / self.lambda) * z.powf(self.k - 1.0) * (-z.powf(self.k)).exp()
}
#[inline]
fn cdf(&self, x: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
1.0 - (-(x / self.lambda).powf(self.k)).exp()
}
#[inline]
fn mean(&self) -> f64 {
self.lambda * math::ln_gamma(1.0 + 1.0 / self.k).exp()
}
#[inline]
fn variance(&self) -> f64 {
let g1 = math::ln_gamma(1.0 + 1.0 / self.k).exp();
let g2 = math::ln_gamma(1.0 + 2.0 / self.k).exp();
self.lambda * self.lambda * (g2 - g1 * g1)
}
#[inline]
fn sample(&self, rng: &mut impl Rng) -> f64 {
let u = rng.next_f64().max(f64::MIN_POSITIVE);
self.lambda * (-(1.0 - u).ln()).powf(1.0 / self.k)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct MultivariateNormal {
pub mean: Vec<f64>,
pub covariance: Vec<Vec<f64>>,
cholesky_l: Vec<Vec<f64>>,
log_det: f64,
dim: usize,
}
impl MultivariateNormal {
pub fn new(mean: Vec<f64>, covariance: Vec<Vec<f64>>) -> Result<Self, PramanaError> {
let dim = mean.len();
if dim == 0 {
return Err(PramanaError::InvalidParameter(
"mean vector must be non-empty".into(),
));
}
if covariance.len() != dim {
return Err(PramanaError::DimensionMismatch(format!(
"covariance has {} rows, expected {dim}",
covariance.len()
)));
}
for (i, row) in covariance.iter().enumerate() {
if row.len() != dim {
return Err(PramanaError::DimensionMismatch(format!(
"covariance row {i} has length {}, expected {dim}",
row.len()
)));
}
}
for (i, row_i) in covariance.iter().enumerate() {
for (j, &cov_ij) in row_i.iter().enumerate().skip(i + 1) {
if (cov_ij - covariance[j][i]).abs() > 1e-10 {
return Err(PramanaError::InvalidParameter(format!(
"covariance not symmetric: [{i}][{j}]={cov_ij} != [{j}][{i}]={}",
covariance[j][i]
)));
}
}
}
let cholesky_l = hisab::num::cholesky(&covariance).map_err(|e| {
PramanaError::InvalidParameter(format!("covariance is not positive-definite: {e}"))
})?;
let log_det: f64 = cholesky_l
.iter()
.enumerate()
.map(|(i, row)| row[i].ln())
.sum::<f64>()
* 2.0;
Ok(Self {
mean,
covariance,
cholesky_l,
log_det,
dim,
})
}
#[must_use]
#[inline]
pub fn dim(&self) -> usize {
self.dim
}
#[must_use = "returns the density value"]
pub fn pdf(&self, x: &[f64]) -> Result<f64, PramanaError> {
Ok(self.log_pdf(x)?.exp())
}
#[must_use = "returns the log-density value"]
pub fn log_pdf(&self, x: &[f64]) -> Result<f64, PramanaError> {
if x.len() != self.dim {
return Err(PramanaError::DimensionMismatch(format!(
"x has length {}, expected {}",
x.len(),
self.dim
)));
}
let diff: Vec<f64> = x.iter().zip(&self.mean).map(|(xi, mi)| xi - mi).collect();
let y = hisab::num::cholesky_solve(&self.cholesky_l, &diff)
.map_err(|e| PramanaError::ComputationError(format!("cholesky_solve failed: {e}")))?;
let mahalanobis_sq: f64 = diff.iter().zip(&y).map(|(d, yi)| d * yi).sum();
let d = self.dim as f64;
Ok(-0.5 * (d * (2.0 * PI).ln() + self.log_det + mahalanobis_sq))
}
pub fn sample(&self, rng: &mut impl Rng) -> Vec<f64> {
let mut z = Vec::with_capacity(self.dim);
let mut i = 0;
while i < self.dim {
let u1 = rng.next_f64().max(f64::MIN_POSITIVE);
let u2 = rng.next_f64();
let r = (-2.0 * u1.ln()).sqrt();
let theta = 2.0 * PI * u2;
z.push(r * theta.cos());
if i + 1 < self.dim {
z.push(r * theta.sin());
}
i += 2;
}
let mut result = self.mean.clone();
for (i, (res_i, row)) in result.iter_mut().zip(&self.cholesky_l).enumerate() {
let sum: f64 = row.iter().zip(&z[..=i]).map(|(l, zi)| l * zi).sum();
*res_i += sum;
}
result
}
#[must_use]
#[inline]
pub fn mean_vec(&self) -> &[f64] {
&self.mean
}
#[must_use]
#[inline]
pub fn covariance_matrix(&self) -> &[Vec<f64>] {
&self.covariance
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::rng::SimpleRng;
#[test]
fn normal_pdf_at_mean() {
let n = Normal::new(0.0, 1.0).unwrap();
let expected = 1.0 / (2.0 * PI).sqrt();
assert!((n.pdf(0.0) - expected).abs() < 1e-10);
}
#[test]
fn normal_cdf_at_mean() {
let n = Normal::new(0.0, 1.0).unwrap();
assert!((n.cdf(0.0) - 0.5).abs() < 1e-6);
}
#[test]
fn normal_invalid_std_dev() {
assert!(Normal::new(0.0, 0.0).is_err());
assert!(Normal::new(0.0, -1.0).is_err());
}
#[test]
fn uniform_pdf_and_cdf() {
let u = Uniform::new(0.0, 10.0).unwrap();
assert!((u.pdf(5.0) - 0.1).abs() < 1e-10);
assert!((u.cdf(5.0) - 0.5).abs() < 1e-10);
assert_eq!(u.pdf(-1.0), 0.0);
assert_eq!(u.cdf(-1.0), 0.0);
assert_eq!(u.cdf(11.0), 1.0);
}
#[test]
fn exponential_mean_and_variance() {
let e = Exponential::new(2.0).unwrap();
assert!((e.mean() - 0.5).abs() < 1e-10);
assert!((e.variance() - 0.25).abs() < 1e-10);
}
#[test]
fn poisson_pmf() {
let p = Poisson::new(2.0).unwrap();
let expected = 8.0 * std::f64::consts::E.powf(-2.0) / 6.0;
assert!((p.pdf(3.0) - expected).abs() < 1e-10);
}
#[test]
fn binomial_mean_and_variance() {
let b = Binomial::new(10, 0.3).unwrap();
assert!((b.mean() - 3.0).abs() < 1e-10);
assert!((b.variance() - 2.1).abs() < 1e-10);
}
#[test]
fn bernoulli_pdf() {
let b = Bernoulli::new(0.7).unwrap();
assert!((b.pdf(1.0) - 0.7).abs() < 1e-10);
assert!((b.pdf(0.0) - 0.3).abs() < 1e-10);
}
#[test]
fn normal_sample_finite() {
let n = Normal::new(0.0, 1.0).unwrap();
let mut rng = SimpleRng::new(42);
for _ in 0..1000 {
let s = n.sample(&mut rng);
assert!(s.is_finite());
}
}
#[test]
fn serde_roundtrip_normal() {
let n = Normal::new(1.5, 2.3).unwrap();
let json = serde_json::to_string(&n).unwrap();
let n2: Normal = serde_json::from_str(&json).unwrap();
assert_eq!(n.mean, n2.mean);
assert_eq!(n.std_dev, n2.std_dev);
}
#[test]
fn serde_roundtrip_uniform() {
let u = Uniform::new(-1.0, 5.0).unwrap();
let json = serde_json::to_string(&u).unwrap();
let u2: Uniform = serde_json::from_str(&json).unwrap();
assert_eq!(u.min, u2.min);
assert_eq!(u.max, u2.max);
}
#[test]
fn serde_roundtrip_exponential() {
let e = Exponential::new(2.5).unwrap();
let json = serde_json::to_string(&e).unwrap();
let e2: Exponential = serde_json::from_str(&json).unwrap();
assert_eq!(e.lambda, e2.lambda);
}
#[test]
fn serde_roundtrip_poisson() {
let p = Poisson::new(3.5).unwrap();
let json = serde_json::to_string(&p).unwrap();
let p2: Poisson = serde_json::from_str(&json).unwrap();
assert_eq!(p.lambda, p2.lambda);
}
#[test]
fn serde_roundtrip_binomial() {
let b = Binomial::new(20, 0.4).unwrap();
let json = serde_json::to_string(&b).unwrap();
let b2: Binomial = serde_json::from_str(&json).unwrap();
assert_eq!(b.n, b2.n);
assert_eq!(b.p, b2.p);
}
#[test]
fn serde_roundtrip_bernoulli() {
let b = Bernoulli::new(0.7).unwrap();
let json = serde_json::to_string(&b).unwrap();
let b2: Bernoulli = serde_json::from_str(&json).unwrap();
assert_eq!(b.p, b2.p);
}
#[test]
fn poisson_large_lambda_sample() {
let p = Poisson::new(100.0).unwrap();
let mut rng = SimpleRng::new(42);
let mut sum = 0.0;
let n = 10_000;
for _ in 0..n {
let s = p.sample(&mut rng);
assert!(s >= 0.0);
assert!(s.is_finite());
sum += s;
}
let sample_mean = sum / n as f64;
assert!(
(sample_mean - 100.0).abs() < 5.0,
"sample mean {sample_mean} too far from lambda=100"
);
}
#[test]
fn gamma_pdf_known() {
let g = Gamma::new(1.0, 1.0).unwrap();
assert!((g.pdf(1.0) - (-1.0_f64).exp()).abs() < 1e-8);
}
#[test]
fn gamma_cdf_known() {
let g = Gamma::new(1.0, 1.0).unwrap();
let expected = 1.0 - (-1.0_f64).exp();
assert!((g.cdf(1.0) - expected).abs() < 1e-6, "got {}", g.cdf(1.0));
}
#[test]
fn gamma_mean_and_variance() {
let g = Gamma::new(3.0, 2.0).unwrap();
assert!((g.mean() - 6.0).abs() < 1e-10);
assert!((g.variance() - 12.0).abs() < 1e-10);
}
#[test]
fn gamma_sample_mean() {
let g = Gamma::new(5.0, 2.0).unwrap();
let mut rng = SimpleRng::new(42);
let n = 50_000;
let sum: f64 = (0..n).map(|_| g.sample(&mut rng)).sum();
let sample_mean = sum / n as f64;
assert!(
(sample_mean - 10.0).abs() < 0.5,
"sample mean {sample_mean} too far from 10"
);
}
#[test]
fn gamma_invalid_params() {
assert!(Gamma::new(0.0, 1.0).is_err());
assert!(Gamma::new(1.0, 0.0).is_err());
assert!(Gamma::new(-1.0, 1.0).is_err());
}
#[test]
fn gamma_small_alpha_sample() {
let g = Gamma::new(0.1, 1.0).unwrap();
let mut rng = SimpleRng::new(42);
let n = 10_000;
let mut sum = 0.0;
for _ in 0..n {
let s = g.sample(&mut rng);
assert!(s >= 0.0);
assert!(s.is_finite());
sum += s;
}
let sample_mean = sum / n as f64;
assert!(
(sample_mean - 0.1).abs() < 0.05,
"sample mean {sample_mean} too far from 0.1"
);
}
#[test]
fn serde_roundtrip_gamma() {
let g = Gamma::new(3.0, 2.0).unwrap();
let json = serde_json::to_string(&g).unwrap();
let g2: Gamma = serde_json::from_str(&json).unwrap();
assert_eq!(g.alpha, g2.alpha);
assert_eq!(g.beta, g2.beta);
}
#[test]
fn beta_mean_and_variance() {
let b = Beta::new(2.0, 5.0).unwrap();
assert!((b.mean() - 2.0 / 7.0).abs() < 1e-10);
let expected_var = 2.0 * 5.0 / (49.0 * 8.0);
assert!((b.variance() - expected_var).abs() < 1e-10);
}
#[test]
fn beta_cdf_endpoints() {
let b = Beta::new(2.0, 3.0).unwrap();
assert_eq!(b.cdf(0.0), 0.0);
assert_eq!(b.cdf(1.0), 1.0);
}
#[test]
fn beta_uniform_is_beta_1_1() {
let b = Beta::new(1.0, 1.0).unwrap();
assert!((b.pdf(0.5) - 1.0).abs() < 1e-8);
assert!((b.cdf(0.5) - 0.5).abs() < 1e-6);
}
#[test]
fn beta_sample_in_range() {
let b = Beta::new(2.0, 5.0).unwrap();
let mut rng = SimpleRng::new(42);
for _ in 0..10_000 {
let s = b.sample(&mut rng);
assert!((0.0..=1.0).contains(&s), "out of range: {s}");
}
}
#[test]
fn beta_small_alpha_sample() {
let b = Beta::new(0.5, 0.5).unwrap();
let mut rng = SimpleRng::new(42);
for _ in 0..10_000 {
let s = b.sample(&mut rng);
assert!((0.0..=1.0).contains(&s), "out of range: {s}");
}
}
#[test]
fn serde_roundtrip_beta() {
let b = Beta::new(2.0, 5.0).unwrap();
let json = serde_json::to_string(&b).unwrap();
let b2: Beta = serde_json::from_str(&json).unwrap();
assert_eq!(b.alpha, b2.alpha);
assert_eq!(b.beta, b2.beta);
}
#[test]
fn chi_squared_mean_and_variance() {
let c = ChiSquared::new(5.0).unwrap();
assert!((c.mean() - 5.0).abs() < 1e-10);
assert!((c.variance() - 10.0).abs() < 1e-10);
}
#[test]
fn chi_squared_cdf_known() {
let c = ChiSquared::new(2.0).unwrap();
let expected = 1.0 - (-1.0_f64).exp();
assert!((c.cdf(2.0) - expected).abs() < 1e-4, "got {}", c.cdf(2.0));
}
#[test]
fn chi_squared_sample_mean() {
let c = ChiSquared::new(10.0).unwrap();
let mut rng = SimpleRng::new(42);
let n = 50_000;
let sum: f64 = (0..n).map(|_| c.sample(&mut rng)).sum();
let sample_mean = sum / n as f64;
assert!(
(sample_mean - 10.0).abs() < 0.5,
"sample mean {sample_mean}"
);
}
#[test]
fn serde_roundtrip_chi_squared() {
let c = ChiSquared::new(5.0).unwrap();
let json = serde_json::to_string(&c).unwrap();
let c2: ChiSquared = serde_json::from_str(&json).unwrap();
assert_eq!(c.df, c2.df);
}
#[test]
fn student_t_symmetric_pdf() {
let t = StudentT::new(5.0).unwrap();
assert!((t.pdf(1.0) - t.pdf(-1.0)).abs() < 1e-10);
assert!((t.cdf(0.0) - 0.5).abs() < 1e-6, "CDF at 0 = {}", t.cdf(0.0));
}
#[test]
fn student_t_mean_and_variance() {
let t = StudentT::new(5.0).unwrap();
assert!((t.mean()).abs() < 1e-10);
assert!((t.variance() - 5.0 / 3.0).abs() < 1e-10);
let t1 = StudentT::new(1.0).unwrap();
assert!(t1.mean().is_nan());
assert!(t1.variance().is_nan());
let t15 = StudentT::new(1.5).unwrap();
assert!((t15.mean()).abs() < 1e-10);
assert!(t15.variance().is_infinite());
}
#[test]
fn student_t_sample_symmetric() {
let t = StudentT::new(10.0).unwrap();
let mut rng = SimpleRng::new(42);
let n = 50_000;
let sum: f64 = (0..n).map(|_| t.sample(&mut rng)).sum();
let sample_mean = sum / n as f64;
assert!(
sample_mean.abs() < 0.1,
"sample mean {sample_mean} should be near 0"
);
}
#[test]
fn serde_roundtrip_student_t() {
let t = StudentT::new(5.0).unwrap();
let json = serde_json::to_string(&t).unwrap();
let t2: StudentT = serde_json::from_str(&json).unwrap();
assert_eq!(t.df, t2.df);
}
#[test]
fn f_distribution_mean() {
let f = FDistribution::new(5.0, 10.0).unwrap();
assert!((f.mean() - 10.0 / 8.0).abs() < 1e-10);
let f2 = FDistribution::new(5.0, 2.0).unwrap();
assert!(f2.mean().is_nan());
}
#[test]
fn f_distribution_variance() {
let f = FDistribution::new(5.0, 10.0).unwrap();
let expected = 2.0 * 100.0 * 13.0 / (5.0 * 64.0 * 6.0);
assert!(
(f.variance() - expected).abs() < 1e-10,
"variance = {}",
f.variance()
);
let f2 = FDistribution::new(5.0, 4.0).unwrap();
assert!(f2.variance().is_nan());
}
#[test]
fn f_distribution_cdf_nonneg() {
let f = FDistribution::new(5.0, 10.0).unwrap();
assert_eq!(f.cdf(0.0), 0.0);
assert!(f.cdf(100.0) > 0.99);
}
#[test]
fn f_distribution_sample_positive() {
let f = FDistribution::new(5.0, 10.0).unwrap();
let mut rng = SimpleRng::new(42);
for _ in 0..10_000 {
let s = f.sample(&mut rng);
assert!(s > 0.0, "F sample must be positive: {s}");
}
}
#[test]
fn serde_roundtrip_f_distribution() {
let f = FDistribution::new(5.0, 10.0).unwrap();
let json = serde_json::to_string(&f).unwrap();
let f2: FDistribution = serde_json::from_str(&json).unwrap();
assert_eq!(f.d1, f2.d1);
assert_eq!(f.d2, f2.d2);
}
#[test]
fn cauchy_pdf_and_cdf() {
let c = Cauchy::new(0.0, 1.0).unwrap();
assert!((c.pdf(0.0) - 1.0 / PI).abs() < 1e-10);
assert!((c.cdf(0.0) - 0.5).abs() < 1e-10);
assert!((c.cdf(1.0) + c.cdf(-1.0) - 1.0).abs() < 1e-10);
}
#[test]
fn cauchy_undefined_moments() {
let c = Cauchy::new(0.0, 1.0).unwrap();
assert!(c.mean().is_nan());
assert!(c.variance().is_nan());
}
#[test]
fn cauchy_sample_finite() {
let c = Cauchy::new(0.0, 1.0).unwrap();
let mut rng = SimpleRng::new(42);
for _ in 0..10_000 {
let s = c.sample(&mut rng);
assert!(s.is_finite());
}
}
#[test]
fn serde_roundtrip_cauchy() {
let c = Cauchy::new(2.0, 3.0).unwrap();
let json = serde_json::to_string(&c).unwrap();
let c2: Cauchy = serde_json::from_str(&json).unwrap();
assert_eq!(c.x0, c2.x0);
assert_eq!(c.gamma, c2.gamma);
}
#[test]
fn weibull_exponential_case() {
let w = Weibull::new(1.0, 2.0).unwrap();
assert!((w.mean() - 2.0).abs() < 1e-8);
let expected = 1.0 - (-1.0_f64).exp();
assert!((w.cdf(2.0) - expected).abs() < 1e-8);
}
#[test]
fn weibull_mean_and_variance() {
let w = Weibull::new(2.0, 1.0).unwrap();
let expected_mean = std::f64::consts::PI.sqrt() / 2.0;
assert!(
(w.mean() - expected_mean).abs() < 1e-6,
"mean = {}",
w.mean()
);
}
#[test]
fn weibull_sample_nonnegative() {
let w = Weibull::new(2.0, 3.0).unwrap();
let mut rng = SimpleRng::new(42);
for _ in 0..10_000 {
let s = w.sample(&mut rng);
assert!(s >= 0.0, "Weibull sample must be non-negative: {s}");
}
}
#[test]
fn serde_roundtrip_weibull() {
let w = Weibull::new(2.0, 3.0).unwrap();
let json = serde_json::to_string(&w).unwrap();
let w2: Weibull = serde_json::from_str(&json).unwrap();
assert_eq!(w.k, w2.k);
assert_eq!(w.lambda, w2.lambda);
}
#[test]
fn mvn_1d_matches_univariate() {
let mvn = MultivariateNormal::new(vec![2.0], vec![vec![4.0]]).unwrap();
let n = Normal::new(2.0, 2.0).unwrap(); let x = 3.0;
let mvn_pdf = mvn.pdf(&[x]).unwrap();
let n_pdf = n.pdf(x);
assert!(
(mvn_pdf - n_pdf).abs() < 1e-10,
"mvn_pdf={mvn_pdf}, n_pdf={n_pdf}"
);
}
#[test]
fn mvn_2d_pdf_at_mean() {
let cov = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let mvn = MultivariateNormal::new(vec![0.0, 0.0], cov).unwrap();
let pdf_at_mean = mvn.pdf(&[0.0, 0.0]).unwrap();
let expected = 1.0 / (2.0 * PI); assert!(
(pdf_at_mean - expected).abs() < 1e-10,
"pdf={pdf_at_mean}, expected={expected}"
);
}
#[test]
fn mvn_sample_mean_convergence() {
let mean = vec![1.0, -2.0];
let cov = vec![vec![1.0, 0.3], vec![0.3, 2.0]];
let mvn = MultivariateNormal::new(mean.clone(), cov).unwrap();
let mut rng = SimpleRng::new(42);
let n = 50_000;
let mut sum = [0.0; 2];
for _ in 0..n {
let s = mvn.sample(&mut rng);
assert_eq!(s.len(), 2);
sum[0] += s[0];
sum[1] += s[1];
}
let sample_mean = [sum[0] / n as f64, sum[1] / n as f64];
assert!(
(sample_mean[0] - 1.0).abs() < 0.1,
"mean[0]={}",
sample_mean[0]
);
assert!(
(sample_mean[1] + 2.0).abs() < 0.1,
"mean[1]={}",
sample_mean[1]
);
}
#[test]
fn mvn_invalid_params() {
assert!(MultivariateNormal::new(vec![], vec![]).is_err());
assert!(MultivariateNormal::new(vec![0.0], vec![vec![1.0, 0.0]]).is_err());
assert!(
MultivariateNormal::new(vec![0.0, 0.0], vec![vec![1.0, 0.5], vec![0.0, 1.0]]).is_err()
);
assert!(
MultivariateNormal::new(vec![0.0, 0.0], vec![vec![1.0, 2.0], vec![2.0, 1.0]]).is_err()
);
}
#[test]
fn mvn_dimension_mismatch_pdf() {
let mvn =
MultivariateNormal::new(vec![0.0, 0.0], vec![vec![1.0, 0.0], vec![0.0, 1.0]]).unwrap();
assert!(mvn.pdf(&[1.0]).is_err());
assert!(mvn.pdf(&[1.0, 2.0, 3.0]).is_err());
}
#[test]
fn mvn_log_pdf_consistency() {
let cov = vec![vec![2.0, 0.5], vec![0.5, 3.0]];
let mvn = MultivariateNormal::new(vec![1.0, -1.0], cov).unwrap();
let x = [2.0, 0.0];
let log_pdf = mvn.log_pdf(&x).unwrap();
let pdf = mvn.pdf(&x).unwrap();
assert!(
(log_pdf - pdf.ln()).abs() < 1e-10,
"log_pdf={log_pdf}, ln(pdf)={}",
pdf.ln()
);
}
#[test]
fn serde_roundtrip_mvn() {
let mvn =
MultivariateNormal::new(vec![1.0, 2.0], vec![vec![4.0, 1.0], vec![1.0, 9.0]]).unwrap();
let json = serde_json::to_string(&mvn).unwrap();
let mvn2: MultivariateNormal = serde_json::from_str(&json).unwrap();
assert_eq!(mvn.mean, mvn2.mean);
assert_eq!(mvn.covariance, mvn2.covariance);
assert_eq!(mvn.dim, mvn2.dim);
}
}