pub trait GlmFamily {
fn variance(&self, mu: f64) -> f64;
fn link(&self, mu: f64) -> f64;
fn link_inverse(&self, eta: f64) -> f64;
fn link_derivative(&self, mu: f64) -> f64;
fn valid_mu(&self, mu: f64) -> bool {
mu.is_finite() && mu > 0.0
}
fn valid_eta(&self, eta: f64) -> bool {
eta.is_finite()
}
fn all_valid_mu(&self, mu: &[f64]) -> bool {
mu.iter().all(|&m| self.valid_mu(m))
}
fn all_valid_eta(&self, eta: &[f64]) -> bool {
eta.iter().all(|&e| self.valid_eta(e))
}
fn clamp_mu(&self, mu: f64) -> f64 {
mu.clamp(1e-10, 1e10)
}
fn irls_weight(&self, mu: f64) -> f64 {
let v = self.variance(mu);
let link_deriv = self.link_derivative(mu);
if v.abs() < 1e-14 || link_deriv.abs() < 1e-14 {
return 1e-10;
}
1.0 / (v * link_deriv * link_deriv)
}
fn working_response(&self, y: f64, mu: f64, eta: f64) -> f64 {
let link_deriv = self.link_derivative(mu);
eta + (y - mu) * link_deriv
}
fn unit_deviance(&self, y: f64, mu: f64) -> f64;
fn deviance(&self, y: &[f64], mu: &[f64]) -> f64 {
y.iter()
.zip(mu.iter())
.map(|(&yi, &mui)| self.unit_deviance(yi, mui))
.sum()
}
fn initialize_mu(&self, y: &[f64]) -> Vec<f64>;
fn null_deviance(&self, y: &[f64]) -> f64 {
let y_mean: f64 = y.iter().sum::<f64>() / y.len() as f64;
y.iter().map(|&yi| self.unit_deviance(yi, y_mean)).sum()
}
}
#[derive(Debug, Clone, Copy)]
pub struct TweedieFamily {
pub var_power: f64,
pub link_power: f64,
}
impl Default for TweedieFamily {
fn default() -> Self {
Self {
var_power: 1.5,
link_power: 0.0,
}
}
}
impl TweedieFamily {
pub fn new(var_power: f64, link_power: f64) -> Self {
assert!(
!(var_power > 0.0 && var_power < 1.0),
"var_power in (0, 1) is not allowed (no valid distribution)"
);
Self {
var_power,
link_power,
}
}
pub fn gaussian() -> Self {
Self::new(0.0, 1.0) }
pub fn poisson() -> Self {
Self::new(1.0, 0.0) }
pub fn gamma() -> Self {
Self::new(2.0, 0.0) }
pub fn inverse_gaussian() -> Self {
Self::new(3.0, 0.0) }
pub fn compound_poisson_gamma(var_power: f64) -> Self {
assert!(
var_power > 1.0 && var_power < 2.0,
"Compound Poisson-Gamma requires var_power in (1, 2)"
);
Self::new(var_power, 0.0) }
#[inline]
pub fn variance(&self, mu: f64) -> f64 {
if self.var_power == 0.0 {
1.0 } else if self.var_power == 1.0 {
mu } else if self.var_power == 2.0 {
mu * mu } else {
mu.powf(self.var_power)
}
}
#[inline]
pub fn variance_derivative(&self, mu: f64) -> f64 {
if self.var_power == 0.0 {
0.0
} else if self.var_power == 1.0 {
1.0
} else if self.var_power == 2.0 {
2.0 * mu
} else {
self.var_power * mu.powf(self.var_power - 1.0)
}
}
#[inline]
pub fn link(&self, mu: f64) -> f64 {
if self.link_power == 0.0 {
mu.ln() } else if self.link_power == 1.0 {
mu } else if self.link_power == -1.0 {
1.0 / mu } else {
mu.powf(self.link_power)
}
}
#[inline]
pub fn link_inverse(&self, eta: f64) -> f64 {
if self.link_power == 0.0 {
eta.exp() } else if self.link_power == 1.0 {
eta } else if self.link_power == -1.0 {
1.0 / eta } else {
eta.powf(1.0 / self.link_power)
}
}
#[inline]
pub fn link_inverse_derivative(&self, eta: f64) -> f64 {
if self.link_power == 0.0 {
eta.exp() } else if self.link_power == 1.0 {
1.0 } else if self.link_power == -1.0 {
-1.0 / (eta * eta) } else {
(1.0 / self.link_power) * eta.powf(1.0 / self.link_power - 1.0)
}
}
#[inline]
pub fn link_derivative(&self, mu: f64) -> f64 {
if self.link_power == 0.0 {
1.0 / mu } else if self.link_power == 1.0 {
1.0 } else if self.link_power == -1.0 {
-1.0 / (mu * mu) } else {
self.link_power * mu.powf(self.link_power - 1.0)
}
}
#[inline]
pub fn irls_weight(&self, mu: f64) -> f64 {
let v = self.variance(mu);
let link_deriv = self.link_derivative(mu);
if v.abs() < 1e-14 || link_deriv.abs() < 1e-14 {
return 1e-10; }
1.0 / (v * link_deriv * link_deriv)
}
#[inline]
pub fn working_response(&self, y: f64, mu: f64, eta: f64) -> f64 {
let link_deriv = self.link_derivative(mu);
eta + (y - mu) * link_deriv
}
pub fn unit_deviance(&self, y: f64, mu: f64) -> f64 {
let p = self.var_power;
if p == 0.0 {
(y - mu).powi(2)
} else if (p - 1.0).abs() < 1e-10 {
if y > 0.0 {
2.0 * (y * (y / mu).ln() - (y - mu))
} else {
2.0 * mu
}
} else if (p - 2.0).abs() < 1e-10 {
2.0 * (-(y / mu).ln() + (y - mu) / mu)
} else {
let term1 = if y > 0.0 {
y.powf(2.0 - p) / ((1.0 - p) * (2.0 - p))
} else {
0.0
};
let term2 = if y > 0.0 {
-y * mu.powf(1.0 - p) / (1.0 - p)
} else {
0.0
};
let term3 = mu.powf(2.0 - p) / (2.0 - p);
2.0 * (term1 + term2 + term3)
}
}
pub fn deviance(&self, y: &[f64], mu: &[f64]) -> f64 {
y.iter()
.zip(mu.iter())
.map(|(&yi, &mui)| self.unit_deviance(yi, mui))
.sum()
}
pub fn null_deviance(&self, y: &[f64]) -> f64 {
let y_mean: f64 = y.iter().sum::<f64>() / y.len() as f64;
y.iter().map(|&yi| self.unit_deviance(yi, y_mean)).sum()
}
#[allow(clippy::redundant_guards)]
pub fn initialize_mu(&self, y: &[f64]) -> Vec<f64> {
let n = y.len();
let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
match self.var_power {
p if p == 0.0 => {
y.to_vec()
}
p if p >= 1.0 => {
y.iter()
.map(|&yi| {
let mu = (yi + y_mean) / 2.0;
if mu <= 0.0 {
y_mean.max(0.1)
} else {
mu
}
})
.collect()
}
_ => y.to_vec(),
}
}
pub fn is_valid(&self) -> bool {
!(self.var_power > 0.0 && self.var_power < 1.0)
}
pub fn canonical_link_power(&self) -> f64 {
1.0 - self.var_power
}
pub fn is_canonical_link(&self) -> bool {
(self.link_power - self.canonical_link_power()).abs() < 1e-10
}
}
impl GlmFamily for TweedieFamily {
fn variance(&self, mu: f64) -> f64 {
TweedieFamily::variance(self, mu)
}
fn link(&self, mu: f64) -> f64 {
TweedieFamily::link(self, mu)
}
fn link_inverse(&self, eta: f64) -> f64 {
TweedieFamily::link_inverse(self, eta)
}
fn link_derivative(&self, mu: f64) -> f64 {
TweedieFamily::link_derivative(self, mu)
}
fn unit_deviance(&self, y: f64, mu: f64) -> f64 {
TweedieFamily::unit_deviance(self, y, mu)
}
fn initialize_mu(&self, y: &[f64]) -> Vec<f64> {
TweedieFamily::initialize_mu(self, y)
}
fn valid_mu(&self, mu: f64) -> bool {
if !mu.is_finite() {
return false;
}
if self.var_power == 0.0 {
true } else {
mu > 0.0 }
}
fn clamp_mu(&self, mu: f64) -> f64 {
if self.var_power == 0.0 {
mu.clamp(-1e10, 1e10) } else {
mu.clamp(1e-10, 1e10) }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gaussian_family() {
let fam = TweedieFamily::gaussian();
assert!((fam.var_power - 0.0).abs() < 1e-10);
assert!((fam.link_power - 1.0).abs() < 1e-10);
assert!((fam.variance(1.0) - 1.0).abs() < 1e-10);
assert!((fam.variance(5.0) - 1.0).abs() < 1e-10);
assert!((fam.link(3.0) - 3.0).abs() < 1e-10);
assert!((fam.link_inverse(3.0) - 3.0).abs() < 1e-10);
}
#[test]
fn test_poisson_family() {
let fam = TweedieFamily::poisson();
assert!((fam.var_power - 1.0).abs() < 1e-10);
assert!((fam.variance(2.0) - 2.0).abs() < 1e-10);
assert!((fam.variance(5.0) - 5.0).abs() < 1e-10);
assert!((fam.link(std::f64::consts::E) - 1.0).abs() < 1e-10);
assert!((fam.link_inverse(1.0) - std::f64::consts::E).abs() < 1e-10);
}
#[test]
fn test_gamma_family() {
let fam = TweedieFamily::gamma();
assert!((fam.var_power - 2.0).abs() < 1e-10);
assert!((fam.variance(2.0) - 4.0).abs() < 1e-10);
assert!((fam.variance(3.0) - 9.0).abs() < 1e-10);
}
#[test]
fn test_inverse_gaussian_family() {
let fam = TweedieFamily::inverse_gaussian();
assert!((fam.var_power - 3.0).abs() < 1e-10);
assert!((fam.variance(2.0) - 8.0).abs() < 1e-10);
}
#[test]
fn test_compound_poisson_gamma() {
let fam = TweedieFamily::compound_poisson_gamma(1.5);
assert!(fam.var_power > 1.0 && fam.var_power < 2.0);
let mu: f64 = 4.0;
let expected_var = mu.powf(1.5);
assert!((fam.variance(mu) - expected_var).abs() < 1e-10);
}
#[test]
fn test_irls_weight() {
let fam = TweedieFamily::poisson();
let mu = 2.0;
let weight = fam.irls_weight(mu);
assert!((weight - mu).abs() < 1e-10);
}
#[test]
fn test_working_response() {
let fam = TweedieFamily::poisson();
let y = 3.0;
let mu = 2.0;
let eta = fam.link(mu);
let z = fam.working_response(y, mu, eta);
let expected = eta + (y - mu) / mu;
assert!((z - expected).abs() < 1e-10);
}
#[test]
fn test_unit_deviance_normal() {
let fam = TweedieFamily::gaussian();
let y = 3.0;
let mu = 2.0;
let dev = fam.unit_deviance(y, mu);
assert!((dev - 1.0).abs() < 1e-10); }
#[test]
fn test_unit_deviance_poisson() {
let fam = TweedieFamily::poisson();
let y = 3.0;
let mu = 2.0;
let dev = fam.unit_deviance(y, mu);
let expected = 2.0 * (y * (y / mu).ln() - (y - mu));
assert!((dev - expected).abs() < 1e-10);
}
#[test]
fn test_deviance() {
let fam = TweedieFamily::gaussian();
let y = vec![1.0, 2.0, 3.0];
let mu = vec![1.0, 2.0, 3.0];
let dev = fam.deviance(&y, &mu);
assert!(dev.abs() < 1e-10); }
#[test]
fn test_canonical_link() {
let normal = TweedieFamily::gaussian();
assert!(normal.is_canonical_link());
let poisson = TweedieFamily::poisson();
assert!(poisson.is_canonical_link());
let gamma = TweedieFamily::gamma();
assert!(!gamma.is_canonical_link());
}
#[test]
#[should_panic(expected = "var_power in (0, 1) is not allowed")]
fn test_invalid_var_power() {
TweedieFamily::new(0.5, 0.0);
}
#[test]
fn test_link_inverse_roundtrip() {
let families = vec![
TweedieFamily::gaussian(),
TweedieFamily::poisson(),
TweedieFamily::gamma(),
TweedieFamily::inverse_gaussian(),
TweedieFamily::compound_poisson_gamma(1.5),
];
for fam in families {
let mu = 2.5;
let eta = fam.link(mu);
let mu_back = fam.link_inverse(eta);
assert!(
(mu - mu_back).abs() < 1e-10,
"Roundtrip failed for var_power={}",
fam.var_power
);
}
}
#[test]
fn test_variance_derivative_normal() {
let fam = TweedieFamily::gaussian();
assert!((fam.variance_derivative(1.0) - 0.0).abs() < 1e-10);
assert!((fam.variance_derivative(5.0) - 0.0).abs() < 1e-10);
assert!((fam.variance_derivative(100.0) - 0.0).abs() < 1e-10);
}
#[test]
fn test_variance_derivative_poisson() {
let fam = TweedieFamily::poisson();
assert!((fam.variance_derivative(0.5) - 1.0).abs() < 1e-10);
assert!((fam.variance_derivative(2.0) - 1.0).abs() < 1e-10);
assert!((fam.variance_derivative(10.0) - 1.0).abs() < 1e-10);
}
#[test]
fn test_variance_derivative_gamma() {
let fam = TweedieFamily::gamma();
assert!((fam.variance_derivative(1.0) - 2.0).abs() < 1e-10);
assert!((fam.variance_derivative(3.0) - 6.0).abs() < 1e-10);
assert!((fam.variance_derivative(5.0) - 10.0).abs() < 1e-10);
}
#[test]
fn test_variance_derivative_general() {
let fam = TweedieFamily::inverse_gaussian(); let mu: f64 = 2.0;
let expected = 3.0 * mu.powf(2.0);
assert!((fam.variance_derivative(mu) - expected).abs() < 1e-10);
let fam2 = TweedieFamily::compound_poisson_gamma(1.5);
let expected2 = 1.5 * mu.powf(0.5);
assert!((fam2.variance_derivative(mu) - expected2).abs() < 1e-10);
}
#[test]
fn test_link_inverse_derivative_log() {
let fam = TweedieFamily::poisson(); let eta: f64 = 2.0;
let expected = eta.exp();
assert!((fam.link_inverse_derivative(eta) - expected).abs() < 1e-10);
}
#[test]
fn test_link_inverse_derivative_identity() {
let fam = TweedieFamily::gaussian();
assert!((fam.link_inverse_derivative(5.0) - 1.0).abs() < 1e-10);
assert!((fam.link_inverse_derivative(-3.0) - 1.0).abs() < 1e-10);
}
#[test]
fn test_link_inverse_derivative_inverse() {
let fam = TweedieFamily::new(2.0, -1.0); let eta = 2.0;
let expected = -1.0 / (eta * eta);
assert!((fam.link_inverse_derivative(eta) - expected).abs() < 1e-10);
}
#[test]
fn test_link_inverse_derivative_general() {
let fam = TweedieFamily::new(1.0, 0.5); let eta: f64 = 4.0;
let expected = (1.0 / 0.5) * eta.powf(1.0 / 0.5 - 1.0);
assert!((fam.link_inverse_derivative(eta) - expected).abs() < 1e-10);
}
#[test]
fn test_null_deviance_normal() {
let fam = TweedieFamily::gaussian();
let y = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
let y_mean: f64 = 3.0;
let expected: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
let null_dev = fam.null_deviance(&y);
assert!((null_dev - expected).abs() < 1e-10);
}
#[test]
fn test_null_deviance_poisson() {
let fam = TweedieFamily::poisson();
let y = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
let y_mean: f64 = 3.0;
let expected: f64 = y
.iter()
.map(|&yi| {
if yi > 0.0 {
2.0 * (yi * (yi / y_mean).ln() - (yi - y_mean))
} else {
2.0 * y_mean
}
})
.sum();
let null_dev = fam.null_deviance(&y);
assert!((null_dev - expected).abs() < 1e-10);
}
#[test]
fn test_null_deviance_gamma() {
let fam = TweedieFamily::gamma();
let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let null_dev = fam.null_deviance(&y);
assert!(null_dev.is_finite());
assert!(null_dev >= 0.0);
}
#[test]
fn test_canonical_link_power() {
let normal = TweedieFamily::gaussian();
assert!((normal.canonical_link_power() - 1.0).abs() < 1e-10);
let poisson = TweedieFamily::poisson();
assert!((poisson.canonical_link_power() - 0.0).abs() < 1e-10);
let gamma = TweedieFamily::gamma();
assert!((gamma.canonical_link_power() - (-1.0)).abs() < 1e-10);
let ig = TweedieFamily::inverse_gaussian();
assert!((ig.canonical_link_power() - (-2.0)).abs() < 1e-10); }
#[test]
fn test_is_valid() {
assert!(TweedieFamily::gaussian().is_valid());
assert!(TweedieFamily::poisson().is_valid());
assert!(TweedieFamily::gamma().is_valid());
assert!(TweedieFamily::new(0.0, 1.0).is_valid()); assert!(TweedieFamily::new(1.0, 0.0).is_valid()); assert!(TweedieFamily::new(1.5, 0.0).is_valid()); assert!(TweedieFamily::new(3.0, 0.0).is_valid()); assert!(TweedieFamily::new(-1.0, 1.0).is_valid()); }
#[test]
fn test_unit_deviance_gamma_case() {
let fam = TweedieFamily::gamma();
let y: f64 = 2.0;
let mu: f64 = 3.0;
let expected = 2.0 * (-(y / mu).ln() + (y - mu) / mu);
let dev = fam.unit_deviance(y, mu);
assert!((dev - expected).abs() < 1e-10);
}
#[test]
fn test_unit_deviance_general_tweedie() {
let fam = TweedieFamily::inverse_gaussian(); let y: f64 = 2.0;
let mu: f64 = 1.5;
let p: f64 = 3.0;
let term1 = y.powf(2.0 - p) / ((1.0 - p) * (2.0 - p));
let term2 = -y * mu.powf(1.0 - p) / (1.0 - p);
let term3 = mu.powf(2.0 - p) / (2.0 - p);
let expected = 2.0 * (term1 + term2 + term3);
let dev = fam.unit_deviance(y, mu);
assert!((dev - expected).abs() < 1e-10);
}
#[test]
fn test_unit_deviance_zero_y_poisson() {
let fam = TweedieFamily::poisson();
let dev = fam.unit_deviance(0.0, 2.0);
assert!((dev - 4.0).abs() < 1e-10);
}
#[test]
fn test_unit_deviance_zero_y_general_tweedie() {
let fam = TweedieFamily::compound_poisson_gamma(1.5);
let mu = 2.0;
let p = 1.5;
let dev = fam.unit_deviance(0.0, mu);
let expected = 2.0 * mu.powf(2.0 - p) / (2.0 - p);
assert!((dev - expected).abs() < 1e-10);
}
#[test]
fn test_irls_weight_gaussian() {
let fam = TweedieFamily::gaussian();
let mu = 5.0;
let weight = fam.irls_weight(mu);
assert!((weight - 1.0).abs() < 1e-10);
}
#[test]
fn test_irls_weight_gamma() {
let fam = TweedieFamily::gamma(); let mu = 2.0;
let weight = fam.irls_weight(mu);
assert!((weight - 1.0).abs() < 1e-10);
}
#[test]
fn test_irls_weight_inverse_gaussian() {
let fam = TweedieFamily::inverse_gaussian(); let mu = 2.0;
let weight = fam.irls_weight(mu);
assert!((weight - 0.5).abs() < 1e-10);
}
#[test]
fn test_working_response_gaussian() {
let fam = TweedieFamily::gaussian();
let y = 5.0;
let mu = 3.0;
let eta = fam.link(mu);
let z = fam.working_response(y, mu, eta);
assert!((z - 5.0).abs() < 1e-10);
}
#[test]
fn test_working_response_gamma() {
let fam = TweedieFamily::gamma();
let y = 4.0;
let mu = 2.0;
let eta = fam.link(mu);
let expected = mu.ln() + (y - mu) / mu;
let z = fam.working_response(y, mu, eta);
assert!((z - expected).abs() < 1e-10);
}
#[test]
fn test_link_derivative_all_branches() {
let mu = 2.0;
let log_link = TweedieFamily::poisson();
assert!((log_link.link_derivative(mu) - 0.5).abs() < 1e-10);
let id_link = TweedieFamily::gaussian();
assert!((id_link.link_derivative(mu) - 1.0).abs() < 1e-10);
let inv_link = TweedieFamily::new(2.0, -1.0);
assert!((inv_link.link_derivative(mu) - (-0.25)).abs() < 1e-10);
let pow_link = TweedieFamily::new(1.0, 0.5);
let expected = 0.5 * mu.powf(-0.5); assert!((pow_link.link_derivative(mu) - expected).abs() < 1e-10);
}
#[test]
fn test_link_all_branches() {
let mu = 2.0;
let log_link = TweedieFamily::poisson();
assert!((log_link.link(mu) - mu.ln()).abs() < 1e-10);
let id_link = TweedieFamily::gaussian();
assert!((id_link.link(mu) - mu).abs() < 1e-10);
let inv_link = TweedieFamily::new(2.0, -1.0);
assert!((inv_link.link(mu) - 0.5).abs() < 1e-10);
let pow_link = TweedieFamily::new(1.0, 0.5);
assert!((pow_link.link(mu) - mu.powf(0.5)).abs() < 1e-10);
}
#[test]
fn test_link_inverse_all_branches() {
let eta = 2.0;
let log_link = TweedieFamily::poisson();
assert!((log_link.link_inverse(eta) - eta.exp()).abs() < 1e-10);
let id_link = TweedieFamily::gaussian();
assert!((id_link.link_inverse(eta) - eta).abs() < 1e-10);
let inv_link = TweedieFamily::new(2.0, -1.0);
assert!((inv_link.link_inverse(eta) - 0.5).abs() < 1e-10);
let pow_link = TweedieFamily::new(1.0, 0.5);
assert!((pow_link.link_inverse(eta) - eta.powf(2.0)).abs() < 1e-10); }
#[test]
fn test_initialize_mu_normal() {
let fam = TweedieFamily::gaussian();
let y = vec![1.0, 2.0, 3.0, -1.0, 5.0];
let mu = fam.initialize_mu(&y);
assert_eq!(mu.len(), y.len());
for i in 0..y.len() {
assert!((mu[i] - y[i]).abs() < 1e-10);
}
}
#[test]
fn test_initialize_mu_poisson() {
let fam = TweedieFamily::poisson();
let y = vec![0.0, 1.0, 2.0, 3.0, 4.0];
let mu = fam.initialize_mu(&y);
assert_eq!(mu.len(), y.len());
for (i, &mu_i) in mu.iter().enumerate() {
assert!(mu_i > 0.0, "mu[{}] should be positive", i);
}
assert!((mu[0] - 1.0).abs() < 1e-10);
}
#[test]
fn test_initialize_mu_with_negative() {
let fam = TweedieFamily::poisson();
let y = vec![-1.0, -2.0, 0.0, 1.0, 2.0];
let mu = fam.initialize_mu(&y);
for (i, &mu_i) in mu.iter().enumerate() {
assert!(mu_i > 0.0, "mu[{}] = {} should be positive", i, mu_i);
}
}
#[test]
fn test_variance_all_branches() {
let mu = 2.0;
let normal = TweedieFamily::gaussian();
assert!((normal.variance(mu) - 1.0).abs() < 1e-10);
let poisson = TweedieFamily::poisson();
assert!((poisson.variance(mu) - mu).abs() < 1e-10);
let gamma = TweedieFamily::gamma();
assert!((gamma.variance(mu) - mu * mu).abs() < 1e-10);
let ig = TweedieFamily::inverse_gaussian();
assert!((ig.variance(mu) - mu.powf(3.0)).abs() < 1e-10);
}
#[test]
fn test_glm_family_trait() {
let fam = TweedieFamily::poisson();
let mu = 2.0;
assert!((GlmFamily::variance(&fam, mu) - fam.variance(mu)).abs() < 1e-10);
assert!((GlmFamily::link(&fam, mu) - fam.link(mu)).abs() < 1e-10);
assert!((GlmFamily::link_inverse(&fam, fam.link(mu)) - mu).abs() < 1e-10);
assert!((GlmFamily::link_derivative(&fam, mu) - fam.link_derivative(mu)).abs() < 1e-10);
}
#[test]
fn test_deviance_misfit() {
let fam = TweedieFamily::gaussian();
let y = vec![1.0, 2.0, 3.0];
let mu = vec![2.0, 2.0, 2.0];
let dev = fam.deviance(&y, &mu);
assert!((dev - 2.0).abs() < 1e-10);
}
#[test]
fn test_all_valid_mu_poisson() {
let fam = TweedieFamily::poisson();
let mu_valid = [1.0, 2.0, 3.0, 0.1];
assert!(fam.all_valid_mu(&mu_valid));
let mu_with_zero = [1.0, 0.0, 3.0];
assert!(!fam.all_valid_mu(&mu_with_zero));
let mu_with_neg = [1.0, -1.0, 3.0];
assert!(!fam.all_valid_mu(&mu_with_neg));
let mu_with_nan = [1.0, f64::NAN, 3.0];
assert!(!fam.all_valid_mu(&mu_with_nan));
let mu_with_inf = [1.0, f64::INFINITY, 3.0];
assert!(!fam.all_valid_mu(&mu_with_inf));
}
#[test]
fn test_all_valid_mu_gaussian() {
let fam = TweedieFamily::gaussian();
let mu_valid = [-1.0, 0.0, 1.0, 100.0];
assert!(fam.all_valid_mu(&mu_valid));
let mu_with_nan = [1.0, f64::NAN, 3.0];
assert!(!fam.all_valid_mu(&mu_with_nan));
}
#[test]
fn test_all_valid_eta() {
let fam = TweedieFamily::poisson();
let eta_valid = [-100.0, -1.0, 0.0, 1.0, 100.0];
assert!(fam.all_valid_eta(&eta_valid));
let eta_with_nan = [1.0, f64::NAN, 3.0];
assert!(!fam.all_valid_eta(&eta_with_nan));
let eta_with_inf = [1.0, f64::INFINITY, 3.0];
assert!(!fam.all_valid_eta(&eta_with_inf));
let eta_with_neg_inf = [1.0, f64::NEG_INFINITY, 3.0];
assert!(!fam.all_valid_eta(&eta_with_neg_inf));
}
#[test]
fn test_all_valid_eta_empty() {
let fam = TweedieFamily::poisson();
let eta_empty: [f64; 0] = [];
assert!(fam.all_valid_eta(&eta_empty));
}
#[test]
fn test_all_valid_mu_empty() {
let fam = TweedieFamily::poisson();
let mu_empty: [f64; 0] = [];
assert!(fam.all_valid_mu(&mu_empty));
}
#[test]
fn test_clamp_mu_gaussian() {
let fam = TweedieFamily::gaussian();
assert!((fam.clamp_mu(0.0) - 0.0).abs() < 1e-10);
assert!((fam.clamp_mu(-100.0) - (-100.0)).abs() < 1e-10);
assert!((fam.clamp_mu(1e20) - 1e10).abs() < 1e-5);
assert!((fam.clamp_mu(-1e20) - (-1e10)).abs() < 1e-5);
}
#[test]
fn test_clamp_mu_poisson() {
let fam = TweedieFamily::poisson();
assert!((fam.clamp_mu(0.0) - 1e-10).abs() < 1e-15);
assert!((fam.clamp_mu(-100.0) - 1e-10).abs() < 1e-15);
assert!((fam.clamp_mu(1e20) - 1e10).abs() < 1e-5);
assert!((fam.clamp_mu(5.0) - 5.0).abs() < 1e-10);
}
}