opensrdk_linear_algebra/matrix/sp_hp/pp/
trs.rs

1use super::trf::PPTRF;
2use crate::matrix::ge::Matrix;
3use crate::matrix::MatrixError;
4use crate::number::c64;
5use lapack::{dpptrs, zpptrs};
6
7impl PPTRF {
8    /// # Solve equation
9    ///
10    /// with matrix decomposed by potrf
11    ///
12    /// $$
13    /// \mathbf{A} \mathbf{x} = \mathbf{b}
14    /// $$
15    ///
16    /// $$
17    /// \mathbf{x} = \mathbf{A}^{-1} \mathbf{b}
18    /// $$
19    pub fn pptrs(&self, b: Matrix) -> Result<Matrix, MatrixError> {
20        let PPTRF(mat) = self;
21        let n = mat.dim();
22
23        let mut info = 0;
24
25        let n = n as i32;
26        let mut b = b;
27
28        unsafe {
29            dpptrs(
30                'L' as u8,
31                n,
32                b.cols() as i32,
33                &mat.elems,
34                b.elems_mut(),
35                n,
36                &mut info,
37            );
38        }
39
40        match info {
41            0 => Ok(b),
42            _ => Err(MatrixError::LapackRoutineError {
43                routine: "dpptrs".to_owned(),
44                info,
45            }),
46        }
47    }
48}
49
50impl PPTRF<c64> {
51    /// # Solve equation
52    ///
53    /// with matrix decomposed by potrf
54    ///
55    /// $$
56    /// \mathbf{A} \mathbf{x} = \mathbf{b}
57    /// $$
58    ///
59    /// $$
60    /// \mathbf{x} = \mathbf{A}^{-1} \mathbf{b}
61    /// $$
62    pub fn pptrs(&self, b: Matrix<c64>) -> Result<Matrix<c64>, MatrixError> {
63        let PPTRF::<c64>(mat) = self;
64        let n = mat.dim();
65
66        let mut info = 0;
67
68        let n = n as i32;
69        let mut b = b;
70
71        unsafe {
72            zpptrs(
73                'L' as u8,
74                n,
75                b.cols() as i32,
76                &mat.elems,
77                b.elems_mut(),
78                n,
79                &mut info,
80            );
81        }
82
83        match info {
84            0 => Ok(b),
85            _ => Err(MatrixError::LapackRoutineError {
86                routine: "zpptrs".to_owned(),
87                info,
88            }),
89        }
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use crate::*;
96    #[test]
97    fn it_works() {
98        let a = vec![2.0, 1.0, 2.0];
99        let c = SymmetricPackedMatrix::from(2, a).unwrap();
100        let b = mat![
101            1.0, 3.0;
102            2.0, 4.0
103        ];
104        let l = c.pptrf().unwrap();
105        let x_t = l.pptrs(b).unwrap();
106
107        println!("{:#?}", x_t);
108        // assert_eq!(x_t[0][0], 0.0);
109        // assert_eq!(x_t[0][1], 1.0);
110        // assert_eq!(x_t[1][0], 5.0 / 3.0 - 1.0);
111        // assert_eq!(x_t[1][1], 5.0 / 3.0);
112    }
113}