rusty_compression/
qr.rs

1//! Data Structures and traits for QR Decompositions
2//!
3//! The pivoted QR Decomposition of a matrix $A\in\mathbb{C}^{m\times n}$ is
4//! defined as $AP = QR$, where $P$ is a permutation matrix, $Q\in\mathbb{C}^{m\times k}$
5//! is a matrix with orthogonal columns, satisfying $Q^HQ = I$, and $R\in\mathbb{C}^{k\times n}$
6//! is an upper triangular matrix with diagonal elements $r_{ii}$ satisfying $|r_{11}|\geq |r_{22}|\geq \dots$.
7//! Here $k=\min{m, n}$. The matrix $P$ is defined by an index vector `ind` in such a way that if ind\[j\] = k then
8//! the jth column of $P$ is 1 at the position P\[k, j\] and 0 otherwise. In other words the matrix $P$ permutes the
9//! $k$th column of $A$ to the $j$th column.
10//!
11//! This module also defines the LQ Decomposition defined as $PA = LQ$ with $L$ a lower triangular matrix. If
12//! $A^H\tilde{P}=\tilde{Q}R$ is the QR decomposition as defined above, then $P = \tilde{P}^T$, $L=R^H$, $Q=\tilde{Q}^H$.
13//!
14//! Both, the QR and the LQ Decomposition of a matrix can be compressed further, either by specifying a rank or
15//! by specifying a relative tolerance. Let $AP=QR$. We can compress the QR Decomposition by only keeping the first
16//! $\ell$ columns ($\ell \leq k$) of $Q$ and correspondingly only keeping the first $\ell$ rows of $R$.
17//! We can alternatively determine the $\ell$ by a tolerance tol such that only the first $\ell$ rows of $R$
18//! are kept that satisfy $|r_{\ell, \ell}| / |r_{1, 1}| \geq tol$.
19
20use crate::col_interp_decomp::{ColumnID, ColumnIDTraits};
21use crate::permutation::{ApplyPermutationToMatrix, MatrixPermutationMode};
22use crate::pivoted_qr::PivotedQR;
23use crate::row_interp_decomp::{RowID, RowIDTraits};
24use crate::CompressionType;
25use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Axis};
26use ndarray_linalg::{Diag, SolveTriangular, UPLO};
27use num::ToPrimitive;
28use crate::types::{c32, c64, Result, Scalar};
29use crate::types::{ConjMatMat, RustyCompressionError};
30
31pub struct QR<A: Scalar> {
32    /// The Q matrix from the QR Decomposition
33    pub q: Array2<A>,
34    /// The R matrix from the QR Decomposition
35    pub r: Array2<A>,
36    /// An index array. If ind\[j\] = k then the
37    /// jth column of Q * R is identical to the
38    /// kth column of the original matrix A.
39    pub ind: Array1<usize>,
40}
41
42pub struct LQ<A: Scalar> {
43    /// The Q matrix from the LQ Decomposition
44    pub l: Array2<A>,
45    /// The Q matrix from the LQ Decomposition
46    pub q: Array2<A>,
47    /// An index array. If ind\[j\] = k then the
48    /// jth row of L * Q is identical to the
49    /// kth row of the original matrix A.
50    pub ind: Array1<usize>,
51}
52
53/// Traits for the LQ Decomposition
54pub trait LQTraits {
55    type A: Scalar;
56
57    /// Number of rows
58    fn nrows(&self) -> usize {
59        self.get_l().nrows()
60    }
61
62    /// Number of columns
63    fn ncols(&self) -> usize {
64        self.get_q().ncols()
65    }
66
67    /// Rank of the LQ decomposition
68    fn rank(&self) -> usize {
69        self.get_q().nrows()
70    }
71
72    /// Convert the LQ decomposition to a matrix
73    fn to_mat(&self) -> Array2<Self::A> {
74        self.get_l()
75            .apply_permutation(self.get_ind(), MatrixPermutationMode::ROWINV)
76            .dot(&self.get_q())
77    }
78
79    /// Compress by giving a target rank
80    fn compress_lq_rank(&self, mut max_rank: usize) -> Result<LQ<Self::A>> {
81        let (l, q, ind) = (self.get_l(), self.get_q(), self.get_ind());
82
83        if max_rank > q.nrows() {
84            max_rank = q.nrows()
85        }
86
87        let q = q.slice(s![0..max_rank, ..]);
88        let l = l.slice(s![.., 0..max_rank]);
89
90        Ok(LQ {
91            l: l.into_owned(),
92            q: q.into_owned(),
93            ind: ind.into_owned(),
94        })
95    }
96
97    /// Compress by specifying a relative tolerance
98    fn compress_lq_tolerance(&self, tol: f64) -> Result<LQ<Self::A>> {
99        assert!((tol < 1.0) && (0.0 <= tol), "Require 0 <= tol < 1.0");
100
101        let pos = self
102            .get_l()
103            .diag()
104            .iter()
105            .position(|&item| ((item / self.get_l()[[0, 0]]).abs()).to_f64().unwrap() < tol);
106
107        match pos {
108            Some(index) => self.compress_lq_rank(index),
109            None => Err(RustyCompressionError::CompressionError),
110        }
111    }
112
113    /// Compress the LQ Decomposition by rank or tolerance
114    fn compress(&self, compression_type: CompressionType) -> Result<LQ<Self::A>> {
115        match compression_type {
116            CompressionType::ADAPTIVE(tol) => self.compress_lq_tolerance(tol),
117            CompressionType::RANK(rank) => self.compress_lq_rank(rank),
118        }
119    }
120
121    /// Return the Q matrix
122    fn get_q(&self) -> ArrayView2<Self::A>;
123
124    /// Return the L matrix
125    fn get_l(&self) -> ArrayView2<Self::A>;
126
127    /// Return the index vector
128    fn get_ind(&self) -> ArrayView1<usize>;
129
130    fn get_q_mut(&mut self) -> ArrayViewMut2<Self::A>;
131    fn get_l_mut(&mut self) -> ArrayViewMut2<Self::A>;
132    fn get_ind_mut(&mut self) -> ArrayViewMut1<usize>;
133
134    /// Compute the LQ decomposition from a given array
135    fn compute_from(arr: ArrayView2<Self::A>) -> Result<LQ<Self::A>>;
136
137    /// Compute a row interpolative decomposition from the LQ decomposition
138    fn row_id(&self) -> Result<RowID<Self::A>>;
139}
140
141pub trait QRTraits {
142    type A: Scalar;
143
144    /// Number of rows
145    fn nrows(&self) -> usize {
146        self.get_q().nrows()
147    }
148
149    /// Number of columns
150    fn ncols(&self) -> usize {
151        self.get_r().ncols()
152    }
153
154    /// Rank of the QR Decomposition
155    fn rank(&self) -> usize {
156        self.get_q().ncols()
157    }
158
159    /// Convert the QR decomposition to a matrix
160    fn to_mat(&self) -> Array2<Self::A> {
161        self.get_q().dot(
162            &self
163                .get_r()
164                .apply_permutation(self.get_ind(), MatrixPermutationMode::COLINV),
165        )
166    }
167
168    /// Compress by giving a target rank
169    fn compress_qr_rank(&self, mut max_rank: usize) -> Result<QR<Self::A>> {
170        let (q, r, ind) = (self.get_q(), self.get_r(), self.get_ind());
171
172        if max_rank > q.ncols() {
173            max_rank = q.ncols()
174        }
175
176        let q = q.slice(s![.., 0..max_rank]);
177        let r = r.slice(s![0..max_rank, ..]);
178
179        Ok(QR {
180            q: q.into_owned(),
181            r: r.into_owned(),
182            ind: ind.into_owned(),
183        })
184    }
185
186    /// Compress by specifying a relative tolerance
187    fn compress_qr_tolerance(&self, tol: f64) -> Result<QR<Self::A>> {
188        assert!((tol < 1.0) && (0.0 <= tol), "Require 0 <= tol < 1.0");
189
190        let pos = self
191            .get_r()
192            .diag()
193            .iter()
194            .position(|&item| ((item / self.get_r()[[0, 0]]).abs()).to_f64().unwrap() < tol);
195
196        match pos {
197            Some(index) => self.compress_qr_rank(index),
198            None => Err(RustyCompressionError::CompressionError),
199        }
200    }
201
202    /// Compress the QR decomposition by rank or tolerance
203    fn compress(&self, compression_type: CompressionType) -> Result<QR<Self::A>> {
204        match compression_type {
205            CompressionType::ADAPTIVE(tol) => self.compress_qr_tolerance(tol),
206            CompressionType::RANK(rank) => self.compress_qr_rank(rank),
207        }
208    }
209
210    /// Compute a column interpolative decomposition from the QR decomposition
211    fn column_id(&self) -> Result<ColumnID<Self::A>>;
212
213    /// Compute the QR decomposition from a given array
214    fn compute_from(arr: ArrayView2<Self::A>) -> Result<QR<Self::A>>;
215
216    /// Compute a QR decomposition from a range estimate
217    /// # Arguments
218    /// * `range`: A matrix with orthogonal columns that approximates the range
219    ///            of the operator.
220    /// * `op`: The underlying operator.
221    fn compute_from_range_estimate<Op: ConjMatMat<A = Self::A>>(
222        range: ArrayView2<Self::A>,
223        op: &Op,
224    ) -> Result<QR<Self::A>>;
225
226    /// Return the Q matrix
227    fn get_q(&self) -> ArrayView2<Self::A>;
228
229    /// Return the R matrix
230    fn get_r(&self) -> ArrayView2<Self::A>;
231
232    /// Return the index vector
233    fn get_ind(&self) -> ArrayView1<usize>;
234
235    fn get_q_mut(&mut self) -> ArrayViewMut2<Self::A>;
236    fn get_r_mut(&mut self) -> ArrayViewMut2<Self::A>;
237    fn get_ind_mut(&mut self) -> ArrayViewMut1<usize>;
238}
239
240macro_rules! qr_data_impl {
241    ($scalar:ty) => {
242        impl QRTraits for QR<$scalar> {
243            type A = $scalar;
244            fn get_q(&self) -> ArrayView2<Self::A> {
245                self.q.view()
246            }
247            fn get_r(&self) -> ArrayView2<Self::A> {
248                self.r.view()
249            }
250
251            fn compute_from(arr: ArrayView2<Self::A>) -> Result<QR<Self::A>> {
252                <$scalar>::pivoted_qr(arr)
253            }
254
255            fn get_ind(&self) -> ArrayView1<usize> {
256                self.ind.view()
257            }
258
259            fn get_q_mut(&mut self) -> ArrayViewMut2<Self::A> {
260                self.q.view_mut()
261            }
262            fn get_r_mut(&mut self) -> ArrayViewMut2<Self::A> {
263                self.r.view_mut()
264            }
265
266            fn get_ind_mut(&mut self) -> ArrayViewMut1<usize> {
267                self.ind.view_mut()
268            }
269
270            fn column_id(&self) -> Result<ColumnID<Self::A>> {
271                let rank = self.rank();
272                let nrcols = self.ncols();
273
274                if rank == nrcols {
275                    // Matrix not rank deficient.
276                    Ok(ColumnID::<Self::A>::new(
277                        self.get_q().dot(&self.get_r()),
278                        Array2::<Self::A>::eye(rank)
279                            .apply_permutation(self.get_ind(), MatrixPermutationMode::COLINV),
280                        self.get_ind().into_owned(),
281                    ))
282                } else {
283                    // Matrix is rank deficient.
284
285                    let mut z = Array2::<Self::A>::zeros((rank, self.get_r().ncols()));
286                    z.slice_mut(s![.., 0..rank]).diag_mut().fill(num::one());
287                    let first_part = self.get_r().slice(s![.., 0..rank]).to_owned();
288                    let c = self.get_q().dot(&first_part);
289
290                    for (index, col) in self
291                        .get_r()
292                        .slice(s![.., rank..nrcols])
293                        .axis_iter(Axis(1))
294                        .enumerate()
295                    {
296                        z.index_axis_mut(Axis(1), rank + index).assign(
297                            &first_part
298                                .solve_triangular(UPLO::Upper, Diag::NonUnit, &col.to_owned())
299                                .unwrap(),
300                        );
301                    }
302
303                    Ok(ColumnID::<Self::A>::new(
304                        c,
305                        z.apply_permutation(self.get_ind(), MatrixPermutationMode::COLINV),
306                        self.get_ind().into_owned(),
307                    ))
308                }
309            }
310
311            fn compute_from_range_estimate<Op: ConjMatMat<A = Self::A>>(
312                range: ArrayView2<Self::A>,
313                op: &Op,
314            ) -> Result<QR<Self::A>> {
315                let b = op.conj_matmat(range).t().map(|item| item.conj());
316                let qr = QR::<$scalar>::compute_from(b.view())?;
317
318                Ok(QR {
319                    q: range.dot(&qr.get_q()),
320                    r: qr.get_r().into_owned(),
321                    ind: qr.get_ind().into_owned(),
322                })
323            }
324        }
325    };
326}
327
328macro_rules! lq_data_impl {
329    ($scalar:ty) => {
330        impl LQTraits for LQ<$scalar> {
331            type A = $scalar;
332
333            fn get_q(&self) -> ArrayView2<Self::A> {
334                self.q.view()
335            }
336
337            fn get_l(&self) -> ArrayView2<Self::A> {
338                self.l.view()
339            }
340            fn get_ind(&self) -> ArrayView1<usize> {
341                self.ind.view()
342            }
343
344            fn get_q_mut(&mut self) -> ArrayViewMut2<Self::A> {
345                self.q.view_mut()
346            }
347            fn get_l_mut(&mut self) -> ArrayViewMut2<Self::A> {
348                self.l.view_mut()
349            }
350            fn get_ind_mut(&mut self) -> ArrayViewMut1<usize> {
351                self.ind.view_mut()
352            }
353
354            fn compute_from(arr: ArrayView2<Self::A>) -> Result<LQ<Self::A>> {
355                let arr_trans = arr.t().map(|val| val.conj());
356                let qr = QR::<$scalar>::compute_from(arr_trans.view())?;
357                Ok(LQ {
358                    l: qr.r.t().map(|item| item.conj()),
359                    q: qr.q.t().map(|item| item.conj()),
360                    ind: qr.ind,
361                })
362            }
363            fn row_id(&self) -> Result<RowID<Self::A>> {
364                let rank = self.rank();
365                let nrows = self.nrows();
366
367                if rank == nrows {
368                    // Matrix not rank deficient.
369                    Ok(RowID::<Self::A>::new(
370                        Array2::<Self::A>::eye(rank)
371                            .apply_permutation(self.ind.view(), MatrixPermutationMode::ROWINV),
372                        self.l.dot(&self.q),
373                        self.ind.clone(),
374                    ))
375                } else {
376                    // Matrix is rank deficient.
377
378                    let mut x = Array2::<Self::A>::zeros((self.nrows(), rank));
379                    x.slice_mut(s![0..rank, ..]).diag_mut().fill(num::one());
380                    let first_part = self.l.slice(s![0..rank, ..]).to_owned();
381                    let r = first_part.dot(&self.q);
382                    let first_part_transposed = first_part.t().to_owned();
383
384                    for (index, row) in self
385                        .l
386                        .slice(s![rank..nrows, ..])
387                        .axis_iter(Axis(0))
388                        .enumerate()
389                    {
390                        x.index_axis_mut(Axis(0), rank + index).assign(
391                            &first_part_transposed
392                                .solve_triangular(UPLO::Upper, Diag::NonUnit, &row.to_owned())
393                                .unwrap(),
394                        );
395                    }
396
397                    Ok(RowID::<Self::A>::new(
398                        x.apply_permutation(self.ind.view(), MatrixPermutationMode::ROWINV),
399                        r,
400                        self.ind.clone(),
401                    ))
402                }
403            }
404        }
405    };
406}
407
408qr_data_impl!(f32);
409qr_data_impl!(f64);
410qr_data_impl!(c32);
411qr_data_impl!(c64);
412
413lq_data_impl!(f32);
414lq_data_impl!(f64);
415lq_data_impl!(c32);
416lq_data_impl!(c64);
417
418#[cfg(test)]
419mod tests {
420
421    use super::*;
422    use crate::types::RelDiff;
423    use crate::pivoted_qr::PivotedQR;
424    use crate::random_matrix::RandomMatrix;
425    use ndarray::Axis;
426
427    macro_rules! qr_compression_by_rank_tests {
428
429        ($($name:ident: $scalar:ty, $dim:expr, $tol:expr,)*) => {
430
431            $(
432
433        #[test]
434        fn $name() {
435            let m = $dim.0;
436            let n = $dim.1;
437            let rank: usize = 30;
438
439            let sigma_max = 1.0;
440            let sigma_min = 1E-10;
441            let mut rng = rand::thread_rng();
442            let mat = <$scalar>::random_approximate_low_rank_matrix((m, n), sigma_max, sigma_min, &mut rng);
443
444            let qr = <$scalar>::pivoted_qr(mat.view()).unwrap().compress(CompressionType::RANK(rank)).unwrap();
445
446            // Compare with original matrix
447
448            assert!(qr.q.len_of(Axis(1)) == rank);
449            assert!(qr.r.len_of(Axis(0)) == rank);
450            assert!(<$scalar>::rel_diff_fro(qr.to_mat().view(), mat.view()) < $tol);
451
452        }
453
454            )*
455
456        }
457    }
458
459    macro_rules! qr_compression_by_tol_tests {
460
461        ($($name:ident: $scalar:ty, $dim:expr, $tol:expr,)*) => {
462
463            $(
464
465        #[test]
466        fn $name() {
467            let m = $dim.0;
468            let n = $dim.1;
469
470            let sigma_max = 1.0;
471            let sigma_min = 1E-10;
472            let mut rng = rand::thread_rng();
473            let mat = <$scalar>::random_approximate_low_rank_matrix((m, n), sigma_max, sigma_min, &mut rng);
474
475            let qr = <$scalar>::pivoted_qr(mat.view()).unwrap().compress(CompressionType::ADAPTIVE($tol)).unwrap();
476
477            // Compare with original matrix
478
479            assert!(<$scalar>::rel_diff_fro(qr.to_mat().view(), mat.view()) < 5.0 * $tol);
480
481            // Make sure new rank is smaller than original rank
482
483            assert!(qr.q.ncols() < m.min(n));
484        }
485
486            )*
487
488        }
489    }
490
491    macro_rules! col_id_compression_tests {
492
493        ($($name:ident: $scalar:ty, $dim:expr, $tol:expr,)*) => {
494
495            $(
496
497        #[test]
498        fn $name() {
499            let m = $dim.0;
500            let n = $dim.1;
501
502            let sigma_max = 1.0;
503            let sigma_min = 1E-10;
504            let mut rng = rand::thread_rng();
505            let mat = <$scalar>::random_approximate_low_rank_matrix((m, n), sigma_max, sigma_min, &mut rng);
506
507            let qr = QR::<$scalar>::compute_from(mat.view()).unwrap().compress(CompressionType::ADAPTIVE($tol)).unwrap();
508            let rank = qr.rank();
509            let column_id = qr.column_id().unwrap();
510
511            // Compare with original matrix
512
513            assert!(<$scalar>::rel_diff_fro(column_id.to_mat().view(), mat.view()) < 5.0 * $tol);
514
515            // Now compare the individual columns to make sure that the id basis columns
516            // agree with the corresponding matrix columns.
517
518            let mat_permuted = mat.apply_permutation(column_id.get_col_ind(), MatrixPermutationMode::COL);
519
520            for index in 0..rank {
521                assert!(
522                    <$scalar>::rel_diff_l2(mat_permuted.index_axis(Axis(1), index), column_id.get_c().index_axis(Axis(1), index)) < $tol);
523
524            }
525
526        }
527
528            )*
529
530        }
531    }
532    macro_rules! row_id_compression_tests {
533
534        ($($name:ident: $scalar:ty, $dim:expr, $tol:expr,)*) => {
535
536            $(
537
538        #[test]
539        fn $name() {
540            let m = $dim.0;
541            let n = $dim.1;
542
543            let sigma_max = 1.0;
544            let sigma_min = 1E-10;
545            let mut rng = rand::thread_rng();
546            let mat = <$scalar>::random_approximate_low_rank_matrix((m, n), sigma_max, sigma_min, &mut rng);
547
548            let lq = LQ::<$scalar>::compute_from(mat.view()).unwrap().compress(CompressionType::ADAPTIVE($tol)).unwrap();
549            let rank = lq.rank();
550            let row_id = lq.row_id().unwrap();
551
552            // Compare with original matrix
553
554            assert!(<$scalar>::rel_diff_fro(row_id.to_mat().view(), mat.view()) < 5.0 * $tol);
555
556            // Now compare the individual columns to make sure that the id basis columns
557            // agree with the corresponding matrix columns.
558
559            let mat_permuted = mat.apply_permutation(row_id.get_row_ind(), MatrixPermutationMode::ROW);
560
561            for index in 0..rank {
562                assert!(<$scalar>::rel_diff_l2(mat_permuted.index_axis(Axis(0), index), row_id.get_r().index_axis(Axis(0), index)) < $tol);
563
564            }
565
566        }
567
568            )*
569
570        }
571    }
572
573    row_id_compression_tests! {
574        test_row_id_compression_by_tol_f32_thin: f32, (100, 50), 1E-4,
575        test_row_id_compression_by_tol_c32_thin: ndarray_linalg::c32, (100, 50), 1E-4,
576        test_row_id_compression_by_tol_f64_thin: f64, (100, 50), 1E-4,
577        test_row_id_compression_by_tol_c64_thin: ndarray_linalg::c64, (100, 50), 1E-4,
578        test_row_id_compression_by_tol_f32_thick: f32, (50, 100), 1E-4,
579        test_row_id_compression_by_tol_c32_thick: ndarray_linalg::c32, (50, 100), 1E-4,
580        test_row_id_compression_by_tol_f64_thick: f64, (50, 100), 1E-4,
581        test_row_id_compression_by_tol_c64_thick: ndarray_linalg::c64, (50, 100), 1E-4,
582    }
583
584    col_id_compression_tests! {
585        test_col_id_compression_by_tol_f32_thin: f32, (100, 50), 1E-4,
586        test_col_id_compression_by_tol_c32_thin: ndarray_linalg::c32, (100, 50), 1E-4,
587        test_col_id_compression_by_tol_f64_thin: f64, (100, 50), 1E-4,
588        test_col_id_compression_by_tol_c64_thin: ndarray_linalg::c64, (100, 50), 1E-4,
589        test_col_id_compression_by_tol_f32_thick: f32, (50, 100), 1E-4,
590        test_col_id_compression_by_tol_c32_thick: ndarray_linalg::c32, (50, 100), 1E-4,
591        test_col_id_compression_by_tol_f64_thick: f64, (50, 100), 1E-4,
592        test_col_id_compression_by_tol_c64_thick: ndarray_linalg::c64, (50, 100), 1E-4,
593    }
594
595    qr_compression_by_rank_tests! {
596        test_qr_compression_by_rank_f32_thin: f32, (100, 50), 1E-4,
597
598        test_qr_compression_by_rank_f64_thin: f64, (100, 50), 1E-4,
599        test_qr_compression_by_rank_c64_thin: ndarray_linalg::c64, (100, 50), 1E-4,
600        test_qr_compression_by_rank_f32_thick: f32, (50, 100), 1E-4,
601        test_qr_compression_by_rank_c32_thick: ndarray_linalg::c32, (50, 100), 1E-4,
602        test_qr_compression_by_rank_f64_thick: f64, (50, 100), 1E-4,
603        test_qr_compression_by_rank_c64_thick: ndarray_linalg::c64, (50, 100), 1E-4,
604    }
605
606    qr_compression_by_tol_tests! {
607        test_qr_compression_by_tol_f32_thin: f32, (100, 50), 1E-4,
608        test_qr_compression_by_tol_c32_thin: ndarray_linalg::c32, (100, 50), 1E-4,
609        test_qr_compression_by_tol_f64_thin: f64, (100, 50), 1E-4,
610        test_qr_compression_by_tol_c64_thin: ndarray_linalg::c64, (100, 50), 1E-4,
611        test_qr_compression_by_tol_f32_thick: f32, (50, 100), 1E-4,
612        test_qr_compression_by_tol_c32_thick: ndarray_linalg::c32, (50, 100), 1E-4,
613        test_qr_compression_by_tol_f64_thick: f64, (50, 100), 1E-4,
614        test_qr_compression_by_tol_c64_thick: ndarray_linalg::c64, (50, 100), 1E-4,
615    }
616}