use crate::error::{StatsError, StatsResult};
use num_traits::ToPrimitive;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct ExponentialConfig<T>
where
T: ToPrimitive,
{
pub lambda: T,
}
impl<T> ExponentialConfig<T>
where
T: ToPrimitive,
{
pub fn new(lambda: T) -> StatsResult<Self> {
let lambda_64 = lambda.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "ExponentialConfig::new: Failed to convert lambda to f64".to_string(),
})?;
if lambda_64 > 0.0 {
Ok(Self { lambda })
} else {
Err(StatsError::InvalidInput {
message: "ExponentialConfig::new: lambda must be positive".to_string(),
})
}
}
}
#[inline]
pub fn exponential_pdf<T>(x: T, lambda: T) -> StatsResult<f64>
where
T: ToPrimitive,
{
let x_64 = x.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "exponential_pdf: Failed to convert x to f64".to_string(),
})?;
if x_64 < 0.0 {
return Err(StatsError::InvalidInput {
message: "exponential_pdf: x must be non-negative".to_string(),
});
}
let lambda_64 = lambda.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "exponential_pdf: Failed to convert lambda to f64".to_string(),
})?;
if lambda_64 <= 0.0 {
return Err(StatsError::InvalidInput {
message: "exponential_pdf: lambda must be positive".to_string(),
});
}
Ok(if x_64 == 0.0 {
lambda_64
} else {
lambda_64 * (-lambda_64 * x_64).exp()
})
}
#[inline]
pub fn exponential_cdf<T>(x: T, lambda: T) -> StatsResult<f64>
where
T: ToPrimitive,
{
let x_64 = x.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "exponential_cdf: Failed to convert x to f64".to_string(),
})?;
if x_64 < 0.0 {
return Err(StatsError::InvalidInput {
message: "exponential_cdf: x must be non-negative".to_string(),
});
}
let lambda_64 = lambda.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "exponential_cdf: Failed to convert lambda to f64".to_string(),
})?;
if lambda_64 <= 0.0 {
return Err(StatsError::InvalidInput {
message: "exponential_cdf: lambda must be positive".to_string(),
});
}
Ok(1.0 - (-lambda_64 * x_64).exp())
}
#[inline]
pub fn exponential_inverse_cdf<T>(p: T, lambda: T) -> StatsResult<f64>
where
T: ToPrimitive,
{
let p_64 = p.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "exponential_inverse_cdf: Failed to convert p to f64".to_string(),
})?;
if !(0.0..=1.0).contains(&p_64) {
return Err(StatsError::InvalidInput {
message: "exponential_inverse_cdf: p must be between 0 and 1".to_string(),
});
}
let lambda_64 = lambda.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "exponential_inverse_cdf: Failed to convert lambda to f64".to_string(),
})?;
if lambda_64 <= 0.0 {
return Err(StatsError::InvalidInput {
message: "exponential_inverse_cdf: lambda must be positive".to_string(),
});
}
Ok(-((1.0 - p_64).ln()) / lambda_64)
}
#[inline]
pub fn exponential_mean<T>(lambda: T) -> StatsResult<f64>
where
T: ToPrimitive,
{
let lambda_64 = lambda.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "exponential_mean: Failed to convert lambda to f64".to_string(),
})?;
if lambda_64 <= 0.0 {
return Err(StatsError::InvalidInput {
message: "exponential_mean: lambda must be positive".to_string(),
});
}
Ok(1.0 / lambda_64)
}
#[inline]
pub fn exponential_variance(lambda: f64) -> StatsResult<f64> {
if lambda <= 0.0 {
return Err(StatsError::InvalidInput {
message: "exponential_variance: lambda must be positive".to_string(),
});
}
Ok(1.0 / (lambda * lambda))
}
#[derive(Debug, Clone, Copy)]
pub struct Exponential {
pub lambda: f64,
}
impl Exponential {
pub fn new(lambda: f64) -> StatsResult<Self> {
if lambda <= 0.0 {
return Err(StatsError::InvalidInput {
message: "Exponential::new: lambda must be positive".to_string(),
});
}
Ok(Self { lambda })
}
pub fn fit(data: &[f64]) -> StatsResult<Self> {
if data.is_empty() {
return Err(StatsError::InvalidInput {
message: "Exponential::fit: data must not be empty".to_string(),
});
}
if data.iter().any(|&x| x < 0.0) {
return Err(StatsError::InvalidInput {
message: "Exponential::fit: all data values must be non-negative".to_string(),
});
}
let mean = data.iter().sum::<f64>() / data.len() as f64;
Self::new(1.0 / mean)
}
}
impl crate::distributions::traits::Distribution for Exponential {
fn name(&self) -> &str {
"Exponential"
}
fn num_params(&self) -> usize {
1
}
fn pdf(&self, x: f64) -> StatsResult<f64> {
exponential_pdf(x, self.lambda)
}
fn logpdf(&self, x: f64) -> StatsResult<f64> {
if x < 0.0 {
return Err(StatsError::InvalidInput {
message: "Exponential::logpdf: x must be non-negative".to_string(),
});
}
Ok(self.lambda.ln() - self.lambda * x)
}
fn cdf(&self, x: f64) -> StatsResult<f64> {
exponential_cdf(x, self.lambda)
}
fn inverse_cdf(&self, p: f64) -> StatsResult<f64> {
exponential_inverse_cdf(p, self.lambda)
}
fn mean(&self) -> f64 {
1.0 / self.lambda
}
fn variance(&self) -> f64 {
1.0 / (self.lambda * self.lambda)
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f64 = 1e-10;
#[test]
fn test_exponential_pdf() {
let lambda = 2.0;
let result = exponential_pdf(0.0, lambda).unwrap();
assert_eq!(result, lambda);
let result = exponential_pdf(1.0, lambda).unwrap();
let expected = lambda * (-lambda).exp();
assert!((result - expected).abs() < EPSILON);
let result = exponential_pdf(0.5, lambda).unwrap();
let expected = lambda * (-lambda * 0.5).exp();
assert!((result - expected).abs() < EPSILON);
}
#[test]
fn test_exponential_cdf() {
let lambda = 2.0_f64;
let result = exponential_cdf(0.0, lambda).unwrap();
assert!((result - 0.0).abs() < EPSILON);
let result = exponential_cdf(1.0, lambda).unwrap();
let expected = 1.0 - (-lambda).exp();
assert!((result - expected).abs() < EPSILON);
let result = exponential_cdf(0.5, lambda).unwrap();
let expected = 1.0 - (-lambda * 0.5).exp();
assert!((result - expected).abs() < EPSILON);
}
#[test]
fn test_exponential_inverse_cdf() {
let lambda = 2.0_f64;
let test_cases = vec![0.1, 0.25, 0.5, 0.75, 0.9];
for p in test_cases {
let x = exponential_inverse_cdf(p, lambda).unwrap();
let cdf = exponential_cdf(x, lambda).unwrap();
assert!(
(cdf - p).abs() < EPSILON,
"Inverse CDF failed for p = {}: got {}, expected {}",
p,
cdf,
p
);
}
}
#[test]
fn test_exponential_mean() {
let lambda = 2.0;
let result = exponential_mean(lambda).unwrap();
let expected = 1.0 / lambda;
assert!((result - expected).abs() < EPSILON);
}
#[test]
fn test_exponential_variance() {
let lambda = 2.0;
let result = exponential_variance(lambda).unwrap();
let expected = 1.0 / (lambda * lambda);
assert!((result - expected).abs() < EPSILON);
}
#[test]
fn test_exponential_pdf_invalid_lambda() {
let result = exponential_pdf(1.0, -2.0);
assert!(result.is_err());
match result {
Err(StatsError::InvalidInput { message }) => {
assert!(message.contains("lambda must be positive"));
}
_ => panic!("Expected InvalidInput error"),
}
}
#[test]
fn test_exponential_pdf_invalid_x() {
let result = exponential_pdf(-1.0, 2.0);
assert!(result.is_err());
match result {
Err(StatsError::InvalidInput { message }) => {
assert!(message.contains("x must be non-negative"));
}
_ => panic!("Expected InvalidInput error"),
}
}
#[test]
fn test_exponential_config() {
let config = ExponentialConfig::new(2.0);
assert!(config.is_ok());
let config = ExponentialConfig::new(0.0);
assert!(config.is_err());
let config = ExponentialConfig::new(-1.0);
assert!(config.is_err());
}
#[test]
fn test_exponential_inverse_cdf_p_negative() {
let result = exponential_inverse_cdf(-0.1, 2.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_exponential_inverse_cdf_p_greater_than_one() {
let result = exponential_inverse_cdf(1.5, 2.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_exponential_cdf_invalid_lambda() {
let result = exponential_cdf(1.0, -2.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_exponential_cdf_invalid_x() {
let result = exponential_cdf(-1.0, 2.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_exponential_mean_invalid_lambda() {
let result = exponential_mean(0.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_exponential_variance_invalid_lambda() {
let result = exponential_variance(0.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_exponential_pdf_x_positive() {
let result = exponential_pdf(0.5, 2.0).unwrap();
let lambda: f64 = 2.0;
let x: f64 = 0.5;
let expected = lambda * (-lambda * x).exp();
assert!((result - expected).abs() < EPSILON);
}
#[test]
fn test_exponential_inverse_cdf_p_zero() {
let result = exponential_inverse_cdf(0.0, 2.0).unwrap();
assert_eq!(result, 0.0);
}
#[test]
fn test_exponential_inverse_cdf_p_one() {
let result = exponential_inverse_cdf(1.0, 2.0).unwrap();
assert!(result.is_infinite() || result > 1e10);
}
#[test]
fn test_exponential_inverse_cdf_lambda_zero() {
let result = exponential_inverse_cdf(0.5, 0.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_exponential_inverse_cdf_lambda_negative() {
let result = exponential_inverse_cdf(0.5, -1.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
}