Skip to main content

gam_solve/reml/
mod.rs

1use self::inner_strategy::GeometryBackendKind;
2use super::*;
3use gam_linalg::sparse_exact::SparseExactFactor;
4use crate::pirls::PIRLS_CACHE_BYTE_BUDGET;
5use crate::pirls::assemble_and_factor_sparse_penalized_system;
6use gam_terms::basis::LocalDesignJacobianProvider;
7use gam_problem::SasLinkState;
8use gam_problem::OuterEval;
9use ndarray::{Array1, Array2, s};
10use std::collections::{HashMap, VecDeque};
11use std::ops::Range;
12use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize};
13use std::sync::{Arc, RwLock};
14
15pub mod assembly;
16pub mod atoms;
17pub(crate) mod continuation;
18pub(crate) mod eval;
19mod firth;
20pub(super) mod hyper;
21mod inner_strategy;
22// #1521 carve: promoted `pub(crate)` -> `pub` so the extracted
23// `gam-custom-family` crate reaches the Jeffreys-subspace items it consumes.
24pub mod jeffreys_subspace;
25pub mod outer_eval;
26pub mod penalty_logdet;
27pub mod per_atom_efs;
28pub mod reml_outer_engine;
29mod rho_key;
30mod sparse_exact_penalty;
31mod trace;
32
33pub(crate) use sparse_exact_penalty::sparse_penalty_block_count_from_canonical;
34
35pub(crate) const EXACT_TAU_TAU_HESSIAN_DENSE_CACHE_BUDGET_BYTES: usize = 512 * 1024 * 1024;
36pub(crate) const FIRTH_MAX_OBSERVATIONS: usize = 20_000;
37pub(crate) const FIRTH_MAX_COEFFICIENTS: usize = 256;
38pub(crate) const FIRTH_MAX_LINEAR_WORK: usize = 2_000_000;
39pub(crate) const FIRTH_MAX_QUADRATIC_WORK: usize = 100_000_000;
40pub(crate) const PERSISTENT_LATENT_VALUES_CACHE_CAPACITY: usize = 8;
41
42#[derive(Debug)]
43pub(crate) struct PersistentLatentValuesCache {
44    pub(crate) entries: HashMap<String, Array2<f64>>,
45    pub(crate) lru: VecDeque<String>,
46    pub(crate) capacity: usize,
47}
48
49impl Default for PersistentLatentValuesCache {
50    fn default() -> Self {
51        Self {
52            entries: HashMap::new(),
53            lru: VecDeque::new(),
54            capacity: PERSISTENT_LATENT_VALUES_CACHE_CAPACITY,
55        }
56    }
57}
58
59impl PersistentLatentValuesCache {
60    pub(crate) fn lookup(
61        &mut self,
62        key: &str,
63        n_obs: usize,
64        latent_dim: usize,
65    ) -> Option<Array2<f64>> {
66        let values = self.entries.get(key)?;
67        if values.dim() != (n_obs, latent_dim) {
68            return None;
69        }
70        let values = values.clone();
71        self.touch(key.to_string());
72        Some(values)
73    }
74
75    pub(crate) fn insert(&mut self, key: String, values: Array2<f64>) {
76        if values.iter().any(|value| !value.is_finite()) {
77            return;
78        }
79        self.entries.insert(key.clone(), values);
80        self.touch(key);
81        while self.entries.len() > self.capacity {
82            let Some(evicted) = self.lru.pop_front() else {
83                break;
84            };
85            self.entries.remove(&evicted);
86        }
87    }
88
89    pub(crate) fn touch(&mut self, key: String) {
90        if let Some(index) = self.lru.iter().position(|queued| queued == &key) {
91            self.lru.remove(index);
92        }
93        self.lru.push_back(key);
94    }
95}
96
97/// Cached state from the most recent successful PIRLS solve, populated by
98/// `updatewarm_start_from` and consumed by the IFT-based warm-start
99/// predictor (`RemlState::predict_warm_start_beta_ift_with_outcome`).
100/// See the field doc on `RemlState::ift_warm_start_cache` for the math.
101#[derive(Clone)]
102pub(crate) struct IftWarmStartCache {
103    /// β at the converged solve, in ORIGINAL basis. Mirror of
104    /// `warm_start_beta` stashed alongside the H factor for atomic
105    /// consistency under concurrent reads (the predictor needs both
106    /// β and H from the SAME solve; reading them from two locks risks
107    /// a torn pair if a fresh solve lands between reads).
108    pub beta_original: ndarray::Array1<f64>,
109    /// ρ at which the solve occurred. Mirror of `warm_start_rho`,
110    /// stashed for the same atomic-consistency reason.
111    pub rho: ndarray::Array1<f64>,
112    /// Penalized Hessian H_pen at the converged β, in TRANSFORMED basis.
113    /// The IFT predictor factors this on demand; basis transforms run in
114    /// transformed basis for numerical stability.
115    pub penalized_hessian_transformed: gam_linalg::matrix::SymmetricMatrix,
116    /// Reparameterization matrix qs converting between transformed
117    /// (column) basis and original basis: `β_orig = qs · β_tfd`,
118    /// `H_orig = qs · H_tfd · qs^T`.
119    pub qs: ndarray::Array2<f64>,
120    /// True when the PIRLS result was already in original basis
121    /// (`OriginalSparseNative`) — in which case `qs` is the identity
122    /// and the IFT predictor can skip the basis-conversion ops.
123    pub frame_was_original: bool,
124    /// Per-penalty precomputation `S_k · β_cur[cp.col_range]`,
125    /// indexed in lockstep with `RemlObjectiveState::canonical_penalties`.
126    /// Each entry is the local-block mat-vec the IFT predictor would
127    /// otherwise recompute on every predict call. With H_pen factor
128    /// caching (commit ec18559d) the per-call cost dropped from
129    /// `O(p³)` Cholesky to `O(p²) ≈ k · O(block²)` rhs construction;
130    /// at large-scale CTN (p ≈ several thousand) that's several ms
131    /// per predict call still being paid. By stashing `S_k · β_cur`
132    /// at cache-write time the predictor's per-call work drops to
133    /// just the `Δρ_k · e^{ρ_k} · sb_block` accumulation, which is
134    /// `O(p)` rather than `O(p²)`.
135    ///
136    /// `None` when the cache predates this commit's writer hook (e.g.,
137    /// transient state during invalidation); the predictor falls back
138    /// to recomputing the mat-vec when this is `None` or the length
139    /// mismatches `canonical_penalties.len()`.
140    pub lambda_s_beta_blocks: Option<Vec<ndarray::Array1<f64>>>,
141}
142
143#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
144pub(crate) struct TauTauPlanEstimate {
145    pub(crate) dense_x_bytes: usize,
146    pub(crate) first_order_tau_bytes: usize,
147    pub(crate) second_order_tau_bytes: usize,
148    pub(crate) penalty_first_bytes: usize,
149    pub(crate) penalty_pair_bytes: usize,
150    pub(crate) rho_tau_penalty_bytes: usize,
151    pub(crate) vector_cache_bytes: usize,
152    pub(crate) weighted_scratch_bytes: usize,
153}
154
155impl TauTauPlanEstimate {
156    pub(crate) fn total_bytes(self) -> usize {
157        self.dense_x_bytes
158            .saturating_add(self.first_order_tau_bytes)
159            .saturating_add(self.second_order_tau_bytes)
160            .saturating_add(self.penalty_first_bytes)
161            .saturating_add(self.penalty_pair_bytes)
162            .saturating_add(self.rho_tau_penalty_bytes)
163            .saturating_add(self.vector_cache_bytes)
164            .saturating_add(self.weighted_scratch_bytes)
165    }
166}
167
168#[derive(Clone, Copy, Debug, PartialEq, Eq)]
169pub(crate) struct TauTauHessianPolicy {
170    pub(crate) any_has_implicit: bool,
171    pub(crate) implicit_multidim_duchon: bool,
172    pub(crate) estimated_dense_tau_cache_bytes: usize,
173    pub(crate) gradient_plan: TauTauPlanEstimate,
174    pub(crate) hessian_plan: TauTauPlanEstimate,
175    pub(crate) budget_bytes: usize,
176    pub(crate) firth_pair_terms_unavailable: bool,
177}
178
179impl TauTauHessianPolicy {
180    /// True when the τ-τ exact-Hessian path cannot be assembled at all and the
181    /// eval must fall back to value-and-gradient mode (forcing
182    /// `HessianResult::Unavailable`).
183    ///
184    /// This is the *only* remaining capability gate: the previous
185    /// implementation also forced gradient-only when the design used implicit
186    /// multi-dim Duchon storage or when the dense τ-cache plan would exceed
187    /// the budget.  Both of those are now *cost* gates, not capability gates
188    /// — the unified evaluator's `prefer_outer_hessian_operator(n, p, k)`
189    /// selects the matrix-free `HessianResult::Operator` representation in
190    /// exactly the regimes where the dense cache would be unaffordable, and
191    /// the planner routes operator returns through `run_operator_trust_region`
192    /// (or basis-probes them when `dim ≤ OUTER_HVP_MATERIALIZE_MAX_DIM`).
193    /// Forcing gradient-only would have prevented the operator representation
194    /// from ever being requested, defeating that routing; hence the
195    /// `implicit_multidim_duchon` and cost-bytes clauses are deliberately
196    /// gone.
197    ///
198    /// Firth-pair-terms-unavailability remains a capability gate: when the
199    /// Firth-aware derivative provider cannot produce the τ-τ pair
200    /// corrections at all, no representation choice can substitute.  At every
201    /// production call site this flag is hardcoded `false` (the
202    /// `hphi_tau_tau_partial_apply` + `d_beta_hphi_tau_partial_apply`
203    /// primitives now cover the gap), so this method effectively returns
204    /// `false` in production.  We retain the field and method signature
205    /// unchanged so future Firth corner cases have a single, surfaced place
206    /// to land.
207    pub(crate) fn prefer_gradient_only(self) -> bool {
208        self.firth_pair_terms_unavailable
209    }
210}
211
212pub(crate) fn exact_tau_tau_hessian_policy_with_firth(
213    n_obs: usize,
214    p_coeff: usize,
215    hyper_dirs: &[DirectionalHyperParam],
216    firth_pair_terms_unavailable: bool,
217) -> TauTauHessianPolicy {
218    let f64_bytes = std::mem::size_of::<f64>();
219    let dense_matrix_bytes =
220        |rows: usize, cols: usize| -> usize { rows.saturating_mul(cols).saturating_mul(f64_bytes) };
221    let dense_design_bytes = dense_matrix_bytes(n_obs, p_coeff);
222    let dense_penalty_bytes = dense_matrix_bytes(p_coeff, p_coeff);
223    let psi_dim = hyper_dirs.len();
224    let implicit_n_axes = hyper_dirs
225        .iter()
226        .find_map(DirectionalHyperParam::implicit_axis_count_hint)
227        .unwrap_or(0);
228    let gradient_uses_implicit_design = hyper_dirs
229        .iter()
230        .any(DirectionalHyperParam::has_implicit_operator)
231        && gam_terms::basis::should_use_implicit_operators_with_policy(
232            n_obs,
233            p_coeff,
234            implicit_n_axes,
235            &gam_runtime::resource::ResourcePolicy::default_library(),
236        );
237    let dense_first_order_count = hyper_dirs
238        .iter()
239        .filter(|dir| !dir.has_implicit_operator())
240        .count();
241    let first_penalty_component_count = hyper_dirs
242        .iter()
243        .map(DirectionalHyperParam::penalty_first_component_count)
244        .sum::<usize>();
245
246    let mut dense_second_order_count = 0usize;
247    let mut penalty_pair_count = 0usize;
248    for i in 0..psi_dim {
249        for j in i..psi_dim {
250            if hyper_dirs[i]
251                .x_tau_tau_entry_at(j)
252                .or_else(|| hyper_dirs[j].x_tau_tau_entry_at(i))
253                .is_some_and(|entry| !entry.uses_implicit_storage())
254            {
255                dense_second_order_count += if i == j { 1 } else { 2 };
256            }
257            if hyper_dirs[i].has_penaltysecond_pair_at(j)
258                || hyper_dirs[j].has_penaltysecond_pair_at(i)
259            {
260                penalty_pair_count += if i == j { 1 } else { 2 };
261            }
262        }
263    }
264
265    let gradient_dense_first_order_count = if gradient_uses_implicit_design {
266        dense_first_order_count
267    } else {
268        psi_dim
269    };
270    let gradient_needs_dense_x =
271        firth_pair_terms_unavailable || gradient_dense_first_order_count > 0;
272    let gradient_plan = TauTauPlanEstimate {
273        dense_x_bytes: if gradient_needs_dense_x {
274            dense_design_bytes
275        } else {
276            0
277        },
278        first_order_tau_bytes: if gradient_dense_first_order_count > 0 {
279            dense_design_bytes
280        } else {
281            0
282        },
283        second_order_tau_bytes: 0,
284        penalty_first_bytes: psi_dim.saturating_mul(dense_penalty_bytes),
285        penalty_pair_bytes: 0,
286        rho_tau_penalty_bytes: 0,
287        vector_cache_bytes: n_obs.saturating_mul(f64_bytes),
288        weighted_scratch_bytes: dense_penalty_bytes,
289    };
290    let hessian_plan = TauTauPlanEstimate {
291        dense_x_bytes: if psi_dim > 0 { dense_design_bytes } else { 0 },
292        first_order_tau_bytes: dense_first_order_count.saturating_mul(dense_design_bytes),
293        second_order_tau_bytes: dense_second_order_count.saturating_mul(dense_design_bytes),
294        penalty_first_bytes: psi_dim.saturating_mul(dense_penalty_bytes),
295        penalty_pair_bytes: penalty_pair_count.saturating_mul(dense_penalty_bytes),
296        rho_tau_penalty_bytes: first_penalty_component_count
297            .saturating_mul(2)
298            .saturating_mul(dense_penalty_bytes),
299        vector_cache_bytes: psi_dim.saturating_mul(n_obs).saturating_mul(f64_bytes),
300        weighted_scratch_bytes: dense_penalty_bytes,
301    };
302    let any_has_implicit = hyper_dirs
303        .iter()
304        .any(DirectionalHyperParam::has_implicit_operator);
305    let implicit_multidim_duchon = hyper_dirs
306        .iter()
307        .any(DirectionalHyperParam::has_implicit_multidim_duchon);
308    let estimated_dense_tau_cache_bytes = hessian_plan
309        .first_order_tau_bytes
310        .saturating_add(hessian_plan.second_order_tau_bytes);
311    TauTauHessianPolicy {
312        any_has_implicit,
313        implicit_multidim_duchon,
314        estimated_dense_tau_cache_bytes,
315        gradient_plan,
316        hessian_plan,
317        budget_bytes: EXACT_TAU_TAU_HESSIAN_DENSE_CACHE_BUDGET_BYTES,
318        firth_pair_terms_unavailable: firth_pair_terms_unavailable && !hyper_dirs.is_empty(),
319    }
320}
321
322pub(crate) fn firth_problem_scale_allows(n_obs: usize, p_coeff: usize) -> bool {
323    let linear_work = n_obs.saturating_mul(p_coeff);
324    let quadratic_work = linear_work.saturating_mul(p_coeff);
325    n_obs <= FIRTH_MAX_OBSERVATIONS
326        && p_coeff <= FIRTH_MAX_COEFFICIENTS
327        && linear_work <= FIRTH_MAX_LINEAR_WORK
328        && quadratic_work <= FIRTH_MAX_QUADRATIC_WORK
329}
330
331#[cfg(test)]
332mod tests {
333    use super::{
334        DirectionalHyperParam, EvalCacheManager, EvalShared, HyperDesignDerivative,
335        HyperPenaltyDerivative, ImplicitDerivLevel, RemlConfig, RemlState,
336    };
337    use crate::estimate::EstimationError;
338    use gam_linalg::faer_ndarray::FaerCholesky;
339    use gam_linalg::matrix::symmetrize_in_place;
340    use crate::pirls::PirlsCoordinateFrame;
341    use gam_terms::basis::{ImplicitDesignPsiDerivative, RadialScalarKind};
342    use gam_problem::{
343        GlmLikelihoodSpec, InverseLink, LikelihoodSpec, ResponseFamily, StandardLink,
344    };
345    use faer::Side;
346    use gam_problem::{HessianResult, OuterEval};
347    use ndarray::{Array1, Array2, array, s};
348    use std::sync::Arc;
349
350    /// Shorthand for the canonical Binomial-Logit `GlmLikelihoodSpec` used by
351    /// the REML test fixtures.
352    pub(crate) fn binomial_logit_glm_spec() -> GlmLikelihoodSpec {
353        GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
354            ResponseFamily::Binomial,
355            InverseLink::Standard(StandardLink::Logit),
356        ))
357    }
358
359    /// Shorthand for the canonical Gaussian-Identity `GlmLikelihoodSpec` used
360    /// by the REML test fixtures.
361    pub(crate) fn gaussian_identity_glm_spec() -> GlmLikelihoodSpec {
362        GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
363            ResponseFamily::Gaussian,
364            InverseLink::Standard(StandardLink::Identity),
365        ))
366    }
367
368    impl DirectionalHyperParam {
369        pub(super) fn new(
370            x_tau_original: Array2<f64>,
371            penalty_first_components: Vec<(usize, Array2<f64>)>,
372            x_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
373            penaltysecond_components: Option<Vec<Option<Vec<(usize, Array2<f64>)>>>>,
374        ) -> Result<Self, EstimationError> {
375            let x_tau_tau_original = x_tau_tau_original.map(|rows| {
376                rows.into_iter()
377                    .map(|entry| entry.map(HyperDesignDerivative::from))
378                    .collect::<Vec<_>>()
379            });
380            let penalty_first_components = penalty_first_components
381                .into_iter()
382                .map(|(idx, matrix)| (idx, HyperPenaltyDerivative::from(matrix)))
383                .collect();
384            let penaltysecond_components = penaltysecond_components.map(|rows| {
385                rows.into_iter()
386                    .map(|row| {
387                        row.map(|components| {
388                            components
389                                .into_iter()
390                                .map(|(idx, matrix)| (idx, HyperPenaltyDerivative::from(matrix)))
391                                .collect::<Vec<_>>()
392                        })
393                    })
394                    .collect::<Vec<_>>()
395            });
396            Self::new_compact(
397                HyperDesignDerivative::from(x_tau_original),
398                penalty_first_components,
399                x_tau_tau_original,
400                penaltysecond_components,
401            )
402        }
403
404        pub(super) fn single_penalty(
405            penalty_index: usize,
406            x_tau_original: Array2<f64>,
407            s_tau_original: Array2<f64>,
408            x_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
409            s_tau_tau_original: Option<Vec<Option<Array2<f64>>>>,
410        ) -> Result<Self, EstimationError> {
411            let penaltysecond_components = s_tau_tau_original.map(|rows| {
412                rows.into_iter()
413                    .map(|mat| mat.map(|mat| vec![(penalty_index, mat)]))
414                    .collect::<Vec<_>>()
415            });
416            Self::new(
417                x_tau_original,
418                vec![(penalty_index, s_tau_original)],
419                x_tau_tau_original,
420                penaltysecond_components,
421            )
422        }
423    }
424
425    #[test]
426    pub(crate) fn firth_problem_scale_gate_blocks_large_quadratic_work() {
427        assert!(super::firth_problem_scale_allows(2_000, 200));
428        assert!(!super::firth_problem_scale_allows(4_800, 241));
429        assert!(!super::firth_problem_scale_allows(4_800, 433));
430    }
431
432    #[test]
433    pub(crate) fn tau_tau_hessian_policy_prefers_gradient_only_for_implicit_tau() {
434        let operator = ImplicitDesignPsiDerivative::new(
435            array![1.0, 2.0, 3.0, 4.0],
436            array![0.5, -1.0, 1.5, 2.0],
437            array![0.1, 0.2, 0.3, 0.4],
438            array![[1.0, 0.2], [0.5, 0.1], [1.5, 0.3], [2.0, 0.4]],
439            None,
440            None,
441            2,
442            2,
443            1,
444            2,
445        );
446        let dir = DirectionalHyperParam::new_compact(
447            HyperDesignDerivative::from_implicit(
448                Arc::new(operator),
449                ImplicitDerivLevel::First(0),
450                1..4,
451                5,
452            ),
453            Vec::new(),
454            None,
455            None,
456        )
457        .expect("implicit directional hyperparam");
458        let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], false);
459        assert!(policy.any_has_implicit);
460        assert_eq!(
461            policy.gradient_plan.dense_x_bytes,
462            10 * 5 * std::mem::size_of::<f64>()
463        );
464        assert!(!policy.prefer_gradient_only());
465    }
466
467    #[test]
468    pub(crate) fn tau_tau_hessian_policy_does_not_force_gradient_only_for_implicit_multidim_duchon()
469    {
470        // Multi-dim Duchon implicit storage used to force gradient-only,
471        // because the τ-cache materialization plan was infeasible.  The
472        // unified evaluator now elects the matrix-free
473        // `HessianResult::Operator` representation in this regime via
474        // `prefer_outer_hessian_operator`, so the planner can route to the
475        // operator trust-region (or basis-probe to dense for small K) — the
476        // capability is preserved and gradient-only must NOT engage.
477        let operator = ImplicitDesignPsiDerivative::new_streaming(
478            Arc::new(array![[0.0, 0.0], [1.0, 0.2]]),
479            Arc::new(array![[0.0, 0.0], [1.0, 1.0]]),
480            vec![0.0, 0.0],
481            RadialScalarKind::PureDuchon {
482                block_order: 1,
483                p_order: 0,
484                s_order: 0,
485                dim: 2,
486            },
487            None,
488            None,
489            0,
490        );
491        let dir = DirectionalHyperParam::new_compact(
492            HyperDesignDerivative::from_implicit(
493                Arc::new(operator),
494                ImplicitDerivLevel::First(0),
495                0..2,
496                2,
497            ),
498            Vec::new(),
499            None,
500            None,
501        )
502        .expect("implicit duchon directional hyperparam");
503        let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], false);
504        assert!(policy.any_has_implicit);
505        assert!(policy.implicit_multidim_duchon);
506        assert!(!policy.prefer_gradient_only());
507    }
508
509    #[test]
510    pub(crate) fn tau_tau_hessian_policy_does_not_force_gradient_only_when_cache_budget_is_exceeded()
511     {
512        // The dense τ-cache plan exceeds the budget, but cost is no longer a
513        // capability gate: the eval-side selects the matrix-free operator
514        // representation in exactly this regime, and the planner routes
515        // accordingly.  `prefer_gradient_only` must NOT force `Unavailable`
516        // here.
517        let dirs = (0..16)
518            .map(|_| {
519                DirectionalHyperParam::new_compact(
520                    HyperDesignDerivative::from(Array2::<f64>::zeros((2, 2))),
521                    Vec::new(),
522                    None,
523                    None,
524                )
525                .expect("dense directional hyperparam")
526            })
527            .collect::<Vec<_>>();
528        let policy = super::exact_tau_tau_hessian_policy_with_firth(320_000, 71, &dirs, false);
529        assert!(!policy.any_has_implicit);
530        assert!(policy.hessian_plan.total_bytes() > policy.budget_bytes);
531        assert!(policy.hessian_plan.total_bytes() > policy.gradient_plan.total_bytes());
532        assert!(!policy.prefer_gradient_only());
533    }
534
535    #[test]
536    pub(crate) fn tau_tau_hessian_policy_prefers_gradient_only_for_firth_pair_gap() {
537        let dir = DirectionalHyperParam::new_compact(
538            HyperDesignDerivative::from(Array2::<f64>::zeros((2, 2))),
539            Vec::new(),
540            None,
541            None,
542        )
543        .expect("dense directional hyperparam");
544        let policy = super::exact_tau_tau_hessian_policy_with_firth(10, 5, &[dir], true);
545        assert!(policy.firth_pair_terms_unavailable);
546        assert!(policy.prefer_gradient_only());
547    }
548
549    /// Common shape for the design-motion + penalty-motion REML test fixtures
550    /// (Gaussian-identity and binomial-logit at present): both carry the
551    /// same `(y, w, X, S0, cfg, ρ)` plus a perturbation pair, and need the
552    /// same three helpers (`state`, `state_perturbed`, `fd_directional_gradient`).
553    /// Per-fixture `new()` constructors fill the fields with family-specific
554    /// data; the helpers below are shared via default impls so every fixture
555    /// pays the boilerplate once.
556    trait LogitDesignMotionFixture {
557        fn y(&self) -> &Array1<f64>;
558        fn w(&self) -> &Array1<f64>;
559        fn x(&self) -> &Array2<f64>;
560        fn s0(&self) -> &Array2<f64>;
561        fn cfg(&self) -> &RemlConfig;
562        fn rho(&self) -> &Array1<f64>;
563
564        fn state(&self) -> RemlState<'_> {
565            build_logit_state(self.y(), self.w(), self.x(), self.s0(), self.cfg())
566        }
567
568        fn state_perturbed(
569            &self,
570            x_tau: &Array2<f64>,
571            s_tau: &Array2<f64>,
572            eps: f64,
573        ) -> (RemlState<'_>, RemlState<'_>) {
574            let x_plus = self.x() + &x_tau.mapv(|v| eps * v);
575            let x_minus = self.x() - &x_tau.mapv(|v| eps * v);
576            let s_plus = self.s0() + &s_tau.mapv(|v| eps * v);
577            let s_minus = self.s0() - &s_tau.mapv(|v| eps * v);
578            (
579                build_logit_state(self.y(), self.w(), &x_plus, &s_plus, self.cfg()),
580                build_logit_state(self.y(), self.w(), &x_minus, &s_minus, self.cfg()),
581            )
582        }
583
584        /// Central FD approximation to the directional cost derivative at ρ.
585        fn fd_directional_gradient(&self, x_tau: &Array2<f64>, s_tau: &Array2<f64>) -> f64 {
586            let h = 2e-5;
587            let (state_plus, state_minus) = self.state_perturbed(x_tau, s_tau, h);
588            let v_plus = state_plus.compute_cost(self.rho()).expect("cost+");
589            let v_minus = state_minus.compute_cost(self.rho()).expect("cost-");
590            (v_plus - v_minus) / (2.0 * h)
591        }
592    }
593
594    pub(crate) fn build_logit_state<'a>(
595        y: &'a Array1<f64>,
596        w: &'a Array1<f64>,
597        x: &Array2<f64>,
598        s: &Array2<f64>,
599        cfg: &'a RemlConfig,
600    ) -> RemlState<'a> {
601        use crate::estimate::PenaltySpec;
602        let p = x.ncols();
603        let offset = Array1::<f64>::zeros(y.len());
604        let spec = PenaltySpec::Dense(s.clone());
605        let canonical = gam_terms::construction::canonicalize_penalty_specs(&[spec], &[1], p, "test")
606            .map(|(canonical, _)| canonical)
607            .expect("canonicalize");
608        RemlState::newwith_offset(
609            y.view(),
610            x.clone(),
611            w.view(),
612            offset.view(),
613            canonical,
614            p,
615            cfg,
616            Some(vec![1]),
617            None,
618            None,
619        )
620        .expect("state")
621    }
622
623    #[test]
624    fn repeated_penalty_ranges_keep_analytic_outer_hessian() {
625        let y = array![0.2, -0.1, 0.3, 0.0];
626        let w = Array1::<f64>::ones(y.len());
627        let x = array![[1.0, -0.7], [1.0, -0.2], [1.0, 0.3], [1.0, 0.9]];
628        let offset = Array1::<f64>::zeros(y.len());
629        let cfg = RemlConfig::external(gaussian_identity_glm_spec(), 1e-10, false);
630        let p = x.ncols();
631        let canonical = vec![
632            gam_terms::construction::CanonicalPenalty::from_dense_root(array![[0.0, 1.0]], p),
633            gam_terms::construction::CanonicalPenalty::from_dense_root(array![[1.0, 0.0]], p),
634        ];
635        let state = RemlState::newwith_offset(
636            y.view(),
637            x,
638            w.view(),
639            offset.view(),
640            canonical,
641            p,
642            &cfg,
643            Some(vec![1, 1]),
644            None,
645            None,
646        )
647        .expect("state");
648
649        assert!(
650            state.analytic_outer_hessian_enabled(),
651            "double-penalty-style repeated coefficient ranges must still route to exact Hessian"
652        );
653    }
654
655    pub(crate) fn poisson_log_glm_spec() -> GlmLikelihoodSpec {
656        GlmLikelihoodSpec::canonical(LikelihoodSpec::new(
657            ResponseFamily::Poisson,
658            InverseLink::Standard(StandardLink::Log),
659        ))
660    }
661
662    /// Regression (issue #893): for a fixed-dispersion family a uniform prior
663    /// weight `w = c` is *exact* `c`-fold row replication. The two encodings must
664    /// therefore present a byte-identical LAML smoothing-selection surface — both
665    /// the cost `V(ρ)` and its gradient `∇V(ρ)` — because every term (penalised
666    /// deviance `D_p`, the working cross-product `XᵀWX`, the log-determinants)
667    /// is a sum of per-observation contributions that is identical whether a row
668    /// carries weight `c` or is stacked `c` times. This locks the *surface*
669    /// invariant that #893 ultimately reduces to: when the surfaces coincide,
670    /// the only remaining requirement for `λ̂(w=c) = λ̂(c×)` is that the outer
671    /// optimiser resolve the shared optimum (handled by the tightened outer
672    /// tolerance in `workflow.rs`). A regression that reintroduced a
673    /// row-count-vs-weight-sum asymmetry into the inner solve or the cost would
674    /// break this directly, independent of the optimiser tolerance.
675    #[test]
676    pub(crate) fn fixed_dispersion_laml_surface_is_replication_invariant() {
677        let n = 200usize;
678        let p = 8usize;
679        let c = 3usize;
680        let mut x = Array2::<f64>::zeros((n, p));
681        let mut y = Array1::<f64>::zeros(n);
682        for i in 0..n {
683            let t = (i as f64) / ((n - 1) as f64);
684            let tau = std::f64::consts::TAU;
685            x[[i, 0]] = 1.0;
686            x[[i, 1]] = t;
687            x[[i, 2]] = (tau * t).sin();
688            x[[i, 3]] = (tau * t).cos();
689            x[[i, 4]] = (2.0 * tau * t).sin();
690            x[[i, 5]] = (2.0 * tau * t).cos();
691            x[[i, 6]] = (3.0 * tau * t).sin();
692            x[[i, 7]] = (3.0 * tau * t).cos();
693            let eta = 0.3 + 0.9 * (1.4 * (t - 0.5)).sin();
694            // Deterministic non-negative integer counts near exp(eta).
695            y[i] = (eta.exp() + 0.5 * ((i as f64) * 2.399_963).sin())
696                .round()
697                .max(0.0);
698        }
699        let mut s = Array2::<f64>::zeros((p, p));
700        for j in 1..p {
701            s[[j, j]] = 1.0;
702        }
703
704        // Replicated design (c literal copies of each row).
705        let mut x_rep = Array2::<f64>::zeros((n * c, p));
706        let mut y_rep = Array1::<f64>::zeros(n * c);
707        for r in 0..c {
708            for i in 0..n {
709                let row = r * n + i;
710                for j in 0..p {
711                    x_rep[[row, j]] = x[[i, j]];
712                }
713                y_rep[row] = y[i];
714            }
715        }
716
717        let w_weighted = Array1::<f64>::from_elem(n, c as f64);
718        let w_rep = Array1::<f64>::ones(n * c);
719
720        let cfg = RemlConfig::external(poisson_log_glm_spec(), 1e-10, false);
721        let st_w = build_logit_state(&y, &w_weighted, &x, &s, &cfg);
722        let st_r = build_logit_state(&y_rep, &w_rep, &x_rep, &s, &cfg);
723
724        for &rho in &[-2.0_f64, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0] {
725            let r = Array1::from_elem(1, rho);
726            let cw = st_w.compute_cost(&r).expect("weighted cost");
727            let cr = st_r.compute_cost(&r).expect("replicated cost");
728            let gw = st_w.compute_gradient(&r).expect("weighted grad");
729            let gr = st_r.compute_gradient(&r).expect("replicated grad");
730            // Costs and gradients must coincide to optimiser precision; the only
731            // admissible difference is f64 summation order over n vs c·n rows.
732            assert!(
733                (cw - cr).abs() <= 1e-9 * (1.0 + cw.abs()),
734                "LAML cost differs between w=c and c× replication at rho={rho}: \
735                 cost_w={cw:.12e} cost_r={cr:.12e} diff={:.3e}",
736                cw - cr
737            );
738            assert!(
739                (gw[0] - gr[0]).abs() <= 1e-9 * (1.0 + gw[0].abs()),
740                "LAML gradient differs between w=c and c× replication at rho={rho}: \
741                 g_w={:.12e} g_r={:.12e} diff={:.3e}",
742                gw[0],
743                gr[0],
744                gw[0] - gr[0]
745            );
746        }
747    }
748
749    /// Regression (issue #893): the geometric-mean log-weight ρ-anchor
750    /// ([`RemlState::rho_weight_anchor`]) is a *profiled*-dispersion construct
751    /// (issue #877). For a fixed-dispersion family the optimum does not slide by
752    /// `log c` under a weight rescale in a way the prior should track, and a
753    /// nonzero anchor would evaluate the regularising ρ-prior at *different*
754    /// coordinates for the `w=c` vs `c×` encodings — breaking the very
755    /// equivalence #893 requires. The anchor must therefore be exactly `0` for a
756    /// fixed-dispersion family and the geometric mean for Gaussian-identity.
757    #[test]
758    pub(crate) fn rho_weight_anchor_is_zero_for_fixed_dispersion() {
759        let n = 50usize;
760        let p = 3usize;
761        let mut x = Array2::<f64>::zeros((n, p));
762        let mut y = Array1::<f64>::zeros(n);
763        for i in 0..n {
764            let t = (i as f64) / ((n - 1) as f64);
765            x[[i, 0]] = 1.0;
766            x[[i, 1]] = t;
767            x[[i, 2]] = t * t;
768            y[i] = (1.0 + (3.0 * t).sin()).round().max(0.0);
769        }
770        let mut s = Array2::<f64>::zeros((p, p));
771        s[[2, 2]] = 1.0;
772        // All weights = c: geometric-mean log-weight = ln(c) ≠ 0.
773        let c = 4.0_f64;
774        let w = Array1::<f64>::from_elem(n, c);
775
776        let cfg_pois = RemlConfig::external(poisson_log_glm_spec(), 1e-10, false);
777        let st_pois = build_logit_state(&y, &w, &x, &s, &cfg_pois);
778        assert_eq!(
779            st_pois.rho_weight_anchor(),
780            0.0,
781            "fixed-dispersion (Poisson) anchor must be 0, not the geometric-mean log-weight"
782        );
783
784        let cfg_gauss = RemlConfig::external(gaussian_identity_glm_spec(), 1e-10, false);
785        let st_gauss = build_logit_state(&y, &w, &x, &s, &cfg_gauss);
786        assert!(
787            (st_gauss.rho_weight_anchor() - c.ln()).abs() <= 1e-12,
788            "Gaussian-identity (profiled) anchor must be the geometric-mean log-weight ln(c)={:.6}, got {:.6}",
789            c.ln(),
790            st_gauss.rho_weight_anchor()
791        );
792    }
793
794    pub(crate) fn beta_original_from_bundle(bundle: &EvalShared) -> Array1<f64> {
795        let pr = bundle.pirls_result.as_ref();
796        match pr.coordinate_frame {
797            PirlsCoordinateFrame::OriginalSparseNative => pr.beta_transformed.as_ref().clone(),
798            PirlsCoordinateFrame::TransformedQs => {
799                pr.reparam_result.qs.dot(pr.beta_transformed.as_ref())
800            }
801        }
802    }
803
804    pub(crate) fn compute_joint_hypercostgradienthessian(
805        state: &RemlState<'_>,
806        theta: &Array1<f64>,
807        rho_dim: usize,
808        hyper_dirs: &[DirectionalHyperParam],
809    ) -> Result<(f64, Array1<f64>, Array2<f64>), EstimationError> {
810        let (cost, gradient, hessian) = state.compute_joint_hyper_eval_with_order(
811            theta,
812            rho_dim,
813            hyper_dirs,
814            crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
815        )?;
816        Ok((
817            cost,
818            gradient,
819            hessian
820                .materialize_dense()
821                .map_err(EstimationError::RemlOptimizationFailed)?
822                .ok_or_else(|| {
823                    EstimationError::RemlOptimizationFailed(
824                        "joint hyper Hessian requested but unavailable".to_string(),
825                    )
826                })?,
827        ))
828    }
829
830    pub(crate) fn h_original_from_bundle(bundle: &EvalShared) -> Array2<f64> {
831        let pr = bundle.pirls_result.as_ref();
832        match pr.coordinate_frame {
833            PirlsCoordinateFrame::OriginalSparseNative => bundle.h_total.as_ref().clone(),
834            PirlsCoordinateFrame::TransformedQs => {
835                let qs = &pr.reparam_result.qs;
836                let tmp = gam_linalg::faer_ndarray::fast_ab(qs, bundle.h_total.as_ref());
837                gam_linalg::faer_ndarray::fast_abt(&tmp, qs)
838            }
839        }
840    }
841
842    pub(crate) fn single_directional_tau_gradient(
843        state: &RemlState<'_>,
844        rho: &Array1<f64>,
845        hyper: DirectionalHyperParam,
846    ) -> Result<f64, EstimationError> {
847        let mut theta = Array1::<f64>::zeros(rho.len() + 1);
848        theta.slice_mut(s![..rho.len()]).assign(rho);
849        let (_, gradient, _) = state.compute_joint_hyper_eval_with_order(
850            &theta,
851            rho.len(),
852            &[hyper],
853            crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
854        )?;
855        Ok(gradient[rho.len()])
856    }
857
858    pub(crate) fn fd_directional_tau_cost_gradient(
859        y: &Array1<f64>,
860        w: &Array1<f64>,
861        x: &Array2<f64>,
862        s0: &Array2<f64>,
863        cfg: &RemlConfig,
864        rho: &Array1<f64>,
865        x_tau: &Array2<f64>,
866        s_tau: &Array2<f64>,
867    ) -> f64 {
868        let h = 2e-5;
869        let x_plus = x + &x_tau.mapv(|v| h * v);
870        let x_minus = x - &x_tau.mapv(|v| h * v);
871        let s_plus = s0 + &s_tau.mapv(|v| h * v);
872        let s_minus = s0 - &s_tau.mapv(|v| h * v);
873        let state_plus = build_logit_state(y, w, &x_plus, &s_plus, cfg);
874        let state_minus = build_logit_state(y, w, &x_minus, &s_minus, cfg);
875        let v_plus = state_plus.compute_cost(rho).expect("cost+");
876        let v_minus = state_minus.compute_cost(rho).expect("cost-");
877        (v_plus - v_minus) / (2.0 * h)
878    }
879
880    pub(crate) fn directional_tau_hessian_fd_reference(
881        y: &Array1<f64>,
882        w: &Array1<f64>,
883        x: &Array2<f64>,
884        s0: &Array2<f64>,
885        cfg: &RemlConfig,
886        rho: &Array1<f64>,
887        hyper_dirs: &[DirectionalHyperParam],
888        x_tau_mats: &[Array2<f64>],
889        s_tau_mats: &[Array2<f64>],
890    ) -> Array2<f64> {
891        assert_eq!(hyper_dirs.len(), x_tau_mats.len());
892        assert_eq!(hyper_dirs.len(), s_tau_mats.len());
893
894        const TARGET_PHYSICAL_STEP: f64 = 1e-5;
895
896        let n_dirs = hyper_dirs.len();
897        let mut h_ttfd = Array2::<f64>::zeros((n_dirs, n_dirs));
898        for j in 0..n_dirs {
899            let direction_scale = x_tau_mats[j]
900                .iter()
901                .chain(s_tau_mats[j].iter())
902                .fold(0.0_f64, |acc, value| acc.max(value.abs()));
903            let h = if direction_scale > 0.0 {
904                TARGET_PHYSICAL_STEP / direction_scale
905            } else {
906                TARGET_PHYSICAL_STEP
907            };
908
909            let x_plus = x + &x_tau_mats[j].mapv(|v| h * v);
910            let x_minus = x - &x_tau_mats[j].mapv(|v| h * v);
911            let s_plus = s0 + &s_tau_mats[j].mapv(|v| h * v);
912            let s_minus = s0 - &s_tau_mats[j].mapv(|v| h * v);
913
914            let state_plus = build_logit_state(y, w, &x_plus, &s_plus, cfg);
915            let state_minus = build_logit_state(y, w, &x_minus, &s_minus, cfg);
916            for i in 0..n_dirs {
917                let g_plus =
918                    single_directional_tau_gradient(&state_plus, rho, hyper_dirs[i].clone())
919                        .expect("g+ for FD");
920                let g_minus =
921                    single_directional_tau_gradient(&state_minus, rho, hyper_dirs[i].clone())
922                        .expect("g- for FD");
923                h_ttfd[[i, j]] = (g_plus - g_minus) / (2.0 * h);
924            }
925        }
926        symmetrize_in_place(&mut h_ttfd);
927        h_ttfd
928    }
929
930    #[test]
931    pub(crate) fn eval_cache_manager_stores_first_order_outer_eval() {
932        let cache = EvalCacheManager::new();
933        let rho = array![0.25, -0.0];
934        let rho_key = EvalCacheManager::sanitized_rhokey(&rho);
935        let eval = OuterEval {
936            cost: 3.5,
937            gradient: array![1.0, -2.0],
938            hessian: HessianResult::Unavailable,
939            inner_beta_hint: None,
940        };
941
942        cache.store_outer_eval(&rho_key, &eval);
943
944        let cached = cache
945            .cached_outer_eval(&rho_key)
946            .expect("first-order outer eval should be cached");
947        assert_eq!(cached.cost, eval.cost);
948        assert_eq!(cached.gradient, eval.gradient);
949        assert!(matches!(cached.hessian, HessianResult::Unavailable));
950
951        cache.invalidate_eval_bundle();
952        assert!(
953            cache.cached_outer_eval(&rho_key).is_none(),
954            "invalidating the bundle should clear the outer-eval cache too"
955        );
956    }
957
958    #[test]
959    pub(crate) fn reset_outer_seed_state_clears_pirls_cache() {
960        // Build a minimal logit RemlState, populate the cross-call PIRLS LRU
961        // by evaluating the outer objective at one rho, then verify that
962        // reset_outer_seed_state wipes that LRU (alongside the eval bundle
963        // and warm-start signals). This pins down the cross-attempt
964        // cleanup contract that a budget-bump retry relies on.
965        let y = array![0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
966        let w = Array1::<f64>::ones(y.len());
967        let x = array![
968            [1.0, -1.0, 0.2],
969            [1.0, -0.5, -0.4],
970            [1.0, 0.0, 0.7],
971            [1.0, 0.4, -0.3],
972            [1.0, 0.9, 0.1],
973            [1.0, 1.3, -0.6],
974        ];
975        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.1, 0.15], [0.0, 0.15, 0.8],];
976        let rho = array![0.0];
977        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false);
978        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
979
980        // Trigger a full outer eval so execute_pirls_if_needed inserts at
981        // least one entry into the cross-call PIRLS LRU.
982        state
983            .compute_outer_eval_with_order(
984                &rho,
985                crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
986            )
987            .expect("outer eval should succeed");
988
989        let populated_len = state.cache_manager.pirls_cache.read().unwrap().map.len();
990        assert!(
991            populated_len > 0,
992            "evaluating the outer objective should populate the PIRLS LRU, got {populated_len}"
993        );
994
995        state.reset_outer_seed_state();
996
997        let cleared_len = state.cache_manager.pirls_cache.read().unwrap().map.len();
998        assert_eq!(
999            cleared_len, 0,
1000            "reset_outer_seed_state must clear the cross-call PIRLS LRU; got {cleared_len} entries"
1001        );
1002    }
1003
1004    #[test]
1005    pub(crate) fn reset_outer_seed_state_preserves_frozen_negbin_theta_1448() {
1006        // #1448 regression: the NB outer θ↔λ alternation loop
1007        // (solver/estimate/optimizer.rs) re-runs the ρ search after each θ
1008        // refresh by (a) re-freezing the λ-search θ at θ_final into
1009        // `frozen_negbin_theta`, then (b) calling `reset_outer_seed_state()` to
1010        // drop the caches keyed to the old θ. Step (b) MUST NOT clear the freeze
1011        // set in step (a): the capture in `solve_for_unified_rho` only writes the
1012        // frozen slot when it is 0, so if the reset zeroed it the next round would
1013        // re-derive θ from the seed η and the loop would never reach the (ρ, θ)
1014        // joint fixed point — silently regressing #1448 back to a single
1015        // freeze→refresh pass.
1016        //
1017        // This pins the load-bearing distinction between `reset_outer_seed_state`
1018        // (alternation-round reset, freeze SURVIVES) and the surface-refresh reset
1019        // (new design, freeze re-zeroed). The end-to-end convergence on a real NB
1020        // fit is exercised by the public-API path; here we lock the invariant the
1021        // loop depends on, next to `reset_outer_seed_state_clears_pirls_cache`.
1022        use std::sync::atomic::Ordering;
1023
1024        let y = array![0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
1025        let w = Array1::<f64>::ones(y.len());
1026        let x = array![
1027            [1.0, -1.0, 0.2],
1028            [1.0, -0.5, -0.4],
1029            [1.0, 0.0, 0.7],
1030            [1.0, 0.4, -0.3],
1031            [1.0, 0.9, 0.1],
1032            [1.0, 1.3, -0.6],
1033        ];
1034        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.1, 0.15], [0.0, 0.15, 0.8],];
1035        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false);
1036        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1037
1038        // Simulate the alternation loop's re-freeze step: pin θ_final.
1039        let theta_final_bits = 2.5_f64.to_bits();
1040        state
1041            .frozen_negbin_theta
1042            .store(theta_final_bits, Ordering::Relaxed);
1043        assert_eq!(
1044            state.frozen_negbin_theta.load(Ordering::Relaxed),
1045            theta_final_bits,
1046            "precondition: the re-freeze stores θ_final into the frozen slot"
1047        );
1048
1049        // The alternation loop's per-round reset.
1050        state.reset_outer_seed_state();
1051
1052        assert_eq!(
1053            state.frozen_negbin_theta.load(Ordering::Relaxed),
1054            theta_final_bits,
1055            "reset_outer_seed_state (alternation-round reset) must PRESERVE the \
1056             re-frozen NB θ; clearing it would defeat the #1448 θ↔λ alternation \
1057             (the next ρ search would re-derive θ from the seed and never reach \
1058             the joint fixed point)"
1059        );
1060    }
1061
1062    #[test]
1063    pub(crate) fn implicit_hyper_design_derivative_respects_full_model_embedding() {
1064        let operator = ImplicitDesignPsiDerivative::new(
1065            array![1.0, 2.0, 3.0, 4.0],
1066            array![0.5, -1.0, 1.5, 2.0],
1067            array![0.1, 0.2, 0.3, 0.4],
1068            array![[1.0, 0.2], [0.5, 0.1], [1.5, 0.3], [2.0, 0.4]],
1069            None,
1070            None,
1071            2,
1072            2,
1073            1,
1074            2,
1075        );
1076        let local = operator
1077            .materialize_first(0)
1078            .expect("materialized first derivative");
1079        assert_eq!(
1080            local.ncols(),
1081            3,
1082            "operator-local derivative should stay smooth-local"
1083        );
1084
1085        let implicit = HyperDesignDerivative::from_implicit(
1086            Arc::new(operator),
1087            ImplicitDerivLevel::First(0),
1088            1..4,
1089            5,
1090        );
1091        let embedded = HyperDesignDerivative::from_embedded(local.clone(), 1..4, 5);
1092
1093        assert_eq!(implicit.nrows(), embedded.nrows());
1094        assert_eq!(implicit.ncols(), 5);
1095        assert_eq!(implicit.materialize(), embedded.materialize());
1096
1097        let u = array![7.0, 1.5, -2.0, 0.25, -3.0];
1098        let v = array![0.75, -1.25];
1099        assert_eq!(
1100            implicit.forward_mul_original(&u).expect("implicit forward"),
1101            embedded.forward_mul_original(&u).expect("embedded forward")
1102        );
1103        assert_eq!(
1104            implicit
1105                .transpose_mul_original(&v)
1106                .expect("implicit transpose"),
1107            embedded
1108                .transpose_mul_original(&v)
1109                .expect("embedded transpose")
1110        );
1111
1112        let qs = array![
1113            [1.0, 0.0, 0.0],
1114            [0.0, 1.0, 0.0],
1115            [0.0, 0.5, 0.5],
1116            [0.0, 0.0, 1.0],
1117            [0.0, 0.0, 0.0],
1118        ];
1119        assert_eq!(
1120            implicit
1121                .transformed(&qs, None)
1122                .expect("implicit transformed"),
1123            embedded
1124                .transformed(&qs, None)
1125                .expect("embedded transformed")
1126        );
1127        let u_transformed = array![1.0, -0.5, 2.0];
1128        assert_eq!(
1129            implicit
1130                .transformed_forward_mul(&qs, None, &u_transformed)
1131                .expect("implicit transformed forward"),
1132            embedded
1133                .transformed_forward_mul(&qs, None, &u_transformed)
1134                .expect("embedded transformed forward")
1135        );
1136        assert_eq!(
1137            implicit
1138                .transformed_transpose_mul(&qs, None, &v)
1139                .expect("implicit transformed transpose"),
1140            embedded
1141                .transformed_transpose_mul(&qs, None, &v)
1142                .expect("embedded transformed transpose")
1143        );
1144    }
1145
1146    #[test]
1147    pub(crate) fn directional_hyper_identities_match_finite_differences_logit() {
1148        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1149        let w = Array1::<f64>::ones(y.len());
1150        let x = array![
1151            [1.0, -1.2, 0.3],
1152            [1.0, -0.8, -0.4],
1153            [1.0, -0.3, 0.7],
1154            [1.0, 0.1, -0.9],
1155            [1.0, 0.5, 0.2],
1156            [1.0, 0.9, -0.1],
1157            [1.0, 1.3, 0.8],
1158            [1.0, 1.7, -0.6],
1159        ];
1160        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1161
1162        // Use one directional hyperparameter τ with a penalty perturbation:
1163        // S(τ) = S + τ S_τ.
1164        // Keep X_τ = 0 so this identity test remains valid in both non-Firth
1165        // and Firth-logit modes.
1166        let x_tau = Array2::<f64>::zeros(x.raw_dim());
1167        let s_tau = array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15],];
1168        let hyper =
1169            DirectionalHyperParam::single_penalty(0, x_tau.clone(), s_tau.clone(), None, None)
1170                .expect("single-penalty hyper direction");
1171        let rho = array![0.0];
1172
1173        // Tight inner tolerance: the envelope theorem requires an exact inner
1174        // P-IRLS optimum; 1e-10 leaves enough residual gradient to cause ~12%
1175        // V_tau mismatch on this small (n=8) logistic problem.
1176        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
1177        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1178        let bundle = state.obtain_eval_bundle(&rho).expect("bundle");
1179        let pr = bundle.pirls_result.as_ref();
1180
1181        let beta = beta_original_from_bundle(&bundle);
1182        let h_orig = h_original_from_bundle(&bundle);
1183        let u = &pr.solveweights * &(&pr.solveworking_response - &pr.final_eta);
1184
1185        // B from implicit solve:
1186        //   H B = X_τ^T g - X^T W(X_τ β̂) - S_τ β̂.
1187        let x_tau_beta = gam_linalg::faer_ndarray::fast_av(&x_tau, &beta);
1188        let weighted_x_tau_beta = &pr.finalweights * &x_tau_beta;
1189        let rhs = gam_linalg::faer_ndarray::fast_atv(&x_tau, &u)
1190            - gam_linalg::faer_ndarray::fast_atv(&x, &weighted_x_tau_beta)
1191            - s_tau.dot(&beta);
1192        let chol = h_orig.cholesky(Side::Lower).expect("chol(H)");
1193        let b_analytic = chol.solvevec(&rhs);
1194
1195        // H_τ from exact total derivative:
1196        //   H_τ = X_τ^T W X + X^T W X_τ + X^T W_τ X + S_τ,
1197        // with W_τ provided by the family directional curvature callback.
1198        let eta_dot = &x_tau_beta + &gam_linalg::faer_ndarray::fast_av(&x, &b_analytic);
1199        let w_direction = crate::pirls::directionalworking_curvature_from_c_array(
1200            &pr.solve_c_array,
1201            &pr.finalweights,
1202            &eta_dot,
1203        );
1204        let wx = RemlState::row_scale(&x, &pr.finalweights);
1205        let wx_tau = RemlState::row_scale(&x_tau, &pr.finalweights);
1206        let mut xwtau_x = x.clone();
1207        match w_direction {
1208            crate::pirls::DirectionalWorkingCurvature::Diagonal(diag) => {
1209                xwtau_x = RemlState::row_scale(&xwtau_x, &diag);
1210            }
1211        }
1212        let mut h_tau_analytic = gam_linalg::faer_ndarray::fast_atb(&x_tau, &wx);
1213        h_tau_analytic += &gam_linalg::faer_ndarray::fast_atb(&x, &wx_tau);
1214        h_tau_analytic += &gam_linalg::faer_ndarray::fast_atb(&x, &xwtau_x);
1215        h_tau_analytic += &s_tau;
1216
1217        // Fit-block stationarity cancellation:
1218        //   -ℓ_β^T B + β̂^T S B = 0.
1219        // Here S is the effective penalty in the inner Hessian surface:
1220        //   S = H - X^T W X.
1221        let ell_beta = gam_linalg::faer_ndarray::fast_atv(&x, &u);
1222        let s_eff = &h_orig - &gam_linalg::faer_ndarray::fast_atb(&x, &wx);
1223        let cancellation = -ell_beta.dot(&b_analytic) + beta.dot(&s_eff.dot(&b_analytic));
1224
1225        // Finite differences in τ against re-fit objective and mode.
1226        let h = 2e-5;
1227        let x_plus = &x + &(x_tau.mapv(|v| h * v));
1228        let x_minus = &x - &(x_tau.mapv(|v| h * v));
1229        let s_plus = &s0 + &(s_tau.mapv(|v| h * v));
1230        let s_minus = &s0 - &(s_tau.mapv(|v| h * v));
1231
1232        let state_plus = build_logit_state(&y, &w, &x_plus, &s_plus, &cfg);
1233        let state_minus = build_logit_state(&y, &w, &x_minus, &s_minus, &cfg);
1234        let bundle_plus = state_plus.obtain_eval_bundle(&rho).expect("bundle+");
1235        let bundle_minus = state_minus.obtain_eval_bundle(&rho).expect("bundle-");
1236        let beta_plus = beta_original_from_bundle(&bundle_plus);
1237        let beta_minus = beta_original_from_bundle(&bundle_minus);
1238        let bfd = (&beta_plus - &beta_minus).mapv(|v| v / (2.0 * h));
1239
1240        let h_plus = h_original_from_bundle(&bundle_plus);
1241        let h_minus = h_original_from_bundle(&bundle_minus);
1242        let h_taufd = (&h_plus - &h_minus).mapv(|v| v / (2.0 * h));
1243
1244        let v_plus = state_plus.compute_cost(&rho).expect("cost+");
1245        let v_minus = state_minus.compute_cost(&rho).expect("cost-");
1246        let v_taufd = (v_plus - v_minus) / (2.0 * h);
1247
1248        let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper.clone())
1249            .expect("analytic directional gradient");
1250
1251        let b_num = (&b_analytic - &bfd).mapv(|v| v * v).sum().sqrt();
1252        let b_den = bfd.mapv(|v| v * v).sum().sqrt().max(1e-12);
1253        let b_rel = b_num / b_den;
1254        for i in 0..b_analytic.len() {
1255            assert_eq!(
1256                b_analytic[i].signum(),
1257                bfd[i].signum(),
1258                "B sign mismatch at i={i}: analytic={} fd={}",
1259                b_analytic[i],
1260                bfd[i]
1261            );
1262        }
1263        assert!(
1264            b_rel < 2e-2,
1265            "B implicit solve mismatch vs FD: rel={b_rel:.3e}, num={b_num:.3e}, den={b_den:.3e}"
1266        );
1267
1268        let dh_num = (&h_tau_analytic - &h_taufd).mapv(|v| v * v).sum().sqrt();
1269        let dh_den = h_taufd.mapv(|v| v * v).sum().sqrt().max(1e-12);
1270        let dh_rel = dh_num / dh_den;
1271        for i in 0..h_tau_analytic.nrows() {
1272            for j in 0..h_tau_analytic.ncols() {
1273                assert_eq!(
1274                    h_tau_analytic[[i, j]].signum(),
1275                    h_taufd[[i, j]].signum(),
1276                    "H_tau sign mismatch at ({i},{j}): analytic={} fd={}",
1277                    h_tau_analytic[[i, j]],
1278                    h_taufd[[i, j]]
1279                );
1280            }
1281        }
1282        assert!(
1283            dh_rel < 3e-2,
1284            "H_tau mismatch vs FD: rel={dh_rel:.3e}, num={dh_num:.3e}, den={dh_den:.3e}"
1285        );
1286
1287        let v_abs = (v_tau_analytic - v_taufd).abs();
1288        let v_rel = v_abs / v_taufd.abs().max(1e-10);
1289        assert_eq!(
1290            v_tau_analytic.signum(),
1291            v_taufd.signum(),
1292            "V_tau sign mismatch: analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
1293        );
1294        assert!(
1295            v_rel < 2e-2,
1296            "V_tau mismatch vs FD: rel={v_rel:.3e}, abs={v_abs:.3e}, analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
1297        );
1298
1299        assert!(
1300            cancellation.abs() < 1e-10,
1301            "stationarity cancellation failed: | -ell_beta^T B + beta^T S B | = {:.3e}",
1302            cancellation.abs()
1303        );
1304    }
1305
1306    #[test]
1307    pub(crate) fn firth_exacthessian_includes_analytic_tk_second_derivatives() {
1308        // Rank-deficient X: the 4th column is 2x the 2nd column.
1309        let y = array![0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0];
1310        let w = Array1::<f64>::ones(y.len());
1311        let x = array![
1312            [1.0, -1.2, 0.4, -2.4],
1313            [1.0, -0.9, -0.1, -1.8],
1314            [1.0, -0.6, 0.3, -1.2],
1315            [1.0, -0.2, -0.4, -0.4],
1316            [1.0, 0.1, 0.5, 0.2],
1317            [1.0, 0.4, -0.6, 0.8],
1318            [1.0, 0.8, 0.2, 1.6],
1319            [1.0, 1.1, -0.3, 2.2],
1320            [1.0, 1.4, 0.7, 2.8],
1321            [1.0, 1.7, -0.2, 3.4],
1322        ];
1323        let s0 = array![
1324            [0.0, 0.0, 0.0, 0.0],
1325            [0.0, 1.5, 0.2, 0.0],
1326            [0.0, 0.2, 1.0, 0.0],
1327            [0.0, 0.0, 0.0, 0.5],
1328        ];
1329        let s1 = array![
1330            [0.0, 0.0, 0.0, 0.0],
1331            [0.0, 0.8, -0.1, 0.0],
1332            [0.0, -0.1, 0.6, 0.0],
1333            [0.0, 0.0, 0.0, 0.3],
1334        ];
1335        let offset = Array1::<f64>::zeros(y.len());
1336        // Rank-deficient Firth logit needs more inner iterations to converge
1337        // tightly enough for the envelope-theorem derivative tests.
1338        let cfg =
1339            RemlConfig::external(binomial_logit_glm_spec(), 1e-9, true).with_max_iterations(500);
1340        let p = x.ncols();
1341        use crate::estimate::PenaltySpec;
1342        let specs = vec![PenaltySpec::Dense(s0), PenaltySpec::Dense(s1)];
1343        let canonical = gam_terms::construction::canonicalize_penalty_specs(&specs, &[1, 1], p, "test")
1344            .map(|(canonical, _)| canonical)
1345            .expect("canonicalize");
1346        let state = RemlState::newwith_offset(
1347            y.view(),
1348            x.clone(),
1349            w.view(),
1350            offset.view(),
1351            canonical,
1352            p,
1353            &cfg,
1354            Some(vec![1, 1]),
1355            None,
1356            None,
1357        )
1358        .expect("state");
1359        let rho = array![0.1, -0.2];
1360        assert!(
1361            state.analytic_outer_hessian_enabled(),
1362            "Firth logit should no longer disable analytic outer Hessian planning"
1363        );
1364        let outer = state
1365            .compute_outer_eval_with_order(
1366                &rho,
1367                crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
1368            )
1369            .expect("outer Hessian eval should succeed");
1370        assert!(
1371            outer.hessian.is_analytic(),
1372            "outer planner should request and return an analytic Hessian"
1373        );
1374        let bundle = state.obtain_eval_bundle(&rho).expect("exact firth bundle");
1375        let h_dense = state
1376            .compute_lamlhessian_exact_from_bundle(&rho, &bundle)
1377            .expect("Firth exact Hessian should include analytic TK second derivatives");
1378        assert_eq!(h_dense.raw_dim(), ndarray::Ix2(2, 2));
1379        assert!(
1380            h_dense.iter().all(|value| value.is_finite()),
1381            "Hessian should be finite: {h_dense:?}"
1382        );
1383    }
1384
1385    #[test]
1386    pub(crate) fn firth_outer_hessian_matches_gradient_finite_difference_with_tk_terms() {
1387        let y = array![0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
1388        let w = Array1::<f64>::ones(y.len());
1389        let x = array![
1390            [1.0, -1.0, 0.3],
1391            [1.0, -0.7, -0.2],
1392            [1.0, -0.3, 0.4],
1393            [1.0, 0.0, -0.5],
1394            [1.0, 0.2, 0.6],
1395            [1.0, 0.6, -0.4],
1396            [1.0, 0.9, 0.2],
1397            [1.0, 1.3, -0.1],
1398        ];
1399        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.1], [0.0, 0.1, 0.7],];
1400        let s1 = array![[0.0, 0.0, 0.0], [0.0, 0.4, -0.05], [0.0, -0.05, 0.9],];
1401        let cfg =
1402            RemlConfig::external(binomial_logit_glm_spec(), 1e-9, true).with_max_iterations(500);
1403        let p_dim = x.ncols();
1404        use crate::estimate::PenaltySpec;
1405        let specs = vec![PenaltySpec::Dense(s0), PenaltySpec::Dense(s1)];
1406        let canonical =
1407            gam_terms::construction::canonicalize_penalty_specs(&specs, &[1, 1], p_dim, "test")
1408                .map(|(canonical, _)| canonical)
1409                .expect("canonicalize");
1410        let offset = Array1::<f64>::zeros(y.len());
1411        let state = RemlState::newwith_offset(
1412            y.view(),
1413            x.clone(),
1414            w.view(),
1415            offset.view(),
1416            canonical,
1417            p_dim,
1418            &cfg,
1419            Some(vec![1, 1]),
1420            None,
1421            None,
1422        )
1423        .expect("state");
1424        let rho = array![0.15, -0.25];
1425        let eval = state
1426            .compute_outer_eval_with_order(
1427                &rho,
1428                crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian,
1429            )
1430            .expect("analytic Hessian eval");
1431        let h = match eval.hessian {
1432            HessianResult::Analytic(hessian) => hessian,
1433            HessianResult::Operator(_) | HessianResult::Unavailable => {
1434                panic!("expected dense analytic Hessian")
1435            }
1436        };
1437        let delta = 2.0e-5;
1438        for col in 0..rho.len() {
1439            let mut rp = rho.clone();
1440            let mut rm = rho.clone();
1441            rp[col] += delta;
1442            rm[col] -= delta;
1443            let gp = state
1444                .compute_outer_eval_with_order(
1445                    &rp,
1446                    crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
1447                )
1448                .expect("plus grad")
1449                .gradient;
1450            let gm = state
1451                .compute_outer_eval_with_order(
1452                    &rm,
1453                    crate::rho_optimizer::OuterEvalOrder::ValueAndGradient,
1454                )
1455                .expect("minus grad")
1456                .gradient;
1457            for row in 0..rho.len() {
1458                let fd = (gp[row] - gm[row]) / (2.0 * delta);
1459                let an = h[[row, col]];
1460                let rel = (fd - an).abs() / fd.abs().max(an.abs()).max(1e-6);
1461                assert!(
1462                    rel < 2.0e-3,
1463                    "Hessian mismatch ({row},{col}): analytic={an:.9e}, fd={fd:.9e}, rel={rel:.3e}"
1464                );
1465            }
1466        }
1467    }
1468
1469    #[test]
1470    pub(crate) fn firthgradient_lives_in_design_column_space_under_rank_deficiency() {
1471        // Rank-deficient design: col4 = 2*col2.
1472        let x = array![
1473            [1.0, -1.2, 0.4, -2.4],
1474            [1.0, -0.9, -0.1, -1.8],
1475            [1.0, -0.6, 0.3, -1.2],
1476            [1.0, -0.2, -0.4, -0.4],
1477            [1.0, 0.1, 0.5, 0.2],
1478            [1.0, 0.4, -0.6, 0.8],
1479            [1.0, 0.8, 0.2, 1.6],
1480            [1.0, 1.1, -0.3, 2.2],
1481        ];
1482        let beta = array![0.1, -0.2, 0.3, 0.05];
1483        let eta = x.dot(&beta);
1484        let op = super::RemlState::build_firth_dense_operator_for_link(
1485            &gam_problem::InverseLink::Standard(gam_problem::StandardLink::Logit),
1486            &x,
1487            &eta,
1488            ndarray::Array1::ones(x.nrows()).view(),
1489        )
1490        .expect("firth operator");
1491
1492        // Exact reduced-space Firth gradient:
1493        //   gradPhi = 0.5 Xᵀ (w' ⊙ h), with h = diag(X_r K_r X_rᵀ).
1494        let gradphi = 0.5 * x.t().dot(&(&op.w1 * &op.h_diag));
1495
1496        // Check (I - QQᵀ) gradPhi ≈ 0.
1497        let q = &op.q_basis;
1498        let proj = q.dot(&q.t().dot(&gradphi));
1499        let resid = &gradphi - &proj;
1500        let rel =
1501            resid.mapv(|v| v * v).sum().sqrt() / gradphi.mapv(|v| v * v).sum().sqrt().max(1e-12);
1502        assert!(
1503            rel < 1e-10,
1504            "Firth gradient should lie in Col(Xᵀ): rel residual={rel:.3e}"
1505        );
1506    }
1507
1508    #[test]
1509    pub(crate) fn firth_logit_directional_hypergradient_accepts_penalty_only_with_full_tk_gradient()
1510    {
1511        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1512        let w = Array1::<f64>::ones(y.len());
1513        let x = array![
1514            [1.0, -1.1, 0.2],
1515            [1.0, -0.6, -0.3],
1516            [1.0, -0.1, 0.5],
1517            [1.0, 0.3, -0.7],
1518            [1.0, 0.8, 0.1],
1519            [1.0, 1.2, -0.4],
1520        ];
1521        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1522        let hyper = DirectionalHyperParam::single_penalty(
1523            0,
1524            Array2::<f64>::zeros((x.nrows(), x.ncols())),
1525            array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1526            None,
1527            None,
1528        )
1529        .expect("single-penalty hyper direction");
1530        let rho = array![0.0];
1531        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1532        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1533        let gradient = single_directional_tau_gradient(&state, &rho, hyper)
1534            .expect("Firth penalty-only directional gradient should use analytic TK propagation");
1535        assert!(gradient.is_finite(), "gradient={gradient}");
1536        let fd = fd_directional_tau_cost_gradient(
1537            &y,
1538            &w,
1539            &x,
1540            &s0,
1541            &cfg,
1542            &rho,
1543            &Array2::<f64>::zeros((x.nrows(), x.ncols())),
1544            &array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1545        );
1546        let rel = (gradient - fd).abs() / gradient.abs().max(fd.abs()).max(1.0e-10);
1547        assert!(
1548            rel < 1.0e-3,
1549            "Firth penalty-only directional gradient mismatch: analytic={gradient:.12e}, fd={fd:.12e}, rel={rel:.3e}"
1550        );
1551
1552        let efs_hyper = DirectionalHyperParam::single_penalty(
1553            0,
1554            Array2::<f64>::zeros((x.nrows(), x.ncols())),
1555            array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.03], [0.0, 0.03, 0.12],],
1556            None,
1557            None,
1558        )
1559        .expect("single-penalty EFS hyper direction");
1560        let efs = state
1561            .compute_efs_steps_with_psi_ext(&rho, &[efs_hyper])
1562            .expect("Firth penalty-only EFS should use analytic TK propagation");
1563        assert!(efs.cost.is_finite(), "efs cost={}", efs.cost);
1564    }
1565
1566    #[test]
1567    pub(crate) fn firth_logit_directional_hypergradient_accepts_design_moving_with_full_tk_gradient()
1568     {
1569        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1570        let w = Array1::<f64>::ones(y.len());
1571        let x = array![
1572            [1.0, -1.1, 0.2],
1573            [1.0, -0.6, -0.3],
1574            [1.0, -0.1, 0.5],
1575            [1.0, 0.3, -0.7],
1576            [1.0, 0.8, 0.1],
1577            [1.0, 1.2, -0.4],
1578        ];
1579        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1580        let hyper = DirectionalHyperParam::single_penalty(
1581            0,
1582            Array2::from_elem((x.nrows(), x.ncols()), 1e-3),
1583            Array2::<f64>::zeros((x.ncols(), x.ncols())),
1584            None,
1585            None,
1586        )
1587        .expect("single-penalty hyper direction");
1588        let rho = array![0.0];
1589        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1590        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1591        let gradient = single_directional_tau_gradient(&state, &rho, hyper)
1592            .expect("Firth design-moving directional gradient should use analytic TK propagation");
1593        assert!(gradient.is_finite(), "gradient={gradient}");
1594        let x_tau = Array2::from_elem((x.nrows(), x.ncols()), 1e-3);
1595        let s_tau = Array2::<f64>::zeros((x.ncols(), x.ncols()));
1596        let fd = fd_directional_tau_cost_gradient(&y, &w, &x, &s0, &cfg, &rho, &x_tau, &s_tau);
1597        let rel = (gradient - fd).abs() / gradient.abs().max(fd.abs()).max(1.0e-10);
1598        assert!(
1599            rel < 2.0e-2,
1600            "Firth design-moving directional gradient mismatch: analytic={gradient:.12e}, fd={fd:.12e}, rel={rel:.3e}"
1601        );
1602    }
1603
1604    #[test]
1605    pub(crate) fn firth_logit_hybrid_efs_accepts_full_tk_psi_gradient() {
1606        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1607        let w = Array1::<f64>::ones(y.len());
1608        let x = array![
1609            [1.0, -1.1, 0.2],
1610            [1.0, -0.6, -0.3],
1611            [1.0, -0.1, 0.5],
1612            [1.0, 0.3, -0.7],
1613            [1.0, 0.8, 0.1],
1614            [1.0, 1.2, -0.4],
1615        ];
1616        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1617        let hyper_dirs = vec![
1618            DirectionalHyperParam::single_penalty(
1619                0,
1620                Array2::from_shape_fn((x.nrows(), x.ncols()), |(i, j)| {
1621                    1e-3 * ((i + 1) as f64) * ((j + 2) as f64)
1622                }),
1623                Array2::<f64>::zeros((x.ncols(), x.ncols())),
1624                None,
1625                None,
1626            )
1627            .expect("design-moving hyper direction"),
1628        ];
1629        let rho = array![0.0];
1630        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-8, true);
1631        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1632
1633        let full = state
1634            .evaluate_unified_with_psi_ext(
1635                &rho,
1636                None,
1637                crate::estimate::reml::reml_outer_engine::EvalMode::ValueAndGradient,
1638                &hyper_dirs,
1639            )
1640            .expect("full Firth psi gradient should use analytic TK propagation");
1641        assert!(full.cost.is_finite(), "full cost={}", full.cost);
1642        let full_grad = full.gradient.expect("gradient should be present");
1643        assert!(
1644            full_grad.iter().all(|value| value.is_finite()),
1645            "full gradient={full_grad:?}"
1646        );
1647
1648        let efs = state
1649            .compute_efs_steps_with_psi_ext(&rho, &hyper_dirs)
1650            .expect("hybrid EFS should use analytic TK propagation");
1651        assert!(efs.cost.is_finite(), "efs cost={}", efs.cost);
1652    }
1653
1654    #[test]
1655    pub(crate) fn joint_hyperhessianwires_mixed_blocks() {
1656        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1657        let w = Array1::<f64>::ones(y.len());
1658        let x = array![
1659            [1.0, -1.2, 0.3],
1660            [1.0, -0.8, -0.4],
1661            [1.0, -0.3, 0.7],
1662            [1.0, 0.1, -0.9],
1663            [1.0, 0.5, 0.2],
1664            [1.0, 0.9, -0.1],
1665            [1.0, 1.3, 0.8],
1666            [1.0, 1.7, -0.6],
1667        ];
1668        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1669        let cfg =
1670            RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
1671        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1672        let rho = array![0.0];
1673        let theta = array![0.0, 0.0, 0.0];
1674        let hyper_dirs = vec![
1675            DirectionalHyperParam::single_penalty(
1676                0,
1677                Array2::<f64>::zeros((x.nrows(), x.ncols())),
1678                array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
1679                None,
1680                None,
1681            )
1682            .expect("single-penalty hyper direction"),
1683            DirectionalHyperParam::single_penalty(
1684                0,
1685                Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1686                Array2::<f64>::zeros((x.ncols(), x.ncols())),
1687                None,
1688                None,
1689            )
1690            .expect("single-penalty hyper direction"),
1691        ];
1692
1693        let (_, _, h) =
1694            compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
1695                .expect("joint hyper cost+gradient+hessian");
1696        assert_eq!(h.nrows(), theta.len());
1697        assert_eq!(h.ncols(), theta.len());
1698        assert!(h.iter().all(|v| v.is_finite()));
1699        for i in 0..h.nrows() {
1700            for j in 0..i {
1701                let diff = (h[[i, j]] - h[[j, i]]).abs();
1702                assert!(
1703                    diff < 1e-6,
1704                    "joint hessian asymmetry at ({i},{j}): {diff:.3e}"
1705                );
1706            }
1707        }
1708        // Mixed block must be nontrivial for at least one supplied direction.
1709        let mixed_0 = h[[0, 1]];
1710        let mixed_1 = h[[0, 2]];
1711        assert!(
1712            mixed_0.is_finite() && mixed_1.is_finite(),
1713            "mixed blocks must be finite"
1714        );
1715    }
1716
1717    #[test]
1718    pub(crate) fn joint_tau_tau_linear_dirs_matchfd_reference_away_fromzero_psi() {
1719        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1720        let w = Array1::<f64>::ones(y.len());
1721        let x = array![
1722            [1.0, -1.2, 0.3],
1723            [1.0, -0.8, -0.4],
1724            [1.0, -0.3, 0.7],
1725            [1.0, 0.1, -0.9],
1726            [1.0, 0.5, 0.2],
1727            [1.0, 0.9, -0.1],
1728            [1.0, 1.3, 0.8],
1729            [1.0, 1.7, -0.6],
1730        ];
1731        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1732        let cfg =
1733            RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
1734        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1735        let rho = array![0.0];
1736        let psi = array![0.7, -0.4];
1737        let theta = array![rho[0], psi[0], psi[1]];
1738        let hyper_dirs = vec![
1739            DirectionalHyperParam::single_penalty(
1740                0,
1741                Array2::<f64>::zeros((x.nrows(), x.ncols())),
1742                array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
1743                None,
1744                None,
1745            )
1746            .expect("linear tau direction"),
1747            DirectionalHyperParam::single_penalty(
1748                0,
1749                Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1750                Array2::<f64>::zeros((x.ncols(), x.ncols())),
1751                None,
1752                None,
1753            )
1754            .expect("linear tau direction"),
1755        ];
1756
1757        let (_, _, h_full) =
1758            compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
1759                .expect("joint hyper cost+gradient+hessian");
1760        let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
1761
1762        // FD via physical perturbation of design/penalty matrices (matching
1763        // the V_tau FD pattern).  For column j we perturb X and S₀ along
1764        // direction j, build fresh states, and evaluate the τ-gradient for
1765        // every direction i at those perturbed states.
1766        let x_tau_mats: Vec<Array2<f64>> = vec![
1767            Array2::<f64>::zeros((x.nrows(), x.ncols())),
1768            Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1769        ];
1770        let s_tau_mats: Vec<Array2<f64>> = vec![
1771            array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15]],
1772            Array2::<f64>::zeros((x.ncols(), x.ncols())),
1773        ];
1774
1775        let h_ttfd = directional_tau_hessian_fd_reference(
1776            &y,
1777            &w,
1778            &x,
1779            &s0,
1780            &cfg,
1781            &rho,
1782            &hyper_dirs,
1783            &x_tau_mats,
1784            &s_tau_mats,
1785        );
1786
1787        let num = (&h_tt_analytic - &h_ttfd)
1788            .iter()
1789            .map(|v| v * v)
1790            .sum::<f64>()
1791            .sqrt();
1792        let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
1793        let rel = num / den;
1794        assert!(
1795            rel < 1e-4,
1796            "linear-dir joint tau-tau block deviates from FD reference away from zero psi: rel={rel:.3e}, analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
1797        );
1798    }
1799
1800    #[test]
1801    pub(crate) fn joint_hypervalidation_rejects_out_of_boundssecond_order_penalty_index() {
1802        // The hyper direction declares a second-order penalty derivative
1803        // against base penalty index 1, but the configured ρ vector has
1804        // dimension 1 (so only index 0 is valid).  The pair-callback
1805        // builder in `build_tau_penalty_derivative_data` is responsible for
1806        // validating both first- and second-order penalty indices against
1807        // `rho.len()`; this test pins that contract.
1808        //
1809        // We deliberately keep `firth_bias_reduction = true` here so the
1810        // call site exercises the full Firth/Tierney–Kadane outer pipeline:
1811        // PIRLS + ext-coord construction + pair-callback assembly.  With
1812        // analytic c/d propagation now wired in
1813        // `tk_direct_gradient_from_cd_and_design`, there is no longer any
1814        // FD-fallback rejection on this path, so the out-of-bounds error
1815        // fired by the pair-callback builder is the first failure the
1816        // joint evaluator surfaces — and that is exactly what we want this
1817        // test to assert.
1818        let y = array![0.0, 1.0, 0.0, 1.0];
1819        let w = Array1::<f64>::ones(y.len());
1820        let x = array![
1821            [1.0, -0.5, 0.2],
1822            [1.0, -0.1, -0.3],
1823            [1.0, 0.4, 0.6],
1824            [1.0, 0.9, -0.2],
1825        ];
1826        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.1], [0.0, 0.1, 0.8],];
1827        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-10, true);
1828        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1829        let theta = array![0.0, 0.0];
1830        let hyper_dirs = vec![
1831            DirectionalHyperParam::new(
1832                Array2::<f64>::zeros((x.nrows(), x.ncols())),
1833                vec![(0, Array2::<f64>::zeros((x.ncols(), x.ncols())))],
1834                None,
1835                Some(vec![Some(vec![(1, Array2::<f64>::eye(x.ncols()))])]),
1836            )
1837            .expect("hyper direction with invalid second-order penalty index"),
1838        ];
1839
1840        let msg = match compute_joint_hypercostgradienthessian(&state, &theta, 1, &hyper_dirs) {
1841            Ok(_) => panic!("invalid second-order penalty index should be rejected"),
1842            Err(err) => err.to_string(),
1843        };
1844        assert!(
1845            msg.contains("out of bounds") || msg.contains("penalty_index"),
1846            "unexpected validation error: {msg}"
1847        );
1848    }
1849
1850    #[test]
1851    pub(crate) fn joint_tau_tau_analytic_matchesfd_reference() {
1852        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
1853        let w = Array1::<f64>::ones(y.len());
1854        let x = array![
1855            [1.0, -1.2, 0.3],
1856            [1.0, -0.8, -0.4],
1857            [1.0, -0.3, 0.7],
1858            [1.0, 0.1, -0.9],
1859            [1.0, 0.5, 0.2],
1860            [1.0, 0.9, -0.1],
1861            [1.0, 1.3, 0.8],
1862            [1.0, 1.7, -0.6],
1863        ];
1864        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9],];
1865        let cfg =
1866            RemlConfig::external(binomial_logit_glm_spec(), 1e-10, false).with_max_iterations(500);
1867        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
1868        let rho = array![0.0];
1869        let psi = array![0.0, 0.0];
1870        let hyper_dirs = vec![
1871            DirectionalHyperParam::single_penalty(
1872                0,
1873                Array2::<f64>::zeros((x.nrows(), x.ncols())),
1874                array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15],],
1875                None,
1876                None,
1877            )
1878            .expect("single-penalty hyper direction"),
1879            DirectionalHyperParam::single_penalty(
1880                0,
1881                Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1882                Array2::<f64>::zeros((x.ncols(), x.ncols())),
1883                None,
1884                None,
1885            )
1886            .expect("single-penalty hyper direction"),
1887        ];
1888
1889        let theta = {
1890            let mut t = Array1::<f64>::zeros(rho.len() + psi.len());
1891            t.slice_mut(s![..rho.len()]).assign(&rho);
1892            t.slice_mut(s![rho.len()..]).assign(&psi);
1893            t
1894        };
1895        let (_, _, h_full) =
1896            compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
1897                .expect("joint hyper cost+gradient+hessian");
1898        let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
1899        assert_eq!(h_tt_analytic.nrows(), hyper_dirs.len());
1900        assert_eq!(h_tt_analytic.ncols(), hyper_dirs.len());
1901
1902        // FD via physical perturbation of design/penalty matrices (matching
1903        // the V_tau FD pattern).  For column j we perturb X and S₀ along
1904        // direction j, build fresh states, and evaluate the τ-gradient for
1905        // every direction i at those perturbed states.
1906        let x_tau_mats: Vec<Array2<f64>> = vec![
1907            Array2::<f64>::zeros((x.nrows(), x.ncols())),
1908            Array2::from_elem((x.nrows(), x.ncols()), 2e-4),
1909        ];
1910        let s_tau_mats: Vec<Array2<f64>> = vec![
1911            array![[0.0, 0.0, 0.0], [0.0, 0.2, 0.01], [0.0, 0.01, 0.15]],
1912            Array2::<f64>::zeros((x.ncols(), x.ncols())),
1913        ];
1914
1915        let h_ttfd = directional_tau_hessian_fd_reference(
1916            &y,
1917            &w,
1918            &x,
1919            &s0,
1920            &cfg,
1921            &rho,
1922            &hyper_dirs,
1923            &x_tau_mats,
1924            &s_tau_mats,
1925        );
1926
1927        let num = (&h_tt_analytic - &h_ttfd)
1928            .iter()
1929            .map(|v| v * v)
1930            .sum::<f64>()
1931            .sqrt();
1932        let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
1933        let rel = num / den;
1934        assert!(
1935            rel < 1e-4,
1936            "analytic tau-tau block deviates from FD reference: rel={rel:.3e}, analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
1937        );
1938    }
1939
1940    // ── Profiled Gaussian REML coverage for design-moving τ-directions ──
1941    //
1942    // The existing directional-hyper tests all use BinomialLogit, which has
1943    // DispersionHandling::Fixed.  These tests validate the profiled Gaussian
1944    // path (DispersionHandling::ProfiledGaussian) with design-moving
1945    // τ-directions, where the profiled scale φ̂ = D_p/(n−M) depends on ρ
1946    // and the envelope-theorem rescaling by (n−M)/D_p must be correct.
1947
1948    /// Shared test fixture for profiled Gaussian REML tests.
1949    pub(crate) struct GaussianRemlFixture {
1950        pub(crate) y: Array1<f64>,
1951        pub(crate) w: Array1<f64>,
1952        pub(crate) x: Array2<f64>,
1953        pub(crate) s0: Array2<f64>,
1954        pub(crate) cfg: RemlConfig,
1955        pub(crate) rho: Array1<f64>,
1956        /// Design-moving τ-direction (non-zero X_τ, zero S_τ).
1957        pub(crate) x_tau_design: Array2<f64>,
1958        /// Penalty-only τ-direction (zero X_τ, non-zero S_τ).
1959        pub(crate) s_tau_penalty: Array2<f64>,
1960    }
1961
1962    impl GaussianRemlFixture {
1963        pub(crate) fn new() -> Self {
1964            let y = array![0.5, 1.2, -0.3, 0.8, 1.1, -0.6, 0.9, 0.1, -0.2, 0.7];
1965            let x = array![
1966                [1.0, -1.2, 0.3],
1967                [1.0, -0.8, -0.4],
1968                [1.0, -0.3, 0.7],
1969                [1.0, 0.1, -0.9],
1970                [1.0, 0.5, 0.2],
1971                [1.0, 0.9, -0.1],
1972                [1.0, 1.3, 0.8],
1973                [1.0, 1.7, -0.6],
1974                [1.0, -0.5, 0.5],
1975                [1.0, 0.3, -0.3],
1976            ];
1977            Self {
1978                w: Array1::<f64>::ones(y.len()),
1979                y,
1980                x: x.clone(),
1981                s0: array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]],
1982                cfg: RemlConfig::external(gaussian_identity_glm_spec(), 1e-14, false),
1983                rho: array![0.0],
1984                x_tau_design: array![
1985                    [0.0, 1e-3, -2e-3],
1986                    [0.0, -3e-3, 1e-3],
1987                    [0.0, 2e-3, 0.5e-3],
1988                    [0.0, -1e-3, 3e-3],
1989                    [0.0, 0.5e-3, -1e-3],
1990                    [0.0, 1.5e-3, 2e-3],
1991                    [0.0, -2e-3, -0.5e-3],
1992                    [0.0, 3e-3, 1e-3],
1993                    [0.0, -0.5e-3, 2e-3],
1994                    [0.0, 1e-3, -1.5e-3],
1995                ],
1996                s_tau_penalty: array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15]],
1997            }
1998        }
1999    }
2000
2001    impl LogitDesignMotionFixture for GaussianRemlFixture {
2002        fn y(&self) -> &Array1<f64> {
2003            &self.y
2004        }
2005        fn w(&self) -> &Array1<f64> {
2006            &self.w
2007        }
2008        fn x(&self) -> &Array2<f64> {
2009            &self.x
2010        }
2011        fn s0(&self) -> &Array2<f64> {
2012            &self.s0
2013        }
2014        fn cfg(&self) -> &RemlConfig {
2015            &self.cfg
2016        }
2017        fn rho(&self) -> &Array1<f64> {
2018            &self.rho
2019        }
2020    }
2021
2022    #[test]
2023    pub(crate) fn profiled_gaussian_design_moving_gradient_matches_fd() {
2024        let f = GaussianRemlFixture::new();
2025        let state = f.state();
2026        let s_tau = Array2::<f64>::zeros((3, 3));
2027        let hyper = DirectionalHyperParam::single_penalty(
2028            0,
2029            f.x_tau_design.clone(),
2030            s_tau.clone(),
2031            None,
2032            None,
2033        )
2034        .expect("design-moving hyper direction");
2035
2036        let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2037            .expect("analytic directional gradient");
2038        let v_taufd = f.fd_directional_gradient(&f.x_tau_design, &s_tau);
2039
2040        let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2041        assert!(
2042            v_rel < 1e-3,
2043            "Gaussian REML design-moving V_tau mismatch: rel={v_rel:.3e}, \
2044             analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2045        );
2046    }
2047
2048    #[test]
2049    pub(crate) fn profiled_gaussian_penalty_only_gradient_matches_fd() {
2050        let f = GaussianRemlFixture::new();
2051        let state = f.state();
2052        let x_tau = Array2::<f64>::zeros(f.x.raw_dim());
2053        let hyper = DirectionalHyperParam::single_penalty(
2054            0,
2055            x_tau.clone(),
2056            f.s_tau_penalty.clone(),
2057            None,
2058            None,
2059        )
2060        .expect("penalty-only hyper direction");
2061
2062        let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2063            .expect("analytic directional gradient");
2064        let v_taufd = f.fd_directional_gradient(&x_tau, &f.s_tau_penalty);
2065
2066        let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2067        assert!(
2068            v_rel < 1e-3,
2069            "Gaussian REML penalty-only V_tau mismatch: rel={v_rel:.3e}, \
2070             analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2071        );
2072    }
2073
2074    #[test]
2075    pub(crate) fn profiled_gaussian_joint_hessian_matches_fd() {
2076        // Validate the ττ Hessian block under profiled Gaussian REML with
2077        // both a penalty-only and a design-moving direction.
2078        let f = GaussianRemlFixture::new();
2079        let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2080        let s_tau_0 = f.s_tau_penalty.clone();
2081        let x_tau_1 = f.x_tau_design.clone();
2082        let s_tau_1 = Array2::<f64>::zeros((3, 3));
2083
2084        let hyper_dirs = vec![
2085            DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2086                .expect("penalty-only direction"),
2087            DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2088                .expect("design-moving direction"),
2089        ];
2090
2091        let state = f.state();
2092        let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2093        theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2094        let (_, _, h_full) =
2095            compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2096                .expect("joint cost+gradient+hessian");
2097        let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2098
2099        // Finite-difference Hessian: perturb each direction, re-evaluate
2100        // gradient of all directions at perturbed states.
2101        let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2102        let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2103        let h_ttfd = directional_tau_hessian_fd_reference(
2104            &f.y,
2105            &f.w,
2106            &f.x,
2107            &f.s0,
2108            &f.cfg,
2109            &f.rho,
2110            &hyper_dirs,
2111            &x_tau_mats,
2112            &s_tau_mats,
2113        );
2114
2115        let num = (&h_tt_analytic - &h_ttfd)
2116            .iter()
2117            .map(|v| v * v)
2118            .sum::<f64>()
2119            .sqrt();
2120        let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2121        let rel = num / den;
2122        assert!(
2123            rel < 1e-4,
2124            "Gaussian REML tau-tau Hessian mismatch: rel={rel:.3e}, \
2125             analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2126        );
2127    }
2128
2129    // ── Non-Gaussian + design-motion: IFT Hessian-drift coverage ────────
2130    //
2131    // For non-Gaussian links (logit, probit, cloglog, ...), H = X'W(η)X + S
2132    // depends on β̂ through η = Xβ̂.  When ψ moves the design, the total
2133    // Hessian drift dH/dψ includes an IFT contribution from dβ̂/dψ:
2134    //
2135    //   dH/dψ = [explicit at fixed β] + X' diag(c ⊙ X(-v_i)) X
2136    //
2137    // where v_i = H⁻¹ g_i.  The standard GLM path handles this via
2138    // `hessian_derivative_correction(v_i)`.  This test validates that the
2139    // gradient is correct for logit + design-moving ψ, which would fail if
2140    // the IFT correction were missing.
2141
2142    #[test]
2143    pub(crate) fn logit_design_moving_gradient_matches_fd() {
2144        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
2145        let w = Array1::<f64>::ones(y.len());
2146        let x = array![
2147            [1.0, -1.2, 0.3],
2148            [1.0, -0.8, -0.4],
2149            [1.0, -0.3, 0.7],
2150            [1.0, 0.1, -0.9],
2151            [1.0, 0.5, 0.2],
2152            [1.0, 0.9, -0.1],
2153            [1.0, 1.3, 0.8],
2154            [1.0, 1.7, -0.6],
2155            [1.0, -0.5, 0.5],
2156            [1.0, 0.3, -0.3],
2157        ];
2158        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]];
2159        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2160        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2161        let rho = array![0.0];
2162
2163        // Design-moving direction with non-zero X_τ.
2164        let x_tau = array![
2165            [0.0, 1e-3, -2e-3],
2166            [0.0, -3e-3, 1e-3],
2167            [0.0, 2e-3, 0.5e-3],
2168            [0.0, -1e-3, 3e-3],
2169            [0.0, 0.5e-3, -1e-3],
2170            [0.0, 1.5e-3, 2e-3],
2171            [0.0, -2e-3, -0.5e-3],
2172            [0.0, 3e-3, 1e-3],
2173            [0.0, -0.5e-3, 2e-3],
2174            [0.0, 1e-3, -1.5e-3],
2175        ];
2176        let s_tau = Array2::<f64>::zeros((3, 3));
2177        let hyper =
2178            DirectionalHyperParam::single_penalty(0, x_tau.clone(), s_tau.clone(), None, None)
2179                .expect("design-moving hyper direction");
2180
2181        let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper)
2182            .expect("analytic directional gradient");
2183
2184        let h = 2e-5;
2185        let x_plus = &x + &x_tau.mapv(|v| h * v);
2186        let x_minus = &x - &x_tau.mapv(|v| h * v);
2187        let state_plus = build_logit_state(&y, &w, &x_plus, &s0, &cfg);
2188        let state_minus = build_logit_state(&y, &w, &x_minus, &s0, &cfg);
2189        let v_plus = state_plus.compute_cost(&rho).expect("cost+");
2190        let v_minus = state_minus.compute_cost(&rho).expect("cost-");
2191        let v_taufd = (v_plus - v_minus) / (2.0 * h);
2192
2193        let v_rel = (v_tau_analytic - v_taufd).abs() / v_taufd.abs().max(1e-10);
2194        assert!(
2195            v_rel < 1e-3,
2196            "Logit REML design-moving V_tau mismatch: rel={v_rel:.3e}, \
2197             analytic={v_tau_analytic:.6e}, fd={v_taufd:.6e}"
2198        );
2199    }
2200
2201    #[test]
2202    pub(crate) fn logit_design_moving_hessian_matches_fd() {
2203        // Hessian-level validation for logit + design-motion.
2204        // The IFT correction enters the trace term through
2205        // hessian_derivative_correction(v_i), so the Hessian is the most
2206        // sensitive test of whether the correction is applied correctly.
2207        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0];
2208        let w = Array1::<f64>::ones(y.len());
2209        let x = array![
2210            [1.0, -1.2, 0.3],
2211            [1.0, -0.8, -0.4],
2212            [1.0, -0.3, 0.7],
2213            [1.0, 0.1, -0.9],
2214            [1.0, 0.5, 0.2],
2215            [1.0, 0.9, -0.1],
2216            [1.0, 1.3, 0.8],
2217            [1.0, 1.7, -0.6],
2218            [1.0, -0.5, 0.5],
2219            [1.0, 0.3, -0.3],
2220        ];
2221        let s0 = array![[0.0, 0.0, 0.0], [0.0, 1.2, 0.2], [0.0, 0.2, 0.9]];
2222        let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2223        let rho = array![0.0];
2224
2225        // Two directions: one penalty-only, one design-moving.
2226        let x_tau_0 = Array2::<f64>::zeros(x.raw_dim());
2227        let s_tau_0 = array![[0.0, 0.0, 0.0], [0.0, 0.25, 0.04], [0.0, 0.04, 0.15]];
2228        let x_tau_1 = array![
2229            [0.0, 1e-3, -2e-3],
2230            [0.0, -3e-3, 1e-3],
2231            [0.0, 2e-3, 0.5e-3],
2232            [0.0, -1e-3, 3e-3],
2233            [0.0, 0.5e-3, -1e-3],
2234            [0.0, 1.5e-3, 2e-3],
2235            [0.0, -2e-3, -0.5e-3],
2236            [0.0, 3e-3, 1e-3],
2237            [0.0, -0.5e-3, 2e-3],
2238            [0.0, 1e-3, -1.5e-3],
2239        ];
2240        let s_tau_1 = Array2::<f64>::zeros((3, 3));
2241
2242        let hyper_dirs = vec![
2243            DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2244                .expect("penalty-only direction"),
2245            DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2246                .expect("design-moving direction"),
2247        ];
2248
2249        let state = build_logit_state(&y, &w, &x, &s0, &cfg);
2250        let mut theta = Array1::<f64>::zeros(rho.len() + hyper_dirs.len());
2251        theta.slice_mut(s![..rho.len()]).assign(&rho);
2252        let (_, _, h_full) =
2253            compute_joint_hypercostgradienthessian(&state, &theta, rho.len(), &hyper_dirs)
2254                .expect("joint cost+gradient+hessian");
2255        let h_tt_analytic = h_full.slice(s![rho.len().., rho.len()..]).to_owned();
2256
2257        let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2258        let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2259        let h_ttfd = directional_tau_hessian_fd_reference(
2260            &y,
2261            &w,
2262            &x,
2263            &s0,
2264            &cfg,
2265            &rho,
2266            &hyper_dirs,
2267            &x_tau_mats,
2268            &s_tau_mats,
2269        );
2270
2271        let num = (&h_tt_analytic - &h_ttfd)
2272            .iter()
2273            .map(|v| v * v)
2274            .sum::<f64>()
2275            .sqrt();
2276        let den = h_ttfd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2277        let rel = num / den;
2278        assert!(
2279            rel < 1e-4,
2280            "Logit REML design-moving tau-tau Hessian mismatch: rel={rel:.3e}, \
2281             analytic={h_tt_analytic:?}, fd={h_ttfd:?}"
2282        );
2283    }
2284
2285    // ── Larger non-Gaussian + design-motion fixture (n=30, p=5) ────────
2286    //
2287    // Validates the IFT correction (hessian_derivative_correction) at a
2288    // scale large enough that the correction is numerically non-trivial:
2289    // with n=30 and p=5, the logistic Hessian W(η) is far from identity
2290    // and the IFT term dβ̂/dψ contributes meaningfully.
2291
2292    /// Shared test fixture for binomial-logit REML with design-moving
2293    /// ψ-coordinates, n=30, p=5.
2294    pub(crate) struct BinomialLogitDesignMotionFixture {
2295        pub(crate) y: Array1<f64>,
2296        pub(crate) w: Array1<f64>,
2297        pub(crate) x: Array2<f64>,
2298        pub(crate) s0: Array2<f64>,
2299        pub(crate) cfg: RemlConfig,
2300        pub(crate) rho: Array1<f64>,
2301        /// Design-moving τ-direction: non-zero X_τ, zero S_τ.
2302        pub(crate) x_tau_design: Array2<f64>,
2303        /// Penalty-only τ-direction: zero X_τ, non-zero S_τ.
2304        pub(crate) s_tau_penalty: Array2<f64>,
2305    }
2306
2307    impl BinomialLogitDesignMotionFixture {
2308        pub(crate) fn new() -> Self {
2309            // Binary response with roughly balanced classes.
2310            let y = array![
2311                1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0,
2312                1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0
2313            ];
2314            // Design matrix: intercept + 4 covariate columns with varied magnitudes.
2315            let x = array![
2316                [1.0, -1.50, 0.42, 0.88, -0.31],
2317                [1.0, -1.12, -0.65, 0.14, 1.23],
2318                [1.0, -0.80, 1.10, -0.53, 0.07],
2319                [1.0, -0.55, -0.22, 1.40, -0.90],
2320                [1.0, -0.30, 0.73, -1.05, 0.44],
2321                [1.0, -0.05, -1.33, 0.60, 0.81],
2322                [1.0, 0.18, 0.55, -0.27, -1.15],
2323                [1.0, 0.42, -0.90, 1.12, 0.33],
2324                [1.0, 0.70, 1.28, -0.78, -0.56],
2325                [1.0, 0.95, -0.18, 0.45, 1.40],
2326                [1.0, 1.20, 0.66, -1.30, -0.02],
2327                [1.0, 1.45, -1.05, 0.22, 0.68],
2328                [1.0, -1.35, 0.90, 0.55, -0.43],
2329                [1.0, -0.98, -0.40, -0.88, 1.05],
2330                [1.0, -0.62, 1.42, 0.30, -0.70],
2331                [1.0, -0.28, -0.77, -1.18, 0.52],
2332                [1.0, 0.05, 0.15, 0.95, -1.35],
2333                [1.0, 0.33, -1.20, -0.40, 0.18],
2334                [1.0, 0.60, 0.82, 1.25, -0.85],
2335                [1.0, 0.88, -0.50, -0.65, 1.10],
2336                [1.0, 1.15, 1.05, 0.10, -0.22],
2337                [1.0, -1.22, -0.95, 0.72, 0.90],
2338                [1.0, -0.75, 0.38, -1.42, 0.15],
2339                [1.0, -0.42, -1.15, 0.50, -1.08],
2340                [1.0, -0.10, 0.60, -0.15, 0.75],
2341                [1.0, 0.25, -0.28, 1.05, -0.48],
2342                [1.0, 0.52, 1.35, -0.92, 0.30],
2343                [1.0, 0.80, -0.70, 0.38, 1.20],
2344                [1.0, 1.08, 0.48, -0.60, -0.95],
2345                [1.0, 1.35, -0.55, 0.85, 0.42]
2346            ];
2347            // Penalty matrix: zero on intercept, SPD on remaining 4 columns.
2348            let s0 = array![
2349                [0.0, 0.0, 0.0, 0.0, 0.0],
2350                [0.0, 1.40, 0.15, 0.05, -0.10],
2351                [0.0, 0.15, 1.10, -0.20, 0.08],
2352                [0.0, 0.05, -0.20, 0.95, 0.12],
2353                [0.0, -0.10, 0.08, 0.12, 1.25]
2354            ];
2355            let cfg = RemlConfig::external(binomial_logit_glm_spec(), 1e-14, false);
2356            // Design-moving direction: perturb covariate columns, leave
2357            // intercept untouched.
2358            let x_tau_design = array![
2359                [0.0, 1.2e-3, -0.8e-3, 0.5e-3, -1.5e-3],
2360                [0.0, -2.0e-3, 1.4e-3, -0.3e-3, 0.9e-3],
2361                [0.0, 0.6e-3, -1.1e-3, 1.8e-3, -0.4e-3],
2362                [0.0, -1.3e-3, 0.7e-3, -1.0e-3, 2.1e-3],
2363                [0.0, 0.9e-3, -0.5e-3, 0.2e-3, -0.8e-3],
2364                [0.0, -0.4e-3, 1.8e-3, -1.5e-3, 0.3e-3],
2365                [0.0, 1.5e-3, -1.3e-3, 0.8e-3, -1.1e-3],
2366                [0.0, -0.7e-3, 0.4e-3, -2.0e-3, 1.6e-3],
2367                [0.0, 2.2e-3, -0.9e-3, 1.3e-3, -0.6e-3],
2368                [0.0, -1.0e-3, 1.6e-3, -0.7e-3, 0.5e-3],
2369                [0.0, 0.3e-3, -2.1e-3, 1.1e-3, -1.8e-3],
2370                [0.0, -1.8e-3, 0.2e-3, -0.4e-3, 1.3e-3],
2371                [0.0, 1.1e-3, -1.5e-3, 2.0e-3, -0.2e-3],
2372                [0.0, -0.5e-3, 0.9e-3, -1.2e-3, 0.7e-3],
2373                [0.0, 1.7e-3, -0.3e-3, 0.6e-3, -2.0e-3],
2374                [0.0, -1.4e-3, 1.1e-3, -0.9e-3, 0.4e-3],
2375                [0.0, 0.8e-3, -1.7e-3, 1.5e-3, -0.1e-3],
2376                [0.0, -0.2e-3, 0.6e-3, -1.8e-3, 1.0e-3],
2377                [0.0, 1.4e-3, -0.4e-3, 0.3e-3, -1.3e-3],
2378                [0.0, -0.9e-3, 2.0e-3, -0.5e-3, 0.8e-3],
2379                [0.0, 0.5e-3, -1.0e-3, 1.6e-3, -0.7e-3],
2380                [0.0, -2.1e-3, 0.3e-3, -0.8e-3, 1.5e-3],
2381                [0.0, 0.7e-3, -1.8e-3, 0.9e-3, -0.3e-3],
2382                [0.0, -0.6e-3, 1.3e-3, -2.2e-3, 1.1e-3],
2383                [0.0, 1.9e-3, -0.7e-3, 0.4e-3, -0.9e-3],
2384                [0.0, -1.1e-3, 0.5e-3, -1.4e-3, 2.2e-3],
2385                [0.0, 0.4e-3, -1.6e-3, 1.2e-3, -0.5e-3],
2386                [0.0, -1.6e-3, 0.8e-3, -0.1e-3, 0.6e-3],
2387                [0.0, 1.3e-3, -2.2e-3, 0.7e-3, -1.4e-3],
2388                [0.0, -0.3e-3, 1.0e-3, -1.6e-3, 1.8e-3]
2389            ];
2390            // Penalty-only direction: non-zero S_τ, symmetric, zero on intercept.
2391            let s_tau_penalty = array![
2392                [0.0, 0.0, 0.0, 0.0, 0.0],
2393                [0.0, 0.30, 0.05, -0.02, 0.04],
2394                [0.0, 0.05, 0.22, 0.03, -0.01],
2395                [0.0, -0.02, 0.03, 0.18, 0.06],
2396                [0.0, 0.04, -0.01, 0.06, 0.26]
2397            ];
2398            Self {
2399                w: Array1::<f64>::ones(y.len()),
2400                y,
2401                x,
2402                s0,
2403                cfg,
2404                rho: array![0.0],
2405                x_tau_design,
2406                s_tau_penalty,
2407            }
2408        }
2409    }
2410
2411    impl LogitDesignMotionFixture for BinomialLogitDesignMotionFixture {
2412        fn y(&self) -> &Array1<f64> {
2413            &self.y
2414        }
2415        fn w(&self) -> &Array1<f64> {
2416            &self.w
2417        }
2418        fn x(&self) -> &Array2<f64> {
2419            &self.x
2420        }
2421        fn s0(&self) -> &Array2<f64> {
2422            &self.s0
2423        }
2424        fn cfg(&self) -> &RemlConfig {
2425            &self.cfg
2426        }
2427        fn rho(&self) -> &Array1<f64> {
2428            &self.rho
2429        }
2430    }
2431
2432    // ── n=30, p=5 binomial-logit design-motion gradient tests ────────
2433
2434    #[test]
2435    pub(crate) fn binomial_logit_n30_design_moving_gradient_matches_fd() {
2436        // Pure design-motion: X_τ ≠ 0, S_τ = 0.
2437        // The IFT correction is essential here: because the family is
2438        // binomial-logit, the working weights W(η) depend on β̂, so
2439        // when X moves with ψ, the implicit derivative dβ̂/dψ enters
2440        // the total Hessian drift.  Without hessian_derivative_correction
2441        // the analytic gradient would disagree with FD.
2442        let f = BinomialLogitDesignMotionFixture::new();
2443        let state = f.state();
2444        let s_tau = Array2::<f64>::zeros((5, 5));
2445        let hyper = DirectionalHyperParam::single_penalty(
2446            0,
2447            f.x_tau_design.clone(),
2448            s_tau.clone(),
2449            None,
2450            None,
2451        )
2452        .expect("design-moving hyper direction");
2453
2454        let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2455            .expect("analytic directional gradient");
2456        let v_tau_fd = f.fd_directional_gradient(&f.x_tau_design, &s_tau);
2457
2458        let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2459        assert!(
2460            v_rel < 1e-3,
2461            "Binomial-logit n=30 design-moving gradient mismatch: rel={v_rel:.3e}, \
2462             analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2463        );
2464    }
2465
2466    #[test]
2467    pub(crate) fn binomial_logit_n30_penalty_only_gradient_matches_fd() {
2468        // Penalty-only direction: X_τ = 0, S_τ ≠ 0.
2469        // Serves as a baseline: the IFT correction should still be
2470        // present (since H depends on β̂ through W(η)), but the
2471        // explicit X_τ contribution is zero.
2472        let f = BinomialLogitDesignMotionFixture::new();
2473        let state = f.state();
2474        let x_tau = Array2::<f64>::zeros(f.x.raw_dim());
2475        let hyper = DirectionalHyperParam::single_penalty(
2476            0,
2477            x_tau.clone(),
2478            f.s_tau_penalty.clone(),
2479            None,
2480            None,
2481        )
2482        .expect("penalty-only hyper direction");
2483
2484        let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2485            .expect("analytic directional gradient");
2486        let v_tau_fd = f.fd_directional_gradient(&x_tau, &f.s_tau_penalty);
2487
2488        let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2489        assert!(
2490            v_rel < 1e-3,
2491            "Binomial-logit n=30 penalty-only gradient mismatch: rel={v_rel:.3e}, \
2492             analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2493        );
2494    }
2495
2496    #[test]
2497    pub(crate) fn binomial_logit_n30_joint_design_penalty_gradient_matches_fd() {
2498        // Joint direction: both X_τ ≠ 0 and S_τ ≠ 0 simultaneously.
2499        // This is the hardest case: the analytic gradient must correctly
2500        // combine the explicit penalty drift, the explicit design drift,
2501        // and the IFT Hessian-drift correction.
2502        let f = BinomialLogitDesignMotionFixture::new();
2503        let state = f.state();
2504        let hyper = DirectionalHyperParam::single_penalty(
2505            0,
2506            f.x_tau_design.clone(),
2507            f.s_tau_penalty.clone(),
2508            None,
2509            None,
2510        )
2511        .expect("joint design+penalty hyper direction");
2512
2513        let v_tau_analytic = single_directional_tau_gradient(&state, &f.rho, hyper)
2514            .expect("analytic directional gradient");
2515        let v_tau_fd = f.fd_directional_gradient(&f.x_tau_design, &f.s_tau_penalty);
2516
2517        let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2518        assert!(
2519            v_rel < 1e-3,
2520            "Binomial-logit n=30 joint design+penalty gradient mismatch: rel={v_rel:.3e}, \
2521             analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2522        );
2523    }
2524
2525    #[test]
2526    pub(crate) fn binomial_logit_n30_design_moving_hessian_matches_fd() {
2527        // Hessian-level validation with two τ-directions: one
2528        // penalty-only and one design-moving.  The ττ Hessian block is
2529        // the most sensitive test of the IFT correction because errors
2530        // in the correction accumulate quadratically in the trace term.
2531        let f = BinomialLogitDesignMotionFixture::new();
2532        let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2533        let s_tau_0 = f.s_tau_penalty.clone();
2534        let x_tau_1 = f.x_tau_design.clone();
2535        let s_tau_1 = Array2::<f64>::zeros((5, 5));
2536
2537        let hyper_dirs = vec![
2538            DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2539                .expect("penalty-only direction"),
2540            DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2541                .expect("design-moving direction"),
2542        ];
2543
2544        let state = f.state();
2545        let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2546        theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2547        let (_, _, h_full) =
2548            compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2549                .expect("joint cost+gradient+hessian");
2550        let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2551
2552        let x_tau_mats = vec![x_tau_0.clone(), x_tau_1.clone()];
2553        let s_tau_mats = vec![s_tau_0.clone(), s_tau_1.clone()];
2554        let h_tt_fd = directional_tau_hessian_fd_reference(
2555            &f.y,
2556            &f.w,
2557            &f.x,
2558            &f.s0,
2559            &f.cfg,
2560            &f.rho,
2561            &hyper_dirs,
2562            &x_tau_mats,
2563            &s_tau_mats,
2564        );
2565
2566        let num = (&h_tt_analytic - &h_tt_fd)
2567            .iter()
2568            .map(|v| v * v)
2569            .sum::<f64>()
2570            .sqrt();
2571        let den = h_tt_fd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2572        let rel = num / den;
2573        assert!(
2574            rel < 1e-4,
2575            "Binomial-logit n=30 tau-tau Hessian mismatch: rel={rel:.3e}, \
2576             analytic={h_tt_analytic:?}, fd={h_tt_fd:?}"
2577        );
2578    }
2579
2580    #[test]
2581    pub(crate) fn binomial_logit_n30_nonzero_rho_design_moving_gradient_matches_fd() {
2582        // Validate at a non-trivial smoothing parameter ρ = log(λ) = 1.5,
2583        // so the penalty term λS is scaled up and the balance between
2584        // likelihood and penalty is different from ρ=0.
2585        let f = BinomialLogitDesignMotionFixture::new();
2586        let rho = array![1.5];
2587        let s_tau = Array2::<f64>::zeros((5, 5));
2588
2589        let state = f.state();
2590        let hyper = DirectionalHyperParam::single_penalty(
2591            0,
2592            f.x_tau_design.clone(),
2593            s_tau.clone(),
2594            None,
2595            None,
2596        )
2597        .expect("design-moving hyper direction");
2598
2599        let v_tau_analytic = single_directional_tau_gradient(&state, &rho, hyper)
2600            .expect("analytic directional gradient");
2601
2602        // FD at the shifted ρ: perturb X, re-solve inner, evaluate cost.
2603        let h = 2e-5;
2604        let (state_plus, state_minus) = f.state_perturbed(&f.x_tau_design, &s_tau, h);
2605        let v_plus = state_plus.compute_cost(&rho).expect("cost+");
2606        let v_minus = state_minus.compute_cost(&rho).expect("cost-");
2607        let v_tau_fd = (v_plus - v_minus) / (2.0 * h);
2608
2609        let v_rel = (v_tau_analytic - v_tau_fd).abs() / v_tau_fd.abs().max(1e-10);
2610        assert!(
2611            v_rel < 1e-3,
2612            "Binomial-logit n=30 rho=1.5 design-moving gradient mismatch: rel={v_rel:.3e}, \
2613             analytic={v_tau_analytic:.6e}, fd={v_tau_fd:.6e}"
2614        );
2615    }
2616
2617    #[test]
2618    pub(crate) fn binomial_logit_n30_rank_deficient_hessian_matches_cost_fd() {
2619        // Regression lock for the `PenaltySubspaceTrace` pseudo-logdet
2620        // kernel installed by the rank-deficient LAML fix (see
2621        // `PenaltySubspaceTrace` and `intrinsic_hessian_pseudo_logdet_parts`;
2622        // since #901 the cost is the intrinsic `½ log|H_pen|₊` and the kernel
2623        // is the spectral `H_pen⁺`, exact for every drift direction).
2624        //
2625        // The sibling `binomial_logit_n30_design_moving_hessian_matches_fd`
2626        // passes pre- AND post-fix because its FD reference differentiates
2627        // the *analytic gradient* — any self-consistent (if wrong) gradient
2628        // kernel gives a self-consistent Hessian under re-differentiation,
2629        // so that test cannot distinguish full-space from projected traces.
2630        // It passed under the buggy kernel because the same leakage entered
2631        // both sides of the ratio and cancelled.
2632        //
2633        // Here we FD-differentiate `compute_cost` TWICE and compare against
2634        // the analytic Hessian.  Central second differences expose every
2635        // disagreement between `½ log|U_Sᵀ H U_S|_+` (used by the cost) and
2636        // `½ tr(G_ε(H) · Ḣ)` / `−½ tr(G_ε Ḣ_i G_ε Ḣ_j)` (the full-space
2637        // traces that the gradient and Hessian used before the projection
2638        // fix).  Under the buggy kernel the IFT correction
2639        // `D_β H[v] = X' diag(c ⊙ X v) X` leaks onto `null(S)` — X's
2640        // all-ones intercept column sits there — and that leakage enters
2641        // the analytic Hessian but not the cost's projected logdet.
2642        //
2643        // Direction mix chosen to maximise the null-space leakage pathway:
2644        //   τ_0 = penalty-only (X_τ = 0, S_τ ≠ 0)  → v_0 = H⁻¹(−S_τ β̂) is
2645        //         concentrated in range(S_+), but `D_β H[v_0]` has rows and
2646        //         columns on the intercept because `X[:,0] = 1_n`.
2647        //   τ_1 = design-moving (X_τ ≠ 0 on non-intercept columns, S_τ = 0)
2648        //         → `v_1` also picks up the intercept via `X'WX_τβ̂`, and
2649        //         the base drift `X_τᵀWX + XᵀWX_τ` straddles range(S_+) /
2650        //         null(S).
2651        // Both pure directions AND the mixed partial load the Schur correction,
2652        // so any of the three entries can catch a regression.
2653        let f = BinomialLogitDesignMotionFixture::new();
2654        let x_tau_0 = Array2::<f64>::zeros(f.x.raw_dim());
2655        let s_tau_0 = f.s_tau_penalty.clone();
2656        let x_tau_1 = f.x_tau_design.clone();
2657        let s_tau_1 = Array2::<f64>::zeros((5, 5));
2658
2659        let hyper_dirs = vec![
2660            DirectionalHyperParam::single_penalty(0, x_tau_0.clone(), s_tau_0.clone(), None, None)
2661                .expect("penalty-only direction"),
2662            DirectionalHyperParam::single_penalty(0, x_tau_1.clone(), s_tau_1.clone(), None, None)
2663                .expect("design-moving direction"),
2664        ];
2665
2666        // Analytic Hessian block.
2667        let state = f.state();
2668        let mut theta = Array1::<f64>::zeros(f.rho.len() + hyper_dirs.len());
2669        theta.slice_mut(s![..f.rho.len()]).assign(&f.rho);
2670        let (_, _, h_full) =
2671            compute_joint_hypercostgradienthessian(&state, &theta, f.rho.len(), &hyper_dirs)
2672                .expect("joint cost+gradient+hessian");
2673        let h_tt_analytic = h_full.slice(s![f.rho.len().., f.rho.len()..]).to_owned();
2674
2675        // Cost-level FD reference.  Central second differences give O(h²)
2676        // accuracy; the step is sized so the physical perturbation on X / S
2677        // stays near `1e-5` (same scale as the gradient tests).
2678        const TARGET_PHYSICAL_STEP: f64 = 1e-5;
2679        let x_tau_mats = [&x_tau_0, &x_tau_1];
2680        let s_tau_mats = [&s_tau_0, &s_tau_1];
2681        let steps: [f64; 2] = {
2682            let mut steps = [0.0; 2];
2683            for (j, step) in steps.iter_mut().enumerate() {
2684                let scale = x_tau_mats[j]
2685                    .iter()
2686                    .chain(s_tau_mats[j].iter())
2687                    .fold(0.0_f64, |acc, value| acc.max(value.abs()));
2688                *step = if scale > 0.0 {
2689                    TARGET_PHYSICAL_STEP / scale
2690                } else {
2691                    TARGET_PHYSICAL_STEP
2692                };
2693            }
2694            steps
2695        };
2696
2697        // Evaluate `compute_cost` at `(a · τ_0, b · τ_1)` multipliers.
2698        let eval_cost = |a: f64, b: f64| -> f64 {
2699            let x_eval = &f.x
2700                + &x_tau_mats[0].mapv(|v| a * steps[0] * v)
2701                + &x_tau_mats[1].mapv(|v| b * steps[1] * v);
2702            let s_eval = &f.s0
2703                + &s_tau_mats[0].mapv(|v| a * steps[0] * v)
2704                + &s_tau_mats[1].mapv(|v| b * steps[1] * v);
2705            let st = build_logit_state(&f.y, &f.w, &x_eval, &s_eval, &f.cfg);
2706            st.compute_cost(&f.rho).expect("cost eval")
2707        };
2708
2709        let v_00 = eval_cost(0.0, 0.0);
2710        let v_p0 = eval_cost(1.0, 0.0);
2711        let v_m0 = eval_cost(-1.0, 0.0);
2712        let v_0p = eval_cost(0.0, 1.0);
2713        let v_0m = eval_cost(0.0, -1.0);
2714        let v_pp = eval_cost(1.0, 1.0);
2715        let v_pm = eval_cost(1.0, -1.0);
2716        let v_mp = eval_cost(-1.0, 1.0);
2717        let v_mm = eval_cost(-1.0, -1.0);
2718
2719        let h00_fd = (v_p0 - 2.0 * v_00 + v_m0) / (steps[0] * steps[0]);
2720        let h11_fd = (v_0p - 2.0 * v_00 + v_0m) / (steps[1] * steps[1]);
2721        let h01_fd = (v_pp - v_pm - v_mp + v_mm) / (4.0 * steps[0] * steps[1]);
2722
2723        let h_tt_fd = array![[h00_fd, h01_fd], [h01_fd, h11_fd]];
2724
2725        let num = (&h_tt_analytic - &h_tt_fd)
2726            .iter()
2727            .map(|v| v * v)
2728            .sum::<f64>()
2729            .sqrt();
2730        let den = h_tt_fd.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-10);
2731        let rel = num / den;
2732
2733        assert!(
2734            rel < 3e-3,
2735            "Binomial-logit n=30 rank-deficient Hessian vs cost-FD mismatch: rel={rel:.3e}, \
2736             analytic={h_tt_analytic:?}, fd={h_tt_fd:?}"
2737        );
2738    }
2739}
2740
2741#[derive(Clone, Copy, Debug)]
2742pub(crate) enum RemlGeometry {
2743    DenseSpectral,
2744    SparseExactSpd,
2745}
2746
2747trait PenalizedGeometry {
2748    fn backend_kind(&self) -> GeometryBackendKind;
2749}
2750
2751#[derive(Clone)]
2752pub(crate) enum DerivativeMatrixStorage {
2753    Dense(Array2<f64>),
2754    Zero(ZeroDerivativeMatrix),
2755    Embedded(EmbeddedDerivativeMatrix),
2756    Implicit(ImplicitDerivativeOp),
2757    LatentCoord(LatentCoordDerivativeOp),
2758}
2759
2760/// Mechanical surface every `DerivativeMatrixStorage` variant must expose so
2761/// the `HyperDesignDerivative` / `HyperPenaltyDerivative` wrappers can dispatch
2762/// with a single per-call `storage_dispatch!`. Each backend owns its variant's
2763/// substantive math; the wrappers contain only one-line routing.
2764///
2765/// `design_*` variants treat the backend as an X-style operator (rows index
2766/// data, columns index coefficients); `penalty_*` variants treat the backend
2767/// as a square `p×p` penalty in the global coefficient frame. The Embedded
2768/// case is the only variant whose two views genuinely differ (local rows vs
2769/// total_dim square), which is why the two role-specific methods both live in
2770/// one trait rather than two parallel traits.
2771trait DerivativeStorageBackend {
2772    fn resident_byte_count(&self) -> usize;
2773    fn design_nrows(&self) -> usize;
2774    fn design_ncols(&self) -> usize;
2775    fn penalty_dim(&self) -> usize;
2776    fn uses_implicit_storage(&self) -> bool;
2777    fn any_nonzero(&self) -> bool;
2778    fn materialize(&self) -> Array2<f64>;
2779    fn implicit_first_axis_info(
2780        &self,
2781    ) -> Option<(
2782        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
2783        usize,
2784    )>;
2785    fn implicit_axis_count_hint(&self) -> Option<usize>;
2786    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError>;
2787    fn design_transpose_mul_original(
2788        &self,
2789        v: &Array1<f64>,
2790    ) -> Result<Array1<f64>, EstimationError>;
2791    fn design_transformed(
2792        &self,
2793        qs: &Array2<f64>,
2794        free_basis_opt: Option<&Array2<f64>>,
2795    ) -> Result<Array2<f64>, EstimationError>;
2796    /// Default materialises through `design_transformed` then `.dot(u)`;
2797    /// implicit/latent-coordinate backends override with a direct-operator
2798    /// path that skips the dense materialisation.
2799    fn design_transformed_forward_mul(
2800        &self,
2801        qs: &Array2<f64>,
2802        free_basis_opt: Option<&Array2<f64>>,
2803        u: &Array1<f64>,
2804    ) -> Result<Array1<f64>, EstimationError> {
2805        Ok(self.design_transformed(qs, free_basis_opt)?.dot(u))
2806    }
2807    /// Default materialises through `design_transformed` then `.t().dot(v)`;
2808    /// implicit/latent-coordinate backends override with a direct path.
2809    fn design_transformed_transpose_mul(
2810        &self,
2811        qs: &Array2<f64>,
2812        free_basis_opt: Option<&Array2<f64>>,
2813        v: &Array1<f64>,
2814    ) -> Result<Array1<f64>, EstimationError> {
2815        Ok(self.design_transformed(qs, free_basis_opt)?.t().dot(v))
2816    }
2817    fn penalty_transformed(
2818        &self,
2819        qs: &Array2<f64>,
2820        free_basis_opt: Option<&Array2<f64>>,
2821    ) -> Result<Array2<f64>, EstimationError>;
2822    fn penalty_scaled_add_to(
2823        &self,
2824        target: &mut Array2<f64>,
2825        amp: f64,
2826    ) -> Result<(), EstimationError>;
2827}
2828
2829/// Fans `expr` over the four `DerivativeMatrixStorage` variants in one place
2830/// so every wrapper method is a single dispatch line — the compiler enforces
2831/// exhaustiveness here, so adding a new variant produces one hard error at
2832/// this site rather than a silent miss in any of the (currently 16) ladders.
2833macro_rules! storage_dispatch {
2834    ($scrutinee:expr, $backend:ident => $body:expr) => {
2835        match $scrutinee {
2836            DerivativeMatrixStorage::Dense($backend) => $body,
2837            DerivativeMatrixStorage::Zero($backend) => $body,
2838            DerivativeMatrixStorage::Embedded($backend) => $body,
2839            DerivativeMatrixStorage::Implicit($backend) => $body,
2840            DerivativeMatrixStorage::LatentCoord($backend) => $body,
2841        }
2842    };
2843}
2844
2845#[derive(Clone)]
2846pub(crate) struct ZeroDerivativeMatrix {
2847    rows: usize,
2848    cols: usize,
2849}
2850
2851impl ZeroDerivativeMatrix {
2852    pub(crate) fn new(rows: usize, cols: usize) -> Self {
2853        Self { rows, cols }
2854    }
2855}
2856
2857/// Which derivative level the implicit operator should compute.
2858#[derive(Clone, Copy, Debug)]
2859pub enum ImplicitDerivLevel {
2860    /// ∂X/∂ψ_d
2861    First(usize),
2862    /// ∂²X/∂ψ_d²
2863    SecondDiag(usize),
2864    /// ∂²X/∂ψ_d∂ψ_e
2865    SecondCross(usize, usize),
2866}
2867
2868/// Lazy implicit operator storage: delegates matvecs to the
2869/// `ImplicitDesignPsiDerivative` and materializes dense form only on demand.
2870#[derive(Clone)]
2871pub(crate) struct ImplicitDerivativeOp {
2872    pub(crate) operator: std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
2873    pub(crate) level: ImplicitDerivLevel,
2874    pub(crate) global_range: Range<usize>,
2875    pub(crate) total_dim: usize,
2876    /// Cached dense materialization (lazy, populated on first call to ops that need the full matrix).
2877    ///
2878    /// Rayon-safe: `materialize_local` calls `materialize_first` / `_second_diag`
2879    /// / `_second_cross` on the implicit basis-derivative operator, which for
2880    /// streaming bases dispatches `(0..nc).into_par_iter().for_each(...)`. A plain
2881    /// `std::sync::OnceLock` here would deadlock if `materialize_dense` were first
2882    /// called concurrently from inside another rayon par_iter — racing workers
2883    /// would park on the OnceLock's OS condvar, leaving the leader's nested
2884    /// par_iter without workers. `RayonSafeOnce` runs init lock-free.
2885    pub(crate) cached_dense: std::sync::Arc<gam_runtime::resource::RayonSafeOnce<Array2<f64>>>,
2886}
2887
2888#[derive(Clone)]
2889pub(crate) struct LatentCoordDerivativeOp {
2890    pub(crate) operator: std::sync::Arc<gam_terms::basis::LatentCoordDesignDerivative>,
2891    pub(crate) flat_axis: usize,
2892    pub(crate) global_range: Range<usize>,
2893    pub(crate) total_dim: usize,
2894    pub(crate) cached_dense: std::sync::Arc<gam_runtime::resource::RayonSafeOnce<Array2<f64>>>,
2895}
2896
2897impl LatentCoordDerivativeOp {
2898    pub(crate) fn materialize_local(&self) -> Array2<f64> {
2899        self.operator.materialize_axis(self.flat_axis).expect(
2900            "radial scalar evaluation failed during latent-coordinate derivative materialization",
2901        )
2902    }
2903
2904    pub(crate) fn materialize_dense(&self) -> &Array2<f64> {
2905        self.cached_dense.get_or_compute(|| {
2906            let local = self.materialize_local();
2907            let mut out = Array2::<f64>::zeros((local.nrows(), self.total_dim));
2908            out.slice_mut(s![.., self.global_range.clone()])
2909                .assign(&local);
2910            out
2911        })
2912    }
2913
2914    pub(crate) fn nrows(&self) -> usize {
2915        self.operator.n_data()
2916    }
2917
2918    pub(crate) fn ncols(&self) -> usize {
2919        self.total_dim
2920    }
2921
2922    pub(crate) fn transpose_mul(&self, v: &Array1<f64>) -> Array1<f64> {
2923        let local = self
2924            .operator
2925            .transpose_mul_axis(self.flat_axis, &v.view())
2926            .expect(
2927                "radial scalar evaluation failed during latent-coordinate derivative transpose_mul",
2928            );
2929        let mut out = Array1::<f64>::zeros(self.total_dim);
2930        out.slice_mut(s![self.global_range.clone()]).assign(&local);
2931        out
2932    }
2933
2934    pub(crate) fn forward_mul(&self, u: &Array1<f64>) -> Array1<f64> {
2935        let u_local = u.slice(s![self.global_range.clone()]).to_owned();
2936        self.operator
2937            .forward_mul_axis(self.flat_axis, &u_local.view())
2938            .expect(
2939                "radial scalar evaluation failed during latent-coordinate derivative forward_mul",
2940            )
2941    }
2942}
2943
2944impl ImplicitDerivativeOp {
2945    pub(crate) fn materialize_local(&self) -> Array2<f64> {
2946        match self.level {
2947            ImplicitDerivLevel::First(axis) => self.operator.materialize_first(axis).expect(
2948                "radial scalar evaluation failed during implicit derivative materialization",
2949            ),
2950            ImplicitDerivLevel::SecondDiag(axis) => {
2951                self.operator.materialize_second_diag(axis).expect(
2952                    "radial scalar evaluation failed during implicit derivative materialization",
2953                )
2954            }
2955            ImplicitDerivLevel::SecondCross(d, e) => {
2956                self.operator.materialize_second_cross(d, e).expect(
2957                    "radial scalar evaluation failed during implicit derivative materialization",
2958                )
2959            }
2960        }
2961    }
2962
2963    pub(crate) fn materialize_dense(&self) -> &Array2<f64> {
2964        self.cached_dense.get_or_compute(|| {
2965            let local = self.materialize_local();
2966            let mut out = Array2::<f64>::zeros((local.nrows(), self.total_dim));
2967            out.slice_mut(s![.., self.global_range.clone()])
2968                .assign(&local);
2969            out
2970        })
2971    }
2972
2973    pub(crate) fn nrows(&self) -> usize {
2974        self.operator.n_data()
2975    }
2976
2977    pub(crate) fn ncols(&self) -> usize {
2978        self.total_dim
2979    }
2980
2981    pub(crate) fn transpose_mul(&self, v: &Array1<f64>) -> Array1<f64> {
2982        let local = match self.level {
2983            ImplicitDerivLevel::First(axis) => self
2984                .operator
2985                .transpose_mul(axis, &v.view())
2986                .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
2987            ImplicitDerivLevel::SecondDiag(axis) => self
2988                .operator
2989                .transpose_mul_second_diag(axis, &v.view())
2990                .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
2991            ImplicitDerivLevel::SecondCross(d, e) => self
2992                .operator
2993                .transpose_mul_second_cross(d, e, &v.view())
2994                .expect("radial scalar evaluation failed during implicit derivative transpose_mul"),
2995        };
2996        let mut out = Array1::<f64>::zeros(self.total_dim);
2997        out.slice_mut(s![self.global_range.clone()]).assign(&local);
2998        out
2999    }
3000
3001    pub(crate) fn forward_mul(&self, u: &Array1<f64>) -> Array1<f64> {
3002        let u_local = u.slice(s![self.global_range.clone()]).to_owned();
3003        match self.level {
3004            ImplicitDerivLevel::First(axis) => self
3005                .operator
3006                .forward_mul(axis, &u_local.view())
3007                .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3008            ImplicitDerivLevel::SecondDiag(axis) => self
3009                .operator
3010                .forward_mul_second_diag(axis, &u_local.view())
3011                .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3012            ImplicitDerivLevel::SecondCross(d, e) => self
3013                .operator
3014                .forward_mul_second_cross(d, e, &u_local.view())
3015                .expect("radial scalar evaluation failed during implicit derivative forward_mul"),
3016        }
3017    }
3018}
3019
3020#[derive(Clone)]
3021pub(crate) struct EmbeddedDerivativeMatrix {
3022    pub(crate) local: Array2<f64>,
3023    pub(crate) global_range: Range<usize>,
3024    pub(crate) total_dim: usize,
3025}
3026
3027impl EmbeddedDerivativeMatrix {
3028    pub(crate) fn new(local: Array2<f64>, global_range: Range<usize>, total_dim: usize) -> Self {
3029        Self {
3030            local,
3031            global_range,
3032            total_dim,
3033        }
3034    }
3035}
3036
3037impl DerivativeStorageBackend for Array2<f64> {
3038    fn resident_byte_count(&self) -> usize {
3039        self.len().saturating_mul(std::mem::size_of::<f64>())
3040    }
3041    fn design_nrows(&self) -> usize {
3042        Array2::nrows(self)
3043    }
3044    fn design_ncols(&self) -> usize {
3045        Array2::ncols(self)
3046    }
3047    fn penalty_dim(&self) -> usize {
3048        Array2::nrows(self)
3049    }
3050    fn uses_implicit_storage(&self) -> bool {
3051        false
3052    }
3053    fn any_nonzero(&self) -> bool {
3054        self.iter().any(|v| *v != 0.0)
3055    }
3056    fn materialize(&self) -> Array2<f64> {
3057        self.clone()
3058    }
3059    fn implicit_first_axis_info(
3060        &self,
3061    ) -> Option<(
3062        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3063        usize,
3064    )> {
3065        None
3066    }
3067    fn implicit_axis_count_hint(&self) -> Option<usize> {
3068        None
3069    }
3070
3071    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3072        if Array2::ncols(self) != u.len() {
3073            crate::bail_invalid_estim!(
3074                "dense hyper design derivative forward_mul_original width mismatch: matrix={}x{}, vector={}",
3075                Array2::nrows(self),
3076                Array2::ncols(self),
3077                u.len()
3078            );
3079        }
3080        Ok(self.dot(u))
3081    }
3082
3083    fn design_transpose_mul_original(
3084        &self,
3085        v: &Array1<f64>,
3086    ) -> Result<Array1<f64>, EstimationError> {
3087        if Array2::nrows(self) != v.len() {
3088            crate::bail_invalid_estim!(
3089                "dense hyper design derivative transpose_mul_original height mismatch: matrix={}x{}, vector={}",
3090                Array2::nrows(self),
3091                Array2::ncols(self),
3092                v.len()
3093            );
3094        }
3095        Ok(self.t().dot(v))
3096    }
3097
3098    fn design_transformed(
3099        &self,
3100        qs: &Array2<f64>,
3101        free_basis_opt: Option<&Array2<f64>>,
3102    ) -> Result<Array2<f64>, EstimationError> {
3103        Ok(gam_linalg::matrix::DenseRightProductView::new(self)
3104            .with_factor(qs)
3105            .with_optional_factor(free_basis_opt)
3106            .materialize())
3107    }
3108
3109    fn penalty_transformed(
3110        &self,
3111        qs: &Array2<f64>,
3112        free_basis_opt: Option<&Array2<f64>>,
3113    ) -> Result<Array2<f64>, EstimationError> {
3114        let mut transformed = qs.t().dot(self).dot(qs);
3115        if let Some(z) = free_basis_opt {
3116            transformed = z.t().dot(&transformed).dot(z);
3117        }
3118        Ok(transformed)
3119    }
3120
3121    fn penalty_scaled_add_to(
3122        &self,
3123        target: &mut Array2<f64>,
3124        amp: f64,
3125    ) -> Result<(), EstimationError> {
3126        if target.raw_dim() != self.raw_dim() {
3127            crate::bail_invalid_estim!(
3128                "dense hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3129                target.nrows(),
3130                target.ncols(),
3131                Array2::nrows(self),
3132                Array2::ncols(self)
3133            );
3134        }
3135        target.scaled_add(amp, self);
3136        Ok(())
3137    }
3138}
3139
3140impl DerivativeStorageBackend for ZeroDerivativeMatrix {
3141    fn resident_byte_count(&self) -> usize {
3142        0
3143    }
3144    fn design_nrows(&self) -> usize {
3145        self.rows
3146    }
3147    fn design_ncols(&self) -> usize {
3148        self.cols
3149    }
3150    fn penalty_dim(&self) -> usize {
3151        self.cols
3152    }
3153    fn uses_implicit_storage(&self) -> bool {
3154        false
3155    }
3156    fn any_nonzero(&self) -> bool {
3157        false
3158    }
3159    fn materialize(&self) -> Array2<f64> {
3160        Array2::<f64>::zeros((self.rows, self.cols))
3161    }
3162    fn implicit_first_axis_info(
3163        &self,
3164    ) -> Option<(
3165        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3166        usize,
3167    )> {
3168        None
3169    }
3170    fn implicit_axis_count_hint(&self) -> Option<usize> {
3171        None
3172    }
3173
3174    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3175        if self.cols != u.len() {
3176            crate::bail_invalid_estim!(
3177                "zero hyper design derivative forward_mul_original width mismatch: matrix={}x{}, vector={}",
3178                self.rows,
3179                self.cols,
3180                u.len()
3181            );
3182        }
3183        Ok(Array1::<f64>::zeros(self.rows))
3184    }
3185
3186    fn design_transpose_mul_original(
3187        &self,
3188        v: &Array1<f64>,
3189    ) -> Result<Array1<f64>, EstimationError> {
3190        if self.rows != v.len() {
3191            crate::bail_invalid_estim!(
3192                "zero hyper design derivative transpose_mul_original height mismatch: matrix={}x{}, vector={}",
3193                self.rows,
3194                self.cols,
3195                v.len()
3196            );
3197        }
3198        Ok(Array1::<f64>::zeros(self.cols))
3199    }
3200
3201    fn design_transformed(
3202        &self,
3203        qs: &Array2<f64>,
3204        free_basis_opt: Option<&Array2<f64>>,
3205    ) -> Result<Array2<f64>, EstimationError> {
3206        if self.cols != qs.nrows() {
3207            crate::bail_invalid_estim!(
3208                "zero design derivative width mismatch: total_cols={}, qs rows={}",
3209                self.cols,
3210                qs.nrows()
3211            );
3212        }
3213        let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3214        Ok(Array2::<f64>::zeros((self.rows, cols)))
3215    }
3216
3217    fn design_transformed_forward_mul(
3218        &self,
3219        qs: &Array2<f64>,
3220        free_basis_opt: Option<&Array2<f64>>,
3221        u: &Array1<f64>,
3222    ) -> Result<Array1<f64>, EstimationError> {
3223        if self.cols != qs.nrows() {
3224            crate::bail_invalid_estim!(
3225                "zero design derivative width mismatch: total_cols={}, qs rows={}",
3226                self.cols,
3227                qs.nrows()
3228            );
3229        }
3230        let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3231        if u.len() != cols {
3232            crate::bail_invalid_estim!(
3233                "zero design derivative transformed forward width mismatch: expected {}, vector={}",
3234                cols,
3235                u.len()
3236            );
3237        }
3238        Ok(Array1::<f64>::zeros(self.rows))
3239    }
3240
3241    fn design_transformed_transpose_mul(
3242        &self,
3243        qs: &Array2<f64>,
3244        free_basis_opt: Option<&Array2<f64>>,
3245        v: &Array1<f64>,
3246    ) -> Result<Array1<f64>, EstimationError> {
3247        if self.rows != v.len() {
3248            crate::bail_invalid_estim!(
3249                "zero design derivative transpose height mismatch: matrix rows={}, vector={}",
3250                self.rows,
3251                v.len()
3252            );
3253        }
3254        if self.cols != qs.nrows() {
3255            crate::bail_invalid_estim!(
3256                "zero design derivative width mismatch: total_cols={}, qs rows={}",
3257                self.cols,
3258                qs.nrows()
3259            );
3260        }
3261        let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3262        Ok(Array1::<f64>::zeros(cols))
3263    }
3264
3265    fn penalty_transformed(
3266        &self,
3267        qs: &Array2<f64>,
3268        free_basis_opt: Option<&Array2<f64>>,
3269    ) -> Result<Array2<f64>, EstimationError> {
3270        if self.cols != qs.nrows() {
3271            crate::bail_invalid_estim!(
3272                "zero penalty derivative width mismatch: total_dim={}, qs rows={}",
3273                self.cols,
3274                qs.nrows()
3275            );
3276        }
3277        let cols = free_basis_opt.map_or(qs.ncols(), |z| z.ncols());
3278        Ok(Array2::<f64>::zeros((cols, cols)))
3279    }
3280
3281    fn penalty_scaled_add_to(
3282        &self,
3283        target: &mut Array2<f64>,
3284        amp: f64,
3285    ) -> Result<(), EstimationError> {
3286        // Zero penalty derivative: `amp · 0 = 0`, so `amp` scales nothing and
3287        // `target` is left unchanged. Validate it is finite so a bad scale
3288        // surfaces here rather than silently no-op'ing on a NaN/inf amplitude.
3289        if !amp.is_finite() {
3290            crate::bail_invalid_estim!(
3291                "zero hyper penalty derivative received non-finite amp={amp}"
3292            );
3293        }
3294        if target.nrows() != self.cols || target.ncols() != self.cols {
3295            crate::bail_invalid_estim!(
3296                "zero hyper penalty derivative shape mismatch: target={}x{}, expected {}x{}",
3297                target.nrows(),
3298                target.ncols(),
3299                self.cols,
3300                self.cols
3301            );
3302        }
3303        Ok(())
3304    }
3305}
3306
3307impl DerivativeStorageBackend for EmbeddedDerivativeMatrix {
3308    fn resident_byte_count(&self) -> usize {
3309        self.local.len().saturating_mul(std::mem::size_of::<f64>())
3310    }
3311    fn design_nrows(&self) -> usize {
3312        self.local.nrows()
3313    }
3314    fn design_ncols(&self) -> usize {
3315        self.total_dim
3316    }
3317    fn penalty_dim(&self) -> usize {
3318        self.total_dim
3319    }
3320    fn uses_implicit_storage(&self) -> bool {
3321        false
3322    }
3323    fn any_nonzero(&self) -> bool {
3324        self.local.iter().any(|v| *v != 0.0)
3325    }
3326    fn materialize(&self) -> Array2<f64> {
3327        let mut dense = Array2::<f64>::zeros((self.local.nrows(), self.total_dim));
3328        dense
3329            .slice_mut(s![.., self.global_range.clone()])
3330            .assign(&self.local);
3331        dense
3332    }
3333    fn implicit_first_axis_info(
3334        &self,
3335    ) -> Option<(
3336        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3337        usize,
3338    )> {
3339        None
3340    }
3341    fn implicit_axis_count_hint(&self) -> Option<usize> {
3342        None
3343    }
3344
3345    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3346        if self.total_dim != u.len() {
3347            crate::bail_invalid_estim!(
3348                "embedded hyper design derivative forward_mul_original width mismatch: total_dim={}, vector={}",
3349                self.total_dim,
3350                u.len()
3351            );
3352        }
3353        let u_local = u.slice(s![self.global_range.clone()]).to_owned();
3354        Ok(self.local.dot(&u_local))
3355    }
3356
3357    fn design_transpose_mul_original(
3358        &self,
3359        v: &Array1<f64>,
3360    ) -> Result<Array1<f64>, EstimationError> {
3361        if self.local.nrows() != v.len() {
3362            crate::bail_invalid_estim!(
3363                "embedded hyper design derivative transpose_mul_original height mismatch: local_rows={}, vector={}",
3364                self.local.nrows(),
3365                v.len()
3366            );
3367        }
3368        let mut out = Array1::<f64>::zeros(self.total_dim);
3369        let pulled = self.local.t().dot(v);
3370        out.slice_mut(s![self.global_range.clone()]).assign(&pulled);
3371        Ok(out)
3372    }
3373
3374    fn design_transformed(
3375        &self,
3376        qs: &Array2<f64>,
3377        free_basis_opt: Option<&Array2<f64>>,
3378    ) -> Result<Array2<f64>, EstimationError> {
3379        if self.total_dim != qs.nrows() {
3380            crate::bail_invalid_estim!(
3381                "embedded design derivative width mismatch: total_cols={}, qs rows={}",
3382                self.total_dim,
3383                qs.nrows()
3384            );
3385        }
3386        let qs_local = qs.slice(s![self.global_range.clone(), ..]);
3387        let mut transformed = self.local.dot(&qs_local);
3388        if let Some(z) = free_basis_opt {
3389            transformed = transformed.dot(z);
3390        }
3391        Ok(transformed)
3392    }
3393
3394    fn penalty_transformed(
3395        &self,
3396        qs: &Array2<f64>,
3397        free_basis_opt: Option<&Array2<f64>>,
3398    ) -> Result<Array2<f64>, EstimationError> {
3399        if self.total_dim != qs.nrows() {
3400            crate::bail_invalid_estim!(
3401                "embedded penalty derivative width mismatch: total_dim={}, qs rows={}",
3402                self.total_dim,
3403                qs.nrows()
3404            );
3405        }
3406        let qs_local = qs.slice(s![self.global_range.clone(), ..]);
3407        let mut transformed = qs_local.t().dot(&self.local).dot(&qs_local);
3408        if let Some(z) = free_basis_opt {
3409            transformed = z.t().dot(&transformed).dot(z);
3410        }
3411        Ok(transformed)
3412    }
3413
3414    fn penalty_scaled_add_to(
3415        &self,
3416        target: &mut Array2<f64>,
3417        amp: f64,
3418    ) -> Result<(), EstimationError> {
3419        if target.nrows() != self.total_dim || target.ncols() != self.total_dim {
3420            crate::bail_invalid_estim!(
3421                "embedded hyper penalty derivative shape mismatch: target={}x{}, expected {}x{}",
3422                target.nrows(),
3423                target.ncols(),
3424                self.total_dim,
3425                self.total_dim
3426            );
3427        }
3428        target
3429            .slice_mut(s![self.global_range.clone(), self.global_range.clone()])
3430            .scaled_add(amp, &self.local);
3431        Ok(())
3432    }
3433}
3434
3435impl DerivativeStorageBackend for ImplicitDerivativeOp {
3436    fn resident_byte_count(&self) -> usize {
3437        0
3438    }
3439    fn design_nrows(&self) -> usize {
3440        self.nrows()
3441    }
3442    fn design_ncols(&self) -> usize {
3443        self.ncols()
3444    }
3445    fn penalty_dim(&self) -> usize {
3446        self.nrows()
3447    }
3448    fn uses_implicit_storage(&self) -> bool {
3449        true
3450    }
3451    fn any_nonzero(&self) -> bool {
3452        true
3453    }
3454    fn materialize(&self) -> Array2<f64> {
3455        self.materialize_dense().clone()
3456    }
3457    fn implicit_first_axis_info(
3458        &self,
3459    ) -> Option<(
3460        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3461        usize,
3462    )> {
3463        match self.level {
3464            ImplicitDerivLevel::First(axis) => Some((self.operator.clone(), axis)),
3465            _ => None,
3466        }
3467    }
3468    fn implicit_axis_count_hint(&self) -> Option<usize> {
3469        Some(self.operator.n_axes())
3470    }
3471
3472    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3473        if self.ncols() != u.len() {
3474            crate::bail_invalid_estim!(
3475                "implicit hyper design derivative forward_mul_original width mismatch: operator_cols={}, vector={}",
3476                self.ncols(),
3477                u.len()
3478            );
3479        }
3480        Ok(self.forward_mul(u))
3481    }
3482
3483    fn design_transpose_mul_original(
3484        &self,
3485        v: &Array1<f64>,
3486    ) -> Result<Array1<f64>, EstimationError> {
3487        if self.nrows() != v.len() {
3488            crate::bail_invalid_estim!(
3489                "implicit hyper design derivative transpose_mul_original height mismatch: operator_rows={}, vector={}",
3490                self.nrows(),
3491                v.len()
3492            );
3493        }
3494        Ok(self.transpose_mul(v))
3495    }
3496
3497    fn design_transformed(
3498        &self,
3499        qs: &Array2<f64>,
3500        free_basis_opt: Option<&Array2<f64>>,
3501    ) -> Result<Array2<f64>, EstimationError> {
3502        let dense = self.materialize_dense();
3503        Ok(gam_linalg::matrix::DenseRightProductView::new(dense)
3504            .with_factor(qs)
3505            .with_optional_factor(free_basis_opt)
3506            .materialize())
3507    }
3508
3509    fn design_transformed_forward_mul(
3510        &self,
3511        qs: &Array2<f64>,
3512        free_basis_opt: Option<&Array2<f64>>,
3513        u: &Array1<f64>,
3514    ) -> Result<Array1<f64>, EstimationError> {
3515        let mut right = if let Some(z) = free_basis_opt {
3516            z.dot(u)
3517        } else {
3518            u.clone()
3519        };
3520        right = qs.dot(&right);
3521        Ok(self.forward_mul(&right))
3522    }
3523
3524    fn design_transformed_transpose_mul(
3525        &self,
3526        qs: &Array2<f64>,
3527        free_basis_opt: Option<&Array2<f64>>,
3528        v: &Array1<f64>,
3529    ) -> Result<Array1<f64>, EstimationError> {
3530        let mut pulled = qs.t().dot(&self.transpose_mul(v));
3531        if let Some(z) = free_basis_opt {
3532            pulled = z.t().dot(&pulled);
3533        }
3534        Ok(pulled)
3535    }
3536
3537    fn penalty_transformed(
3538        &self,
3539        qs: &Array2<f64>,
3540        free_basis_opt: Option<&Array2<f64>>,
3541    ) -> Result<Array2<f64>, EstimationError> {
3542        let dense = self.materialize_dense();
3543        let mut transformed = qs.t().dot(dense).dot(qs);
3544        if let Some(z) = free_basis_opt {
3545            transformed = z.t().dot(&transformed).dot(z);
3546        }
3547        Ok(transformed)
3548    }
3549
3550    fn penalty_scaled_add_to(
3551        &self,
3552        target: &mut Array2<f64>,
3553        amp: f64,
3554    ) -> Result<(), EstimationError> {
3555        let dense = self.materialize_dense();
3556        if target.raw_dim() != dense.raw_dim() {
3557            crate::bail_invalid_estim!(
3558                "implicit hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3559                target.nrows(),
3560                target.ncols(),
3561                dense.nrows(),
3562                dense.ncols()
3563            );
3564        }
3565        target.scaled_add(amp, dense);
3566        Ok(())
3567    }
3568}
3569
3570impl DerivativeStorageBackend for LatentCoordDerivativeOp {
3571    fn resident_byte_count(&self) -> usize {
3572        0
3573    }
3574    fn design_nrows(&self) -> usize {
3575        self.nrows()
3576    }
3577    fn design_ncols(&self) -> usize {
3578        self.ncols()
3579    }
3580    fn penalty_dim(&self) -> usize {
3581        self.nrows()
3582    }
3583    fn uses_implicit_storage(&self) -> bool {
3584        true
3585    }
3586    fn any_nonzero(&self) -> bool {
3587        true
3588    }
3589    fn materialize(&self) -> Array2<f64> {
3590        self.materialize_dense().clone()
3591    }
3592    fn implicit_first_axis_info(
3593        &self,
3594    ) -> Option<(
3595        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3596        usize,
3597    )> {
3598        None
3599    }
3600    fn implicit_axis_count_hint(&self) -> Option<usize> {
3601        Some(self.operator.n_axes())
3602    }
3603
3604    fn design_forward_mul_original(&self, u: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
3605        if self.ncols() != u.len() {
3606            crate::bail_invalid_estim!(
3607                "latent-coordinate hyper design derivative forward_mul_original width mismatch: operator_cols={}, vector={}",
3608                self.ncols(),
3609                u.len()
3610            );
3611        }
3612        Ok(self.forward_mul(u))
3613    }
3614
3615    fn design_transpose_mul_original(
3616        &self,
3617        v: &Array1<f64>,
3618    ) -> Result<Array1<f64>, EstimationError> {
3619        if self.nrows() != v.len() {
3620            crate::bail_invalid_estim!(
3621                "latent-coordinate hyper design derivative transpose_mul_original height mismatch: operator_rows={}, vector={}",
3622                self.nrows(),
3623                v.len()
3624            );
3625        }
3626        Ok(self.transpose_mul(v))
3627    }
3628
3629    fn design_transformed(
3630        &self,
3631        qs: &Array2<f64>,
3632        free_basis_opt: Option<&Array2<f64>>,
3633    ) -> Result<Array2<f64>, EstimationError> {
3634        let dense = self.materialize_dense();
3635        Ok(gam_linalg::matrix::DenseRightProductView::new(dense)
3636            .with_factor(qs)
3637            .with_optional_factor(free_basis_opt)
3638            .materialize())
3639    }
3640
3641    fn design_transformed_forward_mul(
3642        &self,
3643        qs: &Array2<f64>,
3644        free_basis_opt: Option<&Array2<f64>>,
3645        u: &Array1<f64>,
3646    ) -> Result<Array1<f64>, EstimationError> {
3647        let mut right = if let Some(z) = free_basis_opt {
3648            z.dot(u)
3649        } else {
3650            u.clone()
3651        };
3652        right = qs.dot(&right);
3653        Ok(self.forward_mul(&right))
3654    }
3655
3656    fn design_transformed_transpose_mul(
3657        &self,
3658        qs: &Array2<f64>,
3659        free_basis_opt: Option<&Array2<f64>>,
3660        v: &Array1<f64>,
3661    ) -> Result<Array1<f64>, EstimationError> {
3662        let mut pulled = qs.t().dot(&self.transpose_mul(v));
3663        if let Some(z) = free_basis_opt {
3664            pulled = z.t().dot(&pulled);
3665        }
3666        Ok(pulled)
3667    }
3668
3669    fn penalty_transformed(
3670        &self,
3671        qs: &Array2<f64>,
3672        free_basis_opt: Option<&Array2<f64>>,
3673    ) -> Result<Array2<f64>, EstimationError> {
3674        let dense = self.materialize_dense();
3675        let mut transformed = qs.t().dot(dense).dot(qs);
3676        if let Some(z) = free_basis_opt {
3677            transformed = z.t().dot(&transformed).dot(z);
3678        }
3679        Ok(transformed)
3680    }
3681
3682    fn penalty_scaled_add_to(
3683        &self,
3684        target: &mut Array2<f64>,
3685        amp: f64,
3686    ) -> Result<(), EstimationError> {
3687        let dense = self.materialize_dense();
3688        if target.raw_dim() != dense.raw_dim() {
3689            crate::bail_invalid_estim!(
3690                "latent-coordinate hyper penalty derivative shape mismatch: target={}x{}, matrix={}x{}",
3691                target.nrows(),
3692                target.ncols(),
3693                dense.nrows(),
3694                dense.ncols()
3695            );
3696        }
3697        target.scaled_add(amp, dense);
3698        Ok(())
3699    }
3700}
3701
3702#[derive(Clone)]
3703pub struct HyperDesignDerivative {
3704    pub(crate) storage: DerivativeMatrixStorage,
3705}
3706
3707impl HyperDesignDerivative {
3708    pub fn zero(nrows: usize, ncols: usize) -> Self {
3709        Self {
3710            storage: DerivativeMatrixStorage::Zero(ZeroDerivativeMatrix::new(nrows, ncols)),
3711        }
3712    }
3713
3714    pub fn from_embedded(
3715        local: Array2<f64>,
3716        global_range: Range<usize>,
3717        total_cols: usize,
3718    ) -> Self {
3719        Self {
3720            storage: DerivativeMatrixStorage::Embedded(EmbeddedDerivativeMatrix::new(
3721                local,
3722                global_range,
3723                total_cols,
3724            )),
3725        }
3726    }
3727
3728    pub fn from_implicit(
3729        operator: std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3730        level: ImplicitDerivLevel,
3731        global_range: Range<usize>,
3732        total_cols: usize,
3733    ) -> Self {
3734        Self {
3735            storage: DerivativeMatrixStorage::Implicit(ImplicitDerivativeOp {
3736                operator,
3737                level,
3738                global_range,
3739                total_dim: total_cols,
3740                cached_dense: std::sync::Arc::new(gam_runtime::resource::RayonSafeOnce::new()),
3741            }),
3742        }
3743    }
3744
3745    pub fn from_latent_coord(
3746        operator: std::sync::Arc<gam_terms::basis::LatentCoordDesignDerivative>,
3747        flat_axis: usize,
3748        global_range: Range<usize>,
3749        total_cols: usize,
3750    ) -> Self {
3751        Self {
3752            storage: DerivativeMatrixStorage::LatentCoord(LatentCoordDerivativeOp {
3753                operator,
3754                flat_axis,
3755                global_range,
3756                total_dim: total_cols,
3757                cached_dense: std::sync::Arc::new(gam_runtime::resource::RayonSafeOnce::new()),
3758            }),
3759        }
3760    }
3761
3762    pub(crate) fn resident_byte_count(&self) -> usize {
3763        storage_dispatch!(&self.storage, b => b.resident_byte_count())
3764    }
3765
3766    pub(crate) fn nrows(&self) -> usize {
3767        storage_dispatch!(&self.storage, b => b.design_nrows())
3768    }
3769
3770    pub(crate) fn ncols(&self) -> usize {
3771        storage_dispatch!(&self.storage, b => b.design_ncols())
3772    }
3773
3774    pub(crate) fn uses_implicit_storage(&self) -> bool {
3775        storage_dispatch!(&self.storage, b => b.uses_implicit_storage())
3776    }
3777
3778    pub(crate) fn materialize(&self) -> Array2<f64> {
3779        storage_dispatch!(&self.storage, b => b.materialize())
3780    }
3781
3782    pub(crate) fn any_nonzero(&self) -> bool {
3783        storage_dispatch!(&self.storage, b => b.any_nonzero())
3784    }
3785
3786    pub(crate) fn forward_mul_original(
3787        &self,
3788        u: &Array1<f64>,
3789    ) -> Result<Array1<f64>, EstimationError> {
3790        storage_dispatch!(&self.storage, b => b.design_forward_mul_original(u))
3791    }
3792
3793    pub(crate) fn transpose_mul_original(
3794        &self,
3795        v: &Array1<f64>,
3796    ) -> Result<Array1<f64>, EstimationError> {
3797        storage_dispatch!(&self.storage, b => b.design_transpose_mul_original(v))
3798    }
3799
3800    pub(crate) fn transformed(
3801        &self,
3802        qs: &Array2<f64>,
3803        free_basis_opt: Option<&Array2<f64>>,
3804    ) -> Result<Array2<f64>, EstimationError> {
3805        storage_dispatch!(&self.storage, b => b.design_transformed(qs, free_basis_opt))
3806    }
3807
3808    pub(crate) fn transformed_forward_mul(
3809        &self,
3810        qs: &Array2<f64>,
3811        free_basis_opt: Option<&Array2<f64>>,
3812        u: &Array1<f64>,
3813    ) -> Result<Array1<f64>, EstimationError> {
3814        storage_dispatch!(&self.storage, b => b.design_transformed_forward_mul(qs, free_basis_opt, u))
3815    }
3816
3817    pub(crate) fn transformed_transpose_mul(
3818        &self,
3819        qs: &Array2<f64>,
3820        free_basis_opt: Option<&Array2<f64>>,
3821        v: &Array1<f64>,
3822    ) -> Result<Array1<f64>, EstimationError> {
3823        storage_dispatch!(&self.storage, b => b.design_transformed_transpose_mul(qs, free_basis_opt, v))
3824    }
3825
3826    /// If this derivative uses implicit storage at the first-derivative level,
3827    /// return the shared implicit operator and the axis index.
3828    ///
3829    /// Returns `None` for dense/embedded storage or for second-derivative levels.
3830    pub(crate) fn implicit_first_axis_info(
3831        &self,
3832    ) -> Option<(
3833        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
3834        usize,
3835    )> {
3836        storage_dispatch!(&self.storage, b => b.implicit_first_axis_info())
3837    }
3838
3839    pub(crate) fn implicit_axis_count_hint(&self) -> Option<usize> {
3840        storage_dispatch!(&self.storage, b => b.implicit_axis_count_hint())
3841    }
3842}
3843
3844impl From<Array2<f64>> for HyperDesignDerivative {
3845    fn from(value: Array2<f64>) -> Self {
3846        Self {
3847            storage: DerivativeMatrixStorage::Dense(value),
3848        }
3849    }
3850}
3851
3852#[derive(Clone)]
3853pub struct HyperPenaltyDerivative {
3854    pub(crate) storage: DerivativeMatrixStorage,
3855}
3856
3857impl HyperPenaltyDerivative {
3858    pub fn from_embedded(
3859        local: Array2<f64>,
3860        global_range: Range<usize>,
3861        total_dim: usize,
3862    ) -> Self {
3863        Self {
3864            storage: DerivativeMatrixStorage::Embedded(EmbeddedDerivativeMatrix::new(
3865                local,
3866                global_range,
3867                total_dim,
3868            )),
3869        }
3870    }
3871
3872    pub(crate) fn resident_byte_count(&self) -> usize {
3873        storage_dispatch!(&self.storage, b => b.resident_byte_count())
3874    }
3875
3876    pub(crate) fn nrows(&self) -> usize {
3877        storage_dispatch!(&self.storage, b => b.penalty_dim())
3878    }
3879
3880    pub(crate) fn ncols(&self) -> usize {
3881        self.nrows()
3882    }
3883
3884    pub(crate) fn scaled_materialize(&self, amp: f64) -> Array2<f64> {
3885        let mut out = Array2::<f64>::zeros((self.nrows(), self.ncols()));
3886        self.scaled_add_to(&mut out, amp)
3887            .expect("scaled materialize uses matching target shape");
3888        out
3889    }
3890
3891    pub(crate) fn transformed(
3892        &self,
3893        qs: &Array2<f64>,
3894        free_basis_opt: Option<&Array2<f64>>,
3895    ) -> Result<Array2<f64>, EstimationError> {
3896        storage_dispatch!(&self.storage, b => b.penalty_transformed(qs, free_basis_opt))
3897    }
3898
3899    pub(crate) fn scaled_add_to(
3900        &self,
3901        target: &mut Array2<f64>,
3902        amp: f64,
3903    ) -> Result<(), EstimationError> {
3904        storage_dispatch!(&self.storage, b => b.penalty_scaled_add_to(target, amp))
3905    }
3906}
3907
3908impl From<Array2<f64>> for HyperPenaltyDerivative {
3909    fn from(value: Array2<f64>) -> Self {
3910        Self {
3911            storage: DerivativeMatrixStorage::Dense(value),
3912        }
3913    }
3914}
3915
3916#[derive(Clone)]
3917pub struct PenaltyDerivativeComponent {
3918    pub penalty_index: usize,
3919    pub matrix: HyperPenaltyDerivative,
3920}
3921
3922#[derive(Clone)]
3923pub struct DirectionalHyperParam {
3924    pub(crate) x_tau_original: HyperDesignDerivative,
3925    // Canonical penalty representation: every tau direction is decomposed into
3926    // base-penalty derivatives. There is no separate "assembled total" path.
3927    pub(crate) penalty_first_components: Vec<PenaltyDerivativeComponent>,
3928    // Optional pairwise second hyper-derivatives against all tau directions.
3929    // If provided, each vector must have length psi_dim and hold an optional
3930    // X_{tau_i,tau_j} entry in original coordinates.
3931    pub(crate) x_tau_tau_original: Option<Vec<Option<HyperDesignDerivative>>>,
3932    // Pairwise second derivatives are stored in the same canonical base-penalty
3933    // decomposition as the first derivatives.
3934    pub(crate) penaltysecond_components: Option<Vec<Option<Vec<PenaltyDerivativeComponent>>>>,
3935    pub(crate) penaltysecond_component_provider: Option<
3936        std::sync::Arc<
3937            dyn Fn(usize) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError>
3938                + Send
3939                + Sync
3940                + 'static,
3941        >,
3942    >,
3943    pub(crate) penaltysecond_partner_indices: Option<std::sync::Arc<[usize]>>,
3944    /// Whether this coordinate is penalty-like (B_i = ∂H/∂τ_i is PSD).
3945    /// True for τ (penalty scaling) coordinates; false for ψ (design-moving,
3946    /// anisotropic length-scale) coordinates. Controls EFS eligibility.
3947    pub(crate) is_penalty_like: bool,
3948}
3949
3950impl DirectionalHyperParam {
3951    pub(crate) fn resident_byte_count(&self) -> usize {
3952        let mut bytes = self.x_tau_original.resident_byte_count();
3953        for component in &self.penalty_first_components {
3954            bytes = bytes.saturating_add(component.matrix.resident_byte_count());
3955        }
3956        if let Some(entries) = self.x_tau_tau_original.as_ref() {
3957            for entry in entries.iter().flatten() {
3958                bytes = bytes.saturating_add(entry.resident_byte_count());
3959            }
3960        }
3961        if let Some(rows) = self.penaltysecond_components.as_ref() {
3962            for components in rows.iter().flatten() {
3963                for component in components {
3964                    bytes = bytes.saturating_add(component.matrix.resident_byte_count());
3965                }
3966            }
3967        }
3968        bytes
3969    }
3970
3971    pub(crate) fn canonicalize_penalty_components(
3972        components: Vec<(usize, HyperPenaltyDerivative)>,
3973    ) -> Result<Vec<PenaltyDerivativeComponent>, EstimationError> {
3974        let mut out: Vec<PenaltyDerivativeComponent> = Vec::with_capacity(components.len());
3975        for (penalty_index, matrix) in components {
3976            if out.iter().any(|c| c.penalty_index == penalty_index) {
3977                crate::bail_invalid_estim!(
3978                    "duplicate penalty derivative component for penalty {}",
3979                    penalty_index
3980                );
3981            }
3982            out.push(PenaltyDerivativeComponent {
3983                penalty_index,
3984                matrix,
3985            });
3986        }
3987        Ok(out)
3988    }
3989
3990    pub fn new_compact(
3991        x_tau_original: HyperDesignDerivative,
3992        penalty_first_components: Vec<(usize, HyperPenaltyDerivative)>,
3993        x_tau_tau_original: Option<Vec<Option<HyperDesignDerivative>>>,
3994        penaltysecond_components: Option<Vec<Option<Vec<(usize, HyperPenaltyDerivative)>>>>,
3995    ) -> Result<Self, EstimationError> {
3996        let is_penalty_like = !x_tau_original.any_nonzero();
3997        let penalty_first_components =
3998            Self::canonicalize_penalty_components(penalty_first_components)?;
3999        let penaltysecond_components = match penaltysecond_components {
4000            Some(rows) => {
4001                let mut out = Vec::with_capacity(rows.len());
4002                for row in rows {
4003                    out.push(match row {
4004                        Some(components) => {
4005                            Some(Self::canonicalize_penalty_components(components)?)
4006                        }
4007                        None => None,
4008                    });
4009                }
4010                Some(out)
4011            }
4012            None => None,
4013        };
4014        Ok(Self {
4015            x_tau_original,
4016            penalty_first_components,
4017            x_tau_tau_original,
4018            penaltysecond_components,
4019            penaltysecond_component_provider: None,
4020            penaltysecond_partner_indices: None,
4021            is_penalty_like,
4022        })
4023    }
4024
4025    /// Mark this coordinate as non-penalty-like (design-moving).
4026    /// EFS will skip it; use Newton/BFGS for these coordinates.
4027    pub fn not_penalty_like(mut self) -> Self {
4028        self.is_penalty_like = false;
4029        self
4030    }
4031
4032    pub fn with_penaltysecond_component_provider(
4033        mut self,
4034        provider: std::sync::Arc<
4035            dyn Fn(usize) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError>
4036                + Send
4037                + Sync
4038                + 'static,
4039        >,
4040    ) -> Self {
4041        self.penaltysecond_component_provider = Some(provider);
4042        self
4043    }
4044
4045    pub fn with_penaltysecond_partner_indices(mut self, partners: Vec<usize>) -> Self {
4046        self.penaltysecond_partner_indices = Some(std::sync::Arc::from(partners));
4047        self
4048    }
4049
4050    pub(crate) fn x_tau_dense(&self) -> Array2<f64> {
4051        self.x_tau_original.materialize()
4052    }
4053
4054    pub(crate) fn transformed_x_tau(
4055        &self,
4056        qs: &Array2<f64>,
4057        free_basis_opt: Option<&Array2<f64>>,
4058    ) -> Result<Array2<f64>, EstimationError> {
4059        self.x_tau_original.transformed(qs, free_basis_opt)
4060    }
4061
4062    pub(crate) fn x_tau_tau_entry_at(&self, j: usize) -> Option<HyperDesignDerivative> {
4063        self.x_tau_tau_original
4064            .as_ref()
4065            .and_then(|rows| rows.get(j))
4066            .and_then(|entry| entry.clone())
4067    }
4068
4069    /// Whether this coordinate's design derivative uses implicit storage at the
4070    /// first-derivative level.
4071    pub(crate) fn has_implicit_operator(&self) -> bool {
4072        self.x_tau_original.uses_implicit_storage()
4073    }
4074
4075    pub(crate) fn has_implicit_multidim_duchon(&self) -> bool {
4076        self.implicit_first_axis_info()
4077            .is_some_and(|(op, _)| op.n_axes() > 1 && op.is_duchon_family())
4078    }
4079
4080    /// Extract the implicit design derivative operator and axis, if available.
4081    pub(crate) fn implicit_first_axis_info(
4082        &self,
4083    ) -> Option<(
4084        std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
4085        usize,
4086    )> {
4087        self.x_tau_original.implicit_first_axis_info()
4088    }
4089
4090    pub(crate) fn implicit_axis_count_hint(&self) -> Option<usize> {
4091        self.x_tau_original.implicit_axis_count_hint()
4092    }
4093
4094    pub(crate) fn penalty_first_components(&self) -> &[PenaltyDerivativeComponent] {
4095        &self.penalty_first_components
4096    }
4097
4098    pub(crate) fn penalty_total_at(
4099        &self,
4100        rho: &Array1<f64>,
4101        p: usize,
4102    ) -> Result<Array2<f64>, EstimationError> {
4103        let mut out = Array2::<f64>::zeros((p, p));
4104        for component in &self.penalty_first_components {
4105            if component.matrix.nrows() != p || component.matrix.ncols() != p {
4106                crate::bail_invalid_estim!(
4107                    "S_tau shape mismatch for penalty {}: expected {}x{}, got {}x{}",
4108                    component.penalty_index,
4109                    p,
4110                    p,
4111                    component.matrix.nrows(),
4112                    component.matrix.ncols()
4113                );
4114            }
4115            if component.penalty_index >= rho.len() {
4116                crate::bail_invalid_estim!(
4117                    "penalty_index {} out of bounds for rho dimension {}",
4118                    component.penalty_index,
4119                    rho.len()
4120                );
4121            }
4122            component
4123                .matrix
4124                .scaled_add_to(&mut out, rho[component.penalty_index].exp())?;
4125        }
4126        Ok(out)
4127    }
4128
4129    pub(crate) fn penaltysecond_components_for(
4130        &self,
4131        j: usize,
4132    ) -> Result<Option<Vec<PenaltyDerivativeComponent>>, EstimationError> {
4133        if let Some(components) = self
4134            .penaltysecond_components
4135            .as_ref()
4136            .and_then(|rows| rows.get(j))
4137            .and_then(|row| row.clone())
4138        {
4139            return Ok(Some(components));
4140        }
4141        if let Some(provider) = self.penaltysecond_component_provider.as_ref() {
4142            return provider(j);
4143        }
4144        Ok(None)
4145    }
4146
4147    pub(crate) fn penaltysecond_componentrows(
4148        &self,
4149    ) -> Option<&[Option<Vec<PenaltyDerivativeComponent>>]> {
4150        self.penaltysecond_components.as_deref()
4151    }
4152
4153    pub(crate) fn penalty_first_component_count(&self) -> usize {
4154        self.penalty_first_components.len()
4155    }
4156
4157    pub(crate) fn has_penaltysecond_pair_at(&self, j: usize) -> bool {
4158        self.penaltysecond_components
4159            .as_ref()
4160            .and_then(|rows| rows.get(j))
4161            .is_some_and(Option::is_some)
4162            || self
4163                .penaltysecond_partner_indices
4164                .as_ref()
4165                .is_some_and(|partners| partners.contains(&j))
4166    }
4167}
4168
4169#[derive(Clone, Debug)]
4170pub(crate) struct SparseRemlDecision {
4171    pub(crate) geometry: RemlGeometry,
4172    pub(crate) reason: &'static str,
4173    pub(crate) p: usize,
4174    pub(crate) nnz_x: usize,
4175    pub(crate) nnz_h_upper_est: Option<usize>,
4176    pub(crate) density_h_upper_est: Option<f64>,
4177}
4178
4179#[derive(Clone)]
4180pub(crate) struct SparseExactEvalData {
4181    pub(crate) factor: Arc<SparseExactFactor>,
4182    pub(crate) takahashi: Option<Arc<gam_linalg::sparse_exact::TakahashiInverse>>,
4183    pub(crate) logdet_h: f64,
4184    pub(crate) logdet_s_pos: f64,
4185    pub(crate) penalty_rank: usize,
4186    pub(crate) det1_values: Arc<Array1<f64>>,
4187}
4188
4189#[derive(Clone)]
4190pub struct FirthDenseOperator {
4191    // Exact Firth/Jeffreys objects on the identifiable subspace.
4192    //
4193    // Let X in R^{n×p} potentially be rank-deficient with rank r.
4194    // With optional fixed observation weights a_i >= 0 we define A = diag(a),
4195    // choose an orthonormal coefficient-space basis Q for the identifiable
4196    // subspace of A^{1/2} X, and set:
4197    //   X_r := A^{1/2} X Q          (A = I when no fixed observation weights),
4198    //   W   := diag(w), with w_i = mu_i (1 - mu_i), 0 < w_i <= 1/4 for finite logit eta,
4199    //   I_r := X_rᵀ W X_r,
4200    //   S_r := X_rᵀ X_r.
4201    //
4202    // Firth term is represented as:
4203    //   Phi(beta) = 0.5 log |I_r(beta)| - 0.5 log |S_r|,
4204    // which is exactly
4205    //   0.5 log |Uᵀ W U|
4206    // for the canonical orthonormalized identifiable design
4207    //   U = X_r S_r^{-1/2}.
4208    // This removes the raw-basis term from explicit reduced designs while
4209    // keeping the same identifiable-subspace hat matrix and beta derivatives,
4210    // because S_r is fixed with respect to beta.
4211    //
4212    // Mapping back to the full p-space uses:
4213    //   I_+^dagger = Q I_r^{-1} Qᵀ.
4214    //
4215    // We store reduced-space factors so all derivatives can be evaluated exactly
4216    // without materializing dense n×n matrices M = X K Xᵀ or P = M⊙M.
4217    pub(crate) x_dense: Array2<f64>,
4218    pub(crate) x_dense_t: Array2<f64>,
4219    // Orthonormal coefficient-space basis for the identifiable subspace,
4220    // built from the retained eigenspace of (A^{1/2} X)ᵀ(A^{1/2} X).
4221    pub(crate) q_basis: Array2<f64>,
4222    // Reduced identifiable design. With fixed observation weights a_i this is
4223    // diag(sqrt(a_i)) X Q; otherwise it is X Q.
4224    pub(crate) x_reduced: Array2<f64>,
4225    // Optional fixed case-weight square roots used when the Jeffreys/Firth
4226    // operator is formed from Xᵀ diag(case_weight ⊙ w(η)) X rather than
4227    // Xᵀ diag(w(η)) X. The exact directional tau derivatives must project and
4228    // row-scale with the same weights so the reduced Fisher, hat diagonals,
4229    // and tau kernels all live on one consistent identifiable subspace.
4230    pub(crate) observation_weight_sqrt: Option<Array1<f64>>,
4231    // I_r^{-1}
4232    pub(crate) k_reduced: Array2<f64>,
4233    // diag(S_r^{-1}) with S_r = X_rᵀ X_r. In the current canonical reduced
4234    // basis this completely characterizes the metric inverse, because Q
4235    // diagonalizes the design Gram. It is used to remove the reduced-coordinate
4236    // basis term from Phi_tau when the design moves.
4237    pub(crate) x_metric_reduced_inv_diag: Array1<f64>,
4238    // 0.5 (log|I_r| - log|S_r|) at the current eta.
4239    pub(crate) half_log_det: f64,
4240    // h = diag(M), M = X_r K_r X_r'
4241    pub(crate) h_diag: Array1<f64>,
4242    // Logistic Fisher-weight eta-derivatives: w', w'', w''', w'''' as n-vectors.
4243    pub(crate) w: Array1<f64>,
4244    pub(crate) w1: Array1<f64>,
4245    pub(crate) w2: Array1<f64>,
4246    pub(crate) w3: Array1<f64>,
4247    pub(crate) w4: Array1<f64>,
4248    // B = diag(w') X used in D Hphi and D^2 Hphi contractions.
4249    pub(crate) b_base: Array2<f64>,
4250    // Cached invariant contraction P*B where P = (X_r K_r X_r') ⊙ (X_r K_r X_r').
4251    // This avoids recomputing the same O(n r^2 p) block in every directional call.
4252    pub(crate) p_b_base: Array2<f64>,
4253}
4254
4255#[derive(Clone)]
4256pub(crate) struct FirthDirection {
4257    pub(crate) deta: Array1<f64>,
4258    pub(crate) g_u_reduced: Array2<f64>,
4259    pub(crate) a_u_reduced: Array2<f64>,
4260    pub(crate) dh: Array1<f64>,
4261    // B_u = diag(w'' ⊙ δη_u) X is represented by the row-scaling vector only.
4262    pub(crate) b_uvec: Array1<f64>,
4263}
4264
4265#[derive(Clone)]
4266pub(crate) struct FirthTauPartialKernel {
4267    pub(super) deta_partial: Array1<f64>,
4268    pub(crate) dotw1: Array1<f64>,
4269    pub(crate) dotw2: Array1<f64>,
4270    pub(crate) dot_h_partial: Array1<f64>,
4271    // Reduced design drift X_{tau,r} = X_tau Q used in exact design-moving
4272    // Hadamard-Gram contractions.
4273    pub(crate) x_tau_reduced: Array2<f64>,
4274    pub(super) dot_i_partial: Array2<f64>,
4275    // Reduced Fisher inverse drift:
4276    //   dot(K_r) = -K_r dot(I_r) K_r
4277    // where dot(I_r) includes explicit X_tau and weight drift at beta-fixed.
4278    pub(crate) dot_k_reduced: Array2<f64>,
4279}
4280
4281#[derive(Clone)]
4282pub(crate) struct FirthTauExactKernel {
4283    pub(crate) gphi_tau: Array1<f64>,
4284    pub(crate) phi_tau_partial: f64,
4285    pub(crate) tau_kernel: Option<FirthTauPartialKernel>,
4286}
4287
4288/// Pair-level (τ_i × τ_j) exact Firth bundle at fixed β.
4289///
4290/// Mirrors `FirthTauExactKernel` but for the 2nd-order cross
4291/// derivatives:
4292///   Phi_{τ_i τ_j}|β  (scalar, `phi_tau_tau_partial`)
4293///   (gphi)_{τ_i τ_j}|β (p-vector, `gphi_tau_tau`)
4294///
4295/// Carries an optional `tau_tau_kernel` so pair callbacks can chain
4296/// into Primitive A (`hphi_tau_tau_partial_apply`) for the operator-
4297/// valued Hessian 2nd drift without recomputing shared reduced Grams.
4298///
4299#[derive(Clone)]
4300pub(crate) struct FirthTauTauExactKernel {
4301    pub(super) phi_tau_tau_partial: f64,
4302    pub(super) gphi_tau_tau: Array1<f64>,
4303    pub(super) tau_tau_kernel: Option<FirthTauTauPartialKernel>,
4304}
4305
4306/// Prepared state for `∂²H_φ/∂τ_i ∂τ_j |_β` (Primitive A).
4307///
4308/// Carries both τ-direction reduced designs, their η̇ vectors, and the
4309/// reduced-coordinate drifts (İ, K̇, ḣ) for i and j so the apply step can
4310/// form M̈_{ij}, K̈_{ij}, ḧ_{ij}, Γ̈_{ij}, and B̈_{ij} matrix-free.  Fields
4311/// are filled in by 13b; kept with a neutral internal shape so downstream
4312/// pair callbacks can hold the kernel across the pair dispatch.
4313///
4314/// Wired into the pair-callback's `b_operator` via
4315/// `FirthAugmentedPairHyperOperator`, and produced by both
4316/// `hphi_tau_tau_partial_prepare_from_partials` and
4317/// `exact_tau_tau_kernel` (the scalar/p-vector companion).
4318#[derive(Clone, Default)]
4319pub(crate) struct FirthTauTauPartialKernel {
4320    pub(super) x_tau_i_reduced: Array2<f64>,
4321    pub(super) x_tau_j_reduced: Array2<f64>,
4322    pub(super) deta_i_partial: Array1<f64>,
4323    pub(super) deta_j_partial: Array1<f64>,
4324    pub(super) dot_h_i_partial: Array1<f64>,
4325    pub(super) dot_h_j_partial: Array1<f64>,
4326    pub(super) dot_k_i_reduced: Array2<f64>,
4327    pub(super) dot_k_j_reduced: Array2<f64>,
4328    pub(super) dot_i_i_partial: Array2<f64>,
4329    pub(super) dot_i_j_partial: Array2<f64>,
4330    pub(super) x_tau_tau_reduced: Option<Array2<f64>>,
4331    pub(super) deta_ij_partial: Option<Array1<f64>>,
4332}
4333
4334/// Prepared state for `D_β((H_φ)_τ|_β)[v]` (Primitive B).
4335///
4336/// Carries the τ-kernel pieces (x_tau_reduced, İ, K̇, ḣ), the
4337/// β-direction quantities (δη_v, A_v, dh_v, b-chain), and the mixed
4338/// β-τ pieces (D_β(K̇_τ)[v], D_β(ḣ_τ)[v], δη_{τ,v}) so the apply
4339/// step collapses to the 9-term β-τ expansion without recomputing
4340/// shared reduced Grams.
4341#[derive(Clone, Default)]
4342pub(crate) struct FirthTauBetaPartialKernel {
4343    pub(super) x_tau_reduced: Array2<f64>,
4344    pub(super) deta_partial: Array1<f64>,
4345    pub(super) dot_h_partial: Array1<f64>,
4346    pub(super) dot_i_partial: Array2<f64>,
4347    pub(super) dot_k_reduced: Array2<f64>,
4348    pub(super) deta_v: Array1<f64>,
4349    pub(super) deta_tau_v: Array1<f64>,
4350    pub(super) a_v_reduced: Array2<f64>,
4351    pub(super) dh_v: Array1<f64>,
4352    pub(super) b_vvec: Array1<f64>,
4353    pub(super) d_beta_dot_k: Array2<f64>,
4354    pub(super) d_beta_dot_h: Array1<f64>,
4355}
4356
4357/// Holds the state for the outer REML optimization and supplies cost and
4358/// gradient evaluations to the `opt` optimizer.
4359///
4360/// The `cache` field uses `RefCell` to enable interior mutability. This is a crucial
4361/// performance optimization. The `cost_andgrad` closure required by the BFGS
4362/// optimizer takes an immutable reference `&self`. However, we want to cache the
4363/// results of the expensive P-IRLS computation to avoid re-calculating the fit
4364/// for the same `rho` vector, which can happen during the line search.
4365/// `RefCell` allows us to mutate the cache through a `&self` reference,
4366/// making this optimization possible while adhering to the optimizer's API.
4367#[derive(Clone)]
4368pub(crate) struct EvalShared {
4369    pub(crate) key: Option<Vec<u64>>,
4370    pub(crate) pirls_result: Arc<PirlsResult>,
4371    pub(crate) ridge_passport: RidgePassport,
4372    pub(crate) geometry: RemlGeometry,
4373    /// The exact H_total matrix used for LAML cost computation.
4374    /// For Firth: effective Hessian minus hphi (plus any barrier curvature).
4375    /// For non-Firth: the effective Hessian itself (plus any barrier curvature).
4376    pub(crate) h_total: Arc<Array2<f64>>,
4377    pub(crate) sparse_exact: Option<Arc<SparseExactEvalData>>,
4378    pub(crate) firth_dense_operator: Option<Arc<FirthDenseOperator>>,
4379    /// Cached FirthDenseOperator built from the original (non-reparameterized)
4380    /// design matrix, for use by the sparse evaluation path.
4381    pub(crate) firth_dense_operator_original: Option<Arc<FirthDenseOperator>>,
4382    /// The ONE original-frame penalty pseudo-logdet factorization for this
4383    /// evaluation point (#931 atom discipline). `log|Σ λ_k S_k|₊`'s VALUE,
4384    /// ρ-derivatives, τ/ψ components, and ρ×τ cross blocks are all
4385    /// contractions of this single eigendecomposition; the ρ-side criterion
4386    /// assembly (`dense_penalty_logdet_derivs`, the sparse det2 path) and the
4387    /// original-basis hyper-coordinate builders share it through
4388    /// [`EvalShared::penalty_pseudologdet_original`]. Building a second
4389    /// factorization of the same Sλ for the same evaluation point is the
4390    /// objective↔gradient desync surface (#748/#752/#901) this cell removes:
4391    /// the ridge and positive-eigenspace threshold are decided exactly once.
4392    /// (The transformed-frame pair-callback path builds its own object — it
4393    /// factorizes the canonical-TRANSFORMED, possibly constraint-projected
4394    /// penalties, a genuinely different matrix, not a duplicate of this one.)
4395    pub(crate) penalty_pseudologdet: std::sync::OnceLock<Arc<penalty_logdet::PenaltyPseudologdet>>,
4396    /// Per-evaluation-point cache of the canonical penalty score vectors
4397    /// `S_k β̂` evaluated at this bundle's inner mode `β̂ =
4398    /// pirls_result.beta_transformed` (unscaled by λ_k). These depend ONLY
4399    /// on the inner solution carried by this bundle and the `RemlState`'s
4400    /// fixed `canonical_penalties` — never on which ρ-coordinate or eval
4401    /// mode the assembly is running — so they are computed exactly once per
4402    /// inner solution and shared by every assemble call that reuses the
4403    /// bundle (cost + gradient evaluations at the same ρ, EFS, synthetic-ext
4404    /// value probes). Exact hoist, not an approximation: every consumer sees
4405    /// literally the same vectors it previously recomputed. Initialized via
4406    /// plain ndarray matvecs (no rayon inside the `OnceLock` closure — the
4407    /// `get_or_init`+`into_par_iter` deadlock trap does not apply).
4408    pub(crate) penalty_scores_at_mode: std::sync::OnceLock<Arc<Vec<Array1<f64>>>>,
4409    /// Per-evaluation-point cache of the #784 block-local Laplace-to-sampling
4410    /// correction `TkCorrectionTerms { value, gradient }`. The correction is a
4411    /// deterministic function of ONLY this bundle's converged inner state
4412    /// (`pirls_result`, `h_total`), the `RemlState`'s fixed
4413    /// `canonical_penalties`, and the bundle's ρ — never of the eval `mode`:
4414    /// the diagnostic eigendecomposition, the fixed-seed importance sampler,
4415    /// and the (b)–(d) gradient channels all read mode-invariant fields, and
4416    /// the term carries no Hessian, so the value+gradient are identical for the
4417    /// value-only, value+gradient, and value+gradient+Hessian assemble calls
4418    /// that share this bundle at a single ρ. The expensive path (eigendecomp +
4419    /// O(draws·n·m) sampler) previously reran on every one of those 2–3 calls
4420    /// per outer iteration; hoisting it onto the bundle computes it exactly
4421    /// once per inner solution (exact hoist, identical values — #784, #1082).
4422    /// Keyed only on the external-coordinate count `n_ext`: with no ψ
4423    /// coordinates (`n_ext == 0`) the correction engages; with ψ present the
4424    /// seam declines (returns the cheap zero), and n_ext is fixed for a fit, so
4425    /// a single cell suffices.
4426    pub(crate) block_local_correction:
4427        std::sync::OnceLock<(usize, Arc<outer_eval::TkCorrectionTerms>)>,
4428}
4429
4430impl EvalShared {
4431    pub(crate) fn matches(&self, key: &Option<Vec<u64>>) -> bool {
4432        match (&self.key, key) {
4433            (None, None) => true,
4434            (Some(a), Some(b)) => a == b,
4435            _ => false,
4436        }
4437    }
4438
4439    /// Lazily build — once per evaluation point — the original-frame
4440    /// [`PenaltyPseudologdet`](penalty_logdet::PenaltyPseudologdet) of
4441    /// `Σ λ_k S_k` and hand every caller the SAME factorization.
4442    ///
4443    /// This is the #931 port of the penalty-logdet term: value, ρ-first /
4444    /// ρ-second derivatives, τ-gradient components, τ×τ and ρ×τ Hessian
4445    /// blocks are all projections of one eigendecomposition, so no pair of
4446    /// consumers can disagree about the ridge or the positive-eigenspace
4447    /// threshold. The ridge is read from this bundle's `ridge_passport` —
4448    /// the single place that convention is decided.
4449    ///
4450    /// `lambdas` must be the λ = exp(ρ) vector of this bundle's evaluation
4451    /// point and `p` the original-basis coefficient dimension; on a cache
4452    /// hit both are checked against the stored object where representable.
4453    pub(crate) fn penalty_pseudologdet_original(
4454        &self,
4455        canonical_penalties: &[gam_terms::construction::CanonicalPenalty],
4456        lambdas: &[f64],
4457        p: usize,
4458    ) -> Result<Arc<penalty_logdet::PenaltyPseudologdet>, EstimationError> {
4459        if let Some(pld) = self.penalty_pseudologdet.get() {
4460            if pld.dim() != p {
4461                return Err(EstimationError::LayoutError(format!(
4462                    "shared penalty pseudo-logdet frame mismatch: cached p={}, requested p={}",
4463                    pld.dim(),
4464                    p
4465                )));
4466            }
4467            return Ok(Arc::clone(pld));
4468        }
4469        let pld = Arc::new(
4470            penalty_logdet::PenaltyPseudologdet::from_penalties(
4471                canonical_penalties,
4472                lambdas,
4473                self.ridge_passport.penalty_logdet_ridge(),
4474                p,
4475            )
4476            .map_err(EstimationError::InvalidInput)?,
4477        );
4478        match self.penalty_pseudologdet.set(Arc::clone(&pld)) {
4479            Ok(()) => Ok(pld),
4480            // A concurrent caller initialized the cell first; both objects
4481            // were built from identical inputs — return the canonical winner
4482            // so every consumer holds literally the same factorization.
4483            Err(_) => Ok(Arc::clone(
4484                self.penalty_pseudologdet
4485                    .get()
4486                    .expect("OnceLock set raced, so it is initialized"),
4487            )),
4488        }
4489    }
4490}
4491
4492impl PenalizedGeometry for EvalShared {
4493    fn backend_kind(&self) -> GeometryBackendKind {
4494        match self.geometry {
4495            RemlGeometry::DenseSpectral => GeometryBackendKind::DenseSpectral,
4496            RemlGeometry::SparseExactSpd => GeometryBackendKind::SparseExactSpd,
4497        }
4498    }
4499}
4500
4501/// LRU cache keyed by sanitized ρ vectors that holds compacted PIRLS results
4502/// for warm-starting outer line searches and revisited evaluations.
4503///
4504/// Eviction is byte-budgeted rather than entry-count-budgeted: each entry
4505/// records its own estimated footprint (the surviving n-length vectors plus
4506/// the two p×p Hessians plus per-entry overhead) and the cache evicts in
4507/// LRU order until the running total fits under the budget. An entry that
4508/// individually exceeds the budget is rejected silently rather than poisoning
4509/// the cache.
4510pub(crate) struct PirlsLruCache {
4511    // Stored tuple: (compacted result, last-touched clock, estimated bytes).
4512    pub(crate) map: HashMap<Vec<u64>, (Arc<PirlsResult>, u64, usize)>,
4513    pub(crate) byte_budget: usize,
4514    pub(crate) current_bytes: usize,
4515    pub(crate) clock: u64,
4516}
4517
4518impl PirlsLruCache {
4519    pub(crate) fn new(byte_budget: usize) -> Self {
4520        Self {
4521            map: HashMap::new(),
4522            byte_budget: byte_budget.max(1),
4523            current_bytes: 0,
4524            clock: 0,
4525        }
4526    }
4527
4528    pub(crate) fn get(&mut self, key: &Vec<u64>) -> Option<Arc<PirlsResult>> {
4529        if let Some(entry) = self.map.get_mut(key) {
4530            self.clock += 1;
4531            entry.1 = self.clock;
4532            Some(entry.0.clone())
4533        } else {
4534            None
4535        }
4536    }
4537
4538    pub(crate) fn insert(&mut self, key: Vec<u64>, value: Arc<PirlsResult>) {
4539        self.clock += 1;
4540        let bytes = pirls_result_cache_bytes(&value);
4541        // Refuse entries that on their own already exceed the entire budget;
4542        // caching one would force eviction of every other entry without
4543        // leaving room for the new one anyway.
4544        if bytes > self.byte_budget {
4545            if let Some((_, _, prev_bytes)) = self.map.remove(&key) {
4546                self.current_bytes = self.current_bytes.saturating_sub(prev_bytes);
4547            }
4548            return;
4549        }
4550        if let Some((_, _, prev_bytes)) = self.map.remove(&key) {
4551            self.current_bytes = self.current_bytes.saturating_sub(prev_bytes);
4552        }
4553        while self.current_bytes + bytes > self.byte_budget {
4554            let evict_key = self
4555                .map
4556                .iter()
4557                .min_by_key(|(_, (_, ts, _))| *ts)
4558                .map(|(k, _)| k.clone());
4559            match evict_key {
4560                Some(k) => {
4561                    if let Some((_, _, evict_bytes)) = self.map.remove(&k) {
4562                        self.current_bytes = self.current_bytes.saturating_sub(evict_bytes);
4563                    }
4564                }
4565                None => break,
4566            }
4567        }
4568        self.current_bytes += bytes;
4569        self.map.insert(key, (value, self.clock, bytes));
4570    }
4571
4572    pub(crate) fn clear(&mut self) {
4573        self.map.clear();
4574        self.current_bytes = 0;
4575    }
4576}
4577
4578#[derive(Clone, Copy, PartialEq, Eq)]
4579pub(crate) struct PenaltySubspaceCacheKey {
4580    pub(crate) penalty_matrix_fingerprint: u64,
4581    pub(crate) ridge_passport_signature: u64,
4582}
4583
4584pub(crate) struct PenaltySubspaceCache {
4585    pub(crate) entry: Option<(PenaltySubspaceCacheKey, Arc<outer_eval::PenaltySubspace>)>,
4586}
4587
4588impl PenaltySubspaceCache {
4589    pub(crate) fn new() -> Self {
4590        Self { entry: None }
4591    }
4592
4593    pub(crate) fn get(
4594        &self,
4595        key: &PenaltySubspaceCacheKey,
4596    ) -> Option<Arc<outer_eval::PenaltySubspace>> {
4597        self.entry
4598            .as_ref()
4599            .filter(|(cached_key, _)| cached_key == key)
4600            .map(|(_, value)| value.clone())
4601    }
4602
4603    pub(crate) fn insert(
4604        &mut self,
4605        key: PenaltySubspaceCacheKey,
4606        value: Arc<outer_eval::PenaltySubspace>,
4607    ) {
4608        self.entry = Some((key, value));
4609    }
4610
4611    pub(crate) fn clear(&mut self) {
4612        self.entry = None;
4613    }
4614}
4615
4616impl PenaltySubspaceCacheKey {
4617    /// Build a cache key from the transformed-E matrix and ridge passport.
4618    /// `E` is hashed by exact f64 bits (column-major), so the key is bit-exact
4619    /// and avoids float-Hash issues; the ridge passport is hashed via its
4620    /// `Hash` impl. Two calls at the same `(E, ridge)` yield equal keys.
4621    pub(crate) fn from_inputs(
4622        e_transformed: &ndarray::Array2<f64>,
4623        ridge_passport: &gam_problem::RidgePassport,
4624    ) -> Self {
4625        use std::collections::hash_map::DefaultHasher;
4626        use std::hash::{Hash, Hasher};
4627        let mut hasher = DefaultHasher::new();
4628        e_transformed.nrows().hash(&mut hasher);
4629        e_transformed.ncols().hash(&mut hasher);
4630        for value in e_transformed.iter() {
4631            value.to_bits().hash(&mut hasher);
4632        }
4633        let penalty_matrix_fingerprint = hasher.finish();
4634        let mut ridge_hasher = DefaultHasher::new();
4635        ridge_passport.delta.to_bits().hash(&mut ridge_hasher);
4636        (ridge_passport.matrix_form as u8).hash(&mut ridge_hasher);
4637        ridge_passport
4638            .policy
4639            .include_penalty_logdet
4640            .hash(&mut ridge_hasher);
4641        ridge_passport
4642            .policy
4643            .include_laplacehessian
4644            .hash(&mut ridge_hasher);
4645        let ridge_passport_signature = ridge_hasher.finish();
4646        Self {
4647            penalty_matrix_fingerprint,
4648            ridge_passport_signature,
4649        }
4650    }
4651}
4652
4653/// Estimate the in-cache footprint of a (compacted) PIRLS result.
4654///
4655/// Mirrors what `compact_for_reml_cache` keeps:
4656/// * six surviving n-length f64 arrays (final_eta, solveweights,
4657///   solveworking_response, solvemu, solve_c_array, solve_d_array);
4658/// * the p-length coefficient vector;
4659/// * the two p×p Hessians (dense or CSC sparse);
4660/// * the `ReparamResult` payload — the dominant scaling term beyond n, since
4661///   it carries `s_transformed`, `qs`, and `e_transformed` as p×p / rank×p
4662///   matrices.
4663/// A small constant overhead absorbs scalar fields, enum discriminants, and
4664/// the HashMap entry. This errs on the conservative side: overestimation
4665/// causes earlier eviction, never under-counting that would let the cache
4666/// silently exceed the byte budget.
4667pub(crate) fn pirls_result_cache_bytes(result: &PirlsResult) -> usize {
4668    use std::mem::size_of;
4669    let n_array_elems = result.final_eta.len()
4670        + result.solveweights.len()
4671        + result.solveworking_response.len()
4672        + result.solvemu.len()
4673        + result.solve_c_array.len()
4674        + result.solve_d_array.len();
4675    let p = result.beta_transformed.0.len();
4676    let pen_h = symmetric_matrix_cache_bytes(&result.penalized_hessian_transformed);
4677    let stab_h = symmetric_matrix_cache_bytes(&result.stabilizedhessian_transformed);
4678    let reparam = (result.reparam_result.s_transformed.len()
4679        + result.reparam_result.qs.len()
4680        + result.reparam_result.e_transformed.len()
4681        + result.reparam_result.det1.len())
4682        * size_of::<f64>();
4683    n_array_elems * size_of::<f64>() + p * size_of::<f64>() + pen_h + stab_h + reparam + 1024
4684}
4685
4686pub(crate) fn symmetric_matrix_cache_bytes(m: &gam_linalg::matrix::SymmetricMatrix) -> usize {
4687    use gam_linalg::matrix::SymmetricMatrix;
4688    use std::mem::size_of;
4689    match m {
4690        SymmetricMatrix::Dense(a) => a.len() * size_of::<f64>(),
4691        SymmetricMatrix::Sparse(s) => {
4692            // CSC sparse: f64 values + usize row indices + usize column pointers.
4693            let (symbolic, values) = s.parts();
4694            values.len() * (size_of::<f64>() + size_of::<usize>())
4695                + std::mem::size_of_val(symbolic.col_ptr())
4696        }
4697    }
4698}
4699
4700/// Centralized cache/memoization owner for REML evaluations.
4701///
4702/// This keeps cache-key identity, bundle reuse, and invalidation policy out of
4703/// the math kernels so objective/derivative routines can stay algebra-focused.
4704pub(crate) struct EvalCacheManager {
4705    pub(crate) pirls_cache: RwLock<PirlsLruCache>,
4706    pub(crate) penalty_subspace_cache: RwLock<PenaltySubspaceCache>,
4707    pub(crate) current_eval_bundle: RwLock<Option<EvalShared>>,
4708    pub(crate) current_outer_eval: RwLock<Option<(Vec<u64>, OuterEval)>>,
4709    pub(crate) pirls_cache_enabled: AtomicBool,
4710}
4711
4712impl EvalCacheManager {
4713    pub(crate) fn new() -> Self {
4714        Self {
4715            pirls_cache: RwLock::new(PirlsLruCache::new(PIRLS_CACHE_BYTE_BUDGET)),
4716            penalty_subspace_cache: RwLock::new(PenaltySubspaceCache::new()),
4717            current_eval_bundle: RwLock::new(None),
4718            current_outer_eval: RwLock::new(None),
4719            pirls_cache_enabled: AtomicBool::new(true),
4720        }
4721    }
4722
4723    /// Creates a sanitized cache key from rho values.
4724    /// Returns None if any component is NaN, in which case caching is skipped.
4725    /// Maps -0.0 to 0.0 to ensure key stability.
4726    pub(crate) fn sanitized_rhokey(rho: &Array1<f64>) -> Option<Vec<u64>> {
4727        self::rho_key::sanitized_rhokey(rho)
4728    }
4729
4730    /// Memoizing wrapper for `PenaltySubspace` construction.
4731    ///
4732    /// The penalty-subspace eigendecomposition is shape-invariant: any two
4733    /// outer evaluations at the same `(E_transformed, ridge_passport)` produce
4734    /// bit-identical subspaces. The single-slot cache amortizes consecutive
4735    /// fixed-S queries (rank, logdet, trace) within a single outer iter.
4736    pub(super) fn cached_penalty_subspace<F>(
4737        &self,
4738        e_transformed: &ndarray::Array2<f64>,
4739        ridge_passport: &gam_problem::RidgePassport,
4740        build: F,
4741    ) -> Result<Arc<outer_eval::PenaltySubspace>, EstimationError>
4742    where
4743        F: FnOnce() -> Result<outer_eval::PenaltySubspace, EstimationError>,
4744    {
4745        let key = PenaltySubspaceCacheKey::from_inputs(e_transformed, ridge_passport);
4746        if let Some(hit) = self.penalty_subspace_cache.read().unwrap().get(&key) {
4747            return Ok(hit);
4748        }
4749        let value = Arc::new(build()?);
4750        self.penalty_subspace_cache
4751            .write()
4752            .unwrap()
4753            .insert(key, value.clone());
4754        Ok(value)
4755    }
4756
4757    pub(crate) fn cached_eval_bundle(&self, key: &Option<Vec<u64>>) -> Option<EvalShared> {
4758        let guard = self.current_eval_bundle.read().unwrap();
4759        let bundle: &EvalShared = guard.as_ref()?;
4760        bundle.matches(key).then(|| bundle.clone())
4761    }
4762
4763    pub(crate) fn store_eval_bundle(&self, bundle: EvalShared) {
4764        *self.current_eval_bundle.write().unwrap() = Some(bundle);
4765    }
4766
4767    pub(crate) fn cached_outer_eval(&self, key: &Option<Vec<u64>>) -> Option<OuterEval> {
4768        let key = key.as_ref()?;
4769        let guard = self.current_outer_eval.read().unwrap();
4770        let (cached_key, eval): &(Vec<u64>, OuterEval) = guard.as_ref()?;
4771        (cached_key == key).then(|| eval.clone())
4772    }
4773
4774    pub(crate) fn store_outer_eval(&self, key: &Option<Vec<u64>>, eval: &OuterEval) {
4775        if let Some(key) = key.clone() {
4776            *self.current_outer_eval.write().unwrap() = Some((key, eval.clone()));
4777        }
4778    }
4779
4780    pub(crate) fn invalidate_eval_bundle(&self) {
4781        self.current_eval_bundle.write().unwrap().take();
4782        self.current_outer_eval.write().unwrap().take();
4783    }
4784
4785    pub(crate) fn clear_eval_and_factor_caches(&self) {
4786        self.invalidate_eval_bundle();
4787        self.penalty_subspace_cache.write().unwrap().clear();
4788    }
4789}
4790
4791/// Reusable scratch/runtime memory that should not be part of mathematical
4792/// state invariants.
4793pub(crate) struct RemlArena {
4794    pub(crate) cost_eval_count: RwLock<u64>,
4795    pub(crate) lastgradient_used_stochastic_fallback: AtomicBool,
4796}
4797
4798impl RemlArena {
4799    pub(crate) fn new() -> Self {
4800        Self {
4801            cost_eval_count: RwLock::new(0),
4802            lastgradient_used_stochastic_fallback: AtomicBool::new(false),
4803        }
4804    }
4805}
4806
4807pub(crate) struct AloFrozenNuisance {
4808    pub(crate) n_obs: usize,
4809    pub(crate) influence_scale: Vec<f64>,
4810    pub(crate) phi: f64,
4811}
4812
4813pub(crate) struct RemlState<'a> {
4814    pub(crate) y: ArrayView1<'a, f64>,
4815    pub(crate) x: DesignMatrix,
4816    pub(crate) weights: ArrayView1<'a, f64>,
4817    pub(crate) offset: Array1<f64>,
4818    /// Canonicalized block-local penalties with pre-computed roots.
4819    /// This is the single canonical penalty representation — no full-width
4820    /// `rank × p` roots are stored separately.
4821    pub(crate) canonical_penalties: Arc<Vec<gam_terms::construction::CanonicalPenalty>>,
4822    pub(crate) balanced_penalty_root: Array2<f64>,
4823    pub(crate) reparam_invariant: ReparamInvariant,
4824    pub(crate) sparse_penalty_block_count: Option<usize>,
4825    pub(crate) p: usize,
4826    pub(crate) config: Arc<RemlConfig>,
4827    pub(crate) runtime_mixture_link_state: Option<gam_problem::MixtureLinkState>,
4828    pub(crate) runtime_sas_link_state: Option<SasLinkState>,
4829    pub(crate) nullspace_dims: Vec<usize>,
4830    pub(crate) coefficient_lower_bounds: Option<Array1<f64>>,
4831    pub(crate) linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
4832    /// Relative shrinkage floor for penalized block eigenvalues (rho-independent).
4833    pub(crate) penalty_shrinkage_floor: Option<f64>,
4834    /// Explicit prior on log smoothing parameters used by the REML/LAML objective.
4835    pub(crate) rho_prior: gam_problem::RhoPrior,
4836
4837    pub(crate) cache_manager: EvalCacheManager,
4838    pub(crate) arena: RemlArena,
4839    pub(crate) warm_start_beta: RwLock<Option<Coefficients>>,
4840    /// Two-point ρ-trajectory used for second-order warm-start
4841    /// extrapolation: when the outer optimizer asks for a fit at a new
4842    /// ρ, we have `β(ρ_k)` (in `warm_start_beta`) and `β(ρ_{k-1})` (in
4843    /// `prev_warm_start_beta`). The implicit β(ρ) trajectory is locally
4844    /// linear under the FOC ∇F(β,ρ)=0, so a tangent-line prediction
4845    /// `β_predict(ρ_new) = β_k + α · (β_k − β_{k-1})` where α is the
4846    /// projection of `(ρ_new − ρ_k)` onto `(ρ_k − ρ_{k-1})` gives a
4847    /// better seed than the flat `β_k` alone — replacing PIRLS warm-
4848    /// start "use last β as-is" with a real tangent-prediction step.
4849    pub(crate) warm_start_rho: RwLock<Option<Array1<f64>>>,
4850    pub(crate) prev_warm_start_beta: RwLock<Option<Coefficients>>,
4851    pub(crate) prev_warm_start_rho: RwLock<Option<Array1<f64>>>,
4852    pub(crate) warm_start_enabled: AtomicBool,
4853    pub(crate) screening_max_inner_iterations: Arc<AtomicUsize>,
4854    /// Outer-aware inner-PIRLS iteration cap for the main descent loop.
4855    ///
4856    /// Distinct from `screening_max_inner_iterations`, which is used during
4857    /// seed selection and toggles a side-effect bundle (cache writes,
4858    /// warm-start updates, KKT enforcement all suppressed). This atomic is
4859    /// purely a cap — when nonzero, the inner Newton loop is capped at
4860    /// `min(this, full_max_iterations)`, but cache writes and warm-start
4861    /// updates remain enabled. Driven by the outer optimizer to coarsen
4862    /// inner solves at early outer iterations when ρ is far from converged,
4863    /// and lifted back to full at the final accepted iter (otherwise the
4864    /// returned β would be biased by the loose cap).
4865    ///
4866    /// Both atomics are honored together as `min(screening_cap, outer_cap)`
4867    /// when both are nonzero. Default 0 (no cap from this source).
4868    pub(crate) outer_inner_cap: Arc<AtomicUsize>,
4869
4870    /// Inner-PIRLS feedback signal driven by `execute_pirls_if_needed` after
4871    /// each NON-screening solve. Stores the iteration count at which the
4872    /// inner Newton stopped, plus a flag indicating whether it converged
4873    /// (vs. hit the iteration cap). The outer first-/second-order bridges
4874    /// read these atomics to drive an adaptive `inner_cap_schedule`: the
4875    /// next outer iter's inner cap becomes `last_iters + small_margin`
4876    /// when the previous solve converged, or a geometric backoff when it
4877    /// hit the cap. This replaces the older hardcoded iter-tier schedule
4878    /// (3/5/10/20) with a cap that follows the inner solver's actual
4879    /// convergence behavior — Eisenstat-Walker style for the inner
4880    /// quadratic loop. Default 0 / false (no signal yet — first outer
4881    /// iter falls back to a coarse iter-count tier).
4882    pub(crate) last_inner_iters: Arc<AtomicUsize>,
4883    pub(crate) last_inner_converged: Arc<AtomicBool>,
4884
4885    /// Cached state from the most recent successful PIRLS solve, used by
4886    /// the IFT-based warm-start predictor.
4887    ///
4888    /// The implicit-function theorem applied to the FOC ∇_β F(β,ρ)=0
4889    /// gives `dβ/dρ_k = -H_pen^{-1} · (e^{ρ_k} · S_k · β)`. A first-order
4890    /// Taylor predictor reads
4891    /// `β_predict(ρ_new) = β_cur − Σ_k Δρ_k · H_pen^{-1} · (e^{ρ_cur_k} · S_k · β_cur)`.
4892    /// This is a strict superset of the tangent-line predictor's
4893    /// requirements: works after a single successful solve (tangent-line
4894    /// needs two prior fits), and gives the EXACT first-order Jacobian
4895    /// of the implicit β(ρ) trajectory rather than a secant proxy along one
4896    /// ρ-direction.
4897    ///
4898    /// Populated in `updatewarm_start_from` when PIRLS converges; cleared
4899    /// on failure, on `reset_surface`, and on link-state changes.
4900    pub(crate) ift_warm_start_cache: RwLock<Option<IftWarmStartCache>>,
4901
4902    /// Persisted Levenberg-Marquardt damping coefficient from the most
4903    /// recent successful PIRLS solve, bit-packed into an `AtomicU64`
4904    /// (`f64::to_bits` low 64 bits). Read at the start of
4905    /// `execute_pirls_if_needed` and written into the
4906    /// `PirlsConfig::initial_lm_lambda` hint so the inner Newton seeds
4907    /// `λ_LM` near the damping the previous solve discovered, instead
4908    /// of cold-starting at `1e-6` and burning 4-6 halving steps to
4909    /// recover. `0` (the default) signals "no hint"; the inner solver
4910    /// clamps any positive hint into `[1e-6, 1e-3]` so a stale value
4911    /// cannot destabilize the next solve. Reset on `reset_surface` and
4912    /// on failed solves.
4913    pub(crate) last_pirls_lm_lambda: Arc<AtomicU64>,
4914
4915    /// Negative-Binomial overdispersion `theta` frozen for the smoothing-
4916    /// parameter (λ) search (#1082), bit-packed `f64` (`f64::to_bits`). `0`
4917    /// (the default) signals "not yet frozen". On the first non-screening
4918    /// λ-search inner solve of an estimated-θ NB fit, the seed's
4919    /// maximum-likelihood θ is computed once and stored here; every subsequent
4920    /// λ-search evaluation pins the inner solve to this value via
4921    /// `GlmLikelihoodSpec::with_negbin_theta_frozen_for_search`, so the REML
4922    /// criterion `F(ρ) = REML(ρ, θ_frozen)` is a stationary function of ρ and
4923    /// the outer optimizer converges instead of chasing the per-eval θ drift
4924    /// that the estimated path injects. The single final reported fit still
4925    /// ML-refreshes θ at the converged η. Reset on `reset_surface`.
4926    pub(crate) frozen_negbin_theta: Arc<AtomicU64>,
4927
4928    /// Last observed IFT-prediction residual (`‖β_converged − β_predicted‖
4929    /// / ‖β_converged‖`) from the most recent non-screening solve where
4930    /// the predictor was actually consumed. Bit-packed `f64` (low 64
4931    /// bits via `f64::to_bits`).
4932    ///
4933    /// "No signal yet" is encoded as a NaN bit-pattern
4934    /// (`IFT_RESIDUAL_NO_SIGNAL_BITS`). The original `0` sentinel
4935    /// collided with `f64::to_bits(0.0) == 0` — a true residual of
4936    /// exactly 0 (degenerate but mathematically possible if every
4937    /// β_predicted_i matched β_converged_i to bit-equality) would
4938    /// have been indistinguishable from "predictor never reported".
4939    /// NaN's self-inequality makes the sentinel unambiguous: any
4940    /// stored finite non-negative value is genuine signal.
4941    ///
4942    /// Read by `predict_warm_start_beta_ift_with_outcome` to drive the adaptive
4943    /// |Δρ| cap (`adaptive_ift_max_drho`): a small residual loosens
4944    /// the cap, a large one tightens it. Replaces the previous
4945    /// hardcoded `IFT_WARM_START_MAX_DRHO = 2.0` constant with a
4946    /// data-driven policy, so the predictor adapts to the empirical
4947    /// faithfulness of the linearization at this surface's scale.
4948    /// Reset on `reset_surface` and on failed solves.
4949    pub(crate) last_ift_prediction_residual: Arc<AtomicU64>,
4950
4951    /// Last observed gain ratio of the accepted LM step
4952    /// (`actual_reduction / predicted_reduction`) from the most recent
4953    /// non-screening PIRLS solve. Bit-packed `f64` with the same NaN
4954    /// sentinel discipline as `last_ift_prediction_residual`: NaN bits
4955    /// (`IFT_RESIDUAL_NO_SIGNAL_BITS`) encode "no signal yet" so a
4956    /// recorded ratio of exactly 0 (degenerate but possible) doesn't
4957    /// collide with the no-signal token.
4958    ///
4959    /// Used by `first_order_inner_cap_schedule` as a third quality
4960    /// signal alongside `last_iters` and `last_converged`. A small
4961    /// `accept_rho` (model overstating predicted reduction) is a hint
4962    /// the next iter's inner Newton may need extra margin even when
4963    /// the previous solve converged in few iters. Reset on
4964    /// `reset_surface` and on failed solves.
4965    pub(crate) last_pirls_accept_rho: Arc<AtomicU64>,
4966
4967    /// Cached Cholesky factorization of `IftWarmStartCache::penalized_hessian_transformed`.
4968    /// Lazily computed on the first IFT predict call after a fresh
4969    /// `updatewarm_start_from`, then reused by every subsequent
4970    /// predict call until the IFT cache is invalidated. At large-scale
4971    /// scale where p can reach several thousand, the dense Cholesky
4972    /// is O(p³)/3 — multiple seconds per refactor — so caching saves
4973    /// real wall time across the typical 5-10 IFT predict calls per
4974    /// outer fit. Reset jointly with `ift_warm_start_cache` (on
4975    /// reset_surface, on link-state changes, on failed PIRLS solves,
4976    /// and whenever a new H_pen replaces the cached one).
4977    pub(crate) ift_cached_factor: RwLock<Option<Arc<dyn gam_linalg::matrix::FactorizedSystem>>>,
4978
4979    /// When set, the penalties have Kronecker (tensor-product) structure and
4980    /// the REML evaluator can use O(∏q_j) logdet instead of O(p³) eigendecomposition.
4981    /// Populated via `set_kronecker_penalty_system` after construction.
4982    pub(crate) kronecker_penalty_system: Option<gam_terms::smooth::KroneckerPenaltySystem>,
4983    /// Full Kronecker factored basis (marginal designs + penalties + dims).
4984    /// Used by P-IRLS for factored reparameterization.
4985    pub(crate) kronecker_factored: Option<gam_terms::basis::KroneckerFactoredBasis>,
4986
4987    /// Precomputed `(XᵀWX, XᵀW(y − offset))` for the Gaussian + Identity
4988    /// outer REML loop, populated once before the outer optimizer when the
4989    /// family / link / constraint preconditions hold and the design supports
4990    /// the Identity short-circuit at `pirls.rs:6237`. When present, each
4991    /// inner `solve_penalized_least_squares_implicit` reads these matrices
4992    /// instead of restreaming the O(N·p²) GEMM and O(N·p) matvec per outer
4993    /// iteration — the penalty `λ·S` is still added per-λ.
4994    ///
4995    /// Invalidated jointly with the design in `reset_surface`.
4996    pub(crate) gaussian_fixed_cache: RwLock<Option<Arc<crate::pirls::GaussianFixedCache>>>,
4997    /// Conditioned-frame exact ψ-derivatives `(∂XᵀWX/∂ψ, ∂XᵀW(y−offset)/∂ψ)`
4998    /// for the SINGLE design-moving spatial hyperparameter (#1033b), assembled
4999    /// n-free from the certified Chebyshev ψ-Gram tensor and installed beside
5000    /// `gaussian_fixed_cache` at the same in-window trial. When present the
5001    /// Gaussian-identity ψ-gradient HyperCoord (`a_j`, `g_j`, dense `B_j`) is
5002    /// formed from these k×k objects instead of realizing and contracting the
5003    /// n×k ∂X/∂ψ slab — retiring the second per-trial n-pass. Lives in the
5004    /// SAME conditioned column frame as `gaussian_fixed_cache.xtwx_orig`, so
5005    /// the hyper-coord builder transforms it by the per-eval Qs/free-basis the
5006    /// same way it transforms the streamed Gram. Invalidated with the design.
5007    pub(crate) gaussian_psi_gram_deriv:
5008        RwLock<Option<Arc<(ndarray::Array2<f64>, ndarray::Array1<f64>)>>>,
5009    /// Conditioned-frame exact ψ-derivative pair `(∂XᵀWX/∂ψ, ∂XᵀW(y−offset)/∂ψ)`
5010    /// for the SINGLE design-moving spatial hyperparameter in the GLM (frozen-W)
5011    /// lane (#1033 / #1111), assembled n-free from
5012    /// [`crate::glm_sufficient_lane::FrozenWeightGramTensor::gradient_pair_if_sound`]
5013    /// and installed beside `glm_first_step_gram` at the same in-window
5014    /// drift-OK trial. When present, the GLM ψ-gradient HyperCoord serves its
5015    /// envelope `a_j` and score `g_j` from these k×k objects instead of
5016    /// realizing and contracting the n×k ∂X/∂ψ slab — the second per-trial
5017    /// n-pass. Unlike the Gaussian lane the Hessian curvature `B_j` is NOT
5018    /// served from the tensor: for a GLM the per-trial `B_j` term
5019    /// `X_τᵀWX + XᵀWX_τ` is irreducibly n-dependent (the moving working weight
5020    /// `W` does not factor out of a frozen-W k×k object), so `B_j` keeps the
5021    /// exact streamed slab (#1033). Lives in the SAME conditioned column frame
5022    /// as `glm_first_step_gram` / `gaussian_fixed_cache.xtwx_orig`, so the
5023    /// hyper-coord builder transforms it by the per-eval Qs/free-basis the same
5024    /// way. NOT family-gated (the GLM lane's own slot). Invalidated with the
5025    /// design.
5026    pub(crate) glm_psi_gram_deriv:
5027        RwLock<Option<Arc<(ndarray::Array2<f64>, ndarray::Array1<f64>)>>>,
5028    /// Frozen-weight first-Fisher-step data-fit Gram `XᵀWX` for the GLM
5029    /// design-moving ψ-sweep (#1111 / #1033 mechanism (c)), in the conditioned
5030    /// (original / `x_fit`) column frame — the SAME frame as
5031    /// `gaussian_fixed_cache.xtwx_orig`.
5032    ///
5033    /// Assembled n-free per in-window ψ-trial from the certified frozen-weight
5034    /// Chebyshev tensor ([`crate::glm_sufficient_lane::FrozenWeightGramTensor`])
5035    /// and installed only when the trial's converged working weight has not
5036    /// drifted past tolerance from the frozen snapshot. When present, the GLM
5037    /// inner P-IRLS serves its FIRST Fisher-scoring iteration's `XᵀWX` from this
5038    /// cache instead of restreaming the O(N·p²) weighted cross-product — the
5039    /// dominant per-trial n-term in a large-n Poisson/Binomial κ-sweep. The
5040    /// penalty `Sλ` is still added per-λ on top, and every subsequent inner
5041    /// iteration restreams the true (moving) `W`, so the converged β̂ is
5042    /// unchanged; only the first-iteration Gram build is elided. Unlike
5043    /// `gaussian_fixed_cache` this is NOT family-gated — it is the GLM lane's
5044    /// own slot, consumed once per inner solve. Invalidated with the design in
5045    /// `reset_surface`.
5046    pub(crate) glm_first_step_gram: RwLock<Option<Arc<ndarray::Array2<f64>>>>,
5047    /// Previous successful non-Gaussian fixed-design data-fit Gram `XᵀWX` in
5048    /// the conditioned original frame, keyed to `warm_start_beta`.
5049    ///
5050    /// When the next outer trial uses a flat warm start, its first PIRLS
5051    /// curvature build evaluates at the same `η = Xβ` as the previous converged
5052    /// solve, so the Hessian weights and `XᵀWX` are identical. Reusing this
5053    /// original-frame Gram skips one dense `O(n·p²)` pass per warm-started
5054    /// trial while still letting later PIRLS iterations restream the moving
5055    /// weights. IFT/tangent-predicted starts do not consume this cache.
5056    pub(crate) flat_glm_first_step_gram: RwLock<Option<Arc<ndarray::Array2<f64>>>>,
5057    /// Frozen ALO robustness weights for this REML surface.
5058    ///
5059    /// The PSIS influence scale is a non-smooth function of the current hat
5060    /// diagonals. Once the high-leverage ALO objective activates, it is frozen
5061    /// for the current surface so the analytic gradient differentiates the
5062    /// same fixed-weight objective the cost evaluates.
5063    pub(crate) alo_frozen_nuisance: RwLock<Option<AloFrozenNuisance>>,
5064
5065    /// Stable disk-cache key for the current realized REML surface. Computed
5066    /// lazily because it hashes the row-chunked design and data vectors.
5067    pub(crate) persistent_warm_start_key: RwLock<Option<String>>,
5068    pub(crate) persistent_latent_values_fingerprint: Option<u64>,
5069    pub(crate) persistent_latent_values_cache: RwLock<PersistentLatentValuesCache>,
5070    pub(crate) analytic_penalty_registry_fingerprint: u64,
5071    /// Ensures the process attempts at most one disk restore per surface.
5072    pub(crate) persistent_warm_start_loaded: AtomicBool,
5073    /// Scoped counter disabling disk writes from cost-only posterior/probe
5074    /// evaluations. In-memory warm starts still update; only JSON/bin
5075    /// persistence and eviction sweeps are suppressed.
5076    pub(crate) persistent_warm_start_store_suppression: AtomicUsize,
5077    /// Scoped counter disabling the Gaussian-identity ALO-stabilization
5078    /// augmentation (#979). The leverage barrier `Σ_i (h_i − τ)₊²` is an OUTER
5079    /// OPTIMIZER aid (#813/#821) that keeps the smoothing-parameter search off
5080    /// pathological high-leverage λ regions. The marginal smoothing-parameter
5081    /// posterior `π(ρ|y) ∝ exp(−LAML(ρ))` (#938) is a property of the genuine
5082    /// model criterion, sampled against a Laplace proposal built from the BASE
5083    /// REML Hessian, so the certificate / NUTS evaluations suppress the
5084    /// augmentation (see `without_alo_stabilization`) — both for proposal↔target
5085    /// consistency and to drop the per-leapfrog ALO diagnostic suite.
5086    pub(crate) alo_stabilization_suppression: AtomicUsize,
5087    /// Whether the cross-process ON-DISK warm-start layer is engaged at all.
5088    ///
5089    /// Default `false`: the optimizer's IN-MEMORY warm start (the actual
5090    /// speed lever) is always on, but the disk checkpoint — `load_record`
5091    /// at fit start and `store_record` at finalize, each of which opens the
5092    /// shared `WarmStartStore` and pays an eviction/dir scan that is O(cache
5093    /// entries) on a network filesystem — is skipped. Disk persistence has
5094    /// reuse value only ACROSS processes or across repeated identical fits;
5095    /// a single in-process fit (and a fortiori a loop of distinct throwaway
5096    /// fits, e.g. CI-coverage replicates each on different data, #1082/#1114)
5097    /// gets zero benefit from it and pays the per-fit open/scan/save in full.
5098    /// `FitConfig::persist_warm_start_disk` flips this to `true` only when the
5099    /// caller explicitly asks for cross-process / repeat-fit persistence.
5100    pub(crate) persistent_warm_start_disk_enabled: AtomicBool,
5101}