mdarray_linalg_lapack/solve/
context.rs

1//! Linear System Solver using LU decomposition (GESV):
2//!     AX = B
3//! where:
4//!     - A is n × n (square coefficient matrix, overwritten with LU factorization)
5//!     - X is n × nrhs (solution matrix)
6//!     - B is n × nrhs (right-hand side matrix, overwritten with solution)
7//!     - P is n × n (permutation matrix from LU decomposition)
8//!
9//! The function `gesv` (LAPACK) solves a system of linear equations AX = B using LU decomposition with partial pivoting.
10//! It computes the LU factorization of A and then uses it to solve the linear system.
11//! The matrix A is overwritten by its LU factorization, and B is overwritten by the solution X.
12
13use super::simple::gesv;
14use mdarray_linalg::get_dims;
15
16use super::scalar::LapackScalar;
17use mdarray::{DSlice, Dense, Layout, tensor};
18use mdarray_linalg::into_i32;
19use mdarray_linalg::ipiv_to_perm_mat;
20use mdarray_linalg::solve::{Solve, SolveError, SolveResult, SolveResultType};
21use num_complex::ComplexFloat;
22
23use crate::Lapack;
24
25impl<T> Solve<T> for Lapack
26where
27    T: ComplexFloat + Default + LapackScalar,
28    T::Real: Into<T>,
29{
30    fn solve_overwrite<La: Layout, Lb: Layout, Lp: Layout>(
31        &self,
32        a: &mut DSlice<T, 2, La>,
33        b: &mut DSlice<T, 2, Lb>,
34        p: &mut DSlice<T, 2, Lp>,
35    ) -> Result<(), SolveError> {
36        let ipiv = gesv::<_, Lb, T>(a, b).unwrap();
37        let (n, _) = *a.shape();
38        let p_matrix = ipiv_to_perm_mat(&ipiv, n);
39        for i in 0..n {
40            for j in 0..n {
41                p[[i, j]] = p_matrix[[i, j]];
42            }
43        }
44        Ok(())
45    }
46
47    fn solve<La: Layout, Lb: Layout>(
48        &self,
49        a: &mut DSlice<T, 2, La>,
50        b: &DSlice<T, 2, Lb>,
51    ) -> SolveResultType<T> {
52        let ((n, _), (_, nrhs)) = get_dims!(a, b);
53
54        let mut b_copy = tensor![[T::default(); nrhs as usize]; n as usize];
55        for i in 0..(n as usize) {
56            for j in 0..(nrhs as usize) {
57                b_copy[[i, j]] = b[[i, j]];
58            }
59        }
60
61        match gesv::<_, Dense, T>(a, &mut b_copy) {
62            Ok(ipiv) => Ok(SolveResult {
63                x: b_copy,
64                p: ipiv_to_perm_mat(&ipiv, n as usize),
65            }),
66            Err(e) => Err(e),
67        }
68    }
69}