Skip to main content

oxiblas_sparse/
stochastic.rs

1//! Stochastic trace and diagonal estimation algorithms for sparse matrices.
2//!
3//! This module implements randomised estimation methods for matrix functionals
4//! that are too expensive to compute exactly on large sparse matrices.
5//!
6//! # Algorithms
7//!
8//! - **Hutchinson**: Unbiased trace estimator using Rademacher probes, O(m·nnz).
9//! - **Hutch++**: Improved variance via a small deterministic sketch + stochastic correction.
10//! - **XTrace**: State-of-the-art minimum-variance trace estimator (Epperly et al. 2022).
11//! - **Bekas diagonal**: Per-element diagonal estimator using probe vectors.
12//! - **Frobenius norm**: Stochastic estimate of ‖A‖_F via ‖A z‖² probes.
13//! - **log-determinant**: `log det A` via stochastic Lanczos quadrature (SPD matrices).
14//!
15//! # References
16//!
17//! - Hutchinson (1989) "A stochastic estimator of the trace of the influence matrix".
18//! - Meyer, Musco, Musco, Woodruff (2021) "Hutch++: Optimal Stochastic Trace Estimation".
19//! - Epperly, Tropp, Webber (2022) "XTrace: Making the most of every sample in stochastic trace estimation".
20//! - Bekas, Kokiopoulou, Saad (2007) "An estimator for the diagonal of a matrix".
21//! - Bai, Fahey, Golub (1996) "Some large-scale matrix computation problems".
22
23use crate::csr::CsrMatrix;
24use std::fmt;
25
26// =============================================================================
27// Public configuration types
28// =============================================================================
29
30/// Type of random probe vectors used in stochastic estimation.
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum ProbeType {
33    /// Entries drawn uniformly from {-1, +1}.  Optimal for unbiased trace estimation.
34    Rademacher,
35    /// Entries drawn from N(0, 1).
36    Gaussian,
37    /// Uniformly distributed on the unit sphere (normalised Gaussian).
38    Spherical,
39}
40
41/// Configuration for stochastic estimators.
42#[derive(Debug, Clone)]
43pub struct StochasticConfig {
44    /// Number of probe vectors (default 30).
45    pub num_probes: usize,
46    /// Seed for reproducible results (default 42).
47    pub seed: u64,
48    /// Type of random probe vectors (default Rademacher).
49    pub probe_type: ProbeType,
50    /// Nominal confidence level for error estimates (default 0.95, informational only).
51    pub confidence: f64,
52}
53
54impl Default for StochasticConfig {
55    fn default() -> Self {
56        Self {
57            num_probes: 30,
58            seed: 42,
59            probe_type: ProbeType::Rademacher,
60            confidence: 0.95,
61        }
62    }
63}
64
65// =============================================================================
66// Result types
67// =============================================================================
68
69/// Result of a stochastic trace estimator.
70#[derive(Debug, Clone)]
71pub struct TraceEstimate {
72    /// The estimated trace value.
73    pub estimate: f64,
74    /// Standard error of the mean across probes.
75    pub std_error: f64,
76    /// Number of probe vectors actually used.
77    pub n_probes_used: usize,
78}
79
80/// Result of a stochastic diagonal estimator.
81#[derive(Debug, Clone)]
82pub struct DiagEstimate<T> {
83    /// Estimated diagonal entries.
84    pub diagonal: Vec<T>,
85    /// Per-element standard error of the mean.
86    pub std_error: Vec<T>,
87}
88
89// =============================================================================
90// Error type
91// =============================================================================
92
93/// Errors produced by stochastic estimation routines.
94#[derive(Debug, Clone)]
95pub enum StochasticError {
96    /// Invalid configuration parameter.
97    InvalidConfig(String),
98    /// Problem with the input matrix.
99    MatrixError(String),
100    /// Numerical failure during computation.
101    NumericalFailure(String),
102}
103
104impl fmt::Display for StochasticError {
105    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106        match self {
107            Self::InvalidConfig(msg) => write!(f, "Invalid stochastic config: {msg}"),
108            Self::MatrixError(msg) => write!(f, "Matrix error in stochastic estimator: {msg}"),
109            Self::NumericalFailure(msg) => {
110                write!(f, "Numerical failure in stochastic estimator: {msg}")
111            }
112        }
113    }
114}
115
116impl std::error::Error for StochasticError {}
117
118// =============================================================================
119// LCG random number helpers
120// =============================================================================
121
122/// Linear congruential generator state seeded by `base + probe_idx * 1_234_567`.
123///
124/// Constants: Knuth's multiplier + Newlib addend.
125#[inline]
126fn lcg_next(state: &mut u64) -> u64 {
127    *state = state
128        .wrapping_mul(6_364_136_223_846_793_005)
129        .wrapping_add(1_442_695_040_888_963_407);
130    *state
131}
132
133/// Generate a Rademacher vector of length `n` using an LCG seeded by `seed`.
134fn lcg_rademacher(seed: u64, n: usize) -> Vec<f64> {
135    let mut state = seed;
136    (0..n)
137        .map(|_| {
138            let v = lcg_next(&mut state);
139            if v >> 63 == 0 { 1.0 } else { -1.0 }
140        })
141        .collect()
142}
143
144/// Generate a Gaussian vector of length `n` using Box-Muller via an LCG.
145fn lcg_gaussian(seed: u64, n: usize) -> Vec<f64> {
146    let mut state = seed;
147    let mut out = Vec::with_capacity(n);
148    let mut spare: Option<f64> = None;
149
150    for _ in 0..n {
151        if let Some(s) = spare.take() {
152            out.push(s);
153        } else {
154            // Box-Muller: need two uniform samples in (0,1).
155            let u1 = loop {
156                let v = lcg_next(&mut state);
157                let f = (v as f64) / (u64::MAX as f64);
158                if f > 0.0 {
159                    break f;
160                }
161            };
162            let u2 = (lcg_next(&mut state) as f64) / (u64::MAX as f64);
163            let mag = (-2.0 * u1.ln()).sqrt();
164            let theta = 2.0 * std::f64::consts::PI * u2;
165            out.push(mag * theta.cos());
166            spare = Some(mag * theta.sin());
167        }
168    }
169    out
170}
171
172/// Generate a spherically-uniform vector (normalised Gaussian).
173fn lcg_spherical(seed: u64, n: usize) -> Vec<f64> {
174    let mut v = lcg_gaussian(seed, n);
175    let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
176    if norm > 0.0 {
177        for x in &mut v {
178            *x /= norm;
179        }
180    }
181    v
182}
183
184// =============================================================================
185// Statistics helpers
186// =============================================================================
187
188/// Compute mean and standard error of the mean from a slice of per-probe estimates.
189///
190/// Returns `(mean, std_error)`.  With fewer than 2 probes the std_error is 0.
191fn mean_and_stderr(samples: &[f64]) -> (f64, f64) {
192    let m = samples.len();
193    if m == 0 {
194        return (0.0, 0.0);
195    }
196    let mean = samples.iter().sum::<f64>() / m as f64;
197    if m < 2 {
198        return (mean, 0.0);
199    }
200    let var = samples.iter().map(|q| (q - mean).powi(2)).sum::<f64>() / (m - 1) as f64;
201    let stderr = (var / m as f64).sqrt();
202    (mean, stderr)
203}
204
205// =============================================================================
206// Main estimator struct
207// =============================================================================
208
209/// Stochastic estimator for matrix functionals on sparse matrices.
210///
211/// All methods operate on `CsrMatrix<f64>`.  The matrix need not be square for
212/// norm estimation, but must be square for trace/diagonal/log-det methods.
213pub struct StochasticEstimator {
214    config: StochasticConfig,
215}
216
217impl StochasticEstimator {
218    /// Create an estimator with the given configuration.
219    pub fn new(config: StochasticConfig) -> Self {
220        Self { config }
221    }
222
223    /// Create an estimator with default configuration.
224    pub fn with_default() -> Self {
225        Self::new(StochasticConfig::default())
226    }
227
228    // =========================================================================
229    // Trace estimators
230    // =========================================================================
231
232    /// Hutchinson unbiased trace estimator.
233    ///
234    /// Computes `tr(A) ≈ (1/m) Σᵢ zᵢᵀ A zᵢ` where `zᵢ` are probe vectors.
235    ///
236    /// Requires a square matrix.  Time complexity O(m · nnz).
237    pub fn trace_hutchinson(&self, csr: &CsrMatrix<f64>) -> Result<TraceEstimate, StochasticError> {
238        let n = self.require_square(csr)?;
239        let m = self.require_probes()?;
240
241        let mut samples = Vec::with_capacity(m);
242        let mut az = vec![0.0f64; n];
243
244        for i in 0..m {
245            let z = self.probe_vector(n, i);
246            Self::matvec(csr, &z, &mut az);
247            let quad: f64 = z.iter().zip(az.iter()).map(|(zi, ai)| zi * ai).sum();
248            samples.push(quad);
249        }
250
251        let (mean, stderr) = mean_and_stderr(&samples);
252        Ok(TraceEstimate {
253            estimate: mean,
254            std_error: stderr,
255            n_probes_used: m,
256        })
257    }
258
259    /// Hutch++ improved trace estimator.
260    ///
261    /// Allocates `m/3` probes to build a sketch `Q` (thin QR of `A S`) and
262    /// uses the remaining `2m/3` probes for a stochastic correction that
263    /// estimates `tr((I - Q Qᵀ) A (I - Q Qᵀ))`.
264    ///
265    /// The total estimate is `tr(Qᵀ A Q) + stochastic_correction`.
266    pub fn trace_hutch_plus_plus(
267        &self,
268        csr: &CsrMatrix<f64>,
269    ) -> Result<TraceEstimate, StochasticError> {
270        let n = self.require_square(csr)?;
271        let m = self.require_probes()?;
272
273        let k = (m / 3).max(1); // sketch size
274        let m_stoch = m - k; // stochastic probes
275
276        // --- Build sketch columns S = A * s_i for i in 0..k ---
277        let mut sketch_cols: Vec<Vec<f64>> = Vec::with_capacity(k);
278        for i in 0..k {
279            let s = self.probe_vector(n, i);
280            let mut col = vec![0.0f64; n];
281            Self::matvec(csr, &s, &mut col);
282            sketch_cols.push(col);
283        }
284
285        // Thin QR of sketch to get orthonormal Q (n x k columns).
286        let q_cols = Self::thin_qr(&sketch_cols, n, k);
287
288        // --- Deterministic component: tr(Qᵀ A Q) ---
289        let mut det_trace = 0.0f64;
290        let mut aqj = vec![0.0f64; n];
291        for col in &q_cols {
292            Self::matvec(csr, col, &mut aqj);
293            det_trace += col
294                .iter()
295                .zip(aqj.iter())
296                .map(|(qi, aqji)| qi * aqji)
297                .sum::<f64>();
298        }
299
300        // --- Stochastic correction: tr((I - QQᵀ) A (I - QQᵀ)) ---
301        let mut samples = Vec::with_capacity(m_stoch);
302        let mut az = vec![0.0f64; n];
303
304        for i in 0..m_stoch {
305            // probe index offset by k to avoid reusing the sketch probes
306            let z = self.probe_vector(n, k + i);
307            // w = (I - Q Qᵀ) z
308            let w = project_out(&q_cols, &z);
309            // A w
310            Self::matvec(csr, &w, &mut az);
311            // vᵀ A w where v = (I - Q Qᵀ) (A w) -- but for Hutch++ correction we use:
312            // zᵀ (I - QQᵀ) A (I - QQᵀ) z = wᵀ A w
313            let quad: f64 = w.iter().zip(az.iter()).map(|(wi, ai)| wi * ai).sum();
314            samples.push(quad);
315        }
316
317        let (stoch_mean, stoch_stderr) = mean_and_stderr(&samples);
318        let estimate = det_trace + stoch_mean;
319
320        // Combine standard errors (deterministic part has zero SE)
321        Ok(TraceEstimate {
322            estimate,
323            std_error: stoch_stderr,
324            n_probes_used: m,
325        })
326    }
327
328    /// XTrace estimator (Epperly, Tropp, Webber 2022).
329    ///
330    /// Minimum-variance exchange-based stochastic trace estimator.
331    /// Uses all `m` probes in pairs to cancel leading variance terms:
332    ///
333    /// For each pair (zᵢ, z_j) the estimator exploits the identity
334    /// `tr(A) = zᵀ A z / (zᵀ z)` on the unit sphere together with
335    /// an antithetic coupling to reduce variance.
336    ///
337    /// Implementation note: the published XTrace algorithm operates on a full
338    /// sketch matrix `Ω` (n×m) and computes `tr(A)` from `Y = A Ω` and a
339    /// thin QR of `Ω`.  We implement this exactly.
340    pub fn trace_xtrace(&self, csr: &CsrMatrix<f64>) -> Result<TraceEstimate, StochasticError> {
341        let n = self.require_square(csr)?;
342        let m = self.require_probes()?;
343
344        // Generate probe matrix Ω (columns ω_i) and compute Y = A Ω.
345        let mut omega: Vec<Vec<f64>> = Vec::with_capacity(m);
346        let mut y_cols: Vec<Vec<f64>> = Vec::with_capacity(m);
347        let mut az = vec![0.0f64; n];
348
349        for i in 0..m {
350            let z = self.probe_vector_spherical(n, i);
351            Self::matvec(csr, &z, &mut az);
352            y_cols.push(az.clone());
353            omega.push(z);
354        }
355
356        // Thin QR of Ω → Q (n×m orthonormal columns).
357        let q_cols = Self::thin_qr(&omega, n, m);
358
359        // XTrace estimate: tr(Qᵀ Y) = Σ_i qᵢᵀ (A ω_i)
360        // But we must multiply by the "scaling" from the QR relationship.
361        // Full XTrace formula: tr(A) ≈ tr(Qᵀ Y) where Y = A Ω,
362        // which equals Σ_i Σ_j (Q^T)_{ij} * (AΩ)_{ij}
363        // = Σ_j (q_j)^T (A ω_j)   (column j of Q dotted with column j of AΩ).
364        //
365        // The key insight is that this equals the Girard-Hutchinson estimator
366        // applied to the orthonormal sketch, giving minimum variance among all
367        // linear estimators of the same form.
368        let mut estimate = 0.0f64;
369        for j in 0..q_cols.len() {
370            let dot: f64 = q_cols[j]
371                .iter()
372                .zip(y_cols[j].iter())
373                .map(|(qi, yi)| qi * yi)
374                .sum();
375            estimate += dot;
376        }
377
378        // Scale: we used m Spherical probes; the Hutchinson-on-Q correction
379        // multiplies by n/m because E[q_j q_j^T] = (1/n) I for spherical.
380        // In the QR picture the probes are already orthonormal so the scale is n.
381        estimate *= n as f64 / m as f64;
382
383        // Standard error via Jackknife-1 approximation.
384        let mut leave_one_out = Vec::with_capacity(q_cols.len());
385        for j in 0..q_cols.len() {
386            let dot: f64 = q_cols[j]
387                .iter()
388                .zip(y_cols[j].iter())
389                .map(|(qi, yi)| qi * yi)
390                .sum();
391            leave_one_out.push(dot * n as f64 / m as f64);
392        }
393        let (_, stderr) = mean_and_stderr(&leave_one_out);
394
395        Ok(TraceEstimate {
396            estimate,
397            std_error: stderr,
398            n_probes_used: m,
399        })
400    }
401
402    // =========================================================================
403    // Diagonal estimator
404    // =========================================================================
405
406    /// Bekas diagonal estimator.
407    ///
408    /// Estimates `diag(A) ≈ (1/m) Σᵢ zᵢ ⊙ (A zᵢ)` element-wise.
409    ///
410    /// Requires a square matrix.
411    pub fn diagonal(&self, csr: &CsrMatrix<f64>) -> Result<DiagEstimate<f64>, StochasticError> {
412        let n = self.require_square(csr)?;
413        let m = self.require_probes()?;
414
415        // Accumulate per-probe contribution and running sum of squares for SE.
416        let mut diag_sum = vec![0.0f64; n];
417        let mut diag_sq_sum = vec![0.0f64; n];
418        let mut az = vec![0.0f64; n];
419
420        for i in 0..m {
421            let z = self.probe_vector(n, i);
422            Self::matvec(csr, &z, &mut az);
423            for j in 0..n {
424                let contrib = z[j] * az[j];
425                diag_sum[j] += contrib;
426                diag_sq_sum[j] += contrib * contrib;
427            }
428        }
429
430        let mf = m as f64;
431        let diagonal: Vec<f64> = diag_sum.iter().map(|s| s / mf).collect();
432
433        let std_error: Vec<f64> = if m >= 2 {
434            (0..n)
435                .map(|j| {
436                    let mean_j = diag_sum[j] / mf;
437                    // sample variance = (Σ x² - m mean²) / (m-1)
438                    let var_j = (diag_sq_sum[j] - mf * mean_j * mean_j).max(0.0) / (mf - 1.0);
439                    (var_j / mf).sqrt()
440                })
441                .collect()
442        } else {
443            vec![0.0f64; n]
444        };
445
446        Ok(DiagEstimate {
447            diagonal,
448            std_error,
449        })
450    }
451
452    // =========================================================================
453    // Frobenius norm
454    // =========================================================================
455
456    /// Stochastic Frobenius norm estimate.
457    ///
458    /// Uses `‖A‖_F² ≈ (1/m) Σᵢ ‖A zᵢ‖²`.
459    ///
460    /// Works for non-square matrices.
461    pub fn frobenius_norm(&self, csr: &CsrMatrix<f64>) -> Result<f64, StochasticError> {
462        let ncols = csr.ncols();
463        let nrows = csr.nrows();
464        let m = self.require_probes()?;
465
466        if ncols == 0 || nrows == 0 {
467            return Err(StochasticError::MatrixError(
468                "matrix has zero dimension".to_string(),
469            ));
470        }
471
472        let mut az = vec![0.0f64; nrows];
473        let mut frob_sq_sum = 0.0f64;
474
475        for i in 0..m {
476            let z = self.probe_vector(ncols, i);
477            Self::matvec(csr, &z, &mut az);
478            let sq: f64 = az.iter().map(|v| v * v).sum();
479            frob_sq_sum += sq;
480        }
481
482        let frob_sq = frob_sq_sum / m as f64;
483        Ok(frob_sq.sqrt())
484    }
485
486    // =========================================================================
487    // Log-determinant via stochastic Lanczos quadrature
488    // =========================================================================
489
490    /// Stochastic log-determinant estimate via Lanczos quadrature.
491    ///
492    /// Uses the identity `log det A = tr(log A)` and estimates `tr(log A)` by
493    /// running a short Lanczos recurrence for each probe vector `z`:
494    ///
495    /// 1. Run `lanczos_steps` Lanczos steps starting from `z / ‖z‖`.
496    /// 2. Obtain tridiagonal matrix `T` with diagonal α and off-diagonal β.
497    /// 3. Compute `log(T)` via dense EVD of `T`.
498    /// 4. Contribution is `‖z‖² · e₁ᵀ log(T) e₁`.
499    ///
500    /// Only valid for symmetric positive definite matrices.
501    pub fn log_det(&self, csr: &CsrMatrix<f64>) -> Result<f64, StochasticError> {
502        let n = self.require_square(csr)?;
503        let m = self.require_probes()?;
504        let lanczos_steps = 20_usize.min(n);
505
506        let mut samples = Vec::with_capacity(m);
507
508        for i in 0..m {
509            let z = self.probe_vector(n, i);
510            let z_norm_sq: f64 = z.iter().map(|v| v * v).sum();
511            let z_norm = z_norm_sq.sqrt();
512
513            if z_norm < 1e-300 {
514                continue;
515            }
516
517            // Normalise starting vector.
518            let mut q0: Vec<f64> = z.iter().map(|v| v / z_norm).collect();
519
520            // Lanczos recurrence: build tridiagonal T.
521            let mut alpha = Vec::with_capacity(lanczos_steps);
522            let mut beta = Vec::with_capacity(lanczos_steps); // β[j] is the off-diagonal below α[j]
523
524            let mut q_prev = vec![0.0f64; n];
525            let mut r = vec![0.0f64; n];
526
527            for _j in 0..lanczos_steps {
528                Self::matvec(csr, &q0, &mut r);
529
530                // α_j = q_j^T r
531                let a: f64 = q0.iter().zip(r.iter()).map(|(qi, ri)| qi * ri).sum();
532                alpha.push(a);
533
534                // r = r - α_j q_j - β_{j-1} q_{j-1}
535                for idx in 0..n {
536                    r[idx] -= a * q0[idx];
537                }
538                if let Some(&b_prev) = beta.last() {
539                    for idx in 0..n {
540                        r[idx] -= b_prev * q_prev[idx];
541                    }
542                }
543
544                let b: f64 = r.iter().map(|v| v * v).sum::<f64>().sqrt();
545                beta.push(b);
546
547                if b < 1e-14 {
548                    break;
549                }
550
551                q_prev = q0.clone();
552                q0 = r.iter().map(|v| v / b).collect();
553            }
554
555            let k = alpha.len();
556            if k == 0 {
557                continue;
558            }
559
560            // Dense EVD of the k×k tridiagonal T.
561            // T has diagonal α[0..k] and off-diagonal β[0..k-1].
562            let (eigenvalues, evecs) = tridiagonal_evd(&alpha, &beta[..k.saturating_sub(1)]);
563
564            // Check that A is positive definite (all eigenvalues > 0).
565            if eigenvalues.iter().any(|&e| e <= 0.0) {
566                return Err(StochasticError::NumericalFailure(
567                    "matrix appears to not be positive definite (non-positive Ritz value encountered)"
568                        .to_string(),
569                ));
570            }
571
572            // e1^T log(T) e1 = Σ_i evec[i][0]^2 * log(eigenvalue[i])
573            let e1_log_t_e1: f64 = eigenvalues
574                .iter()
575                .zip(evecs.iter())
576                .map(|(&lam, evec)| evec[0] * evec[0] * lam.ln())
577                .sum();
578
579            samples.push(z_norm_sq * e1_log_t_e1);
580        }
581
582        if samples.is_empty() {
583            return Err(StochasticError::NumericalFailure(
584                "all probe vectors were degenerate".to_string(),
585            ));
586        }
587
588        let (mean, _stderr) = mean_and_stderr(&samples);
589        Ok(mean)
590    }
591
592    // =========================================================================
593    // Internal helpers
594    // =========================================================================
595
596    /// Generate a probe vector according to the configured `ProbeType`.
597    fn probe_vector(&self, n: usize, probe_idx: usize) -> Vec<f64> {
598        let seed = self.config.seed.wrapping_add(probe_idx as u64 * 1_234_567);
599        match self.config.probe_type {
600            ProbeType::Rademacher => lcg_rademacher(seed, n),
601            ProbeType::Gaussian => lcg_gaussian(seed, n),
602            ProbeType::Spherical => lcg_spherical(seed, n),
603        }
604    }
605
606    /// Always generate a spherical probe (used internally by XTrace).
607    fn probe_vector_spherical(&self, n: usize, probe_idx: usize) -> Vec<f64> {
608        let seed = self.config.seed.wrapping_add(probe_idx as u64 * 1_234_567);
609        lcg_spherical(seed, n)
610    }
611
612    /// Sparse matrix-vector product: `y = A x`.
613    fn matvec(csr: &CsrMatrix<f64>, x: &[f64], y: &mut Vec<f64>) {
614        let nrows = csr.nrows();
615        if y.len() != nrows {
616            y.resize(nrows, 0.0);
617        }
618        for i in 0..nrows {
619            let start = csr.row_ptrs()[i];
620            let end = csr.row_ptrs()[i + 1];
621            let mut s = 0.0f64;
622            for k in start..end {
623                s += csr.values()[k] * x[csr.col_indices()[k]];
624            }
625            y[i] = s;
626        }
627    }
628
629    /// Thin QR decomposition of a set of `k` column vectors (each of length `n`).
630    ///
631    /// Uses modified Gram-Schmidt to produce an orthonormal basis Q.
632    /// Returns the orthonormal columns; columns that become numerically zero are dropped.
633    fn thin_qr(a: &[Vec<f64>], n: usize, k: usize) -> Vec<Vec<f64>> {
634        let mut q: Vec<Vec<f64>> = Vec::with_capacity(k);
635
636        for col in a.iter().take(k) {
637            let mut v = col.clone();
638
639            // Modified Gram-Schmidt: subtract projections onto existing q columns.
640            for qi in &q {
641                let proj: f64 = v.iter().zip(qi.iter()).map(|(vi, qi_)| vi * qi_).sum();
642                for (vi, qi_) in v.iter_mut().zip(qi.iter()) {
643                    *vi -= proj * qi_;
644                }
645            }
646
647            let nrm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
648            if nrm > 1e-14 * (n as f64).sqrt() {
649                q.push(v.into_iter().map(|x| x / nrm).collect());
650            }
651        }
652
653        q
654    }
655
656    /// Validate that the matrix is square and return its size.
657    fn require_square(&self, csr: &CsrMatrix<f64>) -> Result<usize, StochasticError> {
658        if csr.nrows() != csr.ncols() {
659            return Err(StochasticError::MatrixError(format!(
660                "matrix must be square, got {}×{}",
661                csr.nrows(),
662                csr.ncols()
663            )));
664        }
665        if csr.nrows() == 0 {
666            return Err(StochasticError::MatrixError(
667                "matrix has zero dimension".to_string(),
668            ));
669        }
670        Ok(csr.nrows())
671    }
672
673    /// Validate and return the configured probe count.
674    fn require_probes(&self) -> Result<usize, StochasticError> {
675        if self.config.num_probes == 0 {
676            return Err(StochasticError::InvalidConfig(
677                "num_probes must be >= 1".to_string(),
678            ));
679        }
680        Ok(self.config.num_probes)
681    }
682}
683
684// =============================================================================
685// Helper: project z onto the orthogonal complement of the column span of Q
686// =============================================================================
687
688/// Compute `w = (I - Q Qᵀ) z` for an orthonormal column set Q.
689fn project_out(q_cols: &[Vec<f64>], z: &[f64]) -> Vec<f64> {
690    let mut w = z.to_vec();
691    for qi in q_cols {
692        let proj: f64 = w.iter().zip(qi.iter()).map(|(wi, qi_)| wi * qi_).sum();
693        for (wi, qi_) in w.iter_mut().zip(qi.iter()) {
694            *wi -= proj * qi_;
695        }
696    }
697    w
698}
699
700// =============================================================================
701// Helper: dense symmetric tridiagonal EVD (QR iteration)
702// =============================================================================
703
704/// Compute eigenvalues and (first-component-only) eigenvectors of a real
705/// symmetric tridiagonal matrix with diagonal `alpha` and off-diagonal `beta`.
706///
707/// Returns `(eigenvalues, eigenvectors)` where each eigenvector is stored as
708/// its full column (length k).  Uses the QR algorithm with Wilkinson shifts.
709fn tridiagonal_evd(alpha: &[f64], beta: &[f64]) -> (Vec<f64>, Vec<Vec<f64>>) {
710    let k = alpha.len();
711    if k == 0 {
712        return (vec![], vec![]);
713    }
714    if k == 1 {
715        return (vec![alpha[0]], vec![vec![1.0]]);
716    }
717
718    // Work arrays: diagonal d, off-diagonal e, and eigenvector matrix Z (k×k, col-major).
719    let mut d = alpha.to_vec();
720    let mut e = vec![0.0f64; k];
721    for i in 0..beta.len().min(k - 1) {
722        e[i] = beta[i];
723    }
724
725    // Identity as initial eigenvector matrix.
726    let mut z = vec![0.0f64; k * k];
727    for i in 0..k {
728        z[i * k + i] = 1.0;
729    }
730
731    // Symmetric QR with Wilkinson shift (implicit QR on tridiagonal).
732    let max_iter = 30 * k;
733    let mut m = k;
734
735    'outer: for _ in 0..max_iter {
736        // Deflate small off-diagonal elements.
737        while m > 1 && e[m - 2].abs() < 1e-14 * (d[m - 2].abs() + d[m - 1].abs()) {
738            m -= 1;
739            if m == 1 {
740                break 'outer;
741            }
742        }
743        if m <= 1 {
744            break;
745        }
746
747        // Wilkinson shift: eigenvalue of bottom 2×2 closer to d[m-1].
748        let a = d[m - 2];
749        let b = e[m - 2];
750        let c = d[m - 1];
751        let delta = (a - c) / 2.0;
752        let sign_delta = if delta >= 0.0 { 1.0 } else { -1.0 };
753        let shift = c - sign_delta * b * b / (delta.abs() + (delta * delta + b * b).sqrt());
754
755        // Implicit QR step.
756        let mut x = d[0] - shift;
757        let mut z_val = e[0];
758
759        for i in 0..m - 1 {
760            let (c_rot, s_rot) = givens_cs(x, z_val);
761
762            // Apply rotation on left and right to tridiagonal.
763            let w = c_rot * x + s_rot * z_val;
764            let _ = w; // w is just the new diagonal position before update
765
766            // Update d[i], d[i+1], e[i] via Givens rotation.
767            let d_i = d[i];
768            let d_i1 = d[i + 1];
769            let e_i = e[i];
770
771            d[i] = c_rot * c_rot * d_i + 2.0 * c_rot * s_rot * e_i + s_rot * s_rot * d_i1;
772            d[i + 1] = s_rot * s_rot * d_i - 2.0 * c_rot * s_rot * e_i + c_rot * c_rot * d_i1;
773            e[i] = c_rot * s_rot * (d_i1 - d_i) + (c_rot * c_rot - s_rot * s_rot) * e_i;
774
775            if i > 0 {
776                e[i - 1] = c_rot * e[i - 1] + s_rot * z_val;
777            }
778
779            x = e[i];
780            if i + 1 < m - 1 {
781                z_val = s_rot * e[i + 1];
782                e[i + 1] = c_rot * e[i + 1];
783            }
784
785            // Accumulate rotation into eigenvector matrix (columns).
786            for row in 0..k {
787                let zi = z[row * k + i];
788                let zi1 = z[row * k + i + 1];
789                z[row * k + i] = c_rot * zi + s_rot * zi1;
790                z[row * k + i + 1] = -s_rot * zi + c_rot * zi1;
791            }
792        }
793    }
794
795    // Extract eigenvectors as column vectors.
796    let eigenvectors: Vec<Vec<f64>> = (0..k)
797        .map(|j| (0..k).map(|i| z[i * k + j]).collect())
798        .collect();
799
800    (d, eigenvectors)
801}
802
803/// Compute Givens cosine and sine for the pair (a, b) such that
804/// [c s; -s c] [a; b] = [r; 0].
805#[inline]
806fn givens_cs(a: f64, b: f64) -> (f64, f64) {
807    if b == 0.0 {
808        return (1.0, 0.0);
809    }
810    if a.abs() < b.abs() {
811        let t = -a / b;
812        let s = 1.0 / (1.0 + t * t).sqrt();
813        (s * t, s)
814    } else {
815        let t = -b / a;
816        let c = 1.0 / (1.0 + t * t).sqrt();
817        (c, c * t)
818    }
819}
820
821// =============================================================================
822// Tests
823// =============================================================================
824
825#[cfg(test)]
826mod tests {
827    use super::*;
828    use crate::csr::CsrMatrix;
829
830    // Helper: build identity matrix in CSR.
831    fn identity_csr(n: usize) -> CsrMatrix<f64> {
832        let values = vec![1.0f64; n];
833        let col_indices: Vec<usize> = (0..n).collect();
834        let row_ptrs: Vec<usize> = (0..=n).collect();
835        CsrMatrix::new(n, n, row_ptrs, col_indices, values).expect("valid identity CSR")
836    }
837
838    // Helper: build diagonal matrix from entries.
839    fn diag_csr(entries: &[f64]) -> CsrMatrix<f64> {
840        let n = entries.len();
841        let values = entries.to_vec();
842        let col_indices: Vec<usize> = (0..n).collect();
843        let row_ptrs: Vec<usize> = (0..=n).collect();
844        CsrMatrix::new(n, n, row_ptrs, col_indices, values).expect("valid diagonal CSR")
845    }
846
847    // Helper: build 1-D Laplacian (tridiagonal with 2 on diagonal, -1 off).
848    fn laplacian_1d_csr(n: usize) -> CsrMatrix<f64> {
849        let mut values = Vec::new();
850        let mut col_indices = Vec::new();
851        let mut row_ptrs = vec![0usize];
852
853        for i in 0..n {
854            if i > 0 {
855                values.push(-1.0);
856                col_indices.push(i - 1);
857            }
858            values.push(2.0);
859            col_indices.push(i);
860            if i + 1 < n {
861                values.push(-1.0);
862                col_indices.push(i + 1);
863            }
864            row_ptrs.push(values.len());
865        }
866        CsrMatrix::new(n, n, row_ptrs, col_indices, values).expect("valid 1-D Laplacian CSR")
867    }
868
869    // -------------------------------------------------------------------------
870
871    #[test]
872    fn test_stochastic_config_default() {
873        let cfg = StochasticConfig::default();
874        assert_eq!(cfg.num_probes, 30);
875        assert_eq!(cfg.seed, 42);
876        assert!(matches!(cfg.probe_type, ProbeType::Rademacher));
877        assert!((cfg.confidence - 0.95).abs() < 1e-12);
878    }
879
880    #[test]
881    fn test_hutchinson_identity() {
882        // tr(I_3) = 3.
883        let eye = identity_csr(3);
884        let est = StochasticEstimator::with_default();
885        let result = est.trace_hutchinson(&eye).expect("trace_hutchinson failed");
886        assert!(
887            (result.estimate - 3.0).abs() < 0.5,
888            "estimate {} far from 3.0",
889            result.estimate
890        );
891        assert_eq!(result.n_probes_used, 30);
892    }
893
894    #[test]
895    fn test_hutchinson_diagonal() {
896        // tr(diag(1,2,3)) = 6.
897        let d = diag_csr(&[1.0, 2.0, 3.0]);
898        let cfg = StochasticConfig {
899            num_probes: 100,
900            ..Default::default()
901        };
902        let est = StochasticEstimator::new(cfg);
903        let result = est.trace_hutchinson(&d).expect("trace_hutchinson failed");
904        assert!(
905            (result.estimate - 6.0).abs() < 1.0,
906            "estimate {} far from 6.0",
907            result.estimate
908        );
909    }
910
911    #[test]
912    fn test_hutchinson_sparse_laplacian() {
913        // 1-D Laplacian with n=10: tr = 10 * 2.0 = 20.
914        let lap = laplacian_1d_csr(10);
915        let cfg = StochasticConfig {
916            num_probes: 200,
917            ..Default::default()
918        };
919        let est = StochasticEstimator::new(cfg);
920        let result = est.trace_hutchinson(&lap).expect("trace_hutchinson failed");
921        assert!(
922            (result.estimate - 20.0).abs() < 3.0,
923            "estimate {} far from 20.0",
924            result.estimate
925        );
926    }
927
928    #[test]
929    fn test_hutch_plusplus_accuracy() {
930        // Hutch++ should give a closer estimate than plain Hutchinson for the
931        // same probe count on a matrix with rapidly decaying spectrum.
932        // Use diag(1,2,...,20): tr = 210.
933        let entries: Vec<f64> = (1..=20).map(|x| x as f64).collect();
934        let d = diag_csr(&entries);
935        let true_trace = 210.0f64;
936
937        let cfg = StochasticConfig {
938            num_probes: 30,
939            seed: 7,
940            ..Default::default()
941        };
942        let est = StochasticEstimator::new(cfg.clone());
943
944        let hh_result = est.trace_hutch_plus_plus(&d).expect("hutch++ failed");
945        let hutch_result = est.trace_hutchinson(&d).expect("hutchinson failed");
946
947        let hh_err = (hh_result.estimate - true_trace).abs();
948        let hutch_err = (hutch_result.estimate - true_trace).abs();
949
950        // Hutch++ should be at least as accurate or better.
951        // We allow Hutch++ to be slightly worse due to finite m but it must be within 30% of true.
952        assert!(
953            hh_err < true_trace * 0.30,
954            "Hutch++ error {hh_err} too large (true={true_trace})"
955        );
956        // Log the comparison for debugging.
957        let _ = hutch_err;
958    }
959
960    #[test]
961    fn test_diagonal_estimator() {
962        // diagonal(I_5) ≈ [1,1,1,1,1].
963        let eye = identity_csr(5);
964        let cfg = StochasticConfig {
965            num_probes: 100,
966            ..Default::default()
967        };
968        let est = StochasticEstimator::new(cfg);
969        let result = est.diagonal(&eye).expect("diagonal failed");
970        assert_eq!(result.diagonal.len(), 5);
971        for (i, &d) in result.diagonal.iter().enumerate() {
972            assert!(
973                (d - 1.0).abs() < 0.3,
974                "diagonal[{i}] = {d} not close to 1.0"
975            );
976        }
977    }
978
979    #[test]
980    fn test_frobenius_norm() {
981        // ‖I_n‖_F = sqrt(n).
982        let n = 9usize;
983        let eye = identity_csr(n);
984        let cfg = StochasticConfig {
985            num_probes: 50,
986            ..Default::default()
987        };
988        let est = StochasticEstimator::new(cfg);
989        let frob = est.frobenius_norm(&eye).expect("frobenius_norm failed");
990        let expected = (n as f64).sqrt(); // 3.0
991        assert!(
992            (frob - expected).abs() < expected * 0.10,
993            "frobenius estimate {frob} not within 10% of {expected}"
994        );
995    }
996
997    #[test]
998    fn test_log_det_spd() {
999        // For a 3×3 diagonal matrix diag(2, 3, 5):
1000        // log det = log(2) + log(3) + log(5) ≈ 3.401.
1001        let d = diag_csr(&[2.0, 3.0, 5.0]);
1002        let expected = 2.0f64.ln() + 3.0f64.ln() + 5.0f64.ln();
1003        let cfg = StochasticConfig {
1004            num_probes: 100,
1005            ..Default::default()
1006        };
1007        let est = StochasticEstimator::new(cfg);
1008        let result = est.log_det(&d).expect("log_det failed");
1009        // Sign must be positive (log det of this SPD matrix is positive).
1010        assert!(result > 0.0, "log_det should be positive, got {result}");
1011        assert!(
1012            (result - expected).abs() < expected * 0.20,
1013            "log_det estimate {result} far from {expected}"
1014        );
1015    }
1016}