#![allow(
clippy::suboptimal_flops,
clippy::float_cmp,
clippy::neg_cmp_op_on_partial_ord
)]
macro_rules! impl_beta_binomial {
($name:ident, $builder:ident, $ty:ty) => {
#[derive(Debug, Clone)]
pub struct $name {
alpha: $ty,
beta: $ty,
prior_alpha: $ty,
prior_beta: $ty,
}
#[doc = stringify!($name)]
#[derive(Debug, Clone)]
pub struct $builder {
alpha: $ty,
beta: $ty,
}
impl $name {
#[inline]
#[must_use]
pub const fn new() -> Self {
Self {
alpha: 1.0 as $ty,
beta: 1.0 as $ty,
prior_alpha: 1.0 as $ty,
prior_beta: 1.0 as $ty,
}
}
#[inline]
pub fn with_prior(alpha: $ty, beta: $ty) -> Result<Self, crate::ConfigError> {
if !(alpha > 0.0 as $ty) {
return Err(crate::ConfigError::Invalid("alpha must be positive"));
}
if !(beta > 0.0 as $ty) {
return Err(crate::ConfigError::Invalid("beta must be positive"));
}
Ok(Self {
alpha,
beta,
prior_alpha: alpha,
prior_beta: beta,
})
}
#[inline]
#[must_use]
pub fn builder() -> $builder {
$builder {
alpha: 1.0 as $ty,
beta: 1.0 as $ty,
}
}
#[inline]
pub fn update(&mut self, success: bool) {
if success {
self.alpha += 1.0 as $ty;
} else {
self.beta += 1.0 as $ty;
}
}
#[inline]
pub fn update_batch(&mut self, successes: u64, failures: u64) {
self.alpha += successes as $ty;
self.beta += failures as $ty;
}
#[inline]
#[must_use]
pub fn mean(&self) -> $ty {
self.alpha / (self.alpha + self.beta)
}
#[inline]
#[must_use]
pub fn variance(&self) -> $ty {
let sum = self.alpha + self.beta;
(self.alpha * self.beta) / (sum * sum * (sum + 1.0 as $ty))
}
#[inline]
#[must_use]
pub fn mode(&self) -> Option<$ty> {
if self.alpha <= 1.0 as $ty || self.beta <= 1.0 as $ty {
Option::None
} else {
Option::Some((self.alpha - 1.0 as $ty) / (self.alpha + self.beta - 2.0 as $ty))
}
}
#[inline]
#[must_use]
#[cfg(any(feature = "std", feature = "libm"))]
pub fn credible_interval(&self, confidence: $ty) -> Option<($ty, $ty)> {
if self.total() == 0.0 as $ty {
return Option::None;
}
if !(confidence > 0.0 as $ty && confidence < 1.0 as $ty) {
return Option::None;
}
let tail = (1.0 as $ty - confidence) / 2.0 as $ty;
#[allow(clippy::cast_possible_truncation)]
let t = crate::math::sqrt((-2.0 as f64) * crate::math::ln(tail as f64)) as $ty;
let z = t
- (2.515517 as $ty + 0.802853 as $ty * t + 0.010328 as $ty * t * t)
/ (1.0 as $ty
+ 1.432788 as $ty * t
+ 0.189269 as $ty * t * t
+ 0.001308 as $ty * t * t * t);
let mean = self.mean();
#[allow(clippy::cast_possible_truncation)]
let std_dev = crate::math::sqrt(self.variance() as f64) as $ty;
let half_width = z * std_dev;
let lower = if mean - half_width < 0.0 as $ty {
0.0 as $ty
} else {
mean - half_width
};
let upper = if mean + half_width > 1.0 as $ty {
1.0 as $ty
} else {
mean + half_width
};
Option::Some((lower, upper))
}
#[inline]
#[must_use]
pub fn successes(&self) -> $ty {
self.alpha - self.prior_alpha
}
#[inline]
#[must_use]
pub fn failures(&self) -> $ty {
self.beta - self.prior_beta
}
#[inline]
#[must_use]
pub fn total(&self) -> $ty {
self.successes() + self.failures()
}
#[inline]
#[must_use]
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
pub fn count(&self) -> u64 {
self.total() as u64
}
#[inline]
#[must_use]
pub fn is_primed(&self) -> bool {
self.total() > 0.0 as $ty
}
#[inline]
pub fn reset(&mut self) {
self.alpha = self.prior_alpha;
self.beta = self.prior_beta;
}
}
impl Default for $name {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl $builder {
#[inline]
#[must_use]
pub fn alpha(mut self, alpha: $ty) -> Self {
self.alpha = alpha;
self
}
#[inline]
#[must_use]
pub fn beta(mut self, beta: $ty) -> Self {
self.beta = beta;
self
}
#[inline]
pub fn build(self) -> Result<$name, crate::ConfigError> {
if !(self.alpha > 0.0 as $ty) {
return Err(crate::ConfigError::Invalid("alpha must be positive"));
}
if !(self.beta > 0.0 as $ty) {
return Err(crate::ConfigError::Invalid("beta must be positive"));
}
Ok($name {
alpha: self.alpha,
beta: self.beta,
prior_alpha: self.alpha,
prior_beta: self.beta,
})
}
}
};
}
impl_beta_binomial!(BetaBinomialF64, BetaBinomialF64Builder, f64);
impl_beta_binomial!(BetaBinomialF32, BetaBinomialF32Builder, f32);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn uniform_prior_balanced_observations() {
let mut bb = BetaBinomialF64::new();
for _ in 0..50 {
bb.update(true);
bb.update(false);
}
let mean = bb.mean();
assert!((mean - 0.5).abs() < 0.01, "expected ~0.5, got {mean}");
}
#[test]
fn informative_prior() {
let bb = BetaBinomialF64::with_prior(10.0, 1.0).unwrap();
let mean = bb.mean();
assert!(
(mean - 10.0 / 11.0).abs() < 1e-10,
"expected ~0.909, got {mean}"
);
}
#[test]
fn variance_decreases_with_observations() {
let mut bb = BetaBinomialF64::new();
let initial_var = bb.variance();
for _ in 0..100 {
bb.update(true);
}
let final_var = bb.variance();
assert!(
final_var < initial_var,
"variance should decrease: {initial_var} → {final_var}"
);
}
#[test]
fn mode_none_for_uniform_prior() {
let bb = BetaBinomialF64::new();
assert!(bb.mode().is_none(), "mode undefined for Beta(1, 1)");
}
#[test]
fn mode_some_for_informative_prior() {
let bb = BetaBinomialF64::with_prior(10.0, 10.0).unwrap();
let mode = bb.mode().unwrap();
assert!((mode - 0.5).abs() < 1e-10, "expected 0.5, got {mode}");
}
#[cfg(any(feature = "std", feature = "libm"))]
#[test]
fn credible_interval_narrows() {
let mut bb = BetaBinomialF64::with_prior(2.0, 2.0).unwrap();
for _ in 0..10 {
bb.update(true);
bb.update(false);
}
let (lo1, hi1) = bb.credible_interval(0.95).unwrap();
let width1 = hi1 - lo1;
for _ in 0..200 {
bb.update(true);
bb.update(false);
}
let (lo2, hi2) = bb.credible_interval(0.95).unwrap();
let width2 = hi2 - lo2;
assert!(
width2 < width1,
"interval should narrow: {width1:.4} → {width2:.4}"
);
}
#[test]
fn observe_batch_equivalence() {
let mut single = BetaBinomialF64::new();
for _ in 0..30 {
single.update(true);
}
for _ in 0..20 {
single.update(false);
}
let mut batch = BetaBinomialF64::new();
batch.update_batch(30, 20);
assert!(
(single.mean() - batch.mean()).abs() < 1e-10,
"single={} batch={}",
single.mean(),
batch.mean()
);
assert_eq!(single.count(), batch.count());
}
#[test]
fn reset_restores_prior() {
let mut bb = BetaBinomialF64::with_prior(5.0, 3.0).unwrap();
let mean_before = bb.mean();
for _ in 0..100 {
bb.update(true);
}
bb.reset();
assert!(
(bb.mean() - mean_before).abs() < 1e-10,
"reset should restore prior mean"
);
assert_eq!(bb.count(), 0);
assert!(!bb.is_primed());
}
#[test]
fn f32_variant() {
let mut bb = BetaBinomialF32::new();
bb.update(true);
bb.update(false);
let mean = bb.mean();
assert!((mean - 0.5).abs() < 1e-5, "expected ~0.5, got {mean}");
}
#[test]
fn default_is_new() {
let a = BetaBinomialF64::new();
let b = BetaBinomialF64::default();
assert!(
(a.mean() - b.mean()).abs() < 1e-10,
"default and new should be identical"
);
}
#[test]
fn with_prior_validation() {
assert!(BetaBinomialF64::with_prior(0.0, 1.0).is_err());
assert!(BetaBinomialF64::with_prior(1.0, 0.0).is_err());
assert!(BetaBinomialF64::with_prior(-1.0, 1.0).is_err());
assert!(BetaBinomialF64::with_prior(1.0, -1.0).is_err());
assert!(BetaBinomialF64::with_prior(0.5, 0.5).is_ok());
}
}