use crate::{DiagonalKind, LapackErrorCode, Side, Transposition, TriangularStructure, qr::QrReal};
use na::{
Dim, DimMin, DimMinimum, IsContiguous, Matrix, RawStorage, RawStorageMut, RealField, Vector,
};
use num::{ConstOne, Zero};
#[derive(Debug, PartialEq, thiserror::Error)]
pub enum Error {
#[error("incorrect matrix dimensions")]
Dimensions,
#[error("Lapack returned with error: {0}")]
Lapack(#[from] LapackErrorCode),
#[error("QR decomposition for underdetermined systems not supported")]
Underdetermined,
#[error("Matrix has rank zero")]
ZeroRank,
}
pub(crate) fn q_mul_mut<T, R1, C1, S1, C2, S2, S3>(
qr: &Matrix<T, R1, C1, S1>,
tau: &Vector<T, DimMinimum<R1, C1>, S3>,
b: &mut Matrix<T, R1, C2, S2>,
) -> Result<(), Error>
where
T: QrReal + Zero + RealField,
R1: DimMin<C1>,
C1: Dim,
S1: RawStorage<T, R1, C1> + IsContiguous,
C2: Dim,
S2: RawStorageMut<T, R1, C2> + IsContiguous,
S3: RawStorage<T, <R1 as DimMin<C1>>::Output> + IsContiguous,
{
if b.nrows() != qr.nrows() {
return Err(Error::Dimensions);
}
if qr.ncols().min(qr.nrows()) != tau.len() {
return Err(Error::Dimensions);
}
unsafe { multiply_q_mut(qr, tau, b, Side::Left, Transposition::No)? };
Ok(())
}
pub(crate) fn q_tr_mul_mut<T, R1, C1, S1, C2, S2, S3>(
qr: &Matrix<T, R1, C1, S1>,
tau: &Vector<T, DimMinimum<R1, C1>, S3>,
b: &mut Matrix<T, R1, C2, S2>,
) -> Result<(), Error>
where
T: QrReal + Zero + RealField,
R1: DimMin<C1>,
C1: Dim,
S1: RawStorage<T, R1, C1> + IsContiguous,
C2: Dim,
C2: Dim,
S2: RawStorageMut<T, R1, C2> + IsContiguous,
S3: RawStorage<T, <R1 as DimMin<C1>>::Output> + IsContiguous,
{
if b.nrows() != qr.nrows() {
return Err(Error::Dimensions);
}
if qr.ncols().min(qr.nrows()) != tau.len() {
return Err(Error::Dimensions);
}
unsafe { multiply_q_mut(qr, tau, b, Side::Left, Transposition::Transpose)? };
Ok(())
}
pub(crate) fn mul_q_mut<T, R1, C1, S1, R2, S2, S3>(
qr: &Matrix<T, R1, C1, S1>,
tau: &Vector<T, DimMinimum<R1, C1>, S3>,
b: &mut Matrix<T, R2, R1, S2>,
) -> Result<(), Error>
where
T: QrReal + Zero + RealField,
R1: DimMin<C1>,
C1: Dim,
S1: RawStorage<T, R1, C1> + IsContiguous,
R2: Dim,
S2: RawStorageMut<T, R2, R1> + IsContiguous,
S3: RawStorage<T, <R1 as DimMin<C1>>::Output> + IsContiguous,
{
if b.ncols() != qr.nrows() {
return Err(Error::Dimensions);
}
if qr.ncols().min(qr.nrows()) != tau.len() {
return Err(Error::Dimensions);
}
unsafe { multiply_q_mut(qr, tau, b, Side::Right, Transposition::No)? };
Ok(())
}
pub(crate) fn mul_q_tr_mut<T, R1, C1, S1, R2, S2, S3>(
qr: &Matrix<T, R1, C1, S1>,
tau: &Vector<T, DimMinimum<R1, C1>, S3>,
b: &mut Matrix<T, R2, R1, S2>,
) -> Result<(), Error>
where
T: QrReal + Zero + RealField,
R1: DimMin<C1>,
C1: Dim,
S1: RawStorage<T, R1, C1> + IsContiguous,
R2: Dim,
S2: RawStorageMut<T, R2, R1> + IsContiguous,
S3: RawStorage<T, <R1 as DimMin<C1>>::Output> + IsContiguous,
{
if b.ncols() != qr.nrows() {
return Err(Error::Dimensions);
}
if qr.ncols().min(qr.nrows()) != tau.len() {
return Err(Error::Dimensions);
}
unsafe { multiply_q_mut(qr, tau, b, Side::Right, Transposition::Transpose)? }
Ok(())
}
pub(crate) fn qr_solve_mut_with_rank_unpermuted<T, R1, C1, S1, C2: Dim, S3, S2, S4>(
qr: &Matrix<T, R1, C1, S1>,
tau: &Vector<T, DimMinimum<R1, C1>, S4>,
rank: u16,
x: &mut Matrix<T, C1, C2, S2>,
mut b: Matrix<T, R1, C2, S3>,
) -> Result<(), Error>
where
T: QrReal + Zero + RealField,
R1: DimMin<C1>,
C1: Dim,
S1: RawStorage<T, R1, C1> + IsContiguous,
S3: RawStorageMut<T, R1, C2> + IsContiguous,
S2: RawStorageMut<T, C1, C2> + IsContiguous,
S4: RawStorage<T, <R1 as DimMin<C1>>::Output> + IsContiguous,
{
if b.nrows() != qr.nrows() {
return Err(Error::Dimensions);
}
if qr.nrows() < qr.ncols() || qr.nrows() == 0 || qr.ncols() == 0 {
return Err(Error::Underdetermined);
}
if x.ncols() != b.ncols() || x.nrows() != qr.ncols() {
return Err(Error::Dimensions);
}
q_tr_mul_mut(qr, tau, &mut b)?;
if rank == 0 {
return Err(Error::ZeroRank);
}
debug_assert!(rank as usize <= qr.ncols().min(qr.nrows()));
if (rank as usize) < qr.ncols() {
x.view_mut((rank as usize, 0), (x.nrows() - rank as usize, x.ncols()))
.iter_mut()
.for_each(|val| val.set_zero());
}
let x_cols = x.ncols();
x.view_mut((0, 0), (rank as usize, x_cols))
.copy_from(&b.view((0, 0), (rank as usize, x_cols)));
let ldb: i32 = x
.nrows()
.try_into()
.expect("integer dimensions out of bounds");
unsafe {
T::xtrtrs(
TriangularStructure::Upper,
Transposition::No,
DiagonalKind::NonUnit,
rank.try_into().expect("rank out of bounds"),
x.ncols()
.try_into()
.expect("integer dimensions out of bounds"),
qr.as_slice(),
qr.nrows()
.try_into()
.expect("integer dimensions out of bounds"),
x.as_mut_slice(),
ldb,
)?;
}
Ok(())
}
#[inline]
unsafe fn multiply_q_mut<T, R1, C1, S1, R2, C2, S2, S3>(
qr: &Matrix<T, R1, C1, S1>,
tau: &Vector<T, DimMinimum<R1, C1>, S3>,
mat: &mut Matrix<T, R2, C2, S2>,
side: Side,
transpose: Transposition,
) -> Result<(), Error>
where
T: QrReal,
R1: DimMin<C1>,
C1: Dim,
S2: RawStorageMut<T, R2, C2> + IsContiguous,
R2: Dim,
C2: Dim,
S1: IsContiguous + RawStorage<T, R1, C1>,
S3: RawStorage<T, <R1 as DimMin<C1>>::Output> + IsContiguous,
{
let a = qr.as_slice();
let lda = qr
.nrows()
.try_into()
.expect("integer dimension out of range");
let m = mat
.nrows()
.try_into()
.expect("integer dimension out of range");
let n = mat
.ncols()
.try_into()
.expect("integer dimension out of range");
let k = tau
.len()
.try_into()
.expect("integer dimension out of range");
let ldc = mat
.nrows()
.try_into()
.expect("integer dimension out of range");
let c = mat.as_mut_slice();
let trans = transpose;
let tau = tau.as_slice();
if k as usize != qr.ncols() {
return Err(Error::Dimensions);
}
match side {
Side::Left => {
if m < k {
return Err(Error::Dimensions);
}
if lda < m {
return Err(Error::Dimensions);
}
}
Side::Right => {
if n < k {
return Err(Error::Dimensions);
}
if lda < n {
return Err(Error::Dimensions);
}
}
}
if ldc < m {
return Err(Error::Dimensions);
}
let lwork = unsafe { T::xormqr_work_size(side, transpose, m, n, k, a, lda, tau, c, ldc)? };
let mut work = vec![T::zero(); lwork as usize];
unsafe {
T::xormqr(side, trans, m, n, k, a, lda, tau, c, ldc, &mut work, lwork)?;
}
Ok(())
}
pub fn r_xx_mul_mut<T, R1, C1, S1, C2, S2>(
qr: &Matrix<T, R1, C1, S1>,
transpose: Transposition,
b: &mut Matrix<T, C1, C2, S2>,
) -> Result<(), Error>
where
T: QrReal + ConstOne,
R1: Dim,
C1: Dim,
C2: Dim,
S1: RawStorage<T, R1, C1> + IsContiguous,
S2: RawStorageMut<T, C1, C2> + IsContiguous,
{
if qr.nrows() < qr.ncols() {
return Err(Error::Underdetermined);
}
if qr.ncols() != b.nrows() {
return Err(Error::Dimensions);
}
multiply_r_mut(qr, transpose, Side::Left, b)?;
Ok(())
}
pub fn mul_r_xx_mut<T, R1, C1, S1, R2, S2>(
qr: &Matrix<T, R1, C1, S1>,
transpose: Transposition,
b: &mut Matrix<T, R2, C1, S2>,
) -> Result<(), Error>
where
T: QrReal + ConstOne,
R1: Dim,
C1: Dim,
R2: Dim,
S1: RawStorage<T, R1, C1> + IsContiguous,
S2: RawStorageMut<T, R2, C1> + IsContiguous,
{
if qr.nrows() < qr.ncols() {
return Err(Error::Underdetermined);
}
if b.ncols() != qr.ncols() {
return Err(Error::Dimensions);
}
multiply_r_mut(qr, transpose, Side::Right, b)?;
Ok(())
}
#[inline]
fn multiply_r_mut<T, R1, C1, S1, R2, C2, S2>(
qr: &Matrix<T, R1, C1, S1>,
transpose: Transposition,
side: Side,
mat: &mut Matrix<T, R2, C2, S2>,
) -> Result<(), Error>
where
T: QrReal + ConstOne,
R1: Dim,
C1: Dim,
S2: RawStorageMut<T, R2, C2> + IsContiguous,
R2: Dim,
C2: Dim,
S1: IsContiguous + RawStorage<T, R1, C1>,
{
let m: i32 = mat
.nrows()
.try_into()
.expect("integer dimensions out of bounds");
let n: i32 = mat
.ncols()
.try_into()
.expect("integer dimensions out of bounds");
let lda: i32 = qr
.nrows()
.try_into()
.expect("integer dimensions out of bounds");
let ldb: i32 = mat
.nrows()
.try_into()
.expect("integer dimensions out of bounds");
match side {
Side::Left => {
if lda == 0 || lda < m {
return Err(Error::Dimensions);
}
if qr.ncols() != m as usize {
return Err(Error::Dimensions);
}
}
Side::Right => {
if lda == 0 || lda < n {
return Err(Error::Dimensions);
}
if qr.ncols() != n as usize {
return Err(Error::Dimensions);
}
}
}
unsafe {
T::xtrmm(
side,
TriangularStructure::Upper,
transpose,
DiagonalKind::NonUnit,
m,
n,
T::ONE,
qr.as_slice(),
lda,
mat.as_mut_slice(),
ldb,
);
}
Ok(())
}