use crate::custom_family::{
BlockwiseFitOptions, ParameterBlockState, PenaltyMatrix, fit_custom_family_with_rho_prior,
};
use crate::multinomial_reml::MultinomialFamily;
use crate::penalized_vector_glm::{PenalizedVectorGlmInputs, fit_penalized_vector_glm};
use crate::vector_response::{MultinomialLogitLikelihood, validate_multinomial_simplex};
use gam_terms::inference::formula_dsl::parse_formula;
use crate::model_types::EstimationError;
use crate::fit_orchestration::{
FitConfig, build_termspec_with_geometry_and_overrides, resolved_resource_policy,
};
use gam_terms::smooth::{
PenaltyBlockInfo, TermCollectionDesign, TermCollectionSpec, build_term_collection_design,
};
use crate::fit_orchestration::drivers::freeze_term_collection_from_design;
use gam_terms::term_builder::resolve_role_col;
use gam_problem::ResponseColumnKind;
use gam_data::ColumnKindTag;
use gam_data::EncodedDataset;
use gam_runtime::resource::ProblemHints;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayView3};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
const MULTINOMIAL_FORMULA_RIDGE_FLOOR: f64 = 1.0e-4;
const MULTINOMIAL_FORMULA_INNER_TOL: f64 = 1.0e-5;
fn multinomial_formula_penalty_scale(n_classes: usize) -> f64 {
let k = n_classes.max(2) as f64;
2.0 * (k - 1.0) / (k * k)
}
const MULTINOMIAL_EXACT_OUTER_HESSIAN_MAX_DIM: usize = 16;
fn multinomial_formula_use_outer_hessian(total_rho_dim: usize) -> bool {
total_rho_dim <= MULTINOMIAL_EXACT_OUTER_HESSIAN_MAX_DIM
}
const MULTINOMIAL_SEPARATION_ETA_THRESHOLD: f64 = 25.0;
const MULTINOMIAL_OUTER_REML_TOL: f64 = 1e-7;
const MULTINOMIAL_UNBIASED_PROBE_OUTER_MAX_ITER: usize = 20;
const MULTINOMIAL_FORMULA_FISHER_INFO_PER_OBS: f64 = 0.25;
const MULTINOMIAL_FORMULA_PRIOR_PSEUDO_OBS: f64 = 8.0e-4;
const MULTINOMIAL_FORMULA_SPARSE_REFERENCE_SUPPORT: f64 = 50.0;
const MULTINOMIAL_FORMULA_SPARSE_PRIOR_PSEUDO_OBS_MAX: f64 = 4.0e-3;
fn multinomial_formula_min_lambda(y_one_hot: ArrayView2<'_, f64>) -> f64 {
let base = MULTINOMIAL_FORMULA_PRIOR_PSEUDO_OBS * MULTINOMIAL_FORMULA_FISHER_INFO_PER_OBS;
let sparse =
MULTINOMIAL_FORMULA_SPARSE_PRIOR_PSEUDO_OBS_MAX * MULTINOMIAL_FORMULA_FISHER_INFO_PER_OBS;
let min_class_count = (0..y_one_hot.ncols())
.map(|class| y_one_hot.column(class).sum())
.fold(f64::INFINITY, f64::min);
if !min_class_count.is_finite() || min_class_count <= 0.0 {
return base;
}
let pseudo_obs_scale =
(MULTINOMIAL_FORMULA_SPARSE_REFERENCE_SUPPORT / min_class_count).max(1.0);
(base * pseudo_obs_scale).clamp(base, sparse)
}
fn max_abs_eta_location(eta: ArrayView2<'_, f64>) -> (f64, usize, usize) {
let mut best = (0.0_f64, 0usize, 0usize);
for ((row, active_class), &value) in eta.indexed_iter() {
let abs = value.abs();
if abs > best.0 {
best = (abs, row, active_class);
}
}
best
}
fn multinomial_formula_separation_diagnostic(
inner_cycles: usize,
outer_iterations: usize,
block_states: &[ParameterBlockState],
) -> Option<EstimationError> {
let mut nonfinite: Option<(f64, usize, usize)> = None;
for (active_class, state) in block_states.iter().enumerate() {
for (row, &value) in state.eta.iter().enumerate() {
if !value.is_finite() {
nonfinite = Some((value, row, active_class));
break;
}
}
if nonfinite.is_some() {
break;
}
}
nonfinite.map(|(value, row_index, active_class_index)| {
EstimationError::MultinomialSeparationDetected {
iteration: inner_cycles.max(outer_iterations),
max_abs_eta: value.abs(),
active_class_index,
row_index,
}
})
}
fn multinomial_formula_separation_evidence(block_states: &[ParameterBlockState]) -> Option<String> {
for (active_class, state) in block_states.iter().enumerate() {
for (row, &value) in state.eta.iter().enumerate() {
if !value.is_finite() {
return Some(format!(
"non-finite logit eta[row {row}, active class {active_class}] = {value}"
));
}
}
}
None
}
fn multinomial_formula_unresolved_probe_separation_evidence(
block_states: &[ParameterBlockState],
) -> Option<String> {
if let Some(evidence) = multinomial_formula_separation_evidence(block_states) {
return Some(evidence);
}
let mut best = (0.0_f64, 0usize, 0usize);
for (active_class, state) in block_states.iter().enumerate() {
for (row, &value) in state.eta.iter().enumerate() {
let abs = value.abs();
if abs > best.0 {
best = (abs, row, active_class);
}
}
}
if best.0 >= MULTINOMIAL_SEPARATION_ETA_THRESHOLD {
Some(format!(
"separation-scale finite logit |eta[row {}, active class {}]| = {:.3e} \
after capped unbiased probe",
best.1, best.2, best.0
))
} else {
None
}
}
#[derive(Debug, Clone)]
pub struct MultinomialFitInputs<'a> {
pub design: ArrayView2<'a, f64>,
pub y_one_hot: ArrayView2<'a, f64>,
pub penalty: ArrayView2<'a, f64>,
pub lambdas: ArrayView1<'a, f64>,
pub row_weights: Option<ArrayView1<'a, f64>>,
pub fisher_w_override: Option<ArrayView3<'a, f64>>,
pub max_iter: usize,
pub tol: f64,
}
#[derive(Debug, Clone)]
pub struct MultinomialFitOutputs {
pub coefficients_active: Array2<f64>,
pub fitted_probabilities: Array2<f64>,
pub iterations: usize,
pub converged: bool,
pub penalized_neg_log_likelihood: f64,
pub deviance: f64,
}
pub fn fit_penalized_multinomial(
inputs: MultinomialFitInputs<'_>,
) -> Result<MultinomialFitOutputs, EstimationError> {
let MultinomialFitInputs {
design,
y_one_hot,
penalty,
lambdas,
row_weights,
fisher_w_override,
max_iter,
tol,
} = inputs;
let n_obs = design.nrows();
let (y_rows, k) = y_one_hot.dim();
if y_rows != n_obs {
crate::bail_invalid_estim!(
"fit_penalized_multinomial: y rows {y_rows} ≠ design rows {n_obs}"
);
}
if k < 2 {
crate::bail_invalid_estim!(
"fit_penalized_multinomial: need at least 2 classes (got K={k})"
);
}
let m = k - 1;
if lambdas.len() != m {
crate::bail_invalid_estim!(
"fit_penalized_multinomial: lambdas length {} ≠ K-1 = {m}",
lambdas.len()
);
}
if let Some(fw) = fisher_w_override.as_ref() {
if fw.dim() != (n_obs, m, m) {
crate::bail_invalid_estim!(
"fit_penalized_multinomial: fisher_w_override shape {:?} ≠ (N, K-1, K-1) = ({n_obs}, {m}, {m})",
fw.dim()
);
}
}
if let Some(w) = row_weights.as_ref() {
if w.len() != n_obs {
crate::bail_invalid_estim!(
"fit_penalized_multinomial: row_weights length {} ≠ N = {n_obs}",
w.len()
);
}
for (i, &v) in w.iter().enumerate() {
if !(v.is_finite() && v >= 0.0) {
crate::bail_invalid_estim!(
"fit_penalized_multinomial: row_weights[{i}] must be finite and ≥ 0 (got {v})"
);
}
}
}
validate_multinomial_simplex(y_one_hot, "fit_penalized_multinomial")?;
let mut likelihood = MultinomialLogitLikelihood::with_classes(k)?;
if let Some(w) = row_weights.as_ref() {
likelihood = likelihood.with_row_weights(w.to_owned())?;
}
let fit = fit_penalized_vector_glm(
PenalizedVectorGlmInputs {
design,
y: y_one_hot,
penalty,
lambdas,
fisher_w_override,
max_iter,
tol,
class_penalty_metric: crate::penalized_vector_glm::ClassPenaltyMetric::Diagonal,
},
&likelihood,
"fit_penalized_multinomial",
)?;
let (max_abs_eta, row_index, active_class_index) = max_abs_eta_location(fit.eta.view());
if !fit.converged && max_abs_eta >= MULTINOMIAL_SEPARATION_ETA_THRESHOLD {
return Err(EstimationError::MultinomialSeparationDetected {
iteration: fit.iterations,
max_abs_eta,
active_class_index,
row_index,
});
}
let fitted_probabilities = likelihood.probabilities(fit.eta.view());
Ok(MultinomialFitOutputs {
coefficients_active: fit.coefficients,
fitted_probabilities,
iterations: fit.iterations,
converged: fit.converged,
penalized_neg_log_likelihood: -fit.log_likelihood + fit.penalty_term,
deviance: -2.0 * fit.log_likelihood,
})
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultinomialSavedModel {
pub formula: String,
pub class_levels: Vec<String>,
pub reference_class_index: usize,
pub resolved_termspec: TermCollectionSpec,
pub coefficients_flat: Vec<f64>,
pub p_per_class: usize,
pub n_active_classes: usize,
pub training_headers: Vec<String>,
pub lambdas: Vec<f64>,
pub lambdas_per_block: Vec<usize>,
pub iterations: usize,
pub converged: bool,
pub penalized_neg_log_likelihood: f64,
pub deviance: f64,
#[serde(default)]
pub edf_per_class: Option<Vec<f64>>,
#[serde(default)]
pub coefficient_covariance_flat: Option<Vec<f64>>,
#[serde(default)]
pub coefficient_influence_flat: Option<Vec<f64>>,
#[serde(default)]
pub smooth_term_spans: Vec<MultinomialSmoothTermSpan>,
#[serde(default)]
pub lambda_labels: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultinomialSmoothTermSpan {
pub label: String,
pub col_start: usize,
pub col_end: usize,
pub nullspace_dim: usize,
}
fn penalty_component_label(info: Option<&PenaltyBlockInfo>, pen_idx: usize) -> String {
use gam_terms::basis::PenaltySource;
let term = info
.and_then(|i| i.termname.clone())
.unwrap_or_else(|| format!("s{pen_idx}"));
let role = match info.map(|i| &i.penalty.source) {
Some(PenaltySource::Primary) | None => None,
Some(PenaltySource::DoublePenaltyNullspace) => Some("null space".to_string()),
Some(PenaltySource::OperatorMass) => Some("mass".to_string()),
Some(PenaltySource::OperatorTension) => Some("tension".to_string()),
Some(PenaltySource::OperatorStiffness) => Some("stiffness".to_string()),
Some(PenaltySource::OperatorRelevance { axis }) => Some(format!("axis {axis}")),
Some(PenaltySource::TensorMarginal { dim }) => Some(format!("margin {dim}")),
Some(PenaltySource::TensorSeparable { penalized_margins }) => {
Some(format!("separable {penalized_margins:?}"))
}
Some(PenaltySource::TensorGlobalRidge) => Some("ridge".to_string()),
Some(PenaltySource::Other(s)) => Some(s.clone()),
};
match role {
Some(role) => format!("{term} [{role}]"),
None => term,
}
}
impl MultinomialSavedModel {
pub fn coefficients_active(&self) -> Array2<f64> {
Array2::from_shape_vec(
(self.p_per_class, self.n_active_classes),
self.coefficients_flat.clone(),
)
.expect(
"MultinomialSavedModel.coefficients_flat length must equal p_per_class * n_active_classes",
)
}
pub fn predict_probabilities(&self, x_new: ArrayView2<'_, f64>) -> Array2<f64> {
let n_new = x_new.nrows();
let p = self.p_per_class;
let m = self.n_active_classes;
let k = m + 1;
assert_eq!(
x_new.ncols(),
p,
"MultinomialSavedModel.predict_probabilities: X has {} cols, expected {p}",
x_new.ncols()
);
let beta = self.coefficients_active();
let mut probs = Array2::<f64>::zeros((n_new, k));
let mut eta_active = vec![0.0_f64; m];
let mut row_probs = vec![0.0_f64; k];
for row in 0..n_new {
for a in 0..m {
let mut v = 0.0_f64;
for i in 0..p {
v += x_new[[row, i]] * beta[[i, a]];
}
eta_active[a] = v;
}
MultinomialLogitLikelihood::softmax_with_baseline(&eta_active, &mut row_probs);
for c in 0..k {
probs[[row, c]] = row_probs[c];
}
}
probs
}
pub fn coefficient_covariance(&self) -> Option<Array2<f64>> {
let d = self.p_per_class.checked_mul(self.n_active_classes)?;
let flat = self.coefficient_covariance_flat.as_ref()?;
Array2::from_shape_vec((d, d), flat.clone()).ok()
}
pub fn coefficient_influence(&self) -> Option<Array2<f64>> {
let d = self.p_per_class.checked_mul(self.n_active_classes)?;
let flat = self.coefficient_influence_flat.as_ref()?;
Array2::from_shape_vec((d, d), flat.clone()).ok()
}
pub fn predict_probabilities_with_se(
&self,
x_new: ArrayView2<'_, f64>,
) -> (Array2<f64>, Option<Array2<f64>>) {
let probs = self.predict_probabilities(x_new);
let Some(cov) = self.coefficient_covariance() else {
return (probs, None);
};
let n_new = x_new.nrows();
let p = self.p_per_class;
let m = self.n_active_classes;
let k = m + 1;
let d = p * m;
let mut prob_se = Array2::<f64>::zeros((n_new, k));
let mut grad = vec![0.0_f64; d];
for row in 0..n_new {
let prow = probs.row(row);
for c in 0..k {
let pc = prow[c];
for a in 0..m {
let pa = prow[a];
let factor = pc * (if c == a { 1.0 - pa } else { -pa });
let base = a * p;
for i in 0..p {
grad[base + i] = x_new[[row, i]] * factor;
}
}
let mut var = 0.0_f64;
for r in 0..d {
let gr = grad[r];
if gr == 0.0 {
continue;
}
let mut acc = 0.0_f64;
for s in 0..d {
acc += cov[[r, s]] * grad[s];
}
var += gr * acc;
}
prob_se[[row, c]] = var.max(0.0).sqrt();
}
}
(probs, Some(prob_se))
}
pub fn smooth_significance(&self) -> Vec<MultinomialSmoothSignificance> {
let mut out = Vec::new();
let p = self.p_per_class;
let m = self.n_active_classes;
let Some(cov) = self.coefficient_covariance() else {
return out;
};
if self.smooth_term_spans.is_empty() {
return out;
}
let beta = self.coefficients_active();
let d = p * m;
let mut theta = Array1::<f64>::zeros(d);
for a in 0..m {
for i in 0..p {
theta[a * p + i] = beta[[i, a]];
}
}
let influence = self.coefficient_influence();
for a in 0..m {
let class_label = self
.class_levels
.get(a)
.cloned()
.unwrap_or_else(|| format!("class{a}"));
let base = a * p;
for span in &self.smooth_term_spans {
if span.col_end > p {
continue;
}
let start = base + span.col_start;
let end = base + span.col_end;
let block_len = (span.col_end - span.col_start) as f64;
let edf = influence
.as_ref()
.map(|f| (start..end).map(|i| f[[i, i]]).sum::<f64>())
.filter(|v| v.is_finite() && *v > 0.0)
.unwrap_or(block_len);
let result = gam_terms::inference::smooth_test::wood_smooth_test(
gam_terms::inference::smooth_test::SmoothTestInput {
beta: theta.view(),
covariance: &cov,
influence_matrix: influence.as_ref(),
coeff_range: start..end,
edf,
nullspace_dim: span.nullspace_dim,
residual_df: f64::INFINITY,
scale: gam_terms::inference::smooth_test::SmoothTestScale::Known,
},
);
if let Some(res) = result {
out.push(MultinomialSmoothSignificance {
class_label: class_label.clone(),
term_label: span.label.clone(),
edf,
ref_df: res.ref_df,
statistic: res.statistic,
p_value: res.p_value,
});
}
}
}
out
}
}
#[derive(Debug, Clone)]
pub struct MultinomialSmoothSignificance {
pub class_label: String,
pub term_label: String,
pub edf: f64,
pub ref_df: f64,
pub statistic: f64,
pub p_value: f64,
}
fn one_hot_categorical_response(
data: &EncodedDataset,
y_col: usize,
response_name: &str,
) -> Result<(Array2<f64>, Vec<String>), EstimationError> {
let levels: Vec<String> = data
.schema
.columns
.get(y_col)
.map(|sc| sc.levels.clone())
.unwrap_or_default();
if levels.len() < 2 {
crate::bail_invalid_estim!(
"multinomial response '{response_name}' must have at least 2 categorical levels (got {})",
levels.len()
);
}
let n = data.values.nrows();
let k = levels.len();
let mut y_one_hot = Array2::<f64>::zeros((n, k));
for row in 0..n {
let encoded = data.values[[row, y_col]];
if !encoded.is_finite() {
crate::bail_invalid_estim!(
"multinomial response '{response_name}' row {row} is non-finite ({encoded})"
);
}
let class_idx = encoded.round() as i64;
if class_idx < 0 || (class_idx as usize) >= k {
crate::bail_invalid_estim!(
"multinomial response '{response_name}' row {row} encoded as {encoded} \
is outside the level range 0..{k}"
);
}
y_one_hot[[row, class_idx as usize]] = 1.0;
}
Ok((y_one_hot, levels))
}
fn build_formula_design_for_multinomial(
formula: &str,
data: &EncodedDataset,
config: &FitConfig,
) -> Result<
(
TermCollectionSpec,
TermCollectionDesign,
usize,
String,
ResponseColumnKind,
),
EstimationError,
> {
let parsed = parse_formula(formula).map_err(|err| {
EstimationError::InvalidInput(format!(
"multinomial fit: failed to parse formula {formula:?}: {err}"
))
})?;
let col_map = data.column_map();
let y_col = resolve_role_col(&col_map, &parsed.response, "response")
.map_err(|err| EstimationError::InvalidInput(format!("multinomial fit: {err}")))?;
let y_kind = crate::fit_orchestration::response_column_kind(data, y_col);
let policy = resolved_resource_policy(config, data, ProblemHints::default());
let mut inference_notes: Vec<String> = Vec::new();
let spec = build_termspec_with_geometry_and_overrides(
&parsed.terms,
data,
&col_map,
&mut inference_notes,
config.scale_dimensions,
&policy,
config.smooth_overrides.as_ref(),
)
.map_err(|err| {
EstimationError::InvalidInput(format!("multinomial fit: build termspec: {err}"))
})?;
let design = build_term_collection_design(data.values.view(), &spec).map_err(|err| {
EstimationError::InvalidInput(format!("multinomial fit: build design: {err}"))
})?;
Ok((spec, design, y_col, parsed.response, y_kind))
}
fn scale_multinomial_formula_penalty(penalty: PenaltyMatrix, scale: f64) -> PenaltyMatrix {
match penalty {
PenaltyMatrix::Dense(matrix) => PenaltyMatrix::Dense(matrix.mapv(|v| v * scale)),
PenaltyMatrix::KroneckerFactored { left, right } => PenaltyMatrix::KroneckerFactored {
left: left.mapv(|v| v * scale),
right,
},
PenaltyMatrix::Blockwise {
local,
col_range,
total_dim,
} => PenaltyMatrix::Blockwise {
local: local.mapv(|v| v * scale),
col_range,
total_dim,
},
PenaltyMatrix::Labeled { label, inner } => PenaltyMatrix::Labeled {
label,
inner: Box::new(scale_multinomial_formula_penalty(*inner, scale)),
},
PenaltyMatrix::Fixed { log_lambda, inner } => PenaltyMatrix::Fixed {
log_lambda,
inner: Box::new(scale_multinomial_formula_penalty(*inner, scale)),
},
}
}
fn warm_start_blocks_from_log_lambdas(
blocks: &[crate::custom_family::ParameterBlockSpec],
log_lambdas: &[f64],
) -> Option<Vec<crate::custom_family::ParameterBlockSpec>> {
let total: usize = blocks.iter().map(|b| b.initial_log_lambdas.len()).sum();
if total == 0 || log_lambdas.len() != total {
return None;
}
if log_lambdas.iter().any(|v| !v.is_finite()) {
return None;
}
let mut warm = blocks.to_vec();
let mut offset = 0usize;
for block in warm.iter_mut() {
let k = block.initial_log_lambdas.len();
for slot in 0..k {
block.initial_log_lambdas[slot] = log_lambdas[offset + slot];
}
offset += k;
}
Some(warm)
}
pub fn fit_penalized_multinomial_formula(
data: &EncodedDataset,
formula: &str,
config: &FitConfig,
init_lambda: f64,
max_iter: usize,
tol: f64,
) -> Result<MultinomialSavedModel, EstimationError> {
if !(init_lambda.is_finite() && init_lambda > 0.0) {
crate::bail_invalid_estim!(
"multinomial fit: init_lambda must be finite and > 0 (got {init_lambda})"
);
}
let (raw_spec, design, y_col, response_name, y_kind) =
build_formula_design_for_multinomial(formula, data, config)?;
let spec = freeze_term_collection_from_design(&raw_spec, &design)?;
let class_levels = match y_kind {
ResponseColumnKind::Categorical { levels } => levels,
ResponseColumnKind::Binary => vec!["0".to_string(), "1".to_string()],
ResponseColumnKind::Numeric => {
crate::bail_invalid_estim!(
"multinomial fit: response '{response_name}' is numeric, not categorical; \
use family='gaussian'/'binomial'/... or convert the column to a categorical type"
);
}
};
if data.column_kinds.get(y_col) == Some(&ColumnKindTag::Binary) {
} else if data.column_kinds.get(y_col) != Some(&ColumnKindTag::Categorical) {
crate::bail_invalid_estim!(
"multinomial fit: response '{response_name}' must be a categorical column \
(got column kind {:?})",
data.column_kinds.get(y_col)
);
}
let (y_one_hot, _) = one_hot_categorical_response(data, y_col, &response_name)?;
let mut x_dense = design
.design
.try_to_dense_by_chunks("multinomial fit design")
.map_err(EstimationError::InvalidInput)?;
let parametric_standardization: Vec<(usize, f64, f64)> =
if design.coefficient_lower_bounds.is_some() || design.linear_constraints.is_some() {
Vec::new()
} else {
let p_total = x_dense.ncols();
let mut penalized = vec![false; p_total];
for bp in &design.penalties {
for col in bp.col_range.clone() {
if col < p_total {
penalized[col] = true;
}
}
}
let has_intercept = !design.intercept_range.is_empty();
let n_rows = x_dense.nrows().max(1) as f64;
let mut standardized = Vec::new();
for (_, range) in &design.linear_ranges {
for col in range.clone() {
if col >= p_total || penalized[col] {
continue;
}
let column = x_dense.column(col);
let mean = column.sum() / n_rows;
let var = column.iter().map(|v| (v - mean) * (v - mean)).sum::<f64>() / n_rows;
let scale = var.sqrt();
if !(scale.is_finite() && scale > 1e-8 * (mean.abs() + 1.0)) {
continue;
}
let center = if has_intercept { mean } else { 0.0 };
for v in x_dense.column_mut(col).iter_mut() {
*v = (*v - center) / scale;
}
standardized.push((col, center, scale));
}
}
standardized
};
let k = y_one_hot.ncols();
let m = k - 1;
let n_obs = y_one_hot.nrows();
let penalty_scale = multinomial_formula_penalty_scale(k);
let per_term_penalties: Vec<PenaltyMatrix> = design
.penalties_as_penalty_matrix()
.into_iter()
.map(|penalty| scale_multinomial_formula_penalty(penalty, penalty_scale))
.collect();
let per_term_nullspace_dims = design.nullspace_dims.clone();
let design_arc = Arc::new(x_dense);
let penalties_arc = Arc::new(per_term_penalties);
let nullspace_dims_arc = Arc::new(per_term_nullspace_dims);
let weights = Array1::<f64>::ones(n_obs);
let family = MultinomialFamily::new(
y_one_hot.clone(),
weights,
k,
design_arc.clone(),
penalties_arc.clone(),
nullspace_dims_arc.clone(),
)
.map_err(EstimationError::InvalidInput)?
.with_joint_jeffreys_term(false);
let mut blocks = family.build_block_specs();
let log_init = init_lambda.ln();
for spec_block in blocks.iter_mut() {
for v in spec_block.initial_log_lambdas.iter_mut() {
*v = log_init;
}
}
let total_rho_dim = m.saturating_mul(penalties_arc.len());
let use_outer_hessian = multinomial_formula_use_outer_hessian(total_rho_dim);
let outer_max_iter = max_iter.max(1);
let outer_tol = if tol.is_finite() && tol > 0.0 {
tol.max(MULTINOMIAL_OUTER_REML_TOL)
} else {
MULTINOMIAL_OUTER_REML_TOL
};
let outer_rel_cost_tol = Some(BlockwiseFitOptions::default().outer_tol);
let inner_tol = MULTINOMIAL_FORMULA_INNER_TOL.max(tol.max(0.0));
let options = BlockwiseFitOptions {
inner_max_cycles: crate::custom_family::DEFAULT_CUSTOM_FAMILY_INNER_MAX_CYCLES,
inner_tol,
outer_max_iter,
outer_tol,
outer_rel_cost_tol,
rho_lower_bound: multinomial_formula_min_lambda(y_one_hot.view()).ln(),
ridge_floor: MULTINOMIAL_FORMULA_RIDGE_FLOOR,
ridge_policy: gam_problem::RidgePolicy::solver_only(),
use_outer_hessian,
screen_initial_rho: false,
compute_covariance: true,
..BlockwiseFitOptions::default()
};
let mut unbiased_probe_options = options.clone();
unbiased_probe_options.outer_max_iter = unbiased_probe_options
.outer_max_iter
.min(MULTINOMIAL_UNBIASED_PROBE_OUTER_MAX_ITER);
let firth_refit_options = &options;
let run_firth_refit = |evidence: String| {
let firth_family = family.clone().with_joint_jeffreys_term(true);
fit_custom_family_with_rho_prior(
&firth_family,
&blocks,
firth_refit_options,
gam_problem::RhoPrior::Flat,
)
.map_err(|err| {
EstimationError::InvalidInput(format!(
"multinomial REML: Firth/Jeffreys-armed refit (separation evidence: \
{evidence}) failed: {err}"
))
})
};
let probe_attempt = fit_custom_family_with_rho_prior(
&family,
&blocks,
&unbiased_probe_options,
gam_problem::RhoPrior::Flat,
);
let fit = match probe_attempt {
Ok(probe_fit) => {
let separation = multinomial_formula_separation_evidence(&probe_fit.block_states);
if probe_fit.outer_converged && separation.is_none() {
probe_fit
} else if let Some(evidence) =
multinomial_formula_unresolved_probe_separation_evidence(&probe_fit.block_states)
{
run_firth_refit(format!(
"unbiased-criterion REML probe did not converge after {} outer iterations; {evidence}",
probe_fit.outer_iterations
))?
} else if separation.is_none() {
let warm_blocks = warm_start_blocks_from_log_lambdas(
&blocks,
probe_fit.log_lambdas.as_slice().unwrap_or(&[]),
);
let resolve_blocks = warm_blocks.as_deref().unwrap_or(&blocks);
match fit_custom_family_with_rho_prior(
&family,
resolve_blocks,
&options,
gam_problem::RhoPrior::Flat,
) {
Ok(full_unbiased_fit) => {
let full_separation = multinomial_formula_separation_evidence(
&full_unbiased_fit.block_states,
);
if full_unbiased_fit.outer_converged && full_separation.is_none() {
full_unbiased_fit
} else {
let evidence = full_separation.unwrap_or_else(|| {
format!(
"full unbiased-criterion REML solve did not converge after {} outer iterations",
full_unbiased_fit.outer_iterations
)
});
run_firth_refit(evidence)?
}
}
Err(err) => run_firth_refit(format!(
"full unbiased-criterion REML solve failed: {err}"
))?,
}
} else {
let evidence = separation.unwrap_or_else(|| {
format!(
"unbiased-criterion REML probe did not converge after {} outer iterations",
probe_fit.outer_iterations
)
});
run_firth_refit(evidence)?
}
}
Err(err) => run_firth_refit(format!("unbiased-criterion REML solve failed: {err}"))?,
};
if let Some(err) = multinomial_formula_separation_diagnostic(
fit.inner_cycles,
fit.outer_iterations,
&fit.block_states,
) {
return Err(err);
}
if fit.blocks.len() != m {
crate::bail_invalid_estim!(
"multinomial REML: expected {m} fitted blocks (K-1), got {}",
fit.blocks.len()
);
}
let p_per_class = fit.blocks[0].beta.len();
let mut coefficients_active = Array2::<f64>::zeros((p_per_class, m));
for (a, block) in fit.blocks.iter().enumerate() {
if block.beta.len() != p_per_class {
crate::bail_invalid_estim!(
"multinomial REML: block {a} has {} coefs, expected {p_per_class}",
block.beta.len()
);
}
for i in 0..p_per_class {
coefficients_active[[i, a]] = block.beta[i];
}
}
if !parametric_standardization.is_empty() {
let intercept_col = design.intercept_range.clone().next();
for a in 0..m {
let mut intercept_adjust = 0.0;
for &(col, center, scale) in ¶metric_standardization {
if col < p_per_class {
let raw = coefficients_active[[col, a]] / scale;
coefficients_active[[col, a]] = raw;
intercept_adjust += raw * center;
}
}
if let Some(i0) = intercept_col
&& i0 < p_per_class
{
coefficients_active[[i0, a]] -= intercept_adjust;
}
}
}
let lambdas_per_block: Vec<usize> = fit.blocks.iter().map(|b| b.lambdas.len()).collect();
let lambdas_flat: Vec<f64> = fit
.blocks
.iter()
.flat_map(|b| b.lambdas.iter().copied())
.collect();
let edf_per_class = fit.inference.as_ref().and_then(|info| {
let traces = &info.penalty_block_trace;
if traces.len() != lambdas_per_block.iter().sum::<usize>() {
return None;
}
let mut per_class = Vec::with_capacity(m);
let mut cursor = 0usize;
for &n_blocks in &lambdas_per_block {
let class_trace: f64 = traces[cursor..cursor + n_blocks].iter().sum();
per_class.push((p_per_class as f64 - class_trace).clamp(0.0, p_per_class as f64));
cursor += n_blocks;
}
Some(per_class)
});
let coefficients_flat: Vec<f64> = coefficients_active.iter().copied().collect();
let expected_joint = p_per_class.saturating_mul(m);
let intercept_col0 = design.intercept_range.clone().next();
let build_per_class_affine = |amat: &mut Array2<f64>| {
for &(col, center, scale) in ¶metric_standardization {
if col >= p_per_class {
continue;
}
amat[[col, col]] = 1.0 / scale;
if let Some(i0) = intercept_col0
&& i0 < p_per_class
{
amat[[i0, col]] = -center / scale;
}
}
};
let coefficient_covariance_flat = fit
.covariance_conditional
.as_ref()
.filter(|c| c.nrows() == expected_joint && c.ncols() == expected_joint)
.map(|cov_std| {
if parametric_standardization.is_empty() {
return cov_std.iter().copied().collect::<Vec<f64>>();
}
let mut a_joint = Array2::<f64>::eye(expected_joint);
let mut a_class = Array2::<f64>::eye(p_per_class);
build_per_class_affine(&mut a_class);
for a in 0..m {
let base = a * p_per_class;
for i in 0..p_per_class {
for j in 0..p_per_class {
a_joint[[base + i, base + j]] = a_class[[i, j]];
}
}
}
let cov_raw = a_joint.dot(cov_std).dot(&a_joint.t());
cov_raw.iter().copied().collect::<Vec<f64>>()
});
let coefficient_influence_flat = fit
.covariance_conditional
.as_ref()
.filter(|c| c.nrows() == expected_joint && c.ncols() == expected_joint)
.and_then(|hinv| {
if fit.blocks.len() != m {
return None;
}
let mut s_lambda = Array2::<f64>::zeros((expected_joint, expected_joint));
for (a, block) in fit.blocks.iter().enumerate() {
if block.lambdas.len() != penalties_arc.len() {
return None;
}
let base = a * p_per_class;
for (t, pen) in penalties_arc.iter().enumerate() {
let lam = block.lambdas[t];
if lam == 0.0 {
continue;
}
let dense = pen.to_dense();
if dense.nrows() != p_per_class || dense.ncols() != p_per_class {
return None;
}
for i in 0..p_per_class {
for j in 0..p_per_class {
s_lambda[[base + i, base + j]] += lam * dense[[i, j]];
}
}
}
}
let hinv_s = hinv.dot(&s_lambda);
let mut f = Array2::<f64>::eye(expected_joint);
f -= &hinv_s;
Some(f.iter().copied().collect::<Vec<f64>>())
});
let mut smooth_term_spans: Vec<MultinomialSmoothTermSpan> = Vec::new();
for (pen_idx, bp) in design.penalties.iter().enumerate() {
let col_start = bp.col_range.start;
let col_end = bp.col_range.end;
if col_start >= col_end || col_end > p_per_class {
continue;
}
if smooth_term_spans
.iter()
.any(|s| s.col_start == col_start && s.col_end == col_end)
{
continue;
}
let label = design
.penaltyinfo
.get(pen_idx)
.and_then(|info| info.termname.clone())
.unwrap_or_else(|| format!("s{pen_idx}"));
let nullspace_dim = design
.nullspace_dims
.get(pen_idx)
.copied()
.unwrap_or(0)
.min(col_end - col_start);
smooth_term_spans.push(MultinomialSmoothTermSpan {
label,
col_start,
col_end,
nullspace_dim,
});
}
let lambda_labels: Vec<String> = design
.penalties
.iter()
.enumerate()
.map(|(pen_idx, _)| penalty_component_label(design.penaltyinfo.get(pen_idx), pen_idx))
.collect();
let deviance = -2.0 * fit.log_likelihood;
Ok(MultinomialSavedModel {
formula: formula.to_string(),
class_levels: class_levels.clone(),
reference_class_index: class_levels.len() - 1,
resolved_termspec: spec,
coefficients_flat,
p_per_class,
n_active_classes: m,
training_headers: data.headers.clone(),
lambdas: lambdas_flat,
lambdas_per_block,
iterations: fit.inner_cycles,
converged: fit.outer_converged,
penalized_neg_log_likelihood: -fit.log_likelihood + 0.5 * fit.stable_penalty_term,
deviance,
edf_per_class,
coefficient_covariance_flat,
coefficient_influence_flat,
smooth_term_spans,
lambda_labels,
})
}
pub fn predict_multinomial_formula(
model: &MultinomialSavedModel,
data: &EncodedDataset,
) -> Result<Array2<f64>, EstimationError> {
let predict_columns = data.column_map();
let realigned = model.resolved_termspec.remap_feature_columns(
|index| -> Result<usize, EstimationError> {
let name = model.training_headers.get(index).ok_or_else(|| {
EstimationError::InvalidInput(format!(
"multinomial predict: saved training column index {index} is out of bounds \
for {} training headers",
model.training_headers.len()
))
})?;
resolve_role_col(&predict_columns, name, "feature")
.map_err(|err| EstimationError::InvalidInput(err.to_string()))
},
)?;
let design = build_term_collection_design(data.values.view(), &realigned).map_err(|err| {
EstimationError::InvalidInput(format!(
"multinomial predict: rebuild design from saved termspec: {err}"
))
})?;
let x_dense = design
.design
.try_to_dense_by_chunks("multinomial predict design")
.map_err(EstimationError::InvalidInput)?;
if x_dense.ncols() != model.p_per_class {
crate::bail_invalid_estim!(
"multinomial predict: predict design has {} cols, saved model expects {}",
x_dense.ncols(),
model.p_per_class
);
}
Ok(model.predict_probabilities(x_dense.view()))
}
pub fn predict_multinomial_formula_with_se(
model: &MultinomialSavedModel,
data: &EncodedDataset,
) -> Result<(Array2<f64>, Option<Array2<f64>>), EstimationError> {
let predict_columns = data.column_map();
let realigned = model.resolved_termspec.remap_feature_columns(
|index| -> Result<usize, EstimationError> {
let name = model.training_headers.get(index).ok_or_else(|| {
EstimationError::InvalidInput(format!(
"multinomial predict: saved training column index {index} is out of bounds \
for {} training headers",
model.training_headers.len()
))
})?;
resolve_role_col(&predict_columns, name, "feature")
.map_err(|err| EstimationError::InvalidInput(err.to_string()))
},
)?;
let design = build_term_collection_design(data.values.view(), &realigned).map_err(|err| {
EstimationError::InvalidInput(format!(
"multinomial predict: rebuild design from saved termspec: {err}"
))
})?;
let x_dense = design
.design
.try_to_dense_by_chunks("multinomial predict design")
.map_err(EstimationError::InvalidInput)?;
if x_dense.ncols() != model.p_per_class {
crate::bail_invalid_estim!(
"multinomial predict: predict design has {} cols, saved model expects {}",
x_dense.ncols(),
model.p_per_class
);
}
Ok(model.predict_probabilities_with_se(x_dense.view()))
}
#[cfg(test)]
mod fisher_override_tests {
use super::*;
use ndarray::Array3;
fn toy() -> (Array2<f64>, Array2<f64>, Array2<f64>, Array1<f64>) {
let n = 15;
let p = 2;
let k = 3;
let design =
Array2::<f64>::from_shape_fn(
(n, p),
|(i, j)| {
if j == 0 { 1.0 } else { ((i + 2) as f64).cos() }
},
);
let mut y = Array2::<f64>::zeros((n, k));
for i in 0..n {
y[[i, i % k]] = 1.0;
}
let penalty = Array2::<f64>::eye(p);
let lambdas = Array1::<f64>::from_elem(k - 1, 0.5);
(design, y, penalty, lambdas)
}
#[test]
fn fisher_override_none_reproduces_analytic() {
let (design, y, penalty, lambdas) = toy();
let mk = |over: Option<ndarray::ArrayView3<'_, f64>>| {
fit_penalized_multinomial(MultinomialFitInputs {
design: design.view(),
y_one_hot: y.view(),
penalty: penalty.view(),
lambdas: lambdas.view(),
row_weights: None,
fisher_w_override: over,
max_iter: 50,
tol: 1.0e-9,
})
.expect("fit must succeed")
};
let a = mk(None);
let b = mk(None);
for (x, z) in a
.coefficients_active
.iter()
.zip(b.coefficients_active.iter())
{
assert_eq!(x, z);
}
}
#[test]
fn fisher_override_wrong_shape_is_rejected() {
let (design, y, penalty, lambdas) = toy();
let n = design.nrows();
let m = y.ncols(); let bad = Array3::<f64>::zeros((n, m, m));
let err = fit_penalized_multinomial(MultinomialFitInputs {
design: design.view(),
y_one_hot: y.view(),
penalty: penalty.view(),
lambdas: lambdas.view(),
row_weights: None,
fisher_w_override: Some(bad.view()),
max_iter: 50,
tol: 1.0e-9,
})
.expect_err("wrong active-block shape must error");
assert!(format!("{err}").contains("fisher_w_override shape"));
}
#[test]
fn formula_outer_route_uses_exact_curvature_for_medium_d() {
assert!(
multinomial_formula_use_outer_hessian(8),
"D=8 loaded multinomial fits need exact curvature to avoid over-smoothed lambda caps"
);
assert!(
multinomial_formula_use_outer_hessian(12),
"D=12 (3 double-penalty smooth terms, K=3) stays on exact curvature"
);
}
#[test]
fn formula_outer_route_uses_exact_curvature_for_d16_penguin_fixture() {
assert!(
multinomial_formula_use_outer_hessian(16),
"D=16 multinomial fits need exact ARC curvature for the #1082 stall halt"
);
}
#[test]
fn formula_min_lambda_floor_is_continuous_and_information_scaled() {
fn floor_for_min_count(count: usize) -> f64 {
let n = 1000 + count;
let mut y = Array2::<f64>::zeros((n, 2));
for r in 0..1000 {
y[[r, 0]] = 1.0;
}
for r in 1000..n {
y[[r, 1]] = 1.0;
}
multinomial_formula_min_lambda(y.view())
}
let base = MULTINOMIAL_FORMULA_PRIOR_PSEUDO_OBS * MULTINOMIAL_FORMULA_FISHER_INFO_PER_OBS;
let sparse = MULTINOMIAL_FORMULA_SPARSE_PRIOR_PSEUDO_OBS_MAX
* MULTINOMIAL_FORMULA_FISHER_INFO_PER_OBS;
assert!(
(base - 2.0e-4).abs() < 1e-18,
"derived base floor must equal the calibrated 2e-4"
);
assert!(
(sparse - 1.0e-3).abs() < 1e-18,
"derived sparse floor must equal the calibrated 1e-3"
);
assert!((floor_for_min_count(50) - base).abs() < 1e-18);
assert!((floor_for_min_count(200) - base).abs() < 1e-18);
assert!((floor_for_min_count(10) - sparse).abs() < 1e-18);
assert!((floor_for_min_count(5) - sparse).abs() < 1e-18);
let f49 = floor_for_min_count(49);
let f50 = floor_for_min_count(50);
assert!(
f49 >= f50 && f49 <= f50 * 1.05,
"floor must be continuous across c0, got {f49} vs {f50}"
);
let f25 = floor_for_min_count(25);
assert!(
f25 > f50 && f25 < floor_for_min_count(10),
"mid-support floor must interpolate strictly between the two endpoints"
);
for &n_c in &[12usize, 16, 20, 30, 40] {
let expected = base * (MULTINOMIAL_FORMULA_SPARSE_REFERENCE_SUPPORT / n_c as f64);
assert!(
(floor_for_min_count(n_c) - expected).abs() < 1e-15,
"floor at n_c={n_c} must be τ·I₁·n_ref/n_c = {expected}, got {}",
floor_for_min_count(n_c)
);
}
assert!(
(floor_for_min_count(20) - 2.0 * floor_for_min_count(40)).abs() < 1e-15,
"floor must scale like 1/n_c (effective Fisher information) in the interior band"
);
}
#[test]
fn formula_penalty_scale_tracks_softmax_fisher_curvature() {
assert!(
(multinomial_formula_penalty_scale(2) - 0.5).abs() < 1.0e-12,
"binary-logit neutral-simplex curvature scale should remain at 1/2"
);
assert!(
(multinomial_formula_penalty_scale(3) - 4.0 / 9.0).abs() < 1.0e-12,
"three-class softmax penalties should be calibrated to 2*(K-1)/K^2"
);
assert!(
multinomial_formula_penalty_scale(5) < multinomial_formula_penalty_scale(3),
"active-class Fisher curvature decreases as the simplex gains classes"
);
}
#[test]
fn fixed_lambda_multinomial_reports_complete_separation() {
let n = 90;
let design = Array2::<f64>::from_shape_fn((n, 2), |(row, col)| match col {
0 => 1.0,
_ => -3.0 + 6.0 * (row as f64) / ((n - 1) as f64),
});
let mut y = Array2::<f64>::zeros((n, 3));
for row in 0..n {
let x = design[[row, 1]];
let class = if x < -1.0 {
0
} else if x > 1.0 {
1
} else {
2
};
y[[row, class]] = 1.0;
}
let penalty = Array2::<f64>::zeros((2, 2));
let lambdas = Array1::<f64>::zeros(2);
let err = fit_penalized_multinomial(MultinomialFitInputs {
design: design.view(),
y_one_hot: y.view(),
penalty: penalty.view(),
lambdas: lambdas.view(),
row_weights: None,
fisher_w_override: None,
max_iter: 80,
tol: 1.0e-12,
})
.expect_err("complete softmax separation must be a hard diagnostic");
assert!(
matches!(err, EstimationError::MultinomialSeparationDetected { .. }),
"expected MultinomialSeparationDetected, got {err:?}"
);
assert!(
err.to_string().contains("separation"),
"diagnostic should mention separation, got {err}"
);
assert!(
err.to_string().contains("active class-"),
"diagnostic should name the separated active class logit, got {err}"
);
assert!(
!err.to_string().contains("binary outcomes"),
"multinomial diagnostic must not reuse the binary separation text, got {err}"
);
}
#[test]
fn formula_multinomial_accepts_finite_saturated_logits() {
let saturated_states = vec![
ParameterBlockState {
beta: Array1::from_vec(vec![1.0, 2.0]),
eta: Array1::from_vec(vec![0.2, 4.0, -7.0]),
},
ParameterBlockState {
beta: Array1::from_vec(vec![-1.0, 3.0]),
eta: Array1::from_vec(vec![1.0, 25.5, -0.1]),
},
];
assert!(
multinomial_formula_separation_diagnostic(17, 9, &saturated_states).is_none(),
"a finite (even saturated, |eta|>25) formula optimum is a valid fit, \
not a separation diagnostic"
);
let blown_up = vec![
ParameterBlockState {
beta: Array1::from_vec(vec![1.0, 2.0]),
eta: Array1::from_vec(vec![0.2, 4.0, -7.0]),
},
ParameterBlockState {
beta: Array1::from_vec(vec![-1.0, 3.0]),
eta: Array1::from_vec(vec![1.0, f64::INFINITY, -0.1]),
},
];
let err = multinomial_formula_separation_diagnostic(17, 9, &blown_up)
.expect("a non-finite formula logit must raise the separation diagnostic");
assert!(
matches!(
err,
EstimationError::MultinomialSeparationDetected {
iteration: 17,
max_abs_eta,
active_class_index: 1,
row_index: 1,
} if !max_abs_eta.is_finite()
),
"expected typed multinomial separation diagnostic at the non-finite channel, got {err:?}"
);
}
#[test]
fn separation_evidence_gate_arms_firth_only_on_blowup() {
let interior = vec![
ParameterBlockState {
beta: Array1::from_vec(vec![1.0, 2.0]),
eta: Array1::from_vec(vec![0.2, 4.0, -7.0]),
},
ParameterBlockState {
beta: Array1::from_vec(vec![-1.0, 3.0]),
eta: Array1::from_vec(vec![1.0, -3.5, -0.1]),
},
];
assert!(
multinomial_formula_separation_evidence(&interior).is_none(),
"an interior finite mode must not arm the Firth refit"
);
let saturated = vec![
ParameterBlockState {
beta: Array1::from_vec(vec![1.0, 2.0]),
eta: Array1::from_vec(vec![0.2, 4.0, -7.0]),
},
ParameterBlockState {
beta: Array1::from_vec(vec![-1.0, 3.0]),
eta: Array1::from_vec(vec![1.0, 25.5, -0.1]),
},
];
assert!(
multinomial_formula_separation_evidence(&saturated).is_none(),
"a finite saturated formula-mode logit must not arm the Firth refit"
);
let blown_up = vec![ParameterBlockState {
beta: Array1::from_vec(vec![1.0, 2.0]),
eta: Array1::from_vec(vec![0.2, f64::NAN, -7.0]),
}];
let evidence = multinomial_formula_separation_evidence(&blown_up)
.expect("a non-finite logit is separation evidence");
assert!(
evidence.contains("non-finite logit") && evidence.contains("row 1"),
"evidence must name the non-finite logit, got {evidence}"
);
let near = vec![ParameterBlockState {
beta: Array1::from_vec(vec![1.0, 2.0]),
eta: Array1::from_vec(vec![0.2, 24.9, -24.9]),
}];
assert!(
multinomial_formula_separation_evidence(&near).is_none(),
"logits below the saturation threshold must not arm the Firth refit"
);
}
#[test]
fn unresolved_probe_evidence_arms_firth_on_saturated_finite_logits() {
let saturated = vec![
ParameterBlockState {
beta: Array1::from_vec(vec![1.0, 2.0]),
eta: Array1::from_vec(vec![0.2, 4.0, -7.0]),
},
ParameterBlockState {
beta: Array1::from_vec(vec![-1.0, 3.0]),
eta: Array1::from_vec(vec![1.0, 25.5, -0.1]),
},
];
assert!(
multinomial_formula_separation_evidence(&saturated).is_none(),
"a converged finite saturated formula optimum remains unbiased"
);
let evidence = multinomial_formula_unresolved_probe_separation_evidence(&saturated)
.expect("a non-converged saturated probe should arm the Firth refit");
assert!(
evidence.contains("separation-scale finite logit")
&& evidence.contains("row 1")
&& evidence.contains("active class 1"),
"unresolved-probe evidence should name the saturated channel, got {evidence}"
);
let near = vec![ParameterBlockState {
beta: Array1::from_vec(vec![1.0, 2.0]),
eta: Array1::from_vec(vec![0.2, 24.9, -24.9]),
}];
assert!(
multinomial_formula_unresolved_probe_separation_evidence(&near).is_none(),
"finite logits below the separation threshold still get the full unbiased retry"
);
}
#[test]
fn scaled_fisher_override_changes_first_step() {
let (design, y, penalty, lambdas) = toy();
let n = design.nrows();
let m = y.ncols() - 1;
let pk = 1.0 / (y.ncols() as f64);
let mut over = Array3::<f64>::zeros((n, m, m));
for row in 0..n {
for a in 0..m {
for b in 0..m {
let analytic = if a == b { pk * (1.0 - pk) } else { -pk * pk };
over[[row, a, b]] = 4.0 * analytic;
}
}
}
let scaled = fit_penalized_multinomial(MultinomialFitInputs {
design: design.view(),
y_one_hot: y.view(),
penalty: penalty.view(),
lambdas: lambdas.view(),
row_weights: None,
fisher_w_override: Some(over.view()),
max_iter: 1,
tol: 1.0e-9,
})
.expect("override fit must succeed");
let analytic = fit_penalized_multinomial(MultinomialFitInputs {
design: design.view(),
y_one_hot: y.view(),
penalty: penalty.view(),
lambdas: lambdas.view(),
row_weights: None,
fisher_w_override: None,
max_iter: 1,
tol: 1.0e-9,
})
.expect("analytic fit must succeed");
let differs = scaled
.coefficients_active
.iter()
.zip(analytic.coefficients_active.iter())
.any(|(a, b)| (a - b).abs() > 1.0e-6);
assert!(differs, "scaled curvature must change the first step");
}
}
#[cfg(test)]
mod reference_class_invariance_tests {
use super::*;
use gam_data::load_dataset_projected;
use std::fmt::Write as _;
use std::fs;
use tempfile::tempdir;
struct SplitMix64(u64);
impl SplitMix64 {
fn next_u64(&mut self) -> u64 {
self.0 = self.0.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = self.0;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
fn unit(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
}
}
fn sample_classes(seed: u64, n: usize) -> (Vec<f64>, Vec<usize>) {
let mut rng = SplitMix64(seed.wrapping_add(0x1234_5678));
let mut x = Vec::with_capacity(n);
let mut cls = Vec::with_capacity(n);
for _ in 0..n {
let xi = -2.0 + 4.0 * rng.unit();
let eta = [0.5 + 0.8 * xi, -0.3 - 0.5 * xi, 0.0];
let mut p = [eta[0].exp(), eta[1].exp(), eta[2].exp()];
let s: f64 = p.iter().sum();
for v in &mut p {
*v /= s;
}
let u = rng.unit();
let c = if u < p[0] {
0
} else if u < p[0] + p[1] {
1
} else {
2
};
x.push(xi);
cls.push(c);
}
(x, cls)
}
fn dataset_xy(dir: &std::path::Path, tag: &str, x: &[f64], y: &[String]) -> gam_data::EncodedDataset {
let path = dir.join(format!("data_{tag}.csv"));
let mut csv = String::from("x,y\n");
for (xi, yi) in x.iter().zip(y.iter()) {
writeln!(csv, "{xi},{yi}").unwrap();
}
fs::write(&path, csv).expect("write training csv");
load_dataset_projected(&path, &["x".to_string(), "y".to_string()])
.expect("load training dataset")
}
fn fit_predict_aligned(
dir: &std::path::Path,
tag: &str,
x: &[f64],
cls: &[usize],
name_map: [&str; 3],
grid: &[f64],
) -> Array2<f64> {
let labels: Vec<String> = cls.iter().map(|&c| name_map[c].to_string()).collect();
let train = dataset_xy(dir, tag, x, &labels);
let config = FitConfig::default();
let model = fit_penalized_multinomial_formula(&train, "y ~ s(x)", &config, 1.0, 60, 1e-6)
.expect("multinomial formula fit must succeed");
let grid_y: Vec<String> = grid.iter().map(|_| name_map[0].to_string()).collect();
let grid_ds = dataset_xy(dir, &format!("{tag}_grid"), grid, &grid_y);
let probs = predict_multinomial_formula(&model, &grid_ds)
.expect("multinomial predict must succeed");
let mut sorted: Vec<&str> = name_map.to_vec();
sorted.sort_unstable();
let col_of_orig: Vec<usize> = (0..3)
.map(|c| sorted.iter().position(|l| *l == name_map[c]).unwrap())
.collect();
assert_eq!(
model.class_levels,
sorted.iter().map(|s| s.to_string()).collect::<Vec<_>>(),
"class_levels must be the sorted label order"
);
let n = grid.len();
let mut aligned = Array2::<f64>::zeros((n, 3));
for r in 0..n {
for c in 0..3 {
aligned[[r, c]] = probs[[r, col_of_orig[c]]];
}
}
aligned
}
fn max_abs_diff(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
a.iter()
.zip(b.iter())
.map(|(p, q)| (p - q).abs())
.fold(0.0_f64, f64::max)
}
#[test]
#[ignore = "gam#1587: the production multinomial REML path still applies the \
reference-anchored per-block (ALR/Diagonal) smoothing penalty, so this \
invariance assertion FAILS (cross-labeling drift ~1e-2 ≫ 1e-3). The \
`MultinomialFamily::joint_penalty_specs()` hook that returns the \
reference-symmetric centered `M⊗S_t` penalty is defined but not yet \
consumed by the custom-family outer REML loop (no call sites in \
gam-custom-family). Un-ignore once that hook is wired through the outer \
ρ-layout + per-eval JointPenaltyBundle + outer penalty_coords/logdet. \
Also slow (~minutes): an opt-in end-to-end fit guard, not a fast CI unit."]
fn multinomial_fit_is_invariant_to_reference_class_1587() {
let td = tempdir().expect("tempdir");
let dir = td.path();
let (x, cls) = sample_classes(0, 300);
let grid: Vec<f64> = (0..7).map(|i| -1.5 + 3.0 * (i as f64) / 6.0).collect();
let a = fit_predict_aligned(dir, "abc", &x, &cls, ["A", "B", "C"], &grid);
let b = fit_predict_aligned(dir, "bca", &x, &cls, ["B", "C", "A"], &grid);
let c = fit_predict_aligned(dir, "cab", &x, &cls, ["C", "A", "B"], &grid);
let a2 = fit_predict_aligned(dir, "abc2", &x, &cls, ["A", "B", "C"], &grid);
let refit_noise = max_abs_diff(&a, &a2);
assert!(
refit_noise < 1e-6,
"refitting the same labeling must be deterministic (got {refit_noise:.3e})"
);
let drift = max_abs_diff(&a, &b)
.max(max_abs_diff(&a, &c))
.max(max_abs_diff(&b, &c));
assert!(
drift < 1e-3,
"predicted probabilities must be invariant to the reference class; \
cross-labeling drift = {drift:.3e} (refit noise = {refit_noise:.3e})"
);
}
}