gam_solve/estimate/joint_hyper.rs
1use super::*;
2
3pub(crate) fn validate_joint_hyper_direction_shapes(
4 x: &DesignMatrix,
5 canonical_len: usize,
6 theta: &Array1<f64>,
7 rho_dim: usize,
8 hyper_dirs: &[DirectionalHyperParam],
9) -> Result<(), EstimationError> {
10 if rho_dim > theta.len() {
11 crate::bail_invalid_estim!(
12 "rho_dim {} exceeds theta dimension {}",
13 rho_dim,
14 theta.len()
15 );
16 }
17
18 let p = x.ncols();
19 let psi_dim = theta.len() - rho_dim;
20 if hyper_dirs.len() != psi_dim {
21 crate::bail_invalid_estim!(
22 "joint hyper-gradient derivative count mismatch: psi_dim={}, hyper_dirs={}",
23 psi_dim,
24 hyper_dirs.len()
25 );
26 }
27
28 for (idx, hyper_dir) in hyper_dirs.iter().enumerate() {
29 for component in hyper_dir.penalty_first_components() {
30 if component.penalty_index >= canonical_len {
31 crate::bail_invalid_estim!(
32 "penalty_index for dir {idx} out of bounds: {} >= {}",
33 component.penalty_index,
34 canonical_len
35 );
36 }
37 }
38 if hyper_dir.x_tau_original.nrows() != x.nrows() || hyper_dir.x_tau_original.ncols() != p {
39 crate::bail_invalid_estim!(
40 "X_tau[{idx}] must be {}x{}, got {}x{}",
41 x.nrows(),
42 p,
43 hyper_dir.x_tau_original.nrows(),
44 hyper_dir.x_tau_original.ncols()
45 );
46 }
47 RemlState::validate_penalty_component_shapes(
48 hyper_dir.penalty_first_components(),
49 p,
50 &format!("S_tau[{idx}]"),
51 )?;
52 if let Some(x2) = hyper_dir.x_tau_tau_original.as_ref() {
53 if x2.len() != psi_dim {
54 crate::bail_invalid_estim!(
55 "X_tau_tau[{idx}] length mismatch: expected {}, got {}",
56 psi_dim,
57 x2.len()
58 );
59 }
60 for (j, x_ij) in x2.iter().enumerate() {
61 let Some(x_ij) = x_ij.as_ref() else {
62 continue;
63 };
64 if x_ij.nrows() != x.nrows() || x_ij.ncols() != p {
65 crate::bail_invalid_estim!(
66 "X_tau_tau[{idx}][{j}] must be {}x{}, got {}x{}",
67 x.nrows(),
68 p,
69 x_ij.nrows(),
70 x_ij.ncols()
71 );
72 }
73 }
74 }
75 if let Some(s2) = hyper_dir.penaltysecond_componentrows() {
76 if s2.len() != psi_dim {
77 crate::bail_invalid_estim!(
78 "S_tau_tau[{idx}] length mismatch: expected {}, got {}",
79 psi_dim,
80 s2.len()
81 );
82 }
83 for (j, components) in s2.iter().enumerate() {
84 let Some(components) = components.as_ref() else {
85 continue;
86 };
87 RemlState::validate_penalty_component_shapes(
88 components,
89 p,
90 &format!("S_tau_tau[{idx}][{j}]"),
91 )?;
92 }
93 }
94 }
95
96 Ok(())
97}
98
99pub struct ExternalJointHyperEvaluator<'a> {
100 pub(crate) conditioning: ParametricColumnConditioning,
101 pub(crate) penalty_shrinkage_floor: Option<f64>,
102 pub(crate) kronecker_penalty_system: Option<gam_terms::smooth::KroneckerPenaltySystem>,
103 pub(crate) kronecker_factored: Option<gam_terms::basis::KroneckerFactoredBasis>,
104 pub(crate) reml_state: RemlState<'a>,
105 /// Cached design revision counter from the upstream
106 /// `SingleBlockExactJointDesignCache` (or n-block analogue). When the
107 /// caller threads a revision through `evaluate_with_order` /
108 /// `evaluate_efs` / `evaluate_cost_only`, the evaluator can detect ψ-
109 /// invariant repeat calls (cost-only line-search probes, fall-through
110 /// memoization) and short-circuit `reset_surface`'s O(Σ pₖ³) canonical
111 /// rebuild plus the bundle/PIRLS cache wipes. `None` means "no
112 /// revision yet recorded" — every subsequent call is treated as a
113 /// fresh-canonical case and the slow path runs.
114 pub(crate) last_canonical_revision: Option<u64>,
115 /// The ψ at which the last full `reset_surface` (slow path) realized and
116 /// froze the reduced-basis reference surface (#1264). The design-revision
117 /// fast path keeps that surface frozen while re-keying the Gram/penalty to
118 /// the trial ψ; the skip is sound only when the realized reduced basis at
119 /// this pinning ψ is still valid at the trial ψ, certified n-free by
120 /// [`crate::psi_gram_tensor::PsiGramTensor::reduced_basis_equal`].
121 /// `None` until the first slow-path reset records a single-ψ trial.
122 pub(crate) last_reset_psi: Option<f64>,
123 /// Certified Chebyshev-in-ψ Gram tensor for the SINGLE design-moving
124 /// hyperparameter (#1033b, isotropic spatial κ): when present and the
125 /// trial ψ lies inside the certified window, `prepare_eval_state`
126 /// installs the n-free assembled `GaussianFixedCache` after
127 /// `reset_surface`, replacing the per-trial O(n·p²) Gram re-stream. Built
128 /// in the conditioned frame by `build_and_set_psi_gram_tensor` (the same
129 /// fixed column transform the streamed Gram uses), so the installed
130 /// statistics are frame-exact against the streamed ones.
131 pub(crate) psi_gram_tensor:
132 Option<std::sync::Arc<crate::psi_gram_tensor::PsiGramTensor>>,
133 /// Exact k-space correction from the slow-reset Gaussian cache at
134 /// `last_reset_psi`: `(psi_ref, G_exact(ref)-G_tensor(ref),
135 /// r_exact(ref)-r_tensor(ref))`. The initial slow path already paid the row
136 /// pass and built the exact Gaussian cache for its inner solve; the
137 /// design-revision fast path can reuse that anchor to remove the tensor's
138 /// residual without realizing rows again.
139 pub(crate) psi_gram_anchor_correction: Option<(f64, Array2<f64>, Array1<f64>)>,
140 /// EXACT n-free per-ψ canonical penalty surface `S(ψ)` staged for the
141 /// CURRENT ψ-trial (#1033, penalty lane). For a spatial smooth ψ (= log
142 /// length-scale) moves BOTH the design Gram AND the penalty `S(ψ)` (the
143 /// Duchon/Matérn Hilbert scale is built as a function of the length-scale
144 /// from the FROZEN basis centers — n-free). The design-revision fast path
145 /// skips `reset_surface` — the only place the canonical penalty surface is
146 /// rebuilt — so without re-keying `S(ψ)` the inner solve would pair
147 /// `XᵀWX(ψ_new)` with the STALE `S(ψ_old)` and converge to the wrong β̂ /
148 /// κ-optimum.
149 ///
150 /// The CALLER (`SpatialJointContext::eval_full` / `eval_cost`, which holds
151 /// the design `cache`) computes the exact rebuild via
152 /// `cache.canonical_penalties_at(theta)` and hands the owned
153 /// `(Vec<CanonicalPenalty>, active_nullspace_dims)` here through
154 /// `stage_fast_path_penalty` BEFORE the eval — avoiding a `&mut cache`
155 /// borrow alias with the evaluator. On the fast path `prepare_eval_state`
156 /// consumes the staged value (`refresh_psi_penalty_surface`) and re-installs
157 /// `S(ψ_new)` on the kept reference surface via
158 /// `refresh_canonical_penalty_surface`; the slow path takes the freshly
159 /// realized penalties as before and clears this slot. `None` (the default)
160 /// means no exact rebuild is staged — the fast path then refuses (hard
161 /// error) when `supports_nfree_penalty_rekey` is set, so a stale `S` can
162 /// never be silently paired.
163 pub(crate) pending_psi_penalty:
164 Option<std::sync::Arc<(Vec<gam_terms::construction::CanonicalPenalty>, Vec<usize>)>>,
165 /// True when the design `cache` can rebuild `S(ψ)` exactly and n-free for
166 /// the single spatial term (frozen-geometry Duchon/Matérn/ThinPlate). The
167 /// fast-path design-realization skip gates on this (replacing the old
168 /// certified `psi_penalty_tensor_covers` window check): when set, every
169 /// fast-path trial MUST have a staged exact penalty, and a missing stage is
170 /// a hard error rather than a stale-`S` solve.
171 pub(crate) supports_nfree_penalty_rekey: bool,
172 /// Frozen-weight GLM first-Fisher-step data-fit Gram `XᵀWX` staged for the
173 /// CURRENT ψ-trial (#1111 / #1033 mechanism (c)), in the conditioned
174 /// (`x_fit`) frame. Set per-trial by [`SpatialJointContext::eval_full`] when
175 /// the frozen-W tensor covers ψ and the working weight has not drifted, then
176 /// installed onto the inner REML surface inside `prepare_eval_state` (after
177 /// `reset_surface`, on both the slow and design-revision fast paths) and
178 /// cleared. `None` (the default) clears the surface slot so a stale
179 /// previous-ψ Gram is never consumed.
180 pub(crate) pending_glm_first_step_gram: Option<std::sync::Arc<Array2<f64>>>,
181 /// Conditioned-frame exact ψ-derivative pair `(∂XᵀWX/∂ψ, ∂XᵀW(y−offset)/∂ψ)`
182 /// staged for the CURRENT ψ-trial in the GLM frozen-W lane (#1033 / #1111),
183 /// in the conditioned (`x_fit`) frame. Set per-trial by
184 /// [`SpatialJointContext::eval_full`] from
185 /// [`crate::glm_sufficient_lane::FrozenWeightGramTensor::gradient_pair_if_sound`]
186 /// when the frozen-W tensor covers ψ for the gradient and the working weight
187 /// has not drifted, then installed onto the inner REML surface inside
188 /// `prepare_eval_state` and cleared. Serves the GLM ψ-gradient `a_j` / `g_j`
189 /// n-free; `B_j` stays the exact slab. `None` (the default) clears the
190 /// surface slot so a stale previous-ψ derivative is never consumed.
191 pub(crate) pending_glm_psi_gram_deriv: Option<std::sync::Arc<(Array2<f64>, Array1<f64>)>>,
192 /// #1033 instrumentation: count of slow-path entries — i.e. trials for which
193 /// `prepare_eval_state` (or its cost-only twin) rebuilt the canonical
194 /// penalty and re-ran `reset_surface`, paying the per-trial O(n·p) design
195 /// reconditioning + O(Σ pₖ³) canonical rebuild. The design-revision fast
196 /// path does NOT increment this. A bit-tight test asserts that a cache-hit
197 /// trial (repeated `design_revision`) leaves this counter unchanged, proving
198 /// the n-row reconditioning lane was not re-entered. Pure observability; it
199 /// never gates control flow.
200 pub(crate) slow_path_reset_count: std::cell::Cell<u64>,
201}
202
203impl<'a> ExternalJointHyperEvaluator<'a> {
204 pub fn new(
205 y: ArrayView1<'a, f64>,
206 w: ArrayView1<'a, f64>,
207 x: &DesignMatrix,
208 offset: ArrayView1<'_, f64>,
209 s_list: &[BlockwisePenalty],
210 opts: &ExternalOptimOptions,
211 context: &str,
212 ) -> Result<Self, EstimationError> {
213 if let Some(message) = row_mismatch_message(y.len(), w.len(), x.nrows(), offset.len()) {
214 crate::bail_invalid_estim!("{}", message);
215 }
216
217 let p = x.ncols();
218 let specs: Vec<PenaltySpec> = s_list.iter().map(PenaltySpec::from_blockwise_ref).collect();
219 validate_penalty_specs(&specs, p, context)?;
220 let (canonical, active_nullspace_dims) = gam_terms::construction::canonicalize_penalty_specs(
221 &specs,
222 &opts.nullspace_dims,
223 p,
224 context,
225 )?;
226 let conditioning = ParametricColumnConditioning::infer_from_penalty_specs(x, &specs);
227 let x_fit = conditioning.apply_to_design(x);
228 let fit_linear_constraints =
229 conditioning.transform_linear_constraints_to_internal(opts.linear_constraints.clone());
230 let (config, _) = resolved_external_config(opts)?;
231 let config = Arc::new(config);
232
233 let mut reml_state = RemlState::newwith_offset_shared(
234 y,
235 x_fit,
236 w,
237 offset,
238 Arc::new(canonical),
239 p,
240 Arc::clone(&config),
241 Some(active_nullspace_dims.clone()),
242 None,
243 fit_linear_constraints.clone(),
244 )?;
245 reml_state.set_penalty_shrinkage_floor(opts.penalty_shrinkage_floor);
246 reml_state.set_rho_prior(opts.rho_prior.clone());
247 reml_state.set_link_states(
248 config.link_kind.mixture_state().cloned(),
249 config.link_kind.sas_state().copied(),
250 );
251 if let Some(kron) = opts.kronecker_penalty_system.clone() {
252 reml_state.set_kronecker_penalty_system(kron);
253 }
254 if let Some(kf) = opts.kronecker_factored.clone() {
255 reml_state.set_kronecker_factored(kf);
256 }
257 if opts.persist_warm_start_disk {
258 // Caller opted into cross-process resume (#1082): engage the
259 // on-disk warm-start layer. Default-false keeps replicate/CI loops
260 // disk-silent.
261 reml_state.enable_persistent_warm_start_disk();
262 }
263
264 Ok(Self {
265 conditioning,
266 penalty_shrinkage_floor: opts.penalty_shrinkage_floor,
267 kronecker_penalty_system: opts.kronecker_penalty_system.clone(),
268 kronecker_factored: opts.kronecker_factored.clone(),
269 reml_state,
270 last_canonical_revision: None,
271 last_reset_psi: None,
272 psi_gram_tensor: None,
273 psi_gram_anchor_correction: None,
274 pending_psi_penalty: None,
275 supports_nfree_penalty_rekey: false,
276 pending_glm_first_step_gram: None,
277 pending_glm_psi_gram_deriv: None,
278 slow_path_reset_count: std::cell::Cell::new(0),
279 })
280 }
281
282 /// #1033 instrumentation accessor: number of slow-path `reset_surface`
283 /// rebuilds the evaluator has performed since construction. A trial that
284 /// takes the design-revision fast path (cache hit) does not advance this, so
285 /// a test can assert the n-row reconditioning lane was not re-entered by
286 /// checking this counter is unchanged across a repeat-revision eval.
287 pub fn slow_path_reset_count(&self) -> u64 {
288 self.slow_path_reset_count.get()
289 }
290
291 /// Record the pinning ψ frozen by a slow-path `reset_surface` (#1264): the
292 /// single design-moving ψ when `theta` carries one (`theta.len() == rho_dim +
293 /// 1`), else `None` (multi-ψ / no-ψ fits have no single-ψ witness). The
294 /// #1033 n-free production lane gates on the certified value window.
295 fn record_reset_psi(&mut self, theta: &Array1<f64>, rho_dim: usize) {
296 self.last_reset_psi = if theta.len() == rho_dim + 1 {
297 Some(theta[rho_dim])
298 } else {
299 None
300 };
301 }
302
303 /// Stage (or clear) the frozen-weight GLM first-Fisher-step Gram for the
304 /// next trial eval (#1111 / #1033 mechanism (c)). The staged Gram is
305 /// installed onto the inner REML surface inside `prepare_eval_state` and
306 /// then cleared; passing `None` clears any previously staged Gram so a stale
307 /// previous-ψ Gram is never consumed.
308 pub fn stage_glm_first_step_gram(&mut self, gram: Option<Array2<f64>>) {
309 self.pending_glm_first_step_gram = gram.map(std::sync::Arc::new);
310 }
311
312 /// Stage (or clear) the GLM frozen-W conditioned-frame exact ψ-derivative
313 /// pair `(∂XᵀWX/∂ψ, ∂XᵀW(y−offset)/∂ψ)` for the next trial eval
314 /// (#1033 / #1111). Produced by `gradient_pair_if_sound` when the frozen-W
315 /// tensor covers ψ for the gradient and the working weight has not drifted;
316 /// installed onto the inner REML surface inside `prepare_eval_state` (after
317 /// `reset_surface`, on both the slow and design-revision fast paths) and
318 /// then cleared. Serves the GLM ψ-gradient `a_j` / `g_j` n-free; the
319 /// Hessian curvature `B_j` always stays the exact n-dependent slab. Passing
320 /// `None` clears any previously staged pair so a stale previous-ψ
321 /// derivative is never consumed.
322 pub fn stage_glm_psi_gram_deriv(&mut self, deriv: Option<(Array2<f64>, Array1<f64>)>) {
323 self.pending_glm_psi_gram_deriv = deriv.map(std::sync::Arc::new);
324 }
325
326 pub fn set_analytic_penalty_registry(
327 &mut self,
328 registry: Option<&gam_terms::AnalyticPenaltyRegistry>,
329 ) {
330 let fingerprint = registry
331 .map(crate::estimate::reml::outer_eval::analytic_penalty_registry_fingerprint)
332 .unwrap_or(0);
333 crate::estimate::reml::RemlState::set_analytic_penalty_registry_fingerprint(
334 &mut self.reml_state,
335 fingerprint,
336 );
337 }
338
339 pub fn set_persistent_latent_values_fingerprint(
340 &mut self,
341 id_mode: &gam_terms::latent::LatentIdMode,
342 ) {
343 let fingerprint =
344 crate::estimate::reml::outer_eval::latent_id_mode_cache_fingerprint(id_mode);
345 crate::estimate::reml::RemlState::set_persistent_latent_values_fingerprint(
346 &mut self.reml_state,
347 fingerprint,
348 );
349 }
350
351 pub fn load_persistent_latent_values(
352 &self,
353 n_obs: usize,
354 latent_dim: usize,
355 ) -> Option<Array2<f64>> {
356 crate::estimate::reml::RemlState::load_persistent_latent_values(
357 &self.reml_state,
358 n_obs,
359 latent_dim,
360 )
361 }
362
363 pub fn store_persistent_latent_values(&self, values: &Array2<f64>) {
364 crate::estimate::reml::RemlState::store_persistent_latent_values(
365 &self.reml_state,
366 values,
367 );
368 }
369
370 /// Build and attach a certified ψ-Gram tensor (#1033b) for the single
371 /// design-moving hyperparameter ψ over `[psi_lo, psi_hi]`.
372 ///
373 /// `eval_raw_design(psi)` returns the EXACT realized design at `psi` in the
374 /// raw (user) column frame — the same realizer the per-trial path uses.
375 /// This method threads it through THIS evaluator's parametric column
376 /// conditioning before the tensor sees it, so the tensor's assembled
377 /// `XᵀWX(ψ)` lives in the SAME conditioned frame as the streamed
378 /// `gaussian_fixed_cache_if_eligible` (which forms its Gram from
379 /// `x_fit = conditioning.apply_to_design(x)`). The conditioning is a fixed,
380 /// ψ-invariant column transform (means/scales frozen from the baseline
381 /// design at construction), so applying it inside the build keeps the
382 /// expansion analytic and the per-trial installed cache frame-exact —
383 /// without restricting to identity conditioning. Returns whether a
384 /// certified tensor was attached; `false` keeps the exact per-trial path.
385 pub fn build_and_set_psi_gram_tensor(
386 &mut self,
387 mut eval_raw_design: impl FnMut(f64) -> Result<DesignMatrix, String>,
388 weights: ArrayView1<'_, f64>,
389 z: ArrayView1<'_, f64>,
390 psi_lo: f64,
391 psi_hi: f64,
392 ) -> bool {
393 // Clone the (cheap) conditioning so the build closure borrows it
394 // without aliasing `self` while we set the field afterward.
395 let conditioning = self.conditioning.clone();
396 let tensor = crate::psi_gram_tensor::PsiGramTensor::build(
397 |psi| {
398 let raw = eval_raw_design(psi)?;
399 Ok(conditioning.apply_to_design(&raw).to_dense())
400 },
401 weights,
402 z,
403 psi_lo,
404 psi_hi,
405 );
406 match tensor {
407 Ok(tensor) => {
408 self.psi_gram_tensor = Some(std::sync::Arc::new(tensor));
409 self.psi_gram_anchor_correction = None;
410 true
411 }
412 Err(why) => {
413 // The n-free ψ-Gram tensor declined to attach; the caller falls
414 // back to the exact per-trial design path. Record WHY so the
415 // fast-path coverage (#1264/#1216) is diagnosable instead of a
416 // silent non-attachment.
417 log::debug!("ψ-Gram tensor not attached over [{psi_lo}, {psi_hi}]: {why}");
418 false
419 }
420 }
421 }
422
423 /// Declare whether the design `cache` can rebuild `S(ψ)` exactly and n-free
424 /// for the single spatial term (#1033, penalty lane). Set ONCE at setup from
425 /// `cache.supports_nfree_penalty_rekey()`. When `true`, the design-revision
426 /// fast path's design-realization skip is permitted (the penalty can be
427 /// re-keyed without `reset_surface`) and every fast-path trial MUST have a
428 /// staged exact penalty (`stage_fast_path_penalty`), else `prepare_eval_state`
429 /// hard-errors rather than pairing a stale `S`.
430 pub fn set_supports_nfree_penalty_rekey(&mut self, supported: bool) {
431 self.supports_nfree_penalty_rekey = supported;
432 }
433
434 /// True when the n-free penalty re-key lane is enabled for this fit.
435 pub fn supports_nfree_penalty_rekey(&self) -> bool {
436 self.supports_nfree_penalty_rekey
437 }
438
439 /// Stage (or clear) the EXACT n-free canonical penalty surface `S(ψ)` for the
440 /// NEXT trial eval (#1033, penalty lane). The CALLER (which holds the design
441 /// `cache`) computes `cache.canonical_penalties_at(theta)` and hands the
442 /// owned `(Vec<CanonicalPenalty>, active_nullspace_dims)` here BEFORE the
443 /// eval — sidestepping a `&mut cache` borrow alias with the evaluator. On the
444 /// design-revision fast path `prepare_eval_state` /
445 /// `prepare_eval_state_cost_only` consume the staged value via
446 /// `refresh_psi_penalty_surface` and re-install `S(ψ_new)` on the kept
447 /// reference surface; the slow path clears it (the freshly realized penalty
448 /// is used instead). Passing `None` clears any previously staged surface so a
449 /// stale previous-ψ `S` is never consumed.
450 pub fn stage_fast_path_penalty(
451 &mut self,
452 penalty: Option<(Vec<gam_terms::construction::CanonicalPenalty>, Vec<usize>)>,
453 ) {
454 self.pending_psi_penalty = penalty.map(std::sync::Arc::new);
455 }
456
457 /// Build a certified frozen-weight GLM ψ-Gram tensor (#1111 / #1033
458 /// mechanism (c)) for the single design-moving hyperparameter ψ.
459 ///
460 /// Mirrors [`Self::build_and_set_psi_gram_tensor`] but for the GLM
461 /// design-moving lane: the working weight `w` and working response `z` are
462 /// FROZEN at the warm working point, and the tensor wraps the weighted
463 /// design `A(ψ) = diag(√w)·X_fit(ψ)`. Crucially `eval_raw_design` is threaded
464 /// through THIS evaluator's parametric column conditioning before the tensor
465 /// sees it, so the assembled frozen-`W` Gram `XᵀWX(ψ)` lives in the SAME
466 /// conditioned `x_fit` frame the inner PIRLS solve forms its Gram in — the
467 /// same frame-correctness contract the Gaussian lane relies on. Without this
468 /// the tensor would be assembled in the raw user-column frame and silently
469 /// mismatch any inner consumer.
470 ///
471 /// Returns the certified tensor (caller owns it, e.g. to re-use the
472 /// per-trial weight-drift guard), or `None` when no Chebyshev rung certifies
473 /// — the caller then keeps the exact per-trial PIRLS rebuild.
474 pub fn build_frozen_glm_gram_tensor(
475 &self,
476 mut eval_raw_design: impl FnMut(f64) -> Result<DesignMatrix, String>,
477 frozen_w: ArrayView1<'_, f64>,
478 working_z: ArrayView1<'_, f64>,
479 psi_lo: f64,
480 psi_hi: f64,
481 ) -> Option<crate::glm_sufficient_lane::FrozenWeightGramTensor> {
482 let conditioning = self.conditioning.clone();
483 crate::glm_sufficient_lane::FrozenWeightGramTensor::build(
484 |psi| {
485 let raw = eval_raw_design(psi)?;
486 Ok(conditioning.apply_to_design(&raw).to_dense())
487 },
488 frozen_w,
489 working_z,
490 psi_lo,
491 psi_hi,
492 )
493 }
494
495 /// True when a certified ψ-Gram tensor is installed AND `psi` lies inside
496 /// its certified GRADIENT window — i.e. the n-free k-space ψ-derivatives
497 /// `(∂G/∂ψ, ∂b/∂ψ)` will serve the Gaussian gradient HyperCoord, so the
498 /// caller's per-trial n×k ∂X/∂ψ slab is redundant (#1033). For the
499 /// sufficient-statistic kappa search this covers the full optimizer window;
500 /// otherwise the caller does not arm the n-free outer loop.
501 pub fn psi_gram_tensor_covers_gradient(&self, psi: f64) -> bool {
502 self.psi_gram_tensor
503 .as_ref()
504 .is_some_and(|t| t.contains_for_gradient(psi))
505 }
506
507 /// True when a certified ψ-Gram tensor is installed AND `psi` lies inside its
508 /// full certified VALUE window — i.e. the n-free assembled Gaussian
509 /// sufficient statistics `XᵀWX(ψ)/XᵀWz(ψ)` reproduce the streamed Gram to the
510 /// certification tolerance. The caller uses this to skip the per-trial O(n·p)
511 /// design realization + conditioning entirely (#1033): when the value lane is
512 /// covered, `prepare_eval_state` installs the n-free `GaussianFixedCache`, so
513 /// the stale realized design is never read for its rows on the inner Gaussian
514 /// PLS fast path. Strictly narrower-or-equal callers also gate on
515 /// `psi_gram_tensor_covers_gradient` for the gradient channel.
516 pub fn psi_gram_tensor_covers(&self, psi: f64) -> bool {
517 self.psi_gram_tensor
518 .as_ref()
519 .is_some_and(|t| t.contains(psi))
520 }
521
522 /// True when the design-realization SKIP to `psi` is β̂-SOUND given the
523 /// reference surface pinned at the last slow-path reset (#1264). Restored
524 /// after the "stale-penalty-not-stale-basis" theory was empirically refuted:
525 /// cluster measured β̂rel≈1.7e-5 (17× the issue's 1e-6 bar) when the n-free κ skip
526 /// fires on production Duchon geometry — EVEN at a ψ the n-free VALUE window
527 /// admits — because the inner penalized solve `(QsᵀGQs+S)β=b` is run in the
528 /// CONDITIONED reduced basis, and that basis ROTATES with ψ on the near-
529 /// singular radial Gram (κ(G)≈9.5e14). The skip installs the Chebyshev-
530 /// interpolated `gram_at(ψ)` (≤1e-10 vs the streamed exact Gram), and when the
531 /// reduced basis at the trial ψ differs from the reference surface's basis the
532 /// κ-amplified round-off moves the shipped κ-optimum past 1e-6.
533 ///
534 /// So the skip is sound ONLY where the reduced basis is provably unchanged:
535 /// the gauge-invariant range-projector witness `reduced_basis_equal(psi_ref,
536 /// psi)` against the pinning ψ recorded at the last slow-path reset. Without a
537 /// recorded pinning ψ (no reset yet) the skip is refused. Value coverage alone
538 /// is NOT sufficient — this is the load-bearing #1264 soundness gate.
539 pub fn psi_gram_tensor_covers_skip(&self, psi: f64) -> bool {
540 let Some(tensor) = self.psi_gram_tensor.as_ref() else {
541 return false;
542 };
543 if !tensor.contains(psi) {
544 return false;
545 }
546 // The pinning ψ must itself be in-window for its reference projector to be
547 // a valid comparison point; otherwise refuse (forces the exact slow path).
548 let Some(psi_ref) = self.last_reset_psi.filter(|p| tensor.contains(*p)) else {
549 return false;
550 };
551 tensor.reduced_basis_equal(psi_ref, psi)
552 }
553
554 /// Revision of the canonical surface pinned by the last slow-path
555 /// `reset_surface`, if any. The spatial κ caller passes this revision back
556 /// on certified n-free value/gradient probes so [`Self::prepare_eval_state`]
557 /// and [`Self::prepare_eval_state_cost_only`] take their design-revision fast
558 /// paths even if the caller-side realizer revision has since advanced on an
559 /// unrelated miss. The fast path re-keys the Gaussian Gram and `S(ψ)` from
560 /// k-space statistics, so it intentionally reuses this pinned surface rather
561 /// than requiring equality with the current realizer revision (#1033).
562 pub fn nfree_fast_path_revision(&self) -> Option<u64> {
563 self.last_canonical_revision
564 }
565
566 pub fn has_psi_gram_tensor(&self) -> bool {
567 self.psi_gram_tensor.is_some()
568 }
569
570 /// Return the most-recently converged inner β from the last PIRLS solve, if
571 /// it is finite and the right dimension. Used by `SpatialJointContext` to
572 /// warm-start successive outer evaluations instead of cold-starting PIRLS
573 /// from zero every iteration — especially important for GLM families (Poisson,
574 /// NB, Binomial) that cannot use the Gaussian Gram tensor n-free shortcut.
575 pub fn current_beta(&self) -> Option<Array1<f64>> {
576 self.reml_state.current_original_basis_beta()
577 }
578
579 /// Install the n-free per-ψ Gaussian sufficient statistics from the certified
580 /// ψ-Gram tensor (#1033b), when one is present and `theta`'s single ψ lies
581 /// inside the certified window. Idempotent in ψ — must be called on EVERY
582 /// trial (fast-path or slow-path) because the installed `GaussianFixedCache`
583 /// (and the conditioned-frame ψ-derivatives) are keyed to the current ψ, not
584 /// just to the design revision: on the design-revision fast path the design
585 /// did not change but ψ still moved, so the previous ψ's Gram would be stale.
586 ///
587 /// Off-window, multi-ψ, ineligible family, or shape mismatch all return
588 /// without installing — the streamed exact path runs unchanged.
589 /// Returns `true` when the n-free Gaussian ψ-GRADIENT derivative pair was
590 /// installed for this trial — i.e. the certified tensor serves both the
591 /// value AND the gradient n-free, so the conditioned n×k `∂X/∂ψ` slab in the
592 /// hyper_dirs is provably DEAD (the gradient HyperCoord's `j==0` branch reads
593 /// the k-space derivatives and never the slab). The caller uses this to skip
594 /// the per-trial slab conditioning on the design-revision fast path (#1033).
595 fn install_psi_gram_statistics(&mut self, theta: &Array1<f64>, rho_dim: usize) -> bool {
596 let Some(tensor) = self.psi_gram_tensor.as_ref() else {
597 // No tensor installed for this fit → the surface never carries a
598 // ψ-keyed Gaussian Gram, so there is nothing stale to clear.
599 return false;
600 };
601 // #1033: every early return below is a trial for which we CANNOT serve
602 // the n-free per-ψ Gram (off-window, wrong shape, multi-ψ). On the
603 // design-revision fast path `reset_surface` is skipped, so a Gram keyed
604 // to the PREVIOUS in-window ψ would survive and be read stale by the
605 // inner Gaussian PLS. Clear it on every miss so the inner solver
606 // restreams the exact Gram for this trial's design.
607 if theta.len() != rho_dim + 1 {
608 self.reml_state.clear_gaussian_fixed_cache();
609 return false;
610 }
611 let psi = theta[rho_dim];
612 if !tensor.contains(psi) {
613 self.reml_state.clear_gaussian_fixed_cache();
614 return false;
615 }
616 // Clone the Arc handle so the immutable borrow of `self.psi_gram_tensor`
617 // is released before the `&mut self.reml_state` installs below.
618 let tensor = std::sync::Arc::clone(tensor);
619 let mut cache = tensor.gaussian_fixed_cache_at(psi);
620 if let Some(psi_ref) = self.last_reset_psi.filter(|p| tensor.contains(*p)) {
621 let correction_is_current = self
622 .psi_gram_anchor_correction
623 .as_ref()
624 .is_some_and(|(p, _, _)| *p == psi_ref);
625 if !correction_is_current
626 && let Some(anchor) = self.reml_state.installed_gaussian_fixed_cache()
627 && !anchor.row_prediction_is_stale
628 && anchor.xtwx_orig.dim() == cache.xtwx_orig.dim()
629 && anchor.xtwy_orig.len() == cache.xtwy_orig.len()
630 {
631 let tensor_at_ref = tensor.gaussian_fixed_cache_at(psi_ref);
632 if tensor_at_ref.xtwx_orig.dim() == anchor.xtwx_orig.dim()
633 && tensor_at_ref.xtwy_orig.len() == anchor.xtwy_orig.len()
634 {
635 self.psi_gram_anchor_correction = Some((
636 psi_ref,
637 &anchor.xtwx_orig - &tensor_at_ref.xtwx_orig,
638 &anchor.xtwy_orig - &tensor_at_ref.xtwy_orig,
639 ));
640 }
641 }
642 if let Some((p, gram_delta, rhs_delta)) = &self.psi_gram_anchor_correction
643 && *p == psi_ref
644 && gram_delta.dim() == cache.xtwx_orig.dim()
645 && rhs_delta.len() == cache.xtwy_orig.len()
646 {
647 cache.xtwx_orig += gram_delta;
648 cache.xtwy_orig += rhs_delta;
649 }
650 }
651 if !self
652 .reml_state
653 .install_gaussian_fixed_cache(Arc::new(cache))
654 {
655 self.reml_state.clear_gaussian_fixed_cache();
656 return false;
657 }
658 log::debug!(
659 "[psi-gram-tensor] installed n-free Gaussian sufficient statistics at psi={psi:.6}"
660 );
661 // Install the conditioned-frame EXACT ANALYTIC ψ-derivatives so the
662 // Gaussian ψ-gradient HyperCoord is assembled from these k×k objects
663 // instead of the n×k ∂X/∂ψ slab. The #1033 sufficient-statistic kappa
664 // search is armed only when this certified gradient window spans the
665 // full optimizer bounds, so measured trials cannot fall back to a
666 // streamed edge-gradient pass.
667 if tensor.contains_for_gradient(psi)
668 && self.reml_state.install_gaussian_psi_gram_deriv(Arc::new((
669 tensor.dgram_dpsi(psi),
670 tensor.drhs_dpsi(psi),
671 )))
672 {
673 log::debug!(
674 "[psi-gram-tensor] installed n-free analytic ψ-gradient derivatives at \
675 psi={psi:.6}"
676 );
677 true
678 } else {
679 // Outside the certified gradient window (or if the derivative shape
680 // refused). Clear any derivative pair left from a prior ψ so a
681 // non-armed exact path does not reuse a stale derivative.
682 self.reml_state.clear_gaussian_psi_gram_deriv();
683 false
684 }
685 }
686
687 /// #1033 penalty lane: on the design-revision fast path (`reset_surface`
688 /// skipped) re-install the per-ψ canonical penalty surface `S(ψ)` from the
689 /// EXACT n-free rebuild the caller staged via `stage_fast_path_penalty`, so
690 /// the kept reference surface pairs `XᵀWX(ψ_new)` (re-keyed by
691 /// `install_psi_gram_statistics`) with the CORRECT `S(ψ_new)` instead of the
692 /// stale `S(ψ_old)` left from the slow-path reset. The staged penalty is the
693 /// output of `cache.canonical_penalties_at(theta)` — the SAME
694 /// `canonicalize_penalty_specs` pipeline `reset_surface` runs, but built from
695 /// the frozen basis geometry at the trial length-scale (no data rows).
696 ///
697 /// Returns `true` when a staged penalty was consumed and re-keyed. Returns
698 /// `false` when NO penalty was staged — in which case the fast path MUST NOT
699 /// have been taken (the spatial caller only skips design realization when
700 /// `cache.supports_nfree_penalty_rekey()` and always stages the rebuild on
701 /// that lane), so a `false` here is a hard signal that the skip gate and the
702 /// staging have drifted out of sync; the caller treats it as an error rather
703 /// than silently solving with a stale penalty.
704 fn refresh_psi_penalty_surface(&mut self) -> Result<bool, EstimationError> {
705 // Take the staged penalty (consume it — it is keyed to THIS trial's ψ).
706 let Some(staged) = self.pending_psi_penalty.take() else {
707 return Ok(false);
708 };
709 let (canonical, nullspace_dims) =
710 std::sync::Arc::try_unwrap(staged).unwrap_or_else(|arc| (*arc).clone());
711 self.reml_state
712 .refresh_canonical_penalty_surface(Arc::new(canonical), nullspace_dims)?;
713 log::debug!(
714 "[nfree-psi-penalty] re-installed exact n-free canonical penalty surface S(psi) \
715 on the design-revision fast path"
716 );
717 Ok(true)
718 }
719
720 fn prepare_eval_state(
721 &mut self,
722 x: &DesignMatrix,
723 s_list: &[BlockwisePenalty],
724 nullspace_dims: &[usize],
725 linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
726 theta: &Array1<f64>,
727 rho_dim: usize,
728 mut hyper_dirs: Vec<DirectionalHyperParam>,
729 warm_start_beta: Option<ArrayView1<'_, f64>>,
730 context: &str,
731 design_revision: Option<u64>,
732 ) -> Result<Vec<DirectionalHyperParam>, EstimationError> {
733 let p = x.ncols();
734 // Design-revision fast path: when the caller asserts that the
735 // realizer-side design (X + s_list) has not changed since the last
736 // `reset_surface`, we skip the canonical-penalty rebuild and the
737 // `reset_surface` work entirely. Hyper-direction conditioning still
738 // runs (hyper_dirs are freshly constructed per call) and the
739 // warm-start beta / penalty-shrinkage floor still need refreshing.
740 //
741 // #1033 (reduced-basis rotation): the fast path keeps the realized
742 // `self.x` frozen at the pinning ψ but re-keys the Gaussian Gram cache
743 // (`install_psi_gram_statistics`) and the canonical penalty surface
744 // (`refresh_psi_penalty_surface`) to this trial's ψ. The inner
745 // Gaussian-identity solve reads its data statistics ONLY from those
746 // re-keyed k×k objects (and the penalty-derived reparametrization Qs),
747 // never from `self.x` rows, so the skip is sound across a basis ROTATION
748 // — gated on the n-free VALUE coverage (`psi_gram_tensor_covers`), the
749 // condition under which the cache re-key actually fires, rather than the
750 // stricter `reduced_basis_equal` witness (which refused sound rotated-
751 // basis skips and forced the O(n) `reset_surface` fallback).
752 let skip_window_allows_fast_path = match (self.psi_gram_tensor.is_some(), theta.len()) {
753 (true, len) if len == rho_dim + 1 => self.psi_gram_tensor_covers(theta[rho_dim]),
754 _ => true,
755 };
756 let fast_path = match (design_revision, self.last_canonical_revision) {
757 (Some(rev), Some(last)) => rev == last && skip_window_allows_fast_path,
758 _ => false,
759 };
760
761 if fast_path {
762 validate_joint_hyper_direction_shapes(x, s_list.len(), theta, rho_dim, &hyper_dirs)?;
763
764 self.reml_state
765 .set_penalty_shrinkage_floor(self.penalty_shrinkage_floor);
766 self.reml_state.setwarm_start_original_beta(warm_start_beta);
767 // #1033b: the design did not change (fast path) but ψ moved, so the
768 // GaussianFixedCache and conditioned ψ-derivatives are keyed to the
769 // PREVIOUS ψ and must be re-installed for this trial's ψ from the
770 // certified tensor — otherwise the inner PLS reads a stale Gram. The
771 // slow path below clears + reinstalls these; the fast path skips
772 // `reset_surface` (which clears them), so we re-install here directly.
773 // Install BEFORE conditioning so we learn whether the n-free
774 // ψ-gradient was served from the tensor: if so, the conditioned n×k
775 // `∂X/∂ψ` slab below is provably DEAD (the `j==0` gradient branch
776 // reads the k-space derivatives, never the slab), so we skip the
777 // per-trial slab conditioning — the LAST O(n·k²) pass in the κ loop.
778 let gaussian_gradient_is_n_free = self.install_psi_gram_statistics(theta, rho_dim);
779 // #1033 penalty lane: ψ moved BOTH the Gram (re-keyed above) AND the
780 // penalty `S(ψ)`. The skipped `reset_surface` is the only place the
781 // canonical penalty surface is rebuilt, so re-install `S(ψ_new)` from
782 // the EXACT n-free penalty the caller staged here — otherwise the
783 // inner solve would pair `XᵀWX(ψ_new)` with the stale `S(ψ_old)` and
784 // converge to the wrong β̂ / κ-optimum. Done AFTER the Gram install
785 // because `refresh_canonical_penalty_surface` deliberately does NOT
786 // clear the Gaussian Gram cache (it is re-keyed independently above).
787 // The spatial caller only takes the design-realization skip when
788 // `cache.supports_nfree_penalty_rekey()`, and always stages the exact
789 // rebuild on that lane, so the re-key must succeed; a `false` means
790 // the skip gate and the staging drifted apart, which would silently
791 // pair a stale `S` — surface it as a hard error instead.
792 if self.supports_nfree_penalty_rekey && !self.refresh_psi_penalty_surface()? {
793 crate::bail_invalid_estim!(
794 "design-revision fast path fired with n-free penalty re-key enabled but no \
795 exact S(psi) was staged for psi={:.6} (theta_len={}, rho_dim={}); the \
796 reset_surface skip would leave a stale S(psi). The caller must call \
797 stage_fast_path_penalty before every skip-path eval.",
798 if theta.len() > rho_dim {
799 theta[rho_dim]
800 } else {
801 f64::NAN
802 },
803 theta.len(),
804 rho_dim,
805 );
806 }
807 let glm_gradient_is_n_free = self.install_pending_glm_trial_statistics();
808 if !(gaussian_gradient_is_n_free || glm_gradient_is_n_free) {
809 // The slab gradient lane is live for this trial (off the certified
810 // gradient sub-window, non-Gaussian, multi-ψ, …) — condition the
811 // n×k `∂X/∂ψ` slab into the inner solver's frame as before.
812 for dir in &mut hyper_dirs {
813 let mut x_tau = dir.x_tau_dense();
814 self.conditioning
815 .transform_matrix_columnswith_a_inplace(&mut x_tau);
816 dir.x_tau_original =
817 crate::estimate::reml::HyperDesignDerivative::from(x_tau);
818 if let Some(rows) = dir.x_tau_tau_original.as_mut() {
819 for mat in rows.iter_mut().flatten() {
820 let mut dense = mat.materialize();
821 self.conditioning
822 .transform_matrix_columnswith_a_inplace(&mut dense);
823 *mat =
824 crate::estimate::reml::HyperDesignDerivative::from(dense);
825 }
826 }
827 }
828 }
829 return Ok(hyper_dirs);
830 }
831
832 let specs: Vec<PenaltySpec> = s_list.iter().map(PenaltySpec::from_blockwise_ref).collect();
833 validate_penalty_specs(&specs, p, context)?;
834 let (canonical, active_nullspace_dims) =
835 gam_terms::construction::canonicalize_penalty_specs(&specs, nullspace_dims, p, context)?;
836 validate_joint_hyper_direction_shapes(x, canonical.len(), theta, rho_dim, &hyper_dirs)?;
837
838 let x_fit = self.conditioning.apply_to_design(x);
839 let fit_linear_constraints = self
840 .conditioning
841 .transform_linear_constraints_to_internal(linear_constraints);
842
843 for dir in &mut hyper_dirs {
844 let mut x_tau = dir.x_tau_dense();
845 self.conditioning
846 .transform_matrix_columnswith_a_inplace(&mut x_tau);
847 dir.x_tau_original = crate::estimate::reml::HyperDesignDerivative::from(x_tau);
848 if let Some(rows) = dir.x_tau_tau_original.as_mut() {
849 for mat in rows.iter_mut().flatten() {
850 let mut dense = mat.materialize();
851 self.conditioning
852 .transform_matrix_columnswith_a_inplace(&mut dense);
853 *mat = crate::estimate::reml::HyperDesignDerivative::from(dense);
854 }
855 }
856 }
857
858 crate::estimate::reml::RemlState::reset_surface(
859 &mut self.reml_state,
860 x_fit,
861 Arc::new(canonical),
862 p,
863 active_nullspace_dims,
864 None,
865 fit_linear_constraints,
866 self.kronecker_penalty_system.clone(),
867 self.kronecker_factored.clone(),
868 )?;
869 // #1033 instrumentation: this is the slow (n-row) reconditioning lane.
870 self.slow_path_reset_count
871 .set(self.slow_path_reset_count.get().wrapping_add(1));
872 self.reml_state
873 .set_penalty_shrinkage_floor(self.penalty_shrinkage_floor);
874 self.reml_state.setwarm_start_original_beta(warm_start_beta);
875 self.last_canonical_revision = design_revision;
876 // #1264: freeze the reduced-basis reference ψ this slow-path reset pins,
877 // so the next design-revision fast path can certify its skip against it.
878 self.record_reset_psi(theta, rho_dim);
879 self.psi_gram_anchor_correction = None;
880 // #1216 hybrid: on the SLOW path the design was just REALIZED (the n×k
881 // `x_fit` is live in `reset_surface` above), so the inner PLS forms the
882 // EXACT `XᵀWX(ψ)` from it. We deliberately do NOT install the certified
883 // tensor's n-free assembled Gram here: the tensor reconstruction is
884 // bit-tight for the COST / κ-SEARCH (~1e-8) but its ~1e-14 Chebyshev
885 // residual, amplified by the radial-kernel Gram conditioning at
886 // weakly-penalized high-ψ (cond ~1e8), drifts the RECONSTRUCTED β̂ by
887 // ~1e-6 from the exact solve. Since the slow path already paid the O(n·k)
888 // realization (and reset_surface's O(n·p²) reconditioning), forming the
889 // exact Gram is incremental and keeps the MATERIALIZED β̂ bit-exact. The
890 // n-free win is preserved where it matters: the per-trial SEARCH loop
891 // takes the design-revision FAST path (in-window ψ at a pinned revision),
892 // which DOES install the tensor Gram (`prepare_eval_state` fast branch)
893 // and never realizes n rows — the slow path runs only once per revision
894 // (the #1033 "single initial pass"). The Gram slot was cleared by
895 // `reset_surface`, so leaving it clear routes the inner solve to the
896 // exact realized Gram.
897 self.install_pending_glm_trial_statistics();
898 // #1033 penalty lane: the slow path just rebuilt `S` from the freshly
899 // realized design inside `reset_surface`, so a staged n-free penalty (if
900 // any) is superseded — drop it so a later fast-path eval at a DIFFERENT
901 // revision never consumes this trial's stale `S`.
902 self.pending_psi_penalty = None;
903 Ok(hyper_dirs)
904 }
905
906 /// Install the staged frozen-W GLM first-step Gram onto the inner REML
907 /// surface for the current trial, or clear the surface slot when nothing is
908 /// staged (#1111 / #1033 mechanism (c)). Called after `reset_surface` (slow
909 /// path) and on the design-revision fast path, mirroring
910 /// `install_psi_gram_statistics`: the Gram is ψ-keyed, so it must be
911 /// (re)installed per trial and never carried over from the previous ψ.
912 fn install_pending_glm_trial_statistics(&mut self) -> bool {
913 let mut gradient_is_n_free = false;
914 match self.pending_glm_first_step_gram.take() {
915 Some(gram) => {
916 if !self.reml_state.install_glm_first_step_gram(gram) {
917 // Shape mismatch against the current surface — fall back to
918 // the exact streamed first-iteration Gram.
919 self.reml_state.clear_glm_first_step_gram();
920 }
921 }
922 None => self.reml_state.clear_glm_first_step_gram(),
923 }
924 // #1033 / #1111: the GLM frozen-W conditioned-frame ψ-gradient
925 // derivative is keyed to the same per-trial ψ as the first-step Gram, so
926 // install/clear it on the same two sites. Serves the GLM ψ-gradient
927 // `a_j` / `g_j` n-free; `B_j` stays the exact slab.
928 match self.pending_glm_psi_gram_deriv.take() {
929 Some(deriv) => {
930 if self.reml_state.install_glm_psi_gram_deriv(deriv) {
931 gradient_is_n_free = true;
932 } else {
933 // Shape mismatch against the current surface — fall back to
934 // the exact streamed ∂X/∂ψ slab gradient.
935 self.reml_state.clear_glm_psi_gram_deriv();
936 }
937 }
938 None => self.reml_state.clear_glm_psi_gram_deriv(),
939 }
940 gradient_is_n_free
941 }
942
943 pub fn evaluate_with_order(
944 &mut self,
945 x: &DesignMatrix,
946 s_list: &[BlockwisePenalty],
947 nullspace_dims: &[usize],
948 linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
949 theta: &Array1<f64>,
950 rho_dim: usize,
951 hyper_dirs: Vec<DirectionalHyperParam>,
952 warm_start_beta: Option<ArrayView1<'_, f64>>,
953 context: &str,
954 order: crate::rho_optimizer::OuterEvalOrder,
955 design_revision: Option<u64>,
956 ) -> Result<(f64, Array1<f64>, gam_problem::HessianResult), EstimationError> {
957 let order = if matches!(
958 order,
959 crate::rho_optimizer::OuterEvalOrder::ValueGradientHessian
960 ) {
961 // Firth pair Hessian terms are now available via Primitive A +
962 // Primitive B in the reduced Firth dense operator; the tau-tau
963 // policy no longer needs the Firth+Logit gap downgrade.
964 let firth_pair_terms_unavailable = false;
965 let tau_tau_policy =
966 crate::estimate::reml::exact_tau_tau_hessian_policy_with_firth(
967 x.nrows(),
968 x.ncols(),
969 &hyper_dirs,
970 firth_pair_terms_unavailable,
971 );
972 if tau_tau_policy.prefer_gradient_only() {
973 log::warn!(
974 "[OUTER] disabling exact tau Hessian before conditioning; using gradient-only outer eval \
975 (n={}, p={}, psi_dim={}, implicit_tau={}, implicit_multidim_duchon={}, firth_pair_gap={}, dense_tau_cache={:.1} MiB, gradient_plan={:.1} MiB, exact_hessian_plan={:.1} MiB, budget={:.1} MiB)",
976 x.nrows(),
977 x.ncols(),
978 hyper_dirs.len(),
979 tau_tau_policy.any_has_implicit,
980 tau_tau_policy.implicit_multidim_duchon,
981 tau_tau_policy.firth_pair_terms_unavailable,
982 tau_tau_policy.estimated_dense_tau_cache_bytes as f64 / (1024.0 * 1024.0),
983 tau_tau_policy.gradient_plan.total_bytes() as f64 / (1024.0 * 1024.0),
984 tau_tau_policy.hessian_plan.total_bytes() as f64 / (1024.0 * 1024.0),
985 tau_tau_policy.budget_bytes as f64 / (1024.0 * 1024.0),
986 );
987 crate::rho_optimizer::OuterEvalOrder::ValueAndGradient
988 } else {
989 order
990 }
991 } else {
992 order
993 };
994 let hyper_dirs = self.prepare_eval_state(
995 x,
996 s_list,
997 nullspace_dims,
998 linear_constraints,
999 theta,
1000 rho_dim,
1001 hyper_dirs,
1002 warm_start_beta,
1003 context,
1004 design_revision,
1005 )?;
1006 crate::estimate::reml::RemlState::compute_joint_hyper_eval_with_order(
1007 &self.reml_state,
1008 theta,
1009 rho_dim,
1010 &hyper_dirs,
1011 order,
1012 )
1013 }
1014
1015 pub fn evaluate_efs(
1016 &mut self,
1017 x: &DesignMatrix,
1018 s_list: &[BlockwisePenalty],
1019 nullspace_dims: &[usize],
1020 linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
1021 theta: &Array1<f64>,
1022 rho_dim: usize,
1023 hyper_dirs: Vec<DirectionalHyperParam>,
1024 warm_start_beta: Option<ArrayView1<'_, f64>>,
1025 context: &str,
1026 design_revision: Option<u64>,
1027 ) -> Result<gam_problem::EfsEval, EstimationError> {
1028 let hyper_dirs = self.prepare_eval_state(
1029 x,
1030 s_list,
1031 nullspace_dims,
1032 linear_constraints,
1033 theta,
1034 rho_dim,
1035 hyper_dirs,
1036 warm_start_beta,
1037 context,
1038 design_revision,
1039 )?;
1040 let rho = theta.slice(s![..rho_dim]).to_owned();
1041 self.reml_state
1042 .compute_efs_steps_with_psi_ext(&rho, &hyper_dirs)
1043 }
1044
1045 /// Reset the inner surface for a value-only evaluation. This is the
1046 /// hyper-dir-free counterpart of [`prepare_eval_state`]: it accepts the
1047 /// fact that the spatial design has been re-realized at the current κ
1048 /// (the caller guarantees this via the realizer cache), so no directional
1049 /// hyper-derivatives are required to produce a correct cost. Skipping
1050 /// the hyper_dir validation and the per-direction conditioning loop is
1051 /// what makes line-search probes cheap in the iso/aniso joint paths.
1052 pub(crate) fn prepare_eval_state_cost_only(
1053 &mut self,
1054 x: &DesignMatrix,
1055 s_list: &[BlockwisePenalty],
1056 nullspace_dims: &[usize],
1057 linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
1058 theta: &Array1<f64>,
1059 rho_dim: usize,
1060 warm_start_beta: Option<ArrayView1<'_, f64>>,
1061 context: &str,
1062 design_revision: Option<u64>,
1063 ) -> Result<(), EstimationError> {
1064 // Design-revision fast path: when ψ hasn't moved since the last
1065 // full `reset_surface`, the cached surface's X, canonical penalties,
1066 // gaussian-fixed cache, and PIRLS cache are all still keyed to the
1067 // exact same (X, y, w, offset) — skip the eigendecomp + cache wipe.
1068 //
1069 // #1033 (reduced-basis rotation): a value-only probe's Gaussian-identity
1070 // inner solve reads its data statistics ONLY from the re-keyed
1071 // `GaussianFixedCache` + penalty-derived reparametrization Qs, never from
1072 // the frozen `self.x` rows, so the skip is sound across a basis ROTATION.
1073 // Gate on the n-free VALUE coverage (the condition under which the cache
1074 // re-key fires) rather than the stricter `reduced_basis_equal` witness,
1075 // which refused sound rotated-basis skips and forced the O(n) fallback.
1076 let skip_window_allows_fast_path = match (self.psi_gram_tensor.is_some(), theta.len()) {
1077 (true, len) if len == rho_dim + 1 => self.psi_gram_tensor_covers(theta[rho_dim]),
1078 _ => true,
1079 };
1080 let fast_path = match (design_revision, self.last_canonical_revision) {
1081 (Some(rev), Some(last)) => rev == last && skip_window_allows_fast_path,
1082 _ => false,
1083 };
1084 if fast_path {
1085 self.reml_state
1086 .set_penalty_shrinkage_floor(self.penalty_shrinkage_floor);
1087 self.reml_state.setwarm_start_original_beta(warm_start_beta);
1088 // #1111 / #1033 mechanism (c): a BFGS line-search VALUE probe can
1089 // carry its own ψ-keyed frozen-W first-step Gram staged by
1090 // `SpatialJointContext::eval_cost`. Install that staged Gram here on
1091 // the same fast path as full evals; when no current-probe value was
1092 // staged, this call clears any prior ψ's slot so stale GLM statistics
1093 // never leak into the probe. Cost-only probes do not consume the GLM
1094 // ψ-gradient derivative, so the stager passes `None` for that slot
1095 // and this call clears it as well.
1096 self.install_pending_glm_trial_statistics();
1097 // #1033: the Gaussian-identity `gaussian_fixed_cache` is ALSO keyed to
1098 // the trial's ψ (the certified ψ-Gram tensor's `XᵀWX(ψ)/XᵀWz(ψ)`), and
1099 // a VALUE probe runs at a different ψ than the eval that installed it.
1100 // On the fast path `reset_surface` is skipped, so without re-keying
1101 // here the inner Gaussian PLS would read the PREVIOUS ψ's Gram — a
1102 // stale-Gram correctness hazard. Re-install the n-free per-ψ Gram for
1103 // THIS probe's ψ (in-window) so the value probe is both correct AND
1104 // touches only k-dim sufficient statistics; off-window the installer
1105 // is a no-op and the surface keeps its streamed Gram. The conditioned
1106 // ψ-derivatives the installer also stages are gradient-channel objects
1107 // unused by `compute_cost`, but keying them to this ψ keeps a single
1108 // source of truth and avoids leaving a prior trial's pair installed.
1109 self.install_psi_gram_statistics(theta, rho_dim);
1110 // #1033 penalty lane (value-probe twin): the probe's ψ differs from the
1111 // eval that ran the last `reset_surface`, so the kept surface still
1112 // carries that ψ's `S`. Re-key `S(ψ_probe)` from the EXACT n-free
1113 // penalty the caller staged, for the same reason as the full-eval fast
1114 // path — otherwise the probe's inner Gaussian PLS pairs `XᵀWX(ψ_probe)`
1115 // (re-keyed above) with a stale `S` and reports the wrong cost,
1116 // mis-ranking the line search. The caller's `skip_value_realization`
1117 // gate requires `cache.supports_nfree_penalty_rekey()` and stages the
1118 // rebuild, so when enabled the re-key must succeed; treat a miss as a
1119 // hard error rather than a stale-`S` cost.
1120 if self.supports_nfree_penalty_rekey && !self.refresh_psi_penalty_surface()? {
1121 crate::bail_invalid_estim!(
1122 "value-probe design-revision fast path fired with n-free penalty re-key \
1123 enabled but no exact S(psi) was staged for psi={:.6} (theta_len={}, \
1124 rho_dim={}); the reset_surface skip would leave a stale S(psi). The caller \
1125 must call stage_fast_path_penalty before every skip-path value probe.",
1126 if theta.len() > rho_dim {
1127 theta[rho_dim]
1128 } else {
1129 f64::NAN
1130 },
1131 theta.len(),
1132 rho_dim,
1133 );
1134 }
1135 return Ok(());
1136 }
1137
1138 let p = x.ncols();
1139 let specs: Vec<PenaltySpec> = s_list.iter().map(PenaltySpec::from_blockwise_ref).collect();
1140 validate_penalty_specs(&specs, p, context)?;
1141 let (canonical, active_nullspace_dims) =
1142 gam_terms::construction::canonicalize_penalty_specs(&specs, nullspace_dims, p, context)?;
1143
1144 let x_fit = self.conditioning.apply_to_design(x);
1145 let fit_linear_constraints = self
1146 .conditioning
1147 .transform_linear_constraints_to_internal(linear_constraints);
1148
1149 // Cost-only paths do not introduce design drift via hyper_dirs, so
1150 // the directional-hyper-support check is unnecessary here.
1151 crate::estimate::reml::RemlState::reset_surface(
1152 &mut self.reml_state,
1153 x_fit,
1154 Arc::new(canonical),
1155 p,
1156 active_nullspace_dims,
1157 None,
1158 fit_linear_constraints,
1159 self.kronecker_penalty_system.clone(),
1160 self.kronecker_factored.clone(),
1161 )?;
1162 // #1033 instrumentation: this is the slow (n-row) reconditioning lane.
1163 self.slow_path_reset_count
1164 .set(self.slow_path_reset_count.get().wrapping_add(1));
1165 self.reml_state
1166 .set_penalty_shrinkage_floor(self.penalty_shrinkage_floor);
1167 self.reml_state.setwarm_start_original_beta(warm_start_beta);
1168 self.last_canonical_revision = design_revision;
1169 // #1264: freeze the reduced-basis reference ψ this slow-path reset pins.
1170 self.record_reset_psi(theta, rho_dim);
1171 self.psi_gram_anchor_correction = None;
1172 self.install_pending_glm_trial_statistics();
1173 self.install_psi_gram_statistics(theta, rho_dim);
1174 // #1033 penalty lane: the slow cost-only path rebuilt `S` from the freshly
1175 // realized design — drop any staged n-free penalty so a later fast-path
1176 // probe never consumes this trial's stale `S`.
1177 self.pending_psi_penalty = None;
1178 Ok(())
1179 }
1180
1181 /// Cost-only evaluation at the current κ-realized design. Used by the
1182 /// joint [ρ, ψ] BFGS line-search cost callback so probes pay neither the
1183 /// `try_build_spatial_log_kappa_hyper_dirs` cost nor the gradient assembly
1184 /// cost. The gradient callback continues to use [`evaluate_with_order`].
1185 ///
1186 /// Contract: the caller MUST have already realized the design at the κ
1187 /// implied by `theta`'s ψ tail (typically via the
1188 /// `SingleBlockExactJointDesignCache::ensure_theta` path). The penalty
1189 /// gradients w.r.t. ρ are independent of κ for the spatial single-block
1190 /// path, but correction gates still need to know that the objective lives
1191 /// on a joint `[ρ, ψ]` surface. Pass the ψ-tail count into the shared cost
1192 /// bridge so value-only probes decline the same ext-coordinate-incomplete
1193 /// corrections as the analytic joint path.
1194 pub fn evaluate_cost_only(
1195 &mut self,
1196 x: &DesignMatrix,
1197 s_list: &[BlockwisePenalty],
1198 nullspace_dims: &[usize],
1199 linear_constraints: Option<crate::pirls::LinearInequalityConstraints>,
1200 theta: &Array1<f64>,
1201 rho_dim: usize,
1202 warm_start_beta: Option<ArrayView1<'_, f64>>,
1203 context: &str,
1204 design_revision: Option<u64>,
1205 ) -> Result<f64, EstimationError> {
1206 if rho_dim > theta.len() {
1207 crate::bail_invalid_estim!(
1208 "rho_dim {} exceeds theta dimension {}",
1209 rho_dim,
1210 theta.len()
1211 );
1212 }
1213 self.prepare_eval_state_cost_only(
1214 x,
1215 s_list,
1216 nullspace_dims,
1217 linear_constraints,
1218 theta,
1219 rho_dim,
1220 warm_start_beta,
1221 context,
1222 design_revision,
1223 )?;
1224 let rho = theta.slice(s![..rho_dim]).to_owned();
1225 self.reml_state
1226 .compute_cost_with_ext_count(&rho, theta.len() - rho_dim)
1227 }
1228}
1229
1230// canonicalize_active_penalties removed — replaced by
1231// gam_terms::construction::canonicalize_penalty_specs.