1use crate::col_interp_decomp::{ColumnID, ColumnIDTraits};
21use crate::permutation::{ApplyPermutationToMatrix, MatrixPermutationMode};
22use crate::pivoted_qr::PivotedQR;
23use crate::row_interp_decomp::{RowID, RowIDTraits};
24use crate::CompressionType;
25use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Axis};
26use ndarray_linalg::{Diag, SolveTriangular, UPLO};
27use num::ToPrimitive;
28use crate::types::{c32, c64, Result, Scalar};
29use crate::types::{ConjMatMat, RustyCompressionError};
30
31pub struct QR<A: Scalar> {
32 pub q: Array2<A>,
34 pub r: Array2<A>,
36 pub ind: Array1<usize>,
40}
41
42pub struct LQ<A: Scalar> {
43 pub l: Array2<A>,
45 pub q: Array2<A>,
47 pub ind: Array1<usize>,
51}
52
53pub trait LQTraits {
55 type A: Scalar;
56
57 fn nrows(&self) -> usize {
59 self.get_l().nrows()
60 }
61
62 fn ncols(&self) -> usize {
64 self.get_q().ncols()
65 }
66
67 fn rank(&self) -> usize {
69 self.get_q().nrows()
70 }
71
72 fn to_mat(&self) -> Array2<Self::A> {
74 self.get_l()
75 .apply_permutation(self.get_ind(), MatrixPermutationMode::ROWINV)
76 .dot(&self.get_q())
77 }
78
79 fn compress_lq_rank(&self, mut max_rank: usize) -> Result<LQ<Self::A>> {
81 let (l, q, ind) = (self.get_l(), self.get_q(), self.get_ind());
82
83 if max_rank > q.nrows() {
84 max_rank = q.nrows()
85 }
86
87 let q = q.slice(s![0..max_rank, ..]);
88 let l = l.slice(s![.., 0..max_rank]);
89
90 Ok(LQ {
91 l: l.into_owned(),
92 q: q.into_owned(),
93 ind: ind.into_owned(),
94 })
95 }
96
97 fn compress_lq_tolerance(&self, tol: f64) -> Result<LQ<Self::A>> {
99 assert!((tol < 1.0) && (0.0 <= tol), "Require 0 <= tol < 1.0");
100
101 let pos = self
102 .get_l()
103 .diag()
104 .iter()
105 .position(|&item| ((item / self.get_l()[[0, 0]]).abs()).to_f64().unwrap() < tol);
106
107 match pos {
108 Some(index) => self.compress_lq_rank(index),
109 None => Err(RustyCompressionError::CompressionError),
110 }
111 }
112
113 fn compress(&self, compression_type: CompressionType) -> Result<LQ<Self::A>> {
115 match compression_type {
116 CompressionType::ADAPTIVE(tol) => self.compress_lq_tolerance(tol),
117 CompressionType::RANK(rank) => self.compress_lq_rank(rank),
118 }
119 }
120
121 fn get_q(&self) -> ArrayView2<Self::A>;
123
124 fn get_l(&self) -> ArrayView2<Self::A>;
126
127 fn get_ind(&self) -> ArrayView1<usize>;
129
130 fn get_q_mut(&mut self) -> ArrayViewMut2<Self::A>;
131 fn get_l_mut(&mut self) -> ArrayViewMut2<Self::A>;
132 fn get_ind_mut(&mut self) -> ArrayViewMut1<usize>;
133
134 fn compute_from(arr: ArrayView2<Self::A>) -> Result<LQ<Self::A>>;
136
137 fn row_id(&self) -> Result<RowID<Self::A>>;
139}
140
141pub trait QRTraits {
142 type A: Scalar;
143
144 fn nrows(&self) -> usize {
146 self.get_q().nrows()
147 }
148
149 fn ncols(&self) -> usize {
151 self.get_r().ncols()
152 }
153
154 fn rank(&self) -> usize {
156 self.get_q().ncols()
157 }
158
159 fn to_mat(&self) -> Array2<Self::A> {
161 self.get_q().dot(
162 &self
163 .get_r()
164 .apply_permutation(self.get_ind(), MatrixPermutationMode::COLINV),
165 )
166 }
167
168 fn compress_qr_rank(&self, mut max_rank: usize) -> Result<QR<Self::A>> {
170 let (q, r, ind) = (self.get_q(), self.get_r(), self.get_ind());
171
172 if max_rank > q.ncols() {
173 max_rank = q.ncols()
174 }
175
176 let q = q.slice(s![.., 0..max_rank]);
177 let r = r.slice(s![0..max_rank, ..]);
178
179 Ok(QR {
180 q: q.into_owned(),
181 r: r.into_owned(),
182 ind: ind.into_owned(),
183 })
184 }
185
186 fn compress_qr_tolerance(&self, tol: f64) -> Result<QR<Self::A>> {
188 assert!((tol < 1.0) && (0.0 <= tol), "Require 0 <= tol < 1.0");
189
190 let pos = self
191 .get_r()
192 .diag()
193 .iter()
194 .position(|&item| ((item / self.get_r()[[0, 0]]).abs()).to_f64().unwrap() < tol);
195
196 match pos {
197 Some(index) => self.compress_qr_rank(index),
198 None => Err(RustyCompressionError::CompressionError),
199 }
200 }
201
202 fn compress(&self, compression_type: CompressionType) -> Result<QR<Self::A>> {
204 match compression_type {
205 CompressionType::ADAPTIVE(tol) => self.compress_qr_tolerance(tol),
206 CompressionType::RANK(rank) => self.compress_qr_rank(rank),
207 }
208 }
209
210 fn column_id(&self) -> Result<ColumnID<Self::A>>;
212
213 fn compute_from(arr: ArrayView2<Self::A>) -> Result<QR<Self::A>>;
215
216 fn compute_from_range_estimate<Op: ConjMatMat<A = Self::A>>(
222 range: ArrayView2<Self::A>,
223 op: &Op,
224 ) -> Result<QR<Self::A>>;
225
226 fn get_q(&self) -> ArrayView2<Self::A>;
228
229 fn get_r(&self) -> ArrayView2<Self::A>;
231
232 fn get_ind(&self) -> ArrayView1<usize>;
234
235 fn get_q_mut(&mut self) -> ArrayViewMut2<Self::A>;
236 fn get_r_mut(&mut self) -> ArrayViewMut2<Self::A>;
237 fn get_ind_mut(&mut self) -> ArrayViewMut1<usize>;
238}
239
240macro_rules! qr_data_impl {
241 ($scalar:ty) => {
242 impl QRTraits for QR<$scalar> {
243 type A = $scalar;
244 fn get_q(&self) -> ArrayView2<Self::A> {
245 self.q.view()
246 }
247 fn get_r(&self) -> ArrayView2<Self::A> {
248 self.r.view()
249 }
250
251 fn compute_from(arr: ArrayView2<Self::A>) -> Result<QR<Self::A>> {
252 <$scalar>::pivoted_qr(arr)
253 }
254
255 fn get_ind(&self) -> ArrayView1<usize> {
256 self.ind.view()
257 }
258
259 fn get_q_mut(&mut self) -> ArrayViewMut2<Self::A> {
260 self.q.view_mut()
261 }
262 fn get_r_mut(&mut self) -> ArrayViewMut2<Self::A> {
263 self.r.view_mut()
264 }
265
266 fn get_ind_mut(&mut self) -> ArrayViewMut1<usize> {
267 self.ind.view_mut()
268 }
269
270 fn column_id(&self) -> Result<ColumnID<Self::A>> {
271 let rank = self.rank();
272 let nrcols = self.ncols();
273
274 if rank == nrcols {
275 Ok(ColumnID::<Self::A>::new(
277 self.get_q().dot(&self.get_r()),
278 Array2::<Self::A>::eye(rank)
279 .apply_permutation(self.get_ind(), MatrixPermutationMode::COLINV),
280 self.get_ind().into_owned(),
281 ))
282 } else {
283 let mut z = Array2::<Self::A>::zeros((rank, self.get_r().ncols()));
286 z.slice_mut(s![.., 0..rank]).diag_mut().fill(num::one());
287 let first_part = self.get_r().slice(s![.., 0..rank]).to_owned();
288 let c = self.get_q().dot(&first_part);
289
290 for (index, col) in self
291 .get_r()
292 .slice(s![.., rank..nrcols])
293 .axis_iter(Axis(1))
294 .enumerate()
295 {
296 z.index_axis_mut(Axis(1), rank + index).assign(
297 &first_part
298 .solve_triangular(UPLO::Upper, Diag::NonUnit, &col.to_owned())
299 .unwrap(),
300 );
301 }
302
303 Ok(ColumnID::<Self::A>::new(
304 c,
305 z.apply_permutation(self.get_ind(), MatrixPermutationMode::COLINV),
306 self.get_ind().into_owned(),
307 ))
308 }
309 }
310
311 fn compute_from_range_estimate<Op: ConjMatMat<A = Self::A>>(
312 range: ArrayView2<Self::A>,
313 op: &Op,
314 ) -> Result<QR<Self::A>> {
315 let b = op.conj_matmat(range).t().map(|item| item.conj());
316 let qr = QR::<$scalar>::compute_from(b.view())?;
317
318 Ok(QR {
319 q: range.dot(&qr.get_q()),
320 r: qr.get_r().into_owned(),
321 ind: qr.get_ind().into_owned(),
322 })
323 }
324 }
325 };
326}
327
328macro_rules! lq_data_impl {
329 ($scalar:ty) => {
330 impl LQTraits for LQ<$scalar> {
331 type A = $scalar;
332
333 fn get_q(&self) -> ArrayView2<Self::A> {
334 self.q.view()
335 }
336
337 fn get_l(&self) -> ArrayView2<Self::A> {
338 self.l.view()
339 }
340 fn get_ind(&self) -> ArrayView1<usize> {
341 self.ind.view()
342 }
343
344 fn get_q_mut(&mut self) -> ArrayViewMut2<Self::A> {
345 self.q.view_mut()
346 }
347 fn get_l_mut(&mut self) -> ArrayViewMut2<Self::A> {
348 self.l.view_mut()
349 }
350 fn get_ind_mut(&mut self) -> ArrayViewMut1<usize> {
351 self.ind.view_mut()
352 }
353
354 fn compute_from(arr: ArrayView2<Self::A>) -> Result<LQ<Self::A>> {
355 let arr_trans = arr.t().map(|val| val.conj());
356 let qr = QR::<$scalar>::compute_from(arr_trans.view())?;
357 Ok(LQ {
358 l: qr.r.t().map(|item| item.conj()),
359 q: qr.q.t().map(|item| item.conj()),
360 ind: qr.ind,
361 })
362 }
363 fn row_id(&self) -> Result<RowID<Self::A>> {
364 let rank = self.rank();
365 let nrows = self.nrows();
366
367 if rank == nrows {
368 Ok(RowID::<Self::A>::new(
370 Array2::<Self::A>::eye(rank)
371 .apply_permutation(self.ind.view(), MatrixPermutationMode::ROWINV),
372 self.l.dot(&self.q),
373 self.ind.clone(),
374 ))
375 } else {
376 let mut x = Array2::<Self::A>::zeros((self.nrows(), rank));
379 x.slice_mut(s![0..rank, ..]).diag_mut().fill(num::one());
380 let first_part = self.l.slice(s![0..rank, ..]).to_owned();
381 let r = first_part.dot(&self.q);
382 let first_part_transposed = first_part.t().to_owned();
383
384 for (index, row) in self
385 .l
386 .slice(s![rank..nrows, ..])
387 .axis_iter(Axis(0))
388 .enumerate()
389 {
390 x.index_axis_mut(Axis(0), rank + index).assign(
391 &first_part_transposed
392 .solve_triangular(UPLO::Upper, Diag::NonUnit, &row.to_owned())
393 .unwrap(),
394 );
395 }
396
397 Ok(RowID::<Self::A>::new(
398 x.apply_permutation(self.ind.view(), MatrixPermutationMode::ROWINV),
399 r,
400 self.ind.clone(),
401 ))
402 }
403 }
404 }
405 };
406}
407
408qr_data_impl!(f32);
409qr_data_impl!(f64);
410qr_data_impl!(c32);
411qr_data_impl!(c64);
412
413lq_data_impl!(f32);
414lq_data_impl!(f64);
415lq_data_impl!(c32);
416lq_data_impl!(c64);
417
418#[cfg(test)]
419mod tests {
420
421 use super::*;
422 use crate::types::RelDiff;
423 use crate::pivoted_qr::PivotedQR;
424 use crate::random_matrix::RandomMatrix;
425 use ndarray::Axis;
426
427 macro_rules! qr_compression_by_rank_tests {
428
429 ($($name:ident: $scalar:ty, $dim:expr, $tol:expr,)*) => {
430
431 $(
432
433 #[test]
434 fn $name() {
435 let m = $dim.0;
436 let n = $dim.1;
437 let rank: usize = 30;
438
439 let sigma_max = 1.0;
440 let sigma_min = 1E-10;
441 let mut rng = rand::thread_rng();
442 let mat = <$scalar>::random_approximate_low_rank_matrix((m, n), sigma_max, sigma_min, &mut rng);
443
444 let qr = <$scalar>::pivoted_qr(mat.view()).unwrap().compress(CompressionType::RANK(rank)).unwrap();
445
446 assert!(qr.q.len_of(Axis(1)) == rank);
449 assert!(qr.r.len_of(Axis(0)) == rank);
450 assert!(<$scalar>::rel_diff_fro(qr.to_mat().view(), mat.view()) < $tol);
451
452 }
453
454 )*
455
456 }
457 }
458
459 macro_rules! qr_compression_by_tol_tests {
460
461 ($($name:ident: $scalar:ty, $dim:expr, $tol:expr,)*) => {
462
463 $(
464
465 #[test]
466 fn $name() {
467 let m = $dim.0;
468 let n = $dim.1;
469
470 let sigma_max = 1.0;
471 let sigma_min = 1E-10;
472 let mut rng = rand::thread_rng();
473 let mat = <$scalar>::random_approximate_low_rank_matrix((m, n), sigma_max, sigma_min, &mut rng);
474
475 let qr = <$scalar>::pivoted_qr(mat.view()).unwrap().compress(CompressionType::ADAPTIVE($tol)).unwrap();
476
477 assert!(<$scalar>::rel_diff_fro(qr.to_mat().view(), mat.view()) < 5.0 * $tol);
480
481 assert!(qr.q.ncols() < m.min(n));
484 }
485
486 )*
487
488 }
489 }
490
491 macro_rules! col_id_compression_tests {
492
493 ($($name:ident: $scalar:ty, $dim:expr, $tol:expr,)*) => {
494
495 $(
496
497 #[test]
498 fn $name() {
499 let m = $dim.0;
500 let n = $dim.1;
501
502 let sigma_max = 1.0;
503 let sigma_min = 1E-10;
504 let mut rng = rand::thread_rng();
505 let mat = <$scalar>::random_approximate_low_rank_matrix((m, n), sigma_max, sigma_min, &mut rng);
506
507 let qr = QR::<$scalar>::compute_from(mat.view()).unwrap().compress(CompressionType::ADAPTIVE($tol)).unwrap();
508 let rank = qr.rank();
509 let column_id = qr.column_id().unwrap();
510
511 assert!(<$scalar>::rel_diff_fro(column_id.to_mat().view(), mat.view()) < 5.0 * $tol);
514
515 let mat_permuted = mat.apply_permutation(column_id.get_col_ind(), MatrixPermutationMode::COL);
519
520 for index in 0..rank {
521 assert!(
522 <$scalar>::rel_diff_l2(mat_permuted.index_axis(Axis(1), index), column_id.get_c().index_axis(Axis(1), index)) < $tol);
523
524 }
525
526 }
527
528 )*
529
530 }
531 }
532 macro_rules! row_id_compression_tests {
533
534 ($($name:ident: $scalar:ty, $dim:expr, $tol:expr,)*) => {
535
536 $(
537
538 #[test]
539 fn $name() {
540 let m = $dim.0;
541 let n = $dim.1;
542
543 let sigma_max = 1.0;
544 let sigma_min = 1E-10;
545 let mut rng = rand::thread_rng();
546 let mat = <$scalar>::random_approximate_low_rank_matrix((m, n), sigma_max, sigma_min, &mut rng);
547
548 let lq = LQ::<$scalar>::compute_from(mat.view()).unwrap().compress(CompressionType::ADAPTIVE($tol)).unwrap();
549 let rank = lq.rank();
550 let row_id = lq.row_id().unwrap();
551
552 assert!(<$scalar>::rel_diff_fro(row_id.to_mat().view(), mat.view()) < 5.0 * $tol);
555
556 let mat_permuted = mat.apply_permutation(row_id.get_row_ind(), MatrixPermutationMode::ROW);
560
561 for index in 0..rank {
562 assert!(<$scalar>::rel_diff_l2(mat_permuted.index_axis(Axis(0), index), row_id.get_r().index_axis(Axis(0), index)) < $tol);
563
564 }
565
566 }
567
568 )*
569
570 }
571 }
572
573 row_id_compression_tests! {
574 test_row_id_compression_by_tol_f32_thin: f32, (100, 50), 1E-4,
575 test_row_id_compression_by_tol_c32_thin: ndarray_linalg::c32, (100, 50), 1E-4,
576 test_row_id_compression_by_tol_f64_thin: f64, (100, 50), 1E-4,
577 test_row_id_compression_by_tol_c64_thin: ndarray_linalg::c64, (100, 50), 1E-4,
578 test_row_id_compression_by_tol_f32_thick: f32, (50, 100), 1E-4,
579 test_row_id_compression_by_tol_c32_thick: ndarray_linalg::c32, (50, 100), 1E-4,
580 test_row_id_compression_by_tol_f64_thick: f64, (50, 100), 1E-4,
581 test_row_id_compression_by_tol_c64_thick: ndarray_linalg::c64, (50, 100), 1E-4,
582 }
583
584 col_id_compression_tests! {
585 test_col_id_compression_by_tol_f32_thin: f32, (100, 50), 1E-4,
586 test_col_id_compression_by_tol_c32_thin: ndarray_linalg::c32, (100, 50), 1E-4,
587 test_col_id_compression_by_tol_f64_thin: f64, (100, 50), 1E-4,
588 test_col_id_compression_by_tol_c64_thin: ndarray_linalg::c64, (100, 50), 1E-4,
589 test_col_id_compression_by_tol_f32_thick: f32, (50, 100), 1E-4,
590 test_col_id_compression_by_tol_c32_thick: ndarray_linalg::c32, (50, 100), 1E-4,
591 test_col_id_compression_by_tol_f64_thick: f64, (50, 100), 1E-4,
592 test_col_id_compression_by_tol_c64_thick: ndarray_linalg::c64, (50, 100), 1E-4,
593 }
594
595 qr_compression_by_rank_tests! {
596 test_qr_compression_by_rank_f32_thin: f32, (100, 50), 1E-4,
597
598 test_qr_compression_by_rank_f64_thin: f64, (100, 50), 1E-4,
599 test_qr_compression_by_rank_c64_thin: ndarray_linalg::c64, (100, 50), 1E-4,
600 test_qr_compression_by_rank_f32_thick: f32, (50, 100), 1E-4,
601 test_qr_compression_by_rank_c32_thick: ndarray_linalg::c32, (50, 100), 1E-4,
602 test_qr_compression_by_rank_f64_thick: f64, (50, 100), 1E-4,
603 test_qr_compression_by_rank_c64_thick: ndarray_linalg::c64, (50, 100), 1E-4,
604 }
605
606 qr_compression_by_tol_tests! {
607 test_qr_compression_by_tol_f32_thin: f32, (100, 50), 1E-4,
608 test_qr_compression_by_tol_c32_thin: ndarray_linalg::c32, (100, 50), 1E-4,
609 test_qr_compression_by_tol_f64_thin: f64, (100, 50), 1E-4,
610 test_qr_compression_by_tol_c64_thin: ndarray_linalg::c64, (100, 50), 1E-4,
611 test_qr_compression_by_tol_f32_thick: f32, (50, 100), 1E-4,
612 test_qr_compression_by_tol_c32_thick: ndarray_linalg::c32, (50, 100), 1E-4,
613 test_qr_compression_by_tol_f64_thick: f64, (50, 100), 1E-4,
614 test_qr_compression_by_tol_c64_thick: ndarray_linalg::c64, (50, 100), 1E-4,
615 }
616}