1use crate::{
5 lobpcg::{lobpcg, random, Lobpcg},
6 Order, Result,
7};
8use ndarray::prelude::*;
9use num_traits::NumCast;
10use std::iter::Sum;
11
12use rand::Rng;
13
14#[derive(Debug, Clone)]
19pub struct TruncatedSvdResult<A> {
20 eigvals: Array1<A>,
21 eigvecs: Array2<A>,
22 problem: Array2<A>,
23 order: Order,
24 ngm: bool,
25}
26
27impl<A: NdFloat + 'static + MagnitudeCorrection> TruncatedSvdResult<A> {
28 fn singular_values_with_indices(&self) -> (Array1<A>, Vec<usize>) {
30 let mut a = self.eigvals.iter().enumerate().collect::<Vec<_>>();
32
33 let (values, indices) = if self.order == Order::Largest {
34 a.sort_by(|(_, x), (_, y)| x.partial_cmp(y).unwrap().reverse());
36
37 let cutoff = A::epsilon() * A::correction() * *a[0].1; let (values, indices): (Vec<A>, Vec<usize>) = a
44 .into_iter()
45 .filter(|(_, x)| *x > &cutoff)
46 .map(|(a, b)| (b.sqrt(), a))
47 .unzip();
48
49 (values, indices)
50 } else {
51 a.sort_by(|(_, x), (_, y)| x.partial_cmp(y).unwrap());
52
53 let (values, indices) = a.into_iter().map(|(a, b)| (b.sqrt(), a)).unzip();
54
55 (values, indices)
56 };
57
58 (Array1::from(values), indices)
59 }
60
61 pub fn values(&self) -> Array1<A> {
63 let (values, _) = self.singular_values_with_indices();
64
65 values
66 }
67
68 pub fn values_vectors(&self) -> (Array2<A>, Array1<A>, Array2<A>) {
70 let (values, indices) = self.singular_values_with_indices();
71
72 #[allow(clippy::branches_sharing_code)]
74 let (u, v) = if self.ngm {
75 let vlarge = self.eigvecs.select(Axis(1), &indices);
76 let mut ularge = self.problem.dot(&vlarge);
77
78 ularge
79 .columns_mut()
80 .into_iter()
81 .zip(values.iter())
82 .for_each(|(mut a, b)| a.mapv_inplace(|x| x / *b));
83
84 (ularge, vlarge)
85 } else {
86 let ularge = self.eigvecs.select(Axis(1), &indices);
87
88 let mut vlarge = self.problem.t().dot(&ularge);
89 vlarge
90 .columns_mut()
91 .into_iter()
92 .zip(values.iter())
93 .for_each(|(mut a, b)| a.mapv_inplace(|x| x / *b));
94
95 (ularge, vlarge)
96 };
97
98 (u, values, v.reversed_axes())
99 }
100}
101
102#[derive(Debug, Clone)]
103pub struct TruncatedSvd<A: NdFloat, R: Rng> {
108 order: Order,
109 problem: Array2<A>,
110 precision: f32,
111 maxiter: usize,
112 rng: R,
113}
114
115impl<A: NdFloat + Sum, R: Rng> TruncatedSvd<A, R> {
116 pub fn new_with_rng(problem: Array2<A>, order: Order, rng: R) -> TruncatedSvd<A, R> {
123 TruncatedSvd {
124 precision: 1e-5,
125 maxiter: problem.len_of(Axis(0)) * 2,
126 order,
127 problem,
128 rng,
129 }
130 }
131}
132
133impl<A: NdFloat + Sum, R: Rng> TruncatedSvd<A, R> {
134 pub fn precision(mut self, precision: f32) -> Self {
140 self.precision = precision;
141
142 self
143 }
144
145 pub fn maxiter(mut self, maxiter: usize) -> Self {
150 self.maxiter = maxiter;
151
152 self
153 }
154
155 pub fn decompose(mut self, num: usize) -> Result<TruncatedSvdResult<A>> {
179 if num == 0 {
180 return Ok(TruncatedSvdResult {
182 eigvals: Array1::zeros(0),
183 eigvecs: Array2::zeros((0, 0)),
184 problem: Array2::zeros((0, 0)),
185 order: self.order,
186 ngm: false,
187 });
188 }
189
190 let (n, m) = (self.problem.nrows(), self.problem.ncols());
191
192 let x: Array2<f32> = random((usize::min(n, m), num), &mut self.rng);
194 let x = x.mapv(|x| NumCast::from(x).unwrap());
195
196 let precision = self.precision * self.precision;
198
199 let res = if n > m {
201 lobpcg(
202 |y| self.problem.t().dot(&self.problem.dot(&y)),
203 x,
204 |_| {},
205 None,
206 precision,
207 self.maxiter,
208 self.order,
209 )
210 } else {
211 lobpcg(
212 |y| self.problem.dot(&self.problem.t().dot(&y)),
213 x,
214 |_| {},
215 None,
216 precision,
217 self.maxiter,
218 self.order,
219 )
220 };
221
222 match res {
224 Ok(Lobpcg {
225 eigvals, eigvecs, ..
226 })
227 | Err((
228 _,
229 Some(Lobpcg {
230 eigvals, eigvecs, ..
231 }),
232 )) => Ok(TruncatedSvdResult {
233 problem: self.problem,
234 eigvals,
235 eigvecs,
236 order: self.order,
237 ngm: n > m,
238 }),
239 Err((err, None)) => Err(err),
240 }
241 }
242}
243
244pub trait MagnitudeCorrection {
250 fn correction() -> Self;
251}
252
253impl MagnitudeCorrection for f32 {
254 fn correction() -> Self {
255 1.0e3
256 }
257}
258
259impl MagnitudeCorrection for f64 {
260 fn correction() -> Self {
261 1.0e6
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::Order;
268 use super::TruncatedSvd;
269
270 use approx::assert_abs_diff_eq;
271 use ndarray::{arr1, arr2, s, Array1, Array2, NdFloat};
272 use ndarray_rand::{rand_distr::StandardNormal, RandomExt};
273 use rand::distributions::{Distribution, Standard};
274 use rand::SeedableRng;
275 use rand_xoshiro::Xoshiro256Plus;
276
277 fn random<A>(sh: (usize, usize)) -> Array2<A>
279 where
280 A: NdFloat,
281 Standard: Distribution<A>,
282 {
283 let rng = Xoshiro256Plus::seed_from_u64(3);
284 super::random(sh, rng)
285 }
286
287 #[test]
288 fn test_truncated_svd() {
289 let a = arr2(&[[3., 2., 2.], [2., 3., -2.]]);
290
291 let res = TruncatedSvd::new_with_rng(a, Order::Largest, Xoshiro256Plus::seed_from_u64(42))
292 .precision(1e-5)
293 .maxiter(10)
294 .decompose(2)
295 .unwrap();
296
297 let (_, sigma, _) = res.values_vectors();
298
299 assert_abs_diff_eq!(&sigma, &arr1(&[5.0, 3.0]), epsilon = 1e-5);
300 }
301
302 #[test]
303 fn test_truncated_svd_random() {
304 let a: Array2<f64> = random((50, 10));
305
306 let res = TruncatedSvd::new_with_rng(
307 a.clone(),
308 Order::Largest,
309 Xoshiro256Plus::seed_from_u64(42),
310 )
311 .precision(1e-5)
312 .maxiter(10)
313 .decompose(10)
314 .unwrap();
315
316 let (u, sigma, v_t) = res.values_vectors();
317 let reconstructed = u.dot(&Array2::from_diag(&sigma).dot(&v_t));
318
319 assert_abs_diff_eq!(&a, &reconstructed, epsilon = 1e-5);
320 }
321
322 #[test]
331 fn test_marchenko_pastur() {
332 let mut rng = Xoshiro256Plus::seed_from_u64(3);
334
335 let data = Array2::random_using((1000, 500), StandardNormal, &mut rng) / 1000f64.sqrt();
337
338 let res =
339 TruncatedSvd::new_with_rng(data, Order::Largest, Xoshiro256Plus::seed_from_u64(42))
340 .precision(1e-3)
341 .decompose(500)
342 .unwrap();
343
344 let sv = res.values().mapv(|x: f64| x * x);
345
346 let (a, b) = (
349 1. * (1. - 0.5f64.sqrt()).powf(2.0),
350 1. * (1. + 0.5f64.sqrt()).powf(2.0),
351 );
352
353 assert_abs_diff_eq!(b, sv[0], epsilon = 0.1);
355 assert_abs_diff_eq!(a, sv[sv.len() - 1], epsilon = 0.1);
356
357 let mut i = 0;
359 'outer: for th in Array1::linspace(0.1f64, 2.8, 28).slice(s![..;-1]) {
360 let mut count = 0;
361 while sv[i] >= *th {
362 count += 1;
363 i += 1;
364
365 if i == sv.len() {
366 break 'outer;
367 }
368 }
369
370 let x = th + 0.05;
371 let mp_law = ((b - x) * (x - a)).sqrt() / std::f64::consts::PI / x;
372 let empirical = count as f64 / 500. / ((2.8 - 0.1) / 28.);
373
374 assert_abs_diff_eq!(mp_law, empirical, epsilon = 0.05);
375 }
376 }
377}