use crate::error::{AprenderError, Result};
use crate::primitives::{Matrix, Vector};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Family {
Poisson,
NegativeBinomial,
Gamma,
Binomial,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Link {
Log,
Inverse,
Logit,
Identity,
}
impl Family {
#[must_use]
pub const fn canonical_link(&self) -> Link {
match self {
Self::Poisson | Self::NegativeBinomial => Link::Log,
Self::Gamma => Link::Inverse,
Self::Binomial => Link::Logit,
}
}
fn variance(self, mu: f32, dispersion: f32) -> f32 {
match self {
Self::Poisson => mu, Self::NegativeBinomial => mu + dispersion * mu * mu, Self::Gamma => mu * mu, Self::Binomial => mu * (1.0 - mu), }
}
const fn name(self) -> &'static str {
match self {
Self::Poisson => "Poisson",
Self::NegativeBinomial => "Negative Binomial",
Self::Gamma => "Gamma",
Self::Binomial => "Binomial",
}
}
fn is_valid_response(self, val: f32) -> bool {
match self {
Self::Poisson | Self::NegativeBinomial => val >= 0.0,
Self::Gamma => val > 0.0,
Self::Binomial => (0.0..=1.0).contains(&val),
}
}
const fn constraint_description(self) -> &'static str {
match self {
Self::Poisson | Self::NegativeBinomial => "non-negative counts",
Self::Gamma => "positive values",
Self::Binomial => "values in [0,1]",
}
}
fn clamp_mu(self, mu_raw: f32) -> f32 {
match self {
Self::Poisson | Self::NegativeBinomial | Self::Gamma => mu_raw.max(1e-6),
Self::Binomial => mu_raw.clamp(1e-6, 1.0 - 1e-6),
}
}
fn validate_response(self, y: &Vector<f32>) -> Result<()> {
for &val in y.as_slice() {
if !self.is_valid_response(val) {
return Err(AprenderError::Other(format!(
"{} requires {}, got {val}",
self.name(),
self.constraint_description(),
)));
}
}
Ok(())
}
}
impl Link {
fn link(self, mu: f32) -> f32 {
match self {
Self::Log => mu.ln(),
Self::Inverse => 1.0 / mu,
Self::Logit => (mu / (1.0 - mu)).ln(),
Self::Identity => mu,
}
}
fn inverse_link(self, eta: f32) -> f32 {
match self {
Self::Log => eta.exp(),
Self::Inverse => 1.0 / eta,
Self::Logit => 1.0 / (1.0 + (-eta).exp()),
Self::Identity => eta,
}
}
fn derivative(self, eta: f32) -> f32 {
match self {
Self::Log => eta.exp(),
Self::Inverse => -1.0 / (eta * eta),
Self::Logit => {
let mu = self.inverse_link(eta);
mu * (1.0 - mu)
}
Self::Identity => 1.0,
}
}
}
#[derive(Debug, Clone)]
pub struct GLM {
family: Family,
link: Link,
max_iter: usize,
tol: f32,
dispersion: f32,
coefficients: Option<Vec<f32>>,
intercept: Option<f32>,
}
impl GLM {
#[must_use]
pub fn new(family: Family) -> Self {
Self {
family,
link: family.canonical_link(),
max_iter: 1000, tol: 1e-3, dispersion: 1.0, coefficients: None,
intercept: None,
}
}
#[must_use]
pub fn with_link(mut self, link: Link) -> Self {
self.link = link;
self
}
#[must_use]
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
#[must_use]
pub fn with_tolerance(mut self, tol: f32) -> Self {
self.tol = tol;
self
}
#[must_use]
pub fn with_dispersion(mut self, dispersion: f32) -> Self {
self.dispersion = dispersion;
self
}
#[must_use]
pub fn coefficients(&self) -> Option<&[f32]> {
self.coefficients.as_deref()
}
#[must_use]
pub fn intercept(&self) -> Option<f32> {
self.intercept
}
pub fn fit(&mut self, x: &Matrix<f32>, y: &Vector<f32>) -> Result<()> {
let n = x.n_rows();
let p = x.n_cols();
if n != y.len() {
return Err(AprenderError::DimensionMismatch {
expected: format!("{n} samples in X"),
actual: format!("{} samples in y", y.len()),
});
}
self.family.validate_response(y)?;
let (mut beta, mut intercept, mut eta) = self.initialize_irls(y, n, p);
for _iter in 0..self.max_iter {
let (beta_new, intercept_new, eta_new, max_change) =
self.irls_iteration(x, y, &beta, intercept, &eta, n, p)?;
beta = beta_new;
intercept = intercept_new;
eta = eta_new;
if max_change < self.tol {
self.coefficients = Some(beta);
self.intercept = Some(intercept);
return Ok(());
}
}
Err(AprenderError::Other(format!(
"GLM IRLS did not converge in {} iterations",
self.max_iter
)))
}
fn initialize_irls(&self, y: &Vector<f32>, n: usize, p: usize) -> (Vec<f32>, f32, Vec<f32>) {
let beta = vec![0.0_f32; p];
let y_mean = y.as_slice().iter().sum::<f32>() / n as f32;
let y_mean_safe = y_mean.clamp(0.01, 0.99); let intercept = self.link.link(y_mean_safe);
let eta = vec![intercept; n];
(beta, intercept, eta)
}
#[allow(clippy::needless_range_loop)]
fn irls_iteration(
&self,
x: &Matrix<f32>,
y: &Vector<f32>,
beta: &[f32],
intercept: f32,
eta: &[f32],
n: usize,
p: usize,
) -> Result<(Vec<f32>, f32, Vec<f32>, f32)> {
let mu: Vec<f32> = eta
.iter()
.map(|&e| self.family.clamp_mu(self.link.inverse_link(e)))
.collect();
let (z, weights) = self.compute_working_response_and_weights(y, &mu, eta, n);
let beta_aug = solve_weighted_least_squares(x, &weights, &z, n, p)?;
let intercept_new = beta_aug[0];
let beta_new = &beta_aug.as_slice()[1..];
let step_size = self.damping_factor();
let intercept_damped = intercept + step_size * (intercept_new - intercept);
let beta_damped: Vec<f32> = beta
.iter()
.zip(beta_new)
.map(|(old, new)| old + step_size * (new - old))
.collect();
let eta_new = compute_linear_predictor(x, &beta_damped, intercept_damped, n, p);
let max_change = compute_max_change(beta, &beta_damped, intercept, intercept_damped);
Ok((beta_damped, intercept_damped, eta_new, max_change))
}
#[allow(clippy::needless_range_loop)]
fn compute_working_response_and_weights(
&self,
y: &Vector<f32>,
mu: &[f32],
eta: &[f32],
n: usize,
) -> (Vec<f32>, Vec<f32>) {
let mut z = Vec::with_capacity(n);
let mut weights = Vec::with_capacity(n);
for i in 0..n {
let deriv = self.link.derivative(eta[i]);
z.push(eta[i] + (y[i] - mu[i]) * deriv);
let var = self.family.variance(mu[i], self.dispersion).max(1e-10);
let weight = 1.0 / (var * deriv * deriv + 1e-10);
weights.push(weight.clamp(1e-6, 1e6));
}
(z, weights)
}
fn damping_factor(&self) -> f32 {
match self.link {
Link::Log => 0.5,
_ => 1.0,
}
}
#[allow(clippy::needless_range_loop)]
pub fn predict(&self, x_test: &Matrix<f32>) -> Result<Vector<f32>> {
let beta = self.coefficients.as_ref().ok_or_else(|| {
AprenderError::Other("Model not fitted yet. Call fit() first.".into())
})?;
let intercept = self.intercept.ok_or_else(|| {
AprenderError::Other("Model not fitted yet. Call fit() first.".into())
})?;
let n = x_test.n_rows();
let p = x_test.n_cols();
if p != beta.len() {
return Err(AprenderError::DimensionMismatch {
expected: format!("{} features", beta.len()),
actual: format!("{p} columns in x_test"),
});
}
let mut predictions = Vec::with_capacity(n);
for i in 0..n {
let mut eta = intercept;
for j in 0..p {
eta += x_test.get(i, j) * beta[j];
}
let mu = self.link.inverse_link(eta);
predictions.push(mu);
}
Ok(Vector::from_vec(predictions))
}
}
#[allow(clippy::needless_range_loop)]
fn solve_weighted_least_squares(
x: &Matrix<f32>,
weights: &[f32],
z: &[f32],
n: usize,
p: usize,
) -> Result<Vector<f32>> {
let x_aug = build_augmented_matrix(x, n, p)?;
let xtw = build_xtw(&x_aug, weights, n, p)?;
let wx = build_wx(&x_aug, weights, n, p)?;
let wz: Vec<f32> = z
.iter()
.enumerate()
.map(|(i, &zi)| weights[i].sqrt() * zi)
.collect();
let wz_vec = Vector::from_vec(wz);
let mut xtwx = xtw
.matmul(&wx)
.map_err(|e| AprenderError::Other(format!("X'WX computation failed: {e}")))?;
let xtwz = xtw
.matvec(&wz_vec)
.map_err(|e| AprenderError::Other(format!("X'Wz computation failed: {e}")))?;
add_ridge_regularization(&mut xtwx, p);
xtwx.cholesky_solve(&xtwz)
.map_err(|e| AprenderError::Other(format!("Cholesky solve failed: {e}")))
}
#[allow(clippy::needless_range_loop)]
fn build_augmented_matrix(x: &Matrix<f32>, n: usize, p: usize) -> Result<Matrix<f32>> {
let mut data = Vec::with_capacity(n * (p + 1));
for i in 0..n {
data.push(1.0);
for j in 0..p {
data.push(x.get(i, j));
}
}
Matrix::from_vec(n, p + 1, data)
.map_err(|e| AprenderError::Other(format!("Augmented matrix error: {e}")))
}
#[allow(clippy::needless_range_loop)]
fn build_xtw(x_aug: &Matrix<f32>, weights: &[f32], n: usize, p: usize) -> Result<Matrix<f32>> {
let mut data = Vec::with_capacity((p + 1) * n);
for j in 0..=p {
for i in 0..n {
data.push(x_aug.get(i, j) * weights[i].sqrt());
}
}
Matrix::from_vec(p + 1, n, data)
.map_err(|e| AprenderError::Other(format!("X'W matrix error: {e}")))
}
#[allow(clippy::needless_range_loop)]
fn build_wx(x_aug: &Matrix<f32>, weights: &[f32], n: usize, p: usize) -> Result<Matrix<f32>> {
let mut data = Vec::with_capacity(n * (p + 1));
for i in 0..n {
for j in 0..=p {
data.push(weights[i].sqrt() * x_aug.get(i, j));
}
}
Matrix::from_vec(n, p + 1, data)
.map_err(|e| AprenderError::Other(format!("WX matrix error: {e}")))
}
fn add_ridge_regularization(xtwx: &mut Matrix<f32>, p: usize) {
let max_diag = (0..=p)
.map(|i| xtwx.get(i, i).abs())
.fold(0.0_f32, f32::max);
let ridge = (max_diag * 1e-6).max(1e-8);
for i in 0..=p {
let old_val = xtwx.get(i, i);
xtwx.set(i, i, old_val + ridge);
}
}
#[allow(clippy::needless_range_loop)]
fn compute_linear_predictor(
x: &Matrix<f32>,
beta: &[f32],
intercept: f32,
n: usize,
p: usize,
) -> Vec<f32> {
let mut eta = Vec::with_capacity(n);
for i in 0..n {
let mut val = intercept;
for j in 0..p {
val += x.get(i, j) * beta[j];
}
eta.push(val);
}
eta
}
fn compute_max_change(
beta_old: &[f32],
beta_new: &[f32],
intercept_old: f32,
intercept_new: f32,
) -> f32 {
let mut max_change = (intercept_new - intercept_old).abs();
for (old, new) in beta_old.iter().zip(beta_new) {
max_change = max_change.max((new - old).abs());
}
max_change
}
#[cfg(test)]
#[path = "glm_tests.rs"]
mod tests;
#[cfg(test)]
#[path = "tests_glm_contract.rs"]
mod tests_glm_contract;