rusty_compression/
svd.rs

1//! Define an SVD container and conversion tools.
2
3use crate::qr::{QRTraits, QR};
4use crate::CompressionType;
5use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Axis, Zip};
6use num::ToPrimitive;
7use crate::types::Result;
8use crate::types::RustyCompressionError;
9use crate::types::{c32,c64, ConjMatMat, Scalar};
10
11/// This structure stores the Singular Value Decomposition
12/// of a matrix
13pub struct SVD<A: Scalar> {
14    /// The U matrix
15    pub u: Array2<A>,
16    /// The array of singular values
17    pub s: Array1<A::Real>,
18    /// The vt matrix
19    pub vt: Array2<A>,
20}
21
22/// SVD Traits
23pub trait SVDTraits {
24    type A: Scalar;
25
26    /// Return the number of rows of the underlying operator
27    fn nrows(&self) -> usize {
28        self.get_u().nrows()
29    }
30
31    /// Return the number of columns of the underlying operator
32    fn ncols(&self) -> usize {
33        self.get_vt().ncols()
34    }
35
36    /// Return the rank of the underlying operator
37    fn rank(&self) -> usize {
38        self.get_u().ncols()
39    }
40
41    /// Convert to a matrix.
42    fn to_mat(&self) -> Array2<Self::A> {
43        let mut scaled_vt =
44            Array2::<Self::A>::zeros((self.get_vt().nrows(), self.get_vt().ncols()));
45        scaled_vt.assign(&self.get_vt());
46
47        Zip::from(scaled_vt.axis_iter_mut(Axis(0)))
48            .and(self.get_s().view())
49            .for_each(|mut row, &s_elem| {
50                row.map_inplace(|item| *item *= <Self::A as Scalar>::from_real(s_elem))
51            });
52
53        self.get_u().dot(&scaled_vt)
54    }
55
56    /// Convert to a QR Decomposition.
57    fn to_qr(self) -> Result<QR<Self::A>>;
58
59    /// Compress to SVD.
60    fn compress(&self, compression_type: CompressionType) -> Result<SVD<Self::A>> {
61        match compression_type {
62            CompressionType::ADAPTIVE(tol) => self.compress_svd_tolerance(tol),
63            CompressionType::RANK(rank) => self.compress_svd_rank(rank),
64        }
65    }
66
67    /// Compress the SVD by specifying a target rank.
68    fn compress_svd_rank(&self, mut max_rank: usize) -> Result<SVD<Self::A>> {
69        let (u, s, vt) = (self.get_u(), self.get_s(), self.get_vt());
70
71        if max_rank > s.len() {
72            max_rank = s.len()
73        }
74
75        let u = u.slice(s![.., 0..max_rank]);
76        let s = s.slice(s![0..max_rank]);
77        let vt = vt.slice(s![0..max_rank, ..]);
78
79        Ok(SVD {
80            u: u.into_owned(),
81            s: s.into_owned(),
82            vt: vt.into_owned(),
83        })
84    }
85
86    /// Compress the SVD by specifying a relative tolerance.
87    fn compress_svd_tolerance(&self, tol: f64) -> Result<SVD<Self::A>> {
88        assert!((tol < 1.0) && (0.0 <= tol), "Require 0 <= tol < 1.0");
89
90        let first_val = self.get_s()[0];
91
92        let pos = self
93            .get_s()
94            .iter()
95            .position(|&item| (item / first_val).to_f64().unwrap() < tol);
96
97        match pos {
98            Some(index) => self.compress_svd_rank(index),
99            None => Err(RustyCompressionError::CompressionError),
100        }
101    }
102
103    fn compute_from(arr: ArrayView2<Self::A>) -> Result<SVD<Self::A>>;
104
105    /// Compute a singular value decomposition from a range estimate
106    /// # Arguments
107    /// * `range`: A matrix with orthogonal columns that approximates the range
108    ///            of the operator.
109    /// * `op`: The underlying operator.
110    fn compute_from_range_estimate<Op: ConjMatMat<A = Self::A>>(
111        range: ArrayView2<Self::A>,
112        op: &Op,
113    ) -> Result<SVD<Self::A>>;
114
115    fn get_u(&self) -> ArrayView2<Self::A>;
116    fn get_s(&self) -> ArrayView1<<Self::A as Scalar>::Real>;
117    fn get_vt(&self) -> ArrayView2<Self::A>;
118
119    fn get_u_mut(&mut self) -> ArrayViewMut2<Self::A>;
120    fn get_s_mut(&mut self) -> ArrayViewMut1<<Self::A as Scalar>::Real>;
121    fn get_vt_mut(&mut self) -> ArrayViewMut2<Self::A>;
122}
123
124macro_rules! svd_impl {
125    ($scalar:ty) => {
126        impl SVDTraits for SVD<$scalar> {
127            type A = $scalar;
128
129            fn get_u(&self) -> ArrayView2<Self::A> {
130                self.u.view()
131            }
132
133            fn get_s(&self) -> ArrayView1<<Self::A as Scalar>::Real> {
134                self.s.view()
135            }
136            fn get_vt(&self) -> ArrayView2<Self::A> {
137                self.vt.view()
138            }
139
140            fn get_u_mut(&mut self) -> ArrayViewMut2<Self::A> {
141                self.u.view_mut()
142            }
143            fn get_s_mut(&mut self) -> ArrayViewMut1<<Self::A as Scalar>::Real> {
144                self.s.view_mut()
145            }
146            fn get_vt_mut(&mut self) -> ArrayViewMut2<Self::A> {
147                self.vt.view_mut()
148            }
149
150            fn to_qr(self) -> Result<QR<Self::A>> {
151                let (u, s, mut vt) = (self.u, self.s, self.vt);
152
153                Zip::from(vt.axis_iter_mut(Axis(0)))
154                    .and(s.view())
155                    .for_each(|mut row, &s_elem| {
156                        row.map_inplace(|item| *item *= <Self::A as Scalar>::from_real(s_elem))
157                    });
158
159                let mut qr = QR::<$scalar>::compute_from(vt.view())?;
160                qr.q = u.dot(&qr.q);
161
162                Ok(qr)
163            }
164
165            fn compute_from(arr: ArrayView2<Self::A>) -> Result<SVD<Self::A>> {
166                use crate::compute_svd::ComputeSVD;
167
168                <$scalar>::compute_svd(arr)
169            }
170
171            fn compute_from_range_estimate<Op: ConjMatMat<A = Self::A>>(
172                range: ArrayView2<Self::A>,
173                op: &Op,
174            ) -> Result<SVD<Self::A>> {
175                let b = op.conj_matmat(range).t().map(|item| item.conj());
176                let svd = SVD::<$scalar>::compute_from(b.view())?;
177
178                Ok(SVD {
179                    u: range.dot(&svd.u),
180                    s: svd.get_s().into_owned(),
181                    vt: svd.get_vt().into_owned(),
182                })
183            }
184        }
185    };
186}
187
188svd_impl!(f32);
189svd_impl!(f64);
190svd_impl!(c32);
191svd_impl!(c64);
192
193#[cfg(test)]
194mod tests {
195
196    use super::*;
197    use crate::types::RelDiff;
198    use crate::random_matrix::RandomMatrix;
199    use crate::CompressionType;
200    use ndarray::Axis;
201    use ndarray_linalg::OperationNorm;
202
203    macro_rules! svd_to_qr_tests {
204        ($($name:ident: $scalar:ty, $dim:expr, $tol:expr,)*) => {
205            $(
206            #[test]
207            fn $name() {
208                let m = $dim.0;
209                let n = $dim.1;
210
211                let mut rng = rand::thread_rng();
212                let mat = <$scalar>::random_approximate_low_rank_matrix((m, n), 1.0, 1E-10, &mut rng);
213
214                let svd = SVD::<$scalar>::compute_from(mat.view()).unwrap();
215
216                // Perform a QR decomposition and recover the original matrix.
217                let actual = svd.to_qr().unwrap().to_mat();
218
219                assert!(<$scalar>::rel_diff_fro(actual.view(), mat.view()) < $tol);
220
221                assert!(
222                    (actual - mat.view()).opnorm_fro().unwrap() / mat.opnorm_fro().unwrap() < $tol
223                );
224            }
225            )*
226        };
227    }
228
229    macro_rules! svd_compression_by_rank_tests {
230
231        ($($name:ident: $scalar:ty, $dim:expr, $tol:expr,)*) => {
232
233            $(
234
235        #[test]
236        fn $name() {
237            let m = $dim.0;
238            let n = $dim.1;
239            let rank: usize = 20;
240
241            let sigma_max = 1.0;
242            let sigma_min = 1E-10;
243            let mut rng = rand::thread_rng();
244            let mat = <$scalar>::random_approximate_low_rank_matrix((m, n), sigma_max, sigma_min, &mut rng);
245
246            let svd = SVD::<$scalar>::compute_from(mat.view()).unwrap().compress(CompressionType::RANK(rank)).unwrap();
247
248            // Compare with original matrix
249
250            assert!(svd.u.len_of(Axis(1)) == rank);
251            assert!(svd.vt.len_of(Axis(0)) == rank);
252
253            assert!(<$scalar>::rel_diff_fro(svd.to_mat().view(), mat.view()) < $tol);
254        }
255
256            )*
257
258        }
259    }
260
261    macro_rules! svd_compression_by_tol_tests {
262
263        ($($name:ident: $scalar:ty, $dim:expr, $tol:expr,)*) => {
264
265            $(
266
267        #[test]
268        fn $name() {
269            let m = $dim.0;
270            let n = $dim.1;
271
272            let sigma_max = 1.0;
273            let sigma_min = 1E-10;
274            let mut rng = rand::thread_rng();
275            let mat = <$scalar>::random_approximate_low_rank_matrix((m, n), sigma_max, sigma_min, &mut rng);
276
277            let svd = SVD::<$scalar>::compute_from(mat.view()).unwrap().compress(CompressionType::ADAPTIVE($tol)).unwrap();
278
279            // Compare with original matrix
280
281            assert!(<$scalar>::rel_diff_fro(svd.to_mat().view(), mat.view()) < $tol);
282        }
283
284            )*
285
286        }
287    }
288
289    svd_to_qr_tests! {
290        test_svd_to_qr_f32_thin: f32, (100, 50), 1E-5,
291        test_svd_to_qr_c32_thin: ndarray_linalg::c32, (100, 50), 1E-5,
292        test_svd_to_qr_f64_thin: f64, (100, 50), 1E-12,
293        test_svd_to_qr_c64_thin: ndarray_linalg::c64, (100, 50), 1E-12,
294        test_svd_to_qr_f32_thick: f32, (50, 100), 1E-5,
295        test_svd_to_qr_c32_thick: ndarray_linalg::c32, (50, 100), 1E-5,
296        test_svd_to_qr_f64_thick: f64, (50, 100), 1E-12,
297        test_svd_to_qr_c64_thick: ndarray_linalg::c64, (50, 100), 1E-12,
298    }
299
300    svd_compression_by_rank_tests! {
301        test_svd_compression_by_rank_f32_thin: f32, (100, 50), 1E-4,
302        test_svd_compression_by_rank_c32_thin: ndarray_linalg::c32, (100, 50), 1E-4,
303        test_svd_compression_by_rank_f64_thin: f64, (100, 50), 1E-4,
304        test_svd_compression_by_rank_c64_thin: ndarray_linalg::c64, (100, 50), 1E-4,
305        test_svd_compression_by_rank_f32_thick: f32, (50, 100), 1E-4,
306        test_svd_compression_by_rank_c32_thick: ndarray_linalg::c32, (50, 100), 1E-4,
307        test_svd_compression_by_rank_f64_thick: f64, (50, 100), 1E-4,
308        test_svd_compression_by_rank_c64_thick: ndarray_linalg::c64, (50, 100), 1E-4,
309    }
310
311    svd_compression_by_tol_tests! {
312        test_svd_compression_by_tol_f32_thin: f32, (100, 50), 1E-4,
313        test_svd_compression_by_tol_c32_thin: ndarray_linalg::c32, (100, 50), 1E-4,
314        test_svd_compression_by_tol_f64_thin: f64, (100, 50), 1E-4,
315        test_svd_compression_by_tol_c64_thin: ndarray_linalg::c64, (100, 50), 1E-4,
316        test_svd_compression_by_tol_f32_thick: f32, (50, 100), 1E-4,
317        test_svd_compression_by_tol_c32_thick: ndarray_linalg::c32, (50, 100), 1E-4,
318        test_svd_compression_by_tol_f64_thick: f64, (50, 100), 1E-4,
319        test_svd_compression_by_tol_c64_thick: ndarray_linalg::c64, (50, 100), 1E-4,
320    }
321}