opensrdk_linear_algebra/matrix/ge/
tri.rs1use super::trf::GETRF;
2use crate::matrix::ge::Matrix;
3use crate::matrix::MatrixError;
4use crate::number::c64;
5use lapack::{dgetri, zgetri};
6
7impl GETRF {
8 pub fn getri(self) -> Result<Matrix, MatrixError> {
11 let GETRF(mut mat, ipiv) = self;
12
13 let n = mat.rows();
14 if n != mat.cols() {
15 return Err(MatrixError::DimensionMismatch);
16 }
17
18 let mut work = vec![f64::default(); n];
19 let mut info = 0;
20
21 let n = n as i32;
22
23 unsafe {
24 dgetri(n, &mut mat.elems, n, &ipiv, &mut work, n, &mut info);
25 }
26
27 match info {
28 0 => Ok(mat),
29 _ => Err(MatrixError::LapackRoutineError {
30 routine: "dgetri".to_owned(),
31 info,
32 }),
33 }
34 }
35}
36
37impl GETRF<c64> {
38 pub fn getri(self) -> Result<Matrix<c64>, MatrixError> {
41 let GETRF::<c64>(mut mat, ipiv) = self;
42
43 let n = mat.rows();
44 if n != mat.cols() {
45 return Err(MatrixError::DimensionMismatch);
46 }
47
48 let mut work = vec![c64::default(); n];
49 let mut info = 0;
50
51 let n = n as i32;
52
53 unsafe {
54 zgetri(n, &mut mat.elems, n, &ipiv, &mut work, n, &mut info);
55 }
56
57 match info {
58 0 => Ok(mat),
59 _ => Err(MatrixError::LapackRoutineError {
60 routine: "zgetri".to_owned(),
61 info,
62 }),
63 }
64 }
65}
66
67#[cfg(test)]
68mod tests {
69 use crate::*;
70 #[test]
71 fn it_works() {
72 let a = mat!(
73 1.0, 2.0;
74 3.0, 4.0
75 );
76 let result = a.clone().getrf().unwrap();
77 let a_inv = result.getri().unwrap();
78 let i = a.dot(&a_inv);
79 let i2 = DiagonalMatrix::identity(2);
80 let i3 = i2.mat();
81 assert_eq!(i[(0, 0)], i3[(0, 0)]);
82 }
83}