opensrdk_linear_algebra/matrix/ge/
trs.rs1use super::trf::GETRF;
2use crate::matrix::ge::Matrix;
3use crate::matrix::MatrixError;
4use crate::number::c64;
5use lapack::{dgetrs, zgetrs};
6
7impl GETRF {
8 pub fn getrs(&self, b: Matrix) -> Result<Matrix, MatrixError> {
20 let GETRF(mat, ipiv) = self;
21
22 let n = mat.rows();
23 if n != mat.cols() || n != b.rows {
24 return Err(MatrixError::DimensionMismatch);
25 }
26
27 let mut info = 0;
28
29 let n = n as i32;
30 let mut b = b;
31
32 unsafe {
33 dgetrs(
34 'N' as u8,
35 n,
36 b.cols as i32,
37 &mat.elems,
38 n,
39 ipiv,
40 &mut b.elems,
41 n,
42 &mut info,
43 );
44 }
45
46 match info {
47 0 => Ok(b),
48 _ => Err(MatrixError::LapackRoutineError {
49 routine: "dgetrs".to_owned(),
50 info,
51 }),
52 }
53 }
54}
55
56impl GETRF<c64> {
57 pub fn getrs(&self, bt: Matrix<c64>) -> Result<Matrix<c64>, MatrixError> {
69 let GETRF::<c64>(mat, ipiv) = self;
70
71 let n = mat.rows();
72 if n != mat.cols() || n != bt.cols {
73 return Err(MatrixError::DimensionMismatch);
74 }
75
76 let mut info = 0;
77
78 let n = n as i32;
79 let mut bt = bt;
80
81 unsafe {
82 zgetrs(
83 'T' as u8,
84 n,
85 bt.rows as i32,
86 &mat.elems,
87 n,
88 ipiv,
89 &mut bt.elems,
90 n,
91 &mut info,
92 );
93 }
94
95 match info {
96 0 => Ok(bt),
97 _ => Err(MatrixError::LapackRoutineError {
98 routine: "zgetrs".to_owned(),
99 info,
100 }),
101 }
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use crate::*;
108 #[test]
109 fn it_works() {
110 let a = mat!(
111 2.0, 1.0;
112 1.0, 1.0
113 );
114 let b = mat!(
115 3.0;
116 2.0
117 );
118 let result = a.clone().getrf().unwrap();
119 let x = result.getrs(b).unwrap();
120 let ans = mat!(
121 1.0;
122 1.0
123 );
124 assert_eq!(x, ans);
125 }
126}