Skip to main content

gam_models/
multinomial.rs

1//! Penalized multinomial-logit (softmax) GLM driver — fixed-λ inner solve.
2//!
3//! This is the principled vector-response companion to the scalar PIRLS path:
4//! the inner-loop Newton solver for a multi-class GAM at fixed smoothing
5//! parameters λ, using the canonical multinomial-logit likelihood
6//! ([`MultinomialLogitLikelihood`]) and the existing dense block-Fisher
7//! assembly in [`gam_solve::pirls::dense_block_xtwx`] /
8//! [`gam_solve::pirls::dense_block_xtwy`].
9//!
10//! # What this module does
11//!
12//! Solve, for the reference-coded multinomial-logit GAM with `K` classes and
13//! design matrix `X ∈ ℝ^{N×P}`,
14//!
15//! ```text
16//!     β̂ = argmin_β { − log L(β) + ½ Σ_{a=0}^{K-2} λ_a · β_a^T S β_a }
17//! ```
18//!
19//! where `β = [β_0; β_1; …; β_{K-2}]` is the stacked coefficient vector in
20//! output-major order (`β_a ∈ ℝ^P` is the coefficient block for class `a`),
21//! `S ∈ ℝ^{P×P}` is the smoothing penalty matrix (shared across classes,
22//! replicated as `I_{K-1} ⊗ S` over the full parameter space), and `λ_a` is
23//! a per-class smoothing parameter.
24//!
25//! The likelihood uses class `K - 1` as the reference (`η_{K-1} ≡ 0`), so the
26//! softmax gauge is fixed at the η level and no additional sum-to-zero
27//! projection is required.
28//!
29//! # Layering
30//!
31//! * **Fixed-λ inner solve** — [`fit_penalized_multinomial`] is the canonical
32//!   coefficient-space Newton solver at *given* smoothing parameters `λ`,
33//!   built on the shared [`crate::penalized_vector_glm`] engine.
34//!
35//! * **REML / LAML smoothing-parameter selection** — [`fit_penalized_multinomial_formula`]
36//!   routes through [`crate::custom_family::fit_custom_family_with_rho_prior`]
37//!   so the per-active-class `λ_a` are selected by the outer REML/LAML loop;
38//!   the caller's `init_lambda` is only a warm-start seed. The multinomial
39//!   [`crate::multinomial_reml::MultinomialFamily`] `CustomFamily`
40//!   impl calls the fixed-λ math above as its inner solve at each ρ trial and
41//!   supplies the dense per-row Hessian block for the outer trace terms.
42//!
43//! * **Formula → design integration** — `build_formula_design_for_multinomial`
44//!   parses the Wilkinson formula and assembles `X` and the per-term `S`
45//!   blocks; the `fit_multinomial_formula_pyfunc` FFI shim wires the Python
46//!   `gamfit.fit(..., family='multinomial')` entry straight to this path.
47//!
48//! # Convergence
49//!
50//! The damped-Newton-with-backtracking scaffold lives once in the shared
51//! [`crate::penalized_vector_glm`] engine: at each iteration the
52//! assembled penalized Hessian `H + I_{K-1} ⊗ (λ_a S)` is factored via faer's
53//! symmetric-PD-with-fallback path, the full Newton step `δ = −H^{-1} ∇F` is
54//! computed, and accepted with step halving if the objective fails to decrease
55//! (up to a small backtracking budget). The convergence test is the relative
56//! coefficient step norm `‖δ‖ / (1 + ‖β‖) ≤ tol`, matching the existing pyffi
57//! reference path. This module is the softmax adapter over that engine: it
58//! supplies the dense `(K-1)×(K-1)` Fisher block, the residual, and the
59//! log-likelihood through [`MultinomialLogitLikelihood`], and owns the
60//! class-count / simplex preconditions. The independent-binomial sibling
61//! [`crate::binomial_multi`] is the same engine with a row-diagonal
62//! Fisher block instead.
63
64use crate::custom_family::{
65    BlockwiseFitOptions, ParameterBlockState, PenaltyMatrix, fit_custom_family_with_rho_prior,
66};
67use crate::multinomial_reml::MultinomialFamily;
68use crate::penalized_vector_glm::{PenalizedVectorGlmInputs, fit_penalized_vector_glm};
69use crate::vector_response::{MultinomialLogitLikelihood, validate_multinomial_simplex};
70use gam_terms::inference::formula_dsl::parse_formula;
71use crate::model_types::EstimationError;
72use crate::fit_orchestration::{
73    FitConfig, build_termspec_with_geometry_and_overrides, resolved_resource_policy,
74};
75use gam_terms::smooth::{
76    PenaltyBlockInfo, TermCollectionDesign, TermCollectionSpec, build_term_collection_design,
77};
78use crate::fit_orchestration::drivers::freeze_term_collection_from_design;
79use gam_terms::term_builder::resolve_role_col;
80use gam_problem::ResponseColumnKind;
81use gam_data::ColumnKindTag;
82use gam_data::EncodedDataset;
83use gam_runtime::resource::ProblemHints;
84use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayView3};
85use serde::{Deserialize, Serialize};
86use std::sync::Arc;
87
88/// Solver-only numerical stabilization floor for the formula-driven
89/// multinomial REML inner solve (gam#747).
90///
91/// Installed with [`RidgePolicy::solver_only`](gam_problem::RidgePolicy::solver_only)
92/// so it stabilizes the inner joint-Newton **linear solve** but never enters
93/// the REML objective, the penalty log-determinant, or the Laplace Hessian.
94///
95/// What it does: the multinomial smoothing penalties are rank-deficient by
96/// design (each smooth carries an unpenalized polynomial null space) and the
97/// formula may add a fully unpenalized parametric term (`x3` / `body_mass`). On
98/// near-separable hard labels the softmax curvature is ill-conditioned along
99/// those directions, so the bare Newton step `H⁻¹∇` is huge. Lifting the
100/// smallest Hessian eigenvalue to `δ` bounds the step (`‖(H+δI)⁻¹∇‖ ≤ ‖∇‖/δ`),
101/// keeping the screening iterates finite without poisoning the softmax with
102/// `inf − inf = NaN`.
103///
104/// What it deliberately does NOT do: it adds no `½·δ·‖β‖²` term to the
105/// objective and no `δ`-shift to the REML log-determinant. The earlier
106/// `explicit_stabilization_pospart` policy folded both into the criterion,
107/// which made `1e-4` a fixed-λ Gaussian prior that shrank every identified
108/// coefficient off the MLE and biased smoothing-parameter selection — a value
109/// that had to be tuned *between* under-stabilization (NaN seeds) and
110/// over-shrinkage (lost VGAM match). As a solver-only floor that tradeoff is
111/// gone: the over-shrinkage failure mode cannot occur (nothing is shrunk), the
112/// optimized objective is the true penalized REML criterion, and the floor
113/// only has to be large enough to keep the linear algebra finite.
114///
115/// The separation defect (#753) is no longer this floor's job. If the
116/// multinomial MLE is genuinely at infinity for an unpenalized/null-space
117/// direction (complete/quasi-complete separation), no solver floor makes that
118/// direction's estimate finite. The formula REML path arms the full-span
119/// Jeffreys/Firth correction CONDITIONALLY — only on separation evidence (see
120/// [`multinomial_formula_separation_evidence`] and the two-attempt logic in
121/// [`fit_penalized_multinomial_formula`]) — so an interior, well-identified fit
122/// optimizes the unbiased penalized-REML criterion with no Firth shrinkage
123/// toward the uniform simplex, while a (quasi-)separated geometry gets the
124/// proper prior that is the only thing able to bound its penalty-null
125/// directions (#715 real-data arm). The bare fixed-λ inner driver
126/// [`fit_penalized_multinomial`] (no outer REML, no Jeffreys term) surfaces the
127/// explicit `MultinomialSeparationDetected` diagnostic for the path that has no
128/// proper prior to lean on.
129const MULTINOMIAL_FORMULA_RIDGE_FLOOR: f64 = 1.0e-4;
130
131/// Inner joint-Newton KKT tolerance for the multinomial formula path.
132///
133/// The softmax Fisher weight `W = diag(p) − ppᵀ` collapses on saturated rows,
134/// so near-separable fits (penguins, #715) reach the OBJECTIVE's f64 noise
135/// floor before the default `inner_tol = 1e-6` KKT target: measured on the
136/// penguins arm (standardized columns), the trust region collapses to 1e-12
137/// with per-attempt objective changes of ~+2e-9 on |obj| ≈ 1e2 (≈ 1e-11
138/// relative — pure rounding) while the KKT residual plateaus at 2.8e-5–9.4e-5
139/// against a scaled tolerance of ~1.9e-5. Demanding a residual below the
140/// floating-point noise floor is certifiable-never: every eval is rejected by
141/// the stall guard and the whole fit fails. `1e-5` certifies the measured
142/// plateaus while still resolving β to ~1e-6 in the relevant metric — the
143/// LAML criterion consumes β̂ with error O(residual²/curvature), far below
144/// any quantity the outer ρ-search can read.
145const MULTINOMIAL_FORMULA_INNER_TOL: f64 = 1.0e-5;
146
147/// Formula-adapter penalty calibration for multinomial softmax REML.
148///
149/// The term builder's normalized penalties are calibrated on single-response
150/// Gaussian-style score curvature. A reference-coded softmax class block sees
151/// per-row active-class Fisher diagonal `p_a(1-p_a)` plus negative cross-class
152/// coupling. At the neutral simplex (`p_k = 1/K`) the active diagonal is
153/// `(K-1)/K²`, so the binary-logit calibration is `2·(K-1)/K² = 1/2` and the
154/// three-class calibration is `4/9` rather than the historical hard-coded
155/// `1/2`. Making the scale a function of `K` keeps the physical smoothness
156/// prior tied to the likelihood curvature instead of over-penalizing every
157/// class as the simplex gains categories.
158fn multinomial_formula_penalty_scale(n_classes: usize) -> f64 {
159    let k = n_classes.max(2) as f64;
160    2.0 * (k - 1.0) / (k * k)
161}
162
163/// Largest smoothing-parameter dimension where exact dense outer curvature is
164/// still worth paying for multinomial formula fits.
165///
166/// `D = (K - 1) * n_penalties`. Medium-size loaded models use exact curvature
167/// so the optimizer does not wander into over-smoothed lambda caps on
168/// near-boundary softmax surfaces. The threshold was originally calibrated at
169/// `D <= 6` when each `s()` term carried ONE penalty; the double-penalty
170/// migration (wiggliness + null-space shrinkage per term, mgcv `select=TRUE`
171/// semantics) doubled `D` for the SAME models, silently flipping the
172/// reference formula fits (2 smooths, K = 3: old `D = 4`, now `D = 8`) onto
173/// the gradient-only route — where the #715 quality arm showed every
174/// wiggliness ρ driven onto the ±10 box bound (smooths collapsed toward their
175/// polynomial null space, truth-RMSE behind VGAM). `12 = 2 × 6` preserves the
176/// original classification boundary under the doubled penalty count while
177/// keeping the four-smooth penguin species quality fixture on the exact ARC
178/// path: that model is `D = 16`, and first-order BFGS can cycle along the
179/// near-separable lambda-to-zero ridge until the wall-clock budget expires
180/// (#1082). ARC observes the same exact curvature and can halt through the
181/// bound-aware cost-stall guard once the REML surface stops making useful
182/// progress.
183const MULTINOMIAL_EXACT_OUTER_HESSIAN_MAX_DIM: usize = 16;
184
185fn multinomial_formula_use_outer_hessian(total_rho_dim: usize) -> bool {
186    total_rho_dim <= MULTINOMIAL_EXACT_OUTER_HESSIAN_MAX_DIM
187}
188
189/// Logit magnitude beyond which fitted probabilities are saturated at ordinary
190/// double precision diagnostic scale. The bare fixed-λ driver has no outer REML
191/// state and still uses this threshold to reject a non-converged saturated
192/// iterate as a separation artifact. The formula REML path does not use this as
193/// a Firth trigger: with smoothing parameters selected, a finite saturated
194/// surface can be the valid near-separated optimum that should be scored
195/// directly.
196const MULTINOMIAL_SEPARATION_ETA_THRESHOLD: f64 = 25.0;
197
198/// Calibrated convergence tolerance for the OUTER REML/LAML smoothing-parameter
199/// search on the formula multinomial path. Matches the primary GLM REML outer
200/// (`solver::fit_orchestration::materialize` uses `tol = 1e-7`, mirrored by the
201/// `LOG_LAMBDA_TOL` / `KKT_TOL_*` constants across the REML stack): tight enough
202/// that the selected λ reaches the genuine REML optimum (the recovered
203/// probability surface matches the mature reference), loose enough that the
204/// optimizer does not grind surface-irrelevant ρ digits down to the inner KKT
205/// scale (the #1082 wall-clock overrun). The caller's `tol` is floored at this
206/// value for the OUTER loop, while it continues to drive the INNER joint-Newton
207/// KKT target unchanged.
208const MULTINOMIAL_OUTER_REML_TOL: f64 = 1e-7;
209
210/// The first multinomial formula solve is a separation probe: it is accepted
211/// when the unbiased REML criterion converges to a finite interior iterate.
212/// Near-separable data such as the penguin fixture otherwise spend the caller's
213/// full outer budget on an iterate that is discarded before the Firth/Jeffreys
214/// refit. Keep enough iterations for ordinary interior fits to certify quickly,
215/// but hand slow/non-interior probes to the proper-prior refit promptly.
216const MULTINOMIAL_UNBIASED_PROBE_OUTER_MAX_ITER: usize = 20;
217
218/// Per-observation softmax Fisher-information scale for the λ-floor units.
219///
220/// The penalty enters the criterion as `½ λ βᵀ S β` with a Frobenius-normalized
221/// `S` (`‖S‖_F = 1`, see the term-builder calibration referenced by
222/// [`multinomial_formula_penalty_scale`]), so the ridge `λ S` is directly
223/// comparable to data Fisher information. One observation contributes softmax
224/// information `p(1−p)` in a class's logit direction, which is bounded by the
225/// logistic peak `p(1−p) ≤ ¼` at `p = ½`. Using this maximal per-observation
226/// information as the unit makes the floor's strength interpretable as a count
227/// of equivalent **pseudo-observations** of prior: a ridge that equals
228/// `τ · ¼ · ‖S‖_F` carries the same logit-direction curvature as `τ` real rows
229/// sitting at the most-informative point of the likelihood. This scale is
230/// `K`-independent on purpose — the `K`-dependence of the softmax block
231/// curvature already lives in the penalty matrix via
232/// [`multinomial_formula_penalty_scale`], so the floor (a bound on the
233/// multiplier of that already-scaled penalty) must not double-count it.
234const MULTINOMIAL_FORMULA_FISHER_INFO_PER_OBS: f64 = 0.25;
235
236/// Target prior strength of the λ-floor, in pseudo-observations, for a
237/// WELL-SUPPORTED class. The floor holds the unbiased REML optimizer off the
238/// zero-penalty boundary (where a boundary-overfit smooth or a Firth switch on
239/// finite data would otherwise be accepted) with a prior worth a fixed small
240/// fraction of one observation. `8e-4` pseudo-observations reproduces the
241/// previously fixture-calibrated large-support floor `τ · ¼ = 2e-4` exactly at
242/// the calibration point, now expressed as an effective-prior-strength rather
243/// than a tuned λ value.
244const MULTINOMIAL_FORMULA_PRIOR_PSEUDO_OBS: f64 = 8.0e-4;
245
246/// Reference class support `n_ref`: the effective sample size per class at which
247/// the data Fisher information `n_c · I₁` is large enough that the floor sits at
248/// its well-supported value. Below `n_ref` the per-class data information shrinks
249/// like `n_c`, so to keep the floor's prior from vanishing *relative to* that
250/// shrinking data the effective pseudo-observation count is scaled up by
251/// `n_ref / n_c` (the prior is held to a fixed fraction of the data information,
252/// not a fixed absolute λ). At `n_c = n_ref` the scale is exactly 1.
253const MULTINOMIAL_FORMULA_SPARSE_REFERENCE_SUPPORT: f64 = 50.0;
254
255/// Cap on the floor's prior strength in the very-sparse limit, in
256/// pseudo-observations. As `n_c → 0` the `n_ref / n_c` scaling diverges; the cap
257/// holds the prior at `4e-3` pseudo-observations (`τ_max · ¼ = 1e-3` at the
258/// calibration point, the previously-tuned strong-floor value) so the floor
259/// stays a proper prior rather than a hard constraint that would dominate the
260/// likelihood for a handful-of-rows class.
261const MULTINOMIAL_FORMULA_SPARSE_PRIOR_PSEUDO_OBS_MAX: f64 = 4.0e-3;
262
263/// Continuous, Fisher-information-scaled lower λ floor for the formula path,
264/// derived from the minority class's effective sample size `n_c`.
265///
266/// # Derivation (effective-prior-strength / Fisher geometry)
267///
268/// The penalty `½ λ βᵀ S β` with `‖S‖_F = 1` adds curvature `λ` to the class
269/// logit direction; one observation adds at most `I₁ = ¼` there. So a floor that
270/// sets `λ_floor = τ_eff · I₁` gives the smooth a prior worth `τ_eff`
271/// pseudo-observations. We want a fixed *absolute* prior `τ` for a well-supported
272/// class, but for a minority class with only `n_c` effective observations the
273/// data information in its block is `n_c · I₁`; holding the prior to a fixed
274/// *fraction* of that shrinking data information requires
275///
276/// ```text
277///     τ_eff(n_c) = τ · max(1, n_ref / n_c),   clamped to [τ, τ_max]
278///     λ_floor(n_c) = τ_eff(n_c) · I₁
279/// ```
280///
281/// This is the *same* `base · max(1, c0/c)` envelope as before — but `base`,
282/// `sparse`, and `c0` are no longer fixture-tuned magic numbers: `base = τ·I₁`,
283/// `sparse = τ_max·I₁`, and `c0 = n_ref` are an effective-prior-strength of
284/// `τ`/`τ_max` pseudo-observations against the maximal per-observation softmax
285/// information `I₁ = ¼`. Properties preserved by construction:
286///   * reduces EXACTLY to `τ·I₁` for well-supported classes (`n_c ≥ n_ref`);
287///   * reduces EXACTLY to `τ_max·I₁` for very sparse classes
288///     (`n_c ≤ n_ref·τ/τ_max`, here `n_c ≤ 10`);
289///   * interpolates monotonically and continuously between them in the middle —
290///     no cliff at `n_c = n_ref`.
291/// At the calibration point the endpoints equal the previous `2e-4` / `1e-3`, so
292/// fixtures whose smallest class has `n_c ≥ 50` (penguins, the vgam softmax
293/// arms) are unaffected — they sit at `τ·I₁ = 2e-4` exactly as before.
294fn multinomial_formula_min_lambda(y_one_hot: ArrayView2<'_, f64>) -> f64 {
295    let base = MULTINOMIAL_FORMULA_PRIOR_PSEUDO_OBS * MULTINOMIAL_FORMULA_FISHER_INFO_PER_OBS;
296    let sparse =
297        MULTINOMIAL_FORMULA_SPARSE_PRIOR_PSEUDO_OBS_MAX * MULTINOMIAL_FORMULA_FISHER_INFO_PER_OBS;
298    let min_class_count = (0..y_one_hot.ncols())
299        .map(|class| y_one_hot.column(class).sum())
300        .fold(f64::INFINITY, f64::min);
301    if !min_class_count.is_finite() || min_class_count <= 0.0 {
302        return base;
303    }
304    // Effective pseudo-observation prior strength: held to a fixed fraction of
305    // the shrinking per-class data information once n_c falls below n_ref.
306    let pseudo_obs_scale =
307        (MULTINOMIAL_FORMULA_SPARSE_REFERENCE_SUPPORT / min_class_count).max(1.0);
308    (base * pseudo_obs_scale).clamp(base, sparse)
309}
310
311fn max_abs_eta_location(eta: ArrayView2<'_, f64>) -> (f64, usize, usize) {
312    let mut best = (0.0_f64, 0usize, 0usize);
313    for ((row, active_class), &value) in eta.indexed_iter() {
314        let abs = value.abs();
315        if abs > best.0 {
316            best = (abs, row, active_class);
317        }
318    }
319    best
320}
321
322/// Separation gate for the REML/LAML **formula** path.
323///
324/// Unlike the bare fixed-λ driver [`fit_penalized_multinomial`] (which has no
325/// outer REML state and so must reject a saturated, non-converged iterate as a
326/// separation artifact at the [`MULTINOMIAL_SEPARATION_ETA_THRESHOLD`] logit
327/// magnitude), the formula path can return a finite saturated mode after the
328/// coupled outer optimizer has selected smoothing parameters. A `|η| >= 25`
329/// gate is therefore wrong here: the penguins arm can legitimately have large
330/// fitted logits while still producing finite probabilities and a usable REML
331/// mode.
332///
333/// Only a genuinely NON-FINITE `η` (a NaN/Inf blow-up in the inner linear
334/// algebra) is a real formula-path failure. A finite, even saturated, `η` is
335/// accepted so the truth-recovery / match-or-beat bars are evaluated against the
336/// actual fitted surface instead of an adapter diagnostic.
337fn multinomial_formula_separation_diagnostic(
338    inner_cycles: usize,
339    outer_iterations: usize,
340    block_states: &[ParameterBlockState],
341) -> Option<EstimationError> {
342    let mut nonfinite: Option<(f64, usize, usize)> = None;
343    for (active_class, state) in block_states.iter().enumerate() {
344        for (row, &value) in state.eta.iter().enumerate() {
345            if !value.is_finite() {
346                nonfinite = Some((value, row, active_class));
347                break;
348            }
349        }
350        if nonfinite.is_some() {
351            break;
352        }
353    }
354    nonfinite.map(|(value, row_index, active_class_index)| {
355        EstimationError::MultinomialSeparationDetected {
356            iteration: inner_cycles.max(outer_iterations),
357            max_abs_eta: value.abs(),
358            active_class_index,
359            row_index,
360        }
361    })
362}
363
364/// Separation EVIDENCE gate for the conditional Firth/Jeffreys engagement on
365/// the formula REML path (#715 / #753).
366///
367/// The structural mathematics (#715 issue thread): for any coefficient
368/// direction `v` with `S v = 0` (a penalty-null direction — intercept, a
369/// smooth's polynomial null component, an unpenalized parametric term), the
370/// penalized joint Hessian satisfies `(H + S_λ) v = H v` for EVERY smoothing
371/// parameter ρ. When the data (quasi-)separate, the softmax Fisher weight
372/// `W = diag(p) − p pᵀ → 0` on the saturated rows, so `H v = JᵀWJ v → 0` along
373/// the penalty-null directions those rows support: `(H + S_λ) v ≈ 0` for every
374/// ρ — NO λ can repair it, the inner Newton can never certify a KKT point
375/// there, and every outer REML startup seed is rejected (the penguins
376/// real-data arm). The only principled cure is a PROPER prior on that
377/// quotient-null subspace — the Jeffreys/Firth term `Φ = ½ log|ZᵀHZ|`, whose
378/// Gauss–Newton curvature supplies the missing `O(1)` bound.
379///
380/// But the Firth prior is not free on interior data: unconditionally armed, it
381/// shrinks fitted class probabilities toward the uniform simplex `1/K`
382/// (an `O(1/n)` pull that the synthetic match-or-beat arm of #715 measured as
383/// a real truth-RMSE loss vs the unbiased criterion). So the formula path
384/// engages it ONLY on separation evidence, mirroring the #753 "diagnose, then
385/// arm" split:
386///
387/// * a NON-FINITE logit — the inner linear algebra blew up along an unbounded
388///   direction.
389///
390/// Returns `Some(description)` naming the witnessing logit when evidence is
391/// found, `None` for a finite fit (which is then accepted as-is, with zero
392/// Firth bias). A FAILED unbiased solve (`Err` from the rho-prior driver, e.g.
393/// "no startup seed passed") is the second evidence form and is handled
394/// directly at the call site in [`fit_penalized_multinomial_formula`].
395fn multinomial_formula_separation_evidence(block_states: &[ParameterBlockState]) -> Option<String> {
396    for (active_class, state) in block_states.iter().enumerate() {
397        for (row, &value) in state.eta.iter().enumerate() {
398            if !value.is_finite() {
399                return Some(format!(
400                    "non-finite logit eta[row {row}, active class {active_class}] = {value}"
401                ));
402            }
403        }
404    }
405    None
406}
407
408/// Extra evidence used only for a NON-CONVERGED capped unbiased probe.
409///
410/// A converged finite saturated formula fit is still a valid optimum and must be
411/// scored without Firth bias. A capped probe that failed to converge while it
412/// already carries separation-scale logits is different: spending the full
413/// unbiased outer budget on the same lambda-to-zero surface is the #1082
414/// timeout. Route that case straight to the proper-prior refit.
415fn multinomial_formula_unresolved_probe_separation_evidence(
416    block_states: &[ParameterBlockState],
417) -> Option<String> {
418    if let Some(evidence) = multinomial_formula_separation_evidence(block_states) {
419        return Some(evidence);
420    }
421
422    let mut best = (0.0_f64, 0usize, 0usize);
423    for (active_class, state) in block_states.iter().enumerate() {
424        for (row, &value) in state.eta.iter().enumerate() {
425            let abs = value.abs();
426            if abs > best.0 {
427                best = (abs, row, active_class);
428            }
429        }
430    }
431    if best.0 >= MULTINOMIAL_SEPARATION_ETA_THRESHOLD {
432        Some(format!(
433            "separation-scale finite logit |eta[row {}, active class {}]| = {:.3e} \
434             after capped unbiased probe",
435            best.1, best.2, best.0
436        ))
437    } else {
438        None
439    }
440}
441
442/// Inputs to [`fit_penalized_multinomial`].
443///
444/// The penalty matrix `S` is shared across classes; per-class smoothing
445/// parameters `lambdas` (length `K - 1`) scale `S` independently for each
446/// active class. The full block-replicated penalty is `diag_a(λ_a) ⊗ S`,
447/// which is exactly what [`gam_solve::arrow_schur::KroneckerPenaltyOp`]
448/// expresses in matrix-free form when this driver is later lifted into the
449/// arrow-Schur loop.
450#[derive(Debug, Clone)]
451pub struct MultinomialFitInputs<'a> {
452    /// Design matrix `X ∈ ℝ^{N×P}` (one row per observation).
453    pub design: ArrayView2<'a, f64>,
454    /// Categorical response `Y ∈ ℝ^{N×K}`. Each row must be a point on the
455    /// probability simplex (`y_c ≥ 0`, `Σ_c y_c = 1`): a one-hot indicator for
456    /// hard classification, or a label-smoothed probability vector. Rows whose
457    /// mass departs from 1 are rejected — the softmax residual gradient and
458    /// Fisher block are the derivatives of `Σ_c y_c log p_c` only under the
459    /// simplex constraint (see `validate_multinomial_simplex`).
460    pub y_one_hot: ArrayView2<'a, f64>,
461    /// Shared smoothing penalty `S ∈ ℝ^{P×P}` (symmetric, PSD).
462    pub penalty: ArrayView2<'a, f64>,
463    /// Per-active-class smoothing parameter `λ_a` (length `K - 1`).
464    pub lambdas: ArrayView1<'a, f64>,
465    /// Optional per-row weights (length `N`); `None` ⇒ uniform 1.0.
466    pub row_weights: Option<ArrayView1<'a, f64>>,
467    /// Optional per-row Fisher-block override, shape `(N, K-1, K-1)` in the
468    /// active-class gauge (the reference class `K-1` is dropped). When `Some`,
469    /// each Newton step uses this block as the curvature `W` in place of the
470    /// analytic softmax Fisher `w_n (δ_ab p_a − p_a p_b)`; the gradient/residual
471    /// path stays analytic, so this is a curvature-only override (the
472    /// research escape-hatch for latent multinomial fits, issue #349). Each
473    /// per-row block must be symmetric, PSD, and finite — preconditions the
474    /// FFI boundary discharges before constructing this view.
475    pub fisher_w_override: Option<ArrayView3<'a, f64>>,
476    /// Maximum Newton iterations; recommend 50.
477    pub max_iter: usize,
478    /// Relative-step convergence tolerance; recommend 1e-7.
479    pub tol: f64,
480}
481
482/// Outputs of [`fit_penalized_multinomial`].
483#[derive(Debug, Clone)]
484pub struct MultinomialFitOutputs {
485    /// Active-class coefficient block, shape `(P, K-1)` (column `a` is `β_a`).
486    /// The reference class `K - 1` has `β_{K-1} ≡ 0` by construction and is
487    /// not stored.
488    pub coefficients_active: Array2<f64>,
489    /// Fitted probabilities, shape `(N, K)`.
490    pub fitted_probabilities: Array2<f64>,
491    /// Number of Newton iterations executed (including the final step that
492    /// satisfied the tolerance).
493    pub iterations: usize,
494    /// `true` if the relative-step test was satisfied; `false` if the
495    /// solver exhausted `max_iter`. (A non-converged solve is still
496    /// returned; the caller decides whether to escalate.)
497    pub converged: bool,
498    /// Penalized negative log-likelihood at the returned `β̂`:
499    /// `−log L(β̂) + ½ Σ_a λ_a · β̂_a^T S β̂_a`.
500    pub penalized_neg_log_likelihood: f64,
501    /// Unpenalized deviance `−2 log L(β̂)` for diagnostic reporting.
502    pub deviance: f64,
503    /// Joint Laplace posterior coefficient covariance `H⁻¹` at the converged
504    /// `β̂`, shape `(P·(K−1))×(P·(K−1))` (#1101). Block-ordered to match the
505    /// stacked active-class coefficient vector `β = [β_0; …; β_{K-2}]`: active
506    /// class `a`'s `P` coefficients occupy rows/cols `a·P .. (a+1)·P`, indexed
507    /// `θ[a·P + i] = β̂[i, a]`. This is the Laplace covariance from the factored
508    /// penalized Hessian `XᵀWX + diag_a(λ_a)⊗S`; it drives the delta-method
509    /// per-class probability standard errors ([`Self::predict_probabilities_with_se`])
510    /// on the fixed-λ inner-solve path.
511    pub coefficient_covariance: Array2<f64>,
512}
513
514impl MultinomialFitOutputs {
515    /// Number of active classes `M = K − 1` (columns of
516    /// [`Self::coefficients_active`]).
517    pub fn n_active_classes(&self) -> usize {
518        self.coefficients_active.ncols()
519    }
520
521    /// Per-class coefficient dimension `P` (rows of
522    /// [`Self::coefficients_active`]).
523    pub fn p_per_class(&self) -> usize {
524        self.coefficients_active.nrows()
525    }
526
527    /// Evaluate `softmax(X·β̂)` AND its delta-method per-class probability
528    /// standard error at fresh design rows `X_new` (#1101), using the joint
529    /// Laplace covariance [`Self::coefficient_covariance`].
530    ///
531    /// The softmax Jacobian is `∂p_c/∂η_b = p_c (δ_{cb} − p_b)` for active class
532    /// `b ∈ 0..M`, and `∂η_b/∂β[i,a] = X[i]·δ_{ab}`, so the gradient of the
533    /// class-`c` probability w.r.t. the block-ordered coefficient vector is
534    /// `g_c[a·P + i] = X[i]·p_c (δ_{ca} − p_a)` (the reference class `M`
535    /// contributes only through `−p_a` in every active block). The delta-method
536    /// variance is `Var(p_c) = g_cᵀ Σ g_c` with `Σ = H⁻¹`, and
537    /// `SE(p_c) = √Var(p_c)`. Returns `(probs (N,K), prob_se (N,K))`. `X_new`
538    /// must have `P` columns (the same design basis used at fit time); its row
539    /// count sets `N`. The SE is unclamped (the interval consumer applies the
540    /// simplex `[0,1]` clamp).
541    pub fn predict_probabilities_with_se(
542        &self,
543        x_new: ArrayView2<'_, f64>,
544    ) -> Result<(Array2<f64>, Array2<f64>), EstimationError> {
545        let p = self.p_per_class();
546        let m = self.n_active_classes();
547        let k = m + 1;
548        if x_new.ncols() != p {
549            crate::bail_invalid_estim!(
550                "predict_probabilities_with_se: X has {} cols, expected P={p}",
551                x_new.ncols()
552            );
553        }
554        let d = p * m;
555        let cov = &self.coefficient_covariance;
556        if cov.dim() != (d, d) {
557            crate::bail_invalid_estim!(
558                "predict_probabilities_with_se: covariance shape {:?} ≠ (P·M, P·M) = ({d}, {d})",
559                cov.dim()
560            );
561        }
562        let n_new = x_new.nrows();
563        let beta = &self.coefficients_active;
564        let mut probs = Array2::<f64>::zeros((n_new, k));
565        let mut prob_se = Array2::<f64>::zeros((n_new, k));
566        let mut eta_active = vec![0.0_f64; m];
567        let mut row_probs = vec![0.0_f64; k];
568        let mut grad = vec![0.0_f64; d];
569        for row in 0..n_new {
570            for a in 0..m {
571                let mut v = 0.0_f64;
572                for i in 0..p {
573                    v += x_new[[row, i]] * beta[[i, a]];
574                }
575                eta_active[a] = v;
576            }
577            MultinomialLogitLikelihood::softmax_with_baseline(&eta_active, &mut row_probs);
578            for c in 0..k {
579                probs[[row, c]] = row_probs[c];
580            }
581            for c in 0..k {
582                let pc = row_probs[c];
583                // g_c[a·P + i] = X[i] · p_c · (δ_{ca} − p_a), a active.
584                for a in 0..m {
585                    let pa = row_probs[a];
586                    let factor = pc * (if c == a { 1.0 - pa } else { -pa });
587                    let base = a * p;
588                    for i in 0..p {
589                        grad[base + i] = x_new[[row, i]] * factor;
590                    }
591                }
592                // Var = gᵀ Σ g.
593                let mut var = 0.0_f64;
594                for r in 0..d {
595                    let gr = grad[r];
596                    if gr == 0.0 {
597                        continue;
598                    }
599                    let mut acc = 0.0_f64;
600                    for s in 0..d {
601                        acc += cov[[r, s]] * grad[s];
602                    }
603                    var += gr * acc;
604                }
605                prob_se[[row, c]] = var.max(0.0).sqrt();
606            }
607        }
608        Ok((probs, prob_se))
609    }
610}
611
612/// Fit a penalized multinomial-logit GAM at fixed `λ`.
613///
614/// See the module docs for the optimization problem and conventions. This
615/// function is the canonical inner solve: the outer REML/LAML loop, when
616/// added, calls this at each `ρ = log λ` trial.
617pub fn fit_penalized_multinomial(
618    inputs: MultinomialFitInputs<'_>,
619) -> Result<MultinomialFitOutputs, EstimationError> {
620    let MultinomialFitInputs {
621        design,
622        y_one_hot,
623        penalty,
624        lambdas,
625        row_weights,
626        fisher_w_override,
627        max_iter,
628        tol,
629    } = inputs;
630
631    // ──────────────────────── family-specific validation ───────────────────
632    // The shared engine re-validates the geometry common to every vector-GLM
633    // (nonempty design, penalty shape, λ finiteness/non-negativity, override
634    // `(N, M, M)` shape, finite design). The multinomial family owns the
635    // class-count contract (`K ≥ 2`, λ length `K − 1`), the per-row simplex
636    // precondition under which the softmax residual/Fisher are the exact
637    // derivatives of `Σ_c y_c log p_c`, and the row-weight check the likelihood
638    // adapter consumes.
639    let n_obs = design.nrows();
640    let (y_rows, k) = y_one_hot.dim();
641    if y_rows != n_obs {
642        crate::bail_invalid_estim!(
643            "fit_penalized_multinomial: y rows {y_rows} ≠ design rows {n_obs}"
644        );
645    }
646    if k < 2 {
647        crate::bail_invalid_estim!(
648            "fit_penalized_multinomial: need at least 2 classes (got K={k})"
649        );
650    }
651    let m = k - 1;
652    if lambdas.len() != m {
653        crate::bail_invalid_estim!(
654            "fit_penalized_multinomial: lambdas length {} ≠ K-1 = {m}",
655            lambdas.len()
656        );
657    }
658    if let Some(fw) = fisher_w_override.as_ref() {
659        if fw.dim() != (n_obs, m, m) {
660            crate::bail_invalid_estim!(
661                "fit_penalized_multinomial: fisher_w_override shape {:?} ≠ (N, K-1, K-1) = ({n_obs}, {m}, {m})",
662                fw.dim()
663            );
664        }
665    }
666    if let Some(w) = row_weights.as_ref() {
667        if w.len() != n_obs {
668            crate::bail_invalid_estim!(
669                "fit_penalized_multinomial: row_weights length {} ≠ N = {n_obs}",
670                w.len()
671            );
672        }
673        for (i, &v) in w.iter().enumerate() {
674            if !(v.is_finite() && v >= 0.0) {
675                crate::bail_invalid_estim!(
676                    "fit_penalized_multinomial: row_weights[{i}] must be finite and ≥ 0 (got {v})"
677                );
678            }
679        }
680    }
681    validate_multinomial_simplex(y_one_hot, "fit_penalized_multinomial")?;
682
683    // ────────────────────────── likelihood construction ───────────────────
684    let mut likelihood = MultinomialLogitLikelihood::with_classes(k)?;
685    if let Some(w) = row_weights.as_ref() {
686        likelihood = likelihood.with_row_weights(w.to_owned())?;
687    }
688
689    // ─────────────────── shared penalized vector-GLM solve ─────────────────
690    // The softmax Fisher block is dense across the `M = K − 1` active classes;
691    // the engine assembles the coupled `(P·M)×(P·M)` penalized Hessian, runs
692    // the damped Newton loop, and returns the converged `β̂` and `η = X β̂`.
693    let fit = fit_penalized_vector_glm(
694        PenalizedVectorGlmInputs {
695            design,
696            y: y_one_hot,
697            penalty,
698            lambdas,
699            fisher_w_override,
700            max_iter,
701            tol,
702            // #1587: production multinomial still uses the per-class Diagonal
703            // metric pending the REML per-class→per-term λ re-key that the
704            // reference-symmetric Centered metric requires (shared λ). The
705            // Centered engine path + its invariance proof land first.
706            class_penalty_metric: crate::penalized_vector_glm::ClassPenaltyMetric::Diagonal,
707        },
708        &likelihood,
709        "fit_penalized_multinomial",
710    )?;
711
712    let (max_abs_eta, row_index, active_class_index) = max_abs_eta_location(fit.eta.view());
713    if !fit.converged && max_abs_eta >= MULTINOMIAL_SEPARATION_ETA_THRESHOLD {
714        // Perfect / quasi-perfect separation (#1854): the UNBIASED softmax MLE is
715        // not finite along `active_class_index`'s saturated logit direction, so
716        // the fixed-λ Newton above ran away (`|η| ≥ 25`, no convergence). A
717        // penalty-null direction `v` (`S v = 0`, e.g. an unpenalized intercept /
718        // linear-covariate column) under softmax saturation has
719        // `(XᵀWX + λS) v → 0` for EVERY λ, so no smoothing parameter can bound it
720        // — only a proper prior on that quotient-null subspace can. Rather than
721        // hard-erroring, engage the Firth/Jeffreys proper prior automatically
722        // (magic-by-default): the full-span `½ log|I(β)|` correction supplies the
723        // `O(1)` curvature that keeps the estimate finite on exactly those
724        // separated directions while leaving well-identified fits untouched. This
725        // reuses the same coupled joint-Newton Jeffreys machinery the formula
726        // REML path arms on separation evidence (see
727        // `fit_penalized_multinomial_formula`), only here at the caller's fixed λ.
728        // Engage the fallback, but never let an internal consistency panic in
729        // the coupled joint-Newton assembly (e.g. the #1395 logdet-collapse
730        // guard) escape as a process abort: convert any panic into the
731        // documented hard separation diagnostic, exactly as if the refit had
732        // returned Err. This mirrors the catch_unwind panic-to-typed-error
733        // boundary already used around the faer / cudarc entry points, and keeps
734        // the separation path no worse than the pre-#1854 clean error while the
735        // Firth refit is still being hardened.
736        let firth = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
737            fit_penalized_multinomial_firth_fallback(
738                design,
739                y_one_hot,
740                penalty,
741                lambdas,
742                row_weights,
743                max_iter,
744                tol,
745            )
746        }));
747        match firth {
748            Ok(Ok(out)) => return Ok(out),
749            // Firth refit errored, or an internal consistency guard panicked:
750            // fall back to the explicit hard separation diagnostic.
751            Ok(Err(_)) | Err(_) => {
752                return Err(EstimationError::MultinomialSeparationDetected {
753                    iteration: fit.iterations,
754                    max_abs_eta,
755                    active_class_index,
756                    row_index,
757                });
758            }
759        }
760    }
761
762    let fitted_probabilities = likelihood.probabilities(fit.eta.view());
763
764    Ok(MultinomialFitOutputs {
765        coefficients_active: fit.coefficients,
766        fitted_probabilities,
767        iterations: fit.iterations,
768        converged: fit.converged,
769        penalized_neg_log_likelihood: -fit.log_likelihood + fit.penalty_term,
770        deviance: -2.0 * fit.log_likelihood,
771        coefficient_covariance: fit.coefficient_covariance,
772    })
773}
774
775/// Firth/Jeffreys-penalized multinomial refit engaged automatically when the
776/// unbiased softmax MLE separates (#1854).
777///
778/// The unbiased fixed-λ solve ([`fit_penalized_multinomial`]) runs away on
779/// (quasi-)separated data because the softmax likelihood has no finite mode along
780/// the saturated logit direction and the smoothing penalty `S` cannot bound a
781/// penalty-null direction (`S v = 0` ⇒ `(XᵀWX + λS) v → 0` for every λ). This
782/// refit arms the full-span Jeffreys/Firth proper prior `½ log|I(β)|` on the
783/// coupled joint softmax information, which supplies the `O(1)` curvature that
784/// bounds exactly those directions and keeps the estimate finite.
785///
786/// # The estimator
787///
788/// It maximizes the penalized Firth objective at the caller's *fixed* `λ`
789///
790/// ```text
791///   ℓ*(β) = Σ_n w_n Σ_c y_{nc} log p_{nc}
792///           − ½ Σ_a λ_a βₐᵀ S βₐ
793///           + ½ log det I(β)
794/// ```
795///
796/// where `I(β)` is the coupled `(P·M)×(P·M)` softmax Fisher information (block
797/// `(a,b)` is `Σ_n w_n (δ_{ab} p_{na} − p_{na} p_{nb}) x_n x_nᵀ`, block-ordered so
798/// `θ[a·P+i] = β[i,a]`) and `M = K−1` active classes carry the reference-coded
799/// logits (`η_{ref} ≡ 0`). The Jeffreys term `½ log det I(β)` is the standard
800/// Firth penalty: it diverges to `−∞` as any fitted probability approaches the
801/// simplex boundary (`I → 0`), so its maximizer is interior and finite on exactly
802/// the separated directions that defeat every smoothing `λ`.
803///
804/// # Why this fixed-λ solver rather than the outer-REML formula path
805///
806/// The direct entry ([`fit_penalized_multinomial`]) is a fixed-λ inner solve — it
807/// carries no outer smoothing selection — so the natural Firth engagement is a
808/// fixed-λ Firth Newton, not the formula path's outer-REML joint-Newton machinery
809/// (which is armed instead by [`fit_penalized_multinomial_formula`] on separation
810/// evidence). Solving the Firth objective directly here keeps the separation
811/// contract self-contained and independent of the shared trust-region/KKT
812/// certificate machinery.
813///
814/// # The iteration
815///
816/// A Fisher-scoring Newton on `ℓ*`: the ascent direction is
817/// `Δ = (I + Λ⊗S)⁻¹ U*`, where `U*` is the Firth-adjusted penalized score
818///
819/// ```text
820///   U*[(c,s)] = Σ_n w_n x_{ns} (y_{nc} − p_{nc})       (data score)
821///             − λ_c (S β_c)_s                           (smoothing penalty)
822///             + ½ Σ_n w_n x_{ns} h^c_n                  (Firth adjustment)
823/// ```
824///
825/// and the Firth adjustment uses `h^c_n = Σ_{a,b} G^c_{n,ab} Q_{n,ab}` with the
826/// per-row information "hat" `Q_{n,ab} = x_nᵀ [I⁻¹]_{(a,b)} x_n` and the softmax
827/// third-derivative tensor
828/// `G^c_{ab} = δ_{ab} p_a (δ_{ac} − p_c) − p_a p_b (δ_{ac} + δ_{bc} − 2 p_c)`.
829/// This `½ Σ tr(I⁻¹ ∂I/∂β)` is exactly `∇[½ log det I]` (finite-difference
830/// verified). Each step is globalized by backtracking on `ℓ*`, so a step that
831/// would push a probability to the boundary (making `I` non-PD) is rejected and
832/// the fit stays interior. Convergence is the Newton decrement `½ U*ᵀΔ`.
833fn fit_penalized_multinomial_firth_fallback(
834    design: ArrayView2<'_, f64>,
835    y_one_hot: ArrayView2<'_, f64>,
836    penalty: ArrayView2<'_, f64>,
837    lambdas: ArrayView1<'_, f64>,
838    row_weights: Option<ArrayView1<'_, f64>>,
839    max_iter: usize,
840    tol: f64,
841) -> Result<MultinomialFitOutputs, EstimationError> {
842    use faer::Side;
843    use gam_linalg::faer_ndarray::{
844        FaerArrayView, array1_to_col_matmut, array2_to_matmut, factorize_symmetricwith_fallback,
845    };
846    use gam_linalg::matrix::FactorizedSystem;
847
848    let n_obs = design.nrows();
849    let p = design.ncols();
850    let k = y_one_hot.ncols();
851    let m = k - 1;
852    let d = p * m;
853
854    // Local softmax likelihood mirroring the caller's row weights, used to map the
855    // fitted η back to probabilities.
856    let mut likelihood = MultinomialLogitLikelihood::with_classes(k)?;
857    if let Some(w) = row_weights.as_ref() {
858        likelihood = likelihood.with_row_weights(w.to_owned())?;
859    }
860    let weight = |row: usize| -> f64 { row_weights.as_ref().map_or(1.0, |w| w[row]) };
861
862    let max_iter = max_iter.max(1);
863    let tol_eff = if tol.is_finite() && tol > 0.0 { tol } else { 1e-8 };
864
865    // Probabilities (N, K), active classes 0..M then the pinned reference at M.
866    let probs_at = |beta: &Array2<f64>| -> Array2<f64> {
867        let eta = design.dot(beta);
868        likelihood.probabilities(eta.view())
869    };
870
871    // Coupled softmax Fisher information I (d×d), block-ordered θ[a·P+i] = β[i,a].
872    let assemble_info = |probs: &Array2<f64>| -> Array2<f64> {
873        let mut info = Array2::<f64>::zeros((d, d));
874        for row in 0..n_obs {
875            let w = weight(row);
876            if w == 0.0 {
877                continue;
878            }
879            for a in 0..m {
880                let pa = probs[[row, a]];
881                let ao = a * p;
882                for b in 0..m {
883                    let pb = probs[[row, b]];
884                    let wab = w * (if a == b { pa - pa * pb } else { -pa * pb });
885                    if wab == 0.0 {
886                        continue;
887                    }
888                    let bo = b * p;
889                    for i in 0..p {
890                        let xi = design[[row, i]];
891                        if xi == 0.0 {
892                            continue;
893                        }
894                        let cc = wab * xi;
895                        for j in 0..p {
896                            info[[ao + i, bo + j]] += cc * design[[row, j]];
897                        }
898                    }
899                }
900            }
901        }
902        info
903    };
904
905    // Factor a symmetric matrix (with escalating ridge only if it is not SPD) and
906    // return its inverse and log-determinant.
907    //
908    // The ridge ladder is a standard relative-jitter Cholesky recovery, not a
909    // tuned knob: (a) the base jitter is scaled to the matrix by `max_diag`
910    // (`max_diag · ε` with ε at the double-precision Cholesky floor ~1e-10) so it
911    // is invariant to the overall scale of the Fisher information, falling back
912    // to an absolute floor only when the diagonal is degenerate; (b) it is tried
913    // first at ridge 0 so an already-SPD matrix is factored unperturbed; (c) it
914    // grows geometrically (×4) to span the ~120 dB from the base jitter to O(1)
915    // in a bounded number of steps; (d) the attempt count is capped so a
916    // genuinely singular information (e.g. an exactly rank-deficient Fisher block)
917    // surfaces as an explicit error rather than an unbounded loop.
918    let invert_spd = |mat: &Array2<f64>, context: &str| -> Result<(Array2<f64>, f64), EstimationError> {
919        let max_diag = (0..d).fold(0.0_f64, |acc, i| acc.max(mat[[i, i]].abs()));
920        let base = if max_diag.is_finite() && max_diag > 0.0 {
921            max_diag * 1e-10
922        } else {
923            1e-10
924        };
925        let mut ridge = 0.0_f64;
926        for _ in 0..=60 {
927            let mut ridged = mat.clone();
928            if ridge > 0.0 {
929                for i in 0..d {
930                    ridged[[i, i]] += ridge;
931                }
932            }
933            if let Ok(factor) =
934                factorize_symmetricwith_fallback(FaerArrayView::new(&ridged).as_ref(), Side::Lower)
935            {
936                let logdet = factor.logdet();
937                if logdet.is_finite() {
938                    let mut rhs = Array2::<f64>::eye(d);
939                    {
940                        let v = array2_to_matmut(&mut rhs);
941                        factor.solve_in_place(v);
942                    }
943                    if rhs.iter().all(|x| x.is_finite()) {
944                        let mut inv = Array2::<f64>::zeros((d, d));
945                        for i in 0..d {
946                            for j in 0..d {
947                                inv[[i, j]] = 0.5 * (rhs[[i, j]] + rhs[[j, i]]);
948                            }
949                        }
950                        return Ok((inv, logdet));
951                    }
952                }
953            }
954            ridge = if ridge > 0.0 { ridge * 4.0 } else { base };
955        }
956        Err(EstimationError::InvalidInput(format!(
957            "multinomial Firth fallback: {context} not invertible (max_diag={max_diag:.3e})"
958        )))
959    };
960
961    // SPD log-determinant only (no ridge): used by the backtracking line search to
962    // reject any candidate that pushes a fitted probability to the simplex
963    // boundary (where I loses positive-definiteness and the Firth term → −∞).
964    let spd_logdet = |mat: &Array2<f64>| -> Option<f64> {
965        factorize_symmetricwith_fallback(FaerArrayView::new(mat).as_ref(), Side::Lower)
966            .ok()
967            .map(|factor| factor.logdet())
968            .filter(|ld| ld.is_finite())
969    };
970
971    // Penalized Firth objective ℓ* (MAXIMIZED), given probabilities, β, and the
972    // precomputed log det I(β).
973    let objective = |probs: &Array2<f64>, beta: &Array2<f64>, logdet_info: f64| -> f64 {
974        let mut ll = 0.0_f64;
975        for row in 0..n_obs {
976            let w = weight(row);
977            if w == 0.0 {
978                continue;
979            }
980            for c in 0..k {
981                let ycn = y_one_hot[[row, c]];
982                if ycn != 0.0 {
983                    ll += w * ycn * probs[[row, c]].max(f64::MIN_POSITIVE).ln();
984                }
985            }
986        }
987        let mut pen = 0.0_f64;
988        for a in 0..m {
989            let la = lambdas[a];
990            if la != 0.0 {
991                let bcol = beta.column(a);
992                let sbeta = penalty.dot(&bcol);
993                pen += 0.5 * la * bcol.dot(&sbeta);
994            }
995        }
996        ll - pen + 0.5 * logdet_info
997    };
998
999    // Firth-adjusted penalized score U* (length d, block-ordered).
1000    let firth_score = |probs: &Array2<f64>, beta: &Array2<f64>, iinv: &Array2<f64>| -> Array1<f64> {
1001        let mut u = Array1::<f64>::zeros(d);
1002        let mut xn = vec![0.0_f64; p];
1003        let mut pa = vec![0.0_f64; m];
1004        let mut q = vec![0.0_f64; m * m];
1005        for row in 0..n_obs {
1006            let w = weight(row);
1007            if w == 0.0 {
1008                continue;
1009            }
1010            for i in 0..p {
1011                xn[i] = design[[row, i]];
1012            }
1013            for a in 0..m {
1014                pa[a] = probs[[row, a]];
1015            }
1016            // Data score: U[(a,i)] += w x_{ni} (y_{na} − p_{na}).
1017            for a in 0..m {
1018                let resid = y_one_hot[[row, a]] - pa[a];
1019                let ao = a * p;
1020                for i in 0..p {
1021                    u[ao + i] += w * xn[i] * resid;
1022                }
1023            }
1024            // Per-row information hat Q_{ab} = x_nᵀ [I⁻¹]_{(a,b)} x_n.
1025            for a in 0..m {
1026                let ao = a * p;
1027                for b in 0..m {
1028                    let bo = b * p;
1029                    let mut s = 0.0_f64;
1030                    for i in 0..p {
1031                        let xi = xn[i];
1032                        if xi == 0.0 {
1033                            continue;
1034                        }
1035                        let mut inner = 0.0_f64;
1036                        for j in 0..p {
1037                            inner += iinv[[ao + i, bo + j]] * xn[j];
1038                        }
1039                        s += xi * inner;
1040                    }
1041                    q[a * m + b] = s;
1042                }
1043            }
1044            // Firth adjustment: U[(c,s)] += ½ w x_{ns} h^c_n.
1045            for c in 0..m {
1046                let pc = pa[c];
1047                let mut h = 0.0_f64;
1048                for a in 0..m {
1049                    for b in 0..m {
1050                        let dab = if a == b { 1.0 } else { 0.0 };
1051                        let dac = if a == c { 1.0 } else { 0.0 };
1052                        let dbc = if b == c { 1.0 } else { 0.0 };
1053                        let g = dab * pa[a] * (dac - pc)
1054                            - pa[a] * pa[b] * (dac + dbc - 2.0 * pc);
1055                        h += g * q[a * m + b];
1056                    }
1057                }
1058                let co = c * p;
1059                for s in 0..p {
1060                    u[co + s] += 0.5 * w * h * xn[s];
1061                }
1062            }
1063        }
1064        // Smoothing penalty gradient: U[(a,i)] −= λ_a (S β_a)_i.
1065        for a in 0..m {
1066            let la = lambdas[a];
1067            if la != 0.0 {
1068                let sbeta = penalty.dot(&beta.column(a));
1069                let ao = a * p;
1070                for i in 0..p {
1071                    u[ao + i] -= la * sbeta[i];
1072                }
1073            }
1074        }
1075        u
1076    };
1077
1078    // Penalized Hessian H = I + blockdiag_a(λ_a S) (positive definite).
1079    let penalized_hessian = |info: &Array2<f64>| -> Array2<f64> {
1080        let mut h = info.clone();
1081        for a in 0..m {
1082            let la = lambdas[a];
1083            if la != 0.0 {
1084                let ao = a * p;
1085                for i in 0..p {
1086                    for j in 0..p {
1087                        h[[ao + i, ao + j]] += la * penalty[[i, j]];
1088                    }
1089                }
1090            }
1091        }
1092        h
1093    };
1094
1095    // Solve H Δ = U* for the SPD penalized Hessian, ridge-escalating only on
1096    // factorization failure. Same relative-jitter Cholesky-recovery ladder as
1097    // `invert_spd` above (see its comment for the rationale); the base jitter is
1098    // one decade tighter (`max_diag · 1e-12`) because the penalized Hessian
1099    // solved here is better conditioned than the Fisher information inverted
1100    // there, so a smaller perturbation suffices before escalating.
1101    let solve_spd = |mat: &Array2<f64>, rhs: &Array1<f64>| -> Result<Array1<f64>, EstimationError> {
1102        let max_diag = (0..d).fold(0.0_f64, |acc, i| acc.max(mat[[i, i]].abs()));
1103        let base = if max_diag.is_finite() && max_diag > 0.0 {
1104            max_diag * 1e-12
1105        } else {
1106            1e-12
1107        };
1108        let mut ridge = 0.0_f64;
1109        for _ in 0..=60 {
1110            let mut ridged = mat.clone();
1111            if ridge > 0.0 {
1112                for i in 0..d {
1113                    ridged[[i, i]] += ridge;
1114                }
1115            }
1116            if let Ok(factor) =
1117                factorize_symmetricwith_fallback(FaerArrayView::new(&ridged).as_ref(), Side::Lower)
1118            {
1119                let mut sol = rhs.clone();
1120                {
1121                    let v = array1_to_col_matmut(&mut sol);
1122                    factor.solve_in_place(v);
1123                }
1124                if sol.iter().all(|x| x.is_finite()) {
1125                    return Ok(sol);
1126                }
1127            }
1128            ridge = if ridge > 0.0 { ridge * 4.0 } else { base };
1129        }
1130        Err(EstimationError::InvalidInput(
1131            "multinomial Firth fallback: penalized Hessian solve failed".to_string(),
1132        ))
1133    };
1134
1135    // ─────────────────────────── Firth Newton loop ────────────────────────────
1136    let mut beta = Array2::<f64>::zeros((p, m));
1137    let mut converged = false;
1138    let mut iterations = 0_usize;
1139    for it in 0..max_iter {
1140        iterations = it + 1;
1141        let probs = probs_at(&beta);
1142        let info = assemble_info(&probs);
1143        let (iinv, logdet_info) = invert_spd(&info, "Fisher information")?;
1144        let u = firth_score(&probs, &beta, &iinv);
1145        let hmat = penalized_hessian(&info);
1146        let step_vec = solve_spd(&hmat, &u)?;
1147
1148        // Newton decrement ½ U*ᵀ H⁻¹ U* = ½ U*ᵀ Δ (≥ 0, scale-aware stop).
1149        let decrement = u.dot(&step_vec);
1150        if 0.5 * decrement.abs() < tol_eff {
1151            converged = true;
1152            break;
1153        }
1154
1155        // Δ as (P, M): delta[i, a] = step_vec[a·P + i].
1156        let mut delta = Array2::<f64>::zeros((p, m));
1157        for a in 0..m {
1158            let ao = a * p;
1159            for i in 0..p {
1160                delta[[i, a]] = step_vec[ao + i];
1161            }
1162        }
1163
1164        // Backtracking line search on ℓ* (ascent). Reject any candidate whose I is
1165        // not SPD (boundary), so the iterate stays interior.
1166        let o0 = objective(&probs, &beta, logdet_info);
1167        let mut step = 1.0_f64;
1168        let mut accepted = false;
1169        for _ in 0..60 {
1170            let cand = &beta + &(&delta * step);
1171            let cand_probs = probs_at(&cand);
1172            let cand_info = assemble_info(&cand_probs);
1173            if let Some(cand_logdet) = spd_logdet(&cand_info) {
1174                let o1 = objective(&cand_probs, &cand, cand_logdet);
1175                if o1 >= o0 - 1e-12 {
1176                    beta = cand;
1177                    accepted = true;
1178                    break;
1179                }
1180            }
1181            step *= 0.5;
1182        }
1183        if !accepted {
1184            // Backtracking exhausted 60 halvings without an admissible ascent
1185            // step. This is convergence ONLY if the iterate is already first-order
1186            // stationary; a line-search stall at a non-stationary point is a
1187            // solver failure and must be reported as such, never papered over as
1188            // `converged = true` (#2066 — SPEC: do not report a non-converged
1189            // iterate as success).
1190            //
1191            // The verdict is the loop's OWN stationarity test — the Newton
1192            // decrement `½·Uᵀ H⁻¹ U` against `tol_eff`, the same criterion the top
1193            // of the loop uses to break as converged. A true interior mode never
1194            // reaches this branch: an infinitesimal step (`step → 0`) leaves the
1195            // iterate SPD with `o1 ≈ o0`, so it is accepted; a numerically flat
1196            // mode is caught by the `max_step` test below after that accepted
1197            // tiny step. Reaching here therefore means Newton still sees a
1198            // meaningful ascent direction it cannot realize (boundary / near-
1199            // singular Fisher information), i.e. a genuine stall → not converged.
1200            converged = 0.5 * decrement.abs() < tol_eff;
1201            break;
1202        }
1203
1204        let max_step = step * delta.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
1205        let scale = 1.0 + beta.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
1206        if max_step < tol_eff * scale {
1207            converged = true;
1208            break;
1209        }
1210    }
1211
1212    // ─────────────────────────── final quantities ─────────────────────────────
1213    for (idx, &v) in beta.iter().enumerate() {
1214        if !v.is_finite() {
1215            crate::bail_invalid_estim!(
1216                "multinomial Firth fallback: non-finite coefficient at flat index {idx} = {v}"
1217            );
1218        }
1219    }
1220    let coefficients_active = beta;
1221
1222    let probs = probs_at(&coefficients_active);
1223    let info = assemble_info(&probs);
1224    // Laplace covariance H⁻¹ at the converged mode (block-ordered θ[a·P+i]).
1225    let hmat = penalized_hessian(&info);
1226    let coefficient_covariance = match invert_spd(&hmat, "penalized Hessian covariance") {
1227        Ok((cov, _)) => cov,
1228        Err(_) => Array2::<f64>::zeros((d, d)),
1229    };
1230
1231    let fitted_probabilities = probs;
1232    let mut log_likelihood = 0.0_f64;
1233    for row in 0..n_obs {
1234        let w = weight(row);
1235        for c in 0..k {
1236            let ycn = y_one_hot[[row, c]];
1237            if ycn != 0.0 {
1238                log_likelihood +=
1239                    w * ycn * fitted_probabilities[[row, c]].max(f64::MIN_POSITIVE).ln();
1240            }
1241        }
1242    }
1243
1244    let mut penalty_term = 0.0_f64;
1245    for a in 0..m {
1246        let beta_col = coefficients_active.column(a);
1247        let sbeta = penalty.dot(&beta_col);
1248        penalty_term += 0.5 * lambdas[a] * beta_col.dot(&sbeta);
1249    }
1250
1251    Ok(MultinomialFitOutputs {
1252        coefficients_active,
1253        fitted_probabilities,
1254        iterations,
1255        converged,
1256        penalized_neg_log_likelihood: -log_likelihood + penalty_term,
1257        deviance: -2.0 * log_likelihood,
1258        coefficient_covariance,
1259    })
1260}
1261
1262// ---------------------------------------------------------------------------
1263// Formula-driven multinomial pipeline
1264// ---------------------------------------------------------------------------
1265//
1266// Slice A of the multinomial integration: a single public entry that takes
1267// a parsed `EncodedDataset`, a Wilkinson-style formula, and a uniform initial
1268// smoothing parameter, then runs the full
1269//
1270//     parse → termspec → design (X, S blocks) → one-hot Y → REML λ-selection
1271//
1272// pipeline. `fit_penalized_multinomial_formula` drives the outer REML/LAML
1273// loop (via the custom-family path) to select an independent λ per (class,
1274// term); `init_lambda` (default 1.0) is only the warm-start seed for every
1275// block. The reference class is the last level of the categorical response
1276// column as recorded in the dataset schema.
1277
1278/// Saved-model payload for a multinomial fit driven by a Wilkinson formula.
1279///
1280/// This is what the FFI returns to Python. It carries everything the Python
1281/// `MultinomialModel.predict` path needs to evaluate `softmax(X_new · β)` on
1282/// fresh data using the *training* basis / penalty structure (no refit on
1283/// predict, no re-derivation of class levels).
1284#[derive(Debug, Clone, Serialize, Deserialize)]
1285pub struct MultinomialSavedModel {
1286    /// The training formula, verbatim. Stored so Python's `summary()` and
1287    /// any round-trip persistence path can echo what was fit.
1288    pub formula: String,
1289    /// Names of the *training* response levels in canonical order. The last
1290    /// entry is the reference class (η = 0); the first `K - 1` carry the
1291    /// active linear-predictor blocks. Class permutations are forbidden:
1292    /// this list is fixed at fit time and predictions emit columns in the
1293    /// same order.
1294    pub class_levels: Vec<String>,
1295    /// Index of the reference class within `class_levels` — currently always
1296    /// `class_levels.len() - 1`, exposed as a field so future "user-pinned
1297    /// reference" gauges (e.g. `family='multinomial', reference='setosa'`)
1298    /// can land without changing the on-disk shape.
1299    pub reference_class_index: usize,
1300    /// Resolved term-collection spec used to build `X` at fit time. Replayed
1301    /// on predict via [`gam_terms::smooth::build_term_collection_design`].
1302    pub resolved_termspec: TermCollectionSpec,
1303    /// Active-class coefficient block, shape `(P, K-1)`. Column `a` is the
1304    /// coefficient vector for class `class_levels[a]`. Stored flat in
1305    /// row-major order to keep the serde payload self-describing.
1306    pub coefficients_flat: Vec<f64>,
1307    /// `P` — coefficient count per active class. Matches the column count of
1308    /// the design matrix the saved `resolved_termspec` produces.
1309    pub p_per_class: usize,
1310    /// Number of active classes (`K - 1`).
1311    pub n_active_classes: usize,
1312    /// Original training column headers, in dataset-column order. Needed at
1313    /// predict time so the FFI can align a fresh `Dataset` to the training
1314    /// schema before evaluating the basis.
1315    pub training_headers: Vec<String>,
1316    /// REML/LAML-selected smoothing parameters, one per `(active class, smooth
1317    /// term)`, flattened in block-major order: all of class 0's per-term λ,
1318    /// then class 1's, and so on. Per-term penalties (#561) mean each active
1319    /// class block selects an *independent* λ for every smooth term, so this
1320    /// vector has length `Σ_a (#terms in class a)` = `(K − 1) · #terms`. Use
1321    /// [`MultinomialSavedModel::lambdas_per_block`] to segment it by class. An
1322    /// unpenalized model (no smooth terms) yields an empty vector.
1323    pub lambdas: Vec<f64>,
1324    /// Number of smoothing parameters (smooth terms) in each active class
1325    /// block, parallel to `class_levels[0..K-1]`. Segments the flat `lambdas`
1326    /// vector: class `a`'s λ are `lambdas[Σ_{b<a} lambdas_per_block[b] ..][..
1327    /// lambdas_per_block[a]]`. Every entry is identical in the shared-design
1328    /// architecture (all classes share the same term structure), but it is
1329    /// stored explicitly so consumers never have to assume that.
1330    pub lambdas_per_block: Vec<usize>,
1331    /// Newton iterations executed; recorded for the summary report.
1332    pub iterations: usize,
1333    /// `true` if the inner Newton solver hit the relative-step tolerance.
1334    pub converged: bool,
1335    /// Penalized negative log-likelihood at the returned `β̂`.
1336    pub penalized_neg_log_likelihood: f64,
1337    /// Unpenalized deviance `−2 log L(β̂)`.
1338    pub deviance: f64,
1339    /// Per-active-class effective degrees of freedom (hat-matrix trace),
1340    /// length `K - 1`. Populated when the REML driver reports an
1341    /// inference block; falls back to `None` for the legacy fixed-λ path.
1342    #[serde(default)]
1343    pub edf_per_class: Option<Vec<f64>>,
1344    /// Per-PENALTY effective degrees of freedom, one entry per smoothing
1345    /// parameter (length `== lambdas.len()`), aligned block-major with the flat
1346    /// [`Self::lambdas`] / [`Self::lambdas_per_block`] layout. Each entry is the
1347    /// penalty-block trace EDF `rank(S_k) − λ_k·tr(H⁻¹ S_k)`, clamped to
1348    /// `[0, rank(S_k)]`. This is the per-(class, term, penalty) resolution that
1349    /// the per-class [`Self::edf_per_class`] SUM deliberately hides: only the
1350    /// per-penalty vector reveals whether an individual smooth collapsed onto its
1351    /// polynomial null space (its wiggliness λ driven to the λ-cap), which a
1352    /// per-class total cannot show. Populated whenever the REML driver reports an
1353    /// inference block; `None` on the legacy fixed-λ path or when the trace
1354    /// channel is mis-shaped. Unlike `edf_per_class`, the entries do NOT sum to
1355    /// the model EDF when several penalties share one coefficient range (a
1356    /// double-penalty smooth has `Σ_k rank(S_k) > p_per_class`).
1357    #[serde(default)]
1358    pub edf_per_penalty: Option<Vec<f64>>,
1359    /// Joint posterior coefficient covariance `H⁻¹` (#1101), block-ordered to
1360    /// match the stacked active-class coefficient vector `β = [β_0; …; β_{K-2}]`
1361    /// (class `a`'s `P` coefficients occupy rows/cols `a·P .. (a+1)·P`). This is
1362    /// the Laplace covariance the REML driver already computes from the factored
1363    /// penalized Hessian; storing it gives the predict path delta-method
1364    /// per-class probability standard errors and the summary its Wald
1365    /// smooth-term tests. Flattened row-major over the `(P·M)×(P·M)` matrix.
1366    /// `None` for a model fitted before covariance was surfaced.
1367    #[serde(default)]
1368    pub coefficient_covariance_flat: Option<Vec<f64>>,
1369    /// Joint coefficient-space influence matrix `F = H⁻¹ X'WX` (#1101),
1370    /// block-ordered identically to [`Self::coefficient_covariance_flat`].
1371    /// Its per-term diagonal block trace is the term's effective degrees of
1372    /// freedom and its `tr(F_jj)²/tr(F_jj²)` the Wood reference d.f., feeding
1373    /// the rank-truncated Wald smooth-term test in `summary()`. Flattened
1374    /// row-major over the `(P·M)×(P·M)` matrix. `None` when unavailable.
1375    #[serde(default)]
1376    pub coefficient_influence_flat: Option<Vec<f64>>,
1377    /// Per-(active class, smooth term) coefficient column range and unpenalized
1378    /// nullspace dimension within the `P`-wide class block (#1101). Parallel to
1379    /// the smooth terms the design produced; replicated across classes by the
1380    /// shared-design architecture. Drives the Wald smooth-term table in
1381    /// `summary()`. Empty for a wholly parametric (no-smooth) model.
1382    #[serde(default)]
1383    pub smooth_term_spans: Vec<MultinomialSmoothTermSpan>,
1384    /// One descriptive label per *penalty component* within a single active-class
1385    /// block, parallel to that block's λ slice (i.e. length
1386    /// `lambdas_per_block[0]`). The Marra–Wood double penalty (and tensor /
1387    /// operator smooths) emit **more than one** penalty component — hence more
1388    /// than one λ — per smooth term, so this is NOT 1:1 with
1389    /// [`Self::smooth_term_spans`]: a single `s(x)` term contributes a primary
1390    /// wiggliness λ labelled `s(x)` and a null-space shrinkage λ labelled
1391    /// `s(x) [null space]`. The summary renderer pairs `lambdas` with these
1392    /// labels component-for-component so no λ is ever dropped (#1544). Built from
1393    /// the per-component term name + penalty role at fit time; empty for a
1394    /// wholly parametric model or a model serialized before this field existed.
1395    #[serde(default)]
1396    pub lambda_labels: Vec<String>,
1397}
1398
1399/// One smooth term's coefficient span within a class block, plus its
1400/// unpenalized nullspace dimension and a display label (#1101). The Wald
1401/// smooth-significance test in `summary()` slices the joint covariance /
1402/// influence at `a·P + col_start .. a·P + col_end` for active class `a`.
1403#[derive(Debug, Clone, Serialize, Deserialize)]
1404pub struct MultinomialSmoothTermSpan {
1405    /// Human-readable term label (the smooth's formula token), for the table.
1406    pub label: String,
1407    /// Start column of the term within the per-class `P`-wide coefficient block.
1408    pub col_start: usize,
1409    /// End column (exclusive) of the term within the per-class block.
1410    pub col_end: usize,
1411    /// Leading unpenalized (polynomial nullspace) dimension within the term.
1412    pub nullspace_dim: usize,
1413}
1414
1415/// Descriptive label for one penalty *component* (one λ) within a class block,
1416/// for the `summary()` per-class λ rollup (#1544). A smooth term can emit
1417/// several penalty components — the Marra–Wood double penalty splits `s(x)`
1418/// into a primary wiggliness penalty and a null-space shrinkage penalty, and
1419/// tensor / operator smooths emit a component per margin / differential
1420/// operator — each with its own independently-selected λ. The label is the
1421/// term name (from `PenaltyBlockInfo::termname`) plus a role suffix derived
1422/// from the penalty's [`PenaltySource`], so each λ in the summary names both
1423/// the term it smooths and the role it plays. `pen_idx` is the global penalty
1424/// index, used only as a last-resort fallback label.
1425fn penalty_component_label(info: Option<&PenaltyBlockInfo>, pen_idx: usize) -> String {
1426    use gam_terms::basis::PenaltySource;
1427    let term = info
1428        .and_then(|i| i.termname.clone())
1429        .unwrap_or_else(|| format!("s{pen_idx}"));
1430    let role = match info.map(|i| &i.penalty.source) {
1431        // The primary wiggliness penalty is the term's "main" λ; show the bare
1432        // term name so the common single-penalty case reads cleanly.
1433        Some(PenaltySource::Primary) | None => None,
1434        Some(PenaltySource::DoublePenaltyNullspace) => Some("null space".to_string()),
1435        Some(PenaltySource::OperatorMass) => Some("mass".to_string()),
1436        Some(PenaltySource::OperatorTension) => Some("tension".to_string()),
1437        Some(PenaltySource::OperatorStiffness) => Some("stiffness".to_string()),
1438        Some(PenaltySource::OperatorRelevance { axis }) => Some(format!("axis {axis}")),
1439        Some(PenaltySource::TensorMarginal { dim }) => Some(format!("margin {dim}")),
1440        Some(PenaltySource::TensorSeparable { penalized_margins }) => {
1441            Some(format!("separable {penalized_margins:?}"))
1442        }
1443        Some(PenaltySource::TensorGlobalRidge) => Some("ridge".to_string()),
1444        Some(PenaltySource::Other(s)) => Some(s.clone()),
1445    };
1446    match role {
1447        Some(role) => format!("{term} [{role}]"),
1448        None => term,
1449    }
1450}
1451
1452impl MultinomialSavedModel {
1453    /// Active-class coefficient block as an `(P, K-1)` `ndarray` view.
1454    pub fn coefficients_active(&self) -> Array2<f64> {
1455        Array2::from_shape_vec(
1456            (self.p_per_class, self.n_active_classes),
1457            self.coefficients_flat.clone(),
1458        )
1459        .expect(
1460            "MultinomialSavedModel.coefficients_flat length must equal p_per_class * n_active_classes",
1461        )
1462    }
1463
1464    /// Evaluate `softmax(X · β)` at fresh data rows. `X_new` must have
1465    /// `self.p_per_class` columns (i.e. it was built from the same
1466    /// `resolved_termspec` as fit time). Returns an `(N_new, K)` matrix
1467    /// with rows summing to 1; column order matches `self.class_levels`.
1468    pub fn predict_probabilities(&self, x_new: ArrayView2<'_, f64>) -> Array2<f64> {
1469        let n_new = x_new.nrows();
1470        let p = self.p_per_class;
1471        let m = self.n_active_classes;
1472        let k = m + 1;
1473        assert_eq!(
1474            x_new.ncols(),
1475            p,
1476            "MultinomialSavedModel.predict_probabilities: X has {} cols, expected {p}",
1477            x_new.ncols()
1478        );
1479        let beta = self.coefficients_active();
1480        let mut probs = Array2::<f64>::zeros((n_new, k));
1481        let mut eta_active = vec![0.0_f64; m];
1482        let mut row_probs = vec![0.0_f64; k];
1483        for row in 0..n_new {
1484            for a in 0..m {
1485                let mut v = 0.0_f64;
1486                for i in 0..p {
1487                    v += x_new[[row, i]] * beta[[i, a]];
1488                }
1489                eta_active[a] = v;
1490            }
1491            MultinomialLogitLikelihood::softmax_with_baseline(&eta_active, &mut row_probs);
1492            for c in 0..k {
1493                probs[[row, c]] = row_probs[c];
1494            }
1495        }
1496        probs
1497    }
1498
1499    /// Reconstruct the joint posterior covariance `H⁻¹` as a `(P·M)×(P·M)`
1500    /// `ndarray`, block-ordered to match the stacked coefficient vector
1501    /// `θ[a·P + i] = β[i, a]` (#1101). `None` when the model was fitted before
1502    /// covariance was surfaced (legacy payload).
1503    pub fn coefficient_covariance(&self) -> Option<Array2<f64>> {
1504        let d = self.p_per_class.checked_mul(self.n_active_classes)?;
1505        let flat = self.coefficient_covariance_flat.as_ref()?;
1506        Array2::from_shape_vec((d, d), flat.clone()).ok()
1507    }
1508
1509    /// Reconstruct the joint influence matrix `F = H⁻¹ X'WX` as a
1510    /// `(P·M)×(P·M)` `ndarray`, block-ordered like
1511    /// [`Self::coefficient_covariance`] (#1101). `None` when unavailable.
1512    pub fn coefficient_influence(&self) -> Option<Array2<f64>> {
1513        let d = self.p_per_class.checked_mul(self.n_active_classes)?;
1514        let flat = self.coefficient_influence_flat.as_ref()?;
1515        Array2::from_shape_vec((d, d), flat.clone()).ok()
1516    }
1517
1518    /// Evaluate `softmax(X·β)` AND its delta-method per-class probability
1519    /// standard error at fresh data rows (#1101).
1520    ///
1521    /// For active classes `b ∈ 0..M` the softmax Jacobian is
1522    /// `∂p_c/∂η_b = p_c (δ_{cb} − p_b)`, and `∂η_b/∂β[i,a] = X[i]·δ_{ab}`, so the
1523    /// gradient of class-`c` probability w.r.t. the block-ordered coefficient
1524    /// vector is `g_c[a·P + i] = X[i]·p_c (δ_{ca} − p_a)` (active `a`; the
1525    /// reference class `M` contributes `p_c(0 − p_a)` via every active block).
1526    /// The delta-method variance is `Var(p_c) = g_cᵀ Σ g_c` with `Σ = H⁻¹` the
1527    /// joint posterior covariance, and `SE(p_c) = √Var(p_c)`. Returns
1528    /// `(probs (N,K), prob_se (N,K))`; `prob_se` is `None` when no covariance is
1529    /// stored. The simplex `[0,1]` clamp is applied by the interval consumer, not
1530    /// here (the SE itself is unclamped).
1531    pub fn predict_probabilities_with_se(
1532        &self,
1533        x_new: ArrayView2<'_, f64>,
1534    ) -> (Array2<f64>, Option<Array2<f64>>) {
1535        let probs = self.predict_probabilities(x_new);
1536        let Some(cov) = self.coefficient_covariance() else {
1537            return (probs, None);
1538        };
1539        let n_new = x_new.nrows();
1540        let p = self.p_per_class;
1541        let m = self.n_active_classes;
1542        let k = m + 1;
1543        let d = p * m;
1544        let mut prob_se = Array2::<f64>::zeros((n_new, k));
1545        let mut grad = vec![0.0_f64; d];
1546        for row in 0..n_new {
1547            let prow = probs.row(row);
1548            for c in 0..k {
1549                let pc = prow[c];
1550                // g_c[a·P + i] = X[i] · p_c · (δ_{ca} − p_a), a active.
1551                for a in 0..m {
1552                    let pa = prow[a];
1553                    let factor = pc * (if c == a { 1.0 - pa } else { -pa });
1554                    let base = a * p;
1555                    for i in 0..p {
1556                        grad[base + i] = x_new[[row, i]] * factor;
1557                    }
1558                }
1559                // Var = gᵀ Σ g.
1560                let mut var = 0.0_f64;
1561                for r in 0..d {
1562                    let gr = grad[r];
1563                    if gr == 0.0 {
1564                        continue;
1565                    }
1566                    let mut acc = 0.0_f64;
1567                    for s in 0..d {
1568                        acc += cov[[r, s]] * grad[s];
1569                    }
1570                    var += gr * acc;
1571                }
1572                prob_se[[row, c]] = var.max(0.0).sqrt();
1573            }
1574        }
1575        (probs, Some(prob_se))
1576    }
1577
1578    /// Wood (2013) rank-truncated Wald smooth-significance test per
1579    /// `(active class, smooth term)` (#1101), reusing the exact scalar-summary
1580    /// kernel [`gam_terms::inference::smooth_test::wood_smooth_test`]. For active
1581    /// class `a` and term span `[c0, c1)` within the class block, the global
1582    /// coefficient range is `a·P + c0 .. a·P + c1`; the joint covariance and
1583    /// influence are sliced there. The term EDF is the influence-block trace
1584    /// `tr(F_jj)` (when present) and the reference d.f. uses `tr(F_jj)²/tr(F_jj²)`,
1585    /// exactly as the scalar path. The multinomial softmax is a known-dispersion
1586    /// family, so the χ²_{ref_df} branch applies. Returns one row per
1587    /// `(class label, term label, edf, ref_df, statistic, p_value)`; empty when
1588    /// no covariance/smooth terms are available.
1589    pub fn smooth_significance(&self) -> Vec<MultinomialSmoothSignificance> {
1590        let mut out = Vec::new();
1591        let p = self.p_per_class;
1592        let m = self.n_active_classes;
1593        let Some(cov) = self.coefficient_covariance() else {
1594            return out;
1595        };
1596        if self.smooth_term_spans.is_empty() {
1597            return out;
1598        }
1599        let beta = self.coefficients_active();
1600        // Block-ordered θ = [β_0; …; β_{M-1}], θ[a·P + i] = β[i, a].
1601        let d = p * m;
1602        let mut theta = Array1::<f64>::zeros(d);
1603        for a in 0..m {
1604            for i in 0..p {
1605                theta[a * p + i] = beta[[i, a]];
1606            }
1607        }
1608        let influence = self.coefficient_influence();
1609        for a in 0..m {
1610            let class_label = self
1611                .class_levels
1612                .get(a)
1613                .cloned()
1614                .unwrap_or_else(|| format!("class{a}"));
1615            let base = a * p;
1616            for span in &self.smooth_term_spans {
1617                if span.col_end > p {
1618                    continue;
1619                }
1620                let start = base + span.col_start;
1621                let end = base + span.col_end;
1622                // Term EDF = tr(F_jj); without an influence matrix fall back to
1623                // the block coefficient count (full-rank Wald on the span).
1624                let block_len = (span.col_end - span.col_start) as f64;
1625                let edf = influence
1626                    .as_ref()
1627                    .map(|f| (start..end).map(|i| f[[i, i]]).sum::<f64>())
1628                    .filter(|v| v.is_finite() && *v > 0.0)
1629                    .unwrap_or(block_len);
1630                let result = gam_terms::inference::smooth_test::wood_smooth_test(
1631                    gam_terms::inference::smooth_test::SmoothTestInput {
1632                        beta: theta.view(),
1633                        covariance: &cov,
1634                        influence_matrix: influence.as_ref(),
1635                        coeff_range: start..end,
1636                        edf,
1637                        nullspace_dim: span.nullspace_dim,
1638                        residual_df: f64::INFINITY,
1639                        scale: gam_terms::inference::smooth_test::SmoothTestScale::Known,
1640                    },
1641                );
1642                if let Some(res) = result {
1643                    out.push(MultinomialSmoothSignificance {
1644                        class_label: class_label.clone(),
1645                        term_label: span.label.clone(),
1646                        edf,
1647                        ref_df: res.ref_df,
1648                        statistic: res.statistic,
1649                        p_value: res.p_value,
1650                    });
1651                }
1652            }
1653        }
1654        out
1655    }
1656
1657    /// Draw `n_draws` posterior-predictive replicate class assignments at fresh
1658    /// rows (#1101). Each draw independently samples every row's class from
1659    /// `Categorical(p_row)` with `p = softmax(X·β̂)` — the plug-in predictive
1660    /// distribution, i.e. the multinomial observation noise wrapped around the
1661    /// fitted mean (the categorical analogue of the scalar families'
1662    /// `sample_replicates`). The returned `(n_draws, N)` matrix holds class
1663    /// INDICES `0..K`, aligned to [`Self::class_levels`]. The draw stream is a
1664    /// `StdRng` seeded by `seed`, so `(x_new, n_draws, seed)` reproduce
1665    /// bit-identically — the engine for posterior-predictive checks and
1666    /// simulation-based calibration. `x_new` must have `self.p_per_class`
1667    /// columns (built from the same `resolved_termspec` as fit time).
1668    pub fn sample_replicate_classes(
1669        &self,
1670        x_new: ArrayView2<'_, f64>,
1671        n_draws: usize,
1672        seed: u64,
1673    ) -> Array2<u32> {
1674        use rand::{RngExt, SeedableRng};
1675        let probs = self.predict_probabilities(x_new);
1676        let n = probs.nrows();
1677        let k = probs.ncols();
1678        let mut out = Array2::<u32>::zeros((n_draws, n));
1679        let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
1680        for d in 0..n_draws {
1681            for row in 0..n {
1682                let u: f64 = rng.random::<f64>();
1683                // Inverse-CDF categorical draw over the K simplex weights.
1684                let mut acc = 0.0_f64;
1685                let mut chosen = k - 1; // numerical fallback = reference class
1686                for c in 0..k {
1687                    acc += probs[[row, c]];
1688                    if u < acc {
1689                        chosen = c;
1690                        break;
1691                    }
1692                }
1693                out[[d, row]] = chosen as u32;
1694            }
1695        }
1696        out
1697    }
1698}
1699
1700/// One row of the multinomial smooth-significance table (#1101): the Wood
1701/// rank-truncated Wald test for one `(active class, smooth term)` pair.
1702#[derive(Debug, Clone)]
1703pub struct MultinomialSmoothSignificance {
1704    pub class_label: String,
1705    pub term_label: String,
1706    pub edf: f64,
1707    pub ref_df: f64,
1708    pub statistic: f64,
1709    pub p_value: f64,
1710}
1711
1712/// One-hot-encode the categorical response column and return both the
1713/// encoding and the captured level names. The level order matches the order
1714/// recorded in the dataset schema, which is the canonical (lexicographically
1715/// sorted) factor order produced by inferred-schema construction (#1319) — so
1716/// it is a deterministic function of the label *set*, independent of training
1717/// row order (no silent class permutation under a row shuffle), and matches the
1718/// R `factor()` / pandas `Categorical` convention.
1719fn one_hot_categorical_response(
1720    data: &EncodedDataset,
1721    y_col: usize,
1722    response_name: &str,
1723) -> Result<(Array2<f64>, Vec<String>), EstimationError> {
1724    let levels: Vec<String> = data
1725        .schema
1726        .columns
1727        .get(y_col)
1728        .map(|sc| sc.levels.clone())
1729        .unwrap_or_default();
1730    if levels.len() < 2 {
1731        crate::bail_invalid_estim!(
1732            "multinomial response '{response_name}' must have at least 2 categorical levels (got {})",
1733            levels.len()
1734        );
1735    }
1736    let n = data.values.nrows();
1737    let k = levels.len();
1738    let mut y_one_hot = Array2::<f64>::zeros((n, k));
1739    for row in 0..n {
1740        let encoded = data.values[[row, y_col]];
1741        if !encoded.is_finite() {
1742            crate::bail_invalid_estim!(
1743                "multinomial response '{response_name}' row {row} is non-finite ({encoded})"
1744            );
1745        }
1746        let class_idx = encoded.round() as i64;
1747        if class_idx < 0 || (class_idx as usize) >= k {
1748            crate::bail_invalid_estim!(
1749                "multinomial response '{response_name}' row {row} encoded as {encoded} \
1750                 is outside the level range 0..{k}"
1751            );
1752        }
1753        y_one_hot[[row, class_idx as usize]] = 1.0;
1754    }
1755    Ok((y_one_hot, levels))
1756}
1757
1758/// Build `(TermCollectionSpec, TermCollectionDesign)` from a formula against
1759/// a categorical-response dataset. Mirrors the early scaffolding inside
1760/// `materialize_standard` (response role resolution, geometry-aware spec
1761/// build) without touching the scalar-family resolution path — multinomial
1762/// owns its own response kind check.
1763fn build_formula_design_for_multinomial(
1764    formula: &str,
1765    data: &EncodedDataset,
1766    config: &FitConfig,
1767) -> Result<
1768    (
1769        TermCollectionSpec,
1770        TermCollectionDesign,
1771        usize,
1772        String,
1773        ResponseColumnKind,
1774    ),
1775    EstimationError,
1776> {
1777    let parsed = parse_formula(formula).map_err(|err| {
1778        EstimationError::InvalidInput(format!(
1779            "multinomial fit: failed to parse formula {formula:?}: {err}"
1780        ))
1781    })?;
1782    let col_map = data.column_map();
1783    let y_col = resolve_role_col(&col_map, &parsed.response, "response")
1784        .map_err(|err| EstimationError::InvalidInput(format!("multinomial fit: {err}")))?;
1785    let y_kind = crate::fit_orchestration::response_column_kind(data, y_col);
1786    let policy = resolved_resource_policy(config, data, ProblemHints::default());
1787    let mut inference_notes: Vec<String> = Vec::new();
1788    let spec = build_termspec_with_geometry_and_overrides(
1789        &parsed.terms,
1790        data,
1791        &col_map,
1792        &mut inference_notes,
1793        config.scale_dimensions,
1794        &policy,
1795        config.smooth_overrides.as_ref(),
1796    )
1797    .map_err(|err| {
1798        EstimationError::InvalidInput(format!("multinomial fit: build termspec: {err}"))
1799    })?;
1800    let design = build_term_collection_design(data.values.view(), &spec).map_err(|err| {
1801        EstimationError::InvalidInput(format!("multinomial fit: build design: {err}"))
1802    })?;
1803    Ok((spec, design, y_col, parsed.response, y_kind))
1804}
1805
1806fn scale_multinomial_formula_penalty(penalty: PenaltyMatrix, scale: f64) -> PenaltyMatrix {
1807    match penalty {
1808        PenaltyMatrix::Dense(matrix) => PenaltyMatrix::Dense(matrix.mapv(|v| v * scale)),
1809        PenaltyMatrix::KroneckerFactored { left, right } => PenaltyMatrix::KroneckerFactored {
1810            left: left.mapv(|v| v * scale),
1811            right,
1812        },
1813        PenaltyMatrix::Blockwise {
1814            local,
1815            col_range,
1816            total_dim,
1817        } => PenaltyMatrix::Blockwise {
1818            local: local.mapv(|v| v * scale),
1819            col_range,
1820            total_dim,
1821        },
1822        PenaltyMatrix::Labeled { label, inner } => PenaltyMatrix::Labeled {
1823            label,
1824            inner: Box::new(scale_multinomial_formula_penalty(*inner, scale)),
1825        },
1826        PenaltyMatrix::Fixed { log_lambda, inner } => PenaltyMatrix::Fixed {
1827            log_lambda,
1828            inner: Box::new(scale_multinomial_formula_penalty(*inner, scale)),
1829        },
1830    }
1831}
1832
1833/// Build a warm-started copy of `blocks` whose per-block `initial_log_lambdas`
1834/// are seeded from a previously-selected flat `log_lambdas` vector (#1082).
1835///
1836/// The flat `log_lambdas` returned by [`fit_custom_family_with_rho_prior`]
1837/// concatenates each block's penalty log-λ in block order — the same order
1838/// `build_block_specs()` emits the blocks and the same per-block penalty order
1839/// the spec carries — so it splits back across blocks by each block's penalty
1840/// count. Warm-starting the OUTER ρ-search from a prior iterate changes only the
1841/// optimizer's starting point, never the penalized objective or its optimum, so
1842/// the converged fit is identical; it just resumes near the prior iterate
1843/// instead of restarting from the cold `init_lambda` seed.
1844///
1845/// Returns `None` (caller falls back to the cold blocks) if the flat vector does
1846/// not have exactly one entry per penalty across all blocks, or carries a
1847/// non-finite value — i.e. anything that would make the seed unsafe.
1848fn warm_start_blocks_from_log_lambdas(
1849    blocks: &[crate::custom_family::ParameterBlockSpec],
1850    log_lambdas: &[f64],
1851) -> Option<Vec<crate::custom_family::ParameterBlockSpec>> {
1852    let total: usize = blocks.iter().map(|b| b.initial_log_lambdas.len()).sum();
1853    if total == 0 || log_lambdas.len() != total {
1854        return None;
1855    }
1856    if log_lambdas.iter().any(|v| !v.is_finite()) {
1857        return None;
1858    }
1859    let mut warm = blocks.to_vec();
1860    let mut offset = 0usize;
1861    for block in warm.iter_mut() {
1862        let k = block.initial_log_lambdas.len();
1863        for slot in 0..k {
1864            block.initial_log_lambdas[slot] = log_lambdas[offset + slot];
1865        }
1866        offset += k;
1867    }
1868    Some(warm)
1869}
1870
1871/// Top-level formula-driven multinomial fit.
1872///
1873/// Routes through [`fit_custom_family_with_rho_prior`] so the per-active-class
1874/// smoothing parameters `λ_a` (one per class block, shared-penalty
1875/// architecture) are selected by the outer REML/LAML loop rather than pinned
1876/// by the caller. `init_lambda` survives as a warm-start hint that seeds
1877/// every block's `initial_log_lambdas`. `max_iter` / `tol` drive the OUTER
1878/// REML/LAML smoothing-parameter search (`outer_max_iter` / `outer_tol`); the
1879/// inner joint-Newton solve runs on the framework's principled production cycle
1880/// budget at the default KKT tolerance so an ill-conditioned, LM-damped
1881/// near-simplex-boundary solve can certify a stationary point instead of being
1882/// declared non-converged after only `max_iter` cycles (#715).
1883///
1884/// The Jeffreys/Firth proper prior is engaged CONDITIONALLY: attempt 1 runs
1885/// the unbiased penalized-REML criterion; only on separation evidence (a failed
1886/// solve or a non-finite logit; see [`multinomial_formula_separation_evidence`])
1887/// is the fit re-solved once with the full-span Firth prior armed, which bounds
1888/// the penalty-null directions no smoothing parameter can (`S v = 0` ⇒
1889/// `(H + S_λ) v = H v → 0` when the softmax likelihood has no finite mode).
1890///
1891/// The categorical response column is recognised via the dataset schema
1892/// (`ColumnKindTag::Categorical`); reference class = last level. Returns a
1893/// [`MultinomialSavedModel`] that can be serialised to bytes for the Python
1894/// wrapper or used in-process for `predict_probabilities`.
1895pub fn fit_penalized_multinomial_formula(
1896    data: &EncodedDataset,
1897    formula: &str,
1898    config: &FitConfig,
1899    init_lambda: f64,
1900    max_iter: usize,
1901    tol: f64,
1902) -> Result<MultinomialSavedModel, EstimationError> {
1903    if !(init_lambda.is_finite() && init_lambda > 0.0) {
1904        crate::bail_invalid_estim!(
1905            "multinomial fit: init_lambda must be finite and > 0 (got {init_lambda})"
1906        );
1907    }
1908    let (raw_spec, design, y_col, response_name, y_kind) =
1909        build_formula_design_for_multinomial(formula, data, config)?;
1910    // Freeze the data-derived basis state (B-spline knot vectors, by-factor
1911    // level sets, spatial centers, joint-null rotations, residualization
1912    // charts) from the fit design back onto the spec. The raw geometry spec
1913    // records only *which* columns and *what kind* of basis each smooth uses;
1914    // the actual column count and basis evaluation depend on quantities the
1915    // builder derives from the training data (knot placement, the distinct
1916    // by-factor levels, etc.). Saving the raw spec made predict re-derive those
1917    // from the (smaller, differently-distributed) predict frame, so the rebuilt
1918    // design had a different column count than the fitted one — the panic
1919    // "predict design has 42 cols, saved model expects 191" for an `s(x,
1920    // by=group)` smooth-by-factor model. Every other family's persistence path
1921    // freezes the spec the same way (see `freeze_term_collection_from_design`
1922    // call sites in `main_parts`); multinomial was the lone exception.
1923    let spec = freeze_term_collection_from_design(&raw_spec, &design)?;
1924    let class_levels = match y_kind {
1925        ResponseColumnKind::Categorical { levels } => levels,
1926        ResponseColumnKind::Binary => vec!["0".to_string(), "1".to_string()],
1927        ResponseColumnKind::Numeric => {
1928            crate::bail_invalid_estim!(
1929                "multinomial fit: response '{response_name}' is numeric, not categorical; \
1930                 use family='gaussian'/'binomial'/... or convert the column to a categorical type"
1931            );
1932        }
1933    };
1934    if data.column_kinds.get(y_col) == Some(&ColumnKindTag::Binary) {
1935        // Promote to a 2-level categorical for the multinomial driver; the
1936        // caller explicitly asked for multinomial, so we route through the
1937        // K-1 = 1 active-class softmax (equivalent math to logistic).
1938    } else if data.column_kinds.get(y_col) != Some(&ColumnKindTag::Categorical) {
1939        crate::bail_invalid_estim!(
1940            "multinomial fit: response '{response_name}' must be a categorical column \
1941             (got column kind {:?})",
1942            data.column_kinds.get(y_col)
1943        );
1944    }
1945    let (y_one_hot, _) = one_hot_categorical_response(data, y_col, &response_name)?;
1946    // Build the global X dense (the design is a DesignMatrix abstraction).
1947    let mut x_dense = design
1948        .design
1949        .try_to_dense_by_chunks("multinomial fit design")
1950        .map_err(EstimationError::InvalidInput)?;
1951
1952    // ── #715 real-data conditioning: standardize unpenalized parametric
1953    // columns. Raw-unit linear covariates (penguins `body_mass_g` ~ 4e3 grams)
1954    // inflate the joint Newton information by the squared column scale (a κ(H)
1955    // multiplier of ~s² ≈ 1e7 against the intercept), which is what turns the
1956    // near-separable LM-damped inner solve into a geometric grind that
1957    // exhausts its cycle budgets — the adapter-level face of "all REML startup
1958    // seeds rejected". Because these columns are UNPENALIZED (parametric terms
1959    // carry no default ridge, #749), the affine reparameterization
1960    // `x_j ↦ (x_j − m_j)/s_j` is EXACT for the whole criterion: the optimized
1961    // REML/LAML objective, the fitted η, the selected λ, and the separation
1962    // diagnostics are all invariant — only the conditioning of `H` changes.
1963    // Fitted coefficients are mapped back to raw units at repack below, so the
1964    // saved model and the (raw-design) predict path are untouched. Penalized
1965    // columns are left alone (a penalty makes the rescaling non-equivalent),
1966    // and nothing is touched when explicit coefficient bounds/constraints
1967    // exist (those are stated in raw units).
1968    let parametric_standardization: Vec<(usize, f64, f64)> =
1969        if design.coefficient_lower_bounds.is_some() || design.linear_constraints.is_some() {
1970            Vec::new()
1971        } else {
1972            let p_total = x_dense.ncols();
1973            let mut penalized = vec![false; p_total];
1974            for bp in &design.penalties {
1975                for col in bp.col_range.clone() {
1976                    if col < p_total {
1977                        penalized[col] = true;
1978                    }
1979                }
1980            }
1981            let has_intercept = !design.intercept_range.is_empty();
1982            let n_rows = x_dense.nrows().max(1) as f64;
1983            let mut standardized = Vec::new();
1984            for (_, range) in &design.linear_ranges {
1985                for col in range.clone() {
1986                    if col >= p_total || penalized[col] {
1987                        continue;
1988                    }
1989                    let column = x_dense.column(col);
1990                    let mean = column.sum() / n_rows;
1991                    let var = column.iter().map(|v| (v - mean) * (v - mean)).sum::<f64>() / n_rows;
1992                    let scale = var.sqrt();
1993                    // Skip near-constant or degenerate columns: no conditioning to
1994                    // be gained and the back-map would divide by ~0.
1995                    if !(scale.is_finite() && scale > 1e-8 * (mean.abs() + 1.0)) {
1996                        continue;
1997                    }
1998                    // Centering shifts mass onto the intercept; without one the
1999                    // shift is not representable, so scale only.
2000                    let center = if has_intercept { mean } else { 0.0 };
2001                    for v in x_dense.column_mut(col).iter_mut() {
2002                        *v = (*v - center) / scale;
2003                    }
2004                    standardized.push((col, center, scale));
2005                }
2006            }
2007            standardized
2008        };
2009    // Preserve the per-smooth-term penalty block structure (#561): each smooth
2010    // term `t` contributes its own `P × P` penalty component (`Blockwise` with
2011    // `total_dim = P`, the term's local `S_t` embedded at its `col_range`), and
2012    // every active class block receives the FULL list. The outer REML/LAML loop
2013    // then selects an independent smoothing parameter λ_{a,t} per (class, term),
2014    // matching mgcv/VGAM. Pre-summing the terms into one fused `S` (the prior
2015    // behaviour) forced a single λ per class that scales `Σ_t S_t`, so one
2016    // shared λ had to over-smooth a rough term while under-smoothing a smooth
2017    // one — biasing any multi-term class-probability surface.
2018    let k = y_one_hot.ncols();
2019    let m = k - 1;
2020    let n_obs = y_one_hot.nrows();
2021    let penalty_scale = multinomial_formula_penalty_scale(k);
2022    let per_term_penalties: Vec<PenaltyMatrix> = design
2023        .penalties_as_penalty_matrix()
2024        .into_iter()
2025        .map(|penalty| scale_multinomial_formula_penalty(penalty, penalty_scale))
2026        .collect();
2027    let per_term_nullspace_dims = design.nullspace_dims.clone();
2028
2029    // ── Custom-family driven REML/LAML path ───────────────────────────────
2030    // Each active class becomes one ParameterBlockSpec, all sharing X and the
2031    // per-term penalty list. `initial_log_lambdas` is seeded from the caller's
2032    // `init_lambda` (one entry per term).
2033    let design_arc = Arc::new(x_dense);
2034    let penalties_arc = Arc::new(per_term_penalties);
2035    let nullspace_dims_arc = Arc::new(per_term_nullspace_dims);
2036    let weights = Array1::<f64>::ones(n_obs);
2037    // First attempt runs the UNBIASED penalized-REML criterion (no Firth
2038    // shrinkage toward the uniform simplex); the Jeffreys/Firth proper prior is
2039    // armed conditionally below, only on separation evidence (#715/#753 — see
2040    // `multinomial_formula_separation_evidence`).
2041    let log_init = init_lambda.ln();
2042    let family = MultinomialFamily::new(
2043        y_one_hot.clone(),
2044        weights,
2045        k,
2046        design_arc.clone(),
2047        penalties_arc.clone(),
2048        nullspace_dims_arc.clone(),
2049    )
2050    .map_err(EstimationError::InvalidInput)?
2051    .with_joint_jeffreys_term(false)
2052    // gam#1587: the per-block smooth penalties are emptied (the centered `M⊗S_t`
2053    // joint penalty is the sole smoothing carrier), so the `init_lambda` warm
2054    // start must seed the JOINT penalty's `initial_log_lambda` — the per-block
2055    // `initial_log_lambdas` loop below is now a no-op (empty per-block list).
2056    .with_initial_log_lambda(log_init);
2057    let mut blocks = family.build_block_specs();
2058    for spec_block in blocks.iter_mut() {
2059        for v in spec_block.initial_log_lambdas.iter_mut() {
2060            *v = log_init;
2061        }
2062    }
2063
2064    // ── Outer-derivative policy: dimension-gated exact curvature ────────────
2065    // The total smoothing-parameter dimension is `D = (K−1) · n_terms`.
2066    // Medium-D formula fits need exact curvature to keep lambda selection away
2067    // from over-smoothed caps, while smooth-by-factor `D = 8` models still avoid
2068    // the O(D²) dense Hessian path.
2069    let total_rho_dim = m.saturating_mul(penalties_arc.len());
2070    let use_outer_hessian = multinomial_formula_use_outer_hessian(total_rho_dim);
2071
2072    // ── Inner-vs-outer control split (#715 non-convergence root cause) ────────
2073    // The legacy `max_iter` / `tol` parameters are the *outer* REML/LAML
2074    // smoothing-parameter optimization controls — "how hard to search λ". The
2075    // earlier wiring routed them straight into `inner_max_cycles` / `inner_tol`,
2076    // capping the joint-Newton inner solve at `max_iter` (=50 in the quality
2077    // suite) cycles with a `tol`-tight (=1e-8) KKT target. That is the #715
2078    // hang: near the simplex boundary the softmax Fisher weight
2079    // `W = diag(p) − p pᵀ` collapses, so `H = JᵀWJ + S_λ` is full-rank but
2080    // ILL-CONDITIONED. The self-vanishing Levenberg–Marquardt damping
2081    // (`levenberg_on_ill_conditioning()`) that keeps the inner solve from
2082    // oscillating on those near-singular modes makes it converge only
2083    // GEOMETRICALLY (linearly), not quadratically. Reaching a 1e-8 relative KKT
2084    // residual under geometric descent needs FAR more than 50 cycles, so the
2085    // inner returned `converged = false` on every outer ρ-evaluation; with the
2086    // exact-Hessian outer optimizer on `FallbackPolicy::Disabled` that rejects
2087    // every ρ-step — each rejected eval still paying a near-full 50-cycle inner
2088    // solve plus the O(D²) pairwise outer-Hessian directional work — so the
2089    // outer never certifies and the fit runs unbounded (the observed >8-minute
2090    // non-termination). The certificate cannot be reached, not merely slow.
2091    //
2092    // Fix: give the INNER joint-Newton the framework's principled production
2093    // budget (`DEFAULT_CUSTOM_FAMILY_INNER_MAX_CYCLES` cycles at the default
2094    // `inner_tol`), which exists precisely so an ill-conditioned LM-damped solve
2095    // can certify a stationary KKT point instead of being declared non-converged
2096    // prematurely — and the KKT/objective certificates still exit in a handful
2097    // of cycles on the well-conditioned interior fits, so this is free there.
2098    // The caller's `max_iter` / `tol` become the OUTER controls they were always
2099    // meant to be (smoothing-parameter search depth / accuracy). The inner KKT
2100    // target is kept no tighter than the outer accuracy can consume — and no
2101    // tighter than the softmax objective's f64 noise floor on near-separable
2102    // fits (see `MULTINOMIAL_FORMULA_INNER_TOL`).
2103    let outer_max_iter = max_iter.max(1);
2104    // The OUTER REML/LAML smoothing-parameter search must converge to a
2105    // well-calibrated ρ-gradient tolerance, NOT to the caller's (typically very
2106    // tight) INNER KKT tolerance. The #715 control-split repurposed the caller's
2107    // `tol` as the outer control, but feeding an inner-scale `tol = 1e-8`
2108    // straight into `outer_tol` makes REML grind dozens of extra exact-gradient
2109    // outer iterations (each an O(D·p³) Laplace-derivative assembly over the full
2110    // P·M joint design) to squeeze ρ digits that no longer move the fitted
2111    // surface — the smooth-by-factor 269s wall-clock overrun (#1082).
2112    //
2113    // The right target is the framework's CALIBRATED REML convergence tolerance,
2114    // `MULTINOMIAL_OUTER_REML_TOL = 1e-7` — the same value the primary GLM REML
2115    // outer uses (`solver::fit_orchestration::materialize` `tol: 1e-7`, mirrored by the
2116    // `LOG_LAMBDA_TOL`/`KKT_TOL_*` constants across the REML stack). At 1e-7 the
2117    // λ-search reaches the genuine REML optimum (so the recovered probability
2118    // surface matches the mature reference), but it does NOT chase the last
2119    // surface-irrelevant ρ digits down to 1e-8. The earlier 1e-5 floor (the
2120    // generic `BlockwiseFitOptions` default) was too LOOSE: the optimizer halted
2121    // in a low-curvature region with λ still well above its optimum, UNDER-fitting
2122    // the smooth-by-factor surface (truth-RMSE 0.164 vs VGAM's 0.061). So the
2123    // outer tolerance is floored at the calibrated REML tol — never tighter than
2124    // it (perf), never looser (accuracy) — while the caller's `tol` continues to
2125    // drive the INNER joint-Newton KKT target (`inner_tol` below), where its
2126    // precision actually matters.
2127    let outer_tol = if tol.is_finite() && tol > 0.0 {
2128        tol.max(MULTINOMIAL_OUTER_REML_TOL)
2129    } else {
2130        MULTINOMIAL_OUTER_REML_TOL
2131    };
2132    // #1082 root cause: the outer convergence test derives BOTH the absolute
2133    // projected-gradient floor (`max(outer_tol, n·1e-9)`) AND the relative-cost
2134    // stop (`rel_cost = outer_tol`) from the single `outer_tol`. The accuracy of
2135    // the smooth-by-factor surface is governed by the ABSOLUTE floor reaching the
2136    // n-scaled REML resolution `n·1e-9` (≈ 1.8e-6 at n = 1800) — that is why the
2137    // earlier 1e-5 floor UNDER-fit (its absolute floor was pinned at 1e-5, well
2138    // above the genuine optimum's gradient) and why 1e-7 recovered accuracy (it
2139    // unpins the floor down to the n-scaled 1.8e-6). But tightening `outer_tol`
2140    // to 1e-7 ALSO tightened the rel-cost stop to 1e-7, which on this family's
2141    // dead-flat REML ridge NEVER trips — so the optimizer no longer converges and
2142    // grinds all the way to `outer_max_iter`, each surplus step an O(D·p³) Laplace-
2143    // derivative assembly over the 382-dim joint design (the >600s wall-clock
2144    // overrun; tightening tol REINTRODUCED the crawl the 1e-5 floor had removed).
2145    //
2146    // The two requirements live on two different criteria, so they must be set
2147    // independently. Keep `outer_tol = 1e-7` (drives the accurate absolute floor)
2148    // but FLOOR the relative-cost stop at the framework default 1e-5 (the loose,
2149    // fast value that resolves the cost-decrease plateau without chasing the flat
2150    // tail). The absolute n·1e-9 floor still gates final λ accuracy; the rel-cost
2151    // stop just lets the optimizer DECLARE convergence on the flat ridge instead
2152    // of crawling to the iteration cap.
2153    let outer_rel_cost_tol = Some(BlockwiseFitOptions::default().outer_tol);
2154    let inner_tol = MULTINOMIAL_FORMULA_INNER_TOL.max(tol.max(0.0));
2155
2156    let options = BlockwiseFitOptions {
2157        inner_max_cycles: crate::custom_family::DEFAULT_CUSTOM_FAMILY_INNER_MAX_CYCLES,
2158        inner_tol,
2159        outer_max_iter,
2160        outer_tol,
2161        outer_rel_cost_tol,
2162        rho_lower_bound: multinomial_formula_min_lambda(y_one_hot.view()).ln(),
2163        ridge_floor: MULTINOMIAL_FORMULA_RIDGE_FLOOR,
2164        // #747: the stabilization floor is SOLVER-ONLY — it keeps the inner
2165        // joint-Newton linear solve finite during screening (bounding the step
2166        // `(H+δI)⁻¹∇` away from a near-separable, rank-deficient curvature) but
2167        // is excluded from the REML objective, the penalty log-determinant, and
2168        // the Laplace Hessian. The earlier default (`explicit_stabilization_pospart`)
2169        // folded `½·δ·‖β‖²` and a `δ`-shift of the log-determinant into the
2170        // criterion, shrinking every identified coefficient off the MLE and
2171        // perturbing smoothing-parameter selection — a fixed-λ prior masking
2172        // separation, not a numerical stabilizer. With the floor solver-only the
2173        // optimized objective is the true penalized REML criterion (value tracks
2174        // its analytic gradient), and the smooth directions remain governed
2175        // solely by their own REML-selected `λ`.
2176        ridge_policy: gam_problem::RidgePolicy::solver_only(),
2177        use_outer_hessian,
2178        // #715 real-data arm ("canonical-gauge null direction rejects all REML
2179        // seeds"): skip the multi-seed outer screening cascade and let the
2180        // pinned `init_lambda` ρ flow straight to the outer optimizer.
2181        //
2182        // The multinomial family declares `levenberg_on_ill_conditioning() ->
2183        // true`: near the simplex boundary (the near-separable penguins regime)
2184        // the softmax Fisher weight `W = diag(p) − p pᵀ → 0`, so the joint
2185        // information `H = JᵀWJ + S_λ` can become full-rank but
2186        // ILL-CONDITIONED. The self-vanishing LM damping that keeps the inner
2187        // joint-Newton from oscillating on those near-singular modes converges
2188        // only GEOMETRICALLY. The default screening policy ranks candidate seeds
2189        // with a 2-cycle inner cap (`outer_seed_config`); under geometric
2190        // LM-damped descent two cycles never reach a finite, meaningful proxy
2191        // objective, so EVERY capped seed can collapse to non-finite cost and
2192        // the cascade escalates to ×4, ×16, then an UNCAPPED full inner solve
2193        // PER SEED on the near-singular Hessian. That is the adapter-level face
2194        // of "all REML startup seeds rejected" and the multi-minute timeout.
2195        //
2196        // The pinned seed is already principled here: `init_lambda` gives every
2197        // (class, term) ρ a sensible moderate warm start, and the per-term
2198        // effective-df-floor upper bounds (`effective_df_floor_rho_upper_bounds`,
2199        // #715 arm (a)) keep any λ from collapsing the smooth onto its polynomial
2200        // null space. So the outer ARC/BFGS optimizer performs the real REML ρ
2201        // search from this seed; screening only adds the cascade cost and, on the
2202        // near-separable arm, the rejection stall.
2203        screen_initial_rho: false,
2204        // #1101: compute the joint Laplace posterior covariance `H⁻¹` (and the
2205        // influence matrix `F = H⁻¹ X'WX`) at the converged mode so the saved
2206        // model can surface delta-method per-class probability standard errors
2207        // and Wald smooth-term p-values. The driver factorizes the penalized
2208        // Hessian during the inner solve regardless; this only asks it to keep
2209        // and invert the factor instead of discarding it.
2210        compute_covariance: true,
2211        ..BlockwiseFitOptions::default()
2212    };
2213    // ── Conditional Firth/Jeffreys engagement (#715 arm (b) / #753) ──────────
2214    // Attempt 1: the unbiased criterion (Jeffreys disarmed above). If the
2215    // returned mode is converged, finite, and interior, it is the exact penalized-REML
2216    // optimum with zero Firth bias — accept it (this is the synthetic-arm /
2217    // interior-data path, #715 arm (a)). If the solve FAILS (e.g. the
2218    // (quasi-)separated penguins geometry where `(H + S_λ)v ≈ 0` along
2219    // penalty-null directions for EVERY ρ rejects every REML startup seed) or
2220    // returns a non-finite artifact, that is direct separation evidence:
2221    // re-solve once with the full-span Jeffreys/Firth proper prior armed, which
2222    // supplies the O(1) curvature on the quotient-null subspace that smoothing
2223    // parameters mathematically cannot (`Sv = 0` ⇒ λ never touches `v`). The
2224    // Firth refit is the accepted result only when the unbiased formula solve
2225    // failed, did not converge on its full budget, or blew up; finite
2226    // formula-path logits can be large on valid near-separated optima and
2227    // should not be shrunk toward the uniform simplex once the unbiased outer
2228    // solve has actually certified.
2229    let mut unbiased_probe_options = options.clone();
2230    unbiased_probe_options.outer_max_iter = unbiased_probe_options
2231        .outer_max_iter
2232        .min(MULTINOMIAL_UNBIASED_PROBE_OUTER_MAX_ITER);
2233    // The FINAL accepted Firth/Jeffreys refit runs to the caller's full outer
2234    // budget: it is the result we ship, so it must reach the genuine REML
2235    // optimum, not a truncated iterate. The near-separable penguin refit that
2236    // motivated #1082's wall-clock concern is now halted honestly at its true
2237    // bound optimum by the KKT-stationary-at-bound guard
2238    // (`CostStallGuard`, #1082 / 64711ed82) and the Newton-decrement residual
2239    // certificate (363af9b56 / 2c9580b1f): on separable data the outer ARC
2240    // certifies and stops early on its own, so no artificial iteration cap is
2241    // needed to land in budget. On non-separable data (e.g. the
2242    // `vgam_smooth_by_factor` double-penalty arm) the refit needs the caller's
2243    // full budget to converge, which a `.min(20)` cap would cut off — accepting
2244    // a non-converged fit, which is dishonest. So the refit keeps `options`
2245    // unchanged. Only the discarded unbiased separation probe above is capped.
2246    let firth_refit_options = &options;
2247
2248    let run_firth_refit = |evidence: String| {
2249        let firth_family = family.clone().with_joint_jeffreys_term(true);
2250        fit_custom_family_with_rho_prior(
2251            &firth_family,
2252            &blocks,
2253            firth_refit_options,
2254            gam_problem::RhoPrior::Flat,
2255        )
2256        .map_err(|err| {
2257            EstimationError::InvalidInput(format!(
2258                "multinomial REML: Firth/Jeffreys-armed refit (separation evidence: \
2259                 {evidence}) failed: {err}"
2260            ))
2261        })
2262    };
2263
2264    // #1082: the capped unbiased probe and the (separable-path) Firth decision
2265    // are driven by separation scans over the full P×M logit block. The previous
2266    // match recomputed `multinomial_formula_separation_evidence` /
2267    // `..._unresolved_probe_separation_evidence` in BOTH the match guard AND the
2268    // arm body — three to four full logit walks per fit, paid on the hot
2269    // near-separable penguin path where this branch fires every iterate. Run the
2270    // probe once, evaluate each scan once into a binding, and branch on the
2271    // precomputed results. Behaviour is identical (same scans, same order of
2272    // precedence: converged-interior, unresolved-probe-separation,
2273    // no-separation-needs-full-solve, otherwise-Firth); only the duplicate
2274    // O(n·classes) scans are removed.
2275    let probe_attempt = fit_custom_family_with_rho_prior(
2276        &family,
2277        &blocks,
2278        &unbiased_probe_options,
2279        gam_problem::RhoPrior::Flat,
2280    );
2281    let fit = match probe_attempt {
2282        Ok(probe_fit) => {
2283            let separation = multinomial_formula_separation_evidence(&probe_fit.block_states);
2284            if probe_fit.outer_converged && separation.is_none() {
2285                // Interior, converged, no separation: accept the probe directly.
2286                probe_fit
2287            } else if let Some(evidence) =
2288                multinomial_formula_unresolved_probe_separation_evidence(&probe_fit.block_states)
2289            {
2290                // Non-converged probe already carrying separation-scale logits:
2291                // hand straight to the proper-prior Firth refit (do not spend the
2292                // full unbiased budget grinding the λ→0 separable ridge).
2293                run_firth_refit(format!(
2294                    "unbiased-criterion REML probe did not converge after {} outer iterations; {evidence}",
2295                    probe_fit.outer_iterations
2296                ))?
2297            } else if separation.is_none() {
2298                // Interior but the capped probe ran out of iterations without
2299                // certifying: re-solve at the caller's full outer budget.
2300                //
2301                // #1082 wall-clock: the capped probe is a strict prefix of this
2302                // solve from the same family/seed, so a COLD restart repeats the
2303                // probe's outer iterations. WARM-START the re-solve from the ρ the
2304                // probe already reached — seed each block's `initial_log_lambdas`
2305                // from the probe's selected `log_lambdas` (same block/penalty
2306                // order: the flat vector concatenates per-block penalties in block
2307                // order, exactly the order `build_block_specs()` emits them). This
2308                // changes only the optimizer's STARTING point, never the objective
2309                // or its optimum, but lets the full solve resume near the probe's
2310                // last iterate instead of crawling up from `init_lambda` again —
2311                // removing the probe-iterations double-pay on the non-separable
2312                // (e.g. `vgam_smooth_by_factor`) arm. If the probe's λ vector does
2313                // not line up with the block layout (it always should), fall back
2314                // to the cold `blocks` seed.
2315                let warm_blocks = warm_start_blocks_from_log_lambdas(
2316                    &blocks,
2317                    probe_fit.log_lambdas.as_slice().unwrap_or(&[]),
2318                );
2319                let resolve_blocks = warm_blocks.as_deref().unwrap_or(&blocks);
2320                match fit_custom_family_with_rho_prior(
2321                    &family,
2322                    resolve_blocks,
2323                    &options,
2324                    gam_problem::RhoPrior::Flat,
2325                ) {
2326                    Ok(full_unbiased_fit) => {
2327                        let full_separation = multinomial_formula_separation_evidence(
2328                            &full_unbiased_fit.block_states,
2329                        );
2330                        if full_unbiased_fit.outer_converged && full_separation.is_none() {
2331                            full_unbiased_fit
2332                        } else {
2333                            let evidence = full_separation.unwrap_or_else(|| {
2334                                format!(
2335                                    "full unbiased-criterion REML solve did not converge after {} outer iterations",
2336                                    full_unbiased_fit.outer_iterations
2337                                )
2338                            });
2339                            run_firth_refit(evidence)?
2340                        }
2341                    }
2342                    Err(err) => run_firth_refit(format!(
2343                        "full unbiased-criterion REML solve failed: {err}"
2344                    ))?,
2345                }
2346            } else {
2347                // Probe converged (or capped) but shows interior separation
2348                // evidence: Firth refit using the already-computed scan.
2349                let evidence = separation.unwrap_or_else(|| {
2350                    format!(
2351                        "unbiased-criterion REML probe did not converge after {} outer iterations",
2352                        probe_fit.outer_iterations
2353                    )
2354                });
2355                run_firth_refit(evidence)?
2356            }
2357        }
2358        Err(err) => run_firth_refit(format!("unbiased-criterion REML solve failed: {err}"))?,
2359    };
2360    if let Some(err) = multinomial_formula_separation_diagnostic(
2361        fit.inner_cycles,
2362        fit.outer_iterations,
2363        &fit.block_states,
2364    ) {
2365        return Err(err);
2366    }
2367
2368    // ── Repack coefficients (P, K-1) from per-block β vectors ─────────────
2369    if fit.blocks.len() != m {
2370        crate::bail_invalid_estim!(
2371            "multinomial REML: expected {m} fitted blocks (K-1), got {}",
2372            fit.blocks.len()
2373        );
2374    }
2375    let p_per_class = fit.blocks[0].beta.len();
2376    let mut coefficients_active = Array2::<f64>::zeros((p_per_class, m));
2377    for (a, block) in fit.blocks.iter().enumerate() {
2378        if block.beta.len() != p_per_class {
2379            crate::bail_invalid_estim!(
2380                "multinomial REML: block {a} has {} coefs, expected {p_per_class}",
2381                block.beta.len()
2382            );
2383        }
2384        for i in 0..p_per_class {
2385            coefficients_active[[i, a]] = block.beta[i];
2386        }
2387    }
2388    // Map the standardized-column coefficients back to raw units (the exact
2389    // inverse of the conditioning reparameterization above): β_raw = b/s, with
2390    // the centering mass `Σ_j b_j·m_j/s_j` returned to the intercept.
2391    if !parametric_standardization.is_empty() {
2392        let intercept_col = design.intercept_range.clone().next();
2393        for a in 0..m {
2394            let mut intercept_adjust = 0.0;
2395            for &(col, center, scale) in &parametric_standardization {
2396                if col < p_per_class {
2397                    let raw = coefficients_active[[col, a]] / scale;
2398                    coefficients_active[[col, a]] = raw;
2399                    intercept_adjust += raw * center;
2400                }
2401            }
2402            if let Some(i0) = intercept_col
2403                && i0 < p_per_class
2404            {
2405                coefficients_active[[i0, a]] -= intercept_adjust;
2406            }
2407        }
2408    }
2409    // Flatten every (class, term) smoothing parameter in block-major order
2410    // (class 0's terms, then class 1's, …). With per-term penalties each block
2411    // now carries one λ per smooth term, so a single λ per class would discard
2412    // the independent per-term selection that fixes #561. `lambdas_per_block`
2413    // segments the flat vector by class so callers can recover per-term λ.
2414    // ── gam#1587/#561 joint-penalty reconstruction ───────────────────────────
2415    // Under the #1587 centered-metric architecture every active class block
2416    // leaves its per-block penalty list EMPTY — the entire fit's smoothing rides
2417    // on a single full-width JOINT penalty `S_λ = Σ_t λ_t (M ⊗ S_t)` whose one
2418    // shared `λ_t` per smooth component is selected by the outer REML loop and
2419    // surfaced on `fit.artifacts.joint_log_lambdas`. So `fit.blocks[a].lambdas`
2420    // is `[]`, the inference layer's per-block trace channel is empty, and the
2421    // older per-block reporting (`lambdas_per_block = [0, 0]`, `edf_per_class =
2422    // None`, …) collapsed (#561 reopen).
2423    //
2424    // Reconstruct the per-(class, component) λ and the influence-matrix EDF
2425    // directly from the selected joint `λ_t` and the COUPLED penalty
2426    // `S_λ = Σ_t λ_t (M ⊗ S_t)` (NOT a block-diagonal `Σ_t λ_{a,t} S_t`: the
2427    // centered metric `M` couples classes off the block diagonal, so a
2428    // block-diagonal `S_λ` would mis-state both the influence matrix and every
2429    // trace). With `H⁻¹ = fit.covariance_conditional` now assembled WITH the
2430    // joint penalty (the `compute_joint_covariance` fix), the influence matrix is
2431    // exactly `F = I − H⁻¹ S_λ`, its per-class diagonal-block trace is the honest
2432    // per-class EDF, and `Σ_a edf_a = tr(F) = edf_total`.
2433    let joint_recon = fit.artifacts.joint_log_lambdas.as_ref().and_then(|jll| {
2434        let n_components = penalties_arc.len();
2435        if jll.len() != n_components || n_components == 0 {
2436            return None;
2437        }
2438        let expected_joint = p_per_class.saturating_mul(m);
2439        let hinv = fit
2440            .covariance_conditional
2441            .as_ref()
2442            .filter(|c| c.nrows() == expected_joint && c.ncols() == expected_joint)?;
2443        // The coupled joint penalty components `M ⊗ S_t` at the selected `λ_t`,
2444        // in raw stacked (class-major) coordinates — exactly the operator the
2445        // inner solve and the now-fixed covariance path penalize with.
2446        let joint_specs = family.centered_joint_penalty_specs();
2447        if joint_specs.len() != n_components {
2448            return None;
2449        }
2450        let lam: Vec<f64> = jll.iter().map(|&l| l.exp()).collect();
2451        // Per-component `H⁻¹ (M ⊗ S_t)` (full mp×mp), reused for both the joint
2452        // influence matrix and the per-(class, component) trace decomposition.
2453        let mut hinv_st: Vec<Array2<f64>> = Vec::with_capacity(n_components);
2454        for spec in &joint_specs {
2455            if spec.matrix.nrows() != expected_joint || spec.matrix.ncols() != expected_joint {
2456                return None;
2457            }
2458            hinv_st.push(hinv.dot(&spec.matrix));
2459        }
2460        // F = I − H⁻¹ S_λ = I − Σ_t λ_t H⁻¹ (M ⊗ S_t).
2461        let mut f = Array2::<f64>::eye(expected_joint);
2462        for (t, hs) in hinv_st.iter().enumerate() {
2463            f.scaled_add(-lam[t], hs);
2464        }
2465        // Per-class diagonal-block trace of F (the honest per-class EDF), and the
2466        // per-(class, component) penalty trace `tr_{a,t} = λ_t · Σ_{i∈class a}
2467        // (H⁻¹ (M⊗S_t))[i,i]` for the per-penalty EDF rollup.
2468        let mut edf_per_class = Vec::with_capacity(m);
2469        // class-major per-penalty EDF (class 0's components, then class 1's, …),
2470        // aligned 1:1 with the flat per-component λ replicated per class.
2471        let mut edf_per_penalty = Vec::with_capacity(m * n_components);
2472        for a in 0..m {
2473            let base = a * p_per_class;
2474            let mut class_trace = 0.0_f64;
2475            for t in 0..n_components {
2476                let mut tr_at = 0.0_f64;
2477                for i in 0..p_per_class {
2478                    tr_at += hinv_st[t][[base + i, base + i]];
2479                }
2480                tr_at *= lam[t];
2481                class_trace += tr_at;
2482                // A single component's per-class trace EDF `rank(S_t) − tr_{a,t}`,
2483                // bounded by its local rank (≤ p_per_class).
2484                let ns_t = nullspace_dims_arc.get(t).copied().unwrap_or(0);
2485                let rank_t = (p_per_class as f64 - ns_t as f64).max(0.0);
2486                edf_per_penalty.push((rank_t - tr_at).clamp(0.0, p_per_class as f64));
2487            }
2488            edf_per_class
2489                .push((p_per_class as f64 - class_trace).clamp(0.0, p_per_class as f64));
2490        }
2491        Some((f, edf_per_class, edf_per_penalty, n_components, lam))
2492    });
2493
2494    // Flatten every (class, component) smoothing parameter in class-major order.
2495    // Under the joint-penalty architecture each active class carries the SAME
2496    // per-component λ set (the centered metric ties `λ_t` across classes for
2497    // reference-class invariance), so the flat vector is the selected `λ_t`
2498    // replicated `K-1` times and `lambdas_per_block = [n_components; K-1]`. When
2499    // the joint reconstruction is unavailable (legacy fixed-λ path or absent
2500    // covariance) fall back to the raw — now empty — per-block λ lists.
2501    let (lambdas_per_block, lambdas_flat): (Vec<usize>, Vec<f64>) = match joint_recon.as_ref() {
2502        Some((_, _, _, n_components, lam)) => {
2503            let per_block = vec![*n_components; m];
2504            let mut flat = Vec::with_capacity(m * n_components);
2505            for _ in 0..m {
2506                flat.extend(lam.iter().copied());
2507            }
2508            (per_block, flat)
2509        }
2510        None => {
2511            let per_block: Vec<usize> = fit.blocks.iter().map(|b| b.lambdas.len()).collect();
2512            let flat: Vec<f64> = fit
2513                .blocks
2514                .iter()
2515                .flat_map(|b| b.lambdas.iter().copied())
2516                .collect();
2517            (per_block, flat)
2518        }
2519    };
2520    // Per-active-class effective degrees of freedom, length `K-1`, summing to
2521    // the model `edf_total`. The REML inference block reports `edf_by_block` as
2522    // ONE entry per *penalty block* (per (class, term, penalty)), each computed
2523    // as `rank(S_kk) − tr(H⁻¹ λ_kk S_kk)`. That per-block sum OVER-COUNTS the
2524    // model EDF whenever several penalties share one coefficient range — a
2525    // double-penalty / te / ti / adaptive smooth has ≥2 penalty blocks over the
2526    // same columns, so `Σ_kk rank(S_kk) > p` and `Σ_kk edf_by_block > edf_total`
2527    // (the observed ~79 for a ~24-coefficient model). Handing that raw per-block
2528    // vector out as the documented length-(K-1) per-class EDF is therefore both
2529    // the wrong LENGTH (it is `Σ_a n_blocks_a`, not `K-1`) and an over-count.
2530    //
2531    // The honest per-class EDF is the influence-matrix trace over each class's
2532    // coefficient block. Classes occupy DISJOINT `p_per_class`-wide coefficient
2533    // ranges, and the per-block traces `tr_kk = tr(H⁻¹ λ_kk S_kk)` are additive
2534    // (no rank double-counting), so class `a`'s EDF is
2535    // `p_per_class − Σ_{kk ∈ class a} tr_kk`, and `Σ_a edf_a = m·p_per_class −
2536    // Σ_kk tr_kk = p − Σ tr_kk = edf_total` exactly. Segment the block-major
2537    // `penalty_block_trace` by `lambdas_per_block` (the same per-class λ-count
2538    // segmentation `lambdas_flat` uses). Fall back to `None` when the trace
2539    // channel is unavailable or mis-shaped (legacy fixed-λ path), exactly as the
2540    // raw `edf_by_block` map did before.
2541    let edf_per_class = joint_recon
2542        .as_ref()
2543        .map(|(_, epc, _, _, _)| epc.clone())
2544        .or_else(|| {
2545            // Legacy per-block trace path (fixed-λ / pre-#1587 fits whose
2546            // smoothing is still carried per block). Segment the block-major
2547            // `penalty_block_trace` by `lambdas_per_block`, exactly as before.
2548            fit.inference.as_ref().and_then(|info| {
2549                let traces = &info.penalty_block_trace;
2550                if traces.len() != lambdas_per_block.iter().sum::<usize>() {
2551                    return None;
2552                }
2553                let mut per_class = Vec::with_capacity(m);
2554                let mut cursor = 0usize;
2555                for &n_blocks in &lambdas_per_block {
2556                    let class_trace: f64 = traces[cursor..cursor + n_blocks].iter().sum();
2557                    per_class
2558                        .push((p_per_class as f64 - class_trace).clamp(0.0, p_per_class as f64));
2559                    cursor += n_blocks;
2560                }
2561                Some(per_class)
2562            })
2563        });
2564    // Per-PENALTY EDF: the inference layer's `edf_by_block` is already the
2565    // clamped per-penalty-block trace EDF `rank(S_k) − λ_k·tr(H⁻¹ S_k)`, one
2566    // entry per smoothing parameter and block-major aligned 1:1 with the flat
2567    // `lambdas`. Surface it verbatim (guarding only on the length contract) so
2568    // consumers can inspect per-(class, term, penalty) collapse onto the null
2569    // space — a signal the per-class EDF SUM hides. This is NOT a per-class
2570    // total: with double-penalty smooths `Σ_k rank(S_k) > p_per_class`, so the
2571    // entries deliberately need not sum to the model EDF (the per-class field
2572    // carries that contract instead).
2573    let edf_per_penalty = joint_recon
2574        .as_ref()
2575        .map(|(_, _, epp, _, _)| epp.clone())
2576        .or_else(|| {
2577            // Legacy per-block path: the inference layer's `edf_by_block` is
2578            // already the clamped per-penalty-block trace EDF, aligned 1:1 with
2579            // the flat `lambdas`.
2580            fit.inference.as_ref().and_then(|info| {
2581                if info.edf_by_block.len() != lambdas_flat.len() {
2582                    return None;
2583                }
2584                Some(
2585                    info.edf_by_block
2586                        .iter()
2587                        .map(|&e| e.max(0.0))
2588                        .collect::<Vec<f64>>(),
2589                )
2590            })
2591        });
2592    let coefficients_flat: Vec<f64> = coefficients_active.iter().copied().collect();
2593
2594    // #1101: surface the joint Laplace posterior covariance `H⁻¹` (block-ordered
2595    // [β_0; …; β_{K-2}]) and the influence matrix `F = H⁻¹ X'WX` the REML driver
2596    // computed at the converged mode. These power the predict path's delta-method
2597    // per-class probability standard errors and the summary's Wald smooth-term
2598    // tests. The joint matrices are `(P·M)×(P·M)`. The covariance is mapped back
2599    // to RAW units (see below) so it pairs with the raw predict design; the
2600    // influence is kept in the fitted basis (the Wald table only slices penalized
2601    // columns, which the standardization affine leaves identity-mapped).
2602    let expected_joint = p_per_class.saturating_mul(m);
2603    // The joint Hessian (and thus `H⁻¹`) was assembled in the STANDARDIZED
2604    // parametric basis used during fitting, while the saved coefficients and the
2605    // raw predict design are in raw units. Map the covariance to raw units with
2606    // the same exact affine reparameterization `β_raw = A β_std`: for each
2607    // standardized parametric column `col`, `β_raw[col] = β_std[col]/scale` and
2608    // the intercept absorbs `−Σ_col (center/scale)·β_std[col]`. So `A = I` except
2609    // `A[col,col] = 1/scale` and `A[i0,col] = −center/scale`, replicated
2610    // block-diagonally per active class, and `Cov_raw = A Cov_std Aᵀ`. With no
2611    // standardization (`parametric_standardization` empty) `A = I` and this is a
2612    // no-op. The smooth-term (penalized) columns are untouched by `A`, so the
2613    // Wald table's per-term blocks are identical in both bases.
2614    let intercept_col0 = design.intercept_range.clone().next();
2615    let build_per_class_affine = |amat: &mut Array2<f64>| {
2616        for &(col, center, scale) in &parametric_standardization {
2617            if col >= p_per_class {
2618                continue;
2619            }
2620            amat[[col, col]] = 1.0 / scale;
2621            if let Some(i0) = intercept_col0
2622                && i0 < p_per_class
2623            {
2624                amat[[i0, col]] = -center / scale;
2625            }
2626        }
2627    };
2628    let coefficient_covariance_flat = fit
2629        .covariance_conditional
2630        .as_ref()
2631        .filter(|c| c.nrows() == expected_joint && c.ncols() == expected_joint)
2632        .map(|cov_std| {
2633            if parametric_standardization.is_empty() {
2634                return cov_std.iter().copied().collect::<Vec<f64>>();
2635            }
2636            // Block-diagonal joint A (same per active class).
2637            let mut a_joint = Array2::<f64>::eye(expected_joint);
2638            let mut a_class = Array2::<f64>::eye(p_per_class);
2639            build_per_class_affine(&mut a_class);
2640            for a in 0..m {
2641                let base = a * p_per_class;
2642                for i in 0..p_per_class {
2643                    for j in 0..p_per_class {
2644                        a_joint[[base + i, base + j]] = a_class[[i, j]];
2645                    }
2646                }
2647            }
2648            let cov_raw = a_joint.dot(cov_std).dot(&a_joint.t());
2649            cov_raw.iter().copied().collect::<Vec<f64>>()
2650        });
2651    // The influence matrix `F = H⁻¹ X'WX = H⁻¹(H − S_λ) = I − H⁻¹ S_λ`. The
2652    // exact-Newton multinomial blocks carry no IRLS pseudo-data, so the generic
2653    // inference path does not export `coefficient_influence`; reconstruct it
2654    // exactly here. Under the #1587 joint-penalty architecture the penalty is the
2655    // COUPLED centered metric `S_λ = Σ_t λ_t (M ⊗ S_t)` (off the class-block
2656    // diagonal), already assembled in `joint_recon` above, so reuse that exact
2657    // `F`. Only fall back to the legacy block-diagonal `Σ_t λ_{a,t} S_t`
2658    // reconstruction when the joint reconstruction is unavailable (pre-#1587
2659    // per-block fits whose class blocks still carry their own penalties).
2660    let coefficient_influence_flat = match joint_recon.as_ref() {
2661        Some((f, _, _, _, _)) => Some(f.iter().copied().collect::<Vec<f64>>()),
2662        None => fit
2663            .covariance_conditional
2664            .as_ref()
2665            .filter(|c| c.nrows() == expected_joint && c.ncols() == expected_joint)
2666            .and_then(|hinv| {
2667                if fit.blocks.len() != m {
2668                    return None;
2669                }
2670                // Joint S_λ (block-diagonal across active classes).
2671                let mut s_lambda = Array2::<f64>::zeros((expected_joint, expected_joint));
2672                for (a, block) in fit.blocks.iter().enumerate() {
2673                    if block.lambdas.len() != penalties_arc.len() {
2674                        return None;
2675                    }
2676                    let base = a * p_per_class;
2677                    for (t, pen) in penalties_arc.iter().enumerate() {
2678                        let lam = block.lambdas[t];
2679                        if lam == 0.0 {
2680                            continue;
2681                        }
2682                        let dense = pen.to_dense();
2683                        if dense.nrows() != p_per_class || dense.ncols() != p_per_class {
2684                            return None;
2685                        }
2686                        for i in 0..p_per_class {
2687                            for j in 0..p_per_class {
2688                                s_lambda[[base + i, base + j]] += lam * dense[[i, j]];
2689                            }
2690                        }
2691                    }
2692                }
2693                // F = I − H⁻¹ S_λ.
2694                let hinv_s = hinv.dot(&s_lambda);
2695                let mut f = Array2::<f64>::eye(expected_joint);
2696                f -= &hinv_s;
2697                Some(f.iter().copied().collect::<Vec<f64>>())
2698            }),
2699    };
2700
2701    // Per-(smooth term) coefficient span within a single class block, deduped by
2702    // col_range (the #561 double-penalty migration emits two penalty blocks per
2703    // term sharing one col_range; the Wald test covers the whole term block once).
2704    let mut smooth_term_spans: Vec<MultinomialSmoothTermSpan> = Vec::new();
2705    for (pen_idx, bp) in design.penalties.iter().enumerate() {
2706        let col_start = bp.col_range.start;
2707        let col_end = bp.col_range.end;
2708        if col_start >= col_end || col_end > p_per_class {
2709            continue;
2710        }
2711        if smooth_term_spans
2712            .iter()
2713            .any(|s| s.col_start == col_start && s.col_end == col_end)
2714        {
2715            continue;
2716        }
2717        let label = design
2718            .penaltyinfo
2719            .get(pen_idx)
2720            .and_then(|info| info.termname.clone())
2721            .unwrap_or_else(|| format!("s{pen_idx}"));
2722        let nullspace_dim = design
2723            .nullspace_dims
2724            .get(pen_idx)
2725            .copied()
2726            .unwrap_or(0)
2727            .min(col_end - col_start);
2728        smooth_term_spans.push(MultinomialSmoothTermSpan {
2729            label,
2730            col_start,
2731            col_end,
2732            nullspace_dim,
2733        });
2734    }
2735
2736    // One descriptive label per penalty *component* within a single class block,
2737    // parallel to that block's λ slice (#1544). `design.penalties` is index-
2738    // parallel to every active class's `block.lambdas` (each block carries the
2739    // full per-component penalty list, validated above by
2740    // `block.lambdas.len() == penalties_arc.len()`), so iterating it in order
2741    // yields exactly `lambdas_per_block[0]` labels aligned with the per-block λ.
2742    // This is deliberately NOT deduped by col_range (unlike `smooth_term_spans`):
2743    // the double penalty's primary and null-space components share one col_range
2744    // but select independent λ, and each must keep its own label so the summary
2745    // renderer never collapses or drops a λ.
2746    let lambda_labels: Vec<String> = design
2747        .penalties
2748        .iter()
2749        .enumerate()
2750        .map(|(pen_idx, _)| penalty_component_label(design.penaltyinfo.get(pen_idx), pen_idx))
2751        .collect();
2752
2753    // Unpenalized deviance read directly from the converged unpenalized
2754    // log-likelihood the rho-prior driver already computed (issue #348):
2755    // MultinomialFamily::evaluate sets FamilyEvaluation.log_likelihood =
2756    // log_lik(η, y) with no penalty term, and that value flows unchanged into
2757    // UnifiedFitResult.log_likelihood. This reproduces the legacy fixed-λ
2758    // path's `deviance = -2 · log_lik` contract bit-for-bit, so the previous
2759    // row-by-row η = Xβ rebuild and softmax recompute were pure dead work.
2760    let deviance = -2.0 * fit.log_likelihood;
2761
2762    Ok(MultinomialSavedModel {
2763        formula: formula.to_string(),
2764        class_levels: class_levels.clone(),
2765        reference_class_index: class_levels.len() - 1,
2766        resolved_termspec: spec,
2767        coefficients_flat,
2768        p_per_class,
2769        n_active_classes: m,
2770        training_headers: data.headers.clone(),
2771        lambdas: lambdas_flat,
2772        lambdas_per_block,
2773        iterations: fit.inner_cycles,
2774        converged: fit.outer_converged,
2775        penalized_neg_log_likelihood: -fit.log_likelihood + 0.5 * fit.stable_penalty_term,
2776        deviance,
2777        edf_per_class,
2778        edf_per_penalty,
2779        coefficient_covariance_flat,
2780        coefficient_influence_flat,
2781        smooth_term_spans,
2782        lambda_labels,
2783    })
2784}
2785
2786/// Replay the saved termspec to build the predict-time dense design `X` on a
2787/// fresh dataset, realigning feature columns **by name** so the predict frame
2788/// need not reproduce the training column order or carry the response column.
2789/// Shared by every multinomial predict path (probabilities, SE bands, and the
2790/// posterior-predictive replicate draws).
2791fn build_multinomial_predict_design(
2792    model: &MultinomialSavedModel,
2793    data: &EncodedDataset,
2794) -> Result<Array2<f64>, EstimationError> {
2795    // The saved termspec stores feature columns as absolute indices into the
2796    // *training* table `[response, features...]`. Realign them onto this
2797    // dataset's columns by name, so prediction works on label-free new data
2798    // (the response column is never referenced by any term; issue #803).
2799    let predict_columns = data.column_map();
2800    let realigned = model.resolved_termspec.remap_feature_columns(
2801        |index| -> Result<usize, EstimationError> {
2802            let name = model.training_headers.get(index).ok_or_else(|| {
2803                EstimationError::InvalidInput(format!(
2804                    "multinomial predict: saved training column index {index} is out of bounds \
2805                     for {} training headers",
2806                    model.training_headers.len()
2807                ))
2808            })?;
2809            resolve_role_col(&predict_columns, name, "feature")
2810                .map_err(|err| EstimationError::InvalidInput(err.to_string()))
2811        },
2812    )?;
2813    let design = build_term_collection_design(data.values.view(), &realigned).map_err(|err| {
2814        EstimationError::InvalidInput(format!(
2815            "multinomial predict: rebuild design from saved termspec: {err}"
2816        ))
2817    })?;
2818    let x_dense = design
2819        .design
2820        .try_to_dense_by_chunks("multinomial predict design")
2821        .map_err(EstimationError::InvalidInput)?;
2822    if x_dense.ncols() != model.p_per_class {
2823        crate::bail_invalid_estim!(
2824            "multinomial predict: predict design has {} cols, saved model expects {}",
2825            x_dense.ncols(),
2826            model.p_per_class
2827        );
2828    }
2829    Ok(x_dense)
2830}
2831
2832/// Replay the saved termspec to build the predict-time design on a fresh
2833/// dataset, then evaluate softmax probabilities. The predict dataset must carry
2834/// the same feature columns the training data did, matched **by name** — it need
2835/// not reproduce the training column order, and in particular need not carry the
2836/// response column (prediction is for label-free new data).
2837pub fn predict_multinomial_formula(
2838    model: &MultinomialSavedModel,
2839    data: &EncodedDataset,
2840) -> Result<Array2<f64>, EstimationError> {
2841    let x_dense = build_multinomial_predict_design(model, data)?;
2842    Ok(model.predict_probabilities(x_dense.view()))
2843}
2844
2845/// Draw `n_draws` posterior-predictive replicate class-label assignments for a
2846/// saved multinomial model on fresh data (#1101). Rebuilds the predict design
2847/// exactly as [`predict_multinomial_formula`], then samples each row's class
2848/// from `Categorical(softmax(X·β̂))` (see
2849/// [`MultinomialSavedModel::sample_replicate_classes`]). Returns an
2850/// `(n_draws, N)` matrix of class INDICES `0..K` aligned to `model.class_levels`,
2851/// deterministic in `seed`.
2852pub fn posterior_predict_multinomial_formula(
2853    model: &MultinomialSavedModel,
2854    data: &EncodedDataset,
2855    n_draws: usize,
2856    seed: u64,
2857) -> Result<Array2<u32>, EstimationError> {
2858    if n_draws == 0 {
2859        crate::bail_invalid_estim!("multinomial posterior_predict: n_draws must be >= 1");
2860    }
2861    let x_dense = build_multinomial_predict_design(model, data)?;
2862    Ok(model.sample_replicate_classes(x_dense.view(), n_draws, seed))
2863}
2864
2865/// Predict class probabilities AND delta-method per-class probability standard
2866/// errors for a saved multinomial model on fresh data (#1101). Replays the
2867/// saved termspec to build the predict design exactly as
2868/// [`predict_multinomial_formula`], then applies the softmax-Jacobian delta
2869/// method against the stored joint posterior covariance. Returns
2870/// `(probs (N,K), prob_se (N,K) | None)`; `prob_se` is `None` for a legacy
2871/// model fitted before covariance was surfaced.
2872pub fn predict_multinomial_formula_with_se(
2873    model: &MultinomialSavedModel,
2874    data: &EncodedDataset,
2875) -> Result<(Array2<f64>, Option<Array2<f64>>), EstimationError> {
2876    let x_dense = build_multinomial_predict_design(model, data)?;
2877    Ok(model.predict_probabilities_with_se(x_dense.view()))
2878}
2879
2880#[cfg(test)]
2881mod fisher_override_tests {
2882    use super::*;
2883    use ndarray::Array3;
2884
2885    fn toy() -> (Array2<f64>, Array2<f64>, Array2<f64>, Array1<f64>) {
2886        let n = 15;
2887        let p = 2;
2888        let k = 3;
2889        let design =
2890            Array2::<f64>::from_shape_fn(
2891                (n, p),
2892                |(i, j)| {
2893                    if j == 0 { 1.0 } else { ((i + 2) as f64).cos() }
2894                },
2895            );
2896        let mut y = Array2::<f64>::zeros((n, k));
2897        for i in 0..n {
2898            y[[i, i % k]] = 1.0;
2899        }
2900        let penalty = Array2::<f64>::eye(p);
2901        let lambdas = Array1::<f64>::from_elem(k - 1, 0.5);
2902        (design, y, penalty, lambdas)
2903    }
2904
2905    #[test]
2906    fn fisher_override_none_reproduces_analytic() {
2907        // Issue #349: None override is exactly the analytic fit.
2908        let (design, y, penalty, lambdas) = toy();
2909        let mk = |over: Option<ndarray::ArrayView3<'_, f64>>| {
2910            fit_penalized_multinomial(MultinomialFitInputs {
2911                design: design.view(),
2912                y_one_hot: y.view(),
2913                penalty: penalty.view(),
2914                lambdas: lambdas.view(),
2915                row_weights: None,
2916                fisher_w_override: over,
2917                max_iter: 50,
2918                tol: 1.0e-9,
2919            })
2920            .expect("fit must succeed")
2921        };
2922        let a = mk(None);
2923        let b = mk(None);
2924        for (x, z) in a
2925            .coefficients_active
2926            .iter()
2927            .zip(b.coefficients_active.iter())
2928        {
2929            assert_eq!(x, z);
2930        }
2931    }
2932
2933    #[test]
2934    fn fisher_override_wrong_shape_is_rejected() {
2935        let (design, y, penalty, lambdas) = toy();
2936        let n = design.nrows();
2937        let m = y.ncols(); // K, not K-1 — deliberately wrong
2938        let bad = Array3::<f64>::zeros((n, m, m));
2939        let err = fit_penalized_multinomial(MultinomialFitInputs {
2940            design: design.view(),
2941            y_one_hot: y.view(),
2942            penalty: penalty.view(),
2943            lambdas: lambdas.view(),
2944            row_weights: None,
2945            fisher_w_override: Some(bad.view()),
2946            max_iter: 50,
2947            tol: 1.0e-9,
2948        })
2949        .expect_err("wrong active-block shape must error");
2950        assert!(format!("{err}").contains("fisher_w_override shape"));
2951    }
2952
2953    /// #1101 regression: the fixed-λ inner solve now surfaces the joint Laplace
2954    /// coefficient covariance `H⁻¹`, and the multinomial predictor derives
2955    /// finite delta-method per-class probability standard errors from it. Before
2956    /// this change `MultinomialFitOutputs` carried NO covariance at all, so the
2957    /// covariance-dimension / predictor assertions below could not even compile
2958    /// (fail-before). Asserts, with un-weakened bounds:
2959    ///   1. covariance is `(P·(K−1))²`, all-finite, symmetric, and PSD (every
2960    ///      diagonal ≥ 0 and `vᵀΣv ≥ 0` on probe vectors);
2961    ///   2. the delta-method per-class probability SEs are finite and within
2962    ///      `[0, 1]` (a probability SE can never exceed the unit interval);
2963    ///   3. predicted probabilities are finite, in `[0, 1]`, and each row sums
2964    ///      to 1 (simplex).
2965    #[test]
2966    fn covariance_and_delta_method_se_are_finite_and_wellformed_1101() {
2967        let (design, y, penalty, lambdas) = toy();
2968        let p = design.ncols();
2969        let k = y.ncols();
2970        let m = k - 1;
2971        let d = p * m;
2972
2973        let fit = fit_penalized_multinomial(MultinomialFitInputs {
2974            design: design.view(),
2975            y_one_hot: y.view(),
2976            penalty: penalty.view(),
2977            lambdas: lambdas.view(),
2978            row_weights: None,
2979            fisher_w_override: None,
2980            max_iter: 50,
2981            tol: 1.0e-9,
2982        })
2983        .expect("fit must succeed");
2984        assert!(fit.converged, "toy multinomial fit must converge");
2985
2986        // (1) Covariance shape, finiteness, symmetry.
2987        let cov = &fit.coefficient_covariance;
2988        assert_eq!(
2989            cov.dim(),
2990            (d, d),
2991            "covariance must be (P·(K−1))² = ({d},{d})"
2992        );
2993        for &v in cov.iter() {
2994            assert!(v.is_finite(), "covariance entry must be finite (got {v})");
2995        }
2996        for i in 0..d {
2997            for j in 0..d {
2998                let asym = (cov[[i, j]] - cov[[j, i]]).abs();
2999                assert!(
3000                    asym <= 1e-9 * (1.0 + cov[[i, j]].abs()),
3001                    "covariance must be symmetric at ({i},{j}): |Σ_ij − Σ_ji| = {asym:.3e}"
3002                );
3003            }
3004        }
3005        // PSD: diagonal ≥ 0 and quadratic forms on deterministic probe vectors
3006        // (unit axes and the all-ones vector) are non-negative. `H = XᵀWX + λS`
3007        // with W PSD (softmax Fisher) and S PSD (identity here) is positive
3008        // definite, so its inverse is PD; these probes must all be positive.
3009        for i in 0..d {
3010            assert!(
3011                cov[[i, i]] >= 0.0,
3012                "covariance diagonal[{i}] must be ≥ 0 (got {})",
3013                cov[[i, i]]
3014            );
3015        }
3016        let mut probes: Vec<Vec<f64>> = Vec::new();
3017        for i in 0..d {
3018            let mut e = vec![0.0_f64; d];
3019            e[i] = 1.0;
3020            probes.push(e);
3021        }
3022        probes.push(vec![1.0_f64; d]);
3023        for v in &probes {
3024            let mut q = 0.0_f64;
3025            for i in 0..d {
3026                for j in 0..d {
3027                    q += v[i] * cov[[i, j]] * v[j];
3028                }
3029            }
3030            assert!(
3031                q >= -1e-9,
3032                "covariance must be PSD: vᵀΣv = {q:.3e} < 0"
3033            );
3034        }
3035
3036        // (2) & (3) Delta-method SEs and simplex probabilities on the training
3037        // design (any P-column matrix in the fitted basis works).
3038        let (probs, prob_se) = fit
3039            .predict_probabilities_with_se(design.view())
3040            .expect("delta-method SE must succeed");
3041        let n = design.nrows();
3042        assert_eq!(probs.dim(), (n, k));
3043        assert_eq!(prob_se.dim(), (n, k));
3044        for row in 0..n {
3045            let mut rowsum = 0.0_f64;
3046            for c in 0..k {
3047                let pc = probs[[row, c]];
3048                assert!(pc.is_finite() && (0.0..=1.0).contains(&pc), "prob[{row},{c}]={pc}");
3049                rowsum += pc;
3050                let se = prob_se[[row, c]];
3051                assert!(se.is_finite(), "prob_se[{row},{c}] must be finite (got {se})");
3052                assert!(
3053                    (0.0..=1.0).contains(&se),
3054                    "prob_se[{row},{c}] must be in [0,1] (got {se})"
3055                );
3056            }
3057            assert!(
3058                (rowsum - 1.0).abs() < 1e-9,
3059                "row {row} probabilities must sum to 1 (got {rowsum})"
3060            );
3061        }
3062    }
3063
3064    #[test]
3065    fn formula_outer_route_uses_exact_curvature_for_medium_d() {
3066        // The 2-smooth reference formula fit (K = 3, double-penalty terms) is
3067        // D = (K-1) * 2 terms * 2 penalties = 8 and needs exact curvature to
3068        // avoid over-smoothed lambda caps (#715 arm (a)).
3069        assert!(
3070            multinomial_formula_use_outer_hessian(8),
3071            "D=8 loaded multinomial fits need exact curvature to avoid over-smoothed lambda caps"
3072        );
3073        assert!(
3074            multinomial_formula_use_outer_hessian(12),
3075            "D=12 (3 double-penalty smooth terms, K=3) stays on exact curvature"
3076        );
3077    }
3078
3079    #[test]
3080    fn formula_outer_route_uses_exact_curvature_for_d16_penguin_fixture() {
3081        // Four k=10 penguin smooths (K = 3) are D = 16 under double-penalty
3082        // terms. They must reach the exact ARC route so the #1082 cost-stall
3083        // halt is available on the near-separable lambda-to-zero ridge.
3084        assert!(
3085            multinomial_formula_use_outer_hessian(16),
3086            "D=16 multinomial fits need exact ARC curvature for the #1082 stall halt"
3087        );
3088    }
3089
3090    #[test]
3091    fn formula_min_lambda_floor_is_continuous_and_information_scaled() {
3092        // Build a one-hot label matrix whose smallest class carries `count` rows.
3093        fn floor_for_min_count(count: usize) -> f64 {
3094            // Two classes: a large one (1000 rows) and a minority one (`count`).
3095            let n = 1000 + count;
3096            let mut y = Array2::<f64>::zeros((n, 2));
3097            for r in 0..1000 {
3098                y[[r, 0]] = 1.0;
3099            }
3100            for r in 1000..n {
3101                y[[r, 1]] = 1.0;
3102            }
3103            multinomial_formula_min_lambda(y.view())
3104        }
3105
3106        // The floor's endpoints are now DERIVED from a target prior strength in
3107        // pseudo-observations against the maximal per-observation softmax Fisher
3108        // information I₁ = ¼ (base = τ·I₁, sparse = τ_max·I₁). Pin them to the
3109        // previously fixture-calibrated values so the near-separable quality arms
3110        // (penguins, vgam softmax) — whose smallest class has n_c ≥ 50 — are
3111        // byte-for-byte unaffected: the derivation REDUCES TO the old constants
3112        // at the calibration point.
3113        let base = MULTINOMIAL_FORMULA_PRIOR_PSEUDO_OBS * MULTINOMIAL_FORMULA_FISHER_INFO_PER_OBS;
3114        let sparse = MULTINOMIAL_FORMULA_SPARSE_PRIOR_PSEUDO_OBS_MAX
3115            * MULTINOMIAL_FORMULA_FISHER_INFO_PER_OBS;
3116        assert!(
3117            (base - 2.0e-4).abs() < 1e-18,
3118            "derived base floor must equal the calibrated 2e-4"
3119        );
3120        assert!(
3121            (sparse - 1.0e-3).abs() < 1e-18,
3122            "derived sparse floor must equal the calibrated 1e-3"
3123        );
3124
3125        // Well-supported (n_c >= n_ref=50) sits exactly at the base floor.
3126        assert!((floor_for_min_count(50) - base).abs() < 1e-18);
3127        assert!((floor_for_min_count(200) - base).abs() < 1e-18);
3128        // Very sparse (n_c <= n_ref·base/sparse = 10) clamps to the strong floor.
3129        assert!((floor_for_min_count(10) - sparse).abs() < 1e-18);
3130        assert!((floor_for_min_count(5) - sparse).abs() < 1e-18);
3131        // No cliff at the old hard threshold: 49 vs 50 differ by < 5% (the old
3132        // step jumped 5x). Floor is monotone non-increasing in support.
3133        let f49 = floor_for_min_count(49);
3134        let f50 = floor_for_min_count(50);
3135        assert!(
3136            f49 >= f50 && f49 <= f50 * 1.05,
3137            "floor must be continuous across c0, got {f49} vs {f50}"
3138        );
3139        let f25 = floor_for_min_count(25);
3140        assert!(
3141            f25 > f50 && f25 < floor_for_min_count(10),
3142            "mid-support floor must interpolate strictly between the two endpoints"
3143        );
3144
3145        // FIRST-PRINCIPLES SCALING: in the interpolating regime the floor equals
3146        // exactly τ·I₁·(n_ref/n_c) — the effective-pseudo-observation prior held
3147        // to a fixed fraction of the per-class data information n_c·I₁. Halving
3148        // the effective sample size doubles the floor (until the cap), and the
3149        // absolute value matches the closed-form n_c-scaled prior.
3150        for &n_c in &[12usize, 16, 20, 30, 40] {
3151            let expected = base * (MULTINOMIAL_FORMULA_SPARSE_REFERENCE_SUPPORT / n_c as f64);
3152            assert!(
3153                (floor_for_min_count(n_c) - expected).abs() < 1e-15,
3154                "floor at n_c={n_c} must be τ·I₁·n_ref/n_c = {expected}, got {}",
3155                floor_for_min_count(n_c)
3156            );
3157        }
3158        // Inverse scaling with effective sample size: n_c -> n_c/2 doubles the
3159        // floor inside the unclamped band (20 and 40 are both interior; 40 < 50
3160        // so it is scaled, 20 > 10 so it is not capped).
3161        assert!(
3162            (floor_for_min_count(20) - 2.0 * floor_for_min_count(40)).abs() < 1e-15,
3163            "floor must scale like 1/n_c (effective Fisher information) in the interior band"
3164        );
3165    }
3166
3167    #[test]
3168    fn formula_penalty_scale_tracks_softmax_fisher_curvature() {
3169        assert!(
3170            (multinomial_formula_penalty_scale(2) - 0.5).abs() < 1.0e-12,
3171            "binary-logit neutral-simplex curvature scale should remain at 1/2"
3172        );
3173        assert!(
3174            (multinomial_formula_penalty_scale(3) - 4.0 / 9.0).abs() < 1.0e-12,
3175            "three-class softmax penalties should be calibrated to 2*(K-1)/K^2"
3176        );
3177        assert!(
3178            multinomial_formula_penalty_scale(5) < multinomial_formula_penalty_scale(3),
3179            "active-class Fisher curvature decreases as the simplex gains classes"
3180        );
3181    }
3182
3183    #[test]
3184    fn fixed_lambda_multinomial_firth_keeps_complete_separation_finite() {
3185        // #1854: complete softmax separation used to be a HARD diagnostic
3186        // (`MultinomialSeparationDetected`). It now automatically engages the
3187        // Firth/Jeffreys proper prior (`½ log|I(β)|`, magic-by-default) so the fit
3188        // stays finite instead of running away — the same guarantee the formula
3189        // REML path already provided. The class regions are cleanly separated by
3190        // `x`, so the unbiased MLE is at infinity; the Firth-penalized fit must
3191        // still converge to a finite mode and recover the region structure.
3192        let n = 90;
3193        let design = Array2::<f64>::from_shape_fn((n, 2), |(row, col)| match col {
3194            0 => 1.0,
3195            _ => -3.0 + 6.0 * (row as f64) / ((n - 1) as f64),
3196        });
3197        let mut y = Array2::<f64>::zeros((n, 3));
3198        for row in 0..n {
3199            let x = design[[row, 1]];
3200            let class = if x < -1.0 {
3201                0
3202            } else if x > 1.0 {
3203                1
3204            } else {
3205                2
3206            };
3207            y[[row, class]] = 1.0;
3208        }
3209        let penalty = Array2::<f64>::zeros((2, 2));
3210        let lambdas = Array1::<f64>::zeros(2);
3211        let out = fit_penalized_multinomial(MultinomialFitInputs {
3212            design: design.view(),
3213            y_one_hot: y.view(),
3214            penalty: penalty.view(),
3215            lambdas: lambdas.view(),
3216            row_weights: None,
3217            fisher_w_override: None,
3218            max_iter: 80,
3219            tol: 1.0e-12,
3220        })
3221        .expect("Firth/Jeffreys prior keeps the separated multinomial fit finite (#1854)");
3222        assert!(
3223            out.converged,
3224            "the Firth-penalized separation refit must report convergence"
3225        );
3226        // Every coefficient is finite — the whole point of the Firth prior on the
3227        // separated (unpenalized) logit directions.
3228        for &b in out.coefficients_active.iter() {
3229            assert!(
3230                b.is_finite(),
3231                "Firth-penalized coefficients must be finite, got {b}"
3232            );
3233        }
3234        // Fitted probabilities remain a valid simplex per row.
3235        for row in 0..n {
3236            let mut mass = 0.0_f64;
3237            for c in 0..3 {
3238                let p = out.fitted_probabilities[[row, c]];
3239                assert!(
3240                    p.is_finite() && (0.0..=1.0 + 1e-9).contains(&p),
3241                    "row {row} class {c} probability {p} out of [0,1]"
3242                );
3243                mass += p;
3244            }
3245            assert!(
3246                (mass - 1.0).abs() < 1e-6,
3247                "row {row} probabilities must sum to 1, got {mass}"
3248            );
3249        }
3250        // The finite fit still recovers the separated structure: on a clearly
3251        // interior representative of each region the predicted class is correct.
3252        let predict = |x: f64| -> usize {
3253            let mut eta = [0.0_f64; 3];
3254            for a in 0..2 {
3255                eta[a] = out.coefficients_active[[0, a]] + out.coefficients_active[[1, a]] * x;
3256            }
3257            let mut best = 0usize;
3258            for c in 1..3 {
3259                if eta[c] > eta[best] {
3260                    best = c;
3261                }
3262            }
3263            best
3264        };
3265        assert_eq!(predict(-2.5), 0, "deep-left region should predict class 0");
3266        assert_eq!(predict(2.5), 1, "deep-right region should predict class 1");
3267        assert_eq!(predict(0.0), 2, "central region should predict class 2");
3268    }
3269
3270    #[test]
3271    fn formula_multinomial_accepts_finite_saturated_logits() {
3272        // A saturated-but-FINITE logit surface can be a valid formula REML mode
3273        // (the #715 penguins regime: bill/flipper cleanly separate the species,
3274        // so fitted logits can legitimately exceed ±25). `outer_converged ==
3275        // false` then signals only that the driver auto-escalated to never-fail
3276        // posterior sampling about that finite mode (gam#860), NOT a separation
3277        // artifact — the adapter must accept it, never raise
3278        // `MultinomialSeparationDetected`.
3279        let saturated_states = vec![
3280            ParameterBlockState {
3281                beta: Array1::from_vec(vec![1.0, 2.0]),
3282                eta: Array1::from_vec(vec![0.2, 4.0, -7.0]),
3283            },
3284            ParameterBlockState {
3285                beta: Array1::from_vec(vec![-1.0, 3.0]),
3286                eta: Array1::from_vec(vec![1.0, 25.5, -0.1]),
3287            },
3288        ];
3289        assert!(
3290            multinomial_formula_separation_diagnostic(17, 9, &saturated_states).is_none(),
3291            "a finite (even saturated, |eta|>25) formula optimum is a valid fit, \
3292             not a separation diagnostic"
3293        );
3294
3295        // Only a genuinely NON-FINITE logit — a NaN/Inf blow-up in the inner
3296        // linear algebra with no finite mode to sample about — is a real
3297        // formula-path failure.
3298        let blown_up = vec![
3299            ParameterBlockState {
3300                beta: Array1::from_vec(vec![1.0, 2.0]),
3301                eta: Array1::from_vec(vec![0.2, 4.0, -7.0]),
3302            },
3303            ParameterBlockState {
3304                beta: Array1::from_vec(vec![-1.0, 3.0]),
3305                eta: Array1::from_vec(vec![1.0, f64::INFINITY, -0.1]),
3306            },
3307        ];
3308        let err = multinomial_formula_separation_diagnostic(17, 9, &blown_up)
3309            .expect("a non-finite formula logit must raise the separation diagnostic");
3310        assert!(
3311            matches!(
3312                err,
3313                EstimationError::MultinomialSeparationDetected {
3314                    iteration: 17,
3315                    max_abs_eta,
3316                    active_class_index: 1,
3317                    row_index: 1,
3318                } if !max_abs_eta.is_finite()
3319            ),
3320            "expected typed multinomial separation diagnostic at the non-finite channel, got {err:?}"
3321        );
3322    }
3323
3324    #[test]
3325    fn separation_evidence_gate_arms_firth_only_on_blowup() {
3326        // Interior fit: finite logits well inside the saturation threshold ⇒ NO
3327        // separation evidence ⇒ the unbiased criterion's mode is accepted as-is
3328        // and the Firth/Jeffreys prior stays disarmed (#715 arm (a): no 1/K
3329        // shrinkage on well-identified data).
3330        let interior = vec![
3331            ParameterBlockState {
3332                beta: Array1::from_vec(vec![1.0, 2.0]),
3333                eta: Array1::from_vec(vec![0.2, 4.0, -7.0]),
3334            },
3335            ParameterBlockState {
3336                beta: Array1::from_vec(vec![-1.0, 3.0]),
3337                eta: Array1::from_vec(vec![1.0, -3.5, -0.1]),
3338            },
3339        ];
3340        assert!(
3341            multinomial_formula_separation_evidence(&interior).is_none(),
3342            "an interior finite mode must not arm the Firth refit"
3343        );
3344
3345        // Saturated but finite logits are valid formula-path modes on
3346        // near-separated real data. They must not arm the Firth refit because
3347        // the Jeffreys pull can over-regularize the held-out probabilities.
3348        let saturated = vec![
3349            ParameterBlockState {
3350                beta: Array1::from_vec(vec![1.0, 2.0]),
3351                eta: Array1::from_vec(vec![0.2, 4.0, -7.0]),
3352            },
3353            ParameterBlockState {
3354                beta: Array1::from_vec(vec![-1.0, 3.0]),
3355                eta: Array1::from_vec(vec![1.0, 25.5, -0.1]),
3356            },
3357        ];
3358        assert!(
3359            multinomial_formula_separation_evidence(&saturated).is_none(),
3360            "a finite saturated formula-mode logit must not arm the Firth refit"
3361        );
3362
3363        // Non-finite logit ⇒ inner blow-up along an unbounded direction ⇒
3364        // separation evidence.
3365        let blown_up = vec![ParameterBlockState {
3366            beta: Array1::from_vec(vec![1.0, 2.0]),
3367            eta: Array1::from_vec(vec![0.2, f64::NAN, -7.0]),
3368        }];
3369        let evidence = multinomial_formula_separation_evidence(&blown_up)
3370            .expect("a non-finite logit is separation evidence");
3371        assert!(
3372            evidence.contains("non-finite logit") && evidence.contains("row 1"),
3373            "evidence must name the non-finite logit, got {evidence}"
3374        );
3375
3376        // Large finite logits below the fixed-lambda diagnostic threshold are
3377        // likewise accepted on the formula path.
3378        let near = vec![ParameterBlockState {
3379            beta: Array1::from_vec(vec![1.0, 2.0]),
3380            eta: Array1::from_vec(vec![0.2, 24.9, -24.9]),
3381        }];
3382        assert!(
3383            multinomial_formula_separation_evidence(&near).is_none(),
3384            "logits below the saturation threshold must not arm the Firth refit"
3385        );
3386    }
3387
3388    #[test]
3389    fn unresolved_probe_evidence_arms_firth_on_saturated_finite_logits() {
3390        let saturated = vec![
3391            ParameterBlockState {
3392                beta: Array1::from_vec(vec![1.0, 2.0]),
3393                eta: Array1::from_vec(vec![0.2, 4.0, -7.0]),
3394            },
3395            ParameterBlockState {
3396                beta: Array1::from_vec(vec![-1.0, 3.0]),
3397                eta: Array1::from_vec(vec![1.0, 25.5, -0.1]),
3398            },
3399        ];
3400
3401        assert!(
3402            multinomial_formula_separation_evidence(&saturated).is_none(),
3403            "a converged finite saturated formula optimum remains unbiased"
3404        );
3405        let evidence = multinomial_formula_unresolved_probe_separation_evidence(&saturated)
3406            .expect("a non-converged saturated probe should arm the Firth refit");
3407        assert!(
3408            evidence.contains("separation-scale finite logit")
3409                && evidence.contains("row 1")
3410                && evidence.contains("active class 1"),
3411            "unresolved-probe evidence should name the saturated channel, got {evidence}"
3412        );
3413
3414        let near = vec![ParameterBlockState {
3415            beta: Array1::from_vec(vec![1.0, 2.0]),
3416            eta: Array1::from_vec(vec![0.2, 24.9, -24.9]),
3417        }];
3418        assert!(
3419            multinomial_formula_unresolved_probe_separation_evidence(&near).is_none(),
3420            "finite logits below the separation threshold still get the full unbiased retry"
3421        );
3422    }
3423
3424    #[test]
3425    fn scaled_fisher_override_changes_first_step() {
3426        // Curvature scaled by 4× shrinks the first Newton step relative to the
3427        // analytic fit, so a single-iteration fit must differ.
3428        let (design, y, penalty, lambdas) = toy();
3429        let n = design.nrows();
3430        let m = y.ncols() - 1;
3431        // Analytic block at β = 0: p_a = 1/K = 1/3, so diag = p_a(1−p_a),
3432        // off-diag = −p_a p_b. Scale that exact block by 4.
3433        let pk = 1.0 / (y.ncols() as f64);
3434        let mut over = Array3::<f64>::zeros((n, m, m));
3435        for row in 0..n {
3436            for a in 0..m {
3437                for b in 0..m {
3438                    let analytic = if a == b { pk * (1.0 - pk) } else { -pk * pk };
3439                    over[[row, a, b]] = 4.0 * analytic;
3440                }
3441            }
3442        }
3443        let scaled = fit_penalized_multinomial(MultinomialFitInputs {
3444            design: design.view(),
3445            y_one_hot: y.view(),
3446            penalty: penalty.view(),
3447            lambdas: lambdas.view(),
3448            row_weights: None,
3449            fisher_w_override: Some(over.view()),
3450            max_iter: 1,
3451            tol: 1.0e-9,
3452        })
3453        .expect("override fit must succeed");
3454        let analytic = fit_penalized_multinomial(MultinomialFitInputs {
3455            design: design.view(),
3456            y_one_hot: y.view(),
3457            penalty: penalty.view(),
3458            lambdas: lambdas.view(),
3459            row_weights: None,
3460            fisher_w_override: None,
3461            max_iter: 1,
3462            tol: 1.0e-9,
3463        })
3464        .expect("analytic fit must succeed");
3465        let differs = scaled
3466            .coefficients_active
3467            .iter()
3468            .zip(analytic.coefficients_active.iter())
3469            .any(|(a, b)| (a - b).abs() > 1.0e-6);
3470        assert!(differs, "scaled curvature must change the first step");
3471    }
3472}
3473
3474#[cfg(test)]
3475mod separation_firth_tests {
3476    //! Regression for #1854: on (quasi-)perfect separation the fixed-λ direct
3477    //! multinomial solve must engage the Firth/Jeffreys penalty and return a
3478    //! finite, converged, well-behaved fit instead of hard-erroring with
3479    //! `MultinomialSeparationDetected`.
3480    use super::*;
3481
3482    /// A perfectly linearly separable 3-class problem with an UNPENALIZED design
3483    /// (`S = 0`), so no smoothing `λ` can bound the saturated logits — only the
3484    /// Firth prior `½ log det I(β)` keeps the estimate finite. The unbiased MLE
3485    /// here runs `|η| → ∞` (separation), which is exactly the #1854 trigger.
3486    fn separated_three_class() -> (Array2<f64>, Array2<f64>, Array2<f64>, Array1<f64>) {
3487        let n = 21;
3488        let p = 2; // intercept + ordering covariate x
3489        let k = 3;
3490        let mut design = Array2::<f64>::zeros((n, p));
3491        let mut y = Array2::<f64>::zeros((n, k));
3492        for i in 0..n {
3493            let x = -3.0 + 6.0 * (i as f64) / ((n - 1) as f64);
3494            design[[i, 0]] = 1.0;
3495            design[[i, 1]] = x;
3496            let cls = if x < -1.0 {
3497                0
3498            } else if x < 1.0 {
3499                1
3500            } else {
3501                2
3502            };
3503            y[[i, cls]] = 1.0;
3504        }
3505        // S = 0: no smoothing direction can bound the separated logits.
3506        let penalty = Array2::<f64>::zeros((p, p));
3507        let lambdas = Array1::<f64>::from_elem(k - 1, 1.0);
3508        (design, y, penalty, lambdas)
3509    }
3510
3511    #[test]
3512    fn separation_engages_firth_finite_converged_fit() {
3513        let (design, y, penalty, lambdas) = separated_three_class();
3514        let out = fit_penalized_multinomial(MultinomialFitInputs {
3515            design: design.view(),
3516            y_one_hot: y.view(),
3517            penalty: penalty.view(),
3518            lambdas: lambdas.view(),
3519            row_weights: None,
3520            fisher_w_override: None,
3521            max_iter: 300,
3522            tol: 1e-10,
3523        })
3524        .expect("separated multinomial must engage Firth and return a fit, not error");
3525
3526        assert!(out.converged, "Firth-engaged separation fit must converge");
3527        assert!(
3528            out.coefficients_active.iter().all(|v| v.is_finite()),
3529            "all coefficients must be finite under the Firth prior"
3530        );
3531        assert!(out.deviance.is_finite(), "deviance must be finite");
3532
3533        // The runaway MLE would drive fitted probabilities to the {0,1} boundary;
3534        // the Firth prior keeps them strictly interior.
3535        for v in out.fitted_probabilities.iter() {
3536            assert!(
3537                *v > 0.0 && *v < 1.0,
3538                "Firth fit must stay interior, got p={v}"
3539            );
3540        }
3541
3542        // Perfect separation ⇒ every training row classified to its true class.
3543        let n = design.nrows();
3544        let k = y.ncols();
3545        for i in 0..n {
3546            let mut best = 0usize;
3547            for c in 1..k {
3548                if out.fitted_probabilities[[i, c]] > out.fitted_probabilities[[i, best]] {
3549                    best = c;
3550                }
3551            }
3552            let truth = (0..k)
3553                .find(|&c| y[[i, c]] == 1.0)
3554                .expect("one-hot truth class");
3555            assert_eq!(best, truth, "row {i} misclassified under separation");
3556        }
3557    }
3558
3559    #[test]
3560    fn separation_firth_returns_finite_wellshaped_covariance() {
3561        // Distinct angle: the Firth separation path must also expose a finite,
3562        // correctly-shaped (P·M × P·M) Laplace coefficient covariance — the
3563        // downstream SE machinery consumes it. A runaway MLE would have a
3564        // singular (non-invertible) information here.
3565        let (design, y, penalty, lambdas) = separated_three_class();
3566        let p = design.ncols();
3567        let k = y.ncols();
3568        let m = k - 1;
3569        let out = fit_penalized_multinomial(MultinomialFitInputs {
3570            design: design.view(),
3571            y_one_hot: y.view(),
3572            penalty: penalty.view(),
3573            lambdas: lambdas.view(),
3574            row_weights: None,
3575            fisher_w_override: None,
3576            max_iter: 300,
3577            tol: 1e-10,
3578        })
3579        .expect("separated multinomial must return a Firth fit");
3580
3581        assert_eq!(
3582            out.coefficient_covariance.dim(),
3583            (p * m, p * m),
3584            "covariance must be P·M square"
3585        );
3586        assert!(
3587            out.coefficient_covariance.iter().all(|v| v.is_finite()),
3588            "Firth covariance entries must be finite"
3589        );
3590        // A genuine Laplace covariance is PSD ⇒ non-negative diagonal.
3591        for i in 0..(p * m) {
3592            assert!(
3593                out.coefficient_covariance[[i, i]] >= -1e-9,
3594                "covariance diagonal must be non-negative, got {}",
3595                out.coefficient_covariance[[i, i]]
3596            );
3597        }
3598    }
3599
3600    #[test]
3601    fn firth_solver_does_not_over_report_convergence_when_truncated() {
3602        // #2066 (convergence honesty): the Firth Newton loop must report
3603        // `converged` according to a genuine stationarity criterion — never as a
3604        // side effect of simply stopping. Before the fix, a line-search stall set
3605        // `converged = true` unconditionally; more broadly, the flag must be
3606        // false whenever the solve is stopped short of stationarity.
3607        //
3608        // Angle: run the SAME separated problem that converges under a full
3609        // budget (`separation_engages_firth_finite_converged_fit`) but starve the
3610        // iteration budget so it provably cannot reach the interior Firth mode.
3611        // The honest report is `converged = false`; the coefficients must still be
3612        // finite (no NaN leak from the truncated iterate).
3613        let (design, y, penalty, lambdas) = separated_three_class();
3614
3615        let truncated = fit_penalized_multinomial_firth_fallback(
3616            design.view(),
3617            y.view(),
3618            penalty.view(),
3619            lambdas.view(),
3620            None,
3621            1, // one Newton iteration — far from the separated mode
3622            1e-12,
3623        )
3624        .expect("Firth fallback must return a (non-converged) fit, not error");
3625        assert!(
3626            !truncated.converged,
3627            "a Firth solve stopped after one iteration on a separated problem \
3628             must report converged=false, not paper over non-convergence"
3629        );
3630        assert!(
3631            truncated.coefficients_active.iter().all(|v| v.is_finite()),
3632            "truncated Firth iterate must remain finite"
3633        );
3634
3635        // Contrast: with a full budget the same problem does reach stationarity
3636        // and is honestly reported as converged — so the flag tracks the solve,
3637        // not the exit.
3638        let full = fit_penalized_multinomial_firth_fallback(
3639            design.view(),
3640            y.view(),
3641            penalty.view(),
3642            lambdas.view(),
3643            None,
3644            300,
3645            1e-10,
3646        )
3647        .expect("Firth fallback must converge under a full budget");
3648        assert!(
3649            full.converged,
3650            "full-budget Firth solve on the separated problem must converge"
3651        );
3652    }
3653}
3654
3655#[cfg(test)]
3656mod reference_class_invariance_tests {
3657    //! Regression for #1587: a penalized multinomial-logit GAM fit must be
3658    //! invariant to which class is the (arbitrary) softmax reference/baseline.
3659    //!
3660    //! The production REML path (`fit_penalized_multinomial_formula`) reference-
3661    //! codes the `K` classes (the last sorted label is the baseline) and, with
3662    //! the legacy `Diagonal` penalty metric, penalizes only the `K−1`
3663    //! reference-anchored ALR contrasts `½ Σ_a λ_a β_aᵀ S β_a`. Relabeling the
3664    //! response so a *different* class sorts last penalizes a different frame of
3665    //! log-odds contrasts, so the predicted probabilities drift (~1e-2 absolute)
3666    //! even though they are mathematically independent of the reference choice.
3667    //!
3668    //! This test fits the SAME 3-class softmax sample under three cyclic
3669    //! relabelings — each making a different original class the baseline —
3670    //! realigns the predicted probability columns back to the original class
3671    //! identities, and asserts the cross-labeling drift is below `1e-3`
3672    //! (the defect is ~1e-2; refitting the same labeling twice agrees to
3673    //! ~1e-12). It is the Rust-level sibling of
3674    //! `tests/bug_hunt_multinomial_fit_depends_on_reference_class_test.py`.
3675
3676    use super::*;
3677    use gam_data::load_dataset_projected;
3678    use std::fmt::Write as _;
3679    use std::fs;
3680    use tempfile::tempdir;
3681
3682    /// Deterministic `splitmix64` → `[0,1)` uniform stream (no external RNG dep;
3683    /// the only requirement is a well-distributed, reproducible draw).
3684    struct SplitMix64(u64);
3685    impl SplitMix64 {
3686        fn next_u64(&mut self) -> u64 {
3687            self.0 = self.0.wrapping_add(0x9E37_79B9_7F4A_7C15);
3688            let mut z = self.0;
3689            z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
3690            z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
3691            z ^ (z >> 31)
3692        }
3693        fn unit(&mut self) -> f64 {
3694            // 53-bit mantissa uniform in [0, 1).
3695            (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
3696        }
3697    }
3698
3699    /// Draw a clean 3-class softmax regression sample (the issue's generator).
3700    /// Returns `(x, class)` with integer classes `0/1/2`.
3701    fn sample_classes(seed: u64, n: usize) -> (Vec<f64>, Vec<usize>) {
3702        let mut rng = SplitMix64(seed.wrapping_add(0x1234_5678));
3703        let mut x = Vec::with_capacity(n);
3704        let mut cls = Vec::with_capacity(n);
3705        for _ in 0..n {
3706            let xi = -2.0 + 4.0 * rng.unit();
3707            let eta = [0.5 + 0.8 * xi, -0.3 - 0.5 * xi, 0.0];
3708            let mut p = [eta[0].exp(), eta[1].exp(), eta[2].exp()];
3709            let s: f64 = p.iter().sum();
3710            for v in &mut p {
3711                *v /= s;
3712            }
3713            // Inverse-CDF draw into one of the 3 classes.
3714            let u = rng.unit();
3715            let c = if u < p[0] {
3716                0
3717            } else if u < p[0] + p[1] {
3718                1
3719            } else {
3720                2
3721            };
3722            x.push(xi);
3723            cls.push(c);
3724        }
3725        (x, cls)
3726    }
3727
3728    /// Build an `EncodedDataset` with columns `x` (numeric) and `y`
3729    /// (categorical, from the given string labels) by round-tripping a CSV.
3730    fn dataset_xy(dir: &std::path::Path, tag: &str, x: &[f64], y: &[String]) -> gam_data::EncodedDataset {
3731        let path = dir.join(format!("data_{tag}.csv"));
3732        let mut csv = String::from("x,y\n");
3733        for (xi, yi) in x.iter().zip(y.iter()) {
3734            writeln!(csv, "{xi},{yi}").unwrap();
3735        }
3736        fs::write(&path, csv).expect("write training csv");
3737        load_dataset_projected(&path, &["x".to_string(), "y".to_string()])
3738            .expect("load training dataset")
3739    }
3740
3741    /// Fit `y ~ s(x)` under the relabeling `name_map` (original class `c` gets
3742    /// label `name_map[c]`), predict on `grid`, and return the predicted
3743    /// probabilities **realigned to the original class order** 0/1/2, shape
3744    /// `(grid.len(), 3)`.
3745    fn fit_predict_aligned(
3746        dir: &std::path::Path,
3747        tag: &str,
3748        x: &[f64],
3749        cls: &[usize],
3750        name_map: [&str; 3],
3751        grid: &[f64],
3752    ) -> Array2<f64> {
3753        let labels: Vec<String> = cls.iter().map(|&c| name_map[c].to_string()).collect();
3754        let train = dataset_xy(dir, tag, x, &labels);
3755        let config = FitConfig::default();
3756        let model = fit_penalized_multinomial_formula(&train, "y ~ s(x)", &config, 1.0, 60, 1e-6)
3757            .expect("multinomial formula fit must succeed");
3758
3759        // Predict on the grid. The categorical `y` column is not needed for
3760        // prediction, but the schema is simplest if we supply a dummy.
3761        let grid_y: Vec<String> = grid.iter().map(|_| name_map[0].to_string()).collect();
3762        let grid_ds = dataset_xy(dir, &format!("{tag}_grid"), grid, &grid_y);
3763        let probs = predict_multinomial_formula(&model, &grid_ds)
3764            .expect("multinomial predict must succeed");
3765
3766        // `model.class_levels` is the sorted label order; the column for original
3767        // class `c` is at the rank of `name_map[c]` among the sorted labels.
3768        let mut sorted: Vec<&str> = name_map.to_vec();
3769        sorted.sort_unstable();
3770        let col_of_orig: Vec<usize> = (0..3)
3771            .map(|c| sorted.iter().position(|l| *l == name_map[c]).unwrap())
3772            .collect();
3773        // Sanity: the model's class_levels must match the sorted labels.
3774        assert_eq!(
3775            model.class_levels,
3776            sorted.iter().map(|s| s.to_string()).collect::<Vec<_>>(),
3777            "class_levels must be the sorted label order"
3778        );
3779        let n = grid.len();
3780        let mut aligned = Array2::<f64>::zeros((n, 3));
3781        for r in 0..n {
3782            for c in 0..3 {
3783                aligned[[r, c]] = probs[[r, col_of_orig[c]]];
3784            }
3785        }
3786        aligned
3787    }
3788
3789    fn max_abs_diff(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
3790        a.iter()
3791            .zip(b.iter())
3792            .map(|(p, q)| (p - q).abs())
3793            .fold(0.0_f64, f64::max)
3794    }
3795
3796    // gam#1587: now that the reference-symmetric centered `M⊗S_t` joint penalty
3797    // is wired through the custom-family outer REML loop (per-eval
3798    // `JointPenaltyBundle` + outer penalty_coords/logdet/operator), the
3799    // production multinomial fit is invariant to the arbitrary reference class,
3800    // so this guard runs by default (the opt-in skip attribute it carried while
3801    // the fix was pending is also forbidden by the build.rs ban-scanner). It is
3802    // an end-to-end fit guard (a handful of full softmax `y ~ s(x)` fits) —
3803    // slower than a unit test but a true production-path regression.
3804    #[test]
3805    fn multinomial_fit_is_invariant_to_reference_class_1587() {
3806        let td = tempdir().expect("tempdir");
3807        let dir = td.path();
3808        // The reference-class drift is STRUCTURAL (it does not shrink with n, see
3809        // the issue table), so a modest n exposes it just as cleanly as n=900
3810        // while keeping this an affordable CI guard.
3811        let (x, cls) = sample_classes(0, 300);
3812        let grid: Vec<f64> = (0..7).map(|i| -1.5 + 3.0 * (i as f64) / 6.0).collect();
3813
3814        // Three labelings that each make a DIFFERENT original class the baseline
3815        // (the class whose label sorts LAST is the reference K−1):
3816        //   ["A","B","C"] → ref = class 2
3817        //   ["B","C","A"] → ref = class 1
3818        //   ["C","A","B"] → ref = class 0
3819        let a = fit_predict_aligned(dir, "abc", &x, &cls, ["A", "B", "C"], &grid);
3820        let b = fit_predict_aligned(dir, "bca", &x, &cls, ["B", "C", "A"], &grid);
3821        let c = fit_predict_aligned(dir, "cab", &x, &cls, ["C", "A", "B"], &grid);
3822
3823        // Refitting the SAME labeling twice must agree to ~machine precision —
3824        // this isolates optimizer noise from the structural reference drift.
3825        let a2 = fit_predict_aligned(dir, "abc2", &x, &cls, ["A", "B", "C"], &grid);
3826        let refit_noise = max_abs_diff(&a, &a2);
3827        assert!(
3828            refit_noise < 1e-6,
3829            "refitting the same labeling must be deterministic (got {refit_noise:.3e})"
3830        );
3831
3832        let drift = max_abs_diff(&a, &b)
3833            .max(max_abs_diff(&a, &c))
3834            .max(max_abs_diff(&b, &c));
3835        assert!(
3836            drift < 1e-3,
3837            "predicted probabilities must be invariant to the reference class; \
3838             cross-labeling drift = {drift:.3e} (refit noise = {refit_noise:.3e})"
3839        );
3840    }
3841}