gam_solve/reml/reml_outer_engine/hessian_operator_trait.rs
1use super::*;
2
3// ═══════════════════════════════════════════════════════════════════════════
4// Core traits
5// ═══════════════════════════════════════════════════════════════════════════
6
7/// Fit-level stochastic trace state shared by all adaptive Hutchinson batches.
8///
9/// `monotone_probe_floor` pins the CRN prefix length across batches. The
10/// `cg_warm_starts` map stores the previous H⁻¹ solve for the same deterministic
11/// probe id so the next outer evaluation can initialize matrix-free trace CG
12/// from the matching probe only.
13#[derive(Debug, Default)]
14pub struct StochasticTraceState {
15 pub monotone_probe_floor: usize,
16 pub cg_warm_starts: HashMap<u64, Array1<f64>>,
17 pub solve_rel_tol_override: Option<f64>,
18 pub last_linear_residual_norm: Option<f64>,
19 pub last_probe_sigma_sq: Option<f64>,
20 pub last_probe_count: usize,
21}
22
23/// Abstract interface for Hessian linear algebra operations.
24///
25/// All operations use the SAME internal decomposition, ensuring spectral
26/// consistency between logdet (used in cost) and trace/solve (used in gradient).
27///
28/// Implementors:
29/// - `DenseSpectralOperator`: eigendecomposition of dense H
30/// - Sparse Cholesky operators (external implementations)
31/// - `BlockCoupledOperator`: eigendecomposition of joint multi-block H
32/// Minimum operator dimension at which the Hutch++ stochastic trace estimator is
33/// preferred over materializing an implicit operator densely. Below this, the
34/// `2·m_s + m_h` Hutch++ matvecs do not beat `dim` dense H⁻¹ HVPs, so the dense
35/// fallback is cheaper.
36pub(crate) const HUTCHPP_TRACE_MIN_DIM: usize = 128;
37
38/// Build the Hutch++ stochastic-trace configuration for an operator of the given
39/// dimension. The sketch dimension grows with `dim` (one column per 32 of
40/// dimension, bounded to `[4, 16]`), and the probe budget tracks the sketch so
41/// the estimator's variance and cost stay balanced across problem sizes. Shared
42/// by every implicit-operator trace path so they cannot drift apart.
43pub(crate) fn hutchpp_config_for_dim(dim: usize) -> StochasticTraceConfig {
44 const SKETCH_DIM_PER: usize = 32;
45 const SKETCH_DIM_MIN: usize = 4;
46 const SKETCH_DIM_MAX: usize = 16;
47 const PROBES_PER_SKETCH: usize = 4;
48 const PROBES_MAX_FLOOR: usize = 32;
49 const PROBES_MIN_FLOOR: usize = 8;
50 let sketch = (dim / SKETCH_DIM_PER).clamp(SKETCH_DIM_MIN, SKETCH_DIM_MAX);
51 let mut config = StochasticTraceConfig::default();
52 config.hutchpp_sketch_dim = Some(sketch);
53 config.n_probes_max = (sketch * PROBES_PER_SKETCH).max(PROBES_MAX_FLOOR);
54 config.n_probes_min = sketch.max(PROBES_MIN_FLOOR);
55 config
56}
57
58pub trait HessianOperator: Send + Sync {
59 /// log|H|₊ — pseudo-logdet using only active eigenvalues/pivots.
60 fn logdet(&self) -> f64;
61
62 /// tr(H₊⁻¹ A) — trace of pseudo-inverse times a symmetric matrix.
63 /// Uses the SAME decomposition as `logdet`.
64 fn trace_hinv_product(&self, a: &Array2<f64>) -> f64;
65
66 /// Exact dense spectral representation, when this backend has one.
67 ///
68 /// Outer-Hessian assembly uses this to batch all logdet-Hessian cross
69 /// traces in the eigenbasis. For CTN scale-dimension fits this avoids
70 /// projecting the same implicit ψ drift once per upper-triangular pair.
71 fn as_exact_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
72 None
73 }
74
75 /// Assemble the raw dense Hessian represented by this backend for
76 /// active-constraint tangent projection.
77 ///
78 /// Backends that do not store either a dense spectral decomposition or an
79 /// explicit factorization should keep the default error.
80 fn assemble_h_dense_for_tangent_projection(&self) -> Result<Array2<f64>, String> {
81 Err("backend does not support tangent projection".to_string())
82 }
83
84 /// tr(H₊⁻¹ B) for an operator-backed Hessian drift.
85 ///
86 /// Default implementation materializes `B` densely. Backends with
87 /// native operator traces (notably sparse Cholesky) should override it.
88 ///
89 /// For HVP-only (implicit) operators on large problems we route
90 /// through Hutch++ — the Meyer–Musco split estimator achieves O(1/ε)
91 /// matvecs vs O(1/ε²) for plain Hutchinson, and avoids the O(p²)
92 /// memory + O(p) HVP cost of materializing the operator densely.
93 fn trace_hinv_operator(&self, op: &dyn HyperOperator) -> f64 {
94 // Hutch++ fast path for the warn-and-materialize default. Only
95 // backends that fall through to this default reach here;
96 // backends with native operator traces override it. We require
97 // an implicit operator (so materialization is expensive) and a
98 // moderately-large dim (so 2 m_s + m_h matvecs beats `dim`
99 // dense HVPs).
100 if op.is_implicit() && self.dim() >= HUTCHPP_TRACE_MIN_DIM {
101 let config = hutchpp_config_for_dim(self.dim());
102 return hutchpp_estimate_trace_hinv_operator(self, op, &config);
103 }
104 if op.is_implicit() {
105 log::warn!(
106 "trace_hinv_operator: materializing implicit HyperOperator — \
107 backend should provide a matrix-free override"
108 );
109 }
110 self.trace_hinv_product(&op.to_dense())
111 }
112
113 /// H⁻¹ v — linear solve using the active decomposition.
114 fn solve(&self, rhs: &Array1<f64>) -> Array1<f64>;
115
116 /// H⁻¹ M — multi-column solve.
117 fn solve_multi(&self, rhs: &Array2<f64>) -> Array2<f64>;
118
119 /// H⁻¹ v for stochastic trace probes.
120 ///
121 /// Exact backends use the normal solve. Matrix-free backends may override
122 /// this to use a looser PCG tolerance when the caller's Monte Carlo error
123 /// dominates the linear-solve error.
124 fn stochastic_trace_solve(&self, rhs: &Array1<f64>, rel_tol: f64) -> Array1<f64> {
125 assert!(
126 rel_tol.is_finite() && rel_tol > 0.0,
127 "stochastic trace solve tolerance must be positive and finite"
128 );
129 self.solve(rhs)
130 }
131
132 /// H⁻¹ v for a deterministic stochastic trace probe id.
133 ///
134 /// Backends with matrix-free CG may use `probe_id` to warm-start from the
135 /// previous solve of the same CRN probe. The default exact backend ignores
136 /// the id and uses the normal stochastic trace solve.
137 fn stochastic_trace_solve_for_probe(
138 &self,
139 rhs: &Array1<f64>,
140 rel_tol: f64,
141 probe_id: u64,
142 state: Option<&Arc<Mutex<StochasticTraceState>>>,
143 ) -> Array1<f64> {
144 // Default exact backend has no matrix-free CG, so per-probe warm
145 // starts are inapplicable. If a previous matrix-free backend left
146 // a warm-start vector for this `probe_id` in the shared state,
147 // drop it so a later matrix-free run does not consume a vector
148 // that was generated against a different operator factorization.
149 if let Some(state_arc) = state
150 && let Ok(mut guard) = state_arc.lock()
151 {
152 guard.cg_warm_starts.remove(&probe_id);
153 }
154 self.stochastic_trace_solve(rhs, rel_tol)
155 }
156
157 /// H⁻¹ M for stochastic trace probes.
158 fn stochastic_trace_solve_multi(&self, rhs: &Array2<f64>, rel_tol: f64) -> Array2<f64> {
159 assert!(
160 rel_tol.is_finite() && rel_tol > 0.0,
161 "stochastic trace multi-solve tolerance must be positive and finite"
162 );
163 self.solve_multi(rhs)
164 }
165
166 /// Whether this backend exposes a matrix-free operator usable by trace CG.
167 fn has_matrix_free_trace_cg_operator(&self) -> bool {
168 false
169 }
170
171 /// tr(H⁻¹ A H⁻¹ B) for dense symmetric Hessian drifts.
172 ///
173 /// This is the second-order trace object used by EFS denominators and the
174 /// ψ-block trace Gram preconditioner. The default implementation computes
175 /// both solved column stacks exactly and contracts them as
176 /// `tr((H⁻¹A)(H⁻¹B))`.
177 fn trace_hinv_product_cross(&self, a: &Array2<f64>, b: &Array2<f64>) -> f64 {
178 let solved_a = self.solve_multi(a);
179 if std::ptr::eq(a, b) {
180 return trace_matrix_product(&solved_a, &solved_a);
181 }
182 let solved_b = self.solve_multi(b);
183 trace_matrix_product(&solved_a, &solved_b)
184 }
185
186 /// tr(H⁻¹ A H⁻¹ B) for a dense drift `A` and an operator-backed drift `B`.
187 ///
188 /// Default implementation materializes the operator and dispatches to the
189 /// dense cross-trace path. Matrix-free and sparse backends should override
190 /// this to avoid dense operator materialization.
191 fn trace_hinv_matrix_operator_cross(
192 &self,
193 matrix: &Array2<f64>,
194 op: &dyn HyperOperator,
195 ) -> f64 {
196 if op.is_implicit() && self.dim() >= HUTCHPP_TRACE_MIN_DIM {
197 let config = hutchpp_config_for_dim(self.dim());
198 // Wrap the dense LHS in a matrix-backed HyperOperator so the
199 // shared cross routine can call mul_vec_into on it.
200 let lhs = DenseMatrixHyperOperator {
201 matrix: matrix.clone(),
202 };
203 return hutchpp_estimate_trace_hinv_operator_cross(self, &lhs, op, &config);
204 }
205 if op.is_implicit() {
206 log::warn!(
207 "trace_hinv_matrix_operator_cross: materializing implicit HyperOperator — \
208 backend should provide a matrix-free override"
209 );
210 }
211 self.trace_hinv_product_cross(matrix, &op.to_dense())
212 }
213
214 /// tr(H⁻¹ A H⁻¹ B) for operator-backed Hessian drifts.
215 ///
216 /// Default implementation materializes both operators densely. Backends
217 /// with native operator-aware cross traces should override this.
218 fn trace_hinv_operator_cross(
219 &self,
220 left: &dyn HyperOperator,
221 right: &dyn HyperOperator,
222 ) -> f64 {
223 let l_implicit = left.is_implicit();
224 let r_implicit = right.is_implicit();
225 if (l_implicit || r_implicit) && self.dim() >= HUTCHPP_TRACE_MIN_DIM {
226 let config = hutchpp_config_for_dim(self.dim());
227 // Same-operator self-cross is PSD; the squared form is the
228 // exact algorithm for that case (lower variance, no sign).
229 if std::ptr::eq(
230 left as *const dyn HyperOperator as *const (),
231 right as *const dyn HyperOperator as *const (),
232 ) {
233 return hutchpp_estimate_trace_hinv_op_squared(self, left, &config);
234 }
235 return hutchpp_estimate_trace_hinv_operator_cross(self, left, right, &config);
236 }
237 if l_implicit || r_implicit {
238 log::warn!(
239 "trace_hinv_operator_cross: materializing implicit HyperOperator(s) — \
240 backend should provide a matrix-free override"
241 );
242 }
243 self.trace_hinv_product_cross(&left.to_dense(), &right.to_dense())
244 }
245
246 /// tr(G_ε(H) A) — trace for the logdet gradient ∂_i log|R_ε(H)|.
247 ///
248 /// For non-spectral backends (Cholesky), G_ε = H⁻¹ and this reduces to
249 /// `trace_hinv_product`. For spectral regularization, G_ε uses eigenvalues
250 /// `φ'(σ_a) = 1/√(σ_a² + 4ε²)` instead of `1/r_ε(σ_a)`.
251 fn trace_logdet_gradient(&self, a: &Array2<f64>) -> f64 {
252 self.trace_hinv_product(a)
253 }
254
255 /// diag(X · G_ε(H) · Xᵀ) — the leverage corresponding to `trace_logdet_gradient`.
256 /// `trace_logdet_gradient(Xᵀ diag(w) X) = Σᵢ wᵢ · h^G[i]`.
257 ///
258 /// Streams the rows of `X` through the design's `try_row_chunk` so
259 /// operator-backed (Lazy) designs never materialize the full (n×p)
260 /// block at large scale.
261 fn xt_logdet_kernel_x_diagonal(&self, x: &DesignMatrix) -> Array1<f64> {
262 assert!(self.logdet_traces_match_hinv_kernel());
263 let n = x.nrows();
264 let p = x.ncols();
265
266 let block = {
267 const TARGET_CHUNK_FLOATS: usize = 1 << 16;
268 (TARGET_CHUNK_FLOATS / p.max(1)).clamp(1, n.max(1))
269 };
270
271 let mut h = Array1::<f64>::zeros(n);
272 let mut start = 0usize;
273 while start < n {
274 let end = (start + block).min(n);
275 let rows = x.try_row_chunk(start..end).unwrap_or_else(|err| {
276 // SAFETY: `try_row_chunk` only fails on operator implementation
277 // bugs — the `start..end` range is constructed from
278 // `0..n = 0..x.nrows()` with `end = (start+block).min(n)`,
279 // so it is always a valid sub-range of `x`. A failure here
280 // means the operator violated its row-chunk contract.
281 // SAFETY: row range built from 0..x.nrows(); failure means operator broke its contract.
282 reml_contract_panic(format!(
283 "xt_logdet_kernel_x_diagonal: row chunk failed: {err}"
284 ))
285 });
286 let chunk_t = rows.t().to_owned();
287 let z_chunk = self.solve_multi(&chunk_t);
288 for (i, (row, z_col)) in rows
289 .outer_iter()
290 .zip(z_chunk.columns().into_iter())
291 .enumerate()
292 {
293 let mut acc = 0.0;
294 for (row_value, z_value) in row.iter().copied().zip(z_col.iter().copied()) {
295 acc += row_value * z_value;
296 }
297 h[start + i] = acc;
298 }
299 start = end;
300 }
301 h
302 }
303
304 /// tr(G_ε(H) B) for an operator-backed Hessian drift.
305 ///
306 /// Default implementation materializes `B` densely. For Cholesky-based
307 /// backends this equals `trace_hinv_operator`.
308 ///
309 /// When `logdet_traces_match_hinv_kernel()` is true (Cholesky-style
310 /// backends where `trace_logdet_gradient(A) = trace_hinv_product(A)`)
311 /// and the operator is implicit on a moderate-or-large problem, route
312 /// through Hutch++ to avoid the dense materialization. Spectral
313 /// backends override this to false (their logdet trace uses
314 /// regularized eigenvalue weights, not `H⁻¹`), so they keep the
315 /// materialize path or provide their own override.
316 fn trace_logdet_operator(&self, op: &dyn HyperOperator) -> f64 {
317 if op.is_implicit()
318 && self.dim() >= HUTCHPP_TRACE_MIN_DIM
319 && self.logdet_traces_match_hinv_kernel()
320 {
321 let config = hutchpp_config_for_dim(self.dim());
322 return hutchpp_estimate_trace_hinv_operator(self, op, &config);
323 }
324 if op.is_implicit() {
325 log::warn!(
326 "trace_logdet_operator: materializing implicit HyperOperator — \
327 backend should provide a matrix-free override"
328 );
329 }
330 self.trace_logdet_gradient(&op.to_dense())
331 }
332
333 /// Efficient computation of tr(G_ε(H) Hₖ) for the logdet gradient.
334 ///
335 /// Default implementation: forms the correction and calls `trace_logdet_gradient`.
336 fn trace_logdet_h_k(
337 &self,
338 a_k: &Array2<f64>,
339 third_deriv_correction: Option<&Array2<f64>>,
340 ) -> f64 {
341 let base = self.trace_logdet_gradient(a_k);
342 match third_deriv_correction {
343 Some(c) => base + self.trace_logdet_gradient(c),
344 None => base,
345 }
346 }
347
348 /// tr(G_ε(H) · A_block) where A_block is a p_block × p_block matrix
349 /// embedded at rows/columns [start..end].
350 ///
351 /// This avoids materializing the full p×p matrix for block-structured
352 /// penalties. The default implementation builds the full matrix and
353 /// delegates to `trace_logdet_gradient`; spectral backends override
354 /// this with O(p_block × active_rank) work.
355 fn trace_logdet_block_local(
356 &self,
357 block: &Array2<f64>,
358 scale: f64,
359 start: usize,
360 end: usize,
361 ) -> f64 {
362 let p = self.dim();
363 let mut full = Array2::<f64>::zeros((p, p));
364 let bs = end - start;
365 for i in 0..bs {
366 for j in 0..bs {
367 full[[start + i, start + j]] = scale * block[[i, j]];
368 }
369 }
370 self.trace_logdet_gradient(&full)
371 }
372
373 /// Cross-trace for the logdet Hessian:
374 /// `∂²_{ij} log|R_ε(H)| = tr(G_ε Ḧ_{ij}) + spectral_cross(Ḣ_i, Ḣ_j)`.
375 ///
376 /// This method computes the `spectral_cross(Ḣ_i, Ḣ_j)` part, which for
377 /// non-spectral backends equals `-tr(H⁻¹ Ḣ_j H⁻¹ Ḣ_i)`.
378 ///
379 /// For spectral regularization, the divided-difference kernel Γ_{ab} replaces
380 /// the simple product of inverses.
381 fn trace_logdet_hessian_cross(&self, h_i: &Array2<f64>, h_j: &Array2<f64>) -> f64 {
382 // Default: standard formula -tr(H⁻¹ Ḣ_j H⁻¹ Ḣ_i) = -⟨Y_j^T, Y_i⟩_F
383 // where Y_i = H⁻¹ Ḣ_i.
384 let y_i = self.solve_multi(h_i);
385 if std::ptr::eq(h_i, h_j) {
386 return -trace_matrix_product(&y_i, &y_i);
387 }
388 let y_j = self.solve_multi(h_j);
389 -trace_matrix_product(&y_j, &y_i)
390 }
391
392 /// Operator-backed mixed form of [`trace_logdet_hessian_cross`].
393 ///
394 /// The default materializes the operator; spectral and sparse backends
395 /// override this to keep the exact analytic cross trace matrix-free.
396 fn trace_logdet_hessian_cross_matrix_operator(
397 &self,
398 h_i: &Array2<f64>,
399 h_j: &dyn HyperOperator,
400 ) -> f64 {
401 self.trace_logdet_hessian_cross(h_i, &h_j.to_dense())
402 }
403
404 /// Operator-backed form of [`trace_logdet_hessian_cross`].
405 ///
406 /// The default materializes both operators; exact backends override this
407 /// when they can contract the logdet-Hessian kernel against operator
408 /// projections directly.
409 fn trace_logdet_hessian_cross_operator(
410 &self,
411 h_i: &dyn HyperOperator,
412 h_j: &dyn HyperOperator,
413 ) -> f64 {
414 self.trace_logdet_hessian_cross(&h_i.to_dense(), &h_j.to_dense())
415 }
416
417 /// Number of active dimensions (rank of pseudo-inverse).
418 fn active_rank(&self) -> usize;
419
420 /// Full dimension of H.
421 fn dim(&self) -> usize;
422
423 /// Whether this operator is backed by a dense factorization.
424 ///
425 /// Dense operators (eigendecomposition) have O(p²) trace cost per matrix,
426 /// making stochastic trace estimation worthwhile for large p. Sparse
427 /// operators (Cholesky) have O(nnz) solve cost, so exact column-by-column
428 /// traces are already cheap and stochastic estimation is not needed.
429 fn is_dense(&self) -> bool {
430 false
431 }
432
433 /// Whether the unified evaluator should batch large trace computations
434 /// through the stochastic Hutchinson path for this operator.
435 ///
436 /// Dense eigendecomposition backends prefer this once `p` is large because
437 /// exact per-coordinate traces are O(p²). Matrix-free iterative backends
438 /// have the same preference even though they do not store a dense factor.
439 fn prefers_stochastic_trace_estimation(&self) -> bool {
440 self.is_dense()
441 }
442
443 /// Whether stochastic Hutchinson estimates based on `H⁻¹` are valid for
444 /// logdet-gradient / logdet-Hessian trace terms on this backend.
445 ///
446 /// This is true for plain SPD-logdet operators where
447 /// `trace_logdet_gradient(A) = tr(H⁻¹ A)` and
448 /// `trace_logdet_hessian_cross(A, B) = -tr(H⁻¹ A H⁻¹ B)`.
449 ///
450 /// Smooth spectral regularization does not satisfy those identities, so
451 /// dense spectral backends must override this to `false`.
452 fn logdet_traces_match_hinv_kernel(&self) -> bool {
453 true
454 }
455
456 /// Access the dense spectral backend when this operator is powered by a
457 /// single eigendecomposition.
458 fn as_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
459 None
460 }
461}
462
463/// Representative curvature scale for a Hessian operator.
464///
465/// Returns the geometric mean of the active Hessian eigenvalues,
466/// `exp(log|H|_+ / rank(H))`. This has the same physical units as a Hessian
467/// diagonal entry but is basis-invariant, cheap after the operator has computed
468/// its log-determinant, and well-defined for both dense spectral and
469/// matrix-free operator paths.
470pub fn hessian_operator_geometric_scale(op: &dyn HessianOperator) -> Option<f64> {
471 let rank = op.active_rank();
472 if rank == 0 {
473 return None;
474 }
475 let logdet = op.logdet();
476 if !logdet.is_finite() {
477 return None;
478 }
479 let scale = (logdet / rank as f64).exp();
480 if scale.is_finite() && scale > 0.0 {
481 Some(scale)
482 } else {
483 None
484 }
485}