use crate::error::{StatsError, StatsResult};
use crate::traits::{ContinuousDistribution, Distribution};
use scirs2_core::ndarray::Array1;
use scirs2_core::numeric::{Float, NumCast};
use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
use std::fmt::Debug;
fn standard_normal_pdf<F: Float>(x: F) -> F {
let two = F::from(2.0).expect("F::from should not fail for 2.0");
let sqrt_2pi = F::from(2.506_628_274_631_001).expect("F::from should not fail for sqrt(2π)");
(-x * x / two).exp() / sqrt_2pi
}
fn standard_normal_cdf<F: Float>(x: F) -> F {
let half = F::from(0.5).expect("F::from should not fail for 0.5");
let sqrt2 = F::from(std::f64::consts::SQRT_2).expect("F::from should not fail for SQRT_2");
half * (F::one() + erf(x / sqrt2))
}
fn erf<F: Float>(x: F) -> F {
let neg = x < F::zero();
let x = x.abs();
let t = F::one() / (F::one() + F::from(0.3275911).expect("F::from should not fail for 0.3275911") * x);
let poly = t * (F::from(0.254829592).expect("F::from should not fail for 0.254829592")
+ t * (F::from(-0.284496736).expect("F::from should not fail for -0.284496736")
+ t * (F::from(1.421413741).expect("F::from should not fail for 1.421413741")
+ t * (F::from(-1.453152027).expect("F::from should not fail for -1.453152027")
+ t * F::from(1.061405429).expect("F::from should not fail for 1.061405429")))));
let result = F::one() - poly * (-(x * x)).exp();
if neg {
-result
} else {
result
}
}
fn standard_normal_ppf<F: Float>(p: F) -> F {
let half = F::from(0.5).unwrap_or(F::zero());
if p <= F::zero() {
return F::from(-8.0).unwrap_or(F::neg_infinity());
}
if p >= F::one() {
return F::from(8.0).unwrap_or(F::infinity());
}
if p == half {
return F::zero();
}
let p_low = F::from(0.02425).unwrap_or(F::zero());
let p_high = F::one() - p_low;
if p < p_low {
let q = (-F::from(2.0).unwrap_or(F::zero()) * p.ln()).sqrt();
let c = [
F::from(-7.784894002430293e-03).unwrap_or(F::zero()),
F::from(-3.223964580411365e-01).unwrap_or(F::zero()),
F::from(-2.400758277161838e+00).unwrap_or(F::zero()),
F::from(-2.549732539343734e+00).unwrap_or(F::zero()),
F::from( 4.374664141464968e+00).unwrap_or(F::zero()),
F::from( 2.938163982698783e+00).unwrap_or(F::zero()),
];
let d = [
F::from( 7.784695709041462e-03).unwrap_or(F::zero()),
F::from( 3.224671290700398e-01).unwrap_or(F::zero()),
F::from( 2.445134137142996e+00).unwrap_or(F::zero()),
F::from( 3.754408661907416e+00).unwrap_or(F::zero()),
];
let num = ((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5];
let den = (((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + F::one();
num / den
} else if p <= p_high {
let q = p - half;
let r = q * q;
let a = [
F::from(-3.969683028665376e+01).unwrap_or(F::zero()),
F::from( 2.209460984245205e+02).unwrap_or(F::zero()),
F::from(-2.759285104469687e+02).unwrap_or(F::zero()),
F::from( 1.383577518672690e+02).unwrap_or(F::zero()),
F::from(-3.066479806614716e+01).unwrap_or(F::zero()),
F::from( 2.506628277459239e+00).unwrap_or(F::zero()),
];
let b = [
F::from(-5.447609879822406e+01).unwrap_or(F::zero()),
F::from( 1.615858368580409e+02).unwrap_or(F::zero()),
F::from(-1.556989798598866e+02).unwrap_or(F::zero()),
F::from( 6.680131188771972e+01).unwrap_or(F::zero()),
F::from(-1.328068155288572e+01).unwrap_or(F::zero()),
];
let num = (((((a[0] * r + a[1]) * r + a[2]) * r + a[3]) * r + a[4]) * r + a[5]) * q;
let den = ((((b[0] * r + b[1]) * r + b[2]) * r + b[3]) * r + b[4]) * r + F::one();
num / den
} else {
-standard_normal_ppf(F::one() - p)
}
}
#[derive(Debug, Clone)]
pub struct TruncatedNormal<F: Float> {
pub mu: F,
pub sigma: F,
pub lower: F,
pub upper: F,
alpha: F,
beta: F,
z: F,
}
impl<F: Float + NumCast + Debug + std::fmt::Display> TruncatedNormal<F> {
pub fn new(mu: F, sigma: F, lower: F, upper: F) -> StatsResult<Self> {
if sigma <= F::zero() {
return Err(StatsError::InvalidArgument(
"sigma must be positive".to_string(),
));
}
if lower >= upper {
return Err(StatsError::InvalidArgument(
"lower bound must be strictly less than upper bound".to_string(),
));
}
let alpha = (lower - mu) / sigma;
let beta = (upper - mu) / sigma;
let phi_alpha = standard_normal_cdf(alpha);
let phi_beta = standard_normal_cdf(beta);
let z = phi_beta - phi_alpha;
if z <= F::zero() {
return Err(StatsError::InvalidArgument(
"Normalization constant (Φ(β) - Φ(α)) is zero; \
truncation interval likely too narrow relative to sigma"
.to_string(),
));
}
Ok(Self {
mu,
sigma,
lower,
upper,
alpha,
beta,
z,
})
}
pub fn pdf(&self, x: F) -> F {
if x < self.lower || x > self.upper {
return F::zero();
}
let xi = (x - self.mu) / self.sigma;
standard_normal_pdf(xi) / (self.sigma * self.z)
}
pub fn logpdf(&self, x: F) -> F {
if x < self.lower || x > self.upper {
return F::neg_infinity();
}
self.pdf(x).ln()
}
pub fn cdf(&self, x: F) -> F {
if x <= self.lower {
return F::zero();
}
if x >= self.upper {
return F::one();
}
let xi = (x - self.mu) / self.sigma;
(standard_normal_cdf(xi) - standard_normal_cdf(self.alpha)) / self.z
}
pub fn ppf(&self, p: F) -> StatsResult<F> {
if p < F::zero() || p > F::one() {
return Err(StatsError::InvalidArgument(
"probability p must be in [0, 1]".to_string(),
));
}
if p == F::zero() {
return Ok(self.lower);
}
if p == F::one() {
return Ok(self.upper);
}
let phi_alpha = standard_normal_cdf(self.alpha);
let target = phi_alpha + p * self.z;
Ok(self.mu + self.sigma * standard_normal_ppf(target))
}
pub fn rvs(&self, size: usize, seed: Option<u64>) -> StatsResult<Array1<F>> {
let mut rng = match seed {
Some(s) => StdRng::seed_from_u64(s),
None => {
let s = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(42);
StdRng::seed_from_u64(s)
}
};
let mut samples = Vec::with_capacity(size);
for _ in 0..size {
let u: f64 = rng.random();
let p = F::from(u).expect("F::from should not fail for f64 uniform sample in [0,1)");
let x = self.ppf(p)?;
samples.push(x);
}
Ok(Array1::from_vec(samples))
}
pub fn mean(&self) -> F {
let phi_alpha = standard_normal_pdf(self.alpha);
let phi_beta = standard_normal_pdf(self.beta);
self.mu + self.sigma * (phi_alpha - phi_beta) / self.z
}
pub fn var(&self) -> F {
let phi_alpha = standard_normal_pdf(self.alpha);
let phi_beta = standard_normal_pdf(self.beta);
let ratio = (phi_alpha - phi_beta) / self.z;
let correction = (self.alpha * phi_alpha - self.beta * phi_beta) / self.z;
self.sigma * self.sigma * (F::one() + correction - ratio * ratio)
}
}
#[derive(Debug, Clone)]
pub struct TruncatedExponential<F: Float> {
pub rate: F,
pub lower: F,
pub upper: F,
z: F,
}
impl<F: Float + NumCast + Debug + std::fmt::Display> TruncatedExponential<F> {
pub fn new(rate: F, lower: F, upper: F) -> StatsResult<Self> {
if rate <= F::zero() {
return Err(StatsError::InvalidArgument(
"rate must be positive".to_string(),
));
}
if lower < F::zero() {
return Err(StatsError::InvalidArgument(
"lower bound must be >= 0 for exponential distribution".to_string(),
));
}
if lower >= upper {
return Err(StatsError::InvalidArgument(
"lower bound must be strictly less than upper bound".to_string(),
));
}
let cdf_lower = F::one() - (-rate * lower).exp();
let cdf_upper = F::one() - (-rate * upper).exp();
let z = cdf_upper - cdf_lower;
if z <= F::zero() {
return Err(StatsError::InvalidArgument(
"Normalization constant is zero".to_string(),
));
}
Ok(Self {
rate,
lower,
upper,
z,
})
}
pub fn pdf(&self, x: F) -> F {
if x < self.lower || x > self.upper {
return F::zero();
}
self.rate * (-self.rate * x).exp() / self.z
}
pub fn logpdf(&self, x: F) -> F {
if x < self.lower || x > self.upper {
return F::neg_infinity();
}
self.rate.ln() - self.rate * x - self.z.ln()
}
pub fn cdf(&self, x: F) -> F {
if x <= self.lower {
return F::zero();
}
if x >= self.upper {
return F::one();
}
let cdf_lower = F::one() - (-self.rate * self.lower).exp();
((F::one() - (-self.rate * x).exp()) - cdf_lower) / self.z
}
pub fn ppf(&self, p: F) -> StatsResult<F> {
if p < F::zero() || p > F::one() {
return Err(StatsError::InvalidArgument(
"probability p must be in [0, 1]".to_string(),
));
}
if p == F::zero() {
return Ok(self.lower);
}
if p == F::one() {
return Ok(self.upper);
}
let cdf_lower = F::one() - (-self.rate * self.lower).exp();
let target_cdf = cdf_lower + p * self.z;
Ok(-(F::one() - target_cdf).ln() / self.rate)
}
pub fn rvs(&self, size: usize, seed: Option<u64>) -> StatsResult<Array1<F>> {
let mut rng = match seed {
Some(s) => StdRng::seed_from_u64(s),
None => {
let s = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(42);
StdRng::seed_from_u64(s)
}
};
let mut samples = Vec::with_capacity(size);
for _ in 0..size {
let u: f64 = rng.random();
let p = F::from(u).expect("F::from should not fail for f64 uniform sample in [0,1)");
let x = self.ppf(p)?;
samples.push(x);
}
Ok(Array1::from_vec(samples))
}
pub fn mean(&self) -> F {
let ea = (-self.rate * self.lower).exp();
let eb = (-self.rate * self.upper).exp();
let denom = ea - eb; if denom.abs() < F::from(1e-15).expect("F::from should not fail for 1e-15") {
(self.lower + self.upper) / F::from(2.0).expect("F::from should not fail for 2.0")
} else {
F::one() / self.rate + (self.lower * ea - self.upper * eb) / denom
}
}
pub fn var(&self) -> F {
let mu = self.mean();
let ea = (-self.rate * self.lower).exp();
let eb = (-self.rate * self.upper).exp();
let denom = ea - eb;
if denom.abs() < F::from(1e-15).expect("F::from should not fail for 1e-15") {
let half_width = (self.upper - self.lower) / F::from(2.0).expect("F::from should not fail for 2.0");
return half_width * half_width / F::from(3.0).expect("F::from should not fail for 3.0");
}
let two_over_lam2 = F::from(2.0).expect("F::from should not fail for 2.0") / (self.rate * self.rate);
let ex2 = two_over_lam2
+ (self.lower * self.lower * ea - self.upper * self.upper * eb) / denom
+ (F::from(2.0).expect("F::from should not fail for 2.0") / self.rate)
* (self.lower * ea - self.upper * eb)
/ denom;
let var_raw = ex2 - mu * mu;
if var_raw < F::zero() {
F::zero()
} else {
var_raw
}
}
}
#[derive(Debug, Clone)]
pub struct TruncatedGamma<F: Float> {
pub shape: F,
pub rate: F,
pub lower: F,
pub upper: F,
z: F,
}
impl<F: Float + NumCast + Debug + std::fmt::Display> TruncatedGamma<F> {
pub fn new(shape: F, rate: F, lower: F, upper: F) -> StatsResult<Self> {
if shape <= F::zero() {
return Err(StatsError::InvalidArgument(
"shape must be positive".to_string(),
));
}
if rate <= F::zero() {
return Err(StatsError::InvalidArgument(
"rate must be positive".to_string(),
));
}
if lower < F::zero() {
return Err(StatsError::InvalidArgument(
"lower bound must be >= 0 for gamma distribution".to_string(),
));
}
if lower >= upper {
return Err(StatsError::InvalidArgument(
"lower bound must be strictly less than upper bound".to_string(),
));
}
let cdf_lower = gamma_cdf(shape, rate, lower);
let cdf_upper = gamma_cdf(shape, rate, upper);
let z = cdf_upper - cdf_lower;
if z <= F::from(1e-15).expect("F::from should not fail for 1e-15") {
return Err(StatsError::InvalidArgument(
"Normalization constant is effectively zero; \
truncation interval may be too narrow or outside main support"
.to_string(),
));
}
Ok(Self {
shape,
rate,
lower,
upper,
z,
})
}
pub fn pdf(&self, x: F) -> F {
if x < self.lower || x > self.upper {
return F::zero();
}
gamma_pdf(self.shape, self.rate, x) / self.z
}
pub fn logpdf(&self, x: F) -> F {
if x < self.lower || x > self.upper {
return F::neg_infinity();
}
gamma_logpdf(self.shape, self.rate, x) - self.z.ln()
}
pub fn cdf(&self, x: F) -> F {
if x <= self.lower {
return F::zero();
}
if x >= self.upper {
return F::one();
}
let cdf_lower = gamma_cdf(self.shape, self.rate, self.lower);
(gamma_cdf(self.shape, self.rate, x) - cdf_lower) / self.z
}
pub fn ppf(&self, p: F) -> StatsResult<F> {
if p < F::zero() || p > F::one() {
return Err(StatsError::InvalidArgument(
"probability p must be in [0, 1]".to_string(),
));
}
if p == F::zero() {
return Ok(self.lower);
}
if p == F::one() {
return Ok(self.upper);
}
let mut lo = self.lower;
let mut hi = self.upper;
let eps = F::from(1e-10).expect("F::from should not fail for 1e-10");
for _ in 0..100 {
let mid = (lo + hi) / F::from(2.0).expect("F::from should not fail for 2.0");
if self.cdf(mid) < p {
lo = mid;
} else {
hi = mid;
}
if hi - lo < eps {
break;
}
}
Ok((lo + hi) / F::from(2.0).expect("F::from should not fail for 2.0"))
}
pub fn rvs(&self, size: usize, seed: Option<u64>) -> StatsResult<Array1<F>> {
let mut rng = match seed {
Some(s) => StdRng::seed_from_u64(s),
None => {
let s = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(42);
StdRng::seed_from_u64(s)
}
};
let mut samples = Vec::with_capacity(size);
for _ in 0..size {
let u: f64 = rng.random();
let p = F::from(u).expect("F::from should not fail for f64 uniform sample in [0,1)");
let x = self.ppf(p)?;
samples.push(x);
}
Ok(Array1::from_vec(samples))
}
pub fn mean(&self) -> F {
let n = 200usize;
let h = (self.upper - self.lower) / F::from(n).expect("F::from should not fail for usize n");
let mut sum = F::zero();
for i in 0..n {
let x = self.lower + h * F::from(i).expect("F::from should not fail for usize i") + h / F::from(2.0).expect("F::from should not fail for 2.0");
sum = sum + x * self.pdf(x);
}
sum * h
}
pub fn var(&self) -> F {
let mu = self.mean();
let n = 200usize;
let h = (self.upper - self.lower) / F::from(n).expect("F::from should not fail for usize n");
let mut sum = F::zero();
for i in 0..n {
let x = self.lower + h * F::from(i).expect("F::from should not fail for usize i") + h / F::from(2.0).expect("F::from should not fail for 2.0");
let diff = x - mu;
sum = sum + diff * diff * self.pdf(x);
}
let v = sum * h;
if v < F::zero() { F::zero() } else { v }
}
}
fn gamma_pdf<F: Float>(shape: F, rate: F, x: F) -> F {
if x <= F::zero() {
return F::zero();
}
gamma_logpdf(shape, rate, x).exp()
}
fn gamma_logpdf<F: Float>(shape: F, rate: F, x: F) -> F {
if x <= F::zero() {
return F::neg_infinity();
}
shape * rate.ln() - log_gamma(shape) + (shape - F::one()) * x.ln() - rate * x
}
fn gamma_cdf<F: Float>(shape: F, rate: F, x: F) -> F {
if x <= F::zero() {
return F::zero();
}
let x_scaled: F = rate * x;
regularized_inc_gamma(shape, x_scaled)
}
fn regularized_inc_gamma<F: Float>(a: F, x: F) -> F {
if x < F::zero() {
return F::zero();
}
if x == F::zero() {
return F::zero();
}
if x < a + F::one() {
inc_gamma_series(a, x)
} else {
F::one() - inc_gamma_cf(a, x)
}
}
fn inc_gamma_series<F: Float>(a: F, x: F) -> F {
let max_iter = 200;
let eps = F::from(3e-7).expect("F::from should not fail for 3e-7");
let log_gam_a = log_gamma(a);
let lnx = x.ln();
let mut ap = a;
let mut sum = F::one() / a;
let mut del = sum;
for _ in 0..max_iter {
ap = ap + F::one();
del = del * x / ap;
sum = sum + del;
if del.abs() < sum.abs() * eps {
break;
}
}
let result = sum * (-(x) + a * lnx - log_gam_a).exp();
if result > F::one() {
F::one()
} else {
result
}
}
fn inc_gamma_cf<F: Float>(a: F, x: F) -> F {
let max_iter = 200;
let eps = F::from(3e-7).expect("F::from should not fail for 3e-7");
let fpmin = F::from(1e-300).expect("F::from should not fail for 1e-300");
let log_gam_a = log_gamma(a);
let lnx = x.ln();
let mut b = x + F::one() - a;
let mut c = F::one() / fpmin;
let mut d = F::one() / b;
let mut h = d;
for i in 1..=max_iter {
let fi = F::from(i).expect("F::from should not fail for loop index i");
let an = -fi * (fi - a);
b = b + F::from(2.0).expect("F::from should not fail for 2.0");
d = an * d + b;
if d.abs() < fpmin {
d = fpmin;
}
c = b + an / c;
if c.abs() < fpmin {
c = fpmin;
}
d = F::one() / d;
let del = d * c;
h = h * del;
if (del - F::one()).abs() < eps {
break;
}
}
let result = h * (-(x) + a * lnx - log_gam_a).exp();
if result > F::one() {
F::one()
} else {
result
}
}
fn log_gamma<F: Float>(x: F) -> F {
let c = [
0.99999999999980993_f64,
676.5203681218851_f64,
-1259.1392167224028_f64,
771.323_428_777_653_1_f64,
-176.615_029_162_140_6_f64,
12.507_343_278_686_905_f64,
-0.138_571_095_265_720_12_f64,
9.984_369_578_019_572e-6_f64,
1.505_632_735_149_311_6e-7_f64,
];
let x_f64: f64 = NumCast::from(x).unwrap_or(1.0);
if x_f64 < 0.5 {
let result =
std::f64::consts::PI.ln() - (std::f64::consts::PI * x_f64).sin().ln() - log_gamma_f64(1.0 - x_f64);
return F::from(result).expect("F::from should not fail for log_gamma result");
}
F::from(log_gamma_f64(x_f64)).expect("F::from should not fail for log_gamma_f64 result")
}
fn log_gamma_f64(x: f64) -> f64 {
let c = [
0.99999999999980993_f64,
676.5203681218851_f64,
-1259.1392167224028_f64,
771.3234287776531_f64,
-176.6150291621406_f64,
12.507343278686905_f64,
-0.13857109526572012_f64,
9.984369578019572e-6_f64,
1.5056327351493116e-7_f64,
];
if x < 0.5 {
let v = std::f64::consts::PI - (std::f64::consts::PI * x).sin().ln() - log_gamma_f64(1.0 - x);
return v;
}
let x = x - 1.0;
let mut s = c[0];
for (i, &ci) in c[1..].iter().enumerate() {
s += ci / (x + (i + 1) as f64);
}
let t = x + 7.5; (2.0 * std::f64::consts::PI).sqrt().ln() + t.ln() * (x + 0.5) - t + s.ln()
}
#[derive(Debug, Clone)]
pub struct TruncatedBeta<F: Float> {
pub alpha: F,
pub beta_param: F,
pub lower: F,
pub upper: F,
z: F,
}
impl<F: Float + NumCast + Debug + std::fmt::Display> TruncatedBeta<F> {
pub fn new(alpha: F, beta_param: F, lower: F, upper: F) -> StatsResult<Self> {
if alpha <= F::zero() {
return Err(StatsError::InvalidArgument(
"alpha must be positive".to_string(),
));
}
if beta_param <= F::zero() {
return Err(StatsError::InvalidArgument(
"beta must be positive".to_string(),
));
}
if lower < F::zero() || lower >= F::one() {
return Err(StatsError::InvalidArgument(
"lower must be in [0, 1)".to_string(),
));
}
if upper <= lower || upper > F::one() {
return Err(StatsError::InvalidArgument(
"upper must be in (lower, 1]".to_string(),
));
}
let cdf_lower = beta_cdf(alpha, beta_param, lower);
let cdf_upper = beta_cdf(alpha, beta_param, upper);
let z = cdf_upper - cdf_lower;
if z <= F::from(1e-15).expect("F::from should not fail for 1e-15") {
return Err(StatsError::InvalidArgument(
"Normalization constant is effectively zero".to_string(),
));
}
Ok(Self {
alpha,
beta_param,
lower,
upper,
z,
})
}
pub fn pdf(&self, x: F) -> F {
if x < self.lower || x > self.upper {
return F::zero();
}
beta_pdf(self.alpha, self.beta_param, x) / self.z
}
pub fn cdf(&self, x: F) -> F {
if x <= self.lower {
return F::zero();
}
if x >= self.upper {
return F::one();
}
let cdf_lower = beta_cdf(self.alpha, self.beta_param, self.lower);
(beta_cdf(self.alpha, self.beta_param, x) - cdf_lower) / self.z
}
pub fn ppf(&self, p: F) -> StatsResult<F> {
if p < F::zero() || p > F::one() {
return Err(StatsError::InvalidArgument(
"probability p must be in [0, 1]".to_string(),
));
}
if p == F::zero() {
return Ok(self.lower);
}
if p == F::one() {
return Ok(self.upper);
}
let mut lo = self.lower;
let mut hi = self.upper;
let eps = F::from(1e-10).expect("F::from should not fail for 1e-10");
for _ in 0..100 {
let mid = (lo + hi) / F::from(2.0).expect("F::from should not fail for 2.0");
if self.cdf(mid) < p {
lo = mid;
} else {
hi = mid;
}
if hi - lo < eps {
break;
}
}
Ok((lo + hi) / F::from(2.0).expect("F::from should not fail for 2.0"))
}
pub fn rvs(&self, size: usize, seed: Option<u64>) -> StatsResult<Array1<F>> {
let mut rng = match seed {
Some(s) => StdRng::seed_from_u64(s),
None => {
let s = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(42);
StdRng::seed_from_u64(s)
}
};
let mut samples = Vec::with_capacity(size);
for _ in 0..size {
let u: f64 = rng.random();
let p = F::from(u).expect("F::from should not fail for f64 uniform sample in [0,1)");
let x = self.ppf(p)?;
samples.push(x);
}
Ok(Array1::from_vec(samples))
}
pub fn mean(&self) -> F {
let n = 200usize;
let h = (self.upper - self.lower) / F::from(n).expect("F::from should not fail for usize n");
let mut sum = F::zero();
for i in 0..n {
let x = self.lower + h * F::from(i).expect("F::from should not fail for usize i") + h / F::from(2.0).expect("F::from should not fail for 2.0");
sum = sum + x * self.pdf(x);
}
sum * h
}
pub fn var(&self) -> F {
let mu = self.mean();
let n = 200usize;
let h = (self.upper - self.lower) / F::from(n).expect("F::from should not fail for usize n");
let mut sum = F::zero();
for i in 0..n {
let x = self.lower + h * F::from(i).expect("F::from should not fail for usize i") + h / F::from(2.0).expect("F::from should not fail for 2.0");
let diff = x - mu;
sum = sum + diff * diff * self.pdf(x);
}
let v = sum * h;
if v < F::zero() { F::zero() } else { v }
}
}
fn beta_pdf<F: Float>(a: F, b: F, x: F) -> F {
if x <= F::zero() || x >= F::one() {
return F::zero();
}
let log_beta_ab = log_beta(a, b);
((a - F::one()) * x.ln() + (b - F::one()) * (F::one() - x).ln() - log_beta_ab).exp()
}
fn beta_cdf<F: Float>(a: F, b: F, x: F) -> F {
if x <= F::zero() {
return F::zero();
}
if x >= F::one() {
return F::one();
}
regularized_inc_beta(a, b, x)
}
fn log_beta<F: Float>(a: F, b: F) -> F {
log_gamma(a) + log_gamma(b) - log_gamma(a + b)
}
fn regularized_inc_beta<F: Float>(a: F, b: F, x: F) -> F {
if x < (a + F::one()) / (a + b + F::from(2.0).expect("F::from should not fail for 2.0")) {
inc_beta_cf(a, b, x) * beta_prefactor(a, b, x)
} else {
F::one() - inc_beta_cf(b, a, F::one() - x) * beta_prefactor(b, a, F::one() - x)
}
}
fn beta_prefactor<F: Float>(a: F, b: F, x: F) -> F {
(a * x.ln() + b * (F::one() - x).ln() - log_beta(a, b)).exp() / a
}
fn inc_beta_cf<F: Float>(a: F, b: F, x: F) -> F {
let max_iter = 200;
let eps = F::from(3e-7).expect("F::from should not fail for 3e-7");
let fpmin = F::from(1e-300).expect("F::from should not fail for 1e-300");
let qab = a + b;
let qap = a + F::one();
let qam = a - F::one();
let mut c = F::one();
let mut d = F::one() - qab * x / qap;
if d.abs() < fpmin {
d = fpmin;
}
d = F::one() / d;
let mut h = d;
for m in 1..=max_iter {
let mf = F::from(m).expect("F::from should not fail for loop index m");
let two_mf = F::from(2 * m).expect("F::from should not fail for 2*m");
let aa = mf * (b - mf) * x / ((qam + two_mf) * (a + two_mf));
d = F::one() + aa * d;
if d.abs() < fpmin {
d = fpmin;
}
c = F::one() + aa / c;
if c.abs() < fpmin {
c = fpmin;
}
d = F::one() / d;
h = h * d * c;
let aa = -(a + mf) * (qab + mf) * x / ((a + two_mf) * (qap + two_mf));
d = F::one() + aa * d;
if d.abs() < fpmin {
d = fpmin;
}
c = F::one() + aa / c;
if c.abs() < fpmin {
c = fpmin;
}
d = F::one() / d;
let del = d * c;
h = h * del;
if (del - F::one()).abs() < eps {
break;
}
}
h
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_truncated_normal_pdf_zero_outside() {
let tn = TruncatedNormal::new(0.0f64, 1.0, -2.0, 2.0).unwrap();
assert_eq!(tn.pdf(-3.0), 0.0);
assert_eq!(tn.pdf(3.0), 0.0);
}
#[test]
fn test_truncated_normal_pdf_positive_inside() {
let tn = TruncatedNormal::new(0.0f64, 1.0, -2.0, 2.0).unwrap();
assert!(tn.pdf(0.0) > 0.0);
assert!(tn.pdf(-1.0) > 0.0);
assert!(tn.pdf(1.0) > 0.0);
}
#[test]
fn test_truncated_normal_pdf_integrates_to_one() {
let tn = TruncatedNormal::new(0.0f64, 1.0, -2.0, 2.0).unwrap();
let n = 1000;
let h = 4.0 / n as f64;
let mut sum = 0.0f64;
for i in 0..n {
let x = -2.0 + h * i as f64 + h / 2.0;
sum += tn.pdf(x);
}
sum *= h;
assert!((sum - 1.0).abs() < 1e-5, "Integral = {}", sum);
}
#[test]
fn test_truncated_normal_cdf_bounds() {
let tn = TruncatedNormal::new(0.0f64, 1.0, -2.0, 2.0).unwrap();
assert_eq!(tn.cdf(-3.0), 0.0);
assert_eq!(tn.cdf(-2.0), 0.0);
assert_eq!(tn.cdf(2.0), 1.0);
assert_eq!(tn.cdf(3.0), 1.0);
let mid = tn.cdf(0.0);
assert!(mid > 0.0 && mid < 1.0, "CDF at mean = {}", mid);
}
#[test]
fn test_truncated_normal_cdf_monotone() {
let tn = TruncatedNormal::new(0.0f64, 1.0, -2.0, 2.0).unwrap();
let xs = [-1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5];
let mut prev = 0.0f64;
for &x in &xs {
let c = tn.cdf(x);
assert!(c > prev, "CDF not monotone at x={}: {} <= {}", x, c, prev);
prev = c;
}
}
#[test]
fn test_truncated_normal_mean() {
let tn = TruncatedNormal::new(0.0f64, 1.0, 0.0, 10.0).unwrap();
let mean = tn.mean();
assert!(
(mean - 0.7979).abs() < 0.01,
"Half-normal mean = {}, expected ≈ 0.7979",
mean
);
}
#[test]
fn test_truncated_normal_ppf_roundtrip() {
let tn = TruncatedNormal::new(0.0f64, 1.0, -2.0, 2.0).unwrap();
for p in [0.1, 0.25, 0.5, 0.75, 0.9] {
let x = tn.ppf(p).unwrap();
let cdf_x = tn.cdf(x);
assert!(
(cdf_x - p).abs() < 1e-5,
"PPF roundtrip failed: ppf({}) = {}, cdf = {}",
p,
x,
cdf_x
);
}
}
#[test]
fn test_truncated_normal_rvs() {
let tn = TruncatedNormal::new(0.0f64, 1.0, -2.0, 2.0).unwrap();
let samples = tn.rvs(500, Some(42)).unwrap();
assert_eq!(samples.len(), 500);
for &s in samples.iter() {
assert!(
s >= -2.0 && s <= 2.0,
"Sample {} outside bounds [-2, 2]",
s
);
}
}
#[test]
fn test_truncated_normal_invalid_params() {
assert!(TruncatedNormal::new(0.0f64, -1.0, -2.0, 2.0).is_err());
assert!(TruncatedNormal::new(0.0f64, 1.0, 2.0, -2.0).is_err());
assert!(TruncatedNormal::new(0.0f64, 1.0, 2.0, 2.0).is_err());
}
#[test]
fn test_truncated_exponential_pdf_bounds() {
let te = TruncatedExponential::new(1.0f64, 0.0, 3.0).unwrap();
assert_eq!(te.pdf(-1.0), 0.0);
assert_eq!(te.pdf(4.0), 0.0);
assert!(te.pdf(1.0) > 0.0);
}
#[test]
fn test_truncated_exponential_cdf_bounds() {
let te = TruncatedExponential::new(1.0f64, 0.0, 3.0).unwrap();
assert_eq!(te.cdf(-1.0), 0.0);
assert_eq!(te.cdf(0.0), 0.0);
assert_eq!(te.cdf(3.0), 1.0);
}
#[test]
fn test_truncated_exponential_pdf_integrates_to_one() {
let te = TruncatedExponential::new(2.0f64, 0.5, 3.0).unwrap();
let n = 1000;
let h = 2.5 / n as f64;
let mut sum = 0.0f64;
for i in 0..n {
let x = 0.5 + h * i as f64 + h / 2.0;
sum += te.pdf(x);
}
sum *= h;
assert!((sum - 1.0).abs() < 1e-5, "Integral = {}", sum);
}
#[test]
fn test_truncated_exponential_mean() {
let te = TruncatedExponential::new(1.0f64, 0.0, 3.0).expect("valid truncated exponential");
let mean = te.mean();
assert!(
(mean - 0.8428).abs() < 0.005,
"Mean = {}, expected ≈ 0.8428",
mean
);
}
#[test]
fn test_truncated_exponential_ppf_roundtrip() {
let te = TruncatedExponential::new(1.5f64, 0.2, 4.0).unwrap();
for p in [0.1, 0.3, 0.5, 0.7, 0.9] {
let x = te.ppf(p).unwrap();
let cdf_x = te.cdf(x);
assert!(
(cdf_x - p).abs() < 1e-8,
"PPF roundtrip: ppf({}) = {}, cdf = {}",
p,
x,
cdf_x
);
}
}
#[test]
fn test_truncated_exponential_rvs() {
let te = TruncatedExponential::new(1.0f64, 0.5, 5.0).unwrap();
let samples = te.rvs(200, Some(123)).unwrap();
assert_eq!(samples.len(), 200);
for &s in samples.iter() {
assert!(s >= 0.5 && s <= 5.0, "Sample {} outside [0.5, 5]", s);
}
}
#[test]
fn test_truncated_gamma_pdf_bounds() {
let tg = TruncatedGamma::new(2.0f64, 1.0, 0.0, 5.0).unwrap();
assert_eq!(tg.pdf(-1.0), 0.0);
assert_eq!(tg.pdf(6.0), 0.0);
assert!(tg.pdf(2.0) > 0.0);
}
#[test]
fn test_truncated_gamma_pdf_integrates_to_one() {
let tg = TruncatedGamma::new(2.0f64, 1.0, 0.5, 5.0).unwrap();
let n = 1000;
let h = 4.5 / n as f64;
let mut sum = 0.0f64;
for i in 0..n {
let x = 0.5 + h * i as f64 + h / 2.0;
sum += tg.pdf(x);
}
sum *= h;
assert!((sum - 1.0).abs() < 1e-4, "Integral = {}", sum);
}
#[test]
fn test_truncated_gamma_cdf_monotone() {
let tg = TruncatedGamma::new(3.0f64, 0.5, 1.0, 8.0).unwrap();
let xs = [2.0, 3.0, 4.0, 5.0, 6.0];
let mut prev = 0.0f64;
for &x in &xs {
let c = tg.cdf(x);
assert!(c > prev, "CDF not monotone at {}", x);
prev = c;
}
}
#[test]
fn test_truncated_gamma_ppf_roundtrip() {
let tg = TruncatedGamma::new(2.0f64, 0.5, 0.0, 10.0).unwrap();
for p in [0.2, 0.4, 0.6, 0.8] {
let x = tg.ppf(p).unwrap();
let c = tg.cdf(x);
assert!(
(c - p).abs() < 1e-5,
"PPF roundtrip: ppf({}) = {}, cdf = {}",
p,
x,
c
);
}
}
#[test]
fn test_truncated_gamma_rvs() {
let tg = TruncatedGamma::new(2.0f64, 1.0, 0.5, 5.0).unwrap();
let samples = tg.rvs(200, Some(99)).unwrap();
assert_eq!(samples.len(), 200);
for &s in samples.iter() {
assert!(s >= 0.5 && s <= 5.0, "Sample {} outside [0.5, 5]", s);
}
}
#[test]
fn test_truncated_beta_pdf_bounds() {
let tb = TruncatedBeta::new(2.0f64, 5.0, 0.1, 0.9).unwrap();
assert_eq!(tb.pdf(-0.1), 0.0);
assert_eq!(tb.pdf(0.95), 0.0);
assert!(tb.pdf(0.3) > 0.0);
}
#[test]
fn test_truncated_beta_pdf_integrates_to_one() {
let tb = TruncatedBeta::new(2.0f64, 3.0, 0.1, 0.9).unwrap();
let n = 1000;
let h = 0.8 / n as f64;
let mut sum = 0.0f64;
for i in 0..n {
let x = 0.1 + h * i as f64 + h / 2.0;
sum += tb.pdf(x);
}
sum *= h;
assert!((sum - 1.0).abs() < 1e-4, "Integral = {}", sum);
}
#[test]
fn test_truncated_beta_ppf_roundtrip() {
let tb = TruncatedBeta::new(3.0f64, 2.0, 0.0, 0.8).unwrap();
for p in [0.1, 0.3, 0.5, 0.7, 0.9] {
let x = tb.ppf(p).unwrap();
let c = tb.cdf(x);
assert!(
(c - p).abs() < 1e-5,
"PPF roundtrip: ppf({}) = {}, cdf = {}",
p,
x,
c
);
}
}
}