use crate::linalg::{split_two_col_slices, LinalgError};
use crate::matrix::vector::Vector;
use crate::traits::{LinalgScalar, MatrixMut, MatrixRef};
use crate::Matrix;
use num_traits::Zero;
pub fn qr_in_place<T: LinalgScalar>(
a: &mut impl MatrixMut<T>,
tau: &mut [T],
) -> Result<(), LinalgError> {
let m = a.nrows();
let n = a.ncols();
let k = m.min(n);
assert!(m >= n, "QR decomposition requires M >= N");
assert_eq!(tau.len(), k, "tau length must equal min(M, N)");
for col in 0..k {
let sub_col = a.col_as_slice(col, col);
let mut norm_sq = <T::Real as Zero>::zero();
for &v in sub_col {
norm_sq = norm_sq + (v * v.conj()).re();
}
if norm_sq < T::lepsilon() {
return Err(LinalgError::Singular);
}
let norm = norm_sq.lsqrt();
let a_col_col = *a.get(col, col);
let alpha = a_col_col.modulus();
let sigma = if alpha < T::lepsilon() {
T::from_real(norm)
} else {
T::from_real(norm) * (a_col_col / T::from_real(alpha))
};
let v0 = a_col_col + sigma;
*a.get_mut(col, col) = v0;
let tau_val = v0 / sigma;
tau[col] = tau_val;
{
let sub_col = a.col_as_mut_slice(col, col + 1);
for x in sub_col.iter_mut() {
*x = *x / v0;
}
}
for j in (col + 1)..n {
let mut dot = *a.get(col, j); let (v_slice, a_j_slice) = split_two_col_slices(a, col, j, col + 1);
for idx in 0..v_slice.len() {
dot = dot + v_slice[idx].conj() * a_j_slice[idx];
}
dot = dot * tau_val;
*a.get_mut(col, j) = *a.get(col, j) - dot; let (v_slice, a_j_slice) = split_two_col_slices(a, col, j, col + 1);
crate::simd::axpy_neg_dispatch(a_j_slice, dot, v_slice);
}
*a.get_mut(col, col) = T::zero() - sigma;
}
Ok(())
}
#[derive(Debug)]
pub struct QrDecomposition<T, const M: usize, const N: usize> {
qr: Matrix<T, M, N>,
tau: [T; N],
}
impl<T: LinalgScalar, const M: usize, const N: usize> QrDecomposition<T, M, N> {
pub fn new(a: &Matrix<T, M, N>) -> Result<Self, LinalgError> {
assert!(M >= N, "QR decomposition requires M >= N");
let mut qr = *a;
let mut tau = [T::zero(); N];
qr_in_place(&mut qr, &mut tau)?;
Ok(Self { qr, tau })
}
pub fn r(&self) -> Matrix<T, N, N> {
let mut r = Matrix::<T, N, N>::zeros();
for i in 0..N {
for j in i..N {
r[(i, j)] = self.qr[(i, j)];
}
}
r
}
pub fn q(&self) -> Matrix<T, M, N> {
let mut q = Matrix::<T, M, N>::zeros();
for i in 0..N {
q[(i, i)] = T::one();
}
for col in (0..N).rev() {
let tau_val = self.tau[col];
let v_slice = self.qr.col_as_slice(col, col + 1);
for j in col..N {
let mut dot = q[(col, j)]; let q_j_slice = q.col_as_slice(j, col + 1);
for idx in 0..v_slice.len() {
dot = dot + v_slice[idx].conj() * q_j_slice[idx];
}
dot = dot * tau_val;
q[(col, j)] = q[(col, j)] - dot;
let q_j_slice = q.col_as_mut_slice(j, col + 1);
crate::simd::axpy_neg_dispatch(q_j_slice, dot, v_slice);
}
}
q
}
pub fn solve(&self, b: &Vector<T, M>) -> Vector<T, N> {
let mut qtb = [T::zero(); M];
for i in 0..M {
qtb[i] = b[i];
}
for col in 0..N {
let tau_val = self.tau[col];
let v_slice = self.qr.col_as_slice(col, col + 1);
let mut dot = qtb[col]; for idx in 0..v_slice.len() {
dot = dot + v_slice[idx].conj() * qtb[col + 1 + idx];
}
dot = dot * tau_val;
qtb[col] = qtb[col] - dot;
crate::simd::axpy_neg_dispatch(&mut qtb[col + 1..], dot, v_slice);
}
let mut x = [T::zero(); N];
for i in (0..N).rev() {
let mut sum = qtb[i];
for j in (i + 1)..N {
sum = sum - self.qr[(i, j)] * x[j];
}
x[i] = sum / self.qr[(i, i)];
}
Vector::from_array(x)
}
pub fn det(&self) -> T {
assert_eq!(M, N, "determinant requires a square matrix");
let mut d = T::one();
for i in 0..N {
d = d * self.qr[(i, i)];
}
d
}
}
pub fn qr_col_pivot_in_place<T: LinalgScalar>(
a: &mut impl MatrixMut<T>,
tau: &mut [T],
perm: &mut [usize],
) {
let m = a.nrows();
let n = a.ncols();
let k = m.min(n);
assert!(m >= n, "QR decomposition requires M >= N");
assert_eq!(tau.len(), k, "tau length must equal min(M, N)");
assert_eq!(perm.len(), n, "perm length must equal N");
let mut col_norms = [<T::Real as Zero>::zero(); 64];
assert!(n <= 64, "qr_col_pivot_in_place: N must be <= 64 for fixed-size stack storage");
for j in 0..n {
perm[j] = j;
let col = a.col_as_slice(j, 0);
let mut s = <T::Real as Zero>::zero();
for &v in col {
s = s + (v * v.conj()).re();
}
col_norms[j] = s;
}
for col in 0..k {
let mut best_j = col;
let mut best_norm = col_norms[col];
for j in (col + 1)..n {
if col_norms[j] > best_norm {
best_norm = col_norms[j];
best_j = j;
}
}
if best_j != col {
perm.swap(col, best_j);
col_norms.swap(col, best_j);
for i in 0..m {
let tmp = *a.get(i, col);
*a.get_mut(i, col) = *a.get(i, best_j);
*a.get_mut(i, best_j) = tmp;
}
}
let sub_col = a.col_as_slice(col, col);
let mut norm_sq = <T::Real as Zero>::zero();
for &v in sub_col {
norm_sq = norm_sq + (v * v.conj()).re();
}
if norm_sq < T::lepsilon() {
tau[col] = T::zero();
continue;
}
let norm = norm_sq.lsqrt();
let a_col_col = *a.get(col, col);
let alpha = a_col_col.modulus();
let sigma = if alpha < T::lepsilon() {
T::from_real(norm)
} else {
T::from_real(norm) * (a_col_col / T::from_real(alpha))
};
let v0 = a_col_col + sigma;
*a.get_mut(col, col) = v0;
let tau_val = v0 / sigma;
tau[col] = tau_val;
{
let sub_col = a.col_as_mut_slice(col, col + 1);
for x in sub_col.iter_mut() {
*x = *x / v0;
}
}
for j in (col + 1)..n {
let mut dot = *a.get(col, j);
let (v_slice, a_j_slice) = split_two_col_slices(a, col, j, col + 1);
for idx in 0..v_slice.len() {
dot = dot + v_slice[idx].conj() * a_j_slice[idx];
}
dot = dot * tau_val;
*a.get_mut(col, j) = *a.get(col, j) - dot;
let (v_slice, a_j_slice) = split_two_col_slices(a, col, j, col + 1);
crate::simd::axpy_neg_dispatch(a_j_slice, dot, v_slice);
let _eliminated = (*a.get(col, j) * (*a.get(col, j)).conj()).re();
let sub = a.col_as_slice(j, col + 1);
let mut s = <T::Real as Zero>::zero();
for &v in sub {
s = s + (v * v.conj()).re();
}
col_norms[j] = s;
}
*a.get_mut(col, col) = T::zero() - sigma;
}
}
#[derive(Debug)]
pub struct QrPivotDecomposition<T, const M: usize, const N: usize> {
qr: Matrix<T, M, N>,
tau: [T; N],
perm: [usize; N],
}
impl<T: LinalgScalar, const M: usize, const N: usize> QrPivotDecomposition<T, M, N> {
pub fn new(a: &Matrix<T, M, N>) -> Self {
assert!(M >= N, "QR decomposition requires M >= N");
let mut qr = *a;
let mut tau = [T::zero(); N];
let mut perm = [0usize; N];
qr_col_pivot_in_place(&mut qr, &mut tau, &mut perm);
Self { qr, tau, perm }
}
pub fn r(&self) -> Matrix<T, N, N> {
let mut r = Matrix::<T, N, N>::zeros();
for i in 0..N {
for j in i..N {
r[(i, j)] = self.qr[(i, j)];
}
}
r
}
pub fn q(&self) -> Matrix<T, M, N> {
let mut q = Matrix::<T, M, N>::zeros();
for i in 0..N {
q[(i, i)] = T::one();
}
for col in (0..N).rev() {
let tau_val = self.tau[col];
if tau_val == T::zero() {
continue;
}
let v_slice = self.qr.col_as_slice(col, col + 1);
for j in col..N {
let mut dot = q[(col, j)];
let q_j_slice = q.col_as_slice(j, col + 1);
for idx in 0..v_slice.len() {
dot = dot + v_slice[idx].conj() * q_j_slice[idx];
}
dot = dot * tau_val;
q[(col, j)] = q[(col, j)] - dot;
let q_j_slice = q.col_as_mut_slice(j, col + 1);
crate::simd::axpy_neg_dispatch(q_j_slice, dot, v_slice);
}
}
q
}
pub fn permutation(&self) -> &[usize; N] {
&self.perm
}
pub fn rank(&self, tol: T::Real) -> usize {
(0..N)
.take_while(|&i| self.qr[(i, i)].modulus() > tol)
.count()
}
}
impl<T: LinalgScalar, const M: usize, const N: usize> Matrix<T, M, N> {
pub fn qr(&self) -> Result<QrDecomposition<T, M, N>, LinalgError> {
QrDecomposition::new(self)
}
pub fn qr_col_pivot(&self) -> QrPivotDecomposition<T, M, N> {
QrPivotDecomposition::new(self)
}
}
impl<T: LinalgScalar, const N: usize> Matrix<T, N, N> {
pub fn solve_qr(&self, b: &Vector<T, N>) -> Result<Vector<T, N>, LinalgError> {
Ok(self.qr()?.solve(b))
}
}
#[cfg(test)]
mod tests {
use super::*;
const TOL: f64 = 1e-10;
fn assert_near(a: f64, b: f64, tol: f64, msg: &str) {
assert!((a - b).abs() < tol, "{}: {} vs {} (diff {})", msg, a, b, (a - b).abs());
}
#[test]
fn qr_square_3x3() {
let a = Matrix::new([
[12.0_f64, -51.0, 4.0],
[6.0, 167.0, -68.0],
[-4.0, 24.0, -41.0],
]);
let qr = a.qr().unwrap();
let q = qr.q();
let r = qr.r();
let qr_prod = q * r;
for i in 0..3 {
for j in 0..3 {
assert_near(qr_prod[(i, j)], a[(i, j)], TOL,
&format!("QR[({},{})]", i, j));
}
}
let qtq = q.transpose() * q;
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert_near(qtq[(i, j)], expected, TOL,
&format!("QtQ[({},{})]", i, j));
}
}
}
#[test]
fn qr_rectangular_4x3() {
let a = Matrix::new([
[1.0_f64, -1.0, 4.0],
[1.0, 4.0, -2.0],
[1.0, 4.0, 2.0],
[1.0, -1.0, 0.0],
]);
let qr = a.qr().unwrap();
let q = qr.q();
let r = qr.r();
let qr_prod: Matrix<f64, 4, 3> = q * r;
for i in 0..4 {
for j in 0..3 {
assert_near(qr_prod[(i, j)], a[(i, j)], TOL,
&format!("QR[({},{})]", i, j));
}
}
let qtq: Matrix<f64, 3, 3> = q.transpose() * q;
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert_near(qtq[(i, j)], expected, TOL,
&format!("QtQ[({},{})]", i, j));
}
}
}
#[test]
fn qr_solve_square() {
let a = Matrix::new([
[2.0_f64, 1.0, -1.0],
[-3.0, -1.0, 2.0],
[-2.0, 1.0, 2.0],
]);
let b = Vector::from_array([8.0, -11.0, -3.0]);
let x_qr = a.solve_qr(&b).unwrap();
let x_lu = a.solve(&b).unwrap();
for i in 0..3 {
assert_near(x_qr[i], x_lu[i], TOL, &format!("x[{}]", i));
}
}
#[test]
fn qr_least_squares() {
let a = Matrix::new([
[1.0_f64, 0.0],
[1.0, 1.0],
[1.0, 2.0],
]);
let b = Vector::from_array([1.0, 2.0, 4.0]);
let qr = a.qr().unwrap();
let x = qr.solve(&b);
assert_near(x[0], 5.0 / 6.0, TOL, "c0");
assert_near(x[1], 3.0 / 2.0, TOL, "c1");
let ax = a * x;
let r = b - ax;
let at = a.transpose();
let atr = at * r;
for i in 0..2 {
assert_near(atr[i], 0.0, TOL, &format!("A^T r[{}]", i));
}
}
#[test]
fn qr_det() {
let a = Matrix::new([
[6.0_f64, 1.0, 1.0],
[4.0, -2.0, 5.0],
[2.0, 8.0, 7.0],
]);
let qr = a.qr().unwrap();
let det_qr = qr.det();
let det_lu = a.det();
assert_near(det_qr.abs(), det_lu.abs(), TOL, "det magnitude");
}
#[test]
fn qr_identity() {
let id: Matrix<f64, 3, 3> = Matrix::eye();
let qr = id.qr().unwrap();
let q = qr.q();
let r = qr.r();
let prod = q * r;
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert_near(prod[(i, j)], expected, TOL,
&format!("QR[({},{})]", i, j));
}
}
}
#[test]
fn qr_in_place_generic() {
let mut a = Matrix::new([[2.0_f64, 1.0], [4.0, 3.0]]);
let mut tau = [0.0; 2];
let result = qr_in_place(&mut a, &mut tau);
assert!(result.is_ok());
}
#[test]
fn qr_rank_deficient() {
let a = Matrix::new([
[1.0_f64, 0.0],
[0.0, 0.0],
]);
assert_eq!(a.qr().unwrap_err(), LinalgError::Singular);
}
#[test]
fn qr_pivot_full_rank_3x3() {
let a = Matrix::new([
[12.0_f64, -51.0, 4.0],
[6.0, 167.0, -68.0],
[-4.0, 24.0, -41.0],
]);
let qrp = a.qr_col_pivot();
let q = qrp.q();
let r = qrp.r();
let perm = qrp.permutation();
let qr_prod = q * r;
for i in 0..3 {
for j in 0..3 {
assert_near(qr_prod[(i, j)], a[(i, perm[j])], TOL,
&format!("QR[({},{})]", i, j));
}
}
let qtq = q.transpose() * q;
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert_near(qtq[(i, j)], expected, TOL,
&format!("QtQ[({},{})]", i, j));
}
}
assert_eq!(qrp.rank(1e-10), 3);
}
#[test]
fn qr_pivot_rank_deficient() {
let a = Matrix::new([
[1.0_f64, 2.0, 3.0],
[4.0, 5.0, 6.0],
[5.0, 7.0, 9.0],
]);
let qrp = a.qr_col_pivot();
assert_eq!(qrp.rank(1e-10), 2);
let q = qrp.q();
let r = qrp.r();
let perm = qrp.permutation();
let qr_prod = q * r;
for i in 0..3 {
for j in 0..3 {
assert_near(qr_prod[(i, j)], a[(i, perm[j])], TOL,
&format!("QR[({},{})]", i, j));
}
}
}
#[test]
fn qr_pivot_rank_1() {
let a = Matrix::new([
[1.0_f64, 2.0, 3.0],
[2.0, 4.0, 6.0],
[3.0, 6.0, 9.0],
]);
let qrp = a.qr_col_pivot();
assert_eq!(qrp.rank(1e-10), 1);
let q = qrp.q();
let r = qrp.r();
let perm = qrp.permutation();
let qr_prod = q * r;
for i in 0..3 {
for j in 0..3 {
assert_near(qr_prod[(i, j)], a[(i, perm[j])], TOL,
&format!("QR[({},{})]", i, j));
}
}
}
#[test]
fn qr_pivot_zero_matrix() {
let a = Matrix::<f64, 3, 3>::zeros();
let qrp = a.qr_col_pivot();
assert_eq!(qrp.rank(1e-10), 0);
}
#[test]
fn qr_pivot_identity() {
let id: Matrix<f64, 3, 3> = Matrix::eye();
let qrp = id.qr_col_pivot();
assert_eq!(qrp.rank(1e-10), 3);
let q = qrp.q();
let r = qrp.r();
let perm = qrp.permutation();
let qr_prod = q * r;
for i in 0..3 {
for j in 0..3 {
assert_near(qr_prod[(i, j)], id[(i, perm[j])], TOL,
&format!("QR[({},{})]", i, j));
}
}
}
#[test]
fn qr_pivot_rectangular_4x3() {
let a = Matrix::new([
[1.0_f64, -1.0, 4.0],
[1.0, 4.0, -2.0],
[1.0, 4.0, 2.0],
[1.0, -1.0, 0.0],
]);
let qrp = a.qr_col_pivot();
let q = qrp.q();
let r = qrp.r();
let perm = qrp.permutation();
let qr_prod: Matrix<f64, 4, 3> = q * r;
for i in 0..4 {
for j in 0..3 {
assert_near(qr_prod[(i, j)], a[(i, perm[j])], TOL,
&format!("QR[({},{})]", i, j));
}
}
let qtq: Matrix<f64, 3, 3> = q.transpose() * q;
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert_near(qtq[(i, j)], expected, TOL,
&format!("QtQ[({},{})]", i, j));
}
}
assert_eq!(qrp.rank(1e-10), 3);
}
#[test]
fn qr_pivot_r_diagonal_decreasing() {
let a = Matrix::new([
[1.0_f64, 100.0, 0.5],
[2.0, 200.0, 1.0],
[3.0, 300.0, 1.5],
[4.0, 400.0, 2.0],
]);
let qrp = a.qr_col_pivot();
let r = qrp.r();
for i in 0..(3 - 1) {
assert!(
r[(i, i)].abs() >= r[(i + 1, i + 1)].abs() - TOL,
"|R[{},{}]| = {} should >= |R[{},{}]| = {}",
i, i,
r[(i, i)].abs(),
i + 1, i + 1,
r[(i + 1, i + 1)].abs()
);
}
}
#[test]
fn qr_pivot_permutation_is_valid() {
let a = Matrix::new([
[0.1_f64, 10.0, 5.0],
[0.2, 20.0, 10.0],
[0.3, 30.0, 15.0],
]);
let qrp = a.qr_col_pivot();
let perm = qrp.permutation();
let mut seen = [false; 3];
for &p in perm {
assert!(p < 3, "permutation index out of range");
assert!(!seen[p], "duplicate permutation index");
seen[p] = true;
}
}
#[test]
fn qr_pivot_column_dependent() {
let a = Matrix::new([
[1.0_f64, 0.0, 1.0],
[0.0, 1.0, 1.0],
[1.0, 1.0, 2.0],
]);
let qrp = a.qr_col_pivot();
assert_eq!(qrp.rank(1e-10), 2);
let q = qrp.q();
let r = qrp.r();
let perm = qrp.permutation();
let qr_prod = q * r;
for i in 0..3 {
for j in 0..3 {
assert_near(qr_prod[(i, j)], a[(i, perm[j])], TOL,
&format!("QR[({},{})]", i, j));
}
}
}
#[test]
fn qr_pivot_2x2() {
let a = Matrix::new([
[3.0_f64, 1.0],
[4.0, 2.0],
]);
let qrp = a.qr_col_pivot();
assert_eq!(qrp.rank(1e-10), 2);
let q = qrp.q();
let r = qrp.r();
let perm = qrp.permutation();
let qr_prod = q * r;
for i in 0..2 {
for j in 0..2 {
assert_near(qr_prod[(i, j)], a[(i, perm[j])], TOL,
&format!("QR[({},{})]", i, j));
}
}
}
}