Skip to main content

gam_solve/reml/reml_outer_engine/
hessian_operator_trait.rs

1use super::*;
2
3// ═══════════════════════════════════════════════════════════════════════════
4//  Core traits
5// ═══════════════════════════════════════════════════════════════════════════
6
7/// Fit-level stochastic trace state shared by all adaptive Hutchinson batches.
8///
9/// `monotone_probe_floor` pins the CRN prefix length across batches. The
10/// `cg_warm_starts` map stores the previous H⁻¹ solve for the same deterministic
11/// probe id so the next outer evaluation can initialize matrix-free trace CG
12/// from the matching probe only.
13#[derive(Debug, Default)]
14pub struct StochasticTraceState {
15    pub monotone_probe_floor: usize,
16    pub cg_warm_starts: HashMap<u64, Array1<f64>>,
17    pub solve_rel_tol_override: Option<f64>,
18    pub last_linear_residual_norm: Option<f64>,
19    pub last_probe_sigma_sq: Option<f64>,
20    pub last_probe_count: usize,
21}
22
23/// Abstract interface for Hessian linear algebra operations.
24///
25/// All operations use the SAME internal decomposition, ensuring spectral
26/// consistency between logdet (used in cost) and trace/solve (used in gradient).
27///
28/// Implementors:
29/// - `DenseSpectralOperator`: eigendecomposition of dense H
30/// - Sparse Cholesky operators (external implementations)
31/// - `BlockCoupledOperator`: eigendecomposition of joint multi-block H
32/// Minimum operator dimension at which the Hutch++ stochastic trace estimator is
33/// preferred over materializing an implicit operator densely. Below this, the
34/// `2·m_s + m_h` Hutch++ matvecs do not beat `dim` dense H⁻¹ HVPs, so the dense
35/// fallback is cheaper.
36pub(crate) const HUTCHPP_TRACE_MIN_DIM: usize = 128;
37
38/// Build the Hutch++ stochastic-trace configuration for an operator of the given
39/// dimension. The sketch dimension grows with `dim` (one column per 32 of
40/// dimension, bounded to `[4, 16]`), and the probe budget tracks the sketch so
41/// the estimator's variance and cost stay balanced across problem sizes. Shared
42/// by every implicit-operator trace path so they cannot drift apart.
43pub(crate) fn hutchpp_config_for_dim(dim: usize) -> StochasticTraceConfig {
44    const SKETCH_DIM_PER: usize = 32;
45    const SKETCH_DIM_MIN: usize = 4;
46    const SKETCH_DIM_MAX: usize = 16;
47    const PROBES_PER_SKETCH: usize = 4;
48    const PROBES_MAX_FLOOR: usize = 32;
49    const PROBES_MIN_FLOOR: usize = 8;
50    let sketch = (dim / SKETCH_DIM_PER).clamp(SKETCH_DIM_MIN, SKETCH_DIM_MAX);
51    let mut config = StochasticTraceConfig::default();
52    config.hutchpp_sketch_dim = Some(sketch);
53    config.n_probes_max = (sketch * PROBES_PER_SKETCH).max(PROBES_MAX_FLOOR);
54    config.n_probes_min = sketch.max(PROBES_MIN_FLOOR);
55    config
56}
57
58pub trait HessianOperator: Send + Sync {
59    /// log|H|₊ — pseudo-logdet using only active eigenvalues/pivots.
60    fn logdet(&self) -> f64;
61
62    /// tr(H₊⁻¹ A) — trace of pseudo-inverse times a symmetric matrix.
63    /// Uses the SAME decomposition as `logdet`.
64    fn trace_hinv_product(&self, a: &Array2<f64>) -> f64;
65
66    /// Exact dense spectral representation, when this backend has one.
67    ///
68    /// Outer-Hessian assembly uses this to batch all logdet-Hessian cross
69    /// traces in the eigenbasis. For CTN scale-dimension fits this avoids
70    /// projecting the same implicit ψ drift once per upper-triangular pair.
71    fn as_exact_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
72        None
73    }
74
75    /// Assemble the raw dense Hessian represented by this backend for
76    /// active-constraint tangent projection.
77    ///
78    /// Backends that do not store either a dense spectral decomposition or an
79    /// explicit factorization should keep the default error.
80    fn assemble_h_dense_for_tangent_projection(&self) -> Result<Array2<f64>, String> {
81        Err("backend does not support tangent projection".to_string())
82    }
83
84    /// tr(H₊⁻¹ B) for an operator-backed Hessian drift.
85    ///
86    /// Default implementation materializes `B` densely. Backends with
87    /// native operator traces (notably sparse Cholesky) should override it.
88    ///
89    /// For HVP-only (implicit) operators on large problems we route
90    /// through Hutch++ — the Meyer–Musco split estimator achieves O(1/ε)
91    /// matvecs vs O(1/ε²) for plain Hutchinson, and avoids the O(p²)
92    /// memory + O(p) HVP cost of materializing the operator densely.
93    fn trace_hinv_operator(&self, op: &dyn HyperOperator) -> f64 {
94        // Hutch++ fast path for the warn-and-materialize default. Only
95        // backends that fall through to this default reach here;
96        // backends with native operator traces override it. We require
97        // an implicit operator (so materialization is expensive) and a
98        // moderately-large dim (so 2 m_s + m_h matvecs beats `dim`
99        // dense HVPs).
100        if op.is_implicit() && self.dim() >= HUTCHPP_TRACE_MIN_DIM {
101            let config = hutchpp_config_for_dim(self.dim());
102            return hutchpp_estimate_trace_hinv_operator(self, op, &config);
103        }
104        if op.is_implicit() {
105            log::warn!(
106                "trace_hinv_operator: materializing implicit HyperOperator — \
107                 backend should provide a matrix-free override"
108            );
109        }
110        self.trace_hinv_product(&op.to_dense())
111    }
112
113    /// H⁻¹ v — linear solve using the active decomposition.
114    fn solve(&self, rhs: &Array1<f64>) -> Array1<f64>;
115
116    /// H⁻¹ M — multi-column solve.
117    fn solve_multi(&self, rhs: &Array2<f64>) -> Array2<f64>;
118
119    /// H⁻¹ v for stochastic trace probes.
120    ///
121    /// Exact backends use the normal solve. Matrix-free backends may override
122    /// this to use a looser PCG tolerance when the caller's Monte Carlo error
123    /// dominates the linear-solve error.
124    fn stochastic_trace_solve(&self, rhs: &Array1<f64>, rel_tol: f64) -> Array1<f64> {
125        assert!(
126            rel_tol.is_finite() && rel_tol > 0.0,
127            "stochastic trace solve tolerance must be positive and finite"
128        );
129        self.solve(rhs)
130    }
131
132    /// H⁻¹ v for a deterministic stochastic trace probe id.
133    ///
134    /// Backends with matrix-free CG may use `probe_id` to warm-start from the
135    /// previous solve of the same CRN probe. The default exact backend ignores
136    /// the id and uses the normal stochastic trace solve.
137    fn stochastic_trace_solve_for_probe(
138        &self,
139        rhs: &Array1<f64>,
140        rel_tol: f64,
141        probe_id: u64,
142        state: Option<&Arc<Mutex<StochasticTraceState>>>,
143    ) -> Array1<f64> {
144        // Default exact backend has no matrix-free CG, so per-probe warm
145        // starts are inapplicable. If a previous matrix-free backend left
146        // a warm-start vector for this `probe_id` in the shared state,
147        // drop it so a later matrix-free run does not consume a vector
148        // that was generated against a different operator factorization.
149        if let Some(state_arc) = state
150            && let Ok(mut guard) = state_arc.lock()
151        {
152            guard.cg_warm_starts.remove(&probe_id);
153        }
154        self.stochastic_trace_solve(rhs, rel_tol)
155    }
156
157    /// H⁻¹ M for stochastic trace probes.
158    fn stochastic_trace_solve_multi(&self, rhs: &Array2<f64>, rel_tol: f64) -> Array2<f64> {
159        assert!(
160            rel_tol.is_finite() && rel_tol > 0.0,
161            "stochastic trace multi-solve tolerance must be positive and finite"
162        );
163        self.solve_multi(rhs)
164    }
165
166    /// Whether this backend exposes a matrix-free operator usable by trace CG.
167    fn has_matrix_free_trace_cg_operator(&self) -> bool {
168        false
169    }
170
171    /// tr(H⁻¹ A H⁻¹ B) for dense symmetric Hessian drifts.
172    ///
173    /// This is the second-order trace object used by EFS denominators and the
174    /// ψ-block trace Gram preconditioner. The default implementation computes
175    /// both solved column stacks exactly and contracts them as
176    /// `tr((H⁻¹A)(H⁻¹B))`.
177    fn trace_hinv_product_cross(&self, a: &Array2<f64>, b: &Array2<f64>) -> f64 {
178        let solved_a = self.solve_multi(a);
179        if std::ptr::eq(a, b) {
180            return trace_matrix_product(&solved_a, &solved_a);
181        }
182        let solved_b = self.solve_multi(b);
183        trace_matrix_product(&solved_a, &solved_b)
184    }
185
186    /// tr(H⁻¹ A H⁻¹ B) for a dense drift `A` and an operator-backed drift `B`.
187    ///
188    /// Default implementation materializes the operator and dispatches to the
189    /// dense cross-trace path. Matrix-free and sparse backends should override
190    /// this to avoid dense operator materialization.
191    fn trace_hinv_matrix_operator_cross(
192        &self,
193        matrix: &Array2<f64>,
194        op: &dyn HyperOperator,
195    ) -> f64 {
196        if op.is_implicit() && self.dim() >= HUTCHPP_TRACE_MIN_DIM {
197            let config = hutchpp_config_for_dim(self.dim());
198            // Wrap the dense LHS in a matrix-backed HyperOperator so the
199            // shared cross routine can call mul_vec_into on it.
200            let lhs = DenseMatrixHyperOperator {
201                matrix: matrix.clone(),
202            };
203            return hutchpp_estimate_trace_hinv_operator_cross(self, &lhs, op, &config);
204        }
205        if op.is_implicit() {
206            log::warn!(
207                "trace_hinv_matrix_operator_cross: materializing implicit HyperOperator — \
208                 backend should provide a matrix-free override"
209            );
210        }
211        self.trace_hinv_product_cross(matrix, &op.to_dense())
212    }
213
214    /// tr(H⁻¹ A H⁻¹ B) for operator-backed Hessian drifts.
215    ///
216    /// Default implementation materializes both operators densely. Backends
217    /// with native operator-aware cross traces should override this.
218    fn trace_hinv_operator_cross(
219        &self,
220        left: &dyn HyperOperator,
221        right: &dyn HyperOperator,
222    ) -> f64 {
223        let l_implicit = left.is_implicit();
224        let r_implicit = right.is_implicit();
225        if (l_implicit || r_implicit) && self.dim() >= HUTCHPP_TRACE_MIN_DIM {
226            let config = hutchpp_config_for_dim(self.dim());
227            // Same-operator self-cross is PSD; the squared form is the
228            // exact algorithm for that case (lower variance, no sign).
229            if std::ptr::eq(
230                left as *const dyn HyperOperator as *const (),
231                right as *const dyn HyperOperator as *const (),
232            ) {
233                return hutchpp_estimate_trace_hinv_op_squared(self, left, &config);
234            }
235            return hutchpp_estimate_trace_hinv_operator_cross(self, left, right, &config);
236        }
237        if l_implicit || r_implicit {
238            log::warn!(
239                "trace_hinv_operator_cross: materializing implicit HyperOperator(s) — \
240                 backend should provide a matrix-free override"
241            );
242        }
243        self.trace_hinv_product_cross(&left.to_dense(), &right.to_dense())
244    }
245
246    /// tr(G_ε(H) A) — trace for the logdet gradient ∂_i log|R_ε(H)|.
247    ///
248    /// For non-spectral backends (Cholesky), G_ε = H⁻¹ and this reduces to
249    /// `trace_hinv_product`. For spectral regularization, G_ε uses eigenvalues
250    /// `φ'(σ_a) = 1/√(σ_a² + 4ε²)` instead of `1/r_ε(σ_a)`.
251    fn trace_logdet_gradient(&self, a: &Array2<f64>) -> f64 {
252        self.trace_hinv_product(a)
253    }
254
255    /// diag(X · G_ε(H) · Xᵀ) — the leverage corresponding to `trace_logdet_gradient`.
256    /// `trace_logdet_gradient(Xᵀ diag(w) X) = Σᵢ wᵢ · h^G[i]`.
257    ///
258    /// Streams the rows of `X` through the design's `try_row_chunk` so
259    /// operator-backed (Lazy) designs never materialize the full (n×p)
260    /// block at large scale.
261    fn xt_logdet_kernel_x_diagonal(&self, x: &DesignMatrix) -> Array1<f64> {
262        assert!(self.logdet_traces_match_hinv_kernel());
263        let n = x.nrows();
264        let p = x.ncols();
265
266        let block = {
267            const TARGET_CHUNK_FLOATS: usize = 1 << 16;
268            (TARGET_CHUNK_FLOATS / p.max(1)).clamp(1, n.max(1))
269        };
270
271        let mut h = Array1::<f64>::zeros(n);
272        let mut start = 0usize;
273        while start < n {
274            let end = (start + block).min(n);
275            let rows = x.try_row_chunk(start..end).unwrap_or_else(|err| {
276                // SAFETY: `try_row_chunk` only fails on operator implementation
277                // bugs — the `start..end` range is constructed from
278                // `0..n = 0..x.nrows()` with `end = (start+block).min(n)`,
279                // so it is always a valid sub-range of `x`. A failure here
280                // means the operator violated its row-chunk contract.
281                // SAFETY: row range built from 0..x.nrows(); failure means operator broke its contract.
282                reml_contract_panic(format!(
283                    "xt_logdet_kernel_x_diagonal: row chunk failed: {err}"
284                ))
285            });
286            let chunk_t = rows.t().to_owned();
287            let z_chunk = self.solve_multi(&chunk_t);
288            for (i, (row, z_col)) in rows
289                .outer_iter()
290                .zip(z_chunk.columns().into_iter())
291                .enumerate()
292            {
293                let mut acc = 0.0;
294                for (row_value, z_value) in row.iter().copied().zip(z_col.iter().copied()) {
295                    acc += row_value * z_value;
296                }
297                h[start + i] = acc;
298            }
299            start = end;
300        }
301        h
302    }
303
304    /// tr(G_ε(H) B) for an operator-backed Hessian drift.
305    ///
306    /// Default implementation materializes `B` densely. For Cholesky-based
307    /// backends this equals `trace_hinv_operator`.
308    ///
309    /// When `logdet_traces_match_hinv_kernel()` is true (Cholesky-style
310    /// backends where `trace_logdet_gradient(A) = trace_hinv_product(A)`)
311    /// and the operator is implicit on a moderate-or-large problem, route
312    /// through Hutch++ to avoid the dense materialization. Spectral
313    /// backends override this to false (their logdet trace uses
314    /// regularized eigenvalue weights, not `H⁻¹`), so they keep the
315    /// materialize path or provide their own override.
316    fn trace_logdet_operator(&self, op: &dyn HyperOperator) -> f64 {
317        if op.is_implicit()
318            && self.dim() >= HUTCHPP_TRACE_MIN_DIM
319            && self.logdet_traces_match_hinv_kernel()
320        {
321            let config = hutchpp_config_for_dim(self.dim());
322            return hutchpp_estimate_trace_hinv_operator(self, op, &config);
323        }
324        if op.is_implicit() {
325            log::warn!(
326                "trace_logdet_operator: materializing implicit HyperOperator — \
327                 backend should provide a matrix-free override"
328            );
329        }
330        self.trace_logdet_gradient(&op.to_dense())
331    }
332
333    /// Efficient computation of tr(G_ε(H) Hₖ) for the logdet gradient.
334    ///
335    /// Default implementation: forms the correction and calls `trace_logdet_gradient`.
336    fn trace_logdet_h_k(
337        &self,
338        a_k: &Array2<f64>,
339        third_deriv_correction: Option<&Array2<f64>>,
340    ) -> f64 {
341        let base = self.trace_logdet_gradient(a_k);
342        match third_deriv_correction {
343            Some(c) => base + self.trace_logdet_gradient(c),
344            None => base,
345        }
346    }
347
348    /// tr(G_ε(H) · A_block) where A_block is a p_block × p_block matrix
349    /// embedded at rows/columns [start..end].
350    ///
351    /// This avoids materializing the full p×p matrix for block-structured
352    /// penalties. The default implementation builds the full matrix and
353    /// delegates to `trace_logdet_gradient`; spectral backends override
354    /// this with O(p_block × active_rank) work.
355    fn trace_logdet_block_local(
356        &self,
357        block: &Array2<f64>,
358        scale: f64,
359        start: usize,
360        end: usize,
361    ) -> f64 {
362        let p = self.dim();
363        let mut full = Array2::<f64>::zeros((p, p));
364        let bs = end - start;
365        for i in 0..bs {
366            for j in 0..bs {
367                full[[start + i, start + j]] = scale * block[[i, j]];
368            }
369        }
370        self.trace_logdet_gradient(&full)
371    }
372
373    /// Cross-trace for the logdet Hessian:
374    /// `∂²_{ij} log|R_ε(H)| = tr(G_ε Ḧ_{ij}) + spectral_cross(Ḣ_i, Ḣ_j)`.
375    ///
376    /// This method computes the `spectral_cross(Ḣ_i, Ḣ_j)` part, which for
377    /// non-spectral backends equals `-tr(H⁻¹ Ḣ_j H⁻¹ Ḣ_i)`.
378    ///
379    /// For spectral regularization, the divided-difference kernel Γ_{ab} replaces
380    /// the simple product of inverses.
381    fn trace_logdet_hessian_cross(&self, h_i: &Array2<f64>, h_j: &Array2<f64>) -> f64 {
382        // Default: standard formula -tr(H⁻¹ Ḣ_j H⁻¹ Ḣ_i) = -⟨Y_j^T, Y_i⟩_F
383        // where Y_i = H⁻¹ Ḣ_i.
384        let y_i = self.solve_multi(h_i);
385        if std::ptr::eq(h_i, h_j) {
386            return -trace_matrix_product(&y_i, &y_i);
387        }
388        let y_j = self.solve_multi(h_j);
389        -trace_matrix_product(&y_j, &y_i)
390    }
391
392    /// Operator-backed mixed form of [`trace_logdet_hessian_cross`].
393    ///
394    /// The default materializes the operator; spectral and sparse backends
395    /// override this to keep the exact analytic cross trace matrix-free.
396    fn trace_logdet_hessian_cross_matrix_operator(
397        &self,
398        h_i: &Array2<f64>,
399        h_j: &dyn HyperOperator,
400    ) -> f64 {
401        self.trace_logdet_hessian_cross(h_i, &h_j.to_dense())
402    }
403
404    /// Operator-backed form of [`trace_logdet_hessian_cross`].
405    ///
406    /// The default materializes both operators; exact backends override this
407    /// when they can contract the logdet-Hessian kernel against operator
408    /// projections directly.
409    fn trace_logdet_hessian_cross_operator(
410        &self,
411        h_i: &dyn HyperOperator,
412        h_j: &dyn HyperOperator,
413    ) -> f64 {
414        self.trace_logdet_hessian_cross(&h_i.to_dense(), &h_j.to_dense())
415    }
416
417    /// Number of active dimensions (rank of pseudo-inverse).
418    fn active_rank(&self) -> usize;
419
420    /// Full dimension of H.
421    fn dim(&self) -> usize;
422
423    /// Whether this operator is backed by a dense factorization.
424    ///
425    /// Dense operators (eigendecomposition) have O(p²) trace cost per matrix,
426    /// making stochastic trace estimation worthwhile for large p.  Sparse
427    /// operators (Cholesky) have O(nnz) solve cost, so exact column-by-column
428    /// traces are already cheap and stochastic estimation is not needed.
429    fn is_dense(&self) -> bool {
430        false
431    }
432
433    /// Whether the unified evaluator should batch large trace computations
434    /// through the stochastic Hutchinson path for this operator.
435    ///
436    /// Dense eigendecomposition backends prefer this once `p` is large because
437    /// exact per-coordinate traces are O(p²). Matrix-free iterative backends
438    /// have the same preference even though they do not store a dense factor.
439    fn prefers_stochastic_trace_estimation(&self) -> bool {
440        self.is_dense()
441    }
442
443    /// Whether stochastic Hutchinson estimates based on `H⁻¹` are valid for
444    /// logdet-gradient / logdet-Hessian trace terms on this backend.
445    ///
446    /// This is true for plain SPD-logdet operators where
447    /// `trace_logdet_gradient(A) = tr(H⁻¹ A)` and
448    /// `trace_logdet_hessian_cross(A, B) = -tr(H⁻¹ A H⁻¹ B)`.
449    ///
450    /// Smooth spectral regularization does not satisfy those identities, so
451    /// dense spectral backends must override this to `false`.
452    fn logdet_traces_match_hinv_kernel(&self) -> bool {
453        true
454    }
455
456    /// Access the dense spectral backend when this operator is powered by a
457    /// single eigendecomposition.
458    fn as_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
459        None
460    }
461}
462
463/// Representative curvature scale for a Hessian operator.
464///
465/// Returns the geometric mean of the active Hessian eigenvalues,
466/// `exp(log|H|_+ / rank(H))`. This has the same physical units as a Hessian
467/// diagonal entry but is basis-invariant, cheap after the operator has computed
468/// its log-determinant, and well-defined for both dense spectral and
469/// matrix-free operator paths.
470pub fn hessian_operator_geometric_scale(op: &dyn HessianOperator) -> Option<f64> {
471    let rank = op.active_rank();
472    if rank == 0 {
473        return None;
474    }
475    let logdet = op.logdet();
476    if !logdet.is_finite() {
477        return None;
478    }
479    let scale = (logdet / rank as f64).exp();
480    if scale.is_finite() && scale > 0.0 {
481        Some(scale)
482    } else {
483        None
484    }
485}