gam_solve/arrow_schur/slq_logdet.rs
1//! Matrix-free log-determinant via Stochastic Lanczos Quadrature (SLQ).
2//!
3//! BIBLIOGRAPHY
4//!
5//! * Ubaru, Chen, Saad, "Fast Estimation of tr(f(A)) via Stochastic Lanczos
6//! Quadrature", SIAM J. Matrix Anal. Appl. 38(4), 2017: the canonical SLQ
7//! estimator for `tr(f(A))` with `f = ln` giving `log det A = tr(ln A)`.
8//! * Bai, Fahey, Golub, "Some large-scale matrix computation problems", J.
9//! Comput. Appl. Math. 74, 1996: Gauss-quadrature view of `uᵀ f(A) u` as
10//! `Σ_i (e₁ᵀ y_i)² f(θ_i)` over the Lanczos tridiagonal eigenpairs `(θ_i,y_i)`.
11//! * Hutchinson, "A stochastic estimator of the trace of the influence matrix",
12//! Comm. Statist. Simulation Comput. 19, 1990: Rademacher probe vectors with
13//! `E[zᵀ M z] = tr(M)` and `‖z‖² = dim`.
14//! * Golub, Meurant, "Matrices, Moments and Quadrature with Applications", 2010:
15//! Lanczos quadrature, the need for reorthogonalization, and error analysis.
16//!
17//! ## What this provides
18//!
19//! [`slq_logdet`] estimates `log det A` for a symmetric positive-definite
20//! operator `A` available ONLY through matrix-vector products `v ↦ A v`. It
21//! never forms or factors `A`, so for the reduced-Schur Laplace normaliser it
22//! replaces the dense `O(k³/3)` Cholesky log-determinant with
23//! `O(num_probes · lanczos_steps · matvec)` work.
24//!
25//! The estimator is `tr(ln A) ≈ (dim / num_probes) Σ_p zₚᵀ ln(A) zₚ` with
26//! Rademacher probes `zₚ`, and each quadratic form `zᵀ ln(A) z` is evaluated by
27//! `m` steps of Lanczos against `A` started from `z/‖z‖`: building the symmetric
28//! tridiagonal `T_m` (with FULL reorthogonalization against the stored basis),
29//! eigendecomposing it, and reading the Gauss quadrature
30//! `‖z‖² Σ_i (τ_{i,0})² ln(θ_i)` where `θ_i` are `T_m`'s eigenvalues and
31//! `τ_{i,0}` is the first component of the `i`-th eigenvector.
32//!
33//! ## Reuse
34//!
35//! The numerically-critical Lanczos recurrence + full reorthogonalization +
36//! tridiagonal eigendecomposition is the workspace primitive
37//! [`gam_linalg::lanczos::symmetric_lanczos_eigenpairs`]; this module is the
38//! Hutchinson outer loop (Rademacher probes, averaging, standard error) on top
39//! of it. The clamped log-quadrature is computed here (rather than via
40//! [`gam_linalg::lanczos::symmetric_lanczos_log_quadrature`], which errors on a
41//! non-positive Ritz value) so a round-off-negative Ritz value floors to a tiny
42//! positive number instead of failing the whole evidence solve.
43//!
44//! ## Determinism
45//!
46//! The probe vectors are drawn from [`gam_linalg::utils::splitmix64`] seeded by
47//! `seed + probe_index`; there is NO system-RNG dependence, so a given
48//! `(dim, matvec, num_probes, lanczos_steps, seed)` always returns the same
49//! estimate. This is required by the evidence path, whose REML outer loop must
50//! be reproducible.
51
52use super::*;
53use gam_linalg::lanczos::{symmetric_lanczos_eigenpairs, SymmetricLanczosOptions};
54use gam_linalg::utils::splitmix64;
55use rayon::iter::{IntoParallelIterator, ParallelIterator};
56
57/// Result of a Stochastic Lanczos Quadrature log-determinant estimate.
58#[derive(Debug, Clone, Copy)]
59pub struct SlqLogDet {
60 /// Estimate of `log det A`.
61 pub estimate: f64,
62 /// Standard error of the estimate: the sample standard deviation of the
63 /// per-probe contributions divided by `sqrt(num_probes)`. With a single
64 /// probe this is `0.0` (no spread is observable).
65 pub std_err: f64,
66}
67
68/// Floor on Ritz eigenvalues before taking `ln`. The operator is SPD so the
69/// Ritz values `θ_i` are positive in exact arithmetic; this clamps any tiny
70/// negative/zero value produced by round-off so `ln` stays finite. Chosen far
71/// below any physically meaningful curvature scale.
72const RITZ_LN_FLOOR: f64 = 1e-300;
73
74/// Draw a deterministic Rademacher (±1) vector of length `dim` into `z`,
75/// seeded reproducibly by `probe_seed`. Two bits per draw are wasteful but the
76/// per-element top-bit read keeps this trivially correct and stream-stable.
77fn rademacher_into(z: &mut Array1<f64>, probe_seed: u64) {
78 let mut state = probe_seed;
79 let mut bits: u64 = 0;
80 let mut remaining: u32 = 0;
81 for value in z.iter_mut() {
82 if remaining == 0 {
83 bits = splitmix64(&mut state);
84 remaining = 64;
85 }
86 *value = if bits & 1 == 1 { 1.0 } else { -1.0 };
87 bits >>= 1;
88 remaining -= 1;
89 }
90}
91
92/// Estimate `log det A` for an SPD operator given only its matrix-vector apply.
93///
94/// * `dim` — dimension of the operator (`A` is `dim × dim`).
95/// * `matvec` — applies `A`: `matvec(v) = A v`, for `v.len() == dim`.
96/// * `num_probes` — number of Rademacher probe vectors (Hutchinson samples).
97/// * `lanczos_steps` — Lanczos iterations per probe (Gauss-quadrature nodes).
98/// * `seed` — base seed; probe `p` uses `seed + p`, so results are reproducible.
99///
100/// Returns the averaged estimate and its standard error. For `dim == 0` the
101/// determinant of the empty operator is `1`, so the log-determinant is `0`.
102///
103/// `lanczos_steps` is internally capped at `dim` (a Krylov subspace cannot
104/// exceed the dimension) and `num_probes` is treated as at least `1`.
105pub fn slq_logdet(
106 dim: usize,
107 matvec: impl Fn(ArrayView1<f64>) -> Array1<f64> + Sync,
108 num_probes: usize,
109 lanczos_steps: usize,
110 seed: u64,
111) -> SlqLogDet {
112 if dim == 0 {
113 return SlqLogDet {
114 estimate: 0.0,
115 std_err: 0.0,
116 };
117 }
118 let num_probes = num_probes.max(1);
119 let steps = lanczos_steps.max(1).min(dim);
120 let norm_sq = dim as f64; // ‖z‖² for a ±1 Rademacher vector of length `dim`.
121
122 let lanczos_options = SymmetricLanczosOptions {
123 max_steps: steps,
124 // Pure SPD log-det quadrature: keep iterating until the Krylov space is
125 // genuinely exhausted (a true lucky breakdown), not at a slack residual.
126 residual_tol: 0.0,
127 local_reorthogonalize: false,
128 // Full reorthogonalization is numerically essential for the quadrature:
129 // without it Lanczos loses orthogonality and produces ghost Ritz values.
130 full_reorthogonalize: true,
131 };
132
133 // Each Hutchinson probe is a FULLY INDEPENDENT Lanczos run against the same
134 // read-only (`Sync`) operator, so at the K=32k evidence scale — where SLQ
135 // fires precisely because the operator is large (`num_probes`×`lanczos_steps`
136 // matvecs of an `O(k²)` apply) — the probes fan out across rayon workers for
137 // a near-`num_probes`× wall-clock cut on the dominant matvec work. Each probe
138 // carries its OWN Rademacher vector and matvec input scratch (no shared
139 // mutable state), and the contribution it computes depends only on
140 // `(dim, matvec, probe_seed, options)`, so it is bit-identical to the serial
141 // build. `into_par_iter().collect()` preserves probe order, and the
142 // mean/std-err reduction below runs SERIALLY over that ordered buffer, so the
143 // estimate and std-error are bit-for-bit reproducible for a fixed
144 // `(dim, matvec, num_probes, lanczos_steps, seed)` — the determinism the REML
145 // evidence outer loop requires (see the module `Determinism` note).
146 let matvec = &matvec;
147 let contributions: Vec<f64> = (0..num_probes)
148 .into_par_iter()
149 .map(|probe| {
150 let probe_seed = seed.wrapping_add(probe as u64);
151 let mut z = Array1::<f64>::zeros(dim);
152 rademacher_into(&mut z, probe_seed);
153 // The workspace Lanczos engine consumes `apply(&[f64], &mut [f64])`;
154 // wrap the ndarray `matvec` into that slice contract with a per-probe
155 // input buffer so probes never share mutable scratch.
156 let mut in_buf = Array1::<f64>::zeros(dim);
157 let mut apply = |x: &[f64], out: &mut [f64]| -> Result<(), String> {
158 in_buf
159 .as_slice_mut()
160 .expect("contiguous probe input buffer")
161 .copy_from_slice(x);
162 let y = matvec(in_buf.view());
163 if y.len() != dim {
164 return Err(format!(
165 "slq_logdet matvec returned length {}, expected {dim}",
166 y.len()
167 ));
168 }
169 out.copy_from_slice(y.as_slice().expect("contiguous matvec output"));
170 Ok(())
171 };
172 let start = z.as_slice().expect("contiguous probe vector");
173 match symmetric_lanczos_eigenpairs(dim, start, lanczos_options, &mut apply) {
174 Ok(pairs) => {
175 norm_sq * clamped_log_quadrature(&pairs.eigenvalues, &pairs.eigenvectors)
176 }
177 // A Lanczos failure (non-finite matvec / start) cannot be silently
178 // averaged in; the dense-Cholesky gate above this call should have
179 // caught a degenerate operator. Treat it as a zero contribution and
180 // let the std-error widen rather than poisoning the mean with NaN.
181 Err(_) => 0.0,
182 }
183 })
184 .collect();
185
186 let n = contributions.len() as f64;
187 let mean = contributions.iter().sum::<f64>() / n;
188 let std_err = if contributions.len() > 1 {
189 let var = contributions
190 .iter()
191 .map(|c| {
192 let d = c - mean;
193 d * d
194 })
195 .sum::<f64>()
196 / (n - 1.0);
197 (var / n).sqrt()
198 } else {
199 0.0
200 };
201
202 SlqLogDet {
203 estimate: mean,
204 std_err,
205 }
206}
207
208/// Gauss quadrature `e₁ᵀ ln(T) e₁ = Σ_i (τ_{i,0})² ln(θ_i)` over the Lanczos
209/// tridiagonal eigenpairs, with `θ_i` floored to [`RITZ_LN_FLOOR`] so a
210/// round-off-negative Ritz value (the SPD operator forbids genuine ones) cannot
211/// produce a `NaN`. `eigenvectors` columns are the Ritz vectors `y_i`; `τ_{i,0}`
212/// is their first component.
213fn clamped_log_quadrature(eigenvalues: &Array1<f64>, eigenvectors: &Array2<f64>) -> f64 {
214 let mut quad = 0.0_f64;
215 for i in 0..eigenvalues.len() {
216 let tau0 = eigenvectors[[0, i]];
217 let weight = tau0 * tau0;
218 let lambda = eigenvalues[i].max(RITZ_LN_FLOOR);
219 quad += weight * lambda.ln();
220 }
221 quad
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 /// Deterministic uniform draw in `[lo, hi)` from a SplitMix64 state — keeps
229 /// the test fixtures reproducible with no external RNG dependency.
230 fn next_uniform(state: &mut u64, lo: f64, hi: f64) -> f64 {
231 // 53-bit mantissa fraction in [0, 1).
232 let bits = splitmix64(state) >> 11;
233 let unit = (bits as f64) / ((1u64 << 53) as f64);
234 lo + (hi - lo) * unit
235 }
236
237 /// Build a random SPD matrix `A = MᵀM + δI` (`dim × dim`) from a fixed seed.
238 /// `m_rows ≥ dim` keeps `MᵀM` well-conditioned; `delta` sets the floor on the
239 /// spectrum (larger `delta` ⇒ better conditioned).
240 fn random_spd(dim: usize, m_rows: usize, delta: f64, seed: u64) -> Array2<f64> {
241 let mut state = seed;
242 let mut m = Array2::<f64>::zeros((m_rows, dim));
243 for value in m.iter_mut() {
244 *value = next_uniform(&mut state, -1.0, 1.0);
245 }
246 let mut a = m.t().dot(&m);
247 for i in 0..dim {
248 a[[i, i]] += delta;
249 }
250 // Symmetrize defensively against round-off.
251 for i in 0..dim {
252 for j in (i + 1)..dim {
253 let avg = 0.5 * (a[[i, j]] + a[[j, i]]);
254 a[[i, j]] = avg;
255 a[[j, i]] = avg;
256 }
257 }
258 a
259 }
260
261 /// Exact `log det A` via the workspace symmetric eigensolver (`Σ ln λ_i`).
262 fn exact_logdet(a: &Array2<f64>) -> f64 {
263 let (evals, _) = a.eigh(Side::Lower).expect("SPD eigendecomposition");
264 evals.iter().map(|&l| l.max(RITZ_LN_FLOOR).ln()).sum()
265 }
266
267 fn condition_number(a: &Array2<f64>) -> f64 {
268 let (evals, _) = a.eigh(Side::Lower).expect("SPD eigendecomposition");
269 let max = evals.iter().cloned().fold(f64::MIN, f64::max);
270 let min = evals.iter().cloned().fold(f64::MAX, f64::min);
271 max / min
272 }
273
274 #[test]
275 fn slq_matches_exact_logdet_well_conditioned() {
276 // A spread of dimensions in the 60–200 range, all well-conditioned
277 // (generous δ), checked against the exact eigenvalue log-determinant.
278 for (dim, seed) in [(60usize, 1u64), (120, 2), (200, 3)] {
279 let a = random_spd(dim, dim + 40, 5.0, seed);
280 let exact = exact_logdet(&a);
281 let cond = condition_number(&a);
282
283 let result = slq_logdet(dim, |v| a.dot(&v), 48, 70, 0xA5A5_0000 ^ seed);
284
285 let rel_err = (result.estimate - exact).abs() / exact.abs();
286 eprintln!(
287 "well-conditioned dim={dim} cond={cond:.2e} exact={exact:.6} \
288 est={:.6} rel_err={rel_err:.4e} std_err={:.4e}",
289 result.estimate, result.std_err
290 );
291 assert!(
292 rel_err < 0.05,
293 "dim={dim}: SLQ relative error {rel_err:.4e} exceeds 5% \
294 (exact={exact}, est={})",
295 result.estimate
296 );
297 // The exact value should sit within a few standard errors of the
298 // estimate (the std_err must be a meaningful uncertainty band).
299 assert!(
300 (result.estimate - exact).abs() < 3.0 * result.std_err + 0.05 * exact.abs(),
301 "dim={dim}: estimate not within ~3 std_err of exact \
302 (|Δ|={:.4e}, std_err={:.4e})",
303 (result.estimate - exact).abs(),
304 result.std_err
305 );
306 }
307 }
308
309 #[test]
310 fn slq_handles_moderately_ill_conditioned() {
311 // Smaller δ ⇒ a tighter spectral floor ⇒ a more ill-conditioned A.
312 // More Lanczos steps resolve the wider spectrum.
313 let dim = 150usize;
314 let a = random_spd(dim, dim + 5, 0.05, 7);
315 let exact = exact_logdet(&a);
316 let cond = condition_number(&a);
317 assert!(
318 cond > 1e3,
319 "test fixture should be moderately ill-conditioned, got cond={cond:.2e}"
320 );
321
322 let result = slq_logdet(dim, |v| a.dot(&v), 40, 110, 0xC0FFEE);
323 let rel_err = (result.estimate - exact).abs() / exact.abs();
324 eprintln!(
325 "ill-conditioned dim={dim} cond={cond:.2e} exact={exact:.6} \
326 est={:.6} rel_err={rel_err:.4e} std_err={:.4e}",
327 result.estimate, result.std_err
328 );
329 assert!(
330 rel_err < 0.10,
331 "ill-conditioned dim={dim}: SLQ relative error {rel_err:.4e} \
332 exceeds 10% (cond={cond:.2e}, exact={exact}, est={})",
333 result.estimate
334 );
335 }
336
337 #[test]
338 fn slq_is_deterministic_for_fixed_seed() {
339 let dim = 80usize;
340 let a = random_spd(dim, dim + 20, 2.0, 11);
341 let r1 = slq_logdet(dim, |v| a.dot(&v), 24, 50, 99);
342 let r2 = slq_logdet(dim, |v| a.dot(&v), 24, 50, 99);
343 assert_eq!(
344 r1.estimate, r2.estimate,
345 "SLQ must be bit-reproducible for a fixed seed"
346 );
347 assert_eq!(r1.std_err, r2.std_err);
348 }
349
350 #[test]
351 fn slq_diagonal_operator_matches_closed_form() {
352 // A diagonal operator has a closed-form log-determinant Σ ln d_i; this
353 // exercises the matvec closure path without any matrix assembly.
354 let dim = 100usize;
355 let mut state = 123u64;
356 let diag: Vec<f64> = (0..dim).map(|_| next_uniform(&mut state, 0.5, 4.0)).collect();
357 let exact: f64 = diag.iter().map(|d| d.ln()).sum();
358
359 let diag_clone = diag.clone();
360 let result = slq_logdet(
361 dim,
362 move |v| {
363 let mut out = v.to_owned();
364 for (o, d) in out.iter_mut().zip(diag_clone.iter()) {
365 *o *= d;
366 }
367 out
368 },
369 32,
370 60,
371 7,
372 );
373 let rel_err = (result.estimate - exact).abs() / exact.abs();
374 eprintln!(
375 "diagonal dim={dim} exact={exact:.6} est={:.6} rel_err={rel_err:.4e}",
376 result.estimate
377 );
378 assert!(
379 rel_err < 0.05,
380 "diagonal operator: relative error {rel_err:.4e} exceeds 5%"
381 );
382 }
383
384 #[test]
385 fn slq_empty_operator_is_zero() {
386 let result = slq_logdet(0, |v| v.to_owned(), 8, 8, 1);
387 assert_eq!(result.estimate, 0.0);
388 assert_eq!(result.std_err, 0.0);
389 }
390
391 #[test]
392 fn std_err_shrinks_with_more_probes() {
393 // The standard error of a Monte-Carlo mean falls ~1/sqrt(num_probes);
394 // many probes should give a tighter band than few.
395 let dim = 120usize;
396 let a = random_spd(dim, dim + 30, 3.0, 21);
397 let few = slq_logdet(dim, |v| a.dot(&v), 6, 60, 5);
398 let many = slq_logdet(dim, |v| a.dot(&v), 96, 60, 5);
399 eprintln!(
400 "std_err few(6)={:.4e} many(96)={:.4e}",
401 few.std_err, many.std_err
402 );
403 assert!(
404 many.std_err < few.std_err,
405 "more probes should reduce std_err (few={:.4e}, many={:.4e})",
406 few.std_err,
407 many.std_err
408 );
409 }
410}