Skip to main content

gam_solve/reml/
mod.rs

1use self::inner_strategy::GeometryBackendKind;
2use super::*;
3use gam_linalg::sparse_exact::SparseExactFactor;
4use crate::pirls::PIRLS_CACHE_BYTE_BUDGET;
5use crate::pirls::assemble_and_factor_sparse_penalized_system;
6use gam_terms::basis::LocalDesignJacobianProvider;
7use gam_problem::SasLinkState;
8use gam_problem::OuterEval;
9use ndarray::{Array1, Array2, s};
10use std::collections::{HashMap, VecDeque};
11use std::ops::Range;
12use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize};
13use std::sync::{Arc, RwLock};
14
15pub mod assembly;
16pub mod atoms;
17pub(crate) mod continuation;
18pub(crate) mod eval;
19mod firth;
20pub(super) mod hyper;
21mod inner_strategy;
22// #1521 carve: promoted `pub(crate)` -> `pub` so the extracted
23// `gam-custom-family` crate reaches the Jeffreys-subspace items it consumes.
24pub mod jeffreys_subspace;
25pub mod outer_eval;
26pub mod penalty_logdet;
27pub mod per_atom_efs;
28pub mod reml_outer_engine;
29mod rho_key;
30mod sparse_exact_penalty;
31mod trace;
32
33pub(crate) use sparse_exact_penalty::sparse_penalty_block_count_from_canonical;
34
35pub(crate) const EXACT_TAU_TAU_HESSIAN_DENSE_CACHE_BUDGET_BYTES: usize = 512 * 1024 * 1024;
36pub(crate) const FIRTH_MAX_OBSERVATIONS: usize = 20_000;
37pub(crate) const FIRTH_MAX_COEFFICIENTS: usize = 256;
38pub(crate) const FIRTH_MAX_LINEAR_WORK: usize = 2_000_000;
39pub(crate) const FIRTH_MAX_QUADRATIC_WORK: usize = 100_000_000;
40pub(crate) const PERSISTENT_LATENT_VALUES_CACHE_CAPACITY: usize = 8;
41
42#[derive(Debug)]
43pub(crate) struct PersistentLatentValuesCache {
44    pub(crate) entries: HashMap<String, Array2<f64>>,
45    pub(crate) lru: VecDeque<String>,
46    pub(crate) capacity: usize,
47}
48
49impl Default for PersistentLatentValuesCache {
50    fn default() -> Self {
51        Self {
52            entries: HashMap::new(),
53            lru: VecDeque::new(),
54            capacity: PERSISTENT_LATENT_VALUES_CACHE_CAPACITY,
55        }
56    }
57}
58
59impl PersistentLatentValuesCache {
60    pub(crate) fn lookup(
61        &mut self,
62        key: &str,
63        n_obs: usize,
64        latent_dim: usize,
65    ) -> Option<Array2<f64>> {
66        let values = self.entries.get(key)?;
67        if values.dim() != (n_obs, latent_dim) {
68            return None;
69        }
70        let values = values.clone();
71        self.touch(key.to_string());
72        Some(values)
73    }
74
75    pub(crate) fn insert(&mut self, key: String, values: Array2<f64>) {
76        if values.iter().any(|value| !value.is_finite()) {
77            return;
78        }
79        self.entries.insert(key.clone(), values);
80        self.touch(key);
81        while self.entries.len() > self.capacity {
82            let Some(evicted) = self.lru.pop_front() else {
83                break;
84            };
85            self.entries.remove(&evicted);
86        }
87    }
88
89    pub(crate) fn touch(&mut self, key: String) {
90        if let Some(index) = self.lru.iter().position(|queued| queued == &key) {
91            self.lru.remove(index);
92        }
93        self.lru.push_back(key);
94    }
95}
96
97/// Cached state from the most recent successful PIRLS solve, populated by
98/// `updatewarm_start_from` and consumed by the IFT-based warm-start
99/// predictor (`RemlState::predict_warm_start_beta_ift_with_outcome`).
100/// See the field doc on `RemlState::ift_warm_start_cache` for the math.
101#[derive(Clone)]
102pub(crate) struct IftWarmStartCache {
103    /// β at the converged solve, in ORIGINAL basis. Mirror of
104    /// `warm_start_beta` stashed alongside the H factor for atomic
105    /// consistency under concurrent reads (the predictor needs both
106    /// β and H from the SAME solve; reading them from two locks risks
107    /// a torn pair if a fresh solve lands between reads).
108    pub beta_original: ndarray::Array1<f64>,
109    /// ρ at which the solve occurred. Mirror of `warm_start_rho`,
110    /// stashed for the same atomic-consistency reason.
111    pub rho: ndarray::Array1<f64>,
112    /// Penalized Hessian H_pen at the converged β, in TRANSFORMED basis.
113    /// The IFT predictor factors this on demand; basis transforms run in
114    /// transformed basis for numerical stability.
115    pub penalized_hessian_transformed: gam_linalg::matrix::SymmetricMatrix,
116    /// Reparameterization matrix qs converting between transformed
117    /// (column) basis and original basis: `β_orig = qs · β_tfd`,
118    /// `H_orig = qs · H_tfd · qs^T`.
119    pub qs: ndarray::Array2<f64>,
120    /// True when the PIRLS result was already in original basis
121    /// (`OriginalSparseNative`) — in which case `qs` is the identity
122    /// and the IFT predictor can skip the basis-conversion ops.
123    pub frame_was_original: bool,
124    /// Per-penalty precomputation `S_k · β_cur[cp.col_range]`,
125    /// indexed in lockstep with `RemlObjectiveState::canonical_penalties`.
126    /// Each entry is the local-block mat-vec the IFT predictor would
127    /// otherwise recompute on every predict call. With H_pen factor
128    /// caching (commit ec18559d) the per-call cost dropped from
129    /// `O(p³)` Cholesky to `O(p²) ≈ k · O(block²)` rhs construction;
130    /// at large-scale CTN (p ≈ several thousand) that's several ms
131    /// per predict call still being paid. By stashing `S_k · β_cur`
132    /// at cache-write time the predictor's per-call work drops to
133    /// just the `Δρ_k · e^{ρ_k} · sb_block` accumulation, which is
134    /// `O(p)` rather than `O(p²)`.
135    ///
136    /// `None` when the cache predates this commit's writer hook (e.g.,
137    /// transient state during invalidation); the predictor falls back
138    /// to recomputing the mat-vec when this is `None` or the length
139    /// mismatches `canonical_penalties.len()`.
140    pub lambda_s_beta_blocks: Option<Vec<ndarray::Array1<f64>>>,
141}
142
143#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
144pub(crate) struct TauTauPlanEstimate {
145    pub(crate) dense_x_bytes: usize,
146    pub(crate) first_order_tau_bytes: usize,
147    pub(crate) second_order_tau_bytes: usize,
148    pub(crate) penalty_first_bytes: usize,
149    pub(crate) penalty_pair_bytes: usize,
150    pub(crate) rho_tau_penalty_bytes: usize,
151    pub(crate) vector_cache_bytes: usize,
152    pub(crate) weighted_scratch_bytes: usize,
153}
154
155impl TauTauPlanEstimate {
156    pub(crate) fn total_bytes(self) -> usize {
157        self.dense_x_bytes
158            .saturating_add(self.first_order_tau_bytes)
159            .saturating_add(self.second_order_tau_bytes)
160            .saturating_add(self.penalty_first_bytes)
161            .saturating_add(self.penalty_pair_bytes)
162            .saturating_add(self.rho_tau_penalty_bytes)
163            .saturating_add(self.vector_cache_bytes)
164            .saturating_add(self.weighted_scratch_bytes)
165    }
166}
167
168#[derive(Clone, Copy, Debug, PartialEq, Eq)]
169pub(crate) struct TauTauHessianPolicy {
170    pub(crate) any_has_implicit: bool,
171    pub(crate) implicit_multidim_duchon: bool,
172    pub(crate) estimated_dense_tau_cache_bytes: usize,
173    pub(crate) gradient_plan: TauTauPlanEstimate,
174    pub(crate) hessian_plan: TauTauPlanEstimate,
175    pub(crate) budget_bytes: usize,
176    pub(crate) firth_pair_terms_unavailable: bool,
177}
178
179impl TauTauHessianPolicy {
180    /// True when the τ-τ exact-Hessian path cannot be assembled at all and the
181    /// eval must fall back to value-and-gradient mode (forcing
182    /// `HessianResult::Unavailable`).
183    ///
184    /// This is the *only* remaining capability gate: the previous
185    /// implementation also forced gradient-only when the design used implicit
186    /// multi-dim Duchon storage or when the dense τ-cache plan would exceed
187    /// the budget.  Both of those are now *cost* gates, not capability gates
188    /// — the unified evaluator's `prefer_outer_hessian_operator(n, p, k)`
189    /// selects the matrix-free `HessianResult::Operator` representation in
190    /// exactly the regimes where the dense cache would be unaffordable, and
191    /// the planner routes operator returns through `run_operator_trust_region`
192    /// (or basis-probes them when `dim ≤ OUTER_HVP_MATERIALIZE_MAX_DIM`).
193    /// Forcing gradient-only would have prevented the operator representation
194    /// from ever being requested, defeating that routing; hence the
195    /// `implicit_multidim_duchon` and cost-bytes clauses are deliberately
196    /// gone.
197    ///
198    /// Firth-pair-terms-unavailability remains a capability gate: when the
199    /// Firth-aware derivative provider cannot produce the τ-τ pair
200    /// corrections at all, no representation choice can substitute.  At every
201    /// production call site this flag is hardcoded `false` (the
202    /// `hphi_tau_tau_partial_apply` + `d_beta_hphi_tau_partial_apply`
203    /// primitives now cover the gap), so this method effectively returns
204    /// `false` in production.  We retain the field and method signature
205    /// unchanged so future Firth corner cases have a single, surfaced place
206    /// to land.
207    pub(crate) fn prefer_gradient_only(self) -> bool {
208        self.firth_pair_terms_unavailable
209    }
210}
211
212pub(crate) fn exact_tau_tau_hessian_policy_with_firth(
213    n_obs: usize,
214    p_coeff: usize,
215    hyper_dirs: &[DirectionalHyperParam],
216    firth_pair_terms_unavailable: bool,
217) -> TauTauHessianPolicy {
218    let f64_bytes = std::mem::size_of::<f64>();
219    let dense_matrix_bytes =
220        |rows: usize, cols: usize| -> usize { rows.saturating_mul(cols).saturating_mul(f64_bytes) };
221    let dense_design_bytes = dense_matrix_bytes(n_obs, p_coeff);
222    let dense_penalty_bytes = dense_matrix_bytes(p_coeff, p_coeff);
223    let psi_dim = hyper_dirs.len();
224    let implicit_n_axes = hyper_dirs
225        .iter()
226        .find_map(DirectionalHyperParam::implicit_axis_count_hint)
227        .unwrap_or(0);
228    let gradient_uses_implicit_design = hyper_dirs
229        .iter()
230        .any(DirectionalHyperParam::has_implicit_operator)
231        && gam_terms::basis::should_use_implicit_operators_with_policy(
232            n_obs,
233            p_coeff,
234            implicit_n_axes,
235            &gam_runtime::resource::ResourcePolicy::default_library(),
236        );
237    let dense_first_order_count = hyper_dirs
238        .iter()
239        .filter(|dir| !dir.has_implicit_operator())
240        .count();
241    let first_penalty_component_count = hyper_dirs
242        .iter()
243        .map(DirectionalHyperParam::penalty_first_component_count)
244        .sum::<usize>();
245
246    let mut dense_second_order_count = 0usize;
247    let mut penalty_pair_count = 0usize;
248    for i in 0..psi_dim {
249        for j in i..psi_dim {
250            if hyper_dirs[i]
251                .x_tau_tau_entry_at(j)
252                .or_else(|| hyper_dirs[j].x_tau_tau_entry_at(i))
253                .is_some_and(|entry| !entry.uses_implicit_storage())
254            {
255                dense_second_order_count += if i == j { 1 } else { 2 };
256            }
257            if hyper_dirs[i].has_penaltysecond_pair_at(j)
258                || hyper_dirs[j].has_penaltysecond_pair_at(i)
259            {
260                penalty_pair_count += if i == j { 1 } else { 2 };
261            }
262        }
263    }
264
265    let gradient_dense_first_order_count = if gradient_uses_implicit_design {
266        dense_first_order_count
267    } else {
268        psi_dim
269    };
270    let gradient_needs_dense_x =
271        firth_pair_terms_unavailable || gradient_dense_first_order_count > 0;
272    let gradient_plan = TauTauPlanEstimate {
273        dense_x_bytes: if gradient_needs_dense_x {
274            dense_design_bytes
275        } else {
276            0
277        },
278        first_order_tau_bytes: if gradient_dense_first_order_count > 0 {
279            dense_design_bytes
280        } else {
281            0
282        },
283        second_order_tau_bytes: 0,
284        penalty_first_bytes: psi_dim.saturating_mul(dense_penalty_bytes),
285        penalty_pair_bytes: 0,
286        rho_tau_penalty_bytes: 0,
287        vector_cache_bytes: n_obs.saturating_mul(f64_bytes),
288        weighted_scratch_bytes: dense_penalty_bytes,
289    };
290    let hessian_plan = TauTauPlanEstimate {
291        dense_x_bytes: if psi_dim > 0 { dense_design_bytes } else { 0 },
292        first_order_tau_bytes: dense_first_order_count.saturating_mul(dense_design_bytes),
293        second_order_tau_bytes: dense_second_order_count.saturating_mul(dense_design_bytes),
294        penalty_first_bytes: psi_dim.saturating_mul(dense_penalty_bytes),
295        penalty_pair_bytes: penalty_pair_count.saturating_mul(dense_penalty_bytes),
296        rho_tau_penalty_bytes: first_penalty_component_count
297            .saturating_mul(2)
298            .saturating_mul(dense_penalty_bytes),
299        vector_cache_bytes: psi_dim.saturating_mul(n_obs).saturating_mul(f64_bytes),
300        weighted_scratch_bytes: dense_penalty_bytes,
301    };
302    let any_has_implicit = hyper_dirs
303        .iter()
304        .any(DirectionalHyperParam::has_implicit_operator);
305    let implicit_multidim_duchon = hyper_dirs
306        .iter()
307        .any(DirectionalHyperParam::has_implicit_multidim_duchon);
308    let estimated_dense_tau_cache_bytes = hessian_plan
309        .first_order_tau_bytes
310        .saturating_add(hessian_plan.second_order_tau_bytes);
311    TauTauHessianPolicy {
312        any_has_implicit,
313        implicit_multidim_duchon,
314        estimated_dense_tau_cache_bytes,
315        gradient_plan,
316        hessian_plan,
317        budget_bytes: EXACT_TAU_TAU_HESSIAN_DENSE_CACHE_BUDGET_BYTES,
318        firth_pair_terms_unavailable: firth_pair_terms_unavailable && !hyper_dirs.is_empty(),
319    }
320}
321
322pub(crate) fn firth_problem_scale_allows(n_obs: usize, p_coeff: usize) -> bool {
323    let linear_work = n_obs.saturating_mul(p_coeff);
324    let quadratic_work = linear_work.saturating_mul(p_coeff);
325    n_obs <= FIRTH_MAX_OBSERVATIONS
326        && p_coeff <= FIRTH_MAX_COEFFICIENTS
327        && linear_work <= FIRTH_MAX_LINEAR_WORK
328        && quadratic_work <= FIRTH_MAX_QUADRATIC_WORK
329}
330
331#[cfg(test)]
332mod tests {
333    use super::{
334        DirectionalHyperParam, EvalCacheManager, EvalShared, HyperDesignDerivative,
335        HyperPenaltyDerivative, ImplicitDerivLevel, RemlConfig, RemlState,
336    };
337    use crate::estimate::EstimationError;
338    use gam_linalg::faer_ndarray::FaerCholesky;
339    use gam_linalg::matrix::symmetrize_in_place;
340    use crate::pirls::PirlsCoordinateFrame;
341    use gam_terms::basis::{ImplicitDesignPsiDerivative, RadialScalarKind};
342    use gam_problem::{
343        GlmLikelihoodSpec, InverseLink, LikelihoodSpec, ResponseFamily, StandardLink,
344    };
345    use faer::Side;
346    use gam_problem::{HessianResult, OuterEval};
347    use ndarray::{Array1, Array2, array, s};
348    use std::sync::Arc;
349
350    /// Shorthand for the canonical Binomial-Logit `GlmLikelihoodSpec` used by
351    /// the REML test fixtures.
352    pub(crate) fn binomial_logit_glm_spec() -> GlmLikelihoodSpec {
353        GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
354            ResponseFamily::Binomial,
355            InverseLink::Standard(StandardLink::Logit),
356        ))
357    }
358
359    /// Shorthand for the canonical Gaussian-Identity `GlmLikelihoodSpec` used
360    /// by the REML test fixtures.
361    pub(crate) fn gaussian_identity_glm_spec() -> GlmLikelihoodSpec {
362        GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
363            ResponseFamily::Gaussian,
364            InverseLink::Standard(StandardLink::Identity),
365        ))
366    }
367
368    impl DirectionalHyperParam {
369        pub(super) fn new(
370            x_tau_original: Array2<f64>,
371            penalty_first_components: Vec<(usize, Array2<f64>)>,
372            x_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
373            penaltysecond_components: Option<Vec<Option<Vec<(usize, Array2<f64>)>>>>,
374        ) -> Result<Self, EstimationError> {
375            let x_tau_tau_original = x_tau_tau_original.map(|rows| {
376                rows.into_iter()
377                    .map(|entry| entry.map(HyperDesignDerivative::from))
378                    .collect::<Vec<_>>()
379            });
380            let penalty_first_components = penalty_first_components
381                .into_iter()
382                .map(|(idx, matrix)| (idx, HyperPenaltyDerivative::from(matrix)))
383                .collect();
384            let penaltysecond_components = penaltysecond_components.map(|rows| {
385                rows.into_iter()
386                    .map(|row| {
387                        row.map(|components| {
388                            components
389                                .into_iter()
390                                .map(|(idx, matrix)| (idx, HyperPenaltyDerivative::from(matrix)))
391                                .collect::<Vec<_>>()
392                        })
393                    })
394                    .collect::<Vec<_>>()
395            });
396            Self::new_compact(
397                HyperDesignDerivative::from(x_tau_original),
398                penalty_first_components,
399                x_tau_tau_original,
400                penaltysecond_components,
401            )
402        }
403
404        pub(super) fn single_penalty(
405            penalty_index: usize,
406            x_tau_original: Array2<f64>,
407            s_tau_original: Array2<f64>,
408            x_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
409            s_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
410        ) -> Result<Self, EstimationError> {
411            let penaltysecond_components = s_tau_tau_original.map(|rows| {
412                rows.into_iter()
413                    .map(|mat| mat.map(|mat| vec![(penalty_index, mat)]))
414                    .collect::<Vec<_>>()
415            });
416            Self::new(
417                x_tau_original,
418                vec![(penalty_index, s_tau_original)],
419                x_tau_tau_original,
420                penaltysecond_components,
421            )
422        }
423    }
424
425    #[test]
426    pub(crate) fn firth_problem_scale_gate_blocks_large_quadratic_work() {
427        assert!(super::firth_problem_scale_allows(2_000, 200));
428        assert!(!super::firth_problem_scale_allows(4_800, 241));
429        assert!(!super::firth_problem_scale_allows(4_800, 433));
430    }
431
432    #[test]
433    pub(crate) fn tau_tau_hessian_policy_prefers_gradient_only_for_implicit_tau() {
434        let operator = ImplicitDesignPsiDerivative::new(
435            array![1.0, 2.0, 3.0, 4.0],
436            array![0.5, -1.0, 1.5, 2.0],
437            array![0.1, 0.2, 0.3, 0.4],
438            array![[1.0, 0.2], [0.5, 0.1], [1.5, 0.3], [2.0, 0.4]],
439            None,
440            None,
441            2,
442            2,
443            1,
444            2,
445        );
446        let dir = DirectionalHyperParam::new_compact(
447            HyperDesignDerivative::from_implicit(
448                Arc::new(operator),
449                ImplicitDerivLevel::First(0),
450                1..4,
451                5,
452            ),
453            Vec::new(),
454            None,
455            None,
456        )
457        .expect("implicit directional hyperparam");
458        let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], false);
459        assert!(policy.any_has_implicit);
460        assert_eq!(
461            policy.gradient_plan.dense_x_bytes,
462            10 * 5 * std::mem::size_of::<f64>()
463        );
464        assert!(!policy.prefer_gradient_only());
465    }
466
467    #[test]
468    pub(crate) fn tau_tau_hessian_policy_does_not_force_gradient_only_for_implicit_multidim_duchon()
469    {
470        // Multi-dim Duchon implicit storage used to force gradient-only,
471        // because the τ-cache materialization plan was infeasible.  The
472        // unified evaluator now elects the matrix-free
473        // `HessianResult::Operator` representation in this regime via
474        // `prefer_outer_hessian_operator`, so the planner can route to the
475        // operator trust-region (or basis-probe to dense for small K) — the
476        // capability is preserved and gradient-only must NOT engage.
477        let operator = ImplicitDesignPsiDerivative::new_streaming(
478            Arc::new(array![[0.0, 0.0], [1.0, 0.2]]),
479            Arc::new(array![[0.0, 0.0], [1.0, 1.0]]),
480            vec![0.0, 0.0],
481            RadialScalarKind::PureDuchon {
482                block_order: 1,
483                p_order: 0,
484                s_order: 0,
485                dim: 2,
486            },
487            None,
488            None,
489            0,
490        );
491        let dir = DirectionalHyperParam::new_compact(
492            HyperDesignDerivative::from_implicit(
493                Arc::new(operator),
494                ImplicitDerivLevel::First(0),
495                0..2,
496                2,
497            ),
498            Vec::new(),
499            None,
500            None,
501        )
502        .expect("implicit duchon directional hyperparam");
503        let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], false);
504        assert!(policy.any_has_implicit);
505        assert!(policy.implicit_multidim_duchon);
506        assert!(!policy.prefer_gradient_only());
507    }
508
509    #[test]
510    pub(crate) fn tau_tau_hessian_policy_does_not_force_gradient_only_when_cache_budget_is_exceeded()
511     {
512        // The dense τ-cache plan exceeds the budget, but cost is no longer a
513        // capability gate: the eval-side selects the matrix-free operator
514        // representation in exactly this regime, and the planner routes
515        // accordingly.  `prefer_gradient_only` must NOT force `Unavailable`
516        // here.
517        let dirs = (0..16)
518            .map(|_| {
519                DirectionalHyperParam::new_compact(
520                    HyperDesignDerivative::from(Array2::<f64>::zeros((2, 2))),
521                    Vec::new(),
522                    None,
523                    None,
524                )
525                .expect("dense directional hyperparam")
526            })
527            .collect::<Vec<_>>();
528        let policy = super::exact_tau_tau_hessian_policy_with_firth(320_000, 71, &dirs, false);
529        assert!(!policy.any_has_implicit);
530        assert!(policy.hessian_plan.total_bytes() > policy.budget_bytes);
531        assert!(policy.hessian_plan.total_bytes() > policy.gradient_plan.total_bytes());
532        assert!(!policy.prefer_gradient_only());
533    }
534
535    #[test]
536    pub(crate) fn tau_tau_hessian_policy_prefers_gradient_only_for_firth_pair_gap() {
537        let dir = DirectionalHyperParam::new_compact(
538            HyperDesignDerivative::from(Array2::<f64>::zeros((2, 2))),
539            Vec::new(),
540            None,
541            None,
542        )
543        .expect("dense directional hyperparam");
544        let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], true);
545        assert!(policy.firth_pair_terms_unavailable);
546        assert!(policy.prefer_gradient_only());
547    }
548
549    /// Common shape for the design-motion + penalty-motion REML test fixtures
550    /// (Gaussian-identity and binomial-logit at present): both carry the
551    /// same `(y, w, X, S0, cfg, ρ)` plus a perturbation pair, and need the
552    /// same three helpers (`state`, `state_perturbed`, `fd_directional_gradient`).
553    /// Per-fixture `new()` constructors fill the fields with family-specific
554    /// data; the helpers below are shared via default impls so every fixture
555    /// pays the boilerplate once.
556    trait LogitDesignMotionFixture {
557        fn y(&self) -> &Array1<f64>;
558        fn w(&self) -> &Array1<f64>;
559        fn x(&self) -> &Array2<f64>;
560        fn s0(&self) -> &Array2<f64>;
561        fn cfg(&self) -> &RemlConfig;
562        fn rho(&self) -> &Array1<f64>;
563
564        fn state(&self) -> RemlState<'_> {
565            build_logit_state(self.y(), self.w(), self.x(), self.s0(), self.cfg())
566        }
567
568        fn state_perturbed(
569            &self,
570            x_tau: &Array2<f64>,
571            s_tau: &Array2<f64>,
572            eps: f64,
573        ) -> (RemlState<'_>, RemlState<'_>) {
574            let x_plus = self.x() + &x_tau.mapv(|v| eps * v);
575            let x_minus = self.x() - &x_tau.mapv(|v| eps * v);
576            let s_plus = self.s0() + &s_tau.mapv(|v| eps * v);
577            let s_minus = self.s0() - &s_tau.mapv(|v| eps * v);
578            (
579                build_logit_state(self.y(), self.w(), &x_plus, &s_plus, self.cfg()),
580                build_logit_state(self.y(), self.w(), &x_minus, &s_minus, self.cfg()),
581            )
582        }
583
584        /// Central FD approximation to the directional cost derivative at ρ.
585        fn fd_directional_gradient(&self, x_tau: &Array2<f64>, s_tau: &Array2<f64>) -> f64 {
586            let h = 2e-5;
587            let (state_plus, state_minus) = self.state_perturbed(x_tau, s_tau, h);
588            let v_plus = state_plus.compute_cost(self.rho()).expect("cost+");
589            let v_minus = state_minus.compute_cost(self.rho()).expect("cost-");
590            (v_plus - v_minus) / (2.0 * h)
591        }
592    }
593
594    pub(crate) fn build_logit_state<'a>(
595        y: &'a Array1<f64>,
596        w: &'a Array1<f64>,
597        x: &Array2<f64>,
598        s: &Array2<f64>,
599        cfg: &'a RemlConfig,
600    ) -> RemlState<'a> {
601        use crate::estimate::PenaltySpec;
602        let p = x.ncols();
603        let offset = Array1::<f64>::zeros(y.len());
604        let spec = PenaltySpec::Dense(s.clone());
605        let canonical = gam_terms::construction::canonicalize_penalty_specs(&[spec], &[1], p, "test")
606            .map(|(canonical, _)| canonical)
607            .expect("canonicalize");
608        RemlState::newwith_offset(
609            y.view(),
610            x.clone(),
611            w.view(),
612            offset.view(),
613            canonical,
614            p,
615            cfg,
616            Some(vec![1]),
617            None,
618            None,
619        )
620        .expect("state")
621    }
622
623    #[test]
624    fn repeated_penalty_ranges_keep_analytic_outer_hessian() {
625        let y = array![0.2, -0.1, 0.3, 0.0];
626        let w = Array1::<f64>::ones(y.len());
627        let x = array![[1.0, -0.7], [1.0, -0.2], [1.0, 0.3], [1.0, 0.9]];
628        let offset = Array1::<f64>::zeros(y.len());
629        let cfg = RemlConfig::external(gaussian_identity_glm_spec(), 1e-10, false);
630        let p = x.ncols();
631        let canonical = vec![
632            gam_terms::construction::CanonicalPenalty::from_dense_root(array![[0.0, 1.0]], p),
633            gam_terms::construction::CanonicalPenalty::from_dense_root(array![[1.0, 0.0]], p),
634        ];
635        let state = RemlState::newwith_offset(
636            y.view(),
637            x,
638            w.view(),
639            offset.view(),
640            canonical,
641            p,
642            &cfg,
643            Some(vec![1, 1]),
644            None,
645            None,
646        )
647        .expect("state");
648
649        assert!(
650            state.analytic_outer_hessian_enabled(),
651            "double-penalty-style repeated coefficient ranges must still route to exact Hessian"
652        );
653    }
654
655    #[test]
656    fn canonical_logit_firth_declines_exact_tk_hessian_when_row_pair_work_is_large() {
657        let n = 2_000usize;
658        let p = 28usize;
659        let y = Array1::from_iter((0..n).map(|i| if i % 3 == 0 { 1.0 } else { 0.0 }));
660        let w = Array1::<f64>::ones(n);
661        let mut x = Array2::<f64>::zeros((n, p));
662        for i in 0..n {
663            let t = (i as f64 + 0.5) / n as f64;
664            x[[i, 0]] = 1.0;
665            for j in 1..p {
666                x[[i, j]] = ((j as f64) * std::f64::consts::TAU * t).sin()
667                    + 0.25 * (((j + 1) as f64) * std::f64::consts::TAU * t).cos();
668            }
669        }
670        let mut s = Array2::<f64>::zeros((p, p));
671        for j in 1..p {
672            s[[j, j]] = 1.0;
673        }
674        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, true);
675        let state = build_logit_state(&y, &w, &x, &s, &cfg);
676
677        assert!(
678            !RemlState::firth_tk_exact_hessian_scale_allows(n, p),
679            "fixture must sit beyond the O(n²·p) exact-Hessian budget"
680        );
681        assert!(
682            !state.analytic_outer_hessian_enabled(),
683            "large canonical-logit Firth fits should keep exact value/gradient but route outer curvature to BFGS"
684        );
685    }
686
687    #[test]
688    fn canonical_logit_firth_keeps_exact_tk_hessian_for_small_separation_guards() {
689        let n = 40usize;
690        let p = 6usize;
691        let y = Array1::from_iter((0..n).map(|i| if i >= n / 2 { 1.0 } else { 0.0 }));
692        let w = Array1::<f64>::ones(n);
693        let mut x = Array2::<f64>::zeros((n, p));
694        for i in 0..n {
695            let t = (i as f64) / (n - 1) as f64;
696            x[[i, 0]] = 1.0;
697            for j in 1..p {
698                x[[i, j]] = t.powi(j as i32);
699            }
700        }
701        let mut s = Array2::<f64>::zeros((p, p));
702        for j in 1..p {
703            s[[j, j]] = 1.0;
704        }
705        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, true);
706        let state = build_logit_state(&y, &w, &x, &s, &cfg);
707
708        assert!(RemlState::firth_tk_exact_hessian_scale_allows(n, p));
709        assert!(
710            state.analytic_outer_hessian_enabled(),
711            "small Firth rescue fits should keep exact TK Hessian curvature"
712        );
713    }
714
715    pub(crate) fn poisson_log_glm_spec() -> GlmLikelihoodSpec {
716        GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
717            ResponseFamily::Poisson,
718            InverseLink::Standard(StandardLink::Log),
719        ))
720    }
721
722    /// Regression (issue #893): for a fixed-dispersion family a uniform prior
723    /// weight `w = c` is *exact* `c`-fold row replication. The two encodings must
724    /// therefore present a byte-identical LAML smoothing-selection surface — both
725    /// the cost `V(ρ)` and its gradient `∇V(ρ)` — because every term (penalised
726    /// deviance `D_p`, the working cross-product `XᵀWX`, the log-determinants)
727    /// is a sum of per-observation contributions that is identical whether a row
728    /// carries weight `c` or is stacked `c` times. This locks the *surface*
729    /// invariant that #893 ultimately reduces to: when the surfaces coincide,
730    /// the only remaining requirement for `λ̂(w=c) = λ̂(c×)` is that the outer
731    /// optimiser resolve the shared optimum (handled by the tightened outer
732    /// tolerance in `workflow.rs`). A regression that reintroduced a
733    /// row-count-vs-weight-sum asymmetry into the inner solve or the cost would
734    /// break this directly, independent of the optimiser tolerance.
735    #[test]
736    pub(crate) fn fixed_dispersion_laml_surface_is_replication_invariant() {
737        let n = 200usize;
738        let p = 8usize;
739        let c = 3usize;
740        let mut x = Array2::<f64>::zeros((n, p));
741        let mut y = Array1::<f64>::zeros(n);
742        for i in 0..n {
743            let t = (i as f64) / ((n - 1) as f64);
744            let tau = std::f64::consts::TAU;
745            x[[i, 0]] = 1.0;
746            x[[i, 1]] = t;
747            x[[i, 2]] = (tau * t).sin();
748            x[[i, 3]] = (tau * t).cos();
749            x[[i, 4]] = (2.0 * tau * t).sin();
750            x[[i, 5]] = (2.0 * tau * t).cos();
751            x[[i, 6]] = (3.0 * tau * t).sin();
752            x[[i, 7]] = (3.0 * tau * t).cos();
753            let eta = 0.3 + 0.9 * (1.4 * (t - 0.5)).sin();
754            // Deterministic non-negative integer counts near exp(eta).
755            y[i] = (eta.exp() + 0.5 * ((i as f64) * 2.399_963).sin())
756                .round()
757                .max(0.0);
758        }
759        let mut s = Array2::<f64>::zeros((p, p));
760        for j in 1..p {
761            s[[j, j]] = 1.0;
762        }
763
764        // Replicated design (c literal copies of each row).
765        let mut x_rep = Array2::<f64>::zeros((n * c, p));
766        let mut y_rep = Array1::<f64>::zeros(n * c);
767        for r in 0..c {
768            for i in 0..n {
769                let row = r * n + i;
770                for j in 0..p {
771                    x_rep[[row, j]] = x[[i, j]];
772                }
773                y_rep[row] = y[i];
774            }
775        }
776
777        let w_weighted = Array1::<f64>::from_elem(n, c as f64);
778        let w_rep = Array1::<f64>::ones(n * c);
779
780        let cfg = RemlConfig::external(poisson_log_glm_spec(), 1e-10, false);
781        let st_w = build_logit_state(&y, &w_weighted, &x, &s, &cfg);
782        let st_r = build_logit_state(&y_rep, &w_rep, &x_rep, &s, &cfg);
783
784        for &rho in &[-2.0_f64, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0] {
785            let r = Array1::from_elem(1, rho);
786            let cw = st_w.compute_cost(&r).expect("weighted cost");
787            let cr = st_r.compute_cost(&r).expect("replicated cost");
788            let gw = st_w.compute_gradient(&r).expect("weighted grad");
789            let gr = st_r.compute_gradient(&r).expect("replicated grad");
790            // Costs and gradients must coincide to optimiser precision; the only
791            // admissible difference is f64 summation order over n vs c·n rows.
792            assert!(
793                (cw - cr).abs() <= 1e-9 * (1.0 + cw.abs()),
794                "LAML cost differs between w=c and c× replication at rho={rho}: \
795                 cost_w={cw:.12e} cost_r={cr:.12e} diff={:.3e}",
796                cw - cr
797            );
798            assert!(
799                (gw[0] - gr[0]).abs() <= 1e-9 * (1.0 + gw[0].abs()),
800                "LAML gradient differs between w=c and c× replication at rho={rho}: \
801                 g_w={:.12e} g_r={:.12e} diff={:.3e}",
802                gw[0],
803                gr[0],
804                gw[0] - gr[0]
805            );
806        }
807    }
808
809    /// Regression (issue #893): the geometric-mean log-weight ρ-anchor
810    /// ([`RemlState::rho_weight_anchor`]) is a *profiled*-dispersion construct
811    /// (issue #877). For a fixed-dispersion family the optimum does not slide by
812    /// `log c` under a weight rescale in a way the prior should track, and a
813    /// nonzero anchor would evaluate the regularising ρ-prior at *different*
814    /// coordinates for the `w=c` vs `c×` encodings — breaking the very
815    /// equivalence #893 requires. The anchor must therefore be exactly `0` for a
816    /// fixed-dispersion family and the geometric mean for Gaussian-identity.
817    #[test]
818    pub(crate) fn rho_weight_anchor_is_zero_for_fixed_dispersion() {
819        let n = 50usize;
820        let p = 3usize;
821        let mut x = Array2::<f64>::zeros((n, p));
822        let mut y = Array1::<f64>::zeros(n);
823        for i in 0..n {
824            let t = (i as f64) / ((n - 1) as f64);
825            x[[i, 0]] = 1.0;
826            x[[i, 1]] = t;
827            x[[i, 2]] = t * t;
828            y[i] = (1.0 + (3.0 * t).sin()).round().max(0.0);
829        }
830        let mut s = Array2::<f64>::zeros((p, p));
831        s[[2, 2]] = 1.0;
832        // All weights = c: geometric-mean log-weight = ln(c) ≠ 0.
833        let c = 4.0_f64;
834        let w = Array1::<f64>::from_elem(n, c);
835
836        let cfg_pois = RemlConfig::external(poisson_log_glm_spec(), 1e-10, false);
837        let st_pois = build_logit_state(&y, &w, &x, &s, &cfg_pois);
838        assert_eq!(
839            st_pois.rho_weight_anchor(),
840            0.0,
841            "fixed-dispersion (Poisson) anchor must be 0, not the geometric-mean log-weight"
842        );
843
844        let cfg_gauss = RemlConfig::external(gaussian_identity_glm_spec(), 1e-10, false);
845        let st_gauss = build_logit_state(&y, &w, &x, &s, &cfg_gauss);
846        assert!(
847            (st_gauss.rho_weight_anchor() - c.ln()).abs() <= 1e-12,
848            "Gaussian-identity (profiled) anchor must be the geometric-mean log-weight ln(c)={:.6}, got {:.6}",
849            c.ln(),
850            st_gauss.rho_weight_anchor()
851        );
852    }
853
854    pub(crate) fn beta_original_from_bundle(bundle: &EvalShared) -> Array1<f64> {
855        let pr = bundle.pirls_result.as_ref();
856        match pr.coordinate_frame {
857            PirlsCoordinateFrame::OriginalSparseNative => pr.beta_transformed.as_ref().clone(),
858            PirlsCoordinateFrame::TransformedQs => {
859                pr.reparam_result.qs.dot(pr.beta_transformed.as_ref())
860            }
861        }
862    }
863
864    pub(crate) fn compute_joint_hypercostgradienthessian(
865        state: &RemlState<'_>,
866        theta: &Array1<f64>,
867        rho_dim: usize,
868        hyper_dirs: &[DirectionalHyperParam],
869    ) -> Result<(f64, Array1<f64>, Array2<f64>), EstimationError> {
870        let (cost, gradient, hessian) = state.compute_joint_hyper_eval_with_order(
871            theta,
872            rho_dim,
873            hyper_dirs,
874            crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
875        )?;
876        Ok((
877            cost,
878            gradient,
879            hessian
880                .materialize_dense()
881                .map_err(EstimationError::RemlOptimizationFailed)?
882                .ok_or_else(|| {
883                    EstimationError::RemlOptimizationFailed(
884                        "joint hyper Hessian requested but unavailable".to_string(),
885                    )
886                })?,
887        ))
888    }
889
890    pub(crate) fn h_original_from_bundle(bundle: &EvalShared) -> Array2<f64> {
891        let pr = bundle.pirls_result.as_ref();
892        match pr.coordinate_frame {
893            PirlsCoordinateFrame::OriginalSparseNative => bundle.h_total.as_ref().clone(),
894            PirlsCoordinateFrame::TransformedQs => {
895                let qs = &pr.reparam_result.qs;
896                let tmp = gam_linalg::faer_ndarray::fast_ab(qs, bundle.h_total.as_ref());
897                gam_linalg::faer_ndarray::fast_abt(&tmp, qs)
898            }
899        }
900    }
901
902    pub(crate) fn single_directional_tau_gradient(
903        state: &RemlState<'_>,
904        rho: &Array1<f64>,
905        hyper: DirectionalHyperParam,
906    ) -> Result<f64, EstimationError> {
907        let mut theta = Array1::<f64>::zeros(rho.len() + 1);
908        theta.slice_mut(s![..rho.len()]).assign(rho);
909        let (_, gradient, _) = state.compute_joint_hyper_eval_with_order(
910            &theta,
911            rho.len(),
912            &[hyper],
913            crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
914        )?;
915        Ok(gradient[rho.len()])
916    }
917
918    pub(crate) fn fd_directional_tau_cost_gradient(
919        y: &Array1<f64>,
920        w: &Array1<f64>,
921        x: &Array2<f64>,
922        s0: &Array2<f64>,
923        cfg: &RemlConfig,
924        rho: &Array1<f64>,
925        x_tau: &Array2<f64>,
926        s_tau: &Array2<f64>,
927    ) -> f64 {
928        let h = 2e-5;
929        let x_plus = x + &x_tau.mapv(|v| h * v);
930        let x_minus = x - &x_tau.mapv(|v| h * v);
931        let s_plus = s0 + &s_tau.mapv(|v| h * v);
932        let s_minus = s0 - &s_tau.mapv(|v| h * v);
933        let state_plus = build_logit_state(y, w, &x_plus, &s_plus, cfg);
934        let state_minus = build_logit_state(y, w, &x_minus, &s_minus, cfg);
935        let v_plus = state_plus.compute_cost(rho).expect("cost+");
936        let v_minus = state_minus.compute_cost(rho).expect("cost-");
937        (v_plus - v_minus) / (2.0 * h)
938    }
939
940    pub(crate) fn directional_tau_hessian_fd_reference(
941        y: &Array1<f64>,
942        w: &Array1<f64>,
943        x: &Array2<f64>,
944        s0: &Array2<f64>,
945        cfg: &RemlConfig,
946        rho: &Array1<f64>,
947        hyper_dirs: &[DirectionalHyperParam],
948        x_tau_mats: &[Array2<f64>],
949        s_tau_mats: &[Array2<f64>],
950    ) -> Array2<f64> {
951        assert_eq!(hyper_dirs.len(), x_tau_mats.len());
952        assert_eq!(hyper_dirs.len(), s_tau_mats.len());
953
954        const TARGET_PHYSICAL_STEP: f64 = 1e-5;
955
956        let n_dirs = hyper_dirs.len();
957        let mut h_ttfd = Array2::<f64>::zeros((n_dirs, n_dirs));
958        for j in 0..n_dirs {
959            let direction_scale = x_tau_mats[j]
960                .iter()
961                .chain(s_tau_mats[j].iter())
962                .fold(0.0_f64, |acc, value| acc.max(value.abs()));
963            let h = if direction_scale > 0.0 {
964                TARGET_PHYSICAL_STEP / direction_scale
965            } else {
966                TARGET_PHYSICAL_STEP
967            };
968
969            let x_plus = x + &x_tau_mats[j].mapv(|v| h * v);
970            let x_minus = x - &x_tau_mats[j].mapv(|v| h * v);
971            let s_plus = s0 + &s_tau_mats[j].mapv(|v| h * v);
972            let s_minus = s0 - &s_tau_mats[j].mapv(|v| h * v);
973
974            let state_plus = build_logit_state(y, w, &x_plus, &s_plus, cfg);
975            let state_minus = build_logit_state(y, w, &x_minus, &s_minus, cfg);
976            for i in 0..n_dirs {
977                let g_plus =
978                    single_directional_tau_gradient(&state_plus, rho, hyper_dirs[i].clone())
979                        .expect("g+ for FD");
980                let g_minus =
981                    single_directional_tau_gradient(&state_minus, rho, hyper_dirs[i].clone())
982                        .expect("g- for FD");
983                h_ttfd[[i, j]] = (g_plus - g_minus) / (2.0 * h);
984            }
985        }
986        symmetrize_in_place(&mut h_ttfd);
987        h_ttfd
988    }
989
990    #[test]
991    pub(crate) fn eval_cache_manager_stores_first_order_outer_eval() {
992        let cache = EvalCacheManager::new();
993        let rho = array![0.25, -0.0];
994        let rho_key = EvalCacheManager::sanitized_rhokey(&rho);
995        let eval = OuterEval {
996            cost: 3.5,
997            gradient: array![1.0, -2.0],
998            hessian: HessianResult::Unavailable,
999            inner_beta_hint: None,
1000        };
1001
1002        cache.store_outer_eval(&rho_key, &eval);
1003
1004        let cached = cache
1005            .cached_outer_eval(&rho_key)
1006            .expect("first-order outer eval should be cached");
1007        assert_eq!(cached.cost, eval.cost);
1008        assert_eq!(cached.gradient, eval.gradient);
1009        assert!(matches!(cached.hessian, HessianResult::Unavailable));
1010
1011        cache.invalidate_eval_bundle();
1012        assert!(
1013            cache.cached_outer_eval(&rho_key).is_none(),
1014            "invalidating the bundle should clear the outer-eval cache too"
1015        );
1016    }
1017
1018    /// #1575 multi-slot outer-eval cache correctness oracle.
1019    ///
1020    /// A memoization is only safe if a hit returns *exactly* what the miss path
1021    /// stored. This pins three properties of the bounded LRU:
1022    ///   1. round-trip fidelity — a hit is `f64::to_bits`-identical in cost AND
1023    ///      every gradient component to the value stored on the miss path;
1024    ///   2. no aliasing — distinct rho-keys never return each other's eval;
1025    ///   3. honest eviction — once an evicted key is requested again it MISSES
1026    ///      (so the caller recomputes) rather than returning a stale neighbour.
1027    #[test]
1028    pub(crate) fn outer_eval_lru_hit_is_bit_identical_and_evicts_honestly_1575() {
1029        use super::OUTER_EVAL_LRU_CAPACITY;
1030
1031        // Helper: a deterministic OuterEval whose bits encode `seed`, so any
1032        // cross-key contamination is detectable bit-for-bit.
1033        let make_eval = |seed: f64| OuterEval {
1034            cost: (seed * std::f64::consts::PI).sin() / 3.0 - seed,
1035            gradient: array![seed, -seed * 2.0, seed.recip()],
1036            hessian: HessianResult::Unavailable,
1037            inner_beta_hint: Some(array![seed + 0.5, seed - 0.5]),
1038        };
1039        let bits_eq = |a: &OuterEval, b: &OuterEval| -> bool {
1040            a.cost.to_bits() == b.cost.to_bits()
1041                && a.gradient.len() == b.gradient.len()
1042                && a.gradient
1043                    .iter()
1044                    .zip(b.gradient.iter())
1045                    .all(|(x, y)| x.to_bits() == y.to_bits())
1046        };
1047
1048        let cache = EvalCacheManager::new();
1049
1050        // (1) Round-trip fidelity: store at rho_a, then a forced hit must equal
1051        // the stored eval bit-for-bit (the "hit == miss" guarantee).
1052        let rho_a = array![0.25, -1.5];
1053        let key_a = EvalCacheManager::sanitized_rhokey(&rho_a);
1054        let eval_a = make_eval(0.25);
1055        cache.store_outer_eval(&key_a, &eval_a);
1056        let hit_a = cache
1057            .cached_outer_eval(&key_a)
1058            .expect("stored rho_a must hit");
1059        assert!(
1060            bits_eq(&hit_a, &eval_a),
1061            "cache hit must be bit-identical (cost+gradient) to the stored miss-path eval"
1062        );
1063        assert_eq!(
1064            hit_a.inner_beta_hint.as_ref().map(|b| b.to_vec()),
1065            eval_a.inner_beta_hint.as_ref().map(|b| b.to_vec()),
1066            "inner_beta_hint must round-trip unchanged"
1067        );
1068
1069        // (2) No aliasing: a second, distinct rho must return its OWN eval, and
1070        // the first key must still return the first eval untouched.
1071        let rho_b = array![0.25, -1.4999999999999998];
1072        let key_b = EvalCacheManager::sanitized_rhokey(&rho_b);
1073        assert_ne!(key_a, key_b, "the two rho-keys must differ");
1074        let eval_b = make_eval(7.0);
1075        cache.store_outer_eval(&key_b, &eval_b);
1076        assert!(
1077            bits_eq(
1078                &cache.cached_outer_eval(&key_b).expect("rho_b must hit"),
1079                &eval_b
1080            ),
1081            "rho_b must return its own eval, not rho_a's"
1082        );
1083        assert!(
1084            bits_eq(
1085                &cache.cached_outer_eval(&key_a).expect("rho_a must still hit"),
1086                &eval_a
1087            ),
1088            "rho_a must be unaffected by the rho_b insert"
1089        );
1090
1091        // (3) Honest eviction: overflow the LRU with fresh keys. The
1092        // least-recently-used entry must be evicted and then MISS (forcing a
1093        // recompute), while a still-resident key returns its exact stored bits.
1094        let cache = EvalCacheManager::new();
1095        let mut keys = Vec::new();
1096        let mut evals = Vec::new();
1097        for i in 0..OUTER_EVAL_LRU_CAPACITY {
1098            let rho = array![i as f64, -(i as f64)];
1099            let key = EvalCacheManager::sanitized_rhokey(&rho);
1100            let eval = make_eval(i as f64 + 0.123);
1101            cache.store_outer_eval(&key, &eval);
1102            keys.push(key);
1103            evals.push(eval);
1104        }
1105        // Cache is exactly full; key[0] is the least-recently-used.
1106        assert_eq!(
1107            cache.outer_eval_lru.read().unwrap().entries.len(),
1108            OUTER_EVAL_LRU_CAPACITY
1109        );
1110        // One more distinct key evicts the LRU (key[0]).
1111        let rho_overflow = array![999.0, -999.0];
1112        let key_overflow = EvalCacheManager::sanitized_rhokey(&rho_overflow);
1113        let eval_overflow = make_eval(42.0);
1114        cache.store_outer_eval(&key_overflow, &eval_overflow);
1115        assert_eq!(
1116            cache.outer_eval_lru.read().unwrap().entries.len(),
1117            OUTER_EVAL_LRU_CAPACITY,
1118            "capacity must stay bounded"
1119        );
1120        assert!(
1121            cache.cached_outer_eval(&keys[0]).is_none(),
1122            "the least-recently-used key must be evicted and now MISS (recompute), not return stale"
1123        );
1124        assert!(
1125            bits_eq(
1126                &cache
1127                    .cached_outer_eval(&keys[1])
1128                    .expect("a still-resident key must hit"),
1129                &evals[1]
1130            ),
1131            "a still-resident key must return its exact stored bits"
1132        );
1133        assert!(
1134            bits_eq(
1135                &cache
1136                    .cached_outer_eval(&key_overflow)
1137                    .expect("the freshest key must hit"),
1138                &eval_overflow
1139            ),
1140            "the freshest key must hit with its own eval"
1141        );
1142    }
1143
1144    #[test]
1145    pub(crate) fn reset_outer_seed_state_clears_pirls_cache() {
1146        // Build a minimal logit RemlState, populate the cross-call PIRLS LRU
1147        // by evaluating the outer objective at one rho, then verify that
1148        // reset_outer_seed_state wipes that LRU (alongside the eval bundle
1149        // and warm-start signals). This pins down the cross-attempt
1150        // cleanup contract that a budget-bump retry relies on.
1151        let y = array![0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
1152        let w = Array1::<f64>::ones(y.len());
1153        let x = array![
1154            [1.0, -1.0, 0.2],
1155            [1.0, -0.5, -0.4],
1156            [1.0, 0.0, 0.7],
1157            [1.0, 0.4, -0.3],
1158            [1.0, 0.9, 0.1],
1159            [1.0, 1.3, -0.6],
1160        ];
1161        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.1, 0.15], [0.0, 0.15, 0.8],];
1162        let rho = array![0.0];
1163        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false);
1164        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1165
1166        // Trigger a full outer eval so execute_pirls_if_needed inserts at
1167        // least one entry into the cross-call PIRLS LRU.
1168        state
1169            .compute_outer_eval_with_order(
1170                &rho,
1171                crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
1172            )
1173            .expect("outer eval should succeed");
1174
1175        let populated_len = state.cache_manager.pirls_cache.read().unwrap().map.len();
1176        assert!(
1177            populated_len > 0,
1178            "evaluating the outer objective should populate the PIRLS LRU, got {populated_len}"
1179        );
1180
1181        state.reset_outer_seed_state();
1182
1183        let cleared_len = state.cache_manager.pirls_cache.read().unwrap().map.len();
1184        assert_eq!(
1185            cleared_len, 0,
1186            "reset_outer_seed_state must clear the cross-call PIRLS LRU; got {cleared_len} entries"
1187        );
1188    }
1189
1190    #[test]
1191    pub(crate) fn reset_outer_seed_state_preserves_frozen_negbin_theta_1448() {
1192        // #1448 regression: the NB outer θ↔λ alternation loop
1193        // (solver/estimate/optimizer.rs) re-runs the ρ search after each θ
1194        // refresh by (a) re-freezing the λ-search θ at θ_final into
1195        // `frozen_negbin_theta`, then (b) calling `reset_outer_seed_state()` to
1196        // drop the caches keyed to the old θ. Step (b) MUST NOT clear the freeze
1197        // set in step (a): the capture in `solve_for_unified_rho` only writes the
1198        // frozen slot when it is 0, so if the reset zeroed it the next round would
1199        // re-derive θ from the seed η and the loop would never reach the (ρ, θ)
1200        // joint fixed point — silently regressing #1448 back to a single
1201        // freeze→refresh pass.
1202        //
1203        // This pins the load-bearing distinction between `reset_outer_seed_state`
1204        // (alternation-round reset, freeze SURVIVES) and the surface-refresh reset
1205        // (new design, freeze re-zeroed). The end-to-end convergence on a real NB
1206        // fit is exercised by the public-API path; here we lock the invariant the
1207        // loop depends on, next to `reset_outer_seed_state_clears_pirls_cache`.
1208        use std::sync::atomic::Ordering;
1209
1210        let y = array![0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
1211        let w = Array1::<f64>::ones(y.len());
1212        let x = array![
1213            [1.0, -1.0, 0.2],
1214            [1.0, -0.5, -0.4],
1215            [1.0, 0.0, 0.7],
1216            [1.0, 0.4, -0.3],
1217            [1.0, 0.9, 0.1],
1218            [1.0, 1.3, -0.6],
1219        ];
1220        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.1, 0.15], [0.0, 0.15, 0.8],];
1221        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false);
1222        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1223
1224        // Simulate the alternation loop's re-freeze step: pin θ_final.
1225        let theta_final_bits = 2.5_f64.to_bits();
1226        state
1227            .frozen_negbin_theta
1228            .store(theta_final_bits, Ordering::Relaxed);
1229        assert_eq!(
1230            state.frozen_negbin_theta.load(Ordering::Relaxed),
1231            theta_final_bits,
1232            "precondition: the re-freeze stores θ_final into the frozen slot"
1233        );
1234
1235        // The alternation loop's per-round reset.
1236        state.reset_outer_seed_state();
1237
1238        assert_eq!(
1239            state.frozen_negbin_theta.load(Ordering::Relaxed),
1240            theta_final_bits,
1241            "reset_outer_seed_state (alternation-round reset) must PRESERVE the \
1242             re-frozen NB θ; clearing it would defeat the #1448 θ↔λ alternation \
1243             (the next ρ search would re-derive θ from the seed and never reach \
1244             the joint fixed point)"
1245        );
1246    }
1247
1248    #[test]
1249    pub(crate) fn implicit_hyper_design_derivative_respects_full_model_embedding() {
1250        let operator = ImplicitDesignPsiDerivative::new(
1251            array![1.0, 2.0, 3.0, 4.0],
1252            array![0.5, -1.0, 1.5, 2.0],
1253            array![0.1, 0.2, 0.3, 0.4],
1254            array![[1.0, 0.2], [0.5, 0.1], [1.5, 0.3], [2.0, 0.4]],
1255            None,
1256            None,
1257            2,
1258            2,
1259            1,
1260            2,
1261        );
1262        let local = operator
1263            .materialize_first(0)
1264            .expect("materialized first derivative");
1265        assert_eq!(
1266            local.ncols(),
1267            3,
1268            "operator-local derivative should stay smooth-local"
1269        );
1270
1271        let implicit = HyperDesignDerivative::from_implicit(
1272            Arc::new(operator),
1273            ImplicitDerivLevel::First(0),
1274            1..4,
1275            5,
1276        );
1277        let embedded = HyperDesignDerivative::from_embedded(local.clone(), 1..4, 5);
1278
1279        assert_eq!(implicit.nrows(), embedded.nrows());
1280        assert_eq!(implicit.ncols(), 5);
1281        assert_eq!(implicit.materialize(), embedded.materialize());
1282
1283        let u = array![7.0, 1.5, -2.0, 0.25, -3.0];
1284        let v = array![0.75, -1.25];
1285        assert_eq!(
1286            implicit.forward_mul_original(&u).expect("implicit forward"),
1287            embedded.forward_mul_original(&u).expect("embedded forward")
1288        );
1289        assert_eq!(
1290            implicit
1291                .transpose_mul_original(&v)
1292                .expect("implicit transpose"),
1293            embedded
1294                .transpose_mul_original(&v)
1295                .expect("embedded transpose")
1296        );
1297
1298        let qs = array![
1299            [1.0, 0.0, 0.0],
1300            [0.0, 1.0, 0.0],
1301            [0.0, 0.5, 0.5],
1302            [0.0, 0.0, 1.0],
1303            [0.0, 0.0, 0.0],
1304        ];
1305        assert_eq!(
1306            implicit
1307                .transformed(&qs, None)
1308                .expect("implicit transformed"),
1309            embedded
1310                .transformed(&qs, None)
1311                .expect("embedded transformed")
1312        );
1313        let u_transformed = array![1.0, -0.5, 2.0];
1314        assert_eq!(
1315            implicit
1316                .transformed_forward_mul(&qs, None, &u_transformed)
1317                .expect("implicit transformed forward"),
1318            embedded
1319                .transformed_forward_mul(&qs, None, &u_transformed)
1320                .expect("embedded transformed forward")
1321        );
1322        assert_eq!(
1323            implicit
1324                .transformed_transpose_mul(&qs, None, &v)
1325                .expect("implicit transformed transpose"),
1326            embedded
1327                .transformed_transpose_mul(&qs, None, &v)
1328                .expect("embedded transformed transpose")
1329        );
1330    }
1331
1332    #[test]
1333    pub(crate) fn directional_hyper_identities_match_finite_differences_logit() {
1334        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1335        let w = Array1::<f64>::ones(y.len());
1336        let x = array![
1337            [1.0, -1.2, 0.3],
1338            [1.0, -0.8, -0.4],
1339            [1.0, -0.3, 0.7],
1340            [1.0, 0.1, -0.9],
1341            [1.0, 0.5, 0.2],
1342            [1.0, 0.9, -0.1],
1343            [1.0, 1.3, 0.8],
1344            [1.0, 1.7, -0.6],
1345        ];
1346        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1347
1348        // Use one directional hyperparameter τ with a penalty perturbation:
1349        // S(τ) = S + τ S_τ.
1350        // Keep X_τ = 0 so this identity test remains valid in both non-Firth
1351        // and Firth-logit modes.
1352        let x_tau = Array2::<f64>::zeros(x.raw_dim());
1353        let s_tau = array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15],];
1354        let hyper =
1355            DirectionalHyperParam::single_penalty(0, x_tau.clone(), s_tau.clone(), None, None)
1356                .expect("single-penalty hyper direction");
1357        let rho = array![0.0];
1358
1359        // Tight inner tolerance: the envelope theorem requires an exact inner
1360        // P-IRLS optimum; 1e-10 leaves enough residual gradient to cause ~12%
1361        // V_tau mismatch on this small (n=8) logistic problem.
1362        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
1363        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1364        let bundle = state.obtain_eval_bundle(&rho).expect("bundle");
1365        let pr = bundle.pirls_result.as_ref();
1366
1367        let beta = beta_original_from_bundle(&bundle);
1368        let h_orig = h_original_from_bundle(&bundle);
1369        let u = &pr.solveweights * &(&pr.solveworking_response - &pr.final_eta);
1370
1371        // B from implicit solve:
1372        //   H B = X_τ^T g - X^T W(X_τ β̂) - S_τ β̂.
1373        let x_tau_beta = gam_linalg::faer_ndarray::fast_av(&x_tau, &beta);
1374        let weighted_x_tau_beta = &pr.finalweights * &x_tau_beta;
1375        let rhs = gam_linalg::faer_ndarray::fast_atv(&x_tau, &u)
1376            - gam_linalg::faer_ndarray::fast_atv(&x, &weighted_x_tau_beta)
1377            - s_tau.dot(&beta);
1378        let chol = h_orig.cholesky(Side::Lower).expect("chol(H)");
1379        let b_analytic = chol.solvevec(&rhs);
1380
1381        // H_τ from exact total derivative:
1382        //   H_τ = X_τ^T W X + X^T W X_τ + X^T W_τ X + S_τ,
1383        // with W_τ provided by the family directional curvature callback.
1384        let eta_dot = &x_tau_beta + &gam_linalg::faer_ndarray::fast_av(&x, &b_analytic);
1385        let w_direction = crate::pirls::directionalworking_curvature_from_c_array(
1386            &pr.solve_c_array,
1387            &pr.finalweights,
1388            &eta_dot,
1389        );
1390        let wx = RemlState::row_scale(&x, &pr.finalweights);
1391        let wx_tau = RemlState::row_scale(&x_tau, &pr.finalweights);
1392        let mut xwtau_x = x.clone();
1393        match w_direction {
1394            crate::pirls::DirectionalWorkingCurvature::Diagonal(diag) => {
1395                xwtau_x = RemlState::row_scale(&xwtau_x, &diag);
1396            }
1397        }
1398        let mut h_tau_analytic = gam_linalg::faer_ndarray::fast_atb(&x_tau, &wx);
1399        h_tau_analytic += &gam_linalg::faer_ndarray::fast_atb(&x, &wx_tau);
1400        h_tau_analytic += &gam_linalg::faer_ndarray::fast_atb(&x, &xwtau_x);
1401        h_tau_analytic += &s_tau;
1402
1403        // Fit-block stationarity cancellation:
1404        //   -ℓ_β^T B + β̂^T S B = 0.
1405        // Here S is the effective penalty in the inner Hessian surface:
1406        //   S = H - X^T W X.
1407        let ell_beta = gam_linalg::faer_ndarray::fast_atv(&x, &u);
1408        let s_eff = &h_orig - &gam_linalg::faer_ndarray::fast_atb(&x, &wx);
1409        let cancellation = -ell_beta.dot(&b_analytic) + beta.dot(&s_eff.dot(&b_analytic));
1410
1411        // Finite differences in τ against re-fit objective and mode.
1412        let h = 2e-5;
1413        let x_plus = &x + &(x_tau.mapv(|v| h * v));
1414        let x_minus = &x - &(x_tau.mapv(|v| h * v));
1415        let s_plus = &s0 + &(s_tau.mapv(|v| h * v));
1416        let s_minus = &s0 - &(s_tau.mapv(|v| h * v));
1417
1418        let state_plus = build_logit_state(&y, &w, &x_plus, &s_plus, &cfg);
1419        let state_minus = build_logit_state(&y, &w, &x_minus, &s_minus, &cfg);
1420        let bundle_plus = state_plus.obtain_eval_bundle(&rho).expect("bundle+");
1421        let bundle_minus = state_minus.obtain_eval_bundle(&rho).expect("bundle-");
1422        let beta_plus = beta_original_from_bundle(&bundle_plus);
1423        let beta_minus = beta_original_from_bundle(&bundle_minus);
1424        let bfd = (&beta_plus - &beta_minus).mapv(|v| v / (2.0 * h));
1425
1426        let h_plus = h_original_from_bundle(&bundle_plus);
1427        let h_minus = h_original_from_bundle(&bundle_minus);
1428        let h_taufd = (&h_plus - &h_minus).mapv(|v| v / (2.0 * h));
1429
1430        let v_plus = state_plus.compute_cost(&rho).expect("cost+");
1431        let v_minus = state_minus.compute_cost(&rho).expect("cost-");
1432        let v_taufd = (v_plus - v_minus) / (2.0 * h);
1433
1434        let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper.clone())
1435            .expect("analytic directional gradient");
1436
1437        let b_num = (&b_analytic - &bfd).mapv(|v| v * v).sum().sqrt();
1438        let b_den = bfd.mapv(|v| v * v).sum().sqrt().max(1e-12);
1439        let b_rel = b_num / b_den;
1440        for i in 0..b_analytic.len() {
1441            assert_eq!(
1442                b_analytic[i].signum(),
1443                bfd[i].signum(),
1444                "B sign mismatch at i={i}: analytic={} fd={}",
1445                b_analytic[i],
1446                bfd[i]
1447            );
1448        }
1449        assert!(
1450            b_rel < 2e-2,
1451            "B implicit solve mismatch vs FD: rel={b_rel:.3e}, num={b_num:.3e}, den={b_den:.3e}"
1452        );
1453
1454        let dh_num = (&h_tau_analytic - &h_taufd).mapv(|v| v * v).sum().sqrt();
1455        let dh_den = h_taufd.mapv(|v| v * v).sum().sqrt().max(1e-12);
1456        let dh_rel = dh_num / dh_den;
1457        for i in 0..h_tau_analytic.nrows() {
1458            for j in 0..h_tau_analytic.ncols() {
1459                assert_eq!(
1460                    h_tau_analytic[[i, j]].signum(),
1461                    h_taufd[[i, j]].signum(),
1462                    "H_tau sign mismatch at ({i},{j}): analytic={} fd={}",
1463                    h_tau_analytic[[i, j]],
1464                    h_taufd[[i, j]]
1465                );
1466            }
1467        }
1468        assert!(
1469            dh_rel < 3e-2,
1470            "H_tau mismatch vs FD: rel={dh_rel:.3e}, num={dh_num:.3e}, den={dh_den:.3e}"
1471        );
1472
1473        let v_abs = (v_tau_analytic - v_taufd).abs();
1474        let v_rel = v_abs / v_taufd.abs().max(1e-10);
1475        assert_eq!(
1476            v_tau_analytic.signum(),
1477            v_taufd.signum(),
1478            "V_tau sign mismatch: analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
1479        );
1480        assert!(
1481            v_rel < 2e-2,
1482            "V_tau mismatch vs FD: rel={v_rel:.3e}, abs={v_abs:.3e}, analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
1483        );
1484
1485        assert!(
1486            cancellation.abs() < 1e-10,
1487            "stationarity cancellation failed: | -ell_beta^T B + beta^T S B | = {:.3e}",
1488            cancellation.abs()
1489        );
1490    }
1491
1492    #[test]
1493    pub(crate) fn firth_exacthessian_includes_analytic_tk_second_derivatives() {
1494        // Rank-deficient X: the 4th column is 2x the 2nd column.
1495        let y = array![0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0];
1496        let w = Array1::<f64>::ones(y.len());
1497        let x = array![
1498            [1.0, -1.2, 0.4, -2.4],
1499            [1.0, -0.9, -0.1, -1.8],
1500            [1.0, -0.6, 0.3, -1.2],
1501            [1.0, -0.2, -0.4, -0.4],
1502            [1.0, 0.1, 0.5, 0.2],
1503            [1.0, 0.4, -0.6, 0.8],
1504            [1.0, 0.8, 0.2, 1.6],
1505            [1.0, 1.1, -0.3, 2.2],
1506            [1.0, 1.4, 0.7, 2.8],
1507            [1.0, 1.7, -0.2, 3.4],
1508        ];
1509        let s0 = array![
1510            [0.0, 0.0, 0.0, 0.0],
1511            [0.0, 1.5, 0.2, 0.0],
1512            [0.0, 0.2, 1.0, 0.0],
1513            [0.0, 0.0, 0.0, 0.5],
1514        ];
1515        let s1 = array![
1516            [0.0, 0.0, 0.0, 0.0],
1517            [0.0, 0.8, -0.1, 0.0],
1518            [0.0, -0.1, 0.6, 0.0],
1519            [0.0, 0.0, 0.0, 0.3],
1520        ];
1521        let offset = Array1::<f64>::zeros(y.len());
1522        // Rank-deficient Firth logit needs more inner iterations to converge
1523        // tightly enough for the envelope-theorem derivative tests.
1524        let cfg =
1525            RemlConfig::external(binomial_logit_glm_spec(), 1e-9, true).with_max_iterations(500);
1526        let p = x.ncols();
1527        use crate::estimate::PenaltySpec;
1528        let specs = vec![PenaltySpec::Dense(s0), PenaltySpec::Dense(s1)];
1529        let canonical = gam_terms::construction::canonicalize_penalty_specs(&specs, &[1, 1], p, "test")
1530            .map(|(canonical, _)| canonical)
1531            .expect("canonicalize");
1532        let state = RemlState::newwith_offset(
1533            y.view(),
1534            x.clone(),
1535            w.view(),
1536            offset.view(),
1537            canonical,
1538            p,
1539            &cfg,
1540            Some(vec![1, 1]),
1541            None,
1542            None,
1543        )
1544        .expect("state");
1545        let rho = array![0.1, -0.2];
1546        assert!(
1547            state.analytic_outer_hessian_enabled(),
1548            "Firth logit should no longer disable analytic outer Hessian planning"
1549        );
1550        let outer = state
1551            .compute_outer_eval_with_order(
1552                &rho,
1553                crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
1554            )
1555            .expect("outer Hessian eval should succeed");
1556        assert!(
1557            outer.hessian.is_analytic(),
1558            "outer planner should request and return an analytic Hessian"
1559        );
1560        let bundle = state.obtain_eval_bundle(&rho).expect("exact firth bundle");
1561        let h_dense = state
1562            .compute_lamlhessian_exact_from_bundle(&rho, &bundle)
1563            .expect("Firth exact Hessian should include analytic TK second derivatives");
1564        assert_eq!(h_dense.raw_dim(), ndarray::Ix2(2, 2));
1565        assert!(
1566            h_dense.iter().all(|value| value.is_finite()),
1567            "Hessian should be finite: {h_dense:?}"
1568        );
1569    }
1570
1571    #[test]
1572    pub(crate) fn firth_outer_hessian_matches_gradient_finite_difference_with_tk_terms() {
1573        let y = array![0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
1574        let w = Array1::<f64>::ones(y.len());
1575        let x = array![
1576            [1.0, -1.0, 0.3],
1577            [1.0, -0.7, -0.2],
1578            [1.0, -0.3, 0.4],
1579            [1.0, 0.0, -0.5],
1580            [1.0, 0.2, 0.6],
1581            [1.0, 0.6, -0.4],
1582            [1.0, 0.9, 0.2],
1583            [1.0, 1.3, -0.1],
1584        ];
1585        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.1], [0.0, 0.1, 0.7],];
1586        let s1 = array![[0.0, 0.0, 0.0], [0.0, 0.4, -0.05], [0.0, -0.05, 0.9],];
1587        let cfg =
1588            RemlConfig::external(binomial_logit_glm_spec(), 1e-9, true).with_max_iterations(500);
1589        let p_dim = x.ncols();
1590        use crate::estimate::PenaltySpec;
1591        let specs = vec![PenaltySpec::Dense(s0), PenaltySpec::Dense(s1)];
1592        let canonical =
1593            gam_terms::construction::canonicalize_penalty_specs(&specs, &[1, 1], p_dim, "test")
1594                .map(|(canonical, _)| canonical)
1595                .expect("canonicalize");
1596        let offset = Array1::<f64>::zeros(y.len());
1597        let state = RemlState::newwith_offset(
1598            y.view(),
1599            x.clone(),
1600            w.view(),
1601            offset.view(),
1602            canonical,
1603            p_dim,
1604            &cfg,
1605            Some(vec![1, 1]),
1606            None,
1607            None,
1608        )
1609        .expect("state");
1610        let rho = array![0.15, -0.25];
1611        let eval = state
1612            .compute_outer_eval_with_order(
1613                &rho,
1614                crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
1615            )
1616            .expect("analytic Hessian eval");
1617        let h = match eval.hessian {
1618            HessianResult::Analytic(hessian) => hessian,
1619            HessianResult::Operator(_) | HessianResult::Unavailable => {
1620                panic!("expected dense analytic Hessian")
1621            }
1622        };
1623        let delta = 2.0e-5;
1624        for col in 0..rho.len() {
1625            let mut rp = rho.clone();
1626            let mut rm = rho.clone();
1627            rp[col] += delta;
1628            rm[col] -= delta;
1629            let gp = state
1630                .compute_outer_eval_with_order(
1631                    &rp,
1632                    crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
1633                )
1634                .expect("plus grad")
1635                .gradient;
1636            let gm = state
1637                .compute_outer_eval_with_order(
1638                    &rm,
1639                    crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
1640                )
1641                .expect("minus grad")
1642                .gradient;
1643            for row in 0..rho.len() {
1644                let fd = (gp[row] - gm[row]) / (2.0 * delta);
1645                let an = h[[row, col]];
1646                let rel = (fd - an).abs() / fd.abs().max(an.abs()).max(1e-6);
1647                assert!(
1648                    rel < 2.0e-3,
1649                    "Hessian mismatch ({row},{col}): analytic={an:.9e}, fd={fd:.9e}, rel={rel:.3e}"
1650                );
1651            }
1652        }
1653    }
1654
1655    #[test]
1656    pub(crate) fn firthgradient_lives_in_design_column_space_under_rank_deficiency() {
1657        // Rank-deficient design: col4 = 2*col2.
1658        let x = array![
1659            [1.0, -1.2, 0.4, -2.4],
1660            [1.0, -0.9, -0.1, -1.8],
1661            [1.0, -0.6, 0.3, -1.2],
1662            [1.0, -0.2, -0.4, -0.4],
1663            [1.0, 0.1, 0.5, 0.2],
1664            [1.0, 0.4, -0.6, 0.8],
1665            [1.0, 0.8, 0.2, 1.6],
1666            [1.0, 1.1, -0.3, 2.2],
1667        ];
1668        let beta = array![0.1, -0.2, 0.3, 0.05];
1669        let eta = x.dot(&beta);
1670        let op = super::RemlState::build_firth_dense_operator_for_link(
1671            &gam_problem::InverseLink::Standard(gam_problem::StandardLink::Logit),
1672            &x,
1673            &eta,
1674            ndarray::Array1::ones(x.nrows()).view(),
1675        )
1676        .expect("firth operator");
1677
1678        // Exact reduced-space Firth gradient:
1679        //   gradPhi = 0.5 Xᵀ (w' ⊙ h), with h = diag(X_r K_r X_rᵀ).
1680        let gradphi = 0.5 * x.t().dot(&(&op.w1 * &op.h_diag));
1681
1682        // Check (I - QQᵀ) gradPhi ≈ 0.
1683        let q = &op.q_basis;
1684        let proj = q.dot(&q.t().dot(&gradphi));
1685        let resid = &gradphi - &proj;
1686        let rel =
1687            resid.mapv(|v| v * v).sum().sqrt() / gradphi.mapv(|v| v * v).sum().sqrt().max(1e-12);
1688        assert!(
1689            rel < 1e-10,
1690            "Firth gradient should lie in Col(Xᵀ): rel residual={rel:.3e}"
1691        );
1692    }
1693
1694    #[test]
1695    pub(crate) fn firth_logit_directional_hypergradient_accepts_penalty_only_with_full_tk_gradient()
1696    {
1697        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1698        let w = Array1::<f64>::ones(y.len());
1699        let x = array![
1700            [1.0, -1.1, 0.2],
1701            [1.0, -0.6, -0.3],
1702            [1.0, -0.1, 0.5],
1703            [1.0, 0.3, -0.7],
1704            [1.0, 0.8, 0.1],
1705            [1.0, 1.2, -0.4],
1706        ];
1707        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1708        let hyper = DirectionalHyperParam::single_penalty(
1709            0,
1710            Array2::<f64>::zeros((x.nrows(), x.ncols())),
1711            array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1712            None,
1713            None,
1714        )
1715        .expect("single-penalty hyper direction");
1716        let rho = array![0.0];
1717        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1718        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1719        let gradient = single_directional_tau_gradient(&state, &rho, hyper)
1720            .expect("Firth penalty-only directional gradient should use analytic TK propagation");
1721        assert!(gradient.is_finite(), "gradient={gradient}");
1722        let fd = fd_directional_tau_cost_gradient(
1723            &y,
1724            &w,
1725            &x,
1726            &s0,
1727            &cfg,
1728            &rho,
1729            &Array2::<f64>::zeros((x.nrows(), x.ncols())),
1730            &array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1731        );
1732        let rel = (gradient - fd).abs() / gradient.abs().max(fd.abs()).max(1.0e-10);
1733        assert!(
1734            rel < 1.0e-3,
1735            "Firth penalty-only directional gradient mismatch: analytic={gradient:.12e}, fd={fd:.12e}, rel={rel:.3e}"
1736        );
1737
1738        let efs_hyper = DirectionalHyperParam::single_penalty(
1739            0,
1740            Array2::<f64>::zeros((x.nrows(), x.ncols())),
1741            array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1742            None,
1743            None,
1744        )
1745        .expect("single-penalty EFS hyper direction");
1746        let efs = state
1747            .compute_efs_steps_with_psi_ext(&rho, &[efs_hyper])
1748            .expect("Firth penalty-only EFS should use analytic TK propagation");
1749        assert!(efs.cost.is_finite(), "efs cost={}", efs.cost);
1750    }
1751
1752    /// Regression for gam#1821: the analytic ρ-gradient of the Firth-corrected
1753    /// LAML cost must equal the central finite difference of that SAME cost,
1754    /// evaluated END-TO-END through the inner P-IRLS solve (`compute_cost` /
1755    /// `compute_gradient`), for a genuinely Firth-active (near-separable) fit.
1756    ///
1757    /// This exercises the branch the earlier operator-level Firth FD tests never
1758    /// touched: the LM line-search that produces β̂. The dense LAML gradient uses
1759    /// the envelope identity, which holds ONLY when β̂ satisfies the *Firth*-KKT
1760    /// stationarity `∇(−ℓ+½βᵀSβ) = ∇Φ`. When `GamWorkingModel::update_candidate`
1761    /// built line-search candidates with Firth disabled, the candidate/accepted
1762    /// `WorkingState` dropped the `−2·½log|XᵀWX|` Jeffreys term, the objective
1763    /// the line search compared (candidate vs `current_penalized`) was
1764    /// inconsistent, and — because the accepted state IS the candidate and
1765    /// convergence is certified on `accepted_state.gradient` — the inner solve
1766    /// settled at the ordinary penalized MLE (`∇(−ℓ+½βᵀSβ)=0`) instead. At that
1767    /// wrong mode the envelope breaks and the analytic ρ-gradient disagrees with
1768    /// the FD of the cost by `O((∇Φ)ᵀ∂β̂/∂ρ)` — a large (percent-level) desync
1769    /// that appears ONLY under `firth_bias_reduction`. A tight FD-vs-analytic
1770    /// bound therefore fails iff the inner solve regresses off the Firth mode.
1771    #[test]
1772    pub(crate) fn firth_logit_rho_gradient_matches_finite_difference_through_inner_solve() {
1773        // Near-separable n=3 logit: Firth is materially active (the ordinary
1774        // penalized MLE and the Firth-penalized mode differ enough that any
1775        // envelope break shows up far above the 1e-4 tolerance).
1776        let x = array![[1.0, -6.0], [1.0, 0.2], [1.0, 5.8]];
1777        let y = array![0.0, 0.0, 1.0];
1778        let w = Array1::<f64>::ones(y.len());
1779        // Full-rank identity penalty on both coefficients (single ρ).
1780        let s0 = array![[1.0, 0.0], [0.0, 1.0]];
1781        // Tight inner tolerance so β̂(ρ ± δ) and β̂(ρ) are all at the converged
1782        // Firth-KKT mode; otherwise the FD would capture β̂'s residual motion.
1783        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-12, true);
1784        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1785        let delta = 1e-4_f64;
1786        for &rho in &[-0.6_f64, -0.3, 0.0, 0.3, 0.6] {
1787            let r = array![rho];
1788            let analytic = state
1789                .compute_gradient(&r)
1790                .expect("Firth LAML ρ-gradient should evaluate")[0];
1791            let cost_plus = state
1792                .compute_cost(&array![rho + delta])
1793                .expect("Firth LAML cost(ρ+δ) should evaluate");
1794            let cost_minus = state
1795                .compute_cost(&array![rho - delta])
1796                .expect("Firth LAML cost(ρ−δ) should evaluate");
1797            let fd = (cost_plus - cost_minus) / (2.0 * delta);
1798            let rel = (fd - analytic).abs() / fd.abs().max(1e-3);
1799            assert!(
1800                analytic.is_finite() && fd.is_finite(),
1801                "non-finite Firth ρ-gradient at rho={rho:+.3}: fd={fd:+.6e}, analytic={analytic:+.6e}"
1802            );
1803            assert!(
1804                rel < 1e-4,
1805                "Firth ρ-gradient FD desync at rho={rho:+.3}: fd={fd:+.6e}, analytic={analytic:+.6e}, rel={rel:.3e} (>= 1e-4). \
1806                 The inner P-IRLS likely converged off the Firth-KKT mode (gam#1821)."
1807            );
1808        }
1809    }
1810
1811    #[test]
1812    pub(crate) fn firth_logit_directional_hypergradient_accepts_design_moving_with_full_tk_gradient()
1813     {
1814        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1815        let w = Array1::<f64>::ones(y.len());
1816        let x = array![
1817            [1.0, -1.1, 0.2],
1818            [1.0, -0.6, -0.3],
1819            [1.0, -0.1, 0.5],
1820            [1.0, 0.3, -0.7],
1821            [1.0, 0.8, 0.1],
1822            [1.0, 1.2, -0.4],
1823        ];
1824        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1825        let hyper = DirectionalHyperParam::single_penalty(
1826            0,
1827            Array2::from_elem((x.nrows(), x.ncols()), 1e-3),
1828            Array2::<f64>::zeros((x.ncols(), x.ncols())),
1829            None,
1830            None,
1831        )
1832        .expect("single-penalty hyper direction");
1833        let rho = array![0.0];
1834        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1835        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1836        let gradient = single_directional_tau_gradient(&state, &rho, hyper)
1837            .expect("Firth design-moving directional gradient should use analytic TK propagation");
1838        assert!(gradient.is_finite(), "gradient={gradient}");
1839        let x_tau = Array2::from_elem((x.nrows(), x.ncols()), 1e-3);
1840        let s_tau = Array2::<f64>::zeros((x.ncols(), x.ncols()));
1841        let fd = fd_directional_tau_cost_gradient(&y, &w, &x, &s0, &cfg, &rho, &x_tau, &s_tau);
1842        let rel = (gradient - fd).abs() / gradient.abs().max(fd.abs()).max(1.0e-10);
1843        assert!(
1844            rel < 2.0e-2,
1845            "Firth design-moving directional gradient mismatch: analytic={gradient:.12e}, fd={fd:.12e}, rel={rel:.3e}"
1846        );
1847    }
1848
1849    #[test]
1850    pub(crate) fn firth_logit_hybrid_efs_accepts_full_tk_psi_gradient() {
1851        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1852        let w = Array1::<f64>::ones(y.len());
1853        let x = array![
1854            [1.0, -1.1, 0.2],
1855            [1.0, -0.6, -0.3],
1856            [1.0, -0.1, 0.5],
1857            [1.0, 0.3, -0.7],
1858            [1.0, 0.8, 0.1],
1859            [1.0, 1.2, -0.4],
1860        ];
1861        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1862        let hyper_dirs = vec![
1863            DirectionalHyperParam::single_penalty(
1864                0,
1865                Array2::from_shape_fn((x.nrows(), x.ncols()), |(i, j)| {
1866                    1e-3 * ((i + 1) as f64) * ((j + 2) as f64)
1867                }),
1868                Array2::<f64>::zeros((x.ncols(), x.ncols())),
1869                None,
1870                None,
1871            )
1872            .expect("design-moving hyper direction"),
1873        ];
1874        let rho = array![0.0];
1875        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1876        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1877
1878        let full = state
1879            .evaluate_unified_with_psi_ext(
1880                &rho,
1881                None,
1882                crate::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient,
1883                &hyper_dirs,
1884            )
1885            .expect("full Firth psi gradient should use analytic TK propagation");
1886        assert!(full.cost.is_finite(), "full cost={}", full.cost);
1887        let full_grad = full.gradient.expect("gradient should be present");
1888        assert!(
1889            full_grad.iter().all(|value| value.is_finite()),
1890            "full gradient={full_grad:?}"
1891        );
1892
1893        let efs = state
1894            .compute_efs_steps_with_psi_ext(&rho, &hyper_dirs)
1895            .expect("hybrid EFS should use analytic TK propagation");
1896        assert!(efs.cost.is_finite(), "efs cost={}", efs.cost);
1897    }
1898
1899    #[test]
1900    pub(crate) fn joint_hyperhessianwires_mixed_blocks() {
1901        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1902        let w = Array1::<f64>::ones(y.len());
1903        let x = array![
1904            [1.0, -1.2, 0.3],
1905            [1.0, -0.8, -0.4],
1906            [1.0, -0.3, 0.7],
1907            [1.0, 0.1, -0.9],
1908            [1.0, 0.5, 0.2],
1909            [1.0, 0.9, -0.1],
1910            [1.0, 1.3, 0.8],
1911            [1.0, 1.7, -0.6],
1912        ];
1913        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1914        let cfg =
1915            RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
1916        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1917        let rho = array![0.0];
1918        let theta = array![0.0, 0.0, 0.0];
1919        let hyper_dirs = vec![
1920            DirectionalHyperParam::single_penalty(
1921                0,
1922                Array2::<f64>::zeros((x.nrows(), x.ncols())),
1923                array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
1924                None,
1925                None,
1926            )
1927            .expect("single-penalty hyper direction"),
1928            DirectionalHyperParam::single_penalty(
1929                0,
1930                Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1931                Array2::<f64>::zeros((x.ncols(), x.ncols())),
1932                None,
1933                None,
1934            )
1935            .expect("single-penalty hyper direction"),
1936        ];
1937
1938        let (_, _, h) =
1939            compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
1940                .expect("joint hyper cost+gradient+hessian");
1941        assert_eq!(h.nrows(), theta.len());
1942        assert_eq!(h.ncols(), theta.len());
1943        assert!(h.iter().all(|v| v.is_finite()));
1944        for i in 0..h.nrows() {
1945            for j in 0..i {
1946                let diff = (h[[i, j]] - h[[j, i]]).abs();
1947                assert!(
1948                    diff < 1e-6,
1949                    "joint hessian asymmetry at ({i},{j}): {diff:.3e}"
1950                );
1951            }
1952        }
1953        // Mixed block must be nontrivial for at least one supplied direction.
1954        let mixed_0 = h[[0, 1]];
1955        let mixed_1 = h[[0, 2]];
1956        assert!(
1957            mixed_0.is_finite() && mixed_1.is_finite(),
1958            "mixed blocks must be finite"
1959        );
1960    }
1961
1962    #[test]
1963    pub(crate) fn joint_tau_tau_linear_dirs_matchfd_reference_away_fromzero_psi() {
1964        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1965        let w = Array1::<f64>::ones(y.len());
1966        let x = array![
1967            [1.0, -1.2, 0.3],
1968            [1.0, -0.8, -0.4],
1969            [1.0, -0.3, 0.7],
1970            [1.0, 0.1, -0.9],
1971            [1.0, 0.5, 0.2],
1972            [1.0, 0.9, -0.1],
1973            [1.0, 1.3, 0.8],
1974            [1.0, 1.7, -0.6],
1975        ];
1976        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1977        let cfg =
1978            RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
1979        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1980        let rho = array![0.0];
1981        let psi = array![0.7, -0.4];
1982        let theta = array![rho[0], psi[0], psi[1]];
1983        let hyper_dirs = vec![
1984            DirectionalHyperParam::single_penalty(
1985                0,
1986                Array2::<f64>::zeros((x.nrows(), x.ncols())),
1987                array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
1988                None,
1989                None,
1990            )
1991            .expect("linear tau direction"),
1992            DirectionalHyperParam::single_penalty(
1993                0,
1994                Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1995                Array2::<f64>::zeros((x.ncols(), x.ncols())),
1996                None,
1997                None,
1998            )
1999            .expect("linear tau direction"),
2000        ];
2001
2002        let (_, _, h_full) =
2003            compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
2004                .expect("joint hyper cost+gradient+hessian");
2005        let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
2006
2007        // FD via physical perturbation of design/penalty matrices (matching
2008        // the V_tau FD pattern).  For column j we perturb X and S₀ along
2009        // direction j, build fresh states, and evaluate the τ-gradient for
2010        // every direction i at those perturbed states.
2011        let x_tau_mats: Vec<Array2<f64>> = vec![
2012            Array2::<f64>::zeros((x.nrows(), x.ncols())),
2013            Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
2014        ];
2015        let s_tau_mats: Vec<Array2<f64>> = vec![
2016            array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15]],
2017            Array2::<f64>::zeros((x.ncols(), x.ncols())),
2018        ];
2019
2020        let h_ttfd = directional_tau_hessian_fd_reference(
2021            &y,
2022            &w,
2023            &x,
2024            &s0,
2025            &cfg,
2026            &rho,
2027            &hyper_dirs,
2028            &x_tau_mats,
2029            &s_tau_mats,
2030        );
2031
2032        let num = (&h_tt_analytic - &h_ttfd)
2033            .iter()
2034            .map(|v| v * v)
2035            .sum::<f64>()
2036            .sqrt();
2037        let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2038        let rel = num / den;
2039        assert!(
2040            rel < 1e-4,
2041            "linear-dir joint tau-tau block deviates from FD reference away from zero psi: rel={rel:.3e}, analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2042        );
2043    }
2044
2045    #[test]
2046    pub(crate) fn joint_hypervalidation_rejects_out_of_boundssecond_order_penalty_index() {
2047        // The hyper direction declares a second-order penalty derivative
2048        // against base penalty index 1, but the configured ρ vector has
2049        // dimension 1 (so only index 0 is valid).  The pair-callback
2050        // builder in `build_tau_penalty_derivative_data` is responsible for
2051        // validating both first- and second-order penalty indices against
2052        // `rho.len()`; this test pins that contract.
2053        //
2054        // We deliberately keep `firth_bias_reduction = true` here so the
2055        // call site exercises the full Firth/Tierney–Kadane outer pipeline:
2056        // PIRLS + ext-coord construction + pair-callback assembly.  With
2057        // analytic c/d propagation now wired in
2058        // `tk_direct_gradient_from_cd_and_design`, there is no longer any
2059        // FD-fallback rejection on this path, so the out-of-bounds error
2060        // fired by the pair-callback builder is the first failure the
2061        // joint evaluator surfaces — and that is exactly what we want this
2062        // test to assert.
2063        let y = array![0.0, 1.0, 0.0, 1.0];
2064        let w = Array1::<f64>::ones(y.len());
2065        let x = array![
2066            [1.0, -0.5, 0.2],
2067            [1.0, -0.1, -0.3],
2068            [1.0, 0.4, 0.6],
2069            [1.0, 0.9, -0.2],
2070        ];
2071        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
2072        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, true);
2073        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2074        let theta = array![0.0, 0.0];
2075        let hyper_dirs = vec![
2076            DirectionalHyperParam::new(
2077                Array2::<f64>::zeros((x.nrows(), x.ncols())),
2078                vec![(0, Array2::<f64>::zeros((x.ncols(), x.ncols())))],
2079                None,
2080                Some(vec![Some(vec![(1, Array2::<f64>::eye(x.ncols()))])]),
2081            )
2082            .expect("hyper direction with invalid second-order penalty index"),
2083        ];
2084
2085        let msg = match compute_joint_hypercostgradienthessian(&state, &theta, 1, &hyper_dirs) {
2086            Ok(_) => panic!("invalid second-order penalty index should be rejected"),
2087            Err(err) => err.to_string(),
2088        };
2089        assert!(
2090            msg.contains("out of bounds") || msg.contains("penalty_index"),
2091            "unexpected validation error: {msg}"
2092        );
2093    }
2094
2095    #[test]
2096    pub(crate) fn joint_tau_tau_analytic_matchesfd_reference() {
2097        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
2098        let w = Array1::<f64>::ones(y.len());
2099        let x = array![
2100            [1.0, -1.2, 0.3],
2101            [1.0, -0.8, -0.4],
2102            [1.0, -0.3, 0.7],
2103            [1.0, 0.1, -0.9],
2104            [1.0, 0.5, 0.2],
2105            [1.0, 0.9, -0.1],
2106            [1.0, 1.3, 0.8],
2107            [1.0, 1.7, -0.6],
2108        ];
2109        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
2110        let cfg =
2111            RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
2112        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2113        let rho = array![0.0];
2114        let psi = array![0.0, 0.0];
2115        let hyper_dirs = vec![
2116            DirectionalHyperParam::single_penalty(
2117                0,
2118                Array2::<f64>::zeros((x.nrows(), x.ncols())),
2119                array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
2120                None,
2121                None,
2122            )
2123            .expect("single-penalty hyper direction"),
2124            DirectionalHyperParam::single_penalty(
2125                0,
2126                Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
2127                Array2::<f64>::zeros((x.ncols(), x.ncols())),
2128                None,
2129                None,
2130            )
2131            .expect("single-penalty hyper direction"),
2132        ];
2133
2134        let theta = {
2135            let mut t = Array1::<f64>::zeros(rho.len() + psi.len());
2136            t.slice_mut(s![..rho.len()]).assign(&rho);
2137            t.slice_mut(s![rho.len()..]).assign(&psi);
2138            t
2139        };
2140        let (_, _, h_full) =
2141            compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
2142                .expect("joint hyper cost+gradient+hessian");
2143        let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
2144        assert_eq!(h_tt_analytic.nrows(), hyper_dirs.len());
2145        assert_eq!(h_tt_analytic.ncols(), hyper_dirs.len());
2146
2147        // FD via physical perturbation of design/penalty matrices (matching
2148        // the V_tau FD pattern).  For column j we perturb X and S₀ along
2149        // direction j, build fresh states, and evaluate the τ-gradient for
2150        // every direction i at those perturbed states.
2151        let x_tau_mats: Vec<Array2<f64>> = vec![
2152            Array2::<f64>::zeros((x.nrows(), x.ncols())),
2153            Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
2154        ];
2155        let s_tau_mats: Vec<Array2<f64>> = vec![
2156            array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15]],
2157            Array2::<f64>::zeros((x.ncols(), x.ncols())),
2158        ];
2159
2160        let h_ttfd = directional_tau_hessian_fd_reference(
2161            &y,
2162            &w,
2163            &x,
2164            &s0,
2165            &cfg,
2166            &rho,
2167            &hyper_dirs,
2168            &x_tau_mats,
2169            &s_tau_mats,
2170        );
2171
2172        let num = (&h_tt_analytic - &h_ttfd)
2173            .iter()
2174            .map(|v| v * v)
2175            .sum::<f64>()
2176            .sqrt();
2177        let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2178        let rel = num / den;
2179        assert!(
2180            rel < 1e-4,
2181            "analytic tau-tau block deviates from FD reference: rel={rel:.3e}, analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2182        );
2183    }
2184
2185    // ── Profiled Gaussian REML coverage for design-moving τ-directions ──
2186    //
2187    // The existing directional-hyper tests all use BinomialLogit, which has
2188    // DispersionHandling::Fixed.  These tests validate the profiled Gaussian
2189    // path (DispersionHandling::ProfiledGaussian) with design-moving
2190    // τ-directions, where the profiled scale φ̂ = D_p/(n−M) depends on ρ
2191    // and the envelope-theorem rescaling by (n−M)/D_p must be correct.
2192
2193    /// Shared test fixture for profiled Gaussian REML tests.
2194    pub(crate) struct GaussianRemlFixture {
2195        pub(crate) y: Array1<f64>,
2196        pub(crate) w: Array1<f64>,
2197        pub(crate) x: Array2<f64>,
2198        pub(crate) s0: Array2<f64>,
2199        pub(crate) cfg: RemlConfig,
2200        pub(crate) rho: Array1<f64>,
2201        /// Design-moving τ-direction (non-zero X_τ, zero S_τ).
2202        pub(crate) x_tau_design: Array2<f64>,
2203        /// Penalty-only τ-direction (zero X_τ, non-zero S_τ).
2204        pub(crate) s_tau_penalty: Array2<f64>,
2205    }
2206
2207    impl GaussianRemlFixture {
2208        pub(crate) fn new() -> Self {
2209            let y = array![0.5, 1.2, -0.3, 0.8, 1.1, -0.6, 0.9, 0.1, -0.2, 0.7];
2210            let x = array![
2211                [1.0, -1.2, 0.3],
2212                [1.0, -0.8, -0.4],
2213                [1.0, -0.3, 0.7],
2214                [1.0, 0.1, -0.9],
2215                [1.0, 0.5, 0.2],
2216                [1.0, 0.9, -0.1],
2217                [1.0, 1.3, 0.8],
2218                [1.0, 1.7, -0.6],
2219                [1.0, -0.5, 0.5],
2220                [1.0, 0.3, -0.3],
2221            ];
2222            Self {
2223                w: Array1::<f64>::ones(y.len()),
2224                y,
2225                x: x.clone(),
2226                s0: array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]],
2227                cfg: RemlConfig::external(gaussian_identity_glm_spec(), 1e-14, false),
2228                rho: array![0.0],
2229                x_tau_design: array![
2230                    [0.0, 1e-3, -2e-3],
2231                    [0.0, -3e-3, 1e-3],
2232                    [0.0, 2e-3, 0.5e-3],
2233                    [0.0, -1e-3, 3e-3],
2234                    [0.0, 0.5e-3, -1e-3],
2235                    [0.0, 1.5e-3, 2e-3],
2236                    [0.0, -2e-3, -0.5e-3],
2237                    [0.0, 3e-3, 1e-3],
2238                    [0.0, -0.5e-3, 2e-3],
2239                    [0.0, 1e-3, -1.5e-3],
2240                ],
2241                s_tau_penalty: array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15]],
2242            }
2243        }
2244    }
2245
2246    impl LogitDesignMotionFixture for GaussianRemlFixture {
2247        fn y(&self) -> &Array1<f64> {
2248            &self.y
2249        }
2250        fn w(&self) -> &Array1<f64> {
2251            &self.w
2252        }
2253        fn x(&self) -> &Array2<f64> {
2254            &self.x
2255        }
2256        fn s0(&self) -> &Array2<f64> {
2257            &self.s0
2258        }
2259        fn cfg(&self) -> &RemlConfig {
2260            &self.cfg
2261        }
2262        fn rho(&self) -> &Array1<f64> {
2263            &self.rho
2264        }
2265    }
2266
2267    #[test]
2268    pub(crate) fn profiled_gaussian_design_moving_gradient_matches_fd() {
2269        let f = GaussianRemlFixture::new();
2270        let state = f.state();
2271        let s_tau = Array2::<f64>::zeros((3, 3));
2272        let hyper = DirectionalHyperParam::single_penalty(
2273            0,
2274            f.x_tau_design.clone(),
2275            s_tau.clone(),
2276            None,
2277            None,
2278        )
2279        .expect("design-moving hyper direction");
2280
2281        let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2282            .expect("analytic directional gradient");
2283        let v_taufd = f.fd_directional_gradient(&f.x_tau_design, &s_tau);
2284
2285        let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2286        assert!(
2287            v_rel < 1e-3,
2288            "Gaussian REML design-moving V_tau mismatch: rel={v_rel:.3e}, \
2289             analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2290        );
2291    }
2292
2293    #[test]
2294    pub(crate) fn profiled_gaussian_penalty_only_gradient_matches_fd() {
2295        let f = GaussianRemlFixture::new();
2296        let state = f.state();
2297        let x_tau = Array2::<f64>::zeros(f.x.raw_dim());
2298        let hyper = DirectionalHyperParam::single_penalty(
2299            0,
2300            x_tau.clone(),
2301            f.s_tau_penalty.clone(),
2302            None,
2303            None,
2304        )
2305        .expect("penalty-only hyper direction");
2306
2307        let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2308            .expect("analytic directional gradient");
2309        let v_taufd = f.fd_directional_gradient(&x_tau, &f.s_tau_penalty);
2310
2311        let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2312        assert!(
2313            v_rel < 1e-3,
2314            "Gaussian REML penalty-only V_tau mismatch: rel={v_rel:.3e}, \
2315             analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2316        );
2317    }
2318
2319    #[test]
2320    pub(crate) fn profiled_gaussian_joint_hessian_matches_fd() {
2321        // Validate the ττ Hessian block under profiled Gaussian REML with
2322        // both a penalty-only and a design-moving direction.
2323        let f = GaussianRemlFixture::new();
2324        let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2325        let s_tau_0 = f.s_tau_penalty.clone();
2326        let x_tau_1 = f.x_tau_design.clone();
2327        let s_tau_1 = Array2::<f64>::zeros((3, 3));
2328
2329        let hyper_dirs = vec![
2330            DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2331                .expect("penalty-only direction"),
2332            DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2333                .expect("design-moving direction"),
2334        ];
2335
2336        let state = f.state();
2337        let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2338        theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2339        let (_, _, h_full) =
2340            compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2341                .expect("joint cost+gradient+hessian");
2342        let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2343
2344        // Finite-difference Hessian: perturb each direction, re-evaluate
2345        // gradient of all directions at perturbed states.
2346        let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2347        let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2348        let h_ttfd = directional_tau_hessian_fd_reference(
2349            &f.y,
2350            &f.w,
2351            &f.x,
2352            &f.s0,
2353            &f.cfg,
2354            &f.rho,
2355            &hyper_dirs,
2356            &x_tau_mats,
2357            &s_tau_mats,
2358        );
2359
2360        let num = (&h_tt_analytic - &h_ttfd)
2361            .iter()
2362            .map(|v| v * v)
2363            .sum::<f64>()
2364            .sqrt();
2365        let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2366        let rel = num / den;
2367        assert!(
2368            rel < 1e-4,
2369            "Gaussian REML tau-tau Hessian mismatch: rel={rel:.3e}, \
2370             analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2371        );
2372    }
2373
2374    // ── Non-Gaussian + design-motion: IFT Hessian-drift coverage ────────
2375    //
2376    // For non-Gaussian links (logit, probit, cloglog, ...), H = X'W(η)X + S
2377    // depends on β̂ through η = Xβ̂.  When ψ moves the design, the total
2378    // Hessian drift dH/dψ includes an IFT contribution from dβ̂/dψ:
2379    //
2380    //   dH/dψ = [explicit at fixed β] + X' diag(c ⊙ X(-v_i)) X
2381    //
2382    // where v_i = H⁻¹ g_i.  The standard GLM path handles this via
2383    // `hessian_derivative_correction(v_i)`.  This test validates that the
2384    // gradient is correct for logit + design-moving ψ, which would fail if
2385    // the IFT correction were missing.
2386
2387    #[test]
2388    pub(crate) fn logit_design_moving_gradient_matches_fd() {
2389        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
2390        let w = Array1::<f64>::ones(y.len());
2391        let x = array![
2392            [1.0, -1.2, 0.3],
2393            [1.0, -0.8, -0.4],
2394            [1.0, -0.3, 0.7],
2395            [1.0, 0.1, -0.9],
2396            [1.0, 0.5, 0.2],
2397            [1.0, 0.9, -0.1],
2398            [1.0, 1.3, 0.8],
2399            [1.0, 1.7, -0.6],
2400            [1.0, -0.5, 0.5],
2401            [1.0, 0.3, -0.3],
2402        ];
2403        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]];
2404        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2405        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2406        let rho = array![0.0];
2407
2408        // Design-moving direction with non-zero X_τ.
2409        let x_tau = array![
2410            [0.0, 1e-3, -2e-3],
2411            [0.0, -3e-3, 1e-3],
2412            [0.0, 2e-3, 0.5e-3],
2413            [0.0, -1e-3, 3e-3],
2414            [0.0, 0.5e-3, -1e-3],
2415            [0.0, 1.5e-3, 2e-3],
2416            [0.0, -2e-3, -0.5e-3],
2417            [0.0, 3e-3, 1e-3],
2418            [0.0, -0.5e-3, 2e-3],
2419            [0.0, 1e-3, -1.5e-3],
2420        ];
2421        let s_tau = Array2::<f64>::zeros((3, 3));
2422        let hyper =
2423            DirectionalHyperParam::single_penalty(0, x_tau.clone(), s_tau.clone(), None, None)
2424                .expect("design-moving hyper direction");
2425
2426        let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper)
2427            .expect("analytic directional gradient");
2428
2429        let h = 2e-5;
2430        let x_plus = &x + &x_tau.mapv(|v| h * v);
2431        let x_minus = &x - &x_tau.mapv(|v| h * v);
2432        let state_plus = build_logit_state(&y, &w, &x_plus, &s0, &cfg);
2433        let state_minus = build_logit_state(&y, &w, &x_minus, &s0, &cfg);
2434        let v_plus = state_plus.compute_cost(&rho).expect("cost+");
2435        let v_minus = state_minus.compute_cost(&rho).expect("cost-");
2436        let v_taufd = (v_plus - v_minus) / (2.0 * h);
2437
2438        let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2439        assert!(
2440            v_rel < 1e-3,
2441            "Logit REML design-moving V_tau mismatch: rel={v_rel:.3e}, \
2442             analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2443        );
2444    }
2445
2446    #[test]
2447    pub(crate) fn logit_design_moving_hessian_matches_fd() {
2448        // Hessian-level validation for logit + design-motion.
2449        // The IFT correction enters the trace term through
2450        // hessian_derivative_correction(v_i), so the Hessian is the most
2451        // sensitive test of whether the correction is applied correctly.
2452        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
2453        let w = Array1::<f64>::ones(y.len());
2454        let x = array![
2455            [1.0, -1.2, 0.3],
2456            [1.0, -0.8, -0.4],
2457            [1.0, -0.3, 0.7],
2458            [1.0, 0.1, -0.9],
2459            [1.0, 0.5, 0.2],
2460            [1.0, 0.9, -0.1],
2461            [1.0, 1.3, 0.8],
2462            [1.0, 1.7, -0.6],
2463            [1.0, -0.5, 0.5],
2464            [1.0, 0.3, -0.3],
2465        ];
2466        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]];
2467        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2468        let rho = array![0.0];
2469
2470        // Two directions: one penalty-only, one design-moving.
2471        let x_tau_0 = Array2::<f64>::zeros(x.raw_dim());
2472        let s_tau_0 = array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15]];
2473        let x_tau_1 = array![
2474            [0.0, 1e-3, -2e-3],
2475            [0.0, -3e-3, 1e-3],
2476            [0.0, 2e-3, 0.5e-3],
2477            [0.0, -1e-3, 3e-3],
2478            [0.0, 0.5e-3, -1e-3],
2479            [0.0, 1.5e-3, 2e-3],
2480            [0.0, -2e-3, -0.5e-3],
2481            [0.0, 3e-3, 1e-3],
2482            [0.0, -0.5e-3, 2e-3],
2483            [0.0, 1e-3, -1.5e-3],
2484        ];
2485        let s_tau_1 = Array2::<f64>::zeros((3, 3));
2486
2487        let hyper_dirs = vec![
2488            DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2489                .expect("penalty-only direction"),
2490            DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2491                .expect("design-moving direction"),
2492        ];
2493
2494        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2495        let mut theta = Array1::<f64>::zeros(rho.len() + hyper_dirs.len());
2496        theta.slice_mut(s![..rho.len()]).assign(&rho);
2497        let (_, _, h_full) =
2498            compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
2499                .expect("joint cost+gradient+hessian");
2500        let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
2501
2502        let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2503        let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2504        let h_ttfd = directional_tau_hessian_fd_reference(
2505            &y,
2506            &w,
2507            &x,
2508            &s0,
2509            &cfg,
2510            &rho,
2511            &hyper_dirs,
2512            &x_tau_mats,
2513            &s_tau_mats,
2514        );
2515
2516        let num = (&h_tt_analytic - &h_ttfd)
2517            .iter()
2518            .map(|v| v * v)
2519            .sum::<f64>()
2520            .sqrt();
2521        let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2522        let rel = num / den;
2523        assert!(
2524            rel < 1e-4,
2525            "Logit REML design-moving tau-tau Hessian mismatch: rel={rel:.3e}, \
2526             analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2527        );
2528    }
2529
2530    // ── Larger non-Gaussian + design-motion fixture (n=30, p=5) ────────
2531    //
2532    // Validates the IFT correction (hessian_derivative_correction) at a
2533    // scale large enough that the correction is numerically non-trivial:
2534    // with n=30 and p=5, the logistic Hessian W(η) is far from identity
2535    // and the IFT term dβ̂/dψ contributes meaningfully.
2536
2537    /// Shared test fixture for binomial-logit REML with design-moving
2538    /// ψ-coordinates, n=30, p=5.
2539    pub(crate) struct BinomialLogitDesignMotionFixture {
2540        pub(crate) y: Array1<f64>,
2541        pub(crate) w: Array1<f64>,
2542        pub(crate) x: Array2<f64>,
2543        pub(crate) s0: Array2<f64>,
2544        pub(crate) cfg: RemlConfig,
2545        pub(crate) rho: Array1<f64>,
2546        /// Design-moving τ-direction: non-zero X_τ, zero S_τ.
2547        pub(crate) x_tau_design: Array2<f64>,
2548        /// Penalty-only τ-direction: zero X_τ, non-zero S_τ.
2549        pub(crate) s_tau_penalty: Array2<f64>,
2550    }
2551
2552    impl BinomialLogitDesignMotionFixture {
2553        pub(crate) fn new() -> Self {
2554            // Binary response with roughly balanced classes.
2555            let y = array![
2556                1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0,
2557                1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0
2558            ];
2559            // Design matrix: intercept + 4 covariate columns with varied magnitudes.
2560            let x = array![
2561                [1.0, -1.50, 0.42, 0.88, -0.31],
2562                [1.0, -1.12, -0.65, 0.14, 1.23],
2563                [1.0, -0.80, 1.10, -0.53, 0.07],
2564                [1.0, -0.55, -0.22, 1.40, -0.90],
2565                [1.0, -0.30, 0.73, -1.05, 0.44],
2566                [1.0, -0.05, -1.33, 0.60, 0.81],
2567                [1.0, 0.18, 0.55, -0.27, -1.15],
2568                [1.0, 0.42, -0.90, 1.12, 0.33],
2569                [1.0, 0.70, 1.28, -0.78, -0.56],
2570                [1.0, 0.95, -0.18, 0.45, 1.40],
2571                [1.0, 1.20, 0.66, -1.30, -0.02],
2572                [1.0, 1.45, -1.05, 0.22, 0.68],
2573                [1.0, -1.35, 0.90, 0.55, -0.43],
2574                [1.0, -0.98, -0.40, -0.88, 1.05],
2575                [1.0, -0.62, 1.42, 0.30, -0.70],
2576                [1.0, -0.28, -0.77, -1.18, 0.52],
2577                [1.0, 0.05, 0.15, 0.95, -1.35],
2578                [1.0, 0.33, -1.20, -0.40, 0.18],
2579                [1.0, 0.60, 0.82, 1.25, -0.85],
2580                [1.0, 0.88, -0.50, -0.65, 1.10],
2581                [1.0, 1.15, 1.05, 0.10, -0.22],
2582                [1.0, -1.22, -0.95, 0.72, 0.90],
2583                [1.0, -0.75, 0.38, -1.42, 0.15],
2584                [1.0, -0.42, -1.15, 0.50, -1.08],
2585                [1.0, -0.10, 0.60, -0.15, 0.75],
2586                [1.0, 0.25, -0.28, 1.05, -0.48],
2587                [1.0, 0.52, 1.35, -0.92, 0.30],
2588                [1.0, 0.80, -0.70, 0.38, 1.20],
2589                [1.0, 1.08, 0.48, -0.60, -0.95],
2590                [1.0, 1.35, -0.55, 0.85, 0.42]
2591            ];
2592            // Penalty matrix: zero on intercept, SPD on remaining 4 columns.
2593            let s0 = array![
2594                [0.0, 0.0, 0.0, 0.0, 0.0],
2595                [0.0, 1.40, 0.15, 0.05, -0.10],
2596                [0.0, 0.15, 1.10, -0.20, 0.08],
2597                [0.0, 0.05, -0.20, 0.95, 0.12],
2598                [0.0, -0.10, 0.08, 0.12, 1.25]
2599            ];
2600            let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2601            // Design-moving direction: perturb covariate columns, leave
2602            // intercept untouched.
2603            let x_tau_design = array![
2604                [0.0, 1.2e-3, -0.8e-3, 0.5e-3, -1.5e-3],
2605                [0.0, -2.0e-3, 1.4e-3, -0.3e-3, 0.9e-3],
2606                [0.0, 0.6e-3, -1.1e-3, 1.8e-3, -0.4e-3],
2607                [0.0, -1.3e-3, 0.7e-3, -1.0e-3, 2.1e-3],
2608                [0.0, 0.9e-3, -0.5e-3, 0.2e-3, -0.8e-3],
2609                [0.0, -0.4e-3, 1.8e-3, -1.5e-3, 0.3e-3],
2610                [0.0, 1.5e-3, -1.3e-3, 0.8e-3, -1.1e-3],
2611                [0.0, -0.7e-3, 0.4e-3, -2.0e-3, 1.6e-3],
2612                [0.0, 2.2e-3, -0.9e-3, 1.3e-3, -0.6e-3],
2613                [0.0, -1.0e-3, 1.6e-3, -0.7e-3, 0.5e-3],
2614                [0.0, 0.3e-3, -2.1e-3, 1.1e-3, -1.8e-3],
2615                [0.0, -1.8e-3, 0.2e-3, -0.4e-3, 1.3e-3],
2616                [0.0, 1.1e-3, -1.5e-3, 2.0e-3, -0.2e-3],
2617                [0.0, -0.5e-3, 0.9e-3, -1.2e-3, 0.7e-3],
2618                [0.0, 1.7e-3, -0.3e-3, 0.6e-3, -2.0e-3],
2619                [0.0, -1.4e-3, 1.1e-3, -0.9e-3, 0.4e-3],
2620                [0.0, 0.8e-3, -1.7e-3, 1.5e-3, -0.1e-3],
2621                [0.0, -0.2e-3, 0.6e-3, -1.8e-3, 1.0e-3],
2622                [0.0, 1.4e-3, -0.4e-3, 0.3e-3, -1.3e-3],
2623                [0.0, -0.9e-3, 2.0e-3, -0.5e-3, 0.8e-3],
2624                [0.0, 0.5e-3, -1.0e-3, 1.6e-3, -0.7e-3],
2625                [0.0, -2.1e-3, 0.3e-3, -0.8e-3, 1.5e-3],
2626                [0.0, 0.7e-3, -1.8e-3, 0.9e-3, -0.3e-3],
2627                [0.0, -0.6e-3, 1.3e-3, -2.2e-3, 1.1e-3],
2628                [0.0, 1.9e-3, -0.7e-3, 0.4e-3, -0.9e-3],
2629                [0.0, -1.1e-3, 0.5e-3, -1.4e-3, 2.2e-3],
2630                [0.0, 0.4e-3, -1.6e-3, 1.2e-3, -0.5e-3],
2631                [0.0, -1.6e-3, 0.8e-3, -0.1e-3, 0.6e-3],
2632                [0.0, 1.3e-3, -2.2e-3, 0.7e-3, -1.4e-3],
2633                [0.0, -0.3e-3, 1.0e-3, -1.6e-3, 1.8e-3]
2634            ];
2635            // Penalty-only direction: non-zero S_τ, symmetric, zero on intercept.
2636            let s_tau_penalty = array![
2637                [0.0, 0.0, 0.0, 0.0, 0.0],
2638                [0.0, 0.30, 0.05, -0.02, 0.04],
2639                [0.0, 0.05, 0.22, 0.03, -0.01],
2640                [0.0, -0.02, 0.03, 0.18, 0.06],
2641                [0.0, 0.04, -0.01, 0.06, 0.26]
2642            ];
2643            Self {
2644                w: Array1::<f64>::ones(y.len()),
2645                y,
2646                x,
2647                s0,
2648                cfg,
2649                rho: array![0.0],
2650                x_tau_design,
2651                s_tau_penalty,
2652            }
2653        }
2654    }
2655
2656    impl LogitDesignMotionFixture for BinomialLogitDesignMotionFixture {
2657        fn y(&self) -> &Array1<f64> {
2658            &self.y
2659        }
2660        fn w(&self) -> &Array1<f64> {
2661            &self.w
2662        }
2663        fn x(&self) -> &Array2<f64> {
2664            &self.x
2665        }
2666        fn s0(&self) -> &Array2<f64> {
2667            &self.s0
2668        }
2669        fn cfg(&self) -> &RemlConfig {
2670            &self.cfg
2671        }
2672        fn rho(&self) -> &Array1<f64> {
2673            &self.rho
2674        }
2675    }
2676
2677    // ── n=30, p=5 binomial-logit design-motion gradient tests ────────
2678
2679    #[test]
2680    pub(crate) fn binomial_logit_n30_design_moving_gradient_matches_fd() {
2681        // Pure design-motion: X_τ ≠ 0, S_τ = 0.
2682        // The IFT correction is essential here: because the family is
2683        // binomial-logit, the working weights W(η) depend on β̂, so
2684        // when X moves with ψ, the implicit derivative dβ̂/dψ enters
2685        // the total Hessian drift.  Without hessian_derivative_correction
2686        // the analytic gradient would disagree with FD.
2687        let f = BinomialLogitDesignMotionFixture::new();
2688        let state = f.state();
2689        let s_tau = Array2::<f64>::zeros((5, 5));
2690        let hyper = DirectionalHyperParam::single_penalty(
2691            0,
2692            f.x_tau_design.clone(),
2693            s_tau.clone(),
2694            None,
2695            None,
2696        )
2697        .expect("design-moving hyper direction");
2698
2699        let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2700            .expect("analytic directional gradient");
2701        let v_tau_fd = f.fd_directional_gradient(&f.x_tau_design, &s_tau);
2702
2703        let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2704        assert!(
2705            v_rel < 1e-3,
2706            "Binomial-logit n=30 design-moving gradient mismatch: rel={v_rel:.3e}, \
2707             analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2708        );
2709    }
2710
2711    #[test]
2712    pub(crate) fn binomial_logit_n30_penalty_only_gradient_matches_fd() {
2713        // Penalty-only direction: X_τ = 0, S_τ ≠ 0.
2714        // Serves as a baseline: the IFT correction should still be
2715        // present (since H depends on β̂ through W(η)), but the
2716        // explicit X_τ contribution is zero.
2717        let f = BinomialLogitDesignMotionFixture::new();
2718        let state = f.state();
2719        let x_tau = Array2::<f64>::zeros(f.x.raw_dim());
2720        let hyper = DirectionalHyperParam::single_penalty(
2721            0,
2722            x_tau.clone(),
2723            f.s_tau_penalty.clone(),
2724            None,
2725            None,
2726        )
2727        .expect("penalty-only hyper direction");
2728
2729        let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2730            .expect("analytic directional gradient");
2731        let v_tau_fd = f.fd_directional_gradient(&x_tau, &f.s_tau_penalty);
2732
2733        let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2734        assert!(
2735            v_rel < 1e-3,
2736            "Binomial-logit n=30 penalty-only gradient mismatch: rel={v_rel:.3e}, \
2737             analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2738        );
2739    }
2740
2741    #[test]
2742    pub(crate) fn binomial_logit_n30_joint_design_penalty_gradient_matches_fd() {
2743        // Joint direction: both X_τ ≠ 0 and S_τ ≠ 0 simultaneously.
2744        // This is the hardest case: the analytic gradient must correctly
2745        // combine the explicit penalty drift, the explicit design drift,
2746        // and the IFT Hessian-drift correction.
2747        let f = BinomialLogitDesignMotionFixture::new();
2748        let state = f.state();
2749        let hyper = DirectionalHyperParam::single_penalty(
2750            0,
2751            f.x_tau_design.clone(),
2752            f.s_tau_penalty.clone(),
2753            None,
2754            None,
2755        )
2756        .expect("joint design+penalty hyper direction");
2757
2758        let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2759            .expect("analytic directional gradient");
2760        let v_tau_fd = f.fd_directional_gradient(&f.x_tau_design, &f.s_tau_penalty);
2761
2762        let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2763        assert!(
2764            v_rel < 1e-3,
2765            "Binomial-logit n=30 joint design+penalty gradient mismatch: rel={v_rel:.3e}, \
2766             analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2767        );
2768    }
2769
2770    #[test]
2771    pub(crate) fn binomial_logit_n30_design_moving_hessian_matches_fd() {
2772        // Hessian-level validation with two τ-directions: one
2773        // penalty-only and one design-moving.  The ττ Hessian block is
2774        // the most sensitive test of the IFT correction because errors
2775        // in the correction accumulate quadratically in the trace term.
2776        let f = BinomialLogitDesignMotionFixture::new();
2777        let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2778        let s_tau_0 = f.s_tau_penalty.clone();
2779        let x_tau_1 = f.x_tau_design.clone();
2780        let s_tau_1 = Array2::<f64>::zeros((5, 5));
2781
2782        let hyper_dirs = vec![
2783            DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2784                .expect("penalty-only direction"),
2785            DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2786                .expect("design-moving direction"),
2787        ];
2788
2789        let state = f.state();
2790        let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2791        theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2792        let (_, _, h_full) =
2793            compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2794                .expect("joint cost+gradient+hessian");
2795        let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2796
2797        let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2798        let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2799        let h_tt_fd = directional_tau_hessian_fd_reference(
2800            &f.y,
2801            &f.w,
2802            &f.x,
2803            &f.s0,
2804            &f.cfg,
2805            &f.rho,
2806            &hyper_dirs,
2807            &x_tau_mats,
2808            &s_tau_mats,
2809        );
2810
2811        let num = (&h_tt_analytic - &h_tt_fd)
2812            .iter()
2813            .map(|v| v * v)
2814            .sum::<f64>()
2815            .sqrt();
2816        let den = h_tt_fd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2817        let rel = num / den;
2818        assert!(
2819            rel < 1e-4,
2820            "Binomial-logit n=30 tau-tau Hessian mismatch: rel={rel:.3e}, \
2821             analytic={h_tt_analytic:?}, fd={h_tt_fd:?}"
2822        );
2823    }
2824
2825    #[test]
2826    pub(crate) fn binomial_logit_n30_nonzero_rho_design_moving_gradient_matches_fd() {
2827        // Validate at a non-trivial smoothing parameter ρ = log(λ) = 1.5,
2828        // so the penalty term λS is scaled up and the balance between
2829        // likelihood and penalty is different from ρ=0.
2830        let f = BinomialLogitDesignMotionFixture::new();
2831        let rho = array![1.5];
2832        let s_tau = Array2::<f64>::zeros((5, 5));
2833
2834        let state = f.state();
2835        let hyper = DirectionalHyperParam::single_penalty(
2836            0,
2837            f.x_tau_design.clone(),
2838            s_tau.clone(),
2839            None,
2840            None,
2841        )
2842        .expect("design-moving hyper direction");
2843
2844        let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper)
2845            .expect("analytic directional gradient");
2846
2847        // FD at the shifted ρ: perturb X, re-solve inner, evaluate cost.
2848        let h = 2e-5;
2849        let (state_plus, state_minus) = f.state_perturbed(&f.x_tau_design, &s_tau, h);
2850        let v_plus = state_plus.compute_cost(&rho).expect("cost+");
2851        let v_minus = state_minus.compute_cost(&rho).expect("cost-");
2852        let v_tau_fd = (v_plus - v_minus) / (2.0 * h);
2853
2854        let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2855        assert!(
2856            v_rel < 1e-3,
2857            "Binomial-logit n=30 rho=1.5 design-moving gradient mismatch: rel={v_rel:.3e}, \
2858             analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2859        );
2860    }
2861
2862    #[test]
2863    pub(crate) fn binomial_logit_n30_rank_deficient_hessian_matches_cost_fd() {
2864        // Regression lock for the `PenaltySubspaceTrace` pseudo-logdet
2865        // kernel installed by the rank-deficient LAML fix (see
2866        // `PenaltySubspaceTrace` and `intrinsic_hessian_pseudo_logdet_parts`;
2867        // since #901 the cost is the intrinsic `½ log|H_pen|₊` and the kernel
2868        // is the spectral `H_pen⁺`, exact for every drift direction).
2869        //
2870        // The sibling `binomial_logit_n30_design_moving_hessian_matches_fd`
2871        // passes pre- AND post-fix because its FD reference differentiates
2872        // the *analytic gradient* — any self-consistent (if wrong) gradient
2873        // kernel gives a self-consistent Hessian under re-differentiation,
2874        // so that test cannot distinguish full-space from projected traces.
2875        // It passed under the buggy kernel because the same leakage entered
2876        // both sides of the ratio and cancelled.
2877        //
2878        // Here we FD-differentiate `compute_cost` TWICE and compare against
2879        // the analytic Hessian.  Central second differences expose every
2880        // disagreement between `½ log|U_Sᵀ H U_S|_+` (used by the cost) and
2881        // `½ tr(G_ε(H) · Ḣ)` / `−½ tr(G_ε Ḣ_i G_ε Ḣ_j)` (the full-space
2882        // traces that the gradient and Hessian used before the projection
2883        // fix).  Under the buggy kernel the IFT correction
2884        // `D_β H[v] = X' diag(c ⊙ X v) X` leaks onto `null(S)` — X's
2885        // all-ones intercept column sits there — and that leakage enters
2886        // the analytic Hessian but not the cost's projected logdet.
2887        //
2888        // Direction mix chosen to maximise the null-space leakage pathway:
2889        //   τ_0 = penalty-only (X_τ = 0, S_τ ≠ 0)  → v_0 = H⁻¹(−S_τ β̂) is
2890        //         concentrated in range(S_+), but `D_β H[v_0]` has rows and
2891        //         columns on the intercept because `X[:,0] = 1_n`.
2892        //   τ_1 = design-moving (X_τ ≠ 0 on non-intercept columns, S_τ = 0)
2893        //         → `v_1` also picks up the intercept via `X'WX_τβ̂`, and
2894        //         the base drift `X_τᵀWX + XᵀWX_τ` straddles range(S_+) /
2895        //         null(S).
2896        // Both pure directions AND the mixed partial load the Schur correction,
2897        // so any of the three entries can catch a regression.
2898        let f = BinomialLogitDesignMotionFixture::new();
2899        let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2900        let s_tau_0 = f.s_tau_penalty.clone();
2901        let x_tau_1 = f.x_tau_design.clone();
2902        let s_tau_1 = Array2::<f64>::zeros((5, 5));
2903
2904        let hyper_dirs = vec![
2905            DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2906                .expect("penalty-only direction"),
2907            DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2908                .expect("design-moving direction"),
2909        ];
2910
2911        // Analytic Hessian block.
2912        let state = f.state();
2913        let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2914        theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2915        let (_, _, h_full) =
2916            compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2917                .expect("joint cost+gradient+hessian");
2918        let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2919
2920        // Cost-level FD reference.  Central second differences give O(h²)
2921        // accuracy; the step is sized so the physical perturbation on X / S
2922        // stays near `1e-5` (same scale as the gradient tests).
2923        const TARGET_PHYSICAL_STEP: f64 = 1e-5;
2924        let x_tau_mats = [&x_tau_0, &x_tau_1];
2925        let s_tau_mats = [&s_tau_0, &s_tau_1];
2926        let steps: [f64; 2] = {
2927            let mut steps = [0.0; 2];
2928            for (j, step) in steps.iter_mut().enumerate() {
2929                let scale = x_tau_mats[j]
2930                    .iter()
2931                    .chain(s_tau_mats[j].iter())
2932                    .fold(0.0_f64, |acc, value| acc.max(value.abs()));
2933                *step = if scale > 0.0 {
2934                    TARGET_PHYSICAL_STEP / scale
2935                } else {
2936                    TARGET_PHYSICAL_STEP
2937                };
2938            }
2939            steps
2940        };
2941
2942        // Evaluate `compute_cost` at `(a · τ_0, b · τ_1)` multipliers.
2943        let eval_cost = |a: f64, b: f64| -> f64 {
2944            let x_eval = &f.x
2945                + &x_tau_mats[0].mapv(|v| a * steps[0] * v)
2946                + &x_tau_mats[1].mapv(|v| b * steps[1] * v);
2947            let s_eval = &f.s0
2948                + &s_tau_mats[0].mapv(|v| a * steps[0] * v)
2949                + &s_tau_mats[1].mapv(|v| b * steps[1] * v);
2950            let st = build_logit_state(&f.y, &f.w, &x_eval, &s_eval, &f.cfg);
2951            st.compute_cost(&f.rho).expect("cost eval")
2952        };
2953
2954        let v_00 = eval_cost(0.0, 0.0);
2955        let v_p0 = eval_cost(1.0, 0.0);
2956        let v_m0 = eval_cost(-1.0, 0.0);
2957        let v_0p = eval_cost(0.0, 1.0);
2958        let v_0m = eval_cost(0.0, -1.0);
2959        let v_pp = eval_cost(1.0, 1.0);
2960        let v_pm = eval_cost(1.0, -1.0);
2961        let v_mp = eval_cost(-1.0, 1.0);
2962        let v_mm = eval_cost(-1.0, -1.0);
2963
2964        let h00_fd = (v_p0 - 2.0 * v_00 + v_m0) / (steps[0] * steps[0]);
2965        let h11_fd = (v_0p - 2.0 * v_00 + v_0m) / (steps[1] * steps[1]);
2966        let h01_fd = (v_pp - v_pm - v_mp + v_mm) / (4.0 * steps[0] * steps[1]);
2967
2968        let h_tt_fd = array![[h00_fd, h01_fd], [h01_fd, h11_fd]];
2969
2970        let num = (&h_tt_analytic - &h_tt_fd)
2971            .iter()
2972            .map(|v| v * v)
2973            .sum::<f64>()
2974            .sqrt();
2975        let den = h_tt_fd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2976        let rel = num / den;
2977
2978        assert!(
2979            rel < 3e-3,
2980            "Binomial-logit n=30 rank-deficient Hessian vs cost-FD mismatch: rel={rel:.3e}, \
2981             analytic={h_tt_analytic:?}, fd={h_tt_fd:?}"
2982        );
2983    }
2984}
2985
2986#[derive(Clone, Copy, Debug)]
2987pub(crate) enum RemlGeometry {
2988    DenseSpectral,
2989    SparseExactSpd,
2990}
2991
2992trait PenalizedGeometry {
2993    fn backend_kind(&self) -> GeometryBackendKind;
2994}
2995
2996#[derive(Clone)]
2997pub(crate) enum DerivativeMatrixStorage {
2998    Dense(Array2<f64>),
2999    Zero(ZeroDerivativeMatrix),
3000    Embedded(EmbeddedDerivativeMatrix),
3001    Implicit(ImplicitDerivativeOp),
3002    LatentCoord(LatentCoordDerivativeOp),
3003}
3004
3005/// Mechanical surface every `DerivativeMatrixStorage` variant must expose so
3006/// the `HyperDesignDerivative` / `HyperPenaltyDerivative` wrappers can dispatch
3007/// with a single per-call `storage_dispatch!`. Each backend owns its variant's
3008/// substantive math; the wrappers contain only one-line routing.
3009///
3010/// `design_*` variants treat the backend as an X-style operator (rows index
3011/// data, columns index coefficients); `penalty_*` variants treat the backend
3012/// as a square `p×p` penalty in the global coefficient frame. The Embedded
3013/// case is the only variant whose two views genuinely differ (local rows vs
3014/// total_dim square), which is why the two role-specific methods both live in
3015/// one trait rather than two parallel traits.
3016trait DerivativeStorageBackend {
3017    fn resident_byte_count(&self) -> usize;
3018    fn design_nrows(&self) -> usize;
3019    fn design_ncols(&self) -> usize;
3020    fn penalty_dim(&self) -> usize;
3021    fn uses_implicit_storage(&self) -> bool;
3022    fn any_nonzero(&self) -> bool;
3023    fn materialize(&self) -> Array2<f64>;
3024    fn implicit_first_axis_info(
3025        &self,
3026    ) -> Option<(
3027        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3028        usize,
3029    )>;
3030    fn implicit_axis_count_hint(&self) -> Option<usize>;
3031    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError>;
3032    fn design_transpose_mul_original(
3033        &self,
3034        v: &Array1<f64>,
3035    ) -> Result<Array1<f64>, EstimationError>;
3036    fn design_transformed(
3037        &self,
3038        qs: &Array2<f64>,
3039        free_basis_opt: Option<&Array2<f64>>,
3040    ) -> Result<Array2<f64>, EstimationError>;
3041    /// Default materialises through `design_transformed` then `.dot(u)`;
3042    /// implicit/latent-coordinate backends override with a direct-operator
3043    /// path that skips the dense materialisation.
3044    fn design_transformed_forward_mul(
3045        &self,
3046        qs: &Array2<f64>,
3047        free_basis_opt: Option<&Array2<f64>>,
3048        u: &Array1<f64>,
3049    ) -> Result<Array1<f64>, EstimationError> {
3050        Ok(self.design_transformed(qs, free_basis_opt)?.dot(u))
3051    }
3052    /// Default materialises through `design_transformed` then `.t().dot(v)`;
3053    /// implicit/latent-coordinate backends override with a direct path.
3054    fn design_transformed_transpose_mul(
3055        &self,
3056        qs: &Array2<f64>,
3057        free_basis_opt: Option<&Array2<f64>>,
3058        v: &Array1<f64>,
3059    ) -> Result<Array1<f64>, EstimationError> {
3060        Ok(self.design_transformed(qs, free_basis_opt)?.t().dot(v))
3061    }
3062    fn penalty_transformed(
3063        &self,
3064        qs: &Array2<f64>,
3065        free_basis_opt: Option<&Array2<f64>>,
3066    ) -> Result<Array2<f64>, EstimationError>;
3067    fn penalty_scaled_add_to(
3068        &self,
3069        target: &mut Array2<f64>,
3070        amp: f64,
3071    ) -> Result<(), EstimationError>;
3072}
3073
3074/// Fans `expr` over the four `DerivativeMatrixStorage` variants in one place
3075/// so every wrapper method is a single dispatch line — the compiler enforces
3076/// exhaustiveness here, so adding a new variant produces one hard error at
3077/// this site rather than a silent miss in any of the (currently 16) ladders.
3078macro_rules! storage_dispatch {
3079    ($scrutinee:expr, $backend:ident => $body:expr) => {
3080        match $scrutinee {
3081            DerivativeMatrixStorage::Dense($backend) => $body,
3082            DerivativeMatrixStorage::Zero($backend) => $body,
3083            DerivativeMatrixStorage::Embedded($backend) => $body,
3084            DerivativeMatrixStorage::Implicit($backend) => $body,
3085            DerivativeMatrixStorage::LatentCoord($backend) => $body,
3086        }
3087    };
3088}
3089
3090#[derive(Clone)]
3091pub(crate) struct ZeroDerivativeMatrix {
3092    rows: usize,
3093    cols: usize,
3094}
3095
3096impl ZeroDerivativeMatrix {
3097    pub(crate) fn new(rows: usize, cols: usize) -> Self {
3098        Self { rows, cols }
3099    }
3100}
3101
3102/// Which derivative level the implicit operator should compute.
3103#[derive(Clone, Copy, Debug)]
3104pub enum ImplicitDerivLevel {
3105    /// ∂X/∂ψ_d
3106    First(usize),
3107    /// ∂²X/∂ψ_d²
3108    SecondDiag(usize),
3109    /// ∂²X/∂ψ_d∂ψ_e
3110    SecondCross(usize, usize),
3111}
3112
3113/// Lazy implicit operator storage: delegates matvecs to the
3114/// `ImplicitDesignPsiDerivative` and materializes dense form only on demand.
3115#[derive(Clone)]
3116pub(crate) struct ImplicitDerivativeOp {
3117    pub(crate) operator: std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3118    pub(crate) level: ImplicitDerivLevel,
3119    pub(crate) global_range: Range<usize>,
3120    pub(crate) total_dim: usize,
3121    /// Cached dense materialization (lazy, populated on first call to ops that need the full matrix).
3122    ///
3123    /// Rayon-safe: `materialize_local` calls `materialize_first` / `_second_diag`
3124    /// / `_second_cross` on the implicit basis-derivative operator, which for
3125    /// streaming bases dispatches `(0..nc).into_par_iter().for_each(...)`. A plain
3126    /// `std::sync::OnceLock` here would deadlock if `materialize_dense` were first
3127    /// called concurrently from inside another rayon par_iter — racing workers
3128    /// would park on the OnceLock's OS condvar, leaving the leader's nested
3129    /// par_iter without workers. `RayonSafeOnce` runs init lock-free.
3130    pub(crate) cached_dense: std::sync::Arc<gam_runtime::resource::RayonSafeOnce<Array2<f64>>>,
3131}
3132
3133#[derive(Clone)]
3134pub(crate) struct LatentCoordDerivativeOp {
3135    pub(crate) operator: std::sync::Arc<gam_terms::basis::LatentCoordDesignDerivative>,
3136    pub(crate) flat_axis: usize,
3137    pub(crate) global_range: Range<usize>,
3138    pub(crate) total_dim: usize,
3139    pub(crate) cached_dense: std::sync::Arc<gam_runtime::resource::RayonSafeOnce<Array2<f64>>>,
3140}
3141
3142impl LatentCoordDerivativeOp {
3143    pub(crate) fn materialize_local(&self) -> Array2<f64> {
3144        self.operator.materialize_axis(self.flat_axis).expect(
3145            "radial scalar evaluation failed during latent-coordinate derivative materialization",
3146        )
3147    }
3148
3149    pub(crate) fn materialize_dense(&self) -> &Array2<f64> {
3150        self.cached_dense.get_or_compute(|| {
3151            let local = self.materialize_local();
3152            let mut out = Array2::<f64>::zeros((local.nrows(), self.total_dim));
3153            out.slice_mut(s![.., self.global_range.clone()])
3154                .assign(&local);
3155            out
3156        })
3157    }
3158
3159    pub(crate) fn nrows(&self) -> usize {
3160        self.operator.n_data()
3161    }
3162
3163    pub(crate) fn ncols(&self) -> usize {
3164        self.total_dim
3165    }
3166
3167    pub(crate) fn transpose_mul(&self, v: &Array1<f64>) -> Array1<f64> {
3168        let local = self
3169            .operator
3170            .transpose_mul_axis(self.flat_axis, &v.view())
3171            .expect(
3172                "radial scalar evaluation failed during latent-coordinate derivative transpose_mul",
3173            );
3174        let mut out = Array1::<f64>::zeros(self.total_dim);
3175        out.slice_mut(s![self.global_range.clone()]).assign(&local);
3176        out
3177    }
3178
3179    pub(crate) fn forward_mul(&self, u: &Array1<f64>) -> Array1<f64> {
3180        let u_local = u.slice(s![self.global_range.clone()]).to_owned();
3181        self.operator
3182            .forward_mul_axis(self.flat_axis, &u_local.view())
3183            .expect(
3184                "radial scalar evaluation failed during latent-coordinate derivative forward_mul",
3185            )
3186    }
3187}
3188
3189impl ImplicitDerivativeOp {
3190    pub(crate) fn materialize_local(&self) -> Array2<f64> {
3191        match self.level {
3192            ImplicitDerivLevel::First(axis) => self.operator.materialize_first(axis).expect(
3193                "radial scalar evaluation failed during implicit derivative materialization",
3194            ),
3195            ImplicitDerivLevel::SecondDiag(axis) => {
3196                self.operator.materialize_second_diag(axis).expect(
3197                    "radial scalar evaluation failed during implicit derivative materialization",
3198                )
3199            }
3200            ImplicitDerivLevel::SecondCross(d, e) => {
3201                self.operator.materialize_second_cross(d, e).expect(
3202                    "radial scalar evaluation failed during implicit derivative materialization",
3203                )
3204            }
3205        }
3206    }
3207
3208    pub(crate) fn materialize_dense(&self) -> &Array2<f64> {
3209        self.cached_dense.get_or_compute(|| {
3210            let local = self.materialize_local();
3211            let mut out = Array2::<f64>::zeros((local.nrows(), self.total_dim));
3212            out.slice_mut(s![.., self.global_range.clone()])
3213                .assign(&local);
3214            out
3215        })
3216    }
3217
3218    pub(crate) fn nrows(&self) -> usize {
3219        self.operator.n_data()
3220    }
3221
3222    pub(crate) fn ncols(&self) -> usize {
3223        self.total_dim
3224    }
3225
3226    pub(crate) fn transpose_mul(&self, v: &Array1<f64>) -> Array1<f64> {
3227        let local = match self.level {
3228            ImplicitDerivLevel::First(axis) => self
3229                .operator
3230                .transpose_mul(axis, &v.view())
3231                .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
3232            ImplicitDerivLevel::SecondDiag(axis) => self
3233                .operator
3234                .transpose_mul_second_diag(axis, &v.view())
3235                .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
3236            ImplicitDerivLevel::SecondCross(d, e) => self
3237                .operator
3238                .transpose_mul_second_cross(d, e, &v.view())
3239                .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
3240        };
3241        let mut out = Array1::<f64>::zeros(self.total_dim);
3242        out.slice_mut(s![self.global_range.clone()]).assign(&local);
3243        out
3244    }
3245
3246    pub(crate) fn forward_mul(&self, u: &Array1<f64>) -> Array1<f64> {
3247        let u_local = u.slice(s![self.global_range.clone()]).to_owned();
3248        match self.level {
3249            ImplicitDerivLevel::First(axis) => self
3250                .operator
3251                .forward_mul(axis, &u_local.view())
3252                .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3253            ImplicitDerivLevel::SecondDiag(axis) => self
3254                .operator
3255                .forward_mul_second_diag(axis, &u_local.view())
3256                .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3257            ImplicitDerivLevel::SecondCross(d, e) => self
3258                .operator
3259                .forward_mul_second_cross(d, e, &u_local.view())
3260                .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3261        }
3262    }
3263}
3264
3265#[derive(Clone)]
3266pub(crate) struct EmbeddedDerivativeMatrix {
3267    pub(crate) local: Array2<f64>,
3268    pub(crate) global_range: Range<usize>,
3269    pub(crate) total_dim: usize,
3270}
3271
3272impl EmbeddedDerivativeMatrix {
3273    pub(crate) fn new(local: Array2<f64>, global_range: Range<usize>, total_dim: usize) -> Self {
3274        Self {
3275            local,
3276            global_range,
3277            total_dim,
3278        }
3279    }
3280}
3281
3282impl DerivativeStorageBackend for Array2<f64> {
3283    fn resident_byte_count(&self) -> usize {
3284        self.len().saturating_mul(std::mem::size_of::<f64>())
3285    }
3286    fn design_nrows(&self) -> usize {
3287        Array2::nrows(self)
3288    }
3289    fn design_ncols(&self) -> usize {
3290        Array2::ncols(self)
3291    }
3292    fn penalty_dim(&self) -> usize {
3293        Array2::nrows(self)
3294    }
3295    fn uses_implicit_storage(&self) -> bool {
3296        false
3297    }
3298    fn any_nonzero(&self) -> bool {
3299        self.iter().any(|v| *v != 0.0)
3300    }
3301    fn materialize(&self) -> Array2<f64> {
3302        self.clone()
3303    }
3304    fn implicit_first_axis_info(
3305        &self,
3306    ) -> Option<(
3307        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3308        usize,
3309    )> {
3310        None
3311    }
3312    fn implicit_axis_count_hint(&self) -> Option<usize> {
3313        None
3314    }
3315
3316    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3317        if Array2::ncols(self) != u.len() {
3318            crate::bail_invalid_estim!(
3319                "dense hyper design derivative forward_mul_original width mismatch: matrix={}x{}, vector={}",
3320                Array2::nrows(self),
3321                Array2::ncols(self),
3322                u.len()
3323            );
3324        }
3325        Ok(self.dot(u))
3326    }
3327
3328    fn design_transpose_mul_original(
3329        &self,
3330        v: &Array1<f64>,
3331    ) -> Result<Array1<f64>, EstimationError> {
3332        if Array2::nrows(self) != v.len() {
3333            crate::bail_invalid_estim!(
3334                "dense hyper design derivative transpose_mul_original height mismatch: matrix={}x{}, vector={}",
3335                Array2::nrows(self),
3336                Array2::ncols(self),
3337                v.len()
3338            );
3339        }
3340        Ok(self.t().dot(v))
3341    }
3342
3343    fn design_transformed(
3344        &self,
3345        qs: &Array2<f64>,
3346        free_basis_opt: Option<&Array2<f64>>,
3347    ) -> Result<Array2<f64>, EstimationError> {
3348        Ok(gam_linalg::matrix::DenseRightProductView::new(self)
3349            .with_factor(qs)
3350            .with_optional_factor(free_basis_opt)
3351            .materialize())
3352    }
3353
3354    fn penalty_transformed(
3355        &self,
3356        qs: &Array2<f64>,
3357        free_basis_opt: Option<&Array2<f64>>,
3358    ) -> Result<Array2<f64>, EstimationError> {
3359        let mut transformed = qs.t().dot(self).dot(qs);
3360        if let Some(z) = free_basis_opt {
3361            transformed = z.t().dot(&transformed).dot(z);
3362        }
3363        Ok(transformed)
3364    }
3365
3366    fn penalty_scaled_add_to(
3367        &self,
3368        target: &mut Array2<f64>,
3369        amp: f64,
3370    ) -> Result<(), EstimationError> {
3371        if target.raw_dim() != self.raw_dim() {
3372            crate::bail_invalid_estim!(
3373                "dense hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3374                target.nrows(),
3375                target.ncols(),
3376                Array2::nrows(self),
3377                Array2::ncols(self)
3378            );
3379        }
3380        target.scaled_add(amp, self);
3381        Ok(())
3382    }
3383}
3384
3385impl DerivativeStorageBackend for ZeroDerivativeMatrix {
3386    fn resident_byte_count(&self) -> usize {
3387        0
3388    }
3389    fn design_nrows(&self) -> usize {
3390        self.rows
3391    }
3392    fn design_ncols(&self) -> usize {
3393        self.cols
3394    }
3395    fn penalty_dim(&self) -> usize {
3396        self.cols
3397    }
3398    fn uses_implicit_storage(&self) -> bool {
3399        false
3400    }
3401    fn any_nonzero(&self) -> bool {
3402        false
3403    }
3404    fn materialize(&self) -> Array2<f64> {
3405        Array2::<f64>::zeros((self.rows, self.cols))
3406    }
3407    fn implicit_first_axis_info(
3408        &self,
3409    ) -> Option<(
3410        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3411        usize,
3412    )> {
3413        None
3414    }
3415    fn implicit_axis_count_hint(&self) -> Option<usize> {
3416        None
3417    }
3418
3419    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3420        if self.cols != u.len() {
3421            crate::bail_invalid_estim!(
3422                "zero hyper design derivative forward_mul_original width mismatch: matrix={}x{}, vector={}",
3423                self.rows,
3424                self.cols,
3425                u.len()
3426            );
3427        }
3428        Ok(Array1::<f64>::zeros(self.rows))
3429    }
3430
3431    fn design_transpose_mul_original(
3432        &self,
3433        v: &Array1<f64>,
3434    ) -> Result<Array1<f64>, EstimationError> {
3435        if self.rows != v.len() {
3436            crate::bail_invalid_estim!(
3437                "zero hyper design derivative transpose_mul_original height mismatch: matrix={}x{}, vector={}",
3438                self.rows,
3439                self.cols,
3440                v.len()
3441            );
3442        }
3443        Ok(Array1::<f64>::zeros(self.cols))
3444    }
3445
3446    fn design_transformed(
3447        &self,
3448        qs: &Array2<f64>,
3449        free_basis_opt: Option<&Array2<f64>>,
3450    ) -> Result<Array2<f64>, EstimationError> {
3451        if self.cols != qs.nrows() {
3452            crate::bail_invalid_estim!(
3453                "zero design derivative width mismatch: total_cols={}, qs rows={}",
3454                self.cols,
3455                qs.nrows()
3456            );
3457        }
3458        let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3459        Ok(Array2::<f64>::zeros((self.rows, cols)))
3460    }
3461
3462    fn design_transformed_forward_mul(
3463        &self,
3464        qs: &Array2<f64>,
3465        free_basis_opt: Option<&Array2<f64>>,
3466        u: &Array1<f64>,
3467    ) -> Result<Array1<f64>, EstimationError> {
3468        if self.cols != qs.nrows() {
3469            crate::bail_invalid_estim!(
3470                "zero design derivative width mismatch: total_cols={}, qs rows={}",
3471                self.cols,
3472                qs.nrows()
3473            );
3474        }
3475        let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3476        if u.len() != cols {
3477            crate::bail_invalid_estim!(
3478                "zero design derivative transformed forward width mismatch: expected {}, vector={}",
3479                cols,
3480                u.len()
3481            );
3482        }
3483        Ok(Array1::<f64>::zeros(self.rows))
3484    }
3485
3486    fn design_transformed_transpose_mul(
3487        &self,
3488        qs: &Array2<f64>,
3489        free_basis_opt: Option<&Array2<f64>>,
3490        v: &Array1<f64>,
3491    ) -> Result<Array1<f64>, EstimationError> {
3492        if self.rows != v.len() {
3493            crate::bail_invalid_estim!(
3494                "zero design derivative transpose height mismatch: matrix rows={}, vector={}",
3495                self.rows,
3496                v.len()
3497            );
3498        }
3499        if self.cols != qs.nrows() {
3500            crate::bail_invalid_estim!(
3501                "zero design derivative width mismatch: total_cols={}, qs rows={}",
3502                self.cols,
3503                qs.nrows()
3504            );
3505        }
3506        let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3507        Ok(Array1::<f64>::zeros(cols))
3508    }
3509
3510    fn penalty_transformed(
3511        &self,
3512        qs: &Array2<f64>,
3513        free_basis_opt: Option<&Array2<f64>>,
3514    ) -> Result<Array2<f64>, EstimationError> {
3515        if self.cols != qs.nrows() {
3516            crate::bail_invalid_estim!(
3517                "zero penalty derivative width mismatch: total_dim={}, qs rows={}",
3518                self.cols,
3519                qs.nrows()
3520            );
3521        }
3522        let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3523        Ok(Array2::<f64>::zeros((cols, cols)))
3524    }
3525
3526    fn penalty_scaled_add_to(
3527        &self,
3528        target: &mut Array2<f64>,
3529        amp: f64,
3530    ) -> Result<(), EstimationError> {
3531        // Zero penalty derivative: `amp · 0 = 0`, so `amp` scales nothing and
3532        // `target` is left unchanged. Validate it is finite so a bad scale
3533        // surfaces here rather than silently no-op'ing on a NaN/inf amplitude.
3534        if !amp.is_finite() {
3535            crate::bail_invalid_estim!(
3536                "zero hyper penalty derivative received non-finite amp={amp}"
3537            );
3538        }
3539        if target.nrows() != self.cols || target.ncols() != self.cols {
3540            crate::bail_invalid_estim!(
3541                "zero hyper penalty derivative shape mismatch: target={}x{}, expected {}x{}",
3542                target.nrows(),
3543                target.ncols(),
3544                self.cols,
3545                self.cols
3546            );
3547        }
3548        Ok(())
3549    }
3550}
3551
3552impl DerivativeStorageBackend for EmbeddedDerivativeMatrix {
3553    fn resident_byte_count(&self) -> usize {
3554        self.local.len().saturating_mul(std::mem::size_of::<f64>())
3555    }
3556    fn design_nrows(&self) -> usize {
3557        self.local.nrows()
3558    }
3559    fn design_ncols(&self) -> usize {
3560        self.total_dim
3561    }
3562    fn penalty_dim(&self) -> usize {
3563        self.total_dim
3564    }
3565    fn uses_implicit_storage(&self) -> bool {
3566        false
3567    }
3568    fn any_nonzero(&self) -> bool {
3569        self.local.iter().any(|v| *v != 0.0)
3570    }
3571    fn materialize(&self) -> Array2<f64> {
3572        let mut dense = Array2::<f64>::zeros((self.local.nrows(), self.total_dim));
3573        dense
3574            .slice_mut(s![.., self.global_range.clone()])
3575            .assign(&self.local);
3576        dense
3577    }
3578    fn implicit_first_axis_info(
3579        &self,
3580    ) -> Option<(
3581        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3582        usize,
3583    )> {
3584        None
3585    }
3586    fn implicit_axis_count_hint(&self) -> Option<usize> {
3587        None
3588    }
3589
3590    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3591        if self.total_dim != u.len() {
3592            crate::bail_invalid_estim!(
3593                "embedded hyper design derivative forward_mul_original width mismatch: total_dim={}, vector={}",
3594                self.total_dim,
3595                u.len()
3596            );
3597        }
3598        let u_local = u.slice(s![self.global_range.clone()]).to_owned();
3599        Ok(self.local.dot(&u_local))
3600    }
3601
3602    fn design_transpose_mul_original(
3603        &self,
3604        v: &Array1<f64>,
3605    ) -> Result<Array1<f64>, EstimationError> {
3606        if self.local.nrows() != v.len() {
3607            crate::bail_invalid_estim!(
3608                "embedded hyper design derivative transpose_mul_original height mismatch: local_rows={}, vector={}",
3609                self.local.nrows(),
3610                v.len()
3611            );
3612        }
3613        let mut out = Array1::<f64>::zeros(self.total_dim);
3614        let pulled = self.local.t().dot(v);
3615        out.slice_mut(s![self.global_range.clone()]).assign(&pulled);
3616        Ok(out)
3617    }
3618
3619    fn design_transformed(
3620        &self,
3621        qs: &Array2<f64>,
3622        free_basis_opt: Option<&Array2<f64>>,
3623    ) -> Result<Array2<f64>, EstimationError> {
3624        if self.total_dim != qs.nrows() {
3625            crate::bail_invalid_estim!(
3626                "embedded design derivative width mismatch: total_cols={}, qs rows={}",
3627                self.total_dim,
3628                qs.nrows()
3629            );
3630        }
3631        let qs_local = qs.slice(s![self.global_range.clone(), ..]);
3632        let mut transformed = self.local.dot(&qs_local);
3633        if let Some(z) = free_basis_opt {
3634            transformed = transformed.dot(z);
3635        }
3636        Ok(transformed)
3637    }
3638
3639    fn penalty_transformed(
3640        &self,
3641        qs: &Array2<f64>,
3642        free_basis_opt: Option<&Array2<f64>>,
3643    ) -> Result<Array2<f64>, EstimationError> {
3644        if self.total_dim != qs.nrows() {
3645            crate::bail_invalid_estim!(
3646                "embedded penalty derivative width mismatch: total_dim={}, qs rows={}",
3647                self.total_dim,
3648                qs.nrows()
3649            );
3650        }
3651        let qs_local = qs.slice(s![self.global_range.clone(), ..]);
3652        let mut transformed = qs_local.t().dot(&self.local).dot(&qs_local);
3653        if let Some(z) = free_basis_opt {
3654            transformed = z.t().dot(&transformed).dot(z);
3655        }
3656        Ok(transformed)
3657    }
3658
3659    fn penalty_scaled_add_to(
3660        &self,
3661        target: &mut Array2<f64>,
3662        amp: f64,
3663    ) -> Result<(), EstimationError> {
3664        if target.nrows() != self.total_dim || target.ncols() != self.total_dim {
3665            crate::bail_invalid_estim!(
3666                "embedded hyper penalty derivative shape mismatch: target={}x{}, expected {}x{}",
3667                target.nrows(),
3668                target.ncols(),
3669                self.total_dim,
3670                self.total_dim
3671            );
3672        }
3673        target
3674            .slice_mut(s![self.global_range.clone(), self.global_range.clone()])
3675            .scaled_add(amp, &self.local);
3676        Ok(())
3677    }
3678}
3679
3680impl DerivativeStorageBackend for ImplicitDerivativeOp {
3681    fn resident_byte_count(&self) -> usize {
3682        0
3683    }
3684    fn design_nrows(&self) -> usize {
3685        self.nrows()
3686    }
3687    fn design_ncols(&self) -> usize {
3688        self.ncols()
3689    }
3690    fn penalty_dim(&self) -> usize {
3691        self.nrows()
3692    }
3693    fn uses_implicit_storage(&self) -> bool {
3694        true
3695    }
3696    fn any_nonzero(&self) -> bool {
3697        true
3698    }
3699    fn materialize(&self) -> Array2<f64> {
3700        self.materialize_dense().clone()
3701    }
3702    fn implicit_first_axis_info(
3703        &self,
3704    ) -> Option<(
3705        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3706        usize,
3707    )> {
3708        match self.level {
3709            ImplicitDerivLevel::First(axis) => Some((self.operator.clone(), axis)),
3710            _ => None,
3711        }
3712    }
3713    fn implicit_axis_count_hint(&self) -> Option<usize> {
3714        Some(self.operator.n_axes())
3715    }
3716
3717    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3718        if self.ncols() != u.len() {
3719            crate::bail_invalid_estim!(
3720                "implicit hyper design derivative forward_mul_original width mismatch: operator_cols={}, vector={}",
3721                self.ncols(),
3722                u.len()
3723            );
3724        }
3725        Ok(self.forward_mul(u))
3726    }
3727
3728    fn design_transpose_mul_original(
3729        &self,
3730        v: &Array1<f64>,
3731    ) -> Result<Array1<f64>, EstimationError> {
3732        if self.nrows() != v.len() {
3733            crate::bail_invalid_estim!(
3734                "implicit hyper design derivative transpose_mul_original height mismatch: operator_rows={}, vector={}",
3735                self.nrows(),
3736                v.len()
3737            );
3738        }
3739        Ok(self.transpose_mul(v))
3740    }
3741
3742    fn design_transformed(
3743        &self,
3744        qs: &Array2<f64>,
3745        free_basis_opt: Option<&Array2<f64>>,
3746    ) -> Result<Array2<f64>, EstimationError> {
3747        let dense = self.materialize_dense();
3748        Ok(gam_linalg::matrix::DenseRightProductView::new(dense)
3749            .with_factor(qs)
3750            .with_optional_factor(free_basis_opt)
3751            .materialize())
3752    }
3753
3754    fn design_transformed_forward_mul(
3755        &self,
3756        qs: &Array2<f64>,
3757        free_basis_opt: Option<&Array2<f64>>,
3758        u: &Array1<f64>,
3759    ) -> Result<Array1<f64>, EstimationError> {
3760        let mut right = if let Some(z) = free_basis_opt {
3761            z.dot(u)
3762        } else {
3763            u.clone()
3764        };
3765        right = qs.dot(&right);
3766        Ok(self.forward_mul(&right))
3767    }
3768
3769    fn design_transformed_transpose_mul(
3770        &self,
3771        qs: &Array2<f64>,
3772        free_basis_opt: Option<&Array2<f64>>,
3773        v: &Array1<f64>,
3774    ) -> Result<Array1<f64>, EstimationError> {
3775        let mut pulled = qs.t().dot(&self.transpose_mul(v));
3776        if let Some(z) = free_basis_opt {
3777            pulled = z.t().dot(&pulled);
3778        }
3779        Ok(pulled)
3780    }
3781
3782    fn penalty_transformed(
3783        &self,
3784        qs: &Array2<f64>,
3785        free_basis_opt: Option<&Array2<f64>>,
3786    ) -> Result<Array2<f64>, EstimationError> {
3787        let dense = self.materialize_dense();
3788        let mut transformed = qs.t().dot(dense).dot(qs);
3789        if let Some(z) = free_basis_opt {
3790            transformed = z.t().dot(&transformed).dot(z);
3791        }
3792        Ok(transformed)
3793    }
3794
3795    fn penalty_scaled_add_to(
3796        &self,
3797        target: &mut Array2<f64>,
3798        amp: f64,
3799    ) -> Result<(), EstimationError> {
3800        let dense = self.materialize_dense();
3801        if target.raw_dim() != dense.raw_dim() {
3802            crate::bail_invalid_estim!(
3803                "implicit hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3804                target.nrows(),
3805                target.ncols(),
3806                dense.nrows(),
3807                dense.ncols()
3808            );
3809        }
3810        target.scaled_add(amp, dense);
3811        Ok(())
3812    }
3813}
3814
3815impl DerivativeStorageBackend for LatentCoordDerivativeOp {
3816    fn resident_byte_count(&self) -> usize {
3817        0
3818    }
3819    fn design_nrows(&self) -> usize {
3820        self.nrows()
3821    }
3822    fn design_ncols(&self) -> usize {
3823        self.ncols()
3824    }
3825    fn penalty_dim(&self) -> usize {
3826        self.nrows()
3827    }
3828    fn uses_implicit_storage(&self) -> bool {
3829        true
3830    }
3831    fn any_nonzero(&self) -> bool {
3832        true
3833    }
3834    fn materialize(&self) -> Array2<f64> {
3835        self.materialize_dense().clone()
3836    }
3837    fn implicit_first_axis_info(
3838        &self,
3839    ) -> Option<(
3840        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3841        usize,
3842    )> {
3843        None
3844    }
3845    fn implicit_axis_count_hint(&self) -> Option<usize> {
3846        Some(self.operator.n_axes())
3847    }
3848
3849    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3850        if self.ncols() != u.len() {
3851            crate::bail_invalid_estim!(
3852                "latent-coordinate hyper design derivative forward_mul_original width mismatch: operator_cols={}, vector={}",
3853                self.ncols(),
3854                u.len()
3855            );
3856        }
3857        Ok(self.forward_mul(u))
3858    }
3859
3860    fn design_transpose_mul_original(
3861        &self,
3862        v: &Array1<f64>,
3863    ) -> Result<Array1<f64>, EstimationError> {
3864        if self.nrows() != v.len() {
3865            crate::bail_invalid_estim!(
3866                "latent-coordinate hyper design derivative transpose_mul_original height mismatch: operator_rows={}, vector={}",
3867                self.nrows(),
3868                v.len()
3869            );
3870        }
3871        Ok(self.transpose_mul(v))
3872    }
3873
3874    fn design_transformed(
3875        &self,
3876        qs: &Array2<f64>,
3877        free_basis_opt: Option<&Array2<f64>>,
3878    ) -> Result<Array2<f64>, EstimationError> {
3879        let dense = self.materialize_dense();
3880        Ok(gam_linalg::matrix::DenseRightProductView::new(dense)
3881            .with_factor(qs)
3882            .with_optional_factor(free_basis_opt)
3883            .materialize())
3884    }
3885
3886    fn design_transformed_forward_mul(
3887        &self,
3888        qs: &Array2<f64>,
3889        free_basis_opt: Option<&Array2<f64>>,
3890        u: &Array1<f64>,
3891    ) -> Result<Array1<f64>, EstimationError> {
3892        let mut right = if let Some(z) = free_basis_opt {
3893            z.dot(u)
3894        } else {
3895            u.clone()
3896        };
3897        right = qs.dot(&right);
3898        Ok(self.forward_mul(&right))
3899    }
3900
3901    fn design_transformed_transpose_mul(
3902        &self,
3903        qs: &Array2<f64>,
3904        free_basis_opt: Option<&Array2<f64>>,
3905        v: &Array1<f64>,
3906    ) -> Result<Array1<f64>, EstimationError> {
3907        let mut pulled = qs.t().dot(&self.transpose_mul(v));
3908        if let Some(z) = free_basis_opt {
3909            pulled = z.t().dot(&pulled);
3910        }
3911        Ok(pulled)
3912    }
3913
3914    fn penalty_transformed(
3915        &self,
3916        qs: &Array2<f64>,
3917        free_basis_opt: Option<&Array2<f64>>,
3918    ) -> Result<Array2<f64>, EstimationError> {
3919        let dense = self.materialize_dense();
3920        let mut transformed = qs.t().dot(dense).dot(qs);
3921        if let Some(z) = free_basis_opt {
3922            transformed = z.t().dot(&transformed).dot(z);
3923        }
3924        Ok(transformed)
3925    }
3926
3927    fn penalty_scaled_add_to(
3928        &self,
3929        target: &mut Array2<f64>,
3930        amp: f64,
3931    ) -> Result<(), EstimationError> {
3932        let dense = self.materialize_dense();
3933        if target.raw_dim() != dense.raw_dim() {
3934            crate::bail_invalid_estim!(
3935                "latent-coordinate hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3936                target.nrows(),
3937                target.ncols(),
3938                dense.nrows(),
3939                dense.ncols()
3940            );
3941        }
3942        target.scaled_add(amp, dense);
3943        Ok(())
3944    }
3945}
3946
3947#[derive(Clone)]
3948pub struct HyperDesignDerivative {
3949    pub(crate) storage: DerivativeMatrixStorage,
3950}
3951
3952impl HyperDesignDerivative {
3953    pub fn zero(nrows: usize, ncols: usize) -> Self {
3954        Self {
3955            storage: DerivativeMatrixStorage::Zero(ZeroDerivativeMatrix::new(nrows, ncols)),
3956        }
3957    }
3958
3959    pub fn from_embedded(
3960        local: Array2<f64>,
3961        global_range: Range<usize>,
3962        total_cols: usize,
3963    ) -> Self {
3964        Self {
3965            storage: DerivativeMatrixStorage::Embedded(EmbeddedDerivativeMatrix::new(
3966                local,
3967                global_range,
3968                total_cols,
3969            )),
3970        }
3971    }
3972
3973    pub fn from_implicit(
3974        operator: std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3975        level: ImplicitDerivLevel,
3976        global_range: Range<usize>,
3977        total_cols: usize,
3978    ) -> Self {
3979        Self {
3980            storage: DerivativeMatrixStorage::Implicit(ImplicitDerivativeOp {
3981                operator,
3982                level,
3983                global_range,
3984                total_dim: total_cols,
3985                cached_dense: std::sync::Arc::new(gam_runtime::resource::RayonSafeOnce::new()),
3986            }),
3987        }
3988    }
3989
3990    pub fn from_latent_coord(
3991        operator: std::sync::Arc<gam_terms::basis::LatentCoordDesignDerivative>,
3992        flat_axis: usize,
3993        global_range: Range<usize>,
3994        total_cols: usize,
3995    ) -> Self {
3996        Self {
3997            storage: DerivativeMatrixStorage::LatentCoord(LatentCoordDerivativeOp {
3998                operator,
3999                flat_axis,
4000                global_range,
4001                total_dim: total_cols,
4002                cached_dense: std::sync::Arc::new(gam_runtime::resource::RayonSafeOnce::new()),
4003            }),
4004        }
4005    }
4006
4007    pub(crate) fn resident_byte_count(&self) -> usize {
4008        storage_dispatch!(&self.storage, b => b.resident_byte_count())
4009    }
4010
4011    pub(crate) fn nrows(&self) -> usize {
4012        storage_dispatch!(&self.storage, b => b.design_nrows())
4013    }
4014
4015    pub(crate) fn ncols(&self) -> usize {
4016        storage_dispatch!(&self.storage, b => b.design_ncols())
4017    }
4018
4019    pub(crate) fn uses_implicit_storage(&self) -> bool {
4020        storage_dispatch!(&self.storage, b => b.uses_implicit_storage())
4021    }
4022
4023    pub(crate) fn materialize(&self) -> Array2<f64> {
4024        storage_dispatch!(&self.storage, b => b.materialize())
4025    }
4026
4027    pub(crate) fn any_nonzero(&self) -> bool {
4028        storage_dispatch!(&self.storage, b => b.any_nonzero())
4029    }
4030
4031    pub(crate) fn forward_mul_original(
4032        &self,
4033        u: &Array1<f64>,
4034    ) -> Result<Array1<f64>, EstimationError> {
4035        storage_dispatch!(&self.storage, b => b.design_forward_mul_original(u))
4036    }
4037
4038    pub(crate) fn transpose_mul_original(
4039        &self,
4040        v: &Array1<f64>,
4041    ) -> Result<Array1<f64>, EstimationError> {
4042        storage_dispatch!(&self.storage, b => b.design_transpose_mul_original(v))
4043    }
4044
4045    pub(crate) fn transformed(
4046        &self,
4047        qs: &Array2<f64>,
4048        free_basis_opt: Option<&Array2<f64>>,
4049    ) -> Result<Array2<f64>, EstimationError> {
4050        storage_dispatch!(&self.storage, b => b.design_transformed(qs, free_basis_opt))
4051    }
4052
4053    pub(crate) fn transformed_forward_mul(
4054        &self,
4055        qs: &Array2<f64>,
4056        free_basis_opt: Option<&Array2<f64>>,
4057        u: &Array1<f64>,
4058    ) -> Result<Array1<f64>, EstimationError> {
4059        storage_dispatch!(&self.storage, b => b.design_transformed_forward_mul(qs, free_basis_opt, u))
4060    }
4061
4062    pub(crate) fn transformed_transpose_mul(
4063        &self,
4064        qs: &Array2<f64>,
4065        free_basis_opt: Option<&Array2<f64>>,
4066        v: &Array1<f64>,
4067    ) -> Result<Array1<f64>, EstimationError> {
4068        storage_dispatch!(&self.storage, b => b.design_transformed_transpose_mul(qs, free_basis_opt, v))
4069    }
4070
4071    /// If this derivative uses implicit storage at the first-derivative level,
4072    /// return the shared implicit operator and the axis index.
4073    ///
4074    /// Returns `None` for dense/embedded storage or for second-derivative levels.
4075    pub(crate) fn implicit_first_axis_info(
4076        &self,
4077    ) -> Option<(
4078        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
4079        usize,
4080    )> {
4081        storage_dispatch!(&self.storage, b => b.implicit_first_axis_info())
4082    }
4083
4084    pub(crate) fn implicit_axis_count_hint(&self) -> Option<usize> {
4085        storage_dispatch!(&self.storage, b => b.implicit_axis_count_hint())
4086    }
4087}
4088
4089impl From<Array2<f64>> for HyperDesignDerivative {
4090    fn from(value: Array2<f64>) -> Self {
4091        Self {
4092            storage: DerivativeMatrixStorage::Dense(value),
4093        }
4094    }
4095}
4096
4097#[derive(Clone)]
4098pub struct HyperPenaltyDerivative {
4099    pub(crate) storage: DerivativeMatrixStorage,
4100}
4101
4102impl HyperPenaltyDerivative {
4103    pub fn from_embedded(
4104        local: Array2<f64>,
4105        global_range: Range<usize>,
4106        total_dim: usize,
4107    ) -> Self {
4108        Self {
4109            storage: DerivativeMatrixStorage::Embedded(EmbeddedDerivativeMatrix::new(
4110                local,
4111                global_range,
4112                total_dim,
4113            )),
4114        }
4115    }
4116
4117    pub(crate) fn resident_byte_count(&self) -> usize {
4118        storage_dispatch!(&self.storage, b => b.resident_byte_count())
4119    }
4120
4121    pub(crate) fn nrows(&self) -> usize {
4122        storage_dispatch!(&self.storage, b => b.penalty_dim())
4123    }
4124
4125    pub(crate) fn ncols(&self) -> usize {
4126        self.nrows()
4127    }
4128
4129    pub(crate) fn scaled_materialize(&self, amp: f64) -> Array2<f64> {
4130        let mut out = Array2::<f64>::zeros((self.nrows(), self.ncols()));
4131        self.scaled_add_to(&mut out, amp)
4132            .expect("scaled materialize uses matching target shape");
4133        out
4134    }
4135
4136    pub(crate) fn transformed(
4137        &self,
4138        qs: &Array2<f64>,
4139        free_basis_opt: Option<&Array2<f64>>,
4140    ) -> Result<Array2<f64>, EstimationError> {
4141        storage_dispatch!(&self.storage, b => b.penalty_transformed(qs, free_basis_opt))
4142    }
4143
4144    pub(crate) fn scaled_add_to(
4145        &self,
4146        target: &mut Array2<f64>,
4147        amp: f64,
4148    ) -> Result<(), EstimationError> {
4149        storage_dispatch!(&self.storage, b => b.penalty_scaled_add_to(target, amp))
4150    }
4151}
4152
4153impl From<Array2<f64>> for HyperPenaltyDerivative {
4154    fn from(value: Array2<f64>) -> Self {
4155        Self {
4156            storage: DerivativeMatrixStorage::Dense(value),
4157        }
4158    }
4159}
4160
4161#[derive(Clone)]
4162pub struct PenaltyDerivativeComponent {
4163    pub penalty_index: usize,
4164    pub matrix: HyperPenaltyDerivative,
4165}
4166
4167#[derive(Clone)]
4168pub struct DirectionalHyperParam {
4169    pub(crate) x_tau_original: HyperDesignDerivative,
4170    // Canonical penalty representation: every tau direction is decomposed into
4171    // base-penalty derivatives. There is no separate "assembled total" path.
4172    pub(crate) penalty_first_components: Vec<PenaltyDerivativeComponent>,
4173    // Optional pairwise second hyper-derivatives against all tau directions.
4174    // If provided, each vector must have length psi_dim and hold an optional
4175    // X_{tau_i,tau_j} entry in original coordinates.
4176    pub(crate) x_tau_tau_original: Option<Vec<Option<HyperDesignDerivative>>>,
4177    // Pairwise second derivatives are stored in the same canonical base-penalty
4178    // decomposition as the first derivatives.
4179    pub(crate) penaltysecond_components: Option<Vec<Option<Vec<PenaltyDerivativeComponent>>>>,
4180    pub(crate) penaltysecond_component_provider: Option<
4181        std::sync::Arc<
4182            dyn Fn(usize) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError>
4183                + Send
4184                + Sync
4185                + 'static,
4186        >,
4187    >,
4188    pub(crate) penaltysecond_partner_indices: Option<std::sync::Arc<[usize]>>,
4189    /// Whether this coordinate is penalty-like (B_i = ∂H/∂τ_i is PSD).
4190    /// True for τ (penalty scaling) coordinates; false for ψ (design-moving,
4191    /// anisotropic length-scale) coordinates. Controls EFS eligibility.
4192    pub(crate) is_penalty_like: bool,
4193}
4194
4195impl DirectionalHyperParam {
4196    pub(crate) fn resident_byte_count(&self) -> usize {
4197        let mut bytes = self.x_tau_original.resident_byte_count();
4198        for component in &self.penalty_first_components {
4199            bytes = bytes.saturating_add(component.matrix.resident_byte_count());
4200        }
4201        if let Some(entries) = self.x_tau_tau_original.as_ref() {
4202            for entry in entries.iter().flatten() {
4203                bytes = bytes.saturating_add(entry.resident_byte_count());
4204            }
4205        }
4206        if let Some(rows) = self.penaltysecond_components.as_ref() {
4207            for components in rows.iter().flatten() {
4208                for component in components {
4209                    bytes = bytes.saturating_add(component.matrix.resident_byte_count());
4210                }
4211            }
4212        }
4213        bytes
4214    }
4215
4216    pub(crate) fn canonicalize_penalty_components(
4217        components: Vec<(usize, HyperPenaltyDerivative)>,
4218    ) -> Result<Vec<PenaltyDerivativeComponent>, EstimationError> {
4219        let mut out: Vec<PenaltyDerivativeComponent> = Vec::with_capacity(components.len());
4220        for (penalty_index, matrix) in components {
4221            if out.iter().any(|c| c.penalty_index == penalty_index) {
4222                crate::bail_invalid_estim!(
4223                    "duplicate penalty derivative component for penalty {}",
4224                    penalty_index
4225                );
4226            }
4227            out.push(PenaltyDerivativeComponent {
4228                penalty_index,
4229                matrix,
4230            });
4231        }
4232        Ok(out)
4233    }
4234
4235    pub fn new_compact(
4236        x_tau_original: HyperDesignDerivative,
4237        penalty_first_components: Vec<(usize, HyperPenaltyDerivative)>,
4238        x_tau_tau_original: Option<Vec<Option<HyperDesignDerivative>>>,
4239        penaltysecond_components: Option<Vec<Option<Vec<(usize, HyperPenaltyDerivative)>>>>,
4240    ) -> Result<Self, EstimationError> {
4241        let is_penalty_like = !x_tau_original.any_nonzero();
4242        let penalty_first_components =
4243            Self::canonicalize_penalty_components(penalty_first_components)?;
4244        let penaltysecond_components = match penaltysecond_components {
4245            Some(rows) => {
4246                let mut out = Vec::with_capacity(rows.len());
4247                for row in rows {
4248                    out.push(match row {
4249                        Some(components) => {
4250                            Some(Self::canonicalize_penalty_components(components)?)
4251                        }
4252                        None => None,
4253                    });
4254                }
4255                Some(out)
4256            }
4257            None => None,
4258        };
4259        Ok(Self {
4260            x_tau_original,
4261            penalty_first_components,
4262            x_tau_tau_original,
4263            penaltysecond_components,
4264            penaltysecond_component_provider: None,
4265            penaltysecond_partner_indices: None,
4266            is_penalty_like,
4267        })
4268    }
4269
4270    /// Mark this coordinate as non-penalty-like (design-moving).
4271    /// EFS will skip it; use Newton/BFGS for these coordinates.
4272    pub fn not_penalty_like(mut self) -> Self {
4273        self.is_penalty_like = false;
4274        self
4275    }
4276
4277    pub fn with_penaltysecond_component_provider(
4278        mut self,
4279        provider: std::sync::Arc<
4280            dyn Fn(usize) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError>
4281                + Send
4282                + Sync
4283                + 'static,
4284        >,
4285    ) -> Self {
4286        self.penaltysecond_component_provider = Some(provider);
4287        self
4288    }
4289
4290    pub fn with_penaltysecond_partner_indices(mut self, partners: Vec<usize>) -> Self {
4291        self.penaltysecond_partner_indices = Some(std::sync::Arc::from(partners));
4292        self
4293    }
4294
4295    pub(crate) fn x_tau_dense(&self) -> Array2<f64> {
4296        self.x_tau_original.materialize()
4297    }
4298
4299    pub(crate) fn transformed_x_tau(
4300        &self,
4301        qs: &Array2<f64>,
4302        free_basis_opt: Option<&Array2<f64>>,
4303    ) -> Result<Array2<f64>, EstimationError> {
4304        self.x_tau_original.transformed(qs, free_basis_opt)
4305    }
4306
4307    pub(crate) fn x_tau_tau_entry_at(&self, j: usize) -> Option<HyperDesignDerivative> {
4308        self.x_tau_tau_original
4309            .as_ref()
4310            .and_then(|rows| rows.get(j))
4311            .and_then(|entry| entry.clone())
4312    }
4313
4314    /// Whether this coordinate's design derivative uses implicit storage at the
4315    /// first-derivative level.
4316    pub(crate) fn has_implicit_operator(&self) -> bool {
4317        self.x_tau_original.uses_implicit_storage()
4318    }
4319
4320    pub(crate) fn has_implicit_multidim_duchon(&self) -> bool {
4321        self.implicit_first_axis_info()
4322            .is_some_and(|(op, _)| op.n_axes() > 1 && op.is_duchon_family())
4323    }
4324
4325    /// Extract the implicit design derivative operator and axis, if available.
4326    pub(crate) fn implicit_first_axis_info(
4327        &self,
4328    ) -> Option<(
4329        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
4330        usize,
4331    )> {
4332        self.x_tau_original.implicit_first_axis_info()
4333    }
4334
4335    pub(crate) fn implicit_axis_count_hint(&self) -> Option<usize> {
4336        self.x_tau_original.implicit_axis_count_hint()
4337    }
4338
4339    pub(crate) fn penalty_first_components(&self) -> &[PenaltyDerivativeComponent] {
4340        &self.penalty_first_components
4341    }
4342
4343    pub(crate) fn penalty_total_at(
4344        &self,
4345        rho: &Array1<f64>,
4346        p: usize,
4347    ) -> Result<Array2<f64>, EstimationError> {
4348        let mut out = Array2::<f64>::zeros((p, p));
4349        for component in &self.penalty_first_components {
4350            if component.matrix.nrows() != p || component.matrix.ncols() != p {
4351                crate::bail_invalid_estim!(
4352                    "S_tau shape mismatch for penalty {}: expected {}x{}, got {}x{}",
4353                    component.penalty_index,
4354                    p,
4355                    p,
4356                    component.matrix.nrows(),
4357                    component.matrix.ncols()
4358                );
4359            }
4360            if component.penalty_index >= rho.len() {
4361                crate::bail_invalid_estim!(
4362                    "penalty_index {} out of bounds for rho dimension {}",
4363                    component.penalty_index,
4364                    rho.len()
4365                );
4366            }
4367            component
4368                .matrix
4369                .scaled_add_to(&mut out, rho[component.penalty_index].exp())?;
4370        }
4371        Ok(out)
4372    }
4373
4374    pub(crate) fn penaltysecond_components_for(
4375        &self,
4376        j: usize,
4377    ) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError> {
4378        if let Some(components) = self
4379            .penaltysecond_components
4380            .as_ref()
4381            .and_then(|rows| rows.get(j))
4382            .and_then(|row| row.clone())
4383        {
4384            return Ok(Some(components));
4385        }
4386        if let Some(provider) = self.penaltysecond_component_provider.as_ref() {
4387            return provider(j);
4388        }
4389        Ok(None)
4390    }
4391
4392    pub(crate) fn penaltysecond_componentrows(
4393        &self,
4394    ) -> Option<&[Option<Vec<PenaltyDerivativeComponent>>]> {
4395        self.penaltysecond_components.as_deref()
4396    }
4397
4398    pub(crate) fn penalty_first_component_count(&self) -> usize {
4399        self.penalty_first_components.len()
4400    }
4401
4402    pub(crate) fn has_penaltysecond_pair_at(&self, j: usize) -> bool {
4403        self.penaltysecond_components
4404            .as_ref()
4405            .and_then(|rows| rows.get(j))
4406            .is_some_and(Option::is_some)
4407            || self
4408                .penaltysecond_partner_indices
4409                .as_ref()
4410                .is_some_and(|partners| partners.contains(&j))
4411    }
4412}
4413
4414#[derive(Clone, Debug)]
4415pub(crate) struct SparseRemlDecision {
4416    pub(crate) geometry: RemlGeometry,
4417    pub(crate) reason: &'static str,
4418    pub(crate) p: usize,
4419    pub(crate) nnz_x: usize,
4420    pub(crate) nnz_h_upper_est: Option<usize>,
4421    pub(crate) density_h_upper_est: Option<f64>,
4422}
4423
4424#[derive(Clone)]
4425pub(crate) struct SparseExactEvalData {
4426    pub(crate) factor: Arc<SparseExactFactor>,
4427    pub(crate) takahashi: Option<Arc<gam_linalg::sparse_exact::TakahashiInverse>>,
4428    pub(crate) logdet_h: f64,
4429    pub(crate) logdet_s_pos: f64,
4430    pub(crate) penalty_rank: usize,
4431    pub(crate) det1_values: Arc<Array1<f64>>,
4432}
4433
4434#[derive(Clone)]
4435pub struct FirthDenseOperator {
4436    // Exact Firth/Jeffreys objects on the identifiable subspace.
4437    //
4438    // Let X in R^{n×p} potentially be rank-deficient with rank r.
4439    // With optional fixed observation weights a_i >= 0 we define A = diag(a),
4440    // choose an orthonormal coefficient-space basis Q for the identifiable
4441    // subspace of A^{1/2} X, and set:
4442    //   X_r := A^{1/2} X Q          (A = I when no fixed observation weights),
4443    //   W   := diag(w), with w_i = mu_i (1 - mu_i), 0 < w_i <= 1/4 for finite logit eta,
4444    //   I_r := X_rᵀ W X_r,
4445    //   S_r := X_rᵀ X_r.
4446    //
4447    // Firth term is represented as:
4448    //   Phi(beta) = 0.5 log |I_r(beta)| - 0.5 log |S_r|,
4449    // which is exactly
4450    //   0.5 log |Uᵀ W U|
4451    // for the canonical orthonormalized identifiable design
4452    //   U = X_r S_r^{-1/2}.
4453    // This removes the raw-basis term from explicit reduced designs while
4454    // keeping the same identifiable-subspace hat matrix and beta derivatives,
4455    // because S_r is fixed with respect to beta.
4456    //
4457    // Mapping back to the full p-space uses:
4458    //   I_+^dagger = Q I_r^{-1} Qᵀ.
4459    //
4460    // We store reduced-space factors so all derivatives can be evaluated exactly
4461    // without materializing dense n×n matrices M = X K Xᵀ or P = M⊙M.
4462    pub(crate) x_dense: Array2<f64>,
4463    pub(crate) x_dense_t: Array2<f64>,
4464    // Orthonormal coefficient-space basis for the identifiable subspace,
4465    // built from the retained eigenspace of (A^{1/2} X)ᵀ(A^{1/2} X).
4466    pub(crate) q_basis: Array2<f64>,
4467    // Reduced identifiable design. With fixed observation weights a_i this is
4468    // diag(sqrt(a_i)) X Q; otherwise it is X Q.
4469    pub(crate) x_reduced: Array2<f64>,
4470    // Optional fixed case-weight square roots used when the Jeffreys/Firth
4471    // operator is formed from Xᵀ diag(case_weight ⊙ w(η)) X rather than
4472    // Xᵀ diag(w(η)) X. The exact directional tau derivatives must project and
4473    // row-scale with the same weights so the reduced Fisher, hat diagonals,
4474    // and tau kernels all live on one consistent identifiable subspace.
4475    pub(crate) observation_weight_sqrt: Option<Array1<f64>>,
4476    // I_r^{-1}
4477    pub(crate) k_reduced: Array2<f64>,
4478    // diag(S_r^{-1}) with S_r = X_rᵀ X_r. In the current canonical reduced
4479    // basis this completely characterizes the metric inverse, because Q
4480    // diagonalizes the design Gram. It is used to remove the reduced-coordinate
4481    // basis term from Phi_tau when the design moves.
4482    pub(crate) x_metric_reduced_inv_diag: Array1<f64>,
4483    // 0.5 (log|I_r| - log|S_r|) at the current eta.
4484    pub(crate) half_log_det: f64,
4485    // h = diag(M), M = X_r K_r X_r'
4486    pub(crate) h_diag: Array1<f64>,
4487    // Logistic Fisher-weight eta-derivatives: w', w'', w''', w'''' as n-vectors.
4488    pub(crate) w: Array1<f64>,
4489    pub(crate) w1: Array1<f64>,
4490    pub(crate) w2: Array1<f64>,
4491    pub(crate) w3: Array1<f64>,
4492    pub(crate) w4: Array1<f64>,
4493    // B = diag(w') X used in D Hphi and D^2 Hphi contractions.
4494    pub(crate) b_base: Array2<f64>,
4495    // Cached invariant contraction P*B where P = (X_r K_r X_r') ⊙ (X_r K_r X_r').
4496    // This avoids recomputing the same O(n r^2 p) block in every directional call.
4497    pub(crate) p_b_base: Array2<f64>,
4498}
4499
4500/// β-independent (design-only) factor of the Firth/Jeffreys operator.
4501///
4502/// Everything stored here depends ONLY on the fixed design `X` and the fixed
4503/// prior/observation weights `a_i` — NOT on the current linear predictor `η`
4504/// (i.e. NOT on β). For a single inner PIRLS solve the design and prior weights
4505/// are constant while `η` changes every Newton iteration, so this factor can be
4506/// built once per solve and reused, hoisting the O(n·p²) Gram, the O(p³)
4507/// identifiable-subspace eigendecomposition, and the two n×p design clones out
4508/// of the per-iteration hot path (#1575).
4509///
4510/// The β-dependent remainder (Fisher weights `w(η)`, reduced Fisher
4511/// `I_r = X_rᵀ W X_r`, its inverse `K_r`, the hat diagonal `h`, and the
4512/// half-log-determinant) is rebuilt per iteration from this factor via
4513/// [`FirthDenseOperator::build_from_design_factor`] (full operator) or
4514/// [`FirthDenseOperator::pirls_diagnostics_from_factor`] (the three PIRLS
4515/// diagnostics only). Both reproduce the un-hoisted build bit-for-bit.
4516#[derive(Clone)]
4517pub(crate) struct FirthDesignFactor {
4518    // Raw design and its transpose (the operator stores owned copies).
4519    pub(crate) x_dense: Array2<f64>,
4520    pub(crate) x_dense_t: Array2<f64>,
4521    // Orthonormal identifiable-subspace basis Q of (A^{1/2} X)ᵀ(A^{1/2} X).
4522    pub(crate) q_basis: Array2<f64>,
4523    // Reduced identifiable design X_r = A^{1/2} X Q.
4524    pub(crate) x_reduced: Array2<f64>,
4525    // Fixed case-weight square roots (sqrt(a_i)), if any.
4526    pub(crate) observation_weight_sqrt: Option<Array1<f64>>,
4527    // Retained positive spectrum of the design Gram = S_r diagonal.
4528    pub(crate) metric_spectrum: Array1<f64>,
4529    // diag(S_r^{-1}); precomputed reciprocal of `metric_spectrum`.
4530    pub(crate) x_metric_reduced_inv_diag: Array1<f64>,
4531    // rank r = ncols(q_basis); n = nrows(x_dense).
4532    pub(crate) r: usize,
4533    pub(crate) n: usize,
4534}
4535
4536#[derive(Clone)]
4537pub(crate) struct FirthDirection {
4538    pub(crate) deta: Array1<f64>,
4539    pub(crate) g_u_reduced: Array2<f64>,
4540    pub(crate) a_u_reduced: Array2<f64>,
4541    pub(crate) dh: Array1<f64>,
4542    // B_u = diag(w'' ⊙ δη_u) X is represented by the row-scaling vector only.
4543    pub(crate) b_uvec: Array1<f64>,
4544}
4545
4546#[derive(Clone)]
4547pub(crate) struct FirthTauPartialKernel {
4548    pub(super) deta_partial: Array1<f64>,
4549    pub(crate) dotw1: Array1<f64>,
4550    pub(crate) dotw2: Array1<f64>,
4551    pub(crate) dot_h_partial: Array1<f64>,
4552    // Reduced design drift X_{tau,r} = X_tau Q used in exact design-moving
4553    // Hadamard-Gram contractions.
4554    pub(crate) x_tau_reduced: Array2<f64>,
4555    pub(super) dot_i_partial: Array2<f64>,
4556    // Reduced Fisher inverse drift:
4557    //   dot(K_r) = -K_r dot(I_r) K_r
4558    // where dot(I_r) includes explicit X_tau and weight drift at beta-fixed.
4559    pub(crate) dot_k_reduced: Array2<f64>,
4560}
4561
4562#[derive(Clone)]
4563pub(crate) struct FirthTauExactKernel {
4564    pub(crate) gphi_tau: Array1<f64>,
4565    pub(crate) phi_tau_partial: f64,
4566    pub(crate) tau_kernel: Option<FirthTauPartialKernel>,
4567}
4568
4569/// Pair-level (τ_i × τ_j) exact Firth bundle at fixed β.
4570///
4571/// Mirrors `FirthTauExactKernel` but for the 2nd-order cross
4572/// derivatives:
4573///   Phi_{τ_i τ_j}|β  (scalar, `phi_tau_tau_partial`)
4574///   (gphi)_{τ_i τ_j}|β (p-vector, `gphi_tau_tau`)
4575///
4576/// Carries an optional `tau_tau_kernel` so pair callbacks can chain
4577/// into Primitive A (`hphi_tau_tau_partial_apply`) for the operator-
4578/// valued Hessian 2nd drift without recomputing shared reduced Grams.
4579///
4580#[derive(Clone)]
4581pub(crate) struct FirthTauTauExactKernel {
4582    pub(super) phi_tau_tau_partial: f64,
4583    pub(super) gphi_tau_tau: Array1<f64>,
4584    pub(super) tau_tau_kernel: Option<FirthTauTauPartialKernel>,
4585}
4586
4587/// Prepared state for `∂²H_φ/∂τ_i ∂τ_j |_β` (Primitive A).
4588///
4589/// Carries both τ-direction reduced designs, their η̇ vectors, and the
4590/// reduced-coordinate drifts (İ, K̇, ḣ) for i and j so the apply step can
4591/// form M̈_{ij}, K̈_{ij}, ḧ_{ij}, Γ̈_{ij}, and B̈_{ij} matrix-free.  Fields
4592/// are filled in by 13b; kept with a neutral internal shape so downstream
4593/// pair callbacks can hold the kernel across the pair dispatch.
4594///
4595/// Wired into the pair-callback's `b_operator` via
4596/// `FirthAugmentedPairHyperOperator`, and produced by both
4597/// `hphi_tau_tau_partial_prepare_from_partials` and
4598/// `exact_tau_tau_kernel` (the scalar/p-vector companion).
4599#[derive(Clone, Default)]
4600pub(crate) struct FirthTauTauPartialKernel {
4601    pub(super) x_tau_i_reduced: Array2<f64>,
4602    pub(super) x_tau_j_reduced: Array2<f64>,
4603    pub(super) deta_i_partial: Array1<f64>,
4604    pub(super) deta_j_partial: Array1<f64>,
4605    pub(super) dot_h_i_partial: Array1<f64>,
4606    pub(super) dot_h_j_partial: Array1<f64>,
4607    pub(super) dot_k_i_reduced: Array2<f64>,
4608    pub(super) dot_k_j_reduced: Array2<f64>,
4609    pub(super) dot_i_i_partial: Array2<f64>,
4610    pub(super) dot_i_j_partial: Array2<f64>,
4611    pub(super) x_tau_tau_reduced: Option<Array2<f64>>,
4612    pub(super) deta_ij_partial: Option<Array1<f64>>,
4613}
4614
4615/// Prepared state for `D_β((H_φ)_τ|_β)[v]` (Primitive B).
4616///
4617/// Carries the τ-kernel pieces (x_tau_reduced, İ, K̇, ḣ), the
4618/// β-direction quantities (δη_v, A_v, dh_v, b-chain), and the mixed
4619/// β-τ pieces (D_β(K̇_τ)[v], D_β(ḣ_τ)[v], δη_{τ,v}) so the apply
4620/// step collapses to the 9-term β-τ expansion without recomputing
4621/// shared reduced Grams.
4622#[derive(Clone, Default)]
4623pub(crate) struct FirthTauBetaPartialKernel {
4624    pub(super) x_tau_reduced: Array2<f64>,
4625    pub(super) deta_partial: Array1<f64>,
4626    pub(super) dot_h_partial: Array1<f64>,
4627    pub(super) dot_i_partial: Array2<f64>,
4628    pub(super) dot_k_reduced: Array2<f64>,
4629    pub(super) deta_v: Array1<f64>,
4630    pub(super) deta_tau_v: Array1<f64>,
4631    pub(super) a_v_reduced: Array2<f64>,
4632    pub(super) dh_v: Array1<f64>,
4633    pub(super) b_vvec: Array1<f64>,
4634    pub(super) d_beta_dot_k: Array2<f64>,
4635    pub(super) d_beta_dot_h: Array1<f64>,
4636}
4637
4638/// Holds the state for the outer REML optimization and supplies cost and
4639/// gradient evaluations to the `opt` optimizer.
4640///
4641/// The `cache` field uses `RefCell` to enable interior mutability. This is a crucial
4642/// performance optimization. The `cost_andgrad` closure required by the BFGS
4643/// optimizer takes an immutable reference `&self`. However, we want to cache the
4644/// results of the expensive P-IRLS computation to avoid re-calculating the fit
4645/// for the same `rho` vector, which can happen during the line search.
4646/// `RefCell` allows us to mutate the cache through a `&self` reference,
4647/// making this optimization possible while adhering to the optimizer's API.
4648#[derive(Clone)]
4649pub(crate) struct EvalShared {
4650    pub(crate) key: Option<Vec<u64>>,
4651    pub(crate) pirls_result: Arc<PirlsResult>,
4652    pub(crate) ridge_passport: RidgePassport,
4653    pub(crate) geometry: RemlGeometry,
4654    /// The exact H_total matrix used for LAML cost computation.
4655    /// For Firth: effective Hessian minus hphi (plus any barrier curvature).
4656    /// For non-Firth: the effective Hessian itself (plus any barrier curvature).
4657    pub(crate) h_total: Arc<Array2<f64>>,
4658    pub(crate) sparse_exact: Option<Arc<SparseExactEvalData>>,
4659    pub(crate) firth_dense_operator: Option<Arc<FirthDenseOperator>>,
4660    /// Cached FirthDenseOperator built from the original (non-reparameterized)
4661    /// design matrix, for use by the sparse evaluation path.
4662    pub(crate) firth_dense_operator_original: Option<Arc<FirthDenseOperator>>,
4663    /// The ONE original-frame penalty pseudo-logdet factorization for this
4664    /// evaluation point (#931 atom discipline). `log|Σ λ_k S_k|₊`'s VALUE,
4665    /// ρ-derivatives, τ/ψ components, and ρ×τ cross blocks are all
4666    /// contractions of this single eigendecomposition; the ρ-side criterion
4667    /// assembly (`dense_penalty_logdet_derivs`, the sparse det2 path) and the
4668    /// original-basis hyper-coordinate builders share it through
4669    /// [`EvalShared::penalty_pseudologdet_original`]. Building a second
4670    /// factorization of the same Sλ for the same evaluation point is the
4671    /// objective↔gradient desync surface (#748/#752/#901) this cell removes:
4672    /// the ridge and positive-eigenspace threshold are decided exactly once.
4673    /// (The transformed-frame pair-callback path builds its own object — it
4674    /// factorizes the canonical-TRANSFORMED, possibly constraint-projected
4675    /// penalties, a genuinely different matrix, not a duplicate of this one.)
4676    pub(crate) penalty_pseudologdet: std::sync::OnceLock<Arc<penalty_logdet::PenaltyPseudologdet>>,
4677    /// Per-evaluation-point cache of the canonical penalty score vectors
4678    /// `S_k β̂` evaluated at this bundle's inner mode `β̂ =
4679    /// pirls_result.beta_transformed` (unscaled by λ_k). These depend ONLY
4680    /// on the inner solution carried by this bundle and the `RemlState`'s
4681    /// fixed `canonical_penalties` — never on which ρ-coordinate or eval
4682    /// mode the assembly is running — so they are computed exactly once per
4683    /// inner solution and shared by every assemble call that reuses the
4684    /// bundle (cost + gradient evaluations at the same ρ, EFS, synthetic-ext
4685    /// value probes). Exact hoist, not an approximation: every consumer sees
4686    /// literally the same vectors it previously recomputed. Initialized via
4687    /// plain ndarray matvecs (no rayon inside the `OnceLock` closure — the
4688    /// `get_or_init`+`into_par_iter` deadlock trap does not apply).
4689    pub(crate) penalty_scores_at_mode: std::sync::OnceLock<Arc<Vec<Array1<f64>>>>,
4690    /// Per-evaluation-point cache of the #784 block-local Laplace-to-sampling
4691    /// correction `TkCorrectionTerms { value, gradient }`. The correction is a
4692    /// deterministic function of ONLY this bundle's converged inner state
4693    /// (`pirls_result`, `h_total`), the `RemlState`'s fixed
4694    /// `canonical_penalties`, and the bundle's ρ — never of the eval `mode`:
4695    /// the diagnostic eigendecomposition, the fixed-seed importance sampler,
4696    /// and the (b)–(d) gradient channels all read mode-invariant fields, and
4697    /// the term carries no Hessian, so the value+gradient are identical for the
4698    /// value-only, value+gradient, and value+gradient+Hessian assemble calls
4699    /// that share this bundle at a single ρ. The expensive path (eigendecomp +
4700    /// O(draws·n·m) sampler) previously reran on every one of those 2–3 calls
4701    /// per outer iteration; hoisting it onto the bundle computes it exactly
4702    /// once per inner solution (exact hoist, identical values — #784, #1082).
4703    /// Keyed only on the external-coordinate count `n_ext`: with no ψ
4704    /// coordinates (`n_ext == 0`) the correction engages; with ψ present the
4705    /// seam declines (returns the cheap zero), and n_ext is fixed for a fit, so
4706    /// a single cell suffices.
4707    pub(crate) block_local_correction:
4708        std::sync::OnceLock<(usize, Arc<outer_eval::TkCorrectionTerms>)>,
4709}
4710
4711impl EvalShared {
4712    pub(crate) fn matches(&self, key: &Option<Vec<u64>>) -> bool {
4713        match (&self.key, key) {
4714            (None, None) => true,
4715            (Some(a), Some(b)) => a == b,
4716            _ => false,
4717        }
4718    }
4719
4720    /// Lazily build — once per evaluation point — the original-frame
4721    /// [`PenaltyPseudologdet`](penalty_logdet::PenaltyPseudologdet) of
4722    /// `Σ λ_k S_k` and hand every caller the SAME factorization.
4723    ///
4724    /// This is the #931 port of the penalty-logdet term: value, ρ-first /
4725    /// ρ-second derivatives, τ-gradient components, τ×τ and ρ×τ Hessian
4726    /// blocks are all projections of one eigendecomposition, so no pair of
4727    /// consumers can disagree about the ridge or the positive-eigenspace
4728    /// threshold. The ridge is read from this bundle's `ridge_passport` —
4729    /// the single place that convention is decided.
4730    ///
4731    /// `lambdas` must be the λ = exp(ρ) vector of this bundle's evaluation
4732    /// point and `p` the original-basis coefficient dimension; on a cache
4733    /// hit both are checked against the stored object where representable.
4734    pub(crate) fn penalty_pseudologdet_original(
4735        &self,
4736        canonical_penalties: &[gam_terms::construction::CanonicalPenalty],
4737        lambdas: &[f64],
4738        p: usize,
4739    ) -> Result<Arc<penalty_logdet::PenaltyPseudologdet>, EstimationError> {
4740        if let Some(pld) = self.penalty_pseudologdet.get() {
4741            if pld.dim() != p {
4742                return Err(EstimationError::LayoutError(format!(
4743                    "shared penalty pseudo-logdet frame mismatch: cached p={}, requested p={}",
4744                    pld.dim(),
4745                    p
4746                )));
4747            }
4748            return Ok(Arc::clone(pld));
4749        }
4750        let pld = Arc::new(
4751            penalty_logdet::PenaltyPseudologdet::from_penalties(
4752                canonical_penalties,
4753                lambdas,
4754                self.ridge_passport.penalty_logdet_ridge(),
4755                p,
4756            )
4757            .map_err(EstimationError::InvalidInput)?,
4758        );
4759        match self.penalty_pseudologdet.set(Arc::clone(&pld)) {
4760            Ok(()) => Ok(pld),
4761            // A concurrent caller initialized the cell first; both objects
4762            // were built from identical inputs — return the canonical winner
4763            // so every consumer holds literally the same factorization.
4764            Err(_) => Ok(Arc::clone(
4765                self.penalty_pseudologdet
4766                    .get()
4767                    .expect("OnceLock set raced, so it is initialized"),
4768            )),
4769        }
4770    }
4771}
4772
4773impl PenalizedGeometry for EvalShared {
4774    fn backend_kind(&self) -> GeometryBackendKind {
4775        match self.geometry {
4776            RemlGeometry::DenseSpectral => GeometryBackendKind::DenseSpectral,
4777            RemlGeometry::SparseExactSpd => GeometryBackendKind::SparseExactSpd,
4778        }
4779    }
4780}
4781
4782/// LRU cache keyed by sanitized ρ vectors that holds compacted PIRLS results
4783/// for warm-starting outer line searches and revisited evaluations.
4784///
4785/// Eviction is byte-budgeted rather than entry-count-budgeted: each entry
4786/// records its own estimated footprint (the surviving n-length vectors plus
4787/// the two p×p Hessians plus per-entry overhead) and the cache evicts in
4788/// LRU order until the running total fits under the budget. An entry that
4789/// individually exceeds the budget is rejected silently rather than poisoning
4790/// the cache.
4791pub(crate) struct PirlsLruCache {
4792    // Stored tuple: (compacted result, last-touched clock, estimated bytes).
4793    pub(crate) map: HashMap<Vec<u64>, (Arc<PirlsResult>, u64, usize)>,
4794    pub(crate) byte_budget: usize,
4795    pub(crate) current_bytes: usize,
4796    pub(crate) clock: u64,
4797}
4798
4799impl PirlsLruCache {
4800    pub(crate) fn new(byte_budget: usize) -> Self {
4801        Self {
4802            map: HashMap::new(),
4803            byte_budget: byte_budget.max(1),
4804            current_bytes: 0,
4805            clock: 0,
4806        }
4807    }
4808
4809    pub(crate) fn get(&mut self, key: &Vec<u64>) -> Option<Arc<PirlsResult>> {
4810        if let Some(entry) = self.map.get_mut(key) {
4811            self.clock += 1;
4812            entry.1 = self.clock;
4813            Some(entry.0.clone())
4814        } else {
4815            None
4816        }
4817    }
4818
4819    pub(crate) fn insert(&mut self, key: Vec<u64>, value: Arc<PirlsResult>) {
4820        self.clock += 1;
4821        let bytes = pirls_result_cache_bytes(&value);
4822        // Refuse entries that on their own already exceed the entire budget;
4823        // caching one would force eviction of every other entry without
4824        // leaving room for the new one anyway.
4825        if bytes > self.byte_budget {
4826            if let Some((_, _, prev_bytes)) = self.map.remove(&key) {
4827                self.current_bytes = self.current_bytes.saturating_sub(prev_bytes);
4828            }
4829            return;
4830        }
4831        if let Some((_, _, prev_bytes)) = self.map.remove(&key) {
4832            self.current_bytes = self.current_bytes.saturating_sub(prev_bytes);
4833        }
4834        while self.current_bytes + bytes > self.byte_budget {
4835            let evict_key = self
4836                .map
4837                .iter()
4838                .min_by_key(|(_, (_, ts, _))| *ts)
4839                .map(|(k, _)| k.clone());
4840            match evict_key {
4841                Some(k) => {
4842                    if let Some((_, _, evict_bytes)) = self.map.remove(&k) {
4843                        self.current_bytes = self.current_bytes.saturating_sub(evict_bytes);
4844                    }
4845                }
4846                None => break,
4847            }
4848        }
4849        self.current_bytes += bytes;
4850        self.map.insert(key, (value, self.clock, bytes));
4851    }
4852
4853    pub(crate) fn clear(&mut self) {
4854        self.map.clear();
4855        self.current_bytes = 0;
4856    }
4857}
4858
4859#[derive(Clone, Copy, PartialEq, Eq)]
4860pub(crate) struct PenaltySubspaceCacheKey {
4861    pub(crate) penalty_matrix_fingerprint: u64,
4862    pub(crate) ridge_passport_signature: u64,
4863}
4864
4865pub(crate) struct PenaltySubspaceCache {
4866    pub(crate) entry: Option<(PenaltySubspaceCacheKey, Arc<outer_eval::PenaltySubspace>)>,
4867}
4868
4869impl PenaltySubspaceCache {
4870    pub(crate) fn new() -> Self {
4871        Self { entry: None }
4872    }
4873
4874    pub(crate) fn get(
4875        &self,
4876        key: &PenaltySubspaceCacheKey,
4877    ) -> Option<Arc<outer_eval::PenaltySubspace>> {
4878        self.entry
4879            .as_ref()
4880            .filter(|(cached_key, _)| cached_key == key)
4881            .map(|(_, value)| value.clone())
4882    }
4883
4884    pub(crate) fn insert(
4885        &mut self,
4886        key: PenaltySubspaceCacheKey,
4887        value: Arc<outer_eval::PenaltySubspace>,
4888    ) {
4889        self.entry = Some((key, value));
4890    }
4891
4892    pub(crate) fn clear(&mut self) {
4893        self.entry = None;
4894    }
4895}
4896
4897impl PenaltySubspaceCacheKey {
4898    /// Build a cache key from the transformed-E matrix and ridge passport.
4899    /// `E` is hashed by exact f64 bits (column-major), so the key is bit-exact
4900    /// and avoids float-Hash issues; the ridge passport is hashed via its
4901    /// `Hash` impl. Two calls at the same `(E, ridge)` yield equal keys.
4902    pub(crate) fn from_inputs(
4903        e_transformed: &ndarray::Array2<f64>,
4904        ridge_passport: &gam_problem::RidgePassport,
4905    ) -> Self {
4906        use std::collections::hash_map::DefaultHasher;
4907        use std::hash::{Hash, Hasher};
4908        let mut hasher = DefaultHasher::new();
4909        e_transformed.nrows().hash(&mut hasher);
4910        e_transformed.ncols().hash(&mut hasher);
4911        for value in e_transformed.iter() {
4912            value.to_bits().hash(&mut hasher);
4913        }
4914        let penalty_matrix_fingerprint = hasher.finish();
4915        let mut ridge_hasher = DefaultHasher::new();
4916        ridge_passport.delta.to_bits().hash(&mut ridge_hasher);
4917        (ridge_passport.matrix_form as u8).hash(&mut ridge_hasher);
4918        ridge_passport
4919            .policy
4920            .include_penalty_logdet
4921            .hash(&mut ridge_hasher);
4922        ridge_passport
4923            .policy
4924            .include_laplacehessian
4925            .hash(&mut ridge_hasher);
4926        let ridge_passport_signature = ridge_hasher.finish();
4927        Self {
4928            penalty_matrix_fingerprint,
4929            ridge_passport_signature,
4930        }
4931    }
4932}
4933
4934/// Estimate the in-cache footprint of a (compacted) PIRLS result.
4935///
4936/// Mirrors what `compact_for_reml_cache` keeps:
4937/// * six surviving n-length f64 arrays (final_eta, solveweights,
4938///   solveworking_response, solvemu, solve_c_array, solve_d_array);
4939/// * the p-length coefficient vector;
4940/// * the two p×p Hessians (dense or CSC sparse);
4941/// * the `ReparamResult` payload — the dominant scaling term beyond n, since
4942///   it carries `s_transformed`, `qs`, and `e_transformed` as p×p / rank×p
4943///   matrices.
4944/// A small constant overhead absorbs scalar fields, enum discriminants, and
4945/// the HashMap entry. This errs on the conservative side: overestimation
4946/// causes earlier eviction, never under-counting that would let the cache
4947/// silently exceed the byte budget.
4948pub(crate) fn pirls_result_cache_bytes(result: &PirlsResult) -> usize {
4949    use std::mem::size_of;
4950    let n_array_elems = result.final_eta.len()
4951        + result.solveweights.len()
4952        + result.solveworking_response.len()
4953        + result.solvemu.len()
4954        + result.solve_c_array.len()
4955        + result.solve_d_array.len();
4956    let p = result.beta_transformed.0.len();
4957    let pen_h = symmetric_matrix_cache_bytes(&result.penalized_hessian_transformed);
4958    let stab_h = symmetric_matrix_cache_bytes(&result.stabilizedhessian_transformed);
4959    let reparam = (result.reparam_result.s_transformed.len()
4960        + result.reparam_result.qs.len()
4961        + result.reparam_result.e_transformed.len()
4962        + result.reparam_result.det1.len())
4963        * size_of::<f64>();
4964    n_array_elems * size_of::<f64>() + p * size_of::<f64>() + pen_h + stab_h + reparam + 1024
4965}
4966
4967pub(crate) fn symmetric_matrix_cache_bytes(m: &gam_linalg::matrix::SymmetricMatrix) -> usize {
4968    use gam_linalg::matrix::SymmetricMatrix;
4969    use std::mem::size_of;
4970    match m {
4971        SymmetricMatrix::Dense(a) => a.len() * size_of::<f64>(),
4972        SymmetricMatrix::Sparse(s) => {
4973            // CSC sparse: f64 values + usize row indices + usize column pointers.
4974            let (symbolic, values) = s.parts();
4975            values.len() * (size_of::<f64>() + size_of::<usize>())
4976                + std::mem::size_of_val(symbolic.col_ptr())
4977        }
4978    }
4979}
4980
4981/// Capacity (number of distinct rho-points) of the outer-eval reuse LRU.
4982///
4983/// Sized to comfortably span a binomial seed grid's local revisit window
4984/// (baseline + isotropic shifts + per-axis refinements) plus a few
4985/// line-search trial points without unbounded growth. Each slot holds one
4986/// `OuterEval` (a scalar cost, a length-k gradient, an optional inner-beta
4987/// hint, and a usually-`Unavailable` Hessian), so the footprint is tiny.
4988pub(crate) const OUTER_EVAL_LRU_CAPACITY: usize = 8;
4989
4990/// Bounded least-recently-used cache of converged outer REML evaluations,
4991/// keyed by sanitized rho-bits.
4992///
4993/// CORRECTNESS: the key (`Vec<u64>` of `f64::to_bits`, with ±0 canonicalized)
4994/// is the complete result-determining input for a fixed `RemlState`. Every
4995/// other input to `OuterEval` — design matrix, prior weights, offset, penalty
4996/// structure, link/SAS/mixture state, Firth/Jeffreys configuration, and the
4997/// rho-prior — is immutable for the lifetime of the state that owns the cache,
4998/// so the stored cost / gradient / inner-beta hint depend only on rho. A hit
4999/// therefore returns exactly the value a recompute at that rho would converge
5000/// to (to the solver's own tolerance, identical to the trust the pre-existing
5001/// single-slot cache already placed in rho-only keying). Distinct rho-points
5002/// never alias: lookups compare the full key vector.
5003pub(crate) struct OuterEvalLru {
5004    capacity: usize,
5005    /// Front = least-recently-used, back = most-recently-used.
5006    entries: std::collections::VecDeque<(Vec<u64>, OuterEval)>,
5007}
5008
5009impl OuterEvalLru {
5010    pub(crate) fn new(capacity: usize) -> Self {
5011        Self {
5012            capacity: capacity.max(1),
5013            entries: std::collections::VecDeque::new(),
5014        }
5015    }
5016
5017    /// Returns a clone of the eval stored under `key`, if present, promoting it
5018    /// to most-recently-used. A miss returns `None` so the caller recomputes —
5019    /// never a stale value from a different key.
5020    pub(crate) fn get(&mut self, key: &[u64]) -> Option<OuterEval> {
5021        let pos = self
5022            .entries
5023            .iter()
5024            .position(|(k, _)| k.as_slice() == key)?;
5025        let entry = self.entries.remove(pos)?;
5026        let eval = entry.1.clone();
5027        self.entries.push_back(entry);
5028        Some(eval)
5029    }
5030
5031    /// Inserts (or refreshes) the eval for `key` as most-recently-used,
5032    /// evicting the least-recently-used entry once capacity is exceeded.
5033    pub(crate) fn insert(&mut self, key: Vec<u64>, eval: OuterEval) {
5034        if let Some(pos) = self
5035            .entries
5036            .iter()
5037            .position(|(k, _)| k.as_slice() == key.as_slice())
5038        {
5039            self.entries.remove(pos);
5040        }
5041        self.entries.push_back((key, eval));
5042        while self.entries.len() > self.capacity {
5043            self.entries.pop_front();
5044        }
5045    }
5046
5047    pub(crate) fn clear(&mut self) {
5048        self.entries.clear();
5049    }
5050}
5051
5052/// Centralized cache/memoization owner for REML evaluations.
5053///
5054/// This keeps cache-key identity, bundle reuse, and invalidation policy out of
5055/// the math kernels so objective/derivative routines can stay algebra-focused.
5056pub(crate) struct EvalCacheManager {
5057    pub(crate) pirls_cache: RwLock<PirlsLruCache>,
5058    pub(crate) penalty_subspace_cache: RwLock<PenaltySubspaceCache>,
5059    pub(crate) current_eval_bundle: RwLock<Option<EvalShared>>,
5060    /// Most-recently-*stored* outer eval (single slot). Retained verbatim so
5061    /// `previous_outer_gradient_norm` keeps its exact "immediately previous
5062    /// distinct eval" semantics, independent of the multi-slot reuse cache.
5063    pub(crate) current_outer_eval: RwLock<Option<(Vec<u64>, OuterEval)>>,
5064    /// Bounded multi-slot LRU of converged outer evaluations keyed by the
5065    /// sanitized rho-bits (#1575).
5066    ///
5067    /// For a frozen `RemlState` (fixed design, prior weights, offset, penalty
5068    /// structure, link state, Firth/Jeffreys configuration, and rho-prior — all
5069    /// of which are immutable for the lifetime of the state that owns this
5070    /// manager and therefore the lifetime of the cache), the outer objective
5071    /// value, its gradient, and the inner-beta hint are deterministic functions
5072    /// of rho alone. The sanitized rho-bits are thus the complete result key.
5073    /// The binomial REML fit performs ~20-32 seed-grid pre-solves plus
5074    /// line-search revisits; with only the single `current_outer_eval` slot,
5075    /// any revisit to an earlier rho re-ran a full n-sized P-IRLS. This LRU
5076    /// returns the stored cost/gradient for those revisited rho-points.
5077    pub(crate) outer_eval_lru: RwLock<OuterEvalLru>,
5078    pub(crate) pirls_cache_enabled: AtomicBool,
5079}
5080
5081impl EvalCacheManager {
5082    pub(crate) fn new() -> Self {
5083        Self {
5084            pirls_cache: RwLock::new(PirlsLruCache::new(PIRLS_CACHE_BYTE_BUDGET)),
5085            penalty_subspace_cache: RwLock::new(PenaltySubspaceCache::new()),
5086            current_eval_bundle: RwLock::new(None),
5087            current_outer_eval: RwLock::new(None),
5088            outer_eval_lru: RwLock::new(OuterEvalLru::new(OUTER_EVAL_LRU_CAPACITY)),
5089            pirls_cache_enabled: AtomicBool::new(true),
5090        }
5091    }
5092
5093    /// Creates a sanitized cache key from rho values.
5094    /// Returns None if any component is NaN, in which case caching is skipped.
5095    /// Maps -0.0 to 0.0 to ensure key stability.
5096    pub(crate) fn sanitized_rhokey(rho: &Array1<f64>) -> Option<Vec<u64>> {
5097        self::rho_key::sanitized_rhokey(rho)
5098    }
5099
5100    /// Memoizing wrapper for `PenaltySubspace` construction.
5101    ///
5102    /// The penalty-subspace eigendecomposition is shape-invariant: any two
5103    /// outer evaluations at the same `(E_transformed, ridge_passport)` produce
5104    /// bit-identical subspaces. The single-slot cache amortizes consecutive
5105    /// fixed-S queries (rank, logdet, trace) within a single outer iter.
5106    pub(super) fn cached_penalty_subspace<F>(
5107        &self,
5108        e_transformed: &ndarray::Array2<f64>,
5109        ridge_passport: &gam_problem::RidgePassport,
5110        build: F,
5111    ) -> Result<Arc<outer_eval::PenaltySubspace>, EstimationError>
5112    where
5113        F: FnOnce() -> Result<outer_eval::PenaltySubspace, EstimationError>,
5114    {
5115        let key = PenaltySubspaceCacheKey::from_inputs(e_transformed, ridge_passport);
5116        if let Some(hit) = self.penalty_subspace_cache.read().unwrap().get(&key) {
5117            return Ok(hit);
5118        }
5119        let value = Arc::new(build()?);
5120        self.penalty_subspace_cache
5121            .write()
5122            .unwrap()
5123            .insert(key, value.clone());
5124        Ok(value)
5125    }
5126
5127    pub(crate) fn cached_eval_bundle(&self, key: &Option<Vec<u64>>) -> Option<EvalShared> {
5128        let guard = self.current_eval_bundle.read().unwrap();
5129        let bundle: &EvalShared = guard.as_ref()?;
5130        bundle.matches(key).then(|| bundle.clone())
5131    }
5132
5133    pub(crate) fn store_eval_bundle(&self, bundle: EvalShared) {
5134        *self.current_eval_bundle.write().unwrap() = Some(bundle);
5135    }
5136
5137    pub(crate) fn cached_outer_eval(&self, key: &Option<Vec<u64>>) -> Option<OuterEval> {
5138        let key = key.as_ref()?;
5139        // The LRU is the authoritative multi-slot store; it always contains the
5140        // most-recently-stored eval too (kept in sync by `store_outer_eval`), so
5141        // a single LRU probe subsumes the old single-slot fast path while also
5142        // serving revisited (non-immediate) rho-points. `get` is a tiny linear
5143        // scan (capacity is `OUTER_EVAL_LRU_CAPACITY`) that promotes the hit to
5144        // most-recently-used; hence the write lock.
5145        self.outer_eval_lru.write().unwrap().get(key)
5146    }
5147
5148    pub(crate) fn store_outer_eval(&self, key: &Option<Vec<u64>>, eval: &OuterEval) {
5149        if let Some(key) = key.clone() {
5150            // Keep the single-slot mirror for `previous_outer_gradient_norm`,
5151            // whose "immediately previous distinct eval" contract reads it
5152            // directly and must stay byte-for-byte unchanged.
5153            *self.current_outer_eval.write().unwrap() = Some((key.clone(), eval.clone()));
5154            self.outer_eval_lru.write().unwrap().insert(key, eval.clone());
5155        }
5156    }
5157
5158    pub(crate) fn invalidate_eval_bundle(&self) {
5159        self.current_eval_bundle.write().unwrap().take();
5160        self.current_outer_eval.write().unwrap().take();
5161        self.outer_eval_lru.write().unwrap().clear();
5162    }
5163
5164    pub(crate) fn clear_eval_and_factor_caches(&self) {
5165        self.invalidate_eval_bundle();
5166        self.penalty_subspace_cache.write().unwrap().clear();
5167    }
5168}
5169
5170/// Reusable scratch/runtime memory that should not be part of mathematical
5171/// state invariants.
5172pub(crate) struct RemlArena {
5173    pub(crate) cost_eval_count: RwLock<u64>,
5174    /// Number of *actual* full-n inner P-IRLS solves performed (#1575).
5175    ///
5176    /// Distinct from `cost_eval_count`, which counts every outer cost/gradient
5177    /// REQUEST including single-slot cache hits and prior short-circuits. This
5178    /// counts only the cache-missing `prepare_eval_bundlewithkey` calls — i.e.
5179    /// the genuinely expensive `O(n·p²)` inner solves the #1575 slowdown is
5180    /// about ("~150 outer cost evals each running a full n-sized P-IRLS"). A
5181    /// healthy warm-started fit performs roughly 2 inner solves per outer
5182    /// cost-eval (one value, one gradient/Hessian), so a large ratio between
5183    /// the two signals broken warm-starting or duplicate solving. This is pure
5184    /// observability: it never feeds back into the optimization and changes no
5185    /// fitted value.
5186    pub(crate) inner_pirls_solve_count: AtomicU64,
5187    pub(crate) lastgradient_used_stochastic_fallback: AtomicBool,
5188}
5189
5190impl RemlArena {
5191    pub(crate) fn new() -> Self {
5192        Self {
5193            cost_eval_count: RwLock::new(0),
5194            inner_pirls_solve_count: AtomicU64::new(0),
5195            lastgradient_used_stochastic_fallback: AtomicBool::new(false),
5196        }
5197    }
5198}
5199
5200pub(crate) struct AloFrozenNuisance {
5201    pub(crate) n_obs: usize,
5202    pub(crate) influence_scale: Vec<f64>,
5203    pub(crate) phi: f64,
5204}
5205
5206pub(crate) struct RemlState<'a> {
5207    pub(crate) y: ArrayView1<'a, f64>,
5208    pub(crate) x: DesignMatrix,
5209    pub(crate) weights: ArrayView1<'a, f64>,
5210    pub(crate) offset: Array1<f64>,
5211    /// Canonicalized block-local penalties with pre-computed roots.
5212    /// This is the single canonical penalty representation — no full-width
5213    /// `rank × p` roots are stored separately.
5214    pub(crate) canonical_penalties: Arc<Vec<gam_terms::construction::CanonicalPenalty>>,
5215    pub(crate) balanced_penalty_root: Array2<f64>,
5216    pub(crate) reparam_invariant: ReparamInvariant,
5217    pub(crate) sparse_penalty_block_count: Option<usize>,
5218    pub(crate) p: usize,
5219    pub(crate) config: Arc<RemlConfig>,
5220    pub(crate) runtime_mixture_link_state: Option<gam_problem::MixtureLinkState>,
5221    pub(crate) runtime_sas_link_state: Option<SasLinkState>,
5222    pub(crate) nullspace_dims: Vec<usize>,
5223    pub(crate) coefficient_lower_bounds: Option<Array1<f64>>,
5224    pub(crate) linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
5225    /// Relative shrinkage floor for penalized block eigenvalues (rho-independent).
5226    pub(crate) penalty_shrinkage_floor: Option<f64>,
5227    /// Explicit prior on log smoothing parameters used by the REML/LAML objective.
5228    pub(crate) rho_prior: gam_problem::RhoPrior,
5229
5230    pub(crate) cache_manager: EvalCacheManager,
5231    pub(crate) arena: RemlArena,
5232    pub(crate) warm_start_beta: RwLock<Option<Coefficients>>,
5233    /// Two-point ρ-trajectory used for second-order warm-start
5234    /// extrapolation: when the outer optimizer asks for a fit at a new
5235    /// ρ, we have `β(ρ_k)` (in `warm_start_beta`) and `β(ρ_{k-1})` (in
5236    /// `prev_warm_start_beta`). The implicit β(ρ) trajectory is locally
5237    /// linear under the FOC ∇F(β,ρ)=0, so a tangent-line prediction
5238    /// `β_predict(ρ_new) = β_k + α · (β_k − β_{k-1})` where α is the
5239    /// projection of `(ρ_new − ρ_k)` onto `(ρ_k − ρ_{k-1})` gives a
5240    /// better seed than the flat `β_k` alone — replacing PIRLS warm-
5241    /// start "use last β as-is" with a real tangent-prediction step.
5242    pub(crate) warm_start_rho: RwLock<Option<Array1<f64>>>,
5243    pub(crate) prev_warm_start_beta: RwLock<Option<Coefficients>>,
5244    pub(crate) prev_warm_start_rho: RwLock<Option<Array1<f64>>>,
5245    pub(crate) warm_start_enabled: AtomicBool,
5246    pub(crate) screening_max_inner_iterations: Arc<AtomicUsize>,
5247    /// Outer-aware inner-PIRLS iteration cap for the main descent loop.
5248    ///
5249    /// Distinct from `screening_max_inner_iterations`, which is used during
5250    /// seed selection and toggles a side-effect bundle (cache writes,
5251    /// warm-start updates, KKT enforcement all suppressed). This atomic is
5252    /// purely a cap — when nonzero, the inner Newton loop is capped at
5253    /// `min(this, full_max_iterations)`, but cache writes and warm-start
5254    /// updates remain enabled. Driven by the outer optimizer to coarsen
5255    /// inner solves at early outer iterations when ρ is far from converged,
5256    /// and lifted back to full at the final accepted iter (otherwise the
5257    /// returned β would be biased by the loose cap).
5258    ///
5259    /// Both atomics are honored together as `min(screening_cap, outer_cap)`
5260    /// when both are nonzero. Default 0 (no cap from this source).
5261    pub(crate) outer_inner_cap: Arc<AtomicUsize>,
5262
5263    /// Inner-PIRLS feedback signal driven by `execute_pirls_if_needed` after
5264    /// each NON-screening solve. Stores the iteration count at which the
5265    /// inner Newton stopped, plus a flag indicating whether it converged
5266    /// (vs. hit the iteration cap). The outer first-/second-order bridges
5267    /// read these atomics to drive an adaptive `inner_cap_schedule`: the
5268    /// next outer iter's inner cap becomes `last_iters + small_margin`
5269    /// when the previous solve converged, or a geometric backoff when it
5270    /// hit the cap. This replaces the older hardcoded iter-tier schedule
5271    /// (3/5/10/20) with a cap that follows the inner solver's actual
5272    /// convergence behavior — Eisenstat-Walker style for the inner
5273    /// quadratic loop. Default 0 / false (no signal yet — first outer
5274    /// iter falls back to a coarse iter-count tier).
5275    pub(crate) last_inner_iters: Arc<AtomicUsize>,
5276    pub(crate) last_inner_converged: Arc<AtomicBool>,
5277
5278    /// Cached state from the most recent successful PIRLS solve, used by
5279    /// the IFT-based warm-start predictor.
5280    ///
5281    /// The implicit-function theorem applied to the FOC ∇_β F(β,ρ)=0
5282    /// gives `dβ/dρ_k = -H_pen^{-1} · (e^{ρ_k} · S_k · β)`. A first-order
5283    /// Taylor predictor reads
5284    /// `β_predict(ρ_new) = β_cur − Σ_k Δρ_k · H_pen^{-1} · (e^{ρ_cur_k} · S_k · β_cur)`.
5285    /// This is a strict superset of the tangent-line predictor's
5286    /// requirements: works after a single successful solve (tangent-line
5287    /// needs two prior fits), and gives the EXACT first-order Jacobian
5288    /// of the implicit β(ρ) trajectory rather than a secant proxy along one
5289    /// ρ-direction.
5290    ///
5291    /// Populated in `updatewarm_start_from` when PIRLS converges; cleared
5292    /// on failure, on `reset_surface`, and on link-state changes.
5293    pub(crate) ift_warm_start_cache: RwLock<Option<IftWarmStartCache>>,
5294
5295    /// Persisted Levenberg-Marquardt damping coefficient from the most
5296    /// recent successful PIRLS solve, bit-packed into an `AtomicU64`
5297    /// (`f64::to_bits` low 64 bits). Read at the start of
5298    /// `execute_pirls_if_needed` and written into the
5299    /// `PirlsConfig::initial_lm_lambda` hint so the inner Newton seeds
5300    /// `λ_LM` near the damping the previous solve discovered, instead
5301    /// of cold-starting at `1e-6` and burning 4-6 halving steps to
5302    /// recover. `0` (the default) signals "no hint"; the inner solver
5303    /// clamps any positive hint into `[1e-6, 1e-3]` so a stale value
5304    /// cannot destabilize the next solve. Reset on `reset_surface` and
5305    /// on failed solves.
5306    pub(crate) last_pirls_lm_lambda: Arc<AtomicU64>,
5307
5308    /// Negative-Binomial overdispersion `theta` frozen for the smoothing-
5309    /// parameter (λ) search (#1082), bit-packed `f64` (`f64::to_bits`). `0`
5310    /// (the default) signals "not yet frozen". On the first non-screening
5311    /// λ-search inner solve of an estimated-θ NB fit, the seed's
5312    /// maximum-likelihood θ is computed once and stored here; every subsequent
5313    /// λ-search evaluation pins the inner solve to this value via
5314    /// `GlmLikelihoodSpec::with_negbin_theta_frozen_for_search`, so the REML
5315    /// criterion `F(ρ) = REML(ρ, θ_frozen)` is a stationary function of ρ and
5316    /// the outer optimizer converges instead of chasing the per-eval θ drift
5317    /// that the estimated path injects. The single final reported fit still
5318    /// ML-refreshes θ at the converged η. Reset on `reset_surface`.
5319    pub(crate) frozen_negbin_theta: Arc<AtomicU64>,
5320
5321    /// Tweedie exponential-dispersion `phi` frozen for the smoothing-parameter
5322    /// (λ) search (#1477), bit-packed `f64` (`f64::to_bits`). `0` (the default)
5323    /// signals "not yet frozen". On the first non-screening λ-search inner solve
5324    /// of an estimated-φ Tweedie fit, the seed's Pearson `phî` is captured once
5325    /// and stored here; every subsequent λ-search evaluation pins the inner
5326    /// solve to this value via
5327    /// `GlmLikelihoodSpec::with_tweedie_phi_frozen_for_search`, so the REML
5328    /// criterion `F(ρ) = REML(ρ, φ_frozen)` is a stationary function of ρ. The
5329    /// Tweedie LAML omits the `phi`-dependent saddlepoint normalizer, so a `phi`
5330    /// drifting with each warm-start η lets the criterion reward dispersion
5331    /// inflation and rail a double-penalty null-space `λ` to the box bound (the
5332    /// #1477 boundary blow-up). The single final reported fit still
5333    /// Pearson-refreshes `phi` at the converged η. Reset on `reset_surface`.
5334    pub(crate) frozen_tweedie_phi: Arc<AtomicU64>,
5335
5336    /// Gamma shape `k = 1/φ` frozen for the smoothing-parameter (λ) search
5337    /// (#1074), bit-packed `f64` (`f64::to_bits`). `0` (the default) signals
5338    /// "not yet frozen". On the first non-screening λ-search inner solve of an
5339    /// estimated-shape Gamma fit, the seed's converged-η MLE `k̂` is captured
5340    /// once and stored here; every subsequent λ-search evaluation pins the inner
5341    /// solve to this value via
5342    /// `GlmLikelihoodSpec::with_gamma_shape_frozen_for_search`, so the REML
5343    /// criterion `F(ρ) = REML(ρ, k_frozen)` is a stationary function of ρ. With
5344    /// `k` estimated the inner solver re-derives it from each warm-start η, and
5345    /// because the Gamma working weight is `W = prior·k` and the
5346    /// omitting-constants log-likelihood is `−k·½D`, a `k` swinging with η makes
5347    /// BOTH the curvature `H = k·XᵀX + λS` and the data-fit `k·½D` jump with ρ —
5348    /// the criterion grows deterministic spikes that floor the projected
5349    /// gradient and rail `λ` to the over-smoothed corner (the #1074 te/Gamma
5350    /// tensor under-recovery). The single final reported fit still ML-refreshes
5351    /// `k` at the converged η. Reset on `reset_surface`.
5352    pub(crate) frozen_gamma_shape: Arc<AtomicU64>,
5353
5354    /// Last observed IFT-prediction residual (`‖β_converged − β_predicted‖
5355    /// / ‖β_converged‖`) from the most recent non-screening solve where
5356    /// the predictor was actually consumed. Bit-packed `f64` (low 64
5357    /// bits via `f64::to_bits`).
5358    ///
5359    /// "No signal yet" is encoded as a NaN bit-pattern
5360    /// (`IFT_RESIDUAL_NO_SIGNAL_BITS`). The original `0` sentinel
5361    /// collided with `f64::to_bits(0.0) == 0` — a true residual of
5362    /// exactly 0 (degenerate but mathematically possible if every
5363    /// β_predicted_i matched β_converged_i to bit-equality) would
5364    /// have been indistinguishable from "predictor never reported".
5365    /// NaN's self-inequality makes the sentinel unambiguous: any
5366    /// stored finite non-negative value is genuine signal.
5367    ///
5368    /// Read by `predict_warm_start_beta_ift_with_outcome` to drive the adaptive
5369    /// |Δρ| cap (`adaptive_ift_max_drho`): a small residual loosens
5370    /// the cap, a large one tightens it. Replaces the previous
5371    /// hardcoded `IFT_WARM_START_MAX_DRHO = 2.0` constant with a
5372    /// data-driven policy, so the predictor adapts to the empirical
5373    /// faithfulness of the linearization at this surface's scale.
5374    /// Reset on `reset_surface` and on failed solves.
5375    pub(crate) last_ift_prediction_residual: Arc<AtomicU64>,
5376
5377    /// Last observed gain ratio of the accepted LM step
5378    /// (`actual_reduction / predicted_reduction`) from the most recent
5379    /// non-screening PIRLS solve. Bit-packed `f64` with the same NaN
5380    /// sentinel discipline as `last_ift_prediction_residual`: NaN bits
5381    /// (`IFT_RESIDUAL_NO_SIGNAL_BITS`) encode "no signal yet" so a
5382    /// recorded ratio of exactly 0 (degenerate but possible) doesn't
5383    /// collide with the no-signal token.
5384    ///
5385    /// Used by `first_order_inner_cap_schedule` as a third quality
5386    /// signal alongside `last_iters` and `last_converged`. A small
5387    /// `accept_rho` (model overstating predicted reduction) is a hint
5388    /// the next iter's inner Newton may need extra margin even when
5389    /// the previous solve converged in few iters. Reset on
5390    /// `reset_surface` and on failed solves.
5391    pub(crate) last_pirls_accept_rho: Arc<AtomicU64>,
5392
5393    /// Cached Cholesky factorization of `IftWarmStartCache::penalized_hessian_transformed`.
5394    /// Lazily computed on the first IFT predict call after a fresh
5395    /// `updatewarm_start_from`, then reused by every subsequent
5396    /// predict call until the IFT cache is invalidated. At large-scale
5397    /// scale where p can reach several thousand, the dense Cholesky
5398    /// is O(p³)/3 — multiple seconds per refactor — so caching saves
5399    /// real wall time across the typical 5-10 IFT predict calls per
5400    /// outer fit. Reset jointly with `ift_warm_start_cache` (on
5401    /// reset_surface, on link-state changes, on failed PIRLS solves,
5402    /// and whenever a new H_pen replaces the cached one).
5403    pub(crate) ift_cached_factor: RwLock<Option<Arc<dyn gam_linalg::matrix::FactorizedSystem>>>,
5404
5405    /// When set, the penalties have Kronecker (tensor-product) structure and
5406    /// the REML evaluator can use O(∏q_j) logdet instead of O(p³) eigendecomposition.
5407    /// Populated via `set_kronecker_penalty_system` after construction.
5408    pub(crate) kronecker_penalty_system: Option<gam_terms::smooth::KroneckerPenaltySystem>,
5409    /// Full Kronecker factored basis (marginal designs + penalties + dims).
5410    /// Used by P-IRLS for factored reparameterization.
5411    pub(crate) kronecker_factored: Option<gam_terms::basis::KroneckerFactoredBasis>,
5412
5413    /// Precomputed `(XᵀWX, XᵀW(y − offset))` for the Gaussian + Identity
5414    /// outer REML loop, populated once before the outer optimizer when the
5415    /// family / link / constraint preconditions hold and the design supports
5416    /// the Identity short-circuit at `pirls.rs:6237`. When present, each
5417    /// inner `solve_penalized_least_squares_implicit` reads these matrices
5418    /// instead of restreaming the O(N·p²) GEMM and O(N·p) matvec per outer
5419    /// iteration — the penalty `λ·S` is still added per-λ.
5420    ///
5421    /// Invalidated jointly with the design in `reset_surface`.
5422    pub(crate) gaussian_fixed_cache: RwLock<Option<Arc<crate::pirls::GaussianFixedCache>>>,
5423    /// Conditioned-frame exact ψ-derivatives `(∂XᵀWX/∂ψ, ∂XᵀW(y−offset)/∂ψ)`
5424    /// for the SINGLE design-moving spatial hyperparameter (#1033b), assembled
5425    /// n-free from the certified Chebyshev ψ-Gram tensor and installed beside
5426    /// `gaussian_fixed_cache` at the same in-window trial. When present the
5427    /// Gaussian-identity ψ-gradient HyperCoord (`a_j`, `g_j`, dense `B_j`) is
5428    /// formed from these k×k objects instead of realizing and contracting the
5429    /// n×k ∂X/∂ψ slab — retiring the second per-trial n-pass. Lives in the
5430    /// SAME conditioned column frame as `gaussian_fixed_cache.xtwx_orig`, so
5431    /// the hyper-coord builder transforms it by the per-eval Qs/free-basis the
5432    /// same way it transforms the streamed Gram. Invalidated with the design.
5433    pub(crate) gaussian_psi_gram_deriv:
5434        RwLock<Option<Arc<(ndarray::Array2<f64>, ndarray::Array1<f64>)>>>,
5435    /// Conditioned-frame exact ψ-derivative pair `(∂XᵀWX/∂ψ, ∂XᵀW(y−offset)/∂ψ)`
5436    /// for the SINGLE design-moving spatial hyperparameter in the GLM (frozen-W)
5437    /// lane (#1033 / #1111), assembled n-free from
5438    /// [`crate::glm_sufficient_lane::FrozenWeightGramTensor::gradient_pair_if_sound`]
5439    /// and installed beside `glm_first_step_gram` at the same in-window
5440    /// drift-OK trial. When present, the GLM ψ-gradient HyperCoord serves its
5441    /// envelope `a_j` and score `g_j` from these k×k objects instead of
5442    /// realizing and contracting the n×k ∂X/∂ψ slab — the second per-trial
5443    /// n-pass. Unlike the Gaussian lane the Hessian curvature `B_j` is NOT
5444    /// served from the tensor: for a GLM the per-trial `B_j` term
5445    /// `X_τᵀWX + XᵀWX_τ` is irreducibly n-dependent (the moving working weight
5446    /// `W` does not factor out of a frozen-W k×k object), so `B_j` keeps the
5447    /// exact streamed slab (#1033). Lives in the SAME conditioned column frame
5448    /// as `glm_first_step_gram` / `gaussian_fixed_cache.xtwx_orig`, so the
5449    /// hyper-coord builder transforms it by the per-eval Qs/free-basis the same
5450    /// way. NOT family-gated (the GLM lane's own slot). Invalidated with the
5451    /// design.
5452    pub(crate) glm_psi_gram_deriv:
5453        RwLock<Option<Arc<(ndarray::Array2<f64>, ndarray::Array1<f64>)>>>,
5454    /// Frozen-weight first-Fisher-step data-fit Gram `XᵀWX` for the GLM
5455    /// design-moving ψ-sweep (#1111 / #1033 mechanism (c)), in the conditioned
5456    /// (original / `x_fit`) column frame — the SAME frame as
5457    /// `gaussian_fixed_cache.xtwx_orig`.
5458    ///
5459    /// Assembled n-free per in-window ψ-trial from the certified frozen-weight
5460    /// Chebyshev tensor ([`crate::glm_sufficient_lane::FrozenWeightGramTensor`])
5461    /// and installed only when the trial's converged working weight has not
5462    /// drifted past tolerance from the frozen snapshot. When present, the GLM
5463    /// inner P-IRLS serves its FIRST Fisher-scoring iteration's `XᵀWX` from this
5464    /// cache instead of restreaming the O(N·p²) weighted cross-product — the
5465    /// dominant per-trial n-term in a large-n Poisson/Binomial κ-sweep. The
5466    /// penalty `Sλ` is still added per-λ on top, and every subsequent inner
5467    /// iteration restreams the true (moving) `W`, so the converged β̂ is
5468    /// unchanged; only the first-iteration Gram build is elided. Unlike
5469    /// `gaussian_fixed_cache` this is NOT family-gated — it is the GLM lane's
5470    /// own slot, consumed once per inner solve. Invalidated with the design in
5471    /// `reset_surface`.
5472    pub(crate) glm_first_step_gram: RwLock<Option<Arc<ndarray::Array2<f64>>>>,
5473    /// Previous successful non-Gaussian fixed-design data-fit Gram `XᵀWX` in
5474    /// the conditioned original frame, keyed to `warm_start_beta`.
5475    ///
5476    /// When the next outer trial uses a flat warm start, its first PIRLS
5477    /// curvature build evaluates at the same `η = Xβ` as the previous converged
5478    /// solve, so the Hessian weights and `XᵀWX` are identical. Reusing this
5479    /// original-frame Gram skips one dense `O(n·p²)` pass per warm-started
5480    /// trial while still letting later PIRLS iterations restream the moving
5481    /// weights. IFT/tangent-predicted starts do not consume this cache.
5482    pub(crate) flat_glm_first_step_gram: RwLock<Option<Arc<ndarray::Array2<f64>>>>,
5483    /// Frozen ALO robustness weights for this REML surface.
5484    ///
5485    /// The PSIS influence scale is a non-smooth function of the current hat
5486    /// diagonals. Once the high-leverage ALO objective activates, it is frozen
5487    /// for the current surface so the analytic gradient differentiates the
5488    /// same fixed-weight objective the cost evaluates.
5489    pub(crate) alo_frozen_nuisance: RwLock<Option<AloFrozenNuisance>>,
5490
5491    /// ρ-independent certificate that the Gaussian-identity ALO-stabilization
5492    /// augmentation can never activate on this surface (#1689).
5493    ///
5494    /// The augmentation engages only when some row's *penalized* leverage
5495    /// `h_i = w_i · xᵢᵀ H_λ⁻¹ xᵢ` reaches `ALO_MAX_LEVERAGE_THRESHOLD`, where
5496    /// `H_λ = XᵀWX + S_λ + ridge·I ⪰ XᵀWX`. Because `S_λ + ridge·I ⪰ 0` we have
5497    /// `H_λ⁻¹ ⪯ (XᵀWX)⁻¹`, so `h_i ≤ w_i · xᵢᵀ (XᵀWX)⁻¹ xᵢ` — the *unpenalized*
5498    /// weighted hat diagonal, which (for Gaussian identity, where W is the fixed
5499    /// prior-weight diagonal) is independent of ρ. If the max of that bound is
5500    /// below the activation threshold, no ρ can trip the gate, so the entire
5501    /// per-outer-evaluation O(n·p²) ALO leverage diagnostic — recomputed and
5502    /// discarded on every cost/gradient eval otherwise — is skipped.
5503    ///
5504    /// `Some(true)`  → provably inactive everywhere, skip the diagnostic.
5505    /// `Some(false)` → bound is ≥ threshold or XᵀWX is rank-deficient/ill-
5506    ///                 conditioned (bound not certifiable); fall through to the
5507    ///                 exact per-eval gate. Computed lazily, at most once.
5508    pub(crate) alo_provably_inactive: RwLock<Option<bool>>,
5509
5510    /// Stable disk-cache key for the current realized REML surface. Computed
5511    /// lazily because it hashes the row-chunked design and data vectors.
5512    pub(crate) persistent_warm_start_key: RwLock<Option<String>>,
5513    pub(crate) persistent_latent_values_fingerprint: Option<u64>,
5514    pub(crate) persistent_latent_values_cache: RwLock<PersistentLatentValuesCache>,
5515    pub(crate) analytic_penalty_registry_fingerprint: u64,
5516    /// Ensures the process attempts at most one disk restore per surface.
5517    pub(crate) persistent_warm_start_loaded: AtomicBool,
5518    /// Scoped counter disabling disk writes from cost-only posterior/probe
5519    /// evaluations. In-memory warm starts still update; only JSON/bin
5520    /// persistence and eviction sweeps are suppressed.
5521    pub(crate) persistent_warm_start_store_suppression: AtomicUsize,
5522    /// Scoped counter disabling the Gaussian-identity ALO-stabilization
5523    /// augmentation (#979). The leverage barrier `Σ_i (h_i − τ)₊²` is an OUTER
5524    /// OPTIMIZER aid (#813/#821) that keeps the smoothing-parameter search off
5525    /// pathological high-leverage λ regions. The marginal smoothing-parameter
5526    /// posterior `π(ρ|y) ∝ exp(−LAML(ρ))` (#938) is a property of the genuine
5527    /// model criterion, sampled against a Laplace proposal built from the BASE
5528    /// REML Hessian, so the certificate / NUTS evaluations suppress the
5529    /// augmentation (see `without_alo_stabilization`) — both for proposal↔target
5530    /// consistency and to drop the per-leapfrog ALO diagnostic suite.
5531    pub(crate) alo_stabilization_suppression: AtomicUsize,
5532    /// Whether the cross-process ON-DISK warm-start layer is engaged at all.
5533    ///
5534    /// Default `false`: the optimizer's IN-MEMORY warm start (the actual
5535    /// speed lever) is always on, but the disk checkpoint — `load_record`
5536    /// at fit start and `store_record` at finalize, each of which opens the
5537    /// shared `WarmStartStore` and pays an eviction/dir scan that is O(cache
5538    /// entries) on a network filesystem — is skipped. Disk persistence has
5539    /// reuse value only ACROSS processes or across repeated identical fits;
5540    /// a single in-process fit (and a fortiori a loop of distinct throwaway
5541    /// fits, e.g. CI-coverage replicates each on different data, #1082/#1114)
5542    /// gets zero benefit from it and pays the per-fit open/scan/save in full.
5543    /// `FitConfig::persist_warm_start_disk` flips this to `true` only when the
5544    /// caller explicitly asks for cross-process / repeat-fit persistence.
5545    pub(crate) persistent_warm_start_disk_enabled: AtomicBool,
5546    /// #1033: memoized fit-invariant O(n) response/weight scalars.
5547    ///
5548    /// `gaussian_weight_log_sum_half` (`½·Σ log wᵢ`) and `gaussian_dp_floor_scale`
5549    /// (the weighted null deviance `D₀`) are pure functions of the borrowed
5550    /// `(y, weights)` — fields `reset_surface` NEVER reassigns — so they are
5551    /// constant for the whole life of the `RemlState`. They were the last O(n)
5552    /// passes the n-free κ outer loop ran on EVERY `assemble_and_evaluate`
5553    /// (each an n-length scalar reduction, no k factor). Memoizing them once per
5554    /// fit makes the per-trial eval touch only k-dim objects, completing the
5555    /// issue's sufficient-statistic invariant literally. Plain scalar closures,
5556    /// no rayon inside the `get_or_init` (no deadlock trap).
5557    pub(crate) gaussian_weight_log_sum_half_cache: std::sync::OnceLock<f64>,
5558    pub(crate) gaussian_dp_floor_scale_cache: std::sync::OnceLock<f64>,
5559}