opensrdk_linear_algebra/matrix/ge/
trs.rs

1use super::trf::GETRF;
2use crate::matrix::ge::Matrix;
3use crate::matrix::MatrixError;
4use crate::number::c64;
5use lapack::{dgetrs, zgetrs};
6
7impl GETRF {
8    /// # Solve equation
9    ///
10    /// with matrix decomposed by getrf
11    ///
12    /// $$
13    /// \mathbf{A} \mathbf{x} = \mathbf{b}
14    /// $$
15    ///
16    /// $$
17    /// \mathbf{x} = \mathbf{A}^{-1} \mathbf{b}
18    /// $$
19    pub fn getrs(&self, b: Matrix) -> Result<Matrix, MatrixError> {
20        let GETRF(mat, ipiv) = self;
21
22        let n = mat.rows();
23        if n != mat.cols() || n != b.rows {
24            return Err(MatrixError::DimensionMismatch);
25        }
26
27        let mut info = 0;
28
29        let n = n as i32;
30        let mut b = b;
31
32        unsafe {
33            dgetrs(
34                'N' as u8,
35                n,
36                b.cols as i32,
37                &mat.elems,
38                n,
39                ipiv,
40                &mut b.elems,
41                n,
42                &mut info,
43            );
44        }
45
46        match info {
47            0 => Ok(b),
48            _ => Err(MatrixError::LapackRoutineError {
49                routine: "dgetrs".to_owned(),
50                info,
51            }),
52        }
53    }
54}
55
56impl GETRF<c64> {
57    /// # Solve equation
58    ///
59    /// with matrix decomposed by getrf
60    ///
61    /// $$
62    /// \mathbf{A} \mathbf{x} = \mathbf{b}
63    /// $$
64    ///
65    /// $$
66    /// \mathbf{x} = \mathbf{A}^{-1} \mathbf{b}
67    /// $$
68    pub fn getrs(&self, bt: Matrix<c64>) -> Result<Matrix<c64>, MatrixError> {
69        let GETRF::<c64>(mat, ipiv) = self;
70
71        let n = mat.rows();
72        if n != mat.cols() || n != bt.cols {
73            return Err(MatrixError::DimensionMismatch);
74        }
75
76        let mut info = 0;
77
78        let n = n as i32;
79        let mut bt = bt;
80
81        unsafe {
82            zgetrs(
83                'T' as u8,
84                n,
85                bt.rows as i32,
86                &mat.elems,
87                n,
88                ipiv,
89                &mut bt.elems,
90                n,
91                &mut info,
92            );
93        }
94
95        match info {
96            0 => Ok(bt),
97            _ => Err(MatrixError::LapackRoutineError {
98                routine: "zgetrs".to_owned(),
99                info,
100            }),
101        }
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use crate::*;
108    #[test]
109    fn it_works() {
110        let a = mat!(
111            2.0, 1.0;
112            1.0, 1.0
113        );
114        let b = mat!(
115            3.0;
116            2.0
117        );
118        let result = a.clone().getrf().unwrap();
119        let x = result.getrs(b).unwrap();
120        let ans = mat!(
121            1.0;
122            1.0
123        );
124        assert_eq!(x, ans);
125    }
126}