nalgebra_lapack/
qr.rs

1use crate::sealed::Sealed;
2use crate::{ComplexHelper, DiagonalKind, Side, Transposition, TriangularStructure, qr_util};
3use crate::{LapackErrorCode, lapack_error::check_lapack_info};
4use lapack;
5use na::allocator::Allocator;
6use na::dimension::{Const, Dim, DimMin, DimMinimum};
7use na::{
8    ComplexField, DefaultAllocator, IsContiguous, Matrix, OMatrix, OVector, RawStorageMut,
9    RealField, Scalar,
10};
11use num::Zero;
12#[cfg(feature = "serde-serialize")]
13use serde::{Deserialize, Serialize};
14
15pub use crate::qr_util::Error;
16pub(crate) mod abstraction;
17pub use abstraction::QrDecomposition;
18
19/// The QR decomposition of a rectangular matrix `A ∈ R^(m × n)` with `m >= n`.
20#[cfg_attr(feature = "serde-serialize", derive(Serialize, Deserialize))]
21#[cfg_attr(
22    feature = "serde-serialize",
23    serde(bound(serialize = "DefaultAllocator: Allocator<R, C> +
24                           Allocator<DimMinimum<R, C>>,
25         OMatrix<T, R, C>: Serialize,
26         OVector<T, DimMinimum<R, C>>: Serialize"))
27)]
28#[cfg_attr(
29    feature = "serde-serialize",
30    serde(bound(deserialize = "DefaultAllocator: Allocator<R, C> +
31                           Allocator<DimMinimum<R, C>>,
32         OMatrix<T, R, C>: Deserialize<'de>,
33         OVector<T, DimMinimum<R, C>>: Deserialize<'de>"))
34)]
35#[derive(Clone, Debug)]
36pub struct QR<T, R, C>
37where
38    DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
39    T: Scalar,
40    R: DimMin<C>,
41    C: Dim,
42{
43    qr: OMatrix<T, R, C>,
44    tau: OVector<T, DimMinimum<R, C>>,
45}
46
47impl<T: Scalar + Copy, R: DimMin<C>, C: Dim> Copy for QR<T, R, C>
48where
49    DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
50    OMatrix<T, R, C>: Copy,
51    OVector<T, DimMinimum<R, C>>: Copy,
52{
53}
54
55impl<T, R, C> QR<T, R, C>
56where
57    DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>>,
58    T: Scalar,
59    R: DimMin<C>,
60    C: Dim,
61{
62    /// Computes the QR decomposition of the matrix `m`.
63    pub fn new(mut m: OMatrix<T, R, C>) -> Result<Self, Error>
64    where
65        T: QrScalar + Zero,
66    {
67        let (nrows, ncols) = m.shape_generic();
68
69        let mut tau = Matrix::zeros_generic(nrows.min(ncols), Const::<1>);
70
71        if nrows.value() < ncols.value() {
72            return Err(Error::Underdetermined);
73        }
74
75        if nrows.value() == 0 || ncols.value() == 0 {
76            return Ok(Self { qr: m, tau });
77        }
78
79        let lwork = unsafe {
80            T::xgeqrf_work_size(
81                nrows.value() as i32,
82                ncols.value() as i32,
83                m.as_mut_slice(),
84                nrows.value() as i32,
85                tau.as_mut_slice(),
86            )?
87        };
88
89        let mut work = vec![T::zero(); lwork as usize];
90
91        unsafe {
92            T::xgeqrf(
93                nrows.value() as i32,
94                ncols.value() as i32,
95                m.as_mut_slice(),
96                nrows.value() as i32,
97                tau.as_mut_slice(),
98                &mut work,
99                lwork,
100            )?;
101        }
102
103        Ok(Self { qr: m, tau })
104    }
105}
106
107impl<T, R, C> Sealed for QR<T, R, C>
108where
109    DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<C>,
110    T: Scalar,
111    R: DimMin<C>,
112    C: Dim,
113{
114}
115
116impl<T, R, C> QrDecomposition<T, R, C> for QR<T, R, C>
117where
118    DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<C>,
119    R: DimMin<C, Output = C>,
120    C: Dim,
121    T: Scalar + RealField + QrReal,
122{
123    fn __lapack_qr_ref(&self) -> &OMatrix<T, R, C> {
124        &self.qr
125    }
126
127    fn __lapack_tau_ref(&self) -> &OVector<T, DimMinimum<R, C>> {
128        &self.tau
129    }
130
131    fn solve_mut<C2: Dim, S, S2>(
132        &self,
133        x: &mut Matrix<T, C, C2, S2>,
134        b: Matrix<T, R, C2, S>,
135    ) -> Result<(), Error>
136    where
137        S: RawStorageMut<T, R, C2> + IsContiguous,
138        S2: RawStorageMut<T, C, C2> + IsContiguous,
139        T: Zero,
140    {
141        // this is important because a lot of assumptions rest on this
142        if self.nrows() < self.ncols() {
143            return Err(Error::Underdetermined);
144        }
145
146        // since we use QR decomposition without column pivoting, we assume
147        // full rank.
148        let rank = self
149            .nrows()
150            .min(self.ncols())
151            .try_into()
152            .expect("integer dimensions out of bounds");
153        qr_util::qr_solve_mut_with_rank_unpermuted(&self.qr, &self.tau, rank, x, b)?;
154        Ok(())
155    }
156}
157
158/*
159 *
160 * Lapack functions dispatch.
161 *
162 */
163/// Trait implemented by scalar types for which Lapack function exist to compute the
164/// QR decomposition.
165#[allow(missing_docs)]
166pub trait QrScalar: ComplexField + Scalar + Copy + Sealed {
167    unsafe fn xgeqrf(
168        m: i32,
169        n: i32,
170        a: &mut [Self],
171        lda: i32,
172        tau: &mut [Self],
173        work: &mut [Self],
174        lwork: i32,
175    ) -> Result<(), LapackErrorCode>;
176
177    unsafe fn xgeqrf_work_size(
178        m: i32,
179        n: i32,
180        a: &mut [Self],
181        lda: i32,
182        tau: &mut [Self],
183    ) -> Result<i32, LapackErrorCode>;
184
185    /// routine for column pivoting QR decomposition using level 3 BLAS,
186    /// see <https://www.netlib.org/lapack/lug/node42.html>
187    /// or <https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2023-0/geqp3.html>
188    unsafe fn xgeqp3(
189        m: i32,
190        n: i32,
191        a: &mut [Self],
192        lda: i32,
193        jpvt: &mut [i32],
194        tau: &mut [Self],
195        work: &mut [Self],
196        lwork: i32,
197    ) -> Result<(), LapackErrorCode>;
198
199    unsafe fn xgeqp3_work_size(
200        m: i32,
201        n: i32,
202        a: &mut [Self],
203        lda: i32,
204        jpvt: &mut [i32],
205        tau: &mut [Self],
206    ) -> Result<i32, LapackErrorCode>;
207
208    unsafe fn xtrtrs(
209        uplo: TriangularStructure,
210        trans: Transposition,
211        diag: DiagonalKind,
212        n: i32,
213        nrhs: i32,
214        a: &[Self],
215        lda: i32,
216        b: &mut [Self],
217        ldb: i32,
218    ) -> Result<(), LapackErrorCode>;
219
220    unsafe fn xlapmt(
221        forwrd: bool,
222        m: i32,
223        n: i32,
224        x: &mut [Self],
225        ldx: i32,
226        k: &mut [i32],
227    ) -> Result<(), LapackErrorCode>;
228
229    unsafe fn xlapmr(
230        forwrd: bool,
231        m: i32,
232        n: i32,
233        x: &mut [Self],
234        ldx: i32,
235        k: &mut [i32],
236    ) -> Result<(), LapackErrorCode>;
237}
238
239macro_rules! qr_scalar_impl(
240    ($type:ty,
241        xgeqrf = $xgeqrf: path,
242        xgeqp3=$xgeqp3:path,
243        xtrtrs=$xtrtrs:path,
244        xlapmt=$xlapmt:path,
245        xlapmr=$xlapmr:path $(,)?) => (
246        impl QrScalar for $type {
247            #[inline]
248            unsafe fn xgeqrf(m: i32, n: i32, a: &mut [Self], lda: i32, tau: &mut [Self],
249                      work: &mut [Self], lwork: i32) -> Result<(),LapackErrorCode> {
250                let mut info = 0;
251                unsafe { $xgeqrf(m, n, a, lda, tau, work, lwork, &mut info) }
252                check_lapack_info(info)
253            }
254
255            #[inline]
256            unsafe fn xgeqrf_work_size(m: i32, n: i32, a: &mut [Self], lda: i32, tau: &mut [Self]) -> Result<i32, LapackErrorCode> {
257                let mut info = 0;
258                let mut work = [ Zero::zero() ];
259                let lwork = -1 as i32;
260
261                unsafe { $xgeqrf(m, n, a, lda, tau, &mut work, lwork, &mut info); }
262                check_lapack_info(info)?;
263                Ok(ComplexHelper::real_part(work[0]) as i32)
264            }
265
266            unsafe fn xgeqp3(
267                m: i32,
268                n: i32,
269                a: &mut [Self],
270                lda: i32,
271                jpvt: &mut [i32],
272                tau: &mut [Self],
273                work: &mut [Self],
274                lwork: i32,
275            ) -> Result<(), LapackErrorCode> {
276                let mut info = 0;
277                unsafe { $xgeqp3(m, n, a, lda, jpvt, tau, work, lwork, &mut info) };
278                check_lapack_info(info)
279            }
280
281            unsafe fn xgeqp3_work_size(
282                m: i32,
283                n: i32,
284                a: &mut [Self],
285                lda: i32,
286                jpvt: &mut [i32],
287                tau: &mut [Self],
288            ) -> Result<i32, LapackErrorCode> {
289                let mut work = [Zero::zero()];
290                let lwork = -1 as i32;
291                let mut info = 0;
292                unsafe { $xgeqp3(m, n, a, lda, jpvt, tau, &mut work, lwork, &mut info) };
293                check_lapack_info(info)?;
294                Ok(work[0] as i32)
295            }
296
297            unsafe fn xtrtrs(
298                uplo: TriangularStructure,
299                trans: Transposition,
300                diag: DiagonalKind,
301                n: i32,
302                nrhs: i32,
303                a: &[Self],
304                lda: i32,
305                b: &mut [Self],
306                ldb: i32,
307            ) -> Result<(), LapackErrorCode> {
308                let mut info = 0;
309                let trans = match trans {
310                    Transposition::No => b'N',
311                    Transposition::Transpose => b'T',
312                };
313
314                unsafe {
315                    $xtrtrs(
316                        uplo.into_lapack_uplo_character(),
317                        trans,
318                        diag.into_lapack_diag_character(),
319                        n,
320                        nrhs,
321                        a,
322                        lda,
323                        b,
324                        ldb,
325                        &mut info,
326                    );
327                }
328
329                check_lapack_info(info)
330            }
331
332            unsafe fn xlapmt(
333                forwrd: bool,
334                m: i32,
335                n: i32,
336                x: &mut [Self],
337                ldx: i32,
338                k: &mut [i32],
339            ) -> Result<(), LapackErrorCode> {
340                debug_assert_eq!(k.len(), n as usize);
341
342                let forward: [i32; 1] = [forwrd.then_some(1).unwrap_or(0)];
343                unsafe { $xlapmt(forward.as_slice(), m, n, x, ldx, k) }
344                Ok(())
345            }
346
347            unsafe fn xlapmr(
348                forwrd: bool,
349                m: i32,
350                n: i32,
351                x: &mut [Self],
352                ldx: i32,
353                k: &mut [i32],
354            ) -> Result<(), LapackErrorCode> {
355                debug_assert_eq!(k.len(), m as usize);
356
357                let forward: [i32; 1] = [forwrd.then_some(1).unwrap_or(0)];
358                unsafe { $xlapmr(forward.as_slice(), m, n, x, ldx, k) }
359                Ok(())
360            }
361
362
363        }
364    )
365);
366
367/// Trait implemented by reals for which Lapack function exist to compute the
368/// QR decomposition.
369pub trait QrReal: QrScalar {
370    #[allow(missing_docs)]
371    unsafe fn xorgqr(
372        m: i32,
373        n: i32,
374        k: i32,
375        a: &mut [Self],
376        lda: i32,
377        tau: &[Self],
378        work: &mut [Self],
379        lwork: i32,
380    ) -> Result<(), LapackErrorCode>;
381
382    #[allow(missing_docs)]
383    unsafe fn xorgqr_work_size(
384        m: i32,
385        n: i32,
386        k: i32,
387        a: &mut [Self],
388        lda: i32,
389        tau: &[Self],
390    ) -> Result<i32, LapackErrorCode>;
391
392    #[allow(missing_docs)]
393    unsafe fn xormqr(
394        side: Side,
395        trans: Transposition,
396        m: i32,
397        n: i32,
398        k: i32,
399        a: &[Self],
400        lda: i32,
401        tau: &[Self],
402        c: &mut [Self],
403        ldc: i32,
404        work: &mut [Self],
405        lwork: i32,
406    ) -> Result<(), LapackErrorCode>;
407
408    #[allow(missing_docs)]
409    unsafe fn xormqr_work_size(
410        side: Side,
411        trans: Transposition,
412        m: i32,
413        n: i32,
414        k: i32,
415        a: &[Self],
416        lda: i32,
417        tau: &[Self],
418        c: &mut [Self],
419        ldc: i32,
420    ) -> Result<i32, LapackErrorCode>;
421
422    /// wraps BLAS function [?TRMM](https://www.netlib.org/lapack/explore-html/dd/dab/group__trmm_ga4d2f76d6726f53c69031a2fe7f999add.html#ga4d2f76d6726f53c69031a2fe7f999add)
423    unsafe fn xtrmm(
424        side: Side,
425        uplo: TriangularStructure,
426        transa: Transposition,
427        diag: DiagonalKind,
428        m: i32,
429        n: i32,
430        alpha: Self,
431        a: &[Self],
432        lda: i32,
433        b: &mut [Self],
434        ldb: i32,
435    );
436}
437
438macro_rules! qr_real_impl(
439    ($type:ty, xorgqr = $xorgqr:path, xormqr = $xormqr:path, xtrmm = $xtrmm:path) => (
440        impl QrReal for $type {
441            #[inline]
442            unsafe fn xorgqr(m: i32, n: i32, k: i32, a: &mut [Self], lda: i32, tau: &[Self],
443                      work: &mut [Self], lwork: i32) -> Result<(),LapackErrorCode> {
444                let mut info = 0;
445                unsafe { $xorgqr(m, n, k, a, lda, tau, work, lwork, &mut info) }
446                check_lapack_info(info)
447            }
448
449            #[inline]
450            unsafe fn xorgqr_work_size(m: i32, n: i32, k: i32, a: &mut [Self], lda: i32, tau: &[Self]) -> Result<i32,LapackErrorCode> {
451                let mut info = 0;
452                let mut work = [ Zero::zero() ];
453                let lwork = -1 as i32;
454
455                unsafe { $xorgqr(m, n, k, a, lda, tau, &mut work, lwork, &mut info); }
456                check_lapack_info(info)?;
457                Ok(ComplexHelper::real_part(work[0]) as i32)
458            }
459
460            unsafe fn xormqr(
461                side: Side,
462                trans: Transposition,
463                m: i32,
464                n: i32,
465                k: i32,
466                a: &[Self],
467                lda: i32,
468                tau: &[Self],
469                c: &mut [Self],
470                ldc: i32,
471                work: &mut [Self],
472                lwork: i32,
473            ) -> Result<(), LapackErrorCode> {
474                let mut info = 0;
475                let side = side.into_lapack_side_character();
476
477                // this would be different for complex numbers!
478                let trans = match trans {
479                    Transposition::No => b'N',
480                    Transposition::Transpose => b'T',
481                };
482
483                unsafe {
484                    $xormqr(
485                        side, trans, m, n, k, a, lda, tau, c, ldc, work, lwork, &mut info,
486                    );
487                }
488                check_lapack_info(info)
489            }
490
491            unsafe fn xormqr_work_size(
492                side: Side,
493                trans: Transposition,
494                m: i32,
495                n: i32,
496                k: i32,
497                a: &[Self],
498                lda: i32,
499                tau: &[Self],
500                c: &mut [Self],
501                ldc: i32,
502            ) -> Result<i32, LapackErrorCode> {
503                let mut info = 0;
504                let side = side.into_lapack_side_character();
505
506                // this would be different for complex numbers!
507                let trans = match trans {
508                    Transposition::No => b'N',
509                    Transposition::Transpose => b'T',
510                };
511
512                let mut work = [Zero::zero()];
513                let lwork = -1 as i32;
514                unsafe {
515                    $xormqr(
516                        side, trans, m, n, k, a, lda, tau, c, ldc, &mut work, lwork, &mut info,
517                    );
518                }
519                check_lapack_info(info)?;
520                // for complex numbers: real part
521                Ok(ComplexHelper::real_part(work[0]) as i32)
522            }
523
524            unsafe fn xtrmm(
525                side: Side,
526                uplo: TriangularStructure,
527                transa: Transposition,
528                diag: DiagonalKind,
529                m: i32,
530                n: i32,
531                alpha: Self,
532                a: &[Self],
533                lda: i32,
534                b: &mut [Self],
535                ldb: i32,
536            ) {
537                // this would be different for complex numbers!
538                let transa = match transa {
539                    Transposition::No => b'N',
540                    Transposition::Transpose => b'T',
541                };
542
543                unsafe {$xtrmm(
544                    side.into_lapack_side_character(),
545                    uplo.into_lapack_uplo_character(),
546                    transa,
547                    diag.into_lapack_diag_character(),
548                    m,
549                    n,
550                    alpha,
551                    a,
552                    lda,
553                    b,
554                    ldb
555                )}
556            }
557        }
558    )
559);
560
561qr_scalar_impl!(
562    f32,
563    xgeqrf = lapack::sgeqrf,
564    xgeqp3 = lapack::sgeqp3,
565    xtrtrs = lapack::strtrs,
566    xlapmt = lapack::slapmt,
567    xlapmr = lapack::slapmr
568);
569
570qr_scalar_impl!(
571    f64,
572    xgeqrf = lapack::dgeqrf,
573    xgeqp3 = lapack::dgeqp3,
574    xtrtrs = lapack::dtrtrs,
575    xlapmt = lapack::dlapmt,
576    xlapmr = lapack::dlapmr
577);
578
579qr_real_impl!(
580    f32,
581    xorgqr = lapack::sorgqr,
582    xormqr = lapack::sormqr,
583    xtrmm = blas::strmm
584);
585qr_real_impl!(
586    f64,
587    xorgqr = lapack::dorgqr,
588    xormqr = lapack::dormqr,
589    xtrmm = blas::dtrmm
590);