use crate::{Distribution, Exp1, Normal, StandardNormal, StandardUniform};
use core::fmt;
use num_traits::{Float, FloatConst};
use rand::{Rng, RngExt};
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Poisson<F>(Method<F>)
where
F: Float + FloatConst,
StandardUniform: Distribution<F>;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
ShapeTooSmall,
NonFinite,
ShapeTooLarge,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Error::ShapeTooSmall => "lambda is not positive in Poisson distribution",
Error::NonFinite => "lambda is infinite or nan in Poisson distribution",
Error::ShapeTooLarge => {
"lambda is too large in Poisson distribution, see Poisson::MAX_LAMBDA"
}
})
}
}
#[cfg(feature = "std")]
impl std::error::Error for Error {}
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub(crate) struct KnuthMethod<F> {
exp_lambda: F,
}
impl<F: Float> KnuthMethod<F> {
pub(crate) fn new(lambda: F) -> Self {
KnuthMethod {
exp_lambda: (-lambda).exp(),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
struct RejectionMethod<F> {
lambda: F,
s: F,
d: F,
l: F,
c: F,
c0: F,
c1: F,
c2: F,
c3: F,
omega: F,
}
impl<F: Float + FloatConst> RejectionMethod<F> {
pub(crate) fn new(lambda: F) -> Self {
let b1 = F::from(1.0 / 24.0).unwrap() / lambda;
let b2 = F::from(0.3).unwrap() * b1 * b1;
let c3 = F::from(1.0 / 7.0).unwrap() * b1 * b2;
let c2 = b2 - F::from(15).unwrap() * c3;
let c1 = b1 - F::from(6).unwrap() * b2 + F::from(45).unwrap() * c3;
let c0 = F::one() - b1 + F::from(3).unwrap() * b2 - F::from(15).unwrap() * c3;
RejectionMethod {
lambda,
s: lambda.sqrt(),
d: F::from(6.0).unwrap() * lambda.powi(2),
l: (lambda - F::from(1.1484).unwrap()).floor(),
c: F::from(0.1069).unwrap() / lambda,
c0,
c1,
c2,
c3,
omega: F::one() / (F::from(2).unwrap() * F::PI()).sqrt() / lambda.sqrt(),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
enum Method<F> {
Knuth(KnuthMethod<F>),
Rejection(RejectionMethod<F>),
}
impl<F> Poisson<F>
where
F: Float + FloatConst,
StandardUniform: Distribution<F>,
{
pub fn new(lambda: F) -> Result<Poisson<F>, Error> {
if !lambda.is_finite() {
return Err(Error::NonFinite);
}
if !(lambda > F::zero()) {
return Err(Error::ShapeTooSmall);
}
let method = if lambda < F::from(12.0).unwrap() {
Method::Knuth(KnuthMethod::new(lambda))
} else {
if lambda > F::from(Self::MAX_LAMBDA).unwrap() {
return Err(Error::ShapeTooLarge);
}
Method::Rejection(RejectionMethod::new(lambda))
};
Ok(Poisson(method))
}
pub const MAX_LAMBDA: f64 = 1.844e19;
}
impl<F> Distribution<F> for KnuthMethod<F>
where
F: Float + FloatConst,
StandardUniform: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
let mut result = F::one();
let mut p = rng.random::<F>();
while p > self.exp_lambda {
p = p * rng.random::<F>();
result = result + F::one();
}
result - F::one()
}
}
impl<F> Distribution<F> for RejectionMethod<F>
where
F: Float + FloatConst,
StandardUniform: Distribution<F>,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
let f = |k: F| {
const FACT: [f64; 10] = [
1.0, 1.0, 2.0, 6.0, 24.0, 120.0, 720.0, 5040.0, 40320.0, 362880.0,
]; const A: [f64; 10] = [
-0.5000000002,
0.3333333343,
-0.2499998565,
0.1999997049,
-0.1666848753,
0.1428833286,
-0.1241963125,
0.1101687109,
-0.1142650302,
0.1055093006,
]; let (px, py) = if k < F::from(10.0).unwrap() {
let px = -self.lambda;
let py = self.lambda.powf(k) / F::from(FACT[k.to_usize().unwrap()]).unwrap();
(px, py)
} else {
let delta = (F::from(12.0).unwrap() * k).recip();
let delta = delta - F::from(4.8).unwrap() * delta.powi(3);
let v = (self.lambda - k) / k;
let px = if v.abs() <= F::from(0.25).unwrap() {
k * v.powi(2)
* A.iter()
.rev()
.fold(F::zero(), |acc, &a| {
acc * v + F::from(a).unwrap()
}) - delta
} else {
k * (F::one() + v).ln() - (self.lambda - k) - delta
};
let py = F::one() / (F::from(2.0).unwrap() * F::PI()).sqrt() / k.sqrt();
(px, py)
};
let x = (k - self.lambda + F::from(0.5).unwrap()) / self.s;
let fx = -F::from(0.5).unwrap() * x * x;
let fy =
self.omega * (((self.c3 * x * x + self.c2) * x * x + self.c1) * x * x + self.c0);
(px, py, fx, fy)
};
let normal = Normal::new(self.lambda, self.s).unwrap();
let g = normal.sample(rng);
if g >= F::zero() {
let k1 = g.floor();
if k1 >= self.l {
return k1;
}
let u: F = rng.random();
if self.d * u >= (self.lambda - k1).powi(3) {
return k1;
}
let (px, py, fx, fy) = f(k1);
if fy * (F::one() - u) <= py * (px - fx).exp() {
return k1;
}
}
loop {
let e = Exp1.sample(rng);
let u: F = rng.random() * F::from(2.0).unwrap() - F::one();
let t = F::from(1.8).unwrap() + e * u.signum();
if t > F::from(-0.6744).unwrap() {
let k2 = (self.lambda + self.s * t).floor();
let (px, py, fx, fy) = f(k2);
if self.c * u.abs() <= py * (px + e).exp() - fy * (fx + e).exp() {
return k2;
}
}
}
}
}
impl<F> Distribution<F> for Poisson<F>
where
F: Float + FloatConst,
StandardUniform: Distribution<F>,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
{
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
match &self.0 {
Method::Knuth(method) => method.sample(rng),
Method::Rejection(method) => method.sample(rng),
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
#[should_panic]
fn test_poisson_invalid_lambda_zero() {
Poisson::new(0.0).unwrap();
}
#[test]
#[should_panic]
fn test_poisson_invalid_lambda_infinity() {
Poisson::new(f64::INFINITY).unwrap();
}
#[test]
#[should_panic]
fn test_poisson_invalid_lambda_neg() {
Poisson::new(-10.0).unwrap();
}
#[test]
fn poisson_distributions_can_be_compared() {
assert_eq!(Poisson::new(1.0), Poisson::new(1.0));
}
}