opensrdk_linear_algebra/matrix/ge/
tri.rs

1use super::trf::GETRF;
2use crate::matrix::ge::Matrix;
3use crate::matrix::MatrixError;
4use crate::number::c64;
5use lapack::{dgetri, zgetri};
6
7impl GETRF {
8    /// # Inverse
9    /// with matrix decomposed by getrf
10    pub fn getri(self) -> Result<Matrix, MatrixError> {
11        let GETRF(mut mat, ipiv) = self;
12
13        let n = mat.rows();
14        if n != mat.cols() {
15            return Err(MatrixError::DimensionMismatch);
16        }
17
18        let mut work = vec![f64::default(); n];
19        let mut info = 0;
20
21        let n = n as i32;
22
23        unsafe {
24            dgetri(n, &mut mat.elems, n, &ipiv, &mut work, n, &mut info);
25        }
26
27        match info {
28            0 => Ok(mat),
29            _ => Err(MatrixError::LapackRoutineError {
30                routine: "dgetri".to_owned(),
31                info,
32            }),
33        }
34    }
35}
36
37impl GETRF<c64> {
38    /// # Inverse
39    /// with matrix decomposed by getrf
40    pub fn getri(self) -> Result<Matrix<c64>, MatrixError> {
41        let GETRF::<c64>(mut mat, ipiv) = self;
42
43        let n = mat.rows();
44        if n != mat.cols() {
45            return Err(MatrixError::DimensionMismatch);
46        }
47
48        let mut work = vec![c64::default(); n];
49        let mut info = 0;
50
51        let n = n as i32;
52
53        unsafe {
54            zgetri(n, &mut mat.elems, n, &ipiv, &mut work, n, &mut info);
55        }
56
57        match info {
58            0 => Ok(mat),
59            _ => Err(MatrixError::LapackRoutineError {
60                routine: "zgetri".to_owned(),
61                info,
62            }),
63        }
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use crate::*;
70    #[test]
71    fn it_works() {
72        let a = mat!(
73            1.0, 2.0;
74            3.0, 4.0
75        );
76        let result = a.clone().getrf().unwrap();
77        let a_inv = result.getri().unwrap();
78        let i = a.dot(&a_inv);
79        let i2 = DiagonalMatrix::identity(2);
80        let i3 = i2.mat();
81        assert_eq!(i[(0, 0)], i3[(0, 0)]);
82    }
83}