use super::family::GlmFamily;
use super::poisson_link::PoissonLink;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct PoissonFamily {
pub link: PoissonLink,
}
impl Default for PoissonFamily {
fn default() -> Self {
Self::log()
}
}
impl PoissonFamily {
pub fn new(link: PoissonLink) -> Self {
Self { link }
}
pub fn log() -> Self {
Self {
link: PoissonLink::Log,
}
}
pub fn identity() -> Self {
Self {
link: PoissonLink::Identity,
}
}
pub fn sqrt() -> Self {
Self {
link: PoissonLink::Sqrt,
}
}
pub fn is_canonical_link(&self) -> bool {
self.link == PoissonLink::Log
}
}
impl GlmFamily for PoissonFamily {
fn variance(&self, mu: f64) -> f64 {
mu.max(1e-10)
}
fn link(&self, mu: f64) -> f64 {
self.link.link(mu)
}
fn link_inverse(&self, eta: f64) -> f64 {
self.link.link_inverse(eta)
}
fn link_derivative(&self, mu: f64) -> f64 {
self.link.link_derivative(mu)
}
fn unit_deviance(&self, y: f64, mu: f64) -> f64 {
let mu_clamped = mu.max(1e-10);
if y < 1e-10 {
2.0 * mu_clamped
} else {
2.0 * (y * (y / mu_clamped).ln() - (y - mu_clamped))
}
}
fn initialize_mu(&self, y: &[f64]) -> Vec<f64> {
let y_mean = y.iter().sum::<f64>() / y.len() as f64;
let y_mean = y_mean.max(1e-3);
y.iter()
.map(|&yi| {
let mu = (yi + y_mean) / 2.0;
mu.max(1e-3) })
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_log_family() {
let fam = PoissonFamily::log();
assert!(fam.is_canonical_link());
assert_eq!(fam.link, PoissonLink::Log);
}
#[test]
fn test_identity_family() {
let fam = PoissonFamily::identity();
assert!(!fam.is_canonical_link());
assert_eq!(fam.link, PoissonLink::Identity);
}
#[test]
fn test_sqrt_family() {
let fam = PoissonFamily::sqrt();
assert!(!fam.is_canonical_link());
assert_eq!(fam.link, PoissonLink::Sqrt);
}
#[test]
fn test_variance() {
let fam = PoissonFamily::log();
assert!((fam.variance(1.0) - 1.0).abs() < 1e-10);
assert!((fam.variance(5.0) - 5.0).abs() < 1e-10);
assert!((fam.variance(10.0) - 10.0).abs() < 1e-10);
}
#[test]
fn test_link_roundtrip() {
let families = [
PoissonFamily::log(),
PoissonFamily::identity(),
PoissonFamily::sqrt(),
];
for fam in &families {
for mu in [0.5, 1.0, 2.0, 5.0, 10.0] {
let eta = fam.link(mu);
let mu_back = fam.link_inverse(eta);
assert!(
(mu - mu_back).abs() < 1e-6,
"Roundtrip failed for {:?} at mu={}",
fam.link,
mu
);
}
}
}
#[test]
fn test_unit_deviance_perfect_fit() {
let fam = PoissonFamily::log();
assert!(fam.unit_deviance(5.0, 5.0).abs() < 1e-10);
assert!(fam.unit_deviance(1.0, 1.0).abs() < 1e-10);
}
#[test]
fn test_unit_deviance_zero() {
let fam = PoissonFamily::log();
let dev = fam.unit_deviance(0.0, 1.0);
assert!((dev - 2.0).abs() < 1e-10);
let dev = fam.unit_deviance(0.0, 5.0);
assert!((dev - 10.0).abs() < 1e-10);
}
#[test]
fn test_unit_deviance_nonzero() {
let fam = PoissonFamily::log();
let dev = fam.unit_deviance(5.0, 4.0);
let expected = 2.0 * (5.0 * (5.0_f64 / 4.0).ln() - 1.0);
assert!((dev - expected).abs() < 1e-6);
}
#[test]
fn test_deviance() {
let fam = PoissonFamily::log();
let y = vec![1.0, 2.0, 3.0, 4.0];
let mu = vec![1.0, 2.0, 3.0, 4.0];
let dev = fam.deviance(&y, &mu);
assert!(dev < 1e-8); }
#[test]
fn test_initialize_mu() {
let fam = PoissonFamily::log();
let y = vec![0.0, 1.0, 5.0, 10.0];
let mu_init = fam.initialize_mu(&y);
for &mu in &mu_init {
assert!(mu > 0.0);
}
assert!(mu_init[0] > 0.0); }
#[test]
fn test_irls_weight_log() {
let fam = PoissonFamily::log();
let w = fam.irls_weight(1.0);
assert!((w - 1.0).abs() < 1e-10);
let w = fam.irls_weight(4.0);
assert!((w - 4.0).abs() < 1e-10);
}
#[test]
fn test_working_response() {
let fam = PoissonFamily::log();
let y = 5.0;
let mu = 4.0;
let eta = fam.link(mu);
let z = fam.working_response(y, mu, eta);
let expected = eta + (y - mu) * fam.link_derivative(mu);
assert!((z - expected).abs() < 1e-10);
}
#[test]
fn test_null_deviance() {
let fam = PoissonFamily::log();
let y = vec![1.0, 2.0, 3.0, 4.0];
let null_dev = fam.null_deviance(&y);
assert!(null_dev > 0.0);
}
}