use crate::Float;
use crate::error::{CoreError, Result};
use crate::tensor::Tensor;
#[cfg_attr(
feature = "serde-support",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Debug, Clone)]
pub struct LuDecomposition<T: Float> {
lu: Vec<T>,
pivots: Vec<usize>,
n: usize,
sign: T,
}
impl<T: Float> LuDecomposition<T> {
pub fn decompose(a: &Tensor<T>) -> Result<Self> {
if a.ndim() != 2 {
return Err(CoreError::InvalidArgument {
reason: "LU decomposition requires a 2-D tensor (matrix)",
});
}
let n = a.shape()[0];
if a.shape()[1] != n {
return Err(CoreError::InvalidArgument {
reason: "LU decomposition requires a square matrix",
});
}
let mut lu: Vec<T> = a.as_slice().to_vec();
let mut pivots: Vec<usize> = (0..n).collect();
let mut sign = T::one();
for k in 0..n {
let mut max_val = lu[k * n + k].abs();
let mut max_row = k;
for i in (k + 1)..n {
let val = lu[i * n + k].abs();
if val > max_val {
max_val = val;
max_row = i;
}
}
if max_row != k {
for j in 0..n {
lu.swap(k * n + j, max_row * n + j);
}
pivots.swap(k, max_row);
sign *= T::from_f64(-1.0);
}
let pivot = lu[k * n + k];
if pivot.abs() < T::epsilon() * T::from_f64(1e3) {
return Err(CoreError::SingularMatrix);
}
for i in (k + 1)..n {
let factor = lu[i * n + k] / pivot;
lu[i * n + k] = factor;
for j in (k + 1)..n {
let ukj = lu[k * n + j];
lu[i * n + j] -= factor * ukj;
}
}
}
Ok(Self {
lu,
pivots,
n,
sign,
})
}
pub fn l(&self) -> Tensor<T> {
let n = self.n;
let mut data = vec![T::zero(); n * n];
for i in 0..n {
data[i * n + i] = T::one(); for j in 0..i {
data[i * n + j] = self.lu[i * n + j];
}
}
Tensor::from_vec(data, vec![n, n]).expect("L data length equals n*n by construction")
}
pub fn u(&self) -> Tensor<T> {
let n = self.n;
let mut data = vec![T::zero(); n * n];
for i in 0..n {
for j in i..n {
data[i * n + j] = self.lu[i * n + j];
}
}
Tensor::from_vec(data, vec![n, n]).expect("U data length equals n*n by construction")
}
pub fn p(&self) -> Tensor<T> {
let n = self.n;
let mut data = vec![T::zero(); n * n];
for (i, &pi) in self.pivots.iter().enumerate() {
data[i * n + pi] = T::one();
}
Tensor::from_vec(data, vec![n, n]).expect("P data length equals n*n by construction")
}
pub fn pivots(&self) -> &[usize] {
&self.pivots
}
pub fn det(&self) -> T {
let n = self.n;
let mut d = self.sign;
for i in 0..n {
d *= self.lu[i * n + i];
}
d
}
pub fn solve(&self, b: &Tensor<T>) -> Result<Tensor<T>> {
if b.ndim() != 1 {
return Err(CoreError::InvalidArgument {
reason: "solve: `b` must be a 1-D tensor",
});
}
if b.numel() != self.n {
return Err(CoreError::DimensionMismatch {
expected: vec![self.n],
got: b.shape().to_vec(),
});
}
let n = self.n;
let b_data = b.as_slice();
let mut x: Vec<T> = vec![T::zero(); n];
for (i, &pi) in self.pivots.iter().enumerate() {
x[i] = b_data[pi];
}
#[allow(clippy::needless_range_loop)]
for i in 1..n {
for j in 0..i {
let lij_xj = self.lu[i * n + j] * x[j];
x[i] -= lij_xj;
}
}
#[allow(clippy::needless_range_loop)]
for i in (0..n).rev() {
for j in (i + 1)..n {
let uij_xj = self.lu[i * n + j] * x[j];
x[i] -= uij_xj;
}
x[i] /= self.lu[i * n + i];
}
Tensor::from_vec(x, vec![n])
}
pub fn inverse(&self) -> Result<Tensor<T>> {
let n = self.n;
let mut inv_data = vec![T::zero(); n * n];
for col in 0..n {
let mut e = vec![T::zero(); n];
e[col] = T::one();
let e_tensor = Tensor::from_vec(e, vec![n])?;
let x = self.solve(&e_tensor)?;
let x_data = x.as_slice();
for row in 0..n {
inv_data[row * n + col] = x_data[row];
}
}
Tensor::from_vec(inv_data, vec![n, n])
}
}
#[cfg(test)]
#[allow(clippy::float_cmp)]
mod tests {
use super::*;
fn mat(data: &[f64], n: usize) -> Tensor<f64> {
Tensor::from_vec(data.to_vec(), vec![n, n]).unwrap()
}
fn approx_eq(a: &[f64], b: &[f64], tol: f64) -> bool {
a.len() == b.len() && a.iter().zip(b).all(|(&x, &y)| (x - y).abs() < tol)
}
#[test]
fn test_lu_2x2() {
let a = mat(&[2.0, 1.0, 1.0, 4.0], 2);
let lu = LuDecomposition::decompose(&a).unwrap();
let p = lu.p();
let l = lu.l();
let u = lu.u();
let pa = p.matmul(&a).unwrap();
let lu_prod = l.matmul(&u).unwrap();
assert!(approx_eq(pa.as_slice(), lu_prod.as_slice(), 1e-12));
}
#[test]
fn test_lu_3x3() {
let a = mat(&[2.0, 1.0, 1.0, 4.0, 3.0, 3.0, 8.0, 7.0, 9.0], 3);
let lu = LuDecomposition::decompose(&a).unwrap();
let p = lu.p();
let l = lu.l();
let u = lu.u();
let pa = p.matmul(&a).unwrap();
let lu_prod = l.matmul(&u).unwrap();
assert!(approx_eq(pa.as_slice(), lu_prod.as_slice(), 1e-12));
}
#[test]
fn test_lu_4x4() {
let a = mat(
&[
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 2.0, 6.0, 4.0, 8.0, 3.0, 1.0, 1.0, 2.0,
],
4,
);
let lu = LuDecomposition::decompose(&a).unwrap();
let p = lu.p();
let l = lu.l();
let u = lu.u();
let pa = p.matmul(&a).unwrap();
let lu_prod = l.matmul(&u).unwrap();
assert!(approx_eq(pa.as_slice(), lu_prod.as_slice(), 1e-10));
}
#[test]
fn test_det_2x2() {
let a = mat(&[2.0, 1.0, 1.0, 4.0], 2);
let lu = LuDecomposition::decompose(&a).unwrap();
assert!((lu.det() - 7.0).abs() < 1e-10);
}
#[test]
fn test_det_3x3() {
let a = mat(&[6.0, 1.0, 1.0, 4.0, -2.0, 5.0, 2.0, 8.0, 7.0], 3);
let lu = LuDecomposition::decompose(&a).unwrap();
assert!((lu.det() - (-306.0)).abs() < 1e-10);
}
#[test]
fn test_det_4x4_numpy() {
let a = mat(
&[
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 2.0, 6.0, 4.0, 8.0, 3.0, 1.0, 1.0, 2.0,
],
4,
);
let lu = LuDecomposition::decompose(&a).unwrap();
assert!((lu.det() - 72.0).abs() < 1e-10);
}
#[test]
fn test_det_identity() {
let eye = Tensor::<f64>::eye(5);
let lu = LuDecomposition::decompose(&eye).unwrap();
assert!((lu.det() - 1.0).abs() < 1e-14);
}
#[test]
fn test_singular_matrix() {
let a = mat(&[1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 7.0, 8.0, 9.0], 3);
assert!(LuDecomposition::decompose(&a).is_err());
}
#[test]
fn test_solve_2x2() {
let a = mat(&[2.0, 1.0, 1.0, 4.0], 2);
let b = Tensor::from_vec(vec![5.0, 6.0], vec![2]).unwrap();
let lu = LuDecomposition::decompose(&a).unwrap();
let x = lu.solve(&b).unwrap();
assert!(approx_eq(x.as_slice(), &[2.0, 1.0], 1e-12));
}
#[test]
fn test_solve_3x3() {
let a = mat(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0], 3);
let b = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
let lu = LuDecomposition::decompose(&a).unwrap();
let x = lu.solve(&b).unwrap();
assert!(approx_eq(
x.as_slice(),
&[-1.0 / 3.0, 2.0 / 3.0, 0.0],
1e-12
));
}
#[test]
fn test_solve_4x4_numpy() {
let a = mat(
&[
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 2.0, 6.0, 4.0, 8.0, 3.0, 1.0, 1.0, 2.0,
],
4,
);
let b = Tensor::from_vec(vec![10.0, 26.0, 20.0, 7.0], vec![4]).unwrap();
let lu = LuDecomposition::decompose(&a).unwrap();
let x = lu.solve(&b).unwrap();
assert!(approx_eq(x.as_slice(), &[1.0, 1.0, 1.0, 1.0], 1e-10));
}
#[test]
fn test_inverse_2x2() {
let a = mat(&[2.0, 1.0, 1.0, 4.0], 2);
let lu = LuDecomposition::decompose(&a).unwrap();
let inv = lu.inverse().unwrap();
let eye = a.matmul(&inv).unwrap();
let identity = Tensor::<f64>::eye(2);
assert!(approx_eq(eye.as_slice(), identity.as_slice(), 1e-12));
}
#[test]
fn test_inverse_3x3() {
let a = mat(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0], 3);
let lu = LuDecomposition::decompose(&a).unwrap();
let inv = lu.inverse().unwrap();
let eye = a.matmul(&inv).unwrap();
let identity = Tensor::<f64>::eye(3);
assert!(approx_eq(eye.as_slice(), identity.as_slice(), 1e-10));
}
#[test]
fn test_inverse_identity() {
let eye = Tensor::<f64>::eye(4);
let lu = LuDecomposition::decompose(&eye).unwrap();
let inv = lu.inverse().unwrap();
assert!(approx_eq(inv.as_slice(), eye.as_slice(), 1e-14));
}
#[test]
fn test_not_square() {
let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
assert!(LuDecomposition::decompose(&a).is_err());
}
#[test]
fn test_not_2d() {
let a = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
assert!(LuDecomposition::decompose(&a).is_err());
}
#[test]
fn test_solve_dimension_mismatch() {
let a = mat(&[1.0, 0.0, 0.0, 1.0], 2);
let b = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
let lu = LuDecomposition::decompose(&a).unwrap();
assert!(lu.solve(&b).is_err());
}
}