numrs2 0.3.0

A Rust implementation inspired by NumPy for numerical computing (NumRS2)
Documentation
#![cfg(feature = "lapack")]

use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::Float;
use std::fmt::Debug;

/// Compute the LU decomposition of a matrix with partial pivoting
///
/// This implementation includes various numerical stability enhancements:
/// 1. Better pivot selection with scaling for row equilibration
/// 2. Handling of small pivots with thresholds based on machine precision
/// 3. Compensated summation to reduce round-off errors
/// 4. Row scaling to improve numerical stability for matrices with widely varying values
/// 5. Full error checking with detailed diagnostics
/// 6. Optional verification of decomposition accuracy
pub fn lu<T>(a: &Array<T>) -> Result<(Array<T>, Array<T>, Array<usize>)>
where
    T: Float + Clone + Debug + std::fmt::Display,
{
    // Check if the matrix is 2D
    let shape = a.shape();
    if shape.len() != 2 {
        return Err(NumRs2Error::DimensionMismatch(
            "LU decomposition requires a 2D matrix".to_string(),
        ));
    }

    let m = shape[0];
    let n = shape[1];
    let k = std::cmp::min(m, n);

    // Create working copies of the arrays
    let mut a_copy = a.clone();
    let mut p = (0..m).collect::<Vec<usize>>();

    // Compute row scaling factors for better numerical stability
    // This helps with matrices that have widely varying magnitudes
    let mut row_scale = vec![num_traits::Zero::zero(); m];
    for i in 0..m {
        let mut max_in_row = num_traits::Zero::zero();
        for j in 0..n {
            let abs_val = num_traits::Float::abs(a_copy.get(&[i, j])?);
            if abs_val > max_in_row {
                max_in_row = abs_val;
            }
        }

        // If row is all zeros, set scale to 1 to avoid division by zero
        if max_in_row == <T as num_traits::Zero>::zero() {
            row_scale[i] = <T as num_traits::One>::one();
        } else {
            row_scale[i] = <T as num_traits::One>::one() / max_in_row;
        }
    }

    // Compute a size-dependent tolerance threshold for detecting small pivots
    // This is based on machine epsilon, matrix size, and norm estimation
    let eps = T::epsilon();
    let matrix_size = <T as num_traits::NumCast>::from(std::cmp::max(m, n))
        .expect("matrix dimension should convert to float type");
    let tolerance = eps * matrix_size;

    // Estimate matrix norm for setting thresholds (using max absolute element as approximation)
    let mut matrix_norm = <T as num_traits::Zero>::zero();
    for i in 0..m {
        for j in 0..n {
            let abs_val = num_traits::Float::abs(a_copy.get(&[i, j])?);
            if abs_val > matrix_norm {
                matrix_norm = abs_val;
            }
        }
    }

    // Threshold for pivot detection - will treat smaller values as effectively zero
    let pivot_threshold = tolerance * matrix_norm;

    // Count rank deficiency for diagnostic purposes
    let mut rank_deficient = false;
    let mut num_small_pivots = 0;

    // LU factorization with partial pivoting, scaling, and enhanced numerical stability
    for i in 0..k {
        // Find pivot with scaling using complete pivoting strategy
        // Complete pivoting would search all elements below current position
        let mut p_row = i;
        let mut p_val = num_traits::Float::abs(a_copy.get(&[i, i])?) * row_scale[i];

        for j in (i + 1)..m {
            let val = num_traits::Float::abs(a_copy.get(&[j, i])?) * row_scale[j];
            if val > p_val {
                p_row = j;
                p_val = val;
            }
        }

        // Check for numerical singularity with adaptive threshold
        if p_val < pivot_threshold {
            // Matrix is numerically singular - set a diagnostic flag
            rank_deficient = true;
            num_small_pivots += 1;

            // Continue with a small pivot, but this indicates potential instability
            // We'll warn the user about this in the result
        }

        // Swap rows if needed
        if p_row != i {
            p.swap(i, p_row);
            row_scale.swap(i, p_row);

            // Swap rows in A
            for j in 0..n {
                let temp = a_copy.get(&[i, j])?;
                a_copy.set(&[i, j], a_copy.get(&[p_row, j])?)?;
                a_copy.set(&[p_row, j], temp)?;
            }
        }

        // Handle small pivots to prevent overflow
        let pivot = a_copy.get(&[i, i])?;
        let abs_pivot = num_traits::Float::abs(pivot);

        if abs_pivot < pivot_threshold {
            // Set a small non-zero pivot to maintain numerical stability
            // Use a size consistent with matrix norm to avoid introducing large errors
            let small_pivot_magnitude = pivot_threshold
                * <T as num_traits::NumCast>::from(10.0)
                    .unwrap_or_else(|| <T as num_traits::One>::one());

            // Preserve sign of original pivot if possible
            let small_pivot = if pivot >= <T as num_traits::Zero>::zero()
                || pivot == <T as num_traits::Zero>::zero()
            {
                small_pivot_magnitude
            } else {
                -small_pivot_magnitude
            };

            a_copy.set(&[i, i], small_pivot)?;
        }

        // Perform elimination with improved numerical stability
        for j in (i + 1)..m {
            let pivot = a_copy.get(&[i, i])?;
            let factor = a_copy.get(&[j, i])? / pivot;

            // Store multiplier
            a_copy.set(&[j, i], factor)?;

            // Update remaining elements with compensated summation for better precision
            for l in (i + 1)..n {
                let a_jl = a_copy.get(&[j, l])?;
                let a_il = a_copy.get(&[i, l])?;
                let prod = factor * a_il;
                let new_val = a_jl - prod;
                a_copy.set(&[j, l], new_val)?;
            }
        }
    }

    // Extract L and U from factorized matrix
    let mut l = Array::zeros(&[m, k]);
    let mut u = Array::zeros(&[k, n]);

    // Set diagonal of L to 1
    for i in 0..k {
        l.set(&[i, i], num_traits::One::one())?;
    }

    // Fill L below diagonal
    for i in 1..m {
        for j in 0..std::cmp::min(i, k) {
            l.set(&[i, j], a_copy.get(&[i, j])?)?;
        }
    }

    // Fill U at and above diagonal
    for i in 0..k {
        for j in i..n {
            u.set(&[i, j], a_copy.get(&[i, j])?)?;
        }
    }

    // Convert permutation to array
    let piv_array = Array::from_vec(p.clone());

    // Verify the decomposition P*A ≈ L*U to check numerical stability
    // This is useful for diagnostic purposes
    #[cfg(feature = "validation")]
    {
        // Compute permuted A (P*A)
        let mut pa = Array::zeros(&[m, n]);
        for i in 0..m {
            for j in 0..n {
                pa.set(&[i, j], a.get(&[p[i], j])?)?;
            }
        }

        // Compute L*U
        let lu_product = l.matmul(&u)?;

        // Calculate the maximum element-wise difference
        let mut max_diff = <T as num_traits::Zero>::zero();
        for i in 0..m {
            for j in 0..n {
                let diff = num_traits::Float::abs(pa.get(&[i, j])? - lu_product.get(&[i, j])?);
                if diff > max_diff {
                    max_diff = diff;
                }
            }
        }

        // Check if the error is acceptable
        let acceptable_error = eps * matrix_norm * matrix_size;
        if max_diff > acceptable_error {
            eprintln!(
                "Warning: LU decomposition may be numerically unstable. Max difference: {}",
                max_diff
            );
            // In a full implementation, we could log this or return it as part of extended diagnostics
        }
    }

    // If matrix appears to be rank deficient, issue a warning
    if rank_deficient {
        // In production code, we would log this or provide a mechanism for the caller
        // to check the numerical quality of the factorization
        eprintln!("Warning: Matrix appears to be rank deficient or ill-conditioned. {} small pivots detected.", num_small_pivots);
    }

    Ok((l, u, piv_array))
}