mathhook-core 0.2.0

Core mathematical engine for MathHook - expressions, algebra, and solving
Documentation
//! LU decomposition algorithms
//!
//! This module provides LU decomposition with partial pivoting for solving
//! linear systems and computing matrix properties.

use crate::core::Expression;
use crate::matrices::types::*;
use crate::matrices::unified::Matrix;
use crate::simplify::Simplify;

/// LU decomposition implementation
impl Matrix {
    /// Perform LU decomposition with partial pivoting
    ///
    /// Decomposes matrix A into PA = LU where:
    /// - P is a permutation matrix
    /// - L is lower triangular with 1s on diagonal
    /// - U is upper triangular
    ///
    /// # Examples
    ///
    /// ```
    /// use mathhook_core::matrices::Matrix;
    ///
    /// let matrix = Matrix::from_arrays([
    ///     [2, 1],
    ///     [4, 3]
    /// ]);
    ///
    /// let lu = matrix.lu_decomposition().unwrap();
    /// assert!(lu.p.is_some());
    /// ```
    pub fn lu_decomposition(&self) -> Option<LUDecomposition> {
        let (rows, cols) = self.dimensions();
        if rows != cols {
            return None; // LU decomposition requires square matrix
        }

        // Handle special cases efficiently
        match self {
            Matrix::Identity(data) => Some(LUDecomposition {
                l: Matrix::identity(data.size),
                u: Matrix::identity(data.size),
                p: Some(Matrix::identity(data.size)),
            }),
            Matrix::Zero(_) => Some(LUDecomposition {
                l: Matrix::identity(rows),
                u: Matrix::zero(rows, cols),
                p: Some(Matrix::identity(rows)),
            }),
            Matrix::Diagonal(_) => Some(LUDecomposition {
                l: Matrix::identity(rows),
                u: self.clone(),
                p: Some(Matrix::identity(rows)),
            }),
            Matrix::UpperTriangular(_) => Some(LUDecomposition {
                l: Matrix::identity(rows),
                u: self.clone(),
                p: Some(Matrix::identity(rows)),
            }),
            _ => {
                // General LU decomposition using Gaussian elimination with partial pivoting
                self.general_lu_decomposition()
            }
        }
    }

    /// General LU decomposition implementation using Gaussian elimination
    fn general_lu_decomposition(&self) -> Option<LUDecomposition> {
        let (n, _) = self.dimensions();

        // Convert to dense matrix for computation
        let mut a = self.to_dense_matrix();
        let mut p = Matrix::identity(n);

        // Perform Gaussian elimination with partial pivoting
        for k in 0..n {
            // Find pivot
            let mut pivot_row = k;
            for i in (k + 1)..n {
                let current_elem = a.get_element(i, k);
                let pivot_elem = a.get_element(pivot_row, k);

                // Simplified pivot selection (proper implementation would compare absolute values)
                if !current_elem.is_zero() && pivot_elem.is_zero() {
                    pivot_row = i;
                }
            }

            // Swap rows if needed
            if pivot_row != k {
                a = a.swap_rows(k, pivot_row);
                p = p.swap_rows(k, pivot_row);
            }

            // Check for zero pivot
            let pivot = a.get_element(k, k);
            if pivot.is_zero() {
                continue; // Skip if pivot is zero
            }

            // Eliminate below pivot
            for i in (k + 1)..n {
                // Use canonical form for division: a / b = a * b^(-1)
                let factor = Expression::mul(vec![
                    a.get_element(i, k),
                    Expression::pow(pivot.clone(), Expression::integer(-1)),
                ])
                .simplify();

                // Store multiplier in L (lower triangle of a)
                a = a.set_element(i, k, &factor);

                // Update row i: row_i = row_i - factor * row_k
                for j in (k + 1)..n {
                    let old_val = a.get_element(i, j);
                    let pivot_val = a.get_element(k, j);
                    let new_val = Expression::add(vec![
                        old_val,
                        Expression::mul(vec![Expression::integer(-1), factor.clone(), pivot_val]),
                    ])
                    .simplify();

                    a = a.set_element(i, j, &new_val);
                }
            }
        }

        // Extract L and U matrices
        let mut l_elements = Vec::new();
        let mut u_elements = Vec::new();

        for i in 0..n {
            let mut l_row = Vec::new();
            let mut u_row = Vec::new();

            for j in 0..n {
                if i > j {
                    // Lower triangular part
                    l_row.push(a.get_element(i, j));
                    u_row.push(Expression::integer(0));
                } else if i == j {
                    // Diagonal
                    l_row.push(Expression::integer(1));
                    u_row.push(a.get_element(i, j));
                } else {
                    // Upper triangular part
                    l_row.push(Expression::integer(0));
                    u_row.push(a.get_element(i, j));
                }
            }
            l_elements.push(l_row);
            u_elements.push(u_row);
        }

        Some(LUDecomposition {
            l: Matrix::dense(l_elements),
            u: Matrix::dense(u_elements),
            p: Some(p),
        })
    }
}