use crate::matrix::DenseMatrix;
use crate::Scalar;
use faer::linalg::solvers::Qr;
use faer::prelude::*;
use faer::{ComplexField, Conjugate, Entity, Mat, SimpleEntity};
use numra_core::LinalgError;
pub struct QRFactorization<S: Scalar + Entity> {
qr: Qr<S>,
m: usize,
n: usize,
}
impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> QRFactorization<S> {
pub fn new(matrix: &DenseMatrix<S>) -> Result<Self, LinalgError> {
let m = matrix.rows();
let n = matrix.cols();
if m < n {
return Err(LinalgError::DimensionMismatch {
expected: (n, n),
actual: (m, n),
});
}
let qr = Qr::new(matrix.as_faer());
Ok(Self { qr, m, n })
}
pub fn nrows(&self) -> usize {
self.m
}
pub fn ncols(&self) -> usize {
self.n
}
pub fn solve(&self, b: &[S]) -> Result<Vec<S>, LinalgError> {
if self.m != self.n {
return Err(LinalgError::NotSquare {
nrows: self.m,
ncols: self.n,
});
}
if b.len() != self.m {
return Err(LinalgError::DimensionMismatch {
expected: (self.m, 1),
actual: (b.len(), 1),
});
}
let mut b_mat = Mat::zeros(self.m, 1);
for (i, &val) in b.iter().enumerate() {
b_mat.write(i, 0, val);
}
let x_mat = self.qr.solve(&b_mat);
let mut x = Vec::with_capacity(self.n);
for i in 0..self.n {
x.push(x_mat.read(i, 0));
}
Ok(x)
}
pub fn solve_least_squares(&self, b: &[S]) -> Result<Vec<S>, LinalgError> {
if b.len() != self.m {
return Err(LinalgError::DimensionMismatch {
expected: (self.m, 1),
actual: (b.len(), 1),
});
}
let mut b_mat = Mat::zeros(self.m, 1);
for (i, &val) in b.iter().enumerate() {
b_mat.write(i, 0, val);
}
let x_mat = self.qr.solve_lstsq(&b_mat);
let mut x = Vec::with_capacity(self.n);
for i in 0..self.n {
x.push(x_mat.read(i, 0));
}
Ok(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Matrix;
#[test]
fn test_qr_square() {
let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
m.set(0, 0, 1.0);
m.set(0, 1, 2.0);
m.set(1, 0, 3.0);
m.set(1, 1, 4.0);
let qr = QRFactorization::new(&m).unwrap();
assert_eq!(qr.nrows(), 2);
assert_eq!(qr.ncols(), 2);
let b = vec![5.0, 11.0];
let x = qr.solve(&b).unwrap();
assert!((x[0] - 1.0).abs() < 1e-10);
assert!((x[1] - 2.0).abs() < 1e-10);
}
#[test]
fn test_qr_overdetermined() {
let mut m: DenseMatrix<f64> = DenseMatrix::zeros(3, 2);
m.set(0, 0, 1.0);
m.set(0, 1, 1.0);
m.set(1, 0, 1.0);
m.set(1, 1, 2.0);
m.set(2, 0, 1.0);
m.set(2, 1, 3.0);
let qr = QRFactorization::new(&m).unwrap();
assert_eq!(qr.nrows(), 3);
assert_eq!(qr.ncols(), 2);
let b = vec![1.0, 2.0, 2.0];
let x = qr.solve_least_squares(&b).unwrap();
assert!((x[0] - 2.0 / 3.0).abs() < 1e-10);
assert!((x[1] - 0.5).abs() < 1e-10);
}
#[test]
fn test_qr_dimension_mismatch() {
let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
m.set(0, 0, 1.0);
m.set(0, 1, 2.0);
m.set(1, 0, 3.0);
m.set(1, 1, 4.0);
let qr = QRFactorization::new(&m).unwrap();
let b = vec![1.0, 2.0, 3.0];
let result = qr.solve(&b);
assert!(result.is_err());
let result = qr.solve_least_squares(&b);
assert!(result.is_err());
}
}