Skip to main content

gam_solve/arrow_schur/
slq_logdet.rs

1//! Matrix-free log-determinant via Stochastic Lanczos Quadrature (SLQ).
2//!
3//! BIBLIOGRAPHY
4//!
5//! * Ubaru, Chen, Saad, "Fast Estimation of tr(f(A)) via Stochastic Lanczos
6//!   Quadrature", SIAM J. Matrix Anal. Appl. 38(4), 2017: the canonical SLQ
7//!   estimator for `tr(f(A))` with `f = ln` giving `log det A = tr(ln A)`.
8//! * Bai, Fahey, Golub, "Some large-scale matrix computation problems", J.
9//!   Comput. Appl. Math. 74, 1996: Gauss-quadrature view of `uᵀ f(A) u` as
10//!   `Σ_i (e₁ᵀ y_i)² f(θ_i)` over the Lanczos tridiagonal eigenpairs `(θ_i,y_i)`.
11//! * Hutchinson, "A stochastic estimator of the trace of the influence matrix",
12//!   Comm. Statist. Simulation Comput. 19, 1990: Rademacher probe vectors with
13//!   `E[zᵀ M z] = tr(M)` and `‖z‖² = dim`.
14//! * Golub, Meurant, "Matrices, Moments and Quadrature with Applications", 2010:
15//!   Lanczos quadrature, the need for reorthogonalization, and error analysis.
16//!
17//! ## What this provides
18//!
19//! [`slq_logdet`] estimates `log det A` for a symmetric positive-definite
20//! operator `A` available ONLY through matrix-vector products `v ↦ A v`. It
21//! never forms or factors `A`, so for the reduced-Schur Laplace normaliser it
22//! replaces the dense `O(k³/3)` Cholesky log-determinant with
23//! `O(num_probes · lanczos_steps · matvec)` work.
24//!
25//! The estimator is `tr(ln A) ≈ (dim / num_probes) Σ_p zₚᵀ ln(A) zₚ` with
26//! Rademacher probes `zₚ`, and each quadratic form `zᵀ ln(A) z` is evaluated by
27//! `m` steps of Lanczos against `A` started from `z/‖z‖`: building the symmetric
28//! tridiagonal `T_m` (with FULL reorthogonalization against the stored basis),
29//! eigendecomposing it, and reading the Gauss quadrature
30//! `‖z‖² Σ_i (τ_{i,0})² ln(θ_i)` where `θ_i` are `T_m`'s eigenvalues and
31//! `τ_{i,0}` is the first component of the `i`-th eigenvector.
32//!
33//! ## Reuse
34//!
35//! The numerically-critical Lanczos recurrence + full reorthogonalization +
36//! tridiagonal eigendecomposition is the workspace primitive
37//! [`gam_linalg::lanczos::symmetric_lanczos_eigenpairs`]; this module is the
38//! Hutchinson outer loop (Rademacher probes, averaging, standard error) on top
39//! of it. The clamped log-quadrature is computed here (rather than via
40//! [`gam_linalg::lanczos::symmetric_lanczos_log_quadrature`], which errors on a
41//! non-positive Ritz value) so a round-off-negative Ritz value floors to a tiny
42//! positive number instead of failing the whole evidence solve.
43//!
44//! ## Determinism
45//!
46//! The probe vectors are drawn from [`gam_linalg::utils::splitmix64`] seeded by
47//! `seed + probe_index`; there is NO system-RNG dependence, so a given
48//! `(dim, matvec, num_probes, lanczos_steps, seed)` always returns the same
49//! estimate. This is required by the evidence path, whose REML outer loop must
50//! be reproducible.
51
52use super::*;
53use gam_linalg::lanczos::{symmetric_lanczos_eigenpairs, SymmetricLanczosOptions};
54use gam_linalg::utils::splitmix64;
55use rayon::iter::{IntoParallelIterator, ParallelIterator};
56
57/// Result of a Stochastic Lanczos Quadrature log-determinant estimate.
58#[derive(Debug, Clone, Copy)]
59pub struct SlqLogDet {
60    /// Estimate of `log det A`.
61    pub estimate: f64,
62    /// Standard error of the estimate: the sample standard deviation of the
63    /// per-probe contributions divided by `sqrt(num_probes)`. With a single
64    /// probe this is `0.0` (no spread is observable).
65    pub std_err: f64,
66}
67
68/// Floor on Ritz eigenvalues before taking `ln`. The operator is SPD so the
69/// Ritz values `θ_i` are positive in exact arithmetic; this clamps any tiny
70/// negative/zero value produced by round-off so `ln` stays finite. Chosen far
71/// below any physically meaningful curvature scale.
72const RITZ_LN_FLOOR: f64 = 1e-300;
73
74/// Draw a deterministic Rademacher (±1) vector of length `dim` into `z`,
75/// seeded reproducibly by `probe_seed`. Two bits per draw are wasteful but the
76/// per-element top-bit read keeps this trivially correct and stream-stable.
77fn rademacher_into(z: &mut Array1<f64>, probe_seed: u64) {
78    let mut state = probe_seed;
79    let mut bits: u64 = 0;
80    let mut remaining: u32 = 0;
81    for value in z.iter_mut() {
82        if remaining == 0 {
83            bits = splitmix64(&mut state);
84            remaining = 64;
85        }
86        *value = if bits & 1 == 1 { 1.0 } else { -1.0 };
87        bits >>= 1;
88        remaining -= 1;
89    }
90}
91
92/// Estimate `log det A` for an SPD operator given only its matrix-vector apply.
93///
94/// * `dim` — dimension of the operator (`A` is `dim × dim`).
95/// * `matvec` — applies `A`: `matvec(v) = A v`, for `v.len() == dim`.
96/// * `num_probes` — number of Rademacher probe vectors (Hutchinson samples).
97/// * `lanczos_steps` — Lanczos iterations per probe (Gauss-quadrature nodes).
98/// * `seed` — base seed; probe `p` uses `seed + p`, so results are reproducible.
99///
100/// Returns the averaged estimate and its standard error. For `dim == 0` the
101/// determinant of the empty operator is `1`, so the log-determinant is `0`.
102///
103/// `lanczos_steps` is internally capped at `dim` (a Krylov subspace cannot
104/// exceed the dimension) and `num_probes` is treated as at least `1`.
105pub fn slq_logdet(
106    dim: usize,
107    matvec: impl Fn(ArrayView1<f64>) -> Array1<f64> + Sync,
108    num_probes: usize,
109    lanczos_steps: usize,
110    seed: u64,
111) -> SlqLogDet {
112    if dim == 0 {
113        return SlqLogDet {
114            estimate: 0.0,
115            std_err: 0.0,
116        };
117    }
118    let num_probes = num_probes.max(1);
119    let steps = lanczos_steps.max(1).min(dim);
120    let norm_sq = dim as f64; // ‖z‖² for a ±1 Rademacher vector of length `dim`.
121
122    let lanczos_options = SymmetricLanczosOptions {
123        max_steps: steps,
124        // Pure SPD log-det quadrature: keep iterating until the Krylov space is
125        // genuinely exhausted (a true lucky breakdown), not at a slack residual.
126        residual_tol: 0.0,
127        local_reorthogonalize: false,
128        // Full reorthogonalization is numerically essential for the quadrature:
129        // without it Lanczos loses orthogonality and produces ghost Ritz values.
130        full_reorthogonalize: true,
131    };
132
133    // Each Hutchinson probe is a FULLY INDEPENDENT Lanczos run against the same
134    // read-only (`Sync`) operator, so at the K=32k evidence scale — where SLQ
135    // fires precisely because the operator is large (`num_probes`×`lanczos_steps`
136    // matvecs of an `O(k²)` apply) — the probes fan out across rayon workers for
137    // a near-`num_probes`× wall-clock cut on the dominant matvec work. Each probe
138    // carries its OWN Rademacher vector and matvec input scratch (no shared
139    // mutable state), and the contribution it computes depends only on
140    // `(dim, matvec, probe_seed, options)`, so it is bit-identical to the serial
141    // build. `into_par_iter().collect()` preserves probe order, and the
142    // mean/std-err reduction below runs SERIALLY over that ordered buffer, so the
143    // estimate and std-error are bit-for-bit reproducible for a fixed
144    // `(dim, matvec, num_probes, lanczos_steps, seed)` — the determinism the REML
145    // evidence outer loop requires (see the module `Determinism` note).
146    let matvec = &matvec;
147    let contributions: Vec<f64> = (0..num_probes)
148        .into_par_iter()
149        .map(|probe| {
150            let probe_seed = seed.wrapping_add(probe as u64);
151            let mut z = Array1::<f64>::zeros(dim);
152            rademacher_into(&mut z, probe_seed);
153            // The workspace Lanczos engine consumes `apply(&[f64], &mut [f64])`;
154            // wrap the ndarray `matvec` into that slice contract with a per-probe
155            // input buffer so probes never share mutable scratch.
156            let mut in_buf = Array1::<f64>::zeros(dim);
157            let mut apply = |x: &[f64], out: &mut [f64]| -> Result<(), String> {
158                in_buf
159                    .as_slice_mut()
160                    .expect("contiguous probe input buffer")
161                    .copy_from_slice(x);
162                let y = matvec(in_buf.view());
163                if y.len() != dim {
164                    return Err(format!(
165                        "slq_logdet matvec returned length {}, expected {dim}",
166                        y.len()
167                    ));
168                }
169                out.copy_from_slice(y.as_slice().expect("contiguous matvec output"));
170                Ok(())
171            };
172            let start = z.as_slice().expect("contiguous probe vector");
173            match symmetric_lanczos_eigenpairs(dim, start, lanczos_options, &mut apply) {
174                Ok(pairs) => {
175                    norm_sq * clamped_log_quadrature(&pairs.eigenvalues, &pairs.eigenvectors)
176                }
177                // A Lanczos failure (non-finite matvec / start) cannot be silently
178                // averaged in; the dense-Cholesky gate above this call should have
179                // caught a degenerate operator. Treat it as a zero contribution and
180                // let the std-error widen rather than poisoning the mean with NaN.
181                Err(_) => 0.0,
182            }
183        })
184        .collect();
185
186    let n = contributions.len() as f64;
187    let mean = contributions.iter().sum::<f64>() / n;
188    let std_err = if contributions.len() > 1 {
189        let var = contributions
190            .iter()
191            .map(|c| {
192                let d = c - mean;
193                d * d
194            })
195            .sum::<f64>()
196            / (n - 1.0);
197        (var / n).sqrt()
198    } else {
199        0.0
200    };
201
202    SlqLogDet {
203        estimate: mean,
204        std_err,
205    }
206}
207
208/// Gauss quadrature `e₁ᵀ ln(T) e₁ = Σ_i (τ_{i,0})² ln(θ_i)` over the Lanczos
209/// tridiagonal eigenpairs, with `θ_i` floored to [`RITZ_LN_FLOOR`] so a
210/// round-off-negative Ritz value (the SPD operator forbids genuine ones) cannot
211/// produce a `NaN`. `eigenvectors` columns are the Ritz vectors `y_i`; `τ_{i,0}`
212/// is their first component.
213fn clamped_log_quadrature(eigenvalues: &Array1<f64>, eigenvectors: &Array2<f64>) -> f64 {
214    let mut quad = 0.0_f64;
215    for i in 0..eigenvalues.len() {
216        let tau0 = eigenvectors[[0, i]];
217        let weight = tau0 * tau0;
218        let lambda = eigenvalues[i].max(RITZ_LN_FLOOR);
219        quad += weight * lambda.ln();
220    }
221    quad
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    /// Deterministic uniform draw in `[lo, hi)` from a SplitMix64 state — keeps
229    /// the test fixtures reproducible with no external RNG dependency.
230    fn next_uniform(state: &mut u64, lo: f64, hi: f64) -> f64 {
231        // 53-bit mantissa fraction in [0, 1).
232        let bits = splitmix64(state) >> 11;
233        let unit = (bits as f64) / ((1u64 << 53) as f64);
234        lo + (hi - lo) * unit
235    }
236
237    /// Build a random SPD matrix `A = MᵀM + δI` (`dim × dim`) from a fixed seed.
238    /// `m_rows ≥ dim` keeps `MᵀM` well-conditioned; `delta` sets the floor on the
239    /// spectrum (larger `delta` ⇒ better conditioned).
240    fn random_spd(dim: usize, m_rows: usize, delta: f64, seed: u64) -> Array2<f64> {
241        let mut state = seed;
242        let mut m = Array2::<f64>::zeros((m_rows, dim));
243        for value in m.iter_mut() {
244            *value = next_uniform(&mut state, -1.0, 1.0);
245        }
246        let mut a = m.t().dot(&m);
247        for i in 0..dim {
248            a[[i, i]] += delta;
249        }
250        // Symmetrize defensively against round-off.
251        for i in 0..dim {
252            for j in (i + 1)..dim {
253                let avg = 0.5 * (a[[i, j]] + a[[j, i]]);
254                a[[i, j]] = avg;
255                a[[j, i]] = avg;
256            }
257        }
258        a
259    }
260
261    /// Exact `log det A` via the workspace symmetric eigensolver (`Σ ln λ_i`).
262    fn exact_logdet(a: &Array2<f64>) -> f64 {
263        let (evals, _) = a.eigh(Side::Lower).expect("SPD eigendecomposition");
264        evals.iter().map(|&l| l.max(RITZ_LN_FLOOR).ln()).sum()
265    }
266
267    fn condition_number(a: &Array2<f64>) -> f64 {
268        let (evals, _) = a.eigh(Side::Lower).expect("SPD eigendecomposition");
269        let max = evals.iter().cloned().fold(f64::MIN, f64::max);
270        let min = evals.iter().cloned().fold(f64::MAX, f64::min);
271        max / min
272    }
273
274    #[test]
275    fn slq_matches_exact_logdet_well_conditioned() {
276        // A spread of dimensions in the 60–200 range, all well-conditioned
277        // (generous δ), checked against the exact eigenvalue log-determinant.
278        for (dim, seed) in [(60usize, 1u64), (120, 2), (200, 3)] {
279            let a = random_spd(dim, dim + 40, 5.0, seed);
280            let exact = exact_logdet(&a);
281            let cond = condition_number(&a);
282
283            let result = slq_logdet(dim, |v| a.dot(&v), 48, 70, 0xA5A5_0000 ^ seed);
284
285            let rel_err = (result.estimate - exact).abs() / exact.abs();
286            eprintln!(
287                "well-conditioned dim={dim} cond={cond:.2e} exact={exact:.6} \
288                 est={:.6} rel_err={rel_err:.4e} std_err={:.4e}",
289                result.estimate, result.std_err
290            );
291            assert!(
292                rel_err < 0.05,
293                "dim={dim}: SLQ relative error {rel_err:.4e} exceeds 5% \
294                 (exact={exact}, est={})",
295                result.estimate
296            );
297            // The exact value should sit within a few standard errors of the
298            // estimate (the std_err must be a meaningful uncertainty band).
299            assert!(
300                (result.estimate - exact).abs() < 3.0 * result.std_err + 0.05 * exact.abs(),
301                "dim={dim}: estimate not within ~3 std_err of exact \
302                 (|Δ|={:.4e}, std_err={:.4e})",
303                (result.estimate - exact).abs(),
304                result.std_err
305            );
306        }
307    }
308
309    #[test]
310    fn slq_handles_moderately_ill_conditioned() {
311        // Smaller δ ⇒ a tighter spectral floor ⇒ a more ill-conditioned A.
312        // More Lanczos steps resolve the wider spectrum.
313        let dim = 150usize;
314        let a = random_spd(dim, dim + 5, 0.05, 7);
315        let exact = exact_logdet(&a);
316        let cond = condition_number(&a);
317        assert!(
318            cond > 1e3,
319            "test fixture should be moderately ill-conditioned, got cond={cond:.2e}"
320        );
321
322        let result = slq_logdet(dim, |v| a.dot(&v), 40, 110, 0xC0FFEE);
323        let rel_err = (result.estimate - exact).abs() / exact.abs();
324        eprintln!(
325            "ill-conditioned dim={dim} cond={cond:.2e} exact={exact:.6} \
326             est={:.6} rel_err={rel_err:.4e} std_err={:.4e}",
327            result.estimate, result.std_err
328        );
329        assert!(
330            rel_err < 0.10,
331            "ill-conditioned dim={dim}: SLQ relative error {rel_err:.4e} \
332             exceeds 10% (cond={cond:.2e}, exact={exact}, est={})",
333            result.estimate
334        );
335    }
336
337    #[test]
338    fn slq_is_deterministic_for_fixed_seed() {
339        let dim = 80usize;
340        let a = random_spd(dim, dim + 20, 2.0, 11);
341        let r1 = slq_logdet(dim, |v| a.dot(&v), 24, 50, 99);
342        let r2 = slq_logdet(dim, |v| a.dot(&v), 24, 50, 99);
343        assert_eq!(
344            r1.estimate, r2.estimate,
345            "SLQ must be bit-reproducible for a fixed seed"
346        );
347        assert_eq!(r1.std_err, r2.std_err);
348    }
349
350    #[test]
351    fn slq_diagonal_operator_matches_closed_form() {
352        // A diagonal operator has a closed-form log-determinant Σ ln d_i; this
353        // exercises the matvec closure path without any matrix assembly.
354        let dim = 100usize;
355        let mut state = 123u64;
356        let diag: Vec<f64> = (0..dim).map(|_| next_uniform(&mut state, 0.5, 4.0)).collect();
357        let exact: f64 = diag.iter().map(|d| d.ln()).sum();
358
359        let diag_clone = diag.clone();
360        let result = slq_logdet(
361            dim,
362            move |v| {
363                let mut out = v.to_owned();
364                for (o, d) in out.iter_mut().zip(diag_clone.iter()) {
365                    *o *= d;
366                }
367                out
368            },
369            32,
370            60,
371            7,
372        );
373        let rel_err = (result.estimate - exact).abs() / exact.abs();
374        eprintln!(
375            "diagonal dim={dim} exact={exact:.6} est={:.6} rel_err={rel_err:.4e}",
376            result.estimate
377        );
378        assert!(
379            rel_err < 0.05,
380            "diagonal operator: relative error {rel_err:.4e} exceeds 5%"
381        );
382    }
383
384    #[test]
385    fn slq_empty_operator_is_zero() {
386        let result = slq_logdet(0, |v| v.to_owned(), 8, 8, 1);
387        assert_eq!(result.estimate, 0.0);
388        assert_eq!(result.std_err, 0.0);
389    }
390
391    #[test]
392    fn std_err_shrinks_with_more_probes() {
393        // The standard error of a Monte-Carlo mean falls ~1/sqrt(num_probes);
394        // many probes should give a tighter band than few.
395        let dim = 120usize;
396        let a = random_spd(dim, dim + 30, 3.0, 21);
397        let few = slq_logdet(dim, |v| a.dot(&v), 6, 60, 5);
398        let many = slq_logdet(dim, |v| a.dot(&v), 96, 60, 5);
399        eprintln!(
400            "std_err few(6)={:.4e} many(96)={:.4e}",
401            few.std_err, many.std_err
402        );
403        assert!(
404            many.std_err < few.std_err,
405            "more probes should reduce std_err (few={:.4e}, many={:.4e})",
406            few.std_err,
407            many.std_err
408        );
409    }
410}