Skip to main content

gam_solve/reml/
mod.rs

1use self::inner_strategy::GeometryBackendKind;
2use super::*;
3use gam_linalg::sparse_exact::SparseExactFactor;
4use crate::pirls::PIRLS_CACHE_BYTE_BUDGET;
5use crate::pirls::assemble_and_factor_sparse_penalized_system;
6use gam_terms::basis::LocalDesignJacobianProvider;
7use gam_problem::SasLinkState;
8use gam_problem::OuterEval;
9use ndarray::{Array1, Array2, s};
10use std::collections::{HashMap, VecDeque};
11use std::ops::Range;
12use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize};
13use std::sync::{Arc, RwLock};
14
15pub mod assembly;
16pub mod atoms;
17pub(crate) mod continuation;
18pub(crate) mod eval;
19mod firth;
20pub(super) mod hyper;
21mod inner_strategy;
22// #1521 carve: promoted `pub(crate)` -> `pub` so the extracted
23// `gam-custom-family` crate reaches the Jeffreys-subspace items it consumes.
24pub mod jeffreys_subspace;
25pub mod outer_eval;
26pub mod penalty_logdet;
27pub mod per_atom_efs;
28pub mod reml_outer_engine;
29mod rho_key;
30mod sparse_exact_penalty;
31mod trace;
32
33pub(crate) use sparse_exact_penalty::sparse_penalty_block_count_from_canonical;
34
35pub(crate) const EXACT_TAU_TAU_HESSIAN_DENSE_CACHE_BUDGET_BYTES: usize = 512 * 1024 * 1024;
36pub(crate) const FIRTH_MAX_OBSERVATIONS: usize = 20_000;
37pub(crate) const FIRTH_MAX_COEFFICIENTS: usize = 256;
38pub(crate) const FIRTH_MAX_LINEAR_WORK: usize = 2_000_000;
39pub(crate) const FIRTH_MAX_QUADRATIC_WORK: usize = 100_000_000;
40pub(crate) const PERSISTENT_LATENT_VALUES_CACHE_CAPACITY: usize = 8;
41
42#[derive(Debug)]
43pub(crate) struct PersistentLatentValuesCache {
44    pub(crate) entries: HashMap<String, Array2<f64>>,
45    pub(crate) lru: VecDeque<String>,
46    pub(crate) capacity: usize,
47}
48
49impl Default for PersistentLatentValuesCache {
50    fn default() -> Self {
51        Self {
52            entries: HashMap::new(),
53            lru: VecDeque::new(),
54            capacity: PERSISTENT_LATENT_VALUES_CACHE_CAPACITY,
55        }
56    }
57}
58
59impl PersistentLatentValuesCache {
60    pub(crate) fn lookup(
61        &mut self,
62        key: &str,
63        n_obs: usize,
64        latent_dim: usize,
65    ) -> Option<Array2<f64>> {
66        let values = self.entries.get(key)?;
67        if values.dim() != (n_obs, latent_dim) {
68            return None;
69        }
70        let values = values.clone();
71        self.touch(key.to_string());
72        Some(values)
73    }
74
75    pub(crate) fn insert(&mut self, key: String, values: Array2<f64>) {
76        if values.iter().any(|value| !value.is_finite()) {
77            return;
78        }
79        self.entries.insert(key.clone(), values);
80        self.touch(key);
81        while self.entries.len() > self.capacity {
82            let Some(evicted) = self.lru.pop_front() else {
83                break;
84            };
85            self.entries.remove(&evicted);
86        }
87    }
88
89    pub(crate) fn touch(&mut self, key: String) {
90        if let Some(index) = self.lru.iter().position(|queued| queued == &key) {
91            self.lru.remove(index);
92        }
93        self.lru.push_back(key);
94    }
95}
96
97/// Cached state from the most recent successful PIRLS solve, populated by
98/// `updatewarm_start_from` and consumed by the IFT-based warm-start
99/// predictor (`RemlState::predict_warm_start_beta_ift_with_outcome`).
100/// See the field doc on `RemlState::ift_warm_start_cache` for the math.
101#[derive(Clone)]
102pub(crate) struct IftWarmStartCache {
103    /// β at the converged solve, in ORIGINAL basis. Mirror of
104    /// `warm_start_beta` stashed alongside the H factor for atomic
105    /// consistency under concurrent reads (the predictor needs both
106    /// β and H from the SAME solve; reading them from two locks risks
107    /// a torn pair if a fresh solve lands between reads).
108    pub beta_original: ndarray::Array1<f64>,
109    /// ρ at which the solve occurred. Mirror of `warm_start_rho`,
110    /// stashed for the same atomic-consistency reason.
111    pub rho: ndarray::Array1<f64>,
112    /// Penalized Hessian H_pen at the converged β, in TRANSFORMED basis.
113    /// The IFT predictor factors this on demand; basis transforms run in
114    /// transformed basis for numerical stability.
115    pub penalized_hessian_transformed: gam_linalg::matrix::SymmetricMatrix,
116    /// Reparameterization matrix qs converting between transformed
117    /// (column) basis and original basis: `β_orig = qs · β_tfd`,
118    /// `H_orig = qs · H_tfd · qs^T`.
119    pub qs: ndarray::Array2<f64>,
120    /// True when the PIRLS result was already in original basis
121    /// (`OriginalSparseNative`) — in which case `qs` is the identity
122    /// and the IFT predictor can skip the basis-conversion ops.
123    pub frame_was_original: bool,
124    /// Per-penalty precomputation `S_k · β_cur[cp.col_range]`,
125    /// indexed in lockstep with `RemlObjectiveState::canonical_penalties`.
126    /// Each entry is the local-block mat-vec the IFT predictor would
127    /// otherwise recompute on every predict call. With H_pen factor
128    /// caching (commit ec18559d) the per-call cost dropped from
129    /// `O(p³)` Cholesky to `O(p²) ≈ k · O(block²)` rhs construction;
130    /// at large-scale CTN (p ≈ several thousand) that's several ms
131    /// per predict call still being paid. By stashing `S_k · β_cur`
132    /// at cache-write time the predictor's per-call work drops to
133    /// just the `Δρ_k · e^{ρ_k} · sb_block` accumulation, which is
134    /// `O(p)` rather than `O(p²)`.
135    ///
136    /// `None` when the cache predates this commit's writer hook (e.g.,
137    /// transient state during invalidation); the predictor falls back
138    /// to recomputing the mat-vec when this is `None` or the length
139    /// mismatches `canonical_penalties.len()`.
140    pub lambda_s_beta_blocks: Option<Vec<ndarray::Array1<f64>>>,
141}
142
143#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
144pub(crate) struct TauTauPlanEstimate {
145    pub(crate) dense_x_bytes: usize,
146    pub(crate) first_order_tau_bytes: usize,
147    pub(crate) second_order_tau_bytes: usize,
148    pub(crate) penalty_first_bytes: usize,
149    pub(crate) penalty_pair_bytes: usize,
150    pub(crate) rho_tau_penalty_bytes: usize,
151    pub(crate) vector_cache_bytes: usize,
152    pub(crate) weighted_scratch_bytes: usize,
153}
154
155impl TauTauPlanEstimate {
156    pub(crate) fn total_bytes(self) -> usize {
157        self.dense_x_bytes
158            .saturating_add(self.first_order_tau_bytes)
159            .saturating_add(self.second_order_tau_bytes)
160            .saturating_add(self.penalty_first_bytes)
161            .saturating_add(self.penalty_pair_bytes)
162            .saturating_add(self.rho_tau_penalty_bytes)
163            .saturating_add(self.vector_cache_bytes)
164            .saturating_add(self.weighted_scratch_bytes)
165    }
166}
167
168#[derive(Clone, Copy, Debug, PartialEq, Eq)]
169pub(crate) struct TauTauHessianPolicy {
170    pub(crate) any_has_implicit: bool,
171    pub(crate) implicit_multidim_duchon: bool,
172    pub(crate) estimated_dense_tau_cache_bytes: usize,
173    pub(crate) gradient_plan: TauTauPlanEstimate,
174    pub(crate) hessian_plan: TauTauPlanEstimate,
175    pub(crate) budget_bytes: usize,
176    pub(crate) firth_pair_terms_unavailable: bool,
177}
178
179impl TauTauHessianPolicy {
180    /// True when the τ-τ exact-Hessian path cannot be assembled at all and the
181    /// eval must fall back to value-and-gradient mode (forcing
182    /// `HessianResult::Unavailable`).
183    ///
184    /// This is the *only* remaining capability gate: the previous
185    /// implementation also forced gradient-only when the design used implicit
186    /// multi-dim Duchon storage or when the dense τ-cache plan would exceed
187    /// the budget.  Both of those are now *cost* gates, not capability gates
188    /// — the unified evaluator's `prefer_outer_hessian_operator(n, p, k)`
189    /// selects the matrix-free `HessianResult::Operator` representation in
190    /// exactly the regimes where the dense cache would be unaffordable, and
191    /// the planner routes operator returns through `run_operator_trust_region`
192    /// (or basis-probes them when `dim ≤ OUTER_HVP_MATERIALIZE_MAX_DIM`).
193    /// Forcing gradient-only would have prevented the operator representation
194    /// from ever being requested, defeating that routing; hence the
195    /// `implicit_multidim_duchon` and cost-bytes clauses are deliberately
196    /// gone.
197    ///
198    /// Firth-pair-terms-unavailability remains a capability gate: when the
199    /// Firth-aware derivative provider cannot produce the τ-τ pair
200    /// corrections at all, no representation choice can substitute.  At every
201    /// production call site this flag is hardcoded `false` (the
202    /// `hphi_tau_tau_partial_apply` + `d_beta_hphi_tau_partial_apply`
203    /// primitives now cover the gap), so this method effectively returns
204    /// `false` in production.  We retain the field and method signature
205    /// unchanged so future Firth corner cases have a single, surfaced place
206    /// to land.
207    pub(crate) fn prefer_gradient_only(self) -> bool {
208        self.firth_pair_terms_unavailable
209    }
210}
211
212pub(crate) fn exact_tau_tau_hessian_policy_with_firth(
213    n_obs: usize,
214    p_coeff: usize,
215    hyper_dirs: &[DirectionalHyperParam],
216    firth_pair_terms_unavailable: bool,
217) -> TauTauHessianPolicy {
218    let f64_bytes = std::mem::size_of::<f64>();
219    let dense_matrix_bytes =
220        |rows: usize, cols: usize| -> usize { rows.saturating_mul(cols).saturating_mul(f64_bytes) };
221    let dense_design_bytes = dense_matrix_bytes(n_obs, p_coeff);
222    let dense_penalty_bytes = dense_matrix_bytes(p_coeff, p_coeff);
223    let psi_dim = hyper_dirs.len();
224    let implicit_n_axes = hyper_dirs
225        .iter()
226        .find_map(DirectionalHyperParam::implicit_axis_count_hint)
227        .unwrap_or(0);
228    let gradient_uses_implicit_design = hyper_dirs
229        .iter()
230        .any(DirectionalHyperParam::has_implicit_operator)
231        && gam_terms::basis::should_use_implicit_operators_with_policy(
232            n_obs,
233            p_coeff,
234            implicit_n_axes,
235            &gam_runtime::resource::ResourcePolicy::default_library(),
236        );
237    let dense_first_order_count = hyper_dirs
238        .iter()
239        .filter(|dir| !dir.has_implicit_operator())
240        .count();
241    let first_penalty_component_count = hyper_dirs
242        .iter()
243        .map(DirectionalHyperParam::penalty_first_component_count)
244        .sum::<usize>();
245
246    let mut dense_second_order_count = 0usize;
247    let mut penalty_pair_count = 0usize;
248    for i in 0..psi_dim {
249        for j in i..psi_dim {
250            if hyper_dirs[i]
251                .x_tau_tau_entry_at(j)
252                .or_else(|| hyper_dirs[j].x_tau_tau_entry_at(i))
253                .is_some_and(|entry| !entry.uses_implicit_storage())
254            {
255                dense_second_order_count += if i == j { 1 } else { 2 };
256            }
257            if hyper_dirs[i].has_penaltysecond_pair_at(j)
258                || hyper_dirs[j].has_penaltysecond_pair_at(i)
259            {
260                penalty_pair_count += if i == j { 1 } else { 2 };
261            }
262        }
263    }
264
265    let gradient_dense_first_order_count = if gradient_uses_implicit_design {
266        dense_first_order_count
267    } else {
268        psi_dim
269    };
270    let gradient_needs_dense_x =
271        firth_pair_terms_unavailable || gradient_dense_first_order_count > 0;
272    let gradient_plan = TauTauPlanEstimate {
273        dense_x_bytes: if gradient_needs_dense_x {
274            dense_design_bytes
275        } else {
276            0
277        },
278        first_order_tau_bytes: if gradient_dense_first_order_count > 0 {
279            dense_design_bytes
280        } else {
281            0
282        },
283        second_order_tau_bytes: 0,
284        penalty_first_bytes: psi_dim.saturating_mul(dense_penalty_bytes),
285        penalty_pair_bytes: 0,
286        rho_tau_penalty_bytes: 0,
287        vector_cache_bytes: n_obs.saturating_mul(f64_bytes),
288        weighted_scratch_bytes: dense_penalty_bytes,
289    };
290    let hessian_plan = TauTauPlanEstimate {
291        dense_x_bytes: if psi_dim > 0 { dense_design_bytes } else { 0 },
292        first_order_tau_bytes: dense_first_order_count.saturating_mul(dense_design_bytes),
293        second_order_tau_bytes: dense_second_order_count.saturating_mul(dense_design_bytes),
294        penalty_first_bytes: psi_dim.saturating_mul(dense_penalty_bytes),
295        penalty_pair_bytes: penalty_pair_count.saturating_mul(dense_penalty_bytes),
296        rho_tau_penalty_bytes: first_penalty_component_count
297            .saturating_mul(2)
298            .saturating_mul(dense_penalty_bytes),
299        vector_cache_bytes: psi_dim.saturating_mul(n_obs).saturating_mul(f64_bytes),
300        weighted_scratch_bytes: dense_penalty_bytes,
301    };
302    let any_has_implicit = hyper_dirs
303        .iter()
304        .any(DirectionalHyperParam::has_implicit_operator);
305    let implicit_multidim_duchon = hyper_dirs
306        .iter()
307        .any(DirectionalHyperParam::has_implicit_multidim_duchon);
308    let estimated_dense_tau_cache_bytes = hessian_plan
309        .first_order_tau_bytes
310        .saturating_add(hessian_plan.second_order_tau_bytes);
311    TauTauHessianPolicy {
312        any_has_implicit,
313        implicit_multidim_duchon,
314        estimated_dense_tau_cache_bytes,
315        gradient_plan,
316        hessian_plan,
317        budget_bytes: EXACT_TAU_TAU_HESSIAN_DENSE_CACHE_BUDGET_BYTES,
318        firth_pair_terms_unavailable: firth_pair_terms_unavailable && !hyper_dirs.is_empty(),
319    }
320}
321
322pub(crate) fn firth_problem_scale_allows(n_obs: usize, p_coeff: usize) -> bool {
323    let linear_work = n_obs.saturating_mul(p_coeff);
324    let quadratic_work = linear_work.saturating_mul(p_coeff);
325    n_obs <= FIRTH_MAX_OBSERVATIONS
326        && p_coeff <= FIRTH_MAX_COEFFICIENTS
327        && linear_work <= FIRTH_MAX_LINEAR_WORK
328        && quadratic_work <= FIRTH_MAX_QUADRATIC_WORK
329}
330
331#[cfg(test)]
332mod tests {
333    use super::{
334        DirectionalHyperParam, EvalCacheManager, EvalShared, HyperDesignDerivative,
335        HyperPenaltyDerivative, ImplicitDerivLevel, RemlConfig, RemlState,
336    };
337    use crate::estimate::EstimationError;
338    use gam_linalg::faer_ndarray::FaerCholesky;
339    use gam_linalg::matrix::symmetrize_in_place;
340    use crate::pirls::PirlsCoordinateFrame;
341    use gam_terms::basis::{ImplicitDesignPsiDerivative, RadialScalarKind};
342    use gam_problem::{
343        GlmLikelihoodSpec, InverseLink, LikelihoodSpec, ResponseFamily, StandardLink,
344    };
345    use faer::Side;
346    use gam_problem::{HessianResult, OuterEval};
347    use ndarray::{Array1, Array2, array, s};
348    use std::sync::Arc;
349
350    /// Shorthand for the canonical Binomial-Logit `GlmLikelihoodSpec` used by
351    /// the REML test fixtures.
352    pub(crate) fn binomial_logit_glm_spec() -> GlmLikelihoodSpec {
353        GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
354            ResponseFamily::Binomial,
355            InverseLink::Standard(StandardLink::Logit),
356        ))
357    }
358
359    /// Shorthand for the canonical Gaussian-Identity `GlmLikelihoodSpec` used
360    /// by the REML test fixtures.
361    pub(crate) fn gaussian_identity_glm_spec() -> GlmLikelihoodSpec {
362        GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
363            ResponseFamily::Gaussian,
364            InverseLink::Standard(StandardLink::Identity),
365        ))
366    }
367
368    impl DirectionalHyperParam {
369        pub(super) fn new(
370            x_tau_original: Array2<f64>,
371            penalty_first_components: Vec<(usize, Array2<f64>)>,
372            x_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
373            penaltysecond_components: Option<Vec<Option<Vec<(usize, Array2<f64>)>>>>,
374        ) -> Result<Self, EstimationError> {
375            let x_tau_tau_original = x_tau_tau_original.map(|rows| {
376                rows.into_iter()
377                    .map(|entry| entry.map(HyperDesignDerivative::from))
378                    .collect::<Vec<_>>()
379            });
380            let penalty_first_components = penalty_first_components
381                .into_iter()
382                .map(|(idx, matrix)| (idx, HyperPenaltyDerivative::from(matrix)))
383                .collect();
384            let penaltysecond_components = penaltysecond_components.map(|rows| {
385                rows.into_iter()
386                    .map(|row| {
387                        row.map(|components| {
388                            components
389                                .into_iter()
390                                .map(|(idx, matrix)| (idx, HyperPenaltyDerivative::from(matrix)))
391                                .collect::<Vec<_>>()
392                        })
393                    })
394                    .collect::<Vec<_>>()
395            });
396            Self::new_compact(
397                HyperDesignDerivative::from(x_tau_original),
398                penalty_first_components,
399                x_tau_tau_original,
400                penaltysecond_components,
401            )
402        }
403
404        pub(super) fn single_penalty(
405            penalty_index: usize,
406            x_tau_original: Array2<f64>,
407            s_tau_original: Array2<f64>,
408            x_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
409            s_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
410        ) -> Result<Self, EstimationError> {
411            let penaltysecond_components = s_tau_tau_original.map(|rows| {
412                rows.into_iter()
413                    .map(|mat| mat.map(|mat| vec![(penalty_index, mat)]))
414                    .collect::<Vec<_>>()
415            });
416            Self::new(
417                x_tau_original,
418                vec![(penalty_index, s_tau_original)],
419                x_tau_tau_original,
420                penaltysecond_components,
421            )
422        }
423    }
424
425    #[test]
426    pub(crate) fn firth_problem_scale_gate_blocks_large_quadratic_work() {
427        assert!(super::firth_problem_scale_allows(2_000, 200));
428        assert!(!super::firth_problem_scale_allows(4_800, 241));
429        assert!(!super::firth_problem_scale_allows(4_800, 433));
430    }
431
432    #[test]
433    pub(crate) fn tau_tau_hessian_policy_prefers_gradient_only_for_implicit_tau() {
434        let operator = ImplicitDesignPsiDerivative::new(
435            array![1.0, 2.0, 3.0, 4.0],
436            array![0.5, -1.0, 1.5, 2.0],
437            array![0.1, 0.2, 0.3, 0.4],
438            array![[1.0, 0.2], [0.5, 0.1], [1.5, 0.3], [2.0, 0.4]],
439            None,
440            None,
441            2,
442            2,
443            1,
444            2,
445        );
446        let dir = DirectionalHyperParam::new_compact(
447            HyperDesignDerivative::from_implicit(
448                Arc::new(operator),
449                ImplicitDerivLevel::First(0),
450                1..4,
451                5,
452            ),
453            Vec::new(),
454            None,
455            None,
456        )
457        .expect("implicit directional hyperparam");
458        let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], false);
459        assert!(policy.any_has_implicit);
460        assert_eq!(
461            policy.gradient_plan.dense_x_bytes,
462            10 * 5 * std::mem::size_of::<f64>()
463        );
464        assert!(!policy.prefer_gradient_only());
465    }
466
467    #[test]
468    pub(crate) fn tau_tau_hessian_policy_does_not_force_gradient_only_for_implicit_multidim_duchon()
469    {
470        // Multi-dim Duchon implicit storage used to force gradient-only,
471        // because the τ-cache materialization plan was infeasible.  The
472        // unified evaluator now elects the matrix-free
473        // `HessianResult::Operator` representation in this regime via
474        // `prefer_outer_hessian_operator`, so the planner can route to the
475        // operator trust-region (or basis-probe to dense for small K) — the
476        // capability is preserved and gradient-only must NOT engage.
477        let operator = ImplicitDesignPsiDerivative::new_streaming(
478            Arc::new(array![[0.0, 0.0], [1.0, 0.2]]),
479            Arc::new(array![[0.0, 0.0], [1.0, 1.0]]),
480            vec![0.0, 0.0],
481            RadialScalarKind::PureDuchon {
482                block_order: 1,
483                p_order: 0,
484                s_order: 0,
485                dim: 2,
486            },
487            None,
488            None,
489            0,
490        );
491        let dir = DirectionalHyperParam::new_compact(
492            HyperDesignDerivative::from_implicit(
493                Arc::new(operator),
494                ImplicitDerivLevel::First(0),
495                0..2,
496                2,
497            ),
498            Vec::new(),
499            None,
500            None,
501        )
502        .expect("implicit duchon directional hyperparam");
503        let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], false);
504        assert!(policy.any_has_implicit);
505        assert!(policy.implicit_multidim_duchon);
506        assert!(!policy.prefer_gradient_only());
507    }
508
509    #[test]
510    pub(crate) fn tau_tau_hessian_policy_does_not_force_gradient_only_when_cache_budget_is_exceeded()
511     {
512        // The dense τ-cache plan exceeds the budget, but cost is no longer a
513        // capability gate: the eval-side selects the matrix-free operator
514        // representation in exactly this regime, and the planner routes
515        // accordingly.  `prefer_gradient_only` must NOT force `Unavailable`
516        // here.
517        let dirs = (0..16)
518            .map(|_| {
519                DirectionalHyperParam::new_compact(
520                    HyperDesignDerivative::from(Array2::<f64>::zeros((2, 2))),
521                    Vec::new(),
522                    None,
523                    None,
524                )
525                .expect("dense directional hyperparam")
526            })
527            .collect::<Vec<_>>();
528        let policy = super::exact_tau_tau_hessian_policy_with_firth(320_000, 71, &dirs, false);
529        assert!(!policy.any_has_implicit);
530        assert!(policy.hessian_plan.total_bytes() > policy.budget_bytes);
531        assert!(policy.hessian_plan.total_bytes() > policy.gradient_plan.total_bytes());
532        assert!(!policy.prefer_gradient_only());
533    }
534
535    #[test]
536    pub(crate) fn tau_tau_hessian_policy_prefers_gradient_only_for_firth_pair_gap() {
537        let dir = DirectionalHyperParam::new_compact(
538            HyperDesignDerivative::from(Array2::<f64>::zeros((2, 2))),
539            Vec::new(),
540            None,
541            None,
542        )
543        .expect("dense directional hyperparam");
544        let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], true);
545        assert!(policy.firth_pair_terms_unavailable);
546        assert!(policy.prefer_gradient_only());
547    }
548
549    /// Common shape for the design-motion + penalty-motion REML test fixtures
550    /// (Gaussian-identity and binomial-logit at present): both carry the
551    /// same `(y, w, X, S0, cfg, ρ)` plus a perturbation pair, and need the
552    /// same three helpers (`state`, `state_perturbed`, `fd_directional_gradient`).
553    /// Per-fixture `new()` constructors fill the fields with family-specific
554    /// data; the helpers below are shared via default impls so every fixture
555    /// pays the boilerplate once.
556    trait LogitDesignMotionFixture {
557        fn y(&self) -> &Array1<f64>;
558        fn w(&self) -> &Array1<f64>;
559        fn x(&self) -> &Array2<f64>;
560        fn s0(&self) -> &Array2<f64>;
561        fn cfg(&self) -> &RemlConfig;
562        fn rho(&self) -> &Array1<f64>;
563
564        fn state(&self) -> RemlState<'_> {
565            build_logit_state(self.y(), self.w(), self.x(), self.s0(), self.cfg())
566        }
567
568        fn state_perturbed(
569            &self,
570            x_tau: &Array2<f64>,
571            s_tau: &Array2<f64>,
572            eps: f64,
573        ) -> (RemlState<'_>, RemlState<'_>) {
574            let x_plus = self.x() + &x_tau.mapv(|v| eps * v);
575            let x_minus = self.x() - &x_tau.mapv(|v| eps * v);
576            let s_plus = self.s0() + &s_tau.mapv(|v| eps * v);
577            let s_minus = self.s0() - &s_tau.mapv(|v| eps * v);
578            (
579                build_logit_state(self.y(), self.w(), &x_plus, &s_plus, self.cfg()),
580                build_logit_state(self.y(), self.w(), &x_minus, &s_minus, self.cfg()),
581            )
582        }
583
584        /// Central FD approximation to the directional cost derivative at ρ.
585        fn fd_directional_gradient(&self, x_tau: &Array2<f64>, s_tau: &Array2<f64>) -> f64 {
586            let h = 2e-5;
587            let (state_plus, state_minus) = self.state_perturbed(x_tau, s_tau, h);
588            let v_plus = state_plus.compute_cost(self.rho()).expect("cost+");
589            let v_minus = state_minus.compute_cost(self.rho()).expect("cost-");
590            (v_plus - v_minus) / (2.0 * h)
591        }
592    }
593
594    pub(crate) fn build_logit_state<'a>(
595        y: &'a Array1<f64>,
596        w: &'a Array1<f64>,
597        x: &Array2<f64>,
598        s: &Array2<f64>,
599        cfg: &'a RemlConfig,
600    ) -> RemlState<'a> {
601        use crate::estimate::PenaltySpec;
602        let p = x.ncols();
603        let offset = Array1::<f64>::zeros(y.len());
604        let spec = PenaltySpec::Dense(s.clone());
605        let canonical = gam_terms::construction::canonicalize_penalty_specs(&[spec], &[1], p, "test")
606            .map(|(canonical, _)| canonical)
607            .expect("canonicalize");
608        RemlState::newwith_offset(
609            y.view(),
610            x.clone(),
611            w.view(),
612            offset.view(),
613            canonical,
614            p,
615            cfg,
616            Some(vec![1]),
617            None,
618            None,
619        )
620        .expect("state")
621    }
622
623    #[test]
624    fn repeated_penalty_ranges_keep_analytic_outer_hessian() {
625        let y = array![0.2, -0.1, 0.3, 0.0];
626        let w = Array1::<f64>::ones(y.len());
627        let x = array![[1.0, -0.7], [1.0, -0.2], [1.0, 0.3], [1.0, 0.9]];
628        let offset = Array1::<f64>::zeros(y.len());
629        let cfg = RemlConfig::external(gaussian_identity_glm_spec(), 1e-10, false);
630        let p = x.ncols();
631        let canonical = vec![
632            gam_terms::construction::CanonicalPenalty::from_dense_root(array![[0.0, 1.0]], p),
633            gam_terms::construction::CanonicalPenalty::from_dense_root(array![[1.0, 0.0]], p),
634        ];
635        let state = RemlState::newwith_offset(
636            y.view(),
637            x,
638            w.view(),
639            offset.view(),
640            canonical,
641            p,
642            &cfg,
643            Some(vec![1, 1]),
644            None,
645            None,
646        )
647        .expect("state");
648
649        assert!(
650            state.analytic_outer_hessian_enabled(),
651            "double-penalty-style repeated coefficient ranges must still route to exact Hessian"
652        );
653    }
654
655    #[test]
656    fn canonical_logit_firth_declines_exact_tk_hessian_when_row_pair_work_is_large() {
657        let n = 2_000usize;
658        let p = 28usize;
659        let y = Array1::from_iter((0..n).map(|i| if i % 3 == 0 { 1.0 } else { 0.0 }));
660        let w = Array1::<f64>::ones(n);
661        let mut x = Array2::<f64>::zeros((n, p));
662        for i in 0..n {
663            let t = (i as f64 + 0.5) / n as f64;
664            x[[i, 0]] = 1.0;
665            for j in 1..p {
666                x[[i, j]] = ((j as f64) * std::f64::consts::TAU * t).sin()
667                    + 0.25 * (((j + 1) as f64) * std::f64::consts::TAU * t).cos();
668            }
669        }
670        let mut s = Array2::<f64>::zeros((p, p));
671        for j in 1..p {
672            s[[j, j]] = 1.0;
673        }
674        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, true);
675        let state = build_logit_state(&y, &w, &x, &s, &cfg);
676
677        assert!(
678            !RemlState::firth_tk_exact_hessian_scale_allows(n, p),
679            "fixture must sit beyond the O(n²·p) exact-Hessian budget"
680        );
681        assert!(
682            !state.analytic_outer_hessian_enabled(),
683            "large canonical-logit Firth fits should keep exact value/gradient but route outer curvature to BFGS"
684        );
685    }
686
687    #[test]
688    fn canonical_logit_firth_keeps_exact_tk_hessian_for_small_separation_guards() {
689        let n = 40usize;
690        let p = 6usize;
691        let y = Array1::from_iter((0..n).map(|i| if i >= n / 2 { 1.0 } else { 0.0 }));
692        let w = Array1::<f64>::ones(n);
693        let mut x = Array2::<f64>::zeros((n, p));
694        for i in 0..n {
695            let t = (i as f64) / (n - 1) as f64;
696            x[[i, 0]] = 1.0;
697            for j in 1..p {
698                x[[i, j]] = t.powi(j as i32);
699            }
700        }
701        let mut s = Array2::<f64>::zeros((p, p));
702        for j in 1..p {
703            s[[j, j]] = 1.0;
704        }
705        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, true);
706        let state = build_logit_state(&y, &w, &x, &s, &cfg);
707
708        assert!(RemlState::firth_tk_exact_hessian_scale_allows(n, p));
709        assert!(
710            state.analytic_outer_hessian_enabled(),
711            "small Firth rescue fits should keep exact TK Hessian curvature"
712        );
713    }
714
715    pub(crate) fn poisson_log_glm_spec() -> GlmLikelihoodSpec {
716        GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
717            ResponseFamily::Poisson,
718            InverseLink::Standard(StandardLink::Log),
719        ))
720    }
721
722    /// Regression (issue #893): for a fixed-dispersion family a uniform prior
723    /// weight `w = c` is *exact* `c`-fold row replication. The two encodings must
724    /// therefore present a byte-identical LAML smoothing-selection surface — both
725    /// the cost `V(ρ)` and its gradient `∇V(ρ)` — because every term (penalised
726    /// deviance `D_p`, the working cross-product `XᵀWX`, the log-determinants)
727    /// is a sum of per-observation contributions that is identical whether a row
728    /// carries weight `c` or is stacked `c` times. This locks the *surface*
729    /// invariant that #893 ultimately reduces to: when the surfaces coincide,
730    /// the only remaining requirement for `λ̂(w=c) = λ̂(c×)` is that the outer
731    /// optimiser resolve the shared optimum (handled by the tightened outer
732    /// tolerance in `workflow.rs`). A regression that reintroduced a
733    /// row-count-vs-weight-sum asymmetry into the inner solve or the cost would
734    /// break this directly, independent of the optimiser tolerance.
735    #[test]
736    pub(crate) fn fixed_dispersion_laml_surface_is_replication_invariant() {
737        let n = 200usize;
738        let p = 8usize;
739        let c = 3usize;
740        let mut x = Array2::<f64>::zeros((n, p));
741        let mut y = Array1::<f64>::zeros(n);
742        for i in 0..n {
743            let t = (i as f64) / ((n - 1) as f64);
744            let tau = std::f64::consts::TAU;
745            x[[i, 0]] = 1.0;
746            x[[i, 1]] = t;
747            x[[i, 2]] = (tau * t).sin();
748            x[[i, 3]] = (tau * t).cos();
749            x[[i, 4]] = (2.0 * tau * t).sin();
750            x[[i, 5]] = (2.0 * tau * t).cos();
751            x[[i, 6]] = (3.0 * tau * t).sin();
752            x[[i, 7]] = (3.0 * tau * t).cos();
753            let eta = 0.3 + 0.9 * (1.4 * (t - 0.5)).sin();
754            // Deterministic non-negative integer counts near exp(eta).
755            y[i] = (eta.exp() + 0.5 * ((i as f64) * 2.399_963).sin())
756                .round()
757                .max(0.0);
758        }
759        let mut s = Array2::<f64>::zeros((p, p));
760        for j in 1..p {
761            s[[j, j]] = 1.0;
762        }
763
764        // Replicated design (c literal copies of each row).
765        let mut x_rep = Array2::<f64>::zeros((n * c, p));
766        let mut y_rep = Array1::<f64>::zeros(n * c);
767        for r in 0..c {
768            for i in 0..n {
769                let row = r * n + i;
770                for j in 0..p {
771                    x_rep[[row, j]] = x[[i, j]];
772                }
773                y_rep[row] = y[i];
774            }
775        }
776
777        let w_weighted = Array1::<f64>::from_elem(n, c as f64);
778        let w_rep = Array1::<f64>::ones(n * c);
779
780        let cfg = RemlConfig::external(poisson_log_glm_spec(), 1e-10, false);
781        let st_w = build_logit_state(&y, &w_weighted, &x, &s, &cfg);
782        let st_r = build_logit_state(&y_rep, &w_rep, &x_rep, &s, &cfg);
783
784        for &rho in &[-2.0_f64, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0] {
785            let r = Array1::from_elem(1, rho);
786            let cw = st_w.compute_cost(&r).expect("weighted cost");
787            let cr = st_r.compute_cost(&r).expect("replicated cost");
788            let gw = st_w.compute_gradient(&r).expect("weighted grad");
789            let gr = st_r.compute_gradient(&r).expect("replicated grad");
790            // Costs and gradients must coincide to optimiser precision; the only
791            // admissible difference is f64 summation order over n vs c·n rows.
792            assert!(
793                (cw - cr).abs() <= 1e-9 * (1.0 + cw.abs()),
794                "LAML cost differs between w=c and c× replication at rho={rho}: \
795                 cost_w={cw:.12e} cost_r={cr:.12e} diff={:.3e}",
796                cw - cr
797            );
798            assert!(
799                (gw[0] - gr[0]).abs() <= 1e-9 * (1.0 + gw[0].abs()),
800                "LAML gradient differs between w=c and c× replication at rho={rho}: \
801                 g_w={:.12e} g_r={:.12e} diff={:.3e}",
802                gw[0],
803                gr[0],
804                gw[0] - gr[0]
805            );
806        }
807    }
808
809    /// Regression (issue #893): the geometric-mean log-weight ρ-anchor
810    /// ([`RemlState::rho_weight_anchor`]) is a *profiled*-dispersion construct
811    /// (issue #877). For a fixed-dispersion family the optimum does not slide by
812    /// `log c` under a weight rescale in a way the prior should track, and a
813    /// nonzero anchor would evaluate the regularising ρ-prior at *different*
814    /// coordinates for the `w=c` vs `c×` encodings — breaking the very
815    /// equivalence #893 requires. The anchor must therefore be exactly `0` for a
816    /// fixed-dispersion family and the geometric mean for Gaussian-identity.
817    #[test]
818    pub(crate) fn rho_weight_anchor_is_zero_for_fixed_dispersion() {
819        let n = 50usize;
820        let p = 3usize;
821        let mut x = Array2::<f64>::zeros((n, p));
822        let mut y = Array1::<f64>::zeros(n);
823        for i in 0..n {
824            let t = (i as f64) / ((n - 1) as f64);
825            x[[i, 0]] = 1.0;
826            x[[i, 1]] = t;
827            x[[i, 2]] = t * t;
828            y[i] = (1.0 + (3.0 * t).sin()).round().max(0.0);
829        }
830        let mut s = Array2::<f64>::zeros((p, p));
831        s[[2, 2]] = 1.0;
832        // All weights = c: geometric-mean log-weight = ln(c) ≠ 0.
833        let c = 4.0_f64;
834        let w = Array1::<f64>::from_elem(n, c);
835
836        let cfg_pois = RemlConfig::external(poisson_log_glm_spec(), 1e-10, false);
837        let st_pois = build_logit_state(&y, &w, &x, &s, &cfg_pois);
838        assert_eq!(
839            st_pois.rho_weight_anchor(),
840            0.0,
841            "fixed-dispersion (Poisson) anchor must be 0, not the geometric-mean log-weight"
842        );
843
844        let cfg_gauss = RemlConfig::external(gaussian_identity_glm_spec(), 1e-10, false);
845        let st_gauss = build_logit_state(&y, &w, &x, &s, &cfg_gauss);
846        assert!(
847            (st_gauss.rho_weight_anchor() - c.ln()).abs() <= 1e-12,
848            "Gaussian-identity (profiled) anchor must be the geometric-mean log-weight ln(c)={:.6}, got {:.6}",
849            c.ln(),
850            st_gauss.rho_weight_anchor()
851        );
852    }
853
854    pub(crate) fn beta_original_from_bundle(bundle: &EvalShared) -> Array1<f64> {
855        let pr = bundle.pirls_result.as_ref();
856        match pr.coordinate_frame {
857            PirlsCoordinateFrame::OriginalSparseNative => pr.beta_transformed.as_ref().clone(),
858            PirlsCoordinateFrame::TransformedQs => {
859                pr.reparam_result.qs.dot(pr.beta_transformed.as_ref())
860            }
861        }
862    }
863
864    pub(crate) fn compute_joint_hypercostgradienthessian(
865        state: &RemlState<'_>,
866        theta: &Array1<f64>,
867        rho_dim: usize,
868        hyper_dirs: &[DirectionalHyperParam],
869    ) -> Result<(f64, Array1<f64>, Array2<f64>), EstimationError> {
870        let (cost, gradient, hessian) = state.compute_joint_hyper_eval_with_order(
871            theta,
872            rho_dim,
873            hyper_dirs,
874            crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
875        )?;
876        Ok((
877            cost,
878            gradient,
879            hessian
880                .materialize_dense()
881                .map_err(EstimationError::RemlOptimizationFailed)?
882                .ok_or_else(|| {
883                    EstimationError::RemlOptimizationFailed(
884                        "joint hyper Hessian requested but unavailable".to_string(),
885                    )
886                })?,
887        ))
888    }
889
890    pub(crate) fn h_original_from_bundle(bundle: &EvalShared) -> Array2<f64> {
891        let pr = bundle.pirls_result.as_ref();
892        match pr.coordinate_frame {
893            PirlsCoordinateFrame::OriginalSparseNative => bundle.h_total.as_ref().clone(),
894            PirlsCoordinateFrame::TransformedQs => {
895                let qs = &pr.reparam_result.qs;
896                let tmp = gam_linalg::faer_ndarray::fast_ab(qs, bundle.h_total.as_ref());
897                gam_linalg::faer_ndarray::fast_abt(&tmp, qs)
898            }
899        }
900    }
901
902    pub(crate) fn single_directional_tau_gradient(
903        state: &RemlState<'_>,
904        rho: &Array1<f64>,
905        hyper: DirectionalHyperParam,
906    ) -> Result<f64, EstimationError> {
907        let mut theta = Array1::<f64>::zeros(rho.len() + 1);
908        theta.slice_mut(s![..rho.len()]).assign(rho);
909        let (_, gradient, _) = state.compute_joint_hyper_eval_with_order(
910            &theta,
911            rho.len(),
912            &[hyper],
913            crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
914        )?;
915        Ok(gradient[rho.len()])
916    }
917
918    pub(crate) fn fd_directional_tau_cost_gradient(
919        y: &Array1<f64>,
920        w: &Array1<f64>,
921        x: &Array2<f64>,
922        s0: &Array2<f64>,
923        cfg: &RemlConfig,
924        rho: &Array1<f64>,
925        x_tau: &Array2<f64>,
926        s_tau: &Array2<f64>,
927    ) -> f64 {
928        let h = 2e-5;
929        let x_plus = x + &x_tau.mapv(|v| h * v);
930        let x_minus = x - &x_tau.mapv(|v| h * v);
931        let s_plus = s0 + &s_tau.mapv(|v| h * v);
932        let s_minus = s0 - &s_tau.mapv(|v| h * v);
933        let state_plus = build_logit_state(y, w, &x_plus, &s_plus, cfg);
934        let state_minus = build_logit_state(y, w, &x_minus, &s_minus, cfg);
935        let v_plus = state_plus.compute_cost(rho).expect("cost+");
936        let v_minus = state_minus.compute_cost(rho).expect("cost-");
937        (v_plus - v_minus) / (2.0 * h)
938    }
939
940    pub(crate) fn directional_tau_hessian_fd_reference(
941        y: &Array1<f64>,
942        w: &Array1<f64>,
943        x: &Array2<f64>,
944        s0: &Array2<f64>,
945        cfg: &RemlConfig,
946        rho: &Array1<f64>,
947        hyper_dirs: &[DirectionalHyperParam],
948        x_tau_mats: &[Array2<f64>],
949        s_tau_mats: &[Array2<f64>],
950    ) -> Array2<f64> {
951        assert_eq!(hyper_dirs.len(), x_tau_mats.len());
952        assert_eq!(hyper_dirs.len(), s_tau_mats.len());
953
954        const TARGET_PHYSICAL_STEP: f64 = 1e-5;
955
956        let n_dirs = hyper_dirs.len();
957        let mut h_ttfd = Array2::<f64>::zeros((n_dirs, n_dirs));
958        for j in 0..n_dirs {
959            let direction_scale = x_tau_mats[j]
960                .iter()
961                .chain(s_tau_mats[j].iter())
962                .fold(0.0_f64, |acc, value| acc.max(value.abs()));
963            let h = if direction_scale > 0.0 {
964                TARGET_PHYSICAL_STEP / direction_scale
965            } else {
966                TARGET_PHYSICAL_STEP
967            };
968
969            let x_plus = x + &x_tau_mats[j].mapv(|v| h * v);
970            let x_minus = x - &x_tau_mats[j].mapv(|v| h * v);
971            let s_plus = s0 + &s_tau_mats[j].mapv(|v| h * v);
972            let s_minus = s0 - &s_tau_mats[j].mapv(|v| h * v);
973
974            let state_plus = build_logit_state(y, w, &x_plus, &s_plus, cfg);
975            let state_minus = build_logit_state(y, w, &x_minus, &s_minus, cfg);
976            for i in 0..n_dirs {
977                let g_plus =
978                    single_directional_tau_gradient(&state_plus, rho, hyper_dirs[i].clone())
979                        .expect("g+ for FD");
980                let g_minus =
981                    single_directional_tau_gradient(&state_minus, rho, hyper_dirs[i].clone())
982                        .expect("g- for FD");
983                h_ttfd[[i, j]] = (g_plus - g_minus) / (2.0 * h);
984            }
985        }
986        symmetrize_in_place(&mut h_ttfd);
987        h_ttfd
988    }
989
990    #[test]
991    pub(crate) fn eval_cache_manager_stores_first_order_outer_eval() {
992        let cache = EvalCacheManager::new();
993        let rho = array![0.25, -0.0];
994        let rho_key = EvalCacheManager::sanitized_rhokey(&rho);
995        let eval = OuterEval {
996            cost: 3.5,
997            gradient: array![1.0, -2.0],
998            hessian: HessianResult::Unavailable,
999            inner_beta_hint: None,
1000        };
1001
1002        cache.store_outer_eval(&rho_key, &eval);
1003
1004        let cached = cache
1005            .cached_outer_eval(&rho_key)
1006            .expect("first-order outer eval should be cached");
1007        assert_eq!(cached.cost, eval.cost);
1008        assert_eq!(cached.gradient, eval.gradient);
1009        assert!(matches!(cached.hessian, HessianResult::Unavailable));
1010
1011        cache.invalidate_eval_bundle();
1012        assert!(
1013            cache.cached_outer_eval(&rho_key).is_none(),
1014            "invalidating the bundle should clear the outer-eval cache too"
1015        );
1016    }
1017
1018    /// #1575 multi-slot outer-eval cache correctness oracle.
1019    ///
1020    /// A memoization is only safe if a hit returns *exactly* what the miss path
1021    /// stored. This pins three properties of the bounded LRU:
1022    ///   1. round-trip fidelity — a hit is `f64::to_bits`-identical in cost AND
1023    ///      every gradient component to the value stored on the miss path;
1024    ///   2. no aliasing — distinct rho-keys never return each other's eval;
1025    ///   3. honest eviction — once an evicted key is requested again it MISSES
1026    ///      (so the caller recomputes) rather than returning a stale neighbour.
1027    #[test]
1028    pub(crate) fn outer_eval_lru_hit_is_bit_identical_and_evicts_honestly_1575() {
1029        use super::OUTER_EVAL_LRU_CAPACITY;
1030
1031        // Helper: a deterministic OuterEval whose bits encode `seed`, so any
1032        // cross-key contamination is detectable bit-for-bit.
1033        let make_eval = |seed: f64| OuterEval {
1034            cost: (seed * std::f64::consts::PI).sin() / 3.0 - seed,
1035            gradient: array![seed, -seed * 2.0, seed.recip()],
1036            hessian: HessianResult::Unavailable,
1037            inner_beta_hint: Some(array![seed + 0.5, seed - 0.5]),
1038        };
1039        let bits_eq = |a: &OuterEval, b: &OuterEval| -> bool {
1040            a.cost.to_bits() == b.cost.to_bits()
1041                && a.gradient.len() == b.gradient.len()
1042                && a.gradient
1043                    .iter()
1044                    .zip(b.gradient.iter())
1045                    .all(|(x, y)| x.to_bits() == y.to_bits())
1046        };
1047
1048        let cache = EvalCacheManager::new();
1049
1050        // (1) Round-trip fidelity: store at rho_a, then a forced hit must equal
1051        // the stored eval bit-for-bit (the "hit == miss" guarantee).
1052        let rho_a = array![0.25, -1.5];
1053        let key_a = EvalCacheManager::sanitized_rhokey(&rho_a);
1054        let eval_a = make_eval(0.25);
1055        cache.store_outer_eval(&key_a, &eval_a);
1056        let hit_a = cache
1057            .cached_outer_eval(&key_a)
1058            .expect("stored rho_a must hit");
1059        assert!(
1060            bits_eq(&hit_a, &eval_a),
1061            "cache hit must be bit-identical (cost+gradient) to the stored miss-path eval"
1062        );
1063        assert_eq!(
1064            hit_a.inner_beta_hint.as_ref().map(|b| b.to_vec()),
1065            eval_a.inner_beta_hint.as_ref().map(|b| b.to_vec()),
1066            "inner_beta_hint must round-trip unchanged"
1067        );
1068
1069        // (2) No aliasing: a second, distinct rho must return its OWN eval, and
1070        // the first key must still return the first eval untouched.
1071        let rho_b = array![0.25, -1.4999999999999998];
1072        let key_b = EvalCacheManager::sanitized_rhokey(&rho_b);
1073        assert_ne!(key_a, key_b, "the two rho-keys must differ");
1074        let eval_b = make_eval(7.0);
1075        cache.store_outer_eval(&key_b, &eval_b);
1076        assert!(
1077            bits_eq(
1078                &cache.cached_outer_eval(&key_b).expect("rho_b must hit"),
1079                &eval_b
1080            ),
1081            "rho_b must return its own eval, not rho_a's"
1082        );
1083        assert!(
1084            bits_eq(
1085                &cache.cached_outer_eval(&key_a).expect("rho_a must still hit"),
1086                &eval_a
1087            ),
1088            "rho_a must be unaffected by the rho_b insert"
1089        );
1090
1091        // (3) Honest eviction: overflow the LRU with fresh keys. The
1092        // least-recently-used entry must be evicted and then MISS (forcing a
1093        // recompute), while a still-resident key returns its exact stored bits.
1094        let cache = EvalCacheManager::new();
1095        let mut keys = Vec::new();
1096        let mut evals = Vec::new();
1097        for i in 0..OUTER_EVAL_LRU_CAPACITY {
1098            let rho = array![i as f64, -(i as f64)];
1099            let key = EvalCacheManager::sanitized_rhokey(&rho);
1100            let eval = make_eval(i as f64 + 0.123);
1101            cache.store_outer_eval(&key, &eval);
1102            keys.push(key);
1103            evals.push(eval);
1104        }
1105        // Cache is exactly full; key[0] is the least-recently-used.
1106        assert_eq!(
1107            cache.outer_eval_lru.read().unwrap().entries.len(),
1108            OUTER_EVAL_LRU_CAPACITY
1109        );
1110        // One more distinct key evicts the LRU (key[0]).
1111        let rho_overflow = array![999.0, -999.0];
1112        let key_overflow = EvalCacheManager::sanitized_rhokey(&rho_overflow);
1113        let eval_overflow = make_eval(42.0);
1114        cache.store_outer_eval(&key_overflow, &eval_overflow);
1115        assert_eq!(
1116            cache.outer_eval_lru.read().unwrap().entries.len(),
1117            OUTER_EVAL_LRU_CAPACITY,
1118            "capacity must stay bounded"
1119        );
1120        assert!(
1121            cache.cached_outer_eval(&keys[0]).is_none(),
1122            "the least-recently-used key must be evicted and now MISS (recompute), not return stale"
1123        );
1124        assert!(
1125            bits_eq(
1126                &cache
1127                    .cached_outer_eval(&keys[1])
1128                    .expect("a still-resident key must hit"),
1129                &evals[1]
1130            ),
1131            "a still-resident key must return its exact stored bits"
1132        );
1133        assert!(
1134            bits_eq(
1135                &cache
1136                    .cached_outer_eval(&key_overflow)
1137                    .expect("the freshest key must hit"),
1138                &eval_overflow
1139            ),
1140            "the freshest key must hit with its own eval"
1141        );
1142    }
1143
1144    #[test]
1145    pub(crate) fn reset_outer_seed_state_clears_pirls_cache() {
1146        // Build a minimal logit RemlState, populate the cross-call PIRLS LRU
1147        // by evaluating the outer objective at one rho, then verify that
1148        // reset_outer_seed_state wipes that LRU (alongside the eval bundle
1149        // and warm-start signals). This pins down the cross-attempt
1150        // cleanup contract that a budget-bump retry relies on.
1151        let y = array![0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
1152        let w = Array1::<f64>::ones(y.len());
1153        let x = array![
1154            [1.0, -1.0, 0.2],
1155            [1.0, -0.5, -0.4],
1156            [1.0, 0.0, 0.7],
1157            [1.0, 0.4, -0.3],
1158            [1.0, 0.9, 0.1],
1159            [1.0, 1.3, -0.6],
1160        ];
1161        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.1, 0.15], [0.0, 0.15, 0.8],];
1162        let rho = array![0.0];
1163        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false);
1164        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1165
1166        // Trigger a full outer eval so execute_pirls_if_needed inserts at
1167        // least one entry into the cross-call PIRLS LRU.
1168        state
1169            .compute_outer_eval_with_order(
1170                &rho,
1171                crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
1172            )
1173            .expect("outer eval should succeed");
1174
1175        let populated_len = state.cache_manager.pirls_cache.read().unwrap().map.len();
1176        assert!(
1177            populated_len > 0,
1178            "evaluating the outer objective should populate the PIRLS LRU, got {populated_len}"
1179        );
1180
1181        state.reset_outer_seed_state();
1182
1183        let cleared_len = state.cache_manager.pirls_cache.read().unwrap().map.len();
1184        assert_eq!(
1185            cleared_len, 0,
1186            "reset_outer_seed_state must clear the cross-call PIRLS LRU; got {cleared_len} entries"
1187        );
1188    }
1189
1190    #[test]
1191    pub(crate) fn reset_outer_seed_state_preserves_frozen_negbin_theta_1448() {
1192        // #1448 regression: the NB outer θ↔λ alternation loop
1193        // (solver/estimate/optimizer.rs) re-runs the ρ search after each θ
1194        // refresh by (a) re-freezing the λ-search θ at θ_final into
1195        // `frozen_negbin_theta`, then (b) calling `reset_outer_seed_state()` to
1196        // drop the caches keyed to the old θ. Step (b) MUST NOT clear the freeze
1197        // set in step (a): the capture in `solve_for_unified_rho` only writes the
1198        // frozen slot when it is 0, so if the reset zeroed it the next round would
1199        // re-derive θ from the seed η and the loop would never reach the (ρ, θ)
1200        // joint fixed point — silently regressing #1448 back to a single
1201        // freeze→refresh pass.
1202        //
1203        // This pins the load-bearing distinction between `reset_outer_seed_state`
1204        // (alternation-round reset, freeze SURVIVES) and the surface-refresh reset
1205        // (new design, freeze re-zeroed). The end-to-end convergence on a real NB
1206        // fit is exercised by the public-API path; here we lock the invariant the
1207        // loop depends on, next to `reset_outer_seed_state_clears_pirls_cache`.
1208        use std::sync::atomic::Ordering;
1209
1210        let y = array![0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
1211        let w = Array1::<f64>::ones(y.len());
1212        let x = array![
1213            [1.0, -1.0, 0.2],
1214            [1.0, -0.5, -0.4],
1215            [1.0, 0.0, 0.7],
1216            [1.0, 0.4, -0.3],
1217            [1.0, 0.9, 0.1],
1218            [1.0, 1.3, -0.6],
1219        ];
1220        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.1, 0.15], [0.0, 0.15, 0.8],];
1221        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false);
1222        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1223
1224        // Simulate the alternation loop's re-freeze step: pin θ_final.
1225        let theta_final_bits = 2.5_f64.to_bits();
1226        state
1227            .frozen_negbin_theta
1228            .store(theta_final_bits, Ordering::Relaxed);
1229        assert_eq!(
1230            state.frozen_negbin_theta.load(Ordering::Relaxed),
1231            theta_final_bits,
1232            "precondition: the re-freeze stores θ_final into the frozen slot"
1233        );
1234
1235        // The alternation loop's per-round reset.
1236        state.reset_outer_seed_state();
1237
1238        assert_eq!(
1239            state.frozen_negbin_theta.load(Ordering::Relaxed),
1240            theta_final_bits,
1241            "reset_outer_seed_state (alternation-round reset) must PRESERVE the \
1242             re-frozen NB θ; clearing it would defeat the #1448 θ↔λ alternation \
1243             (the next ρ search would re-derive θ from the seed and never reach \
1244             the joint fixed point)"
1245        );
1246    }
1247
1248    #[test]
1249    pub(crate) fn implicit_hyper_design_derivative_respects_full_model_embedding() {
1250        let operator = ImplicitDesignPsiDerivative::new(
1251            array![1.0, 2.0, 3.0, 4.0],
1252            array![0.5, -1.0, 1.5, 2.0],
1253            array![0.1, 0.2, 0.3, 0.4],
1254            array![[1.0, 0.2], [0.5, 0.1], [1.5, 0.3], [2.0, 0.4]],
1255            None,
1256            None,
1257            2,
1258            2,
1259            1,
1260            2,
1261        );
1262        let local = operator
1263            .materialize_first(0)
1264            .expect("materialized first derivative");
1265        assert_eq!(
1266            local.ncols(),
1267            3,
1268            "operator-local derivative should stay smooth-local"
1269        );
1270
1271        let implicit = HyperDesignDerivative::from_implicit(
1272            Arc::new(operator),
1273            ImplicitDerivLevel::First(0),
1274            1..4,
1275            5,
1276        );
1277        let embedded = HyperDesignDerivative::from_embedded(local.clone(), 1..4, 5);
1278
1279        assert_eq!(implicit.nrows(), embedded.nrows());
1280        assert_eq!(implicit.ncols(), 5);
1281        assert_eq!(implicit.materialize(), embedded.materialize());
1282
1283        let u = array![7.0, 1.5, -2.0, 0.25, -3.0];
1284        let v = array![0.75, -1.25];
1285        assert_eq!(
1286            implicit.forward_mul_original(&u).expect("implicit forward"),
1287            embedded.forward_mul_original(&u).expect("embedded forward")
1288        );
1289        assert_eq!(
1290            implicit
1291                .transpose_mul_original(&v)
1292                .expect("implicit transpose"),
1293            embedded
1294                .transpose_mul_original(&v)
1295                .expect("embedded transpose")
1296        );
1297
1298        let qs = array![
1299            [1.0, 0.0, 0.0],
1300            [0.0, 1.0, 0.0],
1301            [0.0, 0.5, 0.5],
1302            [0.0, 0.0, 1.0],
1303            [0.0, 0.0, 0.0],
1304        ];
1305        assert_eq!(
1306            implicit
1307                .transformed(&qs, None)
1308                .expect("implicit transformed"),
1309            embedded
1310                .transformed(&qs, None)
1311                .expect("embedded transformed")
1312        );
1313        let u_transformed = array![1.0, -0.5, 2.0];
1314        assert_eq!(
1315            implicit
1316                .transformed_forward_mul(&qs, None, &u_transformed)
1317                .expect("implicit transformed forward"),
1318            embedded
1319                .transformed_forward_mul(&qs, None, &u_transformed)
1320                .expect("embedded transformed forward")
1321        );
1322        assert_eq!(
1323            implicit
1324                .transformed_transpose_mul(&qs, None, &v)
1325                .expect("implicit transformed transpose"),
1326            embedded
1327                .transformed_transpose_mul(&qs, None, &v)
1328                .expect("embedded transformed transpose")
1329        );
1330    }
1331
1332    #[test]
1333    pub(crate) fn directional_hyper_identities_match_finite_differences_logit() {
1334        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1335        let w = Array1::<f64>::ones(y.len());
1336        let x = array![
1337            [1.0, -1.2, 0.3],
1338            [1.0, -0.8, -0.4],
1339            [1.0, -0.3, 0.7],
1340            [1.0, 0.1, -0.9],
1341            [1.0, 0.5, 0.2],
1342            [1.0, 0.9, -0.1],
1343            [1.0, 1.3, 0.8],
1344            [1.0, 1.7, -0.6],
1345        ];
1346        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1347
1348        // Use one directional hyperparameter τ with a penalty perturbation:
1349        // S(τ) = S + τ S_τ.
1350        // Keep X_τ = 0 so this identity test remains valid in both non-Firth
1351        // and Firth-logit modes.
1352        let x_tau = Array2::<f64>::zeros(x.raw_dim());
1353        let s_tau = array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15],];
1354        let hyper =
1355            DirectionalHyperParam::single_penalty(0, x_tau.clone(), s_tau.clone(), None, None)
1356                .expect("single-penalty hyper direction");
1357        let rho = array![0.0];
1358
1359        // Tight inner tolerance: the envelope theorem requires an exact inner
1360        // P-IRLS optimum; 1e-10 leaves enough residual gradient to cause ~12%
1361        // V_tau mismatch on this small (n=8) logistic problem.
1362        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
1363        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1364        let bundle = state.obtain_eval_bundle(&rho).expect("bundle");
1365        let pr = bundle.pirls_result.as_ref();
1366
1367        let beta = beta_original_from_bundle(&bundle);
1368        let h_orig = h_original_from_bundle(&bundle);
1369        let u = &pr.solveweights * &(&pr.solveworking_response - &pr.final_eta);
1370
1371        // B from implicit solve:
1372        //   H B = X_τ^T g - X^T W(X_τ β̂) - S_τ β̂.
1373        let x_tau_beta = gam_linalg::faer_ndarray::fast_av(&x_tau, &beta);
1374        let weighted_x_tau_beta = &pr.finalweights * &x_tau_beta;
1375        let rhs = gam_linalg::faer_ndarray::fast_atv(&x_tau, &u)
1376            - gam_linalg::faer_ndarray::fast_atv(&x, &weighted_x_tau_beta)
1377            - s_tau.dot(&beta);
1378        let chol = h_orig.cholesky(Side::Lower).expect("chol(H)");
1379        let b_analytic = chol.solvevec(&rhs);
1380
1381        // H_τ from exact total derivative:
1382        //   H_τ = X_τ^T W X + X^T W X_τ + X^T W_τ X + S_τ,
1383        // with W_τ provided by the family directional curvature callback.
1384        let eta_dot = &x_tau_beta + &gam_linalg::faer_ndarray::fast_av(&x, &b_analytic);
1385        let w_direction = crate::pirls::directionalworking_curvature_from_c_array(
1386            &pr.solve_c_array,
1387            &pr.finalweights,
1388            &eta_dot,
1389        );
1390        let wx = RemlState::row_scale(&x, &pr.finalweights);
1391        let wx_tau = RemlState::row_scale(&x_tau, &pr.finalweights);
1392        let mut xwtau_x = x.clone();
1393        match w_direction {
1394            crate::pirls::DirectionalWorkingCurvature::Diagonal(diag) => {
1395                xwtau_x = RemlState::row_scale(&xwtau_x, &diag);
1396            }
1397        }
1398        let mut h_tau_analytic = gam_linalg::faer_ndarray::fast_atb(&x_tau, &wx);
1399        h_tau_analytic += &gam_linalg::faer_ndarray::fast_atb(&x, &wx_tau);
1400        h_tau_analytic += &gam_linalg::faer_ndarray::fast_atb(&x, &xwtau_x);
1401        h_tau_analytic += &s_tau;
1402
1403        // Fit-block stationarity cancellation:
1404        //   -ℓ_β^T B + β̂^T S B = 0.
1405        // Here S is the effective penalty in the inner Hessian surface:
1406        //   S = H - X^T W X.
1407        let ell_beta = gam_linalg::faer_ndarray::fast_atv(&x, &u);
1408        let s_eff = &h_orig - &gam_linalg::faer_ndarray::fast_atb(&x, &wx);
1409        let cancellation = -ell_beta.dot(&b_analytic) + beta.dot(&s_eff.dot(&b_analytic));
1410
1411        // Finite differences in τ against re-fit objective and mode.
1412        let h = 2e-5;
1413        let x_plus = &x + &(x_tau.mapv(|v| h * v));
1414        let x_minus = &x - &(x_tau.mapv(|v| h * v));
1415        let s_plus = &s0 + &(s_tau.mapv(|v| h * v));
1416        let s_minus = &s0 - &(s_tau.mapv(|v| h * v));
1417
1418        let state_plus = build_logit_state(&y, &w, &x_plus, &s_plus, &cfg);
1419        let state_minus = build_logit_state(&y, &w, &x_minus, &s_minus, &cfg);
1420        let bundle_plus = state_plus.obtain_eval_bundle(&rho).expect("bundle+");
1421        let bundle_minus = state_minus.obtain_eval_bundle(&rho).expect("bundle-");
1422        let beta_plus = beta_original_from_bundle(&bundle_plus);
1423        let beta_minus = beta_original_from_bundle(&bundle_minus);
1424        let bfd = (&beta_plus - &beta_minus).mapv(|v| v / (2.0 * h));
1425
1426        let h_plus = h_original_from_bundle(&bundle_plus);
1427        let h_minus = h_original_from_bundle(&bundle_minus);
1428        let h_taufd = (&h_plus - &h_minus).mapv(|v| v / (2.0 * h));
1429
1430        let v_plus = state_plus.compute_cost(&rho).expect("cost+");
1431        let v_minus = state_minus.compute_cost(&rho).expect("cost-");
1432        let v_taufd = (v_plus - v_minus) / (2.0 * h);
1433
1434        let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper.clone())
1435            .expect("analytic directional gradient");
1436
1437        let b_num = (&b_analytic - &bfd).mapv(|v| v * v).sum().sqrt();
1438        let b_den = bfd.mapv(|v| v * v).sum().sqrt().max(1e-12);
1439        let b_rel = b_num / b_den;
1440        for i in 0..b_analytic.len() {
1441            assert_eq!(
1442                b_analytic[i].signum(),
1443                bfd[i].signum(),
1444                "B sign mismatch at i={i}: analytic={} fd={}",
1445                b_analytic[i],
1446                bfd[i]
1447            );
1448        }
1449        assert!(
1450            b_rel < 2e-2,
1451            "B implicit solve mismatch vs FD: rel={b_rel:.3e}, num={b_num:.3e}, den={b_den:.3e}"
1452        );
1453
1454        let dh_num = (&h_tau_analytic - &h_taufd).mapv(|v| v * v).sum().sqrt();
1455        let dh_den = h_taufd.mapv(|v| v * v).sum().sqrt().max(1e-12);
1456        let dh_rel = dh_num / dh_den;
1457        for i in 0..h_tau_analytic.nrows() {
1458            for j in 0..h_tau_analytic.ncols() {
1459                assert_eq!(
1460                    h_tau_analytic[[i, j]].signum(),
1461                    h_taufd[[i, j]].signum(),
1462                    "H_tau sign mismatch at ({i},{j}): analytic={} fd={}",
1463                    h_tau_analytic[[i, j]],
1464                    h_taufd[[i, j]]
1465                );
1466            }
1467        }
1468        assert!(
1469            dh_rel < 3e-2,
1470            "H_tau mismatch vs FD: rel={dh_rel:.3e}, num={dh_num:.3e}, den={dh_den:.3e}"
1471        );
1472
1473        let v_abs = (v_tau_analytic - v_taufd).abs();
1474        let v_rel = v_abs / v_taufd.abs().max(1e-10);
1475        assert_eq!(
1476            v_tau_analytic.signum(),
1477            v_taufd.signum(),
1478            "V_tau sign mismatch: analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
1479        );
1480        assert!(
1481            v_rel < 2e-2,
1482            "V_tau mismatch vs FD: rel={v_rel:.3e}, abs={v_abs:.3e}, analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
1483        );
1484
1485        assert!(
1486            cancellation.abs() < 1e-10,
1487            "stationarity cancellation failed: | -ell_beta^T B + beta^T S B | = {:.3e}",
1488            cancellation.abs()
1489        );
1490    }
1491
1492    #[test]
1493    pub(crate) fn firth_exacthessian_includes_analytic_tk_second_derivatives() {
1494        // Rank-deficient X: the 4th column is 2x the 2nd column.
1495        let y = array![0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0];
1496        let w = Array1::<f64>::ones(y.len());
1497        let x = array![
1498            [1.0, -1.2, 0.4, -2.4],
1499            [1.0, -0.9, -0.1, -1.8],
1500            [1.0, -0.6, 0.3, -1.2],
1501            [1.0, -0.2, -0.4, -0.4],
1502            [1.0, 0.1, 0.5, 0.2],
1503            [1.0, 0.4, -0.6, 0.8],
1504            [1.0, 0.8, 0.2, 1.6],
1505            [1.0, 1.1, -0.3, 2.2],
1506            [1.0, 1.4, 0.7, 2.8],
1507            [1.0, 1.7, -0.2, 3.4],
1508        ];
1509        let s0 = array![
1510            [0.0, 0.0, 0.0, 0.0],
1511            [0.0, 1.5, 0.2, 0.0],
1512            [0.0, 0.2, 1.0, 0.0],
1513            [0.0, 0.0, 0.0, 0.5],
1514        ];
1515        let s1 = array![
1516            [0.0, 0.0, 0.0, 0.0],
1517            [0.0, 0.8, -0.1, 0.0],
1518            [0.0, -0.1, 0.6, 0.0],
1519            [0.0, 0.0, 0.0, 0.3],
1520        ];
1521        let offset = Array1::<f64>::zeros(y.len());
1522        // Rank-deficient Firth logit needs more inner iterations to converge
1523        // tightly enough for the envelope-theorem derivative tests.
1524        let cfg =
1525            RemlConfig::external(binomial_logit_glm_spec(), 1e-9, true).with_max_iterations(500);
1526        let p = x.ncols();
1527        use crate::estimate::PenaltySpec;
1528        let specs = vec![PenaltySpec::Dense(s0), PenaltySpec::Dense(s1)];
1529        let canonical = gam_terms::construction::canonicalize_penalty_specs(&specs, &[1, 1], p, "test")
1530            .map(|(canonical, _)| canonical)
1531            .expect("canonicalize");
1532        let state = RemlState::newwith_offset(
1533            y.view(),
1534            x.clone(),
1535            w.view(),
1536            offset.view(),
1537            canonical,
1538            p,
1539            &cfg,
1540            Some(vec![1, 1]),
1541            None,
1542            None,
1543        )
1544        .expect("state");
1545        let rho = array![0.1, -0.2];
1546        assert!(
1547            state.analytic_outer_hessian_enabled(),
1548            "Firth logit should no longer disable analytic outer Hessian planning"
1549        );
1550        let outer = state
1551            .compute_outer_eval_with_order(
1552                &rho,
1553                crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
1554            )
1555            .expect("outer Hessian eval should succeed");
1556        assert!(
1557            outer.hessian.is_analytic(),
1558            "outer planner should request and return an analytic Hessian"
1559        );
1560        let bundle = state.obtain_eval_bundle(&rho).expect("exact firth bundle");
1561        let h_dense = state
1562            .compute_lamlhessian_exact_from_bundle(&rho, &bundle)
1563            .expect("Firth exact Hessian should include analytic TK second derivatives");
1564        assert_eq!(h_dense.raw_dim(), ndarray::Ix2(2, 2));
1565        assert!(
1566            h_dense.iter().all(|value| value.is_finite()),
1567            "Hessian should be finite: {h_dense:?}"
1568        );
1569    }
1570
1571    #[test]
1572    pub(crate) fn firth_outer_hessian_matches_gradient_finite_difference_with_tk_terms() {
1573        let y = array![0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
1574        let w = Array1::<f64>::ones(y.len());
1575        let x = array![
1576            [1.0, -1.0, 0.3],
1577            [1.0, -0.7, -0.2],
1578            [1.0, -0.3, 0.4],
1579            [1.0, 0.0, -0.5],
1580            [1.0, 0.2, 0.6],
1581            [1.0, 0.6, -0.4],
1582            [1.0, 0.9, 0.2],
1583            [1.0, 1.3, -0.1],
1584        ];
1585        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.1], [0.0, 0.1, 0.7],];
1586        let s1 = array![[0.0, 0.0, 0.0], [0.0, 0.4, -0.05], [0.0, -0.05, 0.9],];
1587        let cfg =
1588            RemlConfig::external(binomial_logit_glm_spec(), 1e-9, true).with_max_iterations(500);
1589        let p_dim = x.ncols();
1590        use crate::estimate::PenaltySpec;
1591        let specs = vec![PenaltySpec::Dense(s0), PenaltySpec::Dense(s1)];
1592        let canonical =
1593            gam_terms::construction::canonicalize_penalty_specs(&specs, &[1, 1], p_dim, "test")
1594                .map(|(canonical, _)| canonical)
1595                .expect("canonicalize");
1596        let offset = Array1::<f64>::zeros(y.len());
1597        let state = RemlState::newwith_offset(
1598            y.view(),
1599            x.clone(),
1600            w.view(),
1601            offset.view(),
1602            canonical,
1603            p_dim,
1604            &cfg,
1605            Some(vec![1, 1]),
1606            None,
1607            None,
1608        )
1609        .expect("state");
1610        let rho = array![0.15, -0.25];
1611        let eval = state
1612            .compute_outer_eval_with_order(
1613                &rho,
1614                crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
1615            )
1616            .expect("analytic Hessian eval");
1617        let h = match eval.hessian {
1618            HessianResult::Analytic(hessian) => hessian,
1619            HessianResult::Operator(_) | HessianResult::Unavailable => {
1620                panic!("expected dense analytic Hessian")
1621            }
1622        };
1623        let delta = 2.0e-5;
1624        for col in 0..rho.len() {
1625            let mut rp = rho.clone();
1626            let mut rm = rho.clone();
1627            rp[col] += delta;
1628            rm[col] -= delta;
1629            let gp = state
1630                .compute_outer_eval_with_order(
1631                    &rp,
1632                    crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
1633                )
1634                .expect("plus grad")
1635                .gradient;
1636            let gm = state
1637                .compute_outer_eval_with_order(
1638                    &rm,
1639                    crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
1640                )
1641                .expect("minus grad")
1642                .gradient;
1643            for row in 0..rho.len() {
1644                let fd = (gp[row] - gm[row]) / (2.0 * delta);
1645                let an = h[[row, col]];
1646                let rel = (fd - an).abs() / fd.abs().max(an.abs()).max(1e-6);
1647                assert!(
1648                    rel < 2.0e-3,
1649                    "Hessian mismatch ({row},{col}): analytic={an:.9e}, fd={fd:.9e}, rel={rel:.3e}"
1650                );
1651            }
1652        }
1653    }
1654
1655    #[test]
1656    pub(crate) fn firthgradient_lives_in_design_column_space_under_rank_deficiency() {
1657        // Rank-deficient design: col4 = 2*col2.
1658        let x = array![
1659            [1.0, -1.2, 0.4, -2.4],
1660            [1.0, -0.9, -0.1, -1.8],
1661            [1.0, -0.6, 0.3, -1.2],
1662            [1.0, -0.2, -0.4, -0.4],
1663            [1.0, 0.1, 0.5, 0.2],
1664            [1.0, 0.4, -0.6, 0.8],
1665            [1.0, 0.8, 0.2, 1.6],
1666            [1.0, 1.1, -0.3, 2.2],
1667        ];
1668        let beta = array![0.1, -0.2, 0.3, 0.05];
1669        let eta = x.dot(&beta);
1670        let op = super::RemlState::build_firth_dense_operator_for_link(
1671            &gam_problem::InverseLink::Standard(gam_problem::StandardLink::Logit),
1672            &x,
1673            &eta,
1674            ndarray::Array1::ones(x.nrows()).view(),
1675        )
1676        .expect("firth operator");
1677
1678        // Exact reduced-space Firth gradient:
1679        //   gradPhi = 0.5 Xᵀ (w' ⊙ h), with h = diag(X_r K_r X_rᵀ).
1680        let gradphi = 0.5 * x.t().dot(&(&op.w1 * &op.h_diag));
1681
1682        // Check (I - QQᵀ) gradPhi ≈ 0.
1683        let q = &op.q_basis;
1684        let proj = q.dot(&q.t().dot(&gradphi));
1685        let resid = &gradphi - &proj;
1686        let rel =
1687            resid.mapv(|v| v * v).sum().sqrt() / gradphi.mapv(|v| v * v).sum().sqrt().max(1e-12);
1688        assert!(
1689            rel < 1e-10,
1690            "Firth gradient should lie in Col(Xᵀ): rel residual={rel:.3e}"
1691        );
1692    }
1693
1694    #[test]
1695    pub(crate) fn firth_logit_directional_hypergradient_accepts_penalty_only_with_full_tk_gradient()
1696    {
1697        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1698        let w = Array1::<f64>::ones(y.len());
1699        let x = array![
1700            [1.0, -1.1, 0.2],
1701            [1.0, -0.6, -0.3],
1702            [1.0, -0.1, 0.5],
1703            [1.0, 0.3, -0.7],
1704            [1.0, 0.8, 0.1],
1705            [1.0, 1.2, -0.4],
1706        ];
1707        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1708        let hyper = DirectionalHyperParam::single_penalty(
1709            0,
1710            Array2::<f64>::zeros((x.nrows(), x.ncols())),
1711            array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1712            None,
1713            None,
1714        )
1715        .expect("single-penalty hyper direction");
1716        let rho = array![0.0];
1717        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1718        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1719        let gradient = single_directional_tau_gradient(&state, &rho, hyper)
1720            .expect("Firth penalty-only directional gradient should use analytic TK propagation");
1721        assert!(gradient.is_finite(), "gradient={gradient}");
1722        let fd = fd_directional_tau_cost_gradient(
1723            &y,
1724            &w,
1725            &x,
1726            &s0,
1727            &cfg,
1728            &rho,
1729            &Array2::<f64>::zeros((x.nrows(), x.ncols())),
1730            &array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1731        );
1732        let rel = (gradient - fd).abs() / gradient.abs().max(fd.abs()).max(1.0e-10);
1733        assert!(
1734            rel < 1.0e-3,
1735            "Firth penalty-only directional gradient mismatch: analytic={gradient:.12e}, fd={fd:.12e}, rel={rel:.3e}"
1736        );
1737
1738        let efs_hyper = DirectionalHyperParam::single_penalty(
1739            0,
1740            Array2::<f64>::zeros((x.nrows(), x.ncols())),
1741            array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1742            None,
1743            None,
1744        )
1745        .expect("single-penalty EFS hyper direction");
1746        let efs = state
1747            .compute_efs_steps_with_psi_ext(&rho, &[efs_hyper])
1748            .expect("Firth penalty-only EFS should use analytic TK propagation");
1749        assert!(efs.cost.is_finite(), "efs cost={}", efs.cost);
1750    }
1751
1752    #[test]
1753    pub(crate) fn firth_logit_directional_hypergradient_accepts_design_moving_with_full_tk_gradient()
1754     {
1755        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1756        let w = Array1::<f64>::ones(y.len());
1757        let x = array![
1758            [1.0, -1.1, 0.2],
1759            [1.0, -0.6, -0.3],
1760            [1.0, -0.1, 0.5],
1761            [1.0, 0.3, -0.7],
1762            [1.0, 0.8, 0.1],
1763            [1.0, 1.2, -0.4],
1764        ];
1765        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1766        let hyper = DirectionalHyperParam::single_penalty(
1767            0,
1768            Array2::from_elem((x.nrows(), x.ncols()), 1e-3),
1769            Array2::<f64>::zeros((x.ncols(), x.ncols())),
1770            None,
1771            None,
1772        )
1773        .expect("single-penalty hyper direction");
1774        let rho = array![0.0];
1775        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1776        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1777        let gradient = single_directional_tau_gradient(&state, &rho, hyper)
1778            .expect("Firth design-moving directional gradient should use analytic TK propagation");
1779        assert!(gradient.is_finite(), "gradient={gradient}");
1780        let x_tau = Array2::from_elem((x.nrows(), x.ncols()), 1e-3);
1781        let s_tau = Array2::<f64>::zeros((x.ncols(), x.ncols()));
1782        let fd = fd_directional_tau_cost_gradient(&y, &w, &x, &s0, &cfg, &rho, &x_tau, &s_tau);
1783        let rel = (gradient - fd).abs() / gradient.abs().max(fd.abs()).max(1.0e-10);
1784        assert!(
1785            rel < 2.0e-2,
1786            "Firth design-moving directional gradient mismatch: analytic={gradient:.12e}, fd={fd:.12e}, rel={rel:.3e}"
1787        );
1788    }
1789
1790    #[test]
1791    pub(crate) fn firth_logit_hybrid_efs_accepts_full_tk_psi_gradient() {
1792        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1793        let w = Array1::<f64>::ones(y.len());
1794        let x = array![
1795            [1.0, -1.1, 0.2],
1796            [1.0, -0.6, -0.3],
1797            [1.0, -0.1, 0.5],
1798            [1.0, 0.3, -0.7],
1799            [1.0, 0.8, 0.1],
1800            [1.0, 1.2, -0.4],
1801        ];
1802        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1803        let hyper_dirs = vec![
1804            DirectionalHyperParam::single_penalty(
1805                0,
1806                Array2::from_shape_fn((x.nrows(), x.ncols()), |(i, j)| {
1807                    1e-3 * ((i + 1) as f64) * ((j + 2) as f64)
1808                }),
1809                Array2::<f64>::zeros((x.ncols(), x.ncols())),
1810                None,
1811                None,
1812            )
1813            .expect("design-moving hyper direction"),
1814        ];
1815        let rho = array![0.0];
1816        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1817        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1818
1819        let full = state
1820            .evaluate_unified_with_psi_ext(
1821                &rho,
1822                None,
1823                crate::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient,
1824                &hyper_dirs,
1825            )
1826            .expect("full Firth psi gradient should use analytic TK propagation");
1827        assert!(full.cost.is_finite(), "full cost={}", full.cost);
1828        let full_grad = full.gradient.expect("gradient should be present");
1829        assert!(
1830            full_grad.iter().all(|value| value.is_finite()),
1831            "full gradient={full_grad:?}"
1832        );
1833
1834        let efs = state
1835            .compute_efs_steps_with_psi_ext(&rho, &hyper_dirs)
1836            .expect("hybrid EFS should use analytic TK propagation");
1837        assert!(efs.cost.is_finite(), "efs cost={}", efs.cost);
1838    }
1839
1840    #[test]
1841    pub(crate) fn joint_hyperhessianwires_mixed_blocks() {
1842        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1843        let w = Array1::<f64>::ones(y.len());
1844        let x = array![
1845            [1.0, -1.2, 0.3],
1846            [1.0, -0.8, -0.4],
1847            [1.0, -0.3, 0.7],
1848            [1.0, 0.1, -0.9],
1849            [1.0, 0.5, 0.2],
1850            [1.0, 0.9, -0.1],
1851            [1.0, 1.3, 0.8],
1852            [1.0, 1.7, -0.6],
1853        ];
1854        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1855        let cfg =
1856            RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
1857        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1858        let rho = array![0.0];
1859        let theta = array![0.0, 0.0, 0.0];
1860        let hyper_dirs = vec![
1861            DirectionalHyperParam::single_penalty(
1862                0,
1863                Array2::<f64>::zeros((x.nrows(), x.ncols())),
1864                array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
1865                None,
1866                None,
1867            )
1868            .expect("single-penalty hyper direction"),
1869            DirectionalHyperParam::single_penalty(
1870                0,
1871                Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1872                Array2::<f64>::zeros((x.ncols(), x.ncols())),
1873                None,
1874                None,
1875            )
1876            .expect("single-penalty hyper direction"),
1877        ];
1878
1879        let (_, _, h) =
1880            compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
1881                .expect("joint hyper cost+gradient+hessian");
1882        assert_eq!(h.nrows(), theta.len());
1883        assert_eq!(h.ncols(), theta.len());
1884        assert!(h.iter().all(|v| v.is_finite()));
1885        for i in 0..h.nrows() {
1886            for j in 0..i {
1887                let diff = (h[[i, j]] - h[[j, i]]).abs();
1888                assert!(
1889                    diff < 1e-6,
1890                    "joint hessian asymmetry at ({i},{j}): {diff:.3e}"
1891                );
1892            }
1893        }
1894        // Mixed block must be nontrivial for at least one supplied direction.
1895        let mixed_0 = h[[0, 1]];
1896        let mixed_1 = h[[0, 2]];
1897        assert!(
1898            mixed_0.is_finite() && mixed_1.is_finite(),
1899            "mixed blocks must be finite"
1900        );
1901    }
1902
1903    #[test]
1904    pub(crate) fn joint_tau_tau_linear_dirs_matchfd_reference_away_fromzero_psi() {
1905        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1906        let w = Array1::<f64>::ones(y.len());
1907        let x = array![
1908            [1.0, -1.2, 0.3],
1909            [1.0, -0.8, -0.4],
1910            [1.0, -0.3, 0.7],
1911            [1.0, 0.1, -0.9],
1912            [1.0, 0.5, 0.2],
1913            [1.0, 0.9, -0.1],
1914            [1.0, 1.3, 0.8],
1915            [1.0, 1.7, -0.6],
1916        ];
1917        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1918        let cfg =
1919            RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
1920        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1921        let rho = array![0.0];
1922        let psi = array![0.7, -0.4];
1923        let theta = array![rho[0], psi[0], psi[1]];
1924        let hyper_dirs = vec![
1925            DirectionalHyperParam::single_penalty(
1926                0,
1927                Array2::<f64>::zeros((x.nrows(), x.ncols())),
1928                array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
1929                None,
1930                None,
1931            )
1932            .expect("linear tau direction"),
1933            DirectionalHyperParam::single_penalty(
1934                0,
1935                Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1936                Array2::<f64>::zeros((x.ncols(), x.ncols())),
1937                None,
1938                None,
1939            )
1940            .expect("linear tau direction"),
1941        ];
1942
1943        let (_, _, h_full) =
1944            compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
1945                .expect("joint hyper cost+gradient+hessian");
1946        let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
1947
1948        // FD via physical perturbation of design/penalty matrices (matching
1949        // the V_tau FD pattern).  For column j we perturb X and S₀ along
1950        // direction j, build fresh states, and evaluate the τ-gradient for
1951        // every direction i at those perturbed states.
1952        let x_tau_mats: Vec<Array2<f64>> = vec![
1953            Array2::<f64>::zeros((x.nrows(), x.ncols())),
1954            Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1955        ];
1956        let s_tau_mats: Vec<Array2<f64>> = vec![
1957            array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15]],
1958            Array2::<f64>::zeros((x.ncols(), x.ncols())),
1959        ];
1960
1961        let h_ttfd = directional_tau_hessian_fd_reference(
1962            &y,
1963            &w,
1964            &x,
1965            &s0,
1966            &cfg,
1967            &rho,
1968            &hyper_dirs,
1969            &x_tau_mats,
1970            &s_tau_mats,
1971        );
1972
1973        let num = (&h_tt_analytic - &h_ttfd)
1974            .iter()
1975            .map(|v| v * v)
1976            .sum::<f64>()
1977            .sqrt();
1978        let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
1979        let rel = num / den;
1980        assert!(
1981            rel < 1e-4,
1982            "linear-dir joint tau-tau block deviates from FD reference away from zero psi: rel={rel:.3e}, analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
1983        );
1984    }
1985
1986    #[test]
1987    pub(crate) fn joint_hypervalidation_rejects_out_of_boundssecond_order_penalty_index() {
1988        // The hyper direction declares a second-order penalty derivative
1989        // against base penalty index 1, but the configured ρ vector has
1990        // dimension 1 (so only index 0 is valid).  The pair-callback
1991        // builder in `build_tau_penalty_derivative_data` is responsible for
1992        // validating both first- and second-order penalty indices against
1993        // `rho.len()`; this test pins that contract.
1994        //
1995        // We deliberately keep `firth_bias_reduction = true` here so the
1996        // call site exercises the full Firth/Tierney–Kadane outer pipeline:
1997        // PIRLS + ext-coord construction + pair-callback assembly.  With
1998        // analytic c/d propagation now wired in
1999        // `tk_direct_gradient_from_cd_and_design`, there is no longer any
2000        // FD-fallback rejection on this path, so the out-of-bounds error
2001        // fired by the pair-callback builder is the first failure the
2002        // joint evaluator surfaces — and that is exactly what we want this
2003        // test to assert.
2004        let y = array![0.0, 1.0, 0.0, 1.0];
2005        let w = Array1::<f64>::ones(y.len());
2006        let x = array![
2007            [1.0, -0.5, 0.2],
2008            [1.0, -0.1, -0.3],
2009            [1.0, 0.4, 0.6],
2010            [1.0, 0.9, -0.2],
2011        ];
2012        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
2013        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, true);
2014        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2015        let theta = array![0.0, 0.0];
2016        let hyper_dirs = vec![
2017            DirectionalHyperParam::new(
2018                Array2::<f64>::zeros((x.nrows(), x.ncols())),
2019                vec![(0, Array2::<f64>::zeros((x.ncols(), x.ncols())))],
2020                None,
2021                Some(vec![Some(vec![(1, Array2::<f64>::eye(x.ncols()))])]),
2022            )
2023            .expect("hyper direction with invalid second-order penalty index"),
2024        ];
2025
2026        let msg = match compute_joint_hypercostgradienthessian(&state, &theta, 1, &hyper_dirs) {
2027            Ok(_) => panic!("invalid second-order penalty index should be rejected"),
2028            Err(err) => err.to_string(),
2029        };
2030        assert!(
2031            msg.contains("out of bounds") || msg.contains("penalty_index"),
2032            "unexpected validation error: {msg}"
2033        );
2034    }
2035
2036    #[test]
2037    pub(crate) fn joint_tau_tau_analytic_matchesfd_reference() {
2038        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
2039        let w = Array1::<f64>::ones(y.len());
2040        let x = array![
2041            [1.0, -1.2, 0.3],
2042            [1.0, -0.8, -0.4],
2043            [1.0, -0.3, 0.7],
2044            [1.0, 0.1, -0.9],
2045            [1.0, 0.5, 0.2],
2046            [1.0, 0.9, -0.1],
2047            [1.0, 1.3, 0.8],
2048            [1.0, 1.7, -0.6],
2049        ];
2050        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
2051        let cfg =
2052            RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
2053        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2054        let rho = array![0.0];
2055        let psi = array![0.0, 0.0];
2056        let hyper_dirs = vec![
2057            DirectionalHyperParam::single_penalty(
2058                0,
2059                Array2::<f64>::zeros((x.nrows(), x.ncols())),
2060                array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
2061                None,
2062                None,
2063            )
2064            .expect("single-penalty hyper direction"),
2065            DirectionalHyperParam::single_penalty(
2066                0,
2067                Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
2068                Array2::<f64>::zeros((x.ncols(), x.ncols())),
2069                None,
2070                None,
2071            )
2072            .expect("single-penalty hyper direction"),
2073        ];
2074
2075        let theta = {
2076            let mut t = Array1::<f64>::zeros(rho.len() + psi.len());
2077            t.slice_mut(s![..rho.len()]).assign(&rho);
2078            t.slice_mut(s![rho.len()..]).assign(&psi);
2079            t
2080        };
2081        let (_, _, h_full) =
2082            compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
2083                .expect("joint hyper cost+gradient+hessian");
2084        let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
2085        assert_eq!(h_tt_analytic.nrows(), hyper_dirs.len());
2086        assert_eq!(h_tt_analytic.ncols(), hyper_dirs.len());
2087
2088        // FD via physical perturbation of design/penalty matrices (matching
2089        // the V_tau FD pattern).  For column j we perturb X and S₀ along
2090        // direction j, build fresh states, and evaluate the τ-gradient for
2091        // every direction i at those perturbed states.
2092        let x_tau_mats: Vec<Array2<f64>> = vec![
2093            Array2::<f64>::zeros((x.nrows(), x.ncols())),
2094            Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
2095        ];
2096        let s_tau_mats: Vec<Array2<f64>> = vec![
2097            array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15]],
2098            Array2::<f64>::zeros((x.ncols(), x.ncols())),
2099        ];
2100
2101        let h_ttfd = directional_tau_hessian_fd_reference(
2102            &y,
2103            &w,
2104            &x,
2105            &s0,
2106            &cfg,
2107            &rho,
2108            &hyper_dirs,
2109            &x_tau_mats,
2110            &s_tau_mats,
2111        );
2112
2113        let num = (&h_tt_analytic - &h_ttfd)
2114            .iter()
2115            .map(|v| v * v)
2116            .sum::<f64>()
2117            .sqrt();
2118        let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2119        let rel = num / den;
2120        assert!(
2121            rel < 1e-4,
2122            "analytic tau-tau block deviates from FD reference: rel={rel:.3e}, analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2123        );
2124    }
2125
2126    // ── Profiled Gaussian REML coverage for design-moving τ-directions ──
2127    //
2128    // The existing directional-hyper tests all use BinomialLogit, which has
2129    // DispersionHandling::Fixed.  These tests validate the profiled Gaussian
2130    // path (DispersionHandling::ProfiledGaussian) with design-moving
2131    // τ-directions, where the profiled scale φ̂ = D_p/(n−M) depends on ρ
2132    // and the envelope-theorem rescaling by (n−M)/D_p must be correct.
2133
2134    /// Shared test fixture for profiled Gaussian REML tests.
2135    pub(crate) struct GaussianRemlFixture {
2136        pub(crate) y: Array1<f64>,
2137        pub(crate) w: Array1<f64>,
2138        pub(crate) x: Array2<f64>,
2139        pub(crate) s0: Array2<f64>,
2140        pub(crate) cfg: RemlConfig,
2141        pub(crate) rho: Array1<f64>,
2142        /// Design-moving τ-direction (non-zero X_τ, zero S_τ).
2143        pub(crate) x_tau_design: Array2<f64>,
2144        /// Penalty-only τ-direction (zero X_τ, non-zero S_τ).
2145        pub(crate) s_tau_penalty: Array2<f64>,
2146    }
2147
2148    impl GaussianRemlFixture {
2149        pub(crate) fn new() -> Self {
2150            let y = array![0.5, 1.2, -0.3, 0.8, 1.1, -0.6, 0.9, 0.1, -0.2, 0.7];
2151            let x = array![
2152                [1.0, -1.2, 0.3],
2153                [1.0, -0.8, -0.4],
2154                [1.0, -0.3, 0.7],
2155                [1.0, 0.1, -0.9],
2156                [1.0, 0.5, 0.2],
2157                [1.0, 0.9, -0.1],
2158                [1.0, 1.3, 0.8],
2159                [1.0, 1.7, -0.6],
2160                [1.0, -0.5, 0.5],
2161                [1.0, 0.3, -0.3],
2162            ];
2163            Self {
2164                w: Array1::<f64>::ones(y.len()),
2165                y,
2166                x: x.clone(),
2167                s0: array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]],
2168                cfg: RemlConfig::external(gaussian_identity_glm_spec(), 1e-14, false),
2169                rho: array![0.0],
2170                x_tau_design: array![
2171                    [0.0, 1e-3, -2e-3],
2172                    [0.0, -3e-3, 1e-3],
2173                    [0.0, 2e-3, 0.5e-3],
2174                    [0.0, -1e-3, 3e-3],
2175                    [0.0, 0.5e-3, -1e-3],
2176                    [0.0, 1.5e-3, 2e-3],
2177                    [0.0, -2e-3, -0.5e-3],
2178                    [0.0, 3e-3, 1e-3],
2179                    [0.0, -0.5e-3, 2e-3],
2180                    [0.0, 1e-3, -1.5e-3],
2181                ],
2182                s_tau_penalty: array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15]],
2183            }
2184        }
2185    }
2186
2187    impl LogitDesignMotionFixture for GaussianRemlFixture {
2188        fn y(&self) -> &Array1<f64> {
2189            &self.y
2190        }
2191        fn w(&self) -> &Array1<f64> {
2192            &self.w
2193        }
2194        fn x(&self) -> &Array2<f64> {
2195            &self.x
2196        }
2197        fn s0(&self) -> &Array2<f64> {
2198            &self.s0
2199        }
2200        fn cfg(&self) -> &RemlConfig {
2201            &self.cfg
2202        }
2203        fn rho(&self) -> &Array1<f64> {
2204            &self.rho
2205        }
2206    }
2207
2208    #[test]
2209    pub(crate) fn profiled_gaussian_design_moving_gradient_matches_fd() {
2210        let f = GaussianRemlFixture::new();
2211        let state = f.state();
2212        let s_tau = Array2::<f64>::zeros((3, 3));
2213        let hyper = DirectionalHyperParam::single_penalty(
2214            0,
2215            f.x_tau_design.clone(),
2216            s_tau.clone(),
2217            None,
2218            None,
2219        )
2220        .expect("design-moving hyper direction");
2221
2222        let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2223            .expect("analytic directional gradient");
2224        let v_taufd = f.fd_directional_gradient(&f.x_tau_design, &s_tau);
2225
2226        let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2227        assert!(
2228            v_rel < 1e-3,
2229            "Gaussian REML design-moving V_tau mismatch: rel={v_rel:.3e}, \
2230             analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2231        );
2232    }
2233
2234    #[test]
2235    pub(crate) fn profiled_gaussian_penalty_only_gradient_matches_fd() {
2236        let f = GaussianRemlFixture::new();
2237        let state = f.state();
2238        let x_tau = Array2::<f64>::zeros(f.x.raw_dim());
2239        let hyper = DirectionalHyperParam::single_penalty(
2240            0,
2241            x_tau.clone(),
2242            f.s_tau_penalty.clone(),
2243            None,
2244            None,
2245        )
2246        .expect("penalty-only hyper direction");
2247
2248        let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2249            .expect("analytic directional gradient");
2250        let v_taufd = f.fd_directional_gradient(&x_tau, &f.s_tau_penalty);
2251
2252        let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2253        assert!(
2254            v_rel < 1e-3,
2255            "Gaussian REML penalty-only V_tau mismatch: rel={v_rel:.3e}, \
2256             analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2257        );
2258    }
2259
2260    #[test]
2261    pub(crate) fn profiled_gaussian_joint_hessian_matches_fd() {
2262        // Validate the ττ Hessian block under profiled Gaussian REML with
2263        // both a penalty-only and a design-moving direction.
2264        let f = GaussianRemlFixture::new();
2265        let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2266        let s_tau_0 = f.s_tau_penalty.clone();
2267        let x_tau_1 = f.x_tau_design.clone();
2268        let s_tau_1 = Array2::<f64>::zeros((3, 3));
2269
2270        let hyper_dirs = vec![
2271            DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2272                .expect("penalty-only direction"),
2273            DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2274                .expect("design-moving direction"),
2275        ];
2276
2277        let state = f.state();
2278        let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2279        theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2280        let (_, _, h_full) =
2281            compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2282                .expect("joint cost+gradient+hessian");
2283        let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2284
2285        // Finite-difference Hessian: perturb each direction, re-evaluate
2286        // gradient of all directions at perturbed states.
2287        let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2288        let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2289        let h_ttfd = directional_tau_hessian_fd_reference(
2290            &f.y,
2291            &f.w,
2292            &f.x,
2293            &f.s0,
2294            &f.cfg,
2295            &f.rho,
2296            &hyper_dirs,
2297            &x_tau_mats,
2298            &s_tau_mats,
2299        );
2300
2301        let num = (&h_tt_analytic - &h_ttfd)
2302            .iter()
2303            .map(|v| v * v)
2304            .sum::<f64>()
2305            .sqrt();
2306        let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2307        let rel = num / den;
2308        assert!(
2309            rel < 1e-4,
2310            "Gaussian REML tau-tau Hessian mismatch: rel={rel:.3e}, \
2311             analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2312        );
2313    }
2314
2315    // ── Non-Gaussian + design-motion: IFT Hessian-drift coverage ────────
2316    //
2317    // For non-Gaussian links (logit, probit, cloglog, ...), H = X'W(η)X + S
2318    // depends on β̂ through η = Xβ̂.  When ψ moves the design, the total
2319    // Hessian drift dH/dψ includes an IFT contribution from dβ̂/dψ:
2320    //
2321    //   dH/dψ = [explicit at fixed β] + X' diag(c ⊙ X(-v_i)) X
2322    //
2323    // where v_i = H⁻¹ g_i.  The standard GLM path handles this via
2324    // `hessian_derivative_correction(v_i)`.  This test validates that the
2325    // gradient is correct for logit + design-moving ψ, which would fail if
2326    // the IFT correction were missing.
2327
2328    #[test]
2329    pub(crate) fn logit_design_moving_gradient_matches_fd() {
2330        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
2331        let w = Array1::<f64>::ones(y.len());
2332        let x = array![
2333            [1.0, -1.2, 0.3],
2334            [1.0, -0.8, -0.4],
2335            [1.0, -0.3, 0.7],
2336            [1.0, 0.1, -0.9],
2337            [1.0, 0.5, 0.2],
2338            [1.0, 0.9, -0.1],
2339            [1.0, 1.3, 0.8],
2340            [1.0, 1.7, -0.6],
2341            [1.0, -0.5, 0.5],
2342            [1.0, 0.3, -0.3],
2343        ];
2344        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]];
2345        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2346        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2347        let rho = array![0.0];
2348
2349        // Design-moving direction with non-zero X_τ.
2350        let x_tau = array![
2351            [0.0, 1e-3, -2e-3],
2352            [0.0, -3e-3, 1e-3],
2353            [0.0, 2e-3, 0.5e-3],
2354            [0.0, -1e-3, 3e-3],
2355            [0.0, 0.5e-3, -1e-3],
2356            [0.0, 1.5e-3, 2e-3],
2357            [0.0, -2e-3, -0.5e-3],
2358            [0.0, 3e-3, 1e-3],
2359            [0.0, -0.5e-3, 2e-3],
2360            [0.0, 1e-3, -1.5e-3],
2361        ];
2362        let s_tau = Array2::<f64>::zeros((3, 3));
2363        let hyper =
2364            DirectionalHyperParam::single_penalty(0, x_tau.clone(), s_tau.clone(), None, None)
2365                .expect("design-moving hyper direction");
2366
2367        let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper)
2368            .expect("analytic directional gradient");
2369
2370        let h = 2e-5;
2371        let x_plus = &x + &x_tau.mapv(|v| h * v);
2372        let x_minus = &x - &x_tau.mapv(|v| h * v);
2373        let state_plus = build_logit_state(&y, &w, &x_plus, &s0, &cfg);
2374        let state_minus = build_logit_state(&y, &w, &x_minus, &s0, &cfg);
2375        let v_plus = state_plus.compute_cost(&rho).expect("cost+");
2376        let v_minus = state_minus.compute_cost(&rho).expect("cost-");
2377        let v_taufd = (v_plus - v_minus) / (2.0 * h);
2378
2379        let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2380        assert!(
2381            v_rel < 1e-3,
2382            "Logit REML design-moving V_tau mismatch: rel={v_rel:.3e}, \
2383             analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2384        );
2385    }
2386
2387    #[test]
2388    pub(crate) fn logit_design_moving_hessian_matches_fd() {
2389        // Hessian-level validation for logit + design-motion.
2390        // The IFT correction enters the trace term through
2391        // hessian_derivative_correction(v_i), so the Hessian is the most
2392        // sensitive test of whether the correction is applied correctly.
2393        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
2394        let w = Array1::<f64>::ones(y.len());
2395        let x = array![
2396            [1.0, -1.2, 0.3],
2397            [1.0, -0.8, -0.4],
2398            [1.0, -0.3, 0.7],
2399            [1.0, 0.1, -0.9],
2400            [1.0, 0.5, 0.2],
2401            [1.0, 0.9, -0.1],
2402            [1.0, 1.3, 0.8],
2403            [1.0, 1.7, -0.6],
2404            [1.0, -0.5, 0.5],
2405            [1.0, 0.3, -0.3],
2406        ];
2407        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]];
2408        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2409        let rho = array![0.0];
2410
2411        // Two directions: one penalty-only, one design-moving.
2412        let x_tau_0 = Array2::<f64>::zeros(x.raw_dim());
2413        let s_tau_0 = array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15]];
2414        let x_tau_1 = array![
2415            [0.0, 1e-3, -2e-3],
2416            [0.0, -3e-3, 1e-3],
2417            [0.0, 2e-3, 0.5e-3],
2418            [0.0, -1e-3, 3e-3],
2419            [0.0, 0.5e-3, -1e-3],
2420            [0.0, 1.5e-3, 2e-3],
2421            [0.0, -2e-3, -0.5e-3],
2422            [0.0, 3e-3, 1e-3],
2423            [0.0, -0.5e-3, 2e-3],
2424            [0.0, 1e-3, -1.5e-3],
2425        ];
2426        let s_tau_1 = Array2::<f64>::zeros((3, 3));
2427
2428        let hyper_dirs = vec![
2429            DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2430                .expect("penalty-only direction"),
2431            DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2432                .expect("design-moving direction"),
2433        ];
2434
2435        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2436        let mut theta = Array1::<f64>::zeros(rho.len() + hyper_dirs.len());
2437        theta.slice_mut(s![..rho.len()]).assign(&rho);
2438        let (_, _, h_full) =
2439            compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
2440                .expect("joint cost+gradient+hessian");
2441        let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
2442
2443        let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2444        let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2445        let h_ttfd = directional_tau_hessian_fd_reference(
2446            &y,
2447            &w,
2448            &x,
2449            &s0,
2450            &cfg,
2451            &rho,
2452            &hyper_dirs,
2453            &x_tau_mats,
2454            &s_tau_mats,
2455        );
2456
2457        let num = (&h_tt_analytic - &h_ttfd)
2458            .iter()
2459            .map(|v| v * v)
2460            .sum::<f64>()
2461            .sqrt();
2462        let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2463        let rel = num / den;
2464        assert!(
2465            rel < 1e-4,
2466            "Logit REML design-moving tau-tau Hessian mismatch: rel={rel:.3e}, \
2467             analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2468        );
2469    }
2470
2471    // ── Larger non-Gaussian + design-motion fixture (n=30, p=5) ────────
2472    //
2473    // Validates the IFT correction (hessian_derivative_correction) at a
2474    // scale large enough that the correction is numerically non-trivial:
2475    // with n=30 and p=5, the logistic Hessian W(η) is far from identity
2476    // and the IFT term dβ̂/dψ contributes meaningfully.
2477
2478    /// Shared test fixture for binomial-logit REML with design-moving
2479    /// ψ-coordinates, n=30, p=5.
2480    pub(crate) struct BinomialLogitDesignMotionFixture {
2481        pub(crate) y: Array1<f64>,
2482        pub(crate) w: Array1<f64>,
2483        pub(crate) x: Array2<f64>,
2484        pub(crate) s0: Array2<f64>,
2485        pub(crate) cfg: RemlConfig,
2486        pub(crate) rho: Array1<f64>,
2487        /// Design-moving τ-direction: non-zero X_τ, zero S_τ.
2488        pub(crate) x_tau_design: Array2<f64>,
2489        /// Penalty-only τ-direction: zero X_τ, non-zero S_τ.
2490        pub(crate) s_tau_penalty: Array2<f64>,
2491    }
2492
2493    impl BinomialLogitDesignMotionFixture {
2494        pub(crate) fn new() -> Self {
2495            // Binary response with roughly balanced classes.
2496            let y = array![
2497                1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0,
2498                1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0
2499            ];
2500            // Design matrix: intercept + 4 covariate columns with varied magnitudes.
2501            let x = array![
2502                [1.0, -1.50, 0.42, 0.88, -0.31],
2503                [1.0, -1.12, -0.65, 0.14, 1.23],
2504                [1.0, -0.80, 1.10, -0.53, 0.07],
2505                [1.0, -0.55, -0.22, 1.40, -0.90],
2506                [1.0, -0.30, 0.73, -1.05, 0.44],
2507                [1.0, -0.05, -1.33, 0.60, 0.81],
2508                [1.0, 0.18, 0.55, -0.27, -1.15],
2509                [1.0, 0.42, -0.90, 1.12, 0.33],
2510                [1.0, 0.70, 1.28, -0.78, -0.56],
2511                [1.0, 0.95, -0.18, 0.45, 1.40],
2512                [1.0, 1.20, 0.66, -1.30, -0.02],
2513                [1.0, 1.45, -1.05, 0.22, 0.68],
2514                [1.0, -1.35, 0.90, 0.55, -0.43],
2515                [1.0, -0.98, -0.40, -0.88, 1.05],
2516                [1.0, -0.62, 1.42, 0.30, -0.70],
2517                [1.0, -0.28, -0.77, -1.18, 0.52],
2518                [1.0, 0.05, 0.15, 0.95, -1.35],
2519                [1.0, 0.33, -1.20, -0.40, 0.18],
2520                [1.0, 0.60, 0.82, 1.25, -0.85],
2521                [1.0, 0.88, -0.50, -0.65, 1.10],
2522                [1.0, 1.15, 1.05, 0.10, -0.22],
2523                [1.0, -1.22, -0.95, 0.72, 0.90],
2524                [1.0, -0.75, 0.38, -1.42, 0.15],
2525                [1.0, -0.42, -1.15, 0.50, -1.08],
2526                [1.0, -0.10, 0.60, -0.15, 0.75],
2527                [1.0, 0.25, -0.28, 1.05, -0.48],
2528                [1.0, 0.52, 1.35, -0.92, 0.30],
2529                [1.0, 0.80, -0.70, 0.38, 1.20],
2530                [1.0, 1.08, 0.48, -0.60, -0.95],
2531                [1.0, 1.35, -0.55, 0.85, 0.42]
2532            ];
2533            // Penalty matrix: zero on intercept, SPD on remaining 4 columns.
2534            let s0 = array![
2535                [0.0, 0.0, 0.0, 0.0, 0.0],
2536                [0.0, 1.40, 0.15, 0.05, -0.10],
2537                [0.0, 0.15, 1.10, -0.20, 0.08],
2538                [0.0, 0.05, -0.20, 0.95, 0.12],
2539                [0.0, -0.10, 0.08, 0.12, 1.25]
2540            ];
2541            let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2542            // Design-moving direction: perturb covariate columns, leave
2543            // intercept untouched.
2544            let x_tau_design = array![
2545                [0.0, 1.2e-3, -0.8e-3, 0.5e-3, -1.5e-3],
2546                [0.0, -2.0e-3, 1.4e-3, -0.3e-3, 0.9e-3],
2547                [0.0, 0.6e-3, -1.1e-3, 1.8e-3, -0.4e-3],
2548                [0.0, -1.3e-3, 0.7e-3, -1.0e-3, 2.1e-3],
2549                [0.0, 0.9e-3, -0.5e-3, 0.2e-3, -0.8e-3],
2550                [0.0, -0.4e-3, 1.8e-3, -1.5e-3, 0.3e-3],
2551                [0.0, 1.5e-3, -1.3e-3, 0.8e-3, -1.1e-3],
2552                [0.0, -0.7e-3, 0.4e-3, -2.0e-3, 1.6e-3],
2553                [0.0, 2.2e-3, -0.9e-3, 1.3e-3, -0.6e-3],
2554                [0.0, -1.0e-3, 1.6e-3, -0.7e-3, 0.5e-3],
2555                [0.0, 0.3e-3, -2.1e-3, 1.1e-3, -1.8e-3],
2556                [0.0, -1.8e-3, 0.2e-3, -0.4e-3, 1.3e-3],
2557                [0.0, 1.1e-3, -1.5e-3, 2.0e-3, -0.2e-3],
2558                [0.0, -0.5e-3, 0.9e-3, -1.2e-3, 0.7e-3],
2559                [0.0, 1.7e-3, -0.3e-3, 0.6e-3, -2.0e-3],
2560                [0.0, -1.4e-3, 1.1e-3, -0.9e-3, 0.4e-3],
2561                [0.0, 0.8e-3, -1.7e-3, 1.5e-3, -0.1e-3],
2562                [0.0, -0.2e-3, 0.6e-3, -1.8e-3, 1.0e-3],
2563                [0.0, 1.4e-3, -0.4e-3, 0.3e-3, -1.3e-3],
2564                [0.0, -0.9e-3, 2.0e-3, -0.5e-3, 0.8e-3],
2565                [0.0, 0.5e-3, -1.0e-3, 1.6e-3, -0.7e-3],
2566                [0.0, -2.1e-3, 0.3e-3, -0.8e-3, 1.5e-3],
2567                [0.0, 0.7e-3, -1.8e-3, 0.9e-3, -0.3e-3],
2568                [0.0, -0.6e-3, 1.3e-3, -2.2e-3, 1.1e-3],
2569                [0.0, 1.9e-3, -0.7e-3, 0.4e-3, -0.9e-3],
2570                [0.0, -1.1e-3, 0.5e-3, -1.4e-3, 2.2e-3],
2571                [0.0, 0.4e-3, -1.6e-3, 1.2e-3, -0.5e-3],
2572                [0.0, -1.6e-3, 0.8e-3, -0.1e-3, 0.6e-3],
2573                [0.0, 1.3e-3, -2.2e-3, 0.7e-3, -1.4e-3],
2574                [0.0, -0.3e-3, 1.0e-3, -1.6e-3, 1.8e-3]
2575            ];
2576            // Penalty-only direction: non-zero S_τ, symmetric, zero on intercept.
2577            let s_tau_penalty = array![
2578                [0.0, 0.0, 0.0, 0.0, 0.0],
2579                [0.0, 0.30, 0.05, -0.02, 0.04],
2580                [0.0, 0.05, 0.22, 0.03, -0.01],
2581                [0.0, -0.02, 0.03, 0.18, 0.06],
2582                [0.0, 0.04, -0.01, 0.06, 0.26]
2583            ];
2584            Self {
2585                w: Array1::<f64>::ones(y.len()),
2586                y,
2587                x,
2588                s0,
2589                cfg,
2590                rho: array![0.0],
2591                x_tau_design,
2592                s_tau_penalty,
2593            }
2594        }
2595    }
2596
2597    impl LogitDesignMotionFixture for BinomialLogitDesignMotionFixture {
2598        fn y(&self) -> &Array1<f64> {
2599            &self.y
2600        }
2601        fn w(&self) -> &Array1<f64> {
2602            &self.w
2603        }
2604        fn x(&self) -> &Array2<f64> {
2605            &self.x
2606        }
2607        fn s0(&self) -> &Array2<f64> {
2608            &self.s0
2609        }
2610        fn cfg(&self) -> &RemlConfig {
2611            &self.cfg
2612        }
2613        fn rho(&self) -> &Array1<f64> {
2614            &self.rho
2615        }
2616    }
2617
2618    // ── n=30, p=5 binomial-logit design-motion gradient tests ────────
2619
2620    #[test]
2621    pub(crate) fn binomial_logit_n30_design_moving_gradient_matches_fd() {
2622        // Pure design-motion: X_τ ≠ 0, S_τ = 0.
2623        // The IFT correction is essential here: because the family is
2624        // binomial-logit, the working weights W(η) depend on β̂, so
2625        // when X moves with ψ, the implicit derivative dβ̂/dψ enters
2626        // the total Hessian drift.  Without hessian_derivative_correction
2627        // the analytic gradient would disagree with FD.
2628        let f = BinomialLogitDesignMotionFixture::new();
2629        let state = f.state();
2630        let s_tau = Array2::<f64>::zeros((5, 5));
2631        let hyper = DirectionalHyperParam::single_penalty(
2632            0,
2633            f.x_tau_design.clone(),
2634            s_tau.clone(),
2635            None,
2636            None,
2637        )
2638        .expect("design-moving hyper direction");
2639
2640        let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2641            .expect("analytic directional gradient");
2642        let v_tau_fd = f.fd_directional_gradient(&f.x_tau_design, &s_tau);
2643
2644        let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2645        assert!(
2646            v_rel < 1e-3,
2647            "Binomial-logit n=30 design-moving gradient mismatch: rel={v_rel:.3e}, \
2648             analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2649        );
2650    }
2651
2652    #[test]
2653    pub(crate) fn binomial_logit_n30_penalty_only_gradient_matches_fd() {
2654        // Penalty-only direction: X_τ = 0, S_τ ≠ 0.
2655        // Serves as a baseline: the IFT correction should still be
2656        // present (since H depends on β̂ through W(η)), but the
2657        // explicit X_τ contribution is zero.
2658        let f = BinomialLogitDesignMotionFixture::new();
2659        let state = f.state();
2660        let x_tau = Array2::<f64>::zeros(f.x.raw_dim());
2661        let hyper = DirectionalHyperParam::single_penalty(
2662            0,
2663            x_tau.clone(),
2664            f.s_tau_penalty.clone(),
2665            None,
2666            None,
2667        )
2668        .expect("penalty-only hyper direction");
2669
2670        let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2671            .expect("analytic directional gradient");
2672        let v_tau_fd = f.fd_directional_gradient(&x_tau, &f.s_tau_penalty);
2673
2674        let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2675        assert!(
2676            v_rel < 1e-3,
2677            "Binomial-logit n=30 penalty-only gradient mismatch: rel={v_rel:.3e}, \
2678             analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2679        );
2680    }
2681
2682    #[test]
2683    pub(crate) fn binomial_logit_n30_joint_design_penalty_gradient_matches_fd() {
2684        // Joint direction: both X_τ ≠ 0 and S_τ ≠ 0 simultaneously.
2685        // This is the hardest case: the analytic gradient must correctly
2686        // combine the explicit penalty drift, the explicit design drift,
2687        // and the IFT Hessian-drift correction.
2688        let f = BinomialLogitDesignMotionFixture::new();
2689        let state = f.state();
2690        let hyper = DirectionalHyperParam::single_penalty(
2691            0,
2692            f.x_tau_design.clone(),
2693            f.s_tau_penalty.clone(),
2694            None,
2695            None,
2696        )
2697        .expect("joint design+penalty hyper direction");
2698
2699        let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2700            .expect("analytic directional gradient");
2701        let v_tau_fd = f.fd_directional_gradient(&f.x_tau_design, &f.s_tau_penalty);
2702
2703        let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2704        assert!(
2705            v_rel < 1e-3,
2706            "Binomial-logit n=30 joint design+penalty gradient mismatch: rel={v_rel:.3e}, \
2707             analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2708        );
2709    }
2710
2711    #[test]
2712    pub(crate) fn binomial_logit_n30_design_moving_hessian_matches_fd() {
2713        // Hessian-level validation with two τ-directions: one
2714        // penalty-only and one design-moving.  The ττ Hessian block is
2715        // the most sensitive test of the IFT correction because errors
2716        // in the correction accumulate quadratically in the trace term.
2717        let f = BinomialLogitDesignMotionFixture::new();
2718        let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2719        let s_tau_0 = f.s_tau_penalty.clone();
2720        let x_tau_1 = f.x_tau_design.clone();
2721        let s_tau_1 = Array2::<f64>::zeros((5, 5));
2722
2723        let hyper_dirs = vec![
2724            DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2725                .expect("penalty-only direction"),
2726            DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2727                .expect("design-moving direction"),
2728        ];
2729
2730        let state = f.state();
2731        let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2732        theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2733        let (_, _, h_full) =
2734            compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2735                .expect("joint cost+gradient+hessian");
2736        let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2737
2738        let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2739        let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2740        let h_tt_fd = directional_tau_hessian_fd_reference(
2741            &f.y,
2742            &f.w,
2743            &f.x,
2744            &f.s0,
2745            &f.cfg,
2746            &f.rho,
2747            &hyper_dirs,
2748            &x_tau_mats,
2749            &s_tau_mats,
2750        );
2751
2752        let num = (&h_tt_analytic - &h_tt_fd)
2753            .iter()
2754            .map(|v| v * v)
2755            .sum::<f64>()
2756            .sqrt();
2757        let den = h_tt_fd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2758        let rel = num / den;
2759        assert!(
2760            rel < 1e-4,
2761            "Binomial-logit n=30 tau-tau Hessian mismatch: rel={rel:.3e}, \
2762             analytic={h_tt_analytic:?}, fd={h_tt_fd:?}"
2763        );
2764    }
2765
2766    #[test]
2767    pub(crate) fn binomial_logit_n30_nonzero_rho_design_moving_gradient_matches_fd() {
2768        // Validate at a non-trivial smoothing parameter ρ = log(λ) = 1.5,
2769        // so the penalty term λS is scaled up and the balance between
2770        // likelihood and penalty is different from ρ=0.
2771        let f = BinomialLogitDesignMotionFixture::new();
2772        let rho = array![1.5];
2773        let s_tau = Array2::<f64>::zeros((5, 5));
2774
2775        let state = f.state();
2776        let hyper = DirectionalHyperParam::single_penalty(
2777            0,
2778            f.x_tau_design.clone(),
2779            s_tau.clone(),
2780            None,
2781            None,
2782        )
2783        .expect("design-moving hyper direction");
2784
2785        let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper)
2786            .expect("analytic directional gradient");
2787
2788        // FD at the shifted ρ: perturb X, re-solve inner, evaluate cost.
2789        let h = 2e-5;
2790        let (state_plus, state_minus) = f.state_perturbed(&f.x_tau_design, &s_tau, h);
2791        let v_plus = state_plus.compute_cost(&rho).expect("cost+");
2792        let v_minus = state_minus.compute_cost(&rho).expect("cost-");
2793        let v_tau_fd = (v_plus - v_minus) / (2.0 * h);
2794
2795        let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2796        assert!(
2797            v_rel < 1e-3,
2798            "Binomial-logit n=30 rho=1.5 design-moving gradient mismatch: rel={v_rel:.3e}, \
2799             analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2800        );
2801    }
2802
2803    #[test]
2804    pub(crate) fn binomial_logit_n30_rank_deficient_hessian_matches_cost_fd() {
2805        // Regression lock for the `PenaltySubspaceTrace` pseudo-logdet
2806        // kernel installed by the rank-deficient LAML fix (see
2807        // `PenaltySubspaceTrace` and `intrinsic_hessian_pseudo_logdet_parts`;
2808        // since #901 the cost is the intrinsic `½ log|H_pen|₊` and the kernel
2809        // is the spectral `H_pen⁺`, exact for every drift direction).
2810        //
2811        // The sibling `binomial_logit_n30_design_moving_hessian_matches_fd`
2812        // passes pre- AND post-fix because its FD reference differentiates
2813        // the *analytic gradient* — any self-consistent (if wrong) gradient
2814        // kernel gives a self-consistent Hessian under re-differentiation,
2815        // so that test cannot distinguish full-space from projected traces.
2816        // It passed under the buggy kernel because the same leakage entered
2817        // both sides of the ratio and cancelled.
2818        //
2819        // Here we FD-differentiate `compute_cost` TWICE and compare against
2820        // the analytic Hessian.  Central second differences expose every
2821        // disagreement between `½ log|U_Sᵀ H U_S|_+` (used by the cost) and
2822        // `½ tr(G_ε(H) · Ḣ)` / `−½ tr(G_ε Ḣ_i G_ε Ḣ_j)` (the full-space
2823        // traces that the gradient and Hessian used before the projection
2824        // fix).  Under the buggy kernel the IFT correction
2825        // `D_β H[v] = X' diag(c ⊙ X v) X` leaks onto `null(S)` — X's
2826        // all-ones intercept column sits there — and that leakage enters
2827        // the analytic Hessian but not the cost's projected logdet.
2828        //
2829        // Direction mix chosen to maximise the null-space leakage pathway:
2830        //   τ_0 = penalty-only (X_τ = 0, S_τ ≠ 0)  → v_0 = H⁻¹(−S_τ β̂) is
2831        //         concentrated in range(S_+), but `D_β H[v_0]` has rows and
2832        //         columns on the intercept because `X[:,0] = 1_n`.
2833        //   τ_1 = design-moving (X_τ ≠ 0 on non-intercept columns, S_τ = 0)
2834        //         → `v_1` also picks up the intercept via `X'WX_τβ̂`, and
2835        //         the base drift `X_τᵀWX + XᵀWX_τ` straddles range(S_+) /
2836        //         null(S).
2837        // Both pure directions AND the mixed partial load the Schur correction,
2838        // so any of the three entries can catch a regression.
2839        let f = BinomialLogitDesignMotionFixture::new();
2840        let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2841        let s_tau_0 = f.s_tau_penalty.clone();
2842        let x_tau_1 = f.x_tau_design.clone();
2843        let s_tau_1 = Array2::<f64>::zeros((5, 5));
2844
2845        let hyper_dirs = vec![
2846            DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2847                .expect("penalty-only direction"),
2848            DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2849                .expect("design-moving direction"),
2850        ];
2851
2852        // Analytic Hessian block.
2853        let state = f.state();
2854        let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2855        theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2856        let (_, _, h_full) =
2857            compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2858                .expect("joint cost+gradient+hessian");
2859        let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2860
2861        // Cost-level FD reference.  Central second differences give O(h²)
2862        // accuracy; the step is sized so the physical perturbation on X / S
2863        // stays near `1e-5` (same scale as the gradient tests).
2864        const TARGET_PHYSICAL_STEP: f64 = 1e-5;
2865        let x_tau_mats = [&x_tau_0, &x_tau_1];
2866        let s_tau_mats = [&s_tau_0, &s_tau_1];
2867        let steps: [f64; 2] = {
2868            let mut steps = [0.0; 2];
2869            for (j, step) in steps.iter_mut().enumerate() {
2870                let scale = x_tau_mats[j]
2871                    .iter()
2872                    .chain(s_tau_mats[j].iter())
2873                    .fold(0.0_f64, |acc, value| acc.max(value.abs()));
2874                *step = if scale > 0.0 {
2875                    TARGET_PHYSICAL_STEP / scale
2876                } else {
2877                    TARGET_PHYSICAL_STEP
2878                };
2879            }
2880            steps
2881        };
2882
2883        // Evaluate `compute_cost` at `(a · τ_0, b · τ_1)` multipliers.
2884        let eval_cost = |a: f64, b: f64| -> f64 {
2885            let x_eval = &f.x
2886                + &x_tau_mats[0].mapv(|v| a * steps[0] * v)
2887                + &x_tau_mats[1].mapv(|v| b * steps[1] * v);
2888            let s_eval = &f.s0
2889                + &s_tau_mats[0].mapv(|v| a * steps[0] * v)
2890                + &s_tau_mats[1].mapv(|v| b * steps[1] * v);
2891            let st = build_logit_state(&f.y, &f.w, &x_eval, &s_eval, &f.cfg);
2892            st.compute_cost(&f.rho).expect("cost eval")
2893        };
2894
2895        let v_00 = eval_cost(0.0, 0.0);
2896        let v_p0 = eval_cost(1.0, 0.0);
2897        let v_m0 = eval_cost(-1.0, 0.0);
2898        let v_0p = eval_cost(0.0, 1.0);
2899        let v_0m = eval_cost(0.0, -1.0);
2900        let v_pp = eval_cost(1.0, 1.0);
2901        let v_pm = eval_cost(1.0, -1.0);
2902        let v_mp = eval_cost(-1.0, 1.0);
2903        let v_mm = eval_cost(-1.0, -1.0);
2904
2905        let h00_fd = (v_p0 - 2.0 * v_00 + v_m0) / (steps[0] * steps[0]);
2906        let h11_fd = (v_0p - 2.0 * v_00 + v_0m) / (steps[1] * steps[1]);
2907        let h01_fd = (v_pp - v_pm - v_mp + v_mm) / (4.0 * steps[0] * steps[1]);
2908
2909        let h_tt_fd = array![[h00_fd, h01_fd], [h01_fd, h11_fd]];
2910
2911        let num = (&h_tt_analytic - &h_tt_fd)
2912            .iter()
2913            .map(|v| v * v)
2914            .sum::<f64>()
2915            .sqrt();
2916        let den = h_tt_fd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2917        let rel = num / den;
2918
2919        assert!(
2920            rel < 3e-3,
2921            "Binomial-logit n=30 rank-deficient Hessian vs cost-FD mismatch: rel={rel:.3e}, \
2922             analytic={h_tt_analytic:?}, fd={h_tt_fd:?}"
2923        );
2924    }
2925}
2926
2927#[derive(Clone, Copy, Debug)]
2928pub(crate) enum RemlGeometry {
2929    DenseSpectral,
2930    SparseExactSpd,
2931}
2932
2933trait PenalizedGeometry {
2934    fn backend_kind(&self) -> GeometryBackendKind;
2935}
2936
2937#[derive(Clone)]
2938pub(crate) enum DerivativeMatrixStorage {
2939    Dense(Array2<f64>),
2940    Zero(ZeroDerivativeMatrix),
2941    Embedded(EmbeddedDerivativeMatrix),
2942    Implicit(ImplicitDerivativeOp),
2943    LatentCoord(LatentCoordDerivativeOp),
2944}
2945
2946/// Mechanical surface every `DerivativeMatrixStorage` variant must expose so
2947/// the `HyperDesignDerivative` / `HyperPenaltyDerivative` wrappers can dispatch
2948/// with a single per-call `storage_dispatch!`. Each backend owns its variant's
2949/// substantive math; the wrappers contain only one-line routing.
2950///
2951/// `design_*` variants treat the backend as an X-style operator (rows index
2952/// data, columns index coefficients); `penalty_*` variants treat the backend
2953/// as a square `p×p` penalty in the global coefficient frame. The Embedded
2954/// case is the only variant whose two views genuinely differ (local rows vs
2955/// total_dim square), which is why the two role-specific methods both live in
2956/// one trait rather than two parallel traits.
2957trait DerivativeStorageBackend {
2958    fn resident_byte_count(&self) -> usize;
2959    fn design_nrows(&self) -> usize;
2960    fn design_ncols(&self) -> usize;
2961    fn penalty_dim(&self) -> usize;
2962    fn uses_implicit_storage(&self) -> bool;
2963    fn any_nonzero(&self) -> bool;
2964    fn materialize(&self) -> Array2<f64>;
2965    fn implicit_first_axis_info(
2966        &self,
2967    ) -> Option<(
2968        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
2969        usize,
2970    )>;
2971    fn implicit_axis_count_hint(&self) -> Option<usize>;
2972    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError>;
2973    fn design_transpose_mul_original(
2974        &self,
2975        v: &Array1<f64>,
2976    ) -> Result<Array1<f64>, EstimationError>;
2977    fn design_transformed(
2978        &self,
2979        qs: &Array2<f64>,
2980        free_basis_opt: Option<&Array2<f64>>,
2981    ) -> Result<Array2<f64>, EstimationError>;
2982    /// Default materialises through `design_transformed` then `.dot(u)`;
2983    /// implicit/latent-coordinate backends override with a direct-operator
2984    /// path that skips the dense materialisation.
2985    fn design_transformed_forward_mul(
2986        &self,
2987        qs: &Array2<f64>,
2988        free_basis_opt: Option<&Array2<f64>>,
2989        u: &Array1<f64>,
2990    ) -> Result<Array1<f64>, EstimationError> {
2991        Ok(self.design_transformed(qs, free_basis_opt)?.dot(u))
2992    }
2993    /// Default materialises through `design_transformed` then `.t().dot(v)`;
2994    /// implicit/latent-coordinate backends override with a direct path.
2995    fn design_transformed_transpose_mul(
2996        &self,
2997        qs: &Array2<f64>,
2998        free_basis_opt: Option<&Array2<f64>>,
2999        v: &Array1<f64>,
3000    ) -> Result<Array1<f64>, EstimationError> {
3001        Ok(self.design_transformed(qs, free_basis_opt)?.t().dot(v))
3002    }
3003    fn penalty_transformed(
3004        &self,
3005        qs: &Array2<f64>,
3006        free_basis_opt: Option<&Array2<f64>>,
3007    ) -> Result<Array2<f64>, EstimationError>;
3008    fn penalty_scaled_add_to(
3009        &self,
3010        target: &mut Array2<f64>,
3011        amp: f64,
3012    ) -> Result<(), EstimationError>;
3013}
3014
3015/// Fans `expr` over the four `DerivativeMatrixStorage` variants in one place
3016/// so every wrapper method is a single dispatch line — the compiler enforces
3017/// exhaustiveness here, so adding a new variant produces one hard error at
3018/// this site rather than a silent miss in any of the (currently 16) ladders.
3019macro_rules! storage_dispatch {
3020    ($scrutinee:expr, $backend:ident => $body:expr) => {
3021        match $scrutinee {
3022            DerivativeMatrixStorage::Dense($backend) => $body,
3023            DerivativeMatrixStorage::Zero($backend) => $body,
3024            DerivativeMatrixStorage::Embedded($backend) => $body,
3025            DerivativeMatrixStorage::Implicit($backend) => $body,
3026            DerivativeMatrixStorage::LatentCoord($backend) => $body,
3027        }
3028    };
3029}
3030
3031#[derive(Clone)]
3032pub(crate) struct ZeroDerivativeMatrix {
3033    rows: usize,
3034    cols: usize,
3035}
3036
3037impl ZeroDerivativeMatrix {
3038    pub(crate) fn new(rows: usize, cols: usize) -> Self {
3039        Self { rows, cols }
3040    }
3041}
3042
3043/// Which derivative level the implicit operator should compute.
3044#[derive(Clone, Copy, Debug)]
3045pub enum ImplicitDerivLevel {
3046    /// ∂X/∂ψ_d
3047    First(usize),
3048    /// ∂²X/∂ψ_d²
3049    SecondDiag(usize),
3050    /// ∂²X/∂ψ_d∂ψ_e
3051    SecondCross(usize, usize),
3052}
3053
3054/// Lazy implicit operator storage: delegates matvecs to the
3055/// `ImplicitDesignPsiDerivative` and materializes dense form only on demand.
3056#[derive(Clone)]
3057pub(crate) struct ImplicitDerivativeOp {
3058    pub(crate) operator: std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3059    pub(crate) level: ImplicitDerivLevel,
3060    pub(crate) global_range: Range<usize>,
3061    pub(crate) total_dim: usize,
3062    /// Cached dense materialization (lazy, populated on first call to ops that need the full matrix).
3063    ///
3064    /// Rayon-safe: `materialize_local` calls `materialize_first` / `_second_diag`
3065    /// / `_second_cross` on the implicit basis-derivative operator, which for
3066    /// streaming bases dispatches `(0..nc).into_par_iter().for_each(...)`. A plain
3067    /// `std::sync::OnceLock` here would deadlock if `materialize_dense` were first
3068    /// called concurrently from inside another rayon par_iter — racing workers
3069    /// would park on the OnceLock's OS condvar, leaving the leader's nested
3070    /// par_iter without workers. `RayonSafeOnce` runs init lock-free.
3071    pub(crate) cached_dense: std::sync::Arc<gam_runtime::resource::RayonSafeOnce<Array2<f64>>>,
3072}
3073
3074#[derive(Clone)]
3075pub(crate) struct LatentCoordDerivativeOp {
3076    pub(crate) operator: std::sync::Arc<gam_terms::basis::LatentCoordDesignDerivative>,
3077    pub(crate) flat_axis: usize,
3078    pub(crate) global_range: Range<usize>,
3079    pub(crate) total_dim: usize,
3080    pub(crate) cached_dense: std::sync::Arc<gam_runtime::resource::RayonSafeOnce<Array2<f64>>>,
3081}
3082
3083impl LatentCoordDerivativeOp {
3084    pub(crate) fn materialize_local(&self) -> Array2<f64> {
3085        self.operator.materialize_axis(self.flat_axis).expect(
3086            "radial scalar evaluation failed during latent-coordinate derivative materialization",
3087        )
3088    }
3089
3090    pub(crate) fn materialize_dense(&self) -> &Array2<f64> {
3091        self.cached_dense.get_or_compute(|| {
3092            let local = self.materialize_local();
3093            let mut out = Array2::<f64>::zeros((local.nrows(), self.total_dim));
3094            out.slice_mut(s![.., self.global_range.clone()])
3095                .assign(&local);
3096            out
3097        })
3098    }
3099
3100    pub(crate) fn nrows(&self) -> usize {
3101        self.operator.n_data()
3102    }
3103
3104    pub(crate) fn ncols(&self) -> usize {
3105        self.total_dim
3106    }
3107
3108    pub(crate) fn transpose_mul(&self, v: &Array1<f64>) -> Array1<f64> {
3109        let local = self
3110            .operator
3111            .transpose_mul_axis(self.flat_axis, &v.view())
3112            .expect(
3113                "radial scalar evaluation failed during latent-coordinate derivative transpose_mul",
3114            );
3115        let mut out = Array1::<f64>::zeros(self.total_dim);
3116        out.slice_mut(s![self.global_range.clone()]).assign(&local);
3117        out
3118    }
3119
3120    pub(crate) fn forward_mul(&self, u: &Array1<f64>) -> Array1<f64> {
3121        let u_local = u.slice(s![self.global_range.clone()]).to_owned();
3122        self.operator
3123            .forward_mul_axis(self.flat_axis, &u_local.view())
3124            .expect(
3125                "radial scalar evaluation failed during latent-coordinate derivative forward_mul",
3126            )
3127    }
3128}
3129
3130impl ImplicitDerivativeOp {
3131    pub(crate) fn materialize_local(&self) -> Array2<f64> {
3132        match self.level {
3133            ImplicitDerivLevel::First(axis) => self.operator.materialize_first(axis).expect(
3134                "radial scalar evaluation failed during implicit derivative materialization",
3135            ),
3136            ImplicitDerivLevel::SecondDiag(axis) => {
3137                self.operator.materialize_second_diag(axis).expect(
3138                    "radial scalar evaluation failed during implicit derivative materialization",
3139                )
3140            }
3141            ImplicitDerivLevel::SecondCross(d, e) => {
3142                self.operator.materialize_second_cross(d, e).expect(
3143                    "radial scalar evaluation failed during implicit derivative materialization",
3144                )
3145            }
3146        }
3147    }
3148
3149    pub(crate) fn materialize_dense(&self) -> &Array2<f64> {
3150        self.cached_dense.get_or_compute(|| {
3151            let local = self.materialize_local();
3152            let mut out = Array2::<f64>::zeros((local.nrows(), self.total_dim));
3153            out.slice_mut(s![.., self.global_range.clone()])
3154                .assign(&local);
3155            out
3156        })
3157    }
3158
3159    pub(crate) fn nrows(&self) -> usize {
3160        self.operator.n_data()
3161    }
3162
3163    pub(crate) fn ncols(&self) -> usize {
3164        self.total_dim
3165    }
3166
3167    pub(crate) fn transpose_mul(&self, v: &Array1<f64>) -> Array1<f64> {
3168        let local = match self.level {
3169            ImplicitDerivLevel::First(axis) => self
3170                .operator
3171                .transpose_mul(axis, &v.view())
3172                .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
3173            ImplicitDerivLevel::SecondDiag(axis) => self
3174                .operator
3175                .transpose_mul_second_diag(axis, &v.view())
3176                .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
3177            ImplicitDerivLevel::SecondCross(d, e) => self
3178                .operator
3179                .transpose_mul_second_cross(d, e, &v.view())
3180                .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
3181        };
3182        let mut out = Array1::<f64>::zeros(self.total_dim);
3183        out.slice_mut(s![self.global_range.clone()]).assign(&local);
3184        out
3185    }
3186
3187    pub(crate) fn forward_mul(&self, u: &Array1<f64>) -> Array1<f64> {
3188        let u_local = u.slice(s![self.global_range.clone()]).to_owned();
3189        match self.level {
3190            ImplicitDerivLevel::First(axis) => self
3191                .operator
3192                .forward_mul(axis, &u_local.view())
3193                .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3194            ImplicitDerivLevel::SecondDiag(axis) => self
3195                .operator
3196                .forward_mul_second_diag(axis, &u_local.view())
3197                .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3198            ImplicitDerivLevel::SecondCross(d, e) => self
3199                .operator
3200                .forward_mul_second_cross(d, e, &u_local.view())
3201                .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3202        }
3203    }
3204}
3205
3206#[derive(Clone)]
3207pub(crate) struct EmbeddedDerivativeMatrix {
3208    pub(crate) local: Array2<f64>,
3209    pub(crate) global_range: Range<usize>,
3210    pub(crate) total_dim: usize,
3211}
3212
3213impl EmbeddedDerivativeMatrix {
3214    pub(crate) fn new(local: Array2<f64>, global_range: Range<usize>, total_dim: usize) -> Self {
3215        Self {
3216            local,
3217            global_range,
3218            total_dim,
3219        }
3220    }
3221}
3222
3223impl DerivativeStorageBackend for Array2<f64> {
3224    fn resident_byte_count(&self) -> usize {
3225        self.len().saturating_mul(std::mem::size_of::<f64>())
3226    }
3227    fn design_nrows(&self) -> usize {
3228        Array2::nrows(self)
3229    }
3230    fn design_ncols(&self) -> usize {
3231        Array2::ncols(self)
3232    }
3233    fn penalty_dim(&self) -> usize {
3234        Array2::nrows(self)
3235    }
3236    fn uses_implicit_storage(&self) -> bool {
3237        false
3238    }
3239    fn any_nonzero(&self) -> bool {
3240        self.iter().any(|v| *v != 0.0)
3241    }
3242    fn materialize(&self) -> Array2<f64> {
3243        self.clone()
3244    }
3245    fn implicit_first_axis_info(
3246        &self,
3247    ) -> Option<(
3248        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3249        usize,
3250    )> {
3251        None
3252    }
3253    fn implicit_axis_count_hint(&self) -> Option<usize> {
3254        None
3255    }
3256
3257    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3258        if Array2::ncols(self) != u.len() {
3259            crate::bail_invalid_estim!(
3260                "dense hyper design derivative forward_mul_original width mismatch: matrix={}x{}, vector={}",
3261                Array2::nrows(self),
3262                Array2::ncols(self),
3263                u.len()
3264            );
3265        }
3266        Ok(self.dot(u))
3267    }
3268
3269    fn design_transpose_mul_original(
3270        &self,
3271        v: &Array1<f64>,
3272    ) -> Result<Array1<f64>, EstimationError> {
3273        if Array2::nrows(self) != v.len() {
3274            crate::bail_invalid_estim!(
3275                "dense hyper design derivative transpose_mul_original height mismatch: matrix={}x{}, vector={}",
3276                Array2::nrows(self),
3277                Array2::ncols(self),
3278                v.len()
3279            );
3280        }
3281        Ok(self.t().dot(v))
3282    }
3283
3284    fn design_transformed(
3285        &self,
3286        qs: &Array2<f64>,
3287        free_basis_opt: Option<&Array2<f64>>,
3288    ) -> Result<Array2<f64>, EstimationError> {
3289        Ok(gam_linalg::matrix::DenseRightProductView::new(self)
3290            .with_factor(qs)
3291            .with_optional_factor(free_basis_opt)
3292            .materialize())
3293    }
3294
3295    fn penalty_transformed(
3296        &self,
3297        qs: &Array2<f64>,
3298        free_basis_opt: Option<&Array2<f64>>,
3299    ) -> Result<Array2<f64>, EstimationError> {
3300        let mut transformed = qs.t().dot(self).dot(qs);
3301        if let Some(z) = free_basis_opt {
3302            transformed = z.t().dot(&transformed).dot(z);
3303        }
3304        Ok(transformed)
3305    }
3306
3307    fn penalty_scaled_add_to(
3308        &self,
3309        target: &mut Array2<f64>,
3310        amp: f64,
3311    ) -> Result<(), EstimationError> {
3312        if target.raw_dim() != self.raw_dim() {
3313            crate::bail_invalid_estim!(
3314                "dense hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3315                target.nrows(),
3316                target.ncols(),
3317                Array2::nrows(self),
3318                Array2::ncols(self)
3319            );
3320        }
3321        target.scaled_add(amp, self);
3322        Ok(())
3323    }
3324}
3325
3326impl DerivativeStorageBackend for ZeroDerivativeMatrix {
3327    fn resident_byte_count(&self) -> usize {
3328        0
3329    }
3330    fn design_nrows(&self) -> usize {
3331        self.rows
3332    }
3333    fn design_ncols(&self) -> usize {
3334        self.cols
3335    }
3336    fn penalty_dim(&self) -> usize {
3337        self.cols
3338    }
3339    fn uses_implicit_storage(&self) -> bool {
3340        false
3341    }
3342    fn any_nonzero(&self) -> bool {
3343        false
3344    }
3345    fn materialize(&self) -> Array2<f64> {
3346        Array2::<f64>::zeros((self.rows, self.cols))
3347    }
3348    fn implicit_first_axis_info(
3349        &self,
3350    ) -> Option<(
3351        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3352        usize,
3353    )> {
3354        None
3355    }
3356    fn implicit_axis_count_hint(&self) -> Option<usize> {
3357        None
3358    }
3359
3360    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3361        if self.cols != u.len() {
3362            crate::bail_invalid_estim!(
3363                "zero hyper design derivative forward_mul_original width mismatch: matrix={}x{}, vector={}",
3364                self.rows,
3365                self.cols,
3366                u.len()
3367            );
3368        }
3369        Ok(Array1::<f64>::zeros(self.rows))
3370    }
3371
3372    fn design_transpose_mul_original(
3373        &self,
3374        v: &Array1<f64>,
3375    ) -> Result<Array1<f64>, EstimationError> {
3376        if self.rows != v.len() {
3377            crate::bail_invalid_estim!(
3378                "zero hyper design derivative transpose_mul_original height mismatch: matrix={}x{}, vector={}",
3379                self.rows,
3380                self.cols,
3381                v.len()
3382            );
3383        }
3384        Ok(Array1::<f64>::zeros(self.cols))
3385    }
3386
3387    fn design_transformed(
3388        &self,
3389        qs: &Array2<f64>,
3390        free_basis_opt: Option<&Array2<f64>>,
3391    ) -> Result<Array2<f64>, EstimationError> {
3392        if self.cols != qs.nrows() {
3393            crate::bail_invalid_estim!(
3394                "zero design derivative width mismatch: total_cols={}, qs rows={}",
3395                self.cols,
3396                qs.nrows()
3397            );
3398        }
3399        let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3400        Ok(Array2::<f64>::zeros((self.rows, cols)))
3401    }
3402
3403    fn design_transformed_forward_mul(
3404        &self,
3405        qs: &Array2<f64>,
3406        free_basis_opt: Option<&Array2<f64>>,
3407        u: &Array1<f64>,
3408    ) -> Result<Array1<f64>, EstimationError> {
3409        if self.cols != qs.nrows() {
3410            crate::bail_invalid_estim!(
3411                "zero design derivative width mismatch: total_cols={}, qs rows={}",
3412                self.cols,
3413                qs.nrows()
3414            );
3415        }
3416        let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3417        if u.len() != cols {
3418            crate::bail_invalid_estim!(
3419                "zero design derivative transformed forward width mismatch: expected {}, vector={}",
3420                cols,
3421                u.len()
3422            );
3423        }
3424        Ok(Array1::<f64>::zeros(self.rows))
3425    }
3426
3427    fn design_transformed_transpose_mul(
3428        &self,
3429        qs: &Array2<f64>,
3430        free_basis_opt: Option<&Array2<f64>>,
3431        v: &Array1<f64>,
3432    ) -> Result<Array1<f64>, EstimationError> {
3433        if self.rows != v.len() {
3434            crate::bail_invalid_estim!(
3435                "zero design derivative transpose height mismatch: matrix rows={}, vector={}",
3436                self.rows,
3437                v.len()
3438            );
3439        }
3440        if self.cols != qs.nrows() {
3441            crate::bail_invalid_estim!(
3442                "zero design derivative width mismatch: total_cols={}, qs rows={}",
3443                self.cols,
3444                qs.nrows()
3445            );
3446        }
3447        let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3448        Ok(Array1::<f64>::zeros(cols))
3449    }
3450
3451    fn penalty_transformed(
3452        &self,
3453        qs: &Array2<f64>,
3454        free_basis_opt: Option<&Array2<f64>>,
3455    ) -> Result<Array2<f64>, EstimationError> {
3456        if self.cols != qs.nrows() {
3457            crate::bail_invalid_estim!(
3458                "zero penalty derivative width mismatch: total_dim={}, qs rows={}",
3459                self.cols,
3460                qs.nrows()
3461            );
3462        }
3463        let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3464        Ok(Array2::<f64>::zeros((cols, cols)))
3465    }
3466
3467    fn penalty_scaled_add_to(
3468        &self,
3469        target: &mut Array2<f64>,
3470        amp: f64,
3471    ) -> Result<(), EstimationError> {
3472        // Zero penalty derivative: `amp · 0 = 0`, so `amp` scales nothing and
3473        // `target` is left unchanged. Validate it is finite so a bad scale
3474        // surfaces here rather than silently no-op'ing on a NaN/inf amplitude.
3475        if !amp.is_finite() {
3476            crate::bail_invalid_estim!(
3477                "zero hyper penalty derivative received non-finite amp={amp}"
3478            );
3479        }
3480        if target.nrows() != self.cols || target.ncols() != self.cols {
3481            crate::bail_invalid_estim!(
3482                "zero hyper penalty derivative shape mismatch: target={}x{}, expected {}x{}",
3483                target.nrows(),
3484                target.ncols(),
3485                self.cols,
3486                self.cols
3487            );
3488        }
3489        Ok(())
3490    }
3491}
3492
3493impl DerivativeStorageBackend for EmbeddedDerivativeMatrix {
3494    fn resident_byte_count(&self) -> usize {
3495        self.local.len().saturating_mul(std::mem::size_of::<f64>())
3496    }
3497    fn design_nrows(&self) -> usize {
3498        self.local.nrows()
3499    }
3500    fn design_ncols(&self) -> usize {
3501        self.total_dim
3502    }
3503    fn penalty_dim(&self) -> usize {
3504        self.total_dim
3505    }
3506    fn uses_implicit_storage(&self) -> bool {
3507        false
3508    }
3509    fn any_nonzero(&self) -> bool {
3510        self.local.iter().any(|v| *v != 0.0)
3511    }
3512    fn materialize(&self) -> Array2<f64> {
3513        let mut dense = Array2::<f64>::zeros((self.local.nrows(), self.total_dim));
3514        dense
3515            .slice_mut(s![.., self.global_range.clone()])
3516            .assign(&self.local);
3517        dense
3518    }
3519    fn implicit_first_axis_info(
3520        &self,
3521    ) -> Option<(
3522        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3523        usize,
3524    )> {
3525        None
3526    }
3527    fn implicit_axis_count_hint(&self) -> Option<usize> {
3528        None
3529    }
3530
3531    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3532        if self.total_dim != u.len() {
3533            crate::bail_invalid_estim!(
3534                "embedded hyper design derivative forward_mul_original width mismatch: total_dim={}, vector={}",
3535                self.total_dim,
3536                u.len()
3537            );
3538        }
3539        let u_local = u.slice(s![self.global_range.clone()]).to_owned();
3540        Ok(self.local.dot(&u_local))
3541    }
3542
3543    fn design_transpose_mul_original(
3544        &self,
3545        v: &Array1<f64>,
3546    ) -> Result<Array1<f64>, EstimationError> {
3547        if self.local.nrows() != v.len() {
3548            crate::bail_invalid_estim!(
3549                "embedded hyper design derivative transpose_mul_original height mismatch: local_rows={}, vector={}",
3550                self.local.nrows(),
3551                v.len()
3552            );
3553        }
3554        let mut out = Array1::<f64>::zeros(self.total_dim);
3555        let pulled = self.local.t().dot(v);
3556        out.slice_mut(s![self.global_range.clone()]).assign(&pulled);
3557        Ok(out)
3558    }
3559
3560    fn design_transformed(
3561        &self,
3562        qs: &Array2<f64>,
3563        free_basis_opt: Option<&Array2<f64>>,
3564    ) -> Result<Array2<f64>, EstimationError> {
3565        if self.total_dim != qs.nrows() {
3566            crate::bail_invalid_estim!(
3567                "embedded design derivative width mismatch: total_cols={}, qs rows={}",
3568                self.total_dim,
3569                qs.nrows()
3570            );
3571        }
3572        let qs_local = qs.slice(s![self.global_range.clone(), ..]);
3573        let mut transformed = self.local.dot(&qs_local);
3574        if let Some(z) = free_basis_opt {
3575            transformed = transformed.dot(z);
3576        }
3577        Ok(transformed)
3578    }
3579
3580    fn penalty_transformed(
3581        &self,
3582        qs: &Array2<f64>,
3583        free_basis_opt: Option<&Array2<f64>>,
3584    ) -> Result<Array2<f64>, EstimationError> {
3585        if self.total_dim != qs.nrows() {
3586            crate::bail_invalid_estim!(
3587                "embedded penalty derivative width mismatch: total_dim={}, qs rows={}",
3588                self.total_dim,
3589                qs.nrows()
3590            );
3591        }
3592        let qs_local = qs.slice(s![self.global_range.clone(), ..]);
3593        let mut transformed = qs_local.t().dot(&self.local).dot(&qs_local);
3594        if let Some(z) = free_basis_opt {
3595            transformed = z.t().dot(&transformed).dot(z);
3596        }
3597        Ok(transformed)
3598    }
3599
3600    fn penalty_scaled_add_to(
3601        &self,
3602        target: &mut Array2<f64>,
3603        amp: f64,
3604    ) -> Result<(), EstimationError> {
3605        if target.nrows() != self.total_dim || target.ncols() != self.total_dim {
3606            crate::bail_invalid_estim!(
3607                "embedded hyper penalty derivative shape mismatch: target={}x{}, expected {}x{}",
3608                target.nrows(),
3609                target.ncols(),
3610                self.total_dim,
3611                self.total_dim
3612            );
3613        }
3614        target
3615            .slice_mut(s![self.global_range.clone(), self.global_range.clone()])
3616            .scaled_add(amp, &self.local);
3617        Ok(())
3618    }
3619}
3620
3621impl DerivativeStorageBackend for ImplicitDerivativeOp {
3622    fn resident_byte_count(&self) -> usize {
3623        0
3624    }
3625    fn design_nrows(&self) -> usize {
3626        self.nrows()
3627    }
3628    fn design_ncols(&self) -> usize {
3629        self.ncols()
3630    }
3631    fn penalty_dim(&self) -> usize {
3632        self.nrows()
3633    }
3634    fn uses_implicit_storage(&self) -> bool {
3635        true
3636    }
3637    fn any_nonzero(&self) -> bool {
3638        true
3639    }
3640    fn materialize(&self) -> Array2<f64> {
3641        self.materialize_dense().clone()
3642    }
3643    fn implicit_first_axis_info(
3644        &self,
3645    ) -> Option<(
3646        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3647        usize,
3648    )> {
3649        match self.level {
3650            ImplicitDerivLevel::First(axis) => Some((self.operator.clone(), axis)),
3651            _ => None,
3652        }
3653    }
3654    fn implicit_axis_count_hint(&self) -> Option<usize> {
3655        Some(self.operator.n_axes())
3656    }
3657
3658    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3659        if self.ncols() != u.len() {
3660            crate::bail_invalid_estim!(
3661                "implicit hyper design derivative forward_mul_original width mismatch: operator_cols={}, vector={}",
3662                self.ncols(),
3663                u.len()
3664            );
3665        }
3666        Ok(self.forward_mul(u))
3667    }
3668
3669    fn design_transpose_mul_original(
3670        &self,
3671        v: &Array1<f64>,
3672    ) -> Result<Array1<f64>, EstimationError> {
3673        if self.nrows() != v.len() {
3674            crate::bail_invalid_estim!(
3675                "implicit hyper design derivative transpose_mul_original height mismatch: operator_rows={}, vector={}",
3676                self.nrows(),
3677                v.len()
3678            );
3679        }
3680        Ok(self.transpose_mul(v))
3681    }
3682
3683    fn design_transformed(
3684        &self,
3685        qs: &Array2<f64>,
3686        free_basis_opt: Option<&Array2<f64>>,
3687    ) -> Result<Array2<f64>, EstimationError> {
3688        let dense = self.materialize_dense();
3689        Ok(gam_linalg::matrix::DenseRightProductView::new(dense)
3690            .with_factor(qs)
3691            .with_optional_factor(free_basis_opt)
3692            .materialize())
3693    }
3694
3695    fn design_transformed_forward_mul(
3696        &self,
3697        qs: &Array2<f64>,
3698        free_basis_opt: Option<&Array2<f64>>,
3699        u: &Array1<f64>,
3700    ) -> Result<Array1<f64>, EstimationError> {
3701        let mut right = if let Some(z) = free_basis_opt {
3702            z.dot(u)
3703        } else {
3704            u.clone()
3705        };
3706        right = qs.dot(&right);
3707        Ok(self.forward_mul(&right))
3708    }
3709
3710    fn design_transformed_transpose_mul(
3711        &self,
3712        qs: &Array2<f64>,
3713        free_basis_opt: Option<&Array2<f64>>,
3714        v: &Array1<f64>,
3715    ) -> Result<Array1<f64>, EstimationError> {
3716        let mut pulled = qs.t().dot(&self.transpose_mul(v));
3717        if let Some(z) = free_basis_opt {
3718            pulled = z.t().dot(&pulled);
3719        }
3720        Ok(pulled)
3721    }
3722
3723    fn penalty_transformed(
3724        &self,
3725        qs: &Array2<f64>,
3726        free_basis_opt: Option<&Array2<f64>>,
3727    ) -> Result<Array2<f64>, EstimationError> {
3728        let dense = self.materialize_dense();
3729        let mut transformed = qs.t().dot(dense).dot(qs);
3730        if let Some(z) = free_basis_opt {
3731            transformed = z.t().dot(&transformed).dot(z);
3732        }
3733        Ok(transformed)
3734    }
3735
3736    fn penalty_scaled_add_to(
3737        &self,
3738        target: &mut Array2<f64>,
3739        amp: f64,
3740    ) -> Result<(), EstimationError> {
3741        let dense = self.materialize_dense();
3742        if target.raw_dim() != dense.raw_dim() {
3743            crate::bail_invalid_estim!(
3744                "implicit hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3745                target.nrows(),
3746                target.ncols(),
3747                dense.nrows(),
3748                dense.ncols()
3749            );
3750        }
3751        target.scaled_add(amp, dense);
3752        Ok(())
3753    }
3754}
3755
3756impl DerivativeStorageBackend for LatentCoordDerivativeOp {
3757    fn resident_byte_count(&self) -> usize {
3758        0
3759    }
3760    fn design_nrows(&self) -> usize {
3761        self.nrows()
3762    }
3763    fn design_ncols(&self) -> usize {
3764        self.ncols()
3765    }
3766    fn penalty_dim(&self) -> usize {
3767        self.nrows()
3768    }
3769    fn uses_implicit_storage(&self) -> bool {
3770        true
3771    }
3772    fn any_nonzero(&self) -> bool {
3773        true
3774    }
3775    fn materialize(&self) -> Array2<f64> {
3776        self.materialize_dense().clone()
3777    }
3778    fn implicit_first_axis_info(
3779        &self,
3780    ) -> Option<(
3781        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3782        usize,
3783    )> {
3784        None
3785    }
3786    fn implicit_axis_count_hint(&self) -> Option<usize> {
3787        Some(self.operator.n_axes())
3788    }
3789
3790    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3791        if self.ncols() != u.len() {
3792            crate::bail_invalid_estim!(
3793                "latent-coordinate hyper design derivative forward_mul_original width mismatch: operator_cols={}, vector={}",
3794                self.ncols(),
3795                u.len()
3796            );
3797        }
3798        Ok(self.forward_mul(u))
3799    }
3800
3801    fn design_transpose_mul_original(
3802        &self,
3803        v: &Array1<f64>,
3804    ) -> Result<Array1<f64>, EstimationError> {
3805        if self.nrows() != v.len() {
3806            crate::bail_invalid_estim!(
3807                "latent-coordinate hyper design derivative transpose_mul_original height mismatch: operator_rows={}, vector={}",
3808                self.nrows(),
3809                v.len()
3810            );
3811        }
3812        Ok(self.transpose_mul(v))
3813    }
3814
3815    fn design_transformed(
3816        &self,
3817        qs: &Array2<f64>,
3818        free_basis_opt: Option<&Array2<f64>>,
3819    ) -> Result<Array2<f64>, EstimationError> {
3820        let dense = self.materialize_dense();
3821        Ok(gam_linalg::matrix::DenseRightProductView::new(dense)
3822            .with_factor(qs)
3823            .with_optional_factor(free_basis_opt)
3824            .materialize())
3825    }
3826
3827    fn design_transformed_forward_mul(
3828        &self,
3829        qs: &Array2<f64>,
3830        free_basis_opt: Option<&Array2<f64>>,
3831        u: &Array1<f64>,
3832    ) -> Result<Array1<f64>, EstimationError> {
3833        let mut right = if let Some(z) = free_basis_opt {
3834            z.dot(u)
3835        } else {
3836            u.clone()
3837        };
3838        right = qs.dot(&right);
3839        Ok(self.forward_mul(&right))
3840    }
3841
3842    fn design_transformed_transpose_mul(
3843        &self,
3844        qs: &Array2<f64>,
3845        free_basis_opt: Option<&Array2<f64>>,
3846        v: &Array1<f64>,
3847    ) -> Result<Array1<f64>, EstimationError> {
3848        let mut pulled = qs.t().dot(&self.transpose_mul(v));
3849        if let Some(z) = free_basis_opt {
3850            pulled = z.t().dot(&pulled);
3851        }
3852        Ok(pulled)
3853    }
3854
3855    fn penalty_transformed(
3856        &self,
3857        qs: &Array2<f64>,
3858        free_basis_opt: Option<&Array2<f64>>,
3859    ) -> Result<Array2<f64>, EstimationError> {
3860        let dense = self.materialize_dense();
3861        let mut transformed = qs.t().dot(dense).dot(qs);
3862        if let Some(z) = free_basis_opt {
3863            transformed = z.t().dot(&transformed).dot(z);
3864        }
3865        Ok(transformed)
3866    }
3867
3868    fn penalty_scaled_add_to(
3869        &self,
3870        target: &mut Array2<f64>,
3871        amp: f64,
3872    ) -> Result<(), EstimationError> {
3873        let dense = self.materialize_dense();
3874        if target.raw_dim() != dense.raw_dim() {
3875            crate::bail_invalid_estim!(
3876                "latent-coordinate hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3877                target.nrows(),
3878                target.ncols(),
3879                dense.nrows(),
3880                dense.ncols()
3881            );
3882        }
3883        target.scaled_add(amp, dense);
3884        Ok(())
3885    }
3886}
3887
3888#[derive(Clone)]
3889pub struct HyperDesignDerivative {
3890    pub(crate) storage: DerivativeMatrixStorage,
3891}
3892
3893impl HyperDesignDerivative {
3894    pub fn zero(nrows: usize, ncols: usize) -> Self {
3895        Self {
3896            storage: DerivativeMatrixStorage::Zero(ZeroDerivativeMatrix::new(nrows, ncols)),
3897        }
3898    }
3899
3900    pub fn from_embedded(
3901        local: Array2<f64>,
3902        global_range: Range<usize>,
3903        total_cols: usize,
3904    ) -> Self {
3905        Self {
3906            storage: DerivativeMatrixStorage::Embedded(EmbeddedDerivativeMatrix::new(
3907                local,
3908                global_range,
3909                total_cols,
3910            )),
3911        }
3912    }
3913
3914    pub fn from_implicit(
3915        operator: std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3916        level: ImplicitDerivLevel,
3917        global_range: Range<usize>,
3918        total_cols: usize,
3919    ) -> Self {
3920        Self {
3921            storage: DerivativeMatrixStorage::Implicit(ImplicitDerivativeOp {
3922                operator,
3923                level,
3924                global_range,
3925                total_dim: total_cols,
3926                cached_dense: std::sync::Arc::new(gam_runtime::resource::RayonSafeOnce::new()),
3927            }),
3928        }
3929    }
3930
3931    pub fn from_latent_coord(
3932        operator: std::sync::Arc<gam_terms::basis::LatentCoordDesignDerivative>,
3933        flat_axis: usize,
3934        global_range: Range<usize>,
3935        total_cols: usize,
3936    ) -> Self {
3937        Self {
3938            storage: DerivativeMatrixStorage::LatentCoord(LatentCoordDerivativeOp {
3939                operator,
3940                flat_axis,
3941                global_range,
3942                total_dim: total_cols,
3943                cached_dense: std::sync::Arc::new(gam_runtime::resource::RayonSafeOnce::new()),
3944            }),
3945        }
3946    }
3947
3948    pub(crate) fn resident_byte_count(&self) -> usize {
3949        storage_dispatch!(&self.storage, b => b.resident_byte_count())
3950    }
3951
3952    pub(crate) fn nrows(&self) -> usize {
3953        storage_dispatch!(&self.storage, b => b.design_nrows())
3954    }
3955
3956    pub(crate) fn ncols(&self) -> usize {
3957        storage_dispatch!(&self.storage, b => b.design_ncols())
3958    }
3959
3960    pub(crate) fn uses_implicit_storage(&self) -> bool {
3961        storage_dispatch!(&self.storage, b => b.uses_implicit_storage())
3962    }
3963
3964    pub(crate) fn materialize(&self) -> Array2<f64> {
3965        storage_dispatch!(&self.storage, b => b.materialize())
3966    }
3967
3968    pub(crate) fn any_nonzero(&self) -> bool {
3969        storage_dispatch!(&self.storage, b => b.any_nonzero())
3970    }
3971
3972    pub(crate) fn forward_mul_original(
3973        &self,
3974        u: &Array1<f64>,
3975    ) -> Result<Array1<f64>, EstimationError> {
3976        storage_dispatch!(&self.storage, b => b.design_forward_mul_original(u))
3977    }
3978
3979    pub(crate) fn transpose_mul_original(
3980        &self,
3981        v: &Array1<f64>,
3982    ) -> Result<Array1<f64>, EstimationError> {
3983        storage_dispatch!(&self.storage, b => b.design_transpose_mul_original(v))
3984    }
3985
3986    pub(crate) fn transformed(
3987        &self,
3988        qs: &Array2<f64>,
3989        free_basis_opt: Option<&Array2<f64>>,
3990    ) -> Result<Array2<f64>, EstimationError> {
3991        storage_dispatch!(&self.storage, b => b.design_transformed(qs, free_basis_opt))
3992    }
3993
3994    pub(crate) fn transformed_forward_mul(
3995        &self,
3996        qs: &Array2<f64>,
3997        free_basis_opt: Option<&Array2<f64>>,
3998        u: &Array1<f64>,
3999    ) -> Result<Array1<f64>, EstimationError> {
4000        storage_dispatch!(&self.storage, b => b.design_transformed_forward_mul(qs, free_basis_opt, u))
4001    }
4002
4003    pub(crate) fn transformed_transpose_mul(
4004        &self,
4005        qs: &Array2<f64>,
4006        free_basis_opt: Option<&Array2<f64>>,
4007        v: &Array1<f64>,
4008    ) -> Result<Array1<f64>, EstimationError> {
4009        storage_dispatch!(&self.storage, b => b.design_transformed_transpose_mul(qs, free_basis_opt, v))
4010    }
4011
4012    /// If this derivative uses implicit storage at the first-derivative level,
4013    /// return the shared implicit operator and the axis index.
4014    ///
4015    /// Returns `None` for dense/embedded storage or for second-derivative levels.
4016    pub(crate) fn implicit_first_axis_info(
4017        &self,
4018    ) -> Option<(
4019        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
4020        usize,
4021    )> {
4022        storage_dispatch!(&self.storage, b => b.implicit_first_axis_info())
4023    }
4024
4025    pub(crate) fn implicit_axis_count_hint(&self) -> Option<usize> {
4026        storage_dispatch!(&self.storage, b => b.implicit_axis_count_hint())
4027    }
4028}
4029
4030impl From<Array2<f64>> for HyperDesignDerivative {
4031    fn from(value: Array2<f64>) -> Self {
4032        Self {
4033            storage: DerivativeMatrixStorage::Dense(value),
4034        }
4035    }
4036}
4037
4038#[derive(Clone)]
4039pub struct HyperPenaltyDerivative {
4040    pub(crate) storage: DerivativeMatrixStorage,
4041}
4042
4043impl HyperPenaltyDerivative {
4044    pub fn from_embedded(
4045        local: Array2<f64>,
4046        global_range: Range<usize>,
4047        total_dim: usize,
4048    ) -> Self {
4049        Self {
4050            storage: DerivativeMatrixStorage::Embedded(EmbeddedDerivativeMatrix::new(
4051                local,
4052                global_range,
4053                total_dim,
4054            )),
4055        }
4056    }
4057
4058    pub(crate) fn resident_byte_count(&self) -> usize {
4059        storage_dispatch!(&self.storage, b => b.resident_byte_count())
4060    }
4061
4062    pub(crate) fn nrows(&self) -> usize {
4063        storage_dispatch!(&self.storage, b => b.penalty_dim())
4064    }
4065
4066    pub(crate) fn ncols(&self) -> usize {
4067        self.nrows()
4068    }
4069
4070    pub(crate) fn scaled_materialize(&self, amp: f64) -> Array2<f64> {
4071        let mut out = Array2::<f64>::zeros((self.nrows(), self.ncols()));
4072        self.scaled_add_to(&mut out, amp)
4073            .expect("scaled materialize uses matching target shape");
4074        out
4075    }
4076
4077    pub(crate) fn transformed(
4078        &self,
4079        qs: &Array2<f64>,
4080        free_basis_opt: Option<&Array2<f64>>,
4081    ) -> Result<Array2<f64>, EstimationError> {
4082        storage_dispatch!(&self.storage, b => b.penalty_transformed(qs, free_basis_opt))
4083    }
4084
4085    pub(crate) fn scaled_add_to(
4086        &self,
4087        target: &mut Array2<f64>,
4088        amp: f64,
4089    ) -> Result<(), EstimationError> {
4090        storage_dispatch!(&self.storage, b => b.penalty_scaled_add_to(target, amp))
4091    }
4092}
4093
4094impl From<Array2<f64>> for HyperPenaltyDerivative {
4095    fn from(value: Array2<f64>) -> Self {
4096        Self {
4097            storage: DerivativeMatrixStorage::Dense(value),
4098        }
4099    }
4100}
4101
4102#[derive(Clone)]
4103pub struct PenaltyDerivativeComponent {
4104    pub penalty_index: usize,
4105    pub matrix: HyperPenaltyDerivative,
4106}
4107
4108#[derive(Clone)]
4109pub struct DirectionalHyperParam {
4110    pub(crate) x_tau_original: HyperDesignDerivative,
4111    // Canonical penalty representation: every tau direction is decomposed into
4112    // base-penalty derivatives. There is no separate "assembled total" path.
4113    pub(crate) penalty_first_components: Vec<PenaltyDerivativeComponent>,
4114    // Optional pairwise second hyper-derivatives against all tau directions.
4115    // If provided, each vector must have length psi_dim and hold an optional
4116    // X_{tau_i,tau_j} entry in original coordinates.
4117    pub(crate) x_tau_tau_original: Option<Vec<Option<HyperDesignDerivative>>>,
4118    // Pairwise second derivatives are stored in the same canonical base-penalty
4119    // decomposition as the first derivatives.
4120    pub(crate) penaltysecond_components: Option<Vec<Option<Vec<PenaltyDerivativeComponent>>>>,
4121    pub(crate) penaltysecond_component_provider: Option<
4122        std::sync::Arc<
4123            dyn Fn(usize) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError>
4124                + Send
4125                + Sync
4126                + 'static,
4127        >,
4128    >,
4129    pub(crate) penaltysecond_partner_indices: Option<std::sync::Arc<[usize]>>,
4130    /// Whether this coordinate is penalty-like (B_i = ∂H/∂τ_i is PSD).
4131    /// True for τ (penalty scaling) coordinates; false for ψ (design-moving,
4132    /// anisotropic length-scale) coordinates. Controls EFS eligibility.
4133    pub(crate) is_penalty_like: bool,
4134}
4135
4136impl DirectionalHyperParam {
4137    pub(crate) fn resident_byte_count(&self) -> usize {
4138        let mut bytes = self.x_tau_original.resident_byte_count();
4139        for component in &self.penalty_first_components {
4140            bytes = bytes.saturating_add(component.matrix.resident_byte_count());
4141        }
4142        if let Some(entries) = self.x_tau_tau_original.as_ref() {
4143            for entry in entries.iter().flatten() {
4144                bytes = bytes.saturating_add(entry.resident_byte_count());
4145            }
4146        }
4147        if let Some(rows) = self.penaltysecond_components.as_ref() {
4148            for components in rows.iter().flatten() {
4149                for component in components {
4150                    bytes = bytes.saturating_add(component.matrix.resident_byte_count());
4151                }
4152            }
4153        }
4154        bytes
4155    }
4156
4157    pub(crate) fn canonicalize_penalty_components(
4158        components: Vec<(usize, HyperPenaltyDerivative)>,
4159    ) -> Result<Vec<PenaltyDerivativeComponent>, EstimationError> {
4160        let mut out: Vec<PenaltyDerivativeComponent> = Vec::with_capacity(components.len());
4161        for (penalty_index, matrix) in components {
4162            if out.iter().any(|c| c.penalty_index == penalty_index) {
4163                crate::bail_invalid_estim!(
4164                    "duplicate penalty derivative component for penalty {}",
4165                    penalty_index
4166                );
4167            }
4168            out.push(PenaltyDerivativeComponent {
4169                penalty_index,
4170                matrix,
4171            });
4172        }
4173        Ok(out)
4174    }
4175
4176    pub fn new_compact(
4177        x_tau_original: HyperDesignDerivative,
4178        penalty_first_components: Vec<(usize, HyperPenaltyDerivative)>,
4179        x_tau_tau_original: Option<Vec<Option<HyperDesignDerivative>>>,
4180        penaltysecond_components: Option<Vec<Option<Vec<(usize, HyperPenaltyDerivative)>>>>,
4181    ) -> Result<Self, EstimationError> {
4182        let is_penalty_like = !x_tau_original.any_nonzero();
4183        let penalty_first_components =
4184            Self::canonicalize_penalty_components(penalty_first_components)?;
4185        let penaltysecond_components = match penaltysecond_components {
4186            Some(rows) => {
4187                let mut out = Vec::with_capacity(rows.len());
4188                for row in rows {
4189                    out.push(match row {
4190                        Some(components) => {
4191                            Some(Self::canonicalize_penalty_components(components)?)
4192                        }
4193                        None => None,
4194                    });
4195                }
4196                Some(out)
4197            }
4198            None => None,
4199        };
4200        Ok(Self {
4201            x_tau_original,
4202            penalty_first_components,
4203            x_tau_tau_original,
4204            penaltysecond_components,
4205            penaltysecond_component_provider: None,
4206            penaltysecond_partner_indices: None,
4207            is_penalty_like,
4208        })
4209    }
4210
4211    /// Mark this coordinate as non-penalty-like (design-moving).
4212    /// EFS will skip it; use Newton/BFGS for these coordinates.
4213    pub fn not_penalty_like(mut self) -> Self {
4214        self.is_penalty_like = false;
4215        self
4216    }
4217
4218    pub fn with_penaltysecond_component_provider(
4219        mut self,
4220        provider: std::sync::Arc<
4221            dyn Fn(usize) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError>
4222                + Send
4223                + Sync
4224                + 'static,
4225        >,
4226    ) -> Self {
4227        self.penaltysecond_component_provider = Some(provider);
4228        self
4229    }
4230
4231    pub fn with_penaltysecond_partner_indices(mut self, partners: Vec<usize>) -> Self {
4232        self.penaltysecond_partner_indices = Some(std::sync::Arc::from(partners));
4233        self
4234    }
4235
4236    pub(crate) fn x_tau_dense(&self) -> Array2<f64> {
4237        self.x_tau_original.materialize()
4238    }
4239
4240    pub(crate) fn transformed_x_tau(
4241        &self,
4242        qs: &Array2<f64>,
4243        free_basis_opt: Option<&Array2<f64>>,
4244    ) -> Result<Array2<f64>, EstimationError> {
4245        self.x_tau_original.transformed(qs, free_basis_opt)
4246    }
4247
4248    pub(crate) fn x_tau_tau_entry_at(&self, j: usize) -> Option<HyperDesignDerivative> {
4249        self.x_tau_tau_original
4250            .as_ref()
4251            .and_then(|rows| rows.get(j))
4252            .and_then(|entry| entry.clone())
4253    }
4254
4255    /// Whether this coordinate's design derivative uses implicit storage at the
4256    /// first-derivative level.
4257    pub(crate) fn has_implicit_operator(&self) -> bool {
4258        self.x_tau_original.uses_implicit_storage()
4259    }
4260
4261    pub(crate) fn has_implicit_multidim_duchon(&self) -> bool {
4262        self.implicit_first_axis_info()
4263            .is_some_and(|(op, _)| op.n_axes() > 1 && op.is_duchon_family())
4264    }
4265
4266    /// Extract the implicit design derivative operator and axis, if available.
4267    pub(crate) fn implicit_first_axis_info(
4268        &self,
4269    ) -> Option<(
4270        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
4271        usize,
4272    )> {
4273        self.x_tau_original.implicit_first_axis_info()
4274    }
4275
4276    pub(crate) fn implicit_axis_count_hint(&self) -> Option<usize> {
4277        self.x_tau_original.implicit_axis_count_hint()
4278    }
4279
4280    pub(crate) fn penalty_first_components(&self) -> &[PenaltyDerivativeComponent] {
4281        &self.penalty_first_components
4282    }
4283
4284    pub(crate) fn penalty_total_at(
4285        &self,
4286        rho: &Array1<f64>,
4287        p: usize,
4288    ) -> Result<Array2<f64>, EstimationError> {
4289        let mut out = Array2::<f64>::zeros((p, p));
4290        for component in &self.penalty_first_components {
4291            if component.matrix.nrows() != p || component.matrix.ncols() != p {
4292                crate::bail_invalid_estim!(
4293                    "S_tau shape mismatch for penalty {}: expected {}x{}, got {}x{}",
4294                    component.penalty_index,
4295                    p,
4296                    p,
4297                    component.matrix.nrows(),
4298                    component.matrix.ncols()
4299                );
4300            }
4301            if component.penalty_index >= rho.len() {
4302                crate::bail_invalid_estim!(
4303                    "penalty_index {} out of bounds for rho dimension {}",
4304                    component.penalty_index,
4305                    rho.len()
4306                );
4307            }
4308            component
4309                .matrix
4310                .scaled_add_to(&mut out, rho[component.penalty_index].exp())?;
4311        }
4312        Ok(out)
4313    }
4314
4315    pub(crate) fn penaltysecond_components_for(
4316        &self,
4317        j: usize,
4318    ) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError> {
4319        if let Some(components) = self
4320            .penaltysecond_components
4321            .as_ref()
4322            .and_then(|rows| rows.get(j))
4323            .and_then(|row| row.clone())
4324        {
4325            return Ok(Some(components));
4326        }
4327        if let Some(provider) = self.penaltysecond_component_provider.as_ref() {
4328            return provider(j);
4329        }
4330        Ok(None)
4331    }
4332
4333    pub(crate) fn penaltysecond_componentrows(
4334        &self,
4335    ) -> Option<&[Option<Vec<PenaltyDerivativeComponent>>]> {
4336        self.penaltysecond_components.as_deref()
4337    }
4338
4339    pub(crate) fn penalty_first_component_count(&self) -> usize {
4340        self.penalty_first_components.len()
4341    }
4342
4343    pub(crate) fn has_penaltysecond_pair_at(&self, j: usize) -> bool {
4344        self.penaltysecond_components
4345            .as_ref()
4346            .and_then(|rows| rows.get(j))
4347            .is_some_and(Option::is_some)
4348            || self
4349                .penaltysecond_partner_indices
4350                .as_ref()
4351                .is_some_and(|partners| partners.contains(&j))
4352    }
4353}
4354
4355#[derive(Clone, Debug)]
4356pub(crate) struct SparseRemlDecision {
4357    pub(crate) geometry: RemlGeometry,
4358    pub(crate) reason: &'static str,
4359    pub(crate) p: usize,
4360    pub(crate) nnz_x: usize,
4361    pub(crate) nnz_h_upper_est: Option<usize>,
4362    pub(crate) density_h_upper_est: Option<f64>,
4363}
4364
4365#[derive(Clone)]
4366pub(crate) struct SparseExactEvalData {
4367    pub(crate) factor: Arc<SparseExactFactor>,
4368    pub(crate) takahashi: Option<Arc<gam_linalg::sparse_exact::TakahashiInverse>>,
4369    pub(crate) logdet_h: f64,
4370    pub(crate) logdet_s_pos: f64,
4371    pub(crate) penalty_rank: usize,
4372    pub(crate) det1_values: Arc<Array1<f64>>,
4373}
4374
4375#[derive(Clone)]
4376pub struct FirthDenseOperator {
4377    // Exact Firth/Jeffreys objects on the identifiable subspace.
4378    //
4379    // Let X in R^{n×p} potentially be rank-deficient with rank r.
4380    // With optional fixed observation weights a_i >= 0 we define A = diag(a),
4381    // choose an orthonormal coefficient-space basis Q for the identifiable
4382    // subspace of A^{1/2} X, and set:
4383    //   X_r := A^{1/2} X Q          (A = I when no fixed observation weights),
4384    //   W   := diag(w), with w_i = mu_i (1 - mu_i), 0 < w_i <= 1/4 for finite logit eta,
4385    //   I_r := X_rᵀ W X_r,
4386    //   S_r := X_rᵀ X_r.
4387    //
4388    // Firth term is represented as:
4389    //   Phi(beta) = 0.5 log |I_r(beta)| - 0.5 log |S_r|,
4390    // which is exactly
4391    //   0.5 log |Uᵀ W U|
4392    // for the canonical orthonormalized identifiable design
4393    //   U = X_r S_r^{-1/2}.
4394    // This removes the raw-basis term from explicit reduced designs while
4395    // keeping the same identifiable-subspace hat matrix and beta derivatives,
4396    // because S_r is fixed with respect to beta.
4397    //
4398    // Mapping back to the full p-space uses:
4399    //   I_+^dagger = Q I_r^{-1} Qᵀ.
4400    //
4401    // We store reduced-space factors so all derivatives can be evaluated exactly
4402    // without materializing dense n×n matrices M = X K Xᵀ or P = M⊙M.
4403    pub(crate) x_dense: Array2<f64>,
4404    pub(crate) x_dense_t: Array2<f64>,
4405    // Orthonormal coefficient-space basis for the identifiable subspace,
4406    // built from the retained eigenspace of (A^{1/2} X)ᵀ(A^{1/2} X).
4407    pub(crate) q_basis: Array2<f64>,
4408    // Reduced identifiable design. With fixed observation weights a_i this is
4409    // diag(sqrt(a_i)) X Q; otherwise it is X Q.
4410    pub(crate) x_reduced: Array2<f64>,
4411    // Optional fixed case-weight square roots used when the Jeffreys/Firth
4412    // operator is formed from Xᵀ diag(case_weight ⊙ w(η)) X rather than
4413    // Xᵀ diag(w(η)) X. The exact directional tau derivatives must project and
4414    // row-scale with the same weights so the reduced Fisher, hat diagonals,
4415    // and tau kernels all live on one consistent identifiable subspace.
4416    pub(crate) observation_weight_sqrt: Option<Array1<f64>>,
4417    // I_r^{-1}
4418    pub(crate) k_reduced: Array2<f64>,
4419    // diag(S_r^{-1}) with S_r = X_rᵀ X_r. In the current canonical reduced
4420    // basis this completely characterizes the metric inverse, because Q
4421    // diagonalizes the design Gram. It is used to remove the reduced-coordinate
4422    // basis term from Phi_tau when the design moves.
4423    pub(crate) x_metric_reduced_inv_diag: Array1<f64>,
4424    // 0.5 (log|I_r| - log|S_r|) at the current eta.
4425    pub(crate) half_log_det: f64,
4426    // h = diag(M), M = X_r K_r X_r'
4427    pub(crate) h_diag: Array1<f64>,
4428    // Logistic Fisher-weight eta-derivatives: w', w'', w''', w'''' as n-vectors.
4429    pub(crate) w: Array1<f64>,
4430    pub(crate) w1: Array1<f64>,
4431    pub(crate) w2: Array1<f64>,
4432    pub(crate) w3: Array1<f64>,
4433    pub(crate) w4: Array1<f64>,
4434    // B = diag(w') X used in D Hphi and D^2 Hphi contractions.
4435    pub(crate) b_base: Array2<f64>,
4436    // Cached invariant contraction P*B where P = (X_r K_r X_r') ⊙ (X_r K_r X_r').
4437    // This avoids recomputing the same O(n r^2 p) block in every directional call.
4438    pub(crate) p_b_base: Array2<f64>,
4439}
4440
4441/// β-independent (design-only) factor of the Firth/Jeffreys operator.
4442///
4443/// Everything stored here depends ONLY on the fixed design `X` and the fixed
4444/// prior/observation weights `a_i` — NOT on the current linear predictor `η`
4445/// (i.e. NOT on β). For a single inner PIRLS solve the design and prior weights
4446/// are constant while `η` changes every Newton iteration, so this factor can be
4447/// built once per solve and reused, hoisting the O(n·p²) Gram, the O(p³)
4448/// identifiable-subspace eigendecomposition, and the two n×p design clones out
4449/// of the per-iteration hot path (#1575).
4450///
4451/// The β-dependent remainder (Fisher weights `w(η)`, reduced Fisher
4452/// `I_r = X_rᵀ W X_r`, its inverse `K_r`, the hat diagonal `h`, and the
4453/// half-log-determinant) is rebuilt per iteration from this factor via
4454/// [`FirthDenseOperator::build_from_design_factor`] (full operator) or
4455/// [`FirthDenseOperator::pirls_diagnostics_from_factor`] (the three PIRLS
4456/// diagnostics only). Both reproduce the un-hoisted build bit-for-bit.
4457#[derive(Clone)]
4458pub(crate) struct FirthDesignFactor {
4459    // Raw design and its transpose (the operator stores owned copies).
4460    pub(crate) x_dense: Array2<f64>,
4461    pub(crate) x_dense_t: Array2<f64>,
4462    // Orthonormal identifiable-subspace basis Q of (A^{1/2} X)ᵀ(A^{1/2} X).
4463    pub(crate) q_basis: Array2<f64>,
4464    // Reduced identifiable design X_r = A^{1/2} X Q.
4465    pub(crate) x_reduced: Array2<f64>,
4466    // Fixed case-weight square roots (sqrt(a_i)), if any.
4467    pub(crate) observation_weight_sqrt: Option<Array1<f64>>,
4468    // Retained positive spectrum of the design Gram = S_r diagonal.
4469    pub(crate) metric_spectrum: Array1<f64>,
4470    // diag(S_r^{-1}); precomputed reciprocal of `metric_spectrum`.
4471    pub(crate) x_metric_reduced_inv_diag: Array1<f64>,
4472    // rank r = ncols(q_basis); n = nrows(x_dense).
4473    pub(crate) r: usize,
4474    pub(crate) n: usize,
4475}
4476
4477#[derive(Clone)]
4478pub(crate) struct FirthDirection {
4479    pub(crate) deta: Array1<f64>,
4480    pub(crate) g_u_reduced: Array2<f64>,
4481    pub(crate) a_u_reduced: Array2<f64>,
4482    pub(crate) dh: Array1<f64>,
4483    // B_u = diag(w'' ⊙ δη_u) X is represented by the row-scaling vector only.
4484    pub(crate) b_uvec: Array1<f64>,
4485}
4486
4487#[derive(Clone)]
4488pub(crate) struct FirthTauPartialKernel {
4489    pub(super) deta_partial: Array1<f64>,
4490    pub(crate) dotw1: Array1<f64>,
4491    pub(crate) dotw2: Array1<f64>,
4492    pub(crate) dot_h_partial: Array1<f64>,
4493    // Reduced design drift X_{tau,r} = X_tau Q used in exact design-moving
4494    // Hadamard-Gram contractions.
4495    pub(crate) x_tau_reduced: Array2<f64>,
4496    pub(super) dot_i_partial: Array2<f64>,
4497    // Reduced Fisher inverse drift:
4498    //   dot(K_r) = -K_r dot(I_r) K_r
4499    // where dot(I_r) includes explicit X_tau and weight drift at beta-fixed.
4500    pub(crate) dot_k_reduced: Array2<f64>,
4501}
4502
4503#[derive(Clone)]
4504pub(crate) struct FirthTauExactKernel {
4505    pub(crate) gphi_tau: Array1<f64>,
4506    pub(crate) phi_tau_partial: f64,
4507    pub(crate) tau_kernel: Option<FirthTauPartialKernel>,
4508}
4509
4510/// Pair-level (τ_i × τ_j) exact Firth bundle at fixed β.
4511///
4512/// Mirrors `FirthTauExactKernel` but for the 2nd-order cross
4513/// derivatives:
4514///   Phi_{τ_i τ_j}|β  (scalar, `phi_tau_tau_partial`)
4515///   (gphi)_{τ_i τ_j}|β (p-vector, `gphi_tau_tau`)
4516///
4517/// Carries an optional `tau_tau_kernel` so pair callbacks can chain
4518/// into Primitive A (`hphi_tau_tau_partial_apply`) for the operator-
4519/// valued Hessian 2nd drift without recomputing shared reduced Grams.
4520///
4521#[derive(Clone)]
4522pub(crate) struct FirthTauTauExactKernel {
4523    pub(super) phi_tau_tau_partial: f64,
4524    pub(super) gphi_tau_tau: Array1<f64>,
4525    pub(super) tau_tau_kernel: Option<FirthTauTauPartialKernel>,
4526}
4527
4528/// Prepared state for `∂²H_φ/∂τ_i ∂τ_j |_β` (Primitive A).
4529///
4530/// Carries both τ-direction reduced designs, their η̇ vectors, and the
4531/// reduced-coordinate drifts (İ, K̇, ḣ) for i and j so the apply step can
4532/// form M̈_{ij}, K̈_{ij}, ḧ_{ij}, Γ̈_{ij}, and B̈_{ij} matrix-free.  Fields
4533/// are filled in by 13b; kept with a neutral internal shape so downstream
4534/// pair callbacks can hold the kernel across the pair dispatch.
4535///
4536/// Wired into the pair-callback's `b_operator` via
4537/// `FirthAugmentedPairHyperOperator`, and produced by both
4538/// `hphi_tau_tau_partial_prepare_from_partials` and
4539/// `exact_tau_tau_kernel` (the scalar/p-vector companion).
4540#[derive(Clone, Default)]
4541pub(crate) struct FirthTauTauPartialKernel {
4542    pub(super) x_tau_i_reduced: Array2<f64>,
4543    pub(super) x_tau_j_reduced: Array2<f64>,
4544    pub(super) deta_i_partial: Array1<f64>,
4545    pub(super) deta_j_partial: Array1<f64>,
4546    pub(super) dot_h_i_partial: Array1<f64>,
4547    pub(super) dot_h_j_partial: Array1<f64>,
4548    pub(super) dot_k_i_reduced: Array2<f64>,
4549    pub(super) dot_k_j_reduced: Array2<f64>,
4550    pub(super) dot_i_i_partial: Array2<f64>,
4551    pub(super) dot_i_j_partial: Array2<f64>,
4552    pub(super) x_tau_tau_reduced: Option<Array2<f64>>,
4553    pub(super) deta_ij_partial: Option<Array1<f64>>,
4554}
4555
4556/// Prepared state for `D_β((H_φ)_τ|_β)[v]` (Primitive B).
4557///
4558/// Carries the τ-kernel pieces (x_tau_reduced, İ, K̇, ḣ), the
4559/// β-direction quantities (δη_v, A_v, dh_v, b-chain), and the mixed
4560/// β-τ pieces (D_β(K̇_τ)[v], D_β(ḣ_τ)[v], δη_{τ,v}) so the apply
4561/// step collapses to the 9-term β-τ expansion without recomputing
4562/// shared reduced Grams.
4563#[derive(Clone, Default)]
4564pub(crate) struct FirthTauBetaPartialKernel {
4565    pub(super) x_tau_reduced: Array2<f64>,
4566    pub(super) deta_partial: Array1<f64>,
4567    pub(super) dot_h_partial: Array1<f64>,
4568    pub(super) dot_i_partial: Array2<f64>,
4569    pub(super) dot_k_reduced: Array2<f64>,
4570    pub(super) deta_v: Array1<f64>,
4571    pub(super) deta_tau_v: Array1<f64>,
4572    pub(super) a_v_reduced: Array2<f64>,
4573    pub(super) dh_v: Array1<f64>,
4574    pub(super) b_vvec: Array1<f64>,
4575    pub(super) d_beta_dot_k: Array2<f64>,
4576    pub(super) d_beta_dot_h: Array1<f64>,
4577}
4578
4579/// Holds the state for the outer REML optimization and supplies cost and
4580/// gradient evaluations to the `opt` optimizer.
4581///
4582/// The `cache` field uses `RefCell` to enable interior mutability. This is a crucial
4583/// performance optimization. The `cost_andgrad` closure required by the BFGS
4584/// optimizer takes an immutable reference `&self`. However, we want to cache the
4585/// results of the expensive P-IRLS computation to avoid re-calculating the fit
4586/// for the same `rho` vector, which can happen during the line search.
4587/// `RefCell` allows us to mutate the cache through a `&self` reference,
4588/// making this optimization possible while adhering to the optimizer's API.
4589#[derive(Clone)]
4590pub(crate) struct EvalShared {
4591    pub(crate) key: Option<Vec<u64>>,
4592    pub(crate) pirls_result: Arc<PirlsResult>,
4593    pub(crate) ridge_passport: RidgePassport,
4594    pub(crate) geometry: RemlGeometry,
4595    /// The exact H_total matrix used for LAML cost computation.
4596    /// For Firth: effective Hessian minus hphi (plus any barrier curvature).
4597    /// For non-Firth: the effective Hessian itself (plus any barrier curvature).
4598    pub(crate) h_total: Arc<Array2<f64>>,
4599    pub(crate) sparse_exact: Option<Arc<SparseExactEvalData>>,
4600    pub(crate) firth_dense_operator: Option<Arc<FirthDenseOperator>>,
4601    /// Cached FirthDenseOperator built from the original (non-reparameterized)
4602    /// design matrix, for use by the sparse evaluation path.
4603    pub(crate) firth_dense_operator_original: Option<Arc<FirthDenseOperator>>,
4604    /// The ONE original-frame penalty pseudo-logdet factorization for this
4605    /// evaluation point (#931 atom discipline). `log|Σ λ_k S_k|₊`'s VALUE,
4606    /// ρ-derivatives, τ/ψ components, and ρ×τ cross blocks are all
4607    /// contractions of this single eigendecomposition; the ρ-side criterion
4608    /// assembly (`dense_penalty_logdet_derivs`, the sparse det2 path) and the
4609    /// original-basis hyper-coordinate builders share it through
4610    /// [`EvalShared::penalty_pseudologdet_original`]. Building a second
4611    /// factorization of the same Sλ for the same evaluation point is the
4612    /// objective↔gradient desync surface (#748/#752/#901) this cell removes:
4613    /// the ridge and positive-eigenspace threshold are decided exactly once.
4614    /// (The transformed-frame pair-callback path builds its own object — it
4615    /// factorizes the canonical-TRANSFORMED, possibly constraint-projected
4616    /// penalties, a genuinely different matrix, not a duplicate of this one.)
4617    pub(crate) penalty_pseudologdet: std::sync::OnceLock<Arc<penalty_logdet::PenaltyPseudologdet>>,
4618    /// Per-evaluation-point cache of the canonical penalty score vectors
4619    /// `S_k β̂` evaluated at this bundle's inner mode `β̂ =
4620    /// pirls_result.beta_transformed` (unscaled by λ_k). These depend ONLY
4621    /// on the inner solution carried by this bundle and the `RemlState`'s
4622    /// fixed `canonical_penalties` — never on which ρ-coordinate or eval
4623    /// mode the assembly is running — so they are computed exactly once per
4624    /// inner solution and shared by every assemble call that reuses the
4625    /// bundle (cost + gradient evaluations at the same ρ, EFS, synthetic-ext
4626    /// value probes). Exact hoist, not an approximation: every consumer sees
4627    /// literally the same vectors it previously recomputed. Initialized via
4628    /// plain ndarray matvecs (no rayon inside the `OnceLock` closure — the
4629    /// `get_or_init`+`into_par_iter` deadlock trap does not apply).
4630    pub(crate) penalty_scores_at_mode: std::sync::OnceLock<Arc<Vec<Array1<f64>>>>,
4631    /// Per-evaluation-point cache of the #784 block-local Laplace-to-sampling
4632    /// correction `TkCorrectionTerms { value, gradient }`. The correction is a
4633    /// deterministic function of ONLY this bundle's converged inner state
4634    /// (`pirls_result`, `h_total`), the `RemlState`'s fixed
4635    /// `canonical_penalties`, and the bundle's ρ — never of the eval `mode`:
4636    /// the diagnostic eigendecomposition, the fixed-seed importance sampler,
4637    /// and the (b)–(d) gradient channels all read mode-invariant fields, and
4638    /// the term carries no Hessian, so the value+gradient are identical for the
4639    /// value-only, value+gradient, and value+gradient+Hessian assemble calls
4640    /// that share this bundle at a single ρ. The expensive path (eigendecomp +
4641    /// O(draws·n·m) sampler) previously reran on every one of those 2–3 calls
4642    /// per outer iteration; hoisting it onto the bundle computes it exactly
4643    /// once per inner solution (exact hoist, identical values — #784, #1082).
4644    /// Keyed only on the external-coordinate count `n_ext`: with no ψ
4645    /// coordinates (`n_ext == 0`) the correction engages; with ψ present the
4646    /// seam declines (returns the cheap zero), and n_ext is fixed for a fit, so
4647    /// a single cell suffices.
4648    pub(crate) block_local_correction:
4649        std::sync::OnceLock<(usize, Arc<outer_eval::TkCorrectionTerms>)>,
4650}
4651
4652impl EvalShared {
4653    pub(crate) fn matches(&self, key: &Option<Vec<u64>>) -> bool {
4654        match (&self.key, key) {
4655            (None, None) => true,
4656            (Some(a), Some(b)) => a == b,
4657            _ => false,
4658        }
4659    }
4660
4661    /// Lazily build — once per evaluation point — the original-frame
4662    /// [`PenaltyPseudologdet`](penalty_logdet::PenaltyPseudologdet) of
4663    /// `Σ λ_k S_k` and hand every caller the SAME factorization.
4664    ///
4665    /// This is the #931 port of the penalty-logdet term: value, ρ-first /
4666    /// ρ-second derivatives, τ-gradient components, τ×τ and ρ×τ Hessian
4667    /// blocks are all projections of one eigendecomposition, so no pair of
4668    /// consumers can disagree about the ridge or the positive-eigenspace
4669    /// threshold. The ridge is read from this bundle's `ridge_passport` —
4670    /// the single place that convention is decided.
4671    ///
4672    /// `lambdas` must be the λ = exp(ρ) vector of this bundle's evaluation
4673    /// point and `p` the original-basis coefficient dimension; on a cache
4674    /// hit both are checked against the stored object where representable.
4675    pub(crate) fn penalty_pseudologdet_original(
4676        &self,
4677        canonical_penalties: &[gam_terms::construction::CanonicalPenalty],
4678        lambdas: &[f64],
4679        p: usize,
4680    ) -> Result<Arc<penalty_logdet::PenaltyPseudologdet>, EstimationError> {
4681        if let Some(pld) = self.penalty_pseudologdet.get() {
4682            if pld.dim() != p {
4683                return Err(EstimationError::LayoutError(format!(
4684                    "shared penalty pseudo-logdet frame mismatch: cached p={}, requested p={}",
4685                    pld.dim(),
4686                    p
4687                )));
4688            }
4689            return Ok(Arc::clone(pld));
4690        }
4691        let pld = Arc::new(
4692            penalty_logdet::PenaltyPseudologdet::from_penalties(
4693                canonical_penalties,
4694                lambdas,
4695                self.ridge_passport.penalty_logdet_ridge(),
4696                p,
4697            )
4698            .map_err(EstimationError::InvalidInput)?,
4699        );
4700        match self.penalty_pseudologdet.set(Arc::clone(&pld)) {
4701            Ok(()) => Ok(pld),
4702            // A concurrent caller initialized the cell first; both objects
4703            // were built from identical inputs — return the canonical winner
4704            // so every consumer holds literally the same factorization.
4705            Err(_) => Ok(Arc::clone(
4706                self.penalty_pseudologdet
4707                    .get()
4708                    .expect("OnceLock set raced, so it is initialized"),
4709            )),
4710        }
4711    }
4712}
4713
4714impl PenalizedGeometry for EvalShared {
4715    fn backend_kind(&self) -> GeometryBackendKind {
4716        match self.geometry {
4717            RemlGeometry::DenseSpectral => GeometryBackendKind::DenseSpectral,
4718            RemlGeometry::SparseExactSpd => GeometryBackendKind::SparseExactSpd,
4719        }
4720    }
4721}
4722
4723/// LRU cache keyed by sanitized ρ vectors that holds compacted PIRLS results
4724/// for warm-starting outer line searches and revisited evaluations.
4725///
4726/// Eviction is byte-budgeted rather than entry-count-budgeted: each entry
4727/// records its own estimated footprint (the surviving n-length vectors plus
4728/// the two p×p Hessians plus per-entry overhead) and the cache evicts in
4729/// LRU order until the running total fits under the budget. An entry that
4730/// individually exceeds the budget is rejected silently rather than poisoning
4731/// the cache.
4732pub(crate) struct PirlsLruCache {
4733    // Stored tuple: (compacted result, last-touched clock, estimated bytes).
4734    pub(crate) map: HashMap<Vec<u64>, (Arc<PirlsResult>, u64, usize)>,
4735    pub(crate) byte_budget: usize,
4736    pub(crate) current_bytes: usize,
4737    pub(crate) clock: u64,
4738}
4739
4740impl PirlsLruCache {
4741    pub(crate) fn new(byte_budget: usize) -> Self {
4742        Self {
4743            map: HashMap::new(),
4744            byte_budget: byte_budget.max(1),
4745            current_bytes: 0,
4746            clock: 0,
4747        }
4748    }
4749
4750    pub(crate) fn get(&mut self, key: &Vec<u64>) -> Option<Arc<PirlsResult>> {
4751        if let Some(entry) = self.map.get_mut(key) {
4752            self.clock += 1;
4753            entry.1 = self.clock;
4754            Some(entry.0.clone())
4755        } else {
4756            None
4757        }
4758    }
4759
4760    pub(crate) fn insert(&mut self, key: Vec<u64>, value: Arc<PirlsResult>) {
4761        self.clock += 1;
4762        let bytes = pirls_result_cache_bytes(&value);
4763        // Refuse entries that on their own already exceed the entire budget;
4764        // caching one would force eviction of every other entry without
4765        // leaving room for the new one anyway.
4766        if bytes > self.byte_budget {
4767            if let Some((_, _, prev_bytes)) = self.map.remove(&key) {
4768                self.current_bytes = self.current_bytes.saturating_sub(prev_bytes);
4769            }
4770            return;
4771        }
4772        if let Some((_, _, prev_bytes)) = self.map.remove(&key) {
4773            self.current_bytes = self.current_bytes.saturating_sub(prev_bytes);
4774        }
4775        while self.current_bytes + bytes > self.byte_budget {
4776            let evict_key = self
4777                .map
4778                .iter()
4779                .min_by_key(|(_, (_, ts, _))| *ts)
4780                .map(|(k, _)| k.clone());
4781            match evict_key {
4782                Some(k) => {
4783                    if let Some((_, _, evict_bytes)) = self.map.remove(&k) {
4784                        self.current_bytes = self.current_bytes.saturating_sub(evict_bytes);
4785                    }
4786                }
4787                None => break,
4788            }
4789        }
4790        self.current_bytes += bytes;
4791        self.map.insert(key, (value, self.clock, bytes));
4792    }
4793
4794    pub(crate) fn clear(&mut self) {
4795        self.map.clear();
4796        self.current_bytes = 0;
4797    }
4798}
4799
4800#[derive(Clone, Copy, PartialEq, Eq)]
4801pub(crate) struct PenaltySubspaceCacheKey {
4802    pub(crate) penalty_matrix_fingerprint: u64,
4803    pub(crate) ridge_passport_signature: u64,
4804}
4805
4806pub(crate) struct PenaltySubspaceCache {
4807    pub(crate) entry: Option<(PenaltySubspaceCacheKey, Arc<outer_eval::PenaltySubspace>)>,
4808}
4809
4810impl PenaltySubspaceCache {
4811    pub(crate) fn new() -> Self {
4812        Self { entry: None }
4813    }
4814
4815    pub(crate) fn get(
4816        &self,
4817        key: &PenaltySubspaceCacheKey,
4818    ) -> Option<Arc<outer_eval::PenaltySubspace>> {
4819        self.entry
4820            .as_ref()
4821            .filter(|(cached_key, _)| cached_key == key)
4822            .map(|(_, value)| value.clone())
4823    }
4824
4825    pub(crate) fn insert(
4826        &mut self,
4827        key: PenaltySubspaceCacheKey,
4828        value: Arc<outer_eval::PenaltySubspace>,
4829    ) {
4830        self.entry = Some((key, value));
4831    }
4832
4833    pub(crate) fn clear(&mut self) {
4834        self.entry = None;
4835    }
4836}
4837
4838impl PenaltySubspaceCacheKey {
4839    /// Build a cache key from the transformed-E matrix and ridge passport.
4840    /// `E` is hashed by exact f64 bits (column-major), so the key is bit-exact
4841    /// and avoids float-Hash issues; the ridge passport is hashed via its
4842    /// `Hash` impl. Two calls at the same `(E, ridge)` yield equal keys.
4843    pub(crate) fn from_inputs(
4844        e_transformed: &ndarray::Array2<f64>,
4845        ridge_passport: &gam_problem::RidgePassport,
4846    ) -> Self {
4847        use std::collections::hash_map::DefaultHasher;
4848        use std::hash::{Hash, Hasher};
4849        let mut hasher = DefaultHasher::new();
4850        e_transformed.nrows().hash(&mut hasher);
4851        e_transformed.ncols().hash(&mut hasher);
4852        for value in e_transformed.iter() {
4853            value.to_bits().hash(&mut hasher);
4854        }
4855        let penalty_matrix_fingerprint = hasher.finish();
4856        let mut ridge_hasher = DefaultHasher::new();
4857        ridge_passport.delta.to_bits().hash(&mut ridge_hasher);
4858        (ridge_passport.matrix_form as u8).hash(&mut ridge_hasher);
4859        ridge_passport
4860            .policy
4861            .include_penalty_logdet
4862            .hash(&mut ridge_hasher);
4863        ridge_passport
4864            .policy
4865            .include_laplacehessian
4866            .hash(&mut ridge_hasher);
4867        let ridge_passport_signature = ridge_hasher.finish();
4868        Self {
4869            penalty_matrix_fingerprint,
4870            ridge_passport_signature,
4871        }
4872    }
4873}
4874
4875/// Estimate the in-cache footprint of a (compacted) PIRLS result.
4876///
4877/// Mirrors what `compact_for_reml_cache` keeps:
4878/// * six surviving n-length f64 arrays (final_eta, solveweights,
4879///   solveworking_response, solvemu, solve_c_array, solve_d_array);
4880/// * the p-length coefficient vector;
4881/// * the two p×p Hessians (dense or CSC sparse);
4882/// * the `ReparamResult` payload — the dominant scaling term beyond n, since
4883///   it carries `s_transformed`, `qs`, and `e_transformed` as p×p / rank×p
4884///   matrices.
4885/// A small constant overhead absorbs scalar fields, enum discriminants, and
4886/// the HashMap entry. This errs on the conservative side: overestimation
4887/// causes earlier eviction, never under-counting that would let the cache
4888/// silently exceed the byte budget.
4889pub(crate) fn pirls_result_cache_bytes(result: &PirlsResult) -> usize {
4890    use std::mem::size_of;
4891    let n_array_elems = result.final_eta.len()
4892        + result.solveweights.len()
4893        + result.solveworking_response.len()
4894        + result.solvemu.len()
4895        + result.solve_c_array.len()
4896        + result.solve_d_array.len();
4897    let p = result.beta_transformed.0.len();
4898    let pen_h = symmetric_matrix_cache_bytes(&result.penalized_hessian_transformed);
4899    let stab_h = symmetric_matrix_cache_bytes(&result.stabilizedhessian_transformed);
4900    let reparam = (result.reparam_result.s_transformed.len()
4901        + result.reparam_result.qs.len()
4902        + result.reparam_result.e_transformed.len()
4903        + result.reparam_result.det1.len())
4904        * size_of::<f64>();
4905    n_array_elems * size_of::<f64>() + p * size_of::<f64>() + pen_h + stab_h + reparam + 1024
4906}
4907
4908pub(crate) fn symmetric_matrix_cache_bytes(m: &gam_linalg::matrix::SymmetricMatrix) -> usize {
4909    use gam_linalg::matrix::SymmetricMatrix;
4910    use std::mem::size_of;
4911    match m {
4912        SymmetricMatrix::Dense(a) => a.len() * size_of::<f64>(),
4913        SymmetricMatrix::Sparse(s) => {
4914            // CSC sparse: f64 values + usize row indices + usize column pointers.
4915            let (symbolic, values) = s.parts();
4916            values.len() * (size_of::<f64>() + size_of::<usize>())
4917                + std::mem::size_of_val(symbolic.col_ptr())
4918        }
4919    }
4920}
4921
4922/// Capacity (number of distinct rho-points) of the outer-eval reuse LRU.
4923///
4924/// Sized to comfortably span a binomial seed grid's local revisit window
4925/// (baseline + isotropic shifts + per-axis refinements) plus a few
4926/// line-search trial points without unbounded growth. Each slot holds one
4927/// `OuterEval` (a scalar cost, a length-k gradient, an optional inner-beta
4928/// hint, and a usually-`Unavailable` Hessian), so the footprint is tiny.
4929pub(crate) const OUTER_EVAL_LRU_CAPACITY: usize = 8;
4930
4931/// Bounded least-recently-used cache of converged outer REML evaluations,
4932/// keyed by sanitized rho-bits.
4933///
4934/// CORRECTNESS: the key (`Vec<u64>` of `f64::to_bits`, with ±0 canonicalized)
4935/// is the complete result-determining input for a fixed `RemlState`. Every
4936/// other input to `OuterEval` — design matrix, prior weights, offset, penalty
4937/// structure, link/SAS/mixture state, Firth/Jeffreys configuration, and the
4938/// rho-prior — is immutable for the lifetime of the state that owns the cache,
4939/// so the stored cost / gradient / inner-beta hint depend only on rho. A hit
4940/// therefore returns exactly the value a recompute at that rho would converge
4941/// to (to the solver's own tolerance, identical to the trust the pre-existing
4942/// single-slot cache already placed in rho-only keying). Distinct rho-points
4943/// never alias: lookups compare the full key vector.
4944pub(crate) struct OuterEvalLru {
4945    capacity: usize,
4946    /// Front = least-recently-used, back = most-recently-used.
4947    entries: std::collections::VecDeque<(Vec<u64>, OuterEval)>,
4948}
4949
4950impl OuterEvalLru {
4951    pub(crate) fn new(capacity: usize) -> Self {
4952        Self {
4953            capacity: capacity.max(1),
4954            entries: std::collections::VecDeque::new(),
4955        }
4956    }
4957
4958    /// Returns a clone of the eval stored under `key`, if present, promoting it
4959    /// to most-recently-used. A miss returns `None` so the caller recomputes —
4960    /// never a stale value from a different key.
4961    pub(crate) fn get(&mut self, key: &[u64]) -> Option<OuterEval> {
4962        let pos = self
4963            .entries
4964            .iter()
4965            .position(|(k, _)| k.as_slice() == key)?;
4966        let entry = self.entries.remove(pos)?;
4967        let eval = entry.1.clone();
4968        self.entries.push_back(entry);
4969        Some(eval)
4970    }
4971
4972    /// Inserts (or refreshes) the eval for `key` as most-recently-used,
4973    /// evicting the least-recently-used entry once capacity is exceeded.
4974    pub(crate) fn insert(&mut self, key: Vec<u64>, eval: OuterEval) {
4975        if let Some(pos) = self
4976            .entries
4977            .iter()
4978            .position(|(k, _)| k.as_slice() == key.as_slice())
4979        {
4980            self.entries.remove(pos);
4981        }
4982        self.entries.push_back((key, eval));
4983        while self.entries.len() > self.capacity {
4984            self.entries.pop_front();
4985        }
4986    }
4987
4988    pub(crate) fn clear(&mut self) {
4989        self.entries.clear();
4990    }
4991}
4992
4993/// Centralized cache/memoization owner for REML evaluations.
4994///
4995/// This keeps cache-key identity, bundle reuse, and invalidation policy out of
4996/// the math kernels so objective/derivative routines can stay algebra-focused.
4997pub(crate) struct EvalCacheManager {
4998    pub(crate) pirls_cache: RwLock<PirlsLruCache>,
4999    pub(crate) penalty_subspace_cache: RwLock<PenaltySubspaceCache>,
5000    pub(crate) current_eval_bundle: RwLock<Option<EvalShared>>,
5001    /// Most-recently-*stored* outer eval (single slot). Retained verbatim so
5002    /// `previous_outer_gradient_norm` keeps its exact "immediately previous
5003    /// distinct eval" semantics, independent of the multi-slot reuse cache.
5004    pub(crate) current_outer_eval: RwLock<Option<(Vec<u64>, OuterEval)>>,
5005    /// Bounded multi-slot LRU of converged outer evaluations keyed by the
5006    /// sanitized rho-bits (#1575).
5007    ///
5008    /// For a frozen `RemlState` (fixed design, prior weights, offset, penalty
5009    /// structure, link state, Firth/Jeffreys configuration, and rho-prior — all
5010    /// of which are immutable for the lifetime of the state that owns this
5011    /// manager and therefore the lifetime of the cache), the outer objective
5012    /// value, its gradient, and the inner-beta hint are deterministic functions
5013    /// of rho alone. The sanitized rho-bits are thus the complete result key.
5014    /// The binomial REML fit performs ~20-32 seed-grid pre-solves plus
5015    /// line-search revisits; with only the single `current_outer_eval` slot,
5016    /// any revisit to an earlier rho re-ran a full n-sized P-IRLS. This LRU
5017    /// returns the stored cost/gradient for those revisited rho-points.
5018    pub(crate) outer_eval_lru: RwLock<OuterEvalLru>,
5019    pub(crate) pirls_cache_enabled: AtomicBool,
5020}
5021
5022impl EvalCacheManager {
5023    pub(crate) fn new() -> Self {
5024        Self {
5025            pirls_cache: RwLock::new(PirlsLruCache::new(PIRLS_CACHE_BYTE_BUDGET)),
5026            penalty_subspace_cache: RwLock::new(PenaltySubspaceCache::new()),
5027            current_eval_bundle: RwLock::new(None),
5028            current_outer_eval: RwLock::new(None),
5029            outer_eval_lru: RwLock::new(OuterEvalLru::new(OUTER_EVAL_LRU_CAPACITY)),
5030            pirls_cache_enabled: AtomicBool::new(true),
5031        }
5032    }
5033
5034    /// Creates a sanitized cache key from rho values.
5035    /// Returns None if any component is NaN, in which case caching is skipped.
5036    /// Maps -0.0 to 0.0 to ensure key stability.
5037    pub(crate) fn sanitized_rhokey(rho: &Array1<f64>) -> Option<Vec<u64>> {
5038        self::rho_key::sanitized_rhokey(rho)
5039    }
5040
5041    /// Memoizing wrapper for `PenaltySubspace` construction.
5042    ///
5043    /// The penalty-subspace eigendecomposition is shape-invariant: any two
5044    /// outer evaluations at the same `(E_transformed, ridge_passport)` produce
5045    /// bit-identical subspaces. The single-slot cache amortizes consecutive
5046    /// fixed-S queries (rank, logdet, trace) within a single outer iter.
5047    pub(super) fn cached_penalty_subspace<F>(
5048        &self,
5049        e_transformed: &ndarray::Array2<f64>,
5050        ridge_passport: &gam_problem::RidgePassport,
5051        build: F,
5052    ) -> Result<Arc<outer_eval::PenaltySubspace>, EstimationError>
5053    where
5054        F: FnOnce() -> Result<outer_eval::PenaltySubspace, EstimationError>,
5055    {
5056        let key = PenaltySubspaceCacheKey::from_inputs(e_transformed, ridge_passport);
5057        if let Some(hit) = self.penalty_subspace_cache.read().unwrap().get(&key) {
5058            return Ok(hit);
5059        }
5060        let value = Arc::new(build()?);
5061        self.penalty_subspace_cache
5062            .write()
5063            .unwrap()
5064            .insert(key, value.clone());
5065        Ok(value)
5066    }
5067
5068    pub(crate) fn cached_eval_bundle(&self, key: &Option<Vec<u64>>) -> Option<EvalShared> {
5069        let guard = self.current_eval_bundle.read().unwrap();
5070        let bundle: &EvalShared = guard.as_ref()?;
5071        bundle.matches(key).then(|| bundle.clone())
5072    }
5073
5074    pub(crate) fn store_eval_bundle(&self, bundle: EvalShared) {
5075        *self.current_eval_bundle.write().unwrap() = Some(bundle);
5076    }
5077
5078    pub(crate) fn cached_outer_eval(&self, key: &Option<Vec<u64>>) -> Option<OuterEval> {
5079        let key = key.as_ref()?;
5080        // The LRU is the authoritative multi-slot store; it always contains the
5081        // most-recently-stored eval too (kept in sync by `store_outer_eval`), so
5082        // a single LRU probe subsumes the old single-slot fast path while also
5083        // serving revisited (non-immediate) rho-points. `get` is a tiny linear
5084        // scan (capacity is `OUTER_EVAL_LRU_CAPACITY`) that promotes the hit to
5085        // most-recently-used; hence the write lock.
5086        self.outer_eval_lru.write().unwrap().get(key)
5087    }
5088
5089    pub(crate) fn store_outer_eval(&self, key: &Option<Vec<u64>>, eval: &OuterEval) {
5090        if let Some(key) = key.clone() {
5091            // Keep the single-slot mirror for `previous_outer_gradient_norm`,
5092            // whose "immediately previous distinct eval" contract reads it
5093            // directly and must stay byte-for-byte unchanged.
5094            *self.current_outer_eval.write().unwrap() = Some((key.clone(), eval.clone()));
5095            self.outer_eval_lru.write().unwrap().insert(key, eval.clone());
5096        }
5097    }
5098
5099    pub(crate) fn invalidate_eval_bundle(&self) {
5100        self.current_eval_bundle.write().unwrap().take();
5101        self.current_outer_eval.write().unwrap().take();
5102        self.outer_eval_lru.write().unwrap().clear();
5103    }
5104
5105    pub(crate) fn clear_eval_and_factor_caches(&self) {
5106        self.invalidate_eval_bundle();
5107        self.penalty_subspace_cache.write().unwrap().clear();
5108    }
5109}
5110
5111/// Reusable scratch/runtime memory that should not be part of mathematical
5112/// state invariants.
5113pub(crate) struct RemlArena {
5114    pub(crate) cost_eval_count: RwLock<u64>,
5115    /// Number of *actual* full-n inner P-IRLS solves performed (#1575).
5116    ///
5117    /// Distinct from `cost_eval_count`, which counts every outer cost/gradient
5118    /// REQUEST including single-slot cache hits and prior short-circuits. This
5119    /// counts only the cache-missing `prepare_eval_bundlewithkey` calls — i.e.
5120    /// the genuinely expensive `O(n·p²)` inner solves the #1575 slowdown is
5121    /// about ("~150 outer cost evals each running a full n-sized P-IRLS"). A
5122    /// healthy warm-started fit performs roughly 2 inner solves per outer
5123    /// cost-eval (one value, one gradient/Hessian), so a large ratio between
5124    /// the two signals broken warm-starting or duplicate solving. This is pure
5125    /// observability: it never feeds back into the optimization and changes no
5126    /// fitted value.
5127    pub(crate) inner_pirls_solve_count: AtomicU64,
5128    pub(crate) lastgradient_used_stochastic_fallback: AtomicBool,
5129}
5130
5131impl RemlArena {
5132    pub(crate) fn new() -> Self {
5133        Self {
5134            cost_eval_count: RwLock::new(0),
5135            inner_pirls_solve_count: AtomicU64::new(0),
5136            lastgradient_used_stochastic_fallback: AtomicBool::new(false),
5137        }
5138    }
5139}
5140
5141pub(crate) struct AloFrozenNuisance {
5142    pub(crate) n_obs: usize,
5143    pub(crate) influence_scale: Vec<f64>,
5144    pub(crate) phi: f64,
5145}
5146
5147pub(crate) struct RemlState<'a> {
5148    pub(crate) y: ArrayView1<'a, f64>,
5149    pub(crate) x: DesignMatrix,
5150    pub(crate) weights: ArrayView1<'a, f64>,
5151    pub(crate) offset: Array1<f64>,
5152    /// Canonicalized block-local penalties with pre-computed roots.
5153    /// This is the single canonical penalty representation — no full-width
5154    /// `rank × p` roots are stored separately.
5155    pub(crate) canonical_penalties: Arc<Vec<gam_terms::construction::CanonicalPenalty>>,
5156    pub(crate) balanced_penalty_root: Array2<f64>,
5157    pub(crate) reparam_invariant: ReparamInvariant,
5158    pub(crate) sparse_penalty_block_count: Option<usize>,
5159    pub(crate) p: usize,
5160    pub(crate) config: Arc<RemlConfig>,
5161    pub(crate) runtime_mixture_link_state: Option<gam_problem::MixtureLinkState>,
5162    pub(crate) runtime_sas_link_state: Option<SasLinkState>,
5163    pub(crate) nullspace_dims: Vec<usize>,
5164    pub(crate) coefficient_lower_bounds: Option<Array1<f64>>,
5165    pub(crate) linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
5166    /// Relative shrinkage floor for penalized block eigenvalues (rho-independent).
5167    pub(crate) penalty_shrinkage_floor: Option<f64>,
5168    /// Explicit prior on log smoothing parameters used by the REML/LAML objective.
5169    pub(crate) rho_prior: gam_problem::RhoPrior,
5170
5171    pub(crate) cache_manager: EvalCacheManager,
5172    pub(crate) arena: RemlArena,
5173    pub(crate) warm_start_beta: RwLock<Option<Coefficients>>,
5174    /// Two-point ρ-trajectory used for second-order warm-start
5175    /// extrapolation: when the outer optimizer asks for a fit at a new
5176    /// ρ, we have `β(ρ_k)` (in `warm_start_beta`) and `β(ρ_{k-1})` (in
5177    /// `prev_warm_start_beta`). The implicit β(ρ) trajectory is locally
5178    /// linear under the FOC ∇F(β,ρ)=0, so a tangent-line prediction
5179    /// `β_predict(ρ_new) = β_k + α · (β_k − β_{k-1})` where α is the
5180    /// projection of `(ρ_new − ρ_k)` onto `(ρ_k − ρ_{k-1})` gives a
5181    /// better seed than the flat `β_k` alone — replacing PIRLS warm-
5182    /// start "use last β as-is" with a real tangent-prediction step.
5183    pub(crate) warm_start_rho: RwLock<Option<Array1<f64>>>,
5184    pub(crate) prev_warm_start_beta: RwLock<Option<Coefficients>>,
5185    pub(crate) prev_warm_start_rho: RwLock<Option<Array1<f64>>>,
5186    pub(crate) warm_start_enabled: AtomicBool,
5187    pub(crate) screening_max_inner_iterations: Arc<AtomicUsize>,
5188    /// Outer-aware inner-PIRLS iteration cap for the main descent loop.
5189    ///
5190    /// Distinct from `screening_max_inner_iterations`, which is used during
5191    /// seed selection and toggles a side-effect bundle (cache writes,
5192    /// warm-start updates, KKT enforcement all suppressed). This atomic is
5193    /// purely a cap — when nonzero, the inner Newton loop is capped at
5194    /// `min(this, full_max_iterations)`, but cache writes and warm-start
5195    /// updates remain enabled. Driven by the outer optimizer to coarsen
5196    /// inner solves at early outer iterations when ρ is far from converged,
5197    /// and lifted back to full at the final accepted iter (otherwise the
5198    /// returned β would be biased by the loose cap).
5199    ///
5200    /// Both atomics are honored together as `min(screening_cap, outer_cap)`
5201    /// when both are nonzero. Default 0 (no cap from this source).
5202    pub(crate) outer_inner_cap: Arc<AtomicUsize>,
5203
5204    /// Inner-PIRLS feedback signal driven by `execute_pirls_if_needed` after
5205    /// each NON-screening solve. Stores the iteration count at which the
5206    /// inner Newton stopped, plus a flag indicating whether it converged
5207    /// (vs. hit the iteration cap). The outer first-/second-order bridges
5208    /// read these atomics to drive an adaptive `inner_cap_schedule`: the
5209    /// next outer iter's inner cap becomes `last_iters + small_margin`
5210    /// when the previous solve converged, or a geometric backoff when it
5211    /// hit the cap. This replaces the older hardcoded iter-tier schedule
5212    /// (3/5/10/20) with a cap that follows the inner solver's actual
5213    /// convergence behavior — Eisenstat-Walker style for the inner
5214    /// quadratic loop. Default 0 / false (no signal yet — first outer
5215    /// iter falls back to a coarse iter-count tier).
5216    pub(crate) last_inner_iters: Arc<AtomicUsize>,
5217    pub(crate) last_inner_converged: Arc<AtomicBool>,
5218
5219    /// Cached state from the most recent successful PIRLS solve, used by
5220    /// the IFT-based warm-start predictor.
5221    ///
5222    /// The implicit-function theorem applied to the FOC ∇_β F(β,ρ)=0
5223    /// gives `dβ/dρ_k = -H_pen^{-1} · (e^{ρ_k} · S_k · β)`. A first-order
5224    /// Taylor predictor reads
5225    /// `β_predict(ρ_new) = β_cur − Σ_k Δρ_k · H_pen^{-1} · (e^{ρ_cur_k} · S_k · β_cur)`.
5226    /// This is a strict superset of the tangent-line predictor's
5227    /// requirements: works after a single successful solve (tangent-line
5228    /// needs two prior fits), and gives the EXACT first-order Jacobian
5229    /// of the implicit β(ρ) trajectory rather than a secant proxy along one
5230    /// ρ-direction.
5231    ///
5232    /// Populated in `updatewarm_start_from` when PIRLS converges; cleared
5233    /// on failure, on `reset_surface`, and on link-state changes.
5234    pub(crate) ift_warm_start_cache: RwLock<Option<IftWarmStartCache>>,
5235
5236    /// Persisted Levenberg-Marquardt damping coefficient from the most
5237    /// recent successful PIRLS solve, bit-packed into an `AtomicU64`
5238    /// (`f64::to_bits` low 64 bits). Read at the start of
5239    /// `execute_pirls_if_needed` and written into the
5240    /// `PirlsConfig::initial_lm_lambda` hint so the inner Newton seeds
5241    /// `λ_LM` near the damping the previous solve discovered, instead
5242    /// of cold-starting at `1e-6` and burning 4-6 halving steps to
5243    /// recover. `0` (the default) signals "no hint"; the inner solver
5244    /// clamps any positive hint into `[1e-6, 1e-3]` so a stale value
5245    /// cannot destabilize the next solve. Reset on `reset_surface` and
5246    /// on failed solves.
5247    pub(crate) last_pirls_lm_lambda: Arc<AtomicU64>,
5248
5249    /// Negative-Binomial overdispersion `theta` frozen for the smoothing-
5250    /// parameter (λ) search (#1082), bit-packed `f64` (`f64::to_bits`). `0`
5251    /// (the default) signals "not yet frozen". On the first non-screening
5252    /// λ-search inner solve of an estimated-θ NB fit, the seed's
5253    /// maximum-likelihood θ is computed once and stored here; every subsequent
5254    /// λ-search evaluation pins the inner solve to this value via
5255    /// `GlmLikelihoodSpec::with_negbin_theta_frozen_for_search`, so the REML
5256    /// criterion `F(ρ) = REML(ρ, θ_frozen)` is a stationary function of ρ and
5257    /// the outer optimizer converges instead of chasing the per-eval θ drift
5258    /// that the estimated path injects. The single final reported fit still
5259    /// ML-refreshes θ at the converged η. Reset on `reset_surface`.
5260    pub(crate) frozen_negbin_theta: Arc<AtomicU64>,
5261
5262    /// Tweedie exponential-dispersion `phi` frozen for the smoothing-parameter
5263    /// (λ) search (#1477), bit-packed `f64` (`f64::to_bits`). `0` (the default)
5264    /// signals "not yet frozen". On the first non-screening λ-search inner solve
5265    /// of an estimated-φ Tweedie fit, the seed's Pearson `phî` is captured once
5266    /// and stored here; every subsequent λ-search evaluation pins the inner
5267    /// solve to this value via
5268    /// `GlmLikelihoodSpec::with_tweedie_phi_frozen_for_search`, so the REML
5269    /// criterion `F(ρ) = REML(ρ, φ_frozen)` is a stationary function of ρ. The
5270    /// Tweedie LAML omits the `phi`-dependent saddlepoint normalizer, so a `phi`
5271    /// drifting with each warm-start η lets the criterion reward dispersion
5272    /// inflation and rail a double-penalty null-space `λ` to the box bound (the
5273    /// #1477 boundary blow-up). The single final reported fit still
5274    /// Pearson-refreshes `phi` at the converged η. Reset on `reset_surface`.
5275    pub(crate) frozen_tweedie_phi: Arc<AtomicU64>,
5276
5277    /// Gamma shape `k = 1/φ` frozen for the smoothing-parameter (λ) search
5278    /// (#1074), bit-packed `f64` (`f64::to_bits`). `0` (the default) signals
5279    /// "not yet frozen". On the first non-screening λ-search inner solve of an
5280    /// estimated-shape Gamma fit, the seed's converged-η MLE `k̂` is captured
5281    /// once and stored here; every subsequent λ-search evaluation pins the inner
5282    /// solve to this value via
5283    /// `GlmLikelihoodSpec::with_gamma_shape_frozen_for_search`, so the REML
5284    /// criterion `F(ρ) = REML(ρ, k_frozen)` is a stationary function of ρ. With
5285    /// `k` estimated the inner solver re-derives it from each warm-start η, and
5286    /// because the Gamma working weight is `W = prior·k` and the
5287    /// omitting-constants log-likelihood is `−k·½D`, a `k` swinging with η makes
5288    /// BOTH the curvature `H = k·XᵀX + λS` and the data-fit `k·½D` jump with ρ —
5289    /// the criterion grows deterministic spikes that floor the projected
5290    /// gradient and rail `λ` to the over-smoothed corner (the #1074 te/Gamma
5291    /// tensor under-recovery). The single final reported fit still ML-refreshes
5292    /// `k` at the converged η. Reset on `reset_surface`.
5293    pub(crate) frozen_gamma_shape: Arc<AtomicU64>,
5294
5295    /// Last observed IFT-prediction residual (`‖β_converged − β_predicted‖
5296    /// / ‖β_converged‖`) from the most recent non-screening solve where
5297    /// the predictor was actually consumed. Bit-packed `f64` (low 64
5298    /// bits via `f64::to_bits`).
5299    ///
5300    /// "No signal yet" is encoded as a NaN bit-pattern
5301    /// (`IFT_RESIDUAL_NO_SIGNAL_BITS`). The original `0` sentinel
5302    /// collided with `f64::to_bits(0.0) == 0` — a true residual of
5303    /// exactly 0 (degenerate but mathematically possible if every
5304    /// β_predicted_i matched β_converged_i to bit-equality) would
5305    /// have been indistinguishable from "predictor never reported".
5306    /// NaN's self-inequality makes the sentinel unambiguous: any
5307    /// stored finite non-negative value is genuine signal.
5308    ///
5309    /// Read by `predict_warm_start_beta_ift_with_outcome` to drive the adaptive
5310    /// |Δρ| cap (`adaptive_ift_max_drho`): a small residual loosens
5311    /// the cap, a large one tightens it. Replaces the previous
5312    /// hardcoded `IFT_WARM_START_MAX_DRHO = 2.0` constant with a
5313    /// data-driven policy, so the predictor adapts to the empirical
5314    /// faithfulness of the linearization at this surface's scale.
5315    /// Reset on `reset_surface` and on failed solves.
5316    pub(crate) last_ift_prediction_residual: Arc<AtomicU64>,
5317
5318    /// Last observed gain ratio of the accepted LM step
5319    /// (`actual_reduction / predicted_reduction`) from the most recent
5320    /// non-screening PIRLS solve. Bit-packed `f64` with the same NaN
5321    /// sentinel discipline as `last_ift_prediction_residual`: NaN bits
5322    /// (`IFT_RESIDUAL_NO_SIGNAL_BITS`) encode "no signal yet" so a
5323    /// recorded ratio of exactly 0 (degenerate but possible) doesn't
5324    /// collide with the no-signal token.
5325    ///
5326    /// Used by `first_order_inner_cap_schedule` as a third quality
5327    /// signal alongside `last_iters` and `last_converged`. A small
5328    /// `accept_rho` (model overstating predicted reduction) is a hint
5329    /// the next iter's inner Newton may need extra margin even when
5330    /// the previous solve converged in few iters. Reset on
5331    /// `reset_surface` and on failed solves.
5332    pub(crate) last_pirls_accept_rho: Arc<AtomicU64>,
5333
5334    /// Cached Cholesky factorization of `IftWarmStartCache::penalized_hessian_transformed`.
5335    /// Lazily computed on the first IFT predict call after a fresh
5336    /// `updatewarm_start_from`, then reused by every subsequent
5337    /// predict call until the IFT cache is invalidated. At large-scale
5338    /// scale where p can reach several thousand, the dense Cholesky
5339    /// is O(p³)/3 — multiple seconds per refactor — so caching saves
5340    /// real wall time across the typical 5-10 IFT predict calls per
5341    /// outer fit. Reset jointly with `ift_warm_start_cache` (on
5342    /// reset_surface, on link-state changes, on failed PIRLS solves,
5343    /// and whenever a new H_pen replaces the cached one).
5344    pub(crate) ift_cached_factor: RwLock<Option<Arc<dyn gam_linalg::matrix::FactorizedSystem>>>,
5345
5346    /// When set, the penalties have Kronecker (tensor-product) structure and
5347    /// the REML evaluator can use O(∏q_j) logdet instead of O(p³) eigendecomposition.
5348    /// Populated via `set_kronecker_penalty_system` after construction.
5349    pub(crate) kronecker_penalty_system: Option<gam_terms::smooth::KroneckerPenaltySystem>,
5350    /// Full Kronecker factored basis (marginal designs + penalties + dims).
5351    /// Used by P-IRLS for factored reparameterization.
5352    pub(crate) kronecker_factored: Option<gam_terms::basis::KroneckerFactoredBasis>,
5353
5354    /// Precomputed `(XᵀWX, XᵀW(y − offset))` for the Gaussian + Identity
5355    /// outer REML loop, populated once before the outer optimizer when the
5356    /// family / link / constraint preconditions hold and the design supports
5357    /// the Identity short-circuit at `pirls.rs:6237`. When present, each
5358    /// inner `solve_penalized_least_squares_implicit` reads these matrices
5359    /// instead of restreaming the O(N·p²) GEMM and O(N·p) matvec per outer
5360    /// iteration — the penalty `λ·S` is still added per-λ.
5361    ///
5362    /// Invalidated jointly with the design in `reset_surface`.
5363    pub(crate) gaussian_fixed_cache: RwLock<Option<Arc<crate::pirls::GaussianFixedCache>>>,
5364    /// Conditioned-frame exact ψ-derivatives `(∂XᵀWX/∂ψ, ∂XᵀW(y−offset)/∂ψ)`
5365    /// for the SINGLE design-moving spatial hyperparameter (#1033b), assembled
5366    /// n-free from the certified Chebyshev ψ-Gram tensor and installed beside
5367    /// `gaussian_fixed_cache` at the same in-window trial. When present the
5368    /// Gaussian-identity ψ-gradient HyperCoord (`a_j`, `g_j`, dense `B_j`) is
5369    /// formed from these k×k objects instead of realizing and contracting the
5370    /// n×k ∂X/∂ψ slab — retiring the second per-trial n-pass. Lives in the
5371    /// SAME conditioned column frame as `gaussian_fixed_cache.xtwx_orig`, so
5372    /// the hyper-coord builder transforms it by the per-eval Qs/free-basis the
5373    /// same way it transforms the streamed Gram. Invalidated with the design.
5374    pub(crate) gaussian_psi_gram_deriv:
5375        RwLock<Option<Arc<(ndarray::Array2<f64>, ndarray::Array1<f64>)>>>,
5376    /// Conditioned-frame exact ψ-derivative pair `(∂XᵀWX/∂ψ, ∂XᵀW(y−offset)/∂ψ)`
5377    /// for the SINGLE design-moving spatial hyperparameter in the GLM (frozen-W)
5378    /// lane (#1033 / #1111), assembled n-free from
5379    /// [`crate::glm_sufficient_lane::FrozenWeightGramTensor::gradient_pair_if_sound`]
5380    /// and installed beside `glm_first_step_gram` at the same in-window
5381    /// drift-OK trial. When present, the GLM ψ-gradient HyperCoord serves its
5382    /// envelope `a_j` and score `g_j` from these k×k objects instead of
5383    /// realizing and contracting the n×k ∂X/∂ψ slab — the second per-trial
5384    /// n-pass. Unlike the Gaussian lane the Hessian curvature `B_j` is NOT
5385    /// served from the tensor: for a GLM the per-trial `B_j` term
5386    /// `X_τᵀWX + XᵀWX_τ` is irreducibly n-dependent (the moving working weight
5387    /// `W` does not factor out of a frozen-W k×k object), so `B_j` keeps the
5388    /// exact streamed slab (#1033). Lives in the SAME conditioned column frame
5389    /// as `glm_first_step_gram` / `gaussian_fixed_cache.xtwx_orig`, so the
5390    /// hyper-coord builder transforms it by the per-eval Qs/free-basis the same
5391    /// way. NOT family-gated (the GLM lane's own slot). Invalidated with the
5392    /// design.
5393    pub(crate) glm_psi_gram_deriv:
5394        RwLock<Option<Arc<(ndarray::Array2<f64>, ndarray::Array1<f64>)>>>,
5395    /// Frozen-weight first-Fisher-step data-fit Gram `XᵀWX` for the GLM
5396    /// design-moving ψ-sweep (#1111 / #1033 mechanism (c)), in the conditioned
5397    /// (original / `x_fit`) column frame — the SAME frame as
5398    /// `gaussian_fixed_cache.xtwx_orig`.
5399    ///
5400    /// Assembled n-free per in-window ψ-trial from the certified frozen-weight
5401    /// Chebyshev tensor ([`crate::glm_sufficient_lane::FrozenWeightGramTensor`])
5402    /// and installed only when the trial's converged working weight has not
5403    /// drifted past tolerance from the frozen snapshot. When present, the GLM
5404    /// inner P-IRLS serves its FIRST Fisher-scoring iteration's `XᵀWX` from this
5405    /// cache instead of restreaming the O(N·p²) weighted cross-product — the
5406    /// dominant per-trial n-term in a large-n Poisson/Binomial κ-sweep. The
5407    /// penalty `Sλ` is still added per-λ on top, and every subsequent inner
5408    /// iteration restreams the true (moving) `W`, so the converged β̂ is
5409    /// unchanged; only the first-iteration Gram build is elided. Unlike
5410    /// `gaussian_fixed_cache` this is NOT family-gated — it is the GLM lane's
5411    /// own slot, consumed once per inner solve. Invalidated with the design in
5412    /// `reset_surface`.
5413    pub(crate) glm_first_step_gram: RwLock<Option<Arc<ndarray::Array2<f64>>>>,
5414    /// Previous successful non-Gaussian fixed-design data-fit Gram `XᵀWX` in
5415    /// the conditioned original frame, keyed to `warm_start_beta`.
5416    ///
5417    /// When the next outer trial uses a flat warm start, its first PIRLS
5418    /// curvature build evaluates at the same `η = Xβ` as the previous converged
5419    /// solve, so the Hessian weights and `XᵀWX` are identical. Reusing this
5420    /// original-frame Gram skips one dense `O(n·p²)` pass per warm-started
5421    /// trial while still letting later PIRLS iterations restream the moving
5422    /// weights. IFT/tangent-predicted starts do not consume this cache.
5423    pub(crate) flat_glm_first_step_gram: RwLock<Option<Arc<ndarray::Array2<f64>>>>,
5424    /// Frozen ALO robustness weights for this REML surface.
5425    ///
5426    /// The PSIS influence scale is a non-smooth function of the current hat
5427    /// diagonals. Once the high-leverage ALO objective activates, it is frozen
5428    /// for the current surface so the analytic gradient differentiates the
5429    /// same fixed-weight objective the cost evaluates.
5430    pub(crate) alo_frozen_nuisance: RwLock<Option<AloFrozenNuisance>>,
5431
5432    /// ρ-independent certificate that the Gaussian-identity ALO-stabilization
5433    /// augmentation can never activate on this surface (#1689).
5434    ///
5435    /// The augmentation engages only when some row's *penalized* leverage
5436    /// `h_i = w_i · xᵢᵀ H_λ⁻¹ xᵢ` reaches `ALO_MAX_LEVERAGE_THRESHOLD`, where
5437    /// `H_λ = XᵀWX + S_λ + ridge·I ⪰ XᵀWX`. Because `S_λ + ridge·I ⪰ 0` we have
5438    /// `H_λ⁻¹ ⪯ (XᵀWX)⁻¹`, so `h_i ≤ w_i · xᵢᵀ (XᵀWX)⁻¹ xᵢ` — the *unpenalized*
5439    /// weighted hat diagonal, which (for Gaussian identity, where W is the fixed
5440    /// prior-weight diagonal) is independent of ρ. If the max of that bound is
5441    /// below the activation threshold, no ρ can trip the gate, so the entire
5442    /// per-outer-evaluation O(n·p²) ALO leverage diagnostic — recomputed and
5443    /// discarded on every cost/gradient eval otherwise — is skipped.
5444    ///
5445    /// `Some(true)`  → provably inactive everywhere, skip the diagnostic.
5446    /// `Some(false)` → bound is ≥ threshold or XᵀWX is rank-deficient/ill-
5447    ///                 conditioned (bound not certifiable); fall through to the
5448    ///                 exact per-eval gate. Computed lazily, at most once.
5449    pub(crate) alo_provably_inactive: RwLock<Option<bool>>,
5450
5451    /// Stable disk-cache key for the current realized REML surface. Computed
5452    /// lazily because it hashes the row-chunked design and data vectors.
5453    pub(crate) persistent_warm_start_key: RwLock<Option<String>>,
5454    pub(crate) persistent_latent_values_fingerprint: Option<u64>,
5455    pub(crate) persistent_latent_values_cache: RwLock<PersistentLatentValuesCache>,
5456    pub(crate) analytic_penalty_registry_fingerprint: u64,
5457    /// Ensures the process attempts at most one disk restore per surface.
5458    pub(crate) persistent_warm_start_loaded: AtomicBool,
5459    /// Scoped counter disabling disk writes from cost-only posterior/probe
5460    /// evaluations. In-memory warm starts still update; only JSON/bin
5461    /// persistence and eviction sweeps are suppressed.
5462    pub(crate) persistent_warm_start_store_suppression: AtomicUsize,
5463    /// Scoped counter disabling the Gaussian-identity ALO-stabilization
5464    /// augmentation (#979). The leverage barrier `Σ_i (h_i − τ)₊²` is an OUTER
5465    /// OPTIMIZER aid (#813/#821) that keeps the smoothing-parameter search off
5466    /// pathological high-leverage λ regions. The marginal smoothing-parameter
5467    /// posterior `π(ρ|y) ∝ exp(−LAML(ρ))` (#938) is a property of the genuine
5468    /// model criterion, sampled against a Laplace proposal built from the BASE
5469    /// REML Hessian, so the certificate / NUTS evaluations suppress the
5470    /// augmentation (see `without_alo_stabilization`) — both for proposal↔target
5471    /// consistency and to drop the per-leapfrog ALO diagnostic suite.
5472    pub(crate) alo_stabilization_suppression: AtomicUsize,
5473    /// Whether the cross-process ON-DISK warm-start layer is engaged at all.
5474    ///
5475    /// Default `false`: the optimizer's IN-MEMORY warm start (the actual
5476    /// speed lever) is always on, but the disk checkpoint — `load_record`
5477    /// at fit start and `store_record` at finalize, each of which opens the
5478    /// shared `WarmStartStore` and pays an eviction/dir scan that is O(cache
5479    /// entries) on a network filesystem — is skipped. Disk persistence has
5480    /// reuse value only ACROSS processes or across repeated identical fits;
5481    /// a single in-process fit (and a fortiori a loop of distinct throwaway
5482    /// fits, e.g. CI-coverage replicates each on different data, #1082/#1114)
5483    /// gets zero benefit from it and pays the per-fit open/scan/save in full.
5484    /// `FitConfig::persist_warm_start_disk` flips this to `true` only when the
5485    /// caller explicitly asks for cross-process / repeat-fit persistence.
5486    pub(crate) persistent_warm_start_disk_enabled: AtomicBool,
5487    /// #1033: memoized fit-invariant O(n) response/weight scalars.
5488    ///
5489    /// `gaussian_weight_log_sum_half` (`½·Σ log wᵢ`) and `gaussian_dp_floor_scale`
5490    /// (the weighted null deviance `D₀`) are pure functions of the borrowed
5491    /// `(y, weights)` — fields `reset_surface` NEVER reassigns — so they are
5492    /// constant for the whole life of the `RemlState`. They were the last O(n)
5493    /// passes the n-free κ outer loop ran on EVERY `assemble_and_evaluate`
5494    /// (each an n-length scalar reduction, no k factor). Memoizing them once per
5495    /// fit makes the per-trial eval touch only k-dim objects, completing the
5496    /// issue's sufficient-statistic invariant literally. Plain scalar closures,
5497    /// no rayon inside the `get_or_init` (no deadlock trap).
5498    pub(crate) gaussian_weight_log_sum_half_cache: std::sync::OnceLock<f64>,
5499    pub(crate) gaussian_dp_floor_scale_cache: std::sync::OnceLock<f64>,
5500}