1use crate::csr::CsrMatrix;
24use std::fmt;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum ProbeType {
33 Rademacher,
35 Gaussian,
37 Spherical,
39}
40
41#[derive(Debug, Clone)]
43pub struct StochasticConfig {
44 pub num_probes: usize,
46 pub seed: u64,
48 pub probe_type: ProbeType,
50 pub confidence: f64,
52}
53
54impl Default for StochasticConfig {
55 fn default() -> Self {
56 Self {
57 num_probes: 30,
58 seed: 42,
59 probe_type: ProbeType::Rademacher,
60 confidence: 0.95,
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
71pub struct TraceEstimate {
72 pub estimate: f64,
74 pub std_error: f64,
76 pub n_probes_used: usize,
78}
79
80#[derive(Debug, Clone)]
82pub struct DiagEstimate<T> {
83 pub diagonal: Vec<T>,
85 pub std_error: Vec<T>,
87}
88
89#[derive(Debug, Clone)]
95pub enum StochasticError {
96 InvalidConfig(String),
98 MatrixError(String),
100 NumericalFailure(String),
102}
103
104impl fmt::Display for StochasticError {
105 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106 match self {
107 Self::InvalidConfig(msg) => write!(f, "Invalid stochastic config: {msg}"),
108 Self::MatrixError(msg) => write!(f, "Matrix error in stochastic estimator: {msg}"),
109 Self::NumericalFailure(msg) => {
110 write!(f, "Numerical failure in stochastic estimator: {msg}")
111 }
112 }
113 }
114}
115
116impl std::error::Error for StochasticError {}
117
118#[inline]
126fn lcg_next(state: &mut u64) -> u64 {
127 *state = state
128 .wrapping_mul(6_364_136_223_846_793_005)
129 .wrapping_add(1_442_695_040_888_963_407);
130 *state
131}
132
133fn lcg_rademacher(seed: u64, n: usize) -> Vec<f64> {
135 let mut state = seed;
136 (0..n)
137 .map(|_| {
138 let v = lcg_next(&mut state);
139 if v >> 63 == 0 { 1.0 } else { -1.0 }
140 })
141 .collect()
142}
143
144fn lcg_gaussian(seed: u64, n: usize) -> Vec<f64> {
146 let mut state = seed;
147 let mut out = Vec::with_capacity(n);
148 let mut spare: Option<f64> = None;
149
150 for _ in 0..n {
151 if let Some(s) = spare.take() {
152 out.push(s);
153 } else {
154 let u1 = loop {
156 let v = lcg_next(&mut state);
157 let f = (v as f64) / (u64::MAX as f64);
158 if f > 0.0 {
159 break f;
160 }
161 };
162 let u2 = (lcg_next(&mut state) as f64) / (u64::MAX as f64);
163 let mag = (-2.0 * u1.ln()).sqrt();
164 let theta = 2.0 * std::f64::consts::PI * u2;
165 out.push(mag * theta.cos());
166 spare = Some(mag * theta.sin());
167 }
168 }
169 out
170}
171
172fn lcg_spherical(seed: u64, n: usize) -> Vec<f64> {
174 let mut v = lcg_gaussian(seed, n);
175 let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
176 if norm > 0.0 {
177 for x in &mut v {
178 *x /= norm;
179 }
180 }
181 v
182}
183
184fn mean_and_stderr(samples: &[f64]) -> (f64, f64) {
192 let m = samples.len();
193 if m == 0 {
194 return (0.0, 0.0);
195 }
196 let mean = samples.iter().sum::<f64>() / m as f64;
197 if m < 2 {
198 return (mean, 0.0);
199 }
200 let var = samples.iter().map(|q| (q - mean).powi(2)).sum::<f64>() / (m - 1) as f64;
201 let stderr = (var / m as f64).sqrt();
202 (mean, stderr)
203}
204
205pub struct StochasticEstimator {
214 config: StochasticConfig,
215}
216
217impl StochasticEstimator {
218 pub fn new(config: StochasticConfig) -> Self {
220 Self { config }
221 }
222
223 pub fn with_default() -> Self {
225 Self::new(StochasticConfig::default())
226 }
227
228 pub fn trace_hutchinson(&self, csr: &CsrMatrix<f64>) -> Result<TraceEstimate, StochasticError> {
238 let n = self.require_square(csr)?;
239 let m = self.require_probes()?;
240
241 let mut samples = Vec::with_capacity(m);
242 let mut az = vec![0.0f64; n];
243
244 for i in 0..m {
245 let z = self.probe_vector(n, i);
246 Self::matvec(csr, &z, &mut az);
247 let quad: f64 = z.iter().zip(az.iter()).map(|(zi, ai)| zi * ai).sum();
248 samples.push(quad);
249 }
250
251 let (mean, stderr) = mean_and_stderr(&samples);
252 Ok(TraceEstimate {
253 estimate: mean,
254 std_error: stderr,
255 n_probes_used: m,
256 })
257 }
258
259 pub fn trace_hutch_plus_plus(
267 &self,
268 csr: &CsrMatrix<f64>,
269 ) -> Result<TraceEstimate, StochasticError> {
270 let n = self.require_square(csr)?;
271 let m = self.require_probes()?;
272
273 let k = (m / 3).max(1); let m_stoch = m - k; let mut sketch_cols: Vec<Vec<f64>> = Vec::with_capacity(k);
278 for i in 0..k {
279 let s = self.probe_vector(n, i);
280 let mut col = vec![0.0f64; n];
281 Self::matvec(csr, &s, &mut col);
282 sketch_cols.push(col);
283 }
284
285 let q_cols = Self::thin_qr(&sketch_cols, n, k);
287
288 let mut det_trace = 0.0f64;
290 let mut aqj = vec![0.0f64; n];
291 for col in &q_cols {
292 Self::matvec(csr, col, &mut aqj);
293 det_trace += col
294 .iter()
295 .zip(aqj.iter())
296 .map(|(qi, aqji)| qi * aqji)
297 .sum::<f64>();
298 }
299
300 let mut samples = Vec::with_capacity(m_stoch);
302 let mut az = vec![0.0f64; n];
303
304 for i in 0..m_stoch {
305 let z = self.probe_vector(n, k + i);
307 let w = project_out(&q_cols, &z);
309 Self::matvec(csr, &w, &mut az);
311 let quad: f64 = w.iter().zip(az.iter()).map(|(wi, ai)| wi * ai).sum();
314 samples.push(quad);
315 }
316
317 let (stoch_mean, stoch_stderr) = mean_and_stderr(&samples);
318 let estimate = det_trace + stoch_mean;
319
320 Ok(TraceEstimate {
322 estimate,
323 std_error: stoch_stderr,
324 n_probes_used: m,
325 })
326 }
327
328 pub fn trace_xtrace(&self, csr: &CsrMatrix<f64>) -> Result<TraceEstimate, StochasticError> {
341 let n = self.require_square(csr)?;
342 let m = self.require_probes()?;
343
344 let mut omega: Vec<Vec<f64>> = Vec::with_capacity(m);
346 let mut y_cols: Vec<Vec<f64>> = Vec::with_capacity(m);
347 let mut az = vec![0.0f64; n];
348
349 for i in 0..m {
350 let z = self.probe_vector_spherical(n, i);
351 Self::matvec(csr, &z, &mut az);
352 y_cols.push(az.clone());
353 omega.push(z);
354 }
355
356 let q_cols = Self::thin_qr(&omega, n, m);
358
359 let mut estimate = 0.0f64;
369 for j in 0..q_cols.len() {
370 let dot: f64 = q_cols[j]
371 .iter()
372 .zip(y_cols[j].iter())
373 .map(|(qi, yi)| qi * yi)
374 .sum();
375 estimate += dot;
376 }
377
378 estimate *= n as f64 / m as f64;
382
383 let mut leave_one_out = Vec::with_capacity(q_cols.len());
385 for j in 0..q_cols.len() {
386 let dot: f64 = q_cols[j]
387 .iter()
388 .zip(y_cols[j].iter())
389 .map(|(qi, yi)| qi * yi)
390 .sum();
391 leave_one_out.push(dot * n as f64 / m as f64);
392 }
393 let (_, stderr) = mean_and_stderr(&leave_one_out);
394
395 Ok(TraceEstimate {
396 estimate,
397 std_error: stderr,
398 n_probes_used: m,
399 })
400 }
401
402 pub fn diagonal(&self, csr: &CsrMatrix<f64>) -> Result<DiagEstimate<f64>, StochasticError> {
412 let n = self.require_square(csr)?;
413 let m = self.require_probes()?;
414
415 let mut diag_sum = vec![0.0f64; n];
417 let mut diag_sq_sum = vec![0.0f64; n];
418 let mut az = vec![0.0f64; n];
419
420 for i in 0..m {
421 let z = self.probe_vector(n, i);
422 Self::matvec(csr, &z, &mut az);
423 for j in 0..n {
424 let contrib = z[j] * az[j];
425 diag_sum[j] += contrib;
426 diag_sq_sum[j] += contrib * contrib;
427 }
428 }
429
430 let mf = m as f64;
431 let diagonal: Vec<f64> = diag_sum.iter().map(|s| s / mf).collect();
432
433 let std_error: Vec<f64> = if m >= 2 {
434 (0..n)
435 .map(|j| {
436 let mean_j = diag_sum[j] / mf;
437 let var_j = (diag_sq_sum[j] - mf * mean_j * mean_j).max(0.0) / (mf - 1.0);
439 (var_j / mf).sqrt()
440 })
441 .collect()
442 } else {
443 vec![0.0f64; n]
444 };
445
446 Ok(DiagEstimate {
447 diagonal,
448 std_error,
449 })
450 }
451
452 pub fn frobenius_norm(&self, csr: &CsrMatrix<f64>) -> Result<f64, StochasticError> {
462 let ncols = csr.ncols();
463 let nrows = csr.nrows();
464 let m = self.require_probes()?;
465
466 if ncols == 0 || nrows == 0 {
467 return Err(StochasticError::MatrixError(
468 "matrix has zero dimension".to_string(),
469 ));
470 }
471
472 let mut az = vec![0.0f64; nrows];
473 let mut frob_sq_sum = 0.0f64;
474
475 for i in 0..m {
476 let z = self.probe_vector(ncols, i);
477 Self::matvec(csr, &z, &mut az);
478 let sq: f64 = az.iter().map(|v| v * v).sum();
479 frob_sq_sum += sq;
480 }
481
482 let frob_sq = frob_sq_sum / m as f64;
483 Ok(frob_sq.sqrt())
484 }
485
486 pub fn log_det(&self, csr: &CsrMatrix<f64>) -> Result<f64, StochasticError> {
502 let n = self.require_square(csr)?;
503 let m = self.require_probes()?;
504 let lanczos_steps = 20_usize.min(n);
505
506 let mut samples = Vec::with_capacity(m);
507
508 for i in 0..m {
509 let z = self.probe_vector(n, i);
510 let z_norm_sq: f64 = z.iter().map(|v| v * v).sum();
511 let z_norm = z_norm_sq.sqrt();
512
513 if z_norm < 1e-300 {
514 continue;
515 }
516
517 let mut q0: Vec<f64> = z.iter().map(|v| v / z_norm).collect();
519
520 let mut alpha = Vec::with_capacity(lanczos_steps);
522 let mut beta = Vec::with_capacity(lanczos_steps); let mut q_prev = vec![0.0f64; n];
525 let mut r = vec![0.0f64; n];
526
527 for _j in 0..lanczos_steps {
528 Self::matvec(csr, &q0, &mut r);
529
530 let a: f64 = q0.iter().zip(r.iter()).map(|(qi, ri)| qi * ri).sum();
532 alpha.push(a);
533
534 for idx in 0..n {
536 r[idx] -= a * q0[idx];
537 }
538 if let Some(&b_prev) = beta.last() {
539 for idx in 0..n {
540 r[idx] -= b_prev * q_prev[idx];
541 }
542 }
543
544 let b: f64 = r.iter().map(|v| v * v).sum::<f64>().sqrt();
545 beta.push(b);
546
547 if b < 1e-14 {
548 break;
549 }
550
551 q_prev = q0.clone();
552 q0 = r.iter().map(|v| v / b).collect();
553 }
554
555 let k = alpha.len();
556 if k == 0 {
557 continue;
558 }
559
560 let (eigenvalues, evecs) = tridiagonal_evd(&alpha, &beta[..k.saturating_sub(1)]);
563
564 if eigenvalues.iter().any(|&e| e <= 0.0) {
566 return Err(StochasticError::NumericalFailure(
567 "matrix appears to not be positive definite (non-positive Ritz value encountered)"
568 .to_string(),
569 ));
570 }
571
572 let e1_log_t_e1: f64 = eigenvalues
574 .iter()
575 .zip(evecs.iter())
576 .map(|(&lam, evec)| evec[0] * evec[0] * lam.ln())
577 .sum();
578
579 samples.push(z_norm_sq * e1_log_t_e1);
580 }
581
582 if samples.is_empty() {
583 return Err(StochasticError::NumericalFailure(
584 "all probe vectors were degenerate".to_string(),
585 ));
586 }
587
588 let (mean, _stderr) = mean_and_stderr(&samples);
589 Ok(mean)
590 }
591
592 fn probe_vector(&self, n: usize, probe_idx: usize) -> Vec<f64> {
598 let seed = self.config.seed.wrapping_add(probe_idx as u64 * 1_234_567);
599 match self.config.probe_type {
600 ProbeType::Rademacher => lcg_rademacher(seed, n),
601 ProbeType::Gaussian => lcg_gaussian(seed, n),
602 ProbeType::Spherical => lcg_spherical(seed, n),
603 }
604 }
605
606 fn probe_vector_spherical(&self, n: usize, probe_idx: usize) -> Vec<f64> {
608 let seed = self.config.seed.wrapping_add(probe_idx as u64 * 1_234_567);
609 lcg_spherical(seed, n)
610 }
611
612 fn matvec(csr: &CsrMatrix<f64>, x: &[f64], y: &mut Vec<f64>) {
614 let nrows = csr.nrows();
615 if y.len() != nrows {
616 y.resize(nrows, 0.0);
617 }
618 for i in 0..nrows {
619 let start = csr.row_ptrs()[i];
620 let end = csr.row_ptrs()[i + 1];
621 let mut s = 0.0f64;
622 for k in start..end {
623 s += csr.values()[k] * x[csr.col_indices()[k]];
624 }
625 y[i] = s;
626 }
627 }
628
629 fn thin_qr(a: &[Vec<f64>], n: usize, k: usize) -> Vec<Vec<f64>> {
634 let mut q: Vec<Vec<f64>> = Vec::with_capacity(k);
635
636 for col in a.iter().take(k) {
637 let mut v = col.clone();
638
639 for qi in &q {
641 let proj: f64 = v.iter().zip(qi.iter()).map(|(vi, qi_)| vi * qi_).sum();
642 for (vi, qi_) in v.iter_mut().zip(qi.iter()) {
643 *vi -= proj * qi_;
644 }
645 }
646
647 let nrm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
648 if nrm > 1e-14 * (n as f64).sqrt() {
649 q.push(v.into_iter().map(|x| x / nrm).collect());
650 }
651 }
652
653 q
654 }
655
656 fn require_square(&self, csr: &CsrMatrix<f64>) -> Result<usize, StochasticError> {
658 if csr.nrows() != csr.ncols() {
659 return Err(StochasticError::MatrixError(format!(
660 "matrix must be square, got {}×{}",
661 csr.nrows(),
662 csr.ncols()
663 )));
664 }
665 if csr.nrows() == 0 {
666 return Err(StochasticError::MatrixError(
667 "matrix has zero dimension".to_string(),
668 ));
669 }
670 Ok(csr.nrows())
671 }
672
673 fn require_probes(&self) -> Result<usize, StochasticError> {
675 if self.config.num_probes == 0 {
676 return Err(StochasticError::InvalidConfig(
677 "num_probes must be >= 1".to_string(),
678 ));
679 }
680 Ok(self.config.num_probes)
681 }
682}
683
684fn project_out(q_cols: &[Vec<f64>], z: &[f64]) -> Vec<f64> {
690 let mut w = z.to_vec();
691 for qi in q_cols {
692 let proj: f64 = w.iter().zip(qi.iter()).map(|(wi, qi_)| wi * qi_).sum();
693 for (wi, qi_) in w.iter_mut().zip(qi.iter()) {
694 *wi -= proj * qi_;
695 }
696 }
697 w
698}
699
700fn tridiagonal_evd(alpha: &[f64], beta: &[f64]) -> (Vec<f64>, Vec<Vec<f64>>) {
710 let k = alpha.len();
711 if k == 0 {
712 return (vec![], vec![]);
713 }
714 if k == 1 {
715 return (vec![alpha[0]], vec![vec![1.0]]);
716 }
717
718 let mut d = alpha.to_vec();
720 let mut e = vec![0.0f64; k];
721 for i in 0..beta.len().min(k - 1) {
722 e[i] = beta[i];
723 }
724
725 let mut z = vec![0.0f64; k * k];
727 for i in 0..k {
728 z[i * k + i] = 1.0;
729 }
730
731 let max_iter = 30 * k;
733 let mut m = k;
734
735 'outer: for _ in 0..max_iter {
736 while m > 1 && e[m - 2].abs() < 1e-14 * (d[m - 2].abs() + d[m - 1].abs()) {
738 m -= 1;
739 if m == 1 {
740 break 'outer;
741 }
742 }
743 if m <= 1 {
744 break;
745 }
746
747 let a = d[m - 2];
749 let b = e[m - 2];
750 let c = d[m - 1];
751 let delta = (a - c) / 2.0;
752 let sign_delta = if delta >= 0.0 { 1.0 } else { -1.0 };
753 let shift = c - sign_delta * b * b / (delta.abs() + (delta * delta + b * b).sqrt());
754
755 let mut x = d[0] - shift;
757 let mut z_val = e[0];
758
759 for i in 0..m - 1 {
760 let (c_rot, s_rot) = givens_cs(x, z_val);
761
762 let w = c_rot * x + s_rot * z_val;
764 let _ = w; let d_i = d[i];
768 let d_i1 = d[i + 1];
769 let e_i = e[i];
770
771 d[i] = c_rot * c_rot * d_i + 2.0 * c_rot * s_rot * e_i + s_rot * s_rot * d_i1;
772 d[i + 1] = s_rot * s_rot * d_i - 2.0 * c_rot * s_rot * e_i + c_rot * c_rot * d_i1;
773 e[i] = c_rot * s_rot * (d_i1 - d_i) + (c_rot * c_rot - s_rot * s_rot) * e_i;
774
775 if i > 0 {
776 e[i - 1] = c_rot * e[i - 1] + s_rot * z_val;
777 }
778
779 x = e[i];
780 if i + 1 < m - 1 {
781 z_val = s_rot * e[i + 1];
782 e[i + 1] = c_rot * e[i + 1];
783 }
784
785 for row in 0..k {
787 let zi = z[row * k + i];
788 let zi1 = z[row * k + i + 1];
789 z[row * k + i] = c_rot * zi + s_rot * zi1;
790 z[row * k + i + 1] = -s_rot * zi + c_rot * zi1;
791 }
792 }
793 }
794
795 let eigenvectors: Vec<Vec<f64>> = (0..k)
797 .map(|j| (0..k).map(|i| z[i * k + j]).collect())
798 .collect();
799
800 (d, eigenvectors)
801}
802
803#[inline]
806fn givens_cs(a: f64, b: f64) -> (f64, f64) {
807 if b == 0.0 {
808 return (1.0, 0.0);
809 }
810 if a.abs() < b.abs() {
811 let t = -a / b;
812 let s = 1.0 / (1.0 + t * t).sqrt();
813 (s * t, s)
814 } else {
815 let t = -b / a;
816 let c = 1.0 / (1.0 + t * t).sqrt();
817 (c, c * t)
818 }
819}
820
821#[cfg(test)]
826mod tests {
827 use super::*;
828 use crate::csr::CsrMatrix;
829
830 fn identity_csr(n: usize) -> CsrMatrix<f64> {
832 let values = vec![1.0f64; n];
833 let col_indices: Vec<usize> = (0..n).collect();
834 let row_ptrs: Vec<usize> = (0..=n).collect();
835 CsrMatrix::new(n, n, row_ptrs, col_indices, values).expect("valid identity CSR")
836 }
837
838 fn diag_csr(entries: &[f64]) -> CsrMatrix<f64> {
840 let n = entries.len();
841 let values = entries.to_vec();
842 let col_indices: Vec<usize> = (0..n).collect();
843 let row_ptrs: Vec<usize> = (0..=n).collect();
844 CsrMatrix::new(n, n, row_ptrs, col_indices, values).expect("valid diagonal CSR")
845 }
846
847 fn laplacian_1d_csr(n: usize) -> CsrMatrix<f64> {
849 let mut values = Vec::new();
850 let mut col_indices = Vec::new();
851 let mut row_ptrs = vec![0usize];
852
853 for i in 0..n {
854 if i > 0 {
855 values.push(-1.0);
856 col_indices.push(i - 1);
857 }
858 values.push(2.0);
859 col_indices.push(i);
860 if i + 1 < n {
861 values.push(-1.0);
862 col_indices.push(i + 1);
863 }
864 row_ptrs.push(values.len());
865 }
866 CsrMatrix::new(n, n, row_ptrs, col_indices, values).expect("valid 1-D Laplacian CSR")
867 }
868
869 #[test]
872 fn test_stochastic_config_default() {
873 let cfg = StochasticConfig::default();
874 assert_eq!(cfg.num_probes, 30);
875 assert_eq!(cfg.seed, 42);
876 assert!(matches!(cfg.probe_type, ProbeType::Rademacher));
877 assert!((cfg.confidence - 0.95).abs() < 1e-12);
878 }
879
880 #[test]
881 fn test_hutchinson_identity() {
882 let eye = identity_csr(3);
884 let est = StochasticEstimator::with_default();
885 let result = est.trace_hutchinson(&eye).expect("trace_hutchinson failed");
886 assert!(
887 (result.estimate - 3.0).abs() < 0.5,
888 "estimate {} far from 3.0",
889 result.estimate
890 );
891 assert_eq!(result.n_probes_used, 30);
892 }
893
894 #[test]
895 fn test_hutchinson_diagonal() {
896 let d = diag_csr(&[1.0, 2.0, 3.0]);
898 let cfg = StochasticConfig {
899 num_probes: 100,
900 ..Default::default()
901 };
902 let est = StochasticEstimator::new(cfg);
903 let result = est.trace_hutchinson(&d).expect("trace_hutchinson failed");
904 assert!(
905 (result.estimate - 6.0).abs() < 1.0,
906 "estimate {} far from 6.0",
907 result.estimate
908 );
909 }
910
911 #[test]
912 fn test_hutchinson_sparse_laplacian() {
913 let lap = laplacian_1d_csr(10);
915 let cfg = StochasticConfig {
916 num_probes: 200,
917 ..Default::default()
918 };
919 let est = StochasticEstimator::new(cfg);
920 let result = est.trace_hutchinson(&lap).expect("trace_hutchinson failed");
921 assert!(
922 (result.estimate - 20.0).abs() < 3.0,
923 "estimate {} far from 20.0",
924 result.estimate
925 );
926 }
927
928 #[test]
929 fn test_hutch_plusplus_accuracy() {
930 let entries: Vec<f64> = (1..=20).map(|x| x as f64).collect();
934 let d = diag_csr(&entries);
935 let true_trace = 210.0f64;
936
937 let cfg = StochasticConfig {
938 num_probes: 30,
939 seed: 7,
940 ..Default::default()
941 };
942 let est = StochasticEstimator::new(cfg.clone());
943
944 let hh_result = est.trace_hutch_plus_plus(&d).expect("hutch++ failed");
945 let hutch_result = est.trace_hutchinson(&d).expect("hutchinson failed");
946
947 let hh_err = (hh_result.estimate - true_trace).abs();
948 let hutch_err = (hutch_result.estimate - true_trace).abs();
949
950 assert!(
953 hh_err < true_trace * 0.30,
954 "Hutch++ error {hh_err} too large (true={true_trace})"
955 );
956 let _ = hutch_err;
958 }
959
960 #[test]
961 fn test_diagonal_estimator() {
962 let eye = identity_csr(5);
964 let cfg = StochasticConfig {
965 num_probes: 100,
966 ..Default::default()
967 };
968 let est = StochasticEstimator::new(cfg);
969 let result = est.diagonal(&eye).expect("diagonal failed");
970 assert_eq!(result.diagonal.len(), 5);
971 for (i, &d) in result.diagonal.iter().enumerate() {
972 assert!(
973 (d - 1.0).abs() < 0.3,
974 "diagonal[{i}] = {d} not close to 1.0"
975 );
976 }
977 }
978
979 #[test]
980 fn test_frobenius_norm() {
981 let n = 9usize;
983 let eye = identity_csr(n);
984 let cfg = StochasticConfig {
985 num_probes: 50,
986 ..Default::default()
987 };
988 let est = StochasticEstimator::new(cfg);
989 let frob = est.frobenius_norm(&eye).expect("frobenius_norm failed");
990 let expected = (n as f64).sqrt(); assert!(
992 (frob - expected).abs() < expected * 0.10,
993 "frobenius estimate {frob} not within 10% of {expected}"
994 );
995 }
996
997 #[test]
998 fn test_log_det_spd() {
999 let d = diag_csr(&[2.0, 3.0, 5.0]);
1002 let expected = 2.0f64.ln() + 3.0f64.ln() + 5.0f64.ln();
1003 let cfg = StochasticConfig {
1004 num_probes: 100,
1005 ..Default::default()
1006 };
1007 let est = StochasticEstimator::new(cfg);
1008 let result = est.log_det(&d).expect("log_det failed");
1009 assert!(result > 0.0, "log_det should be positive, got {result}");
1011 assert!(
1012 (result - expected).abs() < expected * 0.20,
1013 "log_det estimate {result} far from {expected}"
1014 );
1015 }
1016}