nalgebra_lapack/
colpiv_qr.rs

1use super::qr::{QrReal, QrScalar};
2use crate::qr::QrDecomposition;
3use crate::qr_util;
4use crate::sealed::Sealed;
5use na::{Const, IsContiguous, Matrix, OVector, RealField, Vector};
6use nalgebra::storage::RawStorageMut;
7use nalgebra::{DefaultAllocator, Dim, DimMin, DimMinimum, OMatrix, Scalar, allocator::Allocator};
8use num::float::TotalOrder;
9use num::{Float, Zero};
10use rank::{RankDeterminationAlgorithm, calculate_rank};
11
12pub use qr_util::Error;
13mod permutation;
14#[cfg(test)]
15mod test;
16pub use permutation::Permutation;
17/// Utility functionality to calculate the rank of matrices.
18mod rank;
19
20/// The column-pivoted QR decomposition of a rectangular matrix `A ∈ R^(m × n)`
21/// with `m >= n`.
22///
23/// The columns of the matrix `A` are permuted such that `A P = Q R`, meaning
24/// the column-permuted `A` is the product of `Q` and `R`, where `Q` is an orthonormal
25/// matrix `Q^T Q = I` and `R` is upper triangular.
26///
27/// Note that most of the functionality is provided via the [`QrDecomposition`]
28/// trait, which must be in scope for its functions to be used.
29#[derive(Debug, Clone)]
30pub struct ColPivQR<T, R, C>
31where
32    DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<C>,
33    T: Scalar,
34    R: DimMin<C, Output = C>,
35    C: Dim,
36{
37    // QR decomposition, see https://www.netlib.org/lapack/explore-html/d0/dea/group__geqp3.html
38    qr: OMatrix<T, R, C>,
39    // Householder coefficients, see https://www.netlib.org/lapack/explore-html/d0/dea/group__geqp3.html
40    tau: OVector<T, DimMinimum<R, C>>,
41    // Permutation vector, see https://www.netlib.org/lapack/explore-html/d0/dea/group__geqp3.html
42    // Note that permutation indices are 1-based in LAPACK
43    jpvt: OVector<i32, C>,
44    // Rank of the matrix
45    rank: i32,
46}
47
48impl<T, R, C> Sealed for ColPivQR<T, R, C>
49where
50    DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<C>,
51    T: Scalar,
52    R: DimMin<C, Output = C>,
53    C: Dim,
54{
55}
56
57/// Constructors
58impl<T, R, C> ColPivQR<T, R, C>
59where
60    DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<C>,
61    T: QrScalar + Zero + RealField + TotalOrder + Float,
62    R: DimMin<C, Output = C>,
63    C: Dim,
64{
65    /// Try to create a new decomposition from the given matrix using the default
66    /// strategy for rank determination of a matrix from its QR decomposition.
67    pub fn new(m: OMatrix<T, R, C>) -> Result<Self, Error> {
68        Self::with_rank_algo(m, Default::default())
69    }
70
71    /// Try to create a new decomposition from the given matrix and specify the
72    /// strategy for rank determination. When in doubt, use the default strategy
73    /// via the [`ColPivQR::new`] constructor.
74    pub fn with_rank_algo(
75        mut m: OMatrix<T, R, C>,
76        rank_algo: RankDeterminationAlgorithm<T>,
77    ) -> Result<Self, Error> {
78        let (nrows, ncols) = m.shape_generic();
79
80        if nrows.value() < ncols.value() {
81            return Err(Error::Underdetermined);
82        }
83
84        let mut tau: OVector<T, DimMinimum<R, C>> =
85            Vector::zeros_generic(nrows.min(ncols), Const::<1>);
86        let mut jpvt: OVector<i32, C> = Vector::zeros_generic(ncols, Const::<1>);
87
88        // SAFETY: matrix dimensions are slice dimensions, other inputs are according
89        // to spec, see https://www.netlib.org/lapack/explore-html/d0/dea/group__geqp3.html
90        let lwork = unsafe {
91            T::xgeqp3_work_size(
92                nrows.value().try_into().expect("matrix dims out of bounds"),
93                ncols.value().try_into().expect("matrix dims out of bounds"),
94                m.as_mut_slice(),
95                nrows.value().try_into().expect("matrix dims out of bounds"),
96                jpvt.as_mut_slice(),
97                tau.as_mut_slice(),
98            )?
99        };
100
101        let mut work = vec![T::zero(); lwork as usize];
102
103        // SAFETY: matrix dimensions are slice dimensions, other inputs are according
104        // to spec, see https://www.netlib.org/lapack/explore-html/d0/dea/group__geqp3.html
105        unsafe {
106            T::xgeqp3(
107                nrows.value() as i32,
108                ncols.value() as i32,
109                m.as_mut_slice(),
110                nrows.value() as i32,
111                jpvt.as_mut_slice(),
112                tau.as_mut_slice(),
113                &mut work,
114                lwork,
115            )?;
116        }
117
118        let rank: i32 = calculate_rank(&m, rank_algo)
119            .try_into()
120            .map_err(|_| Error::Dimensions)?;
121
122        Ok(Self {
123            qr: m,
124            rank,
125            tau,
126            jpvt,
127        })
128    }
129}
130
131impl<T, R, C> ColPivQR<T, R, C>
132where
133    DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<C>,
134    T: QrScalar + Zero + RealField,
135    R: DimMin<C, Output = C>,
136    C: Dim,
137{
138    /// get the effective rank of the matrix computed using the stratey
139    /// chosen at construction.
140    #[inline]
141    pub fn rank(&self) -> u16 {
142        self.rank as u16
143    }
144    /// obtain the permutation `P` such that the `A P = Q R` ,
145    /// meaning the column-permuted original matrix `A` is identical to
146    /// `Q R`. This function performs a small allocation.
147    pub fn p(&self) -> Permutation<C> {
148        Permutation::new(self.jpvt.clone())
149    }
150}
151
152impl<T, R, C> QrDecomposition<T, R, C> for ColPivQR<T, R, C>
153where
154    DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<C>,
155    R: DimMin<C, Output = C>,
156    C: Dim,
157    T: Scalar + RealField + QrReal,
158{
159    fn __lapack_qr_ref(&self) -> &OMatrix<T, R, C> {
160        &self.qr
161    }
162
163    fn __lapack_tau_ref(&self) -> &OVector<T, DimMinimum<R, C>> {
164        &self.tau
165    }
166
167    fn solve_mut<C2: Dim, S, S2>(
168        &self,
169        x: &mut Matrix<T, C, C2, S2>,
170        b: Matrix<T, R, C2, S>,
171    ) -> Result<(), Error>
172    where
173        S: RawStorageMut<T, R, C2> + IsContiguous,
174        S2: RawStorageMut<T, C, C2> + IsContiguous,
175        T: Zero,
176    {
177        if self.nrows() < self.ncols() {
178            return Err(Error::Underdetermined);
179        }
180        let rank = self.rank();
181        qr_util::qr_solve_mut_with_rank_unpermuted(&self.qr, &self.tau, rank, x, b)?;
182        self.p().permute_rows_mut(x)?;
183        Ok(())
184    }
185}