rlst/dense/linalg/lapack/interface/
geqrf.rs

1//! Implementation of ?geqrf - QR factorization.
2
3use lapack::{cgeqrf, dgeqrf, sgeqrf, zgeqrf};
4
5use crate::base_types::LapackResult;
6use crate::base_types::{LapackError, c32, c64};
7
8use crate::dense::linalg::lapack::interface::lapack_return;
9
10use num::{Zero, complex::ComplexFloat};
11
12/// ?geqrf - QR factorization.
13pub trait Geqrf: Sized {
14    /// Perform QR factorization of a matrix `a` with dimensions `m` x `n`.
15    ///
16    fn geqrf(m: usize, n: usize, a: &mut [Self], lda: usize, tau: &mut [Self]) -> LapackResult<()>;
17}
18
19macro_rules! implement_geqrf {
20    ($scalar:ty, $geqrf:expr) => {
21        impl Geqrf for $scalar {
22            fn geqrf(
23                m: usize,
24                n: usize,
25                a: &mut [Self],
26                lda: usize,
27                tau: &mut [Self],
28            ) -> LapackResult<()> {
29                assert_eq!(
30                    lda * n,
31                    a.len(),
32                    "Require `lda * n` {} == `a.len()` {}.",
33                    lda * n,
34                    a.len()
35                );
36
37                assert!(
38                    lda >= std::cmp::max(1, m),
39                    "Require `lda` {} >= `max(1, m)` {}.",
40                    lda,
41                    std::cmp::max(1, m)
42                );
43                let k = std::cmp::min(m, n);
44
45                assert_eq!(
46                    tau.len(),
47                    k,
48                    "Require `tau.len()` {} == `min(m, n)` {}.",
49                    tau.len(),
50                    k
51                );
52
53                let mut info = 0;
54
55                let mut work = vec![<$scalar>::zero(); 1];
56
57                unsafe {
58                    $geqrf(
59                        m as i32, n as i32, a, lda as i32, tau, &mut work, -1, &mut info,
60                    );
61                }
62
63                if info != 0 {
64                    Err(LapackError::LapackInfoCode(info))
65                } else {
66                    let lwork = work[0].re() as usize;
67                    let mut work = vec![<$scalar>::zero(); lwork];
68
69                    unsafe {
70                        $geqrf(
71                            m as i32,
72                            n as i32,
73                            a,
74                            lda as i32,
75                            tau,
76                            &mut work,
77                            lwork as i32,
78                            &mut info,
79                        );
80                    }
81                    lapack_return(info, ())
82                }
83            }
84        }
85    };
86}
87
88implement_geqrf!(f32, sgeqrf);
89implement_geqrf!(f64, dgeqrf);
90implement_geqrf!(c32, cgeqrf);
91implement_geqrf!(c64, zgeqrf);