use crate::error::SearchError;
use crate::search::{search_monotone, SEARCH_BOUND};
use crate::special::gamma_inc;
use crate::special::gamma_log;
use crate::traits::{Discrete, DiscreteCdf, Mean, Variance};
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Poisson {
lambda: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Error)]
pub enum PoissonError {
#[error("lambda must be ≥ 0, got {0}")]
LambdaNegative(f64),
#[error("lambda must be finite, got {0}")]
LambdaNotFinite(f64),
#[error("probability {0} outside [0..1]")]
PNotInRange(f64),
#[error("probability {0} outside [0..1]")]
QNotInRange(f64),
#[error("p ({p}) and q ({q}) are not complementary: |p + q - 1| > 3ε")]
PQSumNotOne { p: f64, q: f64 },
#[error(transparent)]
Search(#[from] SearchError),
}
impl Poisson {
#[inline]
pub fn new(lambda: f64) -> Self {
Self::try_new(lambda).unwrap()
}
#[inline]
pub fn try_new(lambda: f64) -> Result<Self, PoissonError> {
if !lambda.is_finite() {
return Err(PoissonError::LambdaNotFinite(lambda));
}
if lambda < 0.0 {
return Err(PoissonError::LambdaNegative(lambda));
}
Ok(Self { lambda })
}
#[inline]
pub const fn lambda(&self) -> f64 {
self.lambda
}
#[inline]
pub fn search_lambda(p: f64, q: f64, s: u64) -> Result<f64, PoissonError> {
check_pq(p, q)?;
let sf = s as f64;
let f = |lambda: f64| {
let (sf_upper, cdf) = gamma_inc(sf + 1.0, lambda);
if p <= q {
cdf - p
} else {
sf_upper - q
}
};
Ok(search_monotone(
0.0,
SEARCH_BOUND,
5.0,
0.0,
SEARCH_BOUND,
f,
)?)
}
}
#[inline]
fn check_p(p: f64) -> Result<(), PoissonError> {
if !(0.0..=1.0).contains(&p) || !p.is_finite() {
Err(PoissonError::PNotInRange(p))
} else {
Ok(())
}
}
#[inline]
fn check_q(q: f64) -> Result<(), PoissonError> {
if !(0.0..=1.0).contains(&q) || !q.is_finite() {
Err(PoissonError::QNotInRange(q))
} else {
Ok(())
}
}
#[inline]
fn check_pq(p: f64, q: f64) -> Result<(), PoissonError> {
check_p(p)?;
check_q(q)?;
if (p + q - 1.0).abs() > 3.0 * f64::EPSILON {
return Err(PoissonError::PQSumNotOne { p, q });
}
Ok(())
}
impl DiscreteCdf for Poisson {
type Error = PoissonError;
#[inline]
fn cdf(&self, s: u64) -> f64 {
let (_, q) = gamma_inc(s as f64 + 1.0, self.lambda);
q
}
#[inline]
fn ccdf(&self, s: u64) -> f64 {
let (p, _) = gamma_inc(s as f64 + 1.0, self.lambda);
p
}
#[inline]
fn inverse_cdf(&self, p: f64) -> Result<u64, PoissonError> {
check_p(p)?;
if p == 0.0 {
return Ok(0);
}
if p == 1.0 {
return Ok(u64::MAX);
}
let mean = self.lambda;
let sd = self.lambda.sqrt();
let mut hi = (mean + 10.0 * sd + 10.0).ceil() as u64;
while self.cdf(hi) < p && hi < u64::MAX / 2 {
hi *= 2;
}
if self.cdf(hi) < p {
return Ok(u64::MAX);
}
let mut lo = 0u64;
while lo < hi {
let mid = lo + (hi - lo) / 2;
if self.cdf(mid) < p {
lo = mid + 1;
} else {
hi = mid;
}
}
Ok(lo)
}
}
impl Poisson {
#[inline]
pub fn inverse_ccdf(&self, q: f64) -> Result<f64, PoissonError> {
check_q(q)?;
let lambda = self.lambda;
let p = 1.0 - q;
let f = |s: f64| {
let (ccum, cum) = gamma_inc(s + 1.0, lambda);
if p <= q {
cum - p
} else {
ccum - q
}
};
Ok(search_monotone(
0.0,
SEARCH_BOUND,
5.0,
0.0,
SEARCH_BOUND,
f,
)?)
}
}
impl Discrete for Poisson {
#[inline]
fn pmf(&self, s: u64) -> f64 {
self.ln_pmf(s).exp()
}
#[inline]
fn ln_pmf(&self, s: u64) -> f64 {
let sf = s as f64;
sf * self.lambda.ln() - self.lambda - gamma_log(sf + 1.0)
}
}
impl Mean for Poisson {
#[inline]
fn mean(&self) -> f64 {
self.lambda
}
}
impl Variance for Poisson {
#[inline]
fn variance(&self) -> f64 {
self.lambda
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_rejects_bad_lambda() {
assert!(matches!(
Poisson::try_new(-1.0),
Err(PoissonError::LambdaNegative(_))
));
assert!(Poisson::try_new(0.0).is_ok());
assert!(matches!(
Poisson::try_new(f64::NAN),
Err(PoissonError::LambdaNotFinite(_))
));
assert!(matches!(
Poisson::try_new(f64::INFINITY),
Err(PoissonError::LambdaNotFinite(_))
));
}
#[test]
fn inverse_ccdf_matches_integer_quantile_at_integer_boundary() {
let p = Poisson::new(3.0);
let q_target = p.ccdf(2); let s = p.inverse_ccdf(q_target).unwrap();
assert!((s - 2.0).abs() < 1e-6, "got s = {s}");
}
#[test]
fn inverse_ccdf_between_integers() {
let p = Poisson::new(3.0);
let hi_sf = p.ccdf(2);
let lo_sf = p.ccdf(3);
let q_target = 0.5 * (lo_sf + hi_sf);
let s = p.inverse_ccdf(q_target).unwrap();
assert!(s > 2.0 && s < 3.0, "got s = {s}");
}
#[test]
fn search_lambda_rejects_bad_p() {
assert!(matches!(
Poisson::search_lambda(-0.1, 1.1, 3),
Err(PoissonError::PNotInRange(_))
));
assert!(matches!(
Poisson::search_lambda(1.5, -0.5, 3),
Err(PoissonError::PNotInRange(_))
));
assert!(matches!(
Poisson::search_lambda(f64::NAN, 0.5, 3),
Err(PoissonError::PNotInRange(_))
));
}
#[test]
fn inverse_cdf_p_zero_returns_zero() {
let p = Poisson::new(5.0);
assert_eq!(p.inverse_cdf(0.0).unwrap(), 0);
}
#[test]
fn inverse_cdf_rejects_bad_p() {
let p = Poisson::new(5.0);
assert!(matches!(
p.inverse_cdf(-0.1),
Err(PoissonError::PNotInRange(_))
));
assert!(matches!(
p.inverse_cdf(1.1),
Err(PoissonError::PNotInRange(_))
));
}
#[cfg(not(miri))]
#[test]
fn search_lambda_uses_precision_pivot_at_both_tails() {
let lambda = 5.0_f64;
let s = 10u64; let dist = Poisson::new(lambda);
let p_target = dist.cdf(s);
let q_target = dist.ccdf(s);
let recovered = Poisson::search_lambda(p_target, q_target, s).unwrap();
assert!(
(recovered - lambda).abs() < 1e-5,
"p_target={p_target}, recovered={recovered}"
);
}
#[test]
fn extreme_right_tail_matches_high_precision_reference() {
let p = Poisson::new(200.0);
let expected_cdf = 0.999_999_993_591_493_9;
let expected_sf = 6.408_506_071_899_014e-9;
assert!((p.cdf(285) - expected_cdf).abs() < 1e-15);
assert!((p.ccdf(285) - expected_sf).abs() < 1e-22);
}
#[test]
fn moments_match_lambda() {
let p = Poisson::new(4.0);
assert_eq!(p.mean(), 4.0);
assert_eq!(p.variance(), 4.0);
assert!(p.ln_pmf(3).is_finite());
}
}