#![allow(
clippy::suboptimal_flops,
clippy::float_cmp,
clippy::neg_cmp_op_on_partial_ord
)]
macro_rules! impl_gamma_poisson {
($name:ident, $builder:ident, $ty:ty) => {
#[doc = concat!("use nexus_stats::estimation::", stringify!($name), ";")]
#[doc = concat!("let mut gp = ", stringify!($name), "::new();")]
#[doc = concat!("assert!((rate - 9.18 as ", stringify!($ty), ").abs() < 0.01 as ", stringify!($ty), ");")]
#[derive(Debug, Clone)]
pub struct $name {
alpha: $ty,
beta: $ty,
prior_alpha: $ty,
prior_beta: $ty,
}
#[doc = stringify!($name)]
#[derive(Debug, Clone)]
pub struct $builder {
alpha: $ty,
beta: $ty,
}
impl $name {
#[inline]
#[must_use]
pub fn builder() -> $builder {
$builder {
alpha: 1.0 as $ty,
beta: 1.0 as $ty,
}
}
#[inline]
#[must_use]
pub fn new() -> Self {
Self {
alpha: 1.0 as $ty,
beta: 1.0 as $ty,
prior_alpha: 1.0 as $ty,
prior_beta: 1.0 as $ty,
}
}
#[inline]
pub fn with_prior(alpha: $ty, beta: $ty) -> Result<Self, crate::ConfigError> {
if !(alpha > 0.0 as $ty) {
return Err(crate::ConfigError::Invalid("alpha must be > 0"));
}
if !(beta > 0.0 as $ty) {
return Err(crate::ConfigError::Invalid("beta must be > 0"));
}
Ok(Self {
alpha,
beta,
prior_alpha: alpha,
prior_beta: beta,
})
}
#[inline]
pub fn update(&mut self, count: u64, exposure: $ty) -> Result<(), crate::DataError> {
check_finite!(exposure);
self.alpha += count as $ty;
self.beta += exposure;
Ok(())
}
#[inline]
#[must_use]
pub fn rate(&self) -> $ty {
self.alpha / self.beta
}
#[inline]
#[must_use]
pub fn variance(&self) -> $ty {
self.alpha / (self.beta * self.beta)
}
#[cfg(any(feature = "std", feature = "libm"))]
#[inline]
#[must_use]
pub fn credible_interval(&self, confidence: $ty) -> Option<($ty, $ty)> {
if self.total_exposure() <= 0.0 as $ty {
return Option::None;
}
if !(confidence > 0.0 as $ty && confidence < 1.0 as $ty) {
return Option::None;
}
let tail = (1.0 as $ty - confidence) / 2.0 as $ty;
#[allow(clippy::cast_possible_truncation)]
let t = crate::math::sqrt(
-2.0 * crate::math::ln(tail as f64),
) as $ty;
let z = t
- (2.515517 as $ty + 0.802853 as $ty * t + 0.010328 as $ty * t * t)
/ (1.0 as $ty
+ 1.432788 as $ty * t
+ 0.189269 as $ty * t * t
+ 0.001308 as $ty * t * t * t);
#[allow(clippy::cast_possible_truncation)]
let std_dev = crate::math::sqrt(self.variance() as f64) as $ty;
let mean = self.rate();
let lower = (mean - z * std_dev).max(0.0 as $ty);
Option::Some((lower, mean + z * std_dev))
}
#[inline]
#[must_use]
pub fn total_count(&self) -> $ty {
self.alpha - self.prior_alpha
}
#[inline]
#[must_use]
pub fn total_exposure(&self) -> $ty {
self.beta - self.prior_beta
}
#[inline]
#[must_use]
pub fn count(&self) -> u64 {
self.total_count() as u64
}
#[inline]
#[must_use]
pub fn is_primed(&self) -> bool {
self.total_exposure() > 0.0 as $ty
}
#[inline]
pub fn reset(&mut self) {
self.alpha = self.prior_alpha;
self.beta = self.prior_beta;
}
}
impl Default for $name {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl $builder {
#[inline]
#[must_use]
pub fn alpha(mut self, alpha: $ty) -> Self {
self.alpha = alpha;
self
}
#[inline]
#[must_use]
pub fn beta(mut self, beta: $ty) -> Self {
self.beta = beta;
self
}
#[inline]
pub fn build(self) -> Result<$name, crate::ConfigError> {
if !(self.alpha > 0.0 as $ty) {
return Err(crate::ConfigError::Invalid("alpha must be > 0"));
}
if !(self.beta > 0.0 as $ty) {
return Err(crate::ConfigError::Invalid("beta must be > 0"));
}
Ok($name {
alpha: self.alpha,
beta: self.beta,
prior_alpha: self.alpha,
prior_beta: self.beta,
})
}
}
};
}
impl_gamma_poisson!(GammaPoissonF64, GammaPoissonF64Builder, f64);
impl_gamma_poisson!(GammaPoissonF32, GammaPoissonF32Builder, f32);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rate_after_observation() {
let mut gp = GammaPoissonF64::new();
gp.update(100, 10.0).unwrap();
let rate = gp.rate();
assert!((rate - 101.0 / 11.0).abs() < 1e-10);
}
#[test]
fn variance_decreases_with_exposure() {
let mut gp = GammaPoissonF64::new();
gp.update(10, 1.0).unwrap();
let v1 = gp.variance();
gp.update(100, 10.0).unwrap();
let v2 = gp.variance();
assert!(v2 < v1, "variance should decrease with more exposure");
}
#[cfg(any(feature = "std", feature = "libm"))]
#[test]
fn credible_interval_narrows_with_data() {
let mut gp = GammaPoissonF64::new();
gp.update(10, 1.0).unwrap();
let (lo1, hi1) = gp.credible_interval(0.95).unwrap();
let width1 = hi1 - lo1;
gp.update(1000, 100.0).unwrap();
let (lo2, hi2) = gp.credible_interval(0.95).unwrap();
let width2 = hi2 - lo2;
assert!(width2 < width1, "interval should narrow with more data");
let rate = gp.rate();
assert!(rate >= lo2 && rate <= hi2);
}
#[test]
fn reset_restores_prior() {
let mut gp = GammaPoissonF64::with_prior(2.0, 3.0).unwrap();
gp.update(50, 5.0).unwrap();
assert!(gp.count() > 0);
gp.reset();
assert_eq!(gp.count(), 0);
assert_eq!(gp.total_exposure(), 0.0);
assert_eq!(gp.rate(), 2.0 / 3.0);
}
#[test]
fn with_prior_validation() {
assert!(GammaPoissonF64::with_prior(0.0, 1.0).is_err());
assert!(GammaPoissonF64::with_prior(-1.0, 1.0).is_err());
assert!(GammaPoissonF64::with_prior(1.0, 0.0).is_err());
assert!(GammaPoissonF64::with_prior(1.0, -1.0).is_err());
assert!(GammaPoissonF64::with_prior(f64::NAN, 1.0).is_err());
assert!(GammaPoissonF64::with_prior(1.0, f64::NAN).is_err());
assert!(GammaPoissonF64::with_prior(1.0, 1.0).is_ok());
}
#[test]
fn f32_variant() {
let mut gp = GammaPoissonF32::new();
gp.update(50, 5.0).unwrap();
let rate = gp.rate();
assert!((rate - 8.5_f32).abs() < 0.01);
}
#[test]
fn default_is_new() {
let a = GammaPoissonF64::new();
let b = GammaPoissonF64::default();
assert_eq!(a.rate(), b.rate());
assert_eq!(a.variance(), b.variance());
}
#[test]
fn batch_observation_accumulates() {
let mut gp = GammaPoissonF64::new();
gp.update(10, 1.0).unwrap();
gp.update(20, 2.0).unwrap();
gp.update(30, 3.0).unwrap();
assert_eq!(gp.count(), 60);
assert!((gp.total_exposure() - 6.0).abs() < 1e-10);
assert!((gp.rate() - 61.0 / 7.0).abs() < 1e-10);
}
#[cfg(any(feature = "std", feature = "libm"))]
#[test]
fn credible_interval_none_without_exposure() {
let gp = GammaPoissonF64::new();
assert!(gp.credible_interval(0.95).is_none());
}
#[cfg(any(feature = "std", feature = "libm"))]
#[test]
fn credible_interval_none_for_invalid_confidence() {
let mut gp = GammaPoissonF64::new();
gp.update(10, 1.0).unwrap();
assert!(gp.credible_interval(0.0).is_none());
assert!(gp.credible_interval(1.0).is_none());
assert!(gp.credible_interval(-0.5).is_none());
assert!(gp.credible_interval(1.5).is_none());
}
#[test]
fn builder_defaults() {
let gp = GammaPoissonF64::builder().build().unwrap();
assert_eq!(gp.rate(), 1.0); }
#[test]
fn builder_custom_prior() {
let gp = GammaPoissonF64::builder()
.alpha(5.0)
.beta(2.0)
.build()
.unwrap();
assert_eq!(gp.rate(), 2.5); }
#[test]
fn builder_validation() {
assert!(GammaPoissonF64::builder().alpha(0.0).build().is_err());
assert!(GammaPoissonF64::builder().beta(-1.0).build().is_err());
}
#[test]
fn rejects_nan_and_inf_exposure() {
let mut gp = GammaPoissonF64::new();
assert_eq!(
gp.update(10, f64::NAN),
Err(crate::DataError::NotANumber)
);
assert_eq!(
gp.update(10, f64::INFINITY),
Err(crate::DataError::Infinite)
);
assert_eq!(gp.count(), 0);
}
}