use crate::matrix::{DenseMatrix, Matrix};
use crate::Scalar;
use faer::linalg::solvers::PartialPivLu;
use faer::prelude::*;
use faer::{ComplexField, Conjugate, Entity, Mat, SimpleEntity};
use numra_core::LinalgError;
pub struct LUFactorization<S: Scalar + Entity> {
lu: PartialPivLu<S>,
n: usize,
}
impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> LUFactorization<S> {
pub fn new(matrix: &DenseMatrix<S>) -> Result<Self, LinalgError> {
if !matrix.is_square() {
return Err(LinalgError::NotSquare {
nrows: matrix.rows(),
ncols: matrix.cols(),
});
}
let n = matrix.rows();
let lu = matrix.as_faer().partial_piv_lu();
Ok(Self { lu, n })
}
pub fn dim(&self) -> usize {
self.n
}
pub fn solve(&self, b: &[S]) -> Result<Vec<S>, LinalgError> {
if b.len() != self.n {
return Err(LinalgError::DimensionMismatch {
expected: (self.n, 1),
actual: (b.len(), 1),
});
}
let mut b_mat = Mat::zeros(self.n, 1);
for (i, &val) in b.iter().enumerate() {
b_mat.write(i, 0, val);
}
let x_mat = self.lu.solve(&b_mat);
let mut x = Vec::with_capacity(self.n);
for i in 0..self.n {
x.push(x_mat.read(i, 0));
}
Ok(x)
}
pub fn solve_inplace(&self, b: &mut [S]) -> Result<(), LinalgError> {
let x = self.solve(b)?;
b.copy_from_slice(&x);
Ok(())
}
pub fn solve_multi(&self, b: &[S], nrhs: usize) -> Result<Vec<S>, LinalgError> {
if b.len() != self.n * nrhs {
return Err(LinalgError::DimensionMismatch {
expected: (self.n, nrhs),
actual: (b.len(), 1),
});
}
let mut b_mat = Mat::zeros(self.n, nrhs);
for j in 0..nrhs {
for i in 0..self.n {
b_mat.write(i, j, b[j * self.n + i]);
}
}
let x_mat = self.lu.solve(&b_mat);
let mut x = vec![S::ZERO; self.n * nrhs];
for j in 0..nrhs {
for i in 0..self.n {
x[j * self.n + i] = x_mat.read(i, j);
}
}
Ok(x)
}
}
pub trait LUSolver<S: Scalar> {
fn lu_solve(&self, b: &[S]) -> Result<Vec<S>, LinalgError>;
fn lu_factor(&self) -> Result<LUFactorization<S>, LinalgError>
where
S: Entity + SimpleEntity + Conjugate<Canonical = S> + ComplexField;
}
impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> LUSolver<S>
for DenseMatrix<S>
{
fn lu_solve(&self, b: &[S]) -> Result<Vec<S>, LinalgError> {
self.solve(b)
}
fn lu_factor(&self) -> Result<LUFactorization<S>, LinalgError> {
LUFactorization::new(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Matrix;
#[test]
fn test_lu_factorization() {
let mut m: DenseMatrix<f64> = DenseMatrix::zeros(3, 3);
m.set(0, 0, 2.0);
m.set(1, 1, 3.0);
m.set(2, 2, 4.0);
let lu = LUFactorization::new(&m).unwrap();
assert_eq!(lu.dim(), 3);
}
#[test]
fn test_lu_solve() {
let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
m.set(0, 0, 1.0);
m.set(0, 1, 2.0);
m.set(1, 0, 3.0);
m.set(1, 1, 4.0);
let lu = LUFactorization::new(&m).unwrap();
let b = vec![5.0, 11.0];
let x = lu.solve(&b).unwrap();
assert!((x[0] - 1.0).abs() < 1e-10);
assert!((x[1] - 2.0).abs() < 1e-10);
}
#[test]
fn test_lu_solve_inplace() {
let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
m.set(0, 0, 2.0);
m.set(0, 1, 0.0);
m.set(1, 0, 0.0);
m.set(1, 1, 3.0);
let lu = LUFactorization::new(&m).unwrap();
let mut b = vec![4.0, 9.0];
lu.solve_inplace(&mut b).unwrap();
assert!((b[0] - 2.0).abs() < 1e-10);
assert!((b[1] - 3.0).abs() < 1e-10);
}
#[test]
fn test_lu_solve_multi() {
let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
m.set(0, 0, 2.0);
m.set(1, 1, 3.0);
let lu = LUFactorization::new(&m).unwrap();
let b = vec![2.0, 6.0, 4.0, 9.0];
let x = lu.solve_multi(&b, 2).unwrap();
assert!((x[0] - 1.0).abs() < 1e-10);
assert!((x[1] - 2.0).abs() < 1e-10);
assert!((x[2] - 2.0).abs() < 1e-10);
assert!((x[3] - 3.0).abs() < 1e-10);
}
#[test]
fn test_lu_solver_trait() {
let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
m.set(0, 0, 1.0);
m.set(0, 1, 2.0);
m.set(1, 0, 3.0);
m.set(1, 1, 4.0);
let b = vec![5.0, 11.0];
let x = m.lu_solve(&b).unwrap();
assert!((x[0] - 1.0).abs() < 1e-10);
assert!((x[1] - 2.0).abs() < 1e-10);
}
#[test]
fn test_lu_repeated_solve() {
let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
m.set(0, 0, 1.0);
m.set(0, 1, 2.0);
m.set(1, 0, 3.0);
m.set(1, 1, 4.0);
let lu = LUFactorization::new(&m).unwrap();
let b1 = vec![5.0, 11.0];
let x1 = lu.solve(&b1).unwrap();
assert!((x1[0] - 1.0).abs() < 1e-10);
assert!((x1[1] - 2.0).abs() < 1e-10);
let b2 = vec![5.0, 13.0];
let x2 = lu.solve(&b2).unwrap();
assert!((x2[0] - 3.0).abs() < 1e-10);
assert!((x2[1] - 1.0).abs() < 1e-10);
}
}