gam_models/multinomial.rs
1//! Penalized multinomial-logit (softmax) GLM driver — fixed-λ inner solve.
2//!
3//! This is the principled vector-response companion to the scalar PIRLS path:
4//! the inner-loop Newton solver for a multi-class GAM at fixed smoothing
5//! parameters λ, using the canonical multinomial-logit likelihood
6//! ([`MultinomialLogitLikelihood`]) and the existing dense block-Fisher
7//! assembly in [`gam_solve::pirls::dense_block_xtwx`] /
8//! [`gam_solve::pirls::dense_block_xtwy`].
9//!
10//! # What this module does
11//!
12//! Solve, for the reference-coded multinomial-logit GAM with `K` classes and
13//! design matrix `X ∈ ℝ^{N×P}`,
14//!
15//! ```text
16//! β̂ = argmin_β { − log L(β) + ½ Σ_{a=0}^{K-2} λ_a · β_a^T S β_a }
17//! ```
18//!
19//! where `β = [β_0; β_1; …; β_{K-2}]` is the stacked coefficient vector in
20//! output-major order (`β_a ∈ ℝ^P` is the coefficient block for class `a`),
21//! `S ∈ ℝ^{P×P}` is the smoothing penalty matrix (shared across classes,
22//! replicated as `I_{K-1} ⊗ S` over the full parameter space), and `λ_a` is
23//! a per-class smoothing parameter.
24//!
25//! The likelihood uses class `K - 1` as the reference (`η_{K-1} ≡ 0`), so the
26//! softmax gauge is fixed at the η level and no additional sum-to-zero
27//! projection is required.
28//!
29//! # Layering
30//!
31//! * **Fixed-λ inner solve** — [`fit_penalized_multinomial`] is the canonical
32//! coefficient-space Newton solver at *given* smoothing parameters `λ`,
33//! built on the shared [`crate::penalized_vector_glm`] engine.
34//!
35//! * **REML / LAML smoothing-parameter selection** — [`fit_penalized_multinomial_formula`]
36//! routes through [`crate::custom_family::fit_custom_family_with_rho_prior`]
37//! so the per-active-class `λ_a` are selected by the outer REML/LAML loop;
38//! the caller's `init_lambda` is only a warm-start seed. The multinomial
39//! [`crate::multinomial_reml::MultinomialFamily`] `CustomFamily`
40//! impl calls the fixed-λ math above as its inner solve at each ρ trial and
41//! supplies the dense per-row Hessian block for the outer trace terms.
42//!
43//! * **Formula → design integration** — `build_formula_design_for_multinomial`
44//! parses the Wilkinson formula and assembles `X` and the per-term `S`
45//! blocks; the `fit_multinomial_formula_pyfunc` FFI shim wires the Python
46//! `gamfit.fit(..., family='multinomial')` entry straight to this path.
47//!
48//! # Convergence
49//!
50//! The damped-Newton-with-backtracking scaffold lives once in the shared
51//! [`crate::penalized_vector_glm`] engine: at each iteration the
52//! assembled penalized Hessian `H + I_{K-1} ⊗ (λ_a S)` is factored via faer's
53//! symmetric-PD-with-fallback path, the full Newton step `δ = −H^{-1} ∇F` is
54//! computed, and accepted with step halving if the objective fails to decrease
55//! (up to a small backtracking budget). The convergence test is the relative
56//! coefficient step norm `‖δ‖ / (1 + ‖β‖) ≤ tol`, matching the existing pyffi
57//! reference path. This module is the softmax adapter over that engine: it
58//! supplies the dense `(K-1)×(K-1)` Fisher block, the residual, and the
59//! log-likelihood through [`MultinomialLogitLikelihood`], and owns the
60//! class-count / simplex preconditions. The independent-binomial sibling
61//! [`crate::binomial_multi`] is the same engine with a row-diagonal
62//! Fisher block instead.
63
64use crate::custom_family::{
65 BlockwiseFitOptions, ParameterBlockState, PenaltyMatrix, fit_custom_family_with_rho_prior,
66};
67use crate::multinomial_reml::MultinomialFamily;
68use crate::penalized_vector_glm::{PenalizedVectorGlmInputs, fit_penalized_vector_glm};
69use crate::vector_response::{MultinomialLogitLikelihood, validate_multinomial_simplex};
70use gam_terms::inference::formula_dsl::parse_formula;
71use crate::model_types::EstimationError;
72use crate::fit_orchestration::{
73 FitConfig, build_termspec_with_geometry_and_overrides, resolved_resource_policy,
74};
75use gam_terms::smooth::{
76 PenaltyBlockInfo, TermCollectionDesign, TermCollectionSpec, build_term_collection_design,
77};
78use crate::fit_orchestration::drivers::freeze_term_collection_from_design;
79use gam_terms::term_builder::resolve_role_col;
80use gam_problem::ResponseColumnKind;
81use gam_data::ColumnKindTag;
82use gam_data::EncodedDataset;
83use gam_runtime::resource::ProblemHints;
84use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayView3};
85use serde::{Deserialize, Serialize};
86use std::sync::Arc;
87
88/// Solver-only numerical stabilization floor for the formula-driven
89/// multinomial REML inner solve (gam#747).
90///
91/// Installed with [`RidgePolicy::solver_only`](gam_problem::RidgePolicy::solver_only)
92/// so it stabilizes the inner joint-Newton **linear solve** but never enters
93/// the REML objective, the penalty log-determinant, or the Laplace Hessian.
94///
95/// What it does: the multinomial smoothing penalties are rank-deficient by
96/// design (each smooth carries an unpenalized polynomial null space) and the
97/// formula may add a fully unpenalized parametric term (`x3` / `body_mass`). On
98/// near-separable hard labels the softmax curvature is ill-conditioned along
99/// those directions, so the bare Newton step `H⁻¹∇` is huge. Lifting the
100/// smallest Hessian eigenvalue to `δ` bounds the step (`‖(H+δI)⁻¹∇‖ ≤ ‖∇‖/δ`),
101/// keeping the screening iterates finite without poisoning the softmax with
102/// `inf − inf = NaN`.
103///
104/// What it deliberately does NOT do: it adds no `½·δ·‖β‖²` term to the
105/// objective and no `δ`-shift to the REML log-determinant. The earlier
106/// `explicit_stabilization_pospart` policy folded both into the criterion,
107/// which made `1e-4` a fixed-λ Gaussian prior that shrank every identified
108/// coefficient off the MLE and biased smoothing-parameter selection — a value
109/// that had to be tuned *between* under-stabilization (NaN seeds) and
110/// over-shrinkage (lost VGAM match). As a solver-only floor that tradeoff is
111/// gone: the over-shrinkage failure mode cannot occur (nothing is shrunk), the
112/// optimized objective is the true penalized REML criterion, and the floor
113/// only has to be large enough to keep the linear algebra finite.
114///
115/// The separation defect (#753) is no longer this floor's job. If the
116/// multinomial MLE is genuinely at infinity for an unpenalized/null-space
117/// direction (complete/quasi-complete separation), no solver floor makes that
118/// direction's estimate finite. The formula REML path arms the full-span
119/// Jeffreys/Firth correction CONDITIONALLY — only on separation evidence (see
120/// [`multinomial_formula_separation_evidence`] and the two-attempt logic in
121/// [`fit_penalized_multinomial_formula`]) — so an interior, well-identified fit
122/// optimizes the unbiased penalized-REML criterion with no Firth shrinkage
123/// toward the uniform simplex, while a (quasi-)separated geometry gets the
124/// proper prior that is the only thing able to bound its penalty-null
125/// directions (#715 real-data arm). The bare fixed-λ inner driver
126/// [`fit_penalized_multinomial`] (no outer REML, no Jeffreys term) surfaces the
127/// explicit `MultinomialSeparationDetected` diagnostic for the path that has no
128/// proper prior to lean on.
129const MULTINOMIAL_FORMULA_RIDGE_FLOOR: f64 = 1.0e-4;
130
131/// Inner joint-Newton KKT tolerance for the multinomial formula path.
132///
133/// The softmax Fisher weight `W = diag(p) − ppᵀ` collapses on saturated rows,
134/// so near-separable fits (penguins, #715) reach the OBJECTIVE's f64 noise
135/// floor before the default `inner_tol = 1e-6` KKT target: measured on the
136/// penguins arm (standardized columns), the trust region collapses to 1e-12
137/// with per-attempt objective changes of ~+2e-9 on |obj| ≈ 1e2 (≈ 1e-11
138/// relative — pure rounding) while the KKT residual plateaus at 2.8e-5–9.4e-5
139/// against a scaled tolerance of ~1.9e-5. Demanding a residual below the
140/// floating-point noise floor is certifiable-never: every eval is rejected by
141/// the stall guard and the whole fit fails. `1e-5` certifies the measured
142/// plateaus while still resolving β to ~1e-6 in the relevant metric — the
143/// LAML criterion consumes β̂ with error O(residual²/curvature), far below
144/// any quantity the outer ρ-search can read.
145const MULTINOMIAL_FORMULA_INNER_TOL: f64 = 1.0e-5;
146
147/// Formula-adapter penalty calibration for multinomial softmax REML.
148///
149/// The term builder's normalized penalties are calibrated on single-response
150/// Gaussian-style score curvature. A reference-coded softmax class block sees
151/// per-row active-class Fisher diagonal `p_a(1-p_a)` plus negative cross-class
152/// coupling. At the neutral simplex (`p_k = 1/K`) the active diagonal is
153/// `(K-1)/K²`, so the binary-logit calibration is `2·(K-1)/K² = 1/2` and the
154/// three-class calibration is `4/9` rather than the historical hard-coded
155/// `1/2`. Making the scale a function of `K` keeps the physical smoothness
156/// prior tied to the likelihood curvature instead of over-penalizing every
157/// class as the simplex gains categories.
158fn multinomial_formula_penalty_scale(n_classes: usize) -> f64 {
159 let k = n_classes.max(2) as f64;
160 2.0 * (k - 1.0) / (k * k)
161}
162
163/// Largest smoothing-parameter dimension where exact dense outer curvature is
164/// still worth paying for multinomial formula fits.
165///
166/// `D = (K - 1) * n_penalties`. Medium-size loaded models use exact curvature
167/// so the optimizer does not wander into over-smoothed lambda caps on
168/// near-boundary softmax surfaces. The threshold was originally calibrated at
169/// `D <= 6` when each `s()` term carried ONE penalty; the double-penalty
170/// migration (wiggliness + null-space shrinkage per term, mgcv `select=TRUE`
171/// semantics) doubled `D` for the SAME models, silently flipping the
172/// reference formula fits (2 smooths, K = 3: old `D = 4`, now `D = 8`) onto
173/// the gradient-only route — where the #715 quality arm showed every
174/// wiggliness ρ driven onto the ±10 box bound (smooths collapsed toward their
175/// polynomial null space, truth-RMSE behind VGAM). `12 = 2 × 6` preserves the
176/// original classification boundary under the doubled penalty count while
177/// keeping the four-smooth penguin species quality fixture on the exact ARC
178/// path: that model is `D = 16`, and first-order BFGS can cycle along the
179/// near-separable lambda-to-zero ridge until the wall-clock budget expires
180/// (#1082). ARC observes the same exact curvature and can halt through the
181/// bound-aware cost-stall guard once the REML surface stops making useful
182/// progress.
183const MULTINOMIAL_EXACT_OUTER_HESSIAN_MAX_DIM: usize = 16;
184
185fn multinomial_formula_use_outer_hessian(total_rho_dim: usize) -> bool {
186 total_rho_dim <= MULTINOMIAL_EXACT_OUTER_HESSIAN_MAX_DIM
187}
188
189/// Logit magnitude beyond which fitted probabilities are saturated at ordinary
190/// double precision diagnostic scale. The bare fixed-λ driver has no outer REML
191/// state and still uses this threshold to reject a non-converged saturated
192/// iterate as a separation artifact. The formula REML path does not use this as
193/// a Firth trigger: with smoothing parameters selected, a finite saturated
194/// surface can be the valid near-separated optimum that should be scored
195/// directly.
196const MULTINOMIAL_SEPARATION_ETA_THRESHOLD: f64 = 25.0;
197
198/// Calibrated convergence tolerance for the OUTER REML/LAML smoothing-parameter
199/// search on the formula multinomial path. Matches the primary GLM REML outer
200/// (`solver::fit_orchestration::materialize` uses `tol = 1e-7`, mirrored by the
201/// `LOG_LAMBDA_TOL` / `KKT_TOL_*` constants across the REML stack): tight enough
202/// that the selected λ reaches the genuine REML optimum (the recovered
203/// probability surface matches the mature reference), loose enough that the
204/// optimizer does not grind surface-irrelevant ρ digits down to the inner KKT
205/// scale (the #1082 wall-clock overrun). The caller's `tol` is floored at this
206/// value for the OUTER loop, while it continues to drive the INNER joint-Newton
207/// KKT target unchanged.
208const MULTINOMIAL_OUTER_REML_TOL: f64 = 1e-7;
209
210/// The first multinomial formula solve is a separation probe: it is accepted
211/// when the unbiased REML criterion converges to a finite interior iterate.
212/// Near-separable data such as the penguin fixture otherwise spend the caller's
213/// full outer budget on an iterate that is discarded before the Firth/Jeffreys
214/// refit. Keep enough iterations for ordinary interior fits to certify quickly,
215/// but hand slow/non-interior probes to the proper-prior refit promptly.
216const MULTINOMIAL_UNBIASED_PROBE_OUTER_MAX_ITER: usize = 20;
217
218/// Per-observation softmax Fisher-information scale for the λ-floor units.
219///
220/// The penalty enters the criterion as `½ λ βᵀ S β` with a Frobenius-normalized
221/// `S` (`‖S‖_F = 1`, see the term-builder calibration referenced by
222/// [`multinomial_formula_penalty_scale`]), so the ridge `λ S` is directly
223/// comparable to data Fisher information. One observation contributes softmax
224/// information `p(1−p)` in a class's logit direction, which is bounded by the
225/// logistic peak `p(1−p) ≤ ¼` at `p = ½`. Using this maximal per-observation
226/// information as the unit makes the floor's strength interpretable as a count
227/// of equivalent **pseudo-observations** of prior: a ridge that equals
228/// `τ · ¼ · ‖S‖_F` carries the same logit-direction curvature as `τ` real rows
229/// sitting at the most-informative point of the likelihood. This scale is
230/// `K`-independent on purpose — the `K`-dependence of the softmax block
231/// curvature already lives in the penalty matrix via
232/// [`multinomial_formula_penalty_scale`], so the floor (a bound on the
233/// multiplier of that already-scaled penalty) must not double-count it.
234const MULTINOMIAL_FORMULA_FISHER_INFO_PER_OBS: f64 = 0.25;
235
236/// Target prior strength of the λ-floor, in pseudo-observations, for a
237/// WELL-SUPPORTED class. The floor holds the unbiased REML optimizer off the
238/// zero-penalty boundary (where a boundary-overfit smooth or a Firth switch on
239/// finite data would otherwise be accepted) with a prior worth a fixed small
240/// fraction of one observation. `8e-4` pseudo-observations reproduces the
241/// previously fixture-calibrated large-support floor `τ · ¼ = 2e-4` exactly at
242/// the calibration point, now expressed as an effective-prior-strength rather
243/// than a tuned λ value.
244const MULTINOMIAL_FORMULA_PRIOR_PSEUDO_OBS: f64 = 8.0e-4;
245
246/// Reference class support `n_ref`: the effective sample size per class at which
247/// the data Fisher information `n_c · I₁` is large enough that the floor sits at
248/// its well-supported value. Below `n_ref` the per-class data information shrinks
249/// like `n_c`, so to keep the floor's prior from vanishing *relative to* that
250/// shrinking data the effective pseudo-observation count is scaled up by
251/// `n_ref / n_c` (the prior is held to a fixed fraction of the data information,
252/// not a fixed absolute λ). At `n_c = n_ref` the scale is exactly 1.
253const MULTINOMIAL_FORMULA_SPARSE_REFERENCE_SUPPORT: f64 = 50.0;
254
255/// Cap on the floor's prior strength in the very-sparse limit, in
256/// pseudo-observations. As `n_c → 0` the `n_ref / n_c` scaling diverges; the cap
257/// holds the prior at `4e-3` pseudo-observations (`τ_max · ¼ = 1e-3` at the
258/// calibration point, the previously-tuned strong-floor value) so the floor
259/// stays a proper prior rather than a hard constraint that would dominate the
260/// likelihood for a handful-of-rows class.
261const MULTINOMIAL_FORMULA_SPARSE_PRIOR_PSEUDO_OBS_MAX: f64 = 4.0e-3;
262
263/// Continuous, Fisher-information-scaled lower λ floor for the formula path,
264/// derived from the minority class's effective sample size `n_c`.
265///
266/// # Derivation (effective-prior-strength / Fisher geometry)
267///
268/// The penalty `½ λ βᵀ S β` with `‖S‖_F = 1` adds curvature `λ` to the class
269/// logit direction; one observation adds at most `I₁ = ¼` there. So a floor that
270/// sets `λ_floor = τ_eff · I₁` gives the smooth a prior worth `τ_eff`
271/// pseudo-observations. We want a fixed *absolute* prior `τ` for a well-supported
272/// class, but for a minority class with only `n_c` effective observations the
273/// data information in its block is `n_c · I₁`; holding the prior to a fixed
274/// *fraction* of that shrinking data information requires
275///
276/// ```text
277/// τ_eff(n_c) = τ · max(1, n_ref / n_c), clamped to [τ, τ_max]
278/// λ_floor(n_c) = τ_eff(n_c) · I₁
279/// ```
280///
281/// This is the *same* `base · max(1, c0/c)` envelope as before — but `base`,
282/// `sparse`, and `c0` are no longer fixture-tuned magic numbers: `base = τ·I₁`,
283/// `sparse = τ_max·I₁`, and `c0 = n_ref` are an effective-prior-strength of
284/// `τ`/`τ_max` pseudo-observations against the maximal per-observation softmax
285/// information `I₁ = ¼`. Properties preserved by construction:
286/// * reduces EXACTLY to `τ·I₁` for well-supported classes (`n_c ≥ n_ref`);
287/// * reduces EXACTLY to `τ_max·I₁` for very sparse classes
288/// (`n_c ≤ n_ref·τ/τ_max`, here `n_c ≤ 10`);
289/// * interpolates monotonically and continuously between them in the middle —
290/// no cliff at `n_c = n_ref`.
291/// At the calibration point the endpoints equal the previous `2e-4` / `1e-3`, so
292/// fixtures whose smallest class has `n_c ≥ 50` (penguins, the vgam softmax
293/// arms) are unaffected — they sit at `τ·I₁ = 2e-4` exactly as before.
294fn multinomial_formula_min_lambda(y_one_hot: ArrayView2<'_, f64>) -> f64 {
295 let base = MULTINOMIAL_FORMULA_PRIOR_PSEUDO_OBS * MULTINOMIAL_FORMULA_FISHER_INFO_PER_OBS;
296 let sparse =
297 MULTINOMIAL_FORMULA_SPARSE_PRIOR_PSEUDO_OBS_MAX * MULTINOMIAL_FORMULA_FISHER_INFO_PER_OBS;
298 let min_class_count = (0..y_one_hot.ncols())
299 .map(|class| y_one_hot.column(class).sum())
300 .fold(f64::INFINITY, f64::min);
301 if !min_class_count.is_finite() || min_class_count <= 0.0 {
302 return base;
303 }
304 // Effective pseudo-observation prior strength: held to a fixed fraction of
305 // the shrinking per-class data information once n_c falls below n_ref.
306 let pseudo_obs_scale =
307 (MULTINOMIAL_FORMULA_SPARSE_REFERENCE_SUPPORT / min_class_count).max(1.0);
308 (base * pseudo_obs_scale).clamp(base, sparse)
309}
310
311fn max_abs_eta_location(eta: ArrayView2<'_, f64>) -> (f64, usize, usize) {
312 let mut best = (0.0_f64, 0usize, 0usize);
313 for ((row, active_class), &value) in eta.indexed_iter() {
314 let abs = value.abs();
315 if abs > best.0 {
316 best = (abs, row, active_class);
317 }
318 }
319 best
320}
321
322/// Separation gate for the REML/LAML **formula** path.
323///
324/// Unlike the bare fixed-λ driver [`fit_penalized_multinomial`] (which has no
325/// outer REML state and so must reject a saturated, non-converged iterate as a
326/// separation artifact at the [`MULTINOMIAL_SEPARATION_ETA_THRESHOLD`] logit
327/// magnitude), the formula path can return a finite saturated mode after the
328/// coupled outer optimizer has selected smoothing parameters. A `|η| >= 25`
329/// gate is therefore wrong here: the penguins arm can legitimately have large
330/// fitted logits while still producing finite probabilities and a usable REML
331/// mode.
332///
333/// Only a genuinely NON-FINITE `η` (a NaN/Inf blow-up in the inner linear
334/// algebra) is a real formula-path failure. A finite, even saturated, `η` is
335/// accepted so the truth-recovery / match-or-beat bars are evaluated against the
336/// actual fitted surface instead of an adapter diagnostic.
337fn multinomial_formula_separation_diagnostic(
338 inner_cycles: usize,
339 outer_iterations: usize,
340 block_states: &[ParameterBlockState],
341) -> Option<EstimationError> {
342 let mut nonfinite: Option<(f64, usize, usize)> = None;
343 for (active_class, state) in block_states.iter().enumerate() {
344 for (row, &value) in state.eta.iter().enumerate() {
345 if !value.is_finite() {
346 nonfinite = Some((value, row, active_class));
347 break;
348 }
349 }
350 if nonfinite.is_some() {
351 break;
352 }
353 }
354 nonfinite.map(|(value, row_index, active_class_index)| {
355 EstimationError::MultinomialSeparationDetected {
356 iteration: inner_cycles.max(outer_iterations),
357 max_abs_eta: value.abs(),
358 active_class_index,
359 row_index,
360 }
361 })
362}
363
364/// Separation EVIDENCE gate for the conditional Firth/Jeffreys engagement on
365/// the formula REML path (#715 / #753).
366///
367/// The structural mathematics (#715 issue thread): for any coefficient
368/// direction `v` with `S v = 0` (a penalty-null direction — intercept, a
369/// smooth's polynomial null component, an unpenalized parametric term), the
370/// penalized joint Hessian satisfies `(H + S_λ) v = H v` for EVERY smoothing
371/// parameter ρ. When the data (quasi-)separate, the softmax Fisher weight
372/// `W = diag(p) − p pᵀ → 0` on the saturated rows, so `H v = JᵀWJ v → 0` along
373/// the penalty-null directions those rows support: `(H + S_λ) v ≈ 0` for every
374/// ρ — NO λ can repair it, the inner Newton can never certify a KKT point
375/// there, and every outer REML startup seed is rejected (the penguins
376/// real-data arm). The only principled cure is a PROPER prior on that
377/// quotient-null subspace — the Jeffreys/Firth term `Φ = ½ log|ZᵀHZ|`, whose
378/// Gauss–Newton curvature supplies the missing `O(1)` bound.
379///
380/// But the Firth prior is not free on interior data: unconditionally armed, it
381/// shrinks fitted class probabilities toward the uniform simplex `1/K`
382/// (an `O(1/n)` pull that the synthetic match-or-beat arm of #715 measured as
383/// a real truth-RMSE loss vs the unbiased criterion). So the formula path
384/// engages it ONLY on separation evidence, mirroring the #753 "diagnose, then
385/// arm" split:
386///
387/// * a NON-FINITE logit — the inner linear algebra blew up along an unbounded
388/// direction.
389///
390/// Returns `Some(description)` naming the witnessing logit when evidence is
391/// found, `None` for a finite fit (which is then accepted as-is, with zero
392/// Firth bias). A FAILED unbiased solve (`Err` from the rho-prior driver, e.g.
393/// "no startup seed passed") is the second evidence form and is handled
394/// directly at the call site in [`fit_penalized_multinomial_formula`].
395fn multinomial_formula_separation_evidence(block_states: &[ParameterBlockState]) -> Option<String> {
396 for (active_class, state) in block_states.iter().enumerate() {
397 for (row, &value) in state.eta.iter().enumerate() {
398 if !value.is_finite() {
399 return Some(format!(
400 "non-finite logit eta[row {row}, active class {active_class}] = {value}"
401 ));
402 }
403 }
404 }
405 None
406}
407
408/// Extra evidence used only for a NON-CONVERGED capped unbiased probe.
409///
410/// A converged finite saturated formula fit is still a valid optimum and must be
411/// scored without Firth bias. A capped probe that failed to converge while it
412/// already carries separation-scale logits is different: spending the full
413/// unbiased outer budget on the same lambda-to-zero surface is the #1082
414/// timeout. Route that case straight to the proper-prior refit.
415fn multinomial_formula_unresolved_probe_separation_evidence(
416 block_states: &[ParameterBlockState],
417) -> Option<String> {
418 if let Some(evidence) = multinomial_formula_separation_evidence(block_states) {
419 return Some(evidence);
420 }
421
422 let mut best = (0.0_f64, 0usize, 0usize);
423 for (active_class, state) in block_states.iter().enumerate() {
424 for (row, &value) in state.eta.iter().enumerate() {
425 let abs = value.abs();
426 if abs > best.0 {
427 best = (abs, row, active_class);
428 }
429 }
430 }
431 if best.0 >= MULTINOMIAL_SEPARATION_ETA_THRESHOLD {
432 Some(format!(
433 "separation-scale finite logit |eta[row {}, active class {}]| = {:.3e} \
434 after capped unbiased probe",
435 best.1, best.2, best.0
436 ))
437 } else {
438 None
439 }
440}
441
442/// Inputs to [`fit_penalized_multinomial`].
443///
444/// The penalty matrix `S` is shared across classes; per-class smoothing
445/// parameters `lambdas` (length `K - 1`) scale `S` independently for each
446/// active class. The full block-replicated penalty is `diag_a(λ_a) ⊗ S`,
447/// which is exactly what [`gam_solve::arrow_schur::KroneckerPenaltyOp`]
448/// expresses in matrix-free form when this driver is later lifted into the
449/// arrow-Schur loop.
450#[derive(Debug, Clone)]
451pub struct MultinomialFitInputs<'a> {
452 /// Design matrix `X ∈ ℝ^{N×P}` (one row per observation).
453 pub design: ArrayView2<'a, f64>,
454 /// Categorical response `Y ∈ ℝ^{N×K}`. Each row must be a point on the
455 /// probability simplex (`y_c ≥ 0`, `Σ_c y_c = 1`): a one-hot indicator for
456 /// hard classification, or a label-smoothed probability vector. Rows whose
457 /// mass departs from 1 are rejected — the softmax residual gradient and
458 /// Fisher block are the derivatives of `Σ_c y_c log p_c` only under the
459 /// simplex constraint (see `validate_multinomial_simplex`).
460 pub y_one_hot: ArrayView2<'a, f64>,
461 /// Shared smoothing penalty `S ∈ ℝ^{P×P}` (symmetric, PSD).
462 pub penalty: ArrayView2<'a, f64>,
463 /// Per-active-class smoothing parameter `λ_a` (length `K - 1`).
464 pub lambdas: ArrayView1<'a, f64>,
465 /// Optional per-row weights (length `N`); `None` ⇒ uniform 1.0.
466 pub row_weights: Option<ArrayView1<'a, f64>>,
467 /// Optional per-row Fisher-block override, shape `(N, K-1, K-1)` in the
468 /// active-class gauge (the reference class `K-1` is dropped). When `Some`,
469 /// each Newton step uses this block as the curvature `W` in place of the
470 /// analytic softmax Fisher `w_n (δ_ab p_a − p_a p_b)`; the gradient/residual
471 /// path stays analytic, so this is a curvature-only override (the
472 /// research escape-hatch for latent multinomial fits, issue #349). Each
473 /// per-row block must be symmetric, PSD, and finite — preconditions the
474 /// FFI boundary discharges before constructing this view.
475 pub fisher_w_override: Option<ArrayView3<'a, f64>>,
476 /// Maximum Newton iterations; recommend 50.
477 pub max_iter: usize,
478 /// Relative-step convergence tolerance; recommend 1e-7.
479 pub tol: f64,
480}
481
482/// Outputs of [`fit_penalized_multinomial`].
483#[derive(Debug, Clone)]
484pub struct MultinomialFitOutputs {
485 /// Active-class coefficient block, shape `(P, K-1)` (column `a` is `β_a`).
486 /// The reference class `K - 1` has `β_{K-1} ≡ 0` by construction and is
487 /// not stored.
488 pub coefficients_active: Array2<f64>,
489 /// Fitted probabilities, shape `(N, K)`.
490 pub fitted_probabilities: Array2<f64>,
491 /// Number of Newton iterations executed (including the final step that
492 /// satisfied the tolerance).
493 pub iterations: usize,
494 /// `true` if the relative-step test was satisfied; `false` if the
495 /// solver exhausted `max_iter`. (A non-converged solve is still
496 /// returned; the caller decides whether to escalate.)
497 pub converged: bool,
498 /// Penalized negative log-likelihood at the returned `β̂`:
499 /// `−log L(β̂) + ½ Σ_a λ_a · β̂_a^T S β̂_a`.
500 pub penalized_neg_log_likelihood: f64,
501 /// Unpenalized deviance `−2 log L(β̂)` for diagnostic reporting.
502 pub deviance: f64,
503 /// Joint Laplace posterior coefficient covariance `H⁻¹` at the converged
504 /// `β̂`, shape `(P·(K−1))×(P·(K−1))` (#1101). Block-ordered to match the
505 /// stacked active-class coefficient vector `β = [β_0; …; β_{K-2}]`: active
506 /// class `a`'s `P` coefficients occupy rows/cols `a·P .. (a+1)·P`, indexed
507 /// `θ[a·P + i] = β̂[i, a]`. This is the Laplace covariance from the factored
508 /// penalized Hessian `XᵀWX + diag_a(λ_a)⊗S`; it drives the delta-method
509 /// per-class probability standard errors ([`Self::predict_probabilities_with_se`])
510 /// on the fixed-λ inner-solve path.
511 pub coefficient_covariance: Array2<f64>,
512}
513
514impl MultinomialFitOutputs {
515 /// Number of active classes `M = K − 1` (columns of
516 /// [`Self::coefficients_active`]).
517 pub fn n_active_classes(&self) -> usize {
518 self.coefficients_active.ncols()
519 }
520
521 /// Per-class coefficient dimension `P` (rows of
522 /// [`Self::coefficients_active`]).
523 pub fn p_per_class(&self) -> usize {
524 self.coefficients_active.nrows()
525 }
526
527 /// Evaluate `softmax(X·β̂)` AND its delta-method per-class probability
528 /// standard error at fresh design rows `X_new` (#1101), using the joint
529 /// Laplace covariance [`Self::coefficient_covariance`].
530 ///
531 /// The softmax Jacobian is `∂p_c/∂η_b = p_c (δ_{cb} − p_b)` for active class
532 /// `b ∈ 0..M`, and `∂η_b/∂β[i,a] = X[i]·δ_{ab}`, so the gradient of the
533 /// class-`c` probability w.r.t. the block-ordered coefficient vector is
534 /// `g_c[a·P + i] = X[i]·p_c (δ_{ca} − p_a)` (the reference class `M`
535 /// contributes only through `−p_a` in every active block). The delta-method
536 /// variance is `Var(p_c) = g_cᵀ Σ g_c` with `Σ = H⁻¹`, and
537 /// `SE(p_c) = √Var(p_c)`. Returns `(probs (N,K), prob_se (N,K))`. `X_new`
538 /// must have `P` columns (the same design basis used at fit time); its row
539 /// count sets `N`. The SE is unclamped (the interval consumer applies the
540 /// simplex `[0,1]` clamp).
541 pub fn predict_probabilities_with_se(
542 &self,
543 x_new: ArrayView2<'_, f64>,
544 ) -> Result<(Array2<f64>, Array2<f64>), EstimationError> {
545 let p = self.p_per_class();
546 let m = self.n_active_classes();
547 let k = m + 1;
548 if x_new.ncols() != p {
549 crate::bail_invalid_estim!(
550 "predict_probabilities_with_se: X has {} cols, expected P={p}",
551 x_new.ncols()
552 );
553 }
554 let d = p * m;
555 let cov = &self.coefficient_covariance;
556 if cov.dim() != (d, d) {
557 crate::bail_invalid_estim!(
558 "predict_probabilities_with_se: covariance shape {:?} ≠ (P·M, P·M) = ({d}, {d})",
559 cov.dim()
560 );
561 }
562 let n_new = x_new.nrows();
563 let beta = &self.coefficients_active;
564 let mut probs = Array2::<f64>::zeros((n_new, k));
565 let mut prob_se = Array2::<f64>::zeros((n_new, k));
566 let mut eta_active = vec![0.0_f64; m];
567 let mut row_probs = vec![0.0_f64; k];
568 let mut grad = vec![0.0_f64; d];
569 for row in 0..n_new {
570 for a in 0..m {
571 let mut v = 0.0_f64;
572 for i in 0..p {
573 v += x_new[[row, i]] * beta[[i, a]];
574 }
575 eta_active[a] = v;
576 }
577 MultinomialLogitLikelihood::softmax_with_baseline(&eta_active, &mut row_probs);
578 for c in 0..k {
579 probs[[row, c]] = row_probs[c];
580 }
581 for c in 0..k {
582 let pc = row_probs[c];
583 // g_c[a·P + i] = X[i] · p_c · (δ_{ca} − p_a), a active.
584 for a in 0..m {
585 let pa = row_probs[a];
586 let factor = pc * (if c == a { 1.0 - pa } else { -pa });
587 let base = a * p;
588 for i in 0..p {
589 grad[base + i] = x_new[[row, i]] * factor;
590 }
591 }
592 // Var = gᵀ Σ g.
593 let mut var = 0.0_f64;
594 for r in 0..d {
595 let gr = grad[r];
596 if gr == 0.0 {
597 continue;
598 }
599 let mut acc = 0.0_f64;
600 for s in 0..d {
601 acc += cov[[r, s]] * grad[s];
602 }
603 var += gr * acc;
604 }
605 prob_se[[row, c]] = var.max(0.0).sqrt();
606 }
607 }
608 Ok((probs, prob_se))
609 }
610}
611
612/// Fit a penalized multinomial-logit GAM at fixed `λ`.
613///
614/// See the module docs for the optimization problem and conventions. This
615/// function is the canonical inner solve: the outer REML/LAML loop, when
616/// added, calls this at each `ρ = log λ` trial.
617pub fn fit_penalized_multinomial(
618 inputs: MultinomialFitInputs<'_>,
619) -> Result<MultinomialFitOutputs, EstimationError> {
620 let MultinomialFitInputs {
621 design,
622 y_one_hot,
623 penalty,
624 lambdas,
625 row_weights,
626 fisher_w_override,
627 max_iter,
628 tol,
629 } = inputs;
630
631 // ──────────────────────── family-specific validation ───────────────────
632 // The shared engine re-validates the geometry common to every vector-GLM
633 // (nonempty design, penalty shape, λ finiteness/non-negativity, override
634 // `(N, M, M)` shape, finite design). The multinomial family owns the
635 // class-count contract (`K ≥ 2`, λ length `K − 1`), the per-row simplex
636 // precondition under which the softmax residual/Fisher are the exact
637 // derivatives of `Σ_c y_c log p_c`, and the row-weight check the likelihood
638 // adapter consumes.
639 let n_obs = design.nrows();
640 let (y_rows, k) = y_one_hot.dim();
641 if y_rows != n_obs {
642 crate::bail_invalid_estim!(
643 "fit_penalized_multinomial: y rows {y_rows} ≠ design rows {n_obs}"
644 );
645 }
646 if k < 2 {
647 crate::bail_invalid_estim!(
648 "fit_penalized_multinomial: need at least 2 classes (got K={k})"
649 );
650 }
651 let m = k - 1;
652 if lambdas.len() != m {
653 crate::bail_invalid_estim!(
654 "fit_penalized_multinomial: lambdas length {} ≠ K-1 = {m}",
655 lambdas.len()
656 );
657 }
658 if let Some(fw) = fisher_w_override.as_ref() {
659 if fw.dim() != (n_obs, m, m) {
660 crate::bail_invalid_estim!(
661 "fit_penalized_multinomial: fisher_w_override shape {:?} ≠ (N, K-1, K-1) = ({n_obs}, {m}, {m})",
662 fw.dim()
663 );
664 }
665 }
666 if let Some(w) = row_weights.as_ref() {
667 if w.len() != n_obs {
668 crate::bail_invalid_estim!(
669 "fit_penalized_multinomial: row_weights length {} ≠ N = {n_obs}",
670 w.len()
671 );
672 }
673 for (i, &v) in w.iter().enumerate() {
674 if !(v.is_finite() && v >= 0.0) {
675 crate::bail_invalid_estim!(
676 "fit_penalized_multinomial: row_weights[{i}] must be finite and ≥ 0 (got {v})"
677 );
678 }
679 }
680 }
681 validate_multinomial_simplex(y_one_hot, "fit_penalized_multinomial")?;
682
683 // ────────────────────────── likelihood construction ───────────────────
684 let mut likelihood = MultinomialLogitLikelihood::with_classes(k)?;
685 if let Some(w) = row_weights.as_ref() {
686 likelihood = likelihood.with_row_weights(w.to_owned())?;
687 }
688
689 // ─────────────────── shared penalized vector-GLM solve ─────────────────
690 // The softmax Fisher block is dense across the `M = K − 1` active classes;
691 // the engine assembles the coupled `(P·M)×(P·M)` penalized Hessian, runs
692 // the damped Newton loop, and returns the converged `β̂` and `η = X β̂`.
693 let fit = fit_penalized_vector_glm(
694 PenalizedVectorGlmInputs {
695 design,
696 y: y_one_hot,
697 penalty,
698 lambdas,
699 fisher_w_override,
700 max_iter,
701 tol,
702 // #1587: production multinomial still uses the per-class Diagonal
703 // metric pending the REML per-class→per-term λ re-key that the
704 // reference-symmetric Centered metric requires (shared λ). The
705 // Centered engine path + its invariance proof land first.
706 class_penalty_metric: crate::penalized_vector_glm::ClassPenaltyMetric::Diagonal,
707 },
708 &likelihood,
709 "fit_penalized_multinomial",
710 )?;
711
712 let (max_abs_eta, row_index, active_class_index) = max_abs_eta_location(fit.eta.view());
713 if !fit.converged && max_abs_eta >= MULTINOMIAL_SEPARATION_ETA_THRESHOLD {
714 // Perfect / quasi-perfect separation (#1854): the UNBIASED softmax MLE is
715 // not finite along `active_class_index`'s saturated logit direction, so
716 // the fixed-λ Newton above ran away (`|η| ≥ 25`, no convergence). A
717 // penalty-null direction `v` (`S v = 0`, e.g. an unpenalized intercept /
718 // linear-covariate column) under softmax saturation has
719 // `(XᵀWX + λS) v → 0` for EVERY λ, so no smoothing parameter can bound it
720 // — only a proper prior on that quotient-null subspace can. Rather than
721 // hard-erroring, engage the Firth/Jeffreys proper prior automatically
722 // (magic-by-default): the full-span `½ log|I(β)|` correction supplies the
723 // `O(1)` curvature that keeps the estimate finite on exactly those
724 // separated directions while leaving well-identified fits untouched. This
725 // reuses the same coupled joint-Newton Jeffreys machinery the formula
726 // REML path arms on separation evidence (see
727 // `fit_penalized_multinomial_formula`), only here at the caller's fixed λ.
728 // Engage the fallback, but never let an internal consistency panic in
729 // the coupled joint-Newton assembly (e.g. the #1395 logdet-collapse
730 // guard) escape as a process abort: convert any panic into the
731 // documented hard separation diagnostic, exactly as if the refit had
732 // returned Err. This mirrors the catch_unwind panic-to-typed-error
733 // boundary already used around the faer / cudarc entry points, and keeps
734 // the separation path no worse than the pre-#1854 clean error while the
735 // Firth refit is still being hardened.
736 let firth = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
737 fit_penalized_multinomial_firth_fallback(
738 design,
739 y_one_hot,
740 penalty,
741 lambdas,
742 row_weights,
743 max_iter,
744 tol,
745 )
746 }));
747 match firth {
748 Ok(Ok(out)) => return Ok(out),
749 // Firth refit errored, or an internal consistency guard panicked:
750 // fall back to the explicit hard separation diagnostic.
751 Ok(Err(_)) | Err(_) => {
752 return Err(EstimationError::MultinomialSeparationDetected {
753 iteration: fit.iterations,
754 max_abs_eta,
755 active_class_index,
756 row_index,
757 });
758 }
759 }
760 }
761
762 let fitted_probabilities = likelihood.probabilities(fit.eta.view());
763
764 Ok(MultinomialFitOutputs {
765 coefficients_active: fit.coefficients,
766 fitted_probabilities,
767 iterations: fit.iterations,
768 converged: fit.converged,
769 penalized_neg_log_likelihood: -fit.log_likelihood + fit.penalty_term,
770 deviance: -2.0 * fit.log_likelihood,
771 coefficient_covariance: fit.coefficient_covariance,
772 })
773}
774
775/// Firth/Jeffreys-penalized multinomial refit engaged automatically when the
776/// unbiased softmax MLE separates (#1854).
777///
778/// The unbiased fixed-λ solve ([`fit_penalized_multinomial`]) runs away on
779/// (quasi-)separated data because the softmax likelihood has no finite mode along
780/// the saturated logit direction and the smoothing penalty `S` cannot bound a
781/// penalty-null direction (`S v = 0` ⇒ `(XᵀWX + λS) v → 0` for every λ). This
782/// refit arms the full-span Jeffreys/Firth proper prior `½ log|I(β)|` on the
783/// coupled joint softmax information, which supplies the `O(1)` curvature that
784/// bounds exactly those directions and keeps the estimate finite.
785///
786/// It reuses the SAME coupled joint-Newton REML machinery the formula multinomial
787/// entry arms on separation evidence (`run_firth_refit` in
788/// [`fit_penalized_multinomial_formula`]): a [`MultinomialFamily`] carrying the
789/// single shared penalty `S` as one smooth term, with the joint Jeffreys/Firth
790/// term armed and the reference-symmetric ("centered") joint penalty as the
791/// smoothing carrier. The caller's per-active-class `lambdas` (the centered metric
792/// ties one shared `λ` across classes; every wine / penguin-style separation
793/// arrives with a single shared smooth λ, for which this is exact) seed the outer
794/// ρ; the outer REML loop then selects the smoothing parameter around that seed.
795/// The Firth curvature — not the smoothing λ — is what keeps the separated
796/// unpenalized directions finite, so the exact λ selection is immaterial to
797/// finiteness here.
798fn fit_penalized_multinomial_firth_fallback(
799 design: ArrayView2<'_, f64>,
800 y_one_hot: ArrayView2<'_, f64>,
801 penalty: ArrayView2<'_, f64>,
802 lambdas: ArrayView1<'_, f64>,
803 row_weights: Option<ArrayView1<'_, f64>>,
804 max_iter: usize,
805 tol: f64,
806) -> Result<MultinomialFitOutputs, EstimationError> {
807 let n_obs = design.nrows();
808 let p = design.ncols();
809 let k = y_one_hot.ncols();
810 let m = k - 1;
811
812 // Local softmax likelihood mirroring the caller's row weights, used only to
813 // map the fitted η back to probabilities for the returned outputs.
814 let mut likelihood = MultinomialLogitLikelihood::with_classes(k)?;
815 if let Some(w) = row_weights.as_ref() {
816 likelihood = likelihood.with_row_weights(w.to_owned())?;
817 }
818
819 // Representative shared smooth λ seeding the centered joint metric. Floor away
820 // from zero so a fully-unpenalized smooth still seeds a finite `log λ`; the
821 // Firth prior — not this λ — is what bounds the separated directions.
822 let lambda_repr = {
823 let sum: f64 = lambdas.iter().copied().sum();
824 let mean = if m > 0 { sum / m as f64 } else { 0.0 };
825 mean.max(multinomial_formula_min_lambda(y_one_hot))
826 };
827
828 let weights = match row_weights.as_ref() {
829 Some(w) => w.to_owned(),
830 None => Array1::<f64>::ones(n_obs),
831 };
832 let design_arc = Arc::new(design.to_owned());
833 let penalties_arc = Arc::new(vec![PenaltyMatrix::Dense(penalty.to_owned())]);
834 let nullspace_arc = Arc::new(vec![0usize]);
835
836 let family = MultinomialFamily::new(
837 y_one_hot.to_owned(),
838 weights,
839 k,
840 design_arc,
841 penalties_arc,
842 nullspace_arc,
843 )
844 .map_err(EstimationError::InvalidInput)?
845 // Arm the full-span Firth/Jeffreys correction: this is the whole point of the
846 // fallback — supply the `O(1)` curvature that bounds the separated logits.
847 .with_joint_jeffreys_term(true)
848 .with_initial_log_lambda(lambda_repr.ln());
849 let blocks = family.build_block_specs();
850
851 // Mirror the formula path's Firth refit options: full inner budget so the
852 // ill-conditioned LM-damped joint-Newton can certify its KKT point, calibrated
853 // outer REML tolerance, solver-only stabilization (the ridge never enters the
854 // fitted objective), and no seed screening (the pinned seed flows straight to
855 // the outer optimizer).
856 // One smooth term (the shared penalty `S`) per active class.
857 let total_rho_dim = m;
858 let options = BlockwiseFitOptions {
859 inner_max_cycles: crate::custom_family::DEFAULT_CUSTOM_FAMILY_INNER_MAX_CYCLES,
860 inner_tol: MULTINOMIAL_FORMULA_INNER_TOL.max(tol.max(0.0)),
861 outer_max_iter: max_iter.max(1),
862 outer_tol: if tol.is_finite() && tol > 0.0 {
863 tol.max(MULTINOMIAL_OUTER_REML_TOL)
864 } else {
865 MULTINOMIAL_OUTER_REML_TOL
866 },
867 outer_rel_cost_tol: Some(BlockwiseFitOptions::default().outer_tol),
868 rho_lower_bound: multinomial_formula_min_lambda(y_one_hot).ln(),
869 ridge_floor: MULTINOMIAL_FORMULA_RIDGE_FLOOR,
870 ridge_policy: gam_problem::RidgePolicy::solver_only(),
871 use_outer_hessian: multinomial_formula_use_outer_hessian(total_rho_dim),
872 screen_initial_rho: false,
873 compute_covariance: true,
874 ..BlockwiseFitOptions::default()
875 };
876
877 let fit = fit_custom_family_with_rho_prior(
878 &family,
879 &blocks,
880 &options,
881 gam_problem::RhoPrior::Flat,
882 )
883 .map_err(|err| {
884 EstimationError::InvalidInput(format!(
885 "multinomial Firth/Jeffreys separation fallback (#1854) failed: {err}"
886 ))
887 })?;
888
889 if fit.blocks.len() != m {
890 crate::bail_invalid_estim!(
891 "multinomial Firth fallback: expected {m} fitted blocks (K-1), got {}",
892 fit.blocks.len()
893 );
894 }
895 let mut coefficients_active = Array2::<f64>::zeros((p, m));
896 for (a, block) in fit.blocks.iter().enumerate() {
897 if block.beta.len() != p {
898 crate::bail_invalid_estim!(
899 "multinomial Firth fallback: block {a} has {} coefs, expected {p}",
900 block.beta.len()
901 );
902 }
903 for i in 0..p {
904 let v = block.beta[i];
905 if !v.is_finite() {
906 crate::bail_invalid_estim!(
907 "multinomial Firth fallback: non-finite coefficient β[{i}, {a}] = {v}"
908 );
909 }
910 coefficients_active[[i, a]] = v;
911 }
912 }
913
914 // η = X · β̂, fitted probabilities, and the penalized objective on the direct
915 // path's own (diagonal-metric) reporting convention.
916 let mut eta = Array2::<f64>::zeros((n_obs, m));
917 for a in 0..m {
918 for row in 0..n_obs {
919 let mut acc = 0.0_f64;
920 for i in 0..p {
921 acc += design[[row, i]] * coefficients_active[[i, a]];
922 }
923 eta[[row, a]] = acc;
924 }
925 }
926 let fitted_probabilities = likelihood.probabilities(eta.view());
927
928 let mut log_likelihood = 0.0_f64;
929 for n in 0..n_obs {
930 let w = row_weights.as_ref().map_or(1.0, |rw| rw[n]);
931 for c in 0..k {
932 let ycn = y_one_hot[[n, c]];
933 if ycn != 0.0 {
934 log_likelihood += w * ycn * fitted_probabilities[[n, c]].max(f64::MIN_POSITIVE).ln();
935 }
936 }
937 }
938
939 let mut penalty_term = 0.0_f64;
940 for a in 0..m {
941 let beta_col = coefficients_active.column(a);
942 let mut quad = 0.0_f64;
943 for i in 0..p {
944 let mut s_beta_i = 0.0_f64;
945 for j in 0..p {
946 s_beta_i += penalty[[i, j]] * beta_col[j];
947 }
948 quad += beta_col[i] * s_beta_i;
949 }
950 penalty_term += 0.5 * lambda_repr * quad;
951 }
952
953 let d = p * m;
954 let coefficient_covariance = fit
955 .covariance_conditional
956 .filter(|c| c.dim() == (d, d))
957 .unwrap_or_else(|| Array2::<f64>::zeros((d, d)));
958
959 Ok(MultinomialFitOutputs {
960 coefficients_active,
961 fitted_probabilities,
962 iterations: fit.inner_cycles.max(1),
963 converged: true,
964 penalized_neg_log_likelihood: -log_likelihood + penalty_term,
965 deviance: -2.0 * log_likelihood,
966 coefficient_covariance,
967 })
968}
969
970// ---------------------------------------------------------------------------
971// Formula-driven multinomial pipeline
972// ---------------------------------------------------------------------------
973//
974// Slice A of the multinomial integration: a single public entry that takes
975// a parsed `EncodedDataset`, a Wilkinson-style formula, and a uniform initial
976// smoothing parameter, then runs the full
977//
978// parse → termspec → design (X, S blocks) → one-hot Y → REML λ-selection
979//
980// pipeline. `fit_penalized_multinomial_formula` drives the outer REML/LAML
981// loop (via the custom-family path) to select an independent λ per (class,
982// term); `init_lambda` (default 1.0) is only the warm-start seed for every
983// block. The reference class is the last level of the categorical response
984// column as recorded in the dataset schema.
985
986/// Saved-model payload for a multinomial fit driven by a Wilkinson formula.
987///
988/// This is what the FFI returns to Python. It carries everything the Python
989/// `MultinomialModel.predict` path needs to evaluate `softmax(X_new · β)` on
990/// fresh data using the *training* basis / penalty structure (no refit on
991/// predict, no re-derivation of class levels).
992#[derive(Debug, Clone, Serialize, Deserialize)]
993pub struct MultinomialSavedModel {
994 /// The training formula, verbatim. Stored so Python's `summary()` and
995 /// any round-trip persistence path can echo what was fit.
996 pub formula: String,
997 /// Names of the *training* response levels in canonical order. The last
998 /// entry is the reference class (η = 0); the first `K - 1` carry the
999 /// active linear-predictor blocks. Class permutations are forbidden:
1000 /// this list is fixed at fit time and predictions emit columns in the
1001 /// same order.
1002 pub class_levels: Vec<String>,
1003 /// Index of the reference class within `class_levels` — currently always
1004 /// `class_levels.len() - 1`, exposed as a field so future "user-pinned
1005 /// reference" gauges (e.g. `family='multinomial', reference='setosa'`)
1006 /// can land without changing the on-disk shape.
1007 pub reference_class_index: usize,
1008 /// Resolved term-collection spec used to build `X` at fit time. Replayed
1009 /// on predict via [`gam_terms::smooth::build_term_collection_design`].
1010 pub resolved_termspec: TermCollectionSpec,
1011 /// Active-class coefficient block, shape `(P, K-1)`. Column `a` is the
1012 /// coefficient vector for class `class_levels[a]`. Stored flat in
1013 /// row-major order to keep the serde payload self-describing.
1014 pub coefficients_flat: Vec<f64>,
1015 /// `P` — coefficient count per active class. Matches the column count of
1016 /// the design matrix the saved `resolved_termspec` produces.
1017 pub p_per_class: usize,
1018 /// Number of active classes (`K - 1`).
1019 pub n_active_classes: usize,
1020 /// Original training column headers, in dataset-column order. Needed at
1021 /// predict time so the FFI can align a fresh `Dataset` to the training
1022 /// schema before evaluating the basis.
1023 pub training_headers: Vec<String>,
1024 /// REML/LAML-selected smoothing parameters, one per `(active class, smooth
1025 /// term)`, flattened in block-major order: all of class 0's per-term λ,
1026 /// then class 1's, and so on. Per-term penalties (#561) mean each active
1027 /// class block selects an *independent* λ for every smooth term, so this
1028 /// vector has length `Σ_a (#terms in class a)` = `(K − 1) · #terms`. Use
1029 /// [`MultinomialSavedModel::lambdas_per_block`] to segment it by class. An
1030 /// unpenalized model (no smooth terms) yields an empty vector.
1031 pub lambdas: Vec<f64>,
1032 /// Number of smoothing parameters (smooth terms) in each active class
1033 /// block, parallel to `class_levels[0..K-1]`. Segments the flat `lambdas`
1034 /// vector: class `a`'s λ are `lambdas[Σ_{b<a} lambdas_per_block[b] ..][..
1035 /// lambdas_per_block[a]]`. Every entry is identical in the shared-design
1036 /// architecture (all classes share the same term structure), but it is
1037 /// stored explicitly so consumers never have to assume that.
1038 pub lambdas_per_block: Vec<usize>,
1039 /// Newton iterations executed; recorded for the summary report.
1040 pub iterations: usize,
1041 /// `true` if the inner Newton solver hit the relative-step tolerance.
1042 pub converged: bool,
1043 /// Penalized negative log-likelihood at the returned `β̂`.
1044 pub penalized_neg_log_likelihood: f64,
1045 /// Unpenalized deviance `−2 log L(β̂)`.
1046 pub deviance: f64,
1047 /// Per-active-class effective degrees of freedom (hat-matrix trace),
1048 /// length `K - 1`. Populated when the REML driver reports an
1049 /// inference block; falls back to `None` for the legacy fixed-λ path.
1050 #[serde(default)]
1051 pub edf_per_class: Option<Vec<f64>>,
1052 /// Per-PENALTY effective degrees of freedom, one entry per smoothing
1053 /// parameter (length `== lambdas.len()`), aligned block-major with the flat
1054 /// [`Self::lambdas`] / [`Self::lambdas_per_block`] layout. Each entry is the
1055 /// penalty-block trace EDF `rank(S_k) − λ_k·tr(H⁻¹ S_k)`, clamped to
1056 /// `[0, rank(S_k)]`. This is the per-(class, term, penalty) resolution that
1057 /// the per-class [`Self::edf_per_class`] SUM deliberately hides: only the
1058 /// per-penalty vector reveals whether an individual smooth collapsed onto its
1059 /// polynomial null space (its wiggliness λ driven to the λ-cap), which a
1060 /// per-class total cannot show. Populated whenever the REML driver reports an
1061 /// inference block; `None` on the legacy fixed-λ path or when the trace
1062 /// channel is mis-shaped. Unlike `edf_per_class`, the entries do NOT sum to
1063 /// the model EDF when several penalties share one coefficient range (a
1064 /// double-penalty smooth has `Σ_k rank(S_k) > p_per_class`).
1065 #[serde(default)]
1066 pub edf_per_penalty: Option<Vec<f64>>,
1067 /// Joint posterior coefficient covariance `H⁻¹` (#1101), block-ordered to
1068 /// match the stacked active-class coefficient vector `β = [β_0; …; β_{K-2}]`
1069 /// (class `a`'s `P` coefficients occupy rows/cols `a·P .. (a+1)·P`). This is
1070 /// the Laplace covariance the REML driver already computes from the factored
1071 /// penalized Hessian; storing it gives the predict path delta-method
1072 /// per-class probability standard errors and the summary its Wald
1073 /// smooth-term tests. Flattened row-major over the `(P·M)×(P·M)` matrix.
1074 /// `None` for a model fitted before covariance was surfaced.
1075 #[serde(default)]
1076 pub coefficient_covariance_flat: Option<Vec<f64>>,
1077 /// Joint coefficient-space influence matrix `F = H⁻¹ X'WX` (#1101),
1078 /// block-ordered identically to [`Self::coefficient_covariance_flat`].
1079 /// Its per-term diagonal block trace is the term's effective degrees of
1080 /// freedom and its `tr(F_jj)²/tr(F_jj²)` the Wood reference d.f., feeding
1081 /// the rank-truncated Wald smooth-term test in `summary()`. Flattened
1082 /// row-major over the `(P·M)×(P·M)` matrix. `None` when unavailable.
1083 #[serde(default)]
1084 pub coefficient_influence_flat: Option<Vec<f64>>,
1085 /// Per-(active class, smooth term) coefficient column range and unpenalized
1086 /// nullspace dimension within the `P`-wide class block (#1101). Parallel to
1087 /// the smooth terms the design produced; replicated across classes by the
1088 /// shared-design architecture. Drives the Wald smooth-term table in
1089 /// `summary()`. Empty for a wholly parametric (no-smooth) model.
1090 #[serde(default)]
1091 pub smooth_term_spans: Vec<MultinomialSmoothTermSpan>,
1092 /// One descriptive label per *penalty component* within a single active-class
1093 /// block, parallel to that block's λ slice (i.e. length
1094 /// `lambdas_per_block[0]`). The Marra–Wood double penalty (and tensor /
1095 /// operator smooths) emit **more than one** penalty component — hence more
1096 /// than one λ — per smooth term, so this is NOT 1:1 with
1097 /// [`Self::smooth_term_spans`]: a single `s(x)` term contributes a primary
1098 /// wiggliness λ labelled `s(x)` and a null-space shrinkage λ labelled
1099 /// `s(x) [null space]`. The summary renderer pairs `lambdas` with these
1100 /// labels component-for-component so no λ is ever dropped (#1544). Built from
1101 /// the per-component term name + penalty role at fit time; empty for a
1102 /// wholly parametric model or a model serialized before this field existed.
1103 #[serde(default)]
1104 pub lambda_labels: Vec<String>,
1105}
1106
1107/// One smooth term's coefficient span within a class block, plus its
1108/// unpenalized nullspace dimension and a display label (#1101). The Wald
1109/// smooth-significance test in `summary()` slices the joint covariance /
1110/// influence at `a·P + col_start .. a·P + col_end` for active class `a`.
1111#[derive(Debug, Clone, Serialize, Deserialize)]
1112pub struct MultinomialSmoothTermSpan {
1113 /// Human-readable term label (the smooth's formula token), for the table.
1114 pub label: String,
1115 /// Start column of the term within the per-class `P`-wide coefficient block.
1116 pub col_start: usize,
1117 /// End column (exclusive) of the term within the per-class block.
1118 pub col_end: usize,
1119 /// Leading unpenalized (polynomial nullspace) dimension within the term.
1120 pub nullspace_dim: usize,
1121}
1122
1123/// Descriptive label for one penalty *component* (one λ) within a class block,
1124/// for the `summary()` per-class λ rollup (#1544). A smooth term can emit
1125/// several penalty components — the Marra–Wood double penalty splits `s(x)`
1126/// into a primary wiggliness penalty and a null-space shrinkage penalty, and
1127/// tensor / operator smooths emit a component per margin / differential
1128/// operator — each with its own independently-selected λ. The label is the
1129/// term name (from `PenaltyBlockInfo::termname`) plus a role suffix derived
1130/// from the penalty's [`PenaltySource`], so each λ in the summary names both
1131/// the term it smooths and the role it plays. `pen_idx` is the global penalty
1132/// index, used only as a last-resort fallback label.
1133fn penalty_component_label(info: Option<&PenaltyBlockInfo>, pen_idx: usize) -> String {
1134 use gam_terms::basis::PenaltySource;
1135 let term = info
1136 .and_then(|i| i.termname.clone())
1137 .unwrap_or_else(|| format!("s{pen_idx}"));
1138 let role = match info.map(|i| &i.penalty.source) {
1139 // The primary wiggliness penalty is the term's "main" λ; show the bare
1140 // term name so the common single-penalty case reads cleanly.
1141 Some(PenaltySource::Primary) | None => None,
1142 Some(PenaltySource::DoublePenaltyNullspace) => Some("null space".to_string()),
1143 Some(PenaltySource::OperatorMass) => Some("mass".to_string()),
1144 Some(PenaltySource::OperatorTension) => Some("tension".to_string()),
1145 Some(PenaltySource::OperatorStiffness) => Some("stiffness".to_string()),
1146 Some(PenaltySource::OperatorRelevance { axis }) => Some(format!("axis {axis}")),
1147 Some(PenaltySource::TensorMarginal { dim }) => Some(format!("margin {dim}")),
1148 Some(PenaltySource::TensorSeparable { penalized_margins }) => {
1149 Some(format!("separable {penalized_margins:?}"))
1150 }
1151 Some(PenaltySource::TensorGlobalRidge) => Some("ridge".to_string()),
1152 Some(PenaltySource::Other(s)) => Some(s.clone()),
1153 };
1154 match role {
1155 Some(role) => format!("{term} [{role}]"),
1156 None => term,
1157 }
1158}
1159
1160impl MultinomialSavedModel {
1161 /// Active-class coefficient block as an `(P, K-1)` `ndarray` view.
1162 pub fn coefficients_active(&self) -> Array2<f64> {
1163 Array2::from_shape_vec(
1164 (self.p_per_class, self.n_active_classes),
1165 self.coefficients_flat.clone(),
1166 )
1167 .expect(
1168 "MultinomialSavedModel.coefficients_flat length must equal p_per_class * n_active_classes",
1169 )
1170 }
1171
1172 /// Evaluate `softmax(X · β)` at fresh data rows. `X_new` must have
1173 /// `self.p_per_class` columns (i.e. it was built from the same
1174 /// `resolved_termspec` as fit time). Returns an `(N_new, K)` matrix
1175 /// with rows summing to 1; column order matches `self.class_levels`.
1176 pub fn predict_probabilities(&self, x_new: ArrayView2<'_, f64>) -> Array2<f64> {
1177 let n_new = x_new.nrows();
1178 let p = self.p_per_class;
1179 let m = self.n_active_classes;
1180 let k = m + 1;
1181 assert_eq!(
1182 x_new.ncols(),
1183 p,
1184 "MultinomialSavedModel.predict_probabilities: X has {} cols, expected {p}",
1185 x_new.ncols()
1186 );
1187 let beta = self.coefficients_active();
1188 let mut probs = Array2::<f64>::zeros((n_new, k));
1189 let mut eta_active = vec![0.0_f64; m];
1190 let mut row_probs = vec![0.0_f64; k];
1191 for row in 0..n_new {
1192 for a in 0..m {
1193 let mut v = 0.0_f64;
1194 for i in 0..p {
1195 v += x_new[[row, i]] * beta[[i, a]];
1196 }
1197 eta_active[a] = v;
1198 }
1199 MultinomialLogitLikelihood::softmax_with_baseline(&eta_active, &mut row_probs);
1200 for c in 0..k {
1201 probs[[row, c]] = row_probs[c];
1202 }
1203 }
1204 probs
1205 }
1206
1207 /// Reconstruct the joint posterior covariance `H⁻¹` as a `(P·M)×(P·M)`
1208 /// `ndarray`, block-ordered to match the stacked coefficient vector
1209 /// `θ[a·P + i] = β[i, a]` (#1101). `None` when the model was fitted before
1210 /// covariance was surfaced (legacy payload).
1211 pub fn coefficient_covariance(&self) -> Option<Array2<f64>> {
1212 let d = self.p_per_class.checked_mul(self.n_active_classes)?;
1213 let flat = self.coefficient_covariance_flat.as_ref()?;
1214 Array2::from_shape_vec((d, d), flat.clone()).ok()
1215 }
1216
1217 /// Reconstruct the joint influence matrix `F = H⁻¹ X'WX` as a
1218 /// `(P·M)×(P·M)` `ndarray`, block-ordered like
1219 /// [`Self::coefficient_covariance`] (#1101). `None` when unavailable.
1220 pub fn coefficient_influence(&self) -> Option<Array2<f64>> {
1221 let d = self.p_per_class.checked_mul(self.n_active_classes)?;
1222 let flat = self.coefficient_influence_flat.as_ref()?;
1223 Array2::from_shape_vec((d, d), flat.clone()).ok()
1224 }
1225
1226 /// Evaluate `softmax(X·β)` AND its delta-method per-class probability
1227 /// standard error at fresh data rows (#1101).
1228 ///
1229 /// For active classes `b ∈ 0..M` the softmax Jacobian is
1230 /// `∂p_c/∂η_b = p_c (δ_{cb} − p_b)`, and `∂η_b/∂β[i,a] = X[i]·δ_{ab}`, so the
1231 /// gradient of class-`c` probability w.r.t. the block-ordered coefficient
1232 /// vector is `g_c[a·P + i] = X[i]·p_c (δ_{ca} − p_a)` (active `a`; the
1233 /// reference class `M` contributes `p_c(0 − p_a)` via every active block).
1234 /// The delta-method variance is `Var(p_c) = g_cᵀ Σ g_c` with `Σ = H⁻¹` the
1235 /// joint posterior covariance, and `SE(p_c) = √Var(p_c)`. Returns
1236 /// `(probs (N,K), prob_se (N,K))`; `prob_se` is `None` when no covariance is
1237 /// stored. The simplex `[0,1]` clamp is applied by the interval consumer, not
1238 /// here (the SE itself is unclamped).
1239 pub fn predict_probabilities_with_se(
1240 &self,
1241 x_new: ArrayView2<'_, f64>,
1242 ) -> (Array2<f64>, Option<Array2<f64>>) {
1243 let probs = self.predict_probabilities(x_new);
1244 let Some(cov) = self.coefficient_covariance() else {
1245 return (probs, None);
1246 };
1247 let n_new = x_new.nrows();
1248 let p = self.p_per_class;
1249 let m = self.n_active_classes;
1250 let k = m + 1;
1251 let d = p * m;
1252 let mut prob_se = Array2::<f64>::zeros((n_new, k));
1253 let mut grad = vec![0.0_f64; d];
1254 for row in 0..n_new {
1255 let prow = probs.row(row);
1256 for c in 0..k {
1257 let pc = prow[c];
1258 // g_c[a·P + i] = X[i] · p_c · (δ_{ca} − p_a), a active.
1259 for a in 0..m {
1260 let pa = prow[a];
1261 let factor = pc * (if c == a { 1.0 - pa } else { -pa });
1262 let base = a * p;
1263 for i in 0..p {
1264 grad[base + i] = x_new[[row, i]] * factor;
1265 }
1266 }
1267 // Var = gᵀ Σ g.
1268 let mut var = 0.0_f64;
1269 for r in 0..d {
1270 let gr = grad[r];
1271 if gr == 0.0 {
1272 continue;
1273 }
1274 let mut acc = 0.0_f64;
1275 for s in 0..d {
1276 acc += cov[[r, s]] * grad[s];
1277 }
1278 var += gr * acc;
1279 }
1280 prob_se[[row, c]] = var.max(0.0).sqrt();
1281 }
1282 }
1283 (probs, Some(prob_se))
1284 }
1285
1286 /// Wood (2013) rank-truncated Wald smooth-significance test per
1287 /// `(active class, smooth term)` (#1101), reusing the exact scalar-summary
1288 /// kernel [`gam_terms::inference::smooth_test::wood_smooth_test`]. For active
1289 /// class `a` and term span `[c0, c1)` within the class block, the global
1290 /// coefficient range is `a·P + c0 .. a·P + c1`; the joint covariance and
1291 /// influence are sliced there. The term EDF is the influence-block trace
1292 /// `tr(F_jj)` (when present) and the reference d.f. uses `tr(F_jj)²/tr(F_jj²)`,
1293 /// exactly as the scalar path. The multinomial softmax is a known-dispersion
1294 /// family, so the χ²_{ref_df} branch applies. Returns one row per
1295 /// `(class label, term label, edf, ref_df, statistic, p_value)`; empty when
1296 /// no covariance/smooth terms are available.
1297 pub fn smooth_significance(&self) -> Vec<MultinomialSmoothSignificance> {
1298 let mut out = Vec::new();
1299 let p = self.p_per_class;
1300 let m = self.n_active_classes;
1301 let Some(cov) = self.coefficient_covariance() else {
1302 return out;
1303 };
1304 if self.smooth_term_spans.is_empty() {
1305 return out;
1306 }
1307 let beta = self.coefficients_active();
1308 // Block-ordered θ = [β_0; …; β_{M-1}], θ[a·P + i] = β[i, a].
1309 let d = p * m;
1310 let mut theta = Array1::<f64>::zeros(d);
1311 for a in 0..m {
1312 for i in 0..p {
1313 theta[a * p + i] = beta[[i, a]];
1314 }
1315 }
1316 let influence = self.coefficient_influence();
1317 for a in 0..m {
1318 let class_label = self
1319 .class_levels
1320 .get(a)
1321 .cloned()
1322 .unwrap_or_else(|| format!("class{a}"));
1323 let base = a * p;
1324 for span in &self.smooth_term_spans {
1325 if span.col_end > p {
1326 continue;
1327 }
1328 let start = base + span.col_start;
1329 let end = base + span.col_end;
1330 // Term EDF = tr(F_jj); without an influence matrix fall back to
1331 // the block coefficient count (full-rank Wald on the span).
1332 let block_len = (span.col_end - span.col_start) as f64;
1333 let edf = influence
1334 .as_ref()
1335 .map(|f| (start..end).map(|i| f[[i, i]]).sum::<f64>())
1336 .filter(|v| v.is_finite() && *v > 0.0)
1337 .unwrap_or(block_len);
1338 let result = gam_terms::inference::smooth_test::wood_smooth_test(
1339 gam_terms::inference::smooth_test::SmoothTestInput {
1340 beta: theta.view(),
1341 covariance: &cov,
1342 influence_matrix: influence.as_ref(),
1343 coeff_range: start..end,
1344 edf,
1345 nullspace_dim: span.nullspace_dim,
1346 residual_df: f64::INFINITY,
1347 scale: gam_terms::inference::smooth_test::SmoothTestScale::Known,
1348 },
1349 );
1350 if let Some(res) = result {
1351 out.push(MultinomialSmoothSignificance {
1352 class_label: class_label.clone(),
1353 term_label: span.label.clone(),
1354 edf,
1355 ref_df: res.ref_df,
1356 statistic: res.statistic,
1357 p_value: res.p_value,
1358 });
1359 }
1360 }
1361 }
1362 out
1363 }
1364
1365 /// Draw `n_draws` posterior-predictive replicate class assignments at fresh
1366 /// rows (#1101). Each draw independently samples every row's class from
1367 /// `Categorical(p_row)` with `p = softmax(X·β̂)` — the plug-in predictive
1368 /// distribution, i.e. the multinomial observation noise wrapped around the
1369 /// fitted mean (the categorical analogue of the scalar families'
1370 /// `sample_replicates`). The returned `(n_draws, N)` matrix holds class
1371 /// INDICES `0..K`, aligned to [`Self::class_levels`]. The draw stream is a
1372 /// `StdRng` seeded by `seed`, so `(x_new, n_draws, seed)` reproduce
1373 /// bit-identically — the engine for posterior-predictive checks and
1374 /// simulation-based calibration. `x_new` must have `self.p_per_class`
1375 /// columns (built from the same `resolved_termspec` as fit time).
1376 pub fn sample_replicate_classes(
1377 &self,
1378 x_new: ArrayView2<'_, f64>,
1379 n_draws: usize,
1380 seed: u64,
1381 ) -> Array2<u32> {
1382 use rand::{RngExt, SeedableRng};
1383 let probs = self.predict_probabilities(x_new);
1384 let n = probs.nrows();
1385 let k = probs.ncols();
1386 let mut out = Array2::<u32>::zeros((n_draws, n));
1387 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
1388 for d in 0..n_draws {
1389 for row in 0..n {
1390 let u: f64 = rng.random::<f64>();
1391 // Inverse-CDF categorical draw over the K simplex weights.
1392 let mut acc = 0.0_f64;
1393 let mut chosen = k - 1; // numerical fallback = reference class
1394 for c in 0..k {
1395 acc += probs[[row, c]];
1396 if u < acc {
1397 chosen = c;
1398 break;
1399 }
1400 }
1401 out[[d, row]] = chosen as u32;
1402 }
1403 }
1404 out
1405 }
1406}
1407
1408/// One row of the multinomial smooth-significance table (#1101): the Wood
1409/// rank-truncated Wald test for one `(active class, smooth term)` pair.
1410#[derive(Debug, Clone)]
1411pub struct MultinomialSmoothSignificance {
1412 pub class_label: String,
1413 pub term_label: String,
1414 pub edf: f64,
1415 pub ref_df: f64,
1416 pub statistic: f64,
1417 pub p_value: f64,
1418}
1419
1420/// One-hot-encode the categorical response column and return both the
1421/// encoding and the captured level names. The level order matches the order
1422/// recorded in the dataset schema, which is the canonical (lexicographically
1423/// sorted) factor order produced by inferred-schema construction (#1319) — so
1424/// it is a deterministic function of the label *set*, independent of training
1425/// row order (no silent class permutation under a row shuffle), and matches the
1426/// R `factor()` / pandas `Categorical` convention.
1427fn one_hot_categorical_response(
1428 data: &EncodedDataset,
1429 y_col: usize,
1430 response_name: &str,
1431) -> Result<(Array2<f64>, Vec<String>), EstimationError> {
1432 let levels: Vec<String> = data
1433 .schema
1434 .columns
1435 .get(y_col)
1436 .map(|sc| sc.levels.clone())
1437 .unwrap_or_default();
1438 if levels.len() < 2 {
1439 crate::bail_invalid_estim!(
1440 "multinomial response '{response_name}' must have at least 2 categorical levels (got {})",
1441 levels.len()
1442 );
1443 }
1444 let n = data.values.nrows();
1445 let k = levels.len();
1446 let mut y_one_hot = Array2::<f64>::zeros((n, k));
1447 for row in 0..n {
1448 let encoded = data.values[[row, y_col]];
1449 if !encoded.is_finite() {
1450 crate::bail_invalid_estim!(
1451 "multinomial response '{response_name}' row {row} is non-finite ({encoded})"
1452 );
1453 }
1454 let class_idx = encoded.round() as i64;
1455 if class_idx < 0 || (class_idx as usize) >= k {
1456 crate::bail_invalid_estim!(
1457 "multinomial response '{response_name}' row {row} encoded as {encoded} \
1458 is outside the level range 0..{k}"
1459 );
1460 }
1461 y_one_hot[[row, class_idx as usize]] = 1.0;
1462 }
1463 Ok((y_one_hot, levels))
1464}
1465
1466/// Build `(TermCollectionSpec, TermCollectionDesign)` from a formula against
1467/// a categorical-response dataset. Mirrors the early scaffolding inside
1468/// `materialize_standard` (response role resolution, geometry-aware spec
1469/// build) without touching the scalar-family resolution path — multinomial
1470/// owns its own response kind check.
1471fn build_formula_design_for_multinomial(
1472 formula: &str,
1473 data: &EncodedDataset,
1474 config: &FitConfig,
1475) -> Result<
1476 (
1477 TermCollectionSpec,
1478 TermCollectionDesign,
1479 usize,
1480 String,
1481 ResponseColumnKind,
1482 ),
1483 EstimationError,
1484> {
1485 let parsed = parse_formula(formula).map_err(|err| {
1486 EstimationError::InvalidInput(format!(
1487 "multinomial fit: failed to parse formula {formula:?}: {err}"
1488 ))
1489 })?;
1490 let col_map = data.column_map();
1491 let y_col = resolve_role_col(&col_map, &parsed.response, "response")
1492 .map_err(|err| EstimationError::InvalidInput(format!("multinomial fit: {err}")))?;
1493 let y_kind = crate::fit_orchestration::response_column_kind(data, y_col);
1494 let policy = resolved_resource_policy(config, data, ProblemHints::default());
1495 let mut inference_notes: Vec<String> = Vec::new();
1496 let spec = build_termspec_with_geometry_and_overrides(
1497 &parsed.terms,
1498 data,
1499 &col_map,
1500 &mut inference_notes,
1501 config.scale_dimensions,
1502 &policy,
1503 config.smooth_overrides.as_ref(),
1504 )
1505 .map_err(|err| {
1506 EstimationError::InvalidInput(format!("multinomial fit: build termspec: {err}"))
1507 })?;
1508 let design = build_term_collection_design(data.values.view(), &spec).map_err(|err| {
1509 EstimationError::InvalidInput(format!("multinomial fit: build design: {err}"))
1510 })?;
1511 Ok((spec, design, y_col, parsed.response, y_kind))
1512}
1513
1514fn scale_multinomial_formula_penalty(penalty: PenaltyMatrix, scale: f64) -> PenaltyMatrix {
1515 match penalty {
1516 PenaltyMatrix::Dense(matrix) => PenaltyMatrix::Dense(matrix.mapv(|v| v * scale)),
1517 PenaltyMatrix::KroneckerFactored { left, right } => PenaltyMatrix::KroneckerFactored {
1518 left: left.mapv(|v| v * scale),
1519 right,
1520 },
1521 PenaltyMatrix::Blockwise {
1522 local,
1523 col_range,
1524 total_dim,
1525 } => PenaltyMatrix::Blockwise {
1526 local: local.mapv(|v| v * scale),
1527 col_range,
1528 total_dim,
1529 },
1530 PenaltyMatrix::Labeled { label, inner } => PenaltyMatrix::Labeled {
1531 label,
1532 inner: Box::new(scale_multinomial_formula_penalty(*inner, scale)),
1533 },
1534 PenaltyMatrix::Fixed { log_lambda, inner } => PenaltyMatrix::Fixed {
1535 log_lambda,
1536 inner: Box::new(scale_multinomial_formula_penalty(*inner, scale)),
1537 },
1538 }
1539}
1540
1541/// Build a warm-started copy of `blocks` whose per-block `initial_log_lambdas`
1542/// are seeded from a previously-selected flat `log_lambdas` vector (#1082).
1543///
1544/// The flat `log_lambdas` returned by [`fit_custom_family_with_rho_prior`]
1545/// concatenates each block's penalty log-λ in block order — the same order
1546/// `build_block_specs()` emits the blocks and the same per-block penalty order
1547/// the spec carries — so it splits back across blocks by each block's penalty
1548/// count. Warm-starting the OUTER ρ-search from a prior iterate changes only the
1549/// optimizer's starting point, never the penalized objective or its optimum, so
1550/// the converged fit is identical; it just resumes near the prior iterate
1551/// instead of restarting from the cold `init_lambda` seed.
1552///
1553/// Returns `None` (caller falls back to the cold blocks) if the flat vector does
1554/// not have exactly one entry per penalty across all blocks, or carries a
1555/// non-finite value — i.e. anything that would make the seed unsafe.
1556fn warm_start_blocks_from_log_lambdas(
1557 blocks: &[crate::custom_family::ParameterBlockSpec],
1558 log_lambdas: &[f64],
1559) -> Option<Vec<crate::custom_family::ParameterBlockSpec>> {
1560 let total: usize = blocks.iter().map(|b| b.initial_log_lambdas.len()).sum();
1561 if total == 0 || log_lambdas.len() != total {
1562 return None;
1563 }
1564 if log_lambdas.iter().any(|v| !v.is_finite()) {
1565 return None;
1566 }
1567 let mut warm = blocks.to_vec();
1568 let mut offset = 0usize;
1569 for block in warm.iter_mut() {
1570 let k = block.initial_log_lambdas.len();
1571 for slot in 0..k {
1572 block.initial_log_lambdas[slot] = log_lambdas[offset + slot];
1573 }
1574 offset += k;
1575 }
1576 Some(warm)
1577}
1578
1579/// Top-level formula-driven multinomial fit.
1580///
1581/// Routes through [`fit_custom_family_with_rho_prior`] so the per-active-class
1582/// smoothing parameters `λ_a` (one per class block, shared-penalty
1583/// architecture) are selected by the outer REML/LAML loop rather than pinned
1584/// by the caller. `init_lambda` survives as a warm-start hint that seeds
1585/// every block's `initial_log_lambdas`. `max_iter` / `tol` drive the OUTER
1586/// REML/LAML smoothing-parameter search (`outer_max_iter` / `outer_tol`); the
1587/// inner joint-Newton solve runs on the framework's principled production cycle
1588/// budget at the default KKT tolerance so an ill-conditioned, LM-damped
1589/// near-simplex-boundary solve can certify a stationary point instead of being
1590/// declared non-converged after only `max_iter` cycles (#715).
1591///
1592/// The Jeffreys/Firth proper prior is engaged CONDITIONALLY: attempt 1 runs
1593/// the unbiased penalized-REML criterion; only on separation evidence (a failed
1594/// solve or a non-finite logit; see [`multinomial_formula_separation_evidence`])
1595/// is the fit re-solved once with the full-span Firth prior armed, which bounds
1596/// the penalty-null directions no smoothing parameter can (`S v = 0` ⇒
1597/// `(H + S_λ) v = H v → 0` when the softmax likelihood has no finite mode).
1598///
1599/// The categorical response column is recognised via the dataset schema
1600/// (`ColumnKindTag::Categorical`); reference class = last level. Returns a
1601/// [`MultinomialSavedModel`] that can be serialised to bytes for the Python
1602/// wrapper or used in-process for `predict_probabilities`.
1603pub fn fit_penalized_multinomial_formula(
1604 data: &EncodedDataset,
1605 formula: &str,
1606 config: &FitConfig,
1607 init_lambda: f64,
1608 max_iter: usize,
1609 tol: f64,
1610) -> Result<MultinomialSavedModel, EstimationError> {
1611 if !(init_lambda.is_finite() && init_lambda > 0.0) {
1612 crate::bail_invalid_estim!(
1613 "multinomial fit: init_lambda must be finite and > 0 (got {init_lambda})"
1614 );
1615 }
1616 let (raw_spec, design, y_col, response_name, y_kind) =
1617 build_formula_design_for_multinomial(formula, data, config)?;
1618 // Freeze the data-derived basis state (B-spline knot vectors, by-factor
1619 // level sets, spatial centers, joint-null rotations, residualization
1620 // charts) from the fit design back onto the spec. The raw geometry spec
1621 // records only *which* columns and *what kind* of basis each smooth uses;
1622 // the actual column count and basis evaluation depend on quantities the
1623 // builder derives from the training data (knot placement, the distinct
1624 // by-factor levels, etc.). Saving the raw spec made predict re-derive those
1625 // from the (smaller, differently-distributed) predict frame, so the rebuilt
1626 // design had a different column count than the fitted one — the panic
1627 // "predict design has 42 cols, saved model expects 191" for an `s(x,
1628 // by=group)` smooth-by-factor model. Every other family's persistence path
1629 // freezes the spec the same way (see `freeze_term_collection_from_design`
1630 // call sites in `main_parts`); multinomial was the lone exception.
1631 let spec = freeze_term_collection_from_design(&raw_spec, &design)?;
1632 let class_levels = match y_kind {
1633 ResponseColumnKind::Categorical { levels } => levels,
1634 ResponseColumnKind::Binary => vec!["0".to_string(), "1".to_string()],
1635 ResponseColumnKind::Numeric => {
1636 crate::bail_invalid_estim!(
1637 "multinomial fit: response '{response_name}' is numeric, not categorical; \
1638 use family='gaussian'/'binomial'/... or convert the column to a categorical type"
1639 );
1640 }
1641 };
1642 if data.column_kinds.get(y_col) == Some(&ColumnKindTag::Binary) {
1643 // Promote to a 2-level categorical for the multinomial driver; the
1644 // caller explicitly asked for multinomial, so we route through the
1645 // K-1 = 1 active-class softmax (equivalent math to logistic).
1646 } else if data.column_kinds.get(y_col) != Some(&ColumnKindTag::Categorical) {
1647 crate::bail_invalid_estim!(
1648 "multinomial fit: response '{response_name}' must be a categorical column \
1649 (got column kind {:?})",
1650 data.column_kinds.get(y_col)
1651 );
1652 }
1653 let (y_one_hot, _) = one_hot_categorical_response(data, y_col, &response_name)?;
1654 // Build the global X dense (the design is a DesignMatrix abstraction).
1655 let mut x_dense = design
1656 .design
1657 .try_to_dense_by_chunks("multinomial fit design")
1658 .map_err(EstimationError::InvalidInput)?;
1659
1660 // ── #715 real-data conditioning: standardize unpenalized parametric
1661 // columns. Raw-unit linear covariates (penguins `body_mass_g` ~ 4e3 grams)
1662 // inflate the joint Newton information by the squared column scale (a κ(H)
1663 // multiplier of ~s² ≈ 1e7 against the intercept), which is what turns the
1664 // near-separable LM-damped inner solve into a geometric grind that
1665 // exhausts its cycle budgets — the adapter-level face of "all REML startup
1666 // seeds rejected". Because these columns are UNPENALIZED (parametric terms
1667 // carry no default ridge, #749), the affine reparameterization
1668 // `x_j ↦ (x_j − m_j)/s_j` is EXACT for the whole criterion: the optimized
1669 // REML/LAML objective, the fitted η, the selected λ, and the separation
1670 // diagnostics are all invariant — only the conditioning of `H` changes.
1671 // Fitted coefficients are mapped back to raw units at repack below, so the
1672 // saved model and the (raw-design) predict path are untouched. Penalized
1673 // columns are left alone (a penalty makes the rescaling non-equivalent),
1674 // and nothing is touched when explicit coefficient bounds/constraints
1675 // exist (those are stated in raw units).
1676 let parametric_standardization: Vec<(usize, f64, f64)> =
1677 if design.coefficient_lower_bounds.is_some() || design.linear_constraints.is_some() {
1678 Vec::new()
1679 } else {
1680 let p_total = x_dense.ncols();
1681 let mut penalized = vec![false; p_total];
1682 for bp in &design.penalties {
1683 for col in bp.col_range.clone() {
1684 if col < p_total {
1685 penalized[col] = true;
1686 }
1687 }
1688 }
1689 let has_intercept = !design.intercept_range.is_empty();
1690 let n_rows = x_dense.nrows().max(1) as f64;
1691 let mut standardized = Vec::new();
1692 for (_, range) in &design.linear_ranges {
1693 for col in range.clone() {
1694 if col >= p_total || penalized[col] {
1695 continue;
1696 }
1697 let column = x_dense.column(col);
1698 let mean = column.sum() / n_rows;
1699 let var = column.iter().map(|v| (v - mean) * (v - mean)).sum::<f64>() / n_rows;
1700 let scale = var.sqrt();
1701 // Skip near-constant or degenerate columns: no conditioning to
1702 // be gained and the back-map would divide by ~0.
1703 if !(scale.is_finite() && scale > 1e-8 * (mean.abs() + 1.0)) {
1704 continue;
1705 }
1706 // Centering shifts mass onto the intercept; without one the
1707 // shift is not representable, so scale only.
1708 let center = if has_intercept { mean } else { 0.0 };
1709 for v in x_dense.column_mut(col).iter_mut() {
1710 *v = (*v - center) / scale;
1711 }
1712 standardized.push((col, center, scale));
1713 }
1714 }
1715 standardized
1716 };
1717 // Preserve the per-smooth-term penalty block structure (#561): each smooth
1718 // term `t` contributes its own `P × P` penalty component (`Blockwise` with
1719 // `total_dim = P`, the term's local `S_t` embedded at its `col_range`), and
1720 // every active class block receives the FULL list. The outer REML/LAML loop
1721 // then selects an independent smoothing parameter λ_{a,t} per (class, term),
1722 // matching mgcv/VGAM. Pre-summing the terms into one fused `S` (the prior
1723 // behaviour) forced a single λ per class that scales `Σ_t S_t`, so one
1724 // shared λ had to over-smooth a rough term while under-smoothing a smooth
1725 // one — biasing any multi-term class-probability surface.
1726 let k = y_one_hot.ncols();
1727 let m = k - 1;
1728 let n_obs = y_one_hot.nrows();
1729 let penalty_scale = multinomial_formula_penalty_scale(k);
1730 let per_term_penalties: Vec<PenaltyMatrix> = design
1731 .penalties_as_penalty_matrix()
1732 .into_iter()
1733 .map(|penalty| scale_multinomial_formula_penalty(penalty, penalty_scale))
1734 .collect();
1735 let per_term_nullspace_dims = design.nullspace_dims.clone();
1736
1737 // ── Custom-family driven REML/LAML path ───────────────────────────────
1738 // Each active class becomes one ParameterBlockSpec, all sharing X and the
1739 // per-term penalty list. `initial_log_lambdas` is seeded from the caller's
1740 // `init_lambda` (one entry per term).
1741 let design_arc = Arc::new(x_dense);
1742 let penalties_arc = Arc::new(per_term_penalties);
1743 let nullspace_dims_arc = Arc::new(per_term_nullspace_dims);
1744 let weights = Array1::<f64>::ones(n_obs);
1745 // First attempt runs the UNBIASED penalized-REML criterion (no Firth
1746 // shrinkage toward the uniform simplex); the Jeffreys/Firth proper prior is
1747 // armed conditionally below, only on separation evidence (#715/#753 — see
1748 // `multinomial_formula_separation_evidence`).
1749 let log_init = init_lambda.ln();
1750 let family = MultinomialFamily::new(
1751 y_one_hot.clone(),
1752 weights,
1753 k,
1754 design_arc.clone(),
1755 penalties_arc.clone(),
1756 nullspace_dims_arc.clone(),
1757 )
1758 .map_err(EstimationError::InvalidInput)?
1759 .with_joint_jeffreys_term(false)
1760 // gam#1587: the per-block smooth penalties are emptied (the centered `M⊗S_t`
1761 // joint penalty is the sole smoothing carrier), so the `init_lambda` warm
1762 // start must seed the JOINT penalty's `initial_log_lambda` — the per-block
1763 // `initial_log_lambdas` loop below is now a no-op (empty per-block list).
1764 .with_initial_log_lambda(log_init);
1765 let mut blocks = family.build_block_specs();
1766 for spec_block in blocks.iter_mut() {
1767 for v in spec_block.initial_log_lambdas.iter_mut() {
1768 *v = log_init;
1769 }
1770 }
1771
1772 // ── Outer-derivative policy: dimension-gated exact curvature ────────────
1773 // The total smoothing-parameter dimension is `D = (K−1) · n_terms`.
1774 // Medium-D formula fits need exact curvature to keep lambda selection away
1775 // from over-smoothed caps, while smooth-by-factor `D = 8` models still avoid
1776 // the O(D²) dense Hessian path.
1777 let total_rho_dim = m.saturating_mul(penalties_arc.len());
1778 let use_outer_hessian = multinomial_formula_use_outer_hessian(total_rho_dim);
1779
1780 // ── Inner-vs-outer control split (#715 non-convergence root cause) ────────
1781 // The legacy `max_iter` / `tol` parameters are the *outer* REML/LAML
1782 // smoothing-parameter optimization controls — "how hard to search λ". The
1783 // earlier wiring routed them straight into `inner_max_cycles` / `inner_tol`,
1784 // capping the joint-Newton inner solve at `max_iter` (=50 in the quality
1785 // suite) cycles with a `tol`-tight (=1e-8) KKT target. That is the #715
1786 // hang: near the simplex boundary the softmax Fisher weight
1787 // `W = diag(p) − p pᵀ` collapses, so `H = JᵀWJ + S_λ` is full-rank but
1788 // ILL-CONDITIONED. The self-vanishing Levenberg–Marquardt damping
1789 // (`levenberg_on_ill_conditioning()`) that keeps the inner solve from
1790 // oscillating on those near-singular modes makes it converge only
1791 // GEOMETRICALLY (linearly), not quadratically. Reaching a 1e-8 relative KKT
1792 // residual under geometric descent needs FAR more than 50 cycles, so the
1793 // inner returned `converged = false` on every outer ρ-evaluation; with the
1794 // exact-Hessian outer optimizer on `FallbackPolicy::Disabled` that rejects
1795 // every ρ-step — each rejected eval still paying a near-full 50-cycle inner
1796 // solve plus the O(D²) pairwise outer-Hessian directional work — so the
1797 // outer never certifies and the fit runs unbounded (the observed >8-minute
1798 // non-termination). The certificate cannot be reached, not merely slow.
1799 //
1800 // Fix: give the INNER joint-Newton the framework's principled production
1801 // budget (`DEFAULT_CUSTOM_FAMILY_INNER_MAX_CYCLES` cycles at the default
1802 // `inner_tol`), which exists precisely so an ill-conditioned LM-damped solve
1803 // can certify a stationary KKT point instead of being declared non-converged
1804 // prematurely — and the KKT/objective certificates still exit in a handful
1805 // of cycles on the well-conditioned interior fits, so this is free there.
1806 // The caller's `max_iter` / `tol` become the OUTER controls they were always
1807 // meant to be (smoothing-parameter search depth / accuracy). The inner KKT
1808 // target is kept no tighter than the outer accuracy can consume — and no
1809 // tighter than the softmax objective's f64 noise floor on near-separable
1810 // fits (see `MULTINOMIAL_FORMULA_INNER_TOL`).
1811 let outer_max_iter = max_iter.max(1);
1812 // The OUTER REML/LAML smoothing-parameter search must converge to a
1813 // well-calibrated ρ-gradient tolerance, NOT to the caller's (typically very
1814 // tight) INNER KKT tolerance. The #715 control-split repurposed the caller's
1815 // `tol` as the outer control, but feeding an inner-scale `tol = 1e-8`
1816 // straight into `outer_tol` makes REML grind dozens of extra exact-gradient
1817 // outer iterations (each an O(D·p³) Laplace-derivative assembly over the full
1818 // P·M joint design) to squeeze ρ digits that no longer move the fitted
1819 // surface — the smooth-by-factor 269s wall-clock overrun (#1082).
1820 //
1821 // The right target is the framework's CALIBRATED REML convergence tolerance,
1822 // `MULTINOMIAL_OUTER_REML_TOL = 1e-7` — the same value the primary GLM REML
1823 // outer uses (`solver::fit_orchestration::materialize` `tol: 1e-7`, mirrored by the
1824 // `LOG_LAMBDA_TOL`/`KKT_TOL_*` constants across the REML stack). At 1e-7 the
1825 // λ-search reaches the genuine REML optimum (so the recovered probability
1826 // surface matches the mature reference), but it does NOT chase the last
1827 // surface-irrelevant ρ digits down to 1e-8. The earlier 1e-5 floor (the
1828 // generic `BlockwiseFitOptions` default) was too LOOSE: the optimizer halted
1829 // in a low-curvature region with λ still well above its optimum, UNDER-fitting
1830 // the smooth-by-factor surface (truth-RMSE 0.164 vs VGAM's 0.061). So the
1831 // outer tolerance is floored at the calibrated REML tol — never tighter than
1832 // it (perf), never looser (accuracy) — while the caller's `tol` continues to
1833 // drive the INNER joint-Newton KKT target (`inner_tol` below), where its
1834 // precision actually matters.
1835 let outer_tol = if tol.is_finite() && tol > 0.0 {
1836 tol.max(MULTINOMIAL_OUTER_REML_TOL)
1837 } else {
1838 MULTINOMIAL_OUTER_REML_TOL
1839 };
1840 // #1082 root cause: the outer convergence test derives BOTH the absolute
1841 // projected-gradient floor (`max(outer_tol, n·1e-9)`) AND the relative-cost
1842 // stop (`rel_cost = outer_tol`) from the single `outer_tol`. The accuracy of
1843 // the smooth-by-factor surface is governed by the ABSOLUTE floor reaching the
1844 // n-scaled REML resolution `n·1e-9` (≈ 1.8e-6 at n = 1800) — that is why the
1845 // earlier 1e-5 floor UNDER-fit (its absolute floor was pinned at 1e-5, well
1846 // above the genuine optimum's gradient) and why 1e-7 recovered accuracy (it
1847 // unpins the floor down to the n-scaled 1.8e-6). But tightening `outer_tol`
1848 // to 1e-7 ALSO tightened the rel-cost stop to 1e-7, which on this family's
1849 // dead-flat REML ridge NEVER trips — so the optimizer no longer converges and
1850 // grinds all the way to `outer_max_iter`, each surplus step an O(D·p³) Laplace-
1851 // derivative assembly over the 382-dim joint design (the >600s wall-clock
1852 // overrun; tightening tol REINTRODUCED the crawl the 1e-5 floor had removed).
1853 //
1854 // The two requirements live on two different criteria, so they must be set
1855 // independently. Keep `outer_tol = 1e-7` (drives the accurate absolute floor)
1856 // but FLOOR the relative-cost stop at the framework default 1e-5 (the loose,
1857 // fast value that resolves the cost-decrease plateau without chasing the flat
1858 // tail). The absolute n·1e-9 floor still gates final λ accuracy; the rel-cost
1859 // stop just lets the optimizer DECLARE convergence on the flat ridge instead
1860 // of crawling to the iteration cap.
1861 let outer_rel_cost_tol = Some(BlockwiseFitOptions::default().outer_tol);
1862 let inner_tol = MULTINOMIAL_FORMULA_INNER_TOL.max(tol.max(0.0));
1863
1864 let options = BlockwiseFitOptions {
1865 inner_max_cycles: crate::custom_family::DEFAULT_CUSTOM_FAMILY_INNER_MAX_CYCLES,
1866 inner_tol,
1867 outer_max_iter,
1868 outer_tol,
1869 outer_rel_cost_tol,
1870 rho_lower_bound: multinomial_formula_min_lambda(y_one_hot.view()).ln(),
1871 ridge_floor: MULTINOMIAL_FORMULA_RIDGE_FLOOR,
1872 // #747: the stabilization floor is SOLVER-ONLY — it keeps the inner
1873 // joint-Newton linear solve finite during screening (bounding the step
1874 // `(H+δI)⁻¹∇` away from a near-separable, rank-deficient curvature) but
1875 // is excluded from the REML objective, the penalty log-determinant, and
1876 // the Laplace Hessian. The earlier default (`explicit_stabilization_pospart`)
1877 // folded `½·δ·‖β‖²` and a `δ`-shift of the log-determinant into the
1878 // criterion, shrinking every identified coefficient off the MLE and
1879 // perturbing smoothing-parameter selection — a fixed-λ prior masking
1880 // separation, not a numerical stabilizer. With the floor solver-only the
1881 // optimized objective is the true penalized REML criterion (value tracks
1882 // its analytic gradient), and the smooth directions remain governed
1883 // solely by their own REML-selected `λ`.
1884 ridge_policy: gam_problem::RidgePolicy::solver_only(),
1885 use_outer_hessian,
1886 // #715 real-data arm ("canonical-gauge null direction rejects all REML
1887 // seeds"): skip the multi-seed outer screening cascade and let the
1888 // pinned `init_lambda` ρ flow straight to the outer optimizer.
1889 //
1890 // The multinomial family declares `levenberg_on_ill_conditioning() ->
1891 // true`: near the simplex boundary (the near-separable penguins regime)
1892 // the softmax Fisher weight `W = diag(p) − p pᵀ → 0`, so the joint
1893 // information `H = JᵀWJ + S_λ` can become full-rank but
1894 // ILL-CONDITIONED. The self-vanishing LM damping that keeps the inner
1895 // joint-Newton from oscillating on those near-singular modes converges
1896 // only GEOMETRICALLY. The default screening policy ranks candidate seeds
1897 // with a 2-cycle inner cap (`outer_seed_config`); under geometric
1898 // LM-damped descent two cycles never reach a finite, meaningful proxy
1899 // objective, so EVERY capped seed can collapse to non-finite cost and
1900 // the cascade escalates to ×4, ×16, then an UNCAPPED full inner solve
1901 // PER SEED on the near-singular Hessian. That is the adapter-level face
1902 // of "all REML startup seeds rejected" and the multi-minute timeout.
1903 //
1904 // The pinned seed is already principled here: `init_lambda` gives every
1905 // (class, term) ρ a sensible moderate warm start, and the per-term
1906 // effective-df-floor upper bounds (`effective_df_floor_rho_upper_bounds`,
1907 // #715 arm (a)) keep any λ from collapsing the smooth onto its polynomial
1908 // null space. So the outer ARC/BFGS optimizer performs the real REML ρ
1909 // search from this seed; screening only adds the cascade cost and, on the
1910 // near-separable arm, the rejection stall.
1911 screen_initial_rho: false,
1912 // #1101: compute the joint Laplace posterior covariance `H⁻¹` (and the
1913 // influence matrix `F = H⁻¹ X'WX`) at the converged mode so the saved
1914 // model can surface delta-method per-class probability standard errors
1915 // and Wald smooth-term p-values. The driver factorizes the penalized
1916 // Hessian during the inner solve regardless; this only asks it to keep
1917 // and invert the factor instead of discarding it.
1918 compute_covariance: true,
1919 ..BlockwiseFitOptions::default()
1920 };
1921 // ── Conditional Firth/Jeffreys engagement (#715 arm (b) / #753) ──────────
1922 // Attempt 1: the unbiased criterion (Jeffreys disarmed above). If the
1923 // returned mode is converged, finite, and interior, it is the exact penalized-REML
1924 // optimum with zero Firth bias — accept it (this is the synthetic-arm /
1925 // interior-data path, #715 arm (a)). If the solve FAILS (e.g. the
1926 // (quasi-)separated penguins geometry where `(H + S_λ)v ≈ 0` along
1927 // penalty-null directions for EVERY ρ rejects every REML startup seed) or
1928 // returns a non-finite artifact, that is direct separation evidence:
1929 // re-solve once with the full-span Jeffreys/Firth proper prior armed, which
1930 // supplies the O(1) curvature on the quotient-null subspace that smoothing
1931 // parameters mathematically cannot (`Sv = 0` ⇒ λ never touches `v`). The
1932 // Firth refit is the accepted result only when the unbiased formula solve
1933 // failed, did not converge on its full budget, or blew up; finite
1934 // formula-path logits can be large on valid near-separated optima and
1935 // should not be shrunk toward the uniform simplex once the unbiased outer
1936 // solve has actually certified.
1937 let mut unbiased_probe_options = options.clone();
1938 unbiased_probe_options.outer_max_iter = unbiased_probe_options
1939 .outer_max_iter
1940 .min(MULTINOMIAL_UNBIASED_PROBE_OUTER_MAX_ITER);
1941 // The FINAL accepted Firth/Jeffreys refit runs to the caller's full outer
1942 // budget: it is the result we ship, so it must reach the genuine REML
1943 // optimum, not a truncated iterate. The near-separable penguin refit that
1944 // motivated #1082's wall-clock concern is now halted honestly at its true
1945 // bound optimum by the KKT-stationary-at-bound guard
1946 // (`CostStallGuard`, #1082 / 64711ed82) and the Newton-decrement residual
1947 // certificate (363af9b56 / 2c9580b1f): on separable data the outer ARC
1948 // certifies and stops early on its own, so no artificial iteration cap is
1949 // needed to land in budget. On non-separable data (e.g. the
1950 // `vgam_smooth_by_factor` double-penalty arm) the refit needs the caller's
1951 // full budget to converge, which a `.min(20)` cap would cut off — accepting
1952 // a non-converged fit, which is dishonest. So the refit keeps `options`
1953 // unchanged. Only the discarded unbiased separation probe above is capped.
1954 let firth_refit_options = &options;
1955
1956 let run_firth_refit = |evidence: String| {
1957 let firth_family = family.clone().with_joint_jeffreys_term(true);
1958 fit_custom_family_with_rho_prior(
1959 &firth_family,
1960 &blocks,
1961 firth_refit_options,
1962 gam_problem::RhoPrior::Flat,
1963 )
1964 .map_err(|err| {
1965 EstimationError::InvalidInput(format!(
1966 "multinomial REML: Firth/Jeffreys-armed refit (separation evidence: \
1967 {evidence}) failed: {err}"
1968 ))
1969 })
1970 };
1971
1972 // #1082: the capped unbiased probe and the (separable-path) Firth decision
1973 // are driven by separation scans over the full P×M logit block. The previous
1974 // match recomputed `multinomial_formula_separation_evidence` /
1975 // `..._unresolved_probe_separation_evidence` in BOTH the match guard AND the
1976 // arm body — three to four full logit walks per fit, paid on the hot
1977 // near-separable penguin path where this branch fires every iterate. Run the
1978 // probe once, evaluate each scan once into a binding, and branch on the
1979 // precomputed results. Behaviour is identical (same scans, same order of
1980 // precedence: converged-interior, unresolved-probe-separation,
1981 // no-separation-needs-full-solve, otherwise-Firth); only the duplicate
1982 // O(n·classes) scans are removed.
1983 let probe_attempt = fit_custom_family_with_rho_prior(
1984 &family,
1985 &blocks,
1986 &unbiased_probe_options,
1987 gam_problem::RhoPrior::Flat,
1988 );
1989 let fit = match probe_attempt {
1990 Ok(probe_fit) => {
1991 let separation = multinomial_formula_separation_evidence(&probe_fit.block_states);
1992 if probe_fit.outer_converged && separation.is_none() {
1993 // Interior, converged, no separation: accept the probe directly.
1994 probe_fit
1995 } else if let Some(evidence) =
1996 multinomial_formula_unresolved_probe_separation_evidence(&probe_fit.block_states)
1997 {
1998 // Non-converged probe already carrying separation-scale logits:
1999 // hand straight to the proper-prior Firth refit (do not spend the
2000 // full unbiased budget grinding the λ→0 separable ridge).
2001 run_firth_refit(format!(
2002 "unbiased-criterion REML probe did not converge after {} outer iterations; {evidence}",
2003 probe_fit.outer_iterations
2004 ))?
2005 } else if separation.is_none() {
2006 // Interior but the capped probe ran out of iterations without
2007 // certifying: re-solve at the caller's full outer budget.
2008 //
2009 // #1082 wall-clock: the capped probe is a strict prefix of this
2010 // solve from the same family/seed, so a COLD restart repeats the
2011 // probe's outer iterations. WARM-START the re-solve from the ρ the
2012 // probe already reached — seed each block's `initial_log_lambdas`
2013 // from the probe's selected `log_lambdas` (same block/penalty
2014 // order: the flat vector concatenates per-block penalties in block
2015 // order, exactly the order `build_block_specs()` emits them). This
2016 // changes only the optimizer's STARTING point, never the objective
2017 // or its optimum, but lets the full solve resume near the probe's
2018 // last iterate instead of crawling up from `init_lambda` again —
2019 // removing the probe-iterations double-pay on the non-separable
2020 // (e.g. `vgam_smooth_by_factor`) arm. If the probe's λ vector does
2021 // not line up with the block layout (it always should), fall back
2022 // to the cold `blocks` seed.
2023 let warm_blocks = warm_start_blocks_from_log_lambdas(
2024 &blocks,
2025 probe_fit.log_lambdas.as_slice().unwrap_or(&[]),
2026 );
2027 let resolve_blocks = warm_blocks.as_deref().unwrap_or(&blocks);
2028 match fit_custom_family_with_rho_prior(
2029 &family,
2030 resolve_blocks,
2031 &options,
2032 gam_problem::RhoPrior::Flat,
2033 ) {
2034 Ok(full_unbiased_fit) => {
2035 let full_separation = multinomial_formula_separation_evidence(
2036 &full_unbiased_fit.block_states,
2037 );
2038 if full_unbiased_fit.outer_converged && full_separation.is_none() {
2039 full_unbiased_fit
2040 } else {
2041 let evidence = full_separation.unwrap_or_else(|| {
2042 format!(
2043 "full unbiased-criterion REML solve did not converge after {} outer iterations",
2044 full_unbiased_fit.outer_iterations
2045 )
2046 });
2047 run_firth_refit(evidence)?
2048 }
2049 }
2050 Err(err) => run_firth_refit(format!(
2051 "full unbiased-criterion REML solve failed: {err}"
2052 ))?,
2053 }
2054 } else {
2055 // Probe converged (or capped) but shows interior separation
2056 // evidence: Firth refit using the already-computed scan.
2057 let evidence = separation.unwrap_or_else(|| {
2058 format!(
2059 "unbiased-criterion REML probe did not converge after {} outer iterations",
2060 probe_fit.outer_iterations
2061 )
2062 });
2063 run_firth_refit(evidence)?
2064 }
2065 }
2066 Err(err) => run_firth_refit(format!("unbiased-criterion REML solve failed: {err}"))?,
2067 };
2068 if let Some(err) = multinomial_formula_separation_diagnostic(
2069 fit.inner_cycles,
2070 fit.outer_iterations,
2071 &fit.block_states,
2072 ) {
2073 return Err(err);
2074 }
2075
2076 // ── Repack coefficients (P, K-1) from per-block β vectors ─────────────
2077 if fit.blocks.len() != m {
2078 crate::bail_invalid_estim!(
2079 "multinomial REML: expected {m} fitted blocks (K-1), got {}",
2080 fit.blocks.len()
2081 );
2082 }
2083 let p_per_class = fit.blocks[0].beta.len();
2084 let mut coefficients_active = Array2::<f64>::zeros((p_per_class, m));
2085 for (a, block) in fit.blocks.iter().enumerate() {
2086 if block.beta.len() != p_per_class {
2087 crate::bail_invalid_estim!(
2088 "multinomial REML: block {a} has {} coefs, expected {p_per_class}",
2089 block.beta.len()
2090 );
2091 }
2092 for i in 0..p_per_class {
2093 coefficients_active[[i, a]] = block.beta[i];
2094 }
2095 }
2096 // Map the standardized-column coefficients back to raw units (the exact
2097 // inverse of the conditioning reparameterization above): β_raw = b/s, with
2098 // the centering mass `Σ_j b_j·m_j/s_j` returned to the intercept.
2099 if !parametric_standardization.is_empty() {
2100 let intercept_col = design.intercept_range.clone().next();
2101 for a in 0..m {
2102 let mut intercept_adjust = 0.0;
2103 for &(col, center, scale) in ¶metric_standardization {
2104 if col < p_per_class {
2105 let raw = coefficients_active[[col, a]] / scale;
2106 coefficients_active[[col, a]] = raw;
2107 intercept_adjust += raw * center;
2108 }
2109 }
2110 if let Some(i0) = intercept_col
2111 && i0 < p_per_class
2112 {
2113 coefficients_active[[i0, a]] -= intercept_adjust;
2114 }
2115 }
2116 }
2117 // Flatten every (class, term) smoothing parameter in block-major order
2118 // (class 0's terms, then class 1's, …). With per-term penalties each block
2119 // now carries one λ per smooth term, so a single λ per class would discard
2120 // the independent per-term selection that fixes #561. `lambdas_per_block`
2121 // segments the flat vector by class so callers can recover per-term λ.
2122 // ── gam#1587/#561 joint-penalty reconstruction ───────────────────────────
2123 // Under the #1587 centered-metric architecture every active class block
2124 // leaves its per-block penalty list EMPTY — the entire fit's smoothing rides
2125 // on a single full-width JOINT penalty `S_λ = Σ_t λ_t (M ⊗ S_t)` whose one
2126 // shared `λ_t` per smooth component is selected by the outer REML loop and
2127 // surfaced on `fit.artifacts.joint_log_lambdas`. So `fit.blocks[a].lambdas`
2128 // is `[]`, the inference layer's per-block trace channel is empty, and the
2129 // older per-block reporting (`lambdas_per_block = [0, 0]`, `edf_per_class =
2130 // None`, …) collapsed (#561 reopen).
2131 //
2132 // Reconstruct the per-(class, component) λ and the influence-matrix EDF
2133 // directly from the selected joint `λ_t` and the COUPLED penalty
2134 // `S_λ = Σ_t λ_t (M ⊗ S_t)` (NOT a block-diagonal `Σ_t λ_{a,t} S_t`: the
2135 // centered metric `M` couples classes off the block diagonal, so a
2136 // block-diagonal `S_λ` would mis-state both the influence matrix and every
2137 // trace). With `H⁻¹ = fit.covariance_conditional` now assembled WITH the
2138 // joint penalty (the `compute_joint_covariance` fix), the influence matrix is
2139 // exactly `F = I − H⁻¹ S_λ`, its per-class diagonal-block trace is the honest
2140 // per-class EDF, and `Σ_a edf_a = tr(F) = edf_total`.
2141 let joint_recon = fit.artifacts.joint_log_lambdas.as_ref().and_then(|jll| {
2142 let n_components = penalties_arc.len();
2143 if jll.len() != n_components || n_components == 0 {
2144 return None;
2145 }
2146 let expected_joint = p_per_class.saturating_mul(m);
2147 let hinv = fit
2148 .covariance_conditional
2149 .as_ref()
2150 .filter(|c| c.nrows() == expected_joint && c.ncols() == expected_joint)?;
2151 // The coupled joint penalty components `M ⊗ S_t` at the selected `λ_t`,
2152 // in raw stacked (class-major) coordinates — exactly the operator the
2153 // inner solve and the now-fixed covariance path penalize with.
2154 let joint_specs = family.centered_joint_penalty_specs();
2155 if joint_specs.len() != n_components {
2156 return None;
2157 }
2158 let lam: Vec<f64> = jll.iter().map(|&l| l.exp()).collect();
2159 // Per-component `H⁻¹ (M ⊗ S_t)` (full mp×mp), reused for both the joint
2160 // influence matrix and the per-(class, component) trace decomposition.
2161 let mut hinv_st: Vec<Array2<f64>> = Vec::with_capacity(n_components);
2162 for spec in &joint_specs {
2163 if spec.matrix.nrows() != expected_joint || spec.matrix.ncols() != expected_joint {
2164 return None;
2165 }
2166 hinv_st.push(hinv.dot(&spec.matrix));
2167 }
2168 // F = I − H⁻¹ S_λ = I − Σ_t λ_t H⁻¹ (M ⊗ S_t).
2169 let mut f = Array2::<f64>::eye(expected_joint);
2170 for (t, hs) in hinv_st.iter().enumerate() {
2171 f.scaled_add(-lam[t], hs);
2172 }
2173 // Per-class diagonal-block trace of F (the honest per-class EDF), and the
2174 // per-(class, component) penalty trace `tr_{a,t} = λ_t · Σ_{i∈class a}
2175 // (H⁻¹ (M⊗S_t))[i,i]` for the per-penalty EDF rollup.
2176 let mut edf_per_class = Vec::with_capacity(m);
2177 // class-major per-penalty EDF (class 0's components, then class 1's, …),
2178 // aligned 1:1 with the flat per-component λ replicated per class.
2179 let mut edf_per_penalty = Vec::with_capacity(m * n_components);
2180 for a in 0..m {
2181 let base = a * p_per_class;
2182 let mut class_trace = 0.0_f64;
2183 for t in 0..n_components {
2184 let mut tr_at = 0.0_f64;
2185 for i in 0..p_per_class {
2186 tr_at += hinv_st[t][[base + i, base + i]];
2187 }
2188 tr_at *= lam[t];
2189 class_trace += tr_at;
2190 // A single component's per-class trace EDF `rank(S_t) − tr_{a,t}`,
2191 // bounded by its local rank (≤ p_per_class).
2192 let ns_t = nullspace_dims_arc.get(t).copied().unwrap_or(0);
2193 let rank_t = (p_per_class as f64 - ns_t as f64).max(0.0);
2194 edf_per_penalty.push((rank_t - tr_at).clamp(0.0, p_per_class as f64));
2195 }
2196 edf_per_class
2197 .push((p_per_class as f64 - class_trace).clamp(0.0, p_per_class as f64));
2198 }
2199 Some((f, edf_per_class, edf_per_penalty, n_components, lam))
2200 });
2201
2202 // Flatten every (class, component) smoothing parameter in class-major order.
2203 // Under the joint-penalty architecture each active class carries the SAME
2204 // per-component λ set (the centered metric ties `λ_t` across classes for
2205 // reference-class invariance), so the flat vector is the selected `λ_t`
2206 // replicated `K-1` times and `lambdas_per_block = [n_components; K-1]`. When
2207 // the joint reconstruction is unavailable (legacy fixed-λ path or absent
2208 // covariance) fall back to the raw — now empty — per-block λ lists.
2209 let (lambdas_per_block, lambdas_flat): (Vec<usize>, Vec<f64>) = match joint_recon.as_ref() {
2210 Some((_, _, _, n_components, lam)) => {
2211 let per_block = vec![*n_components; m];
2212 let mut flat = Vec::with_capacity(m * n_components);
2213 for _ in 0..m {
2214 flat.extend(lam.iter().copied());
2215 }
2216 (per_block, flat)
2217 }
2218 None => {
2219 let per_block: Vec<usize> = fit.blocks.iter().map(|b| b.lambdas.len()).collect();
2220 let flat: Vec<f64> = fit
2221 .blocks
2222 .iter()
2223 .flat_map(|b| b.lambdas.iter().copied())
2224 .collect();
2225 (per_block, flat)
2226 }
2227 };
2228 // Per-active-class effective degrees of freedom, length `K-1`, summing to
2229 // the model `edf_total`. The REML inference block reports `edf_by_block` as
2230 // ONE entry per *penalty block* (per (class, term, penalty)), each computed
2231 // as `rank(S_kk) − tr(H⁻¹ λ_kk S_kk)`. That per-block sum OVER-COUNTS the
2232 // model EDF whenever several penalties share one coefficient range — a
2233 // double-penalty / te / ti / adaptive smooth has ≥2 penalty blocks over the
2234 // same columns, so `Σ_kk rank(S_kk) > p` and `Σ_kk edf_by_block > edf_total`
2235 // (the observed ~79 for a ~24-coefficient model). Handing that raw per-block
2236 // vector out as the documented length-(K-1) per-class EDF is therefore both
2237 // the wrong LENGTH (it is `Σ_a n_blocks_a`, not `K-1`) and an over-count.
2238 //
2239 // The honest per-class EDF is the influence-matrix trace over each class's
2240 // coefficient block. Classes occupy DISJOINT `p_per_class`-wide coefficient
2241 // ranges, and the per-block traces `tr_kk = tr(H⁻¹ λ_kk S_kk)` are additive
2242 // (no rank double-counting), so class `a`'s EDF is
2243 // `p_per_class − Σ_{kk ∈ class a} tr_kk`, and `Σ_a edf_a = m·p_per_class −
2244 // Σ_kk tr_kk = p − Σ tr_kk = edf_total` exactly. Segment the block-major
2245 // `penalty_block_trace` by `lambdas_per_block` (the same per-class λ-count
2246 // segmentation `lambdas_flat` uses). Fall back to `None` when the trace
2247 // channel is unavailable or mis-shaped (legacy fixed-λ path), exactly as the
2248 // raw `edf_by_block` map did before.
2249 let edf_per_class = joint_recon
2250 .as_ref()
2251 .map(|(_, epc, _, _, _)| epc.clone())
2252 .or_else(|| {
2253 // Legacy per-block trace path (fixed-λ / pre-#1587 fits whose
2254 // smoothing is still carried per block). Segment the block-major
2255 // `penalty_block_trace` by `lambdas_per_block`, exactly as before.
2256 fit.inference.as_ref().and_then(|info| {
2257 let traces = &info.penalty_block_trace;
2258 if traces.len() != lambdas_per_block.iter().sum::<usize>() {
2259 return None;
2260 }
2261 let mut per_class = Vec::with_capacity(m);
2262 let mut cursor = 0usize;
2263 for &n_blocks in &lambdas_per_block {
2264 let class_trace: f64 = traces[cursor..cursor + n_blocks].iter().sum();
2265 per_class
2266 .push((p_per_class as f64 - class_trace).clamp(0.0, p_per_class as f64));
2267 cursor += n_blocks;
2268 }
2269 Some(per_class)
2270 })
2271 });
2272 // Per-PENALTY EDF: the inference layer's `edf_by_block` is already the
2273 // clamped per-penalty-block trace EDF `rank(S_k) − λ_k·tr(H⁻¹ S_k)`, one
2274 // entry per smoothing parameter and block-major aligned 1:1 with the flat
2275 // `lambdas`. Surface it verbatim (guarding only on the length contract) so
2276 // consumers can inspect per-(class, term, penalty) collapse onto the null
2277 // space — a signal the per-class EDF SUM hides. This is NOT a per-class
2278 // total: with double-penalty smooths `Σ_k rank(S_k) > p_per_class`, so the
2279 // entries deliberately need not sum to the model EDF (the per-class field
2280 // carries that contract instead).
2281 let edf_per_penalty = joint_recon
2282 .as_ref()
2283 .map(|(_, _, epp, _, _)| epp.clone())
2284 .or_else(|| {
2285 // Legacy per-block path: the inference layer's `edf_by_block` is
2286 // already the clamped per-penalty-block trace EDF, aligned 1:1 with
2287 // the flat `lambdas`.
2288 fit.inference.as_ref().and_then(|info| {
2289 if info.edf_by_block.len() != lambdas_flat.len() {
2290 return None;
2291 }
2292 Some(
2293 info.edf_by_block
2294 .iter()
2295 .map(|&e| e.max(0.0))
2296 .collect::<Vec<f64>>(),
2297 )
2298 })
2299 });
2300 let coefficients_flat: Vec<f64> = coefficients_active.iter().copied().collect();
2301
2302 // #1101: surface the joint Laplace posterior covariance `H⁻¹` (block-ordered
2303 // [β_0; …; β_{K-2}]) and the influence matrix `F = H⁻¹ X'WX` the REML driver
2304 // computed at the converged mode. These power the predict path's delta-method
2305 // per-class probability standard errors and the summary's Wald smooth-term
2306 // tests. The joint matrices are `(P·M)×(P·M)`. The covariance is mapped back
2307 // to RAW units (see below) so it pairs with the raw predict design; the
2308 // influence is kept in the fitted basis (the Wald table only slices penalized
2309 // columns, which the standardization affine leaves identity-mapped).
2310 let expected_joint = p_per_class.saturating_mul(m);
2311 // The joint Hessian (and thus `H⁻¹`) was assembled in the STANDARDIZED
2312 // parametric basis used during fitting, while the saved coefficients and the
2313 // raw predict design are in raw units. Map the covariance to raw units with
2314 // the same exact affine reparameterization `β_raw = A β_std`: for each
2315 // standardized parametric column `col`, `β_raw[col] = β_std[col]/scale` and
2316 // the intercept absorbs `−Σ_col (center/scale)·β_std[col]`. So `A = I` except
2317 // `A[col,col] = 1/scale` and `A[i0,col] = −center/scale`, replicated
2318 // block-diagonally per active class, and `Cov_raw = A Cov_std Aᵀ`. With no
2319 // standardization (`parametric_standardization` empty) `A = I` and this is a
2320 // no-op. The smooth-term (penalized) columns are untouched by `A`, so the
2321 // Wald table's per-term blocks are identical in both bases.
2322 let intercept_col0 = design.intercept_range.clone().next();
2323 let build_per_class_affine = |amat: &mut Array2<f64>| {
2324 for &(col, center, scale) in ¶metric_standardization {
2325 if col >= p_per_class {
2326 continue;
2327 }
2328 amat[[col, col]] = 1.0 / scale;
2329 if let Some(i0) = intercept_col0
2330 && i0 < p_per_class
2331 {
2332 amat[[i0, col]] = -center / scale;
2333 }
2334 }
2335 };
2336 let coefficient_covariance_flat = fit
2337 .covariance_conditional
2338 .as_ref()
2339 .filter(|c| c.nrows() == expected_joint && c.ncols() == expected_joint)
2340 .map(|cov_std| {
2341 if parametric_standardization.is_empty() {
2342 return cov_std.iter().copied().collect::<Vec<f64>>();
2343 }
2344 // Block-diagonal joint A (same per active class).
2345 let mut a_joint = Array2::<f64>::eye(expected_joint);
2346 let mut a_class = Array2::<f64>::eye(p_per_class);
2347 build_per_class_affine(&mut a_class);
2348 for a in 0..m {
2349 let base = a * p_per_class;
2350 for i in 0..p_per_class {
2351 for j in 0..p_per_class {
2352 a_joint[[base + i, base + j]] = a_class[[i, j]];
2353 }
2354 }
2355 }
2356 let cov_raw = a_joint.dot(cov_std).dot(&a_joint.t());
2357 cov_raw.iter().copied().collect::<Vec<f64>>()
2358 });
2359 // The influence matrix `F = H⁻¹ X'WX = H⁻¹(H − S_λ) = I − H⁻¹ S_λ`. The
2360 // exact-Newton multinomial blocks carry no IRLS pseudo-data, so the generic
2361 // inference path does not export `coefficient_influence`; reconstruct it
2362 // exactly here. Under the #1587 joint-penalty architecture the penalty is the
2363 // COUPLED centered metric `S_λ = Σ_t λ_t (M ⊗ S_t)` (off the class-block
2364 // diagonal), already assembled in `joint_recon` above, so reuse that exact
2365 // `F`. Only fall back to the legacy block-diagonal `Σ_t λ_{a,t} S_t`
2366 // reconstruction when the joint reconstruction is unavailable (pre-#1587
2367 // per-block fits whose class blocks still carry their own penalties).
2368 let coefficient_influence_flat = match joint_recon.as_ref() {
2369 Some((f, _, _, _, _)) => Some(f.iter().copied().collect::<Vec<f64>>()),
2370 None => fit
2371 .covariance_conditional
2372 .as_ref()
2373 .filter(|c| c.nrows() == expected_joint && c.ncols() == expected_joint)
2374 .and_then(|hinv| {
2375 if fit.blocks.len() != m {
2376 return None;
2377 }
2378 // Joint S_λ (block-diagonal across active classes).
2379 let mut s_lambda = Array2::<f64>::zeros((expected_joint, expected_joint));
2380 for (a, block) in fit.blocks.iter().enumerate() {
2381 if block.lambdas.len() != penalties_arc.len() {
2382 return None;
2383 }
2384 let base = a * p_per_class;
2385 for (t, pen) in penalties_arc.iter().enumerate() {
2386 let lam = block.lambdas[t];
2387 if lam == 0.0 {
2388 continue;
2389 }
2390 let dense = pen.to_dense();
2391 if dense.nrows() != p_per_class || dense.ncols() != p_per_class {
2392 return None;
2393 }
2394 for i in 0..p_per_class {
2395 for j in 0..p_per_class {
2396 s_lambda[[base + i, base + j]] += lam * dense[[i, j]];
2397 }
2398 }
2399 }
2400 }
2401 // F = I − H⁻¹ S_λ.
2402 let hinv_s = hinv.dot(&s_lambda);
2403 let mut f = Array2::<f64>::eye(expected_joint);
2404 f -= &hinv_s;
2405 Some(f.iter().copied().collect::<Vec<f64>>())
2406 }),
2407 };
2408
2409 // Per-(smooth term) coefficient span within a single class block, deduped by
2410 // col_range (the #561 double-penalty migration emits two penalty blocks per
2411 // term sharing one col_range; the Wald test covers the whole term block once).
2412 let mut smooth_term_spans: Vec<MultinomialSmoothTermSpan> = Vec::new();
2413 for (pen_idx, bp) in design.penalties.iter().enumerate() {
2414 let col_start = bp.col_range.start;
2415 let col_end = bp.col_range.end;
2416 if col_start >= col_end || col_end > p_per_class {
2417 continue;
2418 }
2419 if smooth_term_spans
2420 .iter()
2421 .any(|s| s.col_start == col_start && s.col_end == col_end)
2422 {
2423 continue;
2424 }
2425 let label = design
2426 .penaltyinfo
2427 .get(pen_idx)
2428 .and_then(|info| info.termname.clone())
2429 .unwrap_or_else(|| format!("s{pen_idx}"));
2430 let nullspace_dim = design
2431 .nullspace_dims
2432 .get(pen_idx)
2433 .copied()
2434 .unwrap_or(0)
2435 .min(col_end - col_start);
2436 smooth_term_spans.push(MultinomialSmoothTermSpan {
2437 label,
2438 col_start,
2439 col_end,
2440 nullspace_dim,
2441 });
2442 }
2443
2444 // One descriptive label per penalty *component* within a single class block,
2445 // parallel to that block's λ slice (#1544). `design.penalties` is index-
2446 // parallel to every active class's `block.lambdas` (each block carries the
2447 // full per-component penalty list, validated above by
2448 // `block.lambdas.len() == penalties_arc.len()`), so iterating it in order
2449 // yields exactly `lambdas_per_block[0]` labels aligned with the per-block λ.
2450 // This is deliberately NOT deduped by col_range (unlike `smooth_term_spans`):
2451 // the double penalty's primary and null-space components share one col_range
2452 // but select independent λ, and each must keep its own label so the summary
2453 // renderer never collapses or drops a λ.
2454 let lambda_labels: Vec<String> = design
2455 .penalties
2456 .iter()
2457 .enumerate()
2458 .map(|(pen_idx, _)| penalty_component_label(design.penaltyinfo.get(pen_idx), pen_idx))
2459 .collect();
2460
2461 // Unpenalized deviance read directly from the converged unpenalized
2462 // log-likelihood the rho-prior driver already computed (issue #348):
2463 // MultinomialFamily::evaluate sets FamilyEvaluation.log_likelihood =
2464 // log_lik(η, y) with no penalty term, and that value flows unchanged into
2465 // UnifiedFitResult.log_likelihood. This reproduces the legacy fixed-λ
2466 // path's `deviance = -2 · log_lik` contract bit-for-bit, so the previous
2467 // row-by-row η = Xβ rebuild and softmax recompute were pure dead work.
2468 let deviance = -2.0 * fit.log_likelihood;
2469
2470 Ok(MultinomialSavedModel {
2471 formula: formula.to_string(),
2472 class_levels: class_levels.clone(),
2473 reference_class_index: class_levels.len() - 1,
2474 resolved_termspec: spec,
2475 coefficients_flat,
2476 p_per_class,
2477 n_active_classes: m,
2478 training_headers: data.headers.clone(),
2479 lambdas: lambdas_flat,
2480 lambdas_per_block,
2481 iterations: fit.inner_cycles,
2482 converged: fit.outer_converged,
2483 penalized_neg_log_likelihood: -fit.log_likelihood + 0.5 * fit.stable_penalty_term,
2484 deviance,
2485 edf_per_class,
2486 edf_per_penalty,
2487 coefficient_covariance_flat,
2488 coefficient_influence_flat,
2489 smooth_term_spans,
2490 lambda_labels,
2491 })
2492}
2493
2494/// Replay the saved termspec to build the predict-time dense design `X` on a
2495/// fresh dataset, realigning feature columns **by name** so the predict frame
2496/// need not reproduce the training column order or carry the response column.
2497/// Shared by every multinomial predict path (probabilities, SE bands, and the
2498/// posterior-predictive replicate draws).
2499fn build_multinomial_predict_design(
2500 model: &MultinomialSavedModel,
2501 data: &EncodedDataset,
2502) -> Result<Array2<f64>, EstimationError> {
2503 // The saved termspec stores feature columns as absolute indices into the
2504 // *training* table `[response, features...]`. Realign them onto this
2505 // dataset's columns by name, so prediction works on label-free new data
2506 // (the response column is never referenced by any term; issue #803).
2507 let predict_columns = data.column_map();
2508 let realigned = model.resolved_termspec.remap_feature_columns(
2509 |index| -> Result<usize, EstimationError> {
2510 let name = model.training_headers.get(index).ok_or_else(|| {
2511 EstimationError::InvalidInput(format!(
2512 "multinomial predict: saved training column index {index} is out of bounds \
2513 for {} training headers",
2514 model.training_headers.len()
2515 ))
2516 })?;
2517 resolve_role_col(&predict_columns, name, "feature")
2518 .map_err(|err| EstimationError::InvalidInput(err.to_string()))
2519 },
2520 )?;
2521 let design = build_term_collection_design(data.values.view(), &realigned).map_err(|err| {
2522 EstimationError::InvalidInput(format!(
2523 "multinomial predict: rebuild design from saved termspec: {err}"
2524 ))
2525 })?;
2526 let x_dense = design
2527 .design
2528 .try_to_dense_by_chunks("multinomial predict design")
2529 .map_err(EstimationError::InvalidInput)?;
2530 if x_dense.ncols() != model.p_per_class {
2531 crate::bail_invalid_estim!(
2532 "multinomial predict: predict design has {} cols, saved model expects {}",
2533 x_dense.ncols(),
2534 model.p_per_class
2535 );
2536 }
2537 Ok(x_dense)
2538}
2539
2540/// Replay the saved termspec to build the predict-time design on a fresh
2541/// dataset, then evaluate softmax probabilities. The predict dataset must carry
2542/// the same feature columns the training data did, matched **by name** — it need
2543/// not reproduce the training column order, and in particular need not carry the
2544/// response column (prediction is for label-free new data).
2545pub fn predict_multinomial_formula(
2546 model: &MultinomialSavedModel,
2547 data: &EncodedDataset,
2548) -> Result<Array2<f64>, EstimationError> {
2549 let x_dense = build_multinomial_predict_design(model, data)?;
2550 Ok(model.predict_probabilities(x_dense.view()))
2551}
2552
2553/// Draw `n_draws` posterior-predictive replicate class-label assignments for a
2554/// saved multinomial model on fresh data (#1101). Rebuilds the predict design
2555/// exactly as [`predict_multinomial_formula`], then samples each row's class
2556/// from `Categorical(softmax(X·β̂))` (see
2557/// [`MultinomialSavedModel::sample_replicate_classes`]). Returns an
2558/// `(n_draws, N)` matrix of class INDICES `0..K` aligned to `model.class_levels`,
2559/// deterministic in `seed`.
2560pub fn posterior_predict_multinomial_formula(
2561 model: &MultinomialSavedModel,
2562 data: &EncodedDataset,
2563 n_draws: usize,
2564 seed: u64,
2565) -> Result<Array2<u32>, EstimationError> {
2566 if n_draws == 0 {
2567 crate::bail_invalid_estim!("multinomial posterior_predict: n_draws must be >= 1");
2568 }
2569 let x_dense = build_multinomial_predict_design(model, data)?;
2570 Ok(model.sample_replicate_classes(x_dense.view(), n_draws, seed))
2571}
2572
2573/// Predict class probabilities AND delta-method per-class probability standard
2574/// errors for a saved multinomial model on fresh data (#1101). Replays the
2575/// saved termspec to build the predict design exactly as
2576/// [`predict_multinomial_formula`], then applies the softmax-Jacobian delta
2577/// method against the stored joint posterior covariance. Returns
2578/// `(probs (N,K), prob_se (N,K) | None)`; `prob_se` is `None` for a legacy
2579/// model fitted before covariance was surfaced.
2580pub fn predict_multinomial_formula_with_se(
2581 model: &MultinomialSavedModel,
2582 data: &EncodedDataset,
2583) -> Result<(Array2<f64>, Option<Array2<f64>>), EstimationError> {
2584 let x_dense = build_multinomial_predict_design(model, data)?;
2585 Ok(model.predict_probabilities_with_se(x_dense.view()))
2586}
2587
2588#[cfg(test)]
2589mod fisher_override_tests {
2590 use super::*;
2591 use ndarray::Array3;
2592
2593 fn toy() -> (Array2<f64>, Array2<f64>, Array2<f64>, Array1<f64>) {
2594 let n = 15;
2595 let p = 2;
2596 let k = 3;
2597 let design =
2598 Array2::<f64>::from_shape_fn(
2599 (n, p),
2600 |(i, j)| {
2601 if j == 0 { 1.0 } else { ((i + 2) as f64).cos() }
2602 },
2603 );
2604 let mut y = Array2::<f64>::zeros((n, k));
2605 for i in 0..n {
2606 y[[i, i % k]] = 1.0;
2607 }
2608 let penalty = Array2::<f64>::eye(p);
2609 let lambdas = Array1::<f64>::from_elem(k - 1, 0.5);
2610 (design, y, penalty, lambdas)
2611 }
2612
2613 #[test]
2614 fn fisher_override_none_reproduces_analytic() {
2615 // Issue #349: None override is exactly the analytic fit.
2616 let (design, y, penalty, lambdas) = toy();
2617 let mk = |over: Option<ndarray::ArrayView3<'_, f64>>| {
2618 fit_penalized_multinomial(MultinomialFitInputs {
2619 design: design.view(),
2620 y_one_hot: y.view(),
2621 penalty: penalty.view(),
2622 lambdas: lambdas.view(),
2623 row_weights: None,
2624 fisher_w_override: over,
2625 max_iter: 50,
2626 tol: 1.0e-9,
2627 })
2628 .expect("fit must succeed")
2629 };
2630 let a = mk(None);
2631 let b = mk(None);
2632 for (x, z) in a
2633 .coefficients_active
2634 .iter()
2635 .zip(b.coefficients_active.iter())
2636 {
2637 assert_eq!(x, z);
2638 }
2639 }
2640
2641 #[test]
2642 fn fisher_override_wrong_shape_is_rejected() {
2643 let (design, y, penalty, lambdas) = toy();
2644 let n = design.nrows();
2645 let m = y.ncols(); // K, not K-1 — deliberately wrong
2646 let bad = Array3::<f64>::zeros((n, m, m));
2647 let err = fit_penalized_multinomial(MultinomialFitInputs {
2648 design: design.view(),
2649 y_one_hot: y.view(),
2650 penalty: penalty.view(),
2651 lambdas: lambdas.view(),
2652 row_weights: None,
2653 fisher_w_override: Some(bad.view()),
2654 max_iter: 50,
2655 tol: 1.0e-9,
2656 })
2657 .expect_err("wrong active-block shape must error");
2658 assert!(format!("{err}").contains("fisher_w_override shape"));
2659 }
2660
2661 /// #1101 regression: the fixed-λ inner solve now surfaces the joint Laplace
2662 /// coefficient covariance `H⁻¹`, and the multinomial predictor derives
2663 /// finite delta-method per-class probability standard errors from it. Before
2664 /// this change `MultinomialFitOutputs` carried NO covariance at all, so the
2665 /// covariance-dimension / predictor assertions below could not even compile
2666 /// (fail-before). Asserts, with un-weakened bounds:
2667 /// 1. covariance is `(P·(K−1))²`, all-finite, symmetric, and PSD (every
2668 /// diagonal ≥ 0 and `vᵀΣv ≥ 0` on probe vectors);
2669 /// 2. the delta-method per-class probability SEs are finite and within
2670 /// `[0, 1]` (a probability SE can never exceed the unit interval);
2671 /// 3. predicted probabilities are finite, in `[0, 1]`, and each row sums
2672 /// to 1 (simplex).
2673 #[test]
2674 fn covariance_and_delta_method_se_are_finite_and_wellformed_1101() {
2675 let (design, y, penalty, lambdas) = toy();
2676 let p = design.ncols();
2677 let k = y.ncols();
2678 let m = k - 1;
2679 let d = p * m;
2680
2681 let fit = fit_penalized_multinomial(MultinomialFitInputs {
2682 design: design.view(),
2683 y_one_hot: y.view(),
2684 penalty: penalty.view(),
2685 lambdas: lambdas.view(),
2686 row_weights: None,
2687 fisher_w_override: None,
2688 max_iter: 50,
2689 tol: 1.0e-9,
2690 })
2691 .expect("fit must succeed");
2692 assert!(fit.converged, "toy multinomial fit must converge");
2693
2694 // (1) Covariance shape, finiteness, symmetry.
2695 let cov = &fit.coefficient_covariance;
2696 assert_eq!(
2697 cov.dim(),
2698 (d, d),
2699 "covariance must be (P·(K−1))² = ({d},{d})"
2700 );
2701 for &v in cov.iter() {
2702 assert!(v.is_finite(), "covariance entry must be finite (got {v})");
2703 }
2704 for i in 0..d {
2705 for j in 0..d {
2706 let asym = (cov[[i, j]] - cov[[j, i]]).abs();
2707 assert!(
2708 asym <= 1e-9 * (1.0 + cov[[i, j]].abs()),
2709 "covariance must be symmetric at ({i},{j}): |Σ_ij − Σ_ji| = {asym:.3e}"
2710 );
2711 }
2712 }
2713 // PSD: diagonal ≥ 0 and quadratic forms on deterministic probe vectors
2714 // (unit axes and the all-ones vector) are non-negative. `H = XᵀWX + λS`
2715 // with W PSD (softmax Fisher) and S PSD (identity here) is positive
2716 // definite, so its inverse is PD; these probes must all be positive.
2717 for i in 0..d {
2718 assert!(
2719 cov[[i, i]] >= 0.0,
2720 "covariance diagonal[{i}] must be ≥ 0 (got {})",
2721 cov[[i, i]]
2722 );
2723 }
2724 let mut probes: Vec<Vec<f64>> = Vec::new();
2725 for i in 0..d {
2726 let mut e = vec![0.0_f64; d];
2727 e[i] = 1.0;
2728 probes.push(e);
2729 }
2730 probes.push(vec![1.0_f64; d]);
2731 for v in &probes {
2732 let mut q = 0.0_f64;
2733 for i in 0..d {
2734 for j in 0..d {
2735 q += v[i] * cov[[i, j]] * v[j];
2736 }
2737 }
2738 assert!(
2739 q >= -1e-9,
2740 "covariance must be PSD: vᵀΣv = {q:.3e} < 0"
2741 );
2742 }
2743
2744 // (2) & (3) Delta-method SEs and simplex probabilities on the training
2745 // design (any P-column matrix in the fitted basis works).
2746 let (probs, prob_se) = fit
2747 .predict_probabilities_with_se(design.view())
2748 .expect("delta-method SE must succeed");
2749 let n = design.nrows();
2750 assert_eq!(probs.dim(), (n, k));
2751 assert_eq!(prob_se.dim(), (n, k));
2752 for row in 0..n {
2753 let mut rowsum = 0.0_f64;
2754 for c in 0..k {
2755 let pc = probs[[row, c]];
2756 assert!(pc.is_finite() && (0.0..=1.0).contains(&pc), "prob[{row},{c}]={pc}");
2757 rowsum += pc;
2758 let se = prob_se[[row, c]];
2759 assert!(se.is_finite(), "prob_se[{row},{c}] must be finite (got {se})");
2760 assert!(
2761 (0.0..=1.0).contains(&se),
2762 "prob_se[{row},{c}] must be in [0,1] (got {se})"
2763 );
2764 }
2765 assert!(
2766 (rowsum - 1.0).abs() < 1e-9,
2767 "row {row} probabilities must sum to 1 (got {rowsum})"
2768 );
2769 }
2770 }
2771
2772 #[test]
2773 fn formula_outer_route_uses_exact_curvature_for_medium_d() {
2774 // The 2-smooth reference formula fit (K = 3, double-penalty terms) is
2775 // D = (K-1) * 2 terms * 2 penalties = 8 and needs exact curvature to
2776 // avoid over-smoothed lambda caps (#715 arm (a)).
2777 assert!(
2778 multinomial_formula_use_outer_hessian(8),
2779 "D=8 loaded multinomial fits need exact curvature to avoid over-smoothed lambda caps"
2780 );
2781 assert!(
2782 multinomial_formula_use_outer_hessian(12),
2783 "D=12 (3 double-penalty smooth terms, K=3) stays on exact curvature"
2784 );
2785 }
2786
2787 #[test]
2788 fn formula_outer_route_uses_exact_curvature_for_d16_penguin_fixture() {
2789 // Four k=10 penguin smooths (K = 3) are D = 16 under double-penalty
2790 // terms. They must reach the exact ARC route so the #1082 cost-stall
2791 // halt is available on the near-separable lambda-to-zero ridge.
2792 assert!(
2793 multinomial_formula_use_outer_hessian(16),
2794 "D=16 multinomial fits need exact ARC curvature for the #1082 stall halt"
2795 );
2796 }
2797
2798 #[test]
2799 fn formula_min_lambda_floor_is_continuous_and_information_scaled() {
2800 // Build a one-hot label matrix whose smallest class carries `count` rows.
2801 fn floor_for_min_count(count: usize) -> f64 {
2802 // Two classes: a large one (1000 rows) and a minority one (`count`).
2803 let n = 1000 + count;
2804 let mut y = Array2::<f64>::zeros((n, 2));
2805 for r in 0..1000 {
2806 y[[r, 0]] = 1.0;
2807 }
2808 for r in 1000..n {
2809 y[[r, 1]] = 1.0;
2810 }
2811 multinomial_formula_min_lambda(y.view())
2812 }
2813
2814 // The floor's endpoints are now DERIVED from a target prior strength in
2815 // pseudo-observations against the maximal per-observation softmax Fisher
2816 // information I₁ = ¼ (base = τ·I₁, sparse = τ_max·I₁). Pin them to the
2817 // previously fixture-calibrated values so the near-separable quality arms
2818 // (penguins, vgam softmax) — whose smallest class has n_c ≥ 50 — are
2819 // byte-for-byte unaffected: the derivation REDUCES TO the old constants
2820 // at the calibration point.
2821 let base = MULTINOMIAL_FORMULA_PRIOR_PSEUDO_OBS * MULTINOMIAL_FORMULA_FISHER_INFO_PER_OBS;
2822 let sparse = MULTINOMIAL_FORMULA_SPARSE_PRIOR_PSEUDO_OBS_MAX
2823 * MULTINOMIAL_FORMULA_FISHER_INFO_PER_OBS;
2824 assert!(
2825 (base - 2.0e-4).abs() < 1e-18,
2826 "derived base floor must equal the calibrated 2e-4"
2827 );
2828 assert!(
2829 (sparse - 1.0e-3).abs() < 1e-18,
2830 "derived sparse floor must equal the calibrated 1e-3"
2831 );
2832
2833 // Well-supported (n_c >= n_ref=50) sits exactly at the base floor.
2834 assert!((floor_for_min_count(50) - base).abs() < 1e-18);
2835 assert!((floor_for_min_count(200) - base).abs() < 1e-18);
2836 // Very sparse (n_c <= n_ref·base/sparse = 10) clamps to the strong floor.
2837 assert!((floor_for_min_count(10) - sparse).abs() < 1e-18);
2838 assert!((floor_for_min_count(5) - sparse).abs() < 1e-18);
2839 // No cliff at the old hard threshold: 49 vs 50 differ by < 5% (the old
2840 // step jumped 5x). Floor is monotone non-increasing in support.
2841 let f49 = floor_for_min_count(49);
2842 let f50 = floor_for_min_count(50);
2843 assert!(
2844 f49 >= f50 && f49 <= f50 * 1.05,
2845 "floor must be continuous across c0, got {f49} vs {f50}"
2846 );
2847 let f25 = floor_for_min_count(25);
2848 assert!(
2849 f25 > f50 && f25 < floor_for_min_count(10),
2850 "mid-support floor must interpolate strictly between the two endpoints"
2851 );
2852
2853 // FIRST-PRINCIPLES SCALING: in the interpolating regime the floor equals
2854 // exactly τ·I₁·(n_ref/n_c) — the effective-pseudo-observation prior held
2855 // to a fixed fraction of the per-class data information n_c·I₁. Halving
2856 // the effective sample size doubles the floor (until the cap), and the
2857 // absolute value matches the closed-form n_c-scaled prior.
2858 for &n_c in &[12usize, 16, 20, 30, 40] {
2859 let expected = base * (MULTINOMIAL_FORMULA_SPARSE_REFERENCE_SUPPORT / n_c as f64);
2860 assert!(
2861 (floor_for_min_count(n_c) - expected).abs() < 1e-15,
2862 "floor at n_c={n_c} must be τ·I₁·n_ref/n_c = {expected}, got {}",
2863 floor_for_min_count(n_c)
2864 );
2865 }
2866 // Inverse scaling with effective sample size: n_c -> n_c/2 doubles the
2867 // floor inside the unclamped band (20 and 40 are both interior; 40 < 50
2868 // so it is scaled, 20 > 10 so it is not capped).
2869 assert!(
2870 (floor_for_min_count(20) - 2.0 * floor_for_min_count(40)).abs() < 1e-15,
2871 "floor must scale like 1/n_c (effective Fisher information) in the interior band"
2872 );
2873 }
2874
2875 #[test]
2876 fn formula_penalty_scale_tracks_softmax_fisher_curvature() {
2877 assert!(
2878 (multinomial_formula_penalty_scale(2) - 0.5).abs() < 1.0e-12,
2879 "binary-logit neutral-simplex curvature scale should remain at 1/2"
2880 );
2881 assert!(
2882 (multinomial_formula_penalty_scale(3) - 4.0 / 9.0).abs() < 1.0e-12,
2883 "three-class softmax penalties should be calibrated to 2*(K-1)/K^2"
2884 );
2885 assert!(
2886 multinomial_formula_penalty_scale(5) < multinomial_formula_penalty_scale(3),
2887 "active-class Fisher curvature decreases as the simplex gains classes"
2888 );
2889 }
2890
2891 #[test]
2892 fn fixed_lambda_multinomial_firth_keeps_complete_separation_finite() {
2893 // #1854: complete softmax separation used to be a HARD diagnostic
2894 // (`MultinomialSeparationDetected`). It now automatically engages the
2895 // Firth/Jeffreys proper prior (`½ log|I(β)|`, magic-by-default) so the fit
2896 // stays finite instead of running away — the same guarantee the formula
2897 // REML path already provided. The class regions are cleanly separated by
2898 // `x`, so the unbiased MLE is at infinity; the Firth-penalized fit must
2899 // still converge to a finite mode and recover the region structure.
2900 let n = 90;
2901 let design = Array2::<f64>::from_shape_fn((n, 2), |(row, col)| match col {
2902 0 => 1.0,
2903 _ => -3.0 + 6.0 * (row as f64) / ((n - 1) as f64),
2904 });
2905 let mut y = Array2::<f64>::zeros((n, 3));
2906 for row in 0..n {
2907 let x = design[[row, 1]];
2908 let class = if x < -1.0 {
2909 0
2910 } else if x > 1.0 {
2911 1
2912 } else {
2913 2
2914 };
2915 y[[row, class]] = 1.0;
2916 }
2917 let penalty = Array2::<f64>::zeros((2, 2));
2918 let lambdas = Array1::<f64>::zeros(2);
2919 let out = fit_penalized_multinomial(MultinomialFitInputs {
2920 design: design.view(),
2921 y_one_hot: y.view(),
2922 penalty: penalty.view(),
2923 lambdas: lambdas.view(),
2924 row_weights: None,
2925 fisher_w_override: None,
2926 max_iter: 80,
2927 tol: 1.0e-12,
2928 })
2929 .expect("Firth/Jeffreys prior keeps the separated multinomial fit finite (#1854)");
2930 assert!(
2931 out.converged,
2932 "the Firth-penalized separation refit must report convergence"
2933 );
2934 // Every coefficient is finite — the whole point of the Firth prior on the
2935 // separated (unpenalized) logit directions.
2936 for &b in out.coefficients_active.iter() {
2937 assert!(
2938 b.is_finite(),
2939 "Firth-penalized coefficients must be finite, got {b}"
2940 );
2941 }
2942 // Fitted probabilities remain a valid simplex per row.
2943 for row in 0..n {
2944 let mut mass = 0.0_f64;
2945 for c in 0..3 {
2946 let p = out.fitted_probabilities[[row, c]];
2947 assert!(
2948 p.is_finite() && (0.0..=1.0 + 1e-9).contains(&p),
2949 "row {row} class {c} probability {p} out of [0,1]"
2950 );
2951 mass += p;
2952 }
2953 assert!(
2954 (mass - 1.0).abs() < 1e-6,
2955 "row {row} probabilities must sum to 1, got {mass}"
2956 );
2957 }
2958 // The finite fit still recovers the separated structure: on a clearly
2959 // interior representative of each region the predicted class is correct.
2960 let predict = |x: f64| -> usize {
2961 let mut eta = [0.0_f64; 3];
2962 for a in 0..2 {
2963 eta[a] = out.coefficients_active[[0, a]] + out.coefficients_active[[1, a]] * x;
2964 }
2965 let mut best = 0usize;
2966 for c in 1..3 {
2967 if eta[c] > eta[best] {
2968 best = c;
2969 }
2970 }
2971 best
2972 };
2973 assert_eq!(predict(-2.5), 0, "deep-left region should predict class 0");
2974 assert_eq!(predict(2.5), 1, "deep-right region should predict class 1");
2975 assert_eq!(predict(0.0), 2, "central region should predict class 2");
2976 }
2977
2978 #[test]
2979 fn formula_multinomial_accepts_finite_saturated_logits() {
2980 // A saturated-but-FINITE logit surface can be a valid formula REML mode
2981 // (the #715 penguins regime: bill/flipper cleanly separate the species,
2982 // so fitted logits can legitimately exceed ±25). `outer_converged ==
2983 // false` then signals only that the driver auto-escalated to never-fail
2984 // posterior sampling about that finite mode (gam#860), NOT a separation
2985 // artifact — the adapter must accept it, never raise
2986 // `MultinomialSeparationDetected`.
2987 let saturated_states = vec![
2988 ParameterBlockState {
2989 beta: Array1::from_vec(vec![1.0, 2.0]),
2990 eta: Array1::from_vec(vec![0.2, 4.0, -7.0]),
2991 },
2992 ParameterBlockState {
2993 beta: Array1::from_vec(vec![-1.0, 3.0]),
2994 eta: Array1::from_vec(vec![1.0, 25.5, -0.1]),
2995 },
2996 ];
2997 assert!(
2998 multinomial_formula_separation_diagnostic(17, 9, &saturated_states).is_none(),
2999 "a finite (even saturated, |eta|>25) formula optimum is a valid fit, \
3000 not a separation diagnostic"
3001 );
3002
3003 // Only a genuinely NON-FINITE logit — a NaN/Inf blow-up in the inner
3004 // linear algebra with no finite mode to sample about — is a real
3005 // formula-path failure.
3006 let blown_up = vec![
3007 ParameterBlockState {
3008 beta: Array1::from_vec(vec![1.0, 2.0]),
3009 eta: Array1::from_vec(vec![0.2, 4.0, -7.0]),
3010 },
3011 ParameterBlockState {
3012 beta: Array1::from_vec(vec![-1.0, 3.0]),
3013 eta: Array1::from_vec(vec![1.0, f64::INFINITY, -0.1]),
3014 },
3015 ];
3016 let err = multinomial_formula_separation_diagnostic(17, 9, &blown_up)
3017 .expect("a non-finite formula logit must raise the separation diagnostic");
3018 assert!(
3019 matches!(
3020 err,
3021 EstimationError::MultinomialSeparationDetected {
3022 iteration: 17,
3023 max_abs_eta,
3024 active_class_index: 1,
3025 row_index: 1,
3026 } if !max_abs_eta.is_finite()
3027 ),
3028 "expected typed multinomial separation diagnostic at the non-finite channel, got {err:?}"
3029 );
3030 }
3031
3032 #[test]
3033 fn separation_evidence_gate_arms_firth_only_on_blowup() {
3034 // Interior fit: finite logits well inside the saturation threshold ⇒ NO
3035 // separation evidence ⇒ the unbiased criterion's mode is accepted as-is
3036 // and the Firth/Jeffreys prior stays disarmed (#715 arm (a): no 1/K
3037 // shrinkage on well-identified data).
3038 let interior = vec![
3039 ParameterBlockState {
3040 beta: Array1::from_vec(vec![1.0, 2.0]),
3041 eta: Array1::from_vec(vec![0.2, 4.0, -7.0]),
3042 },
3043 ParameterBlockState {
3044 beta: Array1::from_vec(vec![-1.0, 3.0]),
3045 eta: Array1::from_vec(vec![1.0, -3.5, -0.1]),
3046 },
3047 ];
3048 assert!(
3049 multinomial_formula_separation_evidence(&interior).is_none(),
3050 "an interior finite mode must not arm the Firth refit"
3051 );
3052
3053 // Saturated but finite logits are valid formula-path modes on
3054 // near-separated real data. They must not arm the Firth refit because
3055 // the Jeffreys pull can over-regularize the held-out probabilities.
3056 let saturated = vec![
3057 ParameterBlockState {
3058 beta: Array1::from_vec(vec![1.0, 2.0]),
3059 eta: Array1::from_vec(vec![0.2, 4.0, -7.0]),
3060 },
3061 ParameterBlockState {
3062 beta: Array1::from_vec(vec![-1.0, 3.0]),
3063 eta: Array1::from_vec(vec![1.0, 25.5, -0.1]),
3064 },
3065 ];
3066 assert!(
3067 multinomial_formula_separation_evidence(&saturated).is_none(),
3068 "a finite saturated formula-mode logit must not arm the Firth refit"
3069 );
3070
3071 // Non-finite logit ⇒ inner blow-up along an unbounded direction ⇒
3072 // separation evidence.
3073 let blown_up = vec![ParameterBlockState {
3074 beta: Array1::from_vec(vec![1.0, 2.0]),
3075 eta: Array1::from_vec(vec![0.2, f64::NAN, -7.0]),
3076 }];
3077 let evidence = multinomial_formula_separation_evidence(&blown_up)
3078 .expect("a non-finite logit is separation evidence");
3079 assert!(
3080 evidence.contains("non-finite logit") && evidence.contains("row 1"),
3081 "evidence must name the non-finite logit, got {evidence}"
3082 );
3083
3084 // Large finite logits below the fixed-lambda diagnostic threshold are
3085 // likewise accepted on the formula path.
3086 let near = vec![ParameterBlockState {
3087 beta: Array1::from_vec(vec![1.0, 2.0]),
3088 eta: Array1::from_vec(vec![0.2, 24.9, -24.9]),
3089 }];
3090 assert!(
3091 multinomial_formula_separation_evidence(&near).is_none(),
3092 "logits below the saturation threshold must not arm the Firth refit"
3093 );
3094 }
3095
3096 #[test]
3097 fn unresolved_probe_evidence_arms_firth_on_saturated_finite_logits() {
3098 let saturated = vec![
3099 ParameterBlockState {
3100 beta: Array1::from_vec(vec![1.0, 2.0]),
3101 eta: Array1::from_vec(vec![0.2, 4.0, -7.0]),
3102 },
3103 ParameterBlockState {
3104 beta: Array1::from_vec(vec![-1.0, 3.0]),
3105 eta: Array1::from_vec(vec![1.0, 25.5, -0.1]),
3106 },
3107 ];
3108
3109 assert!(
3110 multinomial_formula_separation_evidence(&saturated).is_none(),
3111 "a converged finite saturated formula optimum remains unbiased"
3112 );
3113 let evidence = multinomial_formula_unresolved_probe_separation_evidence(&saturated)
3114 .expect("a non-converged saturated probe should arm the Firth refit");
3115 assert!(
3116 evidence.contains("separation-scale finite logit")
3117 && evidence.contains("row 1")
3118 && evidence.contains("active class 1"),
3119 "unresolved-probe evidence should name the saturated channel, got {evidence}"
3120 );
3121
3122 let near = vec![ParameterBlockState {
3123 beta: Array1::from_vec(vec![1.0, 2.0]),
3124 eta: Array1::from_vec(vec![0.2, 24.9, -24.9]),
3125 }];
3126 assert!(
3127 multinomial_formula_unresolved_probe_separation_evidence(&near).is_none(),
3128 "finite logits below the separation threshold still get the full unbiased retry"
3129 );
3130 }
3131
3132 #[test]
3133 fn scaled_fisher_override_changes_first_step() {
3134 // Curvature scaled by 4× shrinks the first Newton step relative to the
3135 // analytic fit, so a single-iteration fit must differ.
3136 let (design, y, penalty, lambdas) = toy();
3137 let n = design.nrows();
3138 let m = y.ncols() - 1;
3139 // Analytic block at β = 0: p_a = 1/K = 1/3, so diag = p_a(1−p_a),
3140 // off-diag = −p_a p_b. Scale that exact block by 4.
3141 let pk = 1.0 / (y.ncols() as f64);
3142 let mut over = Array3::<f64>::zeros((n, m, m));
3143 for row in 0..n {
3144 for a in 0..m {
3145 for b in 0..m {
3146 let analytic = if a == b { pk * (1.0 - pk) } else { -pk * pk };
3147 over[[row, a, b]] = 4.0 * analytic;
3148 }
3149 }
3150 }
3151 let scaled = fit_penalized_multinomial(MultinomialFitInputs {
3152 design: design.view(),
3153 y_one_hot: y.view(),
3154 penalty: penalty.view(),
3155 lambdas: lambdas.view(),
3156 row_weights: None,
3157 fisher_w_override: Some(over.view()),
3158 max_iter: 1,
3159 tol: 1.0e-9,
3160 })
3161 .expect("override fit must succeed");
3162 let analytic = fit_penalized_multinomial(MultinomialFitInputs {
3163 design: design.view(),
3164 y_one_hot: y.view(),
3165 penalty: penalty.view(),
3166 lambdas: lambdas.view(),
3167 row_weights: None,
3168 fisher_w_override: None,
3169 max_iter: 1,
3170 tol: 1.0e-9,
3171 })
3172 .expect("analytic fit must succeed");
3173 let differs = scaled
3174 .coefficients_active
3175 .iter()
3176 .zip(analytic.coefficients_active.iter())
3177 .any(|(a, b)| (a - b).abs() > 1.0e-6);
3178 assert!(differs, "scaled curvature must change the first step");
3179 }
3180}
3181
3182#[cfg(test)]
3183mod reference_class_invariance_tests {
3184 //! Regression for #1587: a penalized multinomial-logit GAM fit must be
3185 //! invariant to which class is the (arbitrary) softmax reference/baseline.
3186 //!
3187 //! The production REML path (`fit_penalized_multinomial_formula`) reference-
3188 //! codes the `K` classes (the last sorted label is the baseline) and, with
3189 //! the legacy `Diagonal` penalty metric, penalizes only the `K−1`
3190 //! reference-anchored ALR contrasts `½ Σ_a λ_a β_aᵀ S β_a`. Relabeling the
3191 //! response so a *different* class sorts last penalizes a different frame of
3192 //! log-odds contrasts, so the predicted probabilities drift (~1e-2 absolute)
3193 //! even though they are mathematically independent of the reference choice.
3194 //!
3195 //! This test fits the SAME 3-class softmax sample under three cyclic
3196 //! relabelings — each making a different original class the baseline —
3197 //! realigns the predicted probability columns back to the original class
3198 //! identities, and asserts the cross-labeling drift is below `1e-3`
3199 //! (the defect is ~1e-2; refitting the same labeling twice agrees to
3200 //! ~1e-12). It is the Rust-level sibling of
3201 //! `tests/bug_hunt_multinomial_fit_depends_on_reference_class_test.py`.
3202
3203 use super::*;
3204 use gam_data::load_dataset_projected;
3205 use std::fmt::Write as _;
3206 use std::fs;
3207 use tempfile::tempdir;
3208
3209 /// Deterministic `splitmix64` → `[0,1)` uniform stream (no external RNG dep;
3210 /// the only requirement is a well-distributed, reproducible draw).
3211 struct SplitMix64(u64);
3212 impl SplitMix64 {
3213 fn next_u64(&mut self) -> u64 {
3214 self.0 = self.0.wrapping_add(0x9E37_79B9_7F4A_7C15);
3215 let mut z = self.0;
3216 z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
3217 z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
3218 z ^ (z >> 31)
3219 }
3220 fn unit(&mut self) -> f64 {
3221 // 53-bit mantissa uniform in [0, 1).
3222 (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
3223 }
3224 }
3225
3226 /// Draw a clean 3-class softmax regression sample (the issue's generator).
3227 /// Returns `(x, class)` with integer classes `0/1/2`.
3228 fn sample_classes(seed: u64, n: usize) -> (Vec<f64>, Vec<usize>) {
3229 let mut rng = SplitMix64(seed.wrapping_add(0x1234_5678));
3230 let mut x = Vec::with_capacity(n);
3231 let mut cls = Vec::with_capacity(n);
3232 for _ in 0..n {
3233 let xi = -2.0 + 4.0 * rng.unit();
3234 let eta = [0.5 + 0.8 * xi, -0.3 - 0.5 * xi, 0.0];
3235 let mut p = [eta[0].exp(), eta[1].exp(), eta[2].exp()];
3236 let s: f64 = p.iter().sum();
3237 for v in &mut p {
3238 *v /= s;
3239 }
3240 // Inverse-CDF draw into one of the 3 classes.
3241 let u = rng.unit();
3242 let c = if u < p[0] {
3243 0
3244 } else if u < p[0] + p[1] {
3245 1
3246 } else {
3247 2
3248 };
3249 x.push(xi);
3250 cls.push(c);
3251 }
3252 (x, cls)
3253 }
3254
3255 /// Build an `EncodedDataset` with columns `x` (numeric) and `y`
3256 /// (categorical, from the given string labels) by round-tripping a CSV.
3257 fn dataset_xy(dir: &std::path::Path, tag: &str, x: &[f64], y: &[String]) -> gam_data::EncodedDataset {
3258 let path = dir.join(format!("data_{tag}.csv"));
3259 let mut csv = String::from("x,y\n");
3260 for (xi, yi) in x.iter().zip(y.iter()) {
3261 writeln!(csv, "{xi},{yi}").unwrap();
3262 }
3263 fs::write(&path, csv).expect("write training csv");
3264 load_dataset_projected(&path, &["x".to_string(), "y".to_string()])
3265 .expect("load training dataset")
3266 }
3267
3268 /// Fit `y ~ s(x)` under the relabeling `name_map` (original class `c` gets
3269 /// label `name_map[c]`), predict on `grid`, and return the predicted
3270 /// probabilities **realigned to the original class order** 0/1/2, shape
3271 /// `(grid.len(), 3)`.
3272 fn fit_predict_aligned(
3273 dir: &std::path::Path,
3274 tag: &str,
3275 x: &[f64],
3276 cls: &[usize],
3277 name_map: [&str; 3],
3278 grid: &[f64],
3279 ) -> Array2<f64> {
3280 let labels: Vec<String> = cls.iter().map(|&c| name_map[c].to_string()).collect();
3281 let train = dataset_xy(dir, tag, x, &labels);
3282 let config = FitConfig::default();
3283 let model = fit_penalized_multinomial_formula(&train, "y ~ s(x)", &config, 1.0, 60, 1e-6)
3284 .expect("multinomial formula fit must succeed");
3285
3286 // Predict on the grid. The categorical `y` column is not needed for
3287 // prediction, but the schema is simplest if we supply a dummy.
3288 let grid_y: Vec<String> = grid.iter().map(|_| name_map[0].to_string()).collect();
3289 let grid_ds = dataset_xy(dir, &format!("{tag}_grid"), grid, &grid_y);
3290 let probs = predict_multinomial_formula(&model, &grid_ds)
3291 .expect("multinomial predict must succeed");
3292
3293 // `model.class_levels` is the sorted label order; the column for original
3294 // class `c` is at the rank of `name_map[c]` among the sorted labels.
3295 let mut sorted: Vec<&str> = name_map.to_vec();
3296 sorted.sort_unstable();
3297 let col_of_orig: Vec<usize> = (0..3)
3298 .map(|c| sorted.iter().position(|l| *l == name_map[c]).unwrap())
3299 .collect();
3300 // Sanity: the model's class_levels must match the sorted labels.
3301 assert_eq!(
3302 model.class_levels,
3303 sorted.iter().map(|s| s.to_string()).collect::<Vec<_>>(),
3304 "class_levels must be the sorted label order"
3305 );
3306 let n = grid.len();
3307 let mut aligned = Array2::<f64>::zeros((n, 3));
3308 for r in 0..n {
3309 for c in 0..3 {
3310 aligned[[r, c]] = probs[[r, col_of_orig[c]]];
3311 }
3312 }
3313 aligned
3314 }
3315
3316 fn max_abs_diff(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
3317 a.iter()
3318 .zip(b.iter())
3319 .map(|(p, q)| (p - q).abs())
3320 .fold(0.0_f64, f64::max)
3321 }
3322
3323 // gam#1587: now that the reference-symmetric centered `M⊗S_t` joint penalty
3324 // is wired through the custom-family outer REML loop (per-eval
3325 // `JointPenaltyBundle` + outer penalty_coords/logdet/operator), the
3326 // production multinomial fit is invariant to the arbitrary reference class,
3327 // so this guard runs by default (the opt-in skip attribute it carried while
3328 // the fix was pending is also forbidden by the build.rs ban-scanner). It is
3329 // an end-to-end fit guard (a handful of full softmax `y ~ s(x)` fits) —
3330 // slower than a unit test but a true production-path regression.
3331 #[test]
3332 fn multinomial_fit_is_invariant_to_reference_class_1587() {
3333 let td = tempdir().expect("tempdir");
3334 let dir = td.path();
3335 // The reference-class drift is STRUCTURAL (it does not shrink with n, see
3336 // the issue table), so a modest n exposes it just as cleanly as n=900
3337 // while keeping this an affordable CI guard.
3338 let (x, cls) = sample_classes(0, 300);
3339 let grid: Vec<f64> = (0..7).map(|i| -1.5 + 3.0 * (i as f64) / 6.0).collect();
3340
3341 // Three labelings that each make a DIFFERENT original class the baseline
3342 // (the class whose label sorts LAST is the reference K−1):
3343 // ["A","B","C"] → ref = class 2
3344 // ["B","C","A"] → ref = class 1
3345 // ["C","A","B"] → ref = class 0
3346 let a = fit_predict_aligned(dir, "abc", &x, &cls, ["A", "B", "C"], &grid);
3347 let b = fit_predict_aligned(dir, "bca", &x, &cls, ["B", "C", "A"], &grid);
3348 let c = fit_predict_aligned(dir, "cab", &x, &cls, ["C", "A", "B"], &grid);
3349
3350 // Refitting the SAME labeling twice must agree to ~machine precision —
3351 // this isolates optimizer noise from the structural reference drift.
3352 let a2 = fit_predict_aligned(dir, "abc2", &x, &cls, ["A", "B", "C"], &grid);
3353 let refit_noise = max_abs_diff(&a, &a2);
3354 assert!(
3355 refit_noise < 1e-6,
3356 "refitting the same labeling must be deterministic (got {refit_noise:.3e})"
3357 );
3358
3359 let drift = max_abs_diff(&a, &b)
3360 .max(max_abs_diff(&a, &c))
3361 .max(max_abs_diff(&b, &c));
3362 assert!(
3363 drift < 1e-3,
3364 "predicted probabilities must be invariant to the reference class; \
3365 cross-labeling drift = {drift:.3e} (refit noise = {refit_noise:.3e})"
3366 );
3367 }
3368}