rusty_compression/two_sided_interp_decomp.rs
1//! Implementation of the two sided interpolative decomposition
2//!
3//! The two sided interpolative decomposition of a matrix $A\in\mathbb{C}&{m\times n}$ is
4//! defined as
5//! $$
6//! A \approx CXR,
7//! $$
8//! where $C\in\mathbb{C}^{m\times k}$, $X\in\mathbb{C}^{k\times k}$, and $R\in\mathbb{C}^{k\times n}$.
9//! The matrix $X$ contains a subset of the entries of $A$, such that A\[row_ind\[:\], col_ind\[:\]\] = X, where
10//! row_ind and col_ind are index vectors.
11
12use crate::types::Apply;
13use ndarray::{
14 Array1, Array2, ArrayBase, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Data, Ix1, Ix2,
15};
16use crate::types::{c32, c64, Scalar};
17
18/// Store a two sided interpolative decomposition
19pub struct TwoSidedID<A: Scalar> {
20 /// The C matrix of the two sided interpolative decomposition
21 pub c: Array2<A>,
22 /// The X matrix of the two sided interpolative decomposition
23 pub x: Array2<A>,
24 /// The R matrix of the two sided interpolative decomposition
25 pub r: Array2<A>,
26 /// The row index vector
27 pub row_ind: Array1<usize>,
28 /// The column index vector
29 pub col_ind: Array1<usize>,
30}
31
32/// Traits defining a two sided interpolative decomposition
33///
34/// defined as
35/// The two sided interpolative decomposition of a matrix $A\in\mathbb{C}&{m\times n} is
36/// $$
37/// A \approx CXR,
38/// $$
39/// where $C\in\mathbb{C}^{m\times k}$, $X\in\mathbb{C}^{k\times k}$, and $R\in\mathbb{C}^{k\times n}$.
40/// The matrix $X$ contains a subset of the entries of $A$, such that A\[row_ind\[:\], col_ind\[:\]\] = X, where
41/// row_ind and col_ind are index vectors.
42
43pub trait TwoSidedIDTraits {
44 type A: Scalar;
45
46 /// Number of rows of the underlying operator
47 fn nrows(&self) -> usize {
48 self.get_c().nrows()
49 }
50
51 /// Number of columns of the underlying operator
52 fn ncols(&self) -> usize {
53 self.get_r().ncols()
54 }
55
56 /// Rank of the two sided interpolative decomposition
57 fn rank(&self) -> usize {
58 self.get_c().ncols()
59 }
60
61 /// Convert to a matrix
62 fn to_mat(&self) -> Array2<Self::A> {
63 self.get_c().dot(&self.get_x().dot(&self.get_r()))
64 }
65
66 /// Return the C matrix
67 fn get_c(&self) -> ArrayView2<Self::A>;
68
69 /// Return the X matrix
70 fn get_x(&self) -> ArrayView2<Self::A>;
71
72 /// Return the R matrix
73 fn get_r(&self) -> ArrayView2<Self::A>;
74
75 /// Return the column index vector
76 fn get_col_ind(&self) -> ArrayView1<usize>;
77
78 /// Return the row index vector
79 fn get_row_ind(&self) -> ArrayView1<usize>;
80
81 fn get_c_mut(&mut self) -> ArrayViewMut2<Self::A>;
82 fn get_x_mut(&mut self) -> ArrayViewMut2<Self::A>;
83 fn get_r_mut(&mut self) -> ArrayViewMut2<Self::A>;
84 fn get_col_ind_mut(&mut self) -> ArrayViewMut1<usize>;
85 fn get_row_ind_mut(&mut self) -> ArrayViewMut1<usize>;
86
87 /// Return a two sided interpolative decomposition from the component matrices
88 /// X, R, C, and the column and row index vectors
89 fn new(
90 x: Array2<Self::A>,
91 r: Array2<Self::A>,
92 c: Array2<Self::A>,
93 col_ind: Array1<usize>,
94 row_ind: Array1<usize>,
95 ) -> Self;
96}
97
98macro_rules! impl_two_sided_id {
99 ($scalar:ty) => {
100 impl TwoSidedIDTraits for TwoSidedID<$scalar> {
101 type A = $scalar;
102
103 fn get_c(&self) -> ArrayView2<Self::A> {
104 self.c.view()
105 }
106
107 fn get_x(&self) -> ArrayView2<Self::A> {
108 self.x.view()
109 }
110
111 fn get_r(&self) -> ArrayView2<Self::A> {
112 self.r.view()
113 }
114 fn get_col_ind(&self) -> ArrayView1<usize> {
115 self.col_ind.view()
116 }
117 fn get_row_ind(&self) -> ArrayView1<usize> {
118 self.row_ind.view()
119 }
120
121 fn get_c_mut(&mut self) -> ArrayViewMut2<Self::A> {
122 self.c.view_mut()
123 }
124
125 fn get_x_mut(&mut self) -> ArrayViewMut2<Self::A> {
126 self.x.view_mut()
127 }
128
129 fn get_r_mut(&mut self) -> ArrayViewMut2<Self::A> {
130 self.r.view_mut()
131 }
132 fn get_col_ind_mut(&mut self) -> ArrayViewMut1<usize> {
133 self.col_ind.view_mut()
134 }
135 fn get_row_ind_mut(&mut self) -> ArrayViewMut1<usize> {
136 self.row_ind.view_mut()
137 }
138 fn new(
139 x: Array2<Self::A>,
140 r: Array2<Self::A>,
141 c: Array2<Self::A>,
142 col_ind: Array1<usize>,
143 row_ind: Array1<usize>,
144 ) -> Self {
145 TwoSidedID::<$scalar> {
146 x,
147 r,
148 c,
149 col_ind,
150 row_ind,
151 }
152 }
153 }
154 impl<S> Apply<$scalar, ArrayBase<S, Ix1>> for TwoSidedID<$scalar>
155 where
156 S: Data<Elem = $scalar>,
157 {
158 type Output = Array1<$scalar>;
159 fn dot(&self, rhs: &ArrayBase<S, Ix1>) -> Self::Output {
160 self.c.dot(&self.x.dot(&self.r.dot(rhs)))
161 }
162 }
163 impl<S> Apply<$scalar, ArrayBase<S, Ix2>> for TwoSidedID<$scalar>
164 where
165 S: Data<Elem = $scalar>,
166 {
167 type Output = Array2<$scalar>;
168 fn dot(&self, rhs: &ArrayBase<S, Ix2>) -> Self::Output {
169 self.c.dot(&self.x.dot(&self.r.dot(rhs)))
170 }
171 }
172 };
173}
174
175impl_two_sided_id!(f32);
176impl_two_sided_id!(f64);
177impl_two_sided_id!(c32);
178impl_two_sided_id!(c64);
179
180// impl<A: ScalarType> TwoSidedIDResult<A> {
181// pub fn nrows(&self) -> usize {
182// self.c.nrows()
183// }
184
185// pub fn ncols(&self) -> usize {
186// self.r.ncols()
187// }
188
189// pub fn rank(&self) -> usize {
190// self.x.nrows()
191// }
192
193// pub fn to_mat(&self) -> Array2<A> {
194// self.c.dot(&self.x.dot(&self.r))
195// }
196
197// pub fn apply_matrix<S: Data<Elem = A>>(
198// &self,
199// other: &ArrayBase<S, Ix2>,
200// ) -> ArrayBase<OwnedRepr<A>, Ix2> {
201// self.c.dot(&self.x.dot(&self.r.dot(other)))
202// }
203
204// pub fn apply_vector<S: Data<Elem = A>>(
205// &self,
206// other: &ArrayBase<S, Ix1>,
207// ) -> ArrayBase<OwnedRepr<A>, Ix1> {
208// self.c.dot(&self.x.dot(&self.r.dot(other)))
209// }
210
211// //}
212// }
213
214// impl<A: ScalarType> QRContainer<A> {
215// pub fn two_sided_id(&self) -> Result<TwoSidedIDResult<A>> {
216// let col_id = self.column_id()?;
217// let row_id = col_id.c.pivoted_lq()?.row_id()?;
218
219// Ok(TwoSidedIDResult {
220// c: row_id.x,
221// x: row_id.r,
222// r: col_id.z,
223// row_ind: row_id.row_ind,
224// col_ind: col_id.col_ind,
225// })
226// }
227// }
228
229// #[cfg(test)]
230// mod tests {
231
232// use crate::prelude::ApplyPermutationToMatrix;
233// use crate::prelude::CompressionType;
234// use crate::prelude::MatrixPermutationMode;
235// use crate::prelude::PivotedQR;
236// use crate::prelude::RandomMatrix;
237// use crate::prelude::RelDiff;
238// use ndarray_linalg::Scalar;
239
240// macro_rules! id_compression_tests {
241
242// ($($name:ident: $scalar:ty, $dim:expr, $tol:expr,)*) => {
243
244// $(
245
246// #[test]
247// fn $name() {
248// let m = $dim.0;
249// let n = $dim.1;
250
251// let sigma_max = 1.0;
252// let sigma_min = 1E-10;
253// let mut rng = rand::thread_rng();
254// let mat = <$scalar>::random_approximate_low_rank_matrix((m, n), sigma_max, sigma_min, &mut rng);
255
256// let qr = mat.pivoted_qr().unwrap().compress(CompressionType::ADAPTIVE($tol)).unwrap();
257// let rank = qr.rank();
258// let two_sided_id = qr.two_sided_id().unwrap();
259
260// // Compare with original matrix
261
262// assert!(two_sided_id.to_mat().rel_diff(&mat) < 5.0 * $tol);
263
264// // Now compare the individual columns to make sure that the id basis columns
265// // agree with the corresponding matrix columns.
266
267// let mat_permuted = mat.apply_permutation(two_sided_id.row_ind.view(), MatrixPermutationMode::ROW).
268// apply_permutation(two_sided_id.col_ind.view(), MatrixPermutationMode::COL);
269
270// // Assert that the x matrix in the two sided id is squared with correct dimension.
271
272// assert!(two_sided_id.x.nrows() == two_sided_id.x.ncols());
273// assert!(two_sided_id.x.nrows() == rank);
274
275// // Now compare with the original matrix.
276
277// for row_index in 0..rank {
278// for col_index in 0..rank {
279// let tmp = (two_sided_id.x[[row_index, col_index]] - mat_permuted[[row_index, col_index]]).abs() / mat_permuted[[row_index, col_index]].abs();
280// println!("Rel Error {}", tmp);
281// //if tmp >= 5.0 * $tol {
282// //println!(" Rel Error {}", tmp);
283// //}
284
285// assert!((two_sided_id.x[[row_index, col_index]] - mat_permuted[[row_index, col_index]]).abs()
286// < 10.0 * $tol * mat_permuted[[row_index, col_index]].abs())
287// }
288// }
289// }
290
291// )*
292
293// }
294// }
295
296// id_compression_tests! {
297// test_id_compression_by_tol_f32_thin: f32, (100, 50), 1E-4,
298// test_id_compression_by_tol_c32_thin: ndarray_linalg::c32, (100, 50), 1E-4,
299// test_id_compression_by_tol_f64_thin: f64, (100, 50), 1E-4,
300// test_id_compression_by_tol_c64_thin: ndarray_linalg::c64, (100, 50), 1E-4,
301// test_id_compression_by_tol_f32_thick: f32, (50, 100), 1E-4,
302// test_id_compression_by_tol_c32_thick: ndarray_linalg::c32, (50, 100), 1E-4,
303// test_id_compression_by_tol_f64_thick: f64, (50, 100), 1E-4,
304// test_id_compression_by_tol_c64_thick: ndarray_linalg::c64, (50, 100), 1E-4,
305// }
306// }