opensrdk_linear_algebra/matrix/sp_hp/
trs.rs

1use super::trf::{HPTRF, SPTRF};
2use crate::number::c64;
3use crate::{matrix::MatrixError, Matrix};
4use lapack::{dsptrs, zhptrs, zsptrs};
5
6impl SPTRF {
7    /// # Solve equation
8    ///
9    /// with matrix decomposed by sptrf
10    ///
11    /// $$
12    /// \mathbf{A} \mathbf{x} = \mathbf{b}
13    /// $$
14    ///
15    /// $$
16    /// \mathbf{x} = \mathbf{A}^{-1} \mathbf{b}
17    /// $$
18    pub fn sptrs(&self, b: Matrix) -> Result<Matrix, MatrixError> {
19        let SPTRF(mat, ipiv) = self;
20        let n = mat.dim();
21
22        let mut info = 0;
23
24        let n = n as i32;
25        let mut b = b;
26
27        unsafe {
28            dsptrs(
29                'L' as u8,
30                n,
31                b.cols() as i32,
32                &mat.elems,
33                &ipiv,
34                b.elems_mut(),
35                n,
36                &mut info,
37            );
38        }
39
40        match info {
41            0 => Ok(b),
42            _ => Err(MatrixError::LapackRoutineError {
43                routine: "dsptrs".to_owned(),
44                info,
45            }),
46        }
47    }
48}
49
50impl SPTRF<c64> {
51    /// # Solve equation
52    ///
53    /// with matrix decomposed by sptrf
54    ///
55    /// $$
56    /// \mathbf{A} \mathbf{x} = \mathbf{b}
57    /// $$
58    ///
59    /// $$
60    /// \mathbf{x} = \mathbf{A}^{-1} \mathbf{b}
61    /// $$
62    pub fn sptrs(&self, b: Matrix<c64>) -> Result<Matrix<c64>, MatrixError> {
63        let SPTRF::<c64>(mat, ipiv) = self;
64        let n = mat.dim();
65
66        let mut info = 0;
67
68        let n = n as i32;
69        let mut b = b;
70
71        unsafe {
72            zsptrs(
73                'L' as u8,
74                n,
75                b.cols() as i32,
76                &mat.elems,
77                &ipiv,
78                b.elems_mut(),
79                n,
80                &mut info,
81            );
82        }
83
84        match info {
85            0 => Ok(b),
86            _ => Err(MatrixError::LapackRoutineError {
87                routine: "zsptrs".to_owned(),
88                info,
89            }),
90        }
91    }
92}
93
94impl HPTRF {
95    /// # Solve equation
96    ///
97    /// with matrix decomposed by hptrf
98    ///
99    /// $$
100    /// \mathbf{A} \mathbf{x} = \mathbf{b}
101    /// $$
102    ///
103    /// $$
104    /// \mathbf{x} = \mathbf{A}^{-1} \mathbf{b}
105    /// $$
106    pub fn hptrs(&self, b: Matrix<c64>) -> Result<Matrix<c64>, MatrixError> {
107        let HPTRF(mat, ipiv) = self;
108        let n = mat.dim();
109
110        let mut info = 0;
111
112        let n = n as i32;
113        let mut b = b;
114
115        unsafe {
116            zhptrs(
117                'L' as u8,
118                n,
119                b.cols() as i32,
120                &mat.elems,
121                &ipiv,
122                b.elems_mut(),
123                n,
124                &mut info,
125            );
126        }
127
128        match info {
129            0 => Ok(b),
130            _ => Err(MatrixError::LapackRoutineError {
131                routine: "zhptrs".to_owned(),
132                info,
133            }),
134        }
135    }
136}
137
138#[cfg(test)]
139
140mod tests {
141    use crate::*;
142    #[test]
143    fn it_works() {
144        let a = vec![2.0, 1.0, 2.0];
145        let c = SymmetricPackedMatrix::from(2, a).unwrap();
146        let b = mat![
147            1.0, 3.0;
148            2.0, 4.0
149        ];
150        let l = c.sptrf().unwrap();
151        let x_t = l.sptrs(b).unwrap();
152
153        println!("{:#?}", x_t);
154        // assert_eq!(x_t[0][0], 0.0);
155        // assert_eq!(x_t[0][1], 1.0);
156        // assert_eq!(x_t[1][0], 5.0 / 3.0 - 1.0);
157        // assert_eq!(x_t[1][1], 5.0 / 3.0);
158    }
159}