mdarray_linalg_lapack/solve/
context.rs1use super::simple::gesv;
14use mdarray_linalg::get_dims;
15
16use super::scalar::LapackScalar;
17use mdarray::{DSlice, Dense, Layout, tensor};
18use mdarray_linalg::into_i32;
19use mdarray_linalg::ipiv_to_perm_mat;
20use mdarray_linalg::solve::{Solve, SolveError, SolveResult, SolveResultType};
21use num_complex::ComplexFloat;
22
23use crate::Lapack;
24
25impl<T> Solve<T> for Lapack
26where
27 T: ComplexFloat + Default + LapackScalar,
28 T::Real: Into<T>,
29{
30 fn solve_overwrite<La: Layout, Lb: Layout, Lp: Layout>(
31 &self,
32 a: &mut DSlice<T, 2, La>,
33 b: &mut DSlice<T, 2, Lb>,
34 p: &mut DSlice<T, 2, Lp>,
35 ) -> Result<(), SolveError> {
36 let ipiv = gesv::<_, Lb, T>(a, b).unwrap();
37 let (n, _) = *a.shape();
38 let p_matrix = ipiv_to_perm_mat(&ipiv, n);
39 for i in 0..n {
40 for j in 0..n {
41 p[[i, j]] = p_matrix[[i, j]];
42 }
43 }
44 Ok(())
45 }
46
47 fn solve<La: Layout, Lb: Layout>(
48 &self,
49 a: &mut DSlice<T, 2, La>,
50 b: &DSlice<T, 2, Lb>,
51 ) -> SolveResultType<T> {
52 let ((n, _), (_, nrhs)) = get_dims!(a, b);
53
54 let mut b_copy = tensor![[T::default(); nrhs as usize]; n as usize];
55 for i in 0..(n as usize) {
56 for j in 0..(nrhs as usize) {
57 b_copy[[i, j]] = b[[i, j]];
58 }
59 }
60
61 match gesv::<_, Dense, T>(a, &mut b_copy) {
62 Ok(ipiv) => Ok(SolveResult {
63 x: b_copy,
64 p: ipiv_to_perm_mat(&ipiv, n as usize),
65 }),
66 Err(e) => Err(e),
67 }
68 }
69}