use crate::error::{StatsError, StatsResult};
#[inline]
fn exponential_pdf(x: f64, lambda: f64) -> StatsResult<f64> {
if x < 0.0 {
return Err(StatsError::InvalidInput {
message: "exponential_pdf: x must be non-negative".to_string(),
});
}
if lambda <= 0.0 {
return Err(StatsError::InvalidInput {
message: "exponential_pdf: lambda must be positive".to_string(),
});
}
Ok(if x == 0.0 {
lambda
} else {
lambda * (-lambda * x).exp()
})
}
#[inline]
fn exponential_cdf(x: f64, lambda: f64) -> StatsResult<f64> {
if x < 0.0 {
return Err(StatsError::InvalidInput {
message: "exponential_cdf: x must be non-negative".to_string(),
});
}
if lambda <= 0.0 {
return Err(StatsError::InvalidInput {
message: "exponential_cdf: lambda must be positive".to_string(),
});
}
Ok(1.0 - (-lambda * x).exp())
}
#[inline]
fn exponential_inverse_cdf(p: f64, lambda: f64) -> StatsResult<f64> {
if !(0.0..=1.0).contains(&p) {
return Err(StatsError::InvalidInput {
message: "exponential_inverse_cdf: p must be between 0 and 1".to_string(),
});
}
if lambda <= 0.0 {
return Err(StatsError::InvalidInput {
message: "exponential_inverse_cdf: lambda must be positive".to_string(),
});
}
Ok(-((1.0 - p).ln()) / lambda)
}
#[derive(Debug, Clone, Copy)]
pub struct Exponential {
pub lambda: f64,
}
impl Exponential {
pub fn new(lambda: f64) -> StatsResult<Self> {
if lambda <= 0.0 {
return Err(StatsError::InvalidInput {
message: "Exponential::new: lambda must be positive".to_string(),
});
}
Ok(Self { lambda })
}
pub fn fit(data: &[f64]) -> StatsResult<Self> {
if data.is_empty() {
return Err(StatsError::InvalidInput {
message: "Exponential::fit: data must not be empty".to_string(),
});
}
if data.iter().any(|&x| x < 0.0) {
return Err(StatsError::InvalidInput {
message: "Exponential::fit: all data values must be non-negative".to_string(),
});
}
let mean = data.iter().sum::<f64>() / data.len() as f64;
Self::new(1.0 / mean)
}
}
impl crate::distributions::traits::Distribution for Exponential {
type X = f64;
fn name(&self) -> &str {
"Exponential"
}
fn num_params(&self) -> usize {
1
}
fn pdf(&self, x: f64) -> StatsResult<f64> {
exponential_pdf(x, self.lambda)
}
fn logpdf(&self, x: f64) -> StatsResult<f64> {
if x < 0.0 {
return Err(StatsError::InvalidInput {
message: "Exponential::logpdf: x must be non-negative".to_string(),
});
}
Ok(self.lambda.ln() - self.lambda * x)
}
fn cdf(&self, x: f64) -> StatsResult<f64> {
exponential_cdf(x, self.lambda)
}
fn inverse_cdf(&self, p: f64) -> StatsResult<f64> {
exponential_inverse_cdf(p, self.lambda)
}
fn mean(&self) -> f64 {
1.0 / self.lambda
}
fn variance(&self) -> f64 {
1.0 / (self.lambda * self.lambda)
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f64 = 1e-10;
#[test]
fn test_exponential_pdf() {
let lambda = 2.0;
let result = exponential_pdf(0.0, lambda).unwrap();
assert_eq!(result, lambda);
let result = exponential_pdf(1.0, lambda).unwrap();
let expected = lambda * (-lambda).exp();
assert!((result - expected).abs() < EPSILON);
let result = exponential_pdf(0.5, lambda).unwrap();
let expected = lambda * (-lambda * 0.5).exp();
assert!((result - expected).abs() < EPSILON);
}
#[test]
fn test_exponential_cdf() {
let lambda = 2.0_f64;
let result = exponential_cdf(0.0, lambda).unwrap();
assert!((result - 0.0).abs() < EPSILON);
let result = exponential_cdf(1.0, lambda).unwrap();
let expected = 1.0 - (-lambda).exp();
assert!((result - expected).abs() < EPSILON);
let result = exponential_cdf(0.5, lambda).unwrap();
let expected = 1.0 - (-lambda * 0.5).exp();
assert!((result - expected).abs() < EPSILON);
}
#[test]
fn test_exponential_inverse_cdf() {
let lambda = 2.0_f64;
let test_cases = vec![0.1, 0.25, 0.5, 0.75, 0.9];
for p in test_cases {
let x = exponential_inverse_cdf(p, lambda).unwrap();
let cdf = exponential_cdf(x, lambda).unwrap();
assert!(
(cdf - p).abs() < EPSILON,
"Inverse CDF failed for p = {}: got {}, expected {}",
p,
cdf,
p
);
}
}
#[test]
fn test_exponential_mean_via_trait() {
use crate::distributions::traits::Distribution;
let dist = Exponential::new(2.0).unwrap();
assert!((dist.mean() - 0.5).abs() < EPSILON);
assert!((dist.variance() - 0.25).abs() < EPSILON);
}
#[test]
fn test_exponential_pdf_invalid_lambda() {
let result = exponential_pdf(1.0, -2.0);
assert!(result.is_err());
match result {
Err(StatsError::InvalidInput { message }) => {
assert!(message.contains("lambda must be positive"));
}
_ => panic!("Expected InvalidInput error"),
}
}
#[test]
fn test_exponential_pdf_invalid_x() {
let result = exponential_pdf(-1.0, 2.0);
assert!(result.is_err());
match result {
Err(StatsError::InvalidInput { message }) => {
assert!(message.contains("x must be non-negative"));
}
_ => panic!("Expected InvalidInput error"),
}
}
#[test]
fn test_exponential_inverse_cdf_p_negative() {
let result = exponential_inverse_cdf(-0.1, 2.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_exponential_inverse_cdf_p_greater_than_one() {
let result = exponential_inverse_cdf(1.5, 2.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_exponential_cdf_invalid_lambda() {
let result = exponential_cdf(1.0, -2.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_exponential_cdf_invalid_x() {
let result = exponential_cdf(-1.0, 2.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_exponential_pdf_x_positive() {
let result = exponential_pdf(0.5, 2.0).unwrap();
let lambda: f64 = 2.0;
let x: f64 = 0.5;
let expected = lambda * (-lambda * x).exp();
assert!((result - expected).abs() < EPSILON);
}
#[test]
fn test_exponential_inverse_cdf_p_zero() {
let result = exponential_inverse_cdf(0.0, 2.0).unwrap();
assert_eq!(result, 0.0);
}
#[test]
fn test_exponential_inverse_cdf_p_one() {
let result = exponential_inverse_cdf(1.0, 2.0).unwrap();
assert!(result.is_infinite() || result > 1e10);
}
#[test]
fn test_exponential_inverse_cdf_lambda_zero() {
let result = exponential_inverse_cdf(0.5, 0.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_exponential_inverse_cdf_lambda_negative() {
let result = exponential_inverse_cdf(0.5, -1.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
}