mathhook_core/matrices/decomposition/lu.rs
1//! LU decomposition algorithms
2//!
3//! This module provides LU decomposition with partial pivoting for solving
4//! linear systems and computing matrix properties.
5
6use crate::core::Expression;
7use crate::matrices::types::*;
8use crate::matrices::unified::Matrix;
9use crate::simplify::Simplify;
10
11/// LU decomposition implementation
12impl Matrix {
13 /// Perform LU decomposition with partial pivoting
14 ///
15 /// Decomposes matrix A into PA = LU where:
16 /// - P is a permutation matrix
17 /// - L is lower triangular with 1s on diagonal
18 /// - U is upper triangular
19 ///
20 /// # Examples
21 ///
22 /// ```
23 /// use mathhook_core::matrices::Matrix;
24 ///
25 /// let matrix = Matrix::from_arrays([
26 /// [2, 1],
27 /// [4, 3]
28 /// ]);
29 ///
30 /// let lu = matrix.lu_decomposition().unwrap();
31 /// assert!(lu.p.is_some());
32 /// ```
33 pub fn lu_decomposition(&self) -> Option<LUDecomposition> {
34 let (rows, cols) = self.dimensions();
35 if rows != cols {
36 return None; // LU decomposition requires square matrix
37 }
38
39 // Handle special cases efficiently
40 match self {
41 Matrix::Identity(data) => Some(LUDecomposition {
42 l: Matrix::identity(data.size),
43 u: Matrix::identity(data.size),
44 p: Some(Matrix::identity(data.size)),
45 }),
46 Matrix::Zero(_) => Some(LUDecomposition {
47 l: Matrix::identity(rows),
48 u: Matrix::zero(rows, cols),
49 p: Some(Matrix::identity(rows)),
50 }),
51 Matrix::Diagonal(_) => Some(LUDecomposition {
52 l: Matrix::identity(rows),
53 u: self.clone(),
54 p: Some(Matrix::identity(rows)),
55 }),
56 Matrix::UpperTriangular(_) => Some(LUDecomposition {
57 l: Matrix::identity(rows),
58 u: self.clone(),
59 p: Some(Matrix::identity(rows)),
60 }),
61 _ => {
62 // General LU decomposition using Gaussian elimination with partial pivoting
63 self.general_lu_decomposition()
64 }
65 }
66 }
67
68 /// General LU decomposition implementation using Gaussian elimination
69 fn general_lu_decomposition(&self) -> Option<LUDecomposition> {
70 let (n, _) = self.dimensions();
71
72 // Convert to dense matrix for computation
73 let mut a = self.to_dense_matrix();
74 let mut p = Matrix::identity(n);
75
76 // Perform Gaussian elimination with partial pivoting
77 for k in 0..n {
78 // Find pivot
79 let mut pivot_row = k;
80 for i in (k + 1)..n {
81 let current_elem = a.get_element(i, k);
82 let pivot_elem = a.get_element(pivot_row, k);
83
84 // Simplified pivot selection (proper implementation would compare absolute values)
85 if !current_elem.is_zero() && pivot_elem.is_zero() {
86 pivot_row = i;
87 }
88 }
89
90 // Swap rows if needed
91 if pivot_row != k {
92 a = a.swap_rows(k, pivot_row);
93 p = p.swap_rows(k, pivot_row);
94 }
95
96 // Check for zero pivot
97 let pivot = a.get_element(k, k);
98 if pivot.is_zero() {
99 continue; // Skip if pivot is zero
100 }
101
102 // Eliminate below pivot
103 for i in (k + 1)..n {
104 // Use canonical form for division: a / b = a * b^(-1)
105 let factor = Expression::mul(vec![
106 a.get_element(i, k),
107 Expression::pow(pivot.clone(), Expression::integer(-1)),
108 ])
109 .simplify();
110
111 // Store multiplier in L (lower triangle of a)
112 a = a.set_element(i, k, &factor);
113
114 // Update row i: row_i = row_i - factor * row_k
115 for j in (k + 1)..n {
116 let old_val = a.get_element(i, j);
117 let pivot_val = a.get_element(k, j);
118 let new_val = Expression::add(vec![
119 old_val,
120 Expression::mul(vec![Expression::integer(-1), factor.clone(), pivot_val]),
121 ])
122 .simplify();
123
124 a = a.set_element(i, j, &new_val);
125 }
126 }
127 }
128
129 // Extract L and U matrices
130 let mut l_elements = Vec::new();
131 let mut u_elements = Vec::new();
132
133 for i in 0..n {
134 let mut l_row = Vec::new();
135 let mut u_row = Vec::new();
136
137 for j in 0..n {
138 if i > j {
139 // Lower triangular part
140 l_row.push(a.get_element(i, j));
141 u_row.push(Expression::integer(0));
142 } else if i == j {
143 // Diagonal
144 l_row.push(Expression::integer(1));
145 u_row.push(a.get_element(i, j));
146 } else {
147 // Upper triangular part
148 l_row.push(Expression::integer(0));
149 u_row.push(a.get_element(i, j));
150 }
151 }
152 l_elements.push(l_row);
153 u_elements.push(u_row);
154 }
155
156 Some(LUDecomposition {
157 l: Matrix::dense(l_elements),
158 u: Matrix::dense(u_elements),
159 p: Some(p),
160 })
161 }
162}