use super::family::GlmFamily;
use super::link::BinomialLink;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct BinomialFamily {
pub link: BinomialLink,
}
impl Default for BinomialFamily {
fn default() -> Self {
Self::logistic()
}
}
impl BinomialFamily {
pub fn new(link: BinomialLink) -> Self {
Self { link }
}
pub fn logistic() -> Self {
Self {
link: BinomialLink::Logit,
}
}
pub fn probit() -> Self {
Self {
link: BinomialLink::Probit,
}
}
pub fn cloglog() -> Self {
Self {
link: BinomialLink::Cloglog,
}
}
pub fn is_canonical_link(&self) -> bool {
self.link == BinomialLink::Logit
}
}
impl GlmFamily for BinomialFamily {
fn variance(&self, mu: f64) -> f64 {
let mu_clamped = mu.clamp(1e-10, 1.0 - 1e-10);
mu_clamped * (1.0 - mu_clamped)
}
fn link(&self, mu: f64) -> f64 {
self.link.link(mu)
}
fn link_inverse(&self, eta: f64) -> f64 {
self.link.link_inverse(eta)
}
fn link_derivative(&self, mu: f64) -> f64 {
self.link.link_derivative(mu)
}
fn valid_mu(&self, mu: f64) -> bool {
mu.is_finite() && mu > 0.0 && mu < 1.0
}
fn clamp_mu(&self, mu: f64) -> f64 {
mu.clamp(1e-10, 1.0 - 1e-10)
}
fn unit_deviance(&self, y: f64, mu: f64) -> f64 {
let mu_clamped = mu.clamp(1e-10, 1.0 - 1e-10);
let term1 = if y > 1e-10 {
y * (y / mu_clamped).ln()
} else {
0.0
};
let term2 = if y < 1.0 - 1e-10 {
(1.0 - y) * ((1.0 - y) / (1.0 - mu_clamped)).ln()
} else {
0.0
};
(2.0 * (term1 + term2)).max(0.0)
}
fn initialize_mu(&self, y: &[f64]) -> Vec<f64> {
y.iter()
.map(|&yi| {
(yi + 0.5) / 2.0
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_logistic_family() {
let fam = BinomialFamily::logistic();
assert!(fam.is_canonical_link());
assert_eq!(fam.link, BinomialLink::Logit);
}
#[test]
fn test_probit_family() {
let fam = BinomialFamily::probit();
assert!(!fam.is_canonical_link());
assert_eq!(fam.link, BinomialLink::Probit);
}
#[test]
fn test_cloglog_family() {
let fam = BinomialFamily::cloglog();
assert!(!fam.is_canonical_link());
assert_eq!(fam.link, BinomialLink::Cloglog);
}
#[test]
fn test_variance() {
let fam = BinomialFamily::logistic();
assert!((fam.variance(0.5) - 0.25).abs() < 1e-10);
assert!((fam.variance(0.2) - 0.16).abs() < 1e-10);
assert!((fam.variance(0.8) - 0.16).abs() < 1e-10);
}
#[test]
fn test_link_roundtrip() {
let families = [
BinomialFamily::logistic(),
BinomialFamily::probit(),
BinomialFamily::cloglog(),
];
for fam in &families {
for mu in [0.1, 0.3, 0.5, 0.7, 0.9] {
let eta = fam.link(mu);
let mu_back = fam.link_inverse(eta);
assert!(
(mu - mu_back).abs() < 1e-6,
"Roundtrip failed for {:?} at mu={}",
fam.link,
mu
);
}
}
}
#[test]
fn test_unit_deviance() {
let fam = BinomialFamily::logistic();
assert!(fam.unit_deviance(0.5, 0.5).abs() < 1e-10);
let dev = fam.unit_deviance(1.0, 0.9);
let expected = 2.0 * (1.0 / 0.9_f64).ln();
assert!((dev - expected).abs() < 1e-6);
let dev = fam.unit_deviance(0.0, 0.1);
let expected = 2.0 * (1.0 / 0.9_f64).ln();
assert!((dev - expected).abs() < 1e-6);
}
#[test]
fn test_deviance() {
let fam = BinomialFamily::logistic();
let y = vec![0.0, 0.0, 1.0, 1.0];
let mu = vec![0.0001, 0.0001, 0.9999, 0.9999];
let dev = fam.deviance(&y, &mu);
assert!(dev < 0.01); }
#[test]
fn test_initialize_mu() {
let fam = BinomialFamily::logistic();
let y = vec![0.0, 1.0, 0.0, 1.0];
let mu_init = fam.initialize_mu(&y);
for &mu in &mu_init {
assert!(mu > 0.0 && mu < 1.0);
assert!(mu > 0.2 && mu < 0.8); }
}
#[test]
fn test_irls_weight_logistic() {
let fam = BinomialFamily::logistic();
let w = fam.irls_weight(0.5);
assert!((w - 0.25).abs() < 1e-10);
}
#[test]
fn test_working_response() {
let fam = BinomialFamily::logistic();
let y = 1.0;
let mu = 0.5;
let eta = fam.link(mu);
let z = fam.working_response(y, mu, eta);
let expected = eta + (y - mu) * fam.link_derivative(mu);
assert!((z - expected).abs() < 1e-10);
}
#[test]
fn test_null_deviance() {
let fam = BinomialFamily::logistic();
let y = vec![0.0, 0.0, 1.0, 1.0];
let null_dev = fam.null_deviance(&y);
assert!(null_dev > 0.0);
}
#[test]
fn test_valid_mu() {
let fam = BinomialFamily::logistic();
assert!(fam.valid_mu(0.5));
assert!(fam.valid_mu(0.01));
assert!(fam.valid_mu(0.99));
assert!(!fam.valid_mu(0.0));
assert!(!fam.valid_mu(-0.1));
assert!(!fam.valid_mu(1.0));
assert!(!fam.valid_mu(1.1));
assert!(!fam.valid_mu(f64::NAN));
assert!(!fam.valid_mu(f64::INFINITY));
}
#[test]
fn test_clamp_mu() {
let fam = BinomialFamily::logistic();
assert!((fam.clamp_mu(0.5) - 0.5).abs() < 1e-15);
assert!((fam.clamp_mu(0.0) - 1e-10).abs() < 1e-15);
assert!((fam.clamp_mu(-1.0) - 1e-10).abs() < 1e-15);
assert!((fam.clamp_mu(1.0) - (1.0 - 1e-10)).abs() < 1e-15);
assert!((fam.clamp_mu(2.0) - (1.0 - 1e-10)).abs() < 1e-15);
}
#[test]
fn test_all_valid_mu() {
let fam = BinomialFamily::logistic();
let mu_valid = [0.1, 0.5, 0.9];
assert!(fam.all_valid_mu(&mu_valid));
let mu_with_zero = [0.1, 0.0, 0.5];
assert!(!fam.all_valid_mu(&mu_with_zero));
let mu_with_one = [0.1, 1.0, 0.5];
assert!(!fam.all_valid_mu(&mu_with_one));
}
#[test]
fn test_new_constructor() {
let fam = BinomialFamily::new(BinomialLink::Probit);
assert_eq!(fam.link, BinomialLink::Probit);
assert!(!fam.is_canonical_link());
}
#[test]
fn test_default() {
let fam = BinomialFamily::default();
assert_eq!(fam.link, BinomialLink::Logit);
assert!(fam.is_canonical_link());
}
}