mathhook_core/matrices/unified/
solvers.rs

1//! Matrix linear system solvers
2//!
3//! Provides methods for solving Ax = b using LU, Cholesky, and QR decompositions.
4
5use crate::core::Expression;
6use crate::error::MathError;
7use crate::matrices::types::MatrixData;
8use crate::matrices::unified::operations::CoreMatrixOps;
9use crate::matrices::unified::Matrix;
10
11impl Matrix {
12    /// Solve Lx = b for lower triangular L using forward substitution
13    ///
14    /// # Arguments
15    /// * `b` - Right-hand side vector
16    ///
17    /// # Returns
18    /// Solution vector x
19    ///
20    /// # Errors
21    /// * `DivisionByZero` if any diagonal element is zero
22    /// * `DomainError` if dimensions don't match
23    ///
24    /// # Algorithm
25    /// For i = 0 to n-1:
26    ///   `x[i]` = (`b[i]` - Σ(`L[i][j]` * `x[j]`) for `j` < i) / `L[i][i]`
27    pub fn forward_substitution(&self, b: &[Expression]) -> Result<Vec<Expression>, MathError> {
28        let (rows, cols) = self.dimensions();
29
30        if rows != cols {
31            return Err(MathError::DomainError {
32                operation: "forward_substitution".to_string(),
33                value: Expression::function("matrix", vec![]),
34                reason: format!(
35                    "Forward substitution requires square matrix, got {}x{}",
36                    rows, cols
37                ),
38            });
39        }
40
41        if b.len() != rows {
42            return Err(MathError::DomainError {
43                operation: "forward_substitution".to_string(),
44                value: Expression::function("vector", vec![]),
45                reason: format!(
46                    "Dimension mismatch: matrix is {}x{} but b has {} elements",
47                    rows,
48                    cols,
49                    b.len()
50                ),
51            });
52        }
53
54        let mut x = vec![Expression::integer(0); rows];
55
56        for i in 0..rows {
57            // Accumulate sum = Σ L[i][j] * x[j] for j < i
58            let mut terms: Vec<Expression> = Vec::new();
59            for (j, xj) in x.iter().enumerate().take(i) {
60                let lij = self.get_element(i, j);
61                // Use is_zero_fast() - avoids simplify() in hot loop
62                if !lij.is_zero_fast() && !xj.is_zero_fast() {
63                    terms.push(Expression::mul(vec![lij, xj.clone()]));
64                }
65            }
66
67            let lii = self.get_element(i, i);
68            // Use is_zero_fast() - pivot elements should already be simplified
69            if lii.is_zero_fast() {
70                return Err(MathError::DivisionByZero);
71            }
72
73            // x[i] = (b[i] - sum) / L[i][i]
74            // Note: Expression::add() and operators already simplify, no need for .simplify()
75            let numerator = if terms.is_empty() {
76                b[i].clone()
77            } else {
78                let sum = Expression::add(terms);
79                b[i].clone() - sum // Operator already simplifies
80            };
81
82            // Compute x[i] = numerator / L[i][i]
83            // Directly compute integer/integer to produce clean results
84            x[i] = if lii == Expression::integer(1) {
85                numerator
86            } else {
87                // Try to compute integer division directly for clean results
88                match (&numerator, &lii) {
89                    (
90                        Expression::Number(crate::core::Number::Integer(num)),
91                        Expression::Number(crate::core::Number::Integer(den)),
92                    ) => {
93                        if *den != 0 && num % den == 0 {
94                            Expression::integer(num / den)
95                        } else if *den != 0 {
96                            // Create rational for non-exact division
97                            use num_bigint::BigInt;
98                            use num_rational::BigRational;
99                            Expression::Number(crate::core::Number::rational(BigRational::new(
100                                BigInt::from(*num),
101                                BigInt::from(*den),
102                            )))
103                        } else {
104                            Expression::mul(vec![
105                                numerator,
106                                Expression::pow(lii, Expression::integer(-1)),
107                            ])
108                        }
109                    }
110                    _ => Expression::mul(vec![
111                        numerator,
112                        Expression::pow(lii, Expression::integer(-1)),
113                    ]),
114                }
115            };
116        }
117
118        Ok(x)
119    }
120
121    /// Solve Ux = b for upper triangular U using backward substitution
122    ///
123    /// # Arguments
124    /// * `b` - Right-hand side vector
125    ///
126    /// # Returns
127    /// Solution vector x
128    ///
129    /// # Errors
130    /// * `DivisionByZero` if any diagonal element is zero
131    /// * `DomainError` if dimensions don't match
132    ///
133    /// For i = n-1 down to 0:
134    ///   `x[i]` = (`b[i]` - Σ(`U[i][j]` * `x[j]`) for j > i) / `U[i][i]`
135    pub fn backward_substitution(&self, b: &[Expression]) -> Result<Vec<Expression>, MathError> {
136        let (rows, cols) = self.dimensions();
137
138        if rows != cols {
139            return Err(MathError::DomainError {
140                operation: "backward_substitution".to_string(),
141                value: Expression::function("matrix", vec![]),
142                reason: format!(
143                    "Backward substitution requires square matrix, got {}x{}",
144                    rows, cols
145                ),
146            });
147        }
148
149        if b.len() != rows {
150            return Err(MathError::DomainError {
151                operation: "backward_substitution".to_string(),
152                value: Expression::function("vector", vec![]),
153                reason: format!(
154                    "Dimension mismatch: matrix is {}x{} but b has {} elements",
155                    rows,
156                    cols,
157                    b.len()
158                ),
159            });
160        }
161        let mut x = vec![Expression::integer(0); rows];
162
163        for i in (0..rows).rev() {
164            // Accumulate sum = Σ U[i][j] * x[j] for j > i
165            let mut terms: Vec<Expression> = Vec::new();
166            for (j, xj) in x.iter().enumerate().skip(i + 1) {
167                let uij = self.get_element(i, j);
168                // Use is_zero_fast() - avoids simplify() in hot loop
169                if !uij.is_zero_fast() && !xj.is_zero_fast() {
170                    terms.push(Expression::mul(vec![uij, xj.clone()]));
171                }
172            }
173
174            let uii = self.get_element(i, i);
175            // Use is_zero_fast() - pivot elements should already be simplified
176            if uii.is_zero_fast() {
177                return Err(MathError::DivisionByZero);
178            }
179
180            // x[i] = (b[i] - sum) / U[i][i]
181            // Note: Expression::add() and operators already simplify, no need for .simplify()
182            let numerator = if terms.is_empty() {
183                b[i].clone()
184            } else {
185                let sum = Expression::add(terms);
186                b[i].clone() - sum // Operator already simplifies
187            };
188
189            // Compute x[i] = numerator / U[i][i]
190            // Directly compute integer/integer to produce clean results
191            x[i] = if uii == Expression::integer(1) {
192                numerator
193            } else {
194                // Try to compute integer division directly for clean results
195                match (&numerator, &uii) {
196                    (
197                        Expression::Number(crate::core::Number::Integer(num)),
198                        Expression::Number(crate::core::Number::Integer(den)),
199                    ) => {
200                        if *den != 0 && num % den == 0 {
201                            Expression::integer(num / den)
202                        } else if *den != 0 {
203                            // Create rational for non-exact division
204                            use num_bigint::BigInt;
205                            use num_rational::BigRational;
206                            Expression::Number(crate::core::Number::rational(BigRational::new(
207                                BigInt::from(*num),
208                                BigInt::from(*den),
209                            )))
210                        } else {
211                            Expression::mul(vec![
212                                numerator,
213                                Expression::pow(uii, Expression::integer(-1)),
214                            ])
215                        }
216                    }
217                    _ => Expression::mul(vec![
218                        numerator,
219                        Expression::pow(uii, Expression::integer(-1)),
220                    ]),
221                }
222            };
223        }
224
225        Ok(x)
226    }
227
228    /// Solve Ax = b using optimal decomposition
229    ///
230    /// # Arguments
231    /// * `b` - Right-hand side vector
232    ///
233    /// # Returns
234    /// Solution vector x
235    ///
236    /// # Errors
237    /// * `DomainError` if matrix is not square or dimensions don't match
238    /// * `DivisionByZero` if matrix is singular
239    ///
240    /// # Algorithm Selection
241    /// - Symmetric positive definite matrices: Cholesky (LL^T), ~2x faster
242    /// - General square matrices: LU decomposition with partial pivoting
243    ///
244    /// # Examples
245    /// ```
246    /// use mathhook_core::matrices::Matrix;
247    /// use mathhook_core::expr;
248    ///
249    /// let a = Matrix::from_arrays([[2, 1], [1, 3]]);
250    /// let b = vec![expr!(5), expr!(7)];
251    /// let x = a.solve(&b).unwrap();
252    /// ```
253    pub fn solve(&self, b: &[Expression]) -> Result<Vec<Expression>, MathError> {
254        let (rows, cols) = self.dimensions();
255
256        if rows != cols {
257            return Err(MathError::DomainError {
258                operation: "solve".to_string(),
259                value: Expression::function("matrix", vec![]),
260                reason: format!("Solve requires square matrix, got {}x{}", rows, cols),
261            });
262        }
263
264        if b.len() != rows {
265            return Err(MathError::DomainError {
266                operation: "solve".to_string(),
267                value: Expression::function("vector", vec![]),
268                reason: format!(
269                    "Dimension mismatch: matrix is {}x{} but b has {} elements",
270                    rows,
271                    cols,
272                    b.len()
273                ),
274            });
275        }
276
277        // Try Cholesky for symmetric matrices (2x faster for SPD)
278        if self.is_symmetric() {
279            if let Some(chol) = self.cholesky_decomposition() {
280                // Solve LL^T x = b
281                // Step 1: Ly = b (forward substitution)
282                let y = chol.l.forward_substitution(b)?;
283                // Step 2: L^T x = y (backward substitution on L transpose)
284                let lt = chol.l.transpose();
285                return lt.backward_substitution(&y);
286            }
287            // Fall through to LU if Cholesky fails (not positive definite)
288        }
289
290        // General case: LU decomposition with partial pivoting
291        self.solve_via_lu(b)
292    }
293
294    /// Solve Ax = b using LU decomposition
295    ///
296    /// This is the fallback solver for non-SPD matrices.
297    fn solve_via_lu(&self, b: &[Expression]) -> Result<Vec<Expression>, MathError> {
298        let lu = self.lu_decomposition().ok_or(MathError::DivisionByZero)?;
299
300        let pb = apply_permutation(&lu.p, b);
301
302        let y = lu.l.forward_substitution(&pb)?;
303
304        let x = lu.u.backward_substitution(&y)?;
305
306        Ok(x)
307    }
308
309    /// Solve least squares problem: min ||Ax - b||₂ using QR decomposition
310    ///
311    /// # Arguments
312    /// * `b` - Right-hand side vector
313    ///
314    /// # Returns
315    /// Solution vector x that minimizes ||Ax - b||₂
316    ///
317    /// # Errors
318    /// * `DomainError` if dimensions don't match or m < n
319    /// * `DivisionByZero` if R has zero diagonal elements
320    ///
321    /// # Algorithm
322    /// For m×n matrix A (m >= n):
323    /// 1. Compute A = QR (Q is m×n, R is n×n upper triangular)
324    /// 2. Compute c = Q^T * b
325    /// 3. Solve Rx = c`[0:n]` using backward substitution
326    ///
327    /// # Examples
328    /// ```
329    /// use mathhook_core::matrices::Matrix;
330    /// use mathhook_core::expr;
331    ///
332    /// // Overdetermined system: 3 equations, 2 unknowns
333    /// let a = Matrix::from_arrays([[1, 0], [0, 1], [1, 1]]);
334    /// let b = vec![expr!(1), expr!(2), expr!(2)];
335    /// let x = a.solve_least_squares(&b).unwrap();
336    /// ```
337    pub fn solve_least_squares(&self, b: &[Expression]) -> Result<Vec<Expression>, MathError> {
338        let (rows, cols) = self.dimensions();
339
340        if rows < cols {
341            return Err(MathError::DomainError {
342                operation: "solve_least_squares".to_string(),
343                value: Expression::function("matrix", vec![]),
344                reason: format!(
345                    "Least squares requires m >= n (overdetermined), got {}x{}",
346                    rows, cols
347                ),
348            });
349        }
350
351        if b.len() != rows {
352            return Err(MathError::DomainError {
353                operation: "solve_least_squares".to_string(),
354                value: Expression::function("vector", vec![]),
355                reason: format!(
356                    "Dimension mismatch: matrix is {}x{} but b has {} elements",
357                    rows,
358                    cols,
359                    b.len()
360                ),
361            });
362        }
363
364        // For square matrices, use standard solve
365        if rows == cols {
366            return self.solve(b);
367        }
368
369        // QR decomposition: A = QR
370        let qr = self.qr_decomposition().ok_or(MathError::DomainError {
371            operation: "solve_least_squares".to_string(),
372            value: Expression::function("matrix", vec![]),
373            reason: "QR decomposition failed (linearly dependent columns)".to_string(),
374        })?;
375
376        // Compute c = Q^T * b
377        let qt = qr.q.transpose();
378        let c = matrix_vector_multiply(&qt, b);
379
380        // Take first n elements for Rx = c[0:n]
381        let c_truncated: Vec<Expression> = c.into_iter().take(cols).collect();
382
383        // Solve Rx = c using backward substitution
384        qr.r.backward_substitution(&c_truncated)
385    }
386
387    /// Compute inverse using LU decomposition: A^(-1) = solve(A, I) column by column
388    ///
389    /// For each column j of identity matrix I, solve A*x_j = e_j
390    /// The solution vectors x_j form the columns of A^(-1)
391    pub(crate) fn inverse_via_lu(&self) -> Option<Matrix> {
392        let (n, _) = self.dimensions();
393        if n == 0 {
394            return None;
395        }
396
397        // Compute LU decomposition once
398        let lu = self.lu_decomposition()?;
399
400        // Solve for each column of the inverse
401        let mut inv_columns: Vec<Vec<Expression>> = Vec::with_capacity(n);
402
403        for j in 0..n {
404            // Create unit vector e_j
405            let e_j: Vec<Expression> = (0..n)
406                .map(|i| {
407                    if i == j {
408                        Expression::integer(1)
409                    } else {
410                        Expression::integer(0)
411                    }
412                })
413                .collect();
414
415            // Solve A * x_j = e_j using precomputed LU
416            let pb = apply_permutation(&lu.p, &e_j);
417            let y = match lu.l.forward_substitution(&pb) {
418                Ok(y) => y,
419                Err(_) => return None,
420            };
421            let x_j = match lu.u.backward_substitution(&y) {
422                Ok(x) => x,
423                Err(_) => return None,
424            };
425
426            inv_columns.push(x_j);
427        }
428
429        // Transpose columns to rows for Matrix::Dense
430        let mut result_rows: Vec<Vec<Expression>> = Vec::with_capacity(n);
431        for i in 0..n {
432            let row: Vec<Expression> = inv_columns.iter().map(|col| col[i].clone()).collect();
433            result_rows.push(row);
434        }
435
436        Some(Matrix::Dense(MatrixData { rows: result_rows }).optimize())
437    }
438}
439
440/// Multiply matrix M by vector v: result = M * v
441fn matrix_vector_multiply(m: &Matrix, v: &[Expression]) -> Vec<Expression> {
442    let (rows, cols) = m.dimensions();
443    let mut result = Vec::with_capacity(rows);
444
445    for i in 0..rows {
446        let mut terms: Vec<Expression> = Vec::new();
447        for (j, vj) in v.iter().enumerate().take(cols) {
448            let mij = m.get_element(i, j);
449            // Use is_zero_fast() - avoids simplify() in hot loop
450            if !mij.is_zero_fast() && !vj.is_zero_fast() {
451                terms.push(Expression::mul(vec![mij, vj.clone()]));
452            }
453        }
454        // Note: Expression::add() already simplifies internally, no need for .simplify()
455        let row_sum = if terms.is_empty() {
456            Expression::integer(0)
457        } else {
458            Expression::add(terms)
459        };
460        result.push(row_sum);
461    }
462
463    result
464}
465
466/// Apply permutation matrix P to vector b: result = P * b
467///
468/// Optimized for permutation matrices: O(n) instead of O(n²)
469/// since each row of P has exactly one non-zero element (which is 1).
470pub(crate) fn apply_permutation(p: &Option<Matrix>, b: &[Expression]) -> Vec<Expression> {
471    match p {
472        None => b.to_vec(),
473        Some(p_matrix) => {
474            let n = b.len();
475            let mut result = Vec::with_capacity(n);
476
477            for i in 0..n {
478                // Find the column j where P[i][j] = 1
479                // For a permutation matrix, there's exactly one such j per row
480                for (j, bj) in b.iter().enumerate() {
481                    let pij = p_matrix.get_element(i, j);
482                    // Use is_zero_fast() - permutation elements are 0 or 1 literals
483                    if !pij.is_zero_fast() {
484                        // P[i][j] = 1, so result[i] = b[j]
485                        result.push(bj.clone());
486                        break;
487                    }
488                }
489            }
490
491            result
492        }
493    }
494}