Skip to main content

gam_models/
multinomial.rs

1//! Penalized multinomial-logit (softmax) GLM driver — fixed-λ inner solve.
2//!
3//! This is the principled vector-response companion to the scalar PIRLS path:
4//! the inner-loop Newton solver for a multi-class GAM at fixed smoothing
5//! parameters λ, using the canonical multinomial-logit likelihood
6//! ([`MultinomialLogitLikelihood`]) and the existing dense block-Fisher
7//! assembly in [`gam_solve::pirls::dense_block_xtwx`] /
8//! [`gam_solve::pirls::dense_block_xtwy`].
9//!
10//! # What this module does
11//!
12//! Solve, for the reference-coded multinomial-logit GAM with `K` classes and
13//! design matrix `X ∈ ℝ^{N×P}`,
14//!
15//! ```text
16//!     β̂ = argmin_β { − log L(β) + ½ Σ_{a=0}^{K-2} λ_a · β_a^T S β_a }
17//! ```
18//!
19//! where `β = [β_0; β_1; …; β_{K-2}]` is the stacked coefficient vector in
20//! output-major order (`β_a ∈ ℝ^P` is the coefficient block for class `a`),
21//! `S ∈ ℝ^{P×P}` is the smoothing penalty matrix (shared across classes,
22//! replicated as `I_{K-1} ⊗ S` over the full parameter space), and `λ_a` is
23//! a per-class smoothing parameter.
24//!
25//! The likelihood uses class `K - 1` as the reference (`η_{K-1} ≡ 0`), so the
26//! softmax gauge is fixed at the η level and no additional sum-to-zero
27//! projection is required.
28//!
29//! # Layering
30//!
31//! * **Fixed-λ inner solve** — [`fit_penalized_multinomial`] is the canonical
32//!   coefficient-space Newton solver at *given* smoothing parameters `λ`,
33//!   built on the shared [`crate::penalized_vector_glm`] engine.
34//!
35//! * **REML / LAML smoothing-parameter selection** — [`fit_penalized_multinomial_formula`]
36//!   routes through [`crate::custom_family::fit_custom_family_with_rho_prior`]
37//!   so the per-active-class `λ_a` are selected by the outer REML/LAML loop;
38//!   the caller's `init_lambda` is only a warm-start seed. The multinomial
39//!   [`crate::multinomial_reml::MultinomialFamily`] `CustomFamily`
40//!   impl calls the fixed-λ math above as its inner solve at each ρ trial and
41//!   supplies the dense per-row Hessian block for the outer trace terms.
42//!
43//! * **Formula → design integration** — `build_formula_design_for_multinomial`
44//!   parses the Wilkinson formula and assembles `X` and the per-term `S`
45//!   blocks; the `fit_multinomial_formula_pyfunc` FFI shim wires the Python
46//!   `gamfit.fit(..., family='multinomial')` entry straight to this path.
47//!
48//! # Convergence
49//!
50//! The damped-Newton-with-backtracking scaffold lives once in the shared
51//! [`crate::penalized_vector_glm`] engine: at each iteration the
52//! assembled penalized Hessian `H + I_{K-1} ⊗ (λ_a S)` is factored via faer's
53//! symmetric-PD-with-fallback path, the full Newton step `δ = −H^{-1} ∇F` is
54//! computed, and accepted with step halving if the objective fails to decrease
55//! (up to a small backtracking budget). The convergence test is the relative
56//! coefficient step norm `‖δ‖ / (1 + ‖β‖) ≤ tol`, matching the existing pyffi
57//! reference path. This module is the softmax adapter over that engine: it
58//! supplies the dense `(K-1)×(K-1)` Fisher block, the residual, and the
59//! log-likelihood through [`MultinomialLogitLikelihood`], and owns the
60//! class-count / simplex preconditions. The independent-binomial sibling
61//! [`crate::binomial_multi`] is the same engine with a row-diagonal
62//! Fisher block instead.
63
64use crate::custom_family::{
65    BlockwiseFitOptions, ParameterBlockState, PenaltyMatrix, fit_custom_family_with_rho_prior,
66};
67use crate::multinomial_reml::MultinomialFamily;
68use crate::penalized_vector_glm::{PenalizedVectorGlmInputs, fit_penalized_vector_glm};
69use crate::vector_response::{MultinomialLogitLikelihood, validate_multinomial_simplex};
70use gam_terms::inference::formula_dsl::parse_formula;
71use crate::model_types::EstimationError;
72use crate::fit_orchestration::{
73    FitConfig, build_termspec_with_geometry_and_overrides, resolved_resource_policy,
74};
75use gam_terms::smooth::{
76    PenaltyBlockInfo, TermCollectionDesign, TermCollectionSpec, build_term_collection_design,
77};
78use crate::fit_orchestration::drivers::freeze_term_collection_from_design;
79use gam_terms::term_builder::resolve_role_col;
80use gam_problem::ResponseColumnKind;
81use gam_data::ColumnKindTag;
82use gam_data::EncodedDataset;
83use gam_runtime::resource::ProblemHints;
84use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayView3};
85use serde::{Deserialize, Serialize};
86use std::sync::Arc;
87
88/// Solver-only numerical stabilization floor for the formula-driven
89/// multinomial REML inner solve (gam#747).
90///
91/// Installed with [`RidgePolicy::solver_only`](gam_problem::RidgePolicy::solver_only)
92/// so it stabilizes the inner joint-Newton **linear solve** but never enters
93/// the REML objective, the penalty log-determinant, or the Laplace Hessian.
94///
95/// What it does: the multinomial smoothing penalties are rank-deficient by
96/// design (each smooth carries an unpenalized polynomial null space) and the
97/// formula may add a fully unpenalized parametric term (`x3` / `body_mass`). On
98/// near-separable hard labels the softmax curvature is ill-conditioned along
99/// those directions, so the bare Newton step `H⁻¹∇` is huge. Lifting the
100/// smallest Hessian eigenvalue to `δ` bounds the step (`‖(H+δI)⁻¹∇‖ ≤ ‖∇‖/δ`),
101/// keeping the screening iterates finite without poisoning the softmax with
102/// `inf − inf = NaN`.
103///
104/// What it deliberately does NOT do: it adds no `½·δ·‖β‖²` term to the
105/// objective and no `δ`-shift to the REML log-determinant. The earlier
106/// `explicit_stabilization_pospart` policy folded both into the criterion,
107/// which made `1e-4` a fixed-λ Gaussian prior that shrank every identified
108/// coefficient off the MLE and biased smoothing-parameter selection — a value
109/// that had to be tuned *between* under-stabilization (NaN seeds) and
110/// over-shrinkage (lost VGAM match). As a solver-only floor that tradeoff is
111/// gone: the over-shrinkage failure mode cannot occur (nothing is shrunk), the
112/// optimized objective is the true penalized REML criterion, and the floor
113/// only has to be large enough to keep the linear algebra finite.
114///
115/// The separation defect (#753) is no longer this floor's job. If the
116/// multinomial MLE is genuinely at infinity for an unpenalized/null-space
117/// direction (complete/quasi-complete separation), no solver floor makes that
118/// direction's estimate finite. The formula REML path arms the full-span
119/// Jeffreys/Firth correction CONDITIONALLY — only on separation evidence (see
120/// [`multinomial_formula_separation_evidence`] and the two-attempt logic in
121/// [`fit_penalized_multinomial_formula`]) — so an interior, well-identified fit
122/// optimizes the unbiased penalized-REML criterion with no Firth shrinkage
123/// toward the uniform simplex, while a (quasi-)separated geometry gets the
124/// proper prior that is the only thing able to bound its penalty-null
125/// directions (#715 real-data arm). The bare fixed-λ inner driver
126/// [`fit_penalized_multinomial`] (no outer REML, no Jeffreys term) surfaces the
127/// explicit `MultinomialSeparationDetected` diagnostic for the path that has no
128/// proper prior to lean on.
129const MULTINOMIAL_FORMULA_RIDGE_FLOOR: f64 = 1.0e-4;
130
131/// Inner joint-Newton KKT tolerance for the multinomial formula path.
132///
133/// The softmax Fisher weight `W = diag(p) − ppᵀ` collapses on saturated rows,
134/// so near-separable fits (penguins, #715) reach the OBJECTIVE's f64 noise
135/// floor before the default `inner_tol = 1e-6` KKT target: measured on the
136/// penguins arm (standardized columns), the trust region collapses to 1e-12
137/// with per-attempt objective changes of ~+2e-9 on |obj| ≈ 1e2 (≈ 1e-11
138/// relative — pure rounding) while the KKT residual plateaus at 2.8e-5–9.4e-5
139/// against a scaled tolerance of ~1.9e-5. Demanding a residual below the
140/// floating-point noise floor is certifiable-never: every eval is rejected by
141/// the stall guard and the whole fit fails. `1e-5` certifies the measured
142/// plateaus while still resolving β to ~1e-6 in the relevant metric — the
143/// LAML criterion consumes β̂ with error O(residual²/curvature), far below
144/// any quantity the outer ρ-search can read.
145const MULTINOMIAL_FORMULA_INNER_TOL: f64 = 1.0e-5;
146
147/// Formula-adapter penalty calibration for multinomial softmax REML.
148///
149/// The term builder's normalized penalties are calibrated on single-response
150/// Gaussian-style score curvature. A reference-coded softmax class block sees
151/// per-row active-class Fisher diagonal `p_a(1-p_a)` plus negative cross-class
152/// coupling. At the neutral simplex (`p_k = 1/K`) the active diagonal is
153/// `(K-1)/K²`, so the binary-logit calibration is `2·(K-1)/K² = 1/2` and the
154/// three-class calibration is `4/9` rather than the historical hard-coded
155/// `1/2`. Making the scale a function of `K` keeps the physical smoothness
156/// prior tied to the likelihood curvature instead of over-penalizing every
157/// class as the simplex gains categories.
158fn multinomial_formula_penalty_scale(n_classes: usize) -> f64 {
159    let k = n_classes.max(2) as f64;
160    2.0 * (k - 1.0) / (k * k)
161}
162
163/// Largest smoothing-parameter dimension where exact dense outer curvature is
164/// still worth paying for multinomial formula fits.
165///
166/// `D = (K - 1) * n_penalties`. Medium-size loaded models use exact curvature
167/// so the optimizer does not wander into over-smoothed lambda caps on
168/// near-boundary softmax surfaces. The threshold was originally calibrated at
169/// `D <= 6` when each `s()` term carried ONE penalty; the double-penalty
170/// migration (wiggliness + null-space shrinkage per term, mgcv `select=TRUE`
171/// semantics) doubled `D` for the SAME models, silently flipping the
172/// reference formula fits (2 smooths, K = 3: old `D = 4`, now `D = 8`) onto
173/// the gradient-only route — where the #715 quality arm showed every
174/// wiggliness ρ driven onto the ±10 box bound (smooths collapsed toward their
175/// polynomial null space, truth-RMSE behind VGAM). `12 = 2 × 6` preserves the
176/// original classification boundary under the doubled penalty count while
177/// keeping the four-smooth penguin species quality fixture on the exact ARC
178/// path: that model is `D = 16`, and first-order BFGS can cycle along the
179/// near-separable lambda-to-zero ridge until the wall-clock budget expires
180/// (#1082). ARC observes the same exact curvature and can halt through the
181/// bound-aware cost-stall guard once the REML surface stops making useful
182/// progress.
183const MULTINOMIAL_EXACT_OUTER_HESSIAN_MAX_DIM: usize = 16;
184
185fn multinomial_formula_use_outer_hessian(total_rho_dim: usize) -> bool {
186    total_rho_dim <= MULTINOMIAL_EXACT_OUTER_HESSIAN_MAX_DIM
187}
188
189/// Logit magnitude beyond which fitted probabilities are saturated at ordinary
190/// double precision diagnostic scale. The bare fixed-λ driver has no outer REML
191/// state and still uses this threshold to reject a non-converged saturated
192/// iterate as a separation artifact. The formula REML path does not use this as
193/// a Firth trigger: with smoothing parameters selected, a finite saturated
194/// surface can be the valid near-separated optimum that should be scored
195/// directly.
196const MULTINOMIAL_SEPARATION_ETA_THRESHOLD: f64 = 25.0;
197
198/// Calibrated convergence tolerance for the OUTER REML/LAML smoothing-parameter
199/// search on the formula multinomial path. Matches the primary GLM REML outer
200/// (`solver::fit_orchestration::materialize` uses `tol = 1e-7`, mirrored by the
201/// `LOG_LAMBDA_TOL` / `KKT_TOL_*` constants across the REML stack): tight enough
202/// that the selected λ reaches the genuine REML optimum (the recovered
203/// probability surface matches the mature reference), loose enough that the
204/// optimizer does not grind surface-irrelevant ρ digits down to the inner KKT
205/// scale (the #1082 wall-clock overrun). The caller's `tol` is floored at this
206/// value for the OUTER loop, while it continues to drive the INNER joint-Newton
207/// KKT target unchanged.
208const MULTINOMIAL_OUTER_REML_TOL: f64 = 1e-7;
209
210/// The first multinomial formula solve is a separation probe: it is accepted
211/// when the unbiased REML criterion converges to a finite interior iterate.
212/// Near-separable data such as the penguin fixture otherwise spend the caller's
213/// full outer budget on an iterate that is discarded before the Firth/Jeffreys
214/// refit. Keep enough iterations for ordinary interior fits to certify quickly,
215/// but hand slow/non-interior probes to the proper-prior refit promptly.
216const MULTINOMIAL_UNBIASED_PROBE_OUTER_MAX_ITER: usize = 20;
217
218/// Per-observation softmax Fisher-information scale for the λ-floor units.
219///
220/// The penalty enters the criterion as `½ λ βᵀ S β` with a Frobenius-normalized
221/// `S` (`‖S‖_F = 1`, see the term-builder calibration referenced by
222/// [`multinomial_formula_penalty_scale`]), so the ridge `λ S` is directly
223/// comparable to data Fisher information. One observation contributes softmax
224/// information `p(1−p)` in a class's logit direction, which is bounded by the
225/// logistic peak `p(1−p) ≤ ¼` at `p = ½`. Using this maximal per-observation
226/// information as the unit makes the floor's strength interpretable as a count
227/// of equivalent **pseudo-observations** of prior: a ridge that equals
228/// `τ · ¼ · ‖S‖_F` carries the same logit-direction curvature as `τ` real rows
229/// sitting at the most-informative point of the likelihood. This scale is
230/// `K`-independent on purpose — the `K`-dependence of the softmax block
231/// curvature already lives in the penalty matrix via
232/// [`multinomial_formula_penalty_scale`], so the floor (a bound on the
233/// multiplier of that already-scaled penalty) must not double-count it.
234const MULTINOMIAL_FORMULA_FISHER_INFO_PER_OBS: f64 = 0.25;
235
236/// Target prior strength of the λ-floor, in pseudo-observations, for a
237/// WELL-SUPPORTED class. The floor holds the unbiased REML optimizer off the
238/// zero-penalty boundary (where a boundary-overfit smooth or a Firth switch on
239/// finite data would otherwise be accepted) with a prior worth a fixed small
240/// fraction of one observation. `8e-4` pseudo-observations reproduces the
241/// previously fixture-calibrated large-support floor `τ · ¼ = 2e-4` exactly at
242/// the calibration point, now expressed as an effective-prior-strength rather
243/// than a tuned λ value.
244const MULTINOMIAL_FORMULA_PRIOR_PSEUDO_OBS: f64 = 8.0e-4;
245
246/// Reference class support `n_ref`: the effective sample size per class at which
247/// the data Fisher information `n_c · I₁` is large enough that the floor sits at
248/// its well-supported value. Below `n_ref` the per-class data information shrinks
249/// like `n_c`, so to keep the floor's prior from vanishing *relative to* that
250/// shrinking data the effective pseudo-observation count is scaled up by
251/// `n_ref / n_c` (the prior is held to a fixed fraction of the data information,
252/// not a fixed absolute λ). At `n_c = n_ref` the scale is exactly 1.
253const MULTINOMIAL_FORMULA_SPARSE_REFERENCE_SUPPORT: f64 = 50.0;
254
255/// Cap on the floor's prior strength in the very-sparse limit, in
256/// pseudo-observations. As `n_c → 0` the `n_ref / n_c` scaling diverges; the cap
257/// holds the prior at `4e-3` pseudo-observations (`τ_max · ¼ = 1e-3` at the
258/// calibration point, the previously-tuned strong-floor value) so the floor
259/// stays a proper prior rather than a hard constraint that would dominate the
260/// likelihood for a handful-of-rows class.
261const MULTINOMIAL_FORMULA_SPARSE_PRIOR_PSEUDO_OBS_MAX: f64 = 4.0e-3;
262
263/// Continuous, Fisher-information-scaled lower λ floor for the formula path,
264/// derived from the minority class's effective sample size `n_c`.
265///
266/// # Derivation (effective-prior-strength / Fisher geometry)
267///
268/// The penalty `½ λ βᵀ S β` with `‖S‖_F = 1` adds curvature `λ` to the class
269/// logit direction; one observation adds at most `I₁ = ¼` there. So a floor that
270/// sets `λ_floor = τ_eff · I₁` gives the smooth a prior worth `τ_eff`
271/// pseudo-observations. We want a fixed *absolute* prior `τ` for a well-supported
272/// class, but for a minority class with only `n_c` effective observations the
273/// data information in its block is `n_c · I₁`; holding the prior to a fixed
274/// *fraction* of that shrinking data information requires
275///
276/// ```text
277///     τ_eff(n_c) = τ · max(1, n_ref / n_c),   clamped to [τ, τ_max]
278///     λ_floor(n_c) = τ_eff(n_c) · I₁
279/// ```
280///
281/// This is the *same* `base · max(1, c0/c)` envelope as before — but `base`,
282/// `sparse`, and `c0` are no longer fixture-tuned magic numbers: `base = τ·I₁`,
283/// `sparse = τ_max·I₁`, and `c0 = n_ref` are an effective-prior-strength of
284/// `τ`/`τ_max` pseudo-observations against the maximal per-observation softmax
285/// information `I₁ = ¼`. Properties preserved by construction:
286///   * reduces EXACTLY to `τ·I₁` for well-supported classes (`n_c ≥ n_ref`);
287///   * reduces EXACTLY to `τ_max·I₁` for very sparse classes
288///     (`n_c ≤ n_ref·τ/τ_max`, here `n_c ≤ 10`);
289///   * interpolates monotonically and continuously between them in the middle —
290///     no cliff at `n_c = n_ref`.
291/// At the calibration point the endpoints equal the previous `2e-4` / `1e-3`, so
292/// fixtures whose smallest class has `n_c ≥ 50` (penguins, the vgam softmax
293/// arms) are unaffected — they sit at `τ·I₁ = 2e-4` exactly as before.
294fn multinomial_formula_min_lambda(y_one_hot: ArrayView2<'_, f64>) -> f64 {
295    let base = MULTINOMIAL_FORMULA_PRIOR_PSEUDO_OBS * MULTINOMIAL_FORMULA_FISHER_INFO_PER_OBS;
296    let sparse =
297        MULTINOMIAL_FORMULA_SPARSE_PRIOR_PSEUDO_OBS_MAX * MULTINOMIAL_FORMULA_FISHER_INFO_PER_OBS;
298    let min_class_count = (0..y_one_hot.ncols())
299        .map(|class| y_one_hot.column(class).sum())
300        .fold(f64::INFINITY, f64::min);
301    if !min_class_count.is_finite() || min_class_count <= 0.0 {
302        return base;
303    }
304    // Effective pseudo-observation prior strength: held to a fixed fraction of
305    // the shrinking per-class data information once n_c falls below n_ref.
306    let pseudo_obs_scale =
307        (MULTINOMIAL_FORMULA_SPARSE_REFERENCE_SUPPORT / min_class_count).max(1.0);
308    (base * pseudo_obs_scale).clamp(base, sparse)
309}
310
311fn max_abs_eta_location(eta: ArrayView2<'_, f64>) -> (f64, usize, usize) {
312    let mut best = (0.0_f64, 0usize, 0usize);
313    for ((row, active_class), &value) in eta.indexed_iter() {
314        let abs = value.abs();
315        if abs > best.0 {
316            best = (abs, row, active_class);
317        }
318    }
319    best
320}
321
322/// Separation gate for the REML/LAML **formula** path.
323///
324/// Unlike the bare fixed-λ driver [`fit_penalized_multinomial`] (which has no
325/// outer REML state and so must reject a saturated, non-converged iterate as a
326/// separation artifact at the [`MULTINOMIAL_SEPARATION_ETA_THRESHOLD`] logit
327/// magnitude), the formula path can return a finite saturated mode after the
328/// coupled outer optimizer has selected smoothing parameters. A `|η| >= 25`
329/// gate is therefore wrong here: the penguins arm can legitimately have large
330/// fitted logits while still producing finite probabilities and a usable REML
331/// mode.
332///
333/// Only a genuinely NON-FINITE `η` (a NaN/Inf blow-up in the inner linear
334/// algebra) is a real formula-path failure. A finite, even saturated, `η` is
335/// accepted so the truth-recovery / match-or-beat bars are evaluated against the
336/// actual fitted surface instead of an adapter diagnostic.
337fn multinomial_formula_separation_diagnostic(
338    inner_cycles: usize,
339    outer_iterations: usize,
340    block_states: &[ParameterBlockState],
341) -> Option<EstimationError> {
342    let mut nonfinite: Option<(f64, usize, usize)> = None;
343    for (active_class, state) in block_states.iter().enumerate() {
344        for (row, &value) in state.eta.iter().enumerate() {
345            if !value.is_finite() {
346                nonfinite = Some((value, row, active_class));
347                break;
348            }
349        }
350        if nonfinite.is_some() {
351            break;
352        }
353    }
354    nonfinite.map(|(value, row_index, active_class_index)| {
355        EstimationError::MultinomialSeparationDetected {
356            iteration: inner_cycles.max(outer_iterations),
357            max_abs_eta: value.abs(),
358            active_class_index,
359            row_index,
360        }
361    })
362}
363
364/// Separation EVIDENCE gate for the conditional Firth/Jeffreys engagement on
365/// the formula REML path (#715 / #753).
366///
367/// The structural mathematics (#715 issue thread): for any coefficient
368/// direction `v` with `S v = 0` (a penalty-null direction — intercept, a
369/// smooth's polynomial null component, an unpenalized parametric term), the
370/// penalized joint Hessian satisfies `(H + S_λ) v = H v` for EVERY smoothing
371/// parameter ρ. When the data (quasi-)separate, the softmax Fisher weight
372/// `W = diag(p) − p pᵀ → 0` on the saturated rows, so `H v = JᵀWJ v → 0` along
373/// the penalty-null directions those rows support: `(H + S_λ) v ≈ 0` for every
374/// ρ — NO λ can repair it, the inner Newton can never certify a KKT point
375/// there, and every outer REML startup seed is rejected (the penguins
376/// real-data arm). The only principled cure is a PROPER prior on that
377/// quotient-null subspace — the Jeffreys/Firth term `Φ = ½ log|ZᵀHZ|`, whose
378/// Gauss–Newton curvature supplies the missing `O(1)` bound.
379///
380/// But the Firth prior is not free on interior data: unconditionally armed, it
381/// shrinks fitted class probabilities toward the uniform simplex `1/K`
382/// (an `O(1/n)` pull that the synthetic match-or-beat arm of #715 measured as
383/// a real truth-RMSE loss vs the unbiased criterion). So the formula path
384/// engages it ONLY on separation evidence, mirroring the #753 "diagnose, then
385/// arm" split:
386///
387/// * a NON-FINITE logit — the inner linear algebra blew up along an unbounded
388///   direction.
389///
390/// Returns `Some(description)` naming the witnessing logit when evidence is
391/// found, `None` for a finite fit (which is then accepted as-is, with zero
392/// Firth bias). A FAILED unbiased solve (`Err` from the rho-prior driver, e.g.
393/// "no startup seed passed") is the second evidence form and is handled
394/// directly at the call site in [`fit_penalized_multinomial_formula`].
395fn multinomial_formula_separation_evidence(block_states: &[ParameterBlockState]) -> Option<String> {
396    for (active_class, state) in block_states.iter().enumerate() {
397        for (row, &value) in state.eta.iter().enumerate() {
398            if !value.is_finite() {
399                return Some(format!(
400                    "non-finite logit eta[row {row}, active class {active_class}] = {value}"
401                ));
402            }
403        }
404    }
405    None
406}
407
408/// Extra evidence used only for a NON-CONVERGED capped unbiased probe.
409///
410/// A converged finite saturated formula fit is still a valid optimum and must be
411/// scored without Firth bias. A capped probe that failed to converge while it
412/// already carries separation-scale logits is different: spending the full
413/// unbiased outer budget on the same lambda-to-zero surface is the #1082
414/// timeout. Route that case straight to the proper-prior refit.
415fn multinomial_formula_unresolved_probe_separation_evidence(
416    block_states: &[ParameterBlockState],
417) -> Option<String> {
418    if let Some(evidence) = multinomial_formula_separation_evidence(block_states) {
419        return Some(evidence);
420    }
421
422    let mut best = (0.0_f64, 0usize, 0usize);
423    for (active_class, state) in block_states.iter().enumerate() {
424        for (row, &value) in state.eta.iter().enumerate() {
425            let abs = value.abs();
426            if abs > best.0 {
427                best = (abs, row, active_class);
428            }
429        }
430    }
431    if best.0 >= MULTINOMIAL_SEPARATION_ETA_THRESHOLD {
432        Some(format!(
433            "separation-scale finite logit |eta[row {}, active class {}]| = {:.3e} \
434             after capped unbiased probe",
435            best.1, best.2, best.0
436        ))
437    } else {
438        None
439    }
440}
441
442/// Inputs to [`fit_penalized_multinomial`].
443///
444/// The penalty matrix `S` is shared across classes; per-class smoothing
445/// parameters `lambdas` (length `K - 1`) scale `S` independently for each
446/// active class. The full block-replicated penalty is `diag_a(λ_a) ⊗ S`,
447/// which is exactly what [`gam_solve::arrow_schur::KroneckerPenaltyOp`]
448/// expresses in matrix-free form when this driver is later lifted into the
449/// arrow-Schur loop.
450#[derive(Debug, Clone)]
451pub struct MultinomialFitInputs<'a> {
452    /// Design matrix `X ∈ ℝ^{N×P}` (one row per observation).
453    pub design: ArrayView2<'a, f64>,
454    /// Categorical response `Y ∈ ℝ^{N×K}`. Each row must be a point on the
455    /// probability simplex (`y_c ≥ 0`, `Σ_c y_c = 1`): a one-hot indicator for
456    /// hard classification, or a label-smoothed probability vector. Rows whose
457    /// mass departs from 1 are rejected — the softmax residual gradient and
458    /// Fisher block are the derivatives of `Σ_c y_c log p_c` only under the
459    /// simplex constraint (see `validate_multinomial_simplex`).
460    pub y_one_hot: ArrayView2<'a, f64>,
461    /// Shared smoothing penalty `S ∈ ℝ^{P×P}` (symmetric, PSD).
462    pub penalty: ArrayView2<'a, f64>,
463    /// Per-active-class smoothing parameter `λ_a` (length `K - 1`).
464    pub lambdas: ArrayView1<'a, f64>,
465    /// Optional per-row weights (length `N`); `None` ⇒ uniform 1.0.
466    pub row_weights: Option<ArrayView1<'a, f64>>,
467    /// Optional per-row Fisher-block override, shape `(N, K-1, K-1)` in the
468    /// active-class gauge (the reference class `K-1` is dropped). When `Some`,
469    /// each Newton step uses this block as the curvature `W` in place of the
470    /// analytic softmax Fisher `w_n (δ_ab p_a − p_a p_b)`; the gradient/residual
471    /// path stays analytic, so this is a curvature-only override (the
472    /// research escape-hatch for latent multinomial fits, issue #349). Each
473    /// per-row block must be symmetric, PSD, and finite — preconditions the
474    /// FFI boundary discharges before constructing this view.
475    pub fisher_w_override: Option<ArrayView3<'a, f64>>,
476    /// Maximum Newton iterations; recommend 50.
477    pub max_iter: usize,
478    /// Relative-step convergence tolerance; recommend 1e-7.
479    pub tol: f64,
480}
481
482/// Outputs of [`fit_penalized_multinomial`].
483#[derive(Debug, Clone)]
484pub struct MultinomialFitOutputs {
485    /// Active-class coefficient block, shape `(P, K-1)` (column `a` is `β_a`).
486    /// The reference class `K - 1` has `β_{K-1} ≡ 0` by construction and is
487    /// not stored.
488    pub coefficients_active: Array2<f64>,
489    /// Fitted probabilities, shape `(N, K)`.
490    pub fitted_probabilities: Array2<f64>,
491    /// Number of Newton iterations executed (including the final step that
492    /// satisfied the tolerance).
493    pub iterations: usize,
494    /// `true` if the relative-step test was satisfied; `false` if the
495    /// solver exhausted `max_iter`. (A non-converged solve is still
496    /// returned; the caller decides whether to escalate.)
497    pub converged: bool,
498    /// Penalized negative log-likelihood at the returned `β̂`:
499    /// `−log L(β̂) + ½ Σ_a λ_a · β̂_a^T S β̂_a`.
500    pub penalized_neg_log_likelihood: f64,
501    /// Unpenalized deviance `−2 log L(β̂)` for diagnostic reporting.
502    pub deviance: f64,
503}
504
505/// Fit a penalized multinomial-logit GAM at fixed `λ`.
506///
507/// See the module docs for the optimization problem and conventions. This
508/// function is the canonical inner solve: the outer REML/LAML loop, when
509/// added, calls this at each `ρ = log λ` trial.
510pub fn fit_penalized_multinomial(
511    inputs: MultinomialFitInputs<'_>,
512) -> Result<MultinomialFitOutputs, EstimationError> {
513    let MultinomialFitInputs {
514        design,
515        y_one_hot,
516        penalty,
517        lambdas,
518        row_weights,
519        fisher_w_override,
520        max_iter,
521        tol,
522    } = inputs;
523
524    // ──────────────────────── family-specific validation ───────────────────
525    // The shared engine re-validates the geometry common to every vector-GLM
526    // (nonempty design, penalty shape, λ finiteness/non-negativity, override
527    // `(N, M, M)` shape, finite design). The multinomial family owns the
528    // class-count contract (`K ≥ 2`, λ length `K − 1`), the per-row simplex
529    // precondition under which the softmax residual/Fisher are the exact
530    // derivatives of `Σ_c y_c log p_c`, and the row-weight check the likelihood
531    // adapter consumes.
532    let n_obs = design.nrows();
533    let (y_rows, k) = y_one_hot.dim();
534    if y_rows != n_obs {
535        crate::bail_invalid_estim!(
536            "fit_penalized_multinomial: y rows {y_rows} ≠ design rows {n_obs}"
537        );
538    }
539    if k < 2 {
540        crate::bail_invalid_estim!(
541            "fit_penalized_multinomial: need at least 2 classes (got K={k})"
542        );
543    }
544    let m = k - 1;
545    if lambdas.len() != m {
546        crate::bail_invalid_estim!(
547            "fit_penalized_multinomial: lambdas length {} ≠ K-1 = {m}",
548            lambdas.len()
549        );
550    }
551    if let Some(fw) = fisher_w_override.as_ref() {
552        if fw.dim() != (n_obs, m, m) {
553            crate::bail_invalid_estim!(
554                "fit_penalized_multinomial: fisher_w_override shape {:?} ≠ (N, K-1, K-1) = ({n_obs}, {m}, {m})",
555                fw.dim()
556            );
557        }
558    }
559    if let Some(w) = row_weights.as_ref() {
560        if w.len() != n_obs {
561            crate::bail_invalid_estim!(
562                "fit_penalized_multinomial: row_weights length {} ≠ N = {n_obs}",
563                w.len()
564            );
565        }
566        for (i, &v) in w.iter().enumerate() {
567            if !(v.is_finite() && v >= 0.0) {
568                crate::bail_invalid_estim!(
569                    "fit_penalized_multinomial: row_weights[{i}] must be finite and ≥ 0 (got {v})"
570                );
571            }
572        }
573    }
574    validate_multinomial_simplex(y_one_hot, "fit_penalized_multinomial")?;
575
576    // ────────────────────────── likelihood construction ───────────────────
577    let mut likelihood = MultinomialLogitLikelihood::with_classes(k)?;
578    if let Some(w) = row_weights.as_ref() {
579        likelihood = likelihood.with_row_weights(w.to_owned())?;
580    }
581
582    // ─────────────────── shared penalized vector-GLM solve ─────────────────
583    // The softmax Fisher block is dense across the `M = K − 1` active classes;
584    // the engine assembles the coupled `(P·M)×(P·M)` penalized Hessian, runs
585    // the damped Newton loop, and returns the converged `β̂` and `η = X β̂`.
586    let fit = fit_penalized_vector_glm(
587        PenalizedVectorGlmInputs {
588            design,
589            y: y_one_hot,
590            penalty,
591            lambdas,
592            fisher_w_override,
593            max_iter,
594            tol,
595            // #1587: production multinomial still uses the per-class Diagonal
596            // metric pending the REML per-class→per-term λ re-key that the
597            // reference-symmetric Centered metric requires (shared λ). The
598            // Centered engine path + its invariance proof land first.
599            class_penalty_metric: crate::penalized_vector_glm::ClassPenaltyMetric::Diagonal,
600        },
601        &likelihood,
602        "fit_penalized_multinomial",
603    )?;
604
605    let (max_abs_eta, row_index, active_class_index) = max_abs_eta_location(fit.eta.view());
606    if !fit.converged && max_abs_eta >= MULTINOMIAL_SEPARATION_ETA_THRESHOLD {
607        return Err(EstimationError::MultinomialSeparationDetected {
608            iteration: fit.iterations,
609            max_abs_eta,
610            active_class_index,
611            row_index,
612        });
613    }
614
615    let fitted_probabilities = likelihood.probabilities(fit.eta.view());
616
617    Ok(MultinomialFitOutputs {
618        coefficients_active: fit.coefficients,
619        fitted_probabilities,
620        iterations: fit.iterations,
621        converged: fit.converged,
622        penalized_neg_log_likelihood: -fit.log_likelihood + fit.penalty_term,
623        deviance: -2.0 * fit.log_likelihood,
624    })
625}
626
627// ---------------------------------------------------------------------------
628// Formula-driven multinomial pipeline
629// ---------------------------------------------------------------------------
630//
631// Slice A of the multinomial integration: a single public entry that takes
632// a parsed `EncodedDataset`, a Wilkinson-style formula, and a uniform initial
633// smoothing parameter, then runs the full
634//
635//     parse → termspec → design (X, S blocks) → one-hot Y → REML λ-selection
636//
637// pipeline. `fit_penalized_multinomial_formula` drives the outer REML/LAML
638// loop (via the custom-family path) to select an independent λ per (class,
639// term); `init_lambda` (default 1.0) is only the warm-start seed for every
640// block. The reference class is the last level of the categorical response
641// column as recorded in the dataset schema.
642
643/// Saved-model payload for a multinomial fit driven by a Wilkinson formula.
644///
645/// This is what the FFI returns to Python. It carries everything the Python
646/// `MultinomialModel.predict` path needs to evaluate `softmax(X_new · β)` on
647/// fresh data using the *training* basis / penalty structure (no refit on
648/// predict, no re-derivation of class levels).
649#[derive(Debug, Clone, Serialize, Deserialize)]
650pub struct MultinomialSavedModel {
651    /// The training formula, verbatim. Stored so Python's `summary()` and
652    /// any round-trip persistence path can echo what was fit.
653    pub formula: String,
654    /// Names of the *training* response levels in canonical order. The last
655    /// entry is the reference class (η = 0); the first `K - 1` carry the
656    /// active linear-predictor blocks. Class permutations are forbidden:
657    /// this list is fixed at fit time and predictions emit columns in the
658    /// same order.
659    pub class_levels: Vec<String>,
660    /// Index of the reference class within `class_levels` — currently always
661    /// `class_levels.len() - 1`, exposed as a field so future "user-pinned
662    /// reference" gauges (e.g. `family='multinomial', reference='setosa'`)
663    /// can land without changing the on-disk shape.
664    pub reference_class_index: usize,
665    /// Resolved term-collection spec used to build `X` at fit time. Replayed
666    /// on predict via [`gam_terms::smooth::build_term_collection_design`].
667    pub resolved_termspec: TermCollectionSpec,
668    /// Active-class coefficient block, shape `(P, K-1)`. Column `a` is the
669    /// coefficient vector for class `class_levels[a]`. Stored flat in
670    /// row-major order to keep the serde payload self-describing.
671    pub coefficients_flat: Vec<f64>,
672    /// `P` — coefficient count per active class. Matches the column count of
673    /// the design matrix the saved `resolved_termspec` produces.
674    pub p_per_class: usize,
675    /// Number of active classes (`K - 1`).
676    pub n_active_classes: usize,
677    /// Original training column headers, in dataset-column order. Needed at
678    /// predict time so the FFI can align a fresh `Dataset` to the training
679    /// schema before evaluating the basis.
680    pub training_headers: Vec<String>,
681    /// REML/LAML-selected smoothing parameters, one per `(active class, smooth
682    /// term)`, flattened in block-major order: all of class 0's per-term λ,
683    /// then class 1's, and so on. Per-term penalties (#561) mean each active
684    /// class block selects an *independent* λ for every smooth term, so this
685    /// vector has length `Σ_a (#terms in class a)` = `(K − 1) · #terms`. Use
686    /// [`MultinomialSavedModel::lambdas_per_block`] to segment it by class. An
687    /// unpenalized model (no smooth terms) yields an empty vector.
688    pub lambdas: Vec<f64>,
689    /// Number of smoothing parameters (smooth terms) in each active class
690    /// block, parallel to `class_levels[0..K-1]`. Segments the flat `lambdas`
691    /// vector: class `a`'s λ are `lambdas[Σ_{b<a} lambdas_per_block[b] ..][..
692    /// lambdas_per_block[a]]`. Every entry is identical in the shared-design
693    /// architecture (all classes share the same term structure), but it is
694    /// stored explicitly so consumers never have to assume that.
695    pub lambdas_per_block: Vec<usize>,
696    /// Newton iterations executed; recorded for the summary report.
697    pub iterations: usize,
698    /// `true` if the inner Newton solver hit the relative-step tolerance.
699    pub converged: bool,
700    /// Penalized negative log-likelihood at the returned `β̂`.
701    pub penalized_neg_log_likelihood: f64,
702    /// Unpenalized deviance `−2 log L(β̂)`.
703    pub deviance: f64,
704    /// Per-active-class effective degrees of freedom (hat-matrix trace),
705    /// length `K - 1`. Populated when the REML driver reports an
706    /// inference block; falls back to `None` for the legacy fixed-λ path.
707    #[serde(default)]
708    pub edf_per_class: Option<Vec<f64>>,
709    /// Per-PENALTY effective degrees of freedom, one entry per smoothing
710    /// parameter (length `== lambdas.len()`), aligned block-major with the flat
711    /// [`Self::lambdas`] / [`Self::lambdas_per_block`] layout. Each entry is the
712    /// penalty-block trace EDF `rank(S_k) − λ_k·tr(H⁻¹ S_k)`, clamped to
713    /// `[0, rank(S_k)]`. This is the per-(class, term, penalty) resolution that
714    /// the per-class [`Self::edf_per_class`] SUM deliberately hides: only the
715    /// per-penalty vector reveals whether an individual smooth collapsed onto its
716    /// polynomial null space (its wiggliness λ driven to the λ-cap), which a
717    /// per-class total cannot show. Populated whenever the REML driver reports an
718    /// inference block; `None` on the legacy fixed-λ path or when the trace
719    /// channel is mis-shaped. Unlike `edf_per_class`, the entries do NOT sum to
720    /// the model EDF when several penalties share one coefficient range (a
721    /// double-penalty smooth has `Σ_k rank(S_k) > p_per_class`).
722    #[serde(default)]
723    pub edf_per_penalty: Option<Vec<f64>>,
724    /// Joint posterior coefficient covariance `H⁻¹` (#1101), block-ordered to
725    /// match the stacked active-class coefficient vector `β = [β_0; …; β_{K-2}]`
726    /// (class `a`'s `P` coefficients occupy rows/cols `a·P .. (a+1)·P`). This is
727    /// the Laplace covariance the REML driver already computes from the factored
728    /// penalized Hessian; storing it gives the predict path delta-method
729    /// per-class probability standard errors and the summary its Wald
730    /// smooth-term tests. Flattened row-major over the `(P·M)×(P·M)` matrix.
731    /// `None` for a model fitted before covariance was surfaced.
732    #[serde(default)]
733    pub coefficient_covariance_flat: Option<Vec<f64>>,
734    /// Joint coefficient-space influence matrix `F = H⁻¹ X'WX` (#1101),
735    /// block-ordered identically to [`Self::coefficient_covariance_flat`].
736    /// Its per-term diagonal block trace is the term's effective degrees of
737    /// freedom and its `tr(F_jj)²/tr(F_jj²)` the Wood reference d.f., feeding
738    /// the rank-truncated Wald smooth-term test in `summary()`. Flattened
739    /// row-major over the `(P·M)×(P·M)` matrix. `None` when unavailable.
740    #[serde(default)]
741    pub coefficient_influence_flat: Option<Vec<f64>>,
742    /// Per-(active class, smooth term) coefficient column range and unpenalized
743    /// nullspace dimension within the `P`-wide class block (#1101). Parallel to
744    /// the smooth terms the design produced; replicated across classes by the
745    /// shared-design architecture. Drives the Wald smooth-term table in
746    /// `summary()`. Empty for a wholly parametric (no-smooth) model.
747    #[serde(default)]
748    pub smooth_term_spans: Vec<MultinomialSmoothTermSpan>,
749    /// One descriptive label per *penalty component* within a single active-class
750    /// block, parallel to that block's λ slice (i.e. length
751    /// `lambdas_per_block[0]`). The Marra–Wood double penalty (and tensor /
752    /// operator smooths) emit **more than one** penalty component — hence more
753    /// than one λ — per smooth term, so this is NOT 1:1 with
754    /// [`Self::smooth_term_spans`]: a single `s(x)` term contributes a primary
755    /// wiggliness λ labelled `s(x)` and a null-space shrinkage λ labelled
756    /// `s(x) [null space]`. The summary renderer pairs `lambdas` with these
757    /// labels component-for-component so no λ is ever dropped (#1544). Built from
758    /// the per-component term name + penalty role at fit time; empty for a
759    /// wholly parametric model or a model serialized before this field existed.
760    #[serde(default)]
761    pub lambda_labels: Vec<String>,
762}
763
764/// One smooth term's coefficient span within a class block, plus its
765/// unpenalized nullspace dimension and a display label (#1101). The Wald
766/// smooth-significance test in `summary()` slices the joint covariance /
767/// influence at `a·P + col_start .. a·P + col_end` for active class `a`.
768#[derive(Debug, Clone, Serialize, Deserialize)]
769pub struct MultinomialSmoothTermSpan {
770    /// Human-readable term label (the smooth's formula token), for the table.
771    pub label: String,
772    /// Start column of the term within the per-class `P`-wide coefficient block.
773    pub col_start: usize,
774    /// End column (exclusive) of the term within the per-class block.
775    pub col_end: usize,
776    /// Leading unpenalized (polynomial nullspace) dimension within the term.
777    pub nullspace_dim: usize,
778}
779
780/// Descriptive label for one penalty *component* (one λ) within a class block,
781/// for the `summary()` per-class λ rollup (#1544). A smooth term can emit
782/// several penalty components — the Marra–Wood double penalty splits `s(x)`
783/// into a primary wiggliness penalty and a null-space shrinkage penalty, and
784/// tensor / operator smooths emit a component per margin / differential
785/// operator — each with its own independently-selected λ. The label is the
786/// term name (from `PenaltyBlockInfo::termname`) plus a role suffix derived
787/// from the penalty's [`PenaltySource`], so each λ in the summary names both
788/// the term it smooths and the role it plays. `pen_idx` is the global penalty
789/// index, used only as a last-resort fallback label.
790fn penalty_component_label(info: Option<&PenaltyBlockInfo>, pen_idx: usize) -> String {
791    use gam_terms::basis::PenaltySource;
792    let term = info
793        .and_then(|i| i.termname.clone())
794        .unwrap_or_else(|| format!("s{pen_idx}"));
795    let role = match info.map(|i| &i.penalty.source) {
796        // The primary wiggliness penalty is the term's "main" λ; show the bare
797        // term name so the common single-penalty case reads cleanly.
798        Some(PenaltySource::Primary) | None => None,
799        Some(PenaltySource::DoublePenaltyNullspace) => Some("null space".to_string()),
800        Some(PenaltySource::OperatorMass) => Some("mass".to_string()),
801        Some(PenaltySource::OperatorTension) => Some("tension".to_string()),
802        Some(PenaltySource::OperatorStiffness) => Some("stiffness".to_string()),
803        Some(PenaltySource::OperatorRelevance { axis }) => Some(format!("axis {axis}")),
804        Some(PenaltySource::TensorMarginal { dim }) => Some(format!("margin {dim}")),
805        Some(PenaltySource::TensorSeparable { penalized_margins }) => {
806            Some(format!("separable {penalized_margins:?}"))
807        }
808        Some(PenaltySource::TensorGlobalRidge) => Some("ridge".to_string()),
809        Some(PenaltySource::Other(s)) => Some(s.clone()),
810    };
811    match role {
812        Some(role) => format!("{term} [{role}]"),
813        None => term,
814    }
815}
816
817impl MultinomialSavedModel {
818    /// Active-class coefficient block as an `(P, K-1)` `ndarray` view.
819    pub fn coefficients_active(&self) -> Array2<f64> {
820        Array2::from_shape_vec(
821            (self.p_per_class, self.n_active_classes),
822            self.coefficients_flat.clone(),
823        )
824        .expect(
825            "MultinomialSavedModel.coefficients_flat length must equal p_per_class * n_active_classes",
826        )
827    }
828
829    /// Evaluate `softmax(X · β)` at fresh data rows. `X_new` must have
830    /// `self.p_per_class` columns (i.e. it was built from the same
831    /// `resolved_termspec` as fit time). Returns an `(N_new, K)` matrix
832    /// with rows summing to 1; column order matches `self.class_levels`.
833    pub fn predict_probabilities(&self, x_new: ArrayView2<'_, f64>) -> Array2<f64> {
834        let n_new = x_new.nrows();
835        let p = self.p_per_class;
836        let m = self.n_active_classes;
837        let k = m + 1;
838        assert_eq!(
839            x_new.ncols(),
840            p,
841            "MultinomialSavedModel.predict_probabilities: X has {} cols, expected {p}",
842            x_new.ncols()
843        );
844        let beta = self.coefficients_active();
845        let mut probs = Array2::<f64>::zeros((n_new, k));
846        let mut eta_active = vec![0.0_f64; m];
847        let mut row_probs = vec![0.0_f64; k];
848        for row in 0..n_new {
849            for a in 0..m {
850                let mut v = 0.0_f64;
851                for i in 0..p {
852                    v += x_new[[row, i]] * beta[[i, a]];
853                }
854                eta_active[a] = v;
855            }
856            MultinomialLogitLikelihood::softmax_with_baseline(&eta_active, &mut row_probs);
857            for c in 0..k {
858                probs[[row, c]] = row_probs[c];
859            }
860        }
861        probs
862    }
863
864    /// Reconstruct the joint posterior covariance `H⁻¹` as a `(P·M)×(P·M)`
865    /// `ndarray`, block-ordered to match the stacked coefficient vector
866    /// `θ[a·P + i] = β[i, a]` (#1101). `None` when the model was fitted before
867    /// covariance was surfaced (legacy payload).
868    pub fn coefficient_covariance(&self) -> Option<Array2<f64>> {
869        let d = self.p_per_class.checked_mul(self.n_active_classes)?;
870        let flat = self.coefficient_covariance_flat.as_ref()?;
871        Array2::from_shape_vec((d, d), flat.clone()).ok()
872    }
873
874    /// Reconstruct the joint influence matrix `F = H⁻¹ X'WX` as a
875    /// `(P·M)×(P·M)` `ndarray`, block-ordered like
876    /// [`Self::coefficient_covariance`] (#1101). `None` when unavailable.
877    pub fn coefficient_influence(&self) -> Option<Array2<f64>> {
878        let d = self.p_per_class.checked_mul(self.n_active_classes)?;
879        let flat = self.coefficient_influence_flat.as_ref()?;
880        Array2::from_shape_vec((d, d), flat.clone()).ok()
881    }
882
883    /// Evaluate `softmax(X·β)` AND its delta-method per-class probability
884    /// standard error at fresh data rows (#1101).
885    ///
886    /// For active classes `b ∈ 0..M` the softmax Jacobian is
887    /// `∂p_c/∂η_b = p_c (δ_{cb} − p_b)`, and `∂η_b/∂β[i,a] = X[i]·δ_{ab}`, so the
888    /// gradient of class-`c` probability w.r.t. the block-ordered coefficient
889    /// vector is `g_c[a·P + i] = X[i]·p_c (δ_{ca} − p_a)` (active `a`; the
890    /// reference class `M` contributes `p_c(0 − p_a)` via every active block).
891    /// The delta-method variance is `Var(p_c) = g_cᵀ Σ g_c` with `Σ = H⁻¹` the
892    /// joint posterior covariance, and `SE(p_c) = √Var(p_c)`. Returns
893    /// `(probs (N,K), prob_se (N,K))`; `prob_se` is `None` when no covariance is
894    /// stored. The simplex `[0,1]` clamp is applied by the interval consumer, not
895    /// here (the SE itself is unclamped).
896    pub fn predict_probabilities_with_se(
897        &self,
898        x_new: ArrayView2<'_, f64>,
899    ) -> (Array2<f64>, Option<Array2<f64>>) {
900        let probs = self.predict_probabilities(x_new);
901        let Some(cov) = self.coefficient_covariance() else {
902            return (probs, None);
903        };
904        let n_new = x_new.nrows();
905        let p = self.p_per_class;
906        let m = self.n_active_classes;
907        let k = m + 1;
908        let d = p * m;
909        let mut prob_se = Array2::<f64>::zeros((n_new, k));
910        let mut grad = vec![0.0_f64; d];
911        for row in 0..n_new {
912            let prow = probs.row(row);
913            for c in 0..k {
914                let pc = prow[c];
915                // g_c[a·P + i] = X[i] · p_c · (δ_{ca} − p_a), a active.
916                for a in 0..m {
917                    let pa = prow[a];
918                    let factor = pc * (if c == a { 1.0 - pa } else { -pa });
919                    let base = a * p;
920                    for i in 0..p {
921                        grad[base + i] = x_new[[row, i]] * factor;
922                    }
923                }
924                // Var = gᵀ Σ g.
925                let mut var = 0.0_f64;
926                for r in 0..d {
927                    let gr = grad[r];
928                    if gr == 0.0 {
929                        continue;
930                    }
931                    let mut acc = 0.0_f64;
932                    for s in 0..d {
933                        acc += cov[[r, s]] * grad[s];
934                    }
935                    var += gr * acc;
936                }
937                prob_se[[row, c]] = var.max(0.0).sqrt();
938            }
939        }
940        (probs, Some(prob_se))
941    }
942
943    /// Wood (2013) rank-truncated Wald smooth-significance test per
944    /// `(active class, smooth term)` (#1101), reusing the exact scalar-summary
945    /// kernel [`gam_terms::inference::smooth_test::wood_smooth_test`]. For active
946    /// class `a` and term span `[c0, c1)` within the class block, the global
947    /// coefficient range is `a·P + c0 .. a·P + c1`; the joint covariance and
948    /// influence are sliced there. The term EDF is the influence-block trace
949    /// `tr(F_jj)` (when present) and the reference d.f. uses `tr(F_jj)²/tr(F_jj²)`,
950    /// exactly as the scalar path. The multinomial softmax is a known-dispersion
951    /// family, so the χ²_{ref_df} branch applies. Returns one row per
952    /// `(class label, term label, edf, ref_df, statistic, p_value)`; empty when
953    /// no covariance/smooth terms are available.
954    pub fn smooth_significance(&self) -> Vec<MultinomialSmoothSignificance> {
955        let mut out = Vec::new();
956        let p = self.p_per_class;
957        let m = self.n_active_classes;
958        let Some(cov) = self.coefficient_covariance() else {
959            return out;
960        };
961        if self.smooth_term_spans.is_empty() {
962            return out;
963        }
964        let beta = self.coefficients_active();
965        // Block-ordered θ = [β_0; …; β_{M-1}], θ[a·P + i] = β[i, a].
966        let d = p * m;
967        let mut theta = Array1::<f64>::zeros(d);
968        for a in 0..m {
969            for i in 0..p {
970                theta[a * p + i] = beta[[i, a]];
971            }
972        }
973        let influence = self.coefficient_influence();
974        for a in 0..m {
975            let class_label = self
976                .class_levels
977                .get(a)
978                .cloned()
979                .unwrap_or_else(|| format!("class{a}"));
980            let base = a * p;
981            for span in &self.smooth_term_spans {
982                if span.col_end > p {
983                    continue;
984                }
985                let start = base + span.col_start;
986                let end = base + span.col_end;
987                // Term EDF = tr(F_jj); without an influence matrix fall back to
988                // the block coefficient count (full-rank Wald on the span).
989                let block_len = (span.col_end - span.col_start) as f64;
990                let edf = influence
991                    .as_ref()
992                    .map(|f| (start..end).map(|i| f[[i, i]]).sum::<f64>())
993                    .filter(|v| v.is_finite() && *v > 0.0)
994                    .unwrap_or(block_len);
995                let result = gam_terms::inference::smooth_test::wood_smooth_test(
996                    gam_terms::inference::smooth_test::SmoothTestInput {
997                        beta: theta.view(),
998                        covariance: &cov,
999                        influence_matrix: influence.as_ref(),
1000                        coeff_range: start..end,
1001                        edf,
1002                        nullspace_dim: span.nullspace_dim,
1003                        residual_df: f64::INFINITY,
1004                        scale: gam_terms::inference::smooth_test::SmoothTestScale::Known,
1005                    },
1006                );
1007                if let Some(res) = result {
1008                    out.push(MultinomialSmoothSignificance {
1009                        class_label: class_label.clone(),
1010                        term_label: span.label.clone(),
1011                        edf,
1012                        ref_df: res.ref_df,
1013                        statistic: res.statistic,
1014                        p_value: res.p_value,
1015                    });
1016                }
1017            }
1018        }
1019        out
1020    }
1021}
1022
1023/// One row of the multinomial smooth-significance table (#1101): the Wood
1024/// rank-truncated Wald test for one `(active class, smooth term)` pair.
1025#[derive(Debug, Clone)]
1026pub struct MultinomialSmoothSignificance {
1027    pub class_label: String,
1028    pub term_label: String,
1029    pub edf: f64,
1030    pub ref_df: f64,
1031    pub statistic: f64,
1032    pub p_value: f64,
1033}
1034
1035/// One-hot-encode the categorical response column and return both the
1036/// encoding and the captured level names. The level order matches the order
1037/// recorded in the dataset schema, which is the canonical (lexicographically
1038/// sorted) factor order produced by inferred-schema construction (#1319) — so
1039/// it is a deterministic function of the label *set*, independent of training
1040/// row order (no silent class permutation under a row shuffle), and matches the
1041/// R `factor()` / pandas `Categorical` convention.
1042fn one_hot_categorical_response(
1043    data: &EncodedDataset,
1044    y_col: usize,
1045    response_name: &str,
1046) -> Result<(Array2<f64>, Vec<String>), EstimationError> {
1047    let levels: Vec<String> = data
1048        .schema
1049        .columns
1050        .get(y_col)
1051        .map(|sc| sc.levels.clone())
1052        .unwrap_or_default();
1053    if levels.len() < 2 {
1054        crate::bail_invalid_estim!(
1055            "multinomial response '{response_name}' must have at least 2 categorical levels (got {})",
1056            levels.len()
1057        );
1058    }
1059    let n = data.values.nrows();
1060    let k = levels.len();
1061    let mut y_one_hot = Array2::<f64>::zeros((n, k));
1062    for row in 0..n {
1063        let encoded = data.values[[row, y_col]];
1064        if !encoded.is_finite() {
1065            crate::bail_invalid_estim!(
1066                "multinomial response '{response_name}' row {row} is non-finite ({encoded})"
1067            );
1068        }
1069        let class_idx = encoded.round() as i64;
1070        if class_idx < 0 || (class_idx as usize) >= k {
1071            crate::bail_invalid_estim!(
1072                "multinomial response '{response_name}' row {row} encoded as {encoded} \
1073                 is outside the level range 0..{k}"
1074            );
1075        }
1076        y_one_hot[[row, class_idx as usize]] = 1.0;
1077    }
1078    Ok((y_one_hot, levels))
1079}
1080
1081/// Build `(TermCollectionSpec, TermCollectionDesign)` from a formula against
1082/// a categorical-response dataset. Mirrors the early scaffolding inside
1083/// `materialize_standard` (response role resolution, geometry-aware spec
1084/// build) without touching the scalar-family resolution path — multinomial
1085/// owns its own response kind check.
1086fn build_formula_design_for_multinomial(
1087    formula: &str,
1088    data: &EncodedDataset,
1089    config: &FitConfig,
1090) -> Result<
1091    (
1092        TermCollectionSpec,
1093        TermCollectionDesign,
1094        usize,
1095        String,
1096        ResponseColumnKind,
1097    ),
1098    EstimationError,
1099> {
1100    let parsed = parse_formula(formula).map_err(|err| {
1101        EstimationError::InvalidInput(format!(
1102            "multinomial fit: failed to parse formula {formula:?}: {err}"
1103        ))
1104    })?;
1105    let col_map = data.column_map();
1106    let y_col = resolve_role_col(&col_map, &parsed.response, "response")
1107        .map_err(|err| EstimationError::InvalidInput(format!("multinomial fit: {err}")))?;
1108    let y_kind = crate::fit_orchestration::response_column_kind(data, y_col);
1109    let policy = resolved_resource_policy(config, data, ProblemHints::default());
1110    let mut inference_notes: Vec<String> = Vec::new();
1111    let spec = build_termspec_with_geometry_and_overrides(
1112        &parsed.terms,
1113        data,
1114        &col_map,
1115        &mut inference_notes,
1116        config.scale_dimensions,
1117        &policy,
1118        config.smooth_overrides.as_ref(),
1119    )
1120    .map_err(|err| {
1121        EstimationError::InvalidInput(format!("multinomial fit: build termspec: {err}"))
1122    })?;
1123    let design = build_term_collection_design(data.values.view(), &spec).map_err(|err| {
1124        EstimationError::InvalidInput(format!("multinomial fit: build design: {err}"))
1125    })?;
1126    Ok((spec, design, y_col, parsed.response, y_kind))
1127}
1128
1129fn scale_multinomial_formula_penalty(penalty: PenaltyMatrix, scale: f64) -> PenaltyMatrix {
1130    match penalty {
1131        PenaltyMatrix::Dense(matrix) => PenaltyMatrix::Dense(matrix.mapv(|v| v * scale)),
1132        PenaltyMatrix::KroneckerFactored { left, right } => PenaltyMatrix::KroneckerFactored {
1133            left: left.mapv(|v| v * scale),
1134            right,
1135        },
1136        PenaltyMatrix::Blockwise {
1137            local,
1138            col_range,
1139            total_dim,
1140        } => PenaltyMatrix::Blockwise {
1141            local: local.mapv(|v| v * scale),
1142            col_range,
1143            total_dim,
1144        },
1145        PenaltyMatrix::Labeled { label, inner } => PenaltyMatrix::Labeled {
1146            label,
1147            inner: Box::new(scale_multinomial_formula_penalty(*inner, scale)),
1148        },
1149        PenaltyMatrix::Fixed { log_lambda, inner } => PenaltyMatrix::Fixed {
1150            log_lambda,
1151            inner: Box::new(scale_multinomial_formula_penalty(*inner, scale)),
1152        },
1153    }
1154}
1155
1156/// Build a warm-started copy of `blocks` whose per-block `initial_log_lambdas`
1157/// are seeded from a previously-selected flat `log_lambdas` vector (#1082).
1158///
1159/// The flat `log_lambdas` returned by [`fit_custom_family_with_rho_prior`]
1160/// concatenates each block's penalty log-λ in block order — the same order
1161/// `build_block_specs()` emits the blocks and the same per-block penalty order
1162/// the spec carries — so it splits back across blocks by each block's penalty
1163/// count. Warm-starting the OUTER ρ-search from a prior iterate changes only the
1164/// optimizer's starting point, never the penalized objective or its optimum, so
1165/// the converged fit is identical; it just resumes near the prior iterate
1166/// instead of restarting from the cold `init_lambda` seed.
1167///
1168/// Returns `None` (caller falls back to the cold blocks) if the flat vector does
1169/// not have exactly one entry per penalty across all blocks, or carries a
1170/// non-finite value — i.e. anything that would make the seed unsafe.
1171fn warm_start_blocks_from_log_lambdas(
1172    blocks: &[crate::custom_family::ParameterBlockSpec],
1173    log_lambdas: &[f64],
1174) -> Option<Vec<crate::custom_family::ParameterBlockSpec>> {
1175    let total: usize = blocks.iter().map(|b| b.initial_log_lambdas.len()).sum();
1176    if total == 0 || log_lambdas.len() != total {
1177        return None;
1178    }
1179    if log_lambdas.iter().any(|v| !v.is_finite()) {
1180        return None;
1181    }
1182    let mut warm = blocks.to_vec();
1183    let mut offset = 0usize;
1184    for block in warm.iter_mut() {
1185        let k = block.initial_log_lambdas.len();
1186        for slot in 0..k {
1187            block.initial_log_lambdas[slot] = log_lambdas[offset + slot];
1188        }
1189        offset += k;
1190    }
1191    Some(warm)
1192}
1193
1194/// Top-level formula-driven multinomial fit.
1195///
1196/// Routes through [`fit_custom_family_with_rho_prior`] so the per-active-class
1197/// smoothing parameters `λ_a` (one per class block, shared-penalty
1198/// architecture) are selected by the outer REML/LAML loop rather than pinned
1199/// by the caller. `init_lambda` survives as a warm-start hint that seeds
1200/// every block's `initial_log_lambdas`. `max_iter` / `tol` drive the OUTER
1201/// REML/LAML smoothing-parameter search (`outer_max_iter` / `outer_tol`); the
1202/// inner joint-Newton solve runs on the framework's principled production cycle
1203/// budget at the default KKT tolerance so an ill-conditioned, LM-damped
1204/// near-simplex-boundary solve can certify a stationary point instead of being
1205/// declared non-converged after only `max_iter` cycles (#715).
1206///
1207/// The Jeffreys/Firth proper prior is engaged CONDITIONALLY: attempt 1 runs
1208/// the unbiased penalized-REML criterion; only on separation evidence (a failed
1209/// solve or a non-finite logit; see [`multinomial_formula_separation_evidence`])
1210/// is the fit re-solved once with the full-span Firth prior armed, which bounds
1211/// the penalty-null directions no smoothing parameter can (`S v = 0` ⇒
1212/// `(H + S_λ) v = H v → 0` when the softmax likelihood has no finite mode).
1213///
1214/// The categorical response column is recognised via the dataset schema
1215/// (`ColumnKindTag::Categorical`); reference class = last level. Returns a
1216/// [`MultinomialSavedModel`] that can be serialised to bytes for the Python
1217/// wrapper or used in-process for `predict_probabilities`.
1218pub fn fit_penalized_multinomial_formula(
1219    data: &EncodedDataset,
1220    formula: &str,
1221    config: &FitConfig,
1222    init_lambda: f64,
1223    max_iter: usize,
1224    tol: f64,
1225) -> Result<MultinomialSavedModel, EstimationError> {
1226    if !(init_lambda.is_finite() && init_lambda > 0.0) {
1227        crate::bail_invalid_estim!(
1228            "multinomial fit: init_lambda must be finite and > 0 (got {init_lambda})"
1229        );
1230    }
1231    let (raw_spec, design, y_col, response_name, y_kind) =
1232        build_formula_design_for_multinomial(formula, data, config)?;
1233    // Freeze the data-derived basis state (B-spline knot vectors, by-factor
1234    // level sets, spatial centers, joint-null rotations, residualization
1235    // charts) from the fit design back onto the spec. The raw geometry spec
1236    // records only *which* columns and *what kind* of basis each smooth uses;
1237    // the actual column count and basis evaluation depend on quantities the
1238    // builder derives from the training data (knot placement, the distinct
1239    // by-factor levels, etc.). Saving the raw spec made predict re-derive those
1240    // from the (smaller, differently-distributed) predict frame, so the rebuilt
1241    // design had a different column count than the fitted one — the panic
1242    // "predict design has 42 cols, saved model expects 191" for an `s(x,
1243    // by=group)` smooth-by-factor model. Every other family's persistence path
1244    // freezes the spec the same way (see `freeze_term_collection_from_design`
1245    // call sites in `main_parts`); multinomial was the lone exception.
1246    let spec = freeze_term_collection_from_design(&raw_spec, &design)?;
1247    let class_levels = match y_kind {
1248        ResponseColumnKind::Categorical { levels } => levels,
1249        ResponseColumnKind::Binary => vec!["0".to_string(), "1".to_string()],
1250        ResponseColumnKind::Numeric => {
1251            crate::bail_invalid_estim!(
1252                "multinomial fit: response '{response_name}' is numeric, not categorical; \
1253                 use family='gaussian'/'binomial'/... or convert the column to a categorical type"
1254            );
1255        }
1256    };
1257    if data.column_kinds.get(y_col) == Some(&ColumnKindTag::Binary) {
1258        // Promote to a 2-level categorical for the multinomial driver; the
1259        // caller explicitly asked for multinomial, so we route through the
1260        // K-1 = 1 active-class softmax (equivalent math to logistic).
1261    } else if data.column_kinds.get(y_col) != Some(&ColumnKindTag::Categorical) {
1262        crate::bail_invalid_estim!(
1263            "multinomial fit: response '{response_name}' must be a categorical column \
1264             (got column kind {:?})",
1265            data.column_kinds.get(y_col)
1266        );
1267    }
1268    let (y_one_hot, _) = one_hot_categorical_response(data, y_col, &response_name)?;
1269    // Build the global X dense (the design is a DesignMatrix abstraction).
1270    let mut x_dense = design
1271        .design
1272        .try_to_dense_by_chunks("multinomial fit design")
1273        .map_err(EstimationError::InvalidInput)?;
1274
1275    // ── #715 real-data conditioning: standardize unpenalized parametric
1276    // columns. Raw-unit linear covariates (penguins `body_mass_g` ~ 4e3 grams)
1277    // inflate the joint Newton information by the squared column scale (a κ(H)
1278    // multiplier of ~s² ≈ 1e7 against the intercept), which is what turns the
1279    // near-separable LM-damped inner solve into a geometric grind that
1280    // exhausts its cycle budgets — the adapter-level face of "all REML startup
1281    // seeds rejected". Because these columns are UNPENALIZED (parametric terms
1282    // carry no default ridge, #749), the affine reparameterization
1283    // `x_j ↦ (x_j − m_j)/s_j` is EXACT for the whole criterion: the optimized
1284    // REML/LAML objective, the fitted η, the selected λ, and the separation
1285    // diagnostics are all invariant — only the conditioning of `H` changes.
1286    // Fitted coefficients are mapped back to raw units at repack below, so the
1287    // saved model and the (raw-design) predict path are untouched. Penalized
1288    // columns are left alone (a penalty makes the rescaling non-equivalent),
1289    // and nothing is touched when explicit coefficient bounds/constraints
1290    // exist (those are stated in raw units).
1291    let parametric_standardization: Vec<(usize, f64, f64)> =
1292        if design.coefficient_lower_bounds.is_some() || design.linear_constraints.is_some() {
1293            Vec::new()
1294        } else {
1295            let p_total = x_dense.ncols();
1296            let mut penalized = vec![false; p_total];
1297            for bp in &design.penalties {
1298                for col in bp.col_range.clone() {
1299                    if col < p_total {
1300                        penalized[col] = true;
1301                    }
1302                }
1303            }
1304            let has_intercept = !design.intercept_range.is_empty();
1305            let n_rows = x_dense.nrows().max(1) as f64;
1306            let mut standardized = Vec::new();
1307            for (_, range) in &design.linear_ranges {
1308                for col in range.clone() {
1309                    if col >= p_total || penalized[col] {
1310                        continue;
1311                    }
1312                    let column = x_dense.column(col);
1313                    let mean = column.sum() / n_rows;
1314                    let var = column.iter().map(|v| (v - mean) * (v - mean)).sum::<f64>() / n_rows;
1315                    let scale = var.sqrt();
1316                    // Skip near-constant or degenerate columns: no conditioning to
1317                    // be gained and the back-map would divide by ~0.
1318                    if !(scale.is_finite() && scale > 1e-8 * (mean.abs() + 1.0)) {
1319                        continue;
1320                    }
1321                    // Centering shifts mass onto the intercept; without one the
1322                    // shift is not representable, so scale only.
1323                    let center = if has_intercept { mean } else { 0.0 };
1324                    for v in x_dense.column_mut(col).iter_mut() {
1325                        *v = (*v - center) / scale;
1326                    }
1327                    standardized.push((col, center, scale));
1328                }
1329            }
1330            standardized
1331        };
1332    // Preserve the per-smooth-term penalty block structure (#561): each smooth
1333    // term `t` contributes its own `P × P` penalty component (`Blockwise` with
1334    // `total_dim = P`, the term's local `S_t` embedded at its `col_range`), and
1335    // every active class block receives the FULL list. The outer REML/LAML loop
1336    // then selects an independent smoothing parameter λ_{a,t} per (class, term),
1337    // matching mgcv/VGAM. Pre-summing the terms into one fused `S` (the prior
1338    // behaviour) forced a single λ per class that scales `Σ_t S_t`, so one
1339    // shared λ had to over-smooth a rough term while under-smoothing a smooth
1340    // one — biasing any multi-term class-probability surface.
1341    let k = y_one_hot.ncols();
1342    let m = k - 1;
1343    let n_obs = y_one_hot.nrows();
1344    let penalty_scale = multinomial_formula_penalty_scale(k);
1345    let per_term_penalties: Vec<PenaltyMatrix> = design
1346        .penalties_as_penalty_matrix()
1347        .into_iter()
1348        .map(|penalty| scale_multinomial_formula_penalty(penalty, penalty_scale))
1349        .collect();
1350    let per_term_nullspace_dims = design.nullspace_dims.clone();
1351
1352    // ── Custom-family driven REML/LAML path ───────────────────────────────
1353    // Each active class becomes one ParameterBlockSpec, all sharing X and the
1354    // per-term penalty list. `initial_log_lambdas` is seeded from the caller's
1355    // `init_lambda` (one entry per term).
1356    let design_arc = Arc::new(x_dense);
1357    let penalties_arc = Arc::new(per_term_penalties);
1358    let nullspace_dims_arc = Arc::new(per_term_nullspace_dims);
1359    let weights = Array1::<f64>::ones(n_obs);
1360    // First attempt runs the UNBIASED penalized-REML criterion (no Firth
1361    // shrinkage toward the uniform simplex); the Jeffreys/Firth proper prior is
1362    // armed conditionally below, only on separation evidence (#715/#753 — see
1363    // `multinomial_formula_separation_evidence`).
1364    let log_init = init_lambda.ln();
1365    let family = MultinomialFamily::new(
1366        y_one_hot.clone(),
1367        weights,
1368        k,
1369        design_arc.clone(),
1370        penalties_arc.clone(),
1371        nullspace_dims_arc.clone(),
1372    )
1373    .map_err(EstimationError::InvalidInput)?
1374    .with_joint_jeffreys_term(false)
1375    // gam#1587: the per-block smooth penalties are emptied (the centered `M⊗S_t`
1376    // joint penalty is the sole smoothing carrier), so the `init_lambda` warm
1377    // start must seed the JOINT penalty's `initial_log_lambda` — the per-block
1378    // `initial_log_lambdas` loop below is now a no-op (empty per-block list).
1379    .with_initial_log_lambda(log_init);
1380    let mut blocks = family.build_block_specs();
1381    for spec_block in blocks.iter_mut() {
1382        for v in spec_block.initial_log_lambdas.iter_mut() {
1383            *v = log_init;
1384        }
1385    }
1386
1387    // ── Outer-derivative policy: dimension-gated exact curvature ────────────
1388    // The total smoothing-parameter dimension is `D = (K−1) · n_terms`.
1389    // Medium-D formula fits need exact curvature to keep lambda selection away
1390    // from over-smoothed caps, while smooth-by-factor `D = 8` models still avoid
1391    // the O(D²) dense Hessian path.
1392    let total_rho_dim = m.saturating_mul(penalties_arc.len());
1393    let use_outer_hessian = multinomial_formula_use_outer_hessian(total_rho_dim);
1394
1395    // ── Inner-vs-outer control split (#715 non-convergence root cause) ────────
1396    // The legacy `max_iter` / `tol` parameters are the *outer* REML/LAML
1397    // smoothing-parameter optimization controls — "how hard to search λ". The
1398    // earlier wiring routed them straight into `inner_max_cycles` / `inner_tol`,
1399    // capping the joint-Newton inner solve at `max_iter` (=50 in the quality
1400    // suite) cycles with a `tol`-tight (=1e-8) KKT target. That is the #715
1401    // hang: near the simplex boundary the softmax Fisher weight
1402    // `W = diag(p) − p pᵀ` collapses, so `H = JᵀWJ + S_λ` is full-rank but
1403    // ILL-CONDITIONED. The self-vanishing Levenberg–Marquardt damping
1404    // (`levenberg_on_ill_conditioning()`) that keeps the inner solve from
1405    // oscillating on those near-singular modes makes it converge only
1406    // GEOMETRICALLY (linearly), not quadratically. Reaching a 1e-8 relative KKT
1407    // residual under geometric descent needs FAR more than 50 cycles, so the
1408    // inner returned `converged = false` on every outer ρ-evaluation; with the
1409    // exact-Hessian outer optimizer on `FallbackPolicy::Disabled` that rejects
1410    // every ρ-step — each rejected eval still paying a near-full 50-cycle inner
1411    // solve plus the O(D²) pairwise outer-Hessian directional work — so the
1412    // outer never certifies and the fit runs unbounded (the observed >8-minute
1413    // non-termination). The certificate cannot be reached, not merely slow.
1414    //
1415    // Fix: give the INNER joint-Newton the framework's principled production
1416    // budget (`DEFAULT_CUSTOM_FAMILY_INNER_MAX_CYCLES` cycles at the default
1417    // `inner_tol`), which exists precisely so an ill-conditioned LM-damped solve
1418    // can certify a stationary KKT point instead of being declared non-converged
1419    // prematurely — and the KKT/objective certificates still exit in a handful
1420    // of cycles on the well-conditioned interior fits, so this is free there.
1421    // The caller's `max_iter` / `tol` become the OUTER controls they were always
1422    // meant to be (smoothing-parameter search depth / accuracy). The inner KKT
1423    // target is kept no tighter than the outer accuracy can consume — and no
1424    // tighter than the softmax objective's f64 noise floor on near-separable
1425    // fits (see `MULTINOMIAL_FORMULA_INNER_TOL`).
1426    let outer_max_iter = max_iter.max(1);
1427    // The OUTER REML/LAML smoothing-parameter search must converge to a
1428    // well-calibrated ρ-gradient tolerance, NOT to the caller's (typically very
1429    // tight) INNER KKT tolerance. The #715 control-split repurposed the caller's
1430    // `tol` as the outer control, but feeding an inner-scale `tol = 1e-8`
1431    // straight into `outer_tol` makes REML grind dozens of extra exact-gradient
1432    // outer iterations (each an O(D·p³) Laplace-derivative assembly over the full
1433    // P·M joint design) to squeeze ρ digits that no longer move the fitted
1434    // surface — the smooth-by-factor 269s wall-clock overrun (#1082).
1435    //
1436    // The right target is the framework's CALIBRATED REML convergence tolerance,
1437    // `MULTINOMIAL_OUTER_REML_TOL = 1e-7` — the same value the primary GLM REML
1438    // outer uses (`solver::fit_orchestration::materialize` `tol: 1e-7`, mirrored by the
1439    // `LOG_LAMBDA_TOL`/`KKT_TOL_*` constants across the REML stack). At 1e-7 the
1440    // λ-search reaches the genuine REML optimum (so the recovered probability
1441    // surface matches the mature reference), but it does NOT chase the last
1442    // surface-irrelevant ρ digits down to 1e-8. The earlier 1e-5 floor (the
1443    // generic `BlockwiseFitOptions` default) was too LOOSE: the optimizer halted
1444    // in a low-curvature region with λ still well above its optimum, UNDER-fitting
1445    // the smooth-by-factor surface (truth-RMSE 0.164 vs VGAM's 0.061). So the
1446    // outer tolerance is floored at the calibrated REML tol — never tighter than
1447    // it (perf), never looser (accuracy) — while the caller's `tol` continues to
1448    // drive the INNER joint-Newton KKT target (`inner_tol` below), where its
1449    // precision actually matters.
1450    let outer_tol = if tol.is_finite() && tol > 0.0 {
1451        tol.max(MULTINOMIAL_OUTER_REML_TOL)
1452    } else {
1453        MULTINOMIAL_OUTER_REML_TOL
1454    };
1455    // #1082 root cause: the outer convergence test derives BOTH the absolute
1456    // projected-gradient floor (`max(outer_tol, n·1e-9)`) AND the relative-cost
1457    // stop (`rel_cost = outer_tol`) from the single `outer_tol`. The accuracy of
1458    // the smooth-by-factor surface is governed by the ABSOLUTE floor reaching the
1459    // n-scaled REML resolution `n·1e-9` (≈ 1.8e-6 at n = 1800) — that is why the
1460    // earlier 1e-5 floor UNDER-fit (its absolute floor was pinned at 1e-5, well
1461    // above the genuine optimum's gradient) and why 1e-7 recovered accuracy (it
1462    // unpins the floor down to the n-scaled 1.8e-6). But tightening `outer_tol`
1463    // to 1e-7 ALSO tightened the rel-cost stop to 1e-7, which on this family's
1464    // dead-flat REML ridge NEVER trips — so the optimizer no longer converges and
1465    // grinds all the way to `outer_max_iter`, each surplus step an O(D·p³) Laplace-
1466    // derivative assembly over the 382-dim joint design (the >600s wall-clock
1467    // overrun; tightening tol REINTRODUCED the crawl the 1e-5 floor had removed).
1468    //
1469    // The two requirements live on two different criteria, so they must be set
1470    // independently. Keep `outer_tol = 1e-7` (drives the accurate absolute floor)
1471    // but FLOOR the relative-cost stop at the framework default 1e-5 (the loose,
1472    // fast value that resolves the cost-decrease plateau without chasing the flat
1473    // tail). The absolute n·1e-9 floor still gates final λ accuracy; the rel-cost
1474    // stop just lets the optimizer DECLARE convergence on the flat ridge instead
1475    // of crawling to the iteration cap.
1476    let outer_rel_cost_tol = Some(BlockwiseFitOptions::default().outer_tol);
1477    let inner_tol = MULTINOMIAL_FORMULA_INNER_TOL.max(tol.max(0.0));
1478
1479    let options = BlockwiseFitOptions {
1480        inner_max_cycles: crate::custom_family::DEFAULT_CUSTOM_FAMILY_INNER_MAX_CYCLES,
1481        inner_tol,
1482        outer_max_iter,
1483        outer_tol,
1484        outer_rel_cost_tol,
1485        rho_lower_bound: multinomial_formula_min_lambda(y_one_hot.view()).ln(),
1486        ridge_floor: MULTINOMIAL_FORMULA_RIDGE_FLOOR,
1487        // #747: the stabilization floor is SOLVER-ONLY — it keeps the inner
1488        // joint-Newton linear solve finite during screening (bounding the step
1489        // `(H+δI)⁻¹∇` away from a near-separable, rank-deficient curvature) but
1490        // is excluded from the REML objective, the penalty log-determinant, and
1491        // the Laplace Hessian. The earlier default (`explicit_stabilization_pospart`)
1492        // folded `½·δ·‖β‖²` and a `δ`-shift of the log-determinant into the
1493        // criterion, shrinking every identified coefficient off the MLE and
1494        // perturbing smoothing-parameter selection — a fixed-λ prior masking
1495        // separation, not a numerical stabilizer. With the floor solver-only the
1496        // optimized objective is the true penalized REML criterion (value tracks
1497        // its analytic gradient), and the smooth directions remain governed
1498        // solely by their own REML-selected `λ`.
1499        ridge_policy: gam_problem::RidgePolicy::solver_only(),
1500        use_outer_hessian,
1501        // #715 real-data arm ("canonical-gauge null direction rejects all REML
1502        // seeds"): skip the multi-seed outer screening cascade and let the
1503        // pinned `init_lambda` ρ flow straight to the outer optimizer.
1504        //
1505        // The multinomial family declares `levenberg_on_ill_conditioning() ->
1506        // true`: near the simplex boundary (the near-separable penguins regime)
1507        // the softmax Fisher weight `W = diag(p) − p pᵀ → 0`, so the joint
1508        // information `H = JᵀWJ + S_λ` can become full-rank but
1509        // ILL-CONDITIONED. The self-vanishing LM damping that keeps the inner
1510        // joint-Newton from oscillating on those near-singular modes converges
1511        // only GEOMETRICALLY. The default screening policy ranks candidate seeds
1512        // with a 2-cycle inner cap (`outer_seed_config`); under geometric
1513        // LM-damped descent two cycles never reach a finite, meaningful proxy
1514        // objective, so EVERY capped seed can collapse to non-finite cost and
1515        // the cascade escalates to ×4, ×16, then an UNCAPPED full inner solve
1516        // PER SEED on the near-singular Hessian. That is the adapter-level face
1517        // of "all REML startup seeds rejected" and the multi-minute timeout.
1518        //
1519        // The pinned seed is already principled here: `init_lambda` gives every
1520        // (class, term) ρ a sensible moderate warm start, and the per-term
1521        // effective-df-floor upper bounds (`effective_df_floor_rho_upper_bounds`,
1522        // #715 arm (a)) keep any λ from collapsing the smooth onto its polynomial
1523        // null space. So the outer ARC/BFGS optimizer performs the real REML ρ
1524        // search from this seed; screening only adds the cascade cost and, on the
1525        // near-separable arm, the rejection stall.
1526        screen_initial_rho: false,
1527        // #1101: compute the joint Laplace posterior covariance `H⁻¹` (and the
1528        // influence matrix `F = H⁻¹ X'WX`) at the converged mode so the saved
1529        // model can surface delta-method per-class probability standard errors
1530        // and Wald smooth-term p-values. The driver factorizes the penalized
1531        // Hessian during the inner solve regardless; this only asks it to keep
1532        // and invert the factor instead of discarding it.
1533        compute_covariance: true,
1534        ..BlockwiseFitOptions::default()
1535    };
1536    // ── Conditional Firth/Jeffreys engagement (#715 arm (b) / #753) ──────────
1537    // Attempt 1: the unbiased criterion (Jeffreys disarmed above). If the
1538    // returned mode is converged, finite, and interior, it is the exact penalized-REML
1539    // optimum with zero Firth bias — accept it (this is the synthetic-arm /
1540    // interior-data path, #715 arm (a)). If the solve FAILS (e.g. the
1541    // (quasi-)separated penguins geometry where `(H + S_λ)v ≈ 0` along
1542    // penalty-null directions for EVERY ρ rejects every REML startup seed) or
1543    // returns a non-finite artifact, that is direct separation evidence:
1544    // re-solve once with the full-span Jeffreys/Firth proper prior armed, which
1545    // supplies the O(1) curvature on the quotient-null subspace that smoothing
1546    // parameters mathematically cannot (`Sv = 0` ⇒ λ never touches `v`). The
1547    // Firth refit is the accepted result only when the unbiased formula solve
1548    // failed, did not converge on its full budget, or blew up; finite
1549    // formula-path logits can be large on valid near-separated optima and
1550    // should not be shrunk toward the uniform simplex once the unbiased outer
1551    // solve has actually certified.
1552    let mut unbiased_probe_options = options.clone();
1553    unbiased_probe_options.outer_max_iter = unbiased_probe_options
1554        .outer_max_iter
1555        .min(MULTINOMIAL_UNBIASED_PROBE_OUTER_MAX_ITER);
1556    // The FINAL accepted Firth/Jeffreys refit runs to the caller's full outer
1557    // budget: it is the result we ship, so it must reach the genuine REML
1558    // optimum, not a truncated iterate. The near-separable penguin refit that
1559    // motivated #1082's wall-clock concern is now halted honestly at its true
1560    // bound optimum by the KKT-stationary-at-bound guard
1561    // (`CostStallGuard`, #1082 / 64711ed82) and the Newton-decrement residual
1562    // certificate (363af9b56 / 2c9580b1f): on separable data the outer ARC
1563    // certifies and stops early on its own, so no artificial iteration cap is
1564    // needed to land in budget. On non-separable data (e.g. the
1565    // `vgam_smooth_by_factor` double-penalty arm) the refit needs the caller's
1566    // full budget to converge, which a `.min(20)` cap would cut off — accepting
1567    // a non-converged fit, which is dishonest. So the refit keeps `options`
1568    // unchanged. Only the discarded unbiased separation probe above is capped.
1569    let firth_refit_options = &options;
1570
1571    let run_firth_refit = |evidence: String| {
1572        let firth_family = family.clone().with_joint_jeffreys_term(true);
1573        fit_custom_family_with_rho_prior(
1574            &firth_family,
1575            &blocks,
1576            firth_refit_options,
1577            gam_problem::RhoPrior::Flat,
1578        )
1579        .map_err(|err| {
1580            EstimationError::InvalidInput(format!(
1581                "multinomial REML: Firth/Jeffreys-armed refit (separation evidence: \
1582                 {evidence}) failed: {err}"
1583            ))
1584        })
1585    };
1586
1587    // #1082: the capped unbiased probe and the (separable-path) Firth decision
1588    // are driven by separation scans over the full P×M logit block. The previous
1589    // match recomputed `multinomial_formula_separation_evidence` /
1590    // `..._unresolved_probe_separation_evidence` in BOTH the match guard AND the
1591    // arm body — three to four full logit walks per fit, paid on the hot
1592    // near-separable penguin path where this branch fires every iterate. Run the
1593    // probe once, evaluate each scan once into a binding, and branch on the
1594    // precomputed results. Behaviour is identical (same scans, same order of
1595    // precedence: converged-interior, unresolved-probe-separation,
1596    // no-separation-needs-full-solve, otherwise-Firth); only the duplicate
1597    // O(n·classes) scans are removed.
1598    let probe_attempt = fit_custom_family_with_rho_prior(
1599        &family,
1600        &blocks,
1601        &unbiased_probe_options,
1602        gam_problem::RhoPrior::Flat,
1603    );
1604    let fit = match probe_attempt {
1605        Ok(probe_fit) => {
1606            let separation = multinomial_formula_separation_evidence(&probe_fit.block_states);
1607            if probe_fit.outer_converged && separation.is_none() {
1608                // Interior, converged, no separation: accept the probe directly.
1609                probe_fit
1610            } else if let Some(evidence) =
1611                multinomial_formula_unresolved_probe_separation_evidence(&probe_fit.block_states)
1612            {
1613                // Non-converged probe already carrying separation-scale logits:
1614                // hand straight to the proper-prior Firth refit (do not spend the
1615                // full unbiased budget grinding the λ→0 separable ridge).
1616                run_firth_refit(format!(
1617                    "unbiased-criterion REML probe did not converge after {} outer iterations; {evidence}",
1618                    probe_fit.outer_iterations
1619                ))?
1620            } else if separation.is_none() {
1621                // Interior but the capped probe ran out of iterations without
1622                // certifying: re-solve at the caller's full outer budget.
1623                //
1624                // #1082 wall-clock: the capped probe is a strict prefix of this
1625                // solve from the same family/seed, so a COLD restart repeats the
1626                // probe's outer iterations. WARM-START the re-solve from the ρ the
1627                // probe already reached — seed each block's `initial_log_lambdas`
1628                // from the probe's selected `log_lambdas` (same block/penalty
1629                // order: the flat vector concatenates per-block penalties in block
1630                // order, exactly the order `build_block_specs()` emits them). This
1631                // changes only the optimizer's STARTING point, never the objective
1632                // or its optimum, but lets the full solve resume near the probe's
1633                // last iterate instead of crawling up from `init_lambda` again —
1634                // removing the probe-iterations double-pay on the non-separable
1635                // (e.g. `vgam_smooth_by_factor`) arm. If the probe's λ vector does
1636                // not line up with the block layout (it always should), fall back
1637                // to the cold `blocks` seed.
1638                let warm_blocks = warm_start_blocks_from_log_lambdas(
1639                    &blocks,
1640                    probe_fit.log_lambdas.as_slice().unwrap_or(&[]),
1641                );
1642                let resolve_blocks = warm_blocks.as_deref().unwrap_or(&blocks);
1643                match fit_custom_family_with_rho_prior(
1644                    &family,
1645                    resolve_blocks,
1646                    &options,
1647                    gam_problem::RhoPrior::Flat,
1648                ) {
1649                    Ok(full_unbiased_fit) => {
1650                        let full_separation = multinomial_formula_separation_evidence(
1651                            &full_unbiased_fit.block_states,
1652                        );
1653                        if full_unbiased_fit.outer_converged && full_separation.is_none() {
1654                            full_unbiased_fit
1655                        } else {
1656                            let evidence = full_separation.unwrap_or_else(|| {
1657                                format!(
1658                                    "full unbiased-criterion REML solve did not converge after {} outer iterations",
1659                                    full_unbiased_fit.outer_iterations
1660                                )
1661                            });
1662                            run_firth_refit(evidence)?
1663                        }
1664                    }
1665                    Err(err) => run_firth_refit(format!(
1666                        "full unbiased-criterion REML solve failed: {err}"
1667                    ))?,
1668                }
1669            } else {
1670                // Probe converged (or capped) but shows interior separation
1671                // evidence: Firth refit using the already-computed scan.
1672                let evidence = separation.unwrap_or_else(|| {
1673                    format!(
1674                        "unbiased-criterion REML probe did not converge after {} outer iterations",
1675                        probe_fit.outer_iterations
1676                    )
1677                });
1678                run_firth_refit(evidence)?
1679            }
1680        }
1681        Err(err) => run_firth_refit(format!("unbiased-criterion REML solve failed: {err}"))?,
1682    };
1683    if let Some(err) = multinomial_formula_separation_diagnostic(
1684        fit.inner_cycles,
1685        fit.outer_iterations,
1686        &fit.block_states,
1687    ) {
1688        return Err(err);
1689    }
1690
1691    // ── Repack coefficients (P, K-1) from per-block β vectors ─────────────
1692    if fit.blocks.len() != m {
1693        crate::bail_invalid_estim!(
1694            "multinomial REML: expected {m} fitted blocks (K-1), got {}",
1695            fit.blocks.len()
1696        );
1697    }
1698    let p_per_class = fit.blocks[0].beta.len();
1699    let mut coefficients_active = Array2::<f64>::zeros((p_per_class, m));
1700    for (a, block) in fit.blocks.iter().enumerate() {
1701        if block.beta.len() != p_per_class {
1702            crate::bail_invalid_estim!(
1703                "multinomial REML: block {a} has {} coefs, expected {p_per_class}",
1704                block.beta.len()
1705            );
1706        }
1707        for i in 0..p_per_class {
1708            coefficients_active[[i, a]] = block.beta[i];
1709        }
1710    }
1711    // Map the standardized-column coefficients back to raw units (the exact
1712    // inverse of the conditioning reparameterization above): β_raw = b/s, with
1713    // the centering mass `Σ_j b_j·m_j/s_j` returned to the intercept.
1714    if !parametric_standardization.is_empty() {
1715        let intercept_col = design.intercept_range.clone().next();
1716        for a in 0..m {
1717            let mut intercept_adjust = 0.0;
1718            for &(col, center, scale) in &parametric_standardization {
1719                if col < p_per_class {
1720                    let raw = coefficients_active[[col, a]] / scale;
1721                    coefficients_active[[col, a]] = raw;
1722                    intercept_adjust += raw * center;
1723                }
1724            }
1725            if let Some(i0) = intercept_col
1726                && i0 < p_per_class
1727            {
1728                coefficients_active[[i0, a]] -= intercept_adjust;
1729            }
1730        }
1731    }
1732    // Flatten every (class, term) smoothing parameter in block-major order
1733    // (class 0's terms, then class 1's, …). With per-term penalties each block
1734    // now carries one λ per smooth term, so a single λ per class would discard
1735    // the independent per-term selection that fixes #561. `lambdas_per_block`
1736    // segments the flat vector by class so callers can recover per-term λ.
1737    let lambdas_per_block: Vec<usize> = fit.blocks.iter().map(|b| b.lambdas.len()).collect();
1738    let lambdas_flat: Vec<f64> = fit
1739        .blocks
1740        .iter()
1741        .flat_map(|b| b.lambdas.iter().copied())
1742        .collect();
1743    // Per-active-class effective degrees of freedom, length `K-1`, summing to
1744    // the model `edf_total`. The REML inference block reports `edf_by_block` as
1745    // ONE entry per *penalty block* (per (class, term, penalty)), each computed
1746    // as `rank(S_kk) − tr(H⁻¹ λ_kk S_kk)`. That per-block sum OVER-COUNTS the
1747    // model EDF whenever several penalties share one coefficient range — a
1748    // double-penalty / te / ti / adaptive smooth has ≥2 penalty blocks over the
1749    // same columns, so `Σ_kk rank(S_kk) > p` and `Σ_kk edf_by_block > edf_total`
1750    // (the observed ~79 for a ~24-coefficient model). Handing that raw per-block
1751    // vector out as the documented length-(K-1) per-class EDF is therefore both
1752    // the wrong LENGTH (it is `Σ_a n_blocks_a`, not `K-1`) and an over-count.
1753    //
1754    // The honest per-class EDF is the influence-matrix trace over each class's
1755    // coefficient block. Classes occupy DISJOINT `p_per_class`-wide coefficient
1756    // ranges, and the per-block traces `tr_kk = tr(H⁻¹ λ_kk S_kk)` are additive
1757    // (no rank double-counting), so class `a`'s EDF is
1758    // `p_per_class − Σ_{kk ∈ class a} tr_kk`, and `Σ_a edf_a = m·p_per_class −
1759    // Σ_kk tr_kk = p − Σ tr_kk = edf_total` exactly. Segment the block-major
1760    // `penalty_block_trace` by `lambdas_per_block` (the same per-class λ-count
1761    // segmentation `lambdas_flat` uses). Fall back to `None` when the trace
1762    // channel is unavailable or mis-shaped (legacy fixed-λ path), exactly as the
1763    // raw `edf_by_block` map did before.
1764    let edf_per_class = fit.inference.as_ref().and_then(|info| {
1765        let traces = &info.penalty_block_trace;
1766        if traces.len() != lambdas_per_block.iter().sum::<usize>() {
1767            // Trace channel absent or not aligned with the per-class block
1768            // segmentation — cannot assemble an honest per-class EDF.
1769            return None;
1770        }
1771        let mut per_class = Vec::with_capacity(m);
1772        let mut cursor = 0usize;
1773        for &n_blocks in &lambdas_per_block {
1774            let class_trace: f64 = traces[cursor..cursor + n_blocks].iter().sum();
1775            // `tr(F)` over a class block ∈ [0, p_per_class]; clamp away
1776            // round-off so a reported EDF can never be negative or exceed the
1777            // class's own coefficient count.
1778            per_class.push((p_per_class as f64 - class_trace).clamp(0.0, p_per_class as f64));
1779            cursor += n_blocks;
1780        }
1781        Some(per_class)
1782    });
1783    // Per-PENALTY EDF: the inference layer's `edf_by_block` is already the
1784    // clamped per-penalty-block trace EDF `rank(S_k) − λ_k·tr(H⁻¹ S_k)`, one
1785    // entry per smoothing parameter and block-major aligned 1:1 with the flat
1786    // `lambdas`. Surface it verbatim (guarding only on the length contract) so
1787    // consumers can inspect per-(class, term, penalty) collapse onto the null
1788    // space — a signal the per-class EDF SUM hides. This is NOT a per-class
1789    // total: with double-penalty smooths `Σ_k rank(S_k) > p_per_class`, so the
1790    // entries deliberately need not sum to the model EDF (the per-class field
1791    // carries that contract instead).
1792    let edf_per_penalty = fit.inference.as_ref().and_then(|info| {
1793        if info.edf_by_block.len() != lambdas_flat.len() {
1794            return None;
1795        }
1796        Some(
1797            info.edf_by_block
1798                .iter()
1799                .map(|&e| e.max(0.0))
1800                .collect::<Vec<f64>>(),
1801        )
1802    });
1803    let coefficients_flat: Vec<f64> = coefficients_active.iter().copied().collect();
1804
1805    // #1101: surface the joint Laplace posterior covariance `H⁻¹` (block-ordered
1806    // [β_0; …; β_{K-2}]) and the influence matrix `F = H⁻¹ X'WX` the REML driver
1807    // computed at the converged mode. These power the predict path's delta-method
1808    // per-class probability standard errors and the summary's Wald smooth-term
1809    // tests. The joint matrices are `(P·M)×(P·M)`. The covariance is mapped back
1810    // to RAW units (see below) so it pairs with the raw predict design; the
1811    // influence is kept in the fitted basis (the Wald table only slices penalized
1812    // columns, which the standardization affine leaves identity-mapped).
1813    let expected_joint = p_per_class.saturating_mul(m);
1814    // The joint Hessian (and thus `H⁻¹`) was assembled in the STANDARDIZED
1815    // parametric basis used during fitting, while the saved coefficients and the
1816    // raw predict design are in raw units. Map the covariance to raw units with
1817    // the same exact affine reparameterization `β_raw = A β_std`: for each
1818    // standardized parametric column `col`, `β_raw[col] = β_std[col]/scale` and
1819    // the intercept absorbs `−Σ_col (center/scale)·β_std[col]`. So `A = I` except
1820    // `A[col,col] = 1/scale` and `A[i0,col] = −center/scale`, replicated
1821    // block-diagonally per active class, and `Cov_raw = A Cov_std Aᵀ`. With no
1822    // standardization (`parametric_standardization` empty) `A = I` and this is a
1823    // no-op. The smooth-term (penalized) columns are untouched by `A`, so the
1824    // Wald table's per-term blocks are identical in both bases.
1825    let intercept_col0 = design.intercept_range.clone().next();
1826    let build_per_class_affine = |amat: &mut Array2<f64>| {
1827        for &(col, center, scale) in &parametric_standardization {
1828            if col >= p_per_class {
1829                continue;
1830            }
1831            amat[[col, col]] = 1.0 / scale;
1832            if let Some(i0) = intercept_col0
1833                && i0 < p_per_class
1834            {
1835                amat[[i0, col]] = -center / scale;
1836            }
1837        }
1838    };
1839    let coefficient_covariance_flat = fit
1840        .covariance_conditional
1841        .as_ref()
1842        .filter(|c| c.nrows() == expected_joint && c.ncols() == expected_joint)
1843        .map(|cov_std| {
1844            if parametric_standardization.is_empty() {
1845                return cov_std.iter().copied().collect::<Vec<f64>>();
1846            }
1847            // Block-diagonal joint A (same per active class).
1848            let mut a_joint = Array2::<f64>::eye(expected_joint);
1849            let mut a_class = Array2::<f64>::eye(p_per_class);
1850            build_per_class_affine(&mut a_class);
1851            for a in 0..m {
1852                let base = a * p_per_class;
1853                for i in 0..p_per_class {
1854                    for j in 0..p_per_class {
1855                        a_joint[[base + i, base + j]] = a_class[[i, j]];
1856                    }
1857                }
1858            }
1859            let cov_raw = a_joint.dot(cov_std).dot(&a_joint.t());
1860            cov_raw.iter().copied().collect::<Vec<f64>>()
1861        });
1862    // The influence matrix `F = H⁻¹ X'WX = H⁻¹(H − S_λ) = I − H⁻¹ S_λ`. The
1863    // exact-Newton multinomial blocks carry no IRLS pseudo-data, so the generic
1864    // inference path does not export `coefficient_influence`; reconstruct it
1865    // exactly here from the joint covariance `H⁻¹` (above) and the REML-selected
1866    // per-(class, term) `λ` scaling the shared penalties. Block-diagonal `S_λ`:
1867    // class `a`'s block is `Σ_t λ_{a,t} · S_t`, embedded at `a·P .. (a+1)·P`.
1868    let coefficient_influence_flat = fit
1869        .covariance_conditional
1870        .as_ref()
1871        .filter(|c| c.nrows() == expected_joint && c.ncols() == expected_joint)
1872        .and_then(|hinv| {
1873            if fit.blocks.len() != m {
1874                return None;
1875            }
1876            // Joint S_λ (block-diagonal across active classes).
1877            let mut s_lambda = Array2::<f64>::zeros((expected_joint, expected_joint));
1878            for (a, block) in fit.blocks.iter().enumerate() {
1879                if block.lambdas.len() != penalties_arc.len() {
1880                    return None;
1881                }
1882                let base = a * p_per_class;
1883                for (t, pen) in penalties_arc.iter().enumerate() {
1884                    let lam = block.lambdas[t];
1885                    if lam == 0.0 {
1886                        continue;
1887                    }
1888                    let dense = pen.to_dense();
1889                    if dense.nrows() != p_per_class || dense.ncols() != p_per_class {
1890                        return None;
1891                    }
1892                    for i in 0..p_per_class {
1893                        for j in 0..p_per_class {
1894                            s_lambda[[base + i, base + j]] += lam * dense[[i, j]];
1895                        }
1896                    }
1897                }
1898            }
1899            // F = I − H⁻¹ S_λ.
1900            let hinv_s = hinv.dot(&s_lambda);
1901            let mut f = Array2::<f64>::eye(expected_joint);
1902            f -= &hinv_s;
1903            Some(f.iter().copied().collect::<Vec<f64>>())
1904        });
1905
1906    // Per-(smooth term) coefficient span within a single class block, deduped by
1907    // col_range (the #561 double-penalty migration emits two penalty blocks per
1908    // term sharing one col_range; the Wald test covers the whole term block once).
1909    let mut smooth_term_spans: Vec<MultinomialSmoothTermSpan> = Vec::new();
1910    for (pen_idx, bp) in design.penalties.iter().enumerate() {
1911        let col_start = bp.col_range.start;
1912        let col_end = bp.col_range.end;
1913        if col_start >= col_end || col_end > p_per_class {
1914            continue;
1915        }
1916        if smooth_term_spans
1917            .iter()
1918            .any(|s| s.col_start == col_start && s.col_end == col_end)
1919        {
1920            continue;
1921        }
1922        let label = design
1923            .penaltyinfo
1924            .get(pen_idx)
1925            .and_then(|info| info.termname.clone())
1926            .unwrap_or_else(|| format!("s{pen_idx}"));
1927        let nullspace_dim = design
1928            .nullspace_dims
1929            .get(pen_idx)
1930            .copied()
1931            .unwrap_or(0)
1932            .min(col_end - col_start);
1933        smooth_term_spans.push(MultinomialSmoothTermSpan {
1934            label,
1935            col_start,
1936            col_end,
1937            nullspace_dim,
1938        });
1939    }
1940
1941    // One descriptive label per penalty *component* within a single class block,
1942    // parallel to that block's λ slice (#1544). `design.penalties` is index-
1943    // parallel to every active class's `block.lambdas` (each block carries the
1944    // full per-component penalty list, validated above by
1945    // `block.lambdas.len() == penalties_arc.len()`), so iterating it in order
1946    // yields exactly `lambdas_per_block[0]` labels aligned with the per-block λ.
1947    // This is deliberately NOT deduped by col_range (unlike `smooth_term_spans`):
1948    // the double penalty's primary and null-space components share one col_range
1949    // but select independent λ, and each must keep its own label so the summary
1950    // renderer never collapses or drops a λ.
1951    let lambda_labels: Vec<String> = design
1952        .penalties
1953        .iter()
1954        .enumerate()
1955        .map(|(pen_idx, _)| penalty_component_label(design.penaltyinfo.get(pen_idx), pen_idx))
1956        .collect();
1957
1958    // Unpenalized deviance read directly from the converged unpenalized
1959    // log-likelihood the rho-prior driver already computed (issue #348):
1960    // MultinomialFamily::evaluate sets FamilyEvaluation.log_likelihood =
1961    // log_lik(η, y) with no penalty term, and that value flows unchanged into
1962    // UnifiedFitResult.log_likelihood. This reproduces the legacy fixed-λ
1963    // path's `deviance = -2 · log_lik` contract bit-for-bit, so the previous
1964    // row-by-row η = Xβ rebuild and softmax recompute were pure dead work.
1965    let deviance = -2.0 * fit.log_likelihood;
1966
1967    Ok(MultinomialSavedModel {
1968        formula: formula.to_string(),
1969        class_levels: class_levels.clone(),
1970        reference_class_index: class_levels.len() - 1,
1971        resolved_termspec: spec,
1972        coefficients_flat,
1973        p_per_class,
1974        n_active_classes: m,
1975        training_headers: data.headers.clone(),
1976        lambdas: lambdas_flat,
1977        lambdas_per_block,
1978        iterations: fit.inner_cycles,
1979        converged: fit.outer_converged,
1980        penalized_neg_log_likelihood: -fit.log_likelihood + 0.5 * fit.stable_penalty_term,
1981        deviance,
1982        edf_per_class,
1983        edf_per_penalty,
1984        coefficient_covariance_flat,
1985        coefficient_influence_flat,
1986        smooth_term_spans,
1987        lambda_labels,
1988    })
1989}
1990
1991/// Replay the saved termspec to build the predict-time design on a fresh
1992/// dataset, then evaluate softmax probabilities. The predict dataset must carry
1993/// the same feature columns the training data did, matched **by name** — it need
1994/// not reproduce the training column order, and in particular need not carry the
1995/// response column (prediction is for label-free new data).
1996pub fn predict_multinomial_formula(
1997    model: &MultinomialSavedModel,
1998    data: &EncodedDataset,
1999) -> Result<Array2<f64>, EstimationError> {
2000    // The saved termspec stores feature columns as absolute indices into the
2001    // *training* table `[response, features...]`. Replaying it verbatim only
2002    // works if the predict frame reproduces that exact layout — i.e. carries the
2003    // (unknown, at predict time) response column in the same position. Realign
2004    // the indices onto this dataset's columns by name instead, so prediction
2005    // works on label-free new data exactly as every other family's predict path
2006    // does. The response column is simply never referenced by any term, so its
2007    // absence is a non-issue once resolution is by name (issue #803).
2008    let predict_columns = data.column_map();
2009    let realigned = model.resolved_termspec.remap_feature_columns(
2010        |index| -> Result<usize, EstimationError> {
2011            let name = model.training_headers.get(index).ok_or_else(|| {
2012                EstimationError::InvalidInput(format!(
2013                    "multinomial predict: saved training column index {index} is out of bounds \
2014                     for {} training headers",
2015                    model.training_headers.len()
2016                ))
2017            })?;
2018            resolve_role_col(&predict_columns, name, "feature")
2019                .map_err(|err| EstimationError::InvalidInput(err.to_string()))
2020        },
2021    )?;
2022    let design = build_term_collection_design(data.values.view(), &realigned).map_err(|err| {
2023        EstimationError::InvalidInput(format!(
2024            "multinomial predict: rebuild design from saved termspec: {err}"
2025        ))
2026    })?;
2027    let x_dense = design
2028        .design
2029        .try_to_dense_by_chunks("multinomial predict design")
2030        .map_err(EstimationError::InvalidInput)?;
2031    if x_dense.ncols() != model.p_per_class {
2032        crate::bail_invalid_estim!(
2033            "multinomial predict: predict design has {} cols, saved model expects {}",
2034            x_dense.ncols(),
2035            model.p_per_class
2036        );
2037    }
2038    Ok(model.predict_probabilities(x_dense.view()))
2039}
2040
2041/// Predict class probabilities AND delta-method per-class probability standard
2042/// errors for a saved multinomial model on fresh data (#1101). Replays the
2043/// saved termspec to build the predict design exactly as
2044/// [`predict_multinomial_formula`], then applies the softmax-Jacobian delta
2045/// method against the stored joint posterior covariance. Returns
2046/// `(probs (N,K), prob_se (N,K) | None)`; `prob_se` is `None` for a legacy
2047/// model fitted before covariance was surfaced.
2048pub fn predict_multinomial_formula_with_se(
2049    model: &MultinomialSavedModel,
2050    data: &EncodedDataset,
2051) -> Result<(Array2<f64>, Option<Array2<f64>>), EstimationError> {
2052    let predict_columns = data.column_map();
2053    let realigned = model.resolved_termspec.remap_feature_columns(
2054        |index| -> Result<usize, EstimationError> {
2055            let name = model.training_headers.get(index).ok_or_else(|| {
2056                EstimationError::InvalidInput(format!(
2057                    "multinomial predict: saved training column index {index} is out of bounds \
2058                     for {} training headers",
2059                    model.training_headers.len()
2060                ))
2061            })?;
2062            resolve_role_col(&predict_columns, name, "feature")
2063                .map_err(|err| EstimationError::InvalidInput(err.to_string()))
2064        },
2065    )?;
2066    let design = build_term_collection_design(data.values.view(), &realigned).map_err(|err| {
2067        EstimationError::InvalidInput(format!(
2068            "multinomial predict: rebuild design from saved termspec: {err}"
2069        ))
2070    })?;
2071    let x_dense = design
2072        .design
2073        .try_to_dense_by_chunks("multinomial predict design")
2074        .map_err(EstimationError::InvalidInput)?;
2075    if x_dense.ncols() != model.p_per_class {
2076        crate::bail_invalid_estim!(
2077            "multinomial predict: predict design has {} cols, saved model expects {}",
2078            x_dense.ncols(),
2079            model.p_per_class
2080        );
2081    }
2082    Ok(model.predict_probabilities_with_se(x_dense.view()))
2083}
2084
2085#[cfg(test)]
2086mod fisher_override_tests {
2087    use super::*;
2088    use ndarray::Array3;
2089
2090    fn toy() -> (Array2<f64>, Array2<f64>, Array2<f64>, Array1<f64>) {
2091        let n = 15;
2092        let p = 2;
2093        let k = 3;
2094        let design =
2095            Array2::<f64>::from_shape_fn(
2096                (n, p),
2097                |(i, j)| {
2098                    if j == 0 { 1.0 } else { ((i + 2) as f64).cos() }
2099                },
2100            );
2101        let mut y = Array2::<f64>::zeros((n, k));
2102        for i in 0..n {
2103            y[[i, i % k]] = 1.0;
2104        }
2105        let penalty = Array2::<f64>::eye(p);
2106        let lambdas = Array1::<f64>::from_elem(k - 1, 0.5);
2107        (design, y, penalty, lambdas)
2108    }
2109
2110    #[test]
2111    fn fisher_override_none_reproduces_analytic() {
2112        // Issue #349: None override is exactly the analytic fit.
2113        let (design, y, penalty, lambdas) = toy();
2114        let mk = |over: Option<ndarray::ArrayView3<'_, f64>>| {
2115            fit_penalized_multinomial(MultinomialFitInputs {
2116                design: design.view(),
2117                y_one_hot: y.view(),
2118                penalty: penalty.view(),
2119                lambdas: lambdas.view(),
2120                row_weights: None,
2121                fisher_w_override: over,
2122                max_iter: 50,
2123                tol: 1.0e-9,
2124            })
2125            .expect("fit must succeed")
2126        };
2127        let a = mk(None);
2128        let b = mk(None);
2129        for (x, z) in a
2130            .coefficients_active
2131            .iter()
2132            .zip(b.coefficients_active.iter())
2133        {
2134            assert_eq!(x, z);
2135        }
2136    }
2137
2138    #[test]
2139    fn fisher_override_wrong_shape_is_rejected() {
2140        let (design, y, penalty, lambdas) = toy();
2141        let n = design.nrows();
2142        let m = y.ncols(); // K, not K-1 — deliberately wrong
2143        let bad = Array3::<f64>::zeros((n, m, m));
2144        let err = fit_penalized_multinomial(MultinomialFitInputs {
2145            design: design.view(),
2146            y_one_hot: y.view(),
2147            penalty: penalty.view(),
2148            lambdas: lambdas.view(),
2149            row_weights: None,
2150            fisher_w_override: Some(bad.view()),
2151            max_iter: 50,
2152            tol: 1.0e-9,
2153        })
2154        .expect_err("wrong active-block shape must error");
2155        assert!(format!("{err}").contains("fisher_w_override shape"));
2156    }
2157
2158    #[test]
2159    fn formula_outer_route_uses_exact_curvature_for_medium_d() {
2160        // The 2-smooth reference formula fit (K = 3, double-penalty terms) is
2161        // D = (K-1) * 2 terms * 2 penalties = 8 and needs exact curvature to
2162        // avoid over-smoothed lambda caps (#715 arm (a)).
2163        assert!(
2164            multinomial_formula_use_outer_hessian(8),
2165            "D=8 loaded multinomial fits need exact curvature to avoid over-smoothed lambda caps"
2166        );
2167        assert!(
2168            multinomial_formula_use_outer_hessian(12),
2169            "D=12 (3 double-penalty smooth terms, K=3) stays on exact curvature"
2170        );
2171    }
2172
2173    #[test]
2174    fn formula_outer_route_uses_exact_curvature_for_d16_penguin_fixture() {
2175        // Four k=10 penguin smooths (K = 3) are D = 16 under double-penalty
2176        // terms. They must reach the exact ARC route so the #1082 cost-stall
2177        // halt is available on the near-separable lambda-to-zero ridge.
2178        assert!(
2179            multinomial_formula_use_outer_hessian(16),
2180            "D=16 multinomial fits need exact ARC curvature for the #1082 stall halt"
2181        );
2182    }
2183
2184    #[test]
2185    fn formula_min_lambda_floor_is_continuous_and_information_scaled() {
2186        // Build a one-hot label matrix whose smallest class carries `count` rows.
2187        fn floor_for_min_count(count: usize) -> f64 {
2188            // Two classes: a large one (1000 rows) and a minority one (`count`).
2189            let n = 1000 + count;
2190            let mut y = Array2::<f64>::zeros((n, 2));
2191            for r in 0..1000 {
2192                y[[r, 0]] = 1.0;
2193            }
2194            for r in 1000..n {
2195                y[[r, 1]] = 1.0;
2196            }
2197            multinomial_formula_min_lambda(y.view())
2198        }
2199
2200        // The floor's endpoints are now DERIVED from a target prior strength in
2201        // pseudo-observations against the maximal per-observation softmax Fisher
2202        // information I₁ = ¼ (base = τ·I₁, sparse = τ_max·I₁). Pin them to the
2203        // previously fixture-calibrated values so the near-separable quality arms
2204        // (penguins, vgam softmax) — whose smallest class has n_c ≥ 50 — are
2205        // byte-for-byte unaffected: the derivation REDUCES TO the old constants
2206        // at the calibration point.
2207        let base = MULTINOMIAL_FORMULA_PRIOR_PSEUDO_OBS * MULTINOMIAL_FORMULA_FISHER_INFO_PER_OBS;
2208        let sparse = MULTINOMIAL_FORMULA_SPARSE_PRIOR_PSEUDO_OBS_MAX
2209            * MULTINOMIAL_FORMULA_FISHER_INFO_PER_OBS;
2210        assert!(
2211            (base - 2.0e-4).abs() < 1e-18,
2212            "derived base floor must equal the calibrated 2e-4"
2213        );
2214        assert!(
2215            (sparse - 1.0e-3).abs() < 1e-18,
2216            "derived sparse floor must equal the calibrated 1e-3"
2217        );
2218
2219        // Well-supported (n_c >= n_ref=50) sits exactly at the base floor.
2220        assert!((floor_for_min_count(50) - base).abs() < 1e-18);
2221        assert!((floor_for_min_count(200) - base).abs() < 1e-18);
2222        // Very sparse (n_c <= n_ref·base/sparse = 10) clamps to the strong floor.
2223        assert!((floor_for_min_count(10) - sparse).abs() < 1e-18);
2224        assert!((floor_for_min_count(5) - sparse).abs() < 1e-18);
2225        // No cliff at the old hard threshold: 49 vs 50 differ by < 5% (the old
2226        // step jumped 5x). Floor is monotone non-increasing in support.
2227        let f49 = floor_for_min_count(49);
2228        let f50 = floor_for_min_count(50);
2229        assert!(
2230            f49 >= f50 && f49 <= f50 * 1.05,
2231            "floor must be continuous across c0, got {f49} vs {f50}"
2232        );
2233        let f25 = floor_for_min_count(25);
2234        assert!(
2235            f25 > f50 && f25 < floor_for_min_count(10),
2236            "mid-support floor must interpolate strictly between the two endpoints"
2237        );
2238
2239        // FIRST-PRINCIPLES SCALING: in the interpolating regime the floor equals
2240        // exactly τ·I₁·(n_ref/n_c) — the effective-pseudo-observation prior held
2241        // to a fixed fraction of the per-class data information n_c·I₁. Halving
2242        // the effective sample size doubles the floor (until the cap), and the
2243        // absolute value matches the closed-form n_c-scaled prior.
2244        for &n_c in &[12usize, 16, 20, 30, 40] {
2245            let expected = base * (MULTINOMIAL_FORMULA_SPARSE_REFERENCE_SUPPORT / n_c as f64);
2246            assert!(
2247                (floor_for_min_count(n_c) - expected).abs() < 1e-15,
2248                "floor at n_c={n_c} must be τ·I₁·n_ref/n_c = {expected}, got {}",
2249                floor_for_min_count(n_c)
2250            );
2251        }
2252        // Inverse scaling with effective sample size: n_c -> n_c/2 doubles the
2253        // floor inside the unclamped band (20 and 40 are both interior; 40 < 50
2254        // so it is scaled, 20 > 10 so it is not capped).
2255        assert!(
2256            (floor_for_min_count(20) - 2.0 * floor_for_min_count(40)).abs() < 1e-15,
2257            "floor must scale like 1/n_c (effective Fisher information) in the interior band"
2258        );
2259    }
2260
2261    #[test]
2262    fn formula_penalty_scale_tracks_softmax_fisher_curvature() {
2263        assert!(
2264            (multinomial_formula_penalty_scale(2) - 0.5).abs() < 1.0e-12,
2265            "binary-logit neutral-simplex curvature scale should remain at 1/2"
2266        );
2267        assert!(
2268            (multinomial_formula_penalty_scale(3) - 4.0 / 9.0).abs() < 1.0e-12,
2269            "three-class softmax penalties should be calibrated to 2*(K-1)/K^2"
2270        );
2271        assert!(
2272            multinomial_formula_penalty_scale(5) < multinomial_formula_penalty_scale(3),
2273            "active-class Fisher curvature decreases as the simplex gains classes"
2274        );
2275    }
2276
2277    #[test]
2278    fn fixed_lambda_multinomial_reports_complete_separation() {
2279        let n = 90;
2280        let design = Array2::<f64>::from_shape_fn((n, 2), |(row, col)| match col {
2281            0 => 1.0,
2282            _ => -3.0 + 6.0 * (row as f64) / ((n - 1) as f64),
2283        });
2284        let mut y = Array2::<f64>::zeros((n, 3));
2285        for row in 0..n {
2286            let x = design[[row, 1]];
2287            let class = if x < -1.0 {
2288                0
2289            } else if x > 1.0 {
2290                1
2291            } else {
2292                2
2293            };
2294            y[[row, class]] = 1.0;
2295        }
2296        let penalty = Array2::<f64>::zeros((2, 2));
2297        let lambdas = Array1::<f64>::zeros(2);
2298        let err = fit_penalized_multinomial(MultinomialFitInputs {
2299            design: design.view(),
2300            y_one_hot: y.view(),
2301            penalty: penalty.view(),
2302            lambdas: lambdas.view(),
2303            row_weights: None,
2304            fisher_w_override: None,
2305            max_iter: 80,
2306            tol: 1.0e-12,
2307        })
2308        .expect_err("complete softmax separation must be a hard diagnostic");
2309        assert!(
2310            matches!(err, EstimationError::MultinomialSeparationDetected { .. }),
2311            "expected MultinomialSeparationDetected, got {err:?}"
2312        );
2313        assert!(
2314            err.to_string().contains("separation"),
2315            "diagnostic should mention separation, got {err}"
2316        );
2317        assert!(
2318            err.to_string().contains("active class-"),
2319            "diagnostic should name the separated active class logit, got {err}"
2320        );
2321        assert!(
2322            !err.to_string().contains("binary outcomes"),
2323            "multinomial diagnostic must not reuse the binary separation text, got {err}"
2324        );
2325    }
2326
2327    #[test]
2328    fn formula_multinomial_accepts_finite_saturated_logits() {
2329        // A saturated-but-FINITE logit surface can be a valid formula REML mode
2330        // (the #715 penguins regime: bill/flipper cleanly separate the species,
2331        // so fitted logits can legitimately exceed ±25). `outer_converged ==
2332        // false` then signals only that the driver auto-escalated to never-fail
2333        // posterior sampling about that finite mode (gam#860), NOT a separation
2334        // artifact — the adapter must accept it, never raise
2335        // `MultinomialSeparationDetected`.
2336        let saturated_states = vec![
2337            ParameterBlockState {
2338                beta: Array1::from_vec(vec![1.0, 2.0]),
2339                eta: Array1::from_vec(vec![0.2, 4.0, -7.0]),
2340            },
2341            ParameterBlockState {
2342                beta: Array1::from_vec(vec![-1.0, 3.0]),
2343                eta: Array1::from_vec(vec![1.0, 25.5, -0.1]),
2344            },
2345        ];
2346        assert!(
2347            multinomial_formula_separation_diagnostic(17, 9, &saturated_states).is_none(),
2348            "a finite (even saturated, |eta|>25) formula optimum is a valid fit, \
2349             not a separation diagnostic"
2350        );
2351
2352        // Only a genuinely NON-FINITE logit — a NaN/Inf blow-up in the inner
2353        // linear algebra with no finite mode to sample about — is a real
2354        // formula-path failure.
2355        let blown_up = vec![
2356            ParameterBlockState {
2357                beta: Array1::from_vec(vec![1.0, 2.0]),
2358                eta: Array1::from_vec(vec![0.2, 4.0, -7.0]),
2359            },
2360            ParameterBlockState {
2361                beta: Array1::from_vec(vec![-1.0, 3.0]),
2362                eta: Array1::from_vec(vec![1.0, f64::INFINITY, -0.1]),
2363            },
2364        ];
2365        let err = multinomial_formula_separation_diagnostic(17, 9, &blown_up)
2366            .expect("a non-finite formula logit must raise the separation diagnostic");
2367        assert!(
2368            matches!(
2369                err,
2370                EstimationError::MultinomialSeparationDetected {
2371                    iteration: 17,
2372                    max_abs_eta,
2373                    active_class_index: 1,
2374                    row_index: 1,
2375                } if !max_abs_eta.is_finite()
2376            ),
2377            "expected typed multinomial separation diagnostic at the non-finite channel, got {err:?}"
2378        );
2379    }
2380
2381    #[test]
2382    fn separation_evidence_gate_arms_firth_only_on_blowup() {
2383        // Interior fit: finite logits well inside the saturation threshold ⇒ NO
2384        // separation evidence ⇒ the unbiased criterion's mode is accepted as-is
2385        // and the Firth/Jeffreys prior stays disarmed (#715 arm (a): no 1/K
2386        // shrinkage on well-identified data).
2387        let interior = vec![
2388            ParameterBlockState {
2389                beta: Array1::from_vec(vec![1.0, 2.0]),
2390                eta: Array1::from_vec(vec![0.2, 4.0, -7.0]),
2391            },
2392            ParameterBlockState {
2393                beta: Array1::from_vec(vec![-1.0, 3.0]),
2394                eta: Array1::from_vec(vec![1.0, -3.5, -0.1]),
2395            },
2396        ];
2397        assert!(
2398            multinomial_formula_separation_evidence(&interior).is_none(),
2399            "an interior finite mode must not arm the Firth refit"
2400        );
2401
2402        // Saturated but finite logits are valid formula-path modes on
2403        // near-separated real data. They must not arm the Firth refit because
2404        // the Jeffreys pull can over-regularize the held-out probabilities.
2405        let saturated = vec![
2406            ParameterBlockState {
2407                beta: Array1::from_vec(vec![1.0, 2.0]),
2408                eta: Array1::from_vec(vec![0.2, 4.0, -7.0]),
2409            },
2410            ParameterBlockState {
2411                beta: Array1::from_vec(vec![-1.0, 3.0]),
2412                eta: Array1::from_vec(vec![1.0, 25.5, -0.1]),
2413            },
2414        ];
2415        assert!(
2416            multinomial_formula_separation_evidence(&saturated).is_none(),
2417            "a finite saturated formula-mode logit must not arm the Firth refit"
2418        );
2419
2420        // Non-finite logit ⇒ inner blow-up along an unbounded direction ⇒
2421        // separation evidence.
2422        let blown_up = vec![ParameterBlockState {
2423            beta: Array1::from_vec(vec![1.0, 2.0]),
2424            eta: Array1::from_vec(vec![0.2, f64::NAN, -7.0]),
2425        }];
2426        let evidence = multinomial_formula_separation_evidence(&blown_up)
2427            .expect("a non-finite logit is separation evidence");
2428        assert!(
2429            evidence.contains("non-finite logit") && evidence.contains("row 1"),
2430            "evidence must name the non-finite logit, got {evidence}"
2431        );
2432
2433        // Large finite logits below the fixed-lambda diagnostic threshold are
2434        // likewise accepted on the formula path.
2435        let near = vec![ParameterBlockState {
2436            beta: Array1::from_vec(vec![1.0, 2.0]),
2437            eta: Array1::from_vec(vec![0.2, 24.9, -24.9]),
2438        }];
2439        assert!(
2440            multinomial_formula_separation_evidence(&near).is_none(),
2441            "logits below the saturation threshold must not arm the Firth refit"
2442        );
2443    }
2444
2445    #[test]
2446    fn unresolved_probe_evidence_arms_firth_on_saturated_finite_logits() {
2447        let saturated = vec![
2448            ParameterBlockState {
2449                beta: Array1::from_vec(vec![1.0, 2.0]),
2450                eta: Array1::from_vec(vec![0.2, 4.0, -7.0]),
2451            },
2452            ParameterBlockState {
2453                beta: Array1::from_vec(vec![-1.0, 3.0]),
2454                eta: Array1::from_vec(vec![1.0, 25.5, -0.1]),
2455            },
2456        ];
2457
2458        assert!(
2459            multinomial_formula_separation_evidence(&saturated).is_none(),
2460            "a converged finite saturated formula optimum remains unbiased"
2461        );
2462        let evidence = multinomial_formula_unresolved_probe_separation_evidence(&saturated)
2463            .expect("a non-converged saturated probe should arm the Firth refit");
2464        assert!(
2465            evidence.contains("separation-scale finite logit")
2466                && evidence.contains("row 1")
2467                && evidence.contains("active class 1"),
2468            "unresolved-probe evidence should name the saturated channel, got {evidence}"
2469        );
2470
2471        let near = vec![ParameterBlockState {
2472            beta: Array1::from_vec(vec![1.0, 2.0]),
2473            eta: Array1::from_vec(vec![0.2, 24.9, -24.9]),
2474        }];
2475        assert!(
2476            multinomial_formula_unresolved_probe_separation_evidence(&near).is_none(),
2477            "finite logits below the separation threshold still get the full unbiased retry"
2478        );
2479    }
2480
2481    #[test]
2482    fn scaled_fisher_override_changes_first_step() {
2483        // Curvature scaled by 4× shrinks the first Newton step relative to the
2484        // analytic fit, so a single-iteration fit must differ.
2485        let (design, y, penalty, lambdas) = toy();
2486        let n = design.nrows();
2487        let m = y.ncols() - 1;
2488        // Analytic block at β = 0: p_a = 1/K = 1/3, so diag = p_a(1−p_a),
2489        // off-diag = −p_a p_b. Scale that exact block by 4.
2490        let pk = 1.0 / (y.ncols() as f64);
2491        let mut over = Array3::<f64>::zeros((n, m, m));
2492        for row in 0..n {
2493            for a in 0..m {
2494                for b in 0..m {
2495                    let analytic = if a == b { pk * (1.0 - pk) } else { -pk * pk };
2496                    over[[row, a, b]] = 4.0 * analytic;
2497                }
2498            }
2499        }
2500        let scaled = fit_penalized_multinomial(MultinomialFitInputs {
2501            design: design.view(),
2502            y_one_hot: y.view(),
2503            penalty: penalty.view(),
2504            lambdas: lambdas.view(),
2505            row_weights: None,
2506            fisher_w_override: Some(over.view()),
2507            max_iter: 1,
2508            tol: 1.0e-9,
2509        })
2510        .expect("override fit must succeed");
2511        let analytic = fit_penalized_multinomial(MultinomialFitInputs {
2512            design: design.view(),
2513            y_one_hot: y.view(),
2514            penalty: penalty.view(),
2515            lambdas: lambdas.view(),
2516            row_weights: None,
2517            fisher_w_override: None,
2518            max_iter: 1,
2519            tol: 1.0e-9,
2520        })
2521        .expect("analytic fit must succeed");
2522        let differs = scaled
2523            .coefficients_active
2524            .iter()
2525            .zip(analytic.coefficients_active.iter())
2526            .any(|(a, b)| (a - b).abs() > 1.0e-6);
2527        assert!(differs, "scaled curvature must change the first step");
2528    }
2529}
2530
2531#[cfg(test)]
2532mod reference_class_invariance_tests {
2533    //! Regression for #1587: a penalized multinomial-logit GAM fit must be
2534    //! invariant to which class is the (arbitrary) softmax reference/baseline.
2535    //!
2536    //! The production REML path (`fit_penalized_multinomial_formula`) reference-
2537    //! codes the `K` classes (the last sorted label is the baseline) and, with
2538    //! the legacy `Diagonal` penalty metric, penalizes only the `K−1`
2539    //! reference-anchored ALR contrasts `½ Σ_a λ_a β_aᵀ S β_a`. Relabeling the
2540    //! response so a *different* class sorts last penalizes a different frame of
2541    //! log-odds contrasts, so the predicted probabilities drift (~1e-2 absolute)
2542    //! even though they are mathematically independent of the reference choice.
2543    //!
2544    //! This test fits the SAME 3-class softmax sample under three cyclic
2545    //! relabelings — each making a different original class the baseline —
2546    //! realigns the predicted probability columns back to the original class
2547    //! identities, and asserts the cross-labeling drift is below `1e-3`
2548    //! (the defect is ~1e-2; refitting the same labeling twice agrees to
2549    //! ~1e-12). It is the Rust-level sibling of
2550    //! `tests/bug_hunt_multinomial_fit_depends_on_reference_class_test.py`.
2551
2552    use super::*;
2553    use gam_data::load_dataset_projected;
2554    use std::fmt::Write as _;
2555    use std::fs;
2556    use tempfile::tempdir;
2557
2558    /// Deterministic `splitmix64` → `[0,1)` uniform stream (no external RNG dep;
2559    /// the only requirement is a well-distributed, reproducible draw).
2560    struct SplitMix64(u64);
2561    impl SplitMix64 {
2562        fn next_u64(&mut self) -> u64 {
2563            self.0 = self.0.wrapping_add(0x9E37_79B9_7F4A_7C15);
2564            let mut z = self.0;
2565            z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
2566            z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
2567            z ^ (z >> 31)
2568        }
2569        fn unit(&mut self) -> f64 {
2570            // 53-bit mantissa uniform in [0, 1).
2571            (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
2572        }
2573    }
2574
2575    /// Draw a clean 3-class softmax regression sample (the issue's generator).
2576    /// Returns `(x, class)` with integer classes `0/1/2`.
2577    fn sample_classes(seed: u64, n: usize) -> (Vec<f64>, Vec<usize>) {
2578        let mut rng = SplitMix64(seed.wrapping_add(0x1234_5678));
2579        let mut x = Vec::with_capacity(n);
2580        let mut cls = Vec::with_capacity(n);
2581        for _ in 0..n {
2582            let xi = -2.0 + 4.0 * rng.unit();
2583            let eta = [0.5 + 0.8 * xi, -0.3 - 0.5 * xi, 0.0];
2584            let mut p = [eta[0].exp(), eta[1].exp(), eta[2].exp()];
2585            let s: f64 = p.iter().sum();
2586            for v in &mut p {
2587                *v /= s;
2588            }
2589            // Inverse-CDF draw into one of the 3 classes.
2590            let u = rng.unit();
2591            let c = if u < p[0] {
2592                0
2593            } else if u < p[0] + p[1] {
2594                1
2595            } else {
2596                2
2597            };
2598            x.push(xi);
2599            cls.push(c);
2600        }
2601        (x, cls)
2602    }
2603
2604    /// Build an `EncodedDataset` with columns `x` (numeric) and `y`
2605    /// (categorical, from the given string labels) by round-tripping a CSV.
2606    fn dataset_xy(dir: &std::path::Path, tag: &str, x: &[f64], y: &[String]) -> gam_data::EncodedDataset {
2607        let path = dir.join(format!("data_{tag}.csv"));
2608        let mut csv = String::from("x,y\n");
2609        for (xi, yi) in x.iter().zip(y.iter()) {
2610            writeln!(csv, "{xi},{yi}").unwrap();
2611        }
2612        fs::write(&path, csv).expect("write training csv");
2613        load_dataset_projected(&path, &["x".to_string(), "y".to_string()])
2614            .expect("load training dataset")
2615    }
2616
2617    /// Fit `y ~ s(x)` under the relabeling `name_map` (original class `c` gets
2618    /// label `name_map[c]`), predict on `grid`, and return the predicted
2619    /// probabilities **realigned to the original class order** 0/1/2, shape
2620    /// `(grid.len(), 3)`.
2621    fn fit_predict_aligned(
2622        dir: &std::path::Path,
2623        tag: &str,
2624        x: &[f64],
2625        cls: &[usize],
2626        name_map: [&str; 3],
2627        grid: &[f64],
2628    ) -> Array2<f64> {
2629        let labels: Vec<String> = cls.iter().map(|&c| name_map[c].to_string()).collect();
2630        let train = dataset_xy(dir, tag, x, &labels);
2631        let config = FitConfig::default();
2632        let model = fit_penalized_multinomial_formula(&train, "y ~ s(x)", &config, 1.0, 60, 1e-6)
2633            .expect("multinomial formula fit must succeed");
2634
2635        // Predict on the grid. The categorical `y` column is not needed for
2636        // prediction, but the schema is simplest if we supply a dummy.
2637        let grid_y: Vec<String> = grid.iter().map(|_| name_map[0].to_string()).collect();
2638        let grid_ds = dataset_xy(dir, &format!("{tag}_grid"), grid, &grid_y);
2639        let probs = predict_multinomial_formula(&model, &grid_ds)
2640            .expect("multinomial predict must succeed");
2641
2642        // `model.class_levels` is the sorted label order; the column for original
2643        // class `c` is at the rank of `name_map[c]` among the sorted labels.
2644        let mut sorted: Vec<&str> = name_map.to_vec();
2645        sorted.sort_unstable();
2646        let col_of_orig: Vec<usize> = (0..3)
2647            .map(|c| sorted.iter().position(|l| *l == name_map[c]).unwrap())
2648            .collect();
2649        // Sanity: the model's class_levels must match the sorted labels.
2650        assert_eq!(
2651            model.class_levels,
2652            sorted.iter().map(|s| s.to_string()).collect::<Vec<_>>(),
2653            "class_levels must be the sorted label order"
2654        );
2655        let n = grid.len();
2656        let mut aligned = Array2::<f64>::zeros((n, 3));
2657        for r in 0..n {
2658            for c in 0..3 {
2659                aligned[[r, c]] = probs[[r, col_of_orig[c]]];
2660            }
2661        }
2662        aligned
2663    }
2664
2665    fn max_abs_diff(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
2666        a.iter()
2667            .zip(b.iter())
2668            .map(|(p, q)| (p - q).abs())
2669            .fold(0.0_f64, f64::max)
2670    }
2671
2672    // gam#1587: now that the reference-symmetric centered `M⊗S_t` joint penalty
2673    // is wired through the custom-family outer REML loop (per-eval
2674    // `JointPenaltyBundle` + outer penalty_coords/logdet/operator), the
2675    // production multinomial fit is invariant to the arbitrary reference class,
2676    // so this guard runs by default (the opt-in skip attribute it carried while
2677    // the fix was pending is also forbidden by the build.rs ban-scanner). It is
2678    // an end-to-end fit guard (a handful of full softmax `y ~ s(x)` fits) —
2679    // slower than a unit test but a true production-path regression.
2680    #[test]
2681    fn multinomial_fit_is_invariant_to_reference_class_1587() {
2682        let td = tempdir().expect("tempdir");
2683        let dir = td.path();
2684        // The reference-class drift is STRUCTURAL (it does not shrink with n, see
2685        // the issue table), so a modest n exposes it just as cleanly as n=900
2686        // while keeping this an affordable CI guard.
2687        let (x, cls) = sample_classes(0, 300);
2688        let grid: Vec<f64> = (0..7).map(|i| -1.5 + 3.0 * (i as f64) / 6.0).collect();
2689
2690        // Three labelings that each make a DIFFERENT original class the baseline
2691        // (the class whose label sorts LAST is the reference K−1):
2692        //   ["A","B","C"] → ref = class 2
2693        //   ["B","C","A"] → ref = class 1
2694        //   ["C","A","B"] → ref = class 0
2695        let a = fit_predict_aligned(dir, "abc", &x, &cls, ["A", "B", "C"], &grid);
2696        let b = fit_predict_aligned(dir, "bca", &x, &cls, ["B", "C", "A"], &grid);
2697        let c = fit_predict_aligned(dir, "cab", &x, &cls, ["C", "A", "B"], &grid);
2698
2699        // Refitting the SAME labeling twice must agree to ~machine precision —
2700        // this isolates optimizer noise from the structural reference drift.
2701        let a2 = fit_predict_aligned(dir, "abc2", &x, &cls, ["A", "B", "C"], &grid);
2702        let refit_noise = max_abs_diff(&a, &a2);
2703        assert!(
2704            refit_noise < 1e-6,
2705            "refitting the same labeling must be deterministic (got {refit_noise:.3e})"
2706        );
2707
2708        let drift = max_abs_diff(&a, &b)
2709            .max(max_abs_diff(&a, &c))
2710            .max(max_abs_diff(&b, &c));
2711        assert!(
2712            drift < 1e-3,
2713            "predicted probabilities must be invariant to the reference class; \
2714             cross-labeling drift = {drift:.3e} (refit noise = {refit_noise:.3e})"
2715        );
2716    }
2717}