mathhook_core/matrices/decomposition/
lu.rs

1//! LU decomposition algorithms
2//!
3//! This module provides LU decomposition with partial pivoting for solving
4//! linear systems and computing matrix properties.
5
6use crate::core::Expression;
7use crate::matrices::types::*;
8use crate::matrices::unified::Matrix;
9use crate::simplify::Simplify;
10
11/// LU decomposition implementation
12impl Matrix {
13    /// Perform LU decomposition with partial pivoting
14    ///
15    /// Decomposes matrix A into PA = LU where:
16    /// - P is a permutation matrix
17    /// - L is lower triangular with 1s on diagonal
18    /// - U is upper triangular
19    ///
20    /// # Examples
21    ///
22    /// ```
23    /// use mathhook_core::matrices::Matrix;
24    ///
25    /// let matrix = Matrix::from_arrays([
26    ///     [2, 1],
27    ///     [4, 3]
28    /// ]);
29    ///
30    /// let lu = matrix.lu_decomposition().unwrap();
31    /// assert!(lu.p.is_some());
32    /// ```
33    pub fn lu_decomposition(&self) -> Option<LUDecomposition> {
34        let (rows, cols) = self.dimensions();
35        if rows != cols {
36            return None; // LU decomposition requires square matrix
37        }
38
39        // Handle special cases efficiently
40        match self {
41            Matrix::Identity(data) => Some(LUDecomposition {
42                l: Matrix::identity(data.size),
43                u: Matrix::identity(data.size),
44                p: Some(Matrix::identity(data.size)),
45            }),
46            Matrix::Zero(_) => Some(LUDecomposition {
47                l: Matrix::identity(rows),
48                u: Matrix::zero(rows, cols),
49                p: Some(Matrix::identity(rows)),
50            }),
51            Matrix::Diagonal(_) => Some(LUDecomposition {
52                l: Matrix::identity(rows),
53                u: self.clone(),
54                p: Some(Matrix::identity(rows)),
55            }),
56            Matrix::UpperTriangular(_) => Some(LUDecomposition {
57                l: Matrix::identity(rows),
58                u: self.clone(),
59                p: Some(Matrix::identity(rows)),
60            }),
61            _ => {
62                // General LU decomposition using Gaussian elimination with partial pivoting
63                self.general_lu_decomposition()
64            }
65        }
66    }
67
68    /// General LU decomposition implementation using Gaussian elimination
69    fn general_lu_decomposition(&self) -> Option<LUDecomposition> {
70        let (n, _) = self.dimensions();
71
72        // Convert to dense matrix for computation
73        let mut a = self.to_dense_matrix();
74        let mut p = Matrix::identity(n);
75
76        // Perform Gaussian elimination with partial pivoting
77        for k in 0..n {
78            // Find pivot
79            let mut pivot_row = k;
80            for i in (k + 1)..n {
81                let current_elem = a.get_element(i, k);
82                let pivot_elem = a.get_element(pivot_row, k);
83
84                // Simplified pivot selection (proper implementation would compare absolute values)
85                if !current_elem.is_zero() && pivot_elem.is_zero() {
86                    pivot_row = i;
87                }
88            }
89
90            // Swap rows if needed
91            if pivot_row != k {
92                a = a.swap_rows(k, pivot_row);
93                p = p.swap_rows(k, pivot_row);
94            }
95
96            // Check for zero pivot
97            let pivot = a.get_element(k, k);
98            if pivot.is_zero() {
99                continue; // Skip if pivot is zero
100            }
101
102            // Eliminate below pivot
103            for i in (k + 1)..n {
104                // Use canonical form for division: a / b = a * b^(-1)
105                let factor = Expression::mul(vec![
106                    a.get_element(i, k),
107                    Expression::pow(pivot.clone(), Expression::integer(-1)),
108                ])
109                .simplify();
110
111                // Store multiplier in L (lower triangle of a)
112                a = a.set_element(i, k, &factor);
113
114                // Update row i: row_i = row_i - factor * row_k
115                for j in (k + 1)..n {
116                    let old_val = a.get_element(i, j);
117                    let pivot_val = a.get_element(k, j);
118                    let new_val = Expression::add(vec![
119                        old_val,
120                        Expression::mul(vec![Expression::integer(-1), factor.clone(), pivot_val]),
121                    ])
122                    .simplify();
123
124                    a = a.set_element(i, j, &new_val);
125                }
126            }
127        }
128
129        // Extract L and U matrices
130        let mut l_elements = Vec::new();
131        let mut u_elements = Vec::new();
132
133        for i in 0..n {
134            let mut l_row = Vec::new();
135            let mut u_row = Vec::new();
136
137            for j in 0..n {
138                if i > j {
139                    // Lower triangular part
140                    l_row.push(a.get_element(i, j));
141                    u_row.push(Expression::integer(0));
142                } else if i == j {
143                    // Diagonal
144                    l_row.push(Expression::integer(1));
145                    u_row.push(a.get_element(i, j));
146                } else {
147                    // Upper triangular part
148                    l_row.push(Expression::integer(0));
149                    u_row.push(a.get_element(i, j));
150                }
151            }
152            l_elements.push(l_row);
153            u_elements.push(u_row);
154        }
155
156        Some(LUDecomposition {
157            l: Matrix::dense(l_elements),
158            u: Matrix::dense(u_elements),
159            p: Some(p),
160        })
161    }
162}