Skip to main content

gam_solve/reml/
mod.rs

1use self::inner_strategy::GeometryBackendKind;
2use super::*;
3use gam_linalg::sparse_exact::SparseExactFactor;
4use crate::pirls::PIRLS_CACHE_BYTE_BUDGET;
5use crate::pirls::assemble_and_factor_sparse_penalized_system;
6use gam_terms::basis::LocalDesignJacobianProvider;
7use gam_problem::SasLinkState;
8use gam_problem::OuterEval;
9use ndarray::{Array1, Array2, s};
10use std::collections::{HashMap, VecDeque};
11use std::ops::Range;
12use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize};
13use std::sync::{Arc, RwLock};
14
15pub mod assembly;
16pub mod atoms;
17pub(crate) mod continuation;
18pub(crate) mod eval;
19mod firth;
20pub(super) mod hyper;
21mod inner_strategy;
22// #1521 carve: promoted `pub(crate)` -> `pub` so the extracted
23// `gam-custom-family` crate reaches the Jeffreys-subspace items it consumes.
24pub mod jeffreys_subspace;
25pub mod outer_eval;
26pub mod penalty_logdet;
27pub mod per_atom_efs;
28pub mod reml_outer_engine;
29mod rho_key;
30mod sparse_exact_penalty;
31mod trace;
32
33pub(crate) use sparse_exact_penalty::sparse_penalty_block_count_from_canonical;
34
35pub(crate) const EXACT_TAU_TAU_HESSIAN_DENSE_CACHE_BUDGET_BYTES: usize = 512 * 1024 * 1024;
36pub(crate) const FIRTH_MAX_OBSERVATIONS: usize = 20_000;
37pub(crate) const FIRTH_MAX_COEFFICIENTS: usize = 256;
38pub(crate) const FIRTH_MAX_LINEAR_WORK: usize = 2_000_000;
39pub(crate) const FIRTH_MAX_QUADRATIC_WORK: usize = 100_000_000;
40pub(crate) const PERSISTENT_LATENT_VALUES_CACHE_CAPACITY: usize = 8;
41
42#[derive(Debug)]
43pub(crate) struct PersistentLatentValuesCache {
44    pub(crate) entries: HashMap<String, Array2<f64>>,
45    pub(crate) lru: VecDeque<String>,
46    pub(crate) capacity: usize,
47}
48
49impl Default for PersistentLatentValuesCache {
50    fn default() -> Self {
51        Self {
52            entries: HashMap::new(),
53            lru: VecDeque::new(),
54            capacity: PERSISTENT_LATENT_VALUES_CACHE_CAPACITY,
55        }
56    }
57}
58
59impl PersistentLatentValuesCache {
60    pub(crate) fn lookup(
61        &mut self,
62        key: &str,
63        n_obs: usize,
64        latent_dim: usize,
65    ) -> Option<Array2<f64>> {
66        let values = self.entries.get(key)?;
67        if values.dim() != (n_obs, latent_dim) {
68            return None;
69        }
70        let values = values.clone();
71        self.touch(key.to_string());
72        Some(values)
73    }
74
75    pub(crate) fn insert(&mut self, key: String, values: Array2<f64>) {
76        if values.iter().any(|value| !value.is_finite()) {
77            return;
78        }
79        self.entries.insert(key.clone(), values);
80        self.touch(key);
81        while self.entries.len() > self.capacity {
82            let Some(evicted) = self.lru.pop_front() else {
83                break;
84            };
85            self.entries.remove(&evicted);
86        }
87    }
88
89    pub(crate) fn touch(&mut self, key: String) {
90        if let Some(index) = self.lru.iter().position(|queued| queued == &key) {
91            self.lru.remove(index);
92        }
93        self.lru.push_back(key);
94    }
95}
96
97/// Cached state from the most recent successful PIRLS solve, populated by
98/// `updatewarm_start_from` and consumed by the IFT-based warm-start
99/// predictor (`RemlState::predict_warm_start_beta_ift_with_outcome`).
100/// See the field doc on `RemlState::ift_warm_start_cache` for the math.
101#[derive(Clone)]
102pub(crate) struct IftWarmStartCache {
103    /// β at the converged solve, in ORIGINAL basis. Mirror of
104    /// `warm_start_beta` stashed alongside the H factor for atomic
105    /// consistency under concurrent reads (the predictor needs both
106    /// β and H from the SAME solve; reading them from two locks risks
107    /// a torn pair if a fresh solve lands between reads).
108    pub beta_original: ndarray::Array1<f64>,
109    /// ρ at which the solve occurred. Mirror of `warm_start_rho`,
110    /// stashed for the same atomic-consistency reason.
111    pub rho: ndarray::Array1<f64>,
112    /// Penalized Hessian H_pen at the converged β, in TRANSFORMED basis.
113    /// The IFT predictor factors this on demand; basis transforms run in
114    /// transformed basis for numerical stability.
115    pub penalized_hessian_transformed: gam_linalg::matrix::SymmetricMatrix,
116    /// Reparameterization matrix qs converting between transformed
117    /// (column) basis and original basis: `β_orig = qs · β_tfd`,
118    /// `H_orig = qs · H_tfd · qs^T`.
119    pub qs: ndarray::Array2<f64>,
120    /// True when the PIRLS result was already in original basis
121    /// (`OriginalSparseNative`) — in which case `qs` is the identity
122    /// and the IFT predictor can skip the basis-conversion ops.
123    pub frame_was_original: bool,
124    /// Per-penalty precomputation `S_k · β_cur[cp.col_range]`,
125    /// indexed in lockstep with `RemlObjectiveState::canonical_penalties`.
126    /// Each entry is the local-block mat-vec the IFT predictor would
127    /// otherwise recompute on every predict call. With H_pen factor
128    /// caching (commit ec18559d) the per-call cost dropped from
129    /// `O(p³)` Cholesky to `O(p²) ≈ k · O(block²)` rhs construction;
130    /// at large-scale CTN (p ≈ several thousand) that's several ms
131    /// per predict call still being paid. By stashing `S_k · β_cur`
132    /// at cache-write time the predictor's per-call work drops to
133    /// just the `Δρ_k · e^{ρ_k} · sb_block` accumulation, which is
134    /// `O(p)` rather than `O(p²)`.
135    ///
136    /// `None` when the cache predates this commit's writer hook (e.g.,
137    /// transient state during invalidation); the predictor falls back
138    /// to recomputing the mat-vec when this is `None` or the length
139    /// mismatches `canonical_penalties.len()`.
140    pub lambda_s_beta_blocks: Option<Vec<ndarray::Array1<f64>>>,
141}
142
143#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
144pub(crate) struct TauTauPlanEstimate {
145    pub(crate) dense_x_bytes: usize,
146    pub(crate) first_order_tau_bytes: usize,
147    pub(crate) second_order_tau_bytes: usize,
148    pub(crate) penalty_first_bytes: usize,
149    pub(crate) penalty_pair_bytes: usize,
150    pub(crate) rho_tau_penalty_bytes: usize,
151    pub(crate) vector_cache_bytes: usize,
152    pub(crate) weighted_scratch_bytes: usize,
153}
154
155impl TauTauPlanEstimate {
156    pub(crate) fn total_bytes(self) -> usize {
157        self.dense_x_bytes
158            .saturating_add(self.first_order_tau_bytes)
159            .saturating_add(self.second_order_tau_bytes)
160            .saturating_add(self.penalty_first_bytes)
161            .saturating_add(self.penalty_pair_bytes)
162            .saturating_add(self.rho_tau_penalty_bytes)
163            .saturating_add(self.vector_cache_bytes)
164            .saturating_add(self.weighted_scratch_bytes)
165    }
166}
167
168#[derive(Clone, Copy, Debug, PartialEq, Eq)]
169pub(crate) struct TauTauHessianPolicy {
170    pub(crate) any_has_implicit: bool,
171    pub(crate) implicit_multidim_duchon: bool,
172    pub(crate) estimated_dense_tau_cache_bytes: usize,
173    pub(crate) gradient_plan: TauTauPlanEstimate,
174    pub(crate) hessian_plan: TauTauPlanEstimate,
175    pub(crate) budget_bytes: usize,
176    pub(crate) firth_pair_terms_unavailable: bool,
177}
178
179impl TauTauHessianPolicy {
180    /// True when the τ-τ exact-Hessian path cannot be assembled at all and the
181    /// eval must fall back to value-and-gradient mode (forcing
182    /// `HessianResult::Unavailable`).
183    ///
184    /// This is the *only* remaining capability gate: the previous
185    /// implementation also forced gradient-only when the design used implicit
186    /// multi-dim Duchon storage or when the dense τ-cache plan would exceed
187    /// the budget.  Both of those are now *cost* gates, not capability gates
188    /// — the unified evaluator's `prefer_outer_hessian_operator(n, p, k)`
189    /// selects the matrix-free `HessianResult::Operator` representation in
190    /// exactly the regimes where the dense cache would be unaffordable, and
191    /// the planner routes operator returns through `run_operator_trust_region`
192    /// (or basis-probes them when `dim ≤ OUTER_HVP_MATERIALIZE_MAX_DIM`).
193    /// Forcing gradient-only would have prevented the operator representation
194    /// from ever being requested, defeating that routing; hence the
195    /// `implicit_multidim_duchon` and cost-bytes clauses are deliberately
196    /// gone.
197    ///
198    /// Firth-pair-terms-unavailability remains a capability gate: when the
199    /// Firth-aware derivative provider cannot produce the τ-τ pair
200    /// corrections at all, no representation choice can substitute.  At every
201    /// production call site this flag is hardcoded `false` (the
202    /// `hphi_tau_tau_partial_apply` + `d_beta_hphi_tau_partial_apply`
203    /// primitives now cover the gap), so this method effectively returns
204    /// `false` in production.  We retain the field and method signature
205    /// unchanged so future Firth corner cases have a single, surfaced place
206    /// to land.
207    pub(crate) fn prefer_gradient_only(self) -> bool {
208        self.firth_pair_terms_unavailable
209    }
210}
211
212pub(crate) fn exact_tau_tau_hessian_policy_with_firth(
213    n_obs: usize,
214    p_coeff: usize,
215    hyper_dirs: &[DirectionalHyperParam],
216    firth_pair_terms_unavailable: bool,
217) -> TauTauHessianPolicy {
218    let f64_bytes = std::mem::size_of::<f64>();
219    let dense_matrix_bytes =
220        |rows: usize, cols: usize| -> usize { rows.saturating_mul(cols).saturating_mul(f64_bytes) };
221    let dense_design_bytes = dense_matrix_bytes(n_obs, p_coeff);
222    let dense_penalty_bytes = dense_matrix_bytes(p_coeff, p_coeff);
223    let psi_dim = hyper_dirs.len();
224    let implicit_n_axes = hyper_dirs
225        .iter()
226        .find_map(DirectionalHyperParam::implicit_axis_count_hint)
227        .unwrap_or(0);
228    let gradient_uses_implicit_design = hyper_dirs
229        .iter()
230        .any(DirectionalHyperParam::has_implicit_operator)
231        && gam_terms::basis::should_use_implicit_operators_with_policy(
232            n_obs,
233            p_coeff,
234            implicit_n_axes,
235            &gam_runtime::resource::ResourcePolicy::default_library(),
236        );
237    let dense_first_order_count = hyper_dirs
238        .iter()
239        .filter(|dir| !dir.has_implicit_operator())
240        .count();
241    let first_penalty_component_count = hyper_dirs
242        .iter()
243        .map(DirectionalHyperParam::penalty_first_component_count)
244        .sum::<usize>();
245
246    let mut dense_second_order_count = 0usize;
247    let mut penalty_pair_count = 0usize;
248    for i in 0..psi_dim {
249        for j in i..psi_dim {
250            if hyper_dirs[i]
251                .x_tau_tau_entry_at(j)
252                .or_else(|| hyper_dirs[j].x_tau_tau_entry_at(i))
253                .is_some_and(|entry| !entry.uses_implicit_storage())
254            {
255                dense_second_order_count += if i == j { 1 } else { 2 };
256            }
257            if hyper_dirs[i].has_penaltysecond_pair_at(j)
258                || hyper_dirs[j].has_penaltysecond_pair_at(i)
259            {
260                penalty_pair_count += if i == j { 1 } else { 2 };
261            }
262        }
263    }
264
265    let gradient_dense_first_order_count = if gradient_uses_implicit_design {
266        dense_first_order_count
267    } else {
268        psi_dim
269    };
270    let gradient_needs_dense_x =
271        firth_pair_terms_unavailable || gradient_dense_first_order_count > 0;
272    let gradient_plan = TauTauPlanEstimate {
273        dense_x_bytes: if gradient_needs_dense_x {
274            dense_design_bytes
275        } else {
276            0
277        },
278        first_order_tau_bytes: if gradient_dense_first_order_count > 0 {
279            dense_design_bytes
280        } else {
281            0
282        },
283        second_order_tau_bytes: 0,
284        penalty_first_bytes: psi_dim.saturating_mul(dense_penalty_bytes),
285        penalty_pair_bytes: 0,
286        rho_tau_penalty_bytes: 0,
287        vector_cache_bytes: n_obs.saturating_mul(f64_bytes),
288        weighted_scratch_bytes: dense_penalty_bytes,
289    };
290    let hessian_plan = TauTauPlanEstimate {
291        dense_x_bytes: if psi_dim > 0 { dense_design_bytes } else { 0 },
292        first_order_tau_bytes: dense_first_order_count.saturating_mul(dense_design_bytes),
293        second_order_tau_bytes: dense_second_order_count.saturating_mul(dense_design_bytes),
294        penalty_first_bytes: psi_dim.saturating_mul(dense_penalty_bytes),
295        penalty_pair_bytes: penalty_pair_count.saturating_mul(dense_penalty_bytes),
296        rho_tau_penalty_bytes: first_penalty_component_count
297            .saturating_mul(2)
298            .saturating_mul(dense_penalty_bytes),
299        vector_cache_bytes: psi_dim.saturating_mul(n_obs).saturating_mul(f64_bytes),
300        weighted_scratch_bytes: dense_penalty_bytes,
301    };
302    let any_has_implicit = hyper_dirs
303        .iter()
304        .any(DirectionalHyperParam::has_implicit_operator);
305    let implicit_multidim_duchon = hyper_dirs
306        .iter()
307        .any(DirectionalHyperParam::has_implicit_multidim_duchon);
308    let estimated_dense_tau_cache_bytes = hessian_plan
309        .first_order_tau_bytes
310        .saturating_add(hessian_plan.second_order_tau_bytes);
311    TauTauHessianPolicy {
312        any_has_implicit,
313        implicit_multidim_duchon,
314        estimated_dense_tau_cache_bytes,
315        gradient_plan,
316        hessian_plan,
317        budget_bytes: EXACT_TAU_TAU_HESSIAN_DENSE_CACHE_BUDGET_BYTES,
318        firth_pair_terms_unavailable: firth_pair_terms_unavailable && !hyper_dirs.is_empty(),
319    }
320}
321
322pub(crate) fn firth_problem_scale_allows(n_obs: usize, p_coeff: usize) -> bool {
323    let linear_work = n_obs.saturating_mul(p_coeff);
324    let quadratic_work = linear_work.saturating_mul(p_coeff);
325    n_obs <= FIRTH_MAX_OBSERVATIONS
326        && p_coeff <= FIRTH_MAX_COEFFICIENTS
327        && linear_work <= FIRTH_MAX_LINEAR_WORK
328        && quadratic_work <= FIRTH_MAX_QUADRATIC_WORK
329}
330
331#[cfg(test)]
332mod tests {
333    use super::{
334        DirectionalHyperParam, EvalCacheManager, EvalShared, HyperDesignDerivative,
335        HyperPenaltyDerivative, ImplicitDerivLevel, RemlConfig, RemlState,
336    };
337    use crate::estimate::EstimationError;
338    use gam_linalg::faer_ndarray::FaerCholesky;
339    use gam_linalg::matrix::symmetrize_in_place;
340    use crate::pirls::PirlsCoordinateFrame;
341    use gam_terms::basis::{ImplicitDesignPsiDerivative, RadialScalarKind};
342    use gam_problem::{
343        GlmLikelihoodSpec, InverseLink, LikelihoodSpec, ResponseFamily, StandardLink,
344    };
345    use faer::Side;
346    use gam_problem::{HessianResult, OuterEval};
347    use ndarray::{Array1, Array2, array, s};
348    use std::sync::Arc;
349
350    /// Shorthand for the canonical Binomial-Logit `GlmLikelihoodSpec` used by
351    /// the REML test fixtures.
352    pub(crate) fn binomial_logit_glm_spec() -> GlmLikelihoodSpec {
353        GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
354            ResponseFamily::Binomial,
355            InverseLink::Standard(StandardLink::Logit),
356        ))
357    }
358
359    /// Shorthand for the canonical Gaussian-Identity `GlmLikelihoodSpec` used
360    /// by the REML test fixtures.
361    pub(crate) fn gaussian_identity_glm_spec() -> GlmLikelihoodSpec {
362        GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
363            ResponseFamily::Gaussian,
364            InverseLink::Standard(StandardLink::Identity),
365        ))
366    }
367
368    impl DirectionalHyperParam {
369        pub(super) fn new(
370            x_tau_original: Array2<f64>,
371            penalty_first_components: Vec<(usize, Array2<f64>)>,
372            x_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
373            penaltysecond_components: Option<Vec<Option<Vec<(usize, Array2<f64>)>>>>,
374        ) -> Result<Self, EstimationError> {
375            let x_tau_tau_original = x_tau_tau_original.map(|rows| {
376                rows.into_iter()
377                    .map(|entry| entry.map(HyperDesignDerivative::from))
378                    .collect::<Vec<_>>()
379            });
380            let penalty_first_components = penalty_first_components
381                .into_iter()
382                .map(|(idx, matrix)| (idx, HyperPenaltyDerivative::from(matrix)))
383                .collect();
384            let penaltysecond_components = penaltysecond_components.map(|rows| {
385                rows.into_iter()
386                    .map(|row| {
387                        row.map(|components| {
388                            components
389                                .into_iter()
390                                .map(|(idx, matrix)| (idx, HyperPenaltyDerivative::from(matrix)))
391                                .collect::<Vec<_>>()
392                        })
393                    })
394                    .collect::<Vec<_>>()
395            });
396            Self::new_compact(
397                HyperDesignDerivative::from(x_tau_original),
398                penalty_first_components,
399                x_tau_tau_original,
400                penaltysecond_components,
401            )
402        }
403
404        pub(super) fn single_penalty(
405            penalty_index: usize,
406            x_tau_original: Array2<f64>,
407            s_tau_original: Array2<f64>,
408            x_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
409            s_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
410        ) -> Result<Self, EstimationError> {
411            let penaltysecond_components = s_tau_tau_original.map(|rows| {
412                rows.into_iter()
413                    .map(|mat| mat.map(|mat| vec![(penalty_index, mat)]))
414                    .collect::<Vec<_>>()
415            });
416            Self::new(
417                x_tau_original,
418                vec![(penalty_index, s_tau_original)],
419                x_tau_tau_original,
420                penaltysecond_components,
421            )
422        }
423    }
424
425    #[test]
426    pub(crate) fn firth_problem_scale_gate_blocks_large_quadratic_work() {
427        assert!(super::firth_problem_scale_allows(2_000, 200));
428        assert!(!super::firth_problem_scale_allows(4_800, 241));
429        assert!(!super::firth_problem_scale_allows(4_800, 433));
430    }
431
432    #[test]
433    pub(crate) fn tau_tau_hessian_policy_prefers_gradient_only_for_implicit_tau() {
434        let operator = ImplicitDesignPsiDerivative::new(
435            array![1.0, 2.0, 3.0, 4.0],
436            array![0.5, -1.0, 1.5, 2.0],
437            array![0.1, 0.2, 0.3, 0.4],
438            array![[1.0, 0.2], [0.5, 0.1], [1.5, 0.3], [2.0, 0.4]],
439            None,
440            None,
441            2,
442            2,
443            1,
444            2,
445        );
446        let dir = DirectionalHyperParam::new_compact(
447            HyperDesignDerivative::from_implicit(
448                Arc::new(operator),
449                ImplicitDerivLevel::First(0),
450                1..4,
451                5,
452            ),
453            Vec::new(),
454            None,
455            None,
456        )
457        .expect("implicit directional hyperparam");
458        let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], false);
459        assert!(policy.any_has_implicit);
460        assert_eq!(
461            policy.gradient_plan.dense_x_bytes,
462            10 * 5 * std::mem::size_of::<f64>()
463        );
464        assert!(!policy.prefer_gradient_only());
465    }
466
467    #[test]
468    pub(crate) fn tau_tau_hessian_policy_does_not_force_gradient_only_for_implicit_multidim_duchon()
469    {
470        // Multi-dim Duchon implicit storage used to force gradient-only,
471        // because the τ-cache materialization plan was infeasible.  The
472        // unified evaluator now elects the matrix-free
473        // `HessianResult::Operator` representation in this regime via
474        // `prefer_outer_hessian_operator`, so the planner can route to the
475        // operator trust-region (or basis-probe to dense for small K) — the
476        // capability is preserved and gradient-only must NOT engage.
477        let operator = ImplicitDesignPsiDerivative::new_streaming(
478            Arc::new(array![[0.0, 0.0], [1.0, 0.2]]),
479            Arc::new(array![[0.0, 0.0], [1.0, 1.0]]),
480            vec![0.0, 0.0],
481            RadialScalarKind::PureDuchon {
482                block_order: 1,
483                p_order: 0,
484                s_order: 0,
485                dim: 2,
486            },
487            None,
488            None,
489            0,
490        );
491        let dir = DirectionalHyperParam::new_compact(
492            HyperDesignDerivative::from_implicit(
493                Arc::new(operator),
494                ImplicitDerivLevel::First(0),
495                0..2,
496                2,
497            ),
498            Vec::new(),
499            None,
500            None,
501        )
502        .expect("implicit duchon directional hyperparam");
503        let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], false);
504        assert!(policy.any_has_implicit);
505        assert!(policy.implicit_multidim_duchon);
506        assert!(!policy.prefer_gradient_only());
507    }
508
509    #[test]
510    pub(crate) fn tau_tau_hessian_policy_does_not_force_gradient_only_when_cache_budget_is_exceeded()
511     {
512        // The dense τ-cache plan exceeds the budget, but cost is no longer a
513        // capability gate: the eval-side selects the matrix-free operator
514        // representation in exactly this regime, and the planner routes
515        // accordingly.  `prefer_gradient_only` must NOT force `Unavailable`
516        // here.
517        let dirs = (0..16)
518            .map(|_| {
519                DirectionalHyperParam::new_compact(
520                    HyperDesignDerivative::from(Array2::<f64>::zeros((2, 2))),
521                    Vec::new(),
522                    None,
523                    None,
524                )
525                .expect("dense directional hyperparam")
526            })
527            .collect::<Vec<_>>();
528        let policy = super::exact_tau_tau_hessian_policy_with_firth(320_000, 71, &dirs, false);
529        assert!(!policy.any_has_implicit);
530        assert!(policy.hessian_plan.total_bytes() > policy.budget_bytes);
531        assert!(policy.hessian_plan.total_bytes() > policy.gradient_plan.total_bytes());
532        assert!(!policy.prefer_gradient_only());
533    }
534
535    #[test]
536    pub(crate) fn tau_tau_hessian_policy_prefers_gradient_only_for_firth_pair_gap() {
537        let dir = DirectionalHyperParam::new_compact(
538            HyperDesignDerivative::from(Array2::<f64>::zeros((2, 2))),
539            Vec::new(),
540            None,
541            None,
542        )
543        .expect("dense directional hyperparam");
544        let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], true);
545        assert!(policy.firth_pair_terms_unavailable);
546        assert!(policy.prefer_gradient_only());
547    }
548
549    /// Common shape for the design-motion + penalty-motion REML test fixtures
550    /// (Gaussian-identity and binomial-logit at present): both carry the
551    /// same `(y, w, X, S0, cfg, ρ)` plus a perturbation pair, and need the
552    /// same three helpers (`state`, `state_perturbed`, `fd_directional_gradient`).
553    /// Per-fixture `new()` constructors fill the fields with family-specific
554    /// data; the helpers below are shared via default impls so every fixture
555    /// pays the boilerplate once.
556    trait LogitDesignMotionFixture {
557        fn y(&self) -> &Array1<f64>;
558        fn w(&self) -> &Array1<f64>;
559        fn x(&self) -> &Array2<f64>;
560        fn s0(&self) -> &Array2<f64>;
561        fn cfg(&self) -> &RemlConfig;
562        fn rho(&self) -> &Array1<f64>;
563
564        fn state(&self) -> RemlState<'_> {
565            build_logit_state(self.y(), self.w(), self.x(), self.s0(), self.cfg())
566        }
567
568        fn state_perturbed(
569            &self,
570            x_tau: &Array2<f64>,
571            s_tau: &Array2<f64>,
572            eps: f64,
573        ) -> (RemlState<'_>, RemlState<'_>) {
574            let x_plus = self.x() + &x_tau.mapv(|v| eps * v);
575            let x_minus = self.x() - &x_tau.mapv(|v| eps * v);
576            let s_plus = self.s0() + &s_tau.mapv(|v| eps * v);
577            let s_minus = self.s0() - &s_tau.mapv(|v| eps * v);
578            (
579                build_logit_state(self.y(), self.w(), &x_plus, &s_plus, self.cfg()),
580                build_logit_state(self.y(), self.w(), &x_minus, &s_minus, self.cfg()),
581            )
582        }
583
584        /// Central FD approximation to the directional cost derivative at ρ.
585        fn fd_directional_gradient(&self, x_tau: &Array2<f64>, s_tau: &Array2<f64>) -> f64 {
586            let h = 2e-5;
587            let (state_plus, state_minus) = self.state_perturbed(x_tau, s_tau, h);
588            let v_plus = state_plus.compute_cost(self.rho()).expect("cost+");
589            let v_minus = state_minus.compute_cost(self.rho()).expect("cost-");
590            (v_plus - v_minus) / (2.0 * h)
591        }
592    }
593
594    pub(crate) fn build_logit_state<'a>(
595        y: &'a Array1<f64>,
596        w: &'a Array1<f64>,
597        x: &Array2<f64>,
598        s: &Array2<f64>,
599        cfg: &'a RemlConfig,
600    ) -> RemlState<'a> {
601        use crate::estimate::PenaltySpec;
602        let p = x.ncols();
603        let offset = Array1::<f64>::zeros(y.len());
604        let spec = PenaltySpec::Dense(s.clone());
605        let canonical = gam_terms::construction::canonicalize_penalty_specs(&[spec], &[1], p, "test")
606            .map(|(canonical, _)| canonical)
607            .expect("canonicalize");
608        RemlState::newwith_offset(
609            y.view(),
610            x.clone(),
611            w.view(),
612            offset.view(),
613            canonical,
614            p,
615            cfg,
616            Some(vec![1]),
617            None,
618            None,
619        )
620        .expect("state")
621    }
622
623    #[test]
624    fn repeated_penalty_ranges_keep_analytic_outer_hessian() {
625        let y = array![0.2, -0.1, 0.3, 0.0];
626        let w = Array1::<f64>::ones(y.len());
627        let x = array![[1.0, -0.7], [1.0, -0.2], [1.0, 0.3], [1.0, 0.9]];
628        let offset = Array1::<f64>::zeros(y.len());
629        let cfg = RemlConfig::external(gaussian_identity_glm_spec(), 1e-10, false);
630        let p = x.ncols();
631        let canonical = vec![
632            gam_terms::construction::CanonicalPenalty::from_dense_root(array![[0.0, 1.0]], p),
633            gam_terms::construction::CanonicalPenalty::from_dense_root(array![[1.0, 0.0]], p),
634        ];
635        let state = RemlState::newwith_offset(
636            y.view(),
637            x,
638            w.view(),
639            offset.view(),
640            canonical,
641            p,
642            &cfg,
643            Some(vec![1, 1]),
644            None,
645            None,
646        )
647        .expect("state");
648
649        assert!(
650            state.analytic_outer_hessian_enabled(),
651            "double-penalty-style repeated coefficient ranges must still route to exact Hessian"
652        );
653    }
654
655    pub(crate) fn poisson_log_glm_spec() -> GlmLikelihoodSpec {
656        GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
657            ResponseFamily::Poisson,
658            InverseLink::Standard(StandardLink::Log),
659        ))
660    }
661
662    /// Regression (issue #893): for a fixed-dispersion family a uniform prior
663    /// weight `w = c` is *exact* `c`-fold row replication. The two encodings must
664    /// therefore present a byte-identical LAML smoothing-selection surface — both
665    /// the cost `V(ρ)` and its gradient `∇V(ρ)` — because every term (penalised
666    /// deviance `D_p`, the working cross-product `XᵀWX`, the log-determinants)
667    /// is a sum of per-observation contributions that is identical whether a row
668    /// carries weight `c` or is stacked `c` times. This locks the *surface*
669    /// invariant that #893 ultimately reduces to: when the surfaces coincide,
670    /// the only remaining requirement for `λ̂(w=c) = λ̂(c×)` is that the outer
671    /// optimiser resolve the shared optimum (handled by the tightened outer
672    /// tolerance in `workflow.rs`). A regression that reintroduced a
673    /// row-count-vs-weight-sum asymmetry into the inner solve or the cost would
674    /// break this directly, independent of the optimiser tolerance.
675    #[test]
676    pub(crate) fn fixed_dispersion_laml_surface_is_replication_invariant() {
677        let n = 200usize;
678        let p = 8usize;
679        let c = 3usize;
680        let mut x = Array2::<f64>::zeros((n, p));
681        let mut y = Array1::<f64>::zeros(n);
682        for i in 0..n {
683            let t = (i as f64) / ((n - 1) as f64);
684            let tau = std::f64::consts::TAU;
685            x[[i, 0]] = 1.0;
686            x[[i, 1]] = t;
687            x[[i, 2]] = (tau * t).sin();
688            x[[i, 3]] = (tau * t).cos();
689            x[[i, 4]] = (2.0 * tau * t).sin();
690            x[[i, 5]] = (2.0 * tau * t).cos();
691            x[[i, 6]] = (3.0 * tau * t).sin();
692            x[[i, 7]] = (3.0 * tau * t).cos();
693            let eta = 0.3 + 0.9 * (1.4 * (t - 0.5)).sin();
694            // Deterministic non-negative integer counts near exp(eta).
695            y[i] = (eta.exp() + 0.5 * ((i as f64) * 2.399_963).sin())
696                .round()
697                .max(0.0);
698        }
699        let mut s = Array2::<f64>::zeros((p, p));
700        for j in 1..p {
701            s[[j, j]] = 1.0;
702        }
703
704        // Replicated design (c literal copies of each row).
705        let mut x_rep = Array2::<f64>::zeros((n * c, p));
706        let mut y_rep = Array1::<f64>::zeros(n * c);
707        for r in 0..c {
708            for i in 0..n {
709                let row = r * n + i;
710                for j in 0..p {
711                    x_rep[[row, j]] = x[[i, j]];
712                }
713                y_rep[row] = y[i];
714            }
715        }
716
717        let w_weighted = Array1::<f64>::from_elem(n, c as f64);
718        let w_rep = Array1::<f64>::ones(n * c);
719
720        let cfg = RemlConfig::external(poisson_log_glm_spec(), 1e-10, false);
721        let st_w = build_logit_state(&y, &w_weighted, &x, &s, &cfg);
722        let st_r = build_logit_state(&y_rep, &w_rep, &x_rep, &s, &cfg);
723
724        for &rho in &[-2.0_f64, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0] {
725            let r = Array1::from_elem(1, rho);
726            let cw = st_w.compute_cost(&r).expect("weighted cost");
727            let cr = st_r.compute_cost(&r).expect("replicated cost");
728            let gw = st_w.compute_gradient(&r).expect("weighted grad");
729            let gr = st_r.compute_gradient(&r).expect("replicated grad");
730            // Costs and gradients must coincide to optimiser precision; the only
731            // admissible difference is f64 summation order over n vs c·n rows.
732            assert!(
733                (cw - cr).abs() <= 1e-9 * (1.0 + cw.abs()),
734                "LAML cost differs between w=c and c× replication at rho={rho}: \
735                 cost_w={cw:.12e} cost_r={cr:.12e} diff={:.3e}",
736                cw - cr
737            );
738            assert!(
739                (gw[0] - gr[0]).abs() <= 1e-9 * (1.0 + gw[0].abs()),
740                "LAML gradient differs between w=c and c× replication at rho={rho}: \
741                 g_w={:.12e} g_r={:.12e} diff={:.3e}",
742                gw[0],
743                gr[0],
744                gw[0] - gr[0]
745            );
746        }
747    }
748
749    /// Regression (issue #893): the geometric-mean log-weight ρ-anchor
750    /// ([`RemlState::rho_weight_anchor`]) is a *profiled*-dispersion construct
751    /// (issue #877). For a fixed-dispersion family the optimum does not slide by
752    /// `log c` under a weight rescale in a way the prior should track, and a
753    /// nonzero anchor would evaluate the regularising ρ-prior at *different*
754    /// coordinates for the `w=c` vs `c×` encodings — breaking the very
755    /// equivalence #893 requires. The anchor must therefore be exactly `0` for a
756    /// fixed-dispersion family and the geometric mean for Gaussian-identity.
757    #[test]
758    pub(crate) fn rho_weight_anchor_is_zero_for_fixed_dispersion() {
759        let n = 50usize;
760        let p = 3usize;
761        let mut x = Array2::<f64>::zeros((n, p));
762        let mut y = Array1::<f64>::zeros(n);
763        for i in 0..n {
764            let t = (i as f64) / ((n - 1) as f64);
765            x[[i, 0]] = 1.0;
766            x[[i, 1]] = t;
767            x[[i, 2]] = t * t;
768            y[i] = (1.0 + (3.0 * t).sin()).round().max(0.0);
769        }
770        let mut s = Array2::<f64>::zeros((p, p));
771        s[[2, 2]] = 1.0;
772        // All weights = c: geometric-mean log-weight = ln(c) ≠ 0.
773        let c = 4.0_f64;
774        let w = Array1::<f64>::from_elem(n, c);
775
776        let cfg_pois = RemlConfig::external(poisson_log_glm_spec(), 1e-10, false);
777        let st_pois = build_logit_state(&y, &w, &x, &s, &cfg_pois);
778        assert_eq!(
779            st_pois.rho_weight_anchor(),
780            0.0,
781            "fixed-dispersion (Poisson) anchor must be 0, not the geometric-mean log-weight"
782        );
783
784        let cfg_gauss = RemlConfig::external(gaussian_identity_glm_spec(), 1e-10, false);
785        let st_gauss = build_logit_state(&y, &w, &x, &s, &cfg_gauss);
786        assert!(
787            (st_gauss.rho_weight_anchor() - c.ln()).abs() <= 1e-12,
788            "Gaussian-identity (profiled) anchor must be the geometric-mean log-weight ln(c)={:.6}, got {:.6}",
789            c.ln(),
790            st_gauss.rho_weight_anchor()
791        );
792    }
793
794    pub(crate) fn beta_original_from_bundle(bundle: &EvalShared) -> Array1<f64> {
795        let pr = bundle.pirls_result.as_ref();
796        match pr.coordinate_frame {
797            PirlsCoordinateFrame::OriginalSparseNative => pr.beta_transformed.as_ref().clone(),
798            PirlsCoordinateFrame::TransformedQs => {
799                pr.reparam_result.qs.dot(pr.beta_transformed.as_ref())
800            }
801        }
802    }
803
804    pub(crate) fn compute_joint_hypercostgradienthessian(
805        state: &RemlState<'_>,
806        theta: &Array1<f64>,
807        rho_dim: usize,
808        hyper_dirs: &[DirectionalHyperParam],
809    ) -> Result<(f64, Array1<f64>, Array2<f64>), EstimationError> {
810        let (cost, gradient, hessian) = state.compute_joint_hyper_eval_with_order(
811            theta,
812            rho_dim,
813            hyper_dirs,
814            crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
815        )?;
816        Ok((
817            cost,
818            gradient,
819            hessian
820                .materialize_dense()
821                .map_err(EstimationError::RemlOptimizationFailed)?
822                .ok_or_else(|| {
823                    EstimationError::RemlOptimizationFailed(
824                        "joint hyper Hessian requested but unavailable".to_string(),
825                    )
826                })?,
827        ))
828    }
829
830    pub(crate) fn h_original_from_bundle(bundle: &EvalShared) -> Array2<f64> {
831        let pr = bundle.pirls_result.as_ref();
832        match pr.coordinate_frame {
833            PirlsCoordinateFrame::OriginalSparseNative => bundle.h_total.as_ref().clone(),
834            PirlsCoordinateFrame::TransformedQs => {
835                let qs = &pr.reparam_result.qs;
836                let tmp = gam_linalg::faer_ndarray::fast_ab(qs, bundle.h_total.as_ref());
837                gam_linalg::faer_ndarray::fast_abt(&tmp, qs)
838            }
839        }
840    }
841
842    pub(crate) fn single_directional_tau_gradient(
843        state: &RemlState<'_>,
844        rho: &Array1<f64>,
845        hyper: DirectionalHyperParam,
846    ) -> Result<f64, EstimationError> {
847        let mut theta = Array1::<f64>::zeros(rho.len() + 1);
848        theta.slice_mut(s![..rho.len()]).assign(rho);
849        let (_, gradient, _) = state.compute_joint_hyper_eval_with_order(
850            &theta,
851            rho.len(),
852            &[hyper],
853            crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
854        )?;
855        Ok(gradient[rho.len()])
856    }
857
858    pub(crate) fn fd_directional_tau_cost_gradient(
859        y: &Array1<f64>,
860        w: &Array1<f64>,
861        x: &Array2<f64>,
862        s0: &Array2<f64>,
863        cfg: &RemlConfig,
864        rho: &Array1<f64>,
865        x_tau: &Array2<f64>,
866        s_tau: &Array2<f64>,
867    ) -> f64 {
868        let h = 2e-5;
869        let x_plus = x + &x_tau.mapv(|v| h * v);
870        let x_minus = x - &x_tau.mapv(|v| h * v);
871        let s_plus = s0 + &s_tau.mapv(|v| h * v);
872        let s_minus = s0 - &s_tau.mapv(|v| h * v);
873        let state_plus = build_logit_state(y, w, &x_plus, &s_plus, cfg);
874        let state_minus = build_logit_state(y, w, &x_minus, &s_minus, cfg);
875        let v_plus = state_plus.compute_cost(rho).expect("cost+");
876        let v_minus = state_minus.compute_cost(rho).expect("cost-");
877        (v_plus - v_minus) / (2.0 * h)
878    }
879
880    pub(crate) fn directional_tau_hessian_fd_reference(
881        y: &Array1<f64>,
882        w: &Array1<f64>,
883        x: &Array2<f64>,
884        s0: &Array2<f64>,
885        cfg: &RemlConfig,
886        rho: &Array1<f64>,
887        hyper_dirs: &[DirectionalHyperParam],
888        x_tau_mats: &[Array2<f64>],
889        s_tau_mats: &[Array2<f64>],
890    ) -> Array2<f64> {
891        assert_eq!(hyper_dirs.len(), x_tau_mats.len());
892        assert_eq!(hyper_dirs.len(), s_tau_mats.len());
893
894        const TARGET_PHYSICAL_STEP: f64 = 1e-5;
895
896        let n_dirs = hyper_dirs.len();
897        let mut h_ttfd = Array2::<f64>::zeros((n_dirs, n_dirs));
898        for j in 0..n_dirs {
899            let direction_scale = x_tau_mats[j]
900                .iter()
901                .chain(s_tau_mats[j].iter())
902                .fold(0.0_f64, |acc, value| acc.max(value.abs()));
903            let h = if direction_scale > 0.0 {
904                TARGET_PHYSICAL_STEP / direction_scale
905            } else {
906                TARGET_PHYSICAL_STEP
907            };
908
909            let x_plus = x + &x_tau_mats[j].mapv(|v| h * v);
910            let x_minus = x - &x_tau_mats[j].mapv(|v| h * v);
911            let s_plus = s0 + &s_tau_mats[j].mapv(|v| h * v);
912            let s_minus = s0 - &s_tau_mats[j].mapv(|v| h * v);
913
914            let state_plus = build_logit_state(y, w, &x_plus, &s_plus, cfg);
915            let state_minus = build_logit_state(y, w, &x_minus, &s_minus, cfg);
916            for i in 0..n_dirs {
917                let g_plus =
918                    single_directional_tau_gradient(&state_plus, rho, hyper_dirs[i].clone())
919                        .expect("g+ for FD");
920                let g_minus =
921                    single_directional_tau_gradient(&state_minus, rho, hyper_dirs[i].clone())
922                        .expect("g- for FD");
923                h_ttfd[[i, j]] = (g_plus - g_minus) / (2.0 * h);
924            }
925        }
926        symmetrize_in_place(&mut h_ttfd);
927        h_ttfd
928    }
929
930    #[test]
931    pub(crate) fn eval_cache_manager_stores_first_order_outer_eval() {
932        let cache = EvalCacheManager::new();
933        let rho = array![0.25, -0.0];
934        let rho_key = EvalCacheManager::sanitized_rhokey(&rho);
935        let eval = OuterEval {
936            cost: 3.5,
937            gradient: array![1.0, -2.0],
938            hessian: HessianResult::Unavailable,
939            inner_beta_hint: None,
940        };
941
942        cache.store_outer_eval(&rho_key, &eval);
943
944        let cached = cache
945            .cached_outer_eval(&rho_key)
946            .expect("first-order outer eval should be cached");
947        assert_eq!(cached.cost, eval.cost);
948        assert_eq!(cached.gradient, eval.gradient);
949        assert!(matches!(cached.hessian, HessianResult::Unavailable));
950
951        cache.invalidate_eval_bundle();
952        assert!(
953            cache.cached_outer_eval(&rho_key).is_none(),
954            "invalidating the bundle should clear the outer-eval cache too"
955        );
956    }
957
958    /// #1575 multi-slot outer-eval cache correctness oracle.
959    ///
960    /// A memoization is only safe if a hit returns *exactly* what the miss path
961    /// stored. This pins three properties of the bounded LRU:
962    ///   1. round-trip fidelity — a hit is `f64::to_bits`-identical in cost AND
963    ///      every gradient component to the value stored on the miss path;
964    ///   2. no aliasing — distinct rho-keys never return each other's eval;
965    ///   3. honest eviction — once an evicted key is requested again it MISSES
966    ///      (so the caller recomputes) rather than returning a stale neighbour.
967    #[test]
968    pub(crate) fn outer_eval_lru_hit_is_bit_identical_and_evicts_honestly_1575() {
969        use super::OUTER_EVAL_LRU_CAPACITY;
970
971        // Helper: a deterministic OuterEval whose bits encode `seed`, so any
972        // cross-key contamination is detectable bit-for-bit.
973        let make_eval = |seed: f64| OuterEval {
974            cost: (seed * std::f64::consts::PI).sin() / 3.0 - seed,
975            gradient: array![seed, -seed * 2.0, seed.recip()],
976            hessian: HessianResult::Unavailable,
977            inner_beta_hint: Some(array![seed + 0.5, seed - 0.5]),
978        };
979        let bits_eq = |a: &OuterEval, b: &OuterEval| -> bool {
980            a.cost.to_bits() == b.cost.to_bits()
981                && a.gradient.len() == b.gradient.len()
982                && a.gradient
983                    .iter()
984                    .zip(b.gradient.iter())
985                    .all(|(x, y)| x.to_bits() == y.to_bits())
986        };
987
988        let cache = EvalCacheManager::new();
989
990        // (1) Round-trip fidelity: store at rho_a, then a forced hit must equal
991        // the stored eval bit-for-bit (the "hit == miss" guarantee).
992        let rho_a = array![0.25, -1.5];
993        let key_a = EvalCacheManager::sanitized_rhokey(&rho_a);
994        let eval_a = make_eval(0.25);
995        cache.store_outer_eval(&key_a, &eval_a);
996        let hit_a = cache
997            .cached_outer_eval(&key_a)
998            .expect("stored rho_a must hit");
999        assert!(
1000            bits_eq(&hit_a, &eval_a),
1001            "cache hit must be bit-identical (cost+gradient) to the stored miss-path eval"
1002        );
1003        assert_eq!(
1004            hit_a.inner_beta_hint.as_ref().map(|b| b.to_vec()),
1005            eval_a.inner_beta_hint.as_ref().map(|b| b.to_vec()),
1006            "inner_beta_hint must round-trip unchanged"
1007        );
1008
1009        // (2) No aliasing: a second, distinct rho must return its OWN eval, and
1010        // the first key must still return the first eval untouched.
1011        let rho_b = array![0.25, -1.4999999999999998];
1012        let key_b = EvalCacheManager::sanitized_rhokey(&rho_b);
1013        assert_ne!(key_a, key_b, "the two rho-keys must differ");
1014        let eval_b = make_eval(7.0);
1015        cache.store_outer_eval(&key_b, &eval_b);
1016        assert!(
1017            bits_eq(
1018                &cache.cached_outer_eval(&key_b).expect("rho_b must hit"),
1019                &eval_b
1020            ),
1021            "rho_b must return its own eval, not rho_a's"
1022        );
1023        assert!(
1024            bits_eq(
1025                &cache.cached_outer_eval(&key_a).expect("rho_a must still hit"),
1026                &eval_a
1027            ),
1028            "rho_a must be unaffected by the rho_b insert"
1029        );
1030
1031        // (3) Honest eviction: overflow the LRU with fresh keys. The
1032        // least-recently-used entry must be evicted and then MISS (forcing a
1033        // recompute), while a still-resident key returns its exact stored bits.
1034        let cache = EvalCacheManager::new();
1035        let mut keys = Vec::new();
1036        let mut evals = Vec::new();
1037        for i in 0..OUTER_EVAL_LRU_CAPACITY {
1038            let rho = array![i as f64, -(i as f64)];
1039            let key = EvalCacheManager::sanitized_rhokey(&rho);
1040            let eval = make_eval(i as f64 + 0.123);
1041            cache.store_outer_eval(&key, &eval);
1042            keys.push(key);
1043            evals.push(eval);
1044        }
1045        // Cache is exactly full; key[0] is the least-recently-used.
1046        assert_eq!(
1047            cache.outer_eval_lru.read().unwrap().entries.len(),
1048            OUTER_EVAL_LRU_CAPACITY
1049        );
1050        // One more distinct key evicts the LRU (key[0]).
1051        let rho_overflow = array![999.0, -999.0];
1052        let key_overflow = EvalCacheManager::sanitized_rhokey(&rho_overflow);
1053        let eval_overflow = make_eval(42.0);
1054        cache.store_outer_eval(&key_overflow, &eval_overflow);
1055        assert_eq!(
1056            cache.outer_eval_lru.read().unwrap().entries.len(),
1057            OUTER_EVAL_LRU_CAPACITY,
1058            "capacity must stay bounded"
1059        );
1060        assert!(
1061            cache.cached_outer_eval(&keys[0]).is_none(),
1062            "the least-recently-used key must be evicted and now MISS (recompute), not return stale"
1063        );
1064        assert!(
1065            bits_eq(
1066                &cache
1067                    .cached_outer_eval(&keys[1])
1068                    .expect("a still-resident key must hit"),
1069                &evals[1]
1070            ),
1071            "a still-resident key must return its exact stored bits"
1072        );
1073        assert!(
1074            bits_eq(
1075                &cache
1076                    .cached_outer_eval(&key_overflow)
1077                    .expect("the freshest key must hit"),
1078                &eval_overflow
1079            ),
1080            "the freshest key must hit with its own eval"
1081        );
1082    }
1083
1084    #[test]
1085    pub(crate) fn reset_outer_seed_state_clears_pirls_cache() {
1086        // Build a minimal logit RemlState, populate the cross-call PIRLS LRU
1087        // by evaluating the outer objective at one rho, then verify that
1088        // reset_outer_seed_state wipes that LRU (alongside the eval bundle
1089        // and warm-start signals). This pins down the cross-attempt
1090        // cleanup contract that a budget-bump retry relies on.
1091        let y = array![0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
1092        let w = Array1::<f64>::ones(y.len());
1093        let x = array![
1094            [1.0, -1.0, 0.2],
1095            [1.0, -0.5, -0.4],
1096            [1.0, 0.0, 0.7],
1097            [1.0, 0.4, -0.3],
1098            [1.0, 0.9, 0.1],
1099            [1.0, 1.3, -0.6],
1100        ];
1101        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.1, 0.15], [0.0, 0.15, 0.8],];
1102        let rho = array![0.0];
1103        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false);
1104        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1105
1106        // Trigger a full outer eval so execute_pirls_if_needed inserts at
1107        // least one entry into the cross-call PIRLS LRU.
1108        state
1109            .compute_outer_eval_with_order(
1110                &rho,
1111                crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
1112            )
1113            .expect("outer eval should succeed");
1114
1115        let populated_len = state.cache_manager.pirls_cache.read().unwrap().map.len();
1116        assert!(
1117            populated_len > 0,
1118            "evaluating the outer objective should populate the PIRLS LRU, got {populated_len}"
1119        );
1120
1121        state.reset_outer_seed_state();
1122
1123        let cleared_len = state.cache_manager.pirls_cache.read().unwrap().map.len();
1124        assert_eq!(
1125            cleared_len, 0,
1126            "reset_outer_seed_state must clear the cross-call PIRLS LRU; got {cleared_len} entries"
1127        );
1128    }
1129
1130    #[test]
1131    pub(crate) fn reset_outer_seed_state_preserves_frozen_negbin_theta_1448() {
1132        // #1448 regression: the NB outer θ↔λ alternation loop
1133        // (solver/estimate/optimizer.rs) re-runs the ρ search after each θ
1134        // refresh by (a) re-freezing the λ-search θ at θ_final into
1135        // `frozen_negbin_theta`, then (b) calling `reset_outer_seed_state()` to
1136        // drop the caches keyed to the old θ. Step (b) MUST NOT clear the freeze
1137        // set in step (a): the capture in `solve_for_unified_rho` only writes the
1138        // frozen slot when it is 0, so if the reset zeroed it the next round would
1139        // re-derive θ from the seed η and the loop would never reach the (ρ, θ)
1140        // joint fixed point — silently regressing #1448 back to a single
1141        // freeze→refresh pass.
1142        //
1143        // This pins the load-bearing distinction between `reset_outer_seed_state`
1144        // (alternation-round reset, freeze SURVIVES) and the surface-refresh reset
1145        // (new design, freeze re-zeroed). The end-to-end convergence on a real NB
1146        // fit is exercised by the public-API path; here we lock the invariant the
1147        // loop depends on, next to `reset_outer_seed_state_clears_pirls_cache`.
1148        use std::sync::atomic::Ordering;
1149
1150        let y = array![0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
1151        let w = Array1::<f64>::ones(y.len());
1152        let x = array![
1153            [1.0, -1.0, 0.2],
1154            [1.0, -0.5, -0.4],
1155            [1.0, 0.0, 0.7],
1156            [1.0, 0.4, -0.3],
1157            [1.0, 0.9, 0.1],
1158            [1.0, 1.3, -0.6],
1159        ];
1160        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.1, 0.15], [0.0, 0.15, 0.8],];
1161        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false);
1162        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1163
1164        // Simulate the alternation loop's re-freeze step: pin θ_final.
1165        let theta_final_bits = 2.5_f64.to_bits();
1166        state
1167            .frozen_negbin_theta
1168            .store(theta_final_bits, Ordering::Relaxed);
1169        assert_eq!(
1170            state.frozen_negbin_theta.load(Ordering::Relaxed),
1171            theta_final_bits,
1172            "precondition: the re-freeze stores θ_final into the frozen slot"
1173        );
1174
1175        // The alternation loop's per-round reset.
1176        state.reset_outer_seed_state();
1177
1178        assert_eq!(
1179            state.frozen_negbin_theta.load(Ordering::Relaxed),
1180            theta_final_bits,
1181            "reset_outer_seed_state (alternation-round reset) must PRESERVE the \
1182             re-frozen NB θ; clearing it would defeat the #1448 θ↔λ alternation \
1183             (the next ρ search would re-derive θ from the seed and never reach \
1184             the joint fixed point)"
1185        );
1186    }
1187
1188    #[test]
1189    pub(crate) fn implicit_hyper_design_derivative_respects_full_model_embedding() {
1190        let operator = ImplicitDesignPsiDerivative::new(
1191            array![1.0, 2.0, 3.0, 4.0],
1192            array![0.5, -1.0, 1.5, 2.0],
1193            array![0.1, 0.2, 0.3, 0.4],
1194            array![[1.0, 0.2], [0.5, 0.1], [1.5, 0.3], [2.0, 0.4]],
1195            None,
1196            None,
1197            2,
1198            2,
1199            1,
1200            2,
1201        );
1202        let local = operator
1203            .materialize_first(0)
1204            .expect("materialized first derivative");
1205        assert_eq!(
1206            local.ncols(),
1207            3,
1208            "operator-local derivative should stay smooth-local"
1209        );
1210
1211        let implicit = HyperDesignDerivative::from_implicit(
1212            Arc::new(operator),
1213            ImplicitDerivLevel::First(0),
1214            1..4,
1215            5,
1216        );
1217        let embedded = HyperDesignDerivative::from_embedded(local.clone(), 1..4, 5);
1218
1219        assert_eq!(implicit.nrows(), embedded.nrows());
1220        assert_eq!(implicit.ncols(), 5);
1221        assert_eq!(implicit.materialize(), embedded.materialize());
1222
1223        let u = array![7.0, 1.5, -2.0, 0.25, -3.0];
1224        let v = array![0.75, -1.25];
1225        assert_eq!(
1226            implicit.forward_mul_original(&u).expect("implicit forward"),
1227            embedded.forward_mul_original(&u).expect("embedded forward")
1228        );
1229        assert_eq!(
1230            implicit
1231                .transpose_mul_original(&v)
1232                .expect("implicit transpose"),
1233            embedded
1234                .transpose_mul_original(&v)
1235                .expect("embedded transpose")
1236        );
1237
1238        let qs = array![
1239            [1.0, 0.0, 0.0],
1240            [0.0, 1.0, 0.0],
1241            [0.0, 0.5, 0.5],
1242            [0.0, 0.0, 1.0],
1243            [0.0, 0.0, 0.0],
1244        ];
1245        assert_eq!(
1246            implicit
1247                .transformed(&qs, None)
1248                .expect("implicit transformed"),
1249            embedded
1250                .transformed(&qs, None)
1251                .expect("embedded transformed")
1252        );
1253        let u_transformed = array![1.0, -0.5, 2.0];
1254        assert_eq!(
1255            implicit
1256                .transformed_forward_mul(&qs, None, &u_transformed)
1257                .expect("implicit transformed forward"),
1258            embedded
1259                .transformed_forward_mul(&qs, None, &u_transformed)
1260                .expect("embedded transformed forward")
1261        );
1262        assert_eq!(
1263            implicit
1264                .transformed_transpose_mul(&qs, None, &v)
1265                .expect("implicit transformed transpose"),
1266            embedded
1267                .transformed_transpose_mul(&qs, None, &v)
1268                .expect("embedded transformed transpose")
1269        );
1270    }
1271
1272    #[test]
1273    pub(crate) fn directional_hyper_identities_match_finite_differences_logit() {
1274        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1275        let w = Array1::<f64>::ones(y.len());
1276        let x = array![
1277            [1.0, -1.2, 0.3],
1278            [1.0, -0.8, -0.4],
1279            [1.0, -0.3, 0.7],
1280            [1.0, 0.1, -0.9],
1281            [1.0, 0.5, 0.2],
1282            [1.0, 0.9, -0.1],
1283            [1.0, 1.3, 0.8],
1284            [1.0, 1.7, -0.6],
1285        ];
1286        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1287
1288        // Use one directional hyperparameter τ with a penalty perturbation:
1289        // S(τ) = S + τ S_τ.
1290        // Keep X_τ = 0 so this identity test remains valid in both non-Firth
1291        // and Firth-logit modes.
1292        let x_tau = Array2::<f64>::zeros(x.raw_dim());
1293        let s_tau = array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15],];
1294        let hyper =
1295            DirectionalHyperParam::single_penalty(0, x_tau.clone(), s_tau.clone(), None, None)
1296                .expect("single-penalty hyper direction");
1297        let rho = array![0.0];
1298
1299        // Tight inner tolerance: the envelope theorem requires an exact inner
1300        // P-IRLS optimum; 1e-10 leaves enough residual gradient to cause ~12%
1301        // V_tau mismatch on this small (n=8) logistic problem.
1302        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
1303        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1304        let bundle = state.obtain_eval_bundle(&rho).expect("bundle");
1305        let pr = bundle.pirls_result.as_ref();
1306
1307        let beta = beta_original_from_bundle(&bundle);
1308        let h_orig = h_original_from_bundle(&bundle);
1309        let u = &pr.solveweights * &(&pr.solveworking_response - &pr.final_eta);
1310
1311        // B from implicit solve:
1312        //   H B = X_τ^T g - X^T W(X_τ β̂) - S_τ β̂.
1313        let x_tau_beta = gam_linalg::faer_ndarray::fast_av(&x_tau, &beta);
1314        let weighted_x_tau_beta = &pr.finalweights * &x_tau_beta;
1315        let rhs = gam_linalg::faer_ndarray::fast_atv(&x_tau, &u)
1316            - gam_linalg::faer_ndarray::fast_atv(&x, &weighted_x_tau_beta)
1317            - s_tau.dot(&beta);
1318        let chol = h_orig.cholesky(Side::Lower).expect("chol(H)");
1319        let b_analytic = chol.solvevec(&rhs);
1320
1321        // H_τ from exact total derivative:
1322        //   H_τ = X_τ^T W X + X^T W X_τ + X^T W_τ X + S_τ,
1323        // with W_τ provided by the family directional curvature callback.
1324        let eta_dot = &x_tau_beta + &gam_linalg::faer_ndarray::fast_av(&x, &b_analytic);
1325        let w_direction = crate::pirls::directionalworking_curvature_from_c_array(
1326            &pr.solve_c_array,
1327            &pr.finalweights,
1328            &eta_dot,
1329        );
1330        let wx = RemlState::row_scale(&x, &pr.finalweights);
1331        let wx_tau = RemlState::row_scale(&x_tau, &pr.finalweights);
1332        let mut xwtau_x = x.clone();
1333        match w_direction {
1334            crate::pirls::DirectionalWorkingCurvature::Diagonal(diag) => {
1335                xwtau_x = RemlState::row_scale(&xwtau_x, &diag);
1336            }
1337        }
1338        let mut h_tau_analytic = gam_linalg::faer_ndarray::fast_atb(&x_tau, &wx);
1339        h_tau_analytic += &gam_linalg::faer_ndarray::fast_atb(&x, &wx_tau);
1340        h_tau_analytic += &gam_linalg::faer_ndarray::fast_atb(&x, &xwtau_x);
1341        h_tau_analytic += &s_tau;
1342
1343        // Fit-block stationarity cancellation:
1344        //   -ℓ_β^T B + β̂^T S B = 0.
1345        // Here S is the effective penalty in the inner Hessian surface:
1346        //   S = H - X^T W X.
1347        let ell_beta = gam_linalg::faer_ndarray::fast_atv(&x, &u);
1348        let s_eff = &h_orig - &gam_linalg::faer_ndarray::fast_atb(&x, &wx);
1349        let cancellation = -ell_beta.dot(&b_analytic) + beta.dot(&s_eff.dot(&b_analytic));
1350
1351        // Finite differences in τ against re-fit objective and mode.
1352        let h = 2e-5;
1353        let x_plus = &x + &(x_tau.mapv(|v| h * v));
1354        let x_minus = &x - &(x_tau.mapv(|v| h * v));
1355        let s_plus = &s0 + &(s_tau.mapv(|v| h * v));
1356        let s_minus = &s0 - &(s_tau.mapv(|v| h * v));
1357
1358        let state_plus = build_logit_state(&y, &w, &x_plus, &s_plus, &cfg);
1359        let state_minus = build_logit_state(&y, &w, &x_minus, &s_minus, &cfg);
1360        let bundle_plus = state_plus.obtain_eval_bundle(&rho).expect("bundle+");
1361        let bundle_minus = state_minus.obtain_eval_bundle(&rho).expect("bundle-");
1362        let beta_plus = beta_original_from_bundle(&bundle_plus);
1363        let beta_minus = beta_original_from_bundle(&bundle_minus);
1364        let bfd = (&beta_plus - &beta_minus).mapv(|v| v / (2.0 * h));
1365
1366        let h_plus = h_original_from_bundle(&bundle_plus);
1367        let h_minus = h_original_from_bundle(&bundle_minus);
1368        let h_taufd = (&h_plus - &h_minus).mapv(|v| v / (2.0 * h));
1369
1370        let v_plus = state_plus.compute_cost(&rho).expect("cost+");
1371        let v_minus = state_minus.compute_cost(&rho).expect("cost-");
1372        let v_taufd = (v_plus - v_minus) / (2.0 * h);
1373
1374        let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper.clone())
1375            .expect("analytic directional gradient");
1376
1377        let b_num = (&b_analytic - &bfd).mapv(|v| v * v).sum().sqrt();
1378        let b_den = bfd.mapv(|v| v * v).sum().sqrt().max(1e-12);
1379        let b_rel = b_num / b_den;
1380        for i in 0..b_analytic.len() {
1381            assert_eq!(
1382                b_analytic[i].signum(),
1383                bfd[i].signum(),
1384                "B sign mismatch at i={i}: analytic={} fd={}",
1385                b_analytic[i],
1386                bfd[i]
1387            );
1388        }
1389        assert!(
1390            b_rel < 2e-2,
1391            "B implicit solve mismatch vs FD: rel={b_rel:.3e}, num={b_num:.3e}, den={b_den:.3e}"
1392        );
1393
1394        let dh_num = (&h_tau_analytic - &h_taufd).mapv(|v| v * v).sum().sqrt();
1395        let dh_den = h_taufd.mapv(|v| v * v).sum().sqrt().max(1e-12);
1396        let dh_rel = dh_num / dh_den;
1397        for i in 0..h_tau_analytic.nrows() {
1398            for j in 0..h_tau_analytic.ncols() {
1399                assert_eq!(
1400                    h_tau_analytic[[i, j]].signum(),
1401                    h_taufd[[i, j]].signum(),
1402                    "H_tau sign mismatch at ({i},{j}): analytic={} fd={}",
1403                    h_tau_analytic[[i, j]],
1404                    h_taufd[[i, j]]
1405                );
1406            }
1407        }
1408        assert!(
1409            dh_rel < 3e-2,
1410            "H_tau mismatch vs FD: rel={dh_rel:.3e}, num={dh_num:.3e}, den={dh_den:.3e}"
1411        );
1412
1413        let v_abs = (v_tau_analytic - v_taufd).abs();
1414        let v_rel = v_abs / v_taufd.abs().max(1e-10);
1415        assert_eq!(
1416            v_tau_analytic.signum(),
1417            v_taufd.signum(),
1418            "V_tau sign mismatch: analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
1419        );
1420        assert!(
1421            v_rel < 2e-2,
1422            "V_tau mismatch vs FD: rel={v_rel:.3e}, abs={v_abs:.3e}, analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
1423        );
1424
1425        assert!(
1426            cancellation.abs() < 1e-10,
1427            "stationarity cancellation failed: | -ell_beta^T B + beta^T S B | = {:.3e}",
1428            cancellation.abs()
1429        );
1430    }
1431
1432    #[test]
1433    pub(crate) fn firth_exacthessian_includes_analytic_tk_second_derivatives() {
1434        // Rank-deficient X: the 4th column is 2x the 2nd column.
1435        let y = array![0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0];
1436        let w = Array1::<f64>::ones(y.len());
1437        let x = array![
1438            [1.0, -1.2, 0.4, -2.4],
1439            [1.0, -0.9, -0.1, -1.8],
1440            [1.0, -0.6, 0.3, -1.2],
1441            [1.0, -0.2, -0.4, -0.4],
1442            [1.0, 0.1, 0.5, 0.2],
1443            [1.0, 0.4, -0.6, 0.8],
1444            [1.0, 0.8, 0.2, 1.6],
1445            [1.0, 1.1, -0.3, 2.2],
1446            [1.0, 1.4, 0.7, 2.8],
1447            [1.0, 1.7, -0.2, 3.4],
1448        ];
1449        let s0 = array![
1450            [0.0, 0.0, 0.0, 0.0],
1451            [0.0, 1.5, 0.2, 0.0],
1452            [0.0, 0.2, 1.0, 0.0],
1453            [0.0, 0.0, 0.0, 0.5],
1454        ];
1455        let s1 = array![
1456            [0.0, 0.0, 0.0, 0.0],
1457            [0.0, 0.8, -0.1, 0.0],
1458            [0.0, -0.1, 0.6, 0.0],
1459            [0.0, 0.0, 0.0, 0.3],
1460        ];
1461        let offset = Array1::<f64>::zeros(y.len());
1462        // Rank-deficient Firth logit needs more inner iterations to converge
1463        // tightly enough for the envelope-theorem derivative tests.
1464        let cfg =
1465            RemlConfig::external(binomial_logit_glm_spec(), 1e-9, true).with_max_iterations(500);
1466        let p = x.ncols();
1467        use crate::estimate::PenaltySpec;
1468        let specs = vec![PenaltySpec::Dense(s0), PenaltySpec::Dense(s1)];
1469        let canonical = gam_terms::construction::canonicalize_penalty_specs(&specs, &[1, 1], p, "test")
1470            .map(|(canonical, _)| canonical)
1471            .expect("canonicalize");
1472        let state = RemlState::newwith_offset(
1473            y.view(),
1474            x.clone(),
1475            w.view(),
1476            offset.view(),
1477            canonical,
1478            p,
1479            &cfg,
1480            Some(vec![1, 1]),
1481            None,
1482            None,
1483        )
1484        .expect("state");
1485        let rho = array![0.1, -0.2];
1486        assert!(
1487            state.analytic_outer_hessian_enabled(),
1488            "Firth logit should no longer disable analytic outer Hessian planning"
1489        );
1490        let outer = state
1491            .compute_outer_eval_with_order(
1492                &rho,
1493                crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
1494            )
1495            .expect("outer Hessian eval should succeed");
1496        assert!(
1497            outer.hessian.is_analytic(),
1498            "outer planner should request and return an analytic Hessian"
1499        );
1500        let bundle = state.obtain_eval_bundle(&rho).expect("exact firth bundle");
1501        let h_dense = state
1502            .compute_lamlhessian_exact_from_bundle(&rho, &bundle)
1503            .expect("Firth exact Hessian should include analytic TK second derivatives");
1504        assert_eq!(h_dense.raw_dim(), ndarray::Ix2(2, 2));
1505        assert!(
1506            h_dense.iter().all(|value| value.is_finite()),
1507            "Hessian should be finite: {h_dense:?}"
1508        );
1509    }
1510
1511    #[test]
1512    pub(crate) fn firth_outer_hessian_matches_gradient_finite_difference_with_tk_terms() {
1513        let y = array![0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
1514        let w = Array1::<f64>::ones(y.len());
1515        let x = array![
1516            [1.0, -1.0, 0.3],
1517            [1.0, -0.7, -0.2],
1518            [1.0, -0.3, 0.4],
1519            [1.0, 0.0, -0.5],
1520            [1.0, 0.2, 0.6],
1521            [1.0, 0.6, -0.4],
1522            [1.0, 0.9, 0.2],
1523            [1.0, 1.3, -0.1],
1524        ];
1525        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.1], [0.0, 0.1, 0.7],];
1526        let s1 = array![[0.0, 0.0, 0.0], [0.0, 0.4, -0.05], [0.0, -0.05, 0.9],];
1527        let cfg =
1528            RemlConfig::external(binomial_logit_glm_spec(), 1e-9, true).with_max_iterations(500);
1529        let p_dim = x.ncols();
1530        use crate::estimate::PenaltySpec;
1531        let specs = vec![PenaltySpec::Dense(s0), PenaltySpec::Dense(s1)];
1532        let canonical =
1533            gam_terms::construction::canonicalize_penalty_specs(&specs, &[1, 1], p_dim, "test")
1534                .map(|(canonical, _)| canonical)
1535                .expect("canonicalize");
1536        let offset = Array1::<f64>::zeros(y.len());
1537        let state = RemlState::newwith_offset(
1538            y.view(),
1539            x.clone(),
1540            w.view(),
1541            offset.view(),
1542            canonical,
1543            p_dim,
1544            &cfg,
1545            Some(vec![1, 1]),
1546            None,
1547            None,
1548        )
1549        .expect("state");
1550        let rho = array![0.15, -0.25];
1551        let eval = state
1552            .compute_outer_eval_with_order(
1553                &rho,
1554                crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
1555            )
1556            .expect("analytic Hessian eval");
1557        let h = match eval.hessian {
1558            HessianResult::Analytic(hessian) => hessian,
1559            HessianResult::Operator(_) | HessianResult::Unavailable => {
1560                panic!("expected dense analytic Hessian")
1561            }
1562        };
1563        let delta = 2.0e-5;
1564        for col in 0..rho.len() {
1565            let mut rp = rho.clone();
1566            let mut rm = rho.clone();
1567            rp[col] += delta;
1568            rm[col] -= delta;
1569            let gp = state
1570                .compute_outer_eval_with_order(
1571                    &rp,
1572                    crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
1573                )
1574                .expect("plus grad")
1575                .gradient;
1576            let gm = state
1577                .compute_outer_eval_with_order(
1578                    &rm,
1579                    crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
1580                )
1581                .expect("minus grad")
1582                .gradient;
1583            for row in 0..rho.len() {
1584                let fd = (gp[row] - gm[row]) / (2.0 * delta);
1585                let an = h[[row, col]];
1586                let rel = (fd - an).abs() / fd.abs().max(an.abs()).max(1e-6);
1587                assert!(
1588                    rel < 2.0e-3,
1589                    "Hessian mismatch ({row},{col}): analytic={an:.9e}, fd={fd:.9e}, rel={rel:.3e}"
1590                );
1591            }
1592        }
1593    }
1594
1595    #[test]
1596    pub(crate) fn firthgradient_lives_in_design_column_space_under_rank_deficiency() {
1597        // Rank-deficient design: col4 = 2*col2.
1598        let x = array![
1599            [1.0, -1.2, 0.4, -2.4],
1600            [1.0, -0.9, -0.1, -1.8],
1601            [1.0, -0.6, 0.3, -1.2],
1602            [1.0, -0.2, -0.4, -0.4],
1603            [1.0, 0.1, 0.5, 0.2],
1604            [1.0, 0.4, -0.6, 0.8],
1605            [1.0, 0.8, 0.2, 1.6],
1606            [1.0, 1.1, -0.3, 2.2],
1607        ];
1608        let beta = array![0.1, -0.2, 0.3, 0.05];
1609        let eta = x.dot(&beta);
1610        let op = super::RemlState::build_firth_dense_operator_for_link(
1611            &gam_problem::InverseLink::Standard(gam_problem::StandardLink::Logit),
1612            &x,
1613            &eta,
1614            ndarray::Array1::ones(x.nrows()).view(),
1615        )
1616        .expect("firth operator");
1617
1618        // Exact reduced-space Firth gradient:
1619        //   gradPhi = 0.5 Xᵀ (w' ⊙ h), with h = diag(X_r K_r X_rᵀ).
1620        let gradphi = 0.5 * x.t().dot(&(&op.w1 * &op.h_diag));
1621
1622        // Check (I - QQᵀ) gradPhi ≈ 0.
1623        let q = &op.q_basis;
1624        let proj = q.dot(&q.t().dot(&gradphi));
1625        let resid = &gradphi - &proj;
1626        let rel =
1627            resid.mapv(|v| v * v).sum().sqrt() / gradphi.mapv(|v| v * v).sum().sqrt().max(1e-12);
1628        assert!(
1629            rel < 1e-10,
1630            "Firth gradient should lie in Col(Xᵀ): rel residual={rel:.3e}"
1631        );
1632    }
1633
1634    #[test]
1635    pub(crate) fn firth_logit_directional_hypergradient_accepts_penalty_only_with_full_tk_gradient()
1636    {
1637        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1638        let w = Array1::<f64>::ones(y.len());
1639        let x = array![
1640            [1.0, -1.1, 0.2],
1641            [1.0, -0.6, -0.3],
1642            [1.0, -0.1, 0.5],
1643            [1.0, 0.3, -0.7],
1644            [1.0, 0.8, 0.1],
1645            [1.0, 1.2, -0.4],
1646        ];
1647        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1648        let hyper = DirectionalHyperParam::single_penalty(
1649            0,
1650            Array2::<f64>::zeros((x.nrows(), x.ncols())),
1651            array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1652            None,
1653            None,
1654        )
1655        .expect("single-penalty hyper direction");
1656        let rho = array![0.0];
1657        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1658        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1659        let gradient = single_directional_tau_gradient(&state, &rho, hyper)
1660            .expect("Firth penalty-only directional gradient should use analytic TK propagation");
1661        assert!(gradient.is_finite(), "gradient={gradient}");
1662        let fd = fd_directional_tau_cost_gradient(
1663            &y,
1664            &w,
1665            &x,
1666            &s0,
1667            &cfg,
1668            &rho,
1669            &Array2::<f64>::zeros((x.nrows(), x.ncols())),
1670            &array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1671        );
1672        let rel = (gradient - fd).abs() / gradient.abs().max(fd.abs()).max(1.0e-10);
1673        assert!(
1674            rel < 1.0e-3,
1675            "Firth penalty-only directional gradient mismatch: analytic={gradient:.12e}, fd={fd:.12e}, rel={rel:.3e}"
1676        );
1677
1678        let efs_hyper = DirectionalHyperParam::single_penalty(
1679            0,
1680            Array2::<f64>::zeros((x.nrows(), x.ncols())),
1681            array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1682            None,
1683            None,
1684        )
1685        .expect("single-penalty EFS hyper direction");
1686        let efs = state
1687            .compute_efs_steps_with_psi_ext(&rho, &[efs_hyper])
1688            .expect("Firth penalty-only EFS should use analytic TK propagation");
1689        assert!(efs.cost.is_finite(), "efs cost={}", efs.cost);
1690    }
1691
1692    #[test]
1693    pub(crate) fn firth_logit_directional_hypergradient_accepts_design_moving_with_full_tk_gradient()
1694     {
1695        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1696        let w = Array1::<f64>::ones(y.len());
1697        let x = array![
1698            [1.0, -1.1, 0.2],
1699            [1.0, -0.6, -0.3],
1700            [1.0, -0.1, 0.5],
1701            [1.0, 0.3, -0.7],
1702            [1.0, 0.8, 0.1],
1703            [1.0, 1.2, -0.4],
1704        ];
1705        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1706        let hyper = DirectionalHyperParam::single_penalty(
1707            0,
1708            Array2::from_elem((x.nrows(), x.ncols()), 1e-3),
1709            Array2::<f64>::zeros((x.ncols(), x.ncols())),
1710            None,
1711            None,
1712        )
1713        .expect("single-penalty hyper direction");
1714        let rho = array![0.0];
1715        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1716        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1717        let gradient = single_directional_tau_gradient(&state, &rho, hyper)
1718            .expect("Firth design-moving directional gradient should use analytic TK propagation");
1719        assert!(gradient.is_finite(), "gradient={gradient}");
1720        let x_tau = Array2::from_elem((x.nrows(), x.ncols()), 1e-3);
1721        let s_tau = Array2::<f64>::zeros((x.ncols(), x.ncols()));
1722        let fd = fd_directional_tau_cost_gradient(&y, &w, &x, &s0, &cfg, &rho, &x_tau, &s_tau);
1723        let rel = (gradient - fd).abs() / gradient.abs().max(fd.abs()).max(1.0e-10);
1724        assert!(
1725            rel < 2.0e-2,
1726            "Firth design-moving directional gradient mismatch: analytic={gradient:.12e}, fd={fd:.12e}, rel={rel:.3e}"
1727        );
1728    }
1729
1730    #[test]
1731    pub(crate) fn firth_logit_hybrid_efs_accepts_full_tk_psi_gradient() {
1732        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1733        let w = Array1::<f64>::ones(y.len());
1734        let x = array![
1735            [1.0, -1.1, 0.2],
1736            [1.0, -0.6, -0.3],
1737            [1.0, -0.1, 0.5],
1738            [1.0, 0.3, -0.7],
1739            [1.0, 0.8, 0.1],
1740            [1.0, 1.2, -0.4],
1741        ];
1742        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1743        let hyper_dirs = vec![
1744            DirectionalHyperParam::single_penalty(
1745                0,
1746                Array2::from_shape_fn((x.nrows(), x.ncols()), |(i, j)| {
1747                    1e-3 * ((i + 1) as f64) * ((j + 2) as f64)
1748                }),
1749                Array2::<f64>::zeros((x.ncols(), x.ncols())),
1750                None,
1751                None,
1752            )
1753            .expect("design-moving hyper direction"),
1754        ];
1755        let rho = array![0.0];
1756        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1757        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1758
1759        let full = state
1760            .evaluate_unified_with_psi_ext(
1761                &rho,
1762                None,
1763                crate::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient,
1764                &hyper_dirs,
1765            )
1766            .expect("full Firth psi gradient should use analytic TK propagation");
1767        assert!(full.cost.is_finite(), "full cost={}", full.cost);
1768        let full_grad = full.gradient.expect("gradient should be present");
1769        assert!(
1770            full_grad.iter().all(|value| value.is_finite()),
1771            "full gradient={full_grad:?}"
1772        );
1773
1774        let efs = state
1775            .compute_efs_steps_with_psi_ext(&rho, &hyper_dirs)
1776            .expect("hybrid EFS should use analytic TK propagation");
1777        assert!(efs.cost.is_finite(), "efs cost={}", efs.cost);
1778    }
1779
1780    #[test]
1781    pub(crate) fn joint_hyperhessianwires_mixed_blocks() {
1782        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1783        let w = Array1::<f64>::ones(y.len());
1784        let x = array![
1785            [1.0, -1.2, 0.3],
1786            [1.0, -0.8, -0.4],
1787            [1.0, -0.3, 0.7],
1788            [1.0, 0.1, -0.9],
1789            [1.0, 0.5, 0.2],
1790            [1.0, 0.9, -0.1],
1791            [1.0, 1.3, 0.8],
1792            [1.0, 1.7, -0.6],
1793        ];
1794        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1795        let cfg =
1796            RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
1797        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1798        let rho = array![0.0];
1799        let theta = array![0.0, 0.0, 0.0];
1800        let hyper_dirs = vec![
1801            DirectionalHyperParam::single_penalty(
1802                0,
1803                Array2::<f64>::zeros((x.nrows(), x.ncols())),
1804                array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
1805                None,
1806                None,
1807            )
1808            .expect("single-penalty hyper direction"),
1809            DirectionalHyperParam::single_penalty(
1810                0,
1811                Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1812                Array2::<f64>::zeros((x.ncols(), x.ncols())),
1813                None,
1814                None,
1815            )
1816            .expect("single-penalty hyper direction"),
1817        ];
1818
1819        let (_, _, h) =
1820            compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
1821                .expect("joint hyper cost+gradient+hessian");
1822        assert_eq!(h.nrows(), theta.len());
1823        assert_eq!(h.ncols(), theta.len());
1824        assert!(h.iter().all(|v| v.is_finite()));
1825        for i in 0..h.nrows() {
1826            for j in 0..i {
1827                let diff = (h[[i, j]] - h[[j, i]]).abs();
1828                assert!(
1829                    diff < 1e-6,
1830                    "joint hessian asymmetry at ({i},{j}): {diff:.3e}"
1831                );
1832            }
1833        }
1834        // Mixed block must be nontrivial for at least one supplied direction.
1835        let mixed_0 = h[[0, 1]];
1836        let mixed_1 = h[[0, 2]];
1837        assert!(
1838            mixed_0.is_finite() && mixed_1.is_finite(),
1839            "mixed blocks must be finite"
1840        );
1841    }
1842
1843    #[test]
1844    pub(crate) fn joint_tau_tau_linear_dirs_matchfd_reference_away_fromzero_psi() {
1845        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1846        let w = Array1::<f64>::ones(y.len());
1847        let x = array![
1848            [1.0, -1.2, 0.3],
1849            [1.0, -0.8, -0.4],
1850            [1.0, -0.3, 0.7],
1851            [1.0, 0.1, -0.9],
1852            [1.0, 0.5, 0.2],
1853            [1.0, 0.9, -0.1],
1854            [1.0, 1.3, 0.8],
1855            [1.0, 1.7, -0.6],
1856        ];
1857        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1858        let cfg =
1859            RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
1860        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1861        let rho = array![0.0];
1862        let psi = array![0.7, -0.4];
1863        let theta = array![rho[0], psi[0], psi[1]];
1864        let hyper_dirs = vec![
1865            DirectionalHyperParam::single_penalty(
1866                0,
1867                Array2::<f64>::zeros((x.nrows(), x.ncols())),
1868                array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
1869                None,
1870                None,
1871            )
1872            .expect("linear tau direction"),
1873            DirectionalHyperParam::single_penalty(
1874                0,
1875                Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1876                Array2::<f64>::zeros((x.ncols(), x.ncols())),
1877                None,
1878                None,
1879            )
1880            .expect("linear tau direction"),
1881        ];
1882
1883        let (_, _, h_full) =
1884            compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
1885                .expect("joint hyper cost+gradient+hessian");
1886        let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
1887
1888        // FD via physical perturbation of design/penalty matrices (matching
1889        // the V_tau FD pattern).  For column j we perturb X and S₀ along
1890        // direction j, build fresh states, and evaluate the τ-gradient for
1891        // every direction i at those perturbed states.
1892        let x_tau_mats: Vec<Array2<f64>> = vec![
1893            Array2::<f64>::zeros((x.nrows(), x.ncols())),
1894            Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1895        ];
1896        let s_tau_mats: Vec<Array2<f64>> = vec![
1897            array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15]],
1898            Array2::<f64>::zeros((x.ncols(), x.ncols())),
1899        ];
1900
1901        let h_ttfd = directional_tau_hessian_fd_reference(
1902            &y,
1903            &w,
1904            &x,
1905            &s0,
1906            &cfg,
1907            &rho,
1908            &hyper_dirs,
1909            &x_tau_mats,
1910            &s_tau_mats,
1911        );
1912
1913        let num = (&h_tt_analytic - &h_ttfd)
1914            .iter()
1915            .map(|v| v * v)
1916            .sum::<f64>()
1917            .sqrt();
1918        let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
1919        let rel = num / den;
1920        assert!(
1921            rel < 1e-4,
1922            "linear-dir joint tau-tau block deviates from FD reference away from zero psi: rel={rel:.3e}, analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
1923        );
1924    }
1925
1926    #[test]
1927    pub(crate) fn joint_hypervalidation_rejects_out_of_boundssecond_order_penalty_index() {
1928        // The hyper direction declares a second-order penalty derivative
1929        // against base penalty index 1, but the configured ρ vector has
1930        // dimension 1 (so only index 0 is valid).  The pair-callback
1931        // builder in `build_tau_penalty_derivative_data` is responsible for
1932        // validating both first- and second-order penalty indices against
1933        // `rho.len()`; this test pins that contract.
1934        //
1935        // We deliberately keep `firth_bias_reduction = true` here so the
1936        // call site exercises the full Firth/Tierney–Kadane outer pipeline:
1937        // PIRLS + ext-coord construction + pair-callback assembly.  With
1938        // analytic c/d propagation now wired in
1939        // `tk_direct_gradient_from_cd_and_design`, there is no longer any
1940        // FD-fallback rejection on this path, so the out-of-bounds error
1941        // fired by the pair-callback builder is the first failure the
1942        // joint evaluator surfaces — and that is exactly what we want this
1943        // test to assert.
1944        let y = array![0.0, 1.0, 0.0, 1.0];
1945        let w = Array1::<f64>::ones(y.len());
1946        let x = array![
1947            [1.0, -0.5, 0.2],
1948            [1.0, -0.1, -0.3],
1949            [1.0, 0.4, 0.6],
1950            [1.0, 0.9, -0.2],
1951        ];
1952        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1953        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, true);
1954        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1955        let theta = array![0.0, 0.0];
1956        let hyper_dirs = vec![
1957            DirectionalHyperParam::new(
1958                Array2::<f64>::zeros((x.nrows(), x.ncols())),
1959                vec![(0, Array2::<f64>::zeros((x.ncols(), x.ncols())))],
1960                None,
1961                Some(vec![Some(vec![(1, Array2::<f64>::eye(x.ncols()))])]),
1962            )
1963            .expect("hyper direction with invalid second-order penalty index"),
1964        ];
1965
1966        let msg = match compute_joint_hypercostgradienthessian(&state, &theta, 1, &hyper_dirs) {
1967            Ok(_) => panic!("invalid second-order penalty index should be rejected"),
1968            Err(err) => err.to_string(),
1969        };
1970        assert!(
1971            msg.contains("out of bounds") || msg.contains("penalty_index"),
1972            "unexpected validation error: {msg}"
1973        );
1974    }
1975
1976    #[test]
1977    pub(crate) fn joint_tau_tau_analytic_matchesfd_reference() {
1978        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1979        let w = Array1::<f64>::ones(y.len());
1980        let x = array![
1981            [1.0, -1.2, 0.3],
1982            [1.0, -0.8, -0.4],
1983            [1.0, -0.3, 0.7],
1984            [1.0, 0.1, -0.9],
1985            [1.0, 0.5, 0.2],
1986            [1.0, 0.9, -0.1],
1987            [1.0, 1.3, 0.8],
1988            [1.0, 1.7, -0.6],
1989        ];
1990        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1991        let cfg =
1992            RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
1993        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1994        let rho = array![0.0];
1995        let psi = array![0.0, 0.0];
1996        let hyper_dirs = vec![
1997            DirectionalHyperParam::single_penalty(
1998                0,
1999                Array2::<f64>::zeros((x.nrows(), x.ncols())),
2000                array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
2001                None,
2002                None,
2003            )
2004            .expect("single-penalty hyper direction"),
2005            DirectionalHyperParam::single_penalty(
2006                0,
2007                Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
2008                Array2::<f64>::zeros((x.ncols(), x.ncols())),
2009                None,
2010                None,
2011            )
2012            .expect("single-penalty hyper direction"),
2013        ];
2014
2015        let theta = {
2016            let mut t = Array1::<f64>::zeros(rho.len() + psi.len());
2017            t.slice_mut(s![..rho.len()]).assign(&rho);
2018            t.slice_mut(s![rho.len()..]).assign(&psi);
2019            t
2020        };
2021        let (_, _, h_full) =
2022            compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
2023                .expect("joint hyper cost+gradient+hessian");
2024        let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
2025        assert_eq!(h_tt_analytic.nrows(), hyper_dirs.len());
2026        assert_eq!(h_tt_analytic.ncols(), hyper_dirs.len());
2027
2028        // FD via physical perturbation of design/penalty matrices (matching
2029        // the V_tau FD pattern).  For column j we perturb X and S₀ along
2030        // direction j, build fresh states, and evaluate the τ-gradient for
2031        // every direction i at those perturbed states.
2032        let x_tau_mats: Vec<Array2<f64>> = vec![
2033            Array2::<f64>::zeros((x.nrows(), x.ncols())),
2034            Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
2035        ];
2036        let s_tau_mats: Vec<Array2<f64>> = vec![
2037            array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15]],
2038            Array2::<f64>::zeros((x.ncols(), x.ncols())),
2039        ];
2040
2041        let h_ttfd = directional_tau_hessian_fd_reference(
2042            &y,
2043            &w,
2044            &x,
2045            &s0,
2046            &cfg,
2047            &rho,
2048            &hyper_dirs,
2049            &x_tau_mats,
2050            &s_tau_mats,
2051        );
2052
2053        let num = (&h_tt_analytic - &h_ttfd)
2054            .iter()
2055            .map(|v| v * v)
2056            .sum::<f64>()
2057            .sqrt();
2058        let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2059        let rel = num / den;
2060        assert!(
2061            rel < 1e-4,
2062            "analytic tau-tau block deviates from FD reference: rel={rel:.3e}, analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2063        );
2064    }
2065
2066    // ── Profiled Gaussian REML coverage for design-moving τ-directions ──
2067    //
2068    // The existing directional-hyper tests all use BinomialLogit, which has
2069    // DispersionHandling::Fixed.  These tests validate the profiled Gaussian
2070    // path (DispersionHandling::ProfiledGaussian) with design-moving
2071    // τ-directions, where the profiled scale φ̂ = D_p/(n−M) depends on ρ
2072    // and the envelope-theorem rescaling by (n−M)/D_p must be correct.
2073
2074    /// Shared test fixture for profiled Gaussian REML tests.
2075    pub(crate) struct GaussianRemlFixture {
2076        pub(crate) y: Array1<f64>,
2077        pub(crate) w: Array1<f64>,
2078        pub(crate) x: Array2<f64>,
2079        pub(crate) s0: Array2<f64>,
2080        pub(crate) cfg: RemlConfig,
2081        pub(crate) rho: Array1<f64>,
2082        /// Design-moving τ-direction (non-zero X_τ, zero S_τ).
2083        pub(crate) x_tau_design: Array2<f64>,
2084        /// Penalty-only τ-direction (zero X_τ, non-zero S_τ).
2085        pub(crate) s_tau_penalty: Array2<f64>,
2086    }
2087
2088    impl GaussianRemlFixture {
2089        pub(crate) fn new() -> Self {
2090            let y = array![0.5, 1.2, -0.3, 0.8, 1.1, -0.6, 0.9, 0.1, -0.2, 0.7];
2091            let x = array![
2092                [1.0, -1.2, 0.3],
2093                [1.0, -0.8, -0.4],
2094                [1.0, -0.3, 0.7],
2095                [1.0, 0.1, -0.9],
2096                [1.0, 0.5, 0.2],
2097                [1.0, 0.9, -0.1],
2098                [1.0, 1.3, 0.8],
2099                [1.0, 1.7, -0.6],
2100                [1.0, -0.5, 0.5],
2101                [1.0, 0.3, -0.3],
2102            ];
2103            Self {
2104                w: Array1::<f64>::ones(y.len()),
2105                y,
2106                x: x.clone(),
2107                s0: array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]],
2108                cfg: RemlConfig::external(gaussian_identity_glm_spec(), 1e-14, false),
2109                rho: array![0.0],
2110                x_tau_design: array![
2111                    [0.0, 1e-3, -2e-3],
2112                    [0.0, -3e-3, 1e-3],
2113                    [0.0, 2e-3, 0.5e-3],
2114                    [0.0, -1e-3, 3e-3],
2115                    [0.0, 0.5e-3, -1e-3],
2116                    [0.0, 1.5e-3, 2e-3],
2117                    [0.0, -2e-3, -0.5e-3],
2118                    [0.0, 3e-3, 1e-3],
2119                    [0.0, -0.5e-3, 2e-3],
2120                    [0.0, 1e-3, -1.5e-3],
2121                ],
2122                s_tau_penalty: array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15]],
2123            }
2124        }
2125    }
2126
2127    impl LogitDesignMotionFixture for GaussianRemlFixture {
2128        fn y(&self) -> &Array1<f64> {
2129            &self.y
2130        }
2131        fn w(&self) -> &Array1<f64> {
2132            &self.w
2133        }
2134        fn x(&self) -> &Array2<f64> {
2135            &self.x
2136        }
2137        fn s0(&self) -> &Array2<f64> {
2138            &self.s0
2139        }
2140        fn cfg(&self) -> &RemlConfig {
2141            &self.cfg
2142        }
2143        fn rho(&self) -> &Array1<f64> {
2144            &self.rho
2145        }
2146    }
2147
2148    #[test]
2149    pub(crate) fn profiled_gaussian_design_moving_gradient_matches_fd() {
2150        let f = GaussianRemlFixture::new();
2151        let state = f.state();
2152        let s_tau = Array2::<f64>::zeros((3, 3));
2153        let hyper = DirectionalHyperParam::single_penalty(
2154            0,
2155            f.x_tau_design.clone(),
2156            s_tau.clone(),
2157            None,
2158            None,
2159        )
2160        .expect("design-moving hyper direction");
2161
2162        let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2163            .expect("analytic directional gradient");
2164        let v_taufd = f.fd_directional_gradient(&f.x_tau_design, &s_tau);
2165
2166        let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2167        assert!(
2168            v_rel < 1e-3,
2169            "Gaussian REML design-moving V_tau mismatch: rel={v_rel:.3e}, \
2170             analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2171        );
2172    }
2173
2174    #[test]
2175    pub(crate) fn profiled_gaussian_penalty_only_gradient_matches_fd() {
2176        let f = GaussianRemlFixture::new();
2177        let state = f.state();
2178        let x_tau = Array2::<f64>::zeros(f.x.raw_dim());
2179        let hyper = DirectionalHyperParam::single_penalty(
2180            0,
2181            x_tau.clone(),
2182            f.s_tau_penalty.clone(),
2183            None,
2184            None,
2185        )
2186        .expect("penalty-only hyper direction");
2187
2188        let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2189            .expect("analytic directional gradient");
2190        let v_taufd = f.fd_directional_gradient(&x_tau, &f.s_tau_penalty);
2191
2192        let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2193        assert!(
2194            v_rel < 1e-3,
2195            "Gaussian REML penalty-only V_tau mismatch: rel={v_rel:.3e}, \
2196             analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2197        );
2198    }
2199
2200    #[test]
2201    pub(crate) fn profiled_gaussian_joint_hessian_matches_fd() {
2202        // Validate the ττ Hessian block under profiled Gaussian REML with
2203        // both a penalty-only and a design-moving direction.
2204        let f = GaussianRemlFixture::new();
2205        let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2206        let s_tau_0 = f.s_tau_penalty.clone();
2207        let x_tau_1 = f.x_tau_design.clone();
2208        let s_tau_1 = Array2::<f64>::zeros((3, 3));
2209
2210        let hyper_dirs = vec![
2211            DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2212                .expect("penalty-only direction"),
2213            DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2214                .expect("design-moving direction"),
2215        ];
2216
2217        let state = f.state();
2218        let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2219        theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2220        let (_, _, h_full) =
2221            compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2222                .expect("joint cost+gradient+hessian");
2223        let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2224
2225        // Finite-difference Hessian: perturb each direction, re-evaluate
2226        // gradient of all directions at perturbed states.
2227        let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2228        let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2229        let h_ttfd = directional_tau_hessian_fd_reference(
2230            &f.y,
2231            &f.w,
2232            &f.x,
2233            &f.s0,
2234            &f.cfg,
2235            &f.rho,
2236            &hyper_dirs,
2237            &x_tau_mats,
2238            &s_tau_mats,
2239        );
2240
2241        let num = (&h_tt_analytic - &h_ttfd)
2242            .iter()
2243            .map(|v| v * v)
2244            .sum::<f64>()
2245            .sqrt();
2246        let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2247        let rel = num / den;
2248        assert!(
2249            rel < 1e-4,
2250            "Gaussian REML tau-tau Hessian mismatch: rel={rel:.3e}, \
2251             analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2252        );
2253    }
2254
2255    // ── Non-Gaussian + design-motion: IFT Hessian-drift coverage ────────
2256    //
2257    // For non-Gaussian links (logit, probit, cloglog, ...), H = X'W(η)X + S
2258    // depends on β̂ through η = Xβ̂.  When ψ moves the design, the total
2259    // Hessian drift dH/dψ includes an IFT contribution from dβ̂/dψ:
2260    //
2261    //   dH/dψ = [explicit at fixed β] + X' diag(c ⊙ X(-v_i)) X
2262    //
2263    // where v_i = H⁻¹ g_i.  The standard GLM path handles this via
2264    // `hessian_derivative_correction(v_i)`.  This test validates that the
2265    // gradient is correct for logit + design-moving ψ, which would fail if
2266    // the IFT correction were missing.
2267
2268    #[test]
2269    pub(crate) fn logit_design_moving_gradient_matches_fd() {
2270        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
2271        let w = Array1::<f64>::ones(y.len());
2272        let x = array![
2273            [1.0, -1.2, 0.3],
2274            [1.0, -0.8, -0.4],
2275            [1.0, -0.3, 0.7],
2276            [1.0, 0.1, -0.9],
2277            [1.0, 0.5, 0.2],
2278            [1.0, 0.9, -0.1],
2279            [1.0, 1.3, 0.8],
2280            [1.0, 1.7, -0.6],
2281            [1.0, -0.5, 0.5],
2282            [1.0, 0.3, -0.3],
2283        ];
2284        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]];
2285        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2286        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2287        let rho = array![0.0];
2288
2289        // Design-moving direction with non-zero X_τ.
2290        let x_tau = array![
2291            [0.0, 1e-3, -2e-3],
2292            [0.0, -3e-3, 1e-3],
2293            [0.0, 2e-3, 0.5e-3],
2294            [0.0, -1e-3, 3e-3],
2295            [0.0, 0.5e-3, -1e-3],
2296            [0.0, 1.5e-3, 2e-3],
2297            [0.0, -2e-3, -0.5e-3],
2298            [0.0, 3e-3, 1e-3],
2299            [0.0, -0.5e-3, 2e-3],
2300            [0.0, 1e-3, -1.5e-3],
2301        ];
2302        let s_tau = Array2::<f64>::zeros((3, 3));
2303        let hyper =
2304            DirectionalHyperParam::single_penalty(0, x_tau.clone(), s_tau.clone(), None, None)
2305                .expect("design-moving hyper direction");
2306
2307        let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper)
2308            .expect("analytic directional gradient");
2309
2310        let h = 2e-5;
2311        let x_plus = &x + &x_tau.mapv(|v| h * v);
2312        let x_minus = &x - &x_tau.mapv(|v| h * v);
2313        let state_plus = build_logit_state(&y, &w, &x_plus, &s0, &cfg);
2314        let state_minus = build_logit_state(&y, &w, &x_minus, &s0, &cfg);
2315        let v_plus = state_plus.compute_cost(&rho).expect("cost+");
2316        let v_minus = state_minus.compute_cost(&rho).expect("cost-");
2317        let v_taufd = (v_plus - v_minus) / (2.0 * h);
2318
2319        let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2320        assert!(
2321            v_rel < 1e-3,
2322            "Logit REML design-moving V_tau mismatch: rel={v_rel:.3e}, \
2323             analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2324        );
2325    }
2326
2327    #[test]
2328    pub(crate) fn logit_design_moving_hessian_matches_fd() {
2329        // Hessian-level validation for logit + design-motion.
2330        // The IFT correction enters the trace term through
2331        // hessian_derivative_correction(v_i), so the Hessian is the most
2332        // sensitive test of whether the correction is applied correctly.
2333        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
2334        let w = Array1::<f64>::ones(y.len());
2335        let x = array![
2336            [1.0, -1.2, 0.3],
2337            [1.0, -0.8, -0.4],
2338            [1.0, -0.3, 0.7],
2339            [1.0, 0.1, -0.9],
2340            [1.0, 0.5, 0.2],
2341            [1.0, 0.9, -0.1],
2342            [1.0, 1.3, 0.8],
2343            [1.0, 1.7, -0.6],
2344            [1.0, -0.5, 0.5],
2345            [1.0, 0.3, -0.3],
2346        ];
2347        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]];
2348        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2349        let rho = array![0.0];
2350
2351        // Two directions: one penalty-only, one design-moving.
2352        let x_tau_0 = Array2::<f64>::zeros(x.raw_dim());
2353        let s_tau_0 = array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15]];
2354        let x_tau_1 = array![
2355            [0.0, 1e-3, -2e-3],
2356            [0.0, -3e-3, 1e-3],
2357            [0.0, 2e-3, 0.5e-3],
2358            [0.0, -1e-3, 3e-3],
2359            [0.0, 0.5e-3, -1e-3],
2360            [0.0, 1.5e-3, 2e-3],
2361            [0.0, -2e-3, -0.5e-3],
2362            [0.0, 3e-3, 1e-3],
2363            [0.0, -0.5e-3, 2e-3],
2364            [0.0, 1e-3, -1.5e-3],
2365        ];
2366        let s_tau_1 = Array2::<f64>::zeros((3, 3));
2367
2368        let hyper_dirs = vec![
2369            DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2370                .expect("penalty-only direction"),
2371            DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2372                .expect("design-moving direction"),
2373        ];
2374
2375        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2376        let mut theta = Array1::<f64>::zeros(rho.len() + hyper_dirs.len());
2377        theta.slice_mut(s![..rho.len()]).assign(&rho);
2378        let (_, _, h_full) =
2379            compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
2380                .expect("joint cost+gradient+hessian");
2381        let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
2382
2383        let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2384        let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2385        let h_ttfd = directional_tau_hessian_fd_reference(
2386            &y,
2387            &w,
2388            &x,
2389            &s0,
2390            &cfg,
2391            &rho,
2392            &hyper_dirs,
2393            &x_tau_mats,
2394            &s_tau_mats,
2395        );
2396
2397        let num = (&h_tt_analytic - &h_ttfd)
2398            .iter()
2399            .map(|v| v * v)
2400            .sum::<f64>()
2401            .sqrt();
2402        let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2403        let rel = num / den;
2404        assert!(
2405            rel < 1e-4,
2406            "Logit REML design-moving tau-tau Hessian mismatch: rel={rel:.3e}, \
2407             analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2408        );
2409    }
2410
2411    // ── Larger non-Gaussian + design-motion fixture (n=30, p=5) ────────
2412    //
2413    // Validates the IFT correction (hessian_derivative_correction) at a
2414    // scale large enough that the correction is numerically non-trivial:
2415    // with n=30 and p=5, the logistic Hessian W(η) is far from identity
2416    // and the IFT term dβ̂/dψ contributes meaningfully.
2417
2418    /// Shared test fixture for binomial-logit REML with design-moving
2419    /// ψ-coordinates, n=30, p=5.
2420    pub(crate) struct BinomialLogitDesignMotionFixture {
2421        pub(crate) y: Array1<f64>,
2422        pub(crate) w: Array1<f64>,
2423        pub(crate) x: Array2<f64>,
2424        pub(crate) s0: Array2<f64>,
2425        pub(crate) cfg: RemlConfig,
2426        pub(crate) rho: Array1<f64>,
2427        /// Design-moving τ-direction: non-zero X_τ, zero S_τ.
2428        pub(crate) x_tau_design: Array2<f64>,
2429        /// Penalty-only τ-direction: zero X_τ, non-zero S_τ.
2430        pub(crate) s_tau_penalty: Array2<f64>,
2431    }
2432
2433    impl BinomialLogitDesignMotionFixture {
2434        pub(crate) fn new() -> Self {
2435            // Binary response with roughly balanced classes.
2436            let y = array![
2437                1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0,
2438                1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0
2439            ];
2440            // Design matrix: intercept + 4 covariate columns with varied magnitudes.
2441            let x = array![
2442                [1.0, -1.50, 0.42, 0.88, -0.31],
2443                [1.0, -1.12, -0.65, 0.14, 1.23],
2444                [1.0, -0.80, 1.10, -0.53, 0.07],
2445                [1.0, -0.55, -0.22, 1.40, -0.90],
2446                [1.0, -0.30, 0.73, -1.05, 0.44],
2447                [1.0, -0.05, -1.33, 0.60, 0.81],
2448                [1.0, 0.18, 0.55, -0.27, -1.15],
2449                [1.0, 0.42, -0.90, 1.12, 0.33],
2450                [1.0, 0.70, 1.28, -0.78, -0.56],
2451                [1.0, 0.95, -0.18, 0.45, 1.40],
2452                [1.0, 1.20, 0.66, -1.30, -0.02],
2453                [1.0, 1.45, -1.05, 0.22, 0.68],
2454                [1.0, -1.35, 0.90, 0.55, -0.43],
2455                [1.0, -0.98, -0.40, -0.88, 1.05],
2456                [1.0, -0.62, 1.42, 0.30, -0.70],
2457                [1.0, -0.28, -0.77, -1.18, 0.52],
2458                [1.0, 0.05, 0.15, 0.95, -1.35],
2459                [1.0, 0.33, -1.20, -0.40, 0.18],
2460                [1.0, 0.60, 0.82, 1.25, -0.85],
2461                [1.0, 0.88, -0.50, -0.65, 1.10],
2462                [1.0, 1.15, 1.05, 0.10, -0.22],
2463                [1.0, -1.22, -0.95, 0.72, 0.90],
2464                [1.0, -0.75, 0.38, -1.42, 0.15],
2465                [1.0, -0.42, -1.15, 0.50, -1.08],
2466                [1.0, -0.10, 0.60, -0.15, 0.75],
2467                [1.0, 0.25, -0.28, 1.05, -0.48],
2468                [1.0, 0.52, 1.35, -0.92, 0.30],
2469                [1.0, 0.80, -0.70, 0.38, 1.20],
2470                [1.0, 1.08, 0.48, -0.60, -0.95],
2471                [1.0, 1.35, -0.55, 0.85, 0.42]
2472            ];
2473            // Penalty matrix: zero on intercept, SPD on remaining 4 columns.
2474            let s0 = array![
2475                [0.0, 0.0, 0.0, 0.0, 0.0],
2476                [0.0, 1.40, 0.15, 0.05, -0.10],
2477                [0.0, 0.15, 1.10, -0.20, 0.08],
2478                [0.0, 0.05, -0.20, 0.95, 0.12],
2479                [0.0, -0.10, 0.08, 0.12, 1.25]
2480            ];
2481            let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2482            // Design-moving direction: perturb covariate columns, leave
2483            // intercept untouched.
2484            let x_tau_design = array![
2485                [0.0, 1.2e-3, -0.8e-3, 0.5e-3, -1.5e-3],
2486                [0.0, -2.0e-3, 1.4e-3, -0.3e-3, 0.9e-3],
2487                [0.0, 0.6e-3, -1.1e-3, 1.8e-3, -0.4e-3],
2488                [0.0, -1.3e-3, 0.7e-3, -1.0e-3, 2.1e-3],
2489                [0.0, 0.9e-3, -0.5e-3, 0.2e-3, -0.8e-3],
2490                [0.0, -0.4e-3, 1.8e-3, -1.5e-3, 0.3e-3],
2491                [0.0, 1.5e-3, -1.3e-3, 0.8e-3, -1.1e-3],
2492                [0.0, -0.7e-3, 0.4e-3, -2.0e-3, 1.6e-3],
2493                [0.0, 2.2e-3, -0.9e-3, 1.3e-3, -0.6e-3],
2494                [0.0, -1.0e-3, 1.6e-3, -0.7e-3, 0.5e-3],
2495                [0.0, 0.3e-3, -2.1e-3, 1.1e-3, -1.8e-3],
2496                [0.0, -1.8e-3, 0.2e-3, -0.4e-3, 1.3e-3],
2497                [0.0, 1.1e-3, -1.5e-3, 2.0e-3, -0.2e-3],
2498                [0.0, -0.5e-3, 0.9e-3, -1.2e-3, 0.7e-3],
2499                [0.0, 1.7e-3, -0.3e-3, 0.6e-3, -2.0e-3],
2500                [0.0, -1.4e-3, 1.1e-3, -0.9e-3, 0.4e-3],
2501                [0.0, 0.8e-3, -1.7e-3, 1.5e-3, -0.1e-3],
2502                [0.0, -0.2e-3, 0.6e-3, -1.8e-3, 1.0e-3],
2503                [0.0, 1.4e-3, -0.4e-3, 0.3e-3, -1.3e-3],
2504                [0.0, -0.9e-3, 2.0e-3, -0.5e-3, 0.8e-3],
2505                [0.0, 0.5e-3, -1.0e-3, 1.6e-3, -0.7e-3],
2506                [0.0, -2.1e-3, 0.3e-3, -0.8e-3, 1.5e-3],
2507                [0.0, 0.7e-3, -1.8e-3, 0.9e-3, -0.3e-3],
2508                [0.0, -0.6e-3, 1.3e-3, -2.2e-3, 1.1e-3],
2509                [0.0, 1.9e-3, -0.7e-3, 0.4e-3, -0.9e-3],
2510                [0.0, -1.1e-3, 0.5e-3, -1.4e-3, 2.2e-3],
2511                [0.0, 0.4e-3, -1.6e-3, 1.2e-3, -0.5e-3],
2512                [0.0, -1.6e-3, 0.8e-3, -0.1e-3, 0.6e-3],
2513                [0.0, 1.3e-3, -2.2e-3, 0.7e-3, -1.4e-3],
2514                [0.0, -0.3e-3, 1.0e-3, -1.6e-3, 1.8e-3]
2515            ];
2516            // Penalty-only direction: non-zero S_τ, symmetric, zero on intercept.
2517            let s_tau_penalty = array![
2518                [0.0, 0.0, 0.0, 0.0, 0.0],
2519                [0.0, 0.30, 0.05, -0.02, 0.04],
2520                [0.0, 0.05, 0.22, 0.03, -0.01],
2521                [0.0, -0.02, 0.03, 0.18, 0.06],
2522                [0.0, 0.04, -0.01, 0.06, 0.26]
2523            ];
2524            Self {
2525                w: Array1::<f64>::ones(y.len()),
2526                y,
2527                x,
2528                s0,
2529                cfg,
2530                rho: array![0.0],
2531                x_tau_design,
2532                s_tau_penalty,
2533            }
2534        }
2535    }
2536
2537    impl LogitDesignMotionFixture for BinomialLogitDesignMotionFixture {
2538        fn y(&self) -> &Array1<f64> {
2539            &self.y
2540        }
2541        fn w(&self) -> &Array1<f64> {
2542            &self.w
2543        }
2544        fn x(&self) -> &Array2<f64> {
2545            &self.x
2546        }
2547        fn s0(&self) -> &Array2<f64> {
2548            &self.s0
2549        }
2550        fn cfg(&self) -> &RemlConfig {
2551            &self.cfg
2552        }
2553        fn rho(&self) -> &Array1<f64> {
2554            &self.rho
2555        }
2556    }
2557
2558    // ── n=30, p=5 binomial-logit design-motion gradient tests ────────
2559
2560    #[test]
2561    pub(crate) fn binomial_logit_n30_design_moving_gradient_matches_fd() {
2562        // Pure design-motion: X_τ ≠ 0, S_τ = 0.
2563        // The IFT correction is essential here: because the family is
2564        // binomial-logit, the working weights W(η) depend on β̂, so
2565        // when X moves with ψ, the implicit derivative dβ̂/dψ enters
2566        // the total Hessian drift.  Without hessian_derivative_correction
2567        // the analytic gradient would disagree with FD.
2568        let f = BinomialLogitDesignMotionFixture::new();
2569        let state = f.state();
2570        let s_tau = Array2::<f64>::zeros((5, 5));
2571        let hyper = DirectionalHyperParam::single_penalty(
2572            0,
2573            f.x_tau_design.clone(),
2574            s_tau.clone(),
2575            None,
2576            None,
2577        )
2578        .expect("design-moving hyper direction");
2579
2580        let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2581            .expect("analytic directional gradient");
2582        let v_tau_fd = f.fd_directional_gradient(&f.x_tau_design, &s_tau);
2583
2584        let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2585        assert!(
2586            v_rel < 1e-3,
2587            "Binomial-logit n=30 design-moving gradient mismatch: rel={v_rel:.3e}, \
2588             analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2589        );
2590    }
2591
2592    #[test]
2593    pub(crate) fn binomial_logit_n30_penalty_only_gradient_matches_fd() {
2594        // Penalty-only direction: X_τ = 0, S_τ ≠ 0.
2595        // Serves as a baseline: the IFT correction should still be
2596        // present (since H depends on β̂ through W(η)), but the
2597        // explicit X_τ contribution is zero.
2598        let f = BinomialLogitDesignMotionFixture::new();
2599        let state = f.state();
2600        let x_tau = Array2::<f64>::zeros(f.x.raw_dim());
2601        let hyper = DirectionalHyperParam::single_penalty(
2602            0,
2603            x_tau.clone(),
2604            f.s_tau_penalty.clone(),
2605            None,
2606            None,
2607        )
2608        .expect("penalty-only hyper direction");
2609
2610        let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2611            .expect("analytic directional gradient");
2612        let v_tau_fd = f.fd_directional_gradient(&x_tau, &f.s_tau_penalty);
2613
2614        let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2615        assert!(
2616            v_rel < 1e-3,
2617            "Binomial-logit n=30 penalty-only gradient mismatch: rel={v_rel:.3e}, \
2618             analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2619        );
2620    }
2621
2622    #[test]
2623    pub(crate) fn binomial_logit_n30_joint_design_penalty_gradient_matches_fd() {
2624        // Joint direction: both X_τ ≠ 0 and S_τ ≠ 0 simultaneously.
2625        // This is the hardest case: the analytic gradient must correctly
2626        // combine the explicit penalty drift, the explicit design drift,
2627        // and the IFT Hessian-drift correction.
2628        let f = BinomialLogitDesignMotionFixture::new();
2629        let state = f.state();
2630        let hyper = DirectionalHyperParam::single_penalty(
2631            0,
2632            f.x_tau_design.clone(),
2633            f.s_tau_penalty.clone(),
2634            None,
2635            None,
2636        )
2637        .expect("joint design+penalty hyper direction");
2638
2639        let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2640            .expect("analytic directional gradient");
2641        let v_tau_fd = f.fd_directional_gradient(&f.x_tau_design, &f.s_tau_penalty);
2642
2643        let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2644        assert!(
2645            v_rel < 1e-3,
2646            "Binomial-logit n=30 joint design+penalty gradient mismatch: rel={v_rel:.3e}, \
2647             analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2648        );
2649    }
2650
2651    #[test]
2652    pub(crate) fn binomial_logit_n30_design_moving_hessian_matches_fd() {
2653        // Hessian-level validation with two τ-directions: one
2654        // penalty-only and one design-moving.  The ττ Hessian block is
2655        // the most sensitive test of the IFT correction because errors
2656        // in the correction accumulate quadratically in the trace term.
2657        let f = BinomialLogitDesignMotionFixture::new();
2658        let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2659        let s_tau_0 = f.s_tau_penalty.clone();
2660        let x_tau_1 = f.x_tau_design.clone();
2661        let s_tau_1 = Array2::<f64>::zeros((5, 5));
2662
2663        let hyper_dirs = vec![
2664            DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2665                .expect("penalty-only direction"),
2666            DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2667                .expect("design-moving direction"),
2668        ];
2669
2670        let state = f.state();
2671        let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2672        theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2673        let (_, _, h_full) =
2674            compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2675                .expect("joint cost+gradient+hessian");
2676        let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2677
2678        let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2679        let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2680        let h_tt_fd = directional_tau_hessian_fd_reference(
2681            &f.y,
2682            &f.w,
2683            &f.x,
2684            &f.s0,
2685            &f.cfg,
2686            &f.rho,
2687            &hyper_dirs,
2688            &x_tau_mats,
2689            &s_tau_mats,
2690        );
2691
2692        let num = (&h_tt_analytic - &h_tt_fd)
2693            .iter()
2694            .map(|v| v * v)
2695            .sum::<f64>()
2696            .sqrt();
2697        let den = h_tt_fd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2698        let rel = num / den;
2699        assert!(
2700            rel < 1e-4,
2701            "Binomial-logit n=30 tau-tau Hessian mismatch: rel={rel:.3e}, \
2702             analytic={h_tt_analytic:?}, fd={h_tt_fd:?}"
2703        );
2704    }
2705
2706    #[test]
2707    pub(crate) fn binomial_logit_n30_nonzero_rho_design_moving_gradient_matches_fd() {
2708        // Validate at a non-trivial smoothing parameter ρ = log(λ) = 1.5,
2709        // so the penalty term λS is scaled up and the balance between
2710        // likelihood and penalty is different from ρ=0.
2711        let f = BinomialLogitDesignMotionFixture::new();
2712        let rho = array![1.5];
2713        let s_tau = Array2::<f64>::zeros((5, 5));
2714
2715        let state = f.state();
2716        let hyper = DirectionalHyperParam::single_penalty(
2717            0,
2718            f.x_tau_design.clone(),
2719            s_tau.clone(),
2720            None,
2721            None,
2722        )
2723        .expect("design-moving hyper direction");
2724
2725        let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper)
2726            .expect("analytic directional gradient");
2727
2728        // FD at the shifted ρ: perturb X, re-solve inner, evaluate cost.
2729        let h = 2e-5;
2730        let (state_plus, state_minus) = f.state_perturbed(&f.x_tau_design, &s_tau, h);
2731        let v_plus = state_plus.compute_cost(&rho).expect("cost+");
2732        let v_minus = state_minus.compute_cost(&rho).expect("cost-");
2733        let v_tau_fd = (v_plus - v_minus) / (2.0 * h);
2734
2735        let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2736        assert!(
2737            v_rel < 1e-3,
2738            "Binomial-logit n=30 rho=1.5 design-moving gradient mismatch: rel={v_rel:.3e}, \
2739             analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2740        );
2741    }
2742
2743    #[test]
2744    pub(crate) fn binomial_logit_n30_rank_deficient_hessian_matches_cost_fd() {
2745        // Regression lock for the `PenaltySubspaceTrace` pseudo-logdet
2746        // kernel installed by the rank-deficient LAML fix (see
2747        // `PenaltySubspaceTrace` and `intrinsic_hessian_pseudo_logdet_parts`;
2748        // since #901 the cost is the intrinsic `½ log|H_pen|₊` and the kernel
2749        // is the spectral `H_pen⁺`, exact for every drift direction).
2750        //
2751        // The sibling `binomial_logit_n30_design_moving_hessian_matches_fd`
2752        // passes pre- AND post-fix because its FD reference differentiates
2753        // the *analytic gradient* — any self-consistent (if wrong) gradient
2754        // kernel gives a self-consistent Hessian under re-differentiation,
2755        // so that test cannot distinguish full-space from projected traces.
2756        // It passed under the buggy kernel because the same leakage entered
2757        // both sides of the ratio and cancelled.
2758        //
2759        // Here we FD-differentiate `compute_cost` TWICE and compare against
2760        // the analytic Hessian.  Central second differences expose every
2761        // disagreement between `½ log|U_Sᵀ H U_S|_+` (used by the cost) and
2762        // `½ tr(G_ε(H) · Ḣ)` / `−½ tr(G_ε Ḣ_i G_ε Ḣ_j)` (the full-space
2763        // traces that the gradient and Hessian used before the projection
2764        // fix).  Under the buggy kernel the IFT correction
2765        // `D_β H[v] = X' diag(c ⊙ X v) X` leaks onto `null(S)` — X's
2766        // all-ones intercept column sits there — and that leakage enters
2767        // the analytic Hessian but not the cost's projected logdet.
2768        //
2769        // Direction mix chosen to maximise the null-space leakage pathway:
2770        //   τ_0 = penalty-only (X_τ = 0, S_τ ≠ 0)  → v_0 = H⁻¹(−S_τ β̂) is
2771        //         concentrated in range(S_+), but `D_β H[v_0]` has rows and
2772        //         columns on the intercept because `X[:,0] = 1_n`.
2773        //   τ_1 = design-moving (X_τ ≠ 0 on non-intercept columns, S_τ = 0)
2774        //         → `v_1` also picks up the intercept via `X'WX_τβ̂`, and
2775        //         the base drift `X_τᵀWX + XᵀWX_τ` straddles range(S_+) /
2776        //         null(S).
2777        // Both pure directions AND the mixed partial load the Schur correction,
2778        // so any of the three entries can catch a regression.
2779        let f = BinomialLogitDesignMotionFixture::new();
2780        let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2781        let s_tau_0 = f.s_tau_penalty.clone();
2782        let x_tau_1 = f.x_tau_design.clone();
2783        let s_tau_1 = Array2::<f64>::zeros((5, 5));
2784
2785        let hyper_dirs = vec![
2786            DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2787                .expect("penalty-only direction"),
2788            DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2789                .expect("design-moving direction"),
2790        ];
2791
2792        // Analytic Hessian block.
2793        let state = f.state();
2794        let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2795        theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2796        let (_, _, h_full) =
2797            compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2798                .expect("joint cost+gradient+hessian");
2799        let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2800
2801        // Cost-level FD reference.  Central second differences give O(h²)
2802        // accuracy; the step is sized so the physical perturbation on X / S
2803        // stays near `1e-5` (same scale as the gradient tests).
2804        const TARGET_PHYSICAL_STEP: f64 = 1e-5;
2805        let x_tau_mats = [&x_tau_0, &x_tau_1];
2806        let s_tau_mats = [&s_tau_0, &s_tau_1];
2807        let steps: [f64; 2] = {
2808            let mut steps = [0.0; 2];
2809            for (j, step) in steps.iter_mut().enumerate() {
2810                let scale = x_tau_mats[j]
2811                    .iter()
2812                    .chain(s_tau_mats[j].iter())
2813                    .fold(0.0_f64, |acc, value| acc.max(value.abs()));
2814                *step = if scale > 0.0 {
2815                    TARGET_PHYSICAL_STEP / scale
2816                } else {
2817                    TARGET_PHYSICAL_STEP
2818                };
2819            }
2820            steps
2821        };
2822
2823        // Evaluate `compute_cost` at `(a · τ_0, b · τ_1)` multipliers.
2824        let eval_cost = |a: f64, b: f64| -> f64 {
2825            let x_eval = &f.x
2826                + &x_tau_mats[0].mapv(|v| a * steps[0] * v)
2827                + &x_tau_mats[1].mapv(|v| b * steps[1] * v);
2828            let s_eval = &f.s0
2829                + &s_tau_mats[0].mapv(|v| a * steps[0] * v)
2830                + &s_tau_mats[1].mapv(|v| b * steps[1] * v);
2831            let st = build_logit_state(&f.y, &f.w, &x_eval, &s_eval, &f.cfg);
2832            st.compute_cost(&f.rho).expect("cost eval")
2833        };
2834
2835        let v_00 = eval_cost(0.0, 0.0);
2836        let v_p0 = eval_cost(1.0, 0.0);
2837        let v_m0 = eval_cost(-1.0, 0.0);
2838        let v_0p = eval_cost(0.0, 1.0);
2839        let v_0m = eval_cost(0.0, -1.0);
2840        let v_pp = eval_cost(1.0, 1.0);
2841        let v_pm = eval_cost(1.0, -1.0);
2842        let v_mp = eval_cost(-1.0, 1.0);
2843        let v_mm = eval_cost(-1.0, -1.0);
2844
2845        let h00_fd = (v_p0 - 2.0 * v_00 + v_m0) / (steps[0] * steps[0]);
2846        let h11_fd = (v_0p - 2.0 * v_00 + v_0m) / (steps[1] * steps[1]);
2847        let h01_fd = (v_pp - v_pm - v_mp + v_mm) / (4.0 * steps[0] * steps[1]);
2848
2849        let h_tt_fd = array![[h00_fd, h01_fd], [h01_fd, h11_fd]];
2850
2851        let num = (&h_tt_analytic - &h_tt_fd)
2852            .iter()
2853            .map(|v| v * v)
2854            .sum::<f64>()
2855            .sqrt();
2856        let den = h_tt_fd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2857        let rel = num / den;
2858
2859        assert!(
2860            rel < 3e-3,
2861            "Binomial-logit n=30 rank-deficient Hessian vs cost-FD mismatch: rel={rel:.3e}, \
2862             analytic={h_tt_analytic:?}, fd={h_tt_fd:?}"
2863        );
2864    }
2865}
2866
2867#[derive(Clone, Copy, Debug)]
2868pub(crate) enum RemlGeometry {
2869    DenseSpectral,
2870    SparseExactSpd,
2871}
2872
2873trait PenalizedGeometry {
2874    fn backend_kind(&self) -> GeometryBackendKind;
2875}
2876
2877#[derive(Clone)]
2878pub(crate) enum DerivativeMatrixStorage {
2879    Dense(Array2<f64>),
2880    Zero(ZeroDerivativeMatrix),
2881    Embedded(EmbeddedDerivativeMatrix),
2882    Implicit(ImplicitDerivativeOp),
2883    LatentCoord(LatentCoordDerivativeOp),
2884}
2885
2886/// Mechanical surface every `DerivativeMatrixStorage` variant must expose so
2887/// the `HyperDesignDerivative` / `HyperPenaltyDerivative` wrappers can dispatch
2888/// with a single per-call `storage_dispatch!`. Each backend owns its variant's
2889/// substantive math; the wrappers contain only one-line routing.
2890///
2891/// `design_*` variants treat the backend as an X-style operator (rows index
2892/// data, columns index coefficients); `penalty_*` variants treat the backend
2893/// as a square `p×p` penalty in the global coefficient frame. The Embedded
2894/// case is the only variant whose two views genuinely differ (local rows vs
2895/// total_dim square), which is why the two role-specific methods both live in
2896/// one trait rather than two parallel traits.
2897trait DerivativeStorageBackend {
2898    fn resident_byte_count(&self) -> usize;
2899    fn design_nrows(&self) -> usize;
2900    fn design_ncols(&self) -> usize;
2901    fn penalty_dim(&self) -> usize;
2902    fn uses_implicit_storage(&self) -> bool;
2903    fn any_nonzero(&self) -> bool;
2904    fn materialize(&self) -> Array2<f64>;
2905    fn implicit_first_axis_info(
2906        &self,
2907    ) -> Option<(
2908        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
2909        usize,
2910    )>;
2911    fn implicit_axis_count_hint(&self) -> Option<usize>;
2912    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError>;
2913    fn design_transpose_mul_original(
2914        &self,
2915        v: &Array1<f64>,
2916    ) -> Result<Array1<f64>, EstimationError>;
2917    fn design_transformed(
2918        &self,
2919        qs: &Array2<f64>,
2920        free_basis_opt: Option<&Array2<f64>>,
2921    ) -> Result<Array2<f64>, EstimationError>;
2922    /// Default materialises through `design_transformed` then `.dot(u)`;
2923    /// implicit/latent-coordinate backends override with a direct-operator
2924    /// path that skips the dense materialisation.
2925    fn design_transformed_forward_mul(
2926        &self,
2927        qs: &Array2<f64>,
2928        free_basis_opt: Option<&Array2<f64>>,
2929        u: &Array1<f64>,
2930    ) -> Result<Array1<f64>, EstimationError> {
2931        Ok(self.design_transformed(qs, free_basis_opt)?.dot(u))
2932    }
2933    /// Default materialises through `design_transformed` then `.t().dot(v)`;
2934    /// implicit/latent-coordinate backends override with a direct path.
2935    fn design_transformed_transpose_mul(
2936        &self,
2937        qs: &Array2<f64>,
2938        free_basis_opt: Option<&Array2<f64>>,
2939        v: &Array1<f64>,
2940    ) -> Result<Array1<f64>, EstimationError> {
2941        Ok(self.design_transformed(qs, free_basis_opt)?.t().dot(v))
2942    }
2943    fn penalty_transformed(
2944        &self,
2945        qs: &Array2<f64>,
2946        free_basis_opt: Option<&Array2<f64>>,
2947    ) -> Result<Array2<f64>, EstimationError>;
2948    fn penalty_scaled_add_to(
2949        &self,
2950        target: &mut Array2<f64>,
2951        amp: f64,
2952    ) -> Result<(), EstimationError>;
2953}
2954
2955/// Fans `expr` over the four `DerivativeMatrixStorage` variants in one place
2956/// so every wrapper method is a single dispatch line — the compiler enforces
2957/// exhaustiveness here, so adding a new variant produces one hard error at
2958/// this site rather than a silent miss in any of the (currently 16) ladders.
2959macro_rules! storage_dispatch {
2960    ($scrutinee:expr, $backend:ident => $body:expr) => {
2961        match $scrutinee {
2962            DerivativeMatrixStorage::Dense($backend) => $body,
2963            DerivativeMatrixStorage::Zero($backend) => $body,
2964            DerivativeMatrixStorage::Embedded($backend) => $body,
2965            DerivativeMatrixStorage::Implicit($backend) => $body,
2966            DerivativeMatrixStorage::LatentCoord($backend) => $body,
2967        }
2968    };
2969}
2970
2971#[derive(Clone)]
2972pub(crate) struct ZeroDerivativeMatrix {
2973    rows: usize,
2974    cols: usize,
2975}
2976
2977impl ZeroDerivativeMatrix {
2978    pub(crate) fn new(rows: usize, cols: usize) -> Self {
2979        Self { rows, cols }
2980    }
2981}
2982
2983/// Which derivative level the implicit operator should compute.
2984#[derive(Clone, Copy, Debug)]
2985pub enum ImplicitDerivLevel {
2986    /// ∂X/∂ψ_d
2987    First(usize),
2988    /// ∂²X/∂ψ_d²
2989    SecondDiag(usize),
2990    /// ∂²X/∂ψ_d∂ψ_e
2991    SecondCross(usize, usize),
2992}
2993
2994/// Lazy implicit operator storage: delegates matvecs to the
2995/// `ImplicitDesignPsiDerivative` and materializes dense form only on demand.
2996#[derive(Clone)]
2997pub(crate) struct ImplicitDerivativeOp {
2998    pub(crate) operator: std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
2999    pub(crate) level: ImplicitDerivLevel,
3000    pub(crate) global_range: Range<usize>,
3001    pub(crate) total_dim: usize,
3002    /// Cached dense materialization (lazy, populated on first call to ops that need the full matrix).
3003    ///
3004    /// Rayon-safe: `materialize_local` calls `materialize_first` / `_second_diag`
3005    /// / `_second_cross` on the implicit basis-derivative operator, which for
3006    /// streaming bases dispatches `(0..nc).into_par_iter().for_each(...)`. A plain
3007    /// `std::sync::OnceLock` here would deadlock if `materialize_dense` were first
3008    /// called concurrently from inside another rayon par_iter — racing workers
3009    /// would park on the OnceLock's OS condvar, leaving the leader's nested
3010    /// par_iter without workers. `RayonSafeOnce` runs init lock-free.
3011    pub(crate) cached_dense: std::sync::Arc<gam_runtime::resource::RayonSafeOnce<Array2<f64>>>,
3012}
3013
3014#[derive(Clone)]
3015pub(crate) struct LatentCoordDerivativeOp {
3016    pub(crate) operator: std::sync::Arc<gam_terms::basis::LatentCoordDesignDerivative>,
3017    pub(crate) flat_axis: usize,
3018    pub(crate) global_range: Range<usize>,
3019    pub(crate) total_dim: usize,
3020    pub(crate) cached_dense: std::sync::Arc<gam_runtime::resource::RayonSafeOnce<Array2<f64>>>,
3021}
3022
3023impl LatentCoordDerivativeOp {
3024    pub(crate) fn materialize_local(&self) -> Array2<f64> {
3025        self.operator.materialize_axis(self.flat_axis).expect(
3026            "radial scalar evaluation failed during latent-coordinate derivative materialization",
3027        )
3028    }
3029
3030    pub(crate) fn materialize_dense(&self) -> &Array2<f64> {
3031        self.cached_dense.get_or_compute(|| {
3032            let local = self.materialize_local();
3033            let mut out = Array2::<f64>::zeros((local.nrows(), self.total_dim));
3034            out.slice_mut(s![.., self.global_range.clone()])
3035                .assign(&local);
3036            out
3037        })
3038    }
3039
3040    pub(crate) fn nrows(&self) -> usize {
3041        self.operator.n_data()
3042    }
3043
3044    pub(crate) fn ncols(&self) -> usize {
3045        self.total_dim
3046    }
3047
3048    pub(crate) fn transpose_mul(&self, v: &Array1<f64>) -> Array1<f64> {
3049        let local = self
3050            .operator
3051            .transpose_mul_axis(self.flat_axis, &v.view())
3052            .expect(
3053                "radial scalar evaluation failed during latent-coordinate derivative transpose_mul",
3054            );
3055        let mut out = Array1::<f64>::zeros(self.total_dim);
3056        out.slice_mut(s![self.global_range.clone()]).assign(&local);
3057        out
3058    }
3059
3060    pub(crate) fn forward_mul(&self, u: &Array1<f64>) -> Array1<f64> {
3061        let u_local = u.slice(s![self.global_range.clone()]).to_owned();
3062        self.operator
3063            .forward_mul_axis(self.flat_axis, &u_local.view())
3064            .expect(
3065                "radial scalar evaluation failed during latent-coordinate derivative forward_mul",
3066            )
3067    }
3068}
3069
3070impl ImplicitDerivativeOp {
3071    pub(crate) fn materialize_local(&self) -> Array2<f64> {
3072        match self.level {
3073            ImplicitDerivLevel::First(axis) => self.operator.materialize_first(axis).expect(
3074                "radial scalar evaluation failed during implicit derivative materialization",
3075            ),
3076            ImplicitDerivLevel::SecondDiag(axis) => {
3077                self.operator.materialize_second_diag(axis).expect(
3078                    "radial scalar evaluation failed during implicit derivative materialization",
3079                )
3080            }
3081            ImplicitDerivLevel::SecondCross(d, e) => {
3082                self.operator.materialize_second_cross(d, e).expect(
3083                    "radial scalar evaluation failed during implicit derivative materialization",
3084                )
3085            }
3086        }
3087    }
3088
3089    pub(crate) fn materialize_dense(&self) -> &Array2<f64> {
3090        self.cached_dense.get_or_compute(|| {
3091            let local = self.materialize_local();
3092            let mut out = Array2::<f64>::zeros((local.nrows(), self.total_dim));
3093            out.slice_mut(s![.., self.global_range.clone()])
3094                .assign(&local);
3095            out
3096        })
3097    }
3098
3099    pub(crate) fn nrows(&self) -> usize {
3100        self.operator.n_data()
3101    }
3102
3103    pub(crate) fn ncols(&self) -> usize {
3104        self.total_dim
3105    }
3106
3107    pub(crate) fn transpose_mul(&self, v: &Array1<f64>) -> Array1<f64> {
3108        let local = match self.level {
3109            ImplicitDerivLevel::First(axis) => self
3110                .operator
3111                .transpose_mul(axis, &v.view())
3112                .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
3113            ImplicitDerivLevel::SecondDiag(axis) => self
3114                .operator
3115                .transpose_mul_second_diag(axis, &v.view())
3116                .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
3117            ImplicitDerivLevel::SecondCross(d, e) => self
3118                .operator
3119                .transpose_mul_second_cross(d, e, &v.view())
3120                .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
3121        };
3122        let mut out = Array1::<f64>::zeros(self.total_dim);
3123        out.slice_mut(s![self.global_range.clone()]).assign(&local);
3124        out
3125    }
3126
3127    pub(crate) fn forward_mul(&self, u: &Array1<f64>) -> Array1<f64> {
3128        let u_local = u.slice(s![self.global_range.clone()]).to_owned();
3129        match self.level {
3130            ImplicitDerivLevel::First(axis) => self
3131                .operator
3132                .forward_mul(axis, &u_local.view())
3133                .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3134            ImplicitDerivLevel::SecondDiag(axis) => self
3135                .operator
3136                .forward_mul_second_diag(axis, &u_local.view())
3137                .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3138            ImplicitDerivLevel::SecondCross(d, e) => self
3139                .operator
3140                .forward_mul_second_cross(d, e, &u_local.view())
3141                .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3142        }
3143    }
3144}
3145
3146#[derive(Clone)]
3147pub(crate) struct EmbeddedDerivativeMatrix {
3148    pub(crate) local: Array2<f64>,
3149    pub(crate) global_range: Range<usize>,
3150    pub(crate) total_dim: usize,
3151}
3152
3153impl EmbeddedDerivativeMatrix {
3154    pub(crate) fn new(local: Array2<f64>, global_range: Range<usize>, total_dim: usize) -> Self {
3155        Self {
3156            local,
3157            global_range,
3158            total_dim,
3159        }
3160    }
3161}
3162
3163impl DerivativeStorageBackend for Array2<f64> {
3164    fn resident_byte_count(&self) -> usize {
3165        self.len().saturating_mul(std::mem::size_of::<f64>())
3166    }
3167    fn design_nrows(&self) -> usize {
3168        Array2::nrows(self)
3169    }
3170    fn design_ncols(&self) -> usize {
3171        Array2::ncols(self)
3172    }
3173    fn penalty_dim(&self) -> usize {
3174        Array2::nrows(self)
3175    }
3176    fn uses_implicit_storage(&self) -> bool {
3177        false
3178    }
3179    fn any_nonzero(&self) -> bool {
3180        self.iter().any(|v| *v != 0.0)
3181    }
3182    fn materialize(&self) -> Array2<f64> {
3183        self.clone()
3184    }
3185    fn implicit_first_axis_info(
3186        &self,
3187    ) -> Option<(
3188        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3189        usize,
3190    )> {
3191        None
3192    }
3193    fn implicit_axis_count_hint(&self) -> Option<usize> {
3194        None
3195    }
3196
3197    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3198        if Array2::ncols(self) != u.len() {
3199            crate::bail_invalid_estim!(
3200                "dense hyper design derivative forward_mul_original width mismatch: matrix={}x{}, vector={}",
3201                Array2::nrows(self),
3202                Array2::ncols(self),
3203                u.len()
3204            );
3205        }
3206        Ok(self.dot(u))
3207    }
3208
3209    fn design_transpose_mul_original(
3210        &self,
3211        v: &Array1<f64>,
3212    ) -> Result<Array1<f64>, EstimationError> {
3213        if Array2::nrows(self) != v.len() {
3214            crate::bail_invalid_estim!(
3215                "dense hyper design derivative transpose_mul_original height mismatch: matrix={}x{}, vector={}",
3216                Array2::nrows(self),
3217                Array2::ncols(self),
3218                v.len()
3219            );
3220        }
3221        Ok(self.t().dot(v))
3222    }
3223
3224    fn design_transformed(
3225        &self,
3226        qs: &Array2<f64>,
3227        free_basis_opt: Option<&Array2<f64>>,
3228    ) -> Result<Array2<f64>, EstimationError> {
3229        Ok(gam_linalg::matrix::DenseRightProductView::new(self)
3230            .with_factor(qs)
3231            .with_optional_factor(free_basis_opt)
3232            .materialize())
3233    }
3234
3235    fn penalty_transformed(
3236        &self,
3237        qs: &Array2<f64>,
3238        free_basis_opt: Option<&Array2<f64>>,
3239    ) -> Result<Array2<f64>, EstimationError> {
3240        let mut transformed = qs.t().dot(self).dot(qs);
3241        if let Some(z) = free_basis_opt {
3242            transformed = z.t().dot(&transformed).dot(z);
3243        }
3244        Ok(transformed)
3245    }
3246
3247    fn penalty_scaled_add_to(
3248        &self,
3249        target: &mut Array2<f64>,
3250        amp: f64,
3251    ) -> Result<(), EstimationError> {
3252        if target.raw_dim() != self.raw_dim() {
3253            crate::bail_invalid_estim!(
3254                "dense hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3255                target.nrows(),
3256                target.ncols(),
3257                Array2::nrows(self),
3258                Array2::ncols(self)
3259            );
3260        }
3261        target.scaled_add(amp, self);
3262        Ok(())
3263    }
3264}
3265
3266impl DerivativeStorageBackend for ZeroDerivativeMatrix {
3267    fn resident_byte_count(&self) -> usize {
3268        0
3269    }
3270    fn design_nrows(&self) -> usize {
3271        self.rows
3272    }
3273    fn design_ncols(&self) -> usize {
3274        self.cols
3275    }
3276    fn penalty_dim(&self) -> usize {
3277        self.cols
3278    }
3279    fn uses_implicit_storage(&self) -> bool {
3280        false
3281    }
3282    fn any_nonzero(&self) -> bool {
3283        false
3284    }
3285    fn materialize(&self) -> Array2<f64> {
3286        Array2::<f64>::zeros((self.rows, self.cols))
3287    }
3288    fn implicit_first_axis_info(
3289        &self,
3290    ) -> Option<(
3291        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3292        usize,
3293    )> {
3294        None
3295    }
3296    fn implicit_axis_count_hint(&self) -> Option<usize> {
3297        None
3298    }
3299
3300    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3301        if self.cols != u.len() {
3302            crate::bail_invalid_estim!(
3303                "zero hyper design derivative forward_mul_original width mismatch: matrix={}x{}, vector={}",
3304                self.rows,
3305                self.cols,
3306                u.len()
3307            );
3308        }
3309        Ok(Array1::<f64>::zeros(self.rows))
3310    }
3311
3312    fn design_transpose_mul_original(
3313        &self,
3314        v: &Array1<f64>,
3315    ) -> Result<Array1<f64>, EstimationError> {
3316        if self.rows != v.len() {
3317            crate::bail_invalid_estim!(
3318                "zero hyper design derivative transpose_mul_original height mismatch: matrix={}x{}, vector={}",
3319                self.rows,
3320                self.cols,
3321                v.len()
3322            );
3323        }
3324        Ok(Array1::<f64>::zeros(self.cols))
3325    }
3326
3327    fn design_transformed(
3328        &self,
3329        qs: &Array2<f64>,
3330        free_basis_opt: Option<&Array2<f64>>,
3331    ) -> Result<Array2<f64>, EstimationError> {
3332        if self.cols != qs.nrows() {
3333            crate::bail_invalid_estim!(
3334                "zero design derivative width mismatch: total_cols={}, qs rows={}",
3335                self.cols,
3336                qs.nrows()
3337            );
3338        }
3339        let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3340        Ok(Array2::<f64>::zeros((self.rows, cols)))
3341    }
3342
3343    fn design_transformed_forward_mul(
3344        &self,
3345        qs: &Array2<f64>,
3346        free_basis_opt: Option<&Array2<f64>>,
3347        u: &Array1<f64>,
3348    ) -> Result<Array1<f64>, EstimationError> {
3349        if self.cols != qs.nrows() {
3350            crate::bail_invalid_estim!(
3351                "zero design derivative width mismatch: total_cols={}, qs rows={}",
3352                self.cols,
3353                qs.nrows()
3354            );
3355        }
3356        let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3357        if u.len() != cols {
3358            crate::bail_invalid_estim!(
3359                "zero design derivative transformed forward width mismatch: expected {}, vector={}",
3360                cols,
3361                u.len()
3362            );
3363        }
3364        Ok(Array1::<f64>::zeros(self.rows))
3365    }
3366
3367    fn design_transformed_transpose_mul(
3368        &self,
3369        qs: &Array2<f64>,
3370        free_basis_opt: Option<&Array2<f64>>,
3371        v: &Array1<f64>,
3372    ) -> Result<Array1<f64>, EstimationError> {
3373        if self.rows != v.len() {
3374            crate::bail_invalid_estim!(
3375                "zero design derivative transpose height mismatch: matrix rows={}, vector={}",
3376                self.rows,
3377                v.len()
3378            );
3379        }
3380        if self.cols != qs.nrows() {
3381            crate::bail_invalid_estim!(
3382                "zero design derivative width mismatch: total_cols={}, qs rows={}",
3383                self.cols,
3384                qs.nrows()
3385            );
3386        }
3387        let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3388        Ok(Array1::<f64>::zeros(cols))
3389    }
3390
3391    fn penalty_transformed(
3392        &self,
3393        qs: &Array2<f64>,
3394        free_basis_opt: Option<&Array2<f64>>,
3395    ) -> Result<Array2<f64>, EstimationError> {
3396        if self.cols != qs.nrows() {
3397            crate::bail_invalid_estim!(
3398                "zero penalty derivative width mismatch: total_dim={}, qs rows={}",
3399                self.cols,
3400                qs.nrows()
3401            );
3402        }
3403        let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3404        Ok(Array2::<f64>::zeros((cols, cols)))
3405    }
3406
3407    fn penalty_scaled_add_to(
3408        &self,
3409        target: &mut Array2<f64>,
3410        amp: f64,
3411    ) -> Result<(), EstimationError> {
3412        // Zero penalty derivative: `amp · 0 = 0`, so `amp` scales nothing and
3413        // `target` is left unchanged. Validate it is finite so a bad scale
3414        // surfaces here rather than silently no-op'ing on a NaN/inf amplitude.
3415        if !amp.is_finite() {
3416            crate::bail_invalid_estim!(
3417                "zero hyper penalty derivative received non-finite amp={amp}"
3418            );
3419        }
3420        if target.nrows() != self.cols || target.ncols() != self.cols {
3421            crate::bail_invalid_estim!(
3422                "zero hyper penalty derivative shape mismatch: target={}x{}, expected {}x{}",
3423                target.nrows(),
3424                target.ncols(),
3425                self.cols,
3426                self.cols
3427            );
3428        }
3429        Ok(())
3430    }
3431}
3432
3433impl DerivativeStorageBackend for EmbeddedDerivativeMatrix {
3434    fn resident_byte_count(&self) -> usize {
3435        self.local.len().saturating_mul(std::mem::size_of::<f64>())
3436    }
3437    fn design_nrows(&self) -> usize {
3438        self.local.nrows()
3439    }
3440    fn design_ncols(&self) -> usize {
3441        self.total_dim
3442    }
3443    fn penalty_dim(&self) -> usize {
3444        self.total_dim
3445    }
3446    fn uses_implicit_storage(&self) -> bool {
3447        false
3448    }
3449    fn any_nonzero(&self) -> bool {
3450        self.local.iter().any(|v| *v != 0.0)
3451    }
3452    fn materialize(&self) -> Array2<f64> {
3453        let mut dense = Array2::<f64>::zeros((self.local.nrows(), self.total_dim));
3454        dense
3455            .slice_mut(s![.., self.global_range.clone()])
3456            .assign(&self.local);
3457        dense
3458    }
3459    fn implicit_first_axis_info(
3460        &self,
3461    ) -> Option<(
3462        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3463        usize,
3464    )> {
3465        None
3466    }
3467    fn implicit_axis_count_hint(&self) -> Option<usize> {
3468        None
3469    }
3470
3471    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3472        if self.total_dim != u.len() {
3473            crate::bail_invalid_estim!(
3474                "embedded hyper design derivative forward_mul_original width mismatch: total_dim={}, vector={}",
3475                self.total_dim,
3476                u.len()
3477            );
3478        }
3479        let u_local = u.slice(s![self.global_range.clone()]).to_owned();
3480        Ok(self.local.dot(&u_local))
3481    }
3482
3483    fn design_transpose_mul_original(
3484        &self,
3485        v: &Array1<f64>,
3486    ) -> Result<Array1<f64>, EstimationError> {
3487        if self.local.nrows() != v.len() {
3488            crate::bail_invalid_estim!(
3489                "embedded hyper design derivative transpose_mul_original height mismatch: local_rows={}, vector={}",
3490                self.local.nrows(),
3491                v.len()
3492            );
3493        }
3494        let mut out = Array1::<f64>::zeros(self.total_dim);
3495        let pulled = self.local.t().dot(v);
3496        out.slice_mut(s![self.global_range.clone()]).assign(&pulled);
3497        Ok(out)
3498    }
3499
3500    fn design_transformed(
3501        &self,
3502        qs: &Array2<f64>,
3503        free_basis_opt: Option<&Array2<f64>>,
3504    ) -> Result<Array2<f64>, EstimationError> {
3505        if self.total_dim != qs.nrows() {
3506            crate::bail_invalid_estim!(
3507                "embedded design derivative width mismatch: total_cols={}, qs rows={}",
3508                self.total_dim,
3509                qs.nrows()
3510            );
3511        }
3512        let qs_local = qs.slice(s![self.global_range.clone(), ..]);
3513        let mut transformed = self.local.dot(&qs_local);
3514        if let Some(z) = free_basis_opt {
3515            transformed = transformed.dot(z);
3516        }
3517        Ok(transformed)
3518    }
3519
3520    fn penalty_transformed(
3521        &self,
3522        qs: &Array2<f64>,
3523        free_basis_opt: Option<&Array2<f64>>,
3524    ) -> Result<Array2<f64>, EstimationError> {
3525        if self.total_dim != qs.nrows() {
3526            crate::bail_invalid_estim!(
3527                "embedded penalty derivative width mismatch: total_dim={}, qs rows={}",
3528                self.total_dim,
3529                qs.nrows()
3530            );
3531        }
3532        let qs_local = qs.slice(s![self.global_range.clone(), ..]);
3533        let mut transformed = qs_local.t().dot(&self.local).dot(&qs_local);
3534        if let Some(z) = free_basis_opt {
3535            transformed = z.t().dot(&transformed).dot(z);
3536        }
3537        Ok(transformed)
3538    }
3539
3540    fn penalty_scaled_add_to(
3541        &self,
3542        target: &mut Array2<f64>,
3543        amp: f64,
3544    ) -> Result<(), EstimationError> {
3545        if target.nrows() != self.total_dim || target.ncols() != self.total_dim {
3546            crate::bail_invalid_estim!(
3547                "embedded hyper penalty derivative shape mismatch: target={}x{}, expected {}x{}",
3548                target.nrows(),
3549                target.ncols(),
3550                self.total_dim,
3551                self.total_dim
3552            );
3553        }
3554        target
3555            .slice_mut(s![self.global_range.clone(), self.global_range.clone()])
3556            .scaled_add(amp, &self.local);
3557        Ok(())
3558    }
3559}
3560
3561impl DerivativeStorageBackend for ImplicitDerivativeOp {
3562    fn resident_byte_count(&self) -> usize {
3563        0
3564    }
3565    fn design_nrows(&self) -> usize {
3566        self.nrows()
3567    }
3568    fn design_ncols(&self) -> usize {
3569        self.ncols()
3570    }
3571    fn penalty_dim(&self) -> usize {
3572        self.nrows()
3573    }
3574    fn uses_implicit_storage(&self) -> bool {
3575        true
3576    }
3577    fn any_nonzero(&self) -> bool {
3578        true
3579    }
3580    fn materialize(&self) -> Array2<f64> {
3581        self.materialize_dense().clone()
3582    }
3583    fn implicit_first_axis_info(
3584        &self,
3585    ) -> Option<(
3586        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3587        usize,
3588    )> {
3589        match self.level {
3590            ImplicitDerivLevel::First(axis) => Some((self.operator.clone(), axis)),
3591            _ => None,
3592        }
3593    }
3594    fn implicit_axis_count_hint(&self) -> Option<usize> {
3595        Some(self.operator.n_axes())
3596    }
3597
3598    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3599        if self.ncols() != u.len() {
3600            crate::bail_invalid_estim!(
3601                "implicit hyper design derivative forward_mul_original width mismatch: operator_cols={}, vector={}",
3602                self.ncols(),
3603                u.len()
3604            );
3605        }
3606        Ok(self.forward_mul(u))
3607    }
3608
3609    fn design_transpose_mul_original(
3610        &self,
3611        v: &Array1<f64>,
3612    ) -> Result<Array1<f64>, EstimationError> {
3613        if self.nrows() != v.len() {
3614            crate::bail_invalid_estim!(
3615                "implicit hyper design derivative transpose_mul_original height mismatch: operator_rows={}, vector={}",
3616                self.nrows(),
3617                v.len()
3618            );
3619        }
3620        Ok(self.transpose_mul(v))
3621    }
3622
3623    fn design_transformed(
3624        &self,
3625        qs: &Array2<f64>,
3626        free_basis_opt: Option<&Array2<f64>>,
3627    ) -> Result<Array2<f64>, EstimationError> {
3628        let dense = self.materialize_dense();
3629        Ok(gam_linalg::matrix::DenseRightProductView::new(dense)
3630            .with_factor(qs)
3631            .with_optional_factor(free_basis_opt)
3632            .materialize())
3633    }
3634
3635    fn design_transformed_forward_mul(
3636        &self,
3637        qs: &Array2<f64>,
3638        free_basis_opt: Option<&Array2<f64>>,
3639        u: &Array1<f64>,
3640    ) -> Result<Array1<f64>, EstimationError> {
3641        let mut right = if let Some(z) = free_basis_opt {
3642            z.dot(u)
3643        } else {
3644            u.clone()
3645        };
3646        right = qs.dot(&right);
3647        Ok(self.forward_mul(&right))
3648    }
3649
3650    fn design_transformed_transpose_mul(
3651        &self,
3652        qs: &Array2<f64>,
3653        free_basis_opt: Option<&Array2<f64>>,
3654        v: &Array1<f64>,
3655    ) -> Result<Array1<f64>, EstimationError> {
3656        let mut pulled = qs.t().dot(&self.transpose_mul(v));
3657        if let Some(z) = free_basis_opt {
3658            pulled = z.t().dot(&pulled);
3659        }
3660        Ok(pulled)
3661    }
3662
3663    fn penalty_transformed(
3664        &self,
3665        qs: &Array2<f64>,
3666        free_basis_opt: Option<&Array2<f64>>,
3667    ) -> Result<Array2<f64>, EstimationError> {
3668        let dense = self.materialize_dense();
3669        let mut transformed = qs.t().dot(dense).dot(qs);
3670        if let Some(z) = free_basis_opt {
3671            transformed = z.t().dot(&transformed).dot(z);
3672        }
3673        Ok(transformed)
3674    }
3675
3676    fn penalty_scaled_add_to(
3677        &self,
3678        target: &mut Array2<f64>,
3679        amp: f64,
3680    ) -> Result<(), EstimationError> {
3681        let dense = self.materialize_dense();
3682        if target.raw_dim() != dense.raw_dim() {
3683            crate::bail_invalid_estim!(
3684                "implicit hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3685                target.nrows(),
3686                target.ncols(),
3687                dense.nrows(),
3688                dense.ncols()
3689            );
3690        }
3691        target.scaled_add(amp, dense);
3692        Ok(())
3693    }
3694}
3695
3696impl DerivativeStorageBackend for LatentCoordDerivativeOp {
3697    fn resident_byte_count(&self) -> usize {
3698        0
3699    }
3700    fn design_nrows(&self) -> usize {
3701        self.nrows()
3702    }
3703    fn design_ncols(&self) -> usize {
3704        self.ncols()
3705    }
3706    fn penalty_dim(&self) -> usize {
3707        self.nrows()
3708    }
3709    fn uses_implicit_storage(&self) -> bool {
3710        true
3711    }
3712    fn any_nonzero(&self) -> bool {
3713        true
3714    }
3715    fn materialize(&self) -> Array2<f64> {
3716        self.materialize_dense().clone()
3717    }
3718    fn implicit_first_axis_info(
3719        &self,
3720    ) -> Option<(
3721        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3722        usize,
3723    )> {
3724        None
3725    }
3726    fn implicit_axis_count_hint(&self) -> Option<usize> {
3727        Some(self.operator.n_axes())
3728    }
3729
3730    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3731        if self.ncols() != u.len() {
3732            crate::bail_invalid_estim!(
3733                "latent-coordinate hyper design derivative forward_mul_original width mismatch: operator_cols={}, vector={}",
3734                self.ncols(),
3735                u.len()
3736            );
3737        }
3738        Ok(self.forward_mul(u))
3739    }
3740
3741    fn design_transpose_mul_original(
3742        &self,
3743        v: &Array1<f64>,
3744    ) -> Result<Array1<f64>, EstimationError> {
3745        if self.nrows() != v.len() {
3746            crate::bail_invalid_estim!(
3747                "latent-coordinate hyper design derivative transpose_mul_original height mismatch: operator_rows={}, vector={}",
3748                self.nrows(),
3749                v.len()
3750            );
3751        }
3752        Ok(self.transpose_mul(v))
3753    }
3754
3755    fn design_transformed(
3756        &self,
3757        qs: &Array2<f64>,
3758        free_basis_opt: Option<&Array2<f64>>,
3759    ) -> Result<Array2<f64>, EstimationError> {
3760        let dense = self.materialize_dense();
3761        Ok(gam_linalg::matrix::DenseRightProductView::new(dense)
3762            .with_factor(qs)
3763            .with_optional_factor(free_basis_opt)
3764            .materialize())
3765    }
3766
3767    fn design_transformed_forward_mul(
3768        &self,
3769        qs: &Array2<f64>,
3770        free_basis_opt: Option<&Array2<f64>>,
3771        u: &Array1<f64>,
3772    ) -> Result<Array1<f64>, EstimationError> {
3773        let mut right = if let Some(z) = free_basis_opt {
3774            z.dot(u)
3775        } else {
3776            u.clone()
3777        };
3778        right = qs.dot(&right);
3779        Ok(self.forward_mul(&right))
3780    }
3781
3782    fn design_transformed_transpose_mul(
3783        &self,
3784        qs: &Array2<f64>,
3785        free_basis_opt: Option<&Array2<f64>>,
3786        v: &Array1<f64>,
3787    ) -> Result<Array1<f64>, EstimationError> {
3788        let mut pulled = qs.t().dot(&self.transpose_mul(v));
3789        if let Some(z) = free_basis_opt {
3790            pulled = z.t().dot(&pulled);
3791        }
3792        Ok(pulled)
3793    }
3794
3795    fn penalty_transformed(
3796        &self,
3797        qs: &Array2<f64>,
3798        free_basis_opt: Option<&Array2<f64>>,
3799    ) -> Result<Array2<f64>, EstimationError> {
3800        let dense = self.materialize_dense();
3801        let mut transformed = qs.t().dot(dense).dot(qs);
3802        if let Some(z) = free_basis_opt {
3803            transformed = z.t().dot(&transformed).dot(z);
3804        }
3805        Ok(transformed)
3806    }
3807
3808    fn penalty_scaled_add_to(
3809        &self,
3810        target: &mut Array2<f64>,
3811        amp: f64,
3812    ) -> Result<(), EstimationError> {
3813        let dense = self.materialize_dense();
3814        if target.raw_dim() != dense.raw_dim() {
3815            crate::bail_invalid_estim!(
3816                "latent-coordinate hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3817                target.nrows(),
3818                target.ncols(),
3819                dense.nrows(),
3820                dense.ncols()
3821            );
3822        }
3823        target.scaled_add(amp, dense);
3824        Ok(())
3825    }
3826}
3827
3828#[derive(Clone)]
3829pub struct HyperDesignDerivative {
3830    pub(crate) storage: DerivativeMatrixStorage,
3831}
3832
3833impl HyperDesignDerivative {
3834    pub fn zero(nrows: usize, ncols: usize) -> Self {
3835        Self {
3836            storage: DerivativeMatrixStorage::Zero(ZeroDerivativeMatrix::new(nrows, ncols)),
3837        }
3838    }
3839
3840    pub fn from_embedded(
3841        local: Array2<f64>,
3842        global_range: Range<usize>,
3843        total_cols: usize,
3844    ) -> Self {
3845        Self {
3846            storage: DerivativeMatrixStorage::Embedded(EmbeddedDerivativeMatrix::new(
3847                local,
3848                global_range,
3849                total_cols,
3850            )),
3851        }
3852    }
3853
3854    pub fn from_implicit(
3855        operator: std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3856        level: ImplicitDerivLevel,
3857        global_range: Range<usize>,
3858        total_cols: usize,
3859    ) -> Self {
3860        Self {
3861            storage: DerivativeMatrixStorage::Implicit(ImplicitDerivativeOp {
3862                operator,
3863                level,
3864                global_range,
3865                total_dim: total_cols,
3866                cached_dense: std::sync::Arc::new(gam_runtime::resource::RayonSafeOnce::new()),
3867            }),
3868        }
3869    }
3870
3871    pub fn from_latent_coord(
3872        operator: std::sync::Arc<gam_terms::basis::LatentCoordDesignDerivative>,
3873        flat_axis: usize,
3874        global_range: Range<usize>,
3875        total_cols: usize,
3876    ) -> Self {
3877        Self {
3878            storage: DerivativeMatrixStorage::LatentCoord(LatentCoordDerivativeOp {
3879                operator,
3880                flat_axis,
3881                global_range,
3882                total_dim: total_cols,
3883                cached_dense: std::sync::Arc::new(gam_runtime::resource::RayonSafeOnce::new()),
3884            }),
3885        }
3886    }
3887
3888    pub(crate) fn resident_byte_count(&self) -> usize {
3889        storage_dispatch!(&self.storage, b => b.resident_byte_count())
3890    }
3891
3892    pub(crate) fn nrows(&self) -> usize {
3893        storage_dispatch!(&self.storage, b => b.design_nrows())
3894    }
3895
3896    pub(crate) fn ncols(&self) -> usize {
3897        storage_dispatch!(&self.storage, b => b.design_ncols())
3898    }
3899
3900    pub(crate) fn uses_implicit_storage(&self) -> bool {
3901        storage_dispatch!(&self.storage, b => b.uses_implicit_storage())
3902    }
3903
3904    pub(crate) fn materialize(&self) -> Array2<f64> {
3905        storage_dispatch!(&self.storage, b => b.materialize())
3906    }
3907
3908    pub(crate) fn any_nonzero(&self) -> bool {
3909        storage_dispatch!(&self.storage, b => b.any_nonzero())
3910    }
3911
3912    pub(crate) fn forward_mul_original(
3913        &self,
3914        u: &Array1<f64>,
3915    ) -> Result<Array1<f64>, EstimationError> {
3916        storage_dispatch!(&self.storage, b => b.design_forward_mul_original(u))
3917    }
3918
3919    pub(crate) fn transpose_mul_original(
3920        &self,
3921        v: &Array1<f64>,
3922    ) -> Result<Array1<f64>, EstimationError> {
3923        storage_dispatch!(&self.storage, b => b.design_transpose_mul_original(v))
3924    }
3925
3926    pub(crate) fn transformed(
3927        &self,
3928        qs: &Array2<f64>,
3929        free_basis_opt: Option<&Array2<f64>>,
3930    ) -> Result<Array2<f64>, EstimationError> {
3931        storage_dispatch!(&self.storage, b => b.design_transformed(qs, free_basis_opt))
3932    }
3933
3934    pub(crate) fn transformed_forward_mul(
3935        &self,
3936        qs: &Array2<f64>,
3937        free_basis_opt: Option<&Array2<f64>>,
3938        u: &Array1<f64>,
3939    ) -> Result<Array1<f64>, EstimationError> {
3940        storage_dispatch!(&self.storage, b => b.design_transformed_forward_mul(qs, free_basis_opt, u))
3941    }
3942
3943    pub(crate) fn transformed_transpose_mul(
3944        &self,
3945        qs: &Array2<f64>,
3946        free_basis_opt: Option<&Array2<f64>>,
3947        v: &Array1<f64>,
3948    ) -> Result<Array1<f64>, EstimationError> {
3949        storage_dispatch!(&self.storage, b => b.design_transformed_transpose_mul(qs, free_basis_opt, v))
3950    }
3951
3952    /// If this derivative uses implicit storage at the first-derivative level,
3953    /// return the shared implicit operator and the axis index.
3954    ///
3955    /// Returns `None` for dense/embedded storage or for second-derivative levels.
3956    pub(crate) fn implicit_first_axis_info(
3957        &self,
3958    ) -> Option<(
3959        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3960        usize,
3961    )> {
3962        storage_dispatch!(&self.storage, b => b.implicit_first_axis_info())
3963    }
3964
3965    pub(crate) fn implicit_axis_count_hint(&self) -> Option<usize> {
3966        storage_dispatch!(&self.storage, b => b.implicit_axis_count_hint())
3967    }
3968}
3969
3970impl From<Array2<f64>> for HyperDesignDerivative {
3971    fn from(value: Array2<f64>) -> Self {
3972        Self {
3973            storage: DerivativeMatrixStorage::Dense(value),
3974        }
3975    }
3976}
3977
3978#[derive(Clone)]
3979pub struct HyperPenaltyDerivative {
3980    pub(crate) storage: DerivativeMatrixStorage,
3981}
3982
3983impl HyperPenaltyDerivative {
3984    pub fn from_embedded(
3985        local: Array2<f64>,
3986        global_range: Range<usize>,
3987        total_dim: usize,
3988    ) -> Self {
3989        Self {
3990            storage: DerivativeMatrixStorage::Embedded(EmbeddedDerivativeMatrix::new(
3991                local,
3992                global_range,
3993                total_dim,
3994            )),
3995        }
3996    }
3997
3998    pub(crate) fn resident_byte_count(&self) -> usize {
3999        storage_dispatch!(&self.storage, b => b.resident_byte_count())
4000    }
4001
4002    pub(crate) fn nrows(&self) -> usize {
4003        storage_dispatch!(&self.storage, b => b.penalty_dim())
4004    }
4005
4006    pub(crate) fn ncols(&self) -> usize {
4007        self.nrows()
4008    }
4009
4010    pub(crate) fn scaled_materialize(&self, amp: f64) -> Array2<f64> {
4011        let mut out = Array2::<f64>::zeros((self.nrows(), self.ncols()));
4012        self.scaled_add_to(&mut out, amp)
4013            .expect("scaled materialize uses matching target shape");
4014        out
4015    }
4016
4017    pub(crate) fn transformed(
4018        &self,
4019        qs: &Array2<f64>,
4020        free_basis_opt: Option<&Array2<f64>>,
4021    ) -> Result<Array2<f64>, EstimationError> {
4022        storage_dispatch!(&self.storage, b => b.penalty_transformed(qs, free_basis_opt))
4023    }
4024
4025    pub(crate) fn scaled_add_to(
4026        &self,
4027        target: &mut Array2<f64>,
4028        amp: f64,
4029    ) -> Result<(), EstimationError> {
4030        storage_dispatch!(&self.storage, b => b.penalty_scaled_add_to(target, amp))
4031    }
4032}
4033
4034impl From<Array2<f64>> for HyperPenaltyDerivative {
4035    fn from(value: Array2<f64>) -> Self {
4036        Self {
4037            storage: DerivativeMatrixStorage::Dense(value),
4038        }
4039    }
4040}
4041
4042#[derive(Clone)]
4043pub struct PenaltyDerivativeComponent {
4044    pub penalty_index: usize,
4045    pub matrix: HyperPenaltyDerivative,
4046}
4047
4048#[derive(Clone)]
4049pub struct DirectionalHyperParam {
4050    pub(crate) x_tau_original: HyperDesignDerivative,
4051    // Canonical penalty representation: every tau direction is decomposed into
4052    // base-penalty derivatives. There is no separate "assembled total" path.
4053    pub(crate) penalty_first_components: Vec<PenaltyDerivativeComponent>,
4054    // Optional pairwise second hyper-derivatives against all tau directions.
4055    // If provided, each vector must have length psi_dim and hold an optional
4056    // X_{tau_i,tau_j} entry in original coordinates.
4057    pub(crate) x_tau_tau_original: Option<Vec<Option<HyperDesignDerivative>>>,
4058    // Pairwise second derivatives are stored in the same canonical base-penalty
4059    // decomposition as the first derivatives.
4060    pub(crate) penaltysecond_components: Option<Vec<Option<Vec<PenaltyDerivativeComponent>>>>,
4061    pub(crate) penaltysecond_component_provider: Option<
4062        std::sync::Arc<
4063            dyn Fn(usize) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError>
4064                + Send
4065                + Sync
4066                + 'static,
4067        >,
4068    >,
4069    pub(crate) penaltysecond_partner_indices: Option<std::sync::Arc<[usize]>>,
4070    /// Whether this coordinate is penalty-like (B_i = ∂H/∂τ_i is PSD).
4071    /// True for τ (penalty scaling) coordinates; false for ψ (design-moving,
4072    /// anisotropic length-scale) coordinates. Controls EFS eligibility.
4073    pub(crate) is_penalty_like: bool,
4074}
4075
4076impl DirectionalHyperParam {
4077    pub(crate) fn resident_byte_count(&self) -> usize {
4078        let mut bytes = self.x_tau_original.resident_byte_count();
4079        for component in &self.penalty_first_components {
4080            bytes = bytes.saturating_add(component.matrix.resident_byte_count());
4081        }
4082        if let Some(entries) = self.x_tau_tau_original.as_ref() {
4083            for entry in entries.iter().flatten() {
4084                bytes = bytes.saturating_add(entry.resident_byte_count());
4085            }
4086        }
4087        if let Some(rows) = self.penaltysecond_components.as_ref() {
4088            for components in rows.iter().flatten() {
4089                for component in components {
4090                    bytes = bytes.saturating_add(component.matrix.resident_byte_count());
4091                }
4092            }
4093        }
4094        bytes
4095    }
4096
4097    pub(crate) fn canonicalize_penalty_components(
4098        components: Vec<(usize, HyperPenaltyDerivative)>,
4099    ) -> Result<Vec<PenaltyDerivativeComponent>, EstimationError> {
4100        let mut out: Vec<PenaltyDerivativeComponent> = Vec::with_capacity(components.len());
4101        for (penalty_index, matrix) in components {
4102            if out.iter().any(|c| c.penalty_index == penalty_index) {
4103                crate::bail_invalid_estim!(
4104                    "duplicate penalty derivative component for penalty {}",
4105                    penalty_index
4106                );
4107            }
4108            out.push(PenaltyDerivativeComponent {
4109                penalty_index,
4110                matrix,
4111            });
4112        }
4113        Ok(out)
4114    }
4115
4116    pub fn new_compact(
4117        x_tau_original: HyperDesignDerivative,
4118        penalty_first_components: Vec<(usize, HyperPenaltyDerivative)>,
4119        x_tau_tau_original: Option<Vec<Option<HyperDesignDerivative>>>,
4120        penaltysecond_components: Option<Vec<Option<Vec<(usize, HyperPenaltyDerivative)>>>>,
4121    ) -> Result<Self, EstimationError> {
4122        let is_penalty_like = !x_tau_original.any_nonzero();
4123        let penalty_first_components =
4124            Self::canonicalize_penalty_components(penalty_first_components)?;
4125        let penaltysecond_components = match penaltysecond_components {
4126            Some(rows) => {
4127                let mut out = Vec::with_capacity(rows.len());
4128                for row in rows {
4129                    out.push(match row {
4130                        Some(components) => {
4131                            Some(Self::canonicalize_penalty_components(components)?)
4132                        }
4133                        None => None,
4134                    });
4135                }
4136                Some(out)
4137            }
4138            None => None,
4139        };
4140        Ok(Self {
4141            x_tau_original,
4142            penalty_first_components,
4143            x_tau_tau_original,
4144            penaltysecond_components,
4145            penaltysecond_component_provider: None,
4146            penaltysecond_partner_indices: None,
4147            is_penalty_like,
4148        })
4149    }
4150
4151    /// Mark this coordinate as non-penalty-like (design-moving).
4152    /// EFS will skip it; use Newton/BFGS for these coordinates.
4153    pub fn not_penalty_like(mut self) -> Self {
4154        self.is_penalty_like = false;
4155        self
4156    }
4157
4158    pub fn with_penaltysecond_component_provider(
4159        mut self,
4160        provider: std::sync::Arc<
4161            dyn Fn(usize) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError>
4162                + Send
4163                + Sync
4164                + 'static,
4165        >,
4166    ) -> Self {
4167        self.penaltysecond_component_provider = Some(provider);
4168        self
4169    }
4170
4171    pub fn with_penaltysecond_partner_indices(mut self, partners: Vec<usize>) -> Self {
4172        self.penaltysecond_partner_indices = Some(std::sync::Arc::from(partners));
4173        self
4174    }
4175
4176    pub(crate) fn x_tau_dense(&self) -> Array2<f64> {
4177        self.x_tau_original.materialize()
4178    }
4179
4180    pub(crate) fn transformed_x_tau(
4181        &self,
4182        qs: &Array2<f64>,
4183        free_basis_opt: Option<&Array2<f64>>,
4184    ) -> Result<Array2<f64>, EstimationError> {
4185        self.x_tau_original.transformed(qs, free_basis_opt)
4186    }
4187
4188    pub(crate) fn x_tau_tau_entry_at(&self, j: usize) -> Option<HyperDesignDerivative> {
4189        self.x_tau_tau_original
4190            .as_ref()
4191            .and_then(|rows| rows.get(j))
4192            .and_then(|entry| entry.clone())
4193    }
4194
4195    /// Whether this coordinate's design derivative uses implicit storage at the
4196    /// first-derivative level.
4197    pub(crate) fn has_implicit_operator(&self) -> bool {
4198        self.x_tau_original.uses_implicit_storage()
4199    }
4200
4201    pub(crate) fn has_implicit_multidim_duchon(&self) -> bool {
4202        self.implicit_first_axis_info()
4203            .is_some_and(|(op, _)| op.n_axes() > 1 && op.is_duchon_family())
4204    }
4205
4206    /// Extract the implicit design derivative operator and axis, if available.
4207    pub(crate) fn implicit_first_axis_info(
4208        &self,
4209    ) -> Option<(
4210        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
4211        usize,
4212    )> {
4213        self.x_tau_original.implicit_first_axis_info()
4214    }
4215
4216    pub(crate) fn implicit_axis_count_hint(&self) -> Option<usize> {
4217        self.x_tau_original.implicit_axis_count_hint()
4218    }
4219
4220    pub(crate) fn penalty_first_components(&self) -> &[PenaltyDerivativeComponent] {
4221        &self.penalty_first_components
4222    }
4223
4224    pub(crate) fn penalty_total_at(
4225        &self,
4226        rho: &Array1<f64>,
4227        p: usize,
4228    ) -> Result<Array2<f64>, EstimationError> {
4229        let mut out = Array2::<f64>::zeros((p, p));
4230        for component in &self.penalty_first_components {
4231            if component.matrix.nrows() != p || component.matrix.ncols() != p {
4232                crate::bail_invalid_estim!(
4233                    "S_tau shape mismatch for penalty {}: expected {}x{}, got {}x{}",
4234                    component.penalty_index,
4235                    p,
4236                    p,
4237                    component.matrix.nrows(),
4238                    component.matrix.ncols()
4239                );
4240            }
4241            if component.penalty_index >= rho.len() {
4242                crate::bail_invalid_estim!(
4243                    "penalty_index {} out of bounds for rho dimension {}",
4244                    component.penalty_index,
4245                    rho.len()
4246                );
4247            }
4248            component
4249                .matrix
4250                .scaled_add_to(&mut out, rho[component.penalty_index].exp())?;
4251        }
4252        Ok(out)
4253    }
4254
4255    pub(crate) fn penaltysecond_components_for(
4256        &self,
4257        j: usize,
4258    ) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError> {
4259        if let Some(components) = self
4260            .penaltysecond_components
4261            .as_ref()
4262            .and_then(|rows| rows.get(j))
4263            .and_then(|row| row.clone())
4264        {
4265            return Ok(Some(components));
4266        }
4267        if let Some(provider) = self.penaltysecond_component_provider.as_ref() {
4268            return provider(j);
4269        }
4270        Ok(None)
4271    }
4272
4273    pub(crate) fn penaltysecond_componentrows(
4274        &self,
4275    ) -> Option<&[Option<Vec<PenaltyDerivativeComponent>>]> {
4276        self.penaltysecond_components.as_deref()
4277    }
4278
4279    pub(crate) fn penalty_first_component_count(&self) -> usize {
4280        self.penalty_first_components.len()
4281    }
4282
4283    pub(crate) fn has_penaltysecond_pair_at(&self, j: usize) -> bool {
4284        self.penaltysecond_components
4285            .as_ref()
4286            .and_then(|rows| rows.get(j))
4287            .is_some_and(Option::is_some)
4288            || self
4289                .penaltysecond_partner_indices
4290                .as_ref()
4291                .is_some_and(|partners| partners.contains(&j))
4292    }
4293}
4294
4295#[derive(Clone, Debug)]
4296pub(crate) struct SparseRemlDecision {
4297    pub(crate) geometry: RemlGeometry,
4298    pub(crate) reason: &'static str,
4299    pub(crate) p: usize,
4300    pub(crate) nnz_x: usize,
4301    pub(crate) nnz_h_upper_est: Option<usize>,
4302    pub(crate) density_h_upper_est: Option<f64>,
4303}
4304
4305#[derive(Clone)]
4306pub(crate) struct SparseExactEvalData {
4307    pub(crate) factor: Arc<SparseExactFactor>,
4308    pub(crate) takahashi: Option<Arc<gam_linalg::sparse_exact::TakahashiInverse>>,
4309    pub(crate) logdet_h: f64,
4310    pub(crate) logdet_s_pos: f64,
4311    pub(crate) penalty_rank: usize,
4312    pub(crate) det1_values: Arc<Array1<f64>>,
4313}
4314
4315#[derive(Clone)]
4316pub struct FirthDenseOperator {
4317    // Exact Firth/Jeffreys objects on the identifiable subspace.
4318    //
4319    // Let X in R^{n×p} potentially be rank-deficient with rank r.
4320    // With optional fixed observation weights a_i >= 0 we define A = diag(a),
4321    // choose an orthonormal coefficient-space basis Q for the identifiable
4322    // subspace of A^{1/2} X, and set:
4323    //   X_r := A^{1/2} X Q          (A = I when no fixed observation weights),
4324    //   W   := diag(w), with w_i = mu_i (1 - mu_i), 0 < w_i <= 1/4 for finite logit eta,
4325    //   I_r := X_rᵀ W X_r,
4326    //   S_r := X_rᵀ X_r.
4327    //
4328    // Firth term is represented as:
4329    //   Phi(beta) = 0.5 log |I_r(beta)| - 0.5 log |S_r|,
4330    // which is exactly
4331    //   0.5 log |Uᵀ W U|
4332    // for the canonical orthonormalized identifiable design
4333    //   U = X_r S_r^{-1/2}.
4334    // This removes the raw-basis term from explicit reduced designs while
4335    // keeping the same identifiable-subspace hat matrix and beta derivatives,
4336    // because S_r is fixed with respect to beta.
4337    //
4338    // Mapping back to the full p-space uses:
4339    //   I_+^dagger = Q I_r^{-1} Qᵀ.
4340    //
4341    // We store reduced-space factors so all derivatives can be evaluated exactly
4342    // without materializing dense n×n matrices M = X K Xᵀ or P = M⊙M.
4343    pub(crate) x_dense: Array2<f64>,
4344    pub(crate) x_dense_t: Array2<f64>,
4345    // Orthonormal coefficient-space basis for the identifiable subspace,
4346    // built from the retained eigenspace of (A^{1/2} X)ᵀ(A^{1/2} X).
4347    pub(crate) q_basis: Array2<f64>,
4348    // Reduced identifiable design. With fixed observation weights a_i this is
4349    // diag(sqrt(a_i)) X Q; otherwise it is X Q.
4350    pub(crate) x_reduced: Array2<f64>,
4351    // Optional fixed case-weight square roots used when the Jeffreys/Firth
4352    // operator is formed from Xᵀ diag(case_weight ⊙ w(η)) X rather than
4353    // Xᵀ diag(w(η)) X. The exact directional tau derivatives must project and
4354    // row-scale with the same weights so the reduced Fisher, hat diagonals,
4355    // and tau kernels all live on one consistent identifiable subspace.
4356    pub(crate) observation_weight_sqrt: Option<Array1<f64>>,
4357    // I_r^{-1}
4358    pub(crate) k_reduced: Array2<f64>,
4359    // diag(S_r^{-1}) with S_r = X_rᵀ X_r. In the current canonical reduced
4360    // basis this completely characterizes the metric inverse, because Q
4361    // diagonalizes the design Gram. It is used to remove the reduced-coordinate
4362    // basis term from Phi_tau when the design moves.
4363    pub(crate) x_metric_reduced_inv_diag: Array1<f64>,
4364    // 0.5 (log|I_r| - log|S_r|) at the current eta.
4365    pub(crate) half_log_det: f64,
4366    // h = diag(M), M = X_r K_r X_r'
4367    pub(crate) h_diag: Array1<f64>,
4368    // Logistic Fisher-weight eta-derivatives: w', w'', w''', w'''' as n-vectors.
4369    pub(crate) w: Array1<f64>,
4370    pub(crate) w1: Array1<f64>,
4371    pub(crate) w2: Array1<f64>,
4372    pub(crate) w3: Array1<f64>,
4373    pub(crate) w4: Array1<f64>,
4374    // B = diag(w') X used in D Hphi and D^2 Hphi contractions.
4375    pub(crate) b_base: Array2<f64>,
4376    // Cached invariant contraction P*B where P = (X_r K_r X_r') ⊙ (X_r K_r X_r').
4377    // This avoids recomputing the same O(n r^2 p) block in every directional call.
4378    pub(crate) p_b_base: Array2<f64>,
4379}
4380
4381#[derive(Clone)]
4382pub(crate) struct FirthDirection {
4383    pub(crate) deta: Array1<f64>,
4384    pub(crate) g_u_reduced: Array2<f64>,
4385    pub(crate) a_u_reduced: Array2<f64>,
4386    pub(crate) dh: Array1<f64>,
4387    // B_u = diag(w'' ⊙ δη_u) X is represented by the row-scaling vector only.
4388    pub(crate) b_uvec: Array1<f64>,
4389}
4390
4391#[derive(Clone)]
4392pub(crate) struct FirthTauPartialKernel {
4393    pub(super) deta_partial: Array1<f64>,
4394    pub(crate) dotw1: Array1<f64>,
4395    pub(crate) dotw2: Array1<f64>,
4396    pub(crate) dot_h_partial: Array1<f64>,
4397    // Reduced design drift X_{tau,r} = X_tau Q used in exact design-moving
4398    // Hadamard-Gram contractions.
4399    pub(crate) x_tau_reduced: Array2<f64>,
4400    pub(super) dot_i_partial: Array2<f64>,
4401    // Reduced Fisher inverse drift:
4402    //   dot(K_r) = -K_r dot(I_r) K_r
4403    // where dot(I_r) includes explicit X_tau and weight drift at beta-fixed.
4404    pub(crate) dot_k_reduced: Array2<f64>,
4405}
4406
4407#[derive(Clone)]
4408pub(crate) struct FirthTauExactKernel {
4409    pub(crate) gphi_tau: Array1<f64>,
4410    pub(crate) phi_tau_partial: f64,
4411    pub(crate) tau_kernel: Option<FirthTauPartialKernel>,
4412}
4413
4414/// Pair-level (τ_i × τ_j) exact Firth bundle at fixed β.
4415///
4416/// Mirrors `FirthTauExactKernel` but for the 2nd-order cross
4417/// derivatives:
4418///   Phi_{τ_i τ_j}|β  (scalar, `phi_tau_tau_partial`)
4419///   (gphi)_{τ_i τ_j}|β (p-vector, `gphi_tau_tau`)
4420///
4421/// Carries an optional `tau_tau_kernel` so pair callbacks can chain
4422/// into Primitive A (`hphi_tau_tau_partial_apply`) for the operator-
4423/// valued Hessian 2nd drift without recomputing shared reduced Grams.
4424///
4425#[derive(Clone)]
4426pub(crate) struct FirthTauTauExactKernel {
4427    pub(super) phi_tau_tau_partial: f64,
4428    pub(super) gphi_tau_tau: Array1<f64>,
4429    pub(super) tau_tau_kernel: Option<FirthTauTauPartialKernel>,
4430}
4431
4432/// Prepared state for `∂²H_φ/∂τ_i ∂τ_j |_β` (Primitive A).
4433///
4434/// Carries both τ-direction reduced designs, their η̇ vectors, and the
4435/// reduced-coordinate drifts (İ, K̇, ḣ) for i and j so the apply step can
4436/// form M̈_{ij}, K̈_{ij}, ḧ_{ij}, Γ̈_{ij}, and B̈_{ij} matrix-free.  Fields
4437/// are filled in by 13b; kept with a neutral internal shape so downstream
4438/// pair callbacks can hold the kernel across the pair dispatch.
4439///
4440/// Wired into the pair-callback's `b_operator` via
4441/// `FirthAugmentedPairHyperOperator`, and produced by both
4442/// `hphi_tau_tau_partial_prepare_from_partials` and
4443/// `exact_tau_tau_kernel` (the scalar/p-vector companion).
4444#[derive(Clone, Default)]
4445pub(crate) struct FirthTauTauPartialKernel {
4446    pub(super) x_tau_i_reduced: Array2<f64>,
4447    pub(super) x_tau_j_reduced: Array2<f64>,
4448    pub(super) deta_i_partial: Array1<f64>,
4449    pub(super) deta_j_partial: Array1<f64>,
4450    pub(super) dot_h_i_partial: Array1<f64>,
4451    pub(super) dot_h_j_partial: Array1<f64>,
4452    pub(super) dot_k_i_reduced: Array2<f64>,
4453    pub(super) dot_k_j_reduced: Array2<f64>,
4454    pub(super) dot_i_i_partial: Array2<f64>,
4455    pub(super) dot_i_j_partial: Array2<f64>,
4456    pub(super) x_tau_tau_reduced: Option<Array2<f64>>,
4457    pub(super) deta_ij_partial: Option<Array1<f64>>,
4458}
4459
4460/// Prepared state for `D_β((H_φ)_τ|_β)[v]` (Primitive B).
4461///
4462/// Carries the τ-kernel pieces (x_tau_reduced, İ, K̇, ḣ), the
4463/// β-direction quantities (δη_v, A_v, dh_v, b-chain), and the mixed
4464/// β-τ pieces (D_β(K̇_τ)[v], D_β(ḣ_τ)[v], δη_{τ,v}) so the apply
4465/// step collapses to the 9-term β-τ expansion without recomputing
4466/// shared reduced Grams.
4467#[derive(Clone, Default)]
4468pub(crate) struct FirthTauBetaPartialKernel {
4469    pub(super) x_tau_reduced: Array2<f64>,
4470    pub(super) deta_partial: Array1<f64>,
4471    pub(super) dot_h_partial: Array1<f64>,
4472    pub(super) dot_i_partial: Array2<f64>,
4473    pub(super) dot_k_reduced: Array2<f64>,
4474    pub(super) deta_v: Array1<f64>,
4475    pub(super) deta_tau_v: Array1<f64>,
4476    pub(super) a_v_reduced: Array2<f64>,
4477    pub(super) dh_v: Array1<f64>,
4478    pub(super) b_vvec: Array1<f64>,
4479    pub(super) d_beta_dot_k: Array2<f64>,
4480    pub(super) d_beta_dot_h: Array1<f64>,
4481}
4482
4483/// Holds the state for the outer REML optimization and supplies cost and
4484/// gradient evaluations to the `opt` optimizer.
4485///
4486/// The `cache` field uses `RefCell` to enable interior mutability. This is a crucial
4487/// performance optimization. The `cost_andgrad` closure required by the BFGS
4488/// optimizer takes an immutable reference `&self`. However, we want to cache the
4489/// results of the expensive P-IRLS computation to avoid re-calculating the fit
4490/// for the same `rho` vector, which can happen during the line search.
4491/// `RefCell` allows us to mutate the cache through a `&self` reference,
4492/// making this optimization possible while adhering to the optimizer's API.
4493#[derive(Clone)]
4494pub(crate) struct EvalShared {
4495    pub(crate) key: Option<Vec<u64>>,
4496    pub(crate) pirls_result: Arc<PirlsResult>,
4497    pub(crate) ridge_passport: RidgePassport,
4498    pub(crate) geometry: RemlGeometry,
4499    /// The exact H_total matrix used for LAML cost computation.
4500    /// For Firth: effective Hessian minus hphi (plus any barrier curvature).
4501    /// For non-Firth: the effective Hessian itself (plus any barrier curvature).
4502    pub(crate) h_total: Arc<Array2<f64>>,
4503    pub(crate) sparse_exact: Option<Arc<SparseExactEvalData>>,
4504    pub(crate) firth_dense_operator: Option<Arc<FirthDenseOperator>>,
4505    /// Cached FirthDenseOperator built from the original (non-reparameterized)
4506    /// design matrix, for use by the sparse evaluation path.
4507    pub(crate) firth_dense_operator_original: Option<Arc<FirthDenseOperator>>,
4508    /// The ONE original-frame penalty pseudo-logdet factorization for this
4509    /// evaluation point (#931 atom discipline). `log|Σ λ_k S_k|₊`'s VALUE,
4510    /// ρ-derivatives, τ/ψ components, and ρ×τ cross blocks are all
4511    /// contractions of this single eigendecomposition; the ρ-side criterion
4512    /// assembly (`dense_penalty_logdet_derivs`, the sparse det2 path) and the
4513    /// original-basis hyper-coordinate builders share it through
4514    /// [`EvalShared::penalty_pseudologdet_original`]. Building a second
4515    /// factorization of the same Sλ for the same evaluation point is the
4516    /// objective↔gradient desync surface (#748/#752/#901) this cell removes:
4517    /// the ridge and positive-eigenspace threshold are decided exactly once.
4518    /// (The transformed-frame pair-callback path builds its own object — it
4519    /// factorizes the canonical-TRANSFORMED, possibly constraint-projected
4520    /// penalties, a genuinely different matrix, not a duplicate of this one.)
4521    pub(crate) penalty_pseudologdet: std::sync::OnceLock<Arc<penalty_logdet::PenaltyPseudologdet>>,
4522    /// Per-evaluation-point cache of the canonical penalty score vectors
4523    /// `S_k β̂` evaluated at this bundle's inner mode `β̂ =
4524    /// pirls_result.beta_transformed` (unscaled by λ_k). These depend ONLY
4525    /// on the inner solution carried by this bundle and the `RemlState`'s
4526    /// fixed `canonical_penalties` — never on which ρ-coordinate or eval
4527    /// mode the assembly is running — so they are computed exactly once per
4528    /// inner solution and shared by every assemble call that reuses the
4529    /// bundle (cost + gradient evaluations at the same ρ, EFS, synthetic-ext
4530    /// value probes). Exact hoist, not an approximation: every consumer sees
4531    /// literally the same vectors it previously recomputed. Initialized via
4532    /// plain ndarray matvecs (no rayon inside the `OnceLock` closure — the
4533    /// `get_or_init`+`into_par_iter` deadlock trap does not apply).
4534    pub(crate) penalty_scores_at_mode: std::sync::OnceLock<Arc<Vec<Array1<f64>>>>,
4535    /// Per-evaluation-point cache of the #784 block-local Laplace-to-sampling
4536    /// correction `TkCorrectionTerms { value, gradient }`. The correction is a
4537    /// deterministic function of ONLY this bundle's converged inner state
4538    /// (`pirls_result`, `h_total`), the `RemlState`'s fixed
4539    /// `canonical_penalties`, and the bundle's ρ — never of the eval `mode`:
4540    /// the diagnostic eigendecomposition, the fixed-seed importance sampler,
4541    /// and the (b)–(d) gradient channels all read mode-invariant fields, and
4542    /// the term carries no Hessian, so the value+gradient are identical for the
4543    /// value-only, value+gradient, and value+gradient+Hessian assemble calls
4544    /// that share this bundle at a single ρ. The expensive path (eigendecomp +
4545    /// O(draws·n·m) sampler) previously reran on every one of those 2–3 calls
4546    /// per outer iteration; hoisting it onto the bundle computes it exactly
4547    /// once per inner solution (exact hoist, identical values — #784, #1082).
4548    /// Keyed only on the external-coordinate count `n_ext`: with no ψ
4549    /// coordinates (`n_ext == 0`) the correction engages; with ψ present the
4550    /// seam declines (returns the cheap zero), and n_ext is fixed for a fit, so
4551    /// a single cell suffices.
4552    pub(crate) block_local_correction:
4553        std::sync::OnceLock<(usize, Arc<outer_eval::TkCorrectionTerms>)>,
4554}
4555
4556impl EvalShared {
4557    pub(crate) fn matches(&self, key: &Option<Vec<u64>>) -> bool {
4558        match (&self.key, key) {
4559            (None, None) => true,
4560            (Some(a), Some(b)) => a == b,
4561            _ => false,
4562        }
4563    }
4564
4565    /// Lazily build — once per evaluation point — the original-frame
4566    /// [`PenaltyPseudologdet`](penalty_logdet::PenaltyPseudologdet) of
4567    /// `Σ λ_k S_k` and hand every caller the SAME factorization.
4568    ///
4569    /// This is the #931 port of the penalty-logdet term: value, ρ-first /
4570    /// ρ-second derivatives, τ-gradient components, τ×τ and ρ×τ Hessian
4571    /// blocks are all projections of one eigendecomposition, so no pair of
4572    /// consumers can disagree about the ridge or the positive-eigenspace
4573    /// threshold. The ridge is read from this bundle's `ridge_passport` —
4574    /// the single place that convention is decided.
4575    ///
4576    /// `lambdas` must be the λ = exp(ρ) vector of this bundle's evaluation
4577    /// point and `p` the original-basis coefficient dimension; on a cache
4578    /// hit both are checked against the stored object where representable.
4579    pub(crate) fn penalty_pseudologdet_original(
4580        &self,
4581        canonical_penalties: &[gam_terms::construction::CanonicalPenalty],
4582        lambdas: &[f64],
4583        p: usize,
4584    ) -> Result<Arc<penalty_logdet::PenaltyPseudologdet>, EstimationError> {
4585        if let Some(pld) = self.penalty_pseudologdet.get() {
4586            if pld.dim() != p {
4587                return Err(EstimationError::LayoutError(format!(
4588                    "shared penalty pseudo-logdet frame mismatch: cached p={}, requested p={}",
4589                    pld.dim(),
4590                    p
4591                )));
4592            }
4593            return Ok(Arc::clone(pld));
4594        }
4595        let pld = Arc::new(
4596            penalty_logdet::PenaltyPseudologdet::from_penalties(
4597                canonical_penalties,
4598                lambdas,
4599                self.ridge_passport.penalty_logdet_ridge(),
4600                p,
4601            )
4602            .map_err(EstimationError::InvalidInput)?,
4603        );
4604        match self.penalty_pseudologdet.set(Arc::clone(&pld)) {
4605            Ok(()) => Ok(pld),
4606            // A concurrent caller initialized the cell first; both objects
4607            // were built from identical inputs — return the canonical winner
4608            // so every consumer holds literally the same factorization.
4609            Err(_) => Ok(Arc::clone(
4610                self.penalty_pseudologdet
4611                    .get()
4612                    .expect("OnceLock set raced, so it is initialized"),
4613            )),
4614        }
4615    }
4616}
4617
4618impl PenalizedGeometry for EvalShared {
4619    fn backend_kind(&self) -> GeometryBackendKind {
4620        match self.geometry {
4621            RemlGeometry::DenseSpectral => GeometryBackendKind::DenseSpectral,
4622            RemlGeometry::SparseExactSpd => GeometryBackendKind::SparseExactSpd,
4623        }
4624    }
4625}
4626
4627/// LRU cache keyed by sanitized ρ vectors that holds compacted PIRLS results
4628/// for warm-starting outer line searches and revisited evaluations.
4629///
4630/// Eviction is byte-budgeted rather than entry-count-budgeted: each entry
4631/// records its own estimated footprint (the surviving n-length vectors plus
4632/// the two p×p Hessians plus per-entry overhead) and the cache evicts in
4633/// LRU order until the running total fits under the budget. An entry that
4634/// individually exceeds the budget is rejected silently rather than poisoning
4635/// the cache.
4636pub(crate) struct PirlsLruCache {
4637    // Stored tuple: (compacted result, last-touched clock, estimated bytes).
4638    pub(crate) map: HashMap<Vec<u64>, (Arc<PirlsResult>, u64, usize)>,
4639    pub(crate) byte_budget: usize,
4640    pub(crate) current_bytes: usize,
4641    pub(crate) clock: u64,
4642}
4643
4644impl PirlsLruCache {
4645    pub(crate) fn new(byte_budget: usize) -> Self {
4646        Self {
4647            map: HashMap::new(),
4648            byte_budget: byte_budget.max(1),
4649            current_bytes: 0,
4650            clock: 0,
4651        }
4652    }
4653
4654    pub(crate) fn get(&mut self, key: &Vec<u64>) -> Option<Arc<PirlsResult>> {
4655        if let Some(entry) = self.map.get_mut(key) {
4656            self.clock += 1;
4657            entry.1 = self.clock;
4658            Some(entry.0.clone())
4659        } else {
4660            None
4661        }
4662    }
4663
4664    pub(crate) fn insert(&mut self, key: Vec<u64>, value: Arc<PirlsResult>) {
4665        self.clock += 1;
4666        let bytes = pirls_result_cache_bytes(&value);
4667        // Refuse entries that on their own already exceed the entire budget;
4668        // caching one would force eviction of every other entry without
4669        // leaving room for the new one anyway.
4670        if bytes > self.byte_budget {
4671            if let Some((_, _, prev_bytes)) = self.map.remove(&key) {
4672                self.current_bytes = self.current_bytes.saturating_sub(prev_bytes);
4673            }
4674            return;
4675        }
4676        if let Some((_, _, prev_bytes)) = self.map.remove(&key) {
4677            self.current_bytes = self.current_bytes.saturating_sub(prev_bytes);
4678        }
4679        while self.current_bytes + bytes > self.byte_budget {
4680            let evict_key = self
4681                .map
4682                .iter()
4683                .min_by_key(|(_, (_, ts, _))| *ts)
4684                .map(|(k, _)| k.clone());
4685            match evict_key {
4686                Some(k) => {
4687                    if let Some((_, _, evict_bytes)) = self.map.remove(&k) {
4688                        self.current_bytes = self.current_bytes.saturating_sub(evict_bytes);
4689                    }
4690                }
4691                None => break,
4692            }
4693        }
4694        self.current_bytes += bytes;
4695        self.map.insert(key, (value, self.clock, bytes));
4696    }
4697
4698    pub(crate) fn clear(&mut self) {
4699        self.map.clear();
4700        self.current_bytes = 0;
4701    }
4702}
4703
4704#[derive(Clone, Copy, PartialEq, Eq)]
4705pub(crate) struct PenaltySubspaceCacheKey {
4706    pub(crate) penalty_matrix_fingerprint: u64,
4707    pub(crate) ridge_passport_signature: u64,
4708}
4709
4710pub(crate) struct PenaltySubspaceCache {
4711    pub(crate) entry: Option<(PenaltySubspaceCacheKey, Arc<outer_eval::PenaltySubspace>)>,
4712}
4713
4714impl PenaltySubspaceCache {
4715    pub(crate) fn new() -> Self {
4716        Self { entry: None }
4717    }
4718
4719    pub(crate) fn get(
4720        &self,
4721        key: &PenaltySubspaceCacheKey,
4722    ) -> Option<Arc<outer_eval::PenaltySubspace>> {
4723        self.entry
4724            .as_ref()
4725            .filter(|(cached_key, _)| cached_key == key)
4726            .map(|(_, value)| value.clone())
4727    }
4728
4729    pub(crate) fn insert(
4730        &mut self,
4731        key: PenaltySubspaceCacheKey,
4732        value: Arc<outer_eval::PenaltySubspace>,
4733    ) {
4734        self.entry = Some((key, value));
4735    }
4736
4737    pub(crate) fn clear(&mut self) {
4738        self.entry = None;
4739    }
4740}
4741
4742impl PenaltySubspaceCacheKey {
4743    /// Build a cache key from the transformed-E matrix and ridge passport.
4744    /// `E` is hashed by exact f64 bits (column-major), so the key is bit-exact
4745    /// and avoids float-Hash issues; the ridge passport is hashed via its
4746    /// `Hash` impl. Two calls at the same `(E, ridge)` yield equal keys.
4747    pub(crate) fn from_inputs(
4748        e_transformed: &ndarray::Array2<f64>,
4749        ridge_passport: &gam_problem::RidgePassport,
4750    ) -> Self {
4751        use std::collections::hash_map::DefaultHasher;
4752        use std::hash::{Hash, Hasher};
4753        let mut hasher = DefaultHasher::new();
4754        e_transformed.nrows().hash(&mut hasher);
4755        e_transformed.ncols().hash(&mut hasher);
4756        for value in e_transformed.iter() {
4757            value.to_bits().hash(&mut hasher);
4758        }
4759        let penalty_matrix_fingerprint = hasher.finish();
4760        let mut ridge_hasher = DefaultHasher::new();
4761        ridge_passport.delta.to_bits().hash(&mut ridge_hasher);
4762        (ridge_passport.matrix_form as u8).hash(&mut ridge_hasher);
4763        ridge_passport
4764            .policy
4765            .include_penalty_logdet
4766            .hash(&mut ridge_hasher);
4767        ridge_passport
4768            .policy
4769            .include_laplacehessian
4770            .hash(&mut ridge_hasher);
4771        let ridge_passport_signature = ridge_hasher.finish();
4772        Self {
4773            penalty_matrix_fingerprint,
4774            ridge_passport_signature,
4775        }
4776    }
4777}
4778
4779/// Estimate the in-cache footprint of a (compacted) PIRLS result.
4780///
4781/// Mirrors what `compact_for_reml_cache` keeps:
4782/// * six surviving n-length f64 arrays (final_eta, solveweights,
4783///   solveworking_response, solvemu, solve_c_array, solve_d_array);
4784/// * the p-length coefficient vector;
4785/// * the two p×p Hessians (dense or CSC sparse);
4786/// * the `ReparamResult` payload — the dominant scaling term beyond n, since
4787///   it carries `s_transformed`, `qs`, and `e_transformed` as p×p / rank×p
4788///   matrices.
4789/// A small constant overhead absorbs scalar fields, enum discriminants, and
4790/// the HashMap entry. This errs on the conservative side: overestimation
4791/// causes earlier eviction, never under-counting that would let the cache
4792/// silently exceed the byte budget.
4793pub(crate) fn pirls_result_cache_bytes(result: &PirlsResult) -> usize {
4794    use std::mem::size_of;
4795    let n_array_elems = result.final_eta.len()
4796        + result.solveweights.len()
4797        + result.solveworking_response.len()
4798        + result.solvemu.len()
4799        + result.solve_c_array.len()
4800        + result.solve_d_array.len();
4801    let p = result.beta_transformed.0.len();
4802    let pen_h = symmetric_matrix_cache_bytes(&result.penalized_hessian_transformed);
4803    let stab_h = symmetric_matrix_cache_bytes(&result.stabilizedhessian_transformed);
4804    let reparam = (result.reparam_result.s_transformed.len()
4805        + result.reparam_result.qs.len()
4806        + result.reparam_result.e_transformed.len()
4807        + result.reparam_result.det1.len())
4808        * size_of::<f64>();
4809    n_array_elems * size_of::<f64>() + p * size_of::<f64>() + pen_h + stab_h + reparam + 1024
4810}
4811
4812pub(crate) fn symmetric_matrix_cache_bytes(m: &gam_linalg::matrix::SymmetricMatrix) -> usize {
4813    use gam_linalg::matrix::SymmetricMatrix;
4814    use std::mem::size_of;
4815    match m {
4816        SymmetricMatrix::Dense(a) => a.len() * size_of::<f64>(),
4817        SymmetricMatrix::Sparse(s) => {
4818            // CSC sparse: f64 values + usize row indices + usize column pointers.
4819            let (symbolic, values) = s.parts();
4820            values.len() * (size_of::<f64>() + size_of::<usize>())
4821                + std::mem::size_of_val(symbolic.col_ptr())
4822        }
4823    }
4824}
4825
4826/// Capacity (number of distinct rho-points) of the outer-eval reuse LRU.
4827///
4828/// Sized to comfortably span a binomial seed grid's local revisit window
4829/// (baseline + isotropic shifts + per-axis refinements) plus a few
4830/// line-search trial points without unbounded growth. Each slot holds one
4831/// `OuterEval` (a scalar cost, a length-k gradient, an optional inner-beta
4832/// hint, and a usually-`Unavailable` Hessian), so the footprint is tiny.
4833pub(crate) const OUTER_EVAL_LRU_CAPACITY: usize = 8;
4834
4835/// Bounded least-recently-used cache of converged outer REML evaluations,
4836/// keyed by sanitized rho-bits.
4837///
4838/// CORRECTNESS: the key (`Vec<u64>` of `f64::to_bits`, with ±0 canonicalized)
4839/// is the complete result-determining input for a fixed `RemlState`. Every
4840/// other input to `OuterEval` — design matrix, prior weights, offset, penalty
4841/// structure, link/SAS/mixture state, Firth/Jeffreys configuration, and the
4842/// rho-prior — is immutable for the lifetime of the state that owns the cache,
4843/// so the stored cost / gradient / inner-beta hint depend only on rho. A hit
4844/// therefore returns exactly the value a recompute at that rho would converge
4845/// to (to the solver's own tolerance, identical to the trust the pre-existing
4846/// single-slot cache already placed in rho-only keying). Distinct rho-points
4847/// never alias: lookups compare the full key vector.
4848pub(crate) struct OuterEvalLru {
4849    capacity: usize,
4850    /// Front = least-recently-used, back = most-recently-used.
4851    entries: std::collections::VecDeque<(Vec<u64>, OuterEval)>,
4852}
4853
4854impl OuterEvalLru {
4855    pub(crate) fn new(capacity: usize) -> Self {
4856        Self {
4857            capacity: capacity.max(1),
4858            entries: std::collections::VecDeque::new(),
4859        }
4860    }
4861
4862    /// Returns a clone of the eval stored under `key`, if present, promoting it
4863    /// to most-recently-used. A miss returns `None` so the caller recomputes —
4864    /// never a stale value from a different key.
4865    pub(crate) fn get(&mut self, key: &[u64]) -> Option<OuterEval> {
4866        let pos = self
4867            .entries
4868            .iter()
4869            .position(|(k, _)| k.as_slice() == key)?;
4870        let entry = self.entries.remove(pos)?;
4871        let eval = entry.1.clone();
4872        self.entries.push_back(entry);
4873        Some(eval)
4874    }
4875
4876    /// Inserts (or refreshes) the eval for `key` as most-recently-used,
4877    /// evicting the least-recently-used entry once capacity is exceeded.
4878    pub(crate) fn insert(&mut self, key: Vec<u64>, eval: OuterEval) {
4879        if let Some(pos) = self
4880            .entries
4881            .iter()
4882            .position(|(k, _)| k.as_slice() == key.as_slice())
4883        {
4884            self.entries.remove(pos);
4885        }
4886        self.entries.push_back((key, eval));
4887        while self.entries.len() > self.capacity {
4888            self.entries.pop_front();
4889        }
4890    }
4891
4892    pub(crate) fn clear(&mut self) {
4893        self.entries.clear();
4894    }
4895}
4896
4897/// Centralized cache/memoization owner for REML evaluations.
4898///
4899/// This keeps cache-key identity, bundle reuse, and invalidation policy out of
4900/// the math kernels so objective/derivative routines can stay algebra-focused.
4901pub(crate) struct EvalCacheManager {
4902    pub(crate) pirls_cache: RwLock<PirlsLruCache>,
4903    pub(crate) penalty_subspace_cache: RwLock<PenaltySubspaceCache>,
4904    pub(crate) current_eval_bundle: RwLock<Option<EvalShared>>,
4905    /// Most-recently-*stored* outer eval (single slot). Retained verbatim so
4906    /// `previous_outer_gradient_norm` keeps its exact "immediately previous
4907    /// distinct eval" semantics, independent of the multi-slot reuse cache.
4908    pub(crate) current_outer_eval: RwLock<Option<(Vec<u64>, OuterEval)>>,
4909    /// Bounded multi-slot LRU of converged outer evaluations keyed by the
4910    /// sanitized rho-bits (#1575).
4911    ///
4912    /// For a frozen `RemlState` (fixed design, prior weights, offset, penalty
4913    /// structure, link state, Firth/Jeffreys configuration, and rho-prior — all
4914    /// of which are immutable for the lifetime of the state that owns this
4915    /// manager and therefore the lifetime of the cache), the outer objective
4916    /// value, its gradient, and the inner-beta hint are deterministic functions
4917    /// of rho alone. The sanitized rho-bits are thus the complete result key.
4918    /// The binomial REML fit performs ~20-32 seed-grid pre-solves plus
4919    /// line-search revisits; with only the single `current_outer_eval` slot,
4920    /// any revisit to an earlier rho re-ran a full n-sized P-IRLS. This LRU
4921    /// returns the stored cost/gradient for those revisited rho-points.
4922    pub(crate) outer_eval_lru: RwLock<OuterEvalLru>,
4923    pub(crate) pirls_cache_enabled: AtomicBool,
4924}
4925
4926impl EvalCacheManager {
4927    pub(crate) fn new() -> Self {
4928        Self {
4929            pirls_cache: RwLock::new(PirlsLruCache::new(PIRLS_CACHE_BYTE_BUDGET)),
4930            penalty_subspace_cache: RwLock::new(PenaltySubspaceCache::new()),
4931            current_eval_bundle: RwLock::new(None),
4932            current_outer_eval: RwLock::new(None),
4933            outer_eval_lru: RwLock::new(OuterEvalLru::new(OUTER_EVAL_LRU_CAPACITY)),
4934            pirls_cache_enabled: AtomicBool::new(true),
4935        }
4936    }
4937
4938    /// Creates a sanitized cache key from rho values.
4939    /// Returns None if any component is NaN, in which case caching is skipped.
4940    /// Maps -0.0 to 0.0 to ensure key stability.
4941    pub(crate) fn sanitized_rhokey(rho: &Array1<f64>) -> Option<Vec<u64>> {
4942        self::rho_key::sanitized_rhokey(rho)
4943    }
4944
4945    /// Memoizing wrapper for `PenaltySubspace` construction.
4946    ///
4947    /// The penalty-subspace eigendecomposition is shape-invariant: any two
4948    /// outer evaluations at the same `(E_transformed, ridge_passport)` produce
4949    /// bit-identical subspaces. The single-slot cache amortizes consecutive
4950    /// fixed-S queries (rank, logdet, trace) within a single outer iter.
4951    pub(super) fn cached_penalty_subspace<F>(
4952        &self,
4953        e_transformed: &ndarray::Array2<f64>,
4954        ridge_passport: &gam_problem::RidgePassport,
4955        build: F,
4956    ) -> Result<Arc<outer_eval::PenaltySubspace>, EstimationError>
4957    where
4958        F: FnOnce() -> Result<outer_eval::PenaltySubspace, EstimationError>,
4959    {
4960        let key = PenaltySubspaceCacheKey::from_inputs(e_transformed, ridge_passport);
4961        if let Some(hit) = self.penalty_subspace_cache.read().unwrap().get(&key) {
4962            return Ok(hit);
4963        }
4964        let value = Arc::new(build()?);
4965        self.penalty_subspace_cache
4966            .write()
4967            .unwrap()
4968            .insert(key, value.clone());
4969        Ok(value)
4970    }
4971
4972    pub(crate) fn cached_eval_bundle(&self, key: &Option<Vec<u64>>) -> Option<EvalShared> {
4973        let guard = self.current_eval_bundle.read().unwrap();
4974        let bundle: &EvalShared = guard.as_ref()?;
4975        bundle.matches(key).then(|| bundle.clone())
4976    }
4977
4978    pub(crate) fn store_eval_bundle(&self, bundle: EvalShared) {
4979        *self.current_eval_bundle.write().unwrap() = Some(bundle);
4980    }
4981
4982    pub(crate) fn cached_outer_eval(&self, key: &Option<Vec<u64>>) -> Option<OuterEval> {
4983        let key = key.as_ref()?;
4984        // The LRU is the authoritative multi-slot store; it always contains the
4985        // most-recently-stored eval too (kept in sync by `store_outer_eval`), so
4986        // a single LRU probe subsumes the old single-slot fast path while also
4987        // serving revisited (non-immediate) rho-points. `get` is a tiny linear
4988        // scan (capacity is `OUTER_EVAL_LRU_CAPACITY`) that promotes the hit to
4989        // most-recently-used; hence the write lock.
4990        self.outer_eval_lru.write().unwrap().get(key)
4991    }
4992
4993    pub(crate) fn store_outer_eval(&self, key: &Option<Vec<u64>>, eval: &OuterEval) {
4994        if let Some(key) = key.clone() {
4995            // Keep the single-slot mirror for `previous_outer_gradient_norm`,
4996            // whose "immediately previous distinct eval" contract reads it
4997            // directly and must stay byte-for-byte unchanged.
4998            *self.current_outer_eval.write().unwrap() = Some((key.clone(), eval.clone()));
4999            self.outer_eval_lru.write().unwrap().insert(key, eval.clone());
5000        }
5001    }
5002
5003    pub(crate) fn invalidate_eval_bundle(&self) {
5004        self.current_eval_bundle.write().unwrap().take();
5005        self.current_outer_eval.write().unwrap().take();
5006        self.outer_eval_lru.write().unwrap().clear();
5007    }
5008
5009    pub(crate) fn clear_eval_and_factor_caches(&self) {
5010        self.invalidate_eval_bundle();
5011        self.penalty_subspace_cache.write().unwrap().clear();
5012    }
5013}
5014
5015/// Reusable scratch/runtime memory that should not be part of mathematical
5016/// state invariants.
5017pub(crate) struct RemlArena {
5018    pub(crate) cost_eval_count: RwLock<u64>,
5019    /// Number of *actual* full-n inner P-IRLS solves performed (#1575).
5020    ///
5021    /// Distinct from `cost_eval_count`, which counts every outer cost/gradient
5022    /// REQUEST including single-slot cache hits and prior short-circuits. This
5023    /// counts only the cache-missing `prepare_eval_bundlewithkey` calls — i.e.
5024    /// the genuinely expensive `O(n·p²)` inner solves the #1575 slowdown is
5025    /// about ("~150 outer cost evals each running a full n-sized P-IRLS"). A
5026    /// healthy warm-started fit performs roughly 2 inner solves per outer
5027    /// cost-eval (one value, one gradient/Hessian), so a large ratio between
5028    /// the two signals broken warm-starting or duplicate solving. This is pure
5029    /// observability: it never feeds back into the optimization and changes no
5030    /// fitted value.
5031    pub(crate) inner_pirls_solve_count: AtomicU64,
5032    pub(crate) lastgradient_used_stochastic_fallback: AtomicBool,
5033}
5034
5035impl RemlArena {
5036    pub(crate) fn new() -> Self {
5037        Self {
5038            cost_eval_count: RwLock::new(0),
5039            inner_pirls_solve_count: AtomicU64::new(0),
5040            lastgradient_used_stochastic_fallback: AtomicBool::new(false),
5041        }
5042    }
5043}
5044
5045pub(crate) struct AloFrozenNuisance {
5046    pub(crate) n_obs: usize,
5047    pub(crate) influence_scale: Vec<f64>,
5048    pub(crate) phi: f64,
5049}
5050
5051pub(crate) struct RemlState<'a> {
5052    pub(crate) y: ArrayView1<'a, f64>,
5053    pub(crate) x: DesignMatrix,
5054    pub(crate) weights: ArrayView1<'a, f64>,
5055    pub(crate) offset: Array1<f64>,
5056    /// Canonicalized block-local penalties with pre-computed roots.
5057    /// This is the single canonical penalty representation — no full-width
5058    /// `rank × p` roots are stored separately.
5059    pub(crate) canonical_penalties: Arc<Vec<gam_terms::construction::CanonicalPenalty>>,
5060    pub(crate) balanced_penalty_root: Array2<f64>,
5061    pub(crate) reparam_invariant: ReparamInvariant,
5062    pub(crate) sparse_penalty_block_count: Option<usize>,
5063    pub(crate) p: usize,
5064    pub(crate) config: Arc<RemlConfig>,
5065    pub(crate) runtime_mixture_link_state: Option<gam_problem::MixtureLinkState>,
5066    pub(crate) runtime_sas_link_state: Option<SasLinkState>,
5067    pub(crate) nullspace_dims: Vec<usize>,
5068    pub(crate) coefficient_lower_bounds: Option<Array1<f64>>,
5069    pub(crate) linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
5070    /// Relative shrinkage floor for penalized block eigenvalues (rho-independent).
5071    pub(crate) penalty_shrinkage_floor: Option<f64>,
5072    /// Explicit prior on log smoothing parameters used by the REML/LAML objective.
5073    pub(crate) rho_prior: gam_problem::RhoPrior,
5074
5075    pub(crate) cache_manager: EvalCacheManager,
5076    pub(crate) arena: RemlArena,
5077    pub(crate) warm_start_beta: RwLock<Option<Coefficients>>,
5078    /// Two-point ρ-trajectory used for second-order warm-start
5079    /// extrapolation: when the outer optimizer asks for a fit at a new
5080    /// ρ, we have `β(ρ_k)` (in `warm_start_beta`) and `β(ρ_{k-1})` (in
5081    /// `prev_warm_start_beta`). The implicit β(ρ) trajectory is locally
5082    /// linear under the FOC ∇F(β,ρ)=0, so a tangent-line prediction
5083    /// `β_predict(ρ_new) = β_k + α · (β_k − β_{k-1})` where α is the
5084    /// projection of `(ρ_new − ρ_k)` onto `(ρ_k − ρ_{k-1})` gives a
5085    /// better seed than the flat `β_k` alone — replacing PIRLS warm-
5086    /// start "use last β as-is" with a real tangent-prediction step.
5087    pub(crate) warm_start_rho: RwLock<Option<Array1<f64>>>,
5088    pub(crate) prev_warm_start_beta: RwLock<Option<Coefficients>>,
5089    pub(crate) prev_warm_start_rho: RwLock<Option<Array1<f64>>>,
5090    pub(crate) warm_start_enabled: AtomicBool,
5091    pub(crate) screening_max_inner_iterations: Arc<AtomicUsize>,
5092    /// Outer-aware inner-PIRLS iteration cap for the main descent loop.
5093    ///
5094    /// Distinct from `screening_max_inner_iterations`, which is used during
5095    /// seed selection and toggles a side-effect bundle (cache writes,
5096    /// warm-start updates, KKT enforcement all suppressed). This atomic is
5097    /// purely a cap — when nonzero, the inner Newton loop is capped at
5098    /// `min(this, full_max_iterations)`, but cache writes and warm-start
5099    /// updates remain enabled. Driven by the outer optimizer to coarsen
5100    /// inner solves at early outer iterations when ρ is far from converged,
5101    /// and lifted back to full at the final accepted iter (otherwise the
5102    /// returned β would be biased by the loose cap).
5103    ///
5104    /// Both atomics are honored together as `min(screening_cap, outer_cap)`
5105    /// when both are nonzero. Default 0 (no cap from this source).
5106    pub(crate) outer_inner_cap: Arc<AtomicUsize>,
5107
5108    /// Inner-PIRLS feedback signal driven by `execute_pirls_if_needed` after
5109    /// each NON-screening solve. Stores the iteration count at which the
5110    /// inner Newton stopped, plus a flag indicating whether it converged
5111    /// (vs. hit the iteration cap). The outer first-/second-order bridges
5112    /// read these atomics to drive an adaptive `inner_cap_schedule`: the
5113    /// next outer iter's inner cap becomes `last_iters + small_margin`
5114    /// when the previous solve converged, or a geometric backoff when it
5115    /// hit the cap. This replaces the older hardcoded iter-tier schedule
5116    /// (3/5/10/20) with a cap that follows the inner solver's actual
5117    /// convergence behavior — Eisenstat-Walker style for the inner
5118    /// quadratic loop. Default 0 / false (no signal yet — first outer
5119    /// iter falls back to a coarse iter-count tier).
5120    pub(crate) last_inner_iters: Arc<AtomicUsize>,
5121    pub(crate) last_inner_converged: Arc<AtomicBool>,
5122
5123    /// Cached state from the most recent successful PIRLS solve, used by
5124    /// the IFT-based warm-start predictor.
5125    ///
5126    /// The implicit-function theorem applied to the FOC ∇_β F(β,ρ)=0
5127    /// gives `dβ/dρ_k = -H_pen^{-1} · (e^{ρ_k} · S_k · β)`. A first-order
5128    /// Taylor predictor reads
5129    /// `β_predict(ρ_new) = β_cur − Σ_k Δρ_k · H_pen^{-1} · (e^{ρ_cur_k} · S_k · β_cur)`.
5130    /// This is a strict superset of the tangent-line predictor's
5131    /// requirements: works after a single successful solve (tangent-line
5132    /// needs two prior fits), and gives the EXACT first-order Jacobian
5133    /// of the implicit β(ρ) trajectory rather than a secant proxy along one
5134    /// ρ-direction.
5135    ///
5136    /// Populated in `updatewarm_start_from` when PIRLS converges; cleared
5137    /// on failure, on `reset_surface`, and on link-state changes.
5138    pub(crate) ift_warm_start_cache: RwLock<Option<IftWarmStartCache>>,
5139
5140    /// Persisted Levenberg-Marquardt damping coefficient from the most
5141    /// recent successful PIRLS solve, bit-packed into an `AtomicU64`
5142    /// (`f64::to_bits` low 64 bits). Read at the start of
5143    /// `execute_pirls_if_needed` and written into the
5144    /// `PirlsConfig::initial_lm_lambda` hint so the inner Newton seeds
5145    /// `λ_LM` near the damping the previous solve discovered, instead
5146    /// of cold-starting at `1e-6` and burning 4-6 halving steps to
5147    /// recover. `0` (the default) signals "no hint"; the inner solver
5148    /// clamps any positive hint into `[1e-6, 1e-3]` so a stale value
5149    /// cannot destabilize the next solve. Reset on `reset_surface` and
5150    /// on failed solves.
5151    pub(crate) last_pirls_lm_lambda: Arc<AtomicU64>,
5152
5153    /// Negative-Binomial overdispersion `theta` frozen for the smoothing-
5154    /// parameter (λ) search (#1082), bit-packed `f64` (`f64::to_bits`). `0`
5155    /// (the default) signals "not yet frozen". On the first non-screening
5156    /// λ-search inner solve of an estimated-θ NB fit, the seed's
5157    /// maximum-likelihood θ is computed once and stored here; every subsequent
5158    /// λ-search evaluation pins the inner solve to this value via
5159    /// `GlmLikelihoodSpec::with_negbin_theta_frozen_for_search`, so the REML
5160    /// criterion `F(ρ) = REML(ρ, θ_frozen)` is a stationary function of ρ and
5161    /// the outer optimizer converges instead of chasing the per-eval θ drift
5162    /// that the estimated path injects. The single final reported fit still
5163    /// ML-refreshes θ at the converged η. Reset on `reset_surface`.
5164    pub(crate) frozen_negbin_theta: Arc<AtomicU64>,
5165
5166    /// Last observed IFT-prediction residual (`‖β_converged − β_predicted‖
5167    /// / ‖β_converged‖`) from the most recent non-screening solve where
5168    /// the predictor was actually consumed. Bit-packed `f64` (low 64
5169    /// bits via `f64::to_bits`).
5170    ///
5171    /// "No signal yet" is encoded as a NaN bit-pattern
5172    /// (`IFT_RESIDUAL_NO_SIGNAL_BITS`). The original `0` sentinel
5173    /// collided with `f64::to_bits(0.0) == 0` — a true residual of
5174    /// exactly 0 (degenerate but mathematically possible if every
5175    /// β_predicted_i matched β_converged_i to bit-equality) would
5176    /// have been indistinguishable from "predictor never reported".
5177    /// NaN's self-inequality makes the sentinel unambiguous: any
5178    /// stored finite non-negative value is genuine signal.
5179    ///
5180    /// Read by `predict_warm_start_beta_ift_with_outcome` to drive the adaptive
5181    /// |Δρ| cap (`adaptive_ift_max_drho`): a small residual loosens
5182    /// the cap, a large one tightens it. Replaces the previous
5183    /// hardcoded `IFT_WARM_START_MAX_DRHO = 2.0` constant with a
5184    /// data-driven policy, so the predictor adapts to the empirical
5185    /// faithfulness of the linearization at this surface's scale.
5186    /// Reset on `reset_surface` and on failed solves.
5187    pub(crate) last_ift_prediction_residual: Arc<AtomicU64>,
5188
5189    /// Last observed gain ratio of the accepted LM step
5190    /// (`actual_reduction / predicted_reduction`) from the most recent
5191    /// non-screening PIRLS solve. Bit-packed `f64` with the same NaN
5192    /// sentinel discipline as `last_ift_prediction_residual`: NaN bits
5193    /// (`IFT_RESIDUAL_NO_SIGNAL_BITS`) encode "no signal yet" so a
5194    /// recorded ratio of exactly 0 (degenerate but possible) doesn't
5195    /// collide with the no-signal token.
5196    ///
5197    /// Used by `first_order_inner_cap_schedule` as a third quality
5198    /// signal alongside `last_iters` and `last_converged`. A small
5199    /// `accept_rho` (model overstating predicted reduction) is a hint
5200    /// the next iter's inner Newton may need extra margin even when
5201    /// the previous solve converged in few iters. Reset on
5202    /// `reset_surface` and on failed solves.
5203    pub(crate) last_pirls_accept_rho: Arc<AtomicU64>,
5204
5205    /// Cached Cholesky factorization of `IftWarmStartCache::penalized_hessian_transformed`.
5206    /// Lazily computed on the first IFT predict call after a fresh
5207    /// `updatewarm_start_from`, then reused by every subsequent
5208    /// predict call until the IFT cache is invalidated. At large-scale
5209    /// scale where p can reach several thousand, the dense Cholesky
5210    /// is O(p³)/3 — multiple seconds per refactor — so caching saves
5211    /// real wall time across the typical 5-10 IFT predict calls per
5212    /// outer fit. Reset jointly with `ift_warm_start_cache` (on
5213    /// reset_surface, on link-state changes, on failed PIRLS solves,
5214    /// and whenever a new H_pen replaces the cached one).
5215    pub(crate) ift_cached_factor: RwLock<Option<Arc<dyn gam_linalg::matrix::FactorizedSystem>>>,
5216
5217    /// When set, the penalties have Kronecker (tensor-product) structure and
5218    /// the REML evaluator can use O(∏q_j) logdet instead of O(p³) eigendecomposition.
5219    /// Populated via `set_kronecker_penalty_system` after construction.
5220    pub(crate) kronecker_penalty_system: Option<gam_terms::smooth::KroneckerPenaltySystem>,
5221    /// Full Kronecker factored basis (marginal designs + penalties + dims).
5222    /// Used by P-IRLS for factored reparameterization.
5223    pub(crate) kronecker_factored: Option<gam_terms::basis::KroneckerFactoredBasis>,
5224
5225    /// Precomputed `(XᵀWX, XᵀW(y − offset))` for the Gaussian + Identity
5226    /// outer REML loop, populated once before the outer optimizer when the
5227    /// family / link / constraint preconditions hold and the design supports
5228    /// the Identity short-circuit at `pirls.rs:6237`. When present, each
5229    /// inner `solve_penalized_least_squares_implicit` reads these matrices
5230    /// instead of restreaming the O(N·p²) GEMM and O(N·p) matvec per outer
5231    /// iteration — the penalty `λ·S` is still added per-λ.
5232    ///
5233    /// Invalidated jointly with the design in `reset_surface`.
5234    pub(crate) gaussian_fixed_cache: RwLock<Option<Arc<crate::pirls::GaussianFixedCache>>>,
5235    /// Conditioned-frame exact ψ-derivatives `(∂XᵀWX/∂ψ, ∂XᵀW(y−offset)/∂ψ)`
5236    /// for the SINGLE design-moving spatial hyperparameter (#1033b), assembled
5237    /// n-free from the certified Chebyshev ψ-Gram tensor and installed beside
5238    /// `gaussian_fixed_cache` at the same in-window trial. When present the
5239    /// Gaussian-identity ψ-gradient HyperCoord (`a_j`, `g_j`, dense `B_j`) is
5240    /// formed from these k×k objects instead of realizing and contracting the
5241    /// n×k ∂X/∂ψ slab — retiring the second per-trial n-pass. Lives in the
5242    /// SAME conditioned column frame as `gaussian_fixed_cache.xtwx_orig`, so
5243    /// the hyper-coord builder transforms it by the per-eval Qs/free-basis the
5244    /// same way it transforms the streamed Gram. Invalidated with the design.
5245    pub(crate) gaussian_psi_gram_deriv:
5246        RwLock<Option<Arc<(ndarray::Array2<f64>, ndarray::Array1<f64>)>>>,
5247    /// Conditioned-frame exact ψ-derivative pair `(∂XᵀWX/∂ψ, ∂XᵀW(y−offset)/∂ψ)`
5248    /// for the SINGLE design-moving spatial hyperparameter in the GLM (frozen-W)
5249    /// lane (#1033 / #1111), assembled n-free from
5250    /// [`crate::glm_sufficient_lane::FrozenWeightGramTensor::gradient_pair_if_sound`]
5251    /// and installed beside `glm_first_step_gram` at the same in-window
5252    /// drift-OK trial. When present, the GLM ψ-gradient HyperCoord serves its
5253    /// envelope `a_j` and score `g_j` from these k×k objects instead of
5254    /// realizing and contracting the n×k ∂X/∂ψ slab — the second per-trial
5255    /// n-pass. Unlike the Gaussian lane the Hessian curvature `B_j` is NOT
5256    /// served from the tensor: for a GLM the per-trial `B_j` term
5257    /// `X_τᵀWX + XᵀWX_τ` is irreducibly n-dependent (the moving working weight
5258    /// `W` does not factor out of a frozen-W k×k object), so `B_j` keeps the
5259    /// exact streamed slab (#1033). Lives in the SAME conditioned column frame
5260    /// as `glm_first_step_gram` / `gaussian_fixed_cache.xtwx_orig`, so the
5261    /// hyper-coord builder transforms it by the per-eval Qs/free-basis the same
5262    /// way. NOT family-gated (the GLM lane's own slot). Invalidated with the
5263    /// design.
5264    pub(crate) glm_psi_gram_deriv:
5265        RwLock<Option<Arc<(ndarray::Array2<f64>, ndarray::Array1<f64>)>>>,
5266    /// Frozen-weight first-Fisher-step data-fit Gram `XᵀWX` for the GLM
5267    /// design-moving ψ-sweep (#1111 / #1033 mechanism (c)), in the conditioned
5268    /// (original / `x_fit`) column frame — the SAME frame as
5269    /// `gaussian_fixed_cache.xtwx_orig`.
5270    ///
5271    /// Assembled n-free per in-window ψ-trial from the certified frozen-weight
5272    /// Chebyshev tensor ([`crate::glm_sufficient_lane::FrozenWeightGramTensor`])
5273    /// and installed only when the trial's converged working weight has not
5274    /// drifted past tolerance from the frozen snapshot. When present, the GLM
5275    /// inner P-IRLS serves its FIRST Fisher-scoring iteration's `XᵀWX` from this
5276    /// cache instead of restreaming the O(N·p²) weighted cross-product — the
5277    /// dominant per-trial n-term in a large-n Poisson/Binomial κ-sweep. The
5278    /// penalty `Sλ` is still added per-λ on top, and every subsequent inner
5279    /// iteration restreams the true (moving) `W`, so the converged β̂ is
5280    /// unchanged; only the first-iteration Gram build is elided. Unlike
5281    /// `gaussian_fixed_cache` this is NOT family-gated — it is the GLM lane's
5282    /// own slot, consumed once per inner solve. Invalidated with the design in
5283    /// `reset_surface`.
5284    pub(crate) glm_first_step_gram: RwLock<Option<Arc<ndarray::Array2<f64>>>>,
5285    /// Previous successful non-Gaussian fixed-design data-fit Gram `XᵀWX` in
5286    /// the conditioned original frame, keyed to `warm_start_beta`.
5287    ///
5288    /// When the next outer trial uses a flat warm start, its first PIRLS
5289    /// curvature build evaluates at the same `η = Xβ` as the previous converged
5290    /// solve, so the Hessian weights and `XᵀWX` are identical. Reusing this
5291    /// original-frame Gram skips one dense `O(n·p²)` pass per warm-started
5292    /// trial while still letting later PIRLS iterations restream the moving
5293    /// weights. IFT/tangent-predicted starts do not consume this cache.
5294    pub(crate) flat_glm_first_step_gram: RwLock<Option<Arc<ndarray::Array2<f64>>>>,
5295    /// Frozen ALO robustness weights for this REML surface.
5296    ///
5297    /// The PSIS influence scale is a non-smooth function of the current hat
5298    /// diagonals. Once the high-leverage ALO objective activates, it is frozen
5299    /// for the current surface so the analytic gradient differentiates the
5300    /// same fixed-weight objective the cost evaluates.
5301    pub(crate) alo_frozen_nuisance: RwLock<Option<AloFrozenNuisance>>,
5302
5303    /// Stable disk-cache key for the current realized REML surface. Computed
5304    /// lazily because it hashes the row-chunked design and data vectors.
5305    pub(crate) persistent_warm_start_key: RwLock<Option<String>>,
5306    pub(crate) persistent_latent_values_fingerprint: Option<u64>,
5307    pub(crate) persistent_latent_values_cache: RwLock<PersistentLatentValuesCache>,
5308    pub(crate) analytic_penalty_registry_fingerprint: u64,
5309    /// Ensures the process attempts at most one disk restore per surface.
5310    pub(crate) persistent_warm_start_loaded: AtomicBool,
5311    /// Scoped counter disabling disk writes from cost-only posterior/probe
5312    /// evaluations. In-memory warm starts still update; only JSON/bin
5313    /// persistence and eviction sweeps are suppressed.
5314    pub(crate) persistent_warm_start_store_suppression: AtomicUsize,
5315    /// Scoped counter disabling the Gaussian-identity ALO-stabilization
5316    /// augmentation (#979). The leverage barrier `Σ_i (h_i − τ)₊²` is an OUTER
5317    /// OPTIMIZER aid (#813/#821) that keeps the smoothing-parameter search off
5318    /// pathological high-leverage λ regions. The marginal smoothing-parameter
5319    /// posterior `π(ρ|y) ∝ exp(−LAML(ρ))` (#938) is a property of the genuine
5320    /// model criterion, sampled against a Laplace proposal built from the BASE
5321    /// REML Hessian, so the certificate / NUTS evaluations suppress the
5322    /// augmentation (see `without_alo_stabilization`) — both for proposal↔target
5323    /// consistency and to drop the per-leapfrog ALO diagnostic suite.
5324    pub(crate) alo_stabilization_suppression: AtomicUsize,
5325    /// Whether the cross-process ON-DISK warm-start layer is engaged at all.
5326    ///
5327    /// Default `false`: the optimizer's IN-MEMORY warm start (the actual
5328    /// speed lever) is always on, but the disk checkpoint — `load_record`
5329    /// at fit start and `store_record` at finalize, each of which opens the
5330    /// shared `WarmStartStore` and pays an eviction/dir scan that is O(cache
5331    /// entries) on a network filesystem — is skipped. Disk persistence has
5332    /// reuse value only ACROSS processes or across repeated identical fits;
5333    /// a single in-process fit (and a fortiori a loop of distinct throwaway
5334    /// fits, e.g. CI-coverage replicates each on different data, #1082/#1114)
5335    /// gets zero benefit from it and pays the per-fit open/scan/save in full.
5336    /// `FitConfig::persist_warm_start_disk` flips this to `true` only when the
5337    /// caller explicitly asks for cross-process / repeat-fit persistence.
5338    pub(crate) persistent_warm_start_disk_enabled: AtomicBool,
5339}