nalgebra_lapack/
qr_util.rs

1use crate::{DiagonalKind, LapackErrorCode, Side, Transposition, TriangularStructure, qr::QrReal};
2use na::{
3    Dim, DimMin, DimMinimum, IsContiguous, Matrix, RawStorage, RawStorageMut, RealField, Vector,
4};
5use num::{ConstOne, Zero};
6
7/// Error type for QR decomposition operations.
8#[derive(Debug, PartialEq, thiserror::Error)]
9pub enum Error {
10    /// Incorrect matrix dimensions.
11    #[error("incorrect matrix dimensions")]
12    Dimensions,
13    /// LAPACK backend returned error.
14    #[error("Lapack returned with error: {0}")]
15    Lapack(#[from] LapackErrorCode),
16    /// QR decomposition for underdetermined systems not supported.
17    #[error("QR decomposition for underdetermined systems not supported")]
18    Underdetermined,
19    /// Matrix has rank zero.
20    #[error("Matrix has rank zero")]
21    ZeroRank,
22}
23
24/// Thin wrapper around certain invocation of `multiply_q_mut`, where:
25/// * `qr`: contains the LAPACK-style QR decomposition of a matrix A
26/// * `tau`: scalar factors of the elementary reflectors
27/// * `b`: matrix B described below
28///
29/// Efficiently calculate the matrix product `Q B` of the factor `Q` with a
30/// given matrix `B`. `Q` acts as if it is a matrix of dimension `m ⨯ m`, so
31/// we require `B ∈ R^(m ⨯ k)`. The product is calculated in place and
32/// must only be considered valid when the function returns without error.
33pub(crate) fn q_mul_mut<T, R1, C1, S1, C2, S2, S3>(
34    qr: &Matrix<T, R1, C1, S1>,
35    tau: &Vector<T, DimMinimum<R1, C1>, S3>,
36    b: &mut Matrix<T, R1, C2, S2>,
37) -> Result<(), Error>
38where
39    T: QrReal + Zero + RealField,
40    R1: DimMin<C1>,
41    C1: Dim,
42    S1: RawStorage<T, R1, C1> + IsContiguous,
43    C2: Dim,
44    S2: RawStorageMut<T, R1, C2> + IsContiguous,
45    S3: RawStorage<T, <R1 as DimMin<C1>>::Output> + IsContiguous,
46{
47    if b.nrows() != qr.nrows() {
48        return Err(Error::Dimensions);
49    }
50    if qr.ncols().min(qr.nrows()) != tau.len() {
51        return Err(Error::Dimensions);
52    }
53    // SAFETY: matrix has the correct dimensions for operation Q*B
54    unsafe { multiply_q_mut(qr, tau, b, Side::Left, Transposition::No)? };
55    Ok(())
56}
57
58/// Thin wrapper around certain invokation of `multiply_q_mut`, where:
59/// * `qr`: contains the lapack-style qr decomposition of a matrix A
60/// * `tau`: scalar factors of the elementary reflectors
61/// * `b`: matrix B described below
62///
63/// Efficiently calculate the matrix product `Q^T B` of the factor `Q` with a
64/// given matrix `B`. `Q` acts as if it is a matrix of dimension `m ⨯ m`, so
65/// we require `B ∈ R^(m ⨯ k)`. The product is calculated in place and
66/// must only be considered valid when the function returns without error.
67pub(crate) fn q_tr_mul_mut<T, R1, C1, S1, C2, S2, S3>(
68    qr: &Matrix<T, R1, C1, S1>,
69    tau: &Vector<T, DimMinimum<R1, C1>, S3>,
70    b: &mut Matrix<T, R1, C2, S2>,
71) -> Result<(), Error>
72where
73    T: QrReal + Zero + RealField,
74    R1: DimMin<C1>,
75    C1: Dim,
76    S1: RawStorage<T, R1, C1> + IsContiguous,
77    C2: Dim,
78    C2: Dim,
79    S2: RawStorageMut<T, R1, C2> + IsContiguous,
80    S3: RawStorage<T, <R1 as DimMin<C1>>::Output> + IsContiguous,
81{
82    if b.nrows() != qr.nrows() {
83        return Err(Error::Dimensions);
84    }
85    if qr.ncols().min(qr.nrows()) != tau.len() {
86        return Err(Error::Dimensions);
87    }
88    // SAFETY: matrix has the correct dimensions for operation Q^T*B
89    unsafe { multiply_q_mut(qr, tau, b, Side::Left, Transposition::Transpose)? };
90    Ok(())
91}
92
93/// Thin wrapper around certain invokation of `multiply_q_mut`, where:
94/// * `qr`: contains the lapack-style qr decomposition of a matrix A
95/// * `tau`: scalar factors of the elementary reflectors
96/// * `b`: matrix B described below
97///
98/// Efficiently calculate the matrix product `B Q` of the factor `Q` with a
99/// given matrix `B`. `Q` acts as if it is a matrix of dimension `m ⨯ m`, so
100/// we require `B ∈ R^(k ⨯ m)`. The product is calculated in place and
101/// must only be considered valid when the function returns without error.
102pub(crate) fn mul_q_mut<T, R1, C1, S1, R2, S2, S3>(
103    qr: &Matrix<T, R1, C1, S1>,
104    tau: &Vector<T, DimMinimum<R1, C1>, S3>,
105    b: &mut Matrix<T, R2, R1, S2>,
106) -> Result<(), Error>
107where
108    T: QrReal + Zero + RealField,
109    R1: DimMin<C1>,
110    C1: Dim,
111    S1: RawStorage<T, R1, C1> + IsContiguous,
112    R2: Dim,
113    S2: RawStorageMut<T, R2, R1> + IsContiguous,
114    S3: RawStorage<T, <R1 as DimMin<C1>>::Output> + IsContiguous,
115{
116    if b.ncols() != qr.nrows() {
117        return Err(Error::Dimensions);
118    }
119    if qr.ncols().min(qr.nrows()) != tau.len() {
120        return Err(Error::Dimensions);
121    }
122    // SAFETY: matrix has the correct dimensions for operation B*Q
123    unsafe { multiply_q_mut(qr, tau, b, Side::Right, Transposition::No)? };
124    Ok(())
125}
126
127/// Thin wrapper around certain invokation of `multiply_q_mut`, where:
128/// * `qr`: contains the lapack-style qr decomposition of a matrix A
129/// * `tau`: scalar factors of the elementary reflectors
130/// * `b`: matrix B described below
131///
132/// Efficiently calculate the matrix product `B Q^T` of the factor `Q` with a
133/// given matrix `B`. `Q` acts as if it is a matrix of dimension `m ⨯ m`, so
134/// we require `B ∈ R^(k ⨯ m)`. The product is calculated in place and
135/// must only be considered valid when the function returns without error.
136pub(crate) fn mul_q_tr_mut<T, R1, C1, S1, R2, S2, S3>(
137    qr: &Matrix<T, R1, C1, S1>,
138    tau: &Vector<T, DimMinimum<R1, C1>, S3>,
139    b: &mut Matrix<T, R2, R1, S2>,
140) -> Result<(), Error>
141where
142    T: QrReal + Zero + RealField,
143    R1: DimMin<C1>,
144    C1: Dim,
145    S1: RawStorage<T, R1, C1> + IsContiguous,
146    R2: Dim,
147    S2: RawStorageMut<T, R2, R1> + IsContiguous,
148    S3: RawStorage<T, <R1 as DimMin<C1>>::Output> + IsContiguous,
149{
150    if b.ncols() != qr.nrows() {
151        return Err(Error::Dimensions);
152    }
153    if qr.ncols().min(qr.nrows()) != tau.len() {
154        return Err(Error::Dimensions);
155    }
156    // SAFETY: matrix has the correct dimensions for operation B Q^T
157    unsafe { multiply_q_mut(qr, tau, b, Side::Right, Transposition::Transpose)? }
158    Ok(())
159}
160
161/// this factors out solving a the A X = B in a least squares sense, given a
162/// lapack qr decomposition of matrix A (in qr, tau). This also needs an explicit
163/// rank for the matrix, which should be set to full rank for unpivoted QR.
164///
165/// This solver does not do the final row permutation necessary for col-pivoted
166/// qr. For unpivoted QR, no extra permutation is necessary anyways.
167pub(crate) fn qr_solve_mut_with_rank_unpermuted<T, R1, C1, S1, C2: Dim, S3, S2, S4>(
168    qr: &Matrix<T, R1, C1, S1>,
169    tau: &Vector<T, DimMinimum<R1, C1>, S4>,
170    rank: u16,
171    x: &mut Matrix<T, C1, C2, S2>,
172    mut b: Matrix<T, R1, C2, S3>,
173) -> Result<(), Error>
174where
175    T: QrReal + Zero + RealField,
176    R1: DimMin<C1>,
177    C1: Dim,
178    S1: RawStorage<T, R1, C1> + IsContiguous,
179    S3: RawStorageMut<T, R1, C2> + IsContiguous,
180    S2: RawStorageMut<T, C1, C2> + IsContiguous,
181    S4: RawStorage<T, <R1 as DimMin<C1>>::Output> + IsContiguous,
182{
183    if b.nrows() != qr.nrows() {
184        return Err(Error::Dimensions);
185    }
186
187    if qr.nrows() < qr.ncols() || qr.nrows() == 0 || qr.ncols() == 0 {
188        return Err(Error::Underdetermined);
189    }
190
191    if x.ncols() != b.ncols() || x.nrows() != qr.ncols() {
192        return Err(Error::Dimensions);
193    }
194
195    q_tr_mul_mut(qr, tau, &mut b)?;
196
197    if rank == 0 {
198        return Err(Error::ZeroRank);
199    }
200
201    debug_assert!(rank as usize <= qr.ncols().min(qr.nrows()));
202
203    if (rank as usize) < qr.ncols() {
204        x.view_mut((rank as usize, 0), (x.nrows() - rank as usize, x.ncols()))
205            .iter_mut()
206            .for_each(|val| val.set_zero());
207    }
208
209    let x_cols = x.ncols();
210    x.view_mut((0, 0), (rank as usize, x_cols))
211        .copy_from(&b.view((0, 0), (rank as usize, x_cols)));
212
213    let ldb: i32 = x
214        .nrows()
215        .try_into()
216        .expect("integer dimensions out of bounds");
217
218    // SAFETY: input and dimensions according to lapack spec, see
219    // https://www.netlib.org/lapack/explore-html/d4/dc1/group__trtrs_gab0b6a7438a7eb98fe2ab28e6c4d84b21.html#gab0b6a7438a7eb98fe2ab28e6c4d84b21
220    unsafe {
221        T::xtrtrs(
222            TriangularStructure::Upper,
223            Transposition::No,
224            DiagonalKind::NonUnit,
225            rank.try_into().expect("rank out of bounds"),
226            x.ncols()
227                .try_into()
228                .expect("integer dimensions out of bounds"),
229            qr.as_slice(),
230            qr.nrows()
231                .try_into()
232                .expect("integer dimensions out of bounds"),
233            x.as_mut_slice(),
234            ldb,
235        )?;
236    }
237
238    Ok(())
239}
240
241/// Thin-ish wrapper around the LAPACK function
242/// [?ormqr](https://www.netlib.org/lapack/explore-html/d7/d50/group__unmqr.html),
243/// which allows us to calculate either Q*B, Q^T*B, B*Q, B*Q^T for appropriately
244/// shaped matrices B, without having to explicitly form Q. In this calculation
245/// Q is constructed as if it were a square matrix of appropriate dimension.
246///
247/// # Safety
248///
249/// The dimensions of the matrices must be correct such that the multiplication
250/// can be performed.
251#[inline]
252unsafe fn multiply_q_mut<T, R1, C1, S1, R2, C2, S2, S3>(
253    qr: &Matrix<T, R1, C1, S1>,
254    tau: &Vector<T, DimMinimum<R1, C1>, S3>,
255    mat: &mut Matrix<T, R2, C2, S2>,
256    side: Side,
257    transpose: Transposition,
258) -> Result<(), Error>
259where
260    T: QrReal,
261    R1: DimMin<C1>,
262    C1: Dim,
263    S2: RawStorageMut<T, R2, C2> + IsContiguous,
264    R2: Dim,
265    C2: Dim,
266    S1: IsContiguous + RawStorage<T, R1, C1>,
267    S3: RawStorage<T, <R1 as DimMin<C1>>::Output> + IsContiguous,
268{
269    let a = qr.as_slice();
270    let lda = qr
271        .nrows()
272        .try_into()
273        .expect("integer dimension out of range");
274    let m = mat
275        .nrows()
276        .try_into()
277        .expect("integer dimension out of range");
278    let n = mat
279        .ncols()
280        .try_into()
281        .expect("integer dimension out of range");
282    let k = tau
283        .len()
284        .try_into()
285        .expect("integer dimension out of range");
286    let ldc = mat
287        .nrows()
288        .try_into()
289        .expect("integer dimension out of range");
290    let c = mat.as_mut_slice();
291    let trans = transpose;
292    let tau = tau.as_slice();
293
294    if k as usize != qr.ncols() {
295        return Err(Error::Dimensions);
296    }
297
298    // dimensions checks from the lapack documentation
299    // see e.g. https://www.netlib.org/lapack/explore-html/d7/d50/group__unmqr_ga768bd221f959be1b3d15bd177bb5c1b3.html#ga768bd221f959be1b3d15bd177bb5c1b3
300    match side {
301        Side::Left => {
302            if m < k {
303                return Err(Error::Dimensions);
304            }
305
306            if lda < m {
307                return Err(Error::Dimensions);
308            }
309        }
310        Side::Right => {
311            if n < k {
312                return Err(Error::Dimensions);
313            }
314
315            if lda < n {
316                return Err(Error::Dimensions);
317            }
318        }
319    }
320
321    if ldc < m {
322        return Err(Error::Dimensions);
323    }
324
325    // SAFETY: the dimensions are checked as above, but the user has to make
326    // sure that qr indeed contains the contents of a qr decomposition returned
327    // by lapack and tau must contain the scalar factors of the reflectors as
328    // returned by lapack.
329    let lwork = unsafe { T::xormqr_work_size(side, transpose, m, n, k, a, lda, tau, c, ldc)? };
330    let mut work = vec![T::zero(); lwork as usize];
331
332    // SAFETY: the containing function is unsafe and requires the correct
333    // matrix dimensions as input
334    unsafe {
335        T::xormqr(side, trans, m, n, k, a, lda, tau, c, ldc, &mut work, lwork)?;
336    }
337    Ok(())
338}
339
340/// multiply R*B or R^T *B and place the result in B, where R is the upper triangular matrix
341/// in a qr decomposition as computed by lapack.
342pub fn r_xx_mul_mut<T, R1, C1, S1, C2, S2>(
343    qr: &Matrix<T, R1, C1, S1>,
344    transpose: Transposition,
345    b: &mut Matrix<T, C1, C2, S2>,
346) -> Result<(), Error>
347where
348    T: QrReal + ConstOne,
349    R1: Dim,
350    C1: Dim,
351    C2: Dim,
352    S1: RawStorage<T, R1, C1> + IsContiguous,
353    S2: RawStorageMut<T, C1, C2> + IsContiguous,
354{
355    // looking carefully at the lapack docs, the xTRMM requires
356    // an overdetermined matrix (m>=n), because otherwise R will
357    // be upper trapezoidal and the logic will be different and it
358    // might not actually be useful to multiply the square part.
359    if qr.nrows() < qr.ncols() {
360        return Err(Error::Underdetermined);
361    }
362
363    if qr.ncols() != b.nrows() {
364        return Err(Error::Dimensions);
365    }
366
367    multiply_r_mut(qr, transpose, Side::Left, b)?;
368    Ok(())
369}
370
371/// multiply B*R or B * R^T and place the result in B, where R is the upper triangular matrix
372/// in a qr decomposition as computed by lapack.
373pub fn mul_r_xx_mut<T, R1, C1, S1, R2, S2>(
374    qr: &Matrix<T, R1, C1, S1>,
375    transpose: Transposition,
376    b: &mut Matrix<T, R2, C1, S2>,
377) -> Result<(), Error>
378where
379    T: QrReal + ConstOne,
380    R1: Dim,
381    C1: Dim,
382    R2: Dim,
383    S1: RawStorage<T, R1, C1> + IsContiguous,
384    S2: RawStorageMut<T, R2, C1> + IsContiguous,
385{
386    // looking carefully at the lapack docs, the xTRMM requires
387    // an overdetermined matrix (m>=n), because otherwise R will
388    // be upper trapezoidal and the logic will be different and it
389    // might not actually be useful to multiply the square part.
390    if qr.nrows() < qr.ncols() {
391        return Err(Error::Underdetermined);
392    }
393
394    if b.ncols() != qr.ncols() {
395        return Err(Error::Dimensions);
396    }
397
398    multiply_r_mut(qr, transpose, Side::Right, b)?;
399    Ok(())
400}
401/// thin-ish wrapper around the lapack function [?TRMM](https://www.netlib.org/lapack/explore-html/dd/dab/group__trmm.html)
402/// for multiplying the upper triangular part R or a QR decomposition with another
403/// matrix.
404///
405/// The way the ?TRMM logic works is that A is a kxk matrix, and B is m x n.
406/// When multiplying from the left, then k = m and when multiplying from the
407/// right, k = n. The matrix A can be stored in the QR decomposition as the
408/// upper triangular part, so LDA is the number of rows for the QR decomp.
409///
410/// The ?TRMM functions also allow scaling with a factor alpha, which
411/// we always set to 1 and they allow the matrix to be upper or lower triangular,
412/// we always use upper triangular. They also allow to multiply from right or
413/// left, but the dimension of R in a QR decomposition only allows multiplication
414/// from the left, I think.
415#[inline]
416fn multiply_r_mut<T, R1, C1, S1, R2, C2, S2>(
417    qr: &Matrix<T, R1, C1, S1>,
418    transpose: Transposition,
419    side: Side,
420    mat: &mut Matrix<T, R2, C2, S2>,
421) -> Result<(), Error>
422where
423    T: QrReal + ConstOne,
424    R1: Dim,
425    C1: Dim,
426    S2: RawStorageMut<T, R2, C2> + IsContiguous,
427    R2: Dim,
428    C2: Dim,
429    S1: IsContiguous + RawStorage<T, R1, C1>,
430{
431    let m: i32 = mat
432        .nrows()
433        .try_into()
434        .expect("integer dimensions out of bounds");
435    let n: i32 = mat
436        .ncols()
437        .try_into()
438        .expect("integer dimensions out of bounds");
439    let lda: i32 = qr
440        .nrows()
441        .try_into()
442        .expect("integer dimensions out of bounds");
443    let ldb: i32 = mat
444        .nrows()
445        .try_into()
446        .expect("integer dimensions out of bounds");
447
448    // these bounds are from the lapack documentation
449    // see e.g. https://www.netlib.org/lapack/explore-html/dd/dab/group__trmm_ga4d2f76d6726f53c69031a2fe7f999add.html#ga4d2f76d6726f53c69031a2fe7f999add
450    match side {
451        Side::Left => {
452            if lda == 0 || lda < m {
453                return Err(Error::Dimensions);
454            }
455            if qr.ncols() != m as usize {
456                return Err(Error::Dimensions);
457            }
458        }
459        Side::Right => {
460            if lda == 0 || lda < n {
461                return Err(Error::Dimensions);
462            }
463            if qr.ncols() != n as usize {
464                return Err(Error::Dimensions);
465            }
466        }
467    }
468
469    // SAFETY: we're using the correct types and we are giving the
470    // correct matrix dimensions as per lapack docs
471    unsafe {
472        T::xtrmm(
473            side,
474            TriangularStructure::Upper,
475            transpose,
476            DiagonalKind::NonUnit,
477            m,
478            n,
479            T::ONE,
480            qr.as_slice(),
481            lda,
482            mat.as_mut_slice(),
483            ldb,
484        );
485    }
486    Ok(())
487}