1use std::fmt;
31
32use crate::error::PcError;
33use crate::matrix::Matrix;
34
35#[derive(Debug)]
49pub enum SvdError {
50 Convergence {
52 size: usize,
54 iterations: usize,
56 },
57 InvalidInput {
59 reason: String,
61 },
62}
63
64impl fmt::Display for SvdError {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 match self {
67 SvdError::Convergence { size, iterations } => {
68 write!(
69 f,
70 "SVD failed to converge for matrix of size {size} \
71 after {iterations} iterations"
72 )
73 }
74 SvdError::InvalidInput { reason } => {
75 write!(f, "SVD invalid input: {reason}")
76 }
77 }
78 }
79}
80
81impl std::error::Error for SvdError {}
82
83impl From<SvdError> for PcError {
84 fn from(e: SvdError) -> Self {
85 PcError::ConfigValidation(e.to_string())
86 }
87}
88
89#[derive(Debug, Clone)]
109pub struct GolubKahanSvd {
110 pub tol: f64,
112 pub max_iter_factor: usize,
114}
115
116impl GolubKahanSvd {
117 pub fn new() -> Self {
130 Self {
131 tol: 1e-14,
132 max_iter_factor: 30,
133 }
134 }
135
136 pub fn with_tolerance(mut self, tol: f64) -> Self {
151 self.tol = tol;
152 self
153 }
154
155 pub fn with_max_iter_factor(mut self, factor: usize) -> Self {
172 self.max_iter_factor = factor;
173 self
174 }
175
176 pub fn compute(&self, a: &Matrix) -> Result<(Matrix, Vec<f64>, Matrix), SvdError> {
205 for &val in &a.data {
207 if val.is_nan() || val.is_infinite() {
208 return Err(SvdError::InvalidInput {
209 reason: "matrix contains NaN or Inf".to_string(),
210 });
211 }
212 }
213
214 let m = a.rows;
215 let n = a.cols;
216
217 if m == 0 || n == 0 {
219 return Ok((Matrix::zeros(m, 0), Vec::new(), Matrix::zeros(n, 0)));
220 }
221
222 let transposed = m < n;
224 let (work_m, work_n, work_data) = if transposed {
225 let mut t = vec![0.0; m * n];
226 for r in 0..m {
227 for c in 0..n {
228 t[c * m + r] = a.data[r * n + c];
229 }
230 }
231 (n, m, t)
232 } else {
233 (m, n, a.data.clone())
234 };
235
236 let k = work_n;
238
239 if k == 1 {
241 let norm: f64 = work_data.iter().map(|&x| x * x).sum::<f64>().sqrt();
243 if norm < self.tol {
244 let u_mat = make_identity(work_m, 1);
245 let v_mat = make_identity(1, 1);
246 return Self::finalize(u_mat, vec![0.0], v_mat, transposed);
247 }
248 let sign = if work_data[0] >= 0.0 { 1.0 } else { -1.0 };
249 let mut u_data = vec![0.0; work_m];
250 for i in 0..work_m {
251 u_data[i] = work_data[i] * sign / norm;
252 }
253 let u_mat = Matrix {
254 data: u_data,
255 rows: work_m,
256 cols: 1,
257 };
258 let v_mat = Matrix {
259 data: vec![sign],
260 rows: 1,
261 cols: 1,
262 };
263 return Self::finalize(u_mat, vec![norm], v_mat, transposed);
264 }
265
266 let mut w = work_data;
269 let mut u_acc = make_identity(work_m, work_m);
270 let mut v_acc = make_identity(work_n, work_n);
271
272 let mut diag = vec![0.0; k];
273 let mut superdiag = vec![0.0; k.saturating_sub(1)];
274
275 householder_bidiag(
276 &mut w,
277 work_m,
278 work_n,
279 &mut u_acc,
280 &mut v_acc,
281 &mut diag,
282 &mut superdiag,
283 );
284
285 let max_iter = self.max_iter_factor * k * k;
287 implicit_qr_svd(
288 &mut diag,
289 &mut superdiag,
290 &mut u_acc,
291 &mut v_acc,
292 work_m,
293 work_n,
294 k,
295 self.tol,
296 max_iter,
297 )?;
298
299 for (i, d) in diag.iter_mut().enumerate().take(k) {
301 if *d < 0.0 {
302 *d = -*d;
303 for r in 0..work_m {
305 u_acc.data[r * work_m + i] = -u_acc.data[r * work_m + i];
306 }
307 }
308 }
309
310 let mut indices: Vec<usize> = (0..k).collect();
312 indices.sort_by(|&a_idx, &b_idx| {
313 diag[b_idx]
314 .partial_cmp(&diag[a_idx])
315 .unwrap_or(std::cmp::Ordering::Equal)
316 });
317
318 let sorted_s: Vec<f64> = indices.iter().map(|&i| diag[i]).collect();
319
320 let mut u_thin = Matrix::zeros(work_m, k);
322 for (new_col, &old_col) in indices.iter().enumerate() {
323 for r in 0..work_m {
324 u_thin.data[r * k + new_col] = u_acc.data[r * work_m + old_col];
325 }
326 }
327
328 let mut v_thin = Matrix::zeros(work_n, k);
330 for (new_col, &old_col) in indices.iter().enumerate() {
331 for r in 0..work_n {
332 v_thin.data[r * k + new_col] = v_acc.data[r * work_n + old_col];
333 }
334 }
335
336 Self::finalize(u_thin, sorted_s, v_thin, transposed)
337 }
338
339 fn finalize(
341 u: Matrix,
342 s: Vec<f64>,
343 v: Matrix,
344 transposed: bool,
345 ) -> Result<(Matrix, Vec<f64>, Matrix), SvdError> {
346 if transposed {
347 Ok((v, s, u))
348 } else {
349 Ok((u, s, v))
350 }
351 }
352}
353
354impl Default for GolubKahanSvd {
355 fn default() -> Self {
356 Self::new()
357 }
358}
359
360fn make_identity(rows: usize, cols: usize) -> Matrix {
364 let mut data = vec![0.0; rows * cols];
365 let k = rows.min(cols);
366 for i in 0..k {
367 data[i * cols + i] = 1.0;
368 }
369 Matrix { data, rows, cols }
370}
371
372fn householder_bidiag(
388 w: &mut [f64],
389 m: usize,
390 n: usize,
391 u_acc: &mut Matrix,
392 v_acc: &mut Matrix,
393 diag: &mut [f64],
394 superdiag: &mut [f64],
395) {
396 for j in 0..n {
397 {
399 let mut col = vec![0.0; m - j];
400 for i in j..m {
401 col[i - j] = w[i * n + j];
402 }
403 let (v_house, beta) = householder_vector(&col);
404 if beta != 0.0 {
405 apply_householder_left(w, m, n, j, j, &v_house, beta);
407 apply_householder_right_to_matrix(u_acc, m, m, j, &v_house, beta);
409 }
410 }
411 diag[j] = w[j * n + j];
412
413 if j + 2 <= n {
415 let start = j + 1;
416 let mut row = vec![0.0; n - start];
417 for c in start..n {
418 row[c - start] = w[j * n + c];
419 }
420 let (v_house, beta) = householder_vector(&row);
421 if beta != 0.0 {
422 apply_householder_right(w, m, n, j, start, &v_house, beta);
424 apply_householder_right_to_matrix(v_acc, n, n, start, &v_house, beta);
426 }
427 if j < n - 1 {
428 superdiag[j] = w[j * n + j + 1];
429 }
430 } else if j < n - 1 {
431 superdiag[j] = w[j * n + j + 1];
432 }
433 }
434}
435
436fn householder_vector(x: &[f64]) -> (Vec<f64>, f64) {
445 let len = x.len();
446 if len == 0 {
447 return (Vec::new(), 0.0);
448 }
449 if len == 1 {
450 return (vec![1.0], 0.0);
451 }
452
453 let mut sigma = 0.0;
454 for &xi in &x[1..] {
455 sigma += xi * xi;
456 }
457
458 let mut v = vec![0.0; len];
459 v[0] = 1.0;
460 v[1..len].copy_from_slice(&x[1..len]);
461
462 if sigma < 1e-300 {
463 return (v, 0.0);
464 }
465
466 let norm_x = (x[0] * x[0] + sigma).sqrt();
467 if x[0] <= 0.0 {
469 v[0] = x[0] - norm_x;
470 } else {
471 v[0] = -sigma / (x[0] + norm_x);
472 }
473
474 let beta = 2.0 * v[0] * v[0] / (sigma + v[0] * v[0]);
475 let v0 = v[0];
477 for vi in v.iter_mut() {
478 *vi /= v0;
479 }
480
481 (v, beta)
482}
483
484fn apply_householder_left(
487 w: &mut [f64],
488 _m: usize,
489 n: usize,
490 row_start: usize,
491 col_start: usize,
492 v: &[f64],
493 beta: f64,
494) {
495 let v_len = v.len();
496 let num_cols = n - col_start;
497 let mut p = vec![0.0; num_cols];
499 for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
500 let row = row_start + vi_idx;
501 for c in 0..num_cols {
502 p[c] += vi * w[row * n + col_start + c];
503 }
504 }
505 for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
507 let row = row_start + vi_idx;
508 for c in 0..num_cols {
509 w[row * n + col_start + c] -= beta * vi * p[c];
510 }
511 }
512}
513
514fn apply_householder_right(
517 w: &mut [f64],
518 m: usize,
519 n: usize,
520 row_start: usize,
521 col_start: usize,
522 v: &[f64],
523 beta: f64,
524) {
525 let v_len = v.len();
526 let num_rows = m - row_start;
527 for ri in 0..num_rows {
529 let row = row_start + ri;
530 let mut dot = 0.0;
531 for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
532 dot += w[row * n + col_start + vi_idx] * vi;
533 }
534 for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
535 w[row * n + col_start + vi_idx] -= beta * dot * vi;
536 }
537 }
538}
539
540fn apply_householder_right_to_matrix(
545 acc: &mut Matrix,
546 rows: usize,
547 cols: usize,
548 col_start: usize,
549 v: &[f64],
550 beta: f64,
551) {
552 let v_len = v.len();
553 for r in 0..rows {
554 let mut dot = 0.0;
555 for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
556 dot += acc.data[r * cols + col_start + vi_idx] * vi;
557 }
558 for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
559 acc.data[r * cols + col_start + vi_idx] -= beta * dot * vi;
560 }
561 }
562}
563
564fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
576 if b.abs() < 1e-300 {
577 return (1.0, 0.0);
578 }
579 if a.abs() < 1e-300 {
580 return (0.0, b.signum());
581 }
582 if b.abs() > a.abs() {
583 let tau = a / b;
584 let s = (1.0 + tau * tau).sqrt().recip() * b.signum();
585 let c = s * tau;
586 (c, s)
587 } else {
588 let tau = b / a;
589 let c = (1.0 + tau * tau).sqrt().recip() * a.signum();
590 let s = c * tau;
591 (c, s)
592 }
593}
594
595fn apply_givens_cols(
598 mat: &mut Matrix,
599 rows: usize,
600 stride: usize,
601 i: usize,
602 j: usize,
603 c: f64,
604 s: f64,
605) {
606 for r in 0..rows {
607 let a = mat.data[r * stride + i];
608 let b = mat.data[r * stride + j];
609 mat.data[r * stride + i] = c * a + s * b;
610 mat.data[r * stride + j] = -s * a + c * b;
611 }
612}
613
614#[allow(clippy::too_many_arguments)]
624fn implicit_qr_svd(
625 diag: &mut [f64],
626 superdiag: &mut [f64],
627 u_acc: &mut Matrix,
628 v_acc: &mut Matrix,
629 u_rows: usize,
630 v_rows: usize,
631 k: usize,
632 tol: f64,
633 max_iter: usize,
634) -> Result<(), SvdError> {
635 if k <= 1 {
636 return Ok(());
637 }
638
639 let mut iter_count = 0usize;
640
641 loop {
642 let mut q = 0usize;
645 while q < k - 1 {
646 let idx = k - 2 - q;
647 let thresh = tol * (diag[idx].abs() + diag[idx + 1].abs());
648 if superdiag[idx].abs() <= thresh.max(tol * 1e-2) {
649 superdiag[idx] = 0.0;
650 q += 1;
651 } else {
652 break;
653 }
654 }
655
656 if q >= k - 1 {
657 break;
659 }
660
661 let block_end = k - q; let mut p = block_end - 1;
665 while p > 0 {
666 let idx = p - 1;
667 let thresh = tol * (diag[idx].abs() + diag[idx + 1].abs());
668 if superdiag[idx].abs() <= thresh.max(tol * 1e-2) {
669 superdiag[idx] = 0.0;
670 break;
671 }
672 p -= 1;
673 }
674
675 let block_size = block_end - p;
676 if block_size <= 1 {
677 continue;
678 }
679
680 let mut found_zero_diag = false;
683 for i in p..block_end {
684 if diag[i].abs() < tol * 1e-2 {
685 if i < block_end - 1 && superdiag[i].abs() > 0.0 {
687 zero_superdiag_row(diag, superdiag, u_acc, u_rows, i, block_end);
688 } else if i > p && superdiag[i - 1].abs() > 0.0 {
689 zero_superdiag_col(diag, superdiag, v_acc, v_rows, i, p);
690 }
691 found_zero_diag = true;
692 break;
693 }
694 }
695 if found_zero_diag {
696 iter_count += 1;
697 if iter_count > max_iter {
698 return Err(SvdError::Convergence {
699 size: k,
700 iterations: max_iter,
701 });
702 }
703 continue;
704 }
705
706 let n1 = block_end - 1;
708 let n2 = block_end - 2;
709 let d_n1 = diag[n1];
710 let d_n2 = diag[n2];
711 let e_n2 = superdiag[n2];
712 let e_n3_sq = if n2 > p {
716 superdiag[n2 - 1] * superdiag[n2 - 1]
717 } else {
718 0.0
719 };
720 let t11 = d_n2 * d_n2 + e_n3_sq;
721 let t12 = d_n2 * e_n2;
722 let t22 = d_n1 * d_n1 + e_n2 * e_n2;
723
724 let shift = wilkinson_shift(t11, t12, t22);
725
726 golub_kahan_step(
728 diag, superdiag, u_acc, v_acc, u_rows, v_rows, p, block_end, shift,
729 );
730
731 iter_count += 1;
732 if iter_count > max_iter {
733 return Err(SvdError::Convergence {
734 size: k,
735 iterations: max_iter,
736 });
737 }
738 }
739
740 Ok(())
741}
742
743fn wilkinson_shift(a: f64, b: f64, d: f64) -> f64 {
748 let delta = (a - d) * 0.5;
749 if delta.abs() < 1e-300 && b.abs() < 1e-300 {
750 return d;
751 }
752 let sign = if delta >= 0.0 { 1.0 } else { -1.0 };
753 d - b * b / (delta + sign * (delta * delta + b * b).sqrt())
754}
755
756#[allow(clippy::too_many_arguments)]
761fn golub_kahan_step(
762 diag: &mut [f64],
763 superdiag: &mut [f64],
764 u_acc: &mut Matrix,
765 v_acc: &mut Matrix,
766 u_rows: usize,
767 v_rows: usize,
768 p: usize,
769 block_end: usize,
770 shift: f64,
771) {
772 let mut y = diag[p] * diag[p] - shift;
773 let mut z = diag[p] * superdiag[p];
774
775 for i in p..block_end - 1 {
776 let (c, s) = givens_rotation(y, z);
778 if i > p {
779 superdiag[i - 1] = c * superdiag[i - 1] + s * z;
780 }
782 let old_d_i = diag[i];
783 let old_e_i = superdiag[i];
784 diag[i] = c * old_d_i + s * old_e_i;
785 superdiag[i] = -s * old_d_i + c * old_e_i;
786 let old_d_i1 = diag[i + 1];
787 z = s * old_d_i1;
788 diag[i + 1] = c * old_d_i1;
789
790 apply_givens_cols(v_acc, v_rows, v_rows, i, i + 1, c, s);
792
793 let (c, s) = givens_rotation(diag[i], z);
795 diag[i] = c * diag[i] + s * z;
796 let old_e_i = superdiag[i];
797 let old_d_i1 = diag[i + 1];
798 superdiag[i] = c * old_e_i + s * old_d_i1;
799 diag[i + 1] = -s * old_e_i + c * old_d_i1;
800 if i + 1 < block_end - 1 {
801 let old_e_i1 = superdiag[i + 1];
802 z = s * old_e_i1;
803 superdiag[i + 1] = c * old_e_i1;
804 }
805 y = superdiag[i];
806
807 apply_givens_cols(u_acc, u_rows, u_rows, i, i + 1, c, s);
809 }
810}
811
812fn zero_superdiag_row(
820 diag: &mut [f64],
821 superdiag: &mut [f64],
822 u_acc: &mut Matrix,
823 u_rows: usize,
824 zero_idx: usize,
825 block_end: usize,
826) {
827 let mut bulge = superdiag[zero_idx];
828 superdiag[zero_idx] = 0.0;
829
830 for j in zero_idx..block_end - 1 {
831 let (c, s) = givens_rotation(diag[j + 1], bulge);
832 diag[j + 1] = c * diag[j + 1] + s * bulge;
833 if j + 1 < block_end - 1 {
835 let old_e = superdiag[j + 1];
836 superdiag[j + 1] = c * old_e;
837 bulge = -s * old_e;
838 apply_givens_cols(u_acc, u_rows, u_rows, j + 1, zero_idx, c, s);
839 if bulge.abs() < 1e-300 {
840 break;
841 }
842 } else {
843 apply_givens_cols(u_acc, u_rows, u_rows, j + 1, zero_idx, c, s);
844 }
845 }
846}
847
848fn zero_superdiag_col(
856 diag: &mut [f64],
857 superdiag: &mut [f64],
858 v_acc: &mut Matrix,
859 v_rows: usize,
860 zero_idx: usize,
861 block_start: usize,
862) {
863 let mut bulge = superdiag[zero_idx - 1];
864 superdiag[zero_idx - 1] = 0.0;
865
866 for j in (block_start..zero_idx).rev() {
867 let (c, s) = givens_rotation(diag[j], bulge);
868 diag[j] = c * diag[j] + s * bulge;
869 if j > block_start {
871 let old_e = superdiag[j - 1];
872 superdiag[j - 1] = c * old_e;
873 bulge = -s * old_e;
874 apply_givens_cols(v_acc, v_rows, v_rows, j, zero_idx, c, s);
875 if bulge.abs() < 1e-300 {
876 break;
877 }
878 } else {
879 apply_givens_cols(v_acc, v_rows, v_rows, j, zero_idx, c, s);
880 }
881 }
882}
883
884#[cfg(test)]
885mod tests {
886 use super::*;
887
888 #[test]
889 fn test_new_returns_default_parameters() {
890 let svd = GolubKahanSvd::new();
891 assert!((svd.tol - 1e-14).abs() < f64::EPSILON);
892 assert_eq!(svd.max_iter_factor, 30);
893 }
894
895 #[test]
896 fn test_with_tolerance_sets_custom_tol() {
897 let svd = GolubKahanSvd::new().with_tolerance(1e-8);
898 assert!((svd.tol - 1e-8).abs() < f64::EPSILON);
899 }
900
901 #[test]
902 fn test_with_max_iter_factor_sets_custom_factor() {
903 let svd = GolubKahanSvd::new().with_max_iter_factor(50);
904 assert_eq!(svd.max_iter_factor, 50);
905 }
906
907 #[test]
908 fn test_default_trait_matches_new() {
909 let a = GolubKahanSvd::new();
910 let b = GolubKahanSvd::default();
911 assert!((a.tol - b.tol).abs() < f64::EPSILON);
912 assert_eq!(a.max_iter_factor, b.max_iter_factor);
913 }
914
915 #[test]
916 fn test_svd_error_display() {
917 let err = SvdError::Convergence {
918 size: 10,
919 iterations: 300,
920 };
921 let msg = format!("{err}");
922 assert!(msg.contains("10"));
923 assert!(msg.contains("300"));
924 }
925
926 #[test]
927 fn test_svd_error_converts_to_pc_error() {
928 let err = SvdError::Convergence {
929 size: 5,
930 iterations: 150,
931 };
932 let pc_err: crate::error::PcError = err.into();
933 assert!(matches!(pc_err, crate::error::PcError::ConfigValidation(_)));
934 }
935
936 #[test]
937 fn test_empty_matrix() {
938 let a = Matrix::zeros(0, 0);
940 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
941 assert_eq!(u.rows, 0);
942 assert_eq!(u.cols, 0);
943 assert!(s.is_empty());
944 assert_eq!(v.rows, 0);
945 assert_eq!(v.cols, 0);
946 }
947
948 #[test]
949 fn test_nan_input_returns_error() {
950 let a = Matrix {
952 data: vec![1.0, f64::NAN, 3.0, 4.0],
953 rows: 2,
954 cols: 2,
955 };
956 let result = GolubKahanSvd::new().compute(&a);
957 assert!(result.is_err());
958 let err = result.unwrap_err();
959 assert!(matches!(err, SvdError::InvalidInput { .. }));
960 }
961
962 #[test]
963 fn test_inf_input_returns_error() {
964 let a = Matrix {
966 data: vec![1.0, f64::INFINITY, 3.0, 4.0],
967 rows: 2,
968 cols: 2,
969 };
970 let result = GolubKahanSvd::new().compute(&a);
971 assert!(result.is_err());
972 let err = result.unwrap_err();
973 assert!(matches!(err, SvdError::InvalidInput { .. }));
974 }
975
976 #[test]
977 fn test_neg_inf_input_returns_error() {
978 let a = Matrix {
979 data: vec![f64::NEG_INFINITY, 2.0, 3.0, 4.0],
980 rows: 2,
981 cols: 2,
982 };
983 let result = GolubKahanSvd::new().compute(&a);
984 assert!(result.is_err());
985 }
986
987 fn mat_mul_raw(a: &Matrix, b: &Matrix) -> Matrix {
991 assert_eq!(a.cols, b.rows);
992 let mut c = Matrix::zeros(a.rows, b.cols);
993 for i in 0..a.rows {
994 for k in 0..a.cols {
995 let aik = a.data[i * a.cols + k];
996 for j in 0..b.cols {
997 c.data[i * c.cols + j] += aik * b.data[k * b.cols + j];
998 }
999 }
1000 }
1001 c
1002 }
1003
1004 fn transpose_raw(a: &Matrix) -> Matrix {
1005 let mut t = Matrix::zeros(a.cols, a.rows);
1006 for r in 0..a.rows {
1007 for c in 0..a.cols {
1008 t.data[c * t.cols + r] = a.data[r * a.cols + c];
1009 }
1010 }
1011 t
1012 }
1013
1014 fn assert_reconstruction(a: &Matrix, u: &Matrix, s: &[f64], v: &Matrix, tol: f64) {
1015 let k = s.len();
1016 let mut diag_s = Matrix::zeros(k, k);
1017 for (i, &si) in s.iter().enumerate() {
1018 diag_s.data[i * k + i] = si;
1019 }
1020 let us = mat_mul_raw(u, &diag_s);
1021 let recon = mat_mul_raw(&us, &transpose_raw(v));
1022 for r in 0..a.rows {
1023 for c in 0..a.cols {
1024 let diff = (recon.data[r * recon.cols + c] - a.data[r * a.cols + c]).abs();
1025 assert!(
1026 diff < tol,
1027 "reconstruction mismatch at ({r},{c}): got {} expected {}, diff {diff}",
1028 recon.data[r * recon.cols + c],
1029 a.data[r * a.cols + c]
1030 );
1031 }
1032 }
1033 }
1034
1035 fn assert_orthonormal_columns(m: &Matrix, tol: f64) {
1036 let mtm = mat_mul_raw(&transpose_raw(m), m);
1037 let k = mtm.rows;
1038 for i in 0..k {
1039 for j in 0..k {
1040 let expected = if i == j { 1.0 } else { 0.0 };
1041 let diff = (mtm.data[i * k + j] - expected).abs();
1042 assert!(
1043 diff < tol,
1044 "orthonormality violated at ({i},{j}): got {}, expected {expected}",
1045 mtm.data[i * k + j]
1046 );
1047 }
1048 }
1049 }
1050
1051 fn assert_singular_values_sorted(s: &[f64]) {
1052 for (i, &si) in s.iter().enumerate() {
1053 assert!(si >= -1e-14, "singular value s[{i}] = {si} is negative");
1054 }
1055 for i in 1..s.len() {
1056 assert!(
1057 s[i - 1] >= s[i] - 1e-12,
1058 "not descending: s[{}]={} < s[{}]={}",
1059 i - 1,
1060 s[i - 1],
1061 i,
1062 s[i]
1063 );
1064 }
1065 }
1066
1067 #[test]
1070 fn test_identity_3x3() {
1071 let mut data = vec![0.0; 9];
1072 for i in 0..3 {
1073 data[i * 3 + i] = 1.0;
1074 }
1075 let a = Matrix {
1076 data,
1077 rows: 3,
1078 cols: 3,
1079 };
1080 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1081 for &si in &s {
1082 assert!((si - 1.0).abs() < 1e-10, "expected 1.0, got {si}");
1083 }
1084 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1085 assert_orthonormal_columns(&u, 1e-10);
1086 assert_orthonormal_columns(&v, 1e-10);
1087 assert_singular_values_sorted(&s);
1088 }
1089
1090 #[test]
1091 fn test_diagonal_matrix() {
1092 let a = Matrix {
1093 data: vec![5.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 1.0],
1094 rows: 3,
1095 cols: 3,
1096 };
1097 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1098 assert!((s[0] - 5.0).abs() < 1e-10);
1099 assert!((s[1] - 3.0).abs() < 1e-10);
1100 assert!((s[2] - 1.0).abs() < 1e-10);
1101 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1102 assert_orthonormal_columns(&u, 1e-10);
1103 assert_orthonormal_columns(&v, 1e-10);
1104 }
1105
1106 #[test]
1107 fn test_known_2x2() {
1108 let a = Matrix {
1110 data: vec![3.0, 2.0, 2.0, 3.0],
1111 rows: 2,
1112 cols: 2,
1113 };
1114 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1115 assert!((s[0] - 5.0).abs() < 1e-10, "expected s[0]=5, got {}", s[0]);
1116 assert!((s[1] - 1.0).abs() < 1e-10, "expected s[1]=1, got {}", s[1]);
1117 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1118 assert_orthonormal_columns(&u, 1e-10);
1119 assert_orthonormal_columns(&v, 1e-10);
1120 }
1121
1122 #[test]
1123 fn test_known_3x3() {
1124 let a = Matrix {
1125 data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0],
1126 rows: 3,
1127 cols: 3,
1128 };
1129 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1130 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1131 assert_orthonormal_columns(&u, 1e-10);
1132 assert_orthonormal_columns(&v, 1e-10);
1133 assert_singular_values_sorted(&s);
1134 }
1135
1136 #[test]
1137 fn test_known_4x4() {
1138 let a = Matrix {
1139 data: vec![
1140 2.0, -1.0, 0.0, 0.0, -1.0, 2.0, -1.0, 0.0, 0.0, -1.0, 2.0, -1.0, 0.0, 0.0, -1.0,
1141 2.0,
1142 ],
1143 rows: 4,
1144 cols: 4,
1145 };
1146 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1147 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1148 assert_orthonormal_columns(&u, 1e-10);
1149 assert_orthonormal_columns(&v, 1e-10);
1150 assert_singular_values_sorted(&s);
1151 }
1152
1153 #[test]
1154 fn test_tall_rectangular() {
1155 let a = Matrix {
1156 data: vec![
1157 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
1158 ],
1159 rows: 5,
1160 cols: 3,
1161 };
1162 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1163 assert_eq!(u.rows, 5);
1164 assert_eq!(u.cols, 3);
1165 assert_eq!(s.len(), 3);
1166 assert_eq!(v.rows, 3);
1167 assert_eq!(v.cols, 3);
1168 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1169 assert_orthonormal_columns(&u, 1e-10);
1170 assert_orthonormal_columns(&v, 1e-10);
1171 }
1172
1173 #[test]
1174 fn test_wide_rectangular() {
1175 let a = Matrix {
1176 data: vec![
1177 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
1178 ],
1179 rows: 3,
1180 cols: 5,
1181 };
1182 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1183 assert_eq!(u.rows, 3);
1184 assert_eq!(u.cols, 3);
1185 assert_eq!(s.len(), 3);
1186 assert_eq!(v.rows, 5);
1187 assert_eq!(v.cols, 3);
1188 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1189 assert_orthonormal_columns(&u, 1e-10);
1190 assert_orthonormal_columns(&v, 1e-10);
1191 }
1192
1193 #[test]
1194 fn test_single_element() {
1195 let a = Matrix {
1196 data: vec![7.0],
1197 rows: 1,
1198 cols: 1,
1199 };
1200 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1201 assert!((s[0] - 7.0).abs() < 1e-10);
1202 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1203 }
1204
1205 #[test]
1206 fn test_single_element_negative() {
1207 let a = Matrix {
1208 data: vec![-5.0],
1209 rows: 1,
1210 cols: 1,
1211 };
1212 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1213 assert!((s[0] - 5.0).abs() < 1e-10);
1214 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1215 }
1216
1217 #[test]
1218 fn test_single_row() {
1219 let a = Matrix {
1220 data: vec![1.0, 2.0, 3.0, 4.0],
1221 rows: 1,
1222 cols: 4,
1223 };
1224 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1225 assert_eq!(s.len(), 1);
1226 let expected = (1.0f64 + 4.0 + 9.0 + 16.0).sqrt();
1227 assert!((s[0] - expected).abs() < 1e-10);
1228 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1229 }
1230
1231 #[test]
1232 fn test_single_column() {
1233 let a = Matrix {
1234 data: vec![1.0, 2.0, 3.0, 4.0],
1235 rows: 4,
1236 cols: 1,
1237 };
1238 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1239 assert_eq!(s.len(), 1);
1240 let expected = (1.0f64 + 4.0 + 9.0 + 16.0).sqrt();
1241 assert!((s[0] - expected).abs() < 1e-10);
1242 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1243 }
1244
1245 #[test]
1246 fn test_zero_matrix() {
1247 let a = Matrix::zeros(3, 3);
1248 let (_u, s, _v) = GolubKahanSvd::new().compute(&a).unwrap();
1249 for &si in &s {
1250 assert!(si.abs() < 1e-12);
1251 }
1252 assert_singular_values_sorted(&s);
1253 }
1254
1255 #[test]
1256 fn test_rank_deficient() {
1257 let a = Matrix {
1259 data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 5.0, 7.0, 9.0],
1260 rows: 3,
1261 cols: 3,
1262 };
1263 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1264 assert!(
1265 s[2] < 1e-10,
1266 "third singular value should be ~0, got {}",
1267 s[2]
1268 );
1269 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1270 assert_orthonormal_columns(&u, 1e-10);
1271 assert_orthonormal_columns(&v, 1e-10);
1272 assert_singular_values_sorted(&s);
1273 }
1274
1275 #[test]
1276 fn test_rank_one() {
1277 let a = Matrix {
1279 data: vec![4.0, 5.0, 8.0, 10.0, 12.0, 15.0],
1280 rows: 3,
1281 cols: 2,
1282 };
1283 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1284 let norm_u = (1.0f64 + 4.0 + 9.0).sqrt();
1285 let norm_v = (16.0f64 + 25.0).sqrt();
1286 let expected_s0 = norm_u * norm_v;
1287 assert!(
1288 (s[0] - expected_s0).abs() < 1e-8,
1289 "expected s[0]={expected_s0}, got {}",
1290 s[0]
1291 );
1292 assert!(s[1] < 1e-10, "expected s[1]~0, got {}", s[1]);
1293 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1294 }
1295
1296 #[test]
1297 fn test_repeated_singular_values() {
1298 let a = Matrix {
1300 data: vec![4.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 2.0],
1301 rows: 3,
1302 cols: 3,
1303 };
1304 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1305 assert!((s[0] - 4.0).abs() < 1e-10);
1306 assert!((s[1] - 4.0).abs() < 1e-10);
1307 assert!((s[2] - 2.0).abs() < 1e-10);
1308 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1309 assert_orthonormal_columns(&u, 1e-10);
1310 assert_orthonormal_columns(&v, 1e-10);
1311 }
1312
1313 #[test]
1314 fn test_diagonal_with_zeros() {
1315 let a = Matrix {
1317 data: vec![5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0],
1318 rows: 3,
1319 cols: 3,
1320 };
1321 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1322 assert!((s[0] - 5.0).abs() < 1e-10);
1323 assert!((s[1] - 3.0).abs() < 1e-10);
1324 assert!(s[2] < 1e-10);
1325 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1326 assert_singular_values_sorted(&s);
1327 }
1328
1329 #[test]
1330 fn test_ill_conditioned() {
1331 let a = Matrix {
1333 data: vec![1.0, 0.0, 0.0, 0.0, 1e-12, 0.0, 0.0, 0.0, 1e-6],
1334 rows: 3,
1335 cols: 3,
1336 };
1337 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1338 assert!((s[0] - 1.0).abs() < 1e-8);
1339 assert_reconstruction(&a, &u, &s, &v, 1e-6);
1340 assert_singular_values_sorted(&s);
1341 }
1342
1343 #[test]
1344 fn test_extreme_small_values() {
1345 let a = Matrix {
1347 data: vec![1e-300, 0.0, 0.0, 2e-300],
1348 rows: 2,
1349 cols: 2,
1350 };
1351 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1352 assert!(s[0].is_finite());
1353 assert!(s[1].is_finite());
1354 assert_singular_values_sorted(&s);
1355 assert_reconstruction(&a, &u, &s, &v, 1e-290);
1356 }
1357
1358 #[test]
1359 fn test_extreme_large_values() {
1360 let a = Matrix {
1362 data: vec![1e+150, 0.0, 0.0, 2e+150],
1363 rows: 2,
1364 cols: 2,
1365 };
1366 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1367 assert!(s[0].is_finite());
1368 assert!(s[1].is_finite());
1369 for &val in &u.data {
1370 assert!(val.is_finite());
1371 }
1372 for &val in &v.data {
1373 assert!(val.is_finite());
1374 }
1375 assert_singular_values_sorted(&s);
1376 }
1377
1378 #[test]
1379 fn test_convergence_64x64() {
1380 use rand::Rng;
1382 use rand::SeedableRng;
1383 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
1384 let data: Vec<f64> = (0..64 * 64).map(|_| rng.gen_range(-1.0..1.0)).collect();
1385 let a = Matrix {
1386 data,
1387 rows: 64,
1388 cols: 64,
1389 };
1390 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1391 assert_reconstruction(&a, &u, &s, &v, 1e-8);
1392 assert_orthonormal_columns(&u, 1e-8);
1393 assert_orthonormal_columns(&v, 1e-8);
1394 assert_singular_values_sorted(&s);
1395 }
1396
1397 #[test]
1398 fn test_convergence_128x128() {
1399 use rand::Rng;
1401 use rand::SeedableRng;
1402 let mut rng = rand::rngs::StdRng::seed_from_u64(123);
1403 let data: Vec<f64> = (0..128 * 128).map(|_| rng.gen_range(-1.0..1.0)).collect();
1404 let a = Matrix {
1405 data,
1406 rows: 128,
1407 cols: 128,
1408 };
1409 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1410 assert_reconstruction(&a, &u, &s, &v, 1e-8);
1411 assert_orthonormal_columns(&u, 1e-8);
1412 assert_orthonormal_columns(&v, 1e-8);
1413 assert_singular_values_sorted(&s);
1414 }
1415
1416 #[test]
1417 fn test_almost_bidiagonal() {
1418 let a = Matrix {
1420 data: vec![
1421 5.0, 2.0, 0.0, 0.0, 0.0, 4.0, 1.0, 0.0, 0.0, 0.0, 3.0, 0.5, 0.0, 0.0, 0.0, 1.0,
1422 ],
1423 rows: 4,
1424 cols: 4,
1425 };
1426 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1427 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1428 assert_singular_values_sorted(&s);
1429 }
1430
1431 #[test]
1432 fn test_custom_tolerance() {
1433 let a = Matrix {
1435 data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0],
1436 rows: 3,
1437 cols: 3,
1438 };
1439 let (u, s, v) = GolubKahanSvd::new()
1440 .with_tolerance(1e-15)
1441 .compute(&a)
1442 .unwrap();
1443 assert_reconstruction(&a, &u, &s, &v, 1e-12);
1444 }
1445
1446 #[test]
1447 fn test_low_max_iter_triggers_error() {
1448 let a = Matrix {
1450 data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0],
1451 rows: 3,
1452 cols: 3,
1453 };
1454 let result = GolubKahanSvd::new().with_max_iter_factor(0).compute(&a);
1455 assert!(result.is_err(), "expected convergence error with factor=0");
1456 let err = result.unwrap_err();
1457 assert!(matches!(err, SvdError::Convergence { .. }));
1458 }
1459
1460 #[test]
1461 fn test_determinism() {
1462 let a = Matrix {
1464 data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0],
1465 rows: 3,
1466 cols: 3,
1467 };
1468 let svd = GolubKahanSvd::new();
1469 let (u1, s1, v1) = svd.compute(&a).unwrap();
1470 let (u2, s2, v2) = svd.compute(&a).unwrap();
1471 assert_eq!(s1, s2, "singular values differ");
1472 assert_eq!(u1.data, u2.data, "U differs");
1473 assert_eq!(v1.data, v2.data, "V differs");
1474 }
1475}