use crate::{Distribution, StandardUniform};
use core::fmt;
use num_traits::Float;
use rand::{Rng, RngExt};
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Zipf<F>
where
F: Float,
StandardUniform: Distribution<F>,
{
s: F,
t: F,
q: F,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
STooSmall,
NTooSmall,
IllDefined,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Error::STooSmall => "s < 0 or is NaN in Zipf distribution",
Error::NTooSmall => "n < 1 or is NaN in Zipf distribution",
Error::IllDefined => "n = inf and s <= 1 in Zipf distribution",
})
}
}
#[cfg(feature = "std")]
impl std::error::Error for Error {}
impl<F> Zipf<F>
where
F: Float,
StandardUniform: Distribution<F>,
{
#[inline]
pub fn new(n: F, s: F) -> Result<Zipf<F>, Error> {
if !(s >= F::zero()) {
return Err(Error::STooSmall);
}
if !(n >= F::one()) {
return Err(Error::NTooSmall);
}
if n.is_infinite() && s <= F::one() {
return Err(Error::IllDefined);
}
let q = if s != F::one() {
F::one() / (F::one() - s)
} else {
F::zero()
};
let t = if s == F::infinity() {
F::one()
} else if s != F::one() {
(n.powf(F::one() - s) - s) * q
} else {
F::one() + n.ln()
};
debug_assert!(t > F::zero());
Ok(Zipf { s, t, q })
}
#[inline]
fn inv_cdf(&self, p: F) -> F {
let one = F::one();
let pt = p * self.t;
if pt <= one {
pt
} else if self.s != one {
(pt * (one - self.s) + self.s).powf(self.q)
} else {
(pt - one).exp()
}
}
}
impl<F> Distribution<F> for Zipf<F>
where
F: Float,
StandardUniform: Distribution<F>,
{
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
let one = F::one();
loop {
let inv_b = self.inv_cdf(rng.sample(StandardUniform));
let x = (inv_b + one).floor();
let mut ratio = x.powf(-self.s);
if x > one {
ratio = ratio * inv_b.powf(self.s)
};
let y = rng.sample(StandardUniform);
if y < ratio {
return x;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_samples<F: Float + fmt::Debug, D: Distribution<F>>(distr: D, zero: F, expected: &[F]) {
let mut rng = crate::test::rng(213);
let mut buf = [zero; 4];
for x in &mut buf {
*x = rng.sample(&distr);
}
assert_eq!(buf, expected);
}
#[test]
#[should_panic]
fn zipf_s_too_small() {
Zipf::new(10., -1.).unwrap();
}
#[test]
#[should_panic]
fn zipf_n_too_small() {
Zipf::new(0., 1.).unwrap();
}
#[test]
#[should_panic]
fn zipf_nan() {
Zipf::new(10., f64::NAN).unwrap();
}
#[test]
fn zipf_sample() {
let d = Zipf::new(10., 0.5).unwrap();
let mut rng = crate::test::rng(2);
for _ in 0..1000 {
let r = d.sample(&mut rng);
assert!(r >= 1.);
}
}
#[test]
fn zipf_sample_s_1() {
let d = Zipf::new(10., 1.).unwrap();
let mut rng = crate::test::rng(2);
for _ in 0..1000 {
let r = d.sample(&mut rng);
assert!(r >= 1.);
}
}
#[test]
fn zipf_sample_s_0() {
let d = Zipf::new(10., 0.).unwrap();
let mut rng = crate::test::rng(2);
for _ in 0..1000 {
let r = d.sample(&mut rng);
assert!(r >= 1.);
}
}
#[test]
fn zipf_sample_s_inf() {
let d = Zipf::new(10., f64::infinity()).unwrap();
let mut rng = crate::test::rng(2);
for _ in 0..1000 {
let r = d.sample(&mut rng);
assert!(r == 1.);
}
}
#[test]
fn zipf_sample_large_n() {
let d = Zipf::new(f64::MAX, 1.5).unwrap();
let mut rng = crate::test::rng(2);
for _ in 0..1000 {
let r = d.sample(&mut rng);
assert!(r >= 1.);
}
}
#[test]
fn zipf_value_stability() {
test_samples(Zipf::new(10., 0.5).unwrap(), 0f32, &[10.0, 2.0, 6.0, 7.0]);
test_samples(Zipf::new(10., 2.0).unwrap(), 0f64, &[1.0, 2.0, 3.0, 2.0]);
}
#[test]
fn zipf_distributions_can_be_compared() {
assert_eq!(Zipf::new(1.0, 2.0), Zipf::new(1.0, 2.0));
}
}