mdarray_linalg_faer/solve/
context.rs

1// Linear system solver using LU decomposition:
2//     A * X = B
3// is solved by computing the LU decomposition with partial pivoting:
4//     P * A = L * U
5// then solving:
6//     L * Y = P * B  (forward substitution)
7//     U * X = Y      (backward substitution)
8// where:
9//     - A is m × m         (square coefficient matrix, overwritten with LU)
10//     - B is m × n         (right-hand side matrix)
11//     - X is m × n         (solution matrix)
12//     - P is m × m         (permutation matrix)
13//     - L is m × m         (lower triangular with ones on diagonal)
14//     - U is m × m         (upper triangular)
15
16use faer_traits::ComplexField;
17use mdarray::{DSlice, Layout, tensor};
18use mdarray_linalg::identity;
19use mdarray_linalg::solve::{Solve, SolveError, SolveResult, SolveResultType};
20use num_complex::ComplexFloat;
21
22use faer::linalg::solvers::Solve as FaerSolve;
23
24use crate::{Faer, into_faer, into_faer_mut};
25
26impl<T> Solve<T> for Faer
27where
28    T: ComplexFloat
29        + ComplexField
30        + Default
31        + std::convert::From<<T as num_complex::ComplexFloat>::Real>
32        + 'static,
33{
34    /// Solves linear system AX = B with new allocated solution matrix
35    /// A is modified (overwritten with LU decomposition)
36    /// Returns the solution X and P the permutation matrix (identity in that case), or error
37    fn solve<La: Layout, Lb: Layout>(
38        &self,
39        a: &mut DSlice<T, 2, La>,
40        b: &DSlice<T, 2, Lb>,
41    ) -> SolveResultType<T> {
42        let (m, n) = *a.shape();
43        let (b_m, b_n) = *b.shape();
44
45        if m != n {
46            return Err(SolveError::InvalidDimensions);
47        }
48
49        if b_m != m {
50            return Err(SolveError::InvalidDimensions);
51        }
52
53        let a_faer = into_faer_mut(a);
54
55        let solver = a_faer.partial_piv_lu();
56
57        let b_faer = into_faer(b);
58        let x_faer = solver.solve(b_faer);
59
60        let mut x_mda = tensor![[T::default(); b_n]; m];
61        let mut x_faer_mut = into_faer_mut(&mut x_mda);
62        for i in 0..m {
63            for j in 0..b_n {
64                x_faer_mut[(i, j)] = x_faer[(i, j)];
65            }
66        }
67
68        let p_mda = identity(m); // No permutation with this routine
69
70        Ok(SolveResult { x: x_mda, p: p_mda })
71    }
72
73    /// Solves linear system AX = b overwriting existing matrices
74    /// A is overwritten with its LU decomposition
75    /// B is overwritten with the solution X
76    /// P is filled with the permutation matrix such that P*A = L*U (here P = identity)
77    /// Returns Ok(()) on success, Err(SolveError) on failure
78    fn solve_overwrite<La: Layout, Lb: Layout, Lp: Layout>(
79        &self,
80        a: &mut DSlice<T, 2, La>,
81        b: &mut DSlice<T, 2, Lb>,
82        p: &mut DSlice<T, 2, Lp>,
83    ) -> Result<(), SolveError> {
84        let (m, n) = *a.shape();
85        let (b_m, b_n) = *b.shape();
86
87        if m != n {
88            return Err(SolveError::InvalidDimensions);
89        }
90
91        if b_m != m {
92            return Err(SolveError::InvalidDimensions);
93        }
94
95        let _par = faer::get_global_parallelism();
96        let a_faer = into_faer(a);
97
98        let solver = a_faer.partial_piv_lu();
99
100        let b_faer = into_faer(b).to_owned();
101        let x_faer = solver.solve(b_faer);
102
103        let mut b_faer_mut = into_faer_mut(b);
104        for i in 0..m {
105            for j in 0..b_n {
106                b_faer_mut[(i, j)] = x_faer[(i, j)];
107            }
108        }
109
110        let mut p_faer = into_faer_mut(p);
111        for i in 0..m {
112            for j in 0..m {
113                if i != j {
114                    p_faer[(i, j)] = T::zero();
115                } else {
116                    p_faer[(i, j)] = T::one();
117                }
118            }
119        }
120
121        Ok(())
122    }
123}