use crate::{Distribution, StandardUniform};
use core::fmt;
use num_traits::Float;
use rand::{Rng, RngExt, distr::OpenClosed01};
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Zeta<F>
where
F: Float,
StandardUniform: Distribution<F>,
OpenClosed01: Distribution<F>,
{
s_minus_1: F,
b: F,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
STooSmall,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Error::STooSmall => "s <= 1 or is NaN in Zeta distribution",
})
}
}
#[cfg(feature = "std")]
impl std::error::Error for Error {}
impl<F> Zeta<F>
where
F: Float,
StandardUniform: Distribution<F>,
OpenClosed01: Distribution<F>,
{
#[inline]
pub fn new(s: F) -> Result<Zeta<F>, Error> {
if !(s > F::one()) {
return Err(Error::STooSmall);
}
let s_minus_1 = s - F::one();
let two = F::one() + F::one();
Ok(Zeta {
s_minus_1,
b: two.powf(s_minus_1),
})
}
}
impl<F> Distribution<F> for Zeta<F>
where
F: Float,
StandardUniform: Distribution<F>,
OpenClosed01: Distribution<F>,
{
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
loop {
let u = rng.sample(OpenClosed01);
let x = u.powf(-F::one() / self.s_minus_1).floor();
debug_assert!(x >= F::one());
if x.is_infinite() {
return x;
}
let t = (F::one() + F::one() / x).powf(self.s_minus_1);
let v = rng.sample(StandardUniform);
if v * x * (t - F::one()) * self.b <= t * (self.b - F::one()) {
return x;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_samples<F: Float + fmt::Debug, D: Distribution<F>>(distr: D, zero: F, expected: &[F]) {
let mut rng = crate::test::rng(213);
let mut buf = [zero; 4];
for x in &mut buf {
*x = rng.sample(&distr);
}
assert_eq!(buf, expected);
}
#[test]
#[should_panic]
fn zeta_invalid() {
Zeta::new(1.).unwrap();
}
#[test]
#[should_panic]
fn zeta_nan() {
Zeta::new(f64::NAN).unwrap();
}
#[test]
fn zeta_sample() {
let a = 2.0;
let d = Zeta::new(a).unwrap();
let mut rng = crate::test::rng(1);
for _ in 0..1000 {
let r = d.sample(&mut rng);
assert!(r >= 1.);
}
}
#[test]
fn zeta_small_a() {
let a = 1. + 1e-15;
let d = Zeta::new(a).unwrap();
let mut rng = crate::test::rng(2);
for _ in 0..1000 {
let r = d.sample(&mut rng);
assert!(r >= 1.);
}
}
#[test]
fn zeta_value_stability() {
test_samples(Zeta::new(1.5).unwrap(), 0f32, &[1.0, 2.0, 1.0, 1.0]);
test_samples(Zeta::new(2.0).unwrap(), 0f64, &[2.0, 1.0, 1.0, 1.0]);
}
#[test]
fn zeta_distributions_can_be_compared() {
assert_eq!(Zeta::new(1.0), Zeta::new(1.0));
}
}