linfa_linalg/lobpcg/
svd.rs

1//! Truncated singular value decomposition
2//!
3//! This module computes the k largest/smallest singular values/vectors for a dense matrix.
4use crate::{
5    lobpcg::{lobpcg, random, Lobpcg},
6    Order, Result,
7};
8use ndarray::prelude::*;
9use num_traits::NumCast;
10use std::iter::Sum;
11
12use rand::Rng;
13
14/// The result of a eigenvalue decomposition, not yet transformed into singular values/vectors
15///
16/// Provides methods for either calculating just the singular values with reduced cost or the
17/// vectors with additional cost of matrix multiplication.
18#[derive(Debug, Clone)]
19pub struct TruncatedSvdResult<A> {
20    eigvals: Array1<A>,
21    eigvecs: Array2<A>,
22    problem: Array2<A>,
23    order: Order,
24    ngm: bool,
25}
26
27impl<A: NdFloat + 'static + MagnitudeCorrection> TruncatedSvdResult<A> {
28    /// Returns singular values ordered by magnitude with indices.
29    fn singular_values_with_indices(&self) -> (Array1<A>, Vec<usize>) {
30        // numerate eigenvalues
31        let mut a = self.eigvals.iter().enumerate().collect::<Vec<_>>();
32
33        let (values, indices) = if self.order == Order::Largest {
34            // sort by magnitude
35            a.sort_by(|(_, x), (_, y)| x.partial_cmp(y).unwrap().reverse());
36
37            // calculate cut-off magnitude (borrowed from scipy)
38            let cutoff = A::epsilon() * // float precision
39                         A::correction() * // correction term (see trait below)
40                         *a[0].1; // max eigenvalue
41
42            // filter low singular values away
43            let (values, indices): (Vec<A>, Vec<usize>) = a
44                .into_iter()
45                .filter(|(_, x)| *x > &cutoff)
46                .map(|(a, b)| (b.sqrt(), a))
47                .unzip();
48
49            (values, indices)
50        } else {
51            a.sort_by(|(_, x), (_, y)| x.partial_cmp(y).unwrap());
52
53            let (values, indices) = a.into_iter().map(|(a, b)| (b.sqrt(), a)).unzip();
54
55            (values, indices)
56        };
57
58        (Array1::from(values), indices)
59    }
60
61    /// Returns singular values ordered by magnitude
62    pub fn values(&self) -> Array1<A> {
63        let (values, _) = self.singular_values_with_indices();
64
65        values
66    }
67
68    /// Returns singular values, left-singular vectors and right-singular vectors
69    pub fn values_vectors(&self) -> (Array2<A>, Array1<A>, Array2<A>) {
70        let (values, indices) = self.singular_values_with_indices();
71
72        // branch n > m (for A is [n x m])
73        #[allow(clippy::branches_sharing_code)]
74        let (u, v) = if self.ngm {
75            let vlarge = self.eigvecs.select(Axis(1), &indices);
76            let mut ularge = self.problem.dot(&vlarge);
77
78            ularge
79                .columns_mut()
80                .into_iter()
81                .zip(values.iter())
82                .for_each(|(mut a, b)| a.mapv_inplace(|x| x / *b));
83
84            (ularge, vlarge)
85        } else {
86            let ularge = self.eigvecs.select(Axis(1), &indices);
87
88            let mut vlarge = self.problem.t().dot(&ularge);
89            vlarge
90                .columns_mut()
91                .into_iter()
92                .zip(values.iter())
93                .for_each(|(mut a, b)| a.mapv_inplace(|x| x / *b));
94
95            (ularge, vlarge)
96        };
97
98        (u, values, v.reversed_axes())
99    }
100}
101
102#[derive(Debug, Clone)]
103/// Truncated singular value decomposition
104///
105/// Wraps the LOBPCG algorithm and provides convenient builder-pattern access to
106/// parameter like maximal iteration, precision and constrain matrix.
107pub struct TruncatedSvd<A: NdFloat, R: Rng> {
108    order: Order,
109    problem: Array2<A>,
110    precision: f32,
111    maxiter: usize,
112    rng: R,
113}
114
115impl<A: NdFloat + Sum, R: Rng> TruncatedSvd<A, R> {
116    /// Create a new truncated SVD problem
117    ///
118    /// # Parameters
119    ///  * `problem`: rectangular matrix which is decomposed
120    ///  * `order`: whether to return large or small (close to zero) singular values
121    ///  * `rng`: random number generator
122    pub fn new_with_rng(problem: Array2<A>, order: Order, rng: R) -> TruncatedSvd<A, R> {
123        TruncatedSvd {
124            precision: 1e-5,
125            maxiter: problem.len_of(Axis(0)) * 2,
126            order,
127            problem,
128            rng,
129        }
130    }
131}
132
133impl<A: NdFloat + Sum, R: Rng> TruncatedSvd<A, R> {
134    /// Set the required precision of the solution
135    ///
136    /// The precision is, in the context of SVD, the square-root precision of the underlying
137    /// eigenproblem solution. The eigenproblem-precision is used to check the L2 error of each
138    /// eigenvector and stops its optimization when the required precision is reached.
139    pub fn precision(mut self, precision: f32) -> Self {
140        self.precision = precision;
141
142        self
143    }
144
145    /// Set the maximal number of iterations
146    ///
147    /// The LOBPCG is an iterative approach to eigenproblems and stops when this maximum
148    /// number of iterations are reached
149    pub fn maxiter(mut self, maxiter: usize) -> Self {
150        self.maxiter = maxiter;
151
152        self
153    }
154
155    /// Calculate the singular value decomposition
156    ///
157    /// # Parameters
158    ///
159    ///  * `num`: number of singular-value/vector pairs, ordered by magnitude
160    ///
161    /// # Example
162    ///
163    /// ```rust
164    /// use ndarray::{arr1, Array2};
165    /// use linfa_linalg::{Order, lobpcg::TruncatedSvd};
166    /// use rand::SeedableRng;
167    /// use rand_xoshiro::Xoshiro256Plus;
168    ///
169    /// let diag = arr1(&[1., 2., 3., 4., 5.]);
170    /// let a = Array2::from_diag(&diag);
171    ///
172    /// let eig = TruncatedSvd::new_with_rng(a, Order::Largest, Xoshiro256Plus::seed_from_u64(42))
173    ///    .precision(1e-4)
174    ///    .maxiter(500);
175    ///
176    /// let res = eig.decompose(3);
177    /// ```
178    pub fn decompose(mut self, num: usize) -> Result<TruncatedSvdResult<A>> {
179        if num == 0 {
180            // return empty solution if requested eigenvalue number is zero
181            return Ok(TruncatedSvdResult {
182                eigvals: Array1::zeros(0),
183                eigvecs: Array2::zeros((0, 0)),
184                problem: Array2::zeros((0, 0)),
185                order: self.order,
186                ngm: false,
187            });
188        }
189
190        let (n, m) = (self.problem.nrows(), self.problem.ncols());
191
192        // generate initial matrix
193        let x: Array2<f32> = random((usize::min(n, m), num), &mut self.rng);
194        let x = x.mapv(|x| NumCast::from(x).unwrap());
195
196        // square precision because the SVD squares the eigenvalue as well
197        let precision = self.precision * self.precision;
198
199        // use problem definition with less operations required
200        let res = if n > m {
201            lobpcg(
202                |y| self.problem.t().dot(&self.problem.dot(&y)),
203                x,
204                |_| {},
205                None,
206                precision,
207                self.maxiter,
208                self.order,
209            )
210        } else {
211            lobpcg(
212                |y| self.problem.dot(&self.problem.t().dot(&y)),
213                x,
214                |_| {},
215                None,
216                precision,
217                self.maxiter,
218                self.order,
219            )
220        };
221
222        // convert into TruncatedSvdResult
223        match res {
224            Ok(Lobpcg {
225                eigvals, eigvecs, ..
226            })
227            | Err((
228                _,
229                Some(Lobpcg {
230                    eigvals, eigvecs, ..
231                }),
232            )) => Ok(TruncatedSvdResult {
233                problem: self.problem,
234                eigvals,
235                eigvecs,
236                order: self.order,
237                ngm: n > m,
238            }),
239            Err((err, None)) => Err(err),
240        }
241    }
242}
243
244/// Magnitude Correction
245///
246/// The magnitude correction changes the cut-off point at which an eigenvector belongs to the
247/// null-space and its eigenvalue is therefore zero. The correction is multiplied by the floating
248/// point epsilon and therefore dependent on the floating type.
249pub trait MagnitudeCorrection {
250    fn correction() -> Self;
251}
252
253impl MagnitudeCorrection for f32 {
254    fn correction() -> Self {
255        1.0e3
256    }
257}
258
259impl MagnitudeCorrection for f64 {
260    fn correction() -> Self {
261        1.0e6
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::Order;
268    use super::TruncatedSvd;
269
270    use approx::assert_abs_diff_eq;
271    use ndarray::{arr1, arr2, s, Array1, Array2, NdFloat};
272    use ndarray_rand::{rand_distr::StandardNormal, RandomExt};
273    use rand::distributions::{Distribution, Standard};
274    use rand::SeedableRng;
275    use rand_xoshiro::Xoshiro256Plus;
276
277    /// Generate random array
278    fn random<A>(sh: (usize, usize)) -> Array2<A>
279    where
280        A: NdFloat,
281        Standard: Distribution<A>,
282    {
283        let rng = Xoshiro256Plus::seed_from_u64(3);
284        super::random(sh, rng)
285    }
286
287    #[test]
288    fn test_truncated_svd() {
289        let a = arr2(&[[3., 2., 2.], [2., 3., -2.]]);
290
291        let res = TruncatedSvd::new_with_rng(a, Order::Largest, Xoshiro256Plus::seed_from_u64(42))
292            .precision(1e-5)
293            .maxiter(10)
294            .decompose(2)
295            .unwrap();
296
297        let (_, sigma, _) = res.values_vectors();
298
299        assert_abs_diff_eq!(&sigma, &arr1(&[5.0, 3.0]), epsilon = 1e-5);
300    }
301
302    #[test]
303    fn test_truncated_svd_random() {
304        let a: Array2<f64> = random((50, 10));
305
306        let res = TruncatedSvd::new_with_rng(
307            a.clone(),
308            Order::Largest,
309            Xoshiro256Plus::seed_from_u64(42),
310        )
311        .precision(1e-5)
312        .maxiter(10)
313        .decompose(10)
314        .unwrap();
315
316        let (u, sigma, v_t) = res.values_vectors();
317        let reconstructed = u.dot(&Array2::from_diag(&sigma).dot(&v_t));
318
319        assert_abs_diff_eq!(&a, &reconstructed, epsilon = 1e-5);
320    }
321
322    /// Eigenvalue structure in high dimensions
323    ///
324    /// This test checks that the eigenvalues are following the Marchensko-Pastur law. The data is
325    /// standard uniformly distributed (i.e. E(x) = 0, E^2(x) = 1) and we have twice the amount of
326    /// data when compared to features. The probability density of the eigenvalues should then follow
327    /// a special densitiy function, described by the Marchenko-Pastur law.
328    ///
329    /// See also https://en.wikipedia.org/wiki/Marchenko%E2%80%93Pastur_distribution
330    #[test]
331    fn test_marchenko_pastur() {
332        // create random number generator
333        let mut rng = Xoshiro256Plus::seed_from_u64(3);
334
335        // generate normal distribution random data with N >> p
336        let data = Array2::random_using((1000, 500), StandardNormal, &mut rng) / 1000f64.sqrt();
337
338        let res =
339            TruncatedSvd::new_with_rng(data, Order::Largest, Xoshiro256Plus::seed_from_u64(42))
340                .precision(1e-3)
341                .decompose(500)
342                .unwrap();
343
344        let sv = res.values().mapv(|x: f64| x * x);
345
346        // we have created a random spectrum and can apply the Marchenko-Pastur law
347        // with variance 1 and p/n = 0.5
348        let (a, b) = (
349            1. * (1. - 0.5f64.sqrt()).powf(2.0),
350            1. * (1. + 0.5f64.sqrt()).powf(2.0),
351        );
352
353        // check that the spectrum has correct boundaries
354        assert_abs_diff_eq!(b, sv[0], epsilon = 0.1);
355        assert_abs_diff_eq!(a, sv[sv.len() - 1], epsilon = 0.1);
356
357        // estimate density empirical and compare with Marchenko-Pastur law
358        let mut i = 0;
359        'outer: for th in Array1::linspace(0.1f64, 2.8, 28).slice(s![..;-1]) {
360            let mut count = 0;
361            while sv[i] >= *th {
362                count += 1;
363                i += 1;
364
365                if i == sv.len() {
366                    break 'outer;
367                }
368            }
369
370            let x = th + 0.05;
371            let mp_law = ((b - x) * (x - a)).sqrt() / std::f64::consts::PI / x;
372            let empirical = count as f64 / 500. / ((2.8 - 0.1) / 28.);
373
374            assert_abs_diff_eq!(mp_law, empirical, epsilon = 0.05);
375        }
376    }
377}