nalgebra_lapack/colpiv_qr/
permutation.rs

1use na::{
2    DefaultAllocator, Dim, IsContiguous, Matrix, OVector, RawStorageMut, allocator::Allocator,
3};
4
5use super::Error;
6use crate::qr::QrScalar;
7
8#[cfg(test)]
9mod test;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12/// describes an orthogonal permutation matrix `P` that can be used to permute
13/// the rows or columns of an appropriately shaped matrix. Due to LAPACK
14/// internals, an instance must be mutably borrowed when applying the
15/// permutation, but the permutation itself will remain unchanged on success.
16///
17/// **Note**: the associated functions take `Self` by mutable reference due
18/// to the LAPACK internal implementation of the permutation logic.
19pub struct Permutation<D>
20where
21    D: Dim,
22    DefaultAllocator: Allocator<D>,
23{
24    jpvt: OVector<i32, D>,
25}
26
27impl<'a, D> Permutation<D>
28where
29    D: Dim,
30    DefaultAllocator: Allocator<D>,
31{
32    /// Apply the column permutation to a matrix `A`, equivalent to `P A`.
33    #[inline]
34    pub fn permute_cols_mut<T, R, S>(&mut self, mat: &mut Matrix<T, R, D, S>) -> Result<(), Error>
35    where
36        R: Dim,
37        S: RawStorageMut<T, R, D> + IsContiguous,
38        T: QrScalar,
39    {
40        //@note(geo-ant) due to the LAPACK internal logic in the jpvt vector, we have
41        // to invert the forward/backward argument here
42        self.apply_cols_mut(false, mat)
43    }
44
45    /// Apply the inverse column permutation to a matrix `A`, equivalent to `P^T A`.
46    #[inline]
47    pub fn inv_permute_cols_mut<T, R, S>(
48        &mut self,
49        mat: &mut Matrix<T, R, D, S>,
50    ) -> Result<(), Error>
51    where
52        R: Dim,
53        S: RawStorageMut<T, R, D> + IsContiguous,
54        T: QrScalar,
55    {
56        self.apply_cols_mut(true, mat)
57    }
58
59    /// Apply the row permutation to a matrix `A`, equivalent to `P A`.
60    #[inline]
61    pub fn permute_rows_mut<T, C, S>(&mut self, mat: &mut Matrix<T, D, C, S>) -> Result<(), Error>
62    where
63        C: Dim,
64        S: RawStorageMut<T, D, C> + IsContiguous,
65        T: QrScalar,
66    {
67        self.apply_rows_mut(false, mat)
68    }
69
70    /// Apply the inverse row permutation to a matrix `A`, equivalent to `P^T A`.
71    #[inline]
72    pub fn inv_permute_rows_mut<T, C, S>(
73        &mut self,
74        mat: &mut Matrix<T, D, C, S>,
75    ) -> Result<(), Error>
76    where
77        C: Dim,
78        S: RawStorageMut<T, D, C> + IsContiguous,
79        T: QrScalar,
80    {
81        self.apply_rows_mut(true, mat)
82    }
83
84    #[inline]
85    /// a thin wrapper around lapacks xLAPMR
86    fn apply_rows_mut<T, C, S>(
87        &mut self,
88        forward: bool,
89        mat: &mut Matrix<T, D, C, S>,
90    ) -> Result<(), Error>
91    where
92        C: Dim,
93        S: RawStorageMut<T, D, C> + IsContiguous,
94        T: QrScalar,
95    {
96        if mat.nrows() != self.jpvt.len() {
97            return Err(Error::Dimensions);
98        }
99
100        let m = mat
101            .nrows()
102            .try_into()
103            .expect("matrix dimensions out of bounds");
104        let n = mat
105            .ncols()
106            .try_into()
107            .expect("matrix dimensions out of bounds");
108
109        // SAFETY: inputs according to spec, see
110        // https://www.netlib.org/lapack/explore-html/d3/d10/group__lapmr.html
111        unsafe {
112            T::xlapmr(
113                forward,
114                m,
115                n,
116                mat.as_mut_slice(),
117                m,
118                self.jpvt.as_mut_slice(),
119            )?;
120        }
121
122        Ok(())
123    }
124
125    #[inline]
126    /// a thin wrapper around LAPACKS xLAMPT
127    fn apply_cols_mut<T, R, S>(
128        &mut self,
129        forward: bool,
130        mat: &mut Matrix<T, R, D, S>,
131    ) -> Result<(), Error>
132    where
133        R: Dim,
134        S: RawStorageMut<T, R, D> + IsContiguous,
135        T: QrScalar,
136    {
137        if mat.ncols() != self.jpvt.len() {
138            return Err(Error::Dimensions);
139        }
140        let m = mat
141            .nrows()
142            .try_into()
143            .expect("matrix dimensions out of bounds");
144        let n = mat
145            .ncols()
146            .try_into()
147            .expect("matrix dimensions out of bounds");
148
149        // SAFETY: arguments provided according to spec
150        // https://www.netlib.org/lapack/explore-html/d0/dcb/group__lapmt.html
151        unsafe {
152            T::xlapmt(
153                forward,
154                m,
155                n,
156                mat.as_mut_slice(),
157                m,
158                self.jpvt.as_mut_slice(),
159            )?;
160        }
161        Ok(())
162    }
163
164    pub(crate) fn new(jpvt: OVector<i32, D>) -> Self {
165        Self { jpvt }
166    }
167}