use crate::DType;
use crate::stats::distribution::{DiscreteDistribution, Distribution};
use crate::stats::error::{StatsError, StatsResult};
use numr::algorithm::special::SpecialFunctions;
use numr::error::Result;
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
#[derive(Debug, Clone, Copy)]
pub struct Geometric {
p: f64,
q: f64,
}
impl Geometric {
pub fn new(p: f64) -> StatsResult<Self> {
if p <= 0.0 || p > 1.0 {
return Err(StatsError::InvalidParameter {
name: "p".to_string(),
value: p,
reason: "probability must be in (0, 1]".to_string(),
});
}
Ok(Self { p, q: 1.0 - p })
}
pub fn p(&self) -> f64 {
self.p
}
}
impl Distribution for Geometric {
fn mean(&self) -> f64 {
self.q / self.p
}
fn var(&self) -> f64 {
self.q / (self.p * self.p)
}
fn entropy(&self) -> f64 {
if self.q == 0.0 {
return 0.0;
}
(-self.q * self.q.ln() - self.p * self.p.ln()) / self.p
}
fn median(&self) -> f64 {
if self.q == 0.0 {
return 0.0;
}
let val = (-1.0 / (self.q.log2())).ceil() - 1.0;
val.max(0.0)
}
fn mode(&self) -> f64 {
0.0
}
fn skewness(&self) -> f64 {
(2.0 - self.p) / self.q.sqrt()
}
fn kurtosis(&self) -> f64 {
6.0 + (self.p * self.p) / self.q }
}
impl DiscreteDistribution for Geometric {
fn pmf(&self, k: u64) -> f64 {
if self.q == 0.0 {
return if k == 0 { 1.0 } else { 0.0 };
}
self.q.powi(k as i32) * self.p
}
fn log_pmf(&self, k: u64) -> f64 {
if self.q == 0.0 {
return if k == 0 { 0.0 } else { f64::NEG_INFINITY };
}
(k as f64) * self.q.ln() + self.p.ln()
}
fn cdf(&self, k: u64) -> f64 {
if self.q == 0.0 {
return 1.0;
}
1.0 - self.q.powi((k + 1) as i32)
}
fn sf(&self, k: u64) -> f64 {
if self.q == 0.0 {
return 0.0;
}
self.q.powi((k + 1) as i32)
}
fn ppf(&self, prob: f64) -> StatsResult<u64> {
if !(0.0..=1.0).contains(&prob) {
return Err(StatsError::InvalidProbability { value: prob });
}
if prob == 0.0 {
return Ok(0);
}
if prob == 1.0 {
return Ok(u64::MAX);
}
if self.q == 0.0 {
return Ok(0);
}
let k = ((1.0 - prob).ln() / self.q.ln()).ceil() - 1.0;
Ok(k.max(0.0) as u64)
}
fn pmf_tensor<R: Runtime<DType = DType>, C>(
&self,
k: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let log_pmf = self.log_pmf_tensor(k, client)?;
client.exp(&log_pmf)
}
fn log_pmf_tensor<R: Runtime<DType = DType>, C>(
&self,
k: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let log_p = self.p.ln();
let log_q = self.q.ln();
let k_times_log_q = client.mul_scalar(k, log_q)?;
client.add_scalar(&k_times_log_q, log_p)
}
fn cdf_tensor<R: Runtime<DType = DType>, C>(
&self,
k: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let k_floor = client.floor(k)?;
let k_plus_1 = client.add_scalar(&k_floor, 1.0)?;
let log_q = self.q.ln();
let k_plus_1_log_q = client.mul_scalar(&k_plus_1, log_q)?;
let q_to_k_plus_1 = client.exp(&k_plus_1_log_q)?;
client.sub_scalar(&client.mul_scalar(&q_to_k_plus_1, -1.0)?, -1.0)
}
fn sf_tensor<R: Runtime<DType = DType>, C>(
&self,
k: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let k_floor = client.floor(k)?;
let k_plus_1 = client.add_scalar(&k_floor, 1.0)?;
let log_q = self.q.ln();
let k_plus_1_log_q = client.mul_scalar(&k_plus_1, log_q)?;
client.exp(&k_plus_1_log_q)
}
fn ppf_tensor<R: Runtime<DType = DType>, C>(
&self,
p: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let log_q = self.q.ln();
let one_minus_p = client.rsub_scalar(p, 1.0)?;
let ln_one_minus_p = client.log(&one_minus_p)?;
let ratio = client.div_scalar(&ln_one_minus_p, log_q)?;
let ceiled = client.ceil(&ratio)?;
let result = client.sub_scalar(&ceiled, 1.0)?;
client.clamp(&result, 0.0, f64::INFINITY)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_geometric_creation() {
let g = Geometric::new(0.3).unwrap();
assert!((g.p() - 0.3).abs() < 1e-10);
assert!(Geometric::new(0.0).is_err());
assert!(Geometric::new(-0.1).is_err());
assert!(Geometric::new(1.1).is_err());
assert!(Geometric::new(1.0).is_ok());
}
#[test]
fn test_geometric_moments() {
let g = Geometric::new(0.25).unwrap();
assert!((g.mean() - 3.0).abs() < 1e-10);
assert!((g.var() - 12.0).abs() < 1e-10);
assert!((g.mode() - 0.0).abs() < 1e-10);
}
#[test]
fn test_geometric_pmf() {
let g = Geometric::new(0.5).unwrap();
assert!((g.pmf(0) - 0.5).abs() < 1e-10);
assert!((g.pmf(1) - 0.25).abs() < 1e-10);
assert!((g.pmf(2) - 0.125).abs() < 1e-10);
let total: f64 = (0..30).map(|k| g.pmf(k)).sum();
assert!((total - 1.0).abs() < 1e-9);
}
#[test]
fn test_geometric_cdf() {
let g = Geometric::new(0.5).unwrap();
assert!((g.cdf(0) - 0.5).abs() < 1e-10);
assert!((g.cdf(1) - 0.75).abs() < 1e-10);
assert!((g.cdf(2) - 0.875).abs() < 1e-10);
for k in 0..10 {
assert!(g.cdf(k) <= g.cdf(k + 1));
}
}
#[test]
fn test_geometric_ppf() {
let g = Geometric::new(0.3).unwrap();
for k in 0..10 {
let prob = g.cdf(k);
let result = g.ppf(prob).unwrap();
assert!(
g.cdf(result) >= prob,
"k={}, prob={}, result={}, cdf={}",
k,
prob,
result,
g.cdf(result)
);
}
}
#[test]
fn test_geometric_sf() {
let g = Geometric::new(0.5).unwrap();
for k in 0..10 {
assert!((g.sf(k) + g.cdf(k) - 1.0).abs() < 1e-10);
}
}
#[test]
fn test_geometric_p_equals_1() {
let g = Geometric::new(1.0).unwrap();
assert!((g.pmf(0) - 1.0).abs() < 1e-10);
assert!((g.pmf(1) - 0.0).abs() < 1e-10);
assert!((g.cdf(0) - 1.0).abs() < 1e-10);
assert!((g.mean() - 0.0).abs() < 1e-10);
}
}