Skip to main content

gam_solve/reml/reml_outer_engine/
dense_spectral.rs

1use super::*;
2
3// ═══════════════════════════════════════════════════════════════════════════
4//  Dense spectral HessianOperator implementation
5// ═══════════════════════════════════════════════════════════════════════════
6
7/// Dense spectral Hessian operator using eigendecomposition.
8///
9/// Computes logdet, trace, and solve from a single eigendecomposition,
10/// guaranteeing spectral consistency. Indefinite or near-singular eigenvalues
11/// are handled via smooth spectral regularization `r_ε(σ)` rather than hard
12/// clamping, ensuring that logdet and inverse use the same smooth mapping.
13pub struct DenseSpectralOperator {
14    /// Regularized eigenvalues: `r_ε(σ_i)` for each raw eigenvalue `σ_i`.
15    pub(crate) reg_eigenvalues: Vec<f64>,
16    /// Per-eigenvalue mask: `true` if the eigenpair participates in all
17    /// traces, solves, and logdet contributions.  Under
18    /// [`PseudoLogdetMode::Smooth`] every entry is `true`.  Under
19    /// [`PseudoLogdetMode::HardPseudo`] entries with `σ_j ≤ ε` are `false`,
20    /// so the numerical null space is excluded consistently from
21    /// `log|H|_+`, its gradient, its cross-traces, AND `H⁻¹` solves
22    /// (`H⁺` on the active subspace).
23    pub(crate) active_mask: Vec<bool>,
24    /// Eigenvectors of H (columns).
25    pub(crate) eigenvectors: Array2<f64>,
26    /// Precomputed: W = U diag(1/√r_ε(σ)) for efficient traces.
27    /// trace(H⁻¹ A) = Σ (AW ⊙ W)
28    pub(crate) w_factor: Array2<f64>,
29    /// Precomputed kernel K_ab = 1 / (r_a r_b) for exact H⁻¹ cross traces in
30    /// the eigenbasis.
31    pub(crate) hinv_cross_kernel: Array2<f64>,
32    /// Precomputed: G = U diag(1/√(√(σ² + 4ε²))) for logdet gradient traces.
33    /// trace(G_ε(H) A) = Σ (AG ⊙ G) where G_ε uses φ'(σ) = 1/√(σ² + 4ε²).
34    pub(crate) g_factor: Array2<f64>,
35    /// Precomputed divided-difference kernel Γ for exact logdet Hessian cross traces
36    /// in the eigenbasis.
37    pub(crate) logdet_hessian_kernel: Array2<f64>,
38    /// Precomputed log-determinant: Σ ln(r_ε(σ_i)).
39    pub(crate) cached_logdet: f64,
40    pub(crate) projected_factor_cache: ProjectedFactorCache,
41    /// Full dimension.
42    pub(crate) n_dim: usize,
43}
44
45impl DenseSpectralOperator {
46    /// Create from a symmetric matrix (may be indefinite or singular).
47    ///
48    /// The eigendecomposition is computed once. Eigenvalues are smoothly
49    /// regularized via `r_ε(σ)`. All subsequent operations (logdet, trace,
50    /// solve) use the regularized spectrum, ensuring mathematical consistency.
51    pub fn from_symmetric(h: &Array2<f64>) -> Result<Self, String> {
52        Self::from_symmetric_with_mode(h, PseudoLogdetMode::Smooth)
53    }
54
55    /// Variant of [`from_symmetric`](Self::from_symmetric) that selects the
56    /// log-determinant convention.
57    ///
58    /// See [`PseudoLogdetMode`] for the derivation and the exact set of
59    /// kernels that differ between the two modes.  At a high level:
60    /// `Smooth` keeps every eigenpair in play with a soft floor, whereas
61    /// `HardPseudo` masks out `σ_j ≤ ε` consistently across logdet,
62    /// gradient traces, cross-traces, and the H⁻¹ kernels.
63    pub fn from_symmetric_with_mode(
64        h: &Array2<f64>,
65        mode: PseudoLogdetMode,
66    ) -> Result<Self, String> {
67        use faer::Side;
68
69        let n = h.nrows();
70        if n != h.ncols() {
71            return Err(RemlError::DimensionMismatch {
72                reason: format!(
73                    "HessianOperator: expected square matrix, got {}×{}",
74                    n,
75                    h.ncols()
76                ),
77            }
78            .into());
79        }
80
81        let (eigenvalues, eigenvectors) = h
82            .eigh(Side::Lower)
83            .map_err(|e| format!("Eigendecomposition failed: {e}"))?;
84
85        let epsilon = spectral_epsilon(eigenvalues.as_slice().unwrap());
86
87        // `active[j]` selects which eigenpairs participate in every trace
88        // and in the cached logdet.
89        //
90        // `Smooth` is the regularized full-spectrum mode: every eigenpair stays
91        // active and singular directions are handled only through
92        // `r_ε(σ)`. This is the documented default semantics used by the
93        // unified REML/LAML objective.
94        //
95        // `HardPseudo` is the identified-subspace mode: eigenpairs with
96        // `σ_j ≤ ε` are excluded consistently from logdet, traces, and solves.
97        // Families that need exact pseudo-determinant behaviour opt into this
98        // mode explicitly through `pseudo_logdet_mode()`.
99        let active: Vec<bool> = match mode {
100            PseudoLogdetMode::Smooth => vec![true; n],
101            PseudoLogdetMode::HardPseudo => eigenvalues.iter().map(|&s| s > epsilon).collect(),
102        };
103
104        // Apply smooth regularization to all eigenvalues (even inactive ones:
105        // `reg_eigenvalues[j]` is still consulted by `trace_hinv_product`
106        // when using `w_factor[:, j]`, but we zero-out `w_factor[:, j]` for
107        // inactive eigenpairs so those entries never enter any sum).
108        let reg_eigenvalues: Vec<f64> = eigenvalues
109            .iter()
110            .map(|&sigma| spectral_regularize(sigma, epsilon))
111            .collect();
112
113        // Build W factor for traces: W[:, j] = u_j / sqrt(r_ε(σ_j)) on
114        // active eigenpairs, 0 otherwise.
115        let mut w_factor = Array2::zeros((n, n));
116        for j in 0..n {
117            if !active[j] {
118                continue;
119            }
120            let scale = 1.0 / reg_eigenvalues[j].sqrt();
121            for row in 0..n {
122                w_factor[[row, j]] = eigenvectors[[row, j]] * scale;
123            }
124        }
125
126        let mut hinv_cross_kernel = Array2::zeros((n, n));
127        for a in 0..n {
128            if !active[a] {
129                continue;
130            }
131            let inv_ra = 1.0 / reg_eigenvalues[a];
132            for b in 0..n {
133                if !active[b] {
134                    continue;
135                }
136                hinv_cross_kernel[[a, b]] = inv_ra / reg_eigenvalues[b];
137            }
138        }
139
140        // Build G factor for logdet gradient traces: G[:, j] = u_j / sqrt(√(σ_j² + 4ε²))
141        // φ'(σ) = 1/√(σ² + 4ε²), so we need 1/√(φ'(σ)) = (σ² + 4ε²)^{1/4}
142        // Actually: tr(G_ε A) = Σ_j φ'(σ_j) u_jᵀ A u_j = Σ (AG ⊙ G)
143        // where G[:, j] = u_j · √(φ'(σ_j)) = u_j / (σ_j² + 4ε²)^{1/4}
144        let four_eps_sq = 4.0 * epsilon * epsilon;
145        let mut g_factor = Array2::zeros((n, n));
146        for j in 0..n {
147            if !active[j] {
148                continue;
149            }
150            let sigma = eigenvalues[j];
151            let phi_prime = 1.0 / (sigma * sigma + four_eps_sq).sqrt();
152            let scale = phi_prime.sqrt();
153            for row in 0..n {
154                g_factor[[row, j]] = eigenvectors[[row, j]] * scale;
155            }
156        }
157
158        let mut logdet_hessian_kernel = Array2::zeros((n, n));
159        let sqrt_disc: Vec<f64> = eigenvalues
160            .iter()
161            .map(|&s| (s * s + four_eps_sq).sqrt())
162            .collect();
163        for a in 0..n {
164            if !active[a] {
165                continue;
166            }
167            let sigma_a = eigenvalues[a];
168            let sqrt_a = sqrt_disc[a];
169            for b in 0..n {
170                if !active[b] {
171                    continue;
172                }
173                logdet_hessian_kernel[[a, b]] = if a == b {
174                    -sigma_a / (sqrt_a * sqrt_a * sqrt_a)
175                } else {
176                    let sigma_b = eigenvalues[b];
177                    let sqrt_b = sqrt_disc[b];
178                    -(sigma_a + sigma_b) / (sqrt_a * sqrt_b * (sqrt_a + sqrt_b))
179                };
180            }
181        }
182
183        // Precompute logdet: Σ_{active} ln(r_ε(σ_i)).
184        let cached_logdet: f64 = reg_eigenvalues
185            .iter()
186            .zip(active.iter())
187            .filter_map(|(&v, &act)| if act { Some(v.ln()) } else { None })
188            .sum();
189
190        Ok(Self {
191            reg_eigenvalues,
192            active_mask: active,
193            eigenvectors,
194            w_factor,
195            hinv_cross_kernel,
196            g_factor,
197            logdet_hessian_kernel,
198            cached_logdet,
199            projected_factor_cache: ProjectedFactorCache::default(),
200            n_dim: n,
201        })
202    }
203
204    #[inline]
205    pub(crate) fn rotate_to_eigenbasis(&self, matrix: &Array2<f64>) -> Array2<f64> {
206        let left = gam_linalg::faer_ndarray::fast_atb(&self.eigenvectors, matrix);
207        gam_linalg::faer_ndarray::fast_ab(&left, &self.eigenvectors)
208    }
209
210    /// Factor `F` satisfying `trace(G_epsilon(H) A) = trace(F^T A F)`.
211    ///
212    /// Structured row-local operators use this to contract the logdet-gradient
213    /// trace directly in row space without forming `A F` in coefficient space.
214    pub fn logdet_gradient_factor(&self) -> &Array2<f64> {
215        &self.g_factor
216    }
217
218    #[inline]
219    pub(crate) fn trace_hinv_product_cross_rotated(
220        &self,
221        a_rot: &Array2<f64>,
222        b_rot: &Array2<f64>,
223    ) -> f64 {
224        let mut result = 0.0;
225        for ((kernel_row, a_row), b_col) in self
226            .hinv_cross_kernel
227            .rows()
228            .into_iter()
229            .zip(a_rot.rows().into_iter())
230            .zip(b_rot.columns().into_iter())
231        {
232            for ((kernel, a_value), b_value) in kernel_row
233                .iter()
234                .copied()
235                .zip(a_row.iter().copied())
236                .zip(b_col.iter().copied())
237            {
238                result += kernel * a_value * b_value;
239            }
240        }
241        result
242    }
243
244    #[inline]
245    pub(crate) fn trace_hinv_product_cross_dense(&self, a: &Array2<f64>, b: &Array2<f64>) -> f64 {
246        let a_rot = self.rotate_to_eigenbasis(a);
247        if std::ptr::eq(a, b) {
248            return self.trace_hinv_product_cross_rotated(&a_rot, &a_rot);
249        }
250        let b_rot = self.rotate_to_eigenbasis(b);
251        self.trace_hinv_product_cross_rotated(&a_rot, &b_rot)
252    }
253
254    #[inline]
255    pub(crate) fn projected_matrix(&self, matrix: &Array2<f64>) -> Array2<f64> {
256        let left = gam_linalg::faer_ndarray::fast_atb(&self.w_factor, matrix);
257        gam_linalg::faer_ndarray::fast_ab(&left, &self.w_factor)
258    }
259
260    #[inline]
261    pub(crate) fn projected_operator(
262        &self,
263        factor: &Array2<f64>,
264        op: &dyn HyperOperator,
265    ) -> Array2<f64> {
266        if log::log_enabled!(log::Level::Info) {
267            let start = std::time::Instant::now();
268            let result = op.projected_matrix_cached(factor, &self.projected_factor_cache);
269            let signature = format!(
270                "DenseSpectralOperator::projected_operator dim={} rank={} implicit={}",
271                self.n_dim,
272                factor.ncols(),
273                op.is_implicit(),
274            );
275            dense_spectral_stage_log(&signature, start.elapsed().as_secs_f64());
276            result
277        } else {
278            op.projected_matrix_cached(factor, &self.projected_factor_cache)
279        }
280    }
281
282    #[inline]
283    pub(crate) fn trace_projected_cross(&self, left: &Array2<f64>, right: &Array2<f64>) -> f64 {
284        let mut result = 0.0;
285        for (left_row, right_col) in left.rows().into_iter().zip(right.columns().into_iter()) {
286            for (left_value, right_value) in left_row.iter().copied().zip(right_col.iter().copied())
287            {
288                result += left_value * right_value;
289            }
290        }
291        result
292    }
293
294    #[inline]
295    pub(crate) fn trace_logdet_hessian_cross_rotated(
296        &self,
297        h_i_rot: &Array2<f64>,
298        h_j_rot: &Array2<f64>,
299    ) -> f64 {
300        let mut result = 0.0;
301        for ((kernel_row, h_i_row), h_j_col) in self
302            .logdet_hessian_kernel
303            .rows()
304            .into_iter()
305            .zip(h_i_rot.rows().into_iter())
306            .zip(h_j_rot.columns().into_iter())
307        {
308            for ((kernel, h_i_value), h_j_value) in kernel_row
309                .iter()
310                .copied()
311                .zip(h_i_row.iter().copied())
312                .zip(h_j_col.iter().copied())
313            {
314                result += kernel * h_i_value * h_j_value;
315            }
316        }
317        result
318    }
319}
320
321/// Coalesce repeated identical `[STAGE]` log lines from `DenseSpectralOperator`
322/// methods. First occurrence of a (method, dims, implicit-flags) signature
323/// logs immediately; identical consecutive repeats are silenced and accrue
324/// into a counter, emitting heartbeat summaries at doubling cadence
325/// (2, 4, 8, 16, …) and a final summary when the signature changes.
326pub(crate) fn dense_spectral_stage_log(signature: &str, elapsed_s: f64) {
327    use std::sync::Mutex;
328    struct Repeat {
329        pub(crate) signature: String,
330        pub(crate) count: u64,
331        pub(crate) total: f64,
332        pub(crate) min: f64,
333        pub(crate) max: f64,
334        pub(crate) next_heartbeat: u64,
335    }
336    static REPEAT: Mutex<Option<Repeat>> = Mutex::new(None);
337
338    let mut guard = match REPEAT.lock() {
339        Ok(g) => g,
340        Err(poisoned) => poisoned.into_inner(),
341    };
342
343    if let Some(state) = guard.as_mut() {
344        if state.signature == signature {
345            state.count += 1;
346            state.total += elapsed_s;
347            if elapsed_s < state.min {
348                state.min = elapsed_s;
349            }
350            if elapsed_s > state.max {
351                state.max = elapsed_s;
352            }
353            if state.count >= state.next_heartbeat {
354                log::info!(
355                    "[STAGE] {} (×{} so far, total={:.3}s min={:.3}s max={:.3}s avg={:.3}s)",
356                    state.signature,
357                    state.count,
358                    state.total,
359                    state.min,
360                    state.max,
361                    state.total / state.count as f64,
362                );
363                state.next_heartbeat = state.next_heartbeat.saturating_mul(2);
364            }
365            return;
366        }
367        // Signature changed — flush a final summary for the previous one
368        // when it ran more than once (the first occurrence already logged
369        // its own line, so a count of 1 needs no follow-up).
370        if state.count > 1 {
371            log::info!(
372                "[STAGE] {} final ×{} total={:.3}s min={:.3}s max={:.3}s avg={:.3}s",
373                state.signature,
374                state.count,
375                state.total,
376                state.min,
377                state.max,
378                state.total / state.count as f64,
379            );
380        }
381    }
382
383    log::info!("[STAGE] {} elapsed={:.3}s", signature, elapsed_s);
384    *guard = Some(Repeat {
385        signature: signature.to_string(),
386        count: 1,
387        total: elapsed_s,
388        min: elapsed_s,
389        max: elapsed_s,
390        next_heartbeat: 2,
391    });
392}
393
394impl HessianOperator for DenseSpectralOperator {
395    fn logdet(&self) -> f64 {
396        self.cached_logdet
397    }
398
399    fn as_exact_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
400        Some(self)
401    }
402
403    fn assemble_h_dense_for_tangent_projection(&self) -> Result<Array2<f64>, String> {
404        Ok(assemble_h_raw_dense(self))
405    }
406
407    fn trace_hinv_product(&self, a: &Array2<f64>) -> f64 {
408        // tr(H_reg⁻¹ A) = Σ_j (1/r_ε(σ_j)) uⱼᵀAuⱼ
409        // Computed as Σ (AW ⊙ W) where W = U diag(1/√r_ε(σ)).
410        let aw = a.dot(&self.w_factor);
411        aw.iter()
412            .zip(self.w_factor.iter())
413            .map(|(&a, &w)| a * w)
414            .sum()
415    }
416
417    fn solve(&self, rhs: &Array1<f64>) -> Array1<f64> {
418        // H_reg⁻¹ v = Σ_j (1/r_ε(σ_j)) (uⱼᵀv) uⱼ.  Inactive eigenpairs
419        // (σ_j ≤ ε under `HardPseudo`) are skipped so the returned vector
420        // lives entirely in the active subspace — otherwise v_k picks up a
421        // huge spurious component along the numerical null space direction
422        // (coefficient ~ 1/r_ε(σ_j) for σ_j ≈ 0) that is not part of the
423        // IFT mode response `dβ̂/dρ` and would leak into the REML gradient.
424        let mut result = Array1::zeros(self.n_dim);
425        for j in 0..self.n_dim {
426            if !self.active_mask[j] {
427                continue;
428            }
429            let u = self.eigenvectors.column(j);
430            let coeff = u.dot(rhs) / self.reg_eigenvalues[j];
431            for row in 0..self.n_dim {
432                result[row] += coeff * u[row];
433            }
434        }
435        result
436    }
437
438    fn solve_multi(&self, rhs: &Array2<f64>) -> Array2<f64> {
439        let mut projected = self.eigenvectors.t().dot(rhs);
440        for j in 0..self.n_dim {
441            if self.active_mask[j] {
442                let scale = 1.0 / self.reg_eigenvalues[j];
443                projected.row_mut(j).mapv_inplace(|value| value * scale);
444            } else {
445                // Zero out inactive eigendirections so `H⁺` acts on the
446                // active subspace only (mirroring the single-vector `solve`).
447                projected.row_mut(j).fill(0.0);
448            }
449        }
450        self.eigenvectors.dot(&projected)
451    }
452
453    fn trace_hinv_product_cross(&self, a: &Array2<f64>, b: &Array2<f64>) -> f64 {
454        self.trace_hinv_product_cross_dense(a, b)
455    }
456
457    fn trace_hinv_operator(&self, op: &dyn HyperOperator) -> f64 {
458        if log::log_enabled!(log::Level::Info) {
459            let start = std::time::Instant::now();
460            let result =
461                op.trace_projected_factor_cached(&self.w_factor, &self.projected_factor_cache);
462            let signature = format!(
463                "DenseSpectralOperator::trace_hinv_operator dim={} rank={} implicit={}",
464                self.n_dim,
465                self.w_factor.ncols(),
466                op.is_implicit(),
467            );
468            dense_spectral_stage_log(&signature, start.elapsed().as_secs_f64());
469            result
470        } else {
471            op.trace_projected_factor_cached(&self.w_factor, &self.projected_factor_cache)
472        }
473    }
474
475    fn trace_hinv_matrix_operator_cross(
476        &self,
477        matrix: &Array2<f64>,
478        op: &dyn HyperOperator,
479    ) -> f64 {
480        let left = self.w_factor.t().dot(matrix).dot(&self.w_factor);
481        let right = self.projected_operator(&self.w_factor, op);
482        self.trace_projected_cross(&left, &right)
483    }
484
485    fn trace_hinv_operator_cross(
486        &self,
487        left: &dyn HyperOperator,
488        right: &dyn HyperOperator,
489    ) -> f64 {
490        if log::log_enabled!(log::Level::Info) {
491            let start = std::time::Instant::now();
492            let left_proj = self.projected_operator(&self.w_factor, left);
493            let result = if std::ptr::addr_eq(left, right) {
494                self.trace_projected_cross(&left_proj, &left_proj)
495            } else {
496                let right_proj = self.projected_operator(&self.w_factor, right);
497                self.trace_projected_cross(&left_proj, &right_proj)
498            };
499            let signature = format!(
500                "DenseSpectralOperator::trace_hinv_operator_cross dim={} rank={} left_implicit={} right_implicit={}",
501                self.n_dim,
502                self.w_factor.ncols(),
503                left.is_implicit(),
504                right.is_implicit(),
505            );
506            dense_spectral_stage_log(&signature, start.elapsed().as_secs_f64());
507            result
508        } else {
509            let left_proj = self.projected_operator(&self.w_factor, left);
510            if std::ptr::addr_eq(left, right) {
511                self.trace_projected_cross(&left_proj, &left_proj)
512            } else {
513                let right_proj = self.projected_operator(&self.w_factor, right);
514                self.trace_projected_cross(&left_proj, &right_proj)
515            }
516        }
517    }
518
519    fn trace_logdet_gradient(&self, a: &Array2<f64>) -> f64 {
520        // tr(G_ε(H) A) = Σ_j φ'(σ_j) uⱼᵀAuⱼ
521        // where φ'(σ) = 1/√(σ² + 4ε²).
522        // Computed as Σ (AG ⊙ G) where G = U diag(√φ'(σ)).
523        let ag = a.dot(&self.g_factor);
524        ag.iter()
525            .zip(self.g_factor.iter())
526            .map(|(&a, &g)| a * g)
527            .sum()
528    }
529
530    fn xt_logdet_kernel_x_diagonal(&self, x: &DesignMatrix) -> Array1<f64> {
531        // h^G_i = ‖(X G)_{i,:}‖² where G_ε = G Gᵀ and G = self.g_factor.
532        // The dominant cost at large scale is the (n × p)·(p × rank) matmul
533        // — for matern60 with n=320K, p=101 that's ~3.3 GFLOPs and the
534        // ndarray default `.dot()` runs single-threaded (no BLAS feature
535        // enabled in this crate's build), so we route through faer's parallel
536        // SIMD GEMM. For operator-backed (Lazy) designs we additionally
537        // stream by row chunk so we never materialize the full (n×p) block
538        // at large scale.
539        let n = x.nrows();
540        let p = x.ncols();
541        let rank = self.g_factor.ncols();
542        let mut h = Array1::<f64>::zeros(n);
543        if n == 0 || p == 0 || rank == 0 {
544            return h;
545        }
546        // Issue #922: offload this n-dependent pass to the device pool when a
547        // GPU was probed and n·p² clears the dispatch floor. The result is the
548        // same f64 arithmetic (X·G then row-wise ‖·‖²), just relocated across
549        // every device via `scatter_batched`; any failure falls through to the
550        // faer CPU stream below so the REML criterion is byte-for-byte
551        // unchanged on machines without a GPU.
552        if let Some(gpu) = gam_gpu::linalg_dispatch::try_fast_spectral_leverage_diagonal(
553            x,
554            self.g_factor.view(),
555        ) {
556            return gpu;
557        }
558        let chunk_rows = byte_balanced_row_chunk(p + rank, n);
559        let mut start = 0usize;
560        while start < n {
561            let end = (start + chunk_rows).min(n);
562            let rows = x.try_row_chunk(start..end).unwrap_or_else(|err| {
563                // SAFETY: `try_row_chunk` only fails on operator implementation
564                // bugs — the `start..end` range is constructed from
565                // `0..n = 0..x.nrows()` with `end = (start+block).min(n)`,
566                // so it is always a valid sub-range of `x`. A failure here
567                // means the operator violated its row-chunk contract.
568                // SAFETY: row range built from 0..x.nrows(); failure means operator broke its contract.
569                reml_contract_panic(format!(
570                    "xt_logdet_kernel_x_diagonal: row chunk failed: {err}"
571                ))
572            });
573            let xg = gam_linalg::faer_ndarray::fast_ab(&rows, &self.g_factor);
574            for (local, row) in xg.outer_iter().enumerate() {
575                h[start + local] = row.iter().map(|v| v * v).sum();
576            }
577            start = end;
578        }
579        h
580    }
581
582    fn trace_logdet_block_local(
583        &self,
584        block: &Array2<f64>,
585        scale: f64,
586        start: usize,
587        end: usize,
588    ) -> f64 {
589        // tr(G_ε A) = Σ (A·G ⊙ G) for block-local A.
590        // Only needs G[start..end, :] — O(block² × rank) instead of O(p² × rank).
591        let g_block = self.g_factor.slice(ndarray::s![start..end, ..]);
592        let ag = block.dot(&g_block);
593        scale
594            * ag.iter()
595                .zip(g_block.iter())
596                .map(|(&a, &g)| a * g)
597                .sum::<f64>()
598    }
599
600    fn trace_logdet_operator(&self, op: &dyn HyperOperator) -> f64 {
601        if log::log_enabled!(log::Level::Info) {
602            let start = std::time::Instant::now();
603            let result =
604                op.trace_projected_factor_cached(&self.g_factor, &self.projected_factor_cache);
605            let signature = format!(
606                "DenseSpectralOperator::trace_logdet_operator dim={} rank={} implicit={}",
607                self.n_dim,
608                self.g_factor.ncols(),
609                op.is_implicit(),
610            );
611            dense_spectral_stage_log(&signature, start.elapsed().as_secs_f64());
612            result
613        } else {
614            op.trace_projected_factor_cached(&self.g_factor, &self.projected_factor_cache)
615        }
616    }
617
618    fn trace_logdet_hessian_cross(&self, h_i: &Array2<f64>, h_j: &Array2<f64>) -> f64 {
619        let hp_i = self.rotate_to_eigenbasis(h_i);
620        if std::ptr::eq(h_i, h_j) {
621            return self.trace_logdet_hessian_cross_rotated(&hp_i, &hp_i);
622        }
623        let hp_j = self.rotate_to_eigenbasis(h_j);
624        self.trace_logdet_hessian_cross_rotated(&hp_i, &hp_j)
625    }
626
627    fn trace_logdet_hessian_cross_matrix_operator(
628        &self,
629        h_i: &Array2<f64>,
630        h_j: &dyn HyperOperator,
631    ) -> f64 {
632        let hp_i = self.rotate_to_eigenbasis(h_i);
633        let hp_j = self.projected_operator(&self.eigenvectors, h_j);
634        self.trace_logdet_hessian_cross_rotated(&hp_i, &hp_j)
635    }
636
637    fn trace_logdet_hessian_cross_operator(
638        &self,
639        h_i: &dyn HyperOperator,
640        h_j: &dyn HyperOperator,
641    ) -> f64 {
642        let hp_i = self.projected_operator(&self.eigenvectors, h_i);
643        if std::ptr::addr_eq(h_i, h_j) {
644            return self.trace_logdet_hessian_cross_rotated(&hp_i, &hp_i);
645        }
646        let hp_j = self.projected_operator(&self.eigenvectors, h_j);
647        self.trace_logdet_hessian_cross_rotated(&hp_i, &hp_j)
648    }
649
650    fn active_rank(&self) -> usize {
651        self.active_mask.iter().filter(|&&active| active).count()
652    }
653
654    fn dim(&self) -> usize {
655        self.n_dim
656    }
657
658    fn is_dense(&self) -> bool {
659        true
660    }
661
662    fn prefers_stochastic_trace_estimation(&self) -> bool {
663        false
664    }
665
666    fn logdet_traces_match_hinv_kernel(&self) -> bool {
667        false
668    }
669
670    fn as_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
671        Some(self)
672    }
673}