use crate::distributions::traits::DiscreteDistribution;
use crate::error::{StatsError, StatsResult};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct Geometric {
pub p: f64,
}
impl Geometric {
pub fn new(p: f64) -> StatsResult<Self> {
if !(0.0 < p && p <= 1.0) {
return Err(StatsError::InvalidInput {
message: "Geometric::new: p must be in (0, 1]".to_string(),
});
}
Ok(Self { p })
}
pub fn fit(data: &[f64]) -> StatsResult<Self> {
if data.is_empty() {
return Err(StatsError::InvalidInput {
message: "Geometric::fit: data must not be empty".to_string(),
});
}
if data.iter().any(|&x| x < 1.0 || x.fract() != 0.0) {
return Err(StatsError::InvalidInput {
message: "Geometric::fit: all data values must be positive integers (≥ 1)"
.to_string(),
});
}
let mean = data.iter().sum::<f64>() / data.len() as f64;
Self::new((1.0 / mean).clamp(1e-15, 1.0))
}
}
impl DiscreteDistribution for Geometric {
fn name(&self) -> &str {
"Geometric"
}
fn num_params(&self) -> usize {
1
}
fn pmf(&self, k: u64) -> StatsResult<f64> {
if k == 0 {
return Ok(0.0);
}
Ok(self.logpmf(k)?.exp())
}
fn logpmf(&self, k: u64) -> StatsResult<f64> {
if k == 0 {
return Ok(f64::NEG_INFINITY);
}
Ok(self.p.ln() + (k - 1) as f64 * (1.0 - self.p).ln())
}
fn cdf(&self, k: u64) -> StatsResult<f64> {
if k == 0 {
return Ok(0.0);
}
Ok(1.0 - (1.0 - self.p).powf(k as f64))
}
fn inverse_cdf(&self, p: f64) -> crate::error::StatsResult<u64> {
use crate::error::StatsError;
if !(0.0..=1.0).contains(&p) {
return Err(StatsError::InvalidInput {
message: format!("Geometric::inverse_cdf: p must be in [0, 1], got {p}"),
});
}
if p == 0.0 {
return Ok(0);
}
if p == 1.0 || self.p == 1.0 {
return Ok(1);
}
let k = (1.0 - p).ln() / (1.0 - self.p).ln();
Ok(k.ceil().max(1.0) as u64)
}
fn mean(&self) -> f64 {
1.0 / self.p
}
fn variance(&self) -> f64 {
(1.0 - self.p) / (self.p * self.p)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_geometric_mean_variance() {
let g = Geometric::new(0.25).unwrap();
assert!((g.mean() - 4.0).abs() < 1e-10);
assert!((g.variance() - 12.0).abs() < 1e-10);
}
#[test]
fn test_geometric_pmf_k1() {
let g = Geometric::new(0.5).unwrap();
assert!((g.pmf(1).unwrap() - 0.5).abs() < 1e-10);
}
#[test]
fn test_geometric_cdf_large_k() {
let g = Geometric::new(0.5).unwrap();
assert!(g.cdf(100).unwrap() > 0.999_999);
}
#[test]
fn test_geometric_logpmf() {
let g = Geometric::new(0.3).unwrap();
let pmf = g.pmf(3).unwrap();
let logpmf = g.logpmf(3).unwrap();
assert!((logpmf - pmf.ln()).abs() < 1e-10);
}
#[test]
fn test_geometric_fit() {
let data = vec![1.0, 2.0, 1.0, 3.0, 1.0, 2.0, 4.0, 1.0];
let g = Geometric::fit(&data).unwrap();
let expected_p = data.len() as f64 / data.iter().sum::<f64>();
assert!((g.p - expected_p).abs() < 1e-10);
}
#[test]
fn test_geometric_invalid() {
assert!(Geometric::new(0.0).is_err());
assert!(Geometric::new(-0.1).is_err());
assert!(Geometric::new(1.1).is_err());
}
}