use super::family::GlmFamily;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct NegativeBinomialFamily {
pub theta: f64,
}
impl Default for NegativeBinomialFamily {
fn default() -> Self {
Self::new(1.0)
}
}
impl NegativeBinomialFamily {
pub fn new(theta: f64) -> Self {
assert!(theta > 0.0, "theta must be positive");
Self { theta }
}
pub fn poisson_like() -> Self {
Self::new(1e6)
}
pub fn with_theta(&self, new_theta: f64) -> Self {
Self::new(new_theta)
}
pub fn overdispersion_ratio(&self, mu: f64) -> f64 {
1.0 + mu / self.theta
}
}
impl GlmFamily for NegativeBinomialFamily {
fn variance(&self, mu: f64) -> f64 {
let mu_safe = mu.max(1e-10);
mu_safe + mu_safe * mu_safe / self.theta
}
fn link(&self, mu: f64) -> f64 {
mu.max(1e-10).ln()
}
fn link_inverse(&self, eta: f64) -> f64 {
if eta > 30.0 {
(30.0_f64).exp()
} else if eta < -30.0 {
1e-14
} else {
eta.exp().max(1e-14)
}
}
fn link_derivative(&self, mu: f64) -> f64 {
1.0 / mu.max(1e-10)
}
fn unit_deviance(&self, y: f64, mu: f64) -> f64 {
let mu_safe = mu.max(1e-10);
let theta = self.theta;
if y < 1e-10 {
2.0 * theta * (theta / (mu_safe + theta)).ln()
} else {
let term1 = y * (y / mu_safe).ln();
let term2 = (y + theta) * ((y + theta) / (mu_safe + theta)).ln();
2.0 * (term1 - term2)
}
}
fn initialize_mu(&self, y: &[f64]) -> Vec<f64> {
let y_mean = y.iter().sum::<f64>() / y.len() as f64;
let y_mean = y_mean.max(1e-3);
y.iter()
.map(|&yi| {
let mu = (yi + y_mean) / 2.0;
mu.max(1e-3)
})
.collect()
}
}
pub fn estimate_theta_moments(y: &[f64], mu: &[f64]) -> f64 {
let n = y.len() as f64;
let chi2: f64 = y
.iter()
.zip(mu.iter())
.map(|(&yi, &mui)| {
let mui = mui.max(1e-10);
(yi - mui).powi(2) / mui
})
.sum();
let mu_bar = mu.iter().sum::<f64>() / n;
let overdispersion = (chi2 / n - 1.0).max(0.01);
(mu_bar / overdispersion).max(0.1)
}
pub fn estimate_theta_ml(y: &[f64], mu: &[f64], max_iter: usize, tol: f64) -> f64 {
let _n = y.len();
let mut theta = estimate_theta_moments(y, mu);
theta = theta.clamp(0.1, 1e6);
for _ in 0..max_iter {
let (score, info) = theta_score_and_info(y, mu, theta);
if info.abs() < 1e-14 {
break;
}
let delta = score / info;
let new_theta = (theta + delta).clamp(0.01, 1e8);
if (new_theta - theta).abs() < tol * theta.max(1.0) {
theta = new_theta;
break;
}
theta = new_theta;
}
theta
}
fn theta_score_and_info(y: &[f64], mu: &[f64], theta: f64) -> (f64, f64) {
let mut score = 0.0;
let mut info = 0.0;
for (&yi, &mui) in y.iter().zip(mu.iter()) {
let mui = mui.max(1e-10);
score += (theta / (mui + theta)).ln() + (yi - mui) / (mui + theta);
info += 1.0 / theta - 1.0 / (mui + theta);
}
(score, info.abs().max(1e-10))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_variance() {
let nb = NegativeBinomialFamily::new(2.0);
assert!((nb.variance(4.0) - 12.0).abs() < 1e-10);
assert!((nb.variance(1.0) - 1.5).abs() < 1e-10);
}
#[test]
fn test_approaches_poisson() {
let nb_high_theta = NegativeBinomialFamily::new(1e6);
let mu = 5.0;
let var = nb_high_theta.variance(mu);
assert!((var - mu).abs() / mu < 0.01);
}
#[test]
fn test_link_roundtrip() {
let nb = NegativeBinomialFamily::new(1.0);
for mu in [0.5, 1.0, 2.0, 5.0, 10.0] {
let eta = nb.link(mu);
let mu_back = nb.link_inverse(eta);
assert!((mu - mu_back).abs() < 1e-8, "Failed for mu={}", mu);
}
}
#[test]
fn test_unit_deviance_perfect_fit() {
let nb = NegativeBinomialFamily::new(2.0);
let dev = nb.unit_deviance(5.0, 5.0);
assert!(dev.abs() < 1e-10);
}
#[test]
fn test_unit_deviance_zero() {
let nb = NegativeBinomialFamily::new(2.0);
let dev = nb.unit_deviance(0.0, 1.0);
let expected = 2.0 * 2.0 * (2.0 / 3.0_f64).ln();
assert!((dev - expected).abs() < 1e-6);
}
#[test]
fn test_deviance() {
let nb = NegativeBinomialFamily::new(2.0);
let y = vec![1.0, 2.0, 3.0, 4.0];
let mu = vec![1.0, 2.0, 3.0, 4.0];
let dev = nb.deviance(&y, &mu);
assert!(dev < 1e-8);
}
#[test]
fn test_initialize_mu() {
let nb = NegativeBinomialFamily::new(1.0);
let y = vec![0.0, 1.0, 5.0, 10.0];
let mu_init = nb.initialize_mu(&y);
for &mu in &mu_init {
assert!(mu > 0.0);
}
}
#[test]
fn test_overdispersion_ratio() {
let nb = NegativeBinomialFamily::new(2.0);
assert!((nb.overdispersion_ratio(4.0) - 3.0).abs() < 1e-10);
}
#[test]
fn test_estimate_theta_moments() {
let y = vec![0.0, 1.0, 0.0, 5.0, 2.0, 0.0, 8.0, 1.0, 0.0, 3.0];
let mu = vec![2.0; 10];
let theta = estimate_theta_moments(&y, &mu);
assert!(theta > 0.0);
}
#[test]
fn test_irls_weight() {
let nb = NegativeBinomialFamily::new(2.0);
let w = nb.irls_weight(2.0);
assert!((w - 1.0).abs() < 1e-10);
}
}