use crate::dist::{Discrete, DistError, Distribution, Moments};
use crate::num;
use crate::rng::RngCore;
#[derive(Debug, Clone, Copy)]
pub struct Poisson {
lambda: f64,
}
impl Poisson {
pub fn new(lambda: f64) -> Result<Self, DistError> {
if !(lambda.is_finite() && lambda > 0.0) {
return Err(DistError::InvalidParameter);
}
Ok(Self { lambda })
}
#[inline]
pub fn lambda(&self) -> f64 {
self.lambda
}
#[inline]
fn pmf_rec_start(&self) -> f64 {
(-self.lambda).exp()
}
fn pmf_via_recurrence(&self, k: i64) -> f64 {
if k < 0 {
return 0.0;
}
let k = k as u64;
let mut p = self.pmf_rec_start();
for i in 1..=k {
p *= self.lambda / (i as f64);
}
p
}
fn cdf_via_recurrence(&self, k: i64) -> f64 {
if k < 0 {
return 0.0;
}
let k = k as u64;
let mut p = self.pmf_rec_start();
let mut acc = p;
for i in 1..=k {
p *= self.lambda / (i as f64);
acc += p;
}
acc
}
}
impl Distribution for Poisson {
type Value = i64;
fn cdf(&self, x: Self::Value) -> f64 {
self.cdf_via_recurrence(x)
}
fn in_support(&self, x: Self::Value) -> bool {
x >= 0
}
fn sample<R: RngCore>(&self, rng: &mut R) -> Self::Value {
if self.lambda < 30.0 {
let mut k: i64 = 0;
let mut p = self.pmf_rec_start();
let mut c = p;
let u = rng.next_f64();
while u > c {
k += 1;
p *= self.lambda / (k as f64);
c += p;
}
return k;
}
let lambda = self.lambda;
let m = lambda.floor() as i64; if lambda < 400.0 {
let log_p_m = (m as f64) * lambda.ln() - lambda - ln_factorial_u64(m as u64);
let p_m = log_p_m.exp();
if p_m.partial_cmp(&0.0) != Some(std::cmp::Ordering::Greater) {
let mut k: i64 = 0;
let mut p = self.pmf_rec_start();
let mut c = p;
let u = rng.next_f64();
while u > c {
k += 1;
p *= self.lambda / (k as f64);
c += p;
}
return k;
}
let u = rng.next_f64();
let mut c = p_m;
if u <= c {
return m;
}
let mut left = p_m;
let mut right = p_m;
let mut i: i64 = 1;
loop {
if i <= m {
left *= (m - (i - 1)) as f64 / lambda; c += left;
if u <= c {
return m - i;
}
} else {
left = 0.0;
}
right *= lambda / (m + i) as f64; c += right;
if u <= c {
return m + i;
}
i += 1;
}
}
let u_anchor = rng.next_f64();
let z = num::standard_normal_inv_cdf(u_anchor);
let mut k0 = (lambda + z * lambda.sqrt()).floor() as i64;
if k0 < 0 {
k0 = 0;
}
let log_p0 = (k0 as f64) * lambda.ln() - lambda - ln_factorial_u64(k0 as u64);
let p0 = log_p0.exp();
if !(p0 > 0.0 && p0.is_finite()) {
let log_p_m = (m as f64) * lambda.ln() - lambda - ln_factorial_u64(m as u64);
let p_m = log_p_m.exp();
let u = rng.next_f64();
let mut c = p_m;
if u <= c {
return m;
}
let mut left = p_m;
let mut right = p_m;
let mut i: i64 = 1;
loop {
if i <= m {
left *= (m - (i - 1)) as f64 / lambda;
c += left;
if u <= c {
return m - i;
}
} else {
left = 0.0;
}
right *= lambda / (m + i) as f64;
c += right;
if u <= c {
return m + i;
}
i += 1;
}
}
let u = rng.next_f64();
let mut c = p0;
if u <= c {
return k0;
}
let mut left = p0;
let mut right = p0;
let mut i: i64 = 1;
loop {
if i <= k0 {
left *= (k0 - (i - 1)) as f64 / lambda;
c += left;
if u <= c {
return k0 - i;
}
} else {
left = 0.0;
}
right *= lambda / (k0 + i) as f64;
c += right;
if u <= c {
return k0 + i;
}
i += 1;
}
}
}
impl Discrete for Poisson {
fn pmf(&self, x: Self::Value) -> f64 {
self.pmf_via_recurrence(x)
}
fn inv_cdf(&self, p: f64) -> Self::Value {
debug_assert!((0.0..=1.0).contains(&p));
if p <= 0.0 {
return 0;
}
if p >= 1.0 {
return i64::MAX;
}
let mut k: i64 = 0;
let mut pk = self.pmf_rec_start();
let mut acc = pk;
while acc < p {
k += 1;
pk *= self.lambda / (k as f64);
acc += pk;
}
k
}
}
impl Moments for Poisson {
fn mean(&self) -> f64 {
self.lambda
}
fn variance(&self) -> f64 {
self.lambda
}
fn skewness(&self) -> f64 { 1.0 / self.lambda.sqrt() }
fn kurtosis(&self) -> f64 { 1.0 / self.lambda }
fn entropy(&self) -> f64 {
let lambda = self.lambda;
let mean = lambda;
let std = lambda.sqrt();
let mut k_min = (mean - 10.0 * std).floor() as i64;
if k_min < 0 { k_min = 0; }
let k_max = (mean + 10.0 * std).ceil() as i64;
let mut h = 0.0;
for k in k_min..=k_max {
let pk = self.pmf_via_recurrence(k);
if pk > 0.0 { h -= pk * pk.ln(); }
}
h
}
}
#[inline]
fn ln_factorial_u64(n: u64) -> f64 {
const LN_FACT_SMALL: [f64; 21] = [
0.0,
0.0,
std::f64::consts::LN_2,
1.791759469228055,
3.1780538303479458,
4.787491742782046,
6.579251212010101,
8.525161361065415,
10.60460290274525,
12.80182748008147,
15.104412573075516,
17.502307845873887,
19.98721449566189,
22.552163853123425,
25.19122118273868,
27.899271383840894,
30.671860106080675,
33.50507345013689,
36.39544520803305,
39.339884187199495,
42.335616460753485,
];
if n <= 20 {
return LN_FACT_SMALL[n as usize];
}
let x = n as f64;
let inv = 1.0 / x;
let inv2 = inv * inv;
let inv3 = inv2 * inv;
let inv5 = inv3 * inv2;
x * x.ln() - x + 0.5 * ((2.0 * std::f64::consts::PI * x).ln()) + (inv / 12.0) - (inv3 / 360.0)
+ (inv5 / 1260.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pmf_values() {
let p = Poisson::new(3.0).unwrap();
let e3 = (-3.0f64).exp();
assert!((p.pmf(0) - e3).abs() < 1e-15);
assert!((p.pmf(3) - 4.5 * e3).abs() < 1e-12);
}
#[test]
fn cdf_bounds() {
let p = Poisson::new(1.5).unwrap();
assert_eq!(p.cdf(-1), 0.0);
assert!(p.cdf(0) > 0.0);
assert!(p.cdf(10) < 1.0);
}
#[test]
fn inv_cdf_roundtrip() {
let pois = Poisson::new(2.5).unwrap();
let ps = [0.1, 0.3, 0.5, 0.8, 0.95];
for &prob in &ps {
let k = pois.inv_cdf(prob);
assert!(pois.cdf(k) >= prob - 1e-15);
if k > 0 {
assert!(pois.cdf(k - 1) <= prob + 1e-15);
}
}
}
#[test]
fn sampling_deterministic() {
let pois = Poisson::new(4.0).unwrap();
let mut r1 = crate::rng::SplitMix64::seed_from_u64(7);
let mut r2 = crate::rng::SplitMix64::seed_from_u64(7);
let x1 = pois.sample(&mut r1);
let x2 = pois.sample(&mut r2);
assert_eq!(x1, x2);
}
#[test]
fn moments_higher() {
let p = Poisson::new(4.0).unwrap();
assert!((p.skewness() - 0.5).abs() < 1e-15);
assert!((p.kurtosis() - 0.25).abs() < 1e-15);
assert!((p.kurtosis_full() - 3.25).abs() < 1e-15);
}
}