use crate::error::{StatsError, StatsResult};
pub trait Distribution {
fn name(&self) -> &str;
fn num_params(&self) -> usize;
fn pdf(&self, x: f64) -> StatsResult<f64>;
fn logpdf(&self, x: f64) -> StatsResult<f64> {
self.pdf(x).map(|p| p.ln())
}
fn cdf(&self, x: f64) -> StatsResult<f64>;
fn inverse_cdf(&self, p: f64) -> StatsResult<f64>;
fn mean(&self) -> f64;
fn variance(&self) -> f64;
fn std_dev(&self) -> f64 {
self.variance().sqrt()
}
fn log_likelihood(&self, data: &[f64]) -> StatsResult<f64> {
let mut ll = 0.0_f64;
for &x in data {
ll += self.logpdf(x)?;
}
Ok(ll)
}
fn aic(&self, data: &[f64]) -> StatsResult<f64> {
let ll = self.log_likelihood(data)?;
Ok(2.0 * self.num_params() as f64 - 2.0 * ll)
}
fn bic(&self, data: &[f64]) -> StatsResult<f64> {
let ll = self.log_likelihood(data)?;
let n = data.len() as f64;
Ok(self.num_params() as f64 * n.ln() - 2.0 * ll)
}
}
pub trait DiscreteDistribution {
fn name(&self) -> &str;
fn num_params(&self) -> usize;
fn pmf(&self, k: u64) -> StatsResult<f64>;
fn logpmf(&self, k: u64) -> StatsResult<f64> {
self.pmf(k).map(|p| p.ln())
}
fn cdf(&self, k: u64) -> StatsResult<f64>;
fn inverse_cdf(&self, p: f64) -> StatsResult<u64> {
if !(0.0..=1.0).contains(&p) {
return Err(StatsError::InvalidInput {
message: format!("inverse_cdf: p must be in [0, 1], got {p}"),
});
}
if p == 0.0 {
return Ok(0);
}
let mut hi: u64 = 1;
while self.cdf(hi)? < p {
hi = hi.saturating_mul(2);
if hi == u64::MAX {
return Err(StatsError::NumericalError {
message: "inverse_cdf: quantile exceeds u64::MAX".to_string(),
});
}
}
let mut lo: u64 = 0;
while lo < hi {
let mid = lo + (hi - lo) / 2;
if self.cdf(mid)? < p {
lo = mid + 1;
} else {
hi = mid;
}
}
Ok(lo)
}
fn mean(&self) -> f64;
fn variance(&self) -> f64;
fn std_dev(&self) -> f64 {
self.variance().sqrt()
}
fn log_likelihood(&self, data: &[u64]) -> StatsResult<f64> {
let mut ll = 0.0_f64;
for &k in data {
ll += self.logpmf(k)?;
}
Ok(ll)
}
fn aic(&self, data: &[u64]) -> StatsResult<f64> {
let ll = self.log_likelihood(data)?;
Ok(2.0 * self.num_params() as f64 - 2.0 * ll)
}
fn bic(&self, data: &[u64]) -> StatsResult<f64> {
let ll = self.log_likelihood(data)?;
let n = data.len() as f64;
Ok(self.num_params() as f64 * n.ln() - 2.0 * ll)
}
}