rusty_compression/
two_sided_interp_decomp.rs

1//! Implementation of the two sided interpolative decomposition
2//! 
3//! The two sided interpolative decomposition of a matrix $A\in\mathbb{C}&{m\times n}$ is
4//! defined as
5//! $$
6//! A \approx CXR,
7//! $$
8//! where $C\in\mathbb{C}^{m\times k}$, $X\in\mathbb{C}^{k\times k}$, and $R\in\mathbb{C}^{k\times n}$.
9//! The matrix $X$ contains a subset of the entries of $A$, such that A\[row_ind\[:\], col_ind\[:\]\] = X, where
10//! row_ind and col_ind are index vectors.
11
12use crate::types::Apply;
13use ndarray::{
14    Array1, Array2, ArrayBase, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Data, Ix1, Ix2,
15};
16use crate::types::{c32, c64, Scalar};
17
18/// Store a two sided interpolative decomposition
19pub struct TwoSidedID<A: Scalar> {
20    /// The C matrix of the two sided interpolative decomposition
21    pub c: Array2<A>,
22    /// The X matrix of the two sided interpolative decomposition
23    pub x: Array2<A>,
24    /// The R matrix of the two sided interpolative decomposition
25    pub r: Array2<A>,
26    /// The row index vector
27    pub row_ind: Array1<usize>,
28    /// The column index vector
29    pub col_ind: Array1<usize>,
30}
31
32/// Traits defining a two sided interpolative decomposition
33/// 
34/// defined as
35/// The two sided interpolative decomposition of a matrix $A\in\mathbb{C}&{m\times n} is
36/// $$
37/// A \approx CXR,
38/// $$
39/// where $C\in\mathbb{C}^{m\times k}$, $X\in\mathbb{C}^{k\times k}$, and $R\in\mathbb{C}^{k\times n}$.
40/// The matrix $X$ contains a subset of the entries of $A$, such that A\[row_ind\[:\], col_ind\[:\]\] = X, where
41/// row_ind and col_ind are index vectors.
42
43pub trait TwoSidedIDTraits {
44    type A: Scalar;
45
46    /// Number of rows of the underlying operator
47    fn nrows(&self) -> usize {
48        self.get_c().nrows()
49    }
50
51    /// Number of columns of the underlying operator
52    fn ncols(&self) -> usize {
53        self.get_r().ncols()
54    }
55
56    /// Rank of the two sided interpolative decomposition
57    fn rank(&self) -> usize {
58        self.get_c().ncols()
59    }
60
61    /// Convert to a matrix
62    fn to_mat(&self) -> Array2<Self::A> {
63        self.get_c().dot(&self.get_x().dot(&self.get_r()))
64    }
65
66    /// Return the C matrix
67    fn get_c(&self) -> ArrayView2<Self::A>;
68
69    /// Return the X matrix
70    fn get_x(&self) -> ArrayView2<Self::A>;
71
72    /// Return the R matrix
73    fn get_r(&self) -> ArrayView2<Self::A>;
74
75    /// Return the column index vector
76    fn get_col_ind(&self) -> ArrayView1<usize>;
77
78    /// Return the row index vector
79    fn get_row_ind(&self) -> ArrayView1<usize>;
80
81    fn get_c_mut(&mut self) -> ArrayViewMut2<Self::A>;
82    fn get_x_mut(&mut self) -> ArrayViewMut2<Self::A>;
83    fn get_r_mut(&mut self) -> ArrayViewMut2<Self::A>;
84    fn get_col_ind_mut(&mut self) -> ArrayViewMut1<usize>;
85    fn get_row_ind_mut(&mut self) -> ArrayViewMut1<usize>;
86
87    /// Return a two sided interpolative decomposition from the component matrices
88    /// X, R, C, and the column and row index vectors
89    fn new(
90        x: Array2<Self::A>,
91        r: Array2<Self::A>,
92        c: Array2<Self::A>,
93        col_ind: Array1<usize>,
94        row_ind: Array1<usize>,
95    ) -> Self;
96}
97
98macro_rules! impl_two_sided_id {
99    ($scalar:ty) => {
100        impl TwoSidedIDTraits for TwoSidedID<$scalar> {
101            type A = $scalar;
102
103            fn get_c(&self) -> ArrayView2<Self::A> {
104                self.c.view()
105            }
106
107            fn get_x(&self) -> ArrayView2<Self::A> {
108                self.x.view()
109            }
110
111            fn get_r(&self) -> ArrayView2<Self::A> {
112                self.r.view()
113            }
114            fn get_col_ind(&self) -> ArrayView1<usize> {
115                self.col_ind.view()
116            }
117            fn get_row_ind(&self) -> ArrayView1<usize> {
118                self.row_ind.view()
119            }
120
121            fn get_c_mut(&mut self) -> ArrayViewMut2<Self::A> {
122                self.c.view_mut()
123            }
124
125            fn get_x_mut(&mut self) -> ArrayViewMut2<Self::A> {
126                self.x.view_mut()
127            }
128
129            fn get_r_mut(&mut self) -> ArrayViewMut2<Self::A> {
130                self.r.view_mut()
131            }
132            fn get_col_ind_mut(&mut self) -> ArrayViewMut1<usize> {
133                self.col_ind.view_mut()
134            }
135            fn get_row_ind_mut(&mut self) -> ArrayViewMut1<usize> {
136                self.row_ind.view_mut()
137            }
138            fn new(
139                x: Array2<Self::A>,
140                r: Array2<Self::A>,
141                c: Array2<Self::A>,
142                col_ind: Array1<usize>,
143                row_ind: Array1<usize>,
144            ) -> Self {
145                TwoSidedID::<$scalar> {
146                    x,
147                    r,
148                    c,
149                    col_ind,
150                    row_ind,
151                }
152            }
153        }
154        impl<S> Apply<$scalar, ArrayBase<S, Ix1>> for TwoSidedID<$scalar>
155        where
156            S: Data<Elem = $scalar>,
157        {
158            type Output = Array1<$scalar>;
159            fn dot(&self, rhs: &ArrayBase<S, Ix1>) -> Self::Output {
160                self.c.dot(&self.x.dot(&self.r.dot(rhs)))
161            }
162        }
163        impl<S> Apply<$scalar, ArrayBase<S, Ix2>> for TwoSidedID<$scalar>
164        where
165            S: Data<Elem = $scalar>,
166        {
167            type Output = Array2<$scalar>;
168            fn dot(&self, rhs: &ArrayBase<S, Ix2>) -> Self::Output {
169                self.c.dot(&self.x.dot(&self.r.dot(rhs)))
170            }
171        }
172    };
173}
174
175impl_two_sided_id!(f32);
176impl_two_sided_id!(f64);
177impl_two_sided_id!(c32);
178impl_two_sided_id!(c64);
179
180// impl<A: ScalarType> TwoSidedIDResult<A> {
181//     pub fn nrows(&self) -> usize {
182//         self.c.nrows()
183//     }
184
185//     pub fn ncols(&self) -> usize {
186//         self.r.ncols()
187//     }
188
189//     pub fn rank(&self) -> usize {
190//         self.x.nrows()
191//     }
192
193//     pub fn to_mat(&self) -> Array2<A> {
194//         self.c.dot(&self.x.dot(&self.r))
195//     }
196
197//     pub fn apply_matrix<S: Data<Elem = A>>(
198//         &self,
199//         other: &ArrayBase<S, Ix2>,
200//     ) -> ArrayBase<OwnedRepr<A>, Ix2> {
201//         self.c.dot(&self.x.dot(&self.r.dot(other)))
202//     }
203
204//     pub fn apply_vector<S: Data<Elem = A>>(
205//         &self,
206//         other: &ArrayBase<S, Ix1>,
207//     ) -> ArrayBase<OwnedRepr<A>, Ix1> {
208//         self.c.dot(&self.x.dot(&self.r.dot(other)))
209//     }
210
211//     //}
212// }
213
214// impl<A: ScalarType> QRContainer<A> {
215//     pub fn two_sided_id(&self) -> Result<TwoSidedIDResult<A>> {
216//         let col_id = self.column_id()?;
217//         let row_id = col_id.c.pivoted_lq()?.row_id()?;
218
219//         Ok(TwoSidedIDResult {
220//             c: row_id.x,
221//             x: row_id.r,
222//             r: col_id.z,
223//             row_ind: row_id.row_ind,
224//             col_ind: col_id.col_ind,
225//         })
226//     }
227// }
228
229// #[cfg(test)]
230// mod tests {
231
232//     use crate::prelude::ApplyPermutationToMatrix;
233//     use crate::prelude::CompressionType;
234//     use crate::prelude::MatrixPermutationMode;
235//     use crate::prelude::PivotedQR;
236//     use crate::prelude::RandomMatrix;
237//     use crate::prelude::RelDiff;
238//     use ndarray_linalg::Scalar;
239
240//     macro_rules! id_compression_tests {
241
242//         ($($name:ident: $scalar:ty, $dim:expr, $tol:expr,)*) => {
243
244//             $(
245
246//         #[test]
247//         fn $name() {
248//             let m = $dim.0;
249//             let n = $dim.1;
250
251//             let sigma_max = 1.0;
252//             let sigma_min = 1E-10;
253//             let mut rng = rand::thread_rng();
254//             let mat = <$scalar>::random_approximate_low_rank_matrix((m, n), sigma_max, sigma_min, &mut rng);
255
256//             let qr = mat.pivoted_qr().unwrap().compress(CompressionType::ADAPTIVE($tol)).unwrap();
257//             let rank = qr.rank();
258//             let two_sided_id = qr.two_sided_id().unwrap();
259
260//             // Compare with original matrix
261
262//             assert!(two_sided_id.to_mat().rel_diff(&mat) < 5.0 * $tol);
263
264//             // Now compare the individual columns to make sure that the id basis columns
265//             // agree with the corresponding matrix columns.
266
267//             let mat_permuted = mat.apply_permutation(two_sided_id.row_ind.view(), MatrixPermutationMode::ROW).
268//                 apply_permutation(two_sided_id.col_ind.view(), MatrixPermutationMode::COL);
269
270//             // Assert that the x matrix in the two sided id is squared with correct dimension.
271
272//             assert!(two_sided_id.x.nrows() == two_sided_id.x.ncols());
273//             assert!(two_sided_id.x.nrows() == rank);
274
275//             // Now compare with the original matrix.
276
277//             for row_index in 0..rank {
278//                 for col_index in 0..rank {
279//                     let tmp = (two_sided_id.x[[row_index, col_index]] - mat_permuted[[row_index, col_index]]).abs() / mat_permuted[[row_index, col_index]].abs();
280//                     println!("Rel Error {}", tmp);
281//                     //if tmp >= 5.0 * $tol {
282//                         //println!(" Rel Error {}", tmp);
283//                     //}
284
285//                     assert!((two_sided_id.x[[row_index, col_index]] - mat_permuted[[row_index, col_index]]).abs()
286//                             < 10.0 * $tol * mat_permuted[[row_index, col_index]].abs())
287//                 }
288//             }
289//         }
290
291//             )*
292
293//         }
294//     }
295
296//     id_compression_tests! {
297//         test_id_compression_by_tol_f32_thin: f32, (100, 50), 1E-4,
298//         test_id_compression_by_tol_c32_thin: ndarray_linalg::c32, (100, 50), 1E-4,
299//         test_id_compression_by_tol_f64_thin: f64, (100, 50), 1E-4,
300//         test_id_compression_by_tol_c64_thin: ndarray_linalg::c64, (100, 50), 1E-4,
301//         test_id_compression_by_tol_f32_thick: f32, (50, 100), 1E-4,
302//         test_id_compression_by_tol_c32_thick: ndarray_linalg::c32, (50, 100), 1E-4,
303//         test_id_compression_by_tol_f64_thick: f64, (50, 100), 1E-4,
304//         test_id_compression_by_tol_c64_thick: ndarray_linalg::c64, (50, 100), 1E-4,
305//     }
306// }