use crate::{
linalg::{Matrix, error::LinalgError},
traits::Real,
};
pub fn lu_decomp<T: Real>(a: &mut Matrix<T>, ip: &mut [usize]) -> Result<(), LinalgError> {
let n = a.nrows();
if n != a.ncols() {
return Err(LinalgError::BadInput {
message: format!("Matrix is not square: {}x{}", n, a.ncols()),
});
}
if ip.len() != n {
return Err(LinalgError::PivotSizeMismatch {
expected: n,
actual: ip.len(),
});
}
if n == 1 {
if a[(0, 0)] == T::zero() {
return Err(LinalgError::Singular { step: 1 });
}
ip[0] = 0;
return Ok(());
}
let nm1 = n - 1;
for k in 0..nm1 {
let kp1 = k + 1;
let mut m = k;
let mut max_val = a[(k, k)].abs();
for i in kp1..n {
let val = a[(i, k)].abs();
if val > max_val {
max_val = val;
m = i;
}
}
ip[k] = m;
let pivot = a[(m, k)];
if pivot == T::zero() {
return Err(LinalgError::Singular { step: k + 1 });
}
if m != k {
let tmp = a[(m, k)];
a[(m, k)] = a[(k, k)];
a[(k, k)] = tmp;
}
let t = T::one() / pivot;
for i in kp1..n {
a[(i, k)] = -a[(i, k)] * t;
}
for j in kp1..n {
let tj = a[(m, j)];
if m != k {
let temp = a[(m, j)];
a[(m, j)] = a[(k, j)];
a[(k, j)] = temp;
}
if tj != T::zero() {
for i in kp1..n {
a[(i, j)] = a[(i, j)] + a[(i, k)] * tj;
}
}
}
}
if a[(n - 1, n - 1)] == T::zero() {
return Err(LinalgError::Singular { step: n });
}
Ok(())
}
pub fn lu_decomp_complex<T: Real>(
ar: &mut Matrix<T>,
ai: &mut Matrix<T>,
ip: &mut [usize],
) -> Result<(), LinalgError> {
let n = ar.nrows();
if n != ar.ncols() || n != ai.nrows() || n != ai.ncols() {
return Err(LinalgError::BadInput {
message: format!(
"Matrix dimensions inconsistent: {}x{}, {}x{}",
ar.nrows(),
ar.ncols(),
ai.nrows(),
ai.ncols()
),
});
}
if ip.len() != n {
return Err(LinalgError::PivotSizeMismatch {
expected: n,
actual: ip.len(),
});
}
if n == 1 {
if ar[(0, 0)].abs() + ai[(0, 0)].abs() == T::zero() {
return Err(LinalgError::Singular { step: 1 });
}
ip[0] = 0;
return Ok(());
}
let nm1 = n - 1;
for k in 0..nm1 {
let kp1 = k + 1;
let mut m = k;
let mut max_val = ar[(k, k)].abs() + ai[(k, k)].abs();
for i in kp1..n {
let val = ar[(i, k)].abs() + ai[(i, k)].abs();
if val > max_val {
max_val = val;
m = i;
}
}
ip[k] = m;
let mut tr = ar[(m, k)];
let mut ti = ai[(m, k)];
if tr.abs() + ti.abs() == T::zero() {
return Err(LinalgError::Singular { step: k + 1 });
}
if m != k {
let tmp_r = ar[(m, k)];
let tmp_i = ai[(m, k)];
ar[(m, k)] = ar[(k, k)];
ai[(m, k)] = ai[(k, k)];
ar[(k, k)] = tmp_r;
ai[(k, k)] = tmp_i;
}
let den = tr * tr + ti * ti;
tr /= den;
ti = -ti / den;
for i in kp1..n {
let prod_r = ar[(i, k)] * tr - ai[(i, k)] * ti;
let prod_i = ai[(i, k)] * tr + ar[(i, k)] * ti;
ar[(i, k)] = -prod_r;
ai[(i, k)] = -prod_i;
}
for j in kp1..n {
let mr = ar[(m, j)];
let mi = ai[(m, j)];
if m != k {
let temp_r = ar[(m, j)];
let temp_i = ai[(m, j)];
ar[(m, j)] = ar[(k, j)];
ai[(m, j)] = ai[(k, j)];
ar[(k, j)] = temp_r;
ai[(k, j)] = temp_i;
}
if mr.abs() + mi.abs() != T::zero() {
if mi == T::zero() {
for i in kp1..n {
let prod_r = ar[(i, k)] * mr;
let prod_i = ai[(i, k)] * mr;
ar[(i, j)] += prod_r;
ai[(i, j)] += prod_i;
}
} else if mr == T::zero() {
for i in kp1..n {
let prod_r = -ai[(i, k)] * mi;
let prod_i = ar[(i, k)] * mi;
ar[(i, j)] += prod_r;
ai[(i, j)] += prod_i;
}
} else {
for i in kp1..n {
let prod_r = ar[(i, k)] * mr - ai[(i, k)] * mi;
let prod_i = ai[(i, k)] * mr + ar[(i, k)] * mi;
ar[(i, j)] += prod_r;
ai[(i, j)] += prod_i;
}
}
}
}
}
if ar[(n - 1, n - 1)].abs() + ai[(n - 1, n - 1)].abs() == T::zero() {
return Err(LinalgError::Singular { step: n });
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dec_simple() {
let mut a = Matrix::from_vec(2, 2, vec![2.0_f64, 1.0, 4.0, 3.0]);
let mut ip = [0; 2];
let result = lu_decomp(&mut a, &mut ip);
assert!(result.is_ok());
assert!(a[(0, 0)].abs() > 1e-10);
assert!(a[(1, 1)].abs() > 1e-10);
}
#[test]
fn test_dec_singular() {
let mut a = Matrix::from_vec(2, 2, vec![1.0_f64, 0.0, 0.0, 0.0]);
let mut ip = [0; 2];
let result = lu_decomp(&mut a, &mut ip);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), LinalgError::Singular { step: 2 });
}
#[test]
fn test_dec_1x1() {
let mut a = Matrix::from_vec(1, 1, vec![5.0_f64]);
let mut ip = [0; 1];
let result = lu_decomp(&mut a, &mut ip);
assert!(result.is_ok());
assert_eq!(ip[0], 0);
}
#[test]
fn test_dec_1x1_singular() {
let mut a = Matrix::from_vec(1, 1, vec![0.0_f64]);
let mut ip = [0; 1];
let result = lu_decomp(&mut a, &mut ip);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), LinalgError::Singular { step: 1 });
}
#[test]
fn test_decc_simple() {
let mut ar = Matrix::from_vec(2, 2, vec![1.0_f64, 0.0, 0.0, 1.0]);
let mut ai = Matrix::from_vec(2, 2, vec![0.0, 1.0, 1.0, 0.0]);
let mut ip = [0; 2];
let result = lu_decomp_complex(&mut ar, &mut ai, &mut ip);
assert!(result.is_ok());
let diag0_mag = ar[(0, 0)].abs() + ai[(0, 0)].abs();
let diag1_mag = ar[(1, 1)].abs() + ai[(1, 1)].abs();
assert!(diag0_mag > 1e-10);
assert!(diag1_mag > 1e-10);
}
#[test]
fn test_decc_singular() {
let mut ar = Matrix::from_vec(2, 2, vec![1.0_f64, 1.0, 1.0, 1.0]);
let mut ai = Matrix::from_vec(2, 2, vec![0.0_f64, 0.0, 0.0, 0.0]);
let mut ip = [0; 2];
let result = lu_decomp_complex(&mut ar, &mut ai, &mut ip);
assert!(result.is_err());
}
#[test]
fn test_decc_1x1() {
let mut ar = Matrix::from_vec(1, 1, vec![3.0_f64]);
let mut ai = Matrix::from_vec(1, 1, vec![4.0_f64]); let mut ip = [0; 1];
let result = lu_decomp_complex(&mut ar, &mut ai, &mut ip);
assert!(result.is_ok());
assert_eq!(ip[0], 0);
}
#[test]
fn test_decc_1x1_singular() {
let mut ar = Matrix::from_vec(1, 1, vec![0.0_f64]);
let mut ai = Matrix::from_vec(1, 1, vec![0.0_f64]);
let mut ip = [0; 1];
let result = lu_decomp_complex(&mut ar, &mut ai, &mut ip);
assert!(result.is_err());
}
}