use crate::{Distribution, Open01};
use core::fmt;
use num_traits::Float;
use rand::{Rng, RngExt};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
enum BetaAlgorithm<N> {
BB(BB<N>),
BC(BC<N>),
}
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
struct BB<N> {
alpha: N,
beta: N,
gamma: N,
}
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
struct BC<N> {
alpha: N,
beta: N,
kappa1: N,
kappa2: N,
}
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Beta<F>
where
F: Float,
Open01: Distribution<F>,
{
a: F,
b: F,
switched_params: bool,
algorithm: BetaAlgorithm<F>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum Error {
AlphaTooSmall,
BetaTooSmall,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Error::AlphaTooSmall => "alpha is not positive in beta distribution",
Error::BetaTooSmall => "beta is not positive in beta distribution",
})
}
}
#[cfg(feature = "std")]
impl std::error::Error for Error {}
impl<F> Beta<F>
where
F: Float,
Open01: Distribution<F>,
{
pub fn new(alpha: F, beta: F) -> Result<Beta<F>, Error> {
if !(alpha > F::zero()) {
return Err(Error::AlphaTooSmall);
}
if !(beta > F::zero()) {
return Err(Error::BetaTooSmall);
}
let (a0, b0) = (alpha, beta);
let (a, b, switched_params) = if a0 < b0 {
(a0, b0, false)
} else {
(b0, a0, true)
};
if a > F::one() {
let alpha = a + b;
let two = F::from(2.).unwrap();
let beta_numer = alpha - two;
let beta_denom = two * a * b - alpha;
let beta = (beta_numer / beta_denom).sqrt();
let gamma = a + F::one() / beta;
Ok(Beta {
a,
b,
switched_params,
algorithm: BetaAlgorithm::BB(BB { alpha, beta, gamma }),
})
} else {
let (a, b, switched_params) = (b, a, !switched_params);
let alpha = a + b;
let beta = F::one() / b;
let delta = F::one() + a - b;
let kappa1 = delta
* (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap() * b)
/ (a * beta - F::from(14. / 18.).unwrap());
let kappa2 = F::from(0.25).unwrap()
+ (F::from(0.5).unwrap() + F::from(0.25).unwrap() / delta) * b;
Ok(Beta {
a,
b,
switched_params,
algorithm: BetaAlgorithm::BC(BC {
alpha,
beta,
kappa1,
kappa2,
}),
})
}
}
}
impl<F> Distribution<F> for Beta<F>
where
F: Float,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
let mut w;
match self.algorithm {
BetaAlgorithm::BB(algo) => {
loop {
let u1 = rng.sample(Open01);
let u2 = rng.sample(Open01);
let v = algo.beta * (u1 / (F::one() - u1)).ln();
w = self.a * v.exp();
let z = u1 * u1 * u2;
let r = algo.gamma * v - F::from(4.).unwrap().ln();
let s = self.a + r - w;
if s + F::one() + F::from(5.).unwrap().ln() >= F::from(5.).unwrap() * z {
break;
}
let t = z.ln();
if s >= t {
break;
}
if !(r + algo.alpha * (algo.alpha / (self.b + w)).ln() < t) {
break;
}
}
}
BetaAlgorithm::BC(algo) => {
loop {
let z;
let u1 = rng.sample(Open01);
let u2 = rng.sample(Open01);
if u1 < F::from(0.5).unwrap() {
let y = u1 * u2;
z = u1 * y;
if F::from(0.25).unwrap() * u2 + z - y >= algo.kappa1 {
continue;
}
} else {
z = u1 * u1 * u2;
if z <= F::from(0.25).unwrap() {
let v = algo.beta * (u1 / (F::one() - u1)).ln();
w = self.a * v.exp();
break;
}
if z >= algo.kappa2 {
continue;
}
}
let v = algo.beta * (u1 / (F::one() - u1)).ln();
w = self.a * v.exp();
if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v)
- F::from(4.).unwrap().ln()
< z.ln())
{
break;
};
}
}
};
if !self.switched_params {
if w == F::infinity() {
return F::one();
}
w / (self.b + w)
} else {
self.b / (self.b + w)
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_beta() {
let beta = Beta::new(1.0, 2.0).unwrap();
let mut rng = crate::test::rng(201);
for _ in 0..1000 {
beta.sample(&mut rng);
}
}
#[test]
#[should_panic]
fn test_beta_invalid_dof() {
Beta::new(0., 0.).unwrap();
}
#[test]
fn test_beta_small_param() {
let beta = Beta::<f64>::new(1e-3, 1e-3).unwrap();
let mut rng = crate::test::rng(206);
for i in 0..1000 {
assert!(!beta.sample(&mut rng).is_nan(), "failed at i={i}");
}
}
#[test]
fn beta_distributions_can_be_compared() {
assert_eq!(Beta::new(1.0, 2.0), Beta::new(1.0, 2.0));
}
}