Skip to main content

gam_solve/reml/reml_outer_engine/
sparse_cholesky_backends.rs

1use super::*;
2
3// ═══════════════════════════════════════════════════════════════════════════
4//  Sparse Cholesky HessianOperator implementation
5// ═══════════════════════════════════════════════════════════════════════════
6
7/// Sparse Cholesky Hessian operator.
8///
9/// Wraps an existing `SparseExactFactor` and provides logdet, trace, and solve
10/// from the same Cholesky factorization.
11pub struct SparseCholeskyOperator {
12    /// The sparse Cholesky factorization.
13    pub(crate) factor: std::sync::Arc<gam_linalg::sparse_exact::SparseExactFactor>,
14    /// Takahashi selected inverse (precomputed H^{-1} entries on the filled pattern of L).
15    /// When available, trace computations use direct lookups instead of column solves.
16    pub(crate) takahashi: Option<std::sync::Arc<gam_linalg::sparse_exact::TakahashiInverse>>,
17    /// Precomputed log-determinant from the Cholesky diagonal.
18    pub(crate) cached_logdet: f64,
19    /// Dimension of H.
20    pub(crate) n_dim: usize,
21}
22
23impl SparseCholeskyOperator {
24    /// Create from an existing sparse factorization and its precomputed logdet.
25    pub fn new(
26        factor: std::sync::Arc<gam_linalg::sparse_exact::SparseExactFactor>,
27        logdet_h: f64,
28        dim: usize,
29    ) -> Self {
30        Self {
31            factor,
32            takahashi: None,
33            cached_logdet: logdet_h,
34            n_dim: dim,
35        }
36    }
37
38    pub fn with_takahashi(
39        mut self,
40        taka: std::sync::Arc<gam_linalg::sparse_exact::TakahashiInverse>,
41    ) -> Self {
42        self.takahashi = Some(taka);
43        self
44    }
45
46    pub(crate) const OPERATOR_SOLVE_CHUNK: usize = 64;
47
48    pub(crate) fn takahashi_block_trace(
49        taka: &gam_linalg::sparse_exact::TakahashiInverse,
50        block: &Array2<f64>,
51        start: usize,
52    ) -> f64 {
53        assert_eq!(block.nrows(), block.ncols());
54        let mut trace = 0.0;
55        for i in 0..block.nrows() {
56            let diag = block[[i, i]];
57            if diag.abs() > 1e-30 {
58                trace += taka.get(start + i, start + i) * diag;
59            }
60            for j in (i + 1)..block.ncols() {
61                let pair = block[[i, j]] + block[[j, i]];
62                if pair.abs() > 1e-30 {
63                    trace += taka.get(start + i, start + j) * pair;
64                }
65            }
66        }
67        trace
68    }
69
70    pub(crate) fn takahashi_left_multiply_block(
71        taka: &gam_linalg::sparse_exact::TakahashiInverse,
72        block: &Array2<f64>,
73        start: usize,
74    ) -> Array2<f64> {
75        let dim = block.nrows();
76        let mut out = Array2::<f64>::zeros((dim, dim));
77        for i in 0..dim {
78            let z_diag = taka.get(start + i, start + i);
79            if z_diag.abs() > 1e-30 {
80                for k in 0..dim {
81                    out[[i, k]] += z_diag * block[[i, k]];
82                }
83            }
84            for j in (i + 1)..dim {
85                let z = taka.get(start + i, start + j);
86                if z.abs() <= 1e-30 {
87                    continue;
88                }
89                for k in 0..dim {
90                    out[[i, k]] += z * block[[j, k]];
91                    out[[j, k]] += z * block[[i, k]];
92                }
93            }
94        }
95        out
96    }
97
98    pub(crate) fn trace_hinv_operator_exact(&self, op: &dyn HyperOperator) -> f64 {
99        let (range_start, range_end) = op
100            .block_local_data()
101            .map(|(_, start, end)| (start, end))
102            .unwrap_or((0, self.n_dim));
103        let chunk = Self::OPERATOR_SOLVE_CHUNK.min(self.n_dim.max(1));
104        let mut trace = 0.0_f64;
105        let mut rhs_block = Array2::<f64>::zeros((self.n_dim, chunk));
106        let mut start = range_start;
107
108        while start < range_end {
109            let end = (start + chunk).min(range_end);
110            let cols = end - start;
111            op.mul_basis_columns_into(start, rhs_block.slice_mut(ndarray::s![.., ..cols]));
112
113            let diagonal_sum = if cols == chunk {
114                gam_linalg::sparse_exact::solve_sparse_spdmulti_diagonal_sum(
115                    &self.factor,
116                    &rhs_block,
117                    start,
118                )
119            } else {
120                let rhs_view = rhs_block.slice(ndarray::s![.., ..cols]);
121                gam_linalg::sparse_exact::solve_sparse_spdmulti_diagonal_sum(
122                    &self.factor,
123                    &rhs_view,
124                    start,
125                )
126            };
127            trace += diagonal_sum.unwrap_or_else(|e| {
128                // SAFETY: `SparseCholeskyOperator` is constructed only with a
129                // successfully-factorized SPD `self.factor`. The sparse SPD
130                // multi-RHS solve only fails on factor corruption or RHS
131                // shape mismatch; the RHS comes from `mul_basis_columns_into`
132                // matching the factor's dimension, so failure here means
133                // the cached factor was corrupted after construction —
134                // a hard invariant violation.
135                // SAFETY: self.factor is validated SPD; sparse-SPD solve only fails on factor corruption.
136                reml_contract_panic(format!(
137                    "SparseCholeskyOperator exact trace_hinv_operator solve failed: {e}"
138                ))
139            });
140            start = end;
141        }
142
143        trace
144    }
145
146    pub(crate) fn solve_operator_column_range_rows_exact(
147        &self,
148        op: &dyn HyperOperator,
149        col_start: usize,
150        col_end: usize,
151        row_start: usize,
152        row_end: usize,
153    ) -> Result<Array2<f64>, String> {
154        let chunk = Self::OPERATOR_SOLVE_CHUNK.min(self.n_dim.max(1));
155        let cols_total = col_end - col_start;
156        let rows_total = row_end - row_start;
157        let mut solved = Array2::<f64>::zeros((rows_total, cols_total));
158        let mut rhs_block = Array2::<f64>::zeros((self.n_dim, chunk));
159        let mut start = col_start;
160
161        while start < col_end {
162            let end = (start + chunk).min(col_end);
163            let cols = end - start;
164            op.mul_basis_columns_into(start, rhs_block.slice_mut(ndarray::s![.., ..cols]));
165
166            let solved_block = if cols == chunk {
167                gam_linalg::sparse_exact::solve_sparse_spdmulti_rows(
168                    &self.factor,
169                    &rhs_block,
170                    row_start,
171                    row_end,
172                )
173            } else {
174                let rhs_view = rhs_block.slice(ndarray::s![.., ..cols]);
175                gam_linalg::sparse_exact::solve_sparse_spdmulti_rows(
176                    &self.factor,
177                    &rhs_view,
178                    row_start,
179                    row_end,
180                )
181            }
182            .map_err(|e| {
183                format!(
184                    "SparseCholeskyOperator::solve_operator_column_range_rows_exact multi-solve failed: {e}"
185                )
186            })?;
187            solved
188                .slice_mut(ndarray::s![.., start - col_start..end - col_start])
189                .assign(&solved_block);
190            start = end;
191        }
192
193        Ok(solved)
194    }
195
196    pub(crate) fn trace_hinv_matrix_operator_cross_exact(
197        &self,
198        matrix: &Array2<f64>,
199        op: &dyn HyperOperator,
200    ) -> f64 {
201        if let Some((_, range_start, range_end)) = op.block_local_data()
202            && range_end - range_start < self.n_dim
203        {
204            return self.trace_hinv_matrix_block_operator_cross_exact(
205                matrix,
206                op,
207                range_start,
208                range_end,
209            );
210        }
211
212        let solved_matrix = self.solve_multi(matrix);
213        let chunk = Self::OPERATOR_SOLVE_CHUNK.min(self.n_dim.max(1));
214        let mut rhs_block = Array2::<f64>::zeros((self.n_dim, chunk));
215        let mut trace = 0.0_f64;
216        let (range_start, range_end) = op
217            .block_local_data()
218            .map(|(_, start, end)| (start, end))
219            .unwrap_or((0, self.n_dim));
220        let mut start = range_start;
221
222        while start < range_end {
223            let end = (start + chunk).min(range_end);
224            let cols = end - start;
225            op.mul_basis_columns_into(start, rhs_block.slice_mut(ndarray::s![.., ..cols]));
226
227            let solved_op = if cols == chunk {
228                gam_linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, &rhs_block)
229            } else {
230                let rhs_view = rhs_block.slice(ndarray::s![.., ..cols]);
231                gam_linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, &rhs_view)
232            };
233
234            let solved_op = solved_op.unwrap_or_else(|e| {
235                // SAFETY: `self.factor` is the validated SPD Cholesky factor
236                // (set only after successful factorization); the RHS shape
237                // is `n_dim × cols` by construction. A sparse-SPD multi-RHS
238                // failure here would mean factor corruption, which the
239                // construction invariant forbids.
240                // SAFETY: self.factor is validated SPD; matrix/operator multi-solve only fails on corruption.
241                panic!("SparseCholeskyOperator exact matrix/operator cross solve failed: {e}")
242            });
243
244            for local_col in 0..cols {
245                let matrix_row = start + local_col;
246                for row in 0..self.n_dim {
247                    trace += solved_matrix[[matrix_row, row]] * solved_op[[row, local_col]];
248                }
249            }
250            start = end;
251        }
252
253        trace
254    }
255
256    pub(crate) fn trace_hinv_matrix_block_operator_cross_exact(
257        &self,
258        matrix: &Array2<f64>,
259        op: &dyn HyperOperator,
260        range_start: usize,
261        range_end: usize,
262    ) -> f64 {
263        let t_start = std::time::Instant::now();
264        let chunk = Self::OPERATOR_SOLVE_CHUNK.min(self.n_dim.max(1));
265        let mut op_rhs_block = Array2::<f64>::zeros((self.n_dim, chunk));
266        let mut eye_rhs_block = Array2::<f64>::zeros((self.n_dim, chunk));
267        let mut trace = 0.0_f64;
268        let mut start = range_start;
269
270        while start < range_end {
271            let end = (start + chunk).min(range_end);
272            let cols = end - start;
273            op.mul_basis_columns_into(start, op_rhs_block.slice_mut(ndarray::s![.., ..cols]));
274
275            eye_rhs_block.fill(0.0);
276            for local_col in 0..cols {
277                eye_rhs_block[[start + local_col, local_col]] = 1.0;
278            }
279
280            let solved_op = if cols == chunk {
281                gam_linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, &op_rhs_block)
282            } else {
283                let rhs_view = op_rhs_block.slice(ndarray::s![.., ..cols]);
284                gam_linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, &rhs_view)
285            };
286            let solved_op = solved_op.unwrap_or_else(|e| {
287                // SAFETY: same invariant — `self.factor` is the validated
288                // SPD factor and `op_rhs_block` is allocated as
289                // `n_dim × chunk`, so dimensions are compatible by
290                // construction. Any failure indicates factor corruption.
291                // SAFETY: self.factor is validated SPD; block-operator multi-solve only fails on corruption.
292                panic!(
293                    "SparseCholeskyOperator exact matrix/block-operator cross operator solve failed: {e}"
294                )
295            });
296
297            let solved_eye = if cols == chunk {
298                gam_linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, &eye_rhs_block)
299            } else {
300                let rhs_view = eye_rhs_block.slice(ndarray::s![.., ..cols]);
301                gam_linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, &rhs_view)
302            };
303            let solved_eye = solved_eye.unwrap_or_else(|e| {
304                // SAFETY: same invariant — `self.factor` is validated SPD
305                // and `eye_rhs_block` was just filled as an identity-block
306                // RHS sized `n_dim × chunk`. Failure indicates factor
307                // corruption, forbidden by the construction invariant.
308                // SAFETY: self.factor is validated SPD; identity-RHS multi-solve only fails on corruption.
309                panic!(
310                    "SparseCholeskyOperator exact matrix/block-operator cross identity solve failed: {e}"
311                )
312            });
313
314            let selected_rows_t = matrix.t().dot(&solved_eye);
315            for local_col in 0..cols {
316                for row in 0..self.n_dim {
317                    trace += selected_rows_t[[row, local_col]] * solved_op[[row, local_col]];
318                }
319            }
320            start = end;
321        }
322
323        let elapsed_ms = t_start.elapsed().as_secs_f64() * 1000.0;
324        if elapsed_ms > REML_TRACE_SLOW_LOG_MS {
325            log::info!(
326                "[REML-trace] matrix_block_op_cross_exact | n_dim={} | block={} | {:.1}ms",
327                self.n_dim,
328                range_end - range_start,
329                elapsed_ms
330            );
331        }
332        trace
333    }
334
335    pub(crate) fn trace_hinv_operator_cross_exact(
336        &self,
337        left: &dyn HyperOperator,
338        right: &dyn HyperOperator,
339    ) -> f64 {
340        let (left_start, left_end) = left
341            .block_local_data()
342            .map(|(_, start, end)| (start, end))
343            .unwrap_or((0, self.n_dim));
344        let (right_start, right_end) = right
345            .block_local_data()
346            .map(|(_, start, end)| (start, end))
347            .unwrap_or((0, self.n_dim));
348
349        let solved_left = self
350            .solve_operator_column_range_rows_exact(
351                left,
352                left_start,
353                left_end,
354                right_start,
355                right_end,
356            )
357            .unwrap_or_else(|e| {
358                // SAFETY: `solve_operator_column_range_rows_exact` only
359                // forwards `solve_sparse_spdmulti` errors. `self.factor` is
360                // the validated SPD Cholesky factor; column ranges come
361                // from the operator's own `block_local_data` (or fall back
362                // to `0..n_dim`), so failure indicates factor corruption.
363                // SAFETY: self.factor is validated SPD; operator cross-left solve only fails on corruption.
364                panic!("SparseCholeskyOperator exact operator cross left solve failed: {e}")
365            });
366        let same_operator =
367            std::ptr::addr_eq(left, right) && left_start == right_start && left_end == right_end;
368        let solved_right = if same_operator {
369            None
370        } else {
371            Some(
372                self.solve_operator_column_range_rows_exact(
373                    right,
374                    right_start,
375                    right_end,
376                    left_start,
377                    left_end,
378                )
379                .unwrap_or_else(|e| {
380                    // SAFETY: mirrors the left-solve invariant above —
381                    // `self.factor` is validated SPD and the column range
382                    // is taken from `right`'s own `block_local_data`,
383                    // so failure indicates factor corruption.
384                    // SAFETY: self.factor is validated SPD; operator cross-right solve only fails on corruption.
385                    panic!("SparseCholeskyOperator exact operator cross right solve failed: {e}")
386                }),
387            )
388        };
389
390        let right_cols = right_end - right_start;
391        let mut trace = 0.0;
392        for left_col in 0..(left_end - left_start) {
393            for right_col in 0..right_cols {
394                let right_value = match solved_right.as_ref() {
395                    Some(solved) => solved[[left_col, right_col]],
396                    None => solved_left[[left_col, right_col]],
397                };
398                trace += solved_left[[right_col, left_col]] * right_value;
399            }
400        }
401        trace
402    }
403}
404
405impl HessianOperator for SparseCholeskyOperator {
406    fn logdet(&self) -> f64 {
407        self.cached_logdet
408    }
409
410    fn assemble_h_dense_for_tangent_projection(&self) -> Result<Array2<f64>, String> {
411        let h = gam_linalg::sparse_exact::assemble_sparse_factor_h_dense(&self.factor)
412            .map_err(|e| e.to_string())?;
413        if h.nrows() != self.n_dim || h.ncols() != self.n_dim {
414            return Err(format!(
415                "sparse Cholesky tangent projection dense H has shape {}x{}, expected {}x{}",
416                h.nrows(),
417                h.ncols(),
418                self.n_dim,
419                self.n_dim
420            ));
421        }
422        Ok(h)
423    }
424
425    fn trace_hinv_product(&self, a: &Array2<f64>) -> f64 {
426        // When Takahashi is available, use direct entry lookup for tr(H^{-1} A).
427        // This is O(p^2) via dense A iteration but avoids p column solves.
428        if let Some(ref taka) = self.takahashi {
429            let mut trace = 0.0;
430            for i in 0..a.nrows() {
431                let a_ii = a[[i, i]];
432                if a_ii.abs() > 1e-30 {
433                    trace += taka.get(i, i) * a_ii;
434                }
435                for j in (i + 1)..a.ncols() {
436                    let pair = a[[i, j]] + a[[j, i]];
437                    if pair.abs() > 1e-30 {
438                        trace += taka.get(i, j) * pair;
439                    }
440                }
441            }
442            return trace;
443        }
444        gam_linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, a)
445            .unwrap_or_else(|e| {
446                // SAFETY: `self.factor` is the validated SPD Cholesky factor
447                // (created by `SparseCholeskyOperator::new` only after a
448                // successful factorization); a single-square multi-RHS solve
449                // here can only fail on factor corruption, which the
450                // construction invariant forbids.
451                // SAFETY: self.factor is validated SPD; single-square multi-solve only fails on corruption.
452                panic!("SparseCholeskyOperator exact trace_hinv_product solve failed: {e}")
453            })
454            .diag()
455            .sum()
456    }
457
458    fn trace_hinv_operator(&self, op: &dyn HyperOperator) -> f64 {
459        if let Some(ref taka) = self.takahashi {
460            if let Some((local, start, end)) = op.block_local_data() {
461                assert_eq!(local.nrows(), end - start);
462                return Self::takahashi_block_trace(taka, local, start);
463            }
464            // For other non-implicit operators: materialize and use Takahashi lookups
465            if !op.is_implicit() {
466                let dense = op.to_dense();
467                return self.trace_hinv_product(&dense);
468            }
469        }
470        self.trace_hinv_operator_exact(op)
471    }
472
473    fn trace_logdet_operator(&self, op: &dyn HyperOperator) -> f64 {
474        self.trace_hinv_operator(op)
475    }
476
477    fn solve(&self, rhs: &Array1<f64>) -> Array1<f64> {
478        // SAFETY: `self.factor` is the validated SPD Cholesky factor stored
479        // at construction time; a triangular solve against an already-built
480        // factor can only fail on factor corruption, which the
481        // `SparseCholeskyOperator` construction invariant forbids.
482        gam_linalg::sparse_exact::solve_sparse_spd(&self.factor, rhs)
483            // SAFETY: self.factor is validated SPD; triangular solve only fails on corruption.
484            .unwrap_or_else(|e| panic!("SparseCholeskyOperator exact solve failed: {e}"))
485    }
486
487    fn solve_multi(&self, rhs: &Array2<f64>) -> Array2<f64> {
488        // SAFETY: same SPD-factor invariant as `solve` above — `self.factor`
489        // was created from a successful Cholesky factorization, so a
490        // multi-RHS solve can only fail on factor corruption.
491        gam_linalg::sparse_exact::solve_sparse_spdmulti(&self.factor, rhs)
492            // SAFETY: self.factor is validated SPD; multi-RHS solve only fails on corruption.
493            .unwrap_or_else(|e| panic!("SparseCholeskyOperator exact multi-solve failed: {e}"))
494    }
495
496    fn trace_hinv_product_cross(&self, a: &Array2<f64>, b: &Array2<f64>) -> f64 {
497        // For general dense matrices, column solves are better than materializing
498        // full Z from Takahashi (O(p * nnz) vs O(p³)). Takahashi cross-traces
499        // are only used for block-local operators via trace_hinv_operator_cross.
500        let solved_a = self.solve_multi(a);
501        if std::ptr::eq(a, b) {
502            return trace_matrix_product(&solved_a, &solved_a);
503        }
504        let solved_b = self.solve_multi(b);
505        trace_matrix_product(&solved_a, &solved_b)
506    }
507
508    fn trace_hinv_matrix_operator_cross(
509        &self,
510        matrix: &Array2<f64>,
511        op: &dyn HyperOperator,
512    ) -> f64 {
513        // For mixed dense-matrix × block-local-operator, column solves are
514        // still better than materializing full Z. Only use Takahashi when both
515        // sides are block-local (handled in trace_hinv_operator_cross).
516        self.trace_hinv_matrix_operator_cross_exact(matrix, op)
517    }
518
519    fn trace_hinv_operator_cross(
520        &self,
521        left: &dyn HyperOperator,
522        right: &dyn HyperOperator,
523    ) -> f64 {
524        // Takahashi fast path: when both operators are block-local to the same
525        // block, compute tr(Z A Z B) using only the block of Z = H⁻¹.
526        if let Some(ref taka) = self.takahashi
527            && let (Some((a_local, a_start, a_end)), Some((b_local, b_start, b_end))) =
528                (left.block_local_data(), right.block_local_data())
529            && a_start == b_start
530            && a_end == b_end
531        {
532            // Same block: tr(Z_block * A_local * Z_block * B_local)
533            let za = Self::takahashi_left_multiply_block(taka, a_local, a_start);
534            if std::ptr::addr_eq(left, right) {
535                return trace_matrix_product(&za, &za);
536            }
537            let zb = Self::takahashi_left_multiply_block(taka, b_local, b_start);
538            // tr(ZA * ZB) = sum_ij (ZA)_ij * (ZB^T)_ij
539            return (&za * &zb.t()).sum();
540        }
541        // Different blocks: column solves are better than materializing
542        // full p×p Z. Fall through to exact path.
543        self.trace_hinv_operator_cross_exact(left, right)
544    }
545
546    fn trace_logdet_hessian_cross_matrix_operator(
547        &self,
548        h_i: &Array2<f64>,
549        h_j: &dyn HyperOperator,
550    ) -> f64 {
551        -self.trace_hinv_matrix_operator_cross(h_i, h_j)
552    }
553
554    fn trace_logdet_hessian_cross_operator(
555        &self,
556        h_i: &dyn HyperOperator,
557        h_j: &dyn HyperOperator,
558    ) -> f64 {
559        -self.trace_hinv_operator_cross(h_i, h_j)
560    }
561
562    fn active_rank(&self) -> usize {
563        self.n_dim
564    }
565
566    fn dim(&self) -> usize {
567        self.n_dim
568    }
569}
570
571// BlockCoupledDerivativeProvider was removed — its functionality is now handled
572// by the `deriv_provider` trait (HessianDerivativeProvider), with concrete
573// implementations like JointModelDerivProvider and SurvivalDerivProvider
574// capturing the full correction including Jacobian sensitivity, weight
575// sensitivity, and basis sensitivity.
576
577// ═══════════════════════════════════════════════════════════════════════════
578//  Cholesky-backed value-only HessianOperator (logdet + solve, no traces)
579// ═══════════════════════════════════════════════════════════════════════════
580
581/// Dense Cholesky-backed [`HessianOperator`] for `EvalMode::ValueOnly` paths.
582///
583/// When the penalized Hessian is known to be SPD (no Firth bias reduction, no
584/// hard linear constraints, no `HardPseudo` mode), the REML/LAML cost needs
585/// only two Hessian services:
586///
587/// - `logdet()` — used directly in the `½ log|H|` cost term.
588/// - `solve(rhs)` / `solve_multi(rhs)` — used for the optional IFT
589///   cost correction `−½ rᵀ H⁻¹ r`.
590///
591/// An LLT Cholesky factorization delivers both in `O(p³/3)` flops versus
592/// the `O(9·p³)` full eigendecomposition of [`DenseSpectralOperator`], giving
593/// a multi-× speedup per outer REML line-search probe.
594///
595/// Gradient traces (`trace_hinv_product`) are satisfied via column-by-column
596/// forward/back solves so that the operator remains valid if the evaluator
597/// ever reaches a gradient path unexpectedly. Under normal use
598/// `EvalMode::ValueOnly` returns before any trace call.
599pub struct DenseCholeskyValueOnlyOperator {
600    /// LLT Cholesky factor.
601    pub(crate) chol: gam_linalg::faer_ndarray::FaerCholeskyFactor,
602    /// `2 · Σ ln(diag L)` — cached at construction time.
603    pub(crate) cached_logdet: f64,
604    /// Full parameter dimension.
605    pub(crate) n_dim: usize,
606}
607
608impl DenseCholeskyValueOnlyOperator {
609    /// Factorize `h` (assumed SPD) via LLT and cache the log-determinant.
610    ///
611    /// Returns `Err` if `h` is not square, not SPD, or contains non-finite
612    /// entries. Callers should fall back to [`DenseSpectralOperator`] on
613    /// failure (e.g. near-singular Hessians that need soft regularization).
614    pub fn from_spd(h: &Array2<f64>) -> Result<Self, String> {
615        use gam_linalg::faer_ndarray::FaerCholesky;
616        use faer::Side;
617
618        let n = h.nrows();
619        if n != h.ncols() {
620            return Err(format!(
621                "DenseCholeskyValueOnlyOperator: expected square matrix, got {}×{}",
622                n,
623                h.ncols()
624            ));
625        }
626        let chol = h
627            .cholesky(Side::Lower)
628            .map_err(|e| format!("DenseCholeskyValueOnlyOperator LLT failed: {e}"))?;
629        let diag = chol.diag();
630        let cached_logdet = 2.0 * diag.iter().map(|&d| d.ln()).sum::<f64>();
631        Ok(Self {
632            chol,
633            cached_logdet,
634            n_dim: n,
635        })
636    }
637}
638
639impl HessianOperator for DenseCholeskyValueOnlyOperator {
640    fn logdet(&self) -> f64 {
641        self.cached_logdet
642    }
643
644    fn trace_hinv_product(&self, a: &Array2<f64>) -> f64 {
645        // tr(H⁻¹ A) = Σ_j [H⁻¹ A]_jj.
646        // Compute H⁻¹ A via multi-column solve and sum the diagonal.
647        let hinv_a = self.chol.solve_mat(a);
648        hinv_a.diag().iter().sum()
649    }
650
651    fn solve(&self, rhs: &Array1<f64>) -> Array1<f64> {
652        self.chol.solvevec(rhs)
653    }
654
655    fn solve_multi(&self, rhs: &Array2<f64>) -> Array2<f64> {
656        self.chol.solve_mat(rhs)
657    }
658
659    fn active_rank(&self) -> usize {
660        // LLT succeeded ⟹ all pivots are positive ⟹ full rank.
661        self.n_dim
662    }
663
664    fn dim(&self) -> usize {
665        self.n_dim
666    }
667}
668
669// ═══════════════════════════════════════════════════════════════════════════
670//  Block-coupled HessianOperator for joint multi-block models
671// ═══════════════════════════════════════════════════════════════════════════
672
673/// Block-coupled Hessian operator for joint multi-block models (GAMLSS, survival).
674///
675/// Wraps a [`DenseSpectralOperator`] over the full assembled joint Hessian while
676/// retaining block-structure metadata. All [`HessianOperator`] trait methods
677/// delegate to the inner spectral decomposition, ensuring a single
678/// eigendecomposition governs logdet, trace, and solve.
679///
680/// # Block structure
681///
682/// A joint model with B parameter blocks has a joint Hessian of dimension
683/// `p_total = sum_b p_b`. Each block occupies rows/columns
684/// # When to use
685///
686/// Use `BlockCoupledOperator` whenever building an [`InnerSolution`] for a joint
687/// multi-block model. It replaces the pattern of constructing a raw
688/// `DenseSpectralOperator` and manually tracking block ranges separately.
689pub struct BlockCoupledOperator {
690    /// Inner spectral operator over the full joint Hessian.
691    pub(crate) inner: DenseSpectralOperator,
692}
693
694impl BlockCoupledOperator {
695    /// Construct from an assembled joint Hessian using the supplied
696    /// [`PseudoLogdetMode`].  Internally performs a single
697    /// eigendecomposition of `joint_hessian`.
698    pub fn from_joint_hessian_with_mode(
699        joint_hessian: &Array2<f64>,
700        mode: PseudoLogdetMode,
701    ) -> Result<Self, String> {
702        let inner = DenseSpectralOperator::from_symmetric_with_mode(joint_hessian, mode)
703            .map_err(|e| format!("BlockCoupledOperator eigendecomposition: {e}"))?;
704
705        Ok(Self { inner })
706    }
707}
708
709impl HessianOperator for BlockCoupledOperator {
710    fn logdet(&self) -> f64 {
711        self.inner.logdet()
712    }
713
714    fn as_exact_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
715        self.inner.as_exact_dense_spectral()
716    }
717
718    fn assemble_h_dense_for_tangent_projection(&self) -> Result<Array2<f64>, String> {
719        self.inner.assemble_h_dense_for_tangent_projection()
720    }
721
722    fn trace_hinv_product(&self, a: &Array2<f64>) -> f64 {
723        self.inner.trace_hinv_product(a)
724    }
725
726    fn trace_logdet_gradient(&self, a: &Array2<f64>) -> f64 {
727        self.inner.trace_logdet_gradient(a)
728    }
729
730    fn xt_logdet_kernel_x_diagonal(&self, x: &DesignMatrix) -> Array1<f64> {
731        self.inner.xt_logdet_kernel_x_diagonal(x)
732    }
733
734    fn trace_logdet_h_k(
735        &self,
736        a_k: &Array2<f64>,
737        third_deriv_correction: Option<&Array2<f64>>,
738    ) -> f64 {
739        self.inner.trace_logdet_h_k(a_k, third_deriv_correction)
740    }
741
742    fn trace_logdet_operator(&self, op: &dyn HyperOperator) -> f64 {
743        self.inner.trace_logdet_operator(op)
744    }
745
746    fn trace_logdet_hessian_cross(&self, h_i: &Array2<f64>, h_j: &Array2<f64>) -> f64 {
747        self.inner.trace_logdet_hessian_cross(h_i, h_j)
748    }
749
750    fn solve(&self, rhs: &Array1<f64>) -> Array1<f64> {
751        self.inner.solve(rhs)
752    }
753
754    fn solve_multi(&self, rhs: &Array2<f64>) -> Array2<f64> {
755        self.inner.solve_multi(rhs)
756    }
757
758    fn trace_hinv_product_cross(&self, a: &Array2<f64>, b: &Array2<f64>) -> f64 {
759        self.inner.trace_hinv_product_cross(a, b)
760    }
761
762    fn trace_hinv_matrix_operator_cross(
763        &self,
764        matrix: &Array2<f64>,
765        op: &dyn HyperOperator,
766    ) -> f64 {
767        self.inner.trace_hinv_matrix_operator_cross(matrix, op)
768    }
769
770    fn trace_hinv_operator_cross(
771        &self,
772        left: &dyn HyperOperator,
773        right: &dyn HyperOperator,
774    ) -> f64 {
775        self.inner.trace_hinv_operator_cross(left, right)
776    }
777
778    fn active_rank(&self) -> usize {
779        self.inner.active_rank()
780    }
781
782    fn dim(&self) -> usize {
783        self.inner.dim()
784    }
785
786    fn is_dense(&self) -> bool {
787        true
788    }
789
790    fn prefers_stochastic_trace_estimation(&self) -> bool {
791        false
792    }
793
794    fn logdet_traces_match_hinv_kernel(&self) -> bool {
795        false
796    }
797
798    fn as_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
799        Some(&self.inner)
800    }
801}
802
803// ═══════════════════════════════════════════════════════════════════════════
804//  Matrix-free SPD HessianOperator implementation
805// ═══════════════════════════════════════════════════════════════════════════
806
807/// Operator-backed SPD Hessian with exact spectral REML algebra.
808///
809/// The operator closure is still useful for construction paths that naturally
810/// expose HVPs, but REML cost/gradient/Hessian terms must all come from one
811/// exact decomposition so `∂ log|H| = tr(H⁻¹ ∂H)` holds.  We therefore
812/// materialize the coefficient Hessian by canonical-basis HVPs under an
813/// explicit memory cap and delegate logdet, traces, and solves to
814/// `DenseSpectralOperator`.
815pub struct MatrixFreeSpdOperator {
816    pub(crate) apply: Arc<dyn Fn(&Array1<f64>) -> Array1<f64> + Send + Sync>,
817    // Optional single-pass dense assembly of the SAME penalized operator that
818    // `apply` realizes matrix-free, i.e. `H_unpen + S_λ + scale·H_Φ`. When the
819    // operator source can structurally build its full dense matrix in one
820    // chunked BLAS-3 `XᵀWX` row pass (BMS's `hessian_dense_forced` +
821    // construction-site penalty/Jeffreys assembly), `materialize_dense_operator`
822    // calls THIS instead of `dim` canonical-basis matvecs — each of which is a
823    // full n-row pass through the matrix-free operator. One n-pass replaces
824    // `dim` n-passes for the LAML logdet factorization. The closure must return
825    // a matrix numerically identical (up to symmetrization) to the matvec
826    // reconstruction `H·I`; `None` means no direct build is available and the
827    // matvec path is used (the result is bit-for-bit the prior behavior).
828    pub(crate) dense_assemble: Option<Arc<dyn Fn() -> Option<Array2<f64>> + Send + Sync>>,
829    pub(crate) cached_logdet: gam_runtime::resource::RayonSafeOnce<f64>,
830    pub(crate) n_dim: usize,
831    // `RayonSafeOnce`, not `OnceLock`: `materialize_dense_operator` invokes
832    // `apply`, which for operator-source joint Hessians dispatches a nested
833    // `into_par_iter` (e.g. `exact_newton_joint_hessian_matvec_from_cache`).
834    // With a plain `OnceLock`, concurrent rayon workers entering
835    // `solve`/`logdet` from inside an outer par_iter would park on the
836    // OnceLock's OS condvar; the leader's nested par_iter would then starve
837    // for workers. `RayonSafeOnce` keeps init lock-free — racers may
838    // duplicate the dim²-matvec build, but the first to publish wins and
839    // steady-state matches `OnceLock`.
840    pub(crate) dense_spectral: gam_runtime::resource::RayonSafeOnce<Option<DenseSpectralOperator>>,
841    // Pseudo-logdet convention threaded from the family. The dense outer path
842    // already plumbs `PseudoLogdetMode` into `BlockCoupledOperator`; the
843    // matrix-free path materializes a `DenseSpectralOperator` lazily and must
844    // use the same convention so that `logdet`, `trace_hinv_product`, the
845    // IFT response `H⁻¹ g`, and every cross-trace agree with the dense path.
846    // Without this, families that declare `HardPseudo` (BMS, GAMLSS) silently
847    // get Smooth full-spectrum semantics on the matrix-free path, and outer
848    // gradients are inflated by `1/σ_j` over numerical null directions.
849    pub(crate) mode: PseudoLogdetMode,
850}
851
852impl MatrixFreeSpdOperator {
853    pub(crate) const EXACT_DENSE_SPECTRAL_MAX_BYTES: usize = 512 * 1024 * 1024;
854    pub(crate) const EXACT_DENSE_SPECTRAL_ARRAYS: usize = 6;
855
856    pub fn new_with_mode<F>(dim: usize, apply: F, mode: PseudoLogdetMode) -> Self
857    where
858        F: Fn(&Array1<f64>) -> Array1<f64> + Send + Sync + 'static,
859    {
860        Self::new_with_mode_and_dense_assemble(dim, apply, mode, None)
861    }
862
863    /// Like [`new_with_mode`], but additionally accepts an optional single-pass
864    /// dense assembly of the same penalized operator. When present and it yields
865    /// a matrix, `materialize_dense_operator` uses it instead of the `dim`
866    /// canonical-basis matvecs. See the field doc on `dense_assemble`.
867    pub fn new_with_mode_and_dense_assemble<F>(
868        dim: usize,
869        apply: F,
870        mode: PseudoLogdetMode,
871        dense_assemble: Option<Arc<dyn Fn() -> Option<Array2<f64>> + Send + Sync>>,
872    ) -> Self
873    where
874        F: Fn(&Array1<f64>) -> Array1<f64> + Send + Sync + 'static,
875    {
876        let apply = Arc::new(apply);
877
878        Self {
879            apply,
880            dense_assemble,
881            cached_logdet: gam_runtime::resource::RayonSafeOnce::new(),
882            n_dim: dim,
883            dense_spectral: gam_runtime::resource::RayonSafeOnce::new(),
884            mode,
885        }
886    }
887
888    pub(crate) fn exact_dense_spectral_bytes(&self) -> Option<usize> {
889        self.n_dim
890            .checked_mul(self.n_dim)?
891            .checked_mul(std::mem::size_of::<f64>())?
892            .checked_mul(Self::EXACT_DENSE_SPECTRAL_ARRAYS)
893    }
894
895    pub(crate) fn exact_dense_spectral_budget_ok(&self) -> bool {
896        match self.exact_dense_spectral_bytes() {
897            Some(bytes) if bytes <= Self::EXACT_DENSE_SPECTRAL_MAX_BYTES => true,
898            Some(bytes) => {
899                log::error!(
900                    "MatrixFreeSpdOperator exact dense spectral materialization requires {:.2} GiB \
901                     for dim={}, exceeding the {:.2} GiB cap",
902                    bytes as f64 / (1024.0 * 1024.0 * 1024.0),
903                    self.n_dim,
904                    Self::EXACT_DENSE_SPECTRAL_MAX_BYTES as f64 / (1024.0 * 1024.0 * 1024.0),
905                );
906                false
907            }
908            None => {
909                log::error!(
910                    "MatrixFreeSpdOperator exact dense spectral byte count overflow for dim={}",
911                    self.n_dim
912                );
913                false
914            }
915        }
916    }
917
918    pub(crate) fn materialize_dense_operator(&self) -> Option<DenseSpectralOperator> {
919        if !self.exact_dense_spectral_budget_ok() {
920            return None;
921        }
922        let materialize_start = std::time::Instant::now();
923        // Fast path: structural single-pass dense assembly of the SAME penalized
924        // operator (`H_unpen + S_λ + scale·H_Φ`). One chunked BLAS-3 `XᵀWX`
925        // row pass replaces `n_dim` canonical-basis matvecs, each a full n-row
926        // pass through the matrix-free operator. The matvec fallback below is the
927        // exact same algebra column-for-column, so the spectrum/logdet match.
928        let (matrix, matvec_count) =
929            match self.dense_assemble.as_ref().and_then(|assemble| assemble()) {
930                Some(mut direct)
931                    if direct.nrows() == self.n_dim
932                        && direct.ncols() == self.n_dim
933                        && direct.iter().all(|v| v.is_finite()) =>
934                {
935                    // Symmetrize defensively; the direct build is structurally
936                    // symmetric but reduction-order f.p. noise can desync mirror
937                    // entries, exactly as the matvec path symmetrizes below.
938                    for i in 0..self.n_dim {
939                        for j in (i + 1)..self.n_dim {
940                            let avg = 0.5 * (direct[[i, j]] + direct[[j, i]]);
941                            direct[[i, j]] = avg;
942                            direct[[j, i]] = avg;
943                        }
944                    }
945                    (direct, 0usize)
946                }
947                _ => {
948                    let mut matrix = Array2::<f64>::zeros((self.n_dim, self.n_dim));
949                    let mut basis = Array1::<f64>::zeros(self.n_dim);
950                    for j in 0..self.n_dim {
951                        basis[j] = 1.0;
952                        let col = (self.apply)(&basis);
953                        basis[j] = 0.0;
954                        if col.len() != self.n_dim || !col.iter().all(|v| v.is_finite()) {
955                            return None;
956                        }
957                        matrix.column_mut(j).assign(&col);
958                    }
959                    for i in 0..self.n_dim {
960                        for j in (i + 1)..self.n_dim {
961                            let avg = 0.5 * (matrix[[i, j]] + matrix[[j, i]]);
962                            matrix[[i, j]] = avg;
963                            matrix[[j, i]] = avg;
964                        }
965                    }
966                    (matrix, self.n_dim)
967                }
968            };
969        let result = DenseSpectralOperator::from_symmetric_with_mode(&matrix, self.mode).ok();
970        log::info!(
971            "[STAGE] matrix_free_spd materialize n_dim={} matvec_count={} elapsed={:.3}s",
972            self.n_dim,
973            matvec_count,
974            materialize_start.elapsed().as_secs_f64(),
975        );
976        result
977    }
978
979    pub(crate) fn dense_spectral(&self) -> Option<&DenseSpectralOperator> {
980        self.dense_spectral
981            .get_or_compute(|| self.materialize_dense_operator())
982            .as_ref()
983    }
984
985    pub(crate) fn exact_dense_spectral(&self) -> &DenseSpectralOperator {
986        self.dense_spectral().expect(
987            "MatrixFreeSpdOperator exact REML algebra requires dense spectral materialization within the configured budget",
988        )
989    }
990
991    pub(crate) fn use_trace_cg(&self, rel_tol: f64) -> bool {
992        rel_tol.is_finite()
993            && rel_tol > 0.0
994            && self.prefers_stochastic_trace_estimation()
995            && self.has_matrix_free_trace_cg_operator()
996    }
997
998    pub(crate) fn cg_trace_solve(
999        &self,
1000        rhs: &Array1<f64>,
1001        rel_tol: f64,
1002        probe_id: Option<u64>,
1003        trace_state: Option<&Arc<Mutex<StochasticTraceState>>>,
1004    ) -> Array1<f64> {
1005        let dim = rhs.len();
1006        if dim != self.n_dim {
1007            return self.solve(rhs);
1008        }
1009
1010        let (initial, warm_start_used) = match (probe_id, trace_state) {
1011            (Some(id), Some(state)) => {
1012                let cached = match state.lock() {
1013                    Ok(guard) => guard.cg_warm_starts.get(&id).cloned(),
1014                    Err(poisoned) => poisoned.into_inner().cg_warm_starts.get(&id).cloned(),
1015                };
1016                match cached {
1017                    Some(x) if x.len() == dim => (x, true),
1018                    _ => (Array1::<f64>::zeros(dim), false),
1019                }
1020            }
1021            _ => (Array1::<f64>::zeros(dim), false),
1022        };
1023
1024        let Some((solution, iters, residual_norm)) =
1025            conjugate_gradient_trace_solve(rhs, rel_tol, initial, |v| (self.apply)(v))
1026        else {
1027            return self.solve(rhs);
1028        };
1029
1030        if let Some(state) = trace_state {
1031            let mut guard = match state.lock() {
1032                Ok(guard) => guard,
1033                Err(poisoned) => poisoned.into_inner(),
1034            };
1035            guard.last_linear_residual_norm = Some(
1036                guard
1037                    .last_linear_residual_norm
1038                    .unwrap_or(0.0)
1039                    .max(residual_norm),
1040            );
1041            if let Some(id) = probe_id {
1042                guard.cg_warm_starts.insert(id, solution.clone());
1043            }
1044        }
1045
1046        let probe_label = probe_id
1047            .map(|id| id.to_string())
1048            .unwrap_or_else(|| "untracked".to_string());
1049        log::info!(
1050            "[CG-TRACE] probe_id={} iters={} rel_tol={} warm_start_used={}",
1051            probe_label,
1052            iters,
1053            rel_tol,
1054            warm_start_used
1055        );
1056
1057        solution
1058    }
1059}
1060
1061pub(crate) fn conjugate_gradient_trace_solve<F>(
1062    rhs: &Array1<f64>,
1063    rel_tol: f64,
1064    mut x: Array1<f64>,
1065    apply: F,
1066) -> Option<(Array1<f64>, usize, f64)>
1067where
1068    F: Fn(&Array1<f64>) -> Array1<f64>,
1069{
1070    let dim = rhs.len();
1071    if x.len() != dim {
1072        return None;
1073    }
1074
1075    let rhs_norm_sq = rhs.dot(rhs);
1076    if !rhs_norm_sq.is_finite() {
1077        return None;
1078    }
1079    if rhs_norm_sq <= f64::MIN_POSITIVE {
1080        return Some((Array1::<f64>::zeros(dim), 0, 0.0));
1081    }
1082
1083    let target_sq = (rel_tol * rel_tol * rhs_norm_sq).max(f64::MIN_POSITIVE);
1084    let mut r = rhs.clone();
1085    if x.iter().any(|value| *value != 0.0) {
1086        let ax = apply(&x);
1087        if ax.len() != dim || !ax.iter().all(|value| value.is_finite()) {
1088            return None;
1089        }
1090        r.scaled_add(-1.0, &ax);
1091    }
1092
1093    let mut rs_old = r.dot(&r);
1094    if !rs_old.is_finite() {
1095        return None;
1096    }
1097    if rs_old <= target_sq {
1098        return Some((x, 0, rs_old.max(0.0).sqrt()));
1099    }
1100
1101    let mut p = r.clone();
1102    let mut iters = 0usize;
1103    let mut residual_norm = rs_old.max(0.0).sqrt();
1104    for k in 0..dim.max(1) {
1105        let ap = apply(&p);
1106        if ap.len() != dim || !ap.iter().all(|value| value.is_finite()) {
1107            return None;
1108        }
1109        let denom = p.dot(&ap);
1110        if !denom.is_finite() || denom <= 0.0 {
1111            log::warn!(
1112                "[CG-TRACE] non-positive curvature in trace CG at iter={} denom={}",
1113                k + 1,
1114                denom
1115            );
1116            break;
1117        }
1118        let alpha = rs_old / denom;
1119        if !alpha.is_finite() {
1120            return None;
1121        }
1122        x.scaled_add(alpha, &p);
1123        r.scaled_add(-alpha, &ap);
1124        let rs_new = r.dot(&r);
1125        if !rs_new.is_finite() {
1126            return None;
1127        }
1128        iters = k + 1;
1129        residual_norm = rs_new.max(0.0).sqrt();
1130        if rs_new <= target_sq {
1131            break;
1132        }
1133        let beta = rs_new / rs_old;
1134        if !beta.is_finite() {
1135            return None;
1136        }
1137        p.mapv_inplace(|value| beta * value);
1138        p += &r;
1139        rs_old = rs_new;
1140    }
1141
1142    Some((x, iters, residual_norm))
1143}
1144
1145impl HessianOperator for MatrixFreeSpdOperator {
1146    fn logdet(&self) -> f64 {
1147        *self
1148            .cached_logdet
1149            .get_or_compute(|| self.exact_dense_spectral().logdet())
1150    }
1151
1152    fn as_exact_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
1153        Some(self.exact_dense_spectral())
1154    }
1155
1156    fn trace_hinv_product(&self, a: &Array2<f64>) -> f64 {
1157        self.exact_dense_spectral().trace_hinv_product(a)
1158    }
1159
1160    fn trace_hinv_operator(&self, op: &dyn HyperOperator) -> f64 {
1161        self.exact_dense_spectral().trace_hinv_operator(op)
1162    }
1163
1164    fn trace_hinv_product_cross(&self, a: &Array2<f64>, b: &Array2<f64>) -> f64 {
1165        self.exact_dense_spectral().trace_hinv_product_cross(a, b)
1166    }
1167
1168    fn trace_hinv_matrix_operator_cross(
1169        &self,
1170        matrix: &Array2<f64>,
1171        op: &dyn HyperOperator,
1172    ) -> f64 {
1173        self.exact_dense_spectral()
1174            .trace_hinv_matrix_operator_cross(matrix, op)
1175    }
1176
1177    fn trace_hinv_operator_cross(
1178        &self,
1179        left: &dyn HyperOperator,
1180        right: &dyn HyperOperator,
1181    ) -> f64 {
1182        self.exact_dense_spectral()
1183            .trace_hinv_operator_cross(left, right)
1184    }
1185
1186    fn trace_logdet_operator(&self, op: &dyn HyperOperator) -> f64 {
1187        let trace_start = std::time::Instant::now();
1188        let result = self.exact_dense_spectral().trace_logdet_operator(op);
1189        log::info!(
1190            "[STAGE] matrix_free_spd trace_logdet_operator implicit={} dim={} elapsed={:.3}s",
1191            op.is_implicit(),
1192            op.dim(),
1193            trace_start.elapsed().as_secs_f64(),
1194        );
1195        result
1196    }
1197
1198    fn solve(&self, rhs: &Array1<f64>) -> Array1<f64> {
1199        self.exact_dense_spectral().solve(rhs)
1200    }
1201
1202    fn solve_multi(&self, rhs: &Array2<f64>) -> Array2<f64> {
1203        self.exact_dense_spectral().solve_multi(rhs)
1204    }
1205
1206    fn stochastic_trace_solve(&self, rhs: &Array1<f64>, rel_tol: f64) -> Array1<f64> {
1207        if self.use_trace_cg(rel_tol) {
1208            return self.cg_trace_solve(rhs, rel_tol, None, None);
1209        }
1210        self.solve(rhs)
1211    }
1212
1213    fn stochastic_trace_solve_for_probe(
1214        &self,
1215        rhs: &Array1<f64>,
1216        rel_tol: f64,
1217        probe_id: u64,
1218        trace_state: Option<&Arc<Mutex<StochasticTraceState>>>,
1219    ) -> Array1<f64> {
1220        if self.use_trace_cg(rel_tol) {
1221            return self.cg_trace_solve(rhs, rel_tol, Some(probe_id), trace_state);
1222        }
1223        self.solve(rhs)
1224    }
1225
1226    fn stochastic_trace_solve_multi(&self, rhs: &Array2<f64>, rel_tol: f64) -> Array2<f64> {
1227        if self.use_trace_cg(rel_tol) {
1228            let mut out = Array2::<f64>::zeros(rhs.raw_dim());
1229            for j in 0..rhs.ncols() {
1230                let solved = self.cg_trace_solve(&rhs.column(j).to_owned(), rel_tol, None, None);
1231                out.column_mut(j).assign(&solved);
1232            }
1233            return out;
1234        }
1235        self.solve_multi(rhs)
1236    }
1237
1238    fn trace_logdet_hessian_cross(&self, h_i: &Array2<f64>, h_j: &Array2<f64>) -> f64 {
1239        self.exact_dense_spectral()
1240            .trace_logdet_hessian_cross(h_i, h_j)
1241    }
1242
1243    fn trace_logdet_hessian_cross_matrix_operator(
1244        &self,
1245        h_i: &Array2<f64>,
1246        h_j: &dyn HyperOperator,
1247    ) -> f64 {
1248        self.exact_dense_spectral()
1249            .trace_logdet_hessian_cross_matrix_operator(h_i, h_j)
1250    }
1251
1252    fn trace_logdet_hessian_cross_operator(
1253        &self,
1254        h_i: &dyn HyperOperator,
1255        h_j: &dyn HyperOperator,
1256    ) -> f64 {
1257        self.exact_dense_spectral()
1258            .trace_logdet_hessian_cross_operator(h_i, h_j)
1259    }
1260
1261    fn active_rank(&self) -> usize {
1262        self.n_dim
1263    }
1264
1265    fn dim(&self) -> usize {
1266        self.n_dim
1267    }
1268
1269    fn is_dense(&self) -> bool {
1270        true
1271    }
1272
1273    /// The operator delegates `logdet`, `trace_hinv_*`, `trace_logdet_*`,
1274    /// `solve`, and `solve_multi` to a lazily-built `DenseSpectralOperator`
1275    /// whenever the exact-dense materialization fits the configured byte cap
1276    /// (see `exact_dense_spectral_budget_ok` / `EXACT_DENSE_SPECTRAL_MAX_BYTES`).
1277    /// In that regime the algebra is exact spectral — there is no stochastic
1278    /// preference to advertise, and forcing the caller to take the Hutchinson
1279    /// path would replace an O(p²) exact reduction with O(k·apply) noisy probes.
1280    ///
1281    /// When the budget is exceeded the dense factor cannot be built and the
1282    /// CG trace-solve path added in 2bd6af68 is the only feasible route; the
1283    /// flag flips to `true` so `stochastic_trace_solve*` callers route through
1284    /// `cg_trace_solve` instead of crashing in `exact_dense_spectral().expect`.
1285    fn prefers_stochastic_trace_estimation(&self) -> bool {
1286        !self.exact_dense_spectral_budget_ok()
1287    }
1288
1289    /// Mirror the `prefers_stochastic_trace_estimation` gate: when the dense
1290    /// factor is reachable the operator's logdet / trace_hinv reductions all
1291    /// resolve through `DenseSpectralOperator`, whose
1292    /// `logdet_traces_match_hinv_kernel` is `false` for the smooth-spectral
1293    /// regularization variants we run. Reporting `true` here would let the
1294    /// outer evaluator route logdet-gradient/Hessian traces through the
1295    /// Hutchinson `H⁻¹` kernel which does not satisfy
1296    /// `∂ log|H| = tr(H⁻¹ ∂H)` under smooth-spectral. The CG-only regime
1297    /// (budget exceeded) lacks a dense reference so falling back to the
1298    /// stochastic kernel is acceptable as a best-effort estimate.
1299    fn logdet_traces_match_hinv_kernel(&self) -> bool {
1300        !self.exact_dense_spectral_budget_ok()
1301    }
1302
1303    fn as_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
1304        self.dense_spectral()
1305    }
1306
1307    fn has_matrix_free_trace_cg_operator(&self) -> bool {
1308        true
1309    }
1310}
1311
1312// ═══════════════════════════════════════════════════════════════════════════
1313//  Helpers for custom family → InnerSolution conversion
1314// ═══════════════════════════════════════════════════════════════════════════
1315
1316/// Compute the square root of a symmetric positive semidefinite penalty matrix.
1317///
1318/// Returns R such that S = RᵀR, with R having `rank(S)` rows.
1319/// Uses eigendecomposition: S = U Λ U^T → R = Λ_+^{1/2} U_+^T.
1320pub fn penalty_matrix_root(s: &Array2<f64>) -> Result<Array2<f64>, String> {
1321    use faer::Side;
1322    let n = s.nrows();
1323    if n != s.ncols() {
1324        return Err(RemlError::DimensionMismatch {
1325            reason: format!(
1326                "penalty_matrix_root: expected square matrix, got {}×{}",
1327                n,
1328                s.ncols()
1329            ),
1330        }
1331        .into());
1332    }
1333    if n == 0 {
1334        return Ok(Array2::zeros((0, 0)));
1335    }
1336
1337    let (eigenvalues, eigenvectors) = s
1338        .eigh(Side::Lower)
1339        .map_err(|e| format!("penalty_matrix_root eigendecomposition failed: {e}"))?;
1340
1341    let max_ev = eigenvalues.iter().copied().fold(0.0_f64, f64::max);
1342    let tol = (n.max(1) as f64) * f64::EPSILON * max_ev.max(1e-12);
1343
1344    let active: Vec<usize> = eigenvalues
1345        .iter()
1346        .enumerate()
1347        .filter(|(_, v)| **v > tol)
1348        .map(|(i, _)| i)
1349        .collect();
1350    let rank = active.len();
1351
1352    let mut r = Array2::zeros((rank, n));
1353    for (out_row, &idx) in active.iter().enumerate() {
1354        let scale = eigenvalues[idx].sqrt();
1355        for col in 0..n {
1356            r[[out_row, col]] = scale * eigenvectors[[col, idx]];
1357        }
1358    }
1359    Ok(r)
1360}
1361
1362/// Compute the exact pseudo-logdet log|S|₊ and its ρ-derivatives for a
1363/// blockwise penalty structure.
1364///
1365/// For each block, eigendecomposes S_b = Σ λ_k S_k, identifies the positive
1366/// eigenspace (structural nullspace detected from the eigenspectrum), and
1367/// computes exact derivatives on that subspace:
1368///
1369/// - L(S) = Σ_{σ_i > ε} log σ_i
1370/// - ∂/∂ρₖ L = tr(S⁺ Aₖ)
1371/// - ∂²/(∂ρₖ∂ρₗ) L = δ_{kl} ∂_k L − tr(S⁺ Aₗ S⁺ Aₖ)
1372///
1373/// For S(ρ) = Σ exp(ρ_k) S_k with S_k ⪰ 0, the nullspace N(S) = ∩_k N(S_k)
1374/// is structurally fixed (independent of ρ), so L is C∞ in ρ and these are
1375/// its exact derivatives.
1376///
1377/// `per_block_rho[b]` contains the log-lambdas for block b.
1378/// `per_block_penalties[b]` contains the penalty matrices for block b.
1379/// `ridge` is an additional ridge for logdet stability (0 if not applicable).
1380pub fn compute_block_penalty_logdet_derivs(
1381    per_block_rho: &[Array1<f64>],
1382    per_block_penalties: &[&[Array2<f64>]],
1383    ridge: f64,
1384) -> Result<PenaltyLogdetDerivs, String> {
1385    use super::super::penalty_logdet::PenaltyPseudologdet;
1386
1387    let total_k: usize = per_block_rho.iter().map(|r| r.len()).sum();
1388    let block_offsets: Vec<usize> = per_block_rho
1389        .iter()
1390        .scan(0usize, |at, rho| {
1391            let current = *at;
1392            *at += rho.len();
1393            Some(current)
1394        })
1395        .collect();
1396
1397    struct BlockPenaltyLogdetResult {
1398        pub(crate) offset: usize,
1399        pub(crate) value: f64,
1400        pub(crate) first: Array1<f64>,
1401        pub(crate) second: Array2<f64>,
1402    }
1403
1404    let compute_block = |(b, block_rho): (usize, &Array1<f64>)| {
1405        let penalties = per_block_penalties[b];
1406        let kb = block_rho.len();
1407        if penalties.is_empty() || kb == 0 {
1408            return Ok(BlockPenaltyLogdetResult {
1409                offset: block_offsets[b],
1410                value: 0.0,
1411                first: Array1::zeros(kb),
1412                second: Array2::zeros((kb, kb)),
1413            });
1414        }
1415        let lambdas: Vec<f64> = block_rho.iter().map(|&r| r.exp()).collect();
1416
1417        // Single eigendecomposition via canonical PenaltyPseudologdet.
1418        //
1419        // No metadata-based structural-nullity hint: the classifier derives
1420        // the positive eigenspace from the assembled spectrum alone (issues
1421        // #192/#318).
1422        let pld = PenaltyPseudologdet::from_components(penalties, &lambdas, ridge)
1423            .map_err(|e| format!("penalty logdet failed for block {b}: {e}"))?;
1424
1425        let value = pld.value();
1426        let (first, second) = pld.rho_derivatives(penalties, &lambdas);
1427        Ok(BlockPenaltyLogdetResult {
1428            offset: block_offsets[b],
1429            value,
1430            first,
1431            second,
1432        })
1433    };
1434
1435    let block_results: Vec<BlockPenaltyLogdetResult> = if rayon::current_thread_index().is_some() {
1436        per_block_rho
1437            .iter()
1438            .enumerate()
1439            .map(compute_block)
1440            .collect::<Result<Vec<_>, String>>()?
1441    } else {
1442        per_block_rho
1443            .par_iter()
1444            .enumerate()
1445            .map(compute_block)
1446            .collect::<Result<Vec<_>, String>>()?
1447    };
1448
1449    let mut log_det_total = 0.0;
1450    let mut first = Array1::zeros(total_k);
1451    let mut second = Array2::zeros((total_k, total_k));
1452    for block in block_results {
1453        log_det_total += block.value;
1454        let kb = block.first.len();
1455        for k in 0..kb {
1456            first[block.offset + k] = block.first[k];
1457        }
1458        for k in 0..kb {
1459            for l in 0..kb {
1460                second[[block.offset + k, block.offset + l]] = block.second[[k, l]];
1461            }
1462        }
1463    }
1464
1465    Ok(PenaltyLogdetDerivs {
1466        value: log_det_total,
1467        first,
1468        second: Some(second),
1469    })
1470}
1471
1472// ═══════════════════════════════════════════════════════════════════════════
1473//  Stochastic trace estimation via Rademacher probes
1474// ═══════════════════════════════════════════════════════════════════════════
1475//
1476// For large-scale models, computing tr(H⁻¹ A_k) exactly via the full p×p
1477// eigendecomposition or column-by-column sparse solves costs O(p²) per
1478// coordinate k.  Stochastic trace estimation gives an unbiased estimate
1479// using only matrix–vector products (solves), at cost O(M·p) where M is the
1480// number of random probe vectors (typically 10–200).
1481//
1482// The Girard–Hutchinson estimator:
1483//
1484//   tr(H⁻¹ A_k) ≈ (1/M) Σ_m  z_mᵀ H⁻¹ A_k z_m
1485//
1486// where z_m are i.i.d. random vectors with E[zzᵀ] = I.
1487//
1488// Rademacher probes (entries ±1 with equal probability) have strictly
1489// lower variance than Gaussian probes:
1490//   Var_Rad = 2(‖S‖²_F − Σ_i S²_{ii})
1491//   Var_Gau = 2‖S‖²_F
1492// where S = sym(H⁻¹ A_k).  The diagonal variance term is always removed.
1493//
1494// Key efficiency: ONE H⁻¹ solve per probe, shared across ALL k
1495// coordinates.  For each probe z we compute w = H⁻¹z once, then for each k
1496// we get q_k = zᵀ(A_k w) with a cheap matrix–vector multiply.