#![cfg(feature = "lapack")]
use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::Float;
use std::fmt::Debug;
pub fn lu<T>(a: &Array<T>) -> Result<(Array<T>, Array<T>, Array<usize>)>
where
T: Float + Clone + Debug + std::fmt::Display,
{
let shape = a.shape();
if shape.len() != 2 {
return Err(NumRs2Error::DimensionMismatch(
"LU decomposition requires a 2D matrix".to_string(),
));
}
let m = shape[0];
let n = shape[1];
let k = std::cmp::min(m, n);
let mut a_copy = a.clone();
let mut p = (0..m).collect::<Vec<usize>>();
let mut row_scale = vec![num_traits::Zero::zero(); m];
for i in 0..m {
let mut max_in_row = num_traits::Zero::zero();
for j in 0..n {
let abs_val = num_traits::Float::abs(a_copy.get(&[i, j])?);
if abs_val > max_in_row {
max_in_row = abs_val;
}
}
if max_in_row == <T as num_traits::Zero>::zero() {
row_scale[i] = <T as num_traits::One>::one();
} else {
row_scale[i] = <T as num_traits::One>::one() / max_in_row;
}
}
let eps = T::epsilon();
let matrix_size = <T as num_traits::NumCast>::from(std::cmp::max(m, n))
.expect("matrix dimension should convert to float type");
let tolerance = eps * matrix_size;
let mut matrix_norm = <T as num_traits::Zero>::zero();
for i in 0..m {
for j in 0..n {
let abs_val = num_traits::Float::abs(a_copy.get(&[i, j])?);
if abs_val > matrix_norm {
matrix_norm = abs_val;
}
}
}
let pivot_threshold = tolerance * matrix_norm;
let mut rank_deficient = false;
let mut num_small_pivots = 0;
for i in 0..k {
let mut p_row = i;
let mut p_val = num_traits::Float::abs(a_copy.get(&[i, i])?) * row_scale[i];
for j in (i + 1)..m {
let val = num_traits::Float::abs(a_copy.get(&[j, i])?) * row_scale[j];
if val > p_val {
p_row = j;
p_val = val;
}
}
if p_val < pivot_threshold {
rank_deficient = true;
num_small_pivots += 1;
}
if p_row != i {
p.swap(i, p_row);
row_scale.swap(i, p_row);
for j in 0..n {
let temp = a_copy.get(&[i, j])?;
a_copy.set(&[i, j], a_copy.get(&[p_row, j])?)?;
a_copy.set(&[p_row, j], temp)?;
}
}
let pivot = a_copy.get(&[i, i])?;
let abs_pivot = num_traits::Float::abs(pivot);
if abs_pivot < pivot_threshold {
let small_pivot_magnitude = pivot_threshold
* <T as num_traits::NumCast>::from(10.0)
.unwrap_or_else(|| <T as num_traits::One>::one());
let small_pivot = if pivot >= <T as num_traits::Zero>::zero()
|| pivot == <T as num_traits::Zero>::zero()
{
small_pivot_magnitude
} else {
-small_pivot_magnitude
};
a_copy.set(&[i, i], small_pivot)?;
}
for j in (i + 1)..m {
let pivot = a_copy.get(&[i, i])?;
let factor = a_copy.get(&[j, i])? / pivot;
a_copy.set(&[j, i], factor)?;
for l in (i + 1)..n {
let a_jl = a_copy.get(&[j, l])?;
let a_il = a_copy.get(&[i, l])?;
let prod = factor * a_il;
let new_val = a_jl - prod;
a_copy.set(&[j, l], new_val)?;
}
}
}
let mut l = Array::zeros(&[m, k]);
let mut u = Array::zeros(&[k, n]);
for i in 0..k {
l.set(&[i, i], num_traits::One::one())?;
}
for i in 1..m {
for j in 0..std::cmp::min(i, k) {
l.set(&[i, j], a_copy.get(&[i, j])?)?;
}
}
for i in 0..k {
for j in i..n {
u.set(&[i, j], a_copy.get(&[i, j])?)?;
}
}
let piv_array = Array::from_vec(p.clone());
#[cfg(feature = "validation")]
{
let mut pa = Array::zeros(&[m, n]);
for i in 0..m {
for j in 0..n {
pa.set(&[i, j], a.get(&[p[i], j])?)?;
}
}
let lu_product = l.matmul(&u)?;
let mut max_diff = <T as num_traits::Zero>::zero();
for i in 0..m {
for j in 0..n {
let diff = num_traits::Float::abs(pa.get(&[i, j])? - lu_product.get(&[i, j])?);
if diff > max_diff {
max_diff = diff;
}
}
}
let acceptable_error = eps * matrix_norm * matrix_size;
if max_diff > acceptable_error {
eprintln!(
"Warning: LU decomposition may be numerically unstable. Max difference: {}",
max_diff
);
}
}
if rank_deficient {
eprintln!("Warning: Matrix appears to be rank deficient or ill-conditioned. {} small pivots detected.", num_small_pivots);
}
Ok((l, u, piv_array))
}