use crate::faer_ndarray::{FaerArrayView, array2_to_matmut, factorize_symmetricwith_fallback};
use crate::solver::estimate::EstimationError;
use faer::Side;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayView3};
#[derive(Debug, Clone)]
pub struct BinomialMultiFitInputs<'a> {
pub design: ArrayView2<'a, f64>,
pub y: ArrayView2<'a, f64>,
pub penalty: ArrayView2<'a, f64>,
pub lambdas: ArrayView1<'a, f64>,
pub row_weights: Option<ArrayView1<'a, f64>>,
pub fisher_w_override: Option<ArrayView3<'a, f64>>,
pub max_iter: usize,
pub tol: f64,
}
#[derive(Debug, Clone)]
pub struct BinomialMultiFitOutputs {
pub coefficients: Array2<f64>,
pub fitted_probabilities: Array2<f64>,
pub iterations: usize,
pub iterations_per_response: Vec<usize>,
pub converged: bool,
pub penalized_neg_log_likelihood: f64,
pub deviance: f64,
}
#[inline]
fn sigmoid_stable(eta: f64) -> f64 {
if eta >= 0.0 {
let e = (-eta).exp();
1.0 / (1.0 + e)
} else {
let e = eta.exp();
e / (1.0 + e)
}
}
fn binomial_log_lik_column(
eta_col: ArrayView1<'_, f64>,
y_col: ArrayView1<'_, f64>,
row_weights: Option<ArrayView1<'_, f64>>,
) -> f64 {
let mut acc = 0.0_f64;
for (i, &eta_i) in eta_col.iter().enumerate() {
let mu = sigmoid_stable(eta_i).clamp(1.0e-12, 1.0 - 1.0e-12);
let y = y_col[i];
let w = row_weights.as_ref().map(|w| w[i]).unwrap_or(1.0);
acc += w * (y * mu.ln() + (1.0 - y) * (1.0 - mu).ln());
}
acc
}
pub fn fit_penalized_binomial_multi(
inputs: BinomialMultiFitInputs<'_>,
) -> Result<BinomialMultiFitOutputs, EstimationError> {
let BinomialMultiFitInputs {
design,
y,
penalty,
lambdas,
row_weights,
fisher_w_override,
max_iter,
tol,
} = inputs;
let n_obs = design.nrows();
let p = design.ncols();
if n_obs == 0 || p == 0 {
crate::bail_invalid_estim!(
"fit_penalized_binomial_multi: design must be nonempty (got {n_obs}x{p})"
);
}
let (y_rows, k) = y.dim();
if y_rows != n_obs {
crate::bail_invalid_estim!(
"fit_penalized_binomial_multi: y rows {y_rows} ≠ design rows {n_obs}"
);
}
if k == 0 {
crate::bail_invalid_estim!(
"fit_penalized_binomial_multi: y must have at least one column (got K=0)"
);
}
if lambdas.len() != k {
crate::bail_invalid_estim!(
"fit_penalized_binomial_multi: lambdas length {} ≠ K = {k}",
lambdas.len()
);
}
if penalty.dim() != (p, p) {
crate::bail_invalid_estim!(
"fit_penalized_binomial_multi: penalty shape {:?} ≠ (P, P) = ({p}, {p})",
penalty.dim()
);
}
for (i, &v) in lambdas.iter().enumerate() {
if !(v.is_finite() && v >= 0.0) {
crate::bail_invalid_estim!(
"fit_penalized_binomial_multi: lambdas[{i}] must be finite and ≥ 0 (got {v})"
);
}
}
if let Some(fw) = fisher_w_override.as_ref() {
if fw.dim() != (n_obs, k, k) {
crate::bail_invalid_estim!(
"fit_penalized_binomial_multi: fisher_w_override shape {:?} ≠ (N, K, K) = ({n_obs}, {k}, {k})",
fw.dim()
);
}
}
if let Some(w) = row_weights.as_ref() {
if w.len() != n_obs {
crate::bail_invalid_estim!(
"fit_penalized_binomial_multi: row_weights length {} ≠ N = {n_obs}",
w.len()
);
}
for (i, &v) in w.iter().enumerate() {
if !(v.is_finite() && v >= 0.0) {
crate::bail_invalid_estim!(
"fit_penalized_binomial_multi: row_weights[{i}] must be finite and ≥ 0 (got {v})"
);
}
}
}
for ((i, j), &v) in y.indexed_iter() {
if !v.is_finite() {
crate::bail_invalid_estim!(
"fit_penalized_binomial_multi: y[{i},{j}] must be finite (got {v})"
);
}
}
for ((i, j), &v) in design.indexed_iter() {
if !v.is_finite() {
crate::bail_invalid_estim!(
"fit_penalized_binomial_multi: design[{i},{j}] must be finite (got {v})"
);
}
}
let mut beta = Array2::<f64>::zeros((p, k));
let mut eta = Array2::<f64>::zeros((n_obs, k));
let mut fitted = Array2::<f64>::zeros((n_obs, k));
let mut iterations_per_response = vec![0_usize; k];
let mut all_converged = true;
for a in 0..k {
let lambda_a = lambdas[a];
let y_col = y.column(a);
let mut converged_a = false;
let mut last_objective_a = f64::INFINITY;
for iter in 0..max_iter {
iterations_per_response[a] = iter + 1;
for row in 0..n_obs {
let mut eta_val = 0.0_f64;
for i in 0..p {
eta_val += design[[row, i]] * beta[[i, a]];
}
eta[[row, a]] = eta_val;
fitted[[row, a]] = sigmoid_stable(eta_val);
}
let mut hess_diag = Array1::<f64>::zeros(n_obs);
let mut grad = Array1::<f64>::zeros(p);
for row in 0..n_obs {
let mu = fitted[[row, a]];
let w = row_weights.as_ref().map(|w| w[row]).unwrap_or(1.0);
hess_diag[row] = match fisher_w_override.as_ref() {
Some(fw) => fw[[row, a, a]],
None => w * mu * (1.0 - mu),
};
let resid = w * (mu - y_col[row]);
if resid != 0.0 {
for i in 0..p {
grad[i] += design[[row, i]] * resid;
}
}
}
if lambda_a != 0.0 {
let beta_col = beta.column(a);
for i in 0..p {
let mut s_beta_i = 0.0_f64;
for j in 0..p {
s_beta_i += penalty[[i, j]] * beta_col[j];
}
grad[i] += lambda_a * s_beta_i;
}
}
let mut hessian = Array2::<f64>::zeros((p, p));
for row in 0..n_obs {
let w = hess_diag[row];
if w == 0.0 {
continue;
}
for i in 0..p {
let xi = design[[row, i]];
if xi == 0.0 {
continue;
}
let scaled = w * xi;
for j in 0..p {
hessian[[i, j]] += scaled * design[[row, j]];
}
}
}
if lambda_a != 0.0 {
for i in 0..p {
for j in 0..p {
hessian[[i, j]] += lambda_a * penalty[[i, j]];
}
}
}
for i in 0..p {
for j in (i + 1)..p {
let avg = 0.5 * (hessian[[i, j]] + hessian[[j, i]]);
hessian[[i, j]] = avg;
hessian[[j, i]] = avg;
}
}
let factor = factorize_symmetricwith_fallback(
FaerArrayView::new(&hessian).as_ref(),
Side::Lower,
)
.map_err(|err| {
EstimationError::InvalidInput(format!(
"fit_penalized_binomial_multi: Hessian factorization failed at response {a}, iter {iter}: {err}"
))
})?;
let mut rhs = Array2::<f64>::zeros((p, 1));
for i in 0..p {
rhs[[i, 0]] = -grad[i];
}
{
let rhs_view = array2_to_matmut(&mut rhs);
factor.solve_in_place(rhs_view);
}
let mut delta = Array1::<f64>::zeros(p);
for i in 0..p {
delta[i] = rhs[[i, 0]];
}
if delta.iter().any(|v| !v.is_finite()) {
crate::bail_invalid_estim!(
"fit_penalized_binomial_multi: Newton step is non-finite at response {a}, iter {iter}"
);
}
let evaluate_objective = |beta_trial_col: &Array1<f64>| -> f64 {
let mut eta_trial = Array1::<f64>::zeros(n_obs);
for row in 0..n_obs {
let mut v = 0.0_f64;
for i in 0..p {
v += design[[row, i]] * beta_trial_col[i];
}
eta_trial[row] = v;
}
let ll =
binomial_log_lik_column(eta_trial.view(), y_col.view(), row_weights);
let mut pen = 0.0_f64;
if lambda_a != 0.0 {
let mut quad = 0.0_f64;
for i in 0..p {
let mut s_beta_i = 0.0_f64;
for j in 0..p {
s_beta_i += penalty[[i, j]] * beta_trial_col[j];
}
quad += beta_trial_col[i] * s_beta_i;
}
pen = 0.5 * lambda_a * quad;
}
-ll + pen
};
if iter == 0 {
let beta_init: Array1<f64> = beta.column(a).to_owned();
last_objective_a = evaluate_objective(&beta_init);
if !last_objective_a.is_finite() {
crate::bail_invalid_estim!(
"fit_penalized_binomial_multi: non-finite objective at response {a}, β = 0"
);
}
}
let propose_beta = |alpha: f64| -> Array1<f64> {
let mut out: Array1<f64> = beta.column(a).to_owned();
for i in 0..p {
out[i] += alpha * delta[i];
}
out
};
let mut alpha = 1.0_f64;
let mut accepted_beta_a = propose_beta(alpha);
let mut new_objective = evaluate_objective(&accepted_beta_a);
let mut backtrack = 0_usize;
while (!new_objective.is_finite()
|| new_objective > last_objective_a + 1.0e-12)
&& backtrack < 8
{
alpha *= 0.5;
accepted_beta_a = propose_beta(alpha);
new_objective = evaluate_objective(&accepted_beta_a);
backtrack += 1;
}
let mut step_norm_sq = 0.0_f64;
let mut beta_norm_sq = 0.0_f64;
for i in 0..p {
let d = accepted_beta_a[i] - beta[[i, a]];
step_norm_sq += d * d;
let v = accepted_beta_a[i];
beta_norm_sq += v * v;
}
for i in 0..p {
beta[[i, a]] = accepted_beta_a[i];
}
last_objective_a = new_objective;
let step_norm = step_norm_sq.sqrt();
let beta_norm = beta_norm_sq.sqrt();
if step_norm <= tol * (1.0 + beta_norm) {
converged_a = true;
break;
}
}
if !converged_a {
all_converged = false;
}
}
for a in 0..k {
for row in 0..n_obs {
let mut v = 0.0_f64;
for i in 0..p {
v += design[[row, i]] * beta[[i, a]];
}
eta[[row, a]] = v;
fitted[[row, a]] = sigmoid_stable(v);
}
}
let mut log_lik = 0.0_f64;
for a in 0..k {
log_lik += binomial_log_lik_column(eta.column(a), y.column(a), row_weights);
}
let mut pen = 0.0_f64;
for a in 0..k {
let la = lambdas[a];
if la == 0.0 {
continue;
}
let beta_col = beta.column(a);
let mut quad = 0.0_f64;
for i in 0..p {
let mut s_beta_i = 0.0_f64;
for j in 0..p {
s_beta_i += penalty[[i, j]] * beta_col[j];
}
quad += beta_col[i] * s_beta_i;
}
pen += 0.5 * la * quad;
}
let penalized_neg_log_likelihood = -log_lik + pen;
let deviance = -2.0 * log_lik;
let iterations: usize = iterations_per_response.iter().sum();
Ok(BinomialMultiFitOutputs {
coefficients: beta,
fitted_probabilities: fitted,
iterations,
iterations_per_response,
converged: all_converged,
penalized_neg_log_likelihood,
deviance,
})
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array3;
fn toy_inputs() -> (Array2<f64>, Array2<f64>, Array2<f64>, Array1<f64>) {
let n = 12;
let p = 2;
let k = 2;
let design = Array2::<f64>::from_shape_fn((n, p), |(i, j)| {
if j == 0 { 1.0 } else { ((i + 1) as f64).sin() }
});
let y = Array2::<f64>::from_shape_fn((n, k), |(i, a)| {
if (i + a) % 2 == 0 { 1.0 } else { 0.0 }
});
let penalty = Array2::<f64>::eye(p);
let lambdas = Array1::<f64>::from_elem(k, 0.5);
(design, y, penalty, lambdas)
}
#[test]
fn fisher_override_none_reproduces_analytic_bit_for_bit() {
let (design, y, penalty, lambdas) = toy_inputs();
let base = fit_penalized_binomial_multi(BinomialMultiFitInputs {
design: design.view(),
y: y.view(),
penalty: penalty.view(),
lambdas: lambdas.view(),
row_weights: None,
fisher_w_override: None,
max_iter: 50,
tol: 1.0e-9,
})
.expect("analytic fit must succeed");
let again = fit_penalized_binomial_multi(BinomialMultiFitInputs {
design: design.view(),
y: y.view(),
penalty: penalty.view(),
lambdas: lambdas.view(),
row_weights: None,
fisher_w_override: None,
max_iter: 50,
tol: 1.0e-9,
})
.expect("analytic fit must succeed");
for (a, b) in base.coefficients.iter().zip(again.coefficients.iter()) {
assert_eq!(a, b, "None override must be deterministic");
}
}
#[test]
fn fisher_override_shape_mismatch_is_rejected() {
let (design, y, penalty, lambdas) = toy_inputs();
let n = design.nrows();
let k = y.ncols();
let bad = Array3::<f64>::zeros((n, k + 1, k + 1));
let err = fit_penalized_binomial_multi(BinomialMultiFitInputs {
design: design.view(),
y: y.view(),
penalty: penalty.view(),
lambdas: lambdas.view(),
row_weights: None,
fisher_w_override: Some(bad.view()),
max_iter: 50,
tol: 1.0e-9,
})
.expect_err("mismatched override shape must error");
assert!(format!("{err}").contains("fisher_w_override shape"));
}
#[test]
fn fisher_override_replaces_curvature_diagonal() {
let (design, y, penalty, lambdas) = toy_inputs();
let n = design.nrows();
let k = y.ncols();
let mut over = Array3::<f64>::zeros((n, k, k));
for row in 0..n {
for a in 0..k {
over[[row, a, a]] = 0.25 * 4.0; }
}
let scaled = fit_penalized_binomial_multi(BinomialMultiFitInputs {
design: design.view(),
y: y.view(),
penalty: penalty.view(),
lambdas: lambdas.view(),
row_weights: None,
fisher_w_override: Some(over.view()),
max_iter: 1,
tol: 1.0e-9,
})
.expect("override fit must succeed");
let analytic = fit_penalized_binomial_multi(BinomialMultiFitInputs {
design: design.view(),
y: y.view(),
penalty: penalty.view(),
lambdas: lambdas.view(),
row_weights: None,
fisher_w_override: None,
max_iter: 1,
tol: 1.0e-9,
})
.expect("analytic fit must succeed");
let differs = scaled
.coefficients
.iter()
.zip(analytic.coefficients.iter())
.any(|(a, b)| (a - b).abs() > 1.0e-6);
assert!(differs, "scaled curvature override must change the step");
}
}