rlst/dense/linalg/lapack/interface/
geqrf.rs1use lapack::{cgeqrf, dgeqrf, sgeqrf, zgeqrf};
4
5use crate::base_types::LapackResult;
6use crate::base_types::{LapackError, c32, c64};
7
8use crate::dense::linalg::lapack::interface::lapack_return;
9
10use num::{Zero, complex::ComplexFloat};
11
12pub trait Geqrf: Sized {
14 fn geqrf(m: usize, n: usize, a: &mut [Self], lda: usize, tau: &mut [Self]) -> LapackResult<()>;
17}
18
19macro_rules! implement_geqrf {
20 ($scalar:ty, $geqrf:expr) => {
21 impl Geqrf for $scalar {
22 fn geqrf(
23 m: usize,
24 n: usize,
25 a: &mut [Self],
26 lda: usize,
27 tau: &mut [Self],
28 ) -> LapackResult<()> {
29 assert_eq!(
30 lda * n,
31 a.len(),
32 "Require `lda * n` {} == `a.len()` {}.",
33 lda * n,
34 a.len()
35 );
36
37 assert!(
38 lda >= std::cmp::max(1, m),
39 "Require `lda` {} >= `max(1, m)` {}.",
40 lda,
41 std::cmp::max(1, m)
42 );
43 let k = std::cmp::min(m, n);
44
45 assert_eq!(
46 tau.len(),
47 k,
48 "Require `tau.len()` {} == `min(m, n)` {}.",
49 tau.len(),
50 k
51 );
52
53 let mut info = 0;
54
55 let mut work = vec![<$scalar>::zero(); 1];
56
57 unsafe {
58 $geqrf(
59 m as i32, n as i32, a, lda as i32, tau, &mut work, -1, &mut info,
60 );
61 }
62
63 if info != 0 {
64 Err(LapackError::LapackInfoCode(info))
65 } else {
66 let lwork = work[0].re() as usize;
67 let mut work = vec![<$scalar>::zero(); lwork];
68
69 unsafe {
70 $geqrf(
71 m as i32,
72 n as i32,
73 a,
74 lda as i32,
75 tau,
76 &mut work,
77 lwork as i32,
78 &mut info,
79 );
80 }
81 lapack_return(info, ())
82 }
83 }
84 }
85 };
86}
87
88implement_geqrf!(f32, sgeqrf);
89implement_geqrf!(f64, dgeqrf);
90implement_geqrf!(c32, cgeqrf);
91implement_geqrf!(c64, zgeqrf);