mathhook_core/matrices/unified/solvers.rs
1//! Matrix linear system solvers
2//!
3//! Provides methods for solving Ax = b using LU, Cholesky, and QR decompositions.
4
5use crate::core::Expression;
6use crate::error::MathError;
7use crate::matrices::types::MatrixData;
8use crate::matrices::unified::operations::CoreMatrixOps;
9use crate::matrices::unified::Matrix;
10
11impl Matrix {
12 /// Solve Lx = b for lower triangular L using forward substitution
13 ///
14 /// # Arguments
15 /// * `b` - Right-hand side vector
16 ///
17 /// # Returns
18 /// Solution vector x
19 ///
20 /// # Errors
21 /// * `DivisionByZero` if any diagonal element is zero
22 /// * `DomainError` if dimensions don't match
23 ///
24 /// # Algorithm
25 /// For i = 0 to n-1:
26 /// `x[i]` = (`b[i]` - Σ(`L[i][j]` * `x[j]`) for `j` < i) / `L[i][i]`
27 pub fn forward_substitution(&self, b: &[Expression]) -> Result<Vec<Expression>, MathError> {
28 let (rows, cols) = self.dimensions();
29
30 if rows != cols {
31 return Err(MathError::DomainError {
32 operation: "forward_substitution".to_string(),
33 value: Expression::function("matrix", vec![]),
34 reason: format!(
35 "Forward substitution requires square matrix, got {}x{}",
36 rows, cols
37 ),
38 });
39 }
40
41 if b.len() != rows {
42 return Err(MathError::DomainError {
43 operation: "forward_substitution".to_string(),
44 value: Expression::function("vector", vec![]),
45 reason: format!(
46 "Dimension mismatch: matrix is {}x{} but b has {} elements",
47 rows,
48 cols,
49 b.len()
50 ),
51 });
52 }
53
54 let mut x = vec![Expression::integer(0); rows];
55
56 for i in 0..rows {
57 // Accumulate sum = Σ L[i][j] * x[j] for j < i
58 let mut terms: Vec<Expression> = Vec::new();
59 for (j, xj) in x.iter().enumerate().take(i) {
60 let lij = self.get_element(i, j);
61 // Use is_zero_fast() - avoids simplify() in hot loop
62 if !lij.is_zero_fast() && !xj.is_zero_fast() {
63 terms.push(Expression::mul(vec![lij, xj.clone()]));
64 }
65 }
66
67 let lii = self.get_element(i, i);
68 // Use is_zero_fast() - pivot elements should already be simplified
69 if lii.is_zero_fast() {
70 return Err(MathError::DivisionByZero);
71 }
72
73 // x[i] = (b[i] - sum) / L[i][i]
74 // Note: Expression::add() and operators already simplify, no need for .simplify()
75 let numerator = if terms.is_empty() {
76 b[i].clone()
77 } else {
78 let sum = Expression::add(terms);
79 b[i].clone() - sum // Operator already simplifies
80 };
81
82 // Compute x[i] = numerator / L[i][i]
83 // Directly compute integer/integer to produce clean results
84 x[i] = if lii == Expression::integer(1) {
85 numerator
86 } else {
87 // Try to compute integer division directly for clean results
88 match (&numerator, &lii) {
89 (
90 Expression::Number(crate::core::Number::Integer(num)),
91 Expression::Number(crate::core::Number::Integer(den)),
92 ) => {
93 if *den != 0 && num % den == 0 {
94 Expression::integer(num / den)
95 } else if *den != 0 {
96 // Create rational for non-exact division
97 use num_bigint::BigInt;
98 use num_rational::BigRational;
99 Expression::Number(crate::core::Number::rational(BigRational::new(
100 BigInt::from(*num),
101 BigInt::from(*den),
102 )))
103 } else {
104 Expression::mul(vec![
105 numerator,
106 Expression::pow(lii, Expression::integer(-1)),
107 ])
108 }
109 }
110 _ => Expression::mul(vec![
111 numerator,
112 Expression::pow(lii, Expression::integer(-1)),
113 ]),
114 }
115 };
116 }
117
118 Ok(x)
119 }
120
121 /// Solve Ux = b for upper triangular U using backward substitution
122 ///
123 /// # Arguments
124 /// * `b` - Right-hand side vector
125 ///
126 /// # Returns
127 /// Solution vector x
128 ///
129 /// # Errors
130 /// * `DivisionByZero` if any diagonal element is zero
131 /// * `DomainError` if dimensions don't match
132 ///
133 /// For i = n-1 down to 0:
134 /// `x[i]` = (`b[i]` - Σ(`U[i][j]` * `x[j]`) for j > i) / `U[i][i]`
135 pub fn backward_substitution(&self, b: &[Expression]) -> Result<Vec<Expression>, MathError> {
136 let (rows, cols) = self.dimensions();
137
138 if rows != cols {
139 return Err(MathError::DomainError {
140 operation: "backward_substitution".to_string(),
141 value: Expression::function("matrix", vec![]),
142 reason: format!(
143 "Backward substitution requires square matrix, got {}x{}",
144 rows, cols
145 ),
146 });
147 }
148
149 if b.len() != rows {
150 return Err(MathError::DomainError {
151 operation: "backward_substitution".to_string(),
152 value: Expression::function("vector", vec![]),
153 reason: format!(
154 "Dimension mismatch: matrix is {}x{} but b has {} elements",
155 rows,
156 cols,
157 b.len()
158 ),
159 });
160 }
161 let mut x = vec![Expression::integer(0); rows];
162
163 for i in (0..rows).rev() {
164 // Accumulate sum = Σ U[i][j] * x[j] for j > i
165 let mut terms: Vec<Expression> = Vec::new();
166 for (j, xj) in x.iter().enumerate().skip(i + 1) {
167 let uij = self.get_element(i, j);
168 // Use is_zero_fast() - avoids simplify() in hot loop
169 if !uij.is_zero_fast() && !xj.is_zero_fast() {
170 terms.push(Expression::mul(vec![uij, xj.clone()]));
171 }
172 }
173
174 let uii = self.get_element(i, i);
175 // Use is_zero_fast() - pivot elements should already be simplified
176 if uii.is_zero_fast() {
177 return Err(MathError::DivisionByZero);
178 }
179
180 // x[i] = (b[i] - sum) / U[i][i]
181 // Note: Expression::add() and operators already simplify, no need for .simplify()
182 let numerator = if terms.is_empty() {
183 b[i].clone()
184 } else {
185 let sum = Expression::add(terms);
186 b[i].clone() - sum // Operator already simplifies
187 };
188
189 // Compute x[i] = numerator / U[i][i]
190 // Directly compute integer/integer to produce clean results
191 x[i] = if uii == Expression::integer(1) {
192 numerator
193 } else {
194 // Try to compute integer division directly for clean results
195 match (&numerator, &uii) {
196 (
197 Expression::Number(crate::core::Number::Integer(num)),
198 Expression::Number(crate::core::Number::Integer(den)),
199 ) => {
200 if *den != 0 && num % den == 0 {
201 Expression::integer(num / den)
202 } else if *den != 0 {
203 // Create rational for non-exact division
204 use num_bigint::BigInt;
205 use num_rational::BigRational;
206 Expression::Number(crate::core::Number::rational(BigRational::new(
207 BigInt::from(*num),
208 BigInt::from(*den),
209 )))
210 } else {
211 Expression::mul(vec![
212 numerator,
213 Expression::pow(uii, Expression::integer(-1)),
214 ])
215 }
216 }
217 _ => Expression::mul(vec![
218 numerator,
219 Expression::pow(uii, Expression::integer(-1)),
220 ]),
221 }
222 };
223 }
224
225 Ok(x)
226 }
227
228 /// Solve Ax = b using optimal decomposition
229 ///
230 /// # Arguments
231 /// * `b` - Right-hand side vector
232 ///
233 /// # Returns
234 /// Solution vector x
235 ///
236 /// # Errors
237 /// * `DomainError` if matrix is not square or dimensions don't match
238 /// * `DivisionByZero` if matrix is singular
239 ///
240 /// # Algorithm Selection
241 /// - Symmetric positive definite matrices: Cholesky (LL^T), ~2x faster
242 /// - General square matrices: LU decomposition with partial pivoting
243 ///
244 /// # Examples
245 /// ```
246 /// use mathhook_core::matrices::Matrix;
247 /// use mathhook_core::expr;
248 ///
249 /// let a = Matrix::from_arrays([[2, 1], [1, 3]]);
250 /// let b = vec![expr!(5), expr!(7)];
251 /// let x = a.solve(&b).unwrap();
252 /// ```
253 pub fn solve(&self, b: &[Expression]) -> Result<Vec<Expression>, MathError> {
254 let (rows, cols) = self.dimensions();
255
256 if rows != cols {
257 return Err(MathError::DomainError {
258 operation: "solve".to_string(),
259 value: Expression::function("matrix", vec![]),
260 reason: format!("Solve requires square matrix, got {}x{}", rows, cols),
261 });
262 }
263
264 if b.len() != rows {
265 return Err(MathError::DomainError {
266 operation: "solve".to_string(),
267 value: Expression::function("vector", vec![]),
268 reason: format!(
269 "Dimension mismatch: matrix is {}x{} but b has {} elements",
270 rows,
271 cols,
272 b.len()
273 ),
274 });
275 }
276
277 // Try Cholesky for symmetric matrices (2x faster for SPD)
278 if self.is_symmetric() {
279 if let Some(chol) = self.cholesky_decomposition() {
280 // Solve LL^T x = b
281 // Step 1: Ly = b (forward substitution)
282 let y = chol.l.forward_substitution(b)?;
283 // Step 2: L^T x = y (backward substitution on L transpose)
284 let lt = chol.l.transpose();
285 return lt.backward_substitution(&y);
286 }
287 // Fall through to LU if Cholesky fails (not positive definite)
288 }
289
290 // General case: LU decomposition with partial pivoting
291 self.solve_via_lu(b)
292 }
293
294 /// Solve Ax = b using LU decomposition
295 ///
296 /// This is the fallback solver for non-SPD matrices.
297 fn solve_via_lu(&self, b: &[Expression]) -> Result<Vec<Expression>, MathError> {
298 let lu = self.lu_decomposition().ok_or(MathError::DivisionByZero)?;
299
300 let pb = apply_permutation(&lu.p, b);
301
302 let y = lu.l.forward_substitution(&pb)?;
303
304 let x = lu.u.backward_substitution(&y)?;
305
306 Ok(x)
307 }
308
309 /// Solve least squares problem: min ||Ax - b||₂ using QR decomposition
310 ///
311 /// # Arguments
312 /// * `b` - Right-hand side vector
313 ///
314 /// # Returns
315 /// Solution vector x that minimizes ||Ax - b||₂
316 ///
317 /// # Errors
318 /// * `DomainError` if dimensions don't match or m < n
319 /// * `DivisionByZero` if R has zero diagonal elements
320 ///
321 /// # Algorithm
322 /// For m×n matrix A (m >= n):
323 /// 1. Compute A = QR (Q is m×n, R is n×n upper triangular)
324 /// 2. Compute c = Q^T * b
325 /// 3. Solve Rx = c`[0:n]` using backward substitution
326 ///
327 /// # Examples
328 /// ```
329 /// use mathhook_core::matrices::Matrix;
330 /// use mathhook_core::expr;
331 ///
332 /// // Overdetermined system: 3 equations, 2 unknowns
333 /// let a = Matrix::from_arrays([[1, 0], [0, 1], [1, 1]]);
334 /// let b = vec![expr!(1), expr!(2), expr!(2)];
335 /// let x = a.solve_least_squares(&b).unwrap();
336 /// ```
337 pub fn solve_least_squares(&self, b: &[Expression]) -> Result<Vec<Expression>, MathError> {
338 let (rows, cols) = self.dimensions();
339
340 if rows < cols {
341 return Err(MathError::DomainError {
342 operation: "solve_least_squares".to_string(),
343 value: Expression::function("matrix", vec![]),
344 reason: format!(
345 "Least squares requires m >= n (overdetermined), got {}x{}",
346 rows, cols
347 ),
348 });
349 }
350
351 if b.len() != rows {
352 return Err(MathError::DomainError {
353 operation: "solve_least_squares".to_string(),
354 value: Expression::function("vector", vec![]),
355 reason: format!(
356 "Dimension mismatch: matrix is {}x{} but b has {} elements",
357 rows,
358 cols,
359 b.len()
360 ),
361 });
362 }
363
364 // For square matrices, use standard solve
365 if rows == cols {
366 return self.solve(b);
367 }
368
369 // QR decomposition: A = QR
370 let qr = self.qr_decomposition().ok_or(MathError::DomainError {
371 operation: "solve_least_squares".to_string(),
372 value: Expression::function("matrix", vec![]),
373 reason: "QR decomposition failed (linearly dependent columns)".to_string(),
374 })?;
375
376 // Compute c = Q^T * b
377 let qt = qr.q.transpose();
378 let c = matrix_vector_multiply(&qt, b);
379
380 // Take first n elements for Rx = c[0:n]
381 let c_truncated: Vec<Expression> = c.into_iter().take(cols).collect();
382
383 // Solve Rx = c using backward substitution
384 qr.r.backward_substitution(&c_truncated)
385 }
386
387 /// Compute inverse using LU decomposition: A^(-1) = solve(A, I) column by column
388 ///
389 /// For each column j of identity matrix I, solve A*x_j = e_j
390 /// The solution vectors x_j form the columns of A^(-1)
391 pub(crate) fn inverse_via_lu(&self) -> Option<Matrix> {
392 let (n, _) = self.dimensions();
393 if n == 0 {
394 return None;
395 }
396
397 // Compute LU decomposition once
398 let lu = self.lu_decomposition()?;
399
400 // Solve for each column of the inverse
401 let mut inv_columns: Vec<Vec<Expression>> = Vec::with_capacity(n);
402
403 for j in 0..n {
404 // Create unit vector e_j
405 let e_j: Vec<Expression> = (0..n)
406 .map(|i| {
407 if i == j {
408 Expression::integer(1)
409 } else {
410 Expression::integer(0)
411 }
412 })
413 .collect();
414
415 // Solve A * x_j = e_j using precomputed LU
416 let pb = apply_permutation(&lu.p, &e_j);
417 let y = match lu.l.forward_substitution(&pb) {
418 Ok(y) => y,
419 Err(_) => return None,
420 };
421 let x_j = match lu.u.backward_substitution(&y) {
422 Ok(x) => x,
423 Err(_) => return None,
424 };
425
426 inv_columns.push(x_j);
427 }
428
429 // Transpose columns to rows for Matrix::Dense
430 let mut result_rows: Vec<Vec<Expression>> = Vec::with_capacity(n);
431 for i in 0..n {
432 let row: Vec<Expression> = inv_columns.iter().map(|col| col[i].clone()).collect();
433 result_rows.push(row);
434 }
435
436 Some(Matrix::Dense(MatrixData { rows: result_rows }).optimize())
437 }
438}
439
440/// Multiply matrix M by vector v: result = M * v
441fn matrix_vector_multiply(m: &Matrix, v: &[Expression]) -> Vec<Expression> {
442 let (rows, cols) = m.dimensions();
443 let mut result = Vec::with_capacity(rows);
444
445 for i in 0..rows {
446 let mut terms: Vec<Expression> = Vec::new();
447 for (j, vj) in v.iter().enumerate().take(cols) {
448 let mij = m.get_element(i, j);
449 // Use is_zero_fast() - avoids simplify() in hot loop
450 if !mij.is_zero_fast() && !vj.is_zero_fast() {
451 terms.push(Expression::mul(vec![mij, vj.clone()]));
452 }
453 }
454 // Note: Expression::add() already simplifies internally, no need for .simplify()
455 let row_sum = if terms.is_empty() {
456 Expression::integer(0)
457 } else {
458 Expression::add(terms)
459 };
460 result.push(row_sum);
461 }
462
463 result
464}
465
466/// Apply permutation matrix P to vector b: result = P * b
467///
468/// Optimized for permutation matrices: O(n) instead of O(n²)
469/// since each row of P has exactly one non-zero element (which is 1).
470pub(crate) fn apply_permutation(p: &Option<Matrix>, b: &[Expression]) -> Vec<Expression> {
471 match p {
472 None => b.to_vec(),
473 Some(p_matrix) => {
474 let n = b.len();
475 let mut result = Vec::with_capacity(n);
476
477 for i in 0..n {
478 // Find the column j where P[i][j] = 1
479 // For a permutation matrix, there's exactly one such j per row
480 for (j, bj) in b.iter().enumerate() {
481 let pij = p_matrix.get_element(i, j);
482 // Use is_zero_fast() - permutation elements are 0 or 1 literals
483 if !pij.is_zero_fast() {
484 // P[i][j] = 1, so result[i] = b[j]
485 result.push(bj.clone());
486 break;
487 }
488 }
489 }
490
491 result
492 }
493 }
494}