use crate::{AprenderError, Result};
#[derive(Debug, Clone)]
pub struct BetaBinomial {
alpha: f32,
beta: f32,
}
impl BetaBinomial {
#[must_use]
pub fn uniform() -> Self {
Self {
alpha: 1.0,
beta: 1.0,
}
}
#[must_use]
pub fn jeffreys() -> Self {
Self {
alpha: 0.5,
beta: 0.5,
}
}
pub fn new(alpha: f32, beta: f32) -> Result<Self> {
if alpha <= 0.0 || beta <= 0.0 {
return Err(AprenderError::InvalidHyperparameter {
param: "alpha, beta".to_string(),
value: format!("({alpha}, {beta})"),
constraint: "both > 0".to_string(),
});
}
Ok(Self { alpha, beta })
}
#[must_use]
pub fn alpha(&self) -> f32 {
self.alpha
}
#[must_use]
pub fn beta(&self) -> f32 {
self.beta
}
pub fn update(&mut self, successes: u32, trials: u32) {
assert!(successes <= trials, "Successes cannot exceed total trials");
let failures = trials - successes;
#[allow(clippy::cast_precision_loss)]
{
self.alpha += successes as f32;
self.beta += failures as f32;
}
}
#[must_use]
pub fn posterior_mean(&self) -> f32 {
self.alpha / (self.alpha + self.beta)
}
#[must_use]
pub fn posterior_mode(&self) -> Option<f32> {
if self.alpha > 1.0 && self.beta > 1.0 {
Some((self.alpha - 1.0) / (self.alpha + self.beta - 2.0))
} else {
None
}
}
#[must_use]
pub fn posterior_variance(&self) -> f32 {
let sum = self.alpha + self.beta;
(self.alpha * self.beta) / (sum * sum * (sum + 1.0))
}
#[must_use]
pub fn posterior_predictive(&self) -> f32 {
self.posterior_mean()
}
pub fn credible_interval(&self, confidence: f32) -> Result<(f32, f32)> {
if !(0.0..1.0).contains(&confidence) {
return Err(AprenderError::InvalidHyperparameter {
param: "confidence".to_string(),
value: confidence.to_string(),
constraint: "in (0, 1)".to_string(),
});
}
let mean = self.posterior_mean();
let std = self.posterior_variance().sqrt();
let z = match confidence {
c if (c - 0.95).abs() < 0.01 => 1.96,
c if (c - 0.99).abs() < 0.01 => 2.576,
c if (c - 0.90).abs() < 0.01 => 1.645,
_ => 1.96, };
let lower = (mean - z * std).max(0.0);
let upper = (mean + z * std).min(1.0);
Ok((lower, upper))
}
}
#[derive(Debug, Clone)]
pub struct GammaPoisson {
alpha: f32,
beta: f32,
}
include!("normal_inverse_gamma.rs");
include!("normal_inverse_gamma_methods.rs");
include!("dirichlet_multinomial.rs");
#[cfg(test)]
#[path = "tests_conjugate_contract.rs"]
mod tests_conjugate_contract;