use crate::DType;
use crate::stats::continuous::special;
use crate::stats::distribution::{DiscreteDistribution, Distribution};
use crate::stats::error::{StatsError, StatsResult};
use numr::algorithm::special::SpecialFunctions;
use numr::error::Result;
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
#[derive(Debug, Clone, Copy)]
pub struct Poisson {
lambda: f64,
}
impl Poisson {
pub fn new(lambda: f64) -> StatsResult<Self> {
if lambda <= 0.0 {
return Err(StatsError::InvalidParameter {
name: "lambda".to_string(),
value: lambda,
reason: "rate must be positive".to_string(),
});
}
if !lambda.is_finite() {
return Err(StatsError::InvalidParameter {
name: "lambda".to_string(),
value: lambda,
reason: "must be finite".to_string(),
});
}
Ok(Self { lambda })
}
pub fn lambda(&self) -> f64 {
self.lambda
}
pub fn rate(&self) -> f64 {
self.lambda
}
}
impl Distribution for Poisson {
fn mean(&self) -> f64 {
self.lambda
}
fn var(&self) -> f64 {
self.lambda
}
fn entropy(&self) -> f64 {
0.5 * (2.0 * std::f64::consts::PI * std::f64::consts::E * self.lambda).ln()
}
fn median(&self) -> f64 {
(self.lambda + 1.0 / 3.0 - 0.02 / self.lambda).floor()
}
fn mode(&self) -> f64 {
self.lambda.floor()
}
fn skewness(&self) -> f64 {
1.0 / self.lambda.sqrt()
}
fn kurtosis(&self) -> f64 {
1.0 / self.lambda }
}
impl DiscreteDistribution for Poisson {
fn pmf(&self, k: u64) -> f64 {
self.log_pmf(k).exp()
}
fn log_pmf(&self, k: u64) -> f64 {
let k_f = k as f64;
k_f * self.lambda.ln() - self.lambda - special::lgamma(k_f + 1.0)
}
fn cdf(&self, k: u64) -> f64 {
special::gammaincc((k + 1) as f64, self.lambda)
}
fn sf(&self, k: u64) -> f64 {
special::gammainc((k + 1) as f64, self.lambda)
}
fn ppf(&self, prob: f64) -> StatsResult<u64> {
if !(0.0..=1.0).contains(&prob) {
return Err(StatsError::InvalidProbability { value: prob });
}
if prob == 0.0 {
return Ok(0);
}
if prob == 1.0 {
return Ok(u64::MAX);
}
let initial = special::gammaincinv(self.lambda, prob);
let mut k = initial.floor() as u64;
while self.cdf(k) < prob && k < u64::MAX - 1 {
k += 1;
}
while k > 0 && self.cdf(k - 1) >= prob {
k -= 1;
}
Ok(k)
}
fn pmf_tensor<R: Runtime<DType = DType>, C>(
&self,
k: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let log_pmf = self.log_pmf_tensor(k, client)?;
client.exp(&log_pmf)
}
fn log_pmf_tensor<R: Runtime<DType = DType>, C>(
&self,
k: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let ln_lambda = self.lambda.ln();
let k_times_ln_lambda = client.mul_scalar(k, ln_lambda)?;
let k_plus_1 = client.add_scalar(k, 1.0)?;
let lgamma_k_plus_1 = client.lgamma(&k_plus_1)?;
let result = client.sub_scalar(&k_times_ln_lambda, self.lambda)?;
client.sub(&result, &lgamma_k_plus_1)
}
fn cdf_tensor<R: Runtime<DType = DType>, C>(
&self,
k: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let k_floor = client.floor(k)?;
let k_plus_1 = client.add_scalar(&k_floor, 1.0)?;
let shape = k_plus_1.shape();
let lambda_tensor = client.fill(shape, self.lambda, k_plus_1.dtype())?;
client.gammaincc(&k_plus_1, &lambda_tensor)
}
fn sf_tensor<R: Runtime<DType = DType>, C>(
&self,
k: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let k_floor = client.floor(k)?;
let k_plus_1 = client.add_scalar(&k_floor, 1.0)?;
let shape = k_plus_1.shape();
let lambda_tensor = client.fill(shape, self.lambda, k_plus_1.dtype())?;
client.gammainc(&k_plus_1, &lambda_tensor)
}
fn ppf_tensor<R: Runtime<DType = DType>, C>(
&self,
p: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let shape = p.shape();
let lambda_tensor = client.fill(shape, self.lambda, p.dtype())?;
let initial = client.gammaincinv(&lambda_tensor, p)?;
client.floor(&initial)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_poisson_creation() {
let p = Poisson::new(5.0).unwrap();
assert!((p.lambda() - 5.0).abs() < 1e-10);
assert!((p.rate() - 5.0).abs() < 1e-10);
assert!(Poisson::new(0.0).is_err());
assert!(Poisson::new(-1.0).is_err());
}
#[test]
fn test_poisson_moments() {
let p = Poisson::new(4.0).unwrap();
assert!((p.mean() - 4.0).abs() < 1e-10);
assert!((p.var() - 4.0).abs() < 1e-10);
assert!((p.std() - 2.0).abs() < 1e-10);
assert!((p.skewness() - 0.5).abs() < 1e-10);
assert!((p.kurtosis() - 0.25).abs() < 1e-10);
}
#[test]
fn test_poisson_pmf() {
let p = Poisson::new(3.0).unwrap();
assert!((p.pmf(0) - (-3.0_f64).exp()).abs() < 1e-10);
let expected = 27.0 * (-3.0_f64).exp() / 6.0;
assert!((p.pmf(3) - expected).abs() < 1e-10);
let total: f64 = (0..50).map(|k| p.pmf(k)).sum();
assert!((total - 1.0).abs() < 1e-10);
}
#[test]
fn test_poisson_cdf() {
let p = Poisson::new(3.0).unwrap();
assert!((p.cdf(0) - (-3.0_f64).exp()).abs() < 1e-10);
let cdf_3: f64 = (0..=3).map(|k| p.pmf(k)).sum();
assert!((p.cdf(3) - cdf_3).abs() < 1e-6);
for k in 0..10 {
assert!(p.cdf(k) <= p.cdf(k + 1));
}
}
#[test]
fn test_poisson_ppf() {
let p = Poisson::new(5.0).unwrap();
for k in 0..15 {
let prob = p.cdf(k);
let result = p.ppf(prob).unwrap();
assert!(p.cdf(result) >= prob);
}
}
#[test]
fn test_poisson_sf() {
let p = Poisson::new(3.0).unwrap();
for k in 0..10 {
assert!((p.sf(k) + p.cdf(k) - 1.0).abs() < 1e-10);
}
}
}