1use super::qr::{QrReal, QrScalar};
2use crate::qr::QrDecomposition;
3use crate::qr_util;
4use crate::sealed::Sealed;
5use na::{Const, IsContiguous, Matrix, OVector, RealField, Vector};
6use nalgebra::storage::RawStorageMut;
7use nalgebra::{DefaultAllocator, Dim, DimMin, DimMinimum, OMatrix, Scalar, allocator::Allocator};
8use num::float::TotalOrder;
9use num::{Float, Zero};
10use rank::{RankDeterminationAlgorithm, calculate_rank};
11
12pub use qr_util::Error;
13mod permutation;
14#[cfg(test)]
15mod test;
16pub use permutation::Permutation;
17mod rank;
19
20#[derive(Debug, Clone)]
30pub struct ColPivQR<T, R, C>
31where
32 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<C>,
33 T: Scalar,
34 R: DimMin<C, Output = C>,
35 C: Dim,
36{
37 qr: OMatrix<T, R, C>,
39 tau: OVector<T, DimMinimum<R, C>>,
41 jpvt: OVector<i32, C>,
44 rank: i32,
46}
47
48impl<T, R, C> Sealed for ColPivQR<T, R, C>
49where
50 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<C>,
51 T: Scalar,
52 R: DimMin<C, Output = C>,
53 C: Dim,
54{
55}
56
57impl<T, R, C> ColPivQR<T, R, C>
59where
60 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<C>,
61 T: QrScalar + Zero + RealField + TotalOrder + Float,
62 R: DimMin<C, Output = C>,
63 C: Dim,
64{
65 pub fn new(m: OMatrix<T, R, C>) -> Result<Self, Error> {
68 Self::with_rank_algo(m, Default::default())
69 }
70
71 pub fn with_rank_algo(
75 mut m: OMatrix<T, R, C>,
76 rank_algo: RankDeterminationAlgorithm<T>,
77 ) -> Result<Self, Error> {
78 let (nrows, ncols) = m.shape_generic();
79
80 if nrows.value() < ncols.value() {
81 return Err(Error::Underdetermined);
82 }
83
84 let mut tau: OVector<T, DimMinimum<R, C>> =
85 Vector::zeros_generic(nrows.min(ncols), Const::<1>);
86 let mut jpvt: OVector<i32, C> = Vector::zeros_generic(ncols, Const::<1>);
87
88 let lwork = unsafe {
91 T::xgeqp3_work_size(
92 nrows.value().try_into().expect("matrix dims out of bounds"),
93 ncols.value().try_into().expect("matrix dims out of bounds"),
94 m.as_mut_slice(),
95 nrows.value().try_into().expect("matrix dims out of bounds"),
96 jpvt.as_mut_slice(),
97 tau.as_mut_slice(),
98 )?
99 };
100
101 let mut work = vec![T::zero(); lwork as usize];
102
103 unsafe {
106 T::xgeqp3(
107 nrows.value() as i32,
108 ncols.value() as i32,
109 m.as_mut_slice(),
110 nrows.value() as i32,
111 jpvt.as_mut_slice(),
112 tau.as_mut_slice(),
113 &mut work,
114 lwork,
115 )?;
116 }
117
118 let rank: i32 = calculate_rank(&m, rank_algo)
119 .try_into()
120 .map_err(|_| Error::Dimensions)?;
121
122 Ok(Self {
123 qr: m,
124 rank,
125 tau,
126 jpvt,
127 })
128 }
129}
130
131impl<T, R, C> ColPivQR<T, R, C>
132where
133 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<C>,
134 T: QrScalar + Zero + RealField,
135 R: DimMin<C, Output = C>,
136 C: Dim,
137{
138 #[inline]
141 pub fn rank(&self) -> u16 {
142 self.rank as u16
143 }
144 pub fn p(&self) -> Permutation<C> {
148 Permutation::new(self.jpvt.clone())
149 }
150}
151
152impl<T, R, C> QrDecomposition<T, R, C> for ColPivQR<T, R, C>
153where
154 DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<C>,
155 R: DimMin<C, Output = C>,
156 C: Dim,
157 T: Scalar + RealField + QrReal,
158{
159 fn __lapack_qr_ref(&self) -> &OMatrix<T, R, C> {
160 &self.qr
161 }
162
163 fn __lapack_tau_ref(&self) -> &OVector<T, DimMinimum<R, C>> {
164 &self.tau
165 }
166
167 fn solve_mut<C2: Dim, S, S2>(
168 &self,
169 x: &mut Matrix<T, C, C2, S2>,
170 b: Matrix<T, R, C2, S>,
171 ) -> Result<(), Error>
172 where
173 S: RawStorageMut<T, R, C2> + IsContiguous,
174 S2: RawStorageMut<T, C, C2> + IsContiguous,
175 T: Zero,
176 {
177 if self.nrows() < self.ncols() {
178 return Err(Error::Underdetermined);
179 }
180 let rank = self.rank();
181 qr_util::qr_solve_mut_with_rank_unpermuted(&self.qr, &self.tau, rank, x, b)?;
182 self.p().permute_rows_mut(x)?;
183 Ok(())
184 }
185}