1use std::fmt;
31
32use crate::error::PcError;
33use crate::matrix::Matrix;
34
35#[derive(Debug)]
49pub enum SvdError {
50 Convergence {
52 size: usize,
54 iterations: usize,
56 },
57 InvalidInput {
59 reason: String,
61 },
62}
63
64impl fmt::Display for SvdError {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 match self {
67 SvdError::Convergence { size, iterations } => {
68 write!(
69 f,
70 "SVD failed to converge for matrix of size {size} \
71 after {iterations} iterations"
72 )
73 }
74 SvdError::InvalidInput { reason } => {
75 write!(f, "SVD invalid input: {reason}")
76 }
77 }
78 }
79}
80
81impl std::error::Error for SvdError {}
82
83impl From<SvdError> for PcError {
84 fn from(e: SvdError) -> Self {
85 PcError::ConfigValidation(e.to_string())
86 }
87}
88
89#[derive(Debug, Clone)]
109pub struct GolubKahanSvd {
110 pub tol: f64,
112 pub max_iter_factor: usize,
114}
115
116impl GolubKahanSvd {
117 pub fn new() -> Self {
130 Self {
131 tol: 1e-14,
132 max_iter_factor: 30,
133 }
134 }
135
136 pub fn with_tolerance(mut self, tol: f64) -> Self {
151 self.tol = tol;
152 self
153 }
154
155 pub fn with_max_iter_factor(mut self, factor: usize) -> Self {
172 self.max_iter_factor = factor;
173 self
174 }
175
176 pub fn compute(&self, a: &Matrix) -> Result<(Matrix, Vec<f64>, Matrix), SvdError> {
205 for &val in &a.data {
207 if val.is_nan() || val.is_infinite() {
208 return Err(SvdError::InvalidInput {
209 reason: "matrix contains NaN or Inf".to_string(),
210 });
211 }
212 }
213
214 let m = a.rows;
215 let n = a.cols;
216
217 if m == 0 || n == 0 {
219 return Ok((Matrix::zeros(m, 0), Vec::new(), Matrix::zeros(n, 0)));
220 }
221
222 let transposed = m < n;
224 let (work_m, work_n, work_data) = if transposed {
225 let mut t = vec![0.0; m * n];
226 for r in 0..m {
227 for c in 0..n {
228 t[c * m + r] = a.data[r * n + c];
229 }
230 }
231 (n, m, t)
232 } else {
233 (m, n, a.data.clone())
234 };
235
236 let k = work_n;
238
239 if k == 1 {
241 let norm: f64 = work_data.iter().map(|&x| x * x).sum::<f64>().sqrt();
243 if norm < self.tol {
244 let u_mat = make_identity(work_m, 1);
245 let v_mat = make_identity(1, 1);
246 return Self::finalize(u_mat, vec![0.0], v_mat, transposed);
247 }
248 let sign = if work_data[0] >= 0.0 { 1.0 } else { -1.0 };
249 let mut u_data = vec![0.0; work_m];
250 for i in 0..work_m {
251 u_data[i] = work_data[i] * sign / norm;
252 }
253 let u_mat = Matrix {
254 data: u_data,
255 rows: work_m,
256 cols: 1,
257 };
258 let v_mat = Matrix {
259 data: vec![sign],
260 rows: 1,
261 cols: 1,
262 };
263 return Self::finalize(u_mat, vec![norm], v_mat, transposed);
264 }
265
266 let mut w = work_data;
269 let mut u_acc = make_identity(work_m, work_m);
270 let mut v_acc = make_identity(work_n, work_n);
271
272 let mut diag = vec![0.0; k];
273 let mut superdiag = vec![0.0; k.saturating_sub(1)];
274
275 householder_bidiag(
276 &mut w,
277 work_m,
278 work_n,
279 &mut u_acc,
280 &mut v_acc,
281 &mut diag,
282 &mut superdiag,
283 );
284
285 let max_iter = self.max_iter_factor * k * k;
287 implicit_qr_svd(
288 &mut diag,
289 &mut superdiag,
290 &mut u_acc,
291 &mut v_acc,
292 work_m,
293 work_n,
294 k,
295 self.tol,
296 max_iter,
297 )?;
298
299 for (i, d) in diag.iter_mut().enumerate().take(k) {
301 if *d < 0.0 {
302 *d = -*d;
303 for r in 0..work_m {
305 u_acc.data[r * work_m + i] = -u_acc.data[r * work_m + i];
306 }
307 }
308 }
309
310 let mut indices: Vec<usize> = (0..k).collect();
312 indices.sort_by(|&a_idx, &b_idx| diag[b_idx].total_cmp(&diag[a_idx]));
313
314 let sorted_s: Vec<f64> = indices.iter().map(|&i| diag[i]).collect();
315
316 let mut u_thin = Matrix::zeros(work_m, k);
318 for (new_col, &old_col) in indices.iter().enumerate() {
319 for r in 0..work_m {
320 u_thin.data[r * k + new_col] = u_acc.data[r * work_m + old_col];
321 }
322 }
323
324 let mut v_thin = Matrix::zeros(work_n, k);
326 for (new_col, &old_col) in indices.iter().enumerate() {
327 for r in 0..work_n {
328 v_thin.data[r * k + new_col] = v_acc.data[r * work_n + old_col];
329 }
330 }
331
332 Self::finalize(u_thin, sorted_s, v_thin, transposed)
333 }
334
335 fn finalize(
337 u: Matrix,
338 s: Vec<f64>,
339 v: Matrix,
340 transposed: bool,
341 ) -> Result<(Matrix, Vec<f64>, Matrix), SvdError> {
342 if transposed {
343 Ok((v, s, u))
344 } else {
345 Ok((u, s, v))
346 }
347 }
348}
349
350impl Default for GolubKahanSvd {
351 fn default() -> Self {
352 Self::new()
353 }
354}
355
356fn make_identity(rows: usize, cols: usize) -> Matrix {
360 let mut data = vec![0.0; rows * cols];
361 let k = rows.min(cols);
362 for i in 0..k {
363 data[i * cols + i] = 1.0;
364 }
365 Matrix { data, rows, cols }
366}
367
368fn householder_bidiag(
384 w: &mut [f64],
385 m: usize,
386 n: usize,
387 u_acc: &mut Matrix,
388 v_acc: &mut Matrix,
389 diag: &mut [f64],
390 superdiag: &mut [f64],
391) {
392 for j in 0..n {
393 {
395 let mut col = vec![0.0; m - j];
396 for i in j..m {
397 col[i - j] = w[i * n + j];
398 }
399 let (v_house, beta) = householder_vector(&col);
400 if beta != 0.0 {
401 apply_householder_left(w, m, n, j, j, &v_house, beta);
403 apply_householder_right_to_matrix(u_acc, m, m, j, &v_house, beta);
405 }
406 }
407 diag[j] = w[j * n + j];
408
409 if j + 2 <= n {
411 let start = j + 1;
412 let mut row = vec![0.0; n - start];
413 for c in start..n {
414 row[c - start] = w[j * n + c];
415 }
416 let (v_house, beta) = householder_vector(&row);
417 if beta != 0.0 {
418 apply_householder_right(w, m, n, j, start, &v_house, beta);
420 apply_householder_right_to_matrix(v_acc, n, n, start, &v_house, beta);
422 }
423 if j < n - 1 {
424 superdiag[j] = w[j * n + j + 1];
425 }
426 } else if j < n - 1 {
427 superdiag[j] = w[j * n + j + 1];
428 }
429 }
430}
431
432fn householder_vector(x: &[f64]) -> (Vec<f64>, f64) {
441 let len = x.len();
442 if len == 0 {
443 return (Vec::new(), 0.0);
444 }
445 if len == 1 {
446 return (vec![1.0], 0.0);
447 }
448
449 let mut sigma = 0.0;
450 for &xi in &x[1..] {
451 sigma += xi * xi;
452 }
453
454 let mut v = vec![0.0; len];
455 v[0] = 1.0;
456 v[1..len].copy_from_slice(&x[1..len]);
457
458 if sigma < 1e-300 {
459 return (v, 0.0);
460 }
461
462 let norm_x = (x[0] * x[0] + sigma).sqrt();
463 if x[0] <= 0.0 {
465 v[0] = x[0] - norm_x;
466 } else {
467 v[0] = -sigma / (x[0] + norm_x);
468 }
469
470 let beta = 2.0 * v[0] * v[0] / (sigma + v[0] * v[0]);
471 let v0 = v[0];
473 for vi in v.iter_mut() {
474 *vi /= v0;
475 }
476
477 (v, beta)
478}
479
480fn apply_householder_left(
483 w: &mut [f64],
484 _m: usize,
485 n: usize,
486 row_start: usize,
487 col_start: usize,
488 v: &[f64],
489 beta: f64,
490) {
491 let v_len = v.len();
492 let num_cols = n - col_start;
493 let mut p = vec![0.0; num_cols];
495 for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
496 let row = row_start + vi_idx;
497 for c in 0..num_cols {
498 p[c] += vi * w[row * n + col_start + c];
499 }
500 }
501 for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
503 let row = row_start + vi_idx;
504 for c in 0..num_cols {
505 w[row * n + col_start + c] -= beta * vi * p[c];
506 }
507 }
508}
509
510fn apply_householder_right(
513 w: &mut [f64],
514 m: usize,
515 n: usize,
516 row_start: usize,
517 col_start: usize,
518 v: &[f64],
519 beta: f64,
520) {
521 let v_len = v.len();
522 let num_rows = m - row_start;
523 for ri in 0..num_rows {
525 let row = row_start + ri;
526 let mut dot = 0.0;
527 for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
528 dot += w[row * n + col_start + vi_idx] * vi;
529 }
530 for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
531 w[row * n + col_start + vi_idx] -= beta * dot * vi;
532 }
533 }
534}
535
536fn apply_householder_right_to_matrix(
541 acc: &mut Matrix,
542 rows: usize,
543 cols: usize,
544 col_start: usize,
545 v: &[f64],
546 beta: f64,
547) {
548 let v_len = v.len();
549 for r in 0..rows {
550 let mut dot = 0.0;
551 for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
552 dot += acc.data[r * cols + col_start + vi_idx] * vi;
553 }
554 for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
555 acc.data[r * cols + col_start + vi_idx] -= beta * dot * vi;
556 }
557 }
558}
559
560fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
572 if b.abs() < 1e-300 {
573 return (1.0, 0.0);
574 }
575 if a.abs() < 1e-300 {
576 return (0.0, b.signum());
577 }
578 if b.abs() > a.abs() {
579 let tau = a / b;
580 let s = (1.0 + tau * tau).sqrt().recip() * b.signum();
581 let c = s * tau;
582 (c, s)
583 } else {
584 let tau = b / a;
585 let c = (1.0 + tau * tau).sqrt().recip() * a.signum();
586 let s = c * tau;
587 (c, s)
588 }
589}
590
591fn apply_givens_cols(
594 mat: &mut Matrix,
595 rows: usize,
596 stride: usize,
597 i: usize,
598 j: usize,
599 c: f64,
600 s: f64,
601) {
602 for r in 0..rows {
603 let a = mat.data[r * stride + i];
604 let b = mat.data[r * stride + j];
605 mat.data[r * stride + i] = c * a + s * b;
606 mat.data[r * stride + j] = -s * a + c * b;
607 }
608}
609
610#[allow(clippy::too_many_arguments)]
620fn implicit_qr_svd(
621 diag: &mut [f64],
622 superdiag: &mut [f64],
623 u_acc: &mut Matrix,
624 v_acc: &mut Matrix,
625 u_rows: usize,
626 v_rows: usize,
627 k: usize,
628 tol: f64,
629 max_iter: usize,
630) -> Result<(), SvdError> {
631 if k <= 1 {
632 return Ok(());
633 }
634
635 let mut iter_count = 0usize;
636
637 loop {
638 let mut q = 0usize;
641 while q < k - 1 {
642 let idx = k - 2 - q;
643 let thresh = tol * (diag[idx].abs() + diag[idx + 1].abs());
644 if superdiag[idx].abs() <= thresh.max(tol * 1e-2) {
645 superdiag[idx] = 0.0;
646 q += 1;
647 } else {
648 break;
649 }
650 }
651
652 if q >= k - 1 {
653 break;
655 }
656
657 let block_end = k - q; let mut p = block_end - 1;
661 while p > 0 {
662 let idx = p - 1;
663 let thresh = tol * (diag[idx].abs() + diag[idx + 1].abs());
664 if superdiag[idx].abs() <= thresh.max(tol * 1e-2) {
665 superdiag[idx] = 0.0;
666 break;
667 }
668 p -= 1;
669 }
670
671 let block_size = block_end - p;
672 if block_size <= 1 {
673 continue;
674 }
675
676 let mut found_zero_diag = false;
679 for i in p..block_end {
680 if diag[i].abs() < tol * 1e-2 {
681 if i < block_end - 1 && superdiag[i].abs() > 0.0 {
683 zero_superdiag_row(diag, superdiag, u_acc, u_rows, i, block_end);
684 } else if i > p && superdiag[i - 1].abs() > 0.0 {
685 zero_superdiag_col(diag, superdiag, v_acc, v_rows, i, p);
686 }
687 found_zero_diag = true;
688 break;
689 }
690 }
691 if found_zero_diag {
692 iter_count += 1;
693 if iter_count > max_iter {
694 return Err(SvdError::Convergence {
695 size: k,
696 iterations: max_iter,
697 });
698 }
699 continue;
700 }
701
702 let n1 = block_end - 1;
704 let n2 = block_end - 2;
705 let d_n1 = diag[n1];
706 let d_n2 = diag[n2];
707 let e_n2 = superdiag[n2];
708 let e_n3_sq = if n2 > p {
712 superdiag[n2 - 1] * superdiag[n2 - 1]
713 } else {
714 0.0
715 };
716 let t11 = d_n2 * d_n2 + e_n3_sq;
717 let t12 = d_n2 * e_n2;
718 let t22 = d_n1 * d_n1 + e_n2 * e_n2;
719
720 let shift = wilkinson_shift(t11, t12, t22);
721
722 golub_kahan_step(
724 diag, superdiag, u_acc, v_acc, u_rows, v_rows, p, block_end, shift,
725 );
726
727 iter_count += 1;
728 if iter_count > max_iter {
729 return Err(SvdError::Convergence {
730 size: k,
731 iterations: max_iter,
732 });
733 }
734 }
735
736 Ok(())
737}
738
739fn wilkinson_shift(a: f64, b: f64, d: f64) -> f64 {
744 let delta = (a - d) * 0.5;
745 if delta.abs() < 1e-300 && b.abs() < 1e-300 {
746 return d;
747 }
748 let sign = if delta >= 0.0 { 1.0 } else { -1.0 };
749 d - b * b / (delta + sign * (delta * delta + b * b).sqrt())
750}
751
752#[allow(clippy::too_many_arguments)]
757fn golub_kahan_step(
758 diag: &mut [f64],
759 superdiag: &mut [f64],
760 u_acc: &mut Matrix,
761 v_acc: &mut Matrix,
762 u_rows: usize,
763 v_rows: usize,
764 p: usize,
765 block_end: usize,
766 shift: f64,
767) {
768 let mut y = diag[p] * diag[p] - shift;
769 let mut z = diag[p] * superdiag[p];
770
771 for i in p..block_end - 1 {
772 let (c, s) = givens_rotation(y, z);
774 if i > p {
775 superdiag[i - 1] = c * superdiag[i - 1] + s * z;
776 }
778 let old_d_i = diag[i];
779 let old_e_i = superdiag[i];
780 diag[i] = c * old_d_i + s * old_e_i;
781 superdiag[i] = -s * old_d_i + c * old_e_i;
782 let old_d_i1 = diag[i + 1];
783 z = s * old_d_i1;
784 diag[i + 1] = c * old_d_i1;
785
786 apply_givens_cols(v_acc, v_rows, v_rows, i, i + 1, c, s);
788
789 let (c, s) = givens_rotation(diag[i], z);
791 diag[i] = c * diag[i] + s * z;
792 let old_e_i = superdiag[i];
793 let old_d_i1 = diag[i + 1];
794 superdiag[i] = c * old_e_i + s * old_d_i1;
795 diag[i + 1] = -s * old_e_i + c * old_d_i1;
796 if i + 1 < block_end - 1 {
797 let old_e_i1 = superdiag[i + 1];
798 z = s * old_e_i1;
799 superdiag[i + 1] = c * old_e_i1;
800 }
801 y = superdiag[i];
802
803 apply_givens_cols(u_acc, u_rows, u_rows, i, i + 1, c, s);
805 }
806}
807
808fn zero_superdiag_row(
816 diag: &mut [f64],
817 superdiag: &mut [f64],
818 u_acc: &mut Matrix,
819 u_rows: usize,
820 zero_idx: usize,
821 block_end: usize,
822) {
823 let mut bulge = superdiag[zero_idx];
824 superdiag[zero_idx] = 0.0;
825
826 for j in zero_idx..block_end - 1 {
827 let (c, s) = givens_rotation(diag[j + 1], bulge);
828 diag[j + 1] = c * diag[j + 1] + s * bulge;
829 if j + 1 < block_end - 1 {
831 let old_e = superdiag[j + 1];
832 superdiag[j + 1] = c * old_e;
833 bulge = -s * old_e;
834 apply_givens_cols(u_acc, u_rows, u_rows, j + 1, zero_idx, c, s);
835 if bulge.abs() < 1e-300 {
836 break;
837 }
838 } else {
839 apply_givens_cols(u_acc, u_rows, u_rows, j + 1, zero_idx, c, s);
840 }
841 }
842}
843
844fn zero_superdiag_col(
852 diag: &mut [f64],
853 superdiag: &mut [f64],
854 v_acc: &mut Matrix,
855 v_rows: usize,
856 zero_idx: usize,
857 block_start: usize,
858) {
859 let mut bulge = superdiag[zero_idx - 1];
860 superdiag[zero_idx - 1] = 0.0;
861
862 for j in (block_start..zero_idx).rev() {
863 let (c, s) = givens_rotation(diag[j], bulge);
864 diag[j] = c * diag[j] + s * bulge;
865 if j > block_start {
867 let old_e = superdiag[j - 1];
868 superdiag[j - 1] = c * old_e;
869 bulge = -s * old_e;
870 apply_givens_cols(v_acc, v_rows, v_rows, j, zero_idx, c, s);
871 if bulge.abs() < 1e-300 {
872 break;
873 }
874 } else {
875 apply_givens_cols(v_acc, v_rows, v_rows, j, zero_idx, c, s);
876 }
877 }
878}
879
880#[cfg(test)]
881mod tests {
882 use super::*;
883
884 #[test]
885 fn test_new_returns_default_parameters() {
886 let svd = GolubKahanSvd::new();
887 assert!((svd.tol - 1e-14).abs() < f64::EPSILON);
888 assert_eq!(svd.max_iter_factor, 30);
889 }
890
891 #[test]
892 fn test_with_tolerance_sets_custom_tol() {
893 let svd = GolubKahanSvd::new().with_tolerance(1e-8);
894 assert!((svd.tol - 1e-8).abs() < f64::EPSILON);
895 }
896
897 #[test]
898 fn test_with_max_iter_factor_sets_custom_factor() {
899 let svd = GolubKahanSvd::new().with_max_iter_factor(50);
900 assert_eq!(svd.max_iter_factor, 50);
901 }
902
903 #[test]
904 fn test_default_trait_matches_new() {
905 let a = GolubKahanSvd::new();
906 let b = GolubKahanSvd::default();
907 assert!((a.tol - b.tol).abs() < f64::EPSILON);
908 assert_eq!(a.max_iter_factor, b.max_iter_factor);
909 }
910
911 #[test]
912 fn test_svd_error_display() {
913 let err = SvdError::Convergence {
914 size: 10,
915 iterations: 300,
916 };
917 let msg = format!("{err}");
918 assert!(msg.contains("10"));
919 assert!(msg.contains("300"));
920 }
921
922 #[test]
923 fn test_svd_error_converts_to_pc_error() {
924 let err = SvdError::Convergence {
925 size: 5,
926 iterations: 150,
927 };
928 let pc_err: crate::error::PcError = err.into();
929 assert!(matches!(pc_err, crate::error::PcError::ConfigValidation(_)));
930 }
931
932 #[test]
933 fn test_empty_matrix() {
934 let a = Matrix::zeros(0, 0);
936 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
937 assert_eq!(u.rows, 0);
938 assert_eq!(u.cols, 0);
939 assert!(s.is_empty());
940 assert_eq!(v.rows, 0);
941 assert_eq!(v.cols, 0);
942 }
943
944 #[test]
945 fn test_nan_input_returns_error() {
946 let a = Matrix {
948 data: vec![1.0, f64::NAN, 3.0, 4.0],
949 rows: 2,
950 cols: 2,
951 };
952 let result = GolubKahanSvd::new().compute(&a);
953 assert!(result.is_err());
954 let err = result.unwrap_err();
955 assert!(matches!(err, SvdError::InvalidInput { .. }));
956 }
957
958 #[test]
959 fn test_inf_input_returns_error() {
960 let a = Matrix {
962 data: vec![1.0, f64::INFINITY, 3.0, 4.0],
963 rows: 2,
964 cols: 2,
965 };
966 let result = GolubKahanSvd::new().compute(&a);
967 assert!(result.is_err());
968 let err = result.unwrap_err();
969 assert!(matches!(err, SvdError::InvalidInput { .. }));
970 }
971
972 #[test]
973 fn test_neg_inf_input_returns_error() {
974 let a = Matrix {
975 data: vec![f64::NEG_INFINITY, 2.0, 3.0, 4.0],
976 rows: 2,
977 cols: 2,
978 };
979 let result = GolubKahanSvd::new().compute(&a);
980 assert!(result.is_err());
981 }
982
983 fn mat_mul_raw(a: &Matrix, b: &Matrix) -> Matrix {
987 assert_eq!(a.cols, b.rows);
988 let mut c = Matrix::zeros(a.rows, b.cols);
989 for i in 0..a.rows {
990 for k in 0..a.cols {
991 let aik = a.data[i * a.cols + k];
992 for j in 0..b.cols {
993 c.data[i * c.cols + j] += aik * b.data[k * b.cols + j];
994 }
995 }
996 }
997 c
998 }
999
1000 fn transpose_raw(a: &Matrix) -> Matrix {
1001 let mut t = Matrix::zeros(a.cols, a.rows);
1002 for r in 0..a.rows {
1003 for c in 0..a.cols {
1004 t.data[c * t.cols + r] = a.data[r * a.cols + c];
1005 }
1006 }
1007 t
1008 }
1009
1010 fn assert_reconstruction(a: &Matrix, u: &Matrix, s: &[f64], v: &Matrix, tol: f64) {
1011 let k = s.len();
1012 let mut diag_s = Matrix::zeros(k, k);
1013 for (i, &si) in s.iter().enumerate() {
1014 diag_s.data[i * k + i] = si;
1015 }
1016 let us = mat_mul_raw(u, &diag_s);
1017 let recon = mat_mul_raw(&us, &transpose_raw(v));
1018 for r in 0..a.rows {
1019 for c in 0..a.cols {
1020 let diff = (recon.data[r * recon.cols + c] - a.data[r * a.cols + c]).abs();
1021 assert!(
1022 diff < tol,
1023 "reconstruction mismatch at ({r},{c}): got {} expected {}, diff {diff}",
1024 recon.data[r * recon.cols + c],
1025 a.data[r * a.cols + c]
1026 );
1027 }
1028 }
1029 }
1030
1031 fn assert_orthonormal_columns(m: &Matrix, tol: f64) {
1032 let mtm = mat_mul_raw(&transpose_raw(m), m);
1033 let k = mtm.rows;
1034 for i in 0..k {
1035 for j in 0..k {
1036 let expected = if i == j { 1.0 } else { 0.0 };
1037 let diff = (mtm.data[i * k + j] - expected).abs();
1038 assert!(
1039 diff < tol,
1040 "orthonormality violated at ({i},{j}): got {}, expected {expected}",
1041 mtm.data[i * k + j]
1042 );
1043 }
1044 }
1045 }
1046
1047 fn assert_singular_values_sorted(s: &[f64]) {
1048 for (i, &si) in s.iter().enumerate() {
1049 assert!(si >= -1e-14, "singular value s[{i}] = {si} is negative");
1050 }
1051 for i in 1..s.len() {
1052 assert!(
1053 s[i - 1] >= s[i] - 1e-12,
1054 "not descending: s[{}]={} < s[{}]={}",
1055 i - 1,
1056 s[i - 1],
1057 i,
1058 s[i]
1059 );
1060 }
1061 }
1062
1063 #[test]
1066 fn test_identity_3x3() {
1067 let mut data = vec![0.0; 9];
1068 for i in 0..3 {
1069 data[i * 3 + i] = 1.0;
1070 }
1071 let a = Matrix {
1072 data,
1073 rows: 3,
1074 cols: 3,
1075 };
1076 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1077 for &si in &s {
1078 assert!((si - 1.0).abs() < 1e-10, "expected 1.0, got {si}");
1079 }
1080 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1081 assert_orthonormal_columns(&u, 1e-10);
1082 assert_orthonormal_columns(&v, 1e-10);
1083 assert_singular_values_sorted(&s);
1084 }
1085
1086 #[test]
1087 fn test_diagonal_matrix() {
1088 let a = Matrix {
1089 data: vec![5.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 1.0],
1090 rows: 3,
1091 cols: 3,
1092 };
1093 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1094 assert!((s[0] - 5.0).abs() < 1e-10);
1095 assert!((s[1] - 3.0).abs() < 1e-10);
1096 assert!((s[2] - 1.0).abs() < 1e-10);
1097 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1098 assert_orthonormal_columns(&u, 1e-10);
1099 assert_orthonormal_columns(&v, 1e-10);
1100 }
1101
1102 #[test]
1103 fn test_known_2x2() {
1104 let a = Matrix {
1106 data: vec![3.0, 2.0, 2.0, 3.0],
1107 rows: 2,
1108 cols: 2,
1109 };
1110 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1111 assert!((s[0] - 5.0).abs() < 1e-10, "expected s[0]=5, got {}", s[0]);
1112 assert!((s[1] - 1.0).abs() < 1e-10, "expected s[1]=1, got {}", s[1]);
1113 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1114 assert_orthonormal_columns(&u, 1e-10);
1115 assert_orthonormal_columns(&v, 1e-10);
1116 }
1117
1118 #[test]
1119 fn test_known_3x3() {
1120 let a = Matrix {
1121 data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0],
1122 rows: 3,
1123 cols: 3,
1124 };
1125 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1126 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1127 assert_orthonormal_columns(&u, 1e-10);
1128 assert_orthonormal_columns(&v, 1e-10);
1129 assert_singular_values_sorted(&s);
1130 }
1131
1132 #[test]
1133 fn test_known_4x4() {
1134 let a = Matrix {
1135 data: vec![
1136 2.0, -1.0, 0.0, 0.0, -1.0, 2.0, -1.0, 0.0, 0.0, -1.0, 2.0, -1.0, 0.0, 0.0, -1.0,
1137 2.0,
1138 ],
1139 rows: 4,
1140 cols: 4,
1141 };
1142 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1143 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1144 assert_orthonormal_columns(&u, 1e-10);
1145 assert_orthonormal_columns(&v, 1e-10);
1146 assert_singular_values_sorted(&s);
1147 }
1148
1149 #[test]
1150 fn test_tall_rectangular() {
1151 let a = Matrix {
1152 data: vec![
1153 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
1154 ],
1155 rows: 5,
1156 cols: 3,
1157 };
1158 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1159 assert_eq!(u.rows, 5);
1160 assert_eq!(u.cols, 3);
1161 assert_eq!(s.len(), 3);
1162 assert_eq!(v.rows, 3);
1163 assert_eq!(v.cols, 3);
1164 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1165 assert_orthonormal_columns(&u, 1e-10);
1166 assert_orthonormal_columns(&v, 1e-10);
1167 }
1168
1169 #[test]
1170 fn test_wide_rectangular() {
1171 let a = Matrix {
1172 data: vec![
1173 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
1174 ],
1175 rows: 3,
1176 cols: 5,
1177 };
1178 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1179 assert_eq!(u.rows, 3);
1180 assert_eq!(u.cols, 3);
1181 assert_eq!(s.len(), 3);
1182 assert_eq!(v.rows, 5);
1183 assert_eq!(v.cols, 3);
1184 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1185 assert_orthonormal_columns(&u, 1e-10);
1186 assert_orthonormal_columns(&v, 1e-10);
1187 }
1188
1189 #[test]
1190 fn test_single_element() {
1191 let a = Matrix {
1192 data: vec![7.0],
1193 rows: 1,
1194 cols: 1,
1195 };
1196 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1197 assert!((s[0] - 7.0).abs() < 1e-10);
1198 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1199 }
1200
1201 #[test]
1202 fn test_single_element_negative() {
1203 let a = Matrix {
1204 data: vec![-5.0],
1205 rows: 1,
1206 cols: 1,
1207 };
1208 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1209 assert!((s[0] - 5.0).abs() < 1e-10);
1210 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1211 }
1212
1213 #[test]
1214 fn test_single_row() {
1215 let a = Matrix {
1216 data: vec![1.0, 2.0, 3.0, 4.0],
1217 rows: 1,
1218 cols: 4,
1219 };
1220 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1221 assert_eq!(s.len(), 1);
1222 let expected = (1.0f64 + 4.0 + 9.0 + 16.0).sqrt();
1223 assert!((s[0] - expected).abs() < 1e-10);
1224 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1225 }
1226
1227 #[test]
1228 fn test_single_column() {
1229 let a = Matrix {
1230 data: vec![1.0, 2.0, 3.0, 4.0],
1231 rows: 4,
1232 cols: 1,
1233 };
1234 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1235 assert_eq!(s.len(), 1);
1236 let expected = (1.0f64 + 4.0 + 9.0 + 16.0).sqrt();
1237 assert!((s[0] - expected).abs() < 1e-10);
1238 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1239 }
1240
1241 #[test]
1242 fn test_zero_matrix() {
1243 let a = Matrix::zeros(3, 3);
1244 let (_u, s, _v) = GolubKahanSvd::new().compute(&a).unwrap();
1245 for &si in &s {
1246 assert!(si.abs() < 1e-12);
1247 }
1248 assert_singular_values_sorted(&s);
1249 }
1250
1251 #[test]
1252 fn test_rank_deficient() {
1253 let a = Matrix {
1255 data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 5.0, 7.0, 9.0],
1256 rows: 3,
1257 cols: 3,
1258 };
1259 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1260 assert!(
1261 s[2] < 1e-10,
1262 "third singular value should be ~0, got {}",
1263 s[2]
1264 );
1265 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1266 assert_orthonormal_columns(&u, 1e-10);
1267 assert_orthonormal_columns(&v, 1e-10);
1268 assert_singular_values_sorted(&s);
1269 }
1270
1271 #[test]
1272 fn test_rank_one() {
1273 let a = Matrix {
1275 data: vec![4.0, 5.0, 8.0, 10.0, 12.0, 15.0],
1276 rows: 3,
1277 cols: 2,
1278 };
1279 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1280 let norm_u = (1.0f64 + 4.0 + 9.0).sqrt();
1281 let norm_v = (16.0f64 + 25.0).sqrt();
1282 let expected_s0 = norm_u * norm_v;
1283 assert!(
1284 (s[0] - expected_s0).abs() < 1e-8,
1285 "expected s[0]={expected_s0}, got {}",
1286 s[0]
1287 );
1288 assert!(s[1] < 1e-10, "expected s[1]~0, got {}", s[1]);
1289 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1290 }
1291
1292 #[test]
1293 fn test_repeated_singular_values() {
1294 let a = Matrix {
1296 data: vec![4.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 2.0],
1297 rows: 3,
1298 cols: 3,
1299 };
1300 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1301 assert!((s[0] - 4.0).abs() < 1e-10);
1302 assert!((s[1] - 4.0).abs() < 1e-10);
1303 assert!((s[2] - 2.0).abs() < 1e-10);
1304 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1305 assert_orthonormal_columns(&u, 1e-10);
1306 assert_orthonormal_columns(&v, 1e-10);
1307 }
1308
1309 #[test]
1310 fn test_diagonal_with_zeros() {
1311 let a = Matrix {
1313 data: vec![5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0],
1314 rows: 3,
1315 cols: 3,
1316 };
1317 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1318 assert!((s[0] - 5.0).abs() < 1e-10);
1319 assert!((s[1] - 3.0).abs() < 1e-10);
1320 assert!(s[2] < 1e-10);
1321 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1322 assert_singular_values_sorted(&s);
1323 }
1324
1325 #[test]
1326 fn test_ill_conditioned() {
1327 let a = Matrix {
1329 data: vec![1.0, 0.0, 0.0, 0.0, 1e-12, 0.0, 0.0, 0.0, 1e-6],
1330 rows: 3,
1331 cols: 3,
1332 };
1333 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1334 assert!((s[0] - 1.0).abs() < 1e-8);
1335 assert_reconstruction(&a, &u, &s, &v, 1e-6);
1336 assert_singular_values_sorted(&s);
1337 }
1338
1339 #[test]
1340 fn test_extreme_small_values() {
1341 let a = Matrix {
1343 data: vec![1e-300, 0.0, 0.0, 2e-300],
1344 rows: 2,
1345 cols: 2,
1346 };
1347 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1348 assert!(s[0].is_finite());
1349 assert!(s[1].is_finite());
1350 assert_singular_values_sorted(&s);
1351 assert_reconstruction(&a, &u, &s, &v, 1e-290);
1352 }
1353
1354 #[test]
1355 fn test_extreme_large_values() {
1356 let a = Matrix {
1358 data: vec![1e+150, 0.0, 0.0, 2e+150],
1359 rows: 2,
1360 cols: 2,
1361 };
1362 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1363 assert!(s[0].is_finite());
1364 assert!(s[1].is_finite());
1365 for &val in &u.data {
1366 assert!(val.is_finite());
1367 }
1368 for &val in &v.data {
1369 assert!(val.is_finite());
1370 }
1371 assert_singular_values_sorted(&s);
1372 }
1373
1374 #[test]
1375 fn test_convergence_64x64() {
1376 use rand::Rng;
1378 use rand::SeedableRng;
1379 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
1380 let data: Vec<f64> = (0..64 * 64).map(|_| rng.gen_range(-1.0..1.0)).collect();
1381 let a = Matrix {
1382 data,
1383 rows: 64,
1384 cols: 64,
1385 };
1386 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1387 assert_reconstruction(&a, &u, &s, &v, 1e-8);
1388 assert_orthonormal_columns(&u, 1e-8);
1389 assert_orthonormal_columns(&v, 1e-8);
1390 assert_singular_values_sorted(&s);
1391 }
1392
1393 #[test]
1394 fn test_convergence_128x128() {
1395 use rand::Rng;
1397 use rand::SeedableRng;
1398 let mut rng = rand::rngs::StdRng::seed_from_u64(123);
1399 let data: Vec<f64> = (0..128 * 128).map(|_| rng.gen_range(-1.0..1.0)).collect();
1400 let a = Matrix {
1401 data,
1402 rows: 128,
1403 cols: 128,
1404 };
1405 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1406 assert_reconstruction(&a, &u, &s, &v, 1e-8);
1407 assert_orthonormal_columns(&u, 1e-8);
1408 assert_orthonormal_columns(&v, 1e-8);
1409 assert_singular_values_sorted(&s);
1410 }
1411
1412 #[test]
1413 fn test_almost_bidiagonal() {
1414 let a = Matrix {
1416 data: vec![
1417 5.0, 2.0, 0.0, 0.0, 0.0, 4.0, 1.0, 0.0, 0.0, 0.0, 3.0, 0.5, 0.0, 0.0, 0.0, 1.0,
1418 ],
1419 rows: 4,
1420 cols: 4,
1421 };
1422 let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1423 assert_reconstruction(&a, &u, &s, &v, 1e-10);
1424 assert_singular_values_sorted(&s);
1425 }
1426
1427 #[test]
1428 fn test_custom_tolerance() {
1429 let a = Matrix {
1431 data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0],
1432 rows: 3,
1433 cols: 3,
1434 };
1435 let (u, s, v) = GolubKahanSvd::new()
1436 .with_tolerance(1e-15)
1437 .compute(&a)
1438 .unwrap();
1439 assert_reconstruction(&a, &u, &s, &v, 1e-12);
1440 }
1441
1442 #[test]
1443 fn test_low_max_iter_triggers_error() {
1444 let a = Matrix {
1446 data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0],
1447 rows: 3,
1448 cols: 3,
1449 };
1450 let result = GolubKahanSvd::new().with_max_iter_factor(0).compute(&a);
1451 assert!(result.is_err(), "expected convergence error with factor=0");
1452 let err = result.unwrap_err();
1453 assert!(matches!(err, SvdError::Convergence { .. }));
1454 }
1455
1456 #[test]
1457 fn test_determinism() {
1458 let a = Matrix {
1460 data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0],
1461 rows: 3,
1462 cols: 3,
1463 };
1464 let svd = GolubKahanSvd::new();
1465 let (u1, s1, v1) = svd.compute(&a).unwrap();
1466 let (u2, s2, v2) = svd.compute(&a).unwrap();
1467 assert_eq!(s1, s2, "singular values differ");
1468 assert_eq!(u1.data, u2.data, "U differs");
1469 assert_eq!(v1.data, v2.data, "V differs");
1470 }
1471}