use crate::faer_ndarray::{FaerArrayView, array2_to_matmut, factorize_symmetricwith_fallback};
use crate::families::custom_family::{BlockwiseFitOptions, fit_custom_family_with_rho_prior};
use crate::families::multinomial_reml::MultinomialFamily;
use crate::families::vector_response::{MultinomialLogitLikelihood, VectorLikelihood};
use crate::inference::data::EncodedDataset;
use crate::inference::formula_dsl::parse_formula;
use crate::inference::model::ColumnKindTag;
use crate::pirls::dense_block_xtwx;
use crate::resource::ProblemHints;
use crate::solver::estimate::EstimationError;
use crate::solver::workflow::{
FitConfig, build_termspec_with_geometry_and_overrides, resolved_resource_policy,
};
use crate::terms::smooth::{
TermCollectionDesign, TermCollectionSpec, build_term_collection_design,
weighted_blockwise_penalty_sum,
};
use crate::terms::term_builder::resolve_role_col;
use crate::types::ResponseColumnKind;
use faer::Side;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayView3};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct MultinomialFitInputs<'a> {
pub design: ArrayView2<'a, f64>,
pub y_one_hot: ArrayView2<'a, f64>,
pub penalty: ArrayView2<'a, f64>,
pub lambdas: ArrayView1<'a, f64>,
pub row_weights: Option<ArrayView1<'a, f64>>,
pub fisher_w_override: Option<ArrayView3<'a, f64>>,
pub max_iter: usize,
pub tol: f64,
}
#[derive(Debug, Clone)]
pub struct MultinomialFitOutputs {
pub coefficients_active: Array2<f64>,
pub fitted_probabilities: Array2<f64>,
pub iterations: usize,
pub converged: bool,
pub penalized_neg_log_likelihood: f64,
pub deviance: f64,
}
pub fn fit_penalized_multinomial(
inputs: MultinomialFitInputs<'_>,
) -> Result<MultinomialFitOutputs, EstimationError> {
let MultinomialFitInputs {
design,
y_one_hot,
penalty,
lambdas,
row_weights,
fisher_w_override,
max_iter,
tol,
} = inputs;
let n_obs = design.nrows();
let p = design.ncols();
if n_obs == 0 || p == 0 {
crate::bail_invalid_estim!(
"fit_penalized_multinomial: design must be nonempty (got {n_obs}x{p})"
);
}
let (y_rows, k) = y_one_hot.dim();
if y_rows != n_obs {
crate::bail_invalid_estim!(
"fit_penalized_multinomial: y rows {y_rows} ≠ design rows {n_obs}"
);
}
if k < 2 {
crate::bail_invalid_estim!(
"fit_penalized_multinomial: need at least 2 classes (got K={k})"
);
}
let m = k - 1;
if lambdas.len() != m {
crate::bail_invalid_estim!(
"fit_penalized_multinomial: lambdas length {} ≠ K-1 = {m}",
lambdas.len()
);
}
if penalty.dim() != (p, p) {
crate::bail_invalid_estim!(
"fit_penalized_multinomial: penalty shape {:?} ≠ (P, P) = ({p}, {p})",
penalty.dim()
);
}
for (i, &v) in lambdas.iter().enumerate() {
if !(v.is_finite() && v >= 0.0) {
crate::bail_invalid_estim!(
"fit_penalized_multinomial: lambdas[{i}] must be finite and ≥ 0 (got {v})"
);
}
}
if let Some(fw) = fisher_w_override.as_ref() {
if fw.dim() != (n_obs, m, m) {
crate::bail_invalid_estim!(
"fit_penalized_multinomial: fisher_w_override shape {:?} ≠ (N, K-1, K-1) = ({n_obs}, {m}, {m})",
fw.dim()
);
}
}
if let Some(w) = row_weights.as_ref() {
if w.len() != n_obs {
crate::bail_invalid_estim!(
"fit_penalized_multinomial: row_weights length {} ≠ N = {n_obs}",
w.len()
);
}
for (i, &v) in w.iter().enumerate() {
if !(v.is_finite() && v >= 0.0) {
crate::bail_invalid_estim!(
"fit_penalized_multinomial: row_weights[{i}] must be finite and ≥ 0 (got {v})"
);
}
}
}
for ((i, j), &v) in y_one_hot.indexed_iter() {
if !v.is_finite() {
crate::bail_invalid_estim!(
"fit_penalized_multinomial: y[{i},{j}] must be finite (got {v})"
);
}
}
for ((i, j), &v) in design.indexed_iter() {
if !v.is_finite() {
crate::bail_invalid_estim!(
"fit_penalized_multinomial: design[{i},{j}] must be finite (got {v})"
);
}
}
let mut likelihood = MultinomialLogitLikelihood::with_classes(k)?;
if let Some(w) = row_weights.as_ref() {
likelihood = likelihood.with_row_weights(w.to_owned())?;
}
let mut beta = Array2::<f64>::zeros((p, m));
let mut eta = Array2::<f64>::zeros((n_obs, m));
let mut iterations = 0usize;
let mut converged = false;
let mut last_objective = f64::INFINITY;
let beta_flat_dim = p * m;
for iter in 0..max_iter {
iterations = iter + 1;
for a in 0..m {
let beta_col = beta.column(a);
for row in 0..n_obs {
let mut eta_val = 0.0_f64;
for i in 0..p {
eta_val += design[[row, i]] * beta_col[i];
}
eta[[row, a]] = eta_val;
}
}
let analytic_fisher = fisher_w_override
.as_ref()
.map_or_else(|| Some(likelihood.hess_block(eta.view(), y_one_hot)), |_| None);
let fisher_blocks = match fisher_w_override.as_ref() {
Some(fw) => *fw,
None => analytic_fisher
.as_ref()
.expect("analytic Fisher computed when no override")
.view(),
};
let grad_eta_logl = likelihood.grad_eta(eta.view(), y_one_hot);
let residual_active = grad_eta_logl.mapv(|v| -v);
let mut hessian = dense_block_xtwx(design, fisher_blocks, None)?;
if hessian.nrows() != beta_flat_dim || hessian.ncols() != beta_flat_dim {
crate::bail_invalid_estim!(
"fit_penalized_multinomial: assembled Hessian shape {:?} ≠ ({beta_flat_dim}, {beta_flat_dim})",
hessian.dim()
);
}
for a in 0..m {
let la = lambdas[a];
if la == 0.0 {
continue;
}
let base = a * p;
for i in 0..p {
for j in 0..p {
hessian[[base + i, base + j]] += la * penalty[[i, j]];
}
}
}
let mut grad_flat = Array1::<f64>::zeros(beta_flat_dim);
for a in 0..m {
for i in 0..p {
let mut acc = 0.0_f64;
for row in 0..n_obs {
acc += design[[row, i]] * residual_active[[row, a]];
}
grad_flat[a * p + i] = acc;
}
}
for a in 0..m {
let la = lambdas[a];
if la == 0.0 {
continue;
}
let beta_col = beta.column(a);
for i in 0..p {
let mut s_beta_i = 0.0_f64;
for j in 0..p {
s_beta_i += penalty[[i, j]] * beta_col[j];
}
grad_flat[a * p + i] += la * s_beta_i;
}
}
let factor = factorize_symmetricwith_fallback(
FaerArrayView::new(&hessian).as_ref(),
Side::Lower,
)
.map_err(|err| {
EstimationError::InvalidInput(format!(
"fit_penalized_multinomial: Hessian factorization failed at iter {iter}: {err}"
))
})?;
let mut rhs = Array2::<f64>::zeros((beta_flat_dim, 1));
for i in 0..beta_flat_dim {
rhs[[i, 0]] = -grad_flat[i];
}
{
let rhs_view = array2_to_matmut(&mut rhs);
factor.solve_in_place(rhs_view);
}
let mut delta = Array1::<f64>::zeros(beta_flat_dim);
for i in 0..beta_flat_dim {
delta[i] = rhs[[i, 0]];
}
if delta.iter().any(|v| !v.is_finite()) {
crate::bail_invalid_estim!(
"fit_penalized_multinomial: Newton step is non-finite at iter {iter}"
);
}
let proposed_beta = |alpha: f64| -> Array2<f64> {
let mut out = beta.clone();
for a in 0..m {
for i in 0..p {
out[[i, a]] += alpha * delta[a * p + i];
}
}
out
};
let evaluate_objective = |beta_trial: &Array2<f64>| -> f64 {
let mut eta_trial = Array2::<f64>::zeros((n_obs, m));
for a in 0..m {
let beta_col = beta_trial.column(a);
for row in 0..n_obs {
let mut v = 0.0_f64;
for i in 0..p {
v += design[[row, i]] * beta_col[i];
}
eta_trial[[row, a]] = v;
}
}
let ll = likelihood.log_lik(eta_trial.view(), y_one_hot);
let mut pen = 0.0_f64;
for a in 0..m {
let la = lambdas[a];
if la == 0.0 {
continue;
}
let beta_col = beta_trial.column(a);
let mut quad = 0.0_f64;
for i in 0..p {
let mut s_beta_i = 0.0_f64;
for j in 0..p {
s_beta_i += penalty[[i, j]] * beta_col[j];
}
quad += beta_col[i] * s_beta_i;
}
pen += 0.5 * la * quad;
}
-ll + pen
};
if iter == 0 {
last_objective = evaluate_objective(&beta);
if !last_objective.is_finite() {
crate::bail_invalid_estim!(
"fit_penalized_multinomial: non-finite objective at β = 0"
);
}
}
let mut alpha = 1.0_f64;
let mut accepted_beta = proposed_beta(alpha);
let mut new_objective = evaluate_objective(&accepted_beta);
let mut backtrack = 0usize;
while (!new_objective.is_finite() || new_objective > last_objective + 1.0e-12)
&& backtrack < 8
{
alpha *= 0.5;
accepted_beta = proposed_beta(alpha);
new_objective = evaluate_objective(&accepted_beta);
backtrack += 1;
}
let mut step_norm_sq = 0.0_f64;
let mut beta_norm_sq = 0.0_f64;
for a in 0..m {
for i in 0..p {
let d = accepted_beta[[i, a]] - beta[[i, a]];
step_norm_sq += d * d;
let v = accepted_beta[[i, a]];
beta_norm_sq += v * v;
}
}
beta = accepted_beta;
last_objective = new_objective;
let step_norm = step_norm_sq.sqrt();
let beta_norm = beta_norm_sq.sqrt();
if step_norm <= tol * (1.0 + beta_norm) {
converged = true;
break;
}
}
for a in 0..m {
let beta_col = beta.column(a);
for row in 0..n_obs {
let mut v = 0.0_f64;
for i in 0..p {
v += design[[row, i]] * beta_col[i];
}
eta[[row, a]] = v;
}
}
let fitted_probabilities = likelihood.probabilities(eta.view());
let log_lik = likelihood.log_lik(eta.view(), y_one_hot);
let mut pen = 0.0_f64;
for a in 0..m {
let la = lambdas[a];
if la == 0.0 {
continue;
}
let beta_col = beta.column(a);
let mut quad = 0.0_f64;
for i in 0..p {
let mut s_beta_i = 0.0_f64;
for j in 0..p {
s_beta_i += penalty[[i, j]] * beta_col[j];
}
quad += beta_col[i] * s_beta_i;
}
pen += 0.5 * la * quad;
}
Ok(MultinomialFitOutputs {
coefficients_active: beta,
fitted_probabilities,
iterations,
converged,
penalized_neg_log_likelihood: -log_lik + pen,
deviance: -2.0 * log_lik,
})
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultinomialSavedModel {
pub formula: String,
pub class_levels: Vec<String>,
pub reference_class_index: usize,
pub resolved_termspec: TermCollectionSpec,
pub coefficients_flat: Vec<f64>,
pub p_per_class: usize,
pub n_active_classes: usize,
pub training_headers: Vec<String>,
pub lambdas: Vec<f64>,
pub iterations: usize,
pub converged: bool,
pub penalized_neg_log_likelihood: f64,
pub deviance: f64,
#[serde(default)]
pub edf_per_class: Option<Vec<f64>>,
}
impl MultinomialSavedModel {
pub fn coefficients_active(&self) -> Array2<f64> {
Array2::from_shape_vec(
(self.p_per_class, self.n_active_classes),
self.coefficients_flat.clone(),
)
.expect(
"MultinomialSavedModel.coefficients_flat length must equal p_per_class * n_active_classes",
)
}
pub fn predict_probabilities(&self, x_new: ArrayView2<'_, f64>) -> Array2<f64> {
let n_new = x_new.nrows();
let p = self.p_per_class;
let m = self.n_active_classes;
let k = m + 1;
assert_eq!(
x_new.ncols(),
p,
"MultinomialSavedModel.predict_probabilities: X has {} cols, expected {p}",
x_new.ncols()
);
let beta = self.coefficients_active();
let mut probs = Array2::<f64>::zeros((n_new, k));
let mut eta_active = vec![0.0_f64; m];
let mut row_probs = vec![0.0_f64; k];
for row in 0..n_new {
for a in 0..m {
let mut v = 0.0_f64;
for i in 0..p {
v += x_new[[row, i]] * beta[[i, a]];
}
eta_active[a] = v;
}
MultinomialLogitLikelihood::softmax_with_baseline(&eta_active, &mut row_probs);
for c in 0..k {
probs[[row, c]] = row_probs[c];
}
}
probs
}
}
fn one_hot_categorical_response(
data: &EncodedDataset,
y_col: usize,
response_name: &str,
) -> Result<(Array2<f64>, Vec<String>), EstimationError> {
let levels: Vec<String> = data
.schema
.columns
.get(y_col)
.map(|sc| sc.levels.clone())
.unwrap_or_default();
if levels.len() < 2 {
crate::bail_invalid_estim!(
"multinomial response '{response_name}' must have at least 2 categorical levels (got {})",
levels.len()
);
}
let n = data.values.nrows();
let k = levels.len();
let mut y_one_hot = Array2::<f64>::zeros((n, k));
for row in 0..n {
let encoded = data.values[[row, y_col]];
if !encoded.is_finite() {
crate::bail_invalid_estim!(
"multinomial response '{response_name}' row {row} is non-finite ({encoded})"
);
}
let class_idx = encoded.round() as i64;
if class_idx < 0 || (class_idx as usize) >= k {
crate::bail_invalid_estim!(
"multinomial response '{response_name}' row {row} encoded as {encoded} \
is outside the level range 0..{k}"
);
}
y_one_hot[[row, class_idx as usize]] = 1.0;
}
Ok((y_one_hot, levels))
}
fn build_formula_design_for_multinomial(
formula: &str,
data: &EncodedDataset,
config: &FitConfig,
) -> Result<
(
TermCollectionSpec,
TermCollectionDesign,
usize,
String,
ResponseColumnKind,
),
EstimationError,
> {
let parsed = parse_formula(formula).map_err(|err| {
EstimationError::InvalidInput(format!(
"multinomial fit: failed to parse formula {formula:?}: {err}"
))
})?;
let col_map = data.column_map();
let y_col = resolve_role_col(&col_map, &parsed.response, "response")
.map_err(|err| EstimationError::InvalidInput(format!("multinomial fit: {err}")))?;
let y_kind = crate::solver::workflow::response_column_kind(data, y_col);
let policy = resolved_resource_policy(config, data, ProblemHints::default());
let mut inference_notes: Vec<String> = Vec::new();
let spec = build_termspec_with_geometry_and_overrides(
&parsed.terms,
data,
&col_map,
&mut inference_notes,
config.scale_dimensions,
&policy,
config.smooth_overrides.as_ref(),
)
.map_err(|err| {
EstimationError::InvalidInput(format!("multinomial fit: build termspec: {err}"))
})?;
let design = build_term_collection_design(data.values.view(), &spec).map_err(|err| {
EstimationError::InvalidInput(format!("multinomial fit: build design: {err}"))
})?;
Ok((spec, design, y_col, parsed.response, y_kind))
}
pub fn fit_penalized_multinomial_formula(
data: &EncodedDataset,
formula: &str,
config: &FitConfig,
init_lambda: f64,
max_iter: usize,
tol: f64,
) -> Result<MultinomialSavedModel, EstimationError> {
if !(init_lambda.is_finite() && init_lambda > 0.0) {
crate::bail_invalid_estim!(
"multinomial fit: init_lambda must be finite and > 0 (got {init_lambda})"
);
}
let (spec, design, y_col, response_name, y_kind) =
build_formula_design_for_multinomial(formula, data, config)?;
let class_levels = match y_kind {
ResponseColumnKind::Categorical { levels } => levels,
ResponseColumnKind::Binary => vec!["0".to_string(), "1".to_string()],
ResponseColumnKind::Numeric => {
crate::bail_invalid_estim!(
"multinomial fit: response '{response_name}' is numeric, not categorical; \
use family='gaussian'/'binomial'/... or convert the column to a categorical type"
);
}
};
if data.column_kinds.get(y_col) == Some(&ColumnKindTag::Binary) {
} else if data.column_kinds.get(y_col) != Some(&ColumnKindTag::Categorical) {
crate::bail_invalid_estim!(
"multinomial fit: response '{response_name}' must be a categorical column \
(got column kind {:?})",
data.column_kinds.get(y_col)
);
}
let (y_one_hot, _) = one_hot_categorical_response(data, y_col, &response_name)?;
let x_dense = design
.design
.try_to_dense_by_chunks("multinomial fit design")
.map_err(EstimationError::InvalidInput)?;
let p_total = x_dense.ncols();
let lambdas_block = vec![1.0_f64; design.penalties.len()];
let s_total = weighted_blockwise_penalty_sum(&design.penalties, &lambdas_block, p_total);
let k = y_one_hot.ncols();
let m = k - 1;
let n_obs = y_one_hot.nrows();
let design_arc = Arc::new(x_dense);
let penalty_arc = Arc::new(s_total);
let weights = Array1::<f64>::ones(n_obs);
let family = MultinomialFamily::new(
y_one_hot.clone(),
weights,
k,
design_arc.clone(),
penalty_arc.clone(),
0,
)
.map_err(EstimationError::InvalidInput)?;
let mut blocks = family.build_block_specs();
let log_init = init_lambda.ln();
for spec_block in blocks.iter_mut() {
for v in spec_block.initial_log_lambdas.iter_mut() {
*v = log_init;
}
}
let options = BlockwiseFitOptions {
inner_max_cycles: max_iter,
inner_tol: tol,
..BlockwiseFitOptions::default()
};
let fit = fit_custom_family_with_rho_prior(
&family,
&blocks,
&options,
crate::types::RhoPrior::Flat,
)
.map_err(|err| EstimationError::InvalidInput(format!("multinomial REML: {err}")))?;
if fit.blocks.len() != m {
crate::bail_invalid_estim!(
"multinomial REML: expected {m} fitted blocks (K-1), got {}",
fit.blocks.len()
);
}
let p_per_class = fit.blocks[0].beta.len();
let mut coefficients_active = Array2::<f64>::zeros((p_per_class, m));
for (a, block) in fit.blocks.iter().enumerate() {
if block.beta.len() != p_per_class {
crate::bail_invalid_estim!(
"multinomial REML: block {a} has {} coefs, expected {p_per_class}",
block.beta.len()
);
}
for i in 0..p_per_class {
coefficients_active[[i, a]] = block.beta[i];
}
}
let lambdas_per_class: Vec<f64> = fit
.blocks
.iter()
.map(|b| b.lambdas.iter().copied().next().unwrap_or(init_lambda))
.collect();
let edf_per_class = fit
.inference
.as_ref()
.map(|info| info.edf_by_block.clone());
let coefficients_flat: Vec<f64> = coefficients_active.iter().copied().collect();
if fit.block_states.len() != m {
crate::bail_invalid_estim!(
"multinomial REML: expected {m} fitted block states (K-1), got {}",
fit.block_states.len()
);
}
let deviance = -2.0 * fit.log_likelihood;
Ok(MultinomialSavedModel {
formula: formula.to_string(),
class_levels: class_levels.clone(),
reference_class_index: class_levels.len() - 1,
resolved_termspec: spec,
coefficients_flat,
p_per_class,
n_active_classes: m,
training_headers: data.headers.clone(),
lambdas: lambdas_per_class,
iterations: fit.inner_cycles,
converged: fit.outer_converged,
penalized_neg_log_likelihood: -fit.log_likelihood + 0.5 * fit.stable_penalty_term,
deviance,
edf_per_class,
})
}
pub fn predict_multinomial_formula(
model: &MultinomialSavedModel,
data: &EncodedDataset,
) -> Result<Array2<f64>, EstimationError> {
let design = build_term_collection_design(data.values.view(), &model.resolved_termspec)
.map_err(|err| {
EstimationError::InvalidInput(format!(
"multinomial predict: rebuild design from saved termspec: {err}"
))
})?;
let x_dense = design
.design
.try_to_dense_by_chunks("multinomial predict design")
.map_err(EstimationError::InvalidInput)?;
if x_dense.ncols() != model.p_per_class {
crate::bail_invalid_estim!(
"multinomial predict: predict design has {} cols, saved model expects {}",
x_dense.ncols(),
model.p_per_class
);
}
Ok(model.predict_probabilities(x_dense.view()))
}
#[cfg(test)]
mod fisher_override_tests {
use super::*;
use ndarray::Array3;
fn toy() -> (Array2<f64>, Array2<f64>, Array2<f64>, Array1<f64>) {
let n = 15;
let p = 2;
let k = 3;
let design = Array2::<f64>::from_shape_fn((n, p), |(i, j)| {
if j == 0 { 1.0 } else { ((i + 2) as f64).cos() }
});
let mut y = Array2::<f64>::zeros((n, k));
for i in 0..n {
y[[i, i % k]] = 1.0;
}
let penalty = Array2::<f64>::eye(p);
let lambdas = Array1::<f64>::from_elem(k - 1, 0.5);
(design, y, penalty, lambdas)
}
#[test]
fn fisher_override_none_reproduces_analytic() {
let (design, y, penalty, lambdas) = toy();
let mk = |over: Option<ndarray::ArrayView3<'_, f64>>| {
fit_penalized_multinomial(MultinomialFitInputs {
design: design.view(),
y_one_hot: y.view(),
penalty: penalty.view(),
lambdas: lambdas.view(),
row_weights: None,
fisher_w_override: over,
max_iter: 50,
tol: 1.0e-9,
})
.expect("fit must succeed")
};
let a = mk(None);
let b = mk(None);
for (x, z) in a.coefficients_active.iter().zip(b.coefficients_active.iter()) {
assert_eq!(x, z);
}
}
#[test]
fn fisher_override_wrong_shape_is_rejected() {
let (design, y, penalty, lambdas) = toy();
let n = design.nrows();
let m = y.ncols(); let bad = Array3::<f64>::zeros((n, m, m));
let err = fit_penalized_multinomial(MultinomialFitInputs {
design: design.view(),
y_one_hot: y.view(),
penalty: penalty.view(),
lambdas: lambdas.view(),
row_weights: None,
fisher_w_override: Some(bad.view()),
max_iter: 50,
tol: 1.0e-9,
})
.expect_err("wrong active-block shape must error");
assert!(format!("{err}").contains("fisher_w_override shape"));
}
#[test]
fn scaled_fisher_override_changes_first_step() {
let (design, y, penalty, lambdas) = toy();
let n = design.nrows();
let m = y.ncols() - 1;
let pk = 1.0 / (y.ncols() as f64);
let mut over = Array3::<f64>::zeros((n, m, m));
for row in 0..n {
for a in 0..m {
for b in 0..m {
let analytic = if a == b { pk * (1.0 - pk) } else { -pk * pk };
over[[row, a, b]] = 4.0 * analytic;
}
}
}
let scaled = fit_penalized_multinomial(MultinomialFitInputs {
design: design.view(),
y_one_hot: y.view(),
penalty: penalty.view(),
lambdas: lambdas.view(),
row_weights: None,
fisher_w_override: Some(over.view()),
max_iter: 1,
tol: 1.0e-9,
})
.expect("override fit must succeed");
let analytic = fit_penalized_multinomial(MultinomialFitInputs {
design: design.view(),
y_one_hot: y.view(),
penalty: penalty.view(),
lambdas: lambdas.view(),
row_weights: None,
fisher_w_override: None,
max_iter: 1,
tol: 1.0e-9,
})
.expect("analytic fit must succeed");
let differs = scaled
.coefficients_active
.iter()
.zip(analytic.coefficients_active.iter())
.any(|(a, b)| (a - b).abs() > 1.0e-6);
assert!(differs, "scaled curvature must change the first step");
}
}