1use crate::error::SvdLibError;
2use ndarray::{Array, Array1, Array2};
3use num_traits::real::Real;
4use num_traits::{Float, FromPrimitive, One, Zero};
5use rand::rngs::StdRng;
6use rand::{thread_rng, Rng, SeedableRng};
7use rayon::iter::IndexedParallelIterator;
8use rayon::iter::ParallelIterator;
9use rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator};
10use std::fmt::Debug;
11use std::iter::Sum;
12use std::mem;
13use std::ops::{AddAssign, MulAssign, Neg, SubAssign};
14
15pub trait SMat<T: Float>: Sync {
16 fn nrows(&self) -> usize;
17 fn ncols(&self) -> usize;
18 fn nnz(&self) -> usize;
19 fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool); }
21
22#[derive(Debug, Clone, PartialEq)]
31pub struct SvdRec<T: Float> {
32 pub d: usize,
33 pub ut: Array2<T>,
34 pub s: Array1<T>,
35 pub vt: Array2<T>,
36 pub diagnostics: Diagnostics<T>,
37}
38
39#[derive(Debug, Clone, PartialEq)]
54pub struct Diagnostics<T: Float> {
55 pub non_zero: usize,
56 pub dimensions: usize,
57 pub iterations: usize,
58 pub transposed: bool,
59 pub lanczos_steps: usize,
60 pub ritz_values_stabilized: usize,
61 pub significant_values: usize,
62 pub singular_values: usize,
63 pub end_interval: [T; 2],
64 pub kappa: T,
65 pub random_seed: u32,
66}
67
68pub trait SvdFloat:
70 Float
71 + FromPrimitive
72 + Debug
73 + Send
74 + Sync
75 + Zero
76 + One
77 + AddAssign
78 + SubAssign
79 + MulAssign
80 + Neg<Output = Self>
81 + Sum
82{
83 fn eps() -> Self;
84 fn eps34() -> Self;
85 fn compare(a: Self, b: Self) -> bool;
86}
87
88impl SvdFloat for f32 {
89 fn eps() -> Self {
90 f32::EPSILON
91 }
92
93 fn eps34() -> Self {
94 f32::EPSILON.powf(0.75)
95 }
96
97 fn compare(a: Self, b: Self) -> bool {
98 (b - a).abs() < f32::EPSILON
99 }
100}
101
102impl SvdFloat for f64 {
103 fn eps() -> Self {
104 f64::EPSILON
105 }
106
107 fn eps34() -> Self {
108 f64::EPSILON.powf(0.75)
109 }
110
111 fn compare(a: Self, b: Self) -> bool {
112 (b - a).abs() < f64::EPSILON
113 }
114}
115
116pub fn svd<T, M>(a: &M) -> Result<SvdRec<T>, SvdLibError>
123where
124 T: SvdFloat,
125 M: SMat<T>,
126{
127 let eps_small = T::from_f64(-1.0e-30).unwrap();
128 let eps_large = T::from_f64(1.0e-30).unwrap();
129 let kappa = T::from_f64(1.0e-6).unwrap();
130 svd_las2(a, 0, 0, &[eps_small, eps_large], kappa, 0)
131}
132
133pub fn svd_dim<T, M>(a: &M, dimensions: usize) -> Result<SvdRec<T>, SvdLibError>
141where
142 T: SvdFloat,
143 M: SMat<T>,
144{
145 let eps_small = T::from_f64(-1.0e-30).unwrap();
146 let eps_large = T::from_f64(1.0e-30).unwrap();
147 let kappa = T::from_f64(1.0e-6).unwrap();
148
149 svd_las2(a, dimensions, 0, &[eps_small, eps_large], kappa, 0)
150}
151
152pub fn svd_dim_seed<T, M>(
161 a: &M,
162 dimensions: usize,
163 random_seed: u32,
164) -> Result<SvdRec<T>, SvdLibError>
165where
166 T: SvdFloat,
167 M: SMat<T>,
168{
169 let eps_small = T::from_f64(-1.0e-30).unwrap();
170 let eps_large = T::from_f64(1.0e-30).unwrap();
171 let kappa = T::from_f64(1.0e-6).unwrap();
172
173 svd_las2(
174 a,
175 dimensions,
176 0,
177 &[eps_small, eps_large],
178 kappa,
179 random_seed,
180 )
181}
182
183pub fn svd_las2<T, M>(
200 a: &M,
201 dimensions: usize,
202 iterations: usize,
203 end_interval: &[T; 2],
204 kappa: T,
205 random_seed: u32,
206) -> Result<SvdRec<T>, SvdLibError>
207where
208 T: SvdFloat,
209 M: SMat<T>,
210{
211 let random_seed = match random_seed > 0 {
212 true => random_seed,
213 false => thread_rng().gen::<_>(),
214 };
215
216 let min_nrows_ncols = a.nrows().min(a.ncols());
217
218 let dimensions = match dimensions {
219 n if n == 0 || n > min_nrows_ncols => min_nrows_ncols,
220 _ => dimensions,
221 };
222
223 let iterations = match iterations {
224 n if n == 0 || n > min_nrows_ncols => min_nrows_ncols,
225 n if n < dimensions => dimensions,
226 _ => iterations,
227 };
228
229 if dimensions < 2 {
230 return Err(SvdLibError::Las2Error(format!(
231 "svd_las2: insufficient dimensions: {dimensions}"
232 )));
233 }
234
235 assert!(dimensions > 1 && dimensions <= min_nrows_ncols);
236 assert!(iterations >= dimensions && iterations <= min_nrows_ncols);
237
238 let transposed = (a.ncols() as f64) >= ((a.nrows() as f64) * 1.2);
239 let nrows = if transposed { a.ncols() } else { a.nrows() };
240 let ncols = if transposed { a.nrows() } else { a.ncols() };
241
242 let mut wrk = WorkSpace::new(nrows, ncols, transposed, iterations)?;
243 let mut store = Store::new(ncols)?;
244
245 let mut neig = 0;
246 let steps = lanso(
247 a,
248 dimensions,
249 iterations,
250 end_interval,
251 &mut wrk,
252 &mut neig,
253 &mut store,
254 random_seed,
255 )?;
256
257 let kappa = kappa.abs().max(T::eps34());
258 let mut r = ritvec(a, dimensions, kappa, &mut wrk, steps, neig, &mut store)?;
259
260 if transposed {
261 mem::swap(&mut r.Ut, &mut r.Vt);
262 }
263
264 Ok(SvdRec {
265 d: r.d,
267 ut: Array2::from_shape_vec((r.d, r.Ut.cols), r.Ut.value)?,
268 s: Array::from_shape_vec(r.d, r.S)?,
269 vt: Array2::from_shape_vec((r.d, r.Vt.cols), r.Vt.value)?,
270 diagnostics: Diagnostics {
271 non_zero: a.nnz(),
272 dimensions: dimensions,
273 iterations: iterations,
274 transposed: transposed,
275 lanczos_steps: steps + 1,
276 ritz_values_stabilized: neig,
277 significant_values: r.d,
278 singular_values: r.nsig,
279 end_interval: *end_interval,
280 kappa: kappa,
281 random_seed: random_seed,
282 },
283 })
284}
285
286const MAXLL: usize = 2;
287
288#[derive(Debug, Clone, PartialEq)]
289struct Store<T: Float> {
290 n: usize,
291 vecs: Vec<Vec<T>>,
292}
293
294impl<T: Float + Zero + Clone> Store<T> {
295 fn new(n: usize) -> Result<Self, SvdLibError> {
296 Ok(Self { n, vecs: vec![] })
297 }
298
299 fn storq(&mut self, idx: usize, v: &[T]) {
300 while idx + MAXLL >= self.vecs.len() {
301 self.vecs.push(vec![T::zero(); self.n]);
302 }
303 self.vecs[idx + MAXLL].copy_from_slice(v);
304 }
305
306 fn storp(&mut self, idx: usize, v: &[T]) {
307 while idx >= self.vecs.len() {
308 self.vecs.push(vec![T::zero(); self.n]);
309 }
310 self.vecs[idx].copy_from_slice(v);
311 }
312
313 fn retrq(&mut self, idx: usize) -> &[T] {
314 &self.vecs[idx + MAXLL]
315 }
316
317 fn retrp(&mut self, idx: usize) -> &[T] {
318 &self.vecs[idx]
319 }
320}
321
322#[derive(Debug, Clone, PartialEq)]
323struct WorkSpace<T: Float> {
324 nrows: usize,
325 ncols: usize,
326 transposed: bool,
327 w0: Vec<T>, w1: Vec<T>, w2: Vec<T>, w3: Vec<T>, w4: Vec<T>, w5: Vec<T>, alf: Vec<T>, eta: Vec<T>, oldeta: Vec<T>, bet: Vec<T>, bnd: Vec<T>, ritz: Vec<T>, temp: Vec<T>, }
341
342impl<T: Float + Zero + FromPrimitive> WorkSpace<T> {
343 fn new(
344 nrows: usize,
345 ncols: usize,
346 transposed: bool,
347 iterations: usize,
348 ) -> Result<Self, SvdLibError> {
349 Ok(Self {
350 nrows,
351 ncols,
352 transposed,
353 w0: vec![T::zero(); ncols],
354 w1: vec![T::zero(); ncols],
355 w2: vec![T::zero(); ncols],
356 w3: vec![T::zero(); ncols],
357 w4: vec![T::zero(); ncols],
358 w5: vec![T::zero(); ncols],
359 alf: vec![T::zero(); iterations],
360 eta: vec![T::zero(); iterations],
361 oldeta: vec![T::zero(); iterations],
362 bet: vec![T::zero(); 1 + iterations],
363 ritz: vec![T::zero(); 1 + iterations],
364 bnd: vec![T::from_f64(f64::MAX).unwrap(); 1 + iterations],
365 temp: vec![T::zero(); nrows],
366 })
367 }
368}
369
370#[derive(Debug, Clone, PartialEq)]
372struct DMat<T: Float> {
373 cols: usize,
374 value: Vec<T>,
375}
376
377#[allow(non_snake_case)]
378#[derive(Debug, Clone, PartialEq)]
379struct SVDRawRec<T: Float> {
380 d: usize,
381 nsig: usize,
382 Ut: DMat<T>,
383 S: Vec<T>,
384 Vt: DMat<T>,
385}
386
387fn compare<T: SvdFloat>(computed: T, expected: T) -> bool {
388 T::compare(computed, expected)
389}
390
391fn insert_sort<T: PartialOrd>(n: usize, array1: &mut [T], array2: &mut [T]) {
393 for i in 1..n {
394 for j in (1..i + 1).rev() {
395 if array1[j - 1] <= array1[j] {
396 break;
397 }
398 array1.swap(j - 1, j);
399 array2.swap(j - 1, j);
400 }
401 }
402}
403
404#[allow(non_snake_case)]
405#[rustfmt::skip]
406fn svd_opb<T: Float>(A: &dyn SMat<T>, x: &[T], y: &mut [T], temp: &mut [T], transposed: bool) {
407 let nrows = if transposed { A.ncols() } else { A.nrows() };
408 let ncols = if transposed { A.nrows() } else { A.ncols() };
409 assert_eq!(x.len(), ncols, "svd_opb: x must be A.ncols() in length, x = {}, A.ncols = {}", x.len(), ncols);
410 assert_eq!(y.len(), ncols, "svd_opb: y must be A.ncols() in length, y = {}, A.ncols = {}", y.len(), ncols);
411 assert_eq!(temp.len(), nrows, "svd_opa: temp must be A.nrows() in length, temp = {}, A.nrows = {}", temp.len(), nrows);
412 A.svd_opa(x, temp, transposed); A.svd_opa(temp, y, !transposed); }
415
416fn svd_daxpy<T: Float + AddAssign + Send + Sync>(da: T, x: &[T], y: &mut [T]) {
418 if x.len() < 1000 {
419 for (xval, yval) in x.iter().zip(y.iter_mut()) {
420 *yval += da * *xval
421 }
422 } else {
423 y.par_iter_mut()
424 .zip(x.par_iter())
425 .for_each(|(yval, xval)| *yval += da * *xval);
426 }
427}
428
429fn svd_idamax<T: Float>(n: usize, x: &[T]) -> usize {
431 assert!(n > 0, "svd_idamax: unexpected inputs!");
432
433 match n {
434 1 => 0,
435 _ => {
436 let mut imax = 0;
437 for (i, xval) in x.iter().enumerate().take(n).skip(1) {
438 if xval.abs() > x[imax].abs() {
439 imax = i;
440 }
441 }
442 imax
443 }
444 }
445}
446
447fn svd_fsign<T: Float>(a: T, b: T) -> T {
449 match (a >= T::zero() && b >= T::zero()) || (a < T::zero() && b < T::zero()) {
450 true => a,
451 false => -a,
452 }
453}
454
455fn svd_pythag<T: SvdFloat + FromPrimitive>(a: T, b: T) -> T {
457 match a.abs().max(b.abs()) {
458 n if n > T::zero() => {
459 let mut p = n;
460 let mut r = (a.abs().min(b.abs()) / p).powi(2);
461 let four = T::from_f64(4.0).unwrap();
462 let two = T::from_f64(2.0).unwrap();
463 let mut t = four + r;
464 while !compare(t, four) {
465 let s = r / t;
466 let u = T::one() + two * s;
467 p = p * u;
468 r = (s / u).powi(2);
469 t = four + r;
470 }
471 p
472 }
473 _ => T::zero(),
474 }
475}
476
477fn svd_ddot<T: Float + Sum<T> + Send + Sync>(x: &[T], y: &[T]) -> T {
479 if x.len() < 1000 {
480 x.iter().zip(y).map(|(a, b)| *a * *b).sum()
481 } else {
482 x.par_iter().zip(y.par_iter()).map(|(a, b)| *a * *b).sum()
483 }
484}
485
486fn svd_norm<T: Float + Sum<T> + Send + Sync>(x: &[T]) -> T {
488 svd_ddot(x, x).sqrt()
489}
490
491fn svd_datx<T: Float + Sum<T>>(d: T, x: &[T], y: &mut [T]) {
493 for (i, xval) in x.iter().enumerate() {
494 y[i] = d * *xval;
495 }
496}
497
498fn svd_dscal<T: Float + MulAssign + Send + Sync>(d: T, x: &mut [T]) {
500 if x.len() < 1000 {
501 for elem in x.iter_mut() {
502 *elem *= d;
503 }
504 } else {
505 x.par_iter_mut().for_each(|elem| {
506 *elem *= d;
507 });
508 }
509}
510
511fn svd_dcopy<T: Float + Copy>(n: usize, offset: usize, x: &[T], y: &mut [T]) {
513 if n > 0 {
514 let start = n - 1;
515 for i in 0..n {
516 y[offset + start - i] = x[offset + i];
517 }
518 }
519}
520
521const MAX_IMTQLB_ITERATIONS: usize = 100;
522
523fn imtqlb<T: SvdFloat>(
524 n: usize,
525 d: &mut [T],
526 e: &mut [T],
527 bnd: &mut [T],
528 max_imtqlb: Option<usize>,
529) -> Result<(), SvdLibError> {
530 let max_imtqlb = max_imtqlb.unwrap_or(MAX_IMTQLB_ITERATIONS);
531 if n == 1 {
532 return Ok(());
533 }
534
535 let matrix_size_factor = T::from_f64((n as f64).sqrt()).unwrap();
536
537 bnd[0] = T::one();
538 let last = n - 1;
539 for i in 1..=last {
540 bnd[i] = T::zero();
541 e[i - 1] = e[i];
542 }
543 e[last] = T::zero();
544
545 let mut i = 0;
546
547 let mut had_convergence_issues = false;
548
549 for l in 0..=last {
550 let mut iteration = 0;
551 let mut p = d[l];
552 let mut f = bnd[l];
553
554 while iteration <= max_imtqlb {
555 let mut m = l;
556 while m < n {
557 if m == last {
558 break;
559 }
560
561 let test = d[m].abs() + d[m + 1].abs();
563 let tol = T::epsilon()
565 * T::from_f64(100.0).unwrap()
566 * test.max(T::one())
567 * matrix_size_factor;
568
569 if e[m].abs() <= tol {
570 break; }
572 m += 1;
573 }
574
575 if m == l {
576 let mut exchange = true;
578 if l > 0 {
579 i = l;
580 while i >= 1 && exchange {
581 if p < d[i - 1] {
582 d[i] = d[i - 1];
583 bnd[i] = bnd[i - 1];
584 i -= 1;
585 } else {
586 exchange = false;
587 }
588 }
589 }
590 if exchange {
591 i = 0;
592 }
593 d[i] = p;
594 bnd[i] = f;
595 iteration = max_imtqlb + 1; } else {
597 if iteration == max_imtqlb {
599 had_convergence_issues = true;
601
602 for idx in l..=m {
604 bnd[idx] = bnd[idx].max(T::from_f64(0.1).unwrap());
605 }
606
607 e[l] = T::zero();
609
610 break;
612 }
613
614 iteration += 1;
615 let two = T::from_f64(2.0).unwrap();
617 let mut g = (d[l + 1] - p) / (two * e[l]);
618 let mut r = svd_pythag(g, T::one());
619 g = d[m] - p + e[l] / (g + svd_fsign(r, g));
620 let mut s = T::one();
621 let mut c = T::one();
622 p = T::zero();
623
624 assert!(m > 0, "imtqlb: expected 'm' to be non-zero");
625 i = m - 1;
626 let mut underflow = false;
627 while !underflow && i >= l {
628 f = s * e[i];
629 let b = c * e[i];
630 r = svd_pythag(f, g);
631 e[i + 1] = r;
632
633 if r < T::epsilon() * T::from_f64(1000.0).unwrap() * (f.abs() + g.abs()) {
635 underflow = true;
636 break;
637 }
638
639 if r.abs() < T::epsilon() * T::from_f64(100.0).unwrap() {
641 r = T::epsilon() * T::from_f64(100.0).unwrap() * svd_fsign(T::one(), r);
642 }
643
644 s = f / r;
645 c = g / r;
646 g = d[i + 1] - p;
647 r = (d[i] - g) * s + T::from_f64(2.0).unwrap() * c * b;
648 p = s * r;
649 d[i + 1] = g + p;
650 g = c * r - b;
651 f = bnd[i + 1];
652 bnd[i + 1] = s * bnd[i] + c * f;
653 bnd[i] = c * bnd[i] - s * f;
654 if i == 0 {
655 break;
656 }
657 i -= 1;
658 }
659 if underflow {
661 d[i + 1] -= p;
662 } else {
663 d[l] -= p;
664 e[l] = g;
665 }
666 e[m] = T::zero();
667 }
668 }
669 }
670 if had_convergence_issues {
671 eprintln!("Warning: imtqlb had some convergence issues but continued with best estimates. Results may have reduced accuracy.");
672 }
673 Ok(())
674}
675
676#[allow(non_snake_case)]
677fn startv<T: SvdFloat>(
678 A: &dyn SMat<T>,
679 wrk: &mut WorkSpace<T>,
680 step: usize,
681 store: &mut Store<T>,
682 random_seed: u32,
683) -> Result<T, SvdLibError> {
684 let mut rnm2 = svd_ddot(&wrk.w0, &wrk.w0);
686 for id in 0..3 {
687 if id > 0 || step > 0 || compare(rnm2, T::zero()) {
688 let mut bytes = [0; 32];
689 for (i, b) in random_seed.to_le_bytes().iter().enumerate() {
690 bytes[i] = *b;
691 }
692 let mut seeded_rng = StdRng::from_seed(bytes);
693 for val in wrk.w0.iter_mut() {
694 *val = T::from_f64(seeded_rng.gen_range(-1.0..1.0)).unwrap();
695 }
696 }
697 wrk.w3.copy_from_slice(&wrk.w0);
698
699 svd_opb(A, &wrk.w3, &mut wrk.w0, &mut wrk.temp, wrk.transposed);
701 wrk.w3.copy_from_slice(&wrk.w0);
702 rnm2 = svd_ddot(&wrk.w3, &wrk.w3);
703 if rnm2 > T::zero() {
704 break;
705 }
706 }
707
708 if rnm2 <= T::zero() {
709 return Err(SvdLibError::StartvError(format!(
710 "rnm2 <= 0.0, rnm2 = {rnm2:?}"
711 )));
712 }
713
714 if step > 0 {
715 for i in 0..step {
716 let v = store.retrq(i);
717 svd_daxpy(-svd_ddot(&wrk.w3, v), v, &mut wrk.w0);
718 }
719
720 svd_daxpy(-svd_ddot(&wrk.w4, &wrk.w0), &wrk.w2, &mut wrk.w0);
722 wrk.w3.copy_from_slice(&wrk.w0);
723
724 rnm2 = match svd_ddot(&wrk.w3, &wrk.w3) {
725 dot if dot <= T::eps() * rnm2 => T::zero(),
726 dot => dot,
727 }
728 }
729 Ok(rnm2.sqrt())
730}
731
732#[allow(non_snake_case)]
733fn stpone<T: SvdFloat>(
734 A: &dyn SMat<T>,
735 wrk: &mut WorkSpace<T>,
736 store: &mut Store<T>,
737 random_seed: u32,
738) -> Result<(T, T), SvdLibError> {
739 let mut rnm = startv(A, wrk, 0, store, random_seed)?;
741 if compare(rnm, T::zero()) {
742 return Err(SvdLibError::StponeError("rnm == 0.0".to_string()));
743 }
744
745 svd_datx(rnm.recip(), &wrk.w0, &mut wrk.w1);
747 svd_dscal(rnm.recip(), &mut wrk.w3);
748
749 svd_opb(A, &wrk.w3, &mut wrk.w0, &mut wrk.temp, wrk.transposed);
751 wrk.alf[0] = svd_ddot(&wrk.w0, &wrk.w3);
752 svd_daxpy(-wrk.alf[0], &wrk.w1, &mut wrk.w0);
753 let t = svd_ddot(&wrk.w0, &wrk.w3);
754 wrk.alf[0] += t;
755 svd_daxpy(-t, &wrk.w1, &mut wrk.w0);
756 wrk.w4.copy_from_slice(&wrk.w0);
757 rnm = svd_norm(&wrk.w4);
758 let anorm = rnm + wrk.alf[0].abs();
759 Ok((rnm, T::eps().sqrt() * anorm))
760}
761
762#[allow(non_snake_case)]
763#[allow(clippy::too_many_arguments)]
764fn lanczos_step<T: SvdFloat>(
765 A: &dyn SMat<T>,
766 wrk: &mut WorkSpace<T>,
767 first: usize,
768 last: usize,
769 ll: &mut usize,
770 enough: &mut bool,
771 rnm: &mut T,
772 tol: &mut T,
773 store: &mut Store<T>,
774) -> Result<usize, SvdLibError> {
775 let eps1 = T::eps() * T::from_f64(wrk.ncols as f64).unwrap().sqrt();
776 let mut j = first;
777 let four = T::from_f64(4.0).unwrap();
778
779 while j < last {
780 mem::swap(&mut wrk.w1, &mut wrk.w2);
781 mem::swap(&mut wrk.w3, &mut wrk.w4);
782
783 store.storq(j - 1, &wrk.w2);
784 if j - 1 < MAXLL {
785 store.storp(j - 1, &wrk.w4);
786 }
787 wrk.bet[j] = *rnm;
788
789 if compare(*rnm, T::zero()) {
791 *rnm = startv(A, wrk, j, store, 0)?;
792 if compare(*rnm, T::zero()) {
793 *enough = true;
794 }
795 }
796
797 if *enough {
798 mem::swap(&mut wrk.w1, &mut wrk.w2);
799 break;
800 }
801
802 svd_datx(rnm.recip(), &wrk.w0, &mut wrk.w1);
804 svd_dscal(rnm.recip(), &mut wrk.w3);
805 svd_opb(A, &wrk.w3, &mut wrk.w0, &mut wrk.temp, wrk.transposed);
806 svd_daxpy(-*rnm, &wrk.w2, &mut wrk.w0);
807 wrk.alf[j] = svd_ddot(&wrk.w0, &wrk.w3);
808 svd_daxpy(-wrk.alf[j], &wrk.w1, &mut wrk.w0);
809
810 if j <= MAXLL && wrk.alf[j - 1].abs() > four * wrk.alf[j].abs() {
812 *ll = j;
813 }
814 for i in 0..(j - 1).min(*ll) {
815 let v1 = store.retrp(i);
816 let t = svd_ddot(v1, &wrk.w0);
817 let v2 = store.retrq(i);
818 svd_daxpy(-t, v2, &mut wrk.w0);
819 wrk.eta[i] = eps1;
820 wrk.oldeta[i] = eps1;
821 }
822
823 let t = svd_ddot(&wrk.w0, &wrk.w4);
825 svd_daxpy(-t, &wrk.w2, &mut wrk.w0);
826 if wrk.bet[j] > T::zero() {
827 wrk.bet[j] += t;
828 }
829 let t = svd_ddot(&wrk.w0, &wrk.w3);
830 svd_daxpy(-t, &wrk.w1, &mut wrk.w0);
831 wrk.alf[j] += t;
832 wrk.w4.copy_from_slice(&wrk.w0);
833 *rnm = svd_norm(&wrk.w4);
834 let anorm = wrk.bet[j] + wrk.alf[j].abs() + *rnm;
835 *tol = T::eps().sqrt() * anorm;
836
837 ortbnd(wrk, j, *rnm, eps1);
839
840 purge(wrk.ncols, *ll, wrk, j, rnm, *tol, store);
842 if *rnm <= *tol {
843 *rnm = T::zero();
844 }
845 j += 1;
846 }
847 Ok(j)
848}
849
850fn purge<T: SvdFloat>(
851 n: usize,
852 ll: usize,
853 wrk: &mut WorkSpace<T>,
854 step: usize,
855 rnm: &mut T,
856 tol: T,
857 store: &mut Store<T>,
858) {
859 if step < ll + 2 {
860 return;
861 }
862
863 let reps = T::eps().sqrt();
864 let eps1 = T::eps() * T::from_f64(n as f64).unwrap().sqrt();
865 let two = T::from_f64(2.0).unwrap();
866
867 let k = svd_idamax(step - (ll + 1), &wrk.eta) + ll;
868 if wrk.eta[k].abs() > reps {
869 let reps1 = eps1 / reps;
870 let mut iteration = 0;
871 let mut flag = true;
872 while iteration < 2 && flag {
873 if *rnm > tol {
874 let mut tq = T::zero();
876 let mut tr = T::zero();
877 for i in ll..step {
878 let v = store.retrq(i);
879 let t = svd_ddot(v, &wrk.w3);
880 tq += t.abs();
881 svd_daxpy(-t, v, &mut wrk.w1);
882 let t = svd_ddot(v, &wrk.w4);
883 tr += t.abs();
884 svd_daxpy(-t, v, &mut wrk.w0);
885 }
886 wrk.w3.copy_from_slice(&wrk.w1);
887 let t = svd_ddot(&wrk.w0, &wrk.w3);
888 tr += t.abs();
889 svd_daxpy(-t, &wrk.w1, &mut wrk.w0);
890 wrk.w4.copy_from_slice(&wrk.w0);
891 *rnm = svd_norm(&wrk.w4);
892 if tq <= reps1 && tr <= *rnm * reps1 {
893 flag = false;
894 }
895 }
896 iteration += 1;
897 }
898 for i in ll..=step {
899 wrk.eta[i] = eps1;
900 wrk.oldeta[i] = eps1;
901 }
902 }
903}
904
905fn ortbnd<T: SvdFloat>(wrk: &mut WorkSpace<T>, step: usize, rnm: T, eps1: T) {
906 if step < 1 {
907 return;
908 }
909 if !compare(rnm, T::zero()) && step > 1 {
910 wrk.oldeta[0] = (wrk.bet[1] * wrk.eta[1] + (wrk.alf[0] - wrk.alf[step]) * wrk.eta[0]
911 - wrk.bet[step] * wrk.oldeta[0])
912 / rnm
913 + eps1;
914 if step > 2 {
915 for i in 1..=step - 2 {
916 wrk.oldeta[i] = (wrk.bet[i + 1] * wrk.eta[i + 1]
917 + (wrk.alf[i] - wrk.alf[step]) * wrk.eta[i]
918 + wrk.bet[i] * wrk.eta[i - 1]
919 - wrk.bet[step] * wrk.oldeta[i])
920 / rnm
921 + eps1;
922 }
923 }
924 }
925 wrk.oldeta[step - 1] = eps1;
926 mem::swap(&mut wrk.oldeta, &mut wrk.eta);
927 wrk.eta[step] = eps1;
928}
929
930fn error_bound<T: SvdFloat>(
931 enough: &mut bool,
932 endl: T,
933 endr: T,
934 ritz: &mut [T],
935 bnd: &mut [T],
936 step: usize,
937 tol: T,
938) -> usize {
939 assert!(step > 0, "error_bound: expected 'step' to be non-zero");
940
941 let mid = svd_idamax(step + 1, bnd);
943 let sixteen = T::from_f64(16.0).unwrap();
944
945 let mut i = ((step + 1) + (step - 1)) / 2;
946 while i > mid + 1 {
947 if (ritz[i - 1] - ritz[i]).abs() < T::eps34() * ritz[i].abs()
948 && bnd[i] > tol
949 && bnd[i - 1] > tol
950 {
951 bnd[i - 1] = (bnd[i].powi(2) + bnd[i - 1].powi(2)).sqrt();
952 bnd[i] = T::zero();
953 }
954 i -= 1;
955 }
956
957 let mut i = ((step + 1) - (step - 1)) / 2;
958 while i + 1 < mid {
959 if (ritz[i + 1] - ritz[i]).abs() < T::eps34() * ritz[i].abs()
960 && bnd[i] > tol
961 && bnd[i + 1] > tol
962 {
963 bnd[i + 1] = (bnd[i].powi(2) + bnd[i + 1].powi(2)).sqrt();
964 bnd[i] = T::zero();
965 }
966 i += 1;
967 }
968
969 let mut neig = 0;
971 let mut gapl = ritz[step] - ritz[0];
972 for i in 0..=step {
973 let mut gap = gapl;
974 if i < step {
975 gapl = ritz[i + 1] - ritz[i];
976 }
977 gap = gap.min(gapl);
978 if gap > bnd[i] {
979 bnd[i] *= bnd[i] / gap;
980 }
981 if bnd[i] <= sixteen * T::eps() * ritz[i].abs() {
982 neig += 1;
983 if !*enough {
984 *enough = endl < ritz[i] && ritz[i] < endr;
985 }
986 }
987 }
988 neig
989}
990
991fn imtql2<T: SvdFloat>(
992 nm: usize,
993 n: usize,
994 d: &mut [T],
995 e: &mut [T],
996 z: &mut [T],
997 max_imtqlb: Option<usize>,
998) -> Result<(), SvdLibError> {
999 let max_imtqlb = max_imtqlb.unwrap_or(MAX_IMTQLB_ITERATIONS);
1000 if n == 1 {
1001 return Ok(());
1002 }
1003 assert!(n > 1, "imtql2: expected 'n' to be > 1");
1004 let two = T::from_f64(2.0).unwrap();
1005
1006 let last = n - 1;
1007
1008 for i in 1..n {
1009 e[i - 1] = e[i];
1010 }
1011 e[last] = T::zero();
1012
1013 let nnm = n * nm;
1014 for l in 0..n {
1015 let mut iteration = 0;
1016
1017 while iteration <= max_imtqlb {
1019 let mut m = l;
1020 while m < n {
1021 if m == last {
1022 break;
1023 }
1024 let test = d[m].abs() + d[m + 1].abs();
1025 if compare(test, test + e[m].abs()) {
1026 break; }
1028 m += 1;
1029 }
1030 if m == l {
1031 break;
1032 }
1033
1034 if iteration == max_imtqlb {
1036 return Err(SvdLibError::Imtql2Error(format!(
1037 "imtql2 no convergence to an eigenvalue after {} iterations",
1038 max_imtqlb
1039 )));
1040 }
1041 iteration += 1;
1042
1043 let mut g = (d[l + 1] - d[l]) / (two * e[l]);
1045 let mut r = svd_pythag(g, T::one());
1046 g = d[m] - d[l] + e[l] / (g + svd_fsign(r, g));
1047
1048 let mut s = T::one();
1049 let mut c = T::one();
1050 let mut p = T::zero();
1051
1052 assert!(m > 0, "imtql2: expected 'm' to be non-zero");
1053 let mut i = m - 1;
1054 let mut underflow = false;
1055 while !underflow && i >= l {
1056 let mut f = s * e[i];
1057 let b = c * e[i];
1058 r = svd_pythag(f, g);
1059 e[i + 1] = r;
1060 if compare(r, T::zero()) {
1061 underflow = true;
1062 } else {
1063 s = f / r;
1064 c = g / r;
1065 g = d[i + 1] - p;
1066 r = (d[i] - g) * s + two * c * b;
1067 p = s * r;
1068 d[i + 1] = g + p;
1069 g = c * r - b;
1070
1071 for k in (0..nnm).step_by(n) {
1073 let index = k + i;
1074 f = z[index + 1];
1075 z[index + 1] = s * z[index] + c * f;
1076 z[index] = c * z[index] - s * f;
1077 }
1078 if i == 0 {
1079 break;
1080 }
1081 i -= 1;
1082 }
1083 } if underflow {
1086 d[i + 1] -= p;
1087 } else {
1088 d[l] -= p;
1089 e[l] = g;
1090 }
1091 e[m] = T::zero();
1092 }
1093 }
1094
1095 for l in 1..n {
1097 let i = l - 1;
1098 let mut k = i;
1099 let mut p = d[i];
1100 for (j, item) in d.iter().enumerate().take(n).skip(l) {
1101 if *item < p {
1102 k = j;
1103 p = *item;
1104 }
1105 }
1106
1107 if k != i {
1109 d[k] = d[i];
1110 d[i] = p;
1111 for j in (0..nnm).step_by(n) {
1112 z.swap(j + i, j + k);
1113 }
1114 }
1115 }
1116
1117 Ok(())
1118}
1119
1120fn rotate_array<T: Float + Copy>(a: &mut [T], x: usize) {
1121 let n = a.len();
1122 let mut j = 0;
1123 let mut start = 0;
1124 let mut t1 = a[0];
1125
1126 for _ in 0..n {
1127 j = match j >= x {
1128 true => j - x,
1129 false => j + n - x,
1130 };
1131
1132 let t2 = a[j];
1133 a[j] = t1;
1134
1135 if j == start {
1136 j += 1;
1137 start = j;
1138 t1 = a[j];
1139 } else {
1140 t1 = t2;
1141 }
1142 }
1143}
1144
1145#[allow(non_snake_case)]
1146fn ritvec<T: SvdFloat>(
1147 A: &dyn SMat<T>,
1148 dimensions: usize,
1149 kappa: T,
1150 wrk: &mut WorkSpace<T>,
1151 steps: usize,
1152 neig: usize,
1153 store: &mut Store<T>,
1154) -> Result<SVDRawRec<T>, SvdLibError> {
1155 let js = steps + 1;
1156 let jsq = js * js;
1157
1158 let sparsity = T::one()
1159 - (T::from_usize(A.nnz()).unwrap()
1160 / (T::from_usize(A.nrows()).unwrap() * T::from_usize(A.ncols()).unwrap()));
1161
1162 let epsilon = T::epsilon();
1163 let adaptive_eps = if sparsity > T::from_f64(0.99).unwrap() {
1164 epsilon * T::from_f64(100.0).unwrap()
1166 } else if sparsity > T::from_f64(0.9).unwrap() {
1167 epsilon * T::from_f64(10.0).unwrap()
1169 } else {
1170 epsilon
1172 };
1173
1174 let max_iterations_imtql2 = if sparsity > T::from_f64(0.999).unwrap() {
1175 Some(500)
1177 } else if sparsity > T::from_f64(0.99).unwrap() {
1178 Some(300)
1180 } else if sparsity > T::from_f64(0.9).unwrap() {
1181 Some(200)
1183 } else {
1184 Some(50)
1186 };
1187
1188 let mut s = vec![T::zero(); jsq];
1189 for i in (0..jsq).step_by(js + 1) {
1191 s[i] = T::one();
1192 }
1193
1194 let mut Vt = DMat {
1195 cols: wrk.ncols,
1196 value: vec![T::zero(); wrk.ncols * dimensions],
1197 };
1198
1199 svd_dcopy(js, 0, &wrk.alf, &mut Vt.value);
1200 svd_dcopy(steps, 1, &wrk.bet, &mut wrk.w5);
1201
1202 imtql2(
1205 js,
1206 js,
1207 &mut Vt.value,
1208 &mut wrk.w5,
1209 &mut s,
1210 max_iterations_imtql2,
1211 )?;
1212
1213 let max_eigenvalue = Vt
1214 .value
1215 .iter()
1216 .fold(T::zero(), |max, &val| max.max(val.abs()));
1217
1218 let adaptive_kappa = if sparsity > T::from_f64(0.99).unwrap() {
1219 kappa * T::from_f64(10.0).unwrap()
1221 } else {
1222 kappa
1223 };
1224
1225 let mut x = dimensions - 1;
1226
1227 let store_vectors: Vec<Vec<T>> = (0..js).map(|i| store.retrq(i).to_vec()).collect();
1228
1229 let significant_indices: Vec<usize> = (0..js)
1230 .into_par_iter()
1231 .filter(|&k| {
1232 let relative_bound =
1234 adaptive_kappa * wrk.ritz[k].abs().max(max_eigenvalue * adaptive_eps);
1235 wrk.bnd[k] <= relative_bound && k + 1 > js - neig
1236 })
1237 .collect();
1238
1239 let nsig = significant_indices.len();
1240
1241 let mut vt_vectors: Vec<(usize, Vec<T>)> = significant_indices
1242 .into_par_iter()
1243 .map(|k| {
1244 let mut vec = vec![T::zero(); wrk.ncols];
1245 let mut idx = (jsq - js) + k + 1;
1246
1247 for i in 0..js {
1248 idx -= js;
1249 if s[idx].abs() > adaptive_eps {
1251 for (j, item) in store_vectors[i].iter().enumerate().take(wrk.ncols) {
1252 vec[j] += s[idx] * *item;
1253 }
1254 }
1255 }
1256
1257 (k, vec)
1259 })
1260 .collect();
1261
1262 vt_vectors.sort_by_key(|(k, _)| *k);
1264
1265 let d = dimensions.min(nsig);
1267 let mut S = vec![T::zero(); d];
1268 let mut Ut = DMat {
1269 cols: wrk.nrows,
1270 value: vec![T::zero(); wrk.nrows * d],
1271 };
1272
1273 let mut Vt = DMat {
1275 cols: wrk.ncols,
1276 value: vec![T::zero(); wrk.ncols * d],
1277 };
1278
1279 for (i, (_, vec)) in vt_vectors.into_iter().take(d).enumerate() {
1281 let vt_offset = i * Vt.cols;
1282 Vt.value[vt_offset..vt_offset + Vt.cols].copy_from_slice(&vec);
1283 }
1284
1285 let mut ab_products = Vec::with_capacity(d);
1287 let mut a_products = Vec::with_capacity(d);
1288
1289 for i in 0..d {
1291 let vt_offset = i * Vt.cols;
1292 let vt_vec = &Vt.value[vt_offset..vt_offset + Vt.cols];
1293
1294 let mut tmp_vec = vec![T::zero(); Vt.cols];
1295 let mut ut_vec = vec![T::zero(); wrk.nrows];
1296
1297 svd_opb(A, vt_vec, &mut tmp_vec, &mut wrk.temp, wrk.transposed);
1299 A.svd_opa(vt_vec, &mut ut_vec, wrk.transposed);
1300
1301 ab_products.push(tmp_vec);
1302 a_products.push(ut_vec);
1303 }
1304
1305 let results: Vec<(usize, T)> = (0..d)
1306 .into_par_iter()
1307 .map(|i| {
1308 let vt_offset = i * Vt.cols;
1309 let vt_vec = &Vt.value[vt_offset..vt_offset + Vt.cols];
1310 let tmp_vec = &ab_products[i];
1311
1312 let t = svd_ddot(vt_vec, tmp_vec);
1314 let sval = t.max(T::zero()).sqrt();
1315
1316 (i, sval)
1317 })
1318 .collect();
1319
1320 for (i, sval) in results {
1322 S[i] = sval;
1323 let ut_offset = i * Ut.cols;
1324 let mut ut_vec = a_products[i].clone();
1325
1326 if sval > adaptive_eps {
1327 svd_dscal(T::one() / sval, &mut ut_vec);
1328 } else {
1329 let dls = sval.max(adaptive_eps);
1330 let safe_scale = T::one() / dls;
1331 svd_dscal(safe_scale, &mut ut_vec);
1332 }
1333
1334 Ut.value[ut_offset..ut_offset + Ut.cols].copy_from_slice(&ut_vec);
1336 }
1337
1338 Ok(SVDRawRec {
1339 d,
1341 nsig,
1343 Ut,
1346 S,
1348 Vt,
1351 })
1352}
1353
1354#[allow(non_snake_case)]
1355#[allow(clippy::too_many_arguments)]
1356fn lanso<T: SvdFloat>(
1357 A: &dyn SMat<T>,
1358 dim: usize,
1359 iterations: usize,
1360 end_interval: &[T; 2],
1361 wrk: &mut WorkSpace<T>,
1362 neig: &mut usize,
1363 store: &mut Store<T>,
1364 random_seed: u32,
1365) -> Result<usize, SvdLibError> {
1366 let sparsity = T::one()
1367 - (T::from_usize(A.nnz()).unwrap()
1368 / (T::from_usize(A.nrows()).unwrap() * T::from_usize(A.ncols()).unwrap()));
1369 let max_iterations_imtqlb = if sparsity > T::from_f64(0.999).unwrap() {
1370 Some(500)
1372 } else if sparsity > T::from_f64(0.99).unwrap() {
1373 Some(300)
1375 } else if sparsity > T::from_f64(0.9).unwrap() {
1376 Some(100)
1378 } else {
1379 Some(50)
1381 };
1382
1383 let epsilon = T::epsilon();
1384 let adaptive_eps = if sparsity > T::from_f64(0.99).unwrap() {
1385 epsilon * T::from_f64(100.0).unwrap()
1387 } else if sparsity > T::from_f64(0.9).unwrap() {
1388 epsilon * T::from_f64(10.0).unwrap()
1390 } else {
1391 epsilon
1393 };
1394
1395 let (endl, endr) = (end_interval[0], end_interval[1]);
1396
1397 let rnm_tol = stpone(A, wrk, store, random_seed)?;
1399 let mut rnm = rnm_tol.0;
1400 let mut tol = rnm_tol.1;
1401
1402 let eps1 = adaptive_eps * T::from_f64(wrk.ncols as f64).unwrap().sqrt();
1403 wrk.eta[0] = eps1;
1404 wrk.oldeta[0] = eps1;
1405 let mut ll = 0;
1406 let mut first = 1;
1407 let mut last = iterations.min(dim.max(8) + dim);
1408 let mut enough = false;
1409 let mut j = 0;
1410 let mut intro = 0;
1411
1412 while !enough {
1413 if rnm <= tol {
1414 rnm = T::zero();
1415 }
1416
1417 let steps = lanczos_step(
1419 A,
1420 wrk,
1421 first,
1422 last,
1423 &mut ll,
1424 &mut enough,
1425 &mut rnm,
1426 &mut tol,
1427 store,
1428 )?;
1429 j = match enough {
1430 true => steps - 1,
1431 false => last - 1,
1432 };
1433
1434 first = j + 1;
1435 wrk.bet[first] = rnm;
1436
1437 let mut l = 0;
1439 for _ in 0..j {
1440 if l > j {
1441 break;
1442 }
1443
1444 let mut i = l;
1445 while i <= j {
1446 if wrk.bet[i + 1].abs() <= adaptive_eps {
1447 break;
1448 }
1449 i += 1;
1450 }
1451 i = i.min(j);
1452
1453 let sz = i - l;
1455 svd_dcopy(sz + 1, l, &wrk.alf, &mut wrk.ritz);
1456 svd_dcopy(sz, l + 1, &wrk.bet, &mut wrk.w5);
1457
1458 imtqlb(
1459 sz + 1,
1460 &mut wrk.ritz[l..],
1461 &mut wrk.w5[l..],
1462 &mut wrk.bnd[l..],
1463 max_iterations_imtqlb,
1464 )?;
1465
1466 for m in l..=i {
1467 wrk.bnd[m] = rnm * wrk.bnd[m].abs();
1468 }
1469 l = i + 1;
1470 }
1471
1472 insert_sort(j + 1, &mut wrk.ritz, &mut wrk.bnd);
1474
1475 *neig = error_bound(&mut enough, endl, endr, &mut wrk.ritz, &mut wrk.bnd, j, tol);
1476
1477 if *neig < dim {
1479 if *neig == 0 {
1480 last = first + 9;
1481 intro = first;
1482 } else {
1483 let extra_steps = if sparsity > T::from_f64(0.99).unwrap() {
1484 5 } else {
1486 0
1487 };
1488
1489 last = first + 3.max(1 + ((j - intro) * (dim - *neig)) / *neig) + extra_steps;
1490 }
1491 last = last.min(iterations);
1492 } else {
1493 enough = true
1494 }
1495 enough = enough || first >= iterations;
1496 }
1497 store.storq(j, &wrk.w1);
1498 Ok(j)
1499}
1500
1501impl<T: SvdFloat + 'static> SvdRec<T> {
1502 pub fn recompose(&self) -> Array2<T> {
1503 let sdiag = Array2::from_diag(&self.s);
1504 self.ut.t().dot(&sdiag).dot(&self.vt)
1505 }
1506}
1507
1508#[rustfmt::skip]
1509impl<T: Float + Zero + AddAssign + Clone + Sync> SMat<T> for nalgebra_sparse::csc::CscMatrix<T> {
1510 fn nrows(&self) -> usize { self.nrows() }
1511 fn ncols(&self) -> usize { self.ncols() }
1512 fn nnz(&self) -> usize { self.nnz() }
1513
1514 fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
1516 let nrows = if transposed { self.ncols() } else { self.nrows() };
1517 let ncols = if transposed { self.nrows() } else { self.ncols() };
1518 assert_eq!(x.len(), ncols, "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}", x.len(), ncols);
1519 assert_eq!(y.len(), nrows, "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}", y.len(), nrows);
1520
1521 let (major_offsets, minor_indices, values) = self.csc_data();
1522
1523 for y_val in y.iter_mut() {
1524 *y_val = T::zero();
1525 }
1526
1527 if transposed {
1528 for (i, yval) in y.iter_mut().enumerate() {
1529 for j in major_offsets[i]..major_offsets[i + 1] {
1530 *yval += values[j] * x[minor_indices[j]];
1531 }
1532 }
1533 } else {
1534 for (i, xval) in x.iter().enumerate() {
1535 for j in major_offsets[i]..major_offsets[i + 1] {
1536 y[minor_indices[j]] += values[j] * *xval;
1537 }
1538 }
1539 }
1540 }
1541}
1542
1543#[rustfmt::skip]
1544impl<T: Float + Zero + AddAssign + Clone + Sync + Send> SMat<T> for nalgebra_sparse::csr::CsrMatrix<T> {
1545 fn nrows(&self) -> usize { self.nrows() }
1546 fn ncols(&self) -> usize { self.ncols() }
1547 fn nnz(&self) -> usize { self.nnz() }
1548
1549 fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
1551 let nrows = if transposed { self.ncols() } else { self.nrows() };
1553 let ncols = if transposed { self.nrows() } else { self.ncols() };
1554 assert_eq!(x.len(), ncols, "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}", x.len(), ncols);
1555 assert_eq!(y.len(), nrows, "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}", y.len(), nrows);
1556
1557 let (major_offsets, minor_indices, values) = self.csr_data();
1558
1559 y.fill(T::zero());
1560
1561 if !transposed {
1562 let nrows = self.nrows();
1563 let chunk_size = crate::utils::determine_chunk_size(nrows);
1564
1565 let results: Vec<(usize, T)> = (0..nrows)
1567 .into_par_iter()
1568 .map(|i| {
1569 let mut sum = T::zero();
1570 for j in major_offsets[i]..major_offsets[i + 1] {
1571 sum += values[j] * x[minor_indices[j]];
1572 }
1573 (i, sum)
1574 })
1575 .collect();
1576
1577 for (i, val) in results {
1579 y[i] = val;
1580 }
1581 } else {
1582 let nrows = self.nrows();
1583 let chunk_size = crate::utils::determine_chunk_size(nrows);
1584
1585 let results: Vec<Vec<T>> = (0..((nrows + chunk_size - 1) / chunk_size))
1587 .into_par_iter()
1588 .map(|chunk_idx| {
1589 let start = chunk_idx * chunk_size;
1590 let end = (start + chunk_size).min(nrows);
1591
1592 let mut local_y = vec![T::zero(); y.len()];
1593 for i in start..end {
1594 let row_val = x[i];
1595 for j in major_offsets[i]..major_offsets[i + 1] {
1596 let col = minor_indices[j];
1597 local_y[col] += values[j] * row_val;
1598 }
1599 }
1600 local_y
1601 })
1602 .collect();
1603
1604 for local_y in results {
1606 for (idx, val) in local_y.iter().enumerate() {
1607 if !val.is_zero() {
1608 y[idx] += *val;
1609 }
1610 }
1611 }
1612 }
1613 }
1614}
1615
1616#[rustfmt::skip]
1617impl<T: Float + Zero + AddAssign + Clone + Sync> SMat<T> for nalgebra_sparse::coo::CooMatrix<T> {
1618 fn nrows(&self) -> usize { self.nrows() }
1619 fn ncols(&self) -> usize { self.ncols() }
1620 fn nnz(&self) -> usize { self.nnz() }
1621
1622 fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
1624 let nrows = if transposed { self.ncols() } else { self.nrows() };
1625 let ncols = if transposed { self.nrows() } else { self.ncols() };
1626 assert_eq!(x.len(), ncols, "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}", x.len(), ncols);
1627 assert_eq!(y.len(), nrows, "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}", y.len(), nrows);
1628
1629 for y_val in y.iter_mut() {
1630 *y_val = T::zero();
1631 }
1632
1633 if transposed {
1634 for (i, j, v) in self.triplet_iter() {
1635 y[j] += *v * x[i];
1636 }
1637 } else {
1638 for (i, j, v) in self.triplet_iter() {
1639 y[i] += *v * x[j];
1640 }
1641 }
1642 }
1643}