use crate::families::penalized_vector_glm::{PenalizedVectorGlmInputs, fit_penalized_vector_glm};
use crate::families::vector_response::VectorLikelihood;
use crate::model_types::EstimationError;
use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3};
#[derive(Debug, Clone)]
pub struct BinomialMultiFitInputs<'a> {
pub design: ArrayView2<'a, f64>,
pub y: ArrayView2<'a, f64>,
pub penalty: ArrayView2<'a, f64>,
pub lambdas: ArrayView1<'a, f64>,
pub row_weights: Option<ArrayView1<'a, f64>>,
pub fisher_w_override: Option<ArrayView3<'a, f64>>,
pub max_iter: usize,
pub tol: f64,
}
#[derive(Debug, Clone)]
pub struct BinomialMultiFitOutputs {
pub coefficients: Array2<f64>,
pub fitted_probabilities: Array2<f64>,
pub iterations: usize,
pub converged: bool,
pub penalized_neg_log_likelihood: f64,
pub deviance: f64,
}
#[inline]
fn sigmoid_stable(eta: f64) -> f64 {
if eta >= 0.0 {
let e = (-eta).exp();
1.0 / (1.0 + e)
} else {
let e = eta.exp();
e / (1.0 + e)
}
}
struct BinomialMultiLikelihood {
row_weights: Option<Array1<f64>>,
}
impl BinomialMultiLikelihood {
#[inline]
fn row_weight(&self, n: usize) -> f64 {
self.row_weights.as_ref().map_or(1.0, |w| w[n])
}
}
impl VectorLikelihood for BinomialMultiLikelihood {
fn log_lik(&self, eta: ArrayView2<'_, f64>, y: ArrayView2<'_, f64>) -> f64 {
let (n, k) = eta.dim();
let mut acc = 0.0_f64;
for row in 0..n {
let w = self.row_weight(row);
for a in 0..k {
let mu = sigmoid_stable(eta[[row, a]]).clamp(1.0e-12, 1.0 - 1.0e-12);
let yv = y[[row, a]];
acc += w * (yv * mu.ln() + (1.0 - yv) * (1.0 - mu).ln());
}
}
acc
}
fn grad_eta(&self, eta: ArrayView2<'_, f64>, y: ArrayView2<'_, f64>) -> Array2<f64> {
let (n, k) = eta.dim();
let mut out = Array2::<f64>::zeros((n, k));
for row in 0..n {
let w = self.row_weight(row);
for a in 0..k {
let mu = sigmoid_stable(eta[[row, a]]);
out[[row, a]] = w * (y[[row, a]] - mu);
}
}
out
}
fn hess_diag(&self, eta: ArrayView2<'_, f64>, y: ArrayView2<'_, f64>) -> Array2<f64> {
assert_eq!(eta.dim(), y.dim(), "y must match eta shape (N, K)");
let (n, k) = eta.dim();
let mut out = Array2::<f64>::zeros((n, k));
for row in 0..n {
let w = self.row_weight(row);
for a in 0..k {
let mu = sigmoid_stable(eta[[row, a]]);
out[[row, a]] = w * mu * (1.0 - mu);
}
}
out
}
fn hess_block(&self, eta: ArrayView2<'_, f64>, y: ArrayView2<'_, f64>) -> Array3<f64> {
let diag = self.hess_diag(eta, y);
let (n, k) = diag.dim();
let mut out = Array3::<f64>::zeros((n, k, k));
for row in 0..n {
for a in 0..k {
out[[row, a, a]] = diag[[row, a]];
}
}
out
}
}
pub fn fit_penalized_binomial_multi(
inputs: BinomialMultiFitInputs<'_>,
) -> Result<BinomialMultiFitOutputs, EstimationError> {
let BinomialMultiFitInputs {
design,
y,
penalty,
lambdas,
row_weights,
fisher_w_override,
max_iter,
tol,
} = inputs;
let n_obs = design.nrows();
let (y_rows, k) = y.dim();
if y_rows != n_obs {
crate::bail_invalid_estim!(
"fit_penalized_binomial_multi: y rows {y_rows} ≠ design rows {n_obs}"
);
}
if k == 0 {
crate::bail_invalid_estim!(
"fit_penalized_binomial_multi: y must have at least one column (got K=0)"
);
}
if lambdas.len() != k {
crate::bail_invalid_estim!(
"fit_penalized_binomial_multi: lambdas length {} ≠ K = {k}",
lambdas.len()
);
}
if let Some(fw) = fisher_w_override.as_ref() {
if fw.dim() != (n_obs, k, k) {
crate::bail_invalid_estim!(
"fit_penalized_binomial_multi: fisher_w_override shape {:?} ≠ (N, K, K) = ({n_obs}, {k}, {k})",
fw.dim()
);
}
for ((n_idx, a, b), &v) in fw.indexed_iter() {
if a != b && v != 0.0 {
crate::bail_invalid_estim!(
"fit_penalized_binomial_multi: fisher_w_override[{n_idx},{a},{b}] must be zero \
(independent columns have a row-diagonal Fisher block); got {v}"
);
}
}
}
if let Some(w) = row_weights.as_ref() {
if w.len() != n_obs {
crate::bail_invalid_estim!(
"fit_penalized_binomial_multi: row_weights length {} ≠ N = {n_obs}",
w.len()
);
}
for (i, &v) in w.iter().enumerate() {
if !(v.is_finite() && v >= 0.0) {
crate::bail_invalid_estim!(
"fit_penalized_binomial_multi: row_weights[{i}] must be finite and ≥ 0 (got {v})"
);
}
}
}
for ((i, j), &v) in y.indexed_iter() {
if !(v.is_finite() && (0.0..=1.0).contains(&v)) {
crate::bail_invalid_estim!(
"fit_penalized_binomial_multi: y[{i},{j}] must be a binomial proportion in [0,1] (got {v})"
);
}
}
let likelihood = BinomialMultiLikelihood {
row_weights: row_weights.map(|w| w.to_owned()),
};
let fit = fit_penalized_vector_glm(
PenalizedVectorGlmInputs {
design,
y,
penalty,
lambdas,
fisher_w_override,
max_iter,
tol,
},
&likelihood,
"fit_penalized_binomial_multi",
)?;
let fitted = fit.eta.mapv(sigmoid_stable);
Ok(BinomialMultiFitOutputs {
coefficients: fit.coefficients,
fitted_probabilities: fitted,
iterations: fit.iterations,
converged: fit.converged,
penalized_neg_log_likelihood: -fit.log_likelihood + fit.penalty_term,
deviance: -2.0 * fit.log_likelihood,
})
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array3;
fn toy_inputs() -> (Array2<f64>, Array2<f64>, Array2<f64>, Array1<f64>) {
let n = 12;
let p = 2;
let k = 2;
let design =
Array2::<f64>::from_shape_fn(
(n, p),
|(i, j)| {
if j == 0 { 1.0 } else { ((i + 1) as f64).sin() }
},
);
let y =
Array2::<f64>::from_shape_fn((n, k), |(i, a)| if (i + a) % 2 == 0 { 1.0 } else { 0.0 });
let penalty = Array2::<f64>::eye(p);
let lambdas = Array1::<f64>::from_elem(k, 0.5);
(design, y, penalty, lambdas)
}
#[test]
fn fisher_override_none_reproduces_analytic_bit_for_bit() {
let (design, y, penalty, lambdas) = toy_inputs();
let base = fit_penalized_binomial_multi(BinomialMultiFitInputs {
design: design.view(),
y: y.view(),
penalty: penalty.view(),
lambdas: lambdas.view(),
row_weights: None,
fisher_w_override: None,
max_iter: 50,
tol: 1.0e-9,
})
.expect("analytic fit must succeed");
let again = fit_penalized_binomial_multi(BinomialMultiFitInputs {
design: design.view(),
y: y.view(),
penalty: penalty.view(),
lambdas: lambdas.view(),
row_weights: None,
fisher_w_override: None,
max_iter: 50,
tol: 1.0e-9,
})
.expect("analytic fit must succeed");
for (a, b) in base.coefficients.iter().zip(again.coefficients.iter()) {
assert_eq!(a, b, "None override must be deterministic");
}
}
#[test]
fn out_of_range_response_is_rejected() {
let (design, y, penalty, lambdas) = toy_inputs();
let mut bad = y.clone();
bad[[0, 0]] = 2.0;
let err = fit_penalized_binomial_multi(BinomialMultiFitInputs {
design: design.view(),
y: bad.view(),
penalty: penalty.view(),
lambdas: lambdas.view(),
row_weights: None,
fisher_w_override: None,
max_iter: 50,
tol: 1.0e-9,
})
.expect_err("out-of-range response must error");
assert!(format!("{err}").contains("binomial proportion in [0,1]"));
let mut neg = y.clone();
neg[[1, 1]] = -0.5;
let err = fit_penalized_binomial_multi(BinomialMultiFitInputs {
design: design.view(),
y: neg.view(),
penalty: penalty.view(),
lambdas: lambdas.view(),
row_weights: None,
fisher_w_override: None,
max_iter: 50,
tol: 1.0e-9,
})
.expect_err("negative response must error");
assert!(format!("{err}").contains("binomial proportion in [0,1]"));
}
#[test]
fn fisher_override_shape_mismatch_is_rejected() {
let (design, y, penalty, lambdas) = toy_inputs();
let n = design.nrows();
let k = y.ncols();
let bad = Array3::<f64>::zeros((n, k + 1, k + 1));
let err = fit_penalized_binomial_multi(BinomialMultiFitInputs {
design: design.view(),
y: y.view(),
penalty: penalty.view(),
lambdas: lambdas.view(),
row_weights: None,
fisher_w_override: Some(bad.view()),
max_iter: 50,
tol: 1.0e-9,
})
.expect_err("mismatched override shape must error");
assert!(format!("{err}").contains("fisher_w_override shape"));
}
#[test]
fn fisher_override_replaces_curvature_diagonal() {
let (design, y, penalty, lambdas) = toy_inputs();
let n = design.nrows();
let k = y.ncols();
let mut over = Array3::<f64>::zeros((n, k, k));
for row in 0..n {
for a in 0..k {
over[[row, a, a]] = 0.25 * 4.0; }
}
let scaled = fit_penalized_binomial_multi(BinomialMultiFitInputs {
design: design.view(),
y: y.view(),
penalty: penalty.view(),
lambdas: lambdas.view(),
row_weights: None,
fisher_w_override: Some(over.view()),
max_iter: 1,
tol: 1.0e-9,
})
.expect("override fit must succeed");
let analytic = fit_penalized_binomial_multi(BinomialMultiFitInputs {
design: design.view(),
y: y.view(),
penalty: penalty.view(),
lambdas: lambdas.view(),
row_weights: None,
fisher_w_override: None,
max_iter: 1,
tol: 1.0e-9,
})
.expect("analytic fit must succeed");
let differs = scaled
.coefficients
.iter()
.zip(analytic.coefficients.iter())
.any(|(a, b)| (a - b).abs() > 1.0e-6);
assert!(differs, "scaled curvature override must change the step");
}
}