use super::qr::{QrReal, QrScalar};
use crate::qr::QrDecomposition;
use crate::qr_util;
use crate::sealed::Sealed;
use na::{Const, IsContiguous, Matrix, OVector, RealField, Vector};
use nalgebra::storage::RawStorageMut;
use nalgebra::{DefaultAllocator, Dim, DimMin, DimMinimum, OMatrix, Scalar, allocator::Allocator};
use num::float::TotalOrder;
use num::{Float, Zero};
use rank::{RankDeterminationAlgorithm, calculate_rank};
pub use qr_util::Error;
mod permutation;
#[cfg(test)]
mod test;
pub use permutation::Permutation;
mod rank;
#[derive(Debug, Clone)]
pub struct ColPivQR<T, R, C>
where
DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<C>,
T: Scalar,
R: DimMin<C, Output = C>,
C: Dim,
{
qr: OMatrix<T, R, C>,
tau: OVector<T, DimMinimum<R, C>>,
jpvt: OVector<i32, C>,
rank: i32,
}
impl<T, R, C> Sealed for ColPivQR<T, R, C>
where
DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<C>,
T: Scalar,
R: DimMin<C, Output = C>,
C: Dim,
{
}
impl<T, R, C> ColPivQR<T, R, C>
where
DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<C>,
T: QrScalar + Zero + RealField + TotalOrder + Float,
R: DimMin<C, Output = C>,
C: Dim,
{
pub fn new(m: OMatrix<T, R, C>) -> Result<Self, Error> {
Self::with_rank_algo(m, Default::default())
}
pub fn with_rank_algo(
mut m: OMatrix<T, R, C>,
rank_algo: RankDeterminationAlgorithm<T>,
) -> Result<Self, Error> {
let (nrows, ncols) = m.shape_generic();
if nrows.value() < ncols.value() {
return Err(Error::Underdetermined);
}
let mut tau: OVector<T, DimMinimum<R, C>> =
Vector::zeros_generic(nrows.min(ncols), Const::<1>);
let mut jpvt: OVector<i32, C> = Vector::zeros_generic(ncols, Const::<1>);
let lwork = unsafe {
T::xgeqp3_work_size(
nrows.value().try_into().expect("matrix dims out of bounds"),
ncols.value().try_into().expect("matrix dims out of bounds"),
m.as_mut_slice(),
nrows.value().try_into().expect("matrix dims out of bounds"),
jpvt.as_mut_slice(),
tau.as_mut_slice(),
)?
};
let mut work = vec![T::zero(); lwork as usize];
unsafe {
T::xgeqp3(
nrows.value() as i32,
ncols.value() as i32,
m.as_mut_slice(),
nrows.value() as i32,
jpvt.as_mut_slice(),
tau.as_mut_slice(),
&mut work,
lwork,
)?;
}
let rank: i32 = calculate_rank(&m, rank_algo)
.try_into()
.map_err(|_| Error::Dimensions)?;
Ok(Self {
qr: m,
rank,
tau,
jpvt,
})
}
}
impl<T, R, C> ColPivQR<T, R, C>
where
DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<C>,
T: QrScalar + Zero + RealField,
R: DimMin<C, Output = C>,
C: Dim,
{
#[inline]
pub fn rank(&self) -> u16 {
self.rank as u16
}
pub fn p(&self) -> Permutation<C> {
Permutation::new(self.jpvt.clone())
}
}
impl<T, R, C> QrDecomposition<T, R, C> for ColPivQR<T, R, C>
where
DefaultAllocator: Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<C>,
R: DimMin<C, Output = C>,
C: Dim,
T: Scalar + RealField + QrReal,
{
fn __lapack_qr_ref(&self) -> &OMatrix<T, R, C> {
&self.qr
}
fn __lapack_tau_ref(&self) -> &OVector<T, DimMinimum<R, C>> {
&self.tau
}
fn solve_mut<C2: Dim, S, S2>(
&self,
x: &mut Matrix<T, C, C2, S2>,
b: Matrix<T, R, C2, S>,
) -> Result<(), Error>
where
S: RawStorageMut<T, R, C2> + IsContiguous,
S2: RawStorageMut<T, C, C2> + IsContiguous,
T: Zero,
{
if self.nrows() < self.ncols() {
return Err(Error::Underdetermined);
}
let rank = self.rank();
qr_util::qr_solve_mut_with_rank_unpermuted(&self.qr, &self.tau, rank, x, b)?;
self.p().permute_rows_mut(x)?;
Ok(())
}
}