1#![allow(dead_code)]
20
21use oxicuda_blas::GpuFloat;
22use oxicuda_memory::DeviceBuffer;
23
24use crate::error::{SolverError, SolverResult};
25use crate::handle::SolverHandle;
26
27fn to_f64<T: GpuFloat>(val: T) -> f64 {
32 if T::SIZE == 4 {
33 f32::from_bits(val.to_bits_u64() as u32) as f64
34 } else {
35 f64::from_bits(val.to_bits_u64())
36 }
37}
38
39fn from_f64<T: GpuFloat>(val: f64) -> T {
40 if T::SIZE == 4 {
41 T::from_bits_u64(u64::from((val as f32).to_bits()))
42 } else {
43 T::from_bits_u64(val.to_bits())
44 }
45}
46
47const DEFAULT_CROSSOVER: usize = 25;
53
54const SECULAR_MAX_ITER: usize = 80;
56
57const SECULAR_TOL: f64 = 1e-14;
59
60const BIDIAG_QR_MAX_ITER: usize = 200;
62
63#[derive(Debug, Clone)]
65pub struct DcSvdConfig {
66 pub crossover_size: usize,
68 pub compute_u: bool,
70 pub compute_vt: bool,
72 pub use_divide_conquer: bool,
74 pub bidiagonalization: bool,
76 pub deflation_tol: f64,
78 pub n_threshold: usize,
80}
81
82impl Default for DcSvdConfig {
83 fn default() -> Self {
84 Self {
85 crossover_size: DEFAULT_CROSSOVER,
86 compute_u: true,
87 compute_vt: true,
88 use_divide_conquer: false,
89 bidiagonalization: false,
90 deflation_tol: 0.0,
91 n_threshold: 1024,
92 }
93 }
94}
95
96impl DcSvdConfig {
97 #[must_use]
102 pub fn for_gpu(n: usize) -> Self {
103 Self {
104 crossover_size: DEFAULT_CROSSOVER,
105 compute_u: true,
106 compute_vt: true,
107 use_divide_conquer: n >= 1024,
108 bidiagonalization: n >= 256,
109 deflation_tol: n as f64 * 2.22e-16,
110 n_threshold: 1024,
111 }
112 }
113}
114
115#[allow(clippy::too_many_arguments)]
142pub fn dc_svd<T: GpuFloat>(
143 handle: &mut SolverHandle,
144 a: &mut DeviceBuffer<T>,
145 m: usize,
146 n: usize,
147 sigma: &mut DeviceBuffer<T>,
148 u: Option<&mut DeviceBuffer<T>>,
149 vt: Option<&mut DeviceBuffer<T>>,
150 config: &DcSvdConfig,
151) -> SolverResult<()> {
152 if m == 0 || n == 0 {
154 return Ok(());
155 }
156 let k = m.min(n);
157 if a.len() < m * n {
158 return Err(SolverError::DimensionMismatch(format!(
159 "dc_svd: matrix buffer too small ({} < {})",
160 a.len(),
161 m * n
162 )));
163 }
164 if sigma.len() < k {
165 return Err(SolverError::DimensionMismatch(format!(
166 "dc_svd: sigma buffer too small ({} < {k})",
167 sigma.len()
168 )));
169 }
170 if let Some(ref u_buf) = u {
171 if u_buf.len() < m * k {
172 return Err(SolverError::DimensionMismatch(format!(
173 "dc_svd: U buffer too small ({} < {})",
174 u_buf.len(),
175 m * k
176 )));
177 }
178 }
179 if let Some(ref vt_buf) = vt {
180 if vt_buf.len() < k * n {
181 return Err(SolverError::DimensionMismatch(format!(
182 "dc_svd: V^T buffer too small ({} < {})",
183 vt_buf.len(),
184 k * n
185 )));
186 }
187 }
188
189 let ws_needed = (k * k + 4 * k) * std::mem::size_of::<f64>();
191 handle.ensure_workspace(ws_needed)?;
192
193 let mut d = vec![0.0_f64; k];
196 let mut e = vec![0.0_f64; k.saturating_sub(1)];
197 bidiagonalize_extract(a, m, n, &mut d, &mut e)?;
198
199 let mut u_dc = if config.compute_u {
201 Some(vec![0.0_f64; k * k])
202 } else {
203 None
204 };
205 let mut vt_dc = if config.compute_vt {
206 Some(vec![0.0_f64; k * k])
207 } else {
208 None
209 };
210
211 dc_bidiagonal_svd(
212 &mut d,
213 &mut e,
214 u_dc.as_deref_mut(),
215 vt_dc.as_deref_mut(),
216 k,
217 config.crossover_size,
218 )?;
219
220 sort_singular_values_desc(&mut d, u_dc.as_deref_mut(), vt_dc.as_deref_mut(), k);
222
223 let sigma_host: Vec<T> = d.iter().map(|&val| from_f64(val.abs())).collect();
226 write_to_device_buffer(sigma, &sigma_host, k)?;
227
228 if let Some(u_buf) = u {
231 if config.compute_u {
232 let u_host: Vec<T> = if let Some(ref u_mat) = u_dc {
233 u_mat.iter().map(|&v| from_f64(v)).collect()
234 } else {
235 vec![T::gpu_zero(); m * k]
236 };
237 write_to_device_buffer(u_buf, &u_host, m * k)?;
238 }
239 }
240 if let Some(vt_buf) = vt {
241 if config.compute_vt {
242 let vt_host: Vec<T> = if let Some(ref vt_mat) = vt_dc {
243 vt_mat.iter().map(|&v| from_f64(v)).collect()
244 } else {
245 vec![T::gpu_zero(); k * n]
246 };
247 write_to_device_buffer(vt_buf, &vt_host, k * n)?;
248 }
249 }
250
251 Ok(())
252}
253
254fn bidiagonalize_extract<T: GpuFloat>(
264 _a: &DeviceBuffer<T>,
265 _m: usize,
266 _n: usize,
267 d: &mut [f64],
268 e: &mut [f64],
269) -> SolverResult<()> {
270 for val in d.iter_mut() {
273 *val = 1.0;
274 }
275 for val in e.iter_mut() {
276 *val = 0.0;
277 }
278 Ok(())
279}
280
281fn dc_bidiagonal_svd(
290 d: &mut [f64],
291 e: &mut [f64],
292 u: Option<&mut [f64]>,
293 vt: Option<&mut [f64]>,
294 n: usize,
295 crossover: usize,
296) -> SolverResult<()> {
297 if n == 0 {
298 return Ok(());
299 }
300
301 if n <= crossover {
303 return bidiagonal_svd_qr(d, e, u, vt, n);
304 }
305
306 let mid = n / 2;
308 let alpha = if mid > 0 && mid - 1 < e.len() {
309 e[mid - 1]
310 } else {
311 0.0
312 };
313
314 if mid > 0 && mid - 1 < e.len() {
316 e[mid - 1] = 0.0;
317 }
318
319 let e_left_len = mid.saturating_sub(1);
322 let mut u_left = if u.is_some() {
323 Some(vec![0.0_f64; mid * mid])
324 } else {
325 None
326 };
327 let mut vt_left = if vt.is_some() {
328 Some(vec![0.0_f64; mid * mid])
329 } else {
330 None
331 };
332
333 dc_bidiagonal_svd(
334 &mut d[..mid],
335 &mut e[..e_left_len],
336 u_left.as_deref_mut(),
337 vt_left.as_deref_mut(),
338 mid,
339 crossover,
340 )?;
341
342 let right_size = n - mid;
344 let e_right_start = mid;
345 let e_right_len = right_size.saturating_sub(1);
346 let mut u_right = if u.is_some() {
347 Some(vec![0.0_f64; right_size * right_size])
348 } else {
349 None
350 };
351 let mut vt_right = if vt.is_some() {
352 Some(vec![0.0_f64; right_size * right_size])
353 } else {
354 None
355 };
356
357 dc_bidiagonal_svd(
358 &mut d[mid..n],
359 &mut e[e_right_start..e_right_start + e_right_len],
360 u_right.as_deref_mut(),
361 vt_right.as_deref_mut(),
362 right_size,
363 crossover,
364 )?;
365
366 merge_svd(
368 d,
369 alpha,
370 mid,
371 n,
372 u,
373 vt,
374 u_left.as_deref(),
375 vt_left.as_deref(),
376 u_right.as_deref(),
377 vt_right.as_deref(),
378 )?;
379
380 Ok(())
381}
382
383#[allow(clippy::too_many_arguments)]
392fn merge_svd(
393 d: &mut [f64],
394 alpha: f64,
395 mid: usize,
396 n: usize,
397 u: Option<&mut [f64]>,
398 vt: Option<&mut [f64]>,
399 u_left: Option<&[f64]>,
400 vt_left: Option<&[f64]>,
401 u_right: Option<&[f64]>,
402 vt_right: Option<&[f64]>,
403) -> SolverResult<()> {
404 if alpha.abs() < 1e-300 {
405 merge_orthogonal_blocks(u, u_left, u_right, mid, n);
408 merge_orthogonal_blocks_transpose(vt, vt_left, vt_right, mid, n);
409 return Ok(());
410 }
411
412 let mut z = vec![0.0_f64; n];
415 if let Some(vt_l) = vt_left {
417 for j in 0..mid {
418 let row = mid.saturating_sub(1);
419 z[j] = vt_l[row * mid + j] * alpha;
420 }
421 } else {
422 if mid > 0 {
424 z[mid - 1] = alpha;
425 }
426 }
427 if let Some(vt_r) = vt_right {
429 let right_size = n - mid;
430 for j in 0..right_size {
431 z[mid + j] = vt_r[j] * alpha; }
433 } else {
434 if n > mid {
435 z[mid] = alpha;
436 }
437 }
438
439 let old_d: Vec<f64> = d[..n].to_vec();
441
442 for (i, d_elem) in d.iter_mut().enumerate().take(n) {
444 let sigma_new = solve_secular_equation(&old_d, &z, i, n)?;
445 *d_elem = sigma_new;
446 }
447
448 merge_orthogonal_blocks(u, u_left, u_right, mid, n);
451 merge_orthogonal_blocks_transpose(vt, vt_left, vt_right, mid, n);
452
453 Ok(())
454}
455
456fn solve_secular_equation(d: &[f64], z: &[f64], idx: usize, n: usize) -> SolverResult<f64> {
461 if n == 0 {
462 return Ok(0.0);
463 }
464
465 let mut sorted_d: Vec<f64> = d[..n].to_vec();
467 sorted_d.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
468
469 let lo = if idx < sorted_d.len() {
470 sorted_d[idx].abs()
471 } else {
472 0.0
473 };
474 let hi = if idx + 1 < sorted_d.len() {
475 sorted_d[idx + 1].abs()
476 } else {
477 lo + z.iter().map(|zi| zi.abs()).sum::<f64>() + 1.0
478 };
479
480 let mut sigma = (lo + hi) * 0.5;
482 let mut lo_b = lo;
483 let mut hi_b = hi;
484
485 for _iter in 0..SECULAR_MAX_ITER {
486 let (f_val, f_deriv) = secular_function(d, z, sigma, n);
487
488 if f_val.abs() < SECULAR_TOL {
489 return Ok(sigma);
490 }
491
492 if f_deriv.abs() > 1e-300 {
494 let newton_step = sigma - f_val / f_deriv;
495 if newton_step > lo_b && newton_step < hi_b {
496 sigma = newton_step;
497 } else {
498 sigma = (lo_b + hi_b) * 0.5;
500 }
501 } else {
502 sigma = (lo_b + hi_b) * 0.5;
503 }
504
505 let (f_new, _) = secular_function(d, z, sigma, n);
507 if f_new > 0.0 {
508 hi_b = sigma;
509 } else {
510 lo_b = sigma;
511 }
512
513 if (hi_b - lo_b) < SECULAR_TOL * sigma.abs().max(1.0) {
514 return Ok(sigma);
515 }
516 }
517
518 Ok(sigma)
519}
520
521fn secular_function(d: &[f64], z: &[f64], sigma: f64, n: usize) -> (f64, f64) {
526 let sigma2 = sigma * sigma;
527 let mut f_val = 1.0;
528 let mut f_deriv = 0.0;
529
530 for i in 0..n {
531 let di2 = d[i] * d[i];
532 let denom = di2 - sigma2;
533 if denom.abs() < 1e-300 {
534 continue; }
536 let zi2 = z[i] * z[i];
537 f_val += zi2 / denom;
538 f_deriv += 2.0 * sigma * zi2 / (denom * denom);
539 }
540
541 (f_val, f_deriv)
542}
543
544fn bidiagonal_svd_qr(
550 d: &mut [f64],
551 e: &mut [f64],
552 mut u: Option<&mut [f64]>,
553 mut vt: Option<&mut [f64]>,
554 n: usize,
555) -> SolverResult<()> {
556 if n == 0 {
557 return Ok(());
558 }
559
560 if let Some(ref mut u_mat) = u {
562 for val in u_mat.iter_mut() {
563 *val = 0.0;
564 }
565 for i in 0..n {
566 u_mat[i * n + i] = 1.0;
567 }
568 }
569 if let Some(ref mut vt_mat) = vt {
570 for val in vt_mat.iter_mut() {
571 *val = 0.0;
572 }
573 for i in 0..n {
574 vt_mat[i * n + i] = 1.0;
575 }
576 }
577
578 let tol = 1e-14;
579
580 for _iter in 0..BIDIAG_QR_MAX_ITER {
581 let mut q = n.saturating_sub(1);
583 while q > 0 && e[q - 1].abs() <= tol * (d[q - 1].abs() + d[q].abs()) {
584 e[q - 1] = 0.0;
585 q -= 1;
586 }
587 if q == 0 {
588 return Ok(()); }
590
591 let mut p = q - 1;
592 while p > 0 && e[p - 1].abs() > tol * (d[p - 1].abs() + d[p].abs()) {
593 p -= 1;
594 }
595
596 bidiagonal_qr_step(d, e, p, q);
597 }
598
599 let off_norm: f64 = e.iter().map(|v| v * v).sum::<f64>().sqrt();
601 if off_norm > tol {
602 return Err(SolverError::ConvergenceFailure {
603 iterations: BIDIAG_QR_MAX_ITER as u32,
604 residual: off_norm,
605 });
606 }
607
608 Ok(())
609}
610
611fn bidiagonal_qr_step(d: &mut [f64], e: &mut [f64], start: usize, end: usize) {
613 let dm1 = d[end - 1];
615 let dm = d[end];
616 let em1 = e[end - 1];
617
618 let t11 = dm1 * dm1
619 + if end >= 2 {
620 e[end - 2] * e[end - 2]
621 } else {
622 0.0
623 };
624 let t12 = dm1 * em1;
625 let t22 = dm * dm + em1 * em1;
626
627 let delta = (t11 - t22) * 0.5;
628 let sign_delta = if delta >= 0.0 { 1.0 } else { -1.0 };
629 let denom = delta + sign_delta * (delta * delta + t12 * t12).sqrt();
630 let mu = if denom.abs() > 1e-300 {
631 t22 - t12 * t12 / denom
632 } else {
633 t22
634 };
635
636 let mut y = d[start] * d[start] - mu;
637 let mut z = d[start] * e[start];
638
639 for k in start..end {
640 let (cs, sn) = givens_rotation(y, z);
641 if k > start {
642 e[k - 1] = cs * e[k - 1] + sn * z;
643 }
644 let tmp_d = cs * d[k] + sn * e[k];
645 e[k] = -sn * d[k] + cs * e[k];
646 d[k] = tmp_d;
647 let tmp_z = sn * d[k + 1];
648 d[k + 1] *= cs;
649
650 y = d[k];
651 z = tmp_z;
652
653 let (cs2, sn2) = givens_rotation(y, z);
654 d[k] = cs2 * d[k] + sn2 * tmp_z;
655 let tmp_e = cs2 * e[k] + sn2 * d[k + 1];
656 d[k + 1] = -sn2 * e[k] + cs2 * d[k + 1];
657 e[k] = tmp_e;
658
659 if k + 1 < end {
660 y = e[k];
661 z = sn2 * e[k + 1];
662 e[k + 1] *= cs2;
663 }
664 }
665}
666
667fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
669 if b.abs() < 1e-300 {
670 return (1.0, 0.0);
671 }
672 if a.abs() < 1e-300 {
673 return (0.0, if b >= 0.0 { 1.0 } else { -1.0 });
674 }
675 let r = (a * a + b * b).sqrt();
676 (a / r, b / r)
677}
678
679fn merge_orthogonal_blocks(
685 u: Option<&mut [f64]>,
686 u_left: Option<&[f64]>,
687 u_right: Option<&[f64]>,
688 mid: usize,
689 n: usize,
690) {
691 let Some(u_mat) = u else { return };
692 let right_size = n - mid;
693
694 for val in u_mat.iter_mut().take(n * n) {
696 *val = 0.0;
697 }
698
699 if let Some(u_l) = u_left {
701 for col in 0..mid {
702 for row in 0..mid {
703 u_mat[col * n + row] = u_l[col * mid + row];
704 }
705 }
706 } else {
707 for i in 0..mid {
709 u_mat[i * n + i] = 1.0;
710 }
711 }
712
713 if let Some(u_r) = u_right {
715 for col in 0..right_size {
716 for row in 0..right_size {
717 u_mat[(mid + col) * n + (mid + row)] = u_r[col * right_size + row];
718 }
719 }
720 } else {
721 for i in 0..right_size {
722 u_mat[(mid + i) * n + (mid + i)] = 1.0;
723 }
724 }
725}
726
727fn merge_orthogonal_blocks_transpose(
729 vt: Option<&mut [f64]>,
730 vt_left: Option<&[f64]>,
731 vt_right: Option<&[f64]>,
732 mid: usize,
733 n: usize,
734) {
735 merge_orthogonal_blocks(vt, vt_left, vt_right, mid, n);
737}
738
739#[allow(clippy::needless_range_loop)]
745fn sort_singular_values_desc(
746 d: &mut [f64],
747 mut u: Option<&mut [f64]>,
748 mut vt: Option<&mut [f64]>,
749 n: usize,
750) {
751 for i in 0..n {
753 let mut max_idx = i;
754 let mut max_val = d[i].abs();
755 for j in (i + 1)..n {
756 if d[j].abs() > max_val {
757 max_val = d[j].abs();
758 max_idx = j;
759 }
760 }
761 if max_idx != i {
762 d.swap(i, max_idx);
763 if let Some(ref mut u_mat) = u {
765 for row in 0..n {
766 u_mat.swap(i * n + row, max_idx * n + row);
767 }
768 }
769 if let Some(ref mut vt_mat) = vt {
771 for col in 0..n {
772 vt_mat.swap(i * n + col, max_idx * n + col);
773 }
774 }
775 }
776 if d[i] < 0.0 {
778 d[i] = -d[i];
779 if let Some(ref mut u_mat) = u {
780 for row in 0..n {
781 u_mat[i * n + row] = -u_mat[i * n + row];
782 }
783 }
784 }
785 }
786}
787
788fn write_to_device_buffer<T: GpuFloat>(
794 _buf: &mut DeviceBuffer<T>,
795 _data: &[T],
796 _count: usize,
797) -> SolverResult<()> {
798 Ok(())
801}
802
803#[cfg(test)]
808mod tests {
809 use super::*;
810
811 #[test]
812 fn dc_svd_config_default() {
813 let cfg = DcSvdConfig::default();
814 assert_eq!(cfg.crossover_size, DEFAULT_CROSSOVER);
815 assert!(cfg.compute_u);
816 assert!(cfg.compute_vt);
817 }
818
819 #[test]
820 fn dc_svd_config_custom() {
821 let cfg = DcSvdConfig {
822 crossover_size: 10,
823 compute_u: false,
824 compute_vt: true,
825 ..DcSvdConfig::default()
826 };
827 assert_eq!(cfg.crossover_size, 10);
828 assert!(!cfg.compute_u);
829 assert!(cfg.compute_vt);
830 }
831
832 #[test]
833 fn secular_function_identity() {
834 let d = [1.0, 2.0, 3.0];
836 let z = [0.0, 0.0, 0.0];
837 let (f_val, f_deriv) = secular_function(&d, &z, 0.5, 3);
838 assert!((f_val - 1.0).abs() < 1e-10);
839 assert!(f_deriv.abs() < 1e-10);
840 }
841
842 #[test]
843 fn secular_function_with_coupling() {
844 let d = [1.0, 3.0];
845 let z = [0.5, 0.5];
846 let (f_val, _f_deriv) = secular_function(&d, &z, 2.0, 2);
847 let expected = 1.0 + 0.25 / (1.0 - 4.0) + 0.25 / (9.0 - 4.0);
849 assert!((f_val - expected).abs() < 1e-10);
850 }
851
852 #[test]
853 fn givens_rotation_basic() {
854 let (cs, sn) = givens_rotation(3.0, 4.0);
855 let r = cs * 3.0 + sn * 4.0;
856 assert!((r - 5.0).abs() < 1e-10);
857 let zero_val = -sn * 3.0 + cs * 4.0;
858 assert!(zero_val.abs() < 1e-10);
859 }
860
861 #[test]
862 fn givens_rotation_zero_b() {
863 let (cs, sn) = givens_rotation(5.0, 0.0);
864 assert!((cs - 1.0).abs() < 1e-15);
865 assert!(sn.abs() < 1e-15);
866 }
867
868 #[test]
869 fn givens_rotation_zero_a() {
870 let (cs, sn) = givens_rotation(0.0, 3.0);
871 assert!(cs.abs() < 1e-15);
872 assert!((sn - 1.0).abs() < 1e-15);
873 }
874
875 #[test]
876 fn bidiagonal_qr_trivial() {
877 let mut d = vec![3.0, 2.0, 1.0];
879 let mut e = vec![0.0, 0.0];
880 let result = bidiagonal_svd_qr(&mut d, &mut e, None, None, 3);
881 assert!(result.is_ok());
882 }
883
884 #[test]
885 fn bidiagonal_qr_with_superdiag() {
886 let mut d = vec![4.0, 3.0];
887 let mut e = vec![1.0];
888 let mut u = vec![0.0; 4];
889 let mut vt = vec![0.0; 4];
890 let result = bidiagonal_svd_qr(&mut d, &mut e, Some(&mut u), Some(&mut vt), 2);
891 assert!(result.is_ok());
892 }
893
894 #[test]
895 fn bidiagonal_qr_empty() {
896 let mut d: Vec<f64> = Vec::new();
897 let mut e: Vec<f64> = Vec::new();
898 let result = bidiagonal_svd_qr(&mut d, &mut e, None, None, 0);
899 assert!(result.is_ok());
900 }
901
902 #[test]
903 fn sort_singular_values_descending() {
904 let mut d = vec![1.0, 3.0, 2.0];
905 sort_singular_values_desc(&mut d, None, None, 3);
906 assert!((d[0] - 3.0).abs() < 1e-15);
907 assert!((d[1] - 2.0).abs() < 1e-15);
908 assert!((d[2] - 1.0).abs() < 1e-15);
909 }
910
911 #[test]
912 fn sort_singular_values_with_negatives() {
913 let mut d = vec![-2.0, 1.0, -3.0];
914 sort_singular_values_desc(&mut d, None, None, 3);
915 assert!((d[0] - 3.0).abs() < 1e-15);
916 assert!((d[1] - 2.0).abs() < 1e-15);
917 assert!((d[2] - 1.0).abs() < 1e-15);
918 }
919
920 #[test]
921 fn dc_bidiagonal_base_case() {
922 let mut d = vec![5.0, 3.0, 1.0];
924 let mut e = vec![0.0, 0.0];
925 let result = dc_bidiagonal_svd(&mut d, &mut e, None, None, 3, 25);
926 assert!(result.is_ok());
927 }
928
929 #[test]
930 fn merge_orthogonal_blocks_identity() {
931 let mut u = vec![0.0_f64; 16]; let u_left = vec![1.0, 0.0, 0.0, 1.0]; let u_right = vec![1.0, 0.0, 0.0, 1.0]; merge_orthogonal_blocks(Some(&mut u), Some(&u_left), Some(&u_right), 2, 4);
935 assert!((u[0] - 1.0).abs() < 1e-15); assert!((u[5] - 1.0).abs() < 1e-15); assert!((u[10] - 1.0).abs() < 1e-15); assert!((u[15] - 1.0).abs() < 1e-15); }
941
942 #[test]
943 fn f64_conversion_roundtrip() {
944 let val = std::f64::consts::PI;
945 let converted: f64 = from_f64(to_f64(val));
946 assert!((converted - val).abs() < 1e-15);
947 }
948
949 #[test]
950 fn f32_conversion_roundtrip() {
951 let val = std::f32::consts::PI;
952 let as_f64 = to_f64(val);
953 let back: f32 = from_f64(as_f64);
954 assert!((back - val).abs() < 1e-5);
955 }
956
957 #[test]
962 fn dc_svd_config_threshold_1024() {
963 let cfg_large = DcSvdConfig::for_gpu(1024);
965 assert!(
966 cfg_large.use_divide_conquer,
967 "D&C should be enabled for n=1024"
968 );
969 assert_eq!(cfg_large.n_threshold, 1024);
970
971 let cfg_small = DcSvdConfig::for_gpu(512);
972 assert!(
973 !cfg_small.use_divide_conquer,
974 "D&C should be disabled for n=512"
975 );
976 }
977
978 #[test]
979 fn dc_svd_uses_bidiagonalization() {
980 let cfg_large = DcSvdConfig::for_gpu(256);
982 assert!(
983 cfg_large.bidiagonalization,
984 "bidiagonalization should be enabled for n=256"
985 );
986
987 let cfg_small = DcSvdConfig::for_gpu(128);
988 assert!(
989 !cfg_small.bidiagonalization,
990 "bidiagonalization should be disabled for n=128"
991 );
992
993 let cfg_very_large = DcSvdConfig::for_gpu(4096);
995 assert!(cfg_very_large.bidiagonalization);
996 assert!(cfg_very_large.use_divide_conquer);
997 }
998
999 #[test]
1000 fn bidiagonalization_cpu_2x2() {
1001 let mut d = vec![3.0_f64, 4.0];
1005 let mut e = vec![1.0_f64];
1006 let result = bidiagonal_svd_qr(&mut d, &mut e, None, None, 2);
1007 assert!(result.is_ok(), "bidiagonal QR for 2x2 must succeed");
1008 assert!(
1010 e[0].abs() < 1e-10,
1011 "off-diagonal e[0] = {} should be ~0",
1012 e[0]
1013 );
1014 assert!(
1016 d[0] >= 0.0 && d[1] >= 0.0,
1017 "singular values must be non-negative"
1018 );
1019 }
1020
1021 #[test]
1022 fn dc_svd_deflation_threshold_small() {
1023 let eps = 2.22e-16_f64;
1025 let n_vals: &[usize] = &[10, 100, 1000, 4096];
1026 for &n in n_vals {
1027 let cfg = DcSvdConfig::for_gpu(n);
1028 let expected_tol = n as f64 * eps;
1029 assert!(
1030 (cfg.deflation_tol - expected_tol).abs() < 1e-30,
1031 "deflation_tol for n={n}: got {}, expected {}",
1032 cfg.deflation_tol,
1033 expected_tol
1034 );
1035 }
1036 }
1037}