opensrdk_linear_algebra/matrix/ge/sy_he/
trd.rs

1use crate::{
2    matrix::{ge::Vector, MatrixError},
3    st::SymmetricTridiagonalMatrix,
4    Matrix,
5};
6use lapack::dsytrd;
7use serde::{Deserialize, Serialize};
8use std::error::Error;
9
10#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct SYTRD(pub Matrix, pub Vec<f64>, pub SymmetricTridiagonalMatrix);
12
13impl Matrix {
14    /// # Tridiagonalize
15    /// for symmetric matrix
16    pub fn sytrd(self) -> Result<SYTRD, MatrixError> {
17        if self.rows != self.cols {
18            return Err(MatrixError::DimensionMismatch);
19        }
20        let mut mat = self;
21        let n = mat.rows as i32;
22        let mut d = vec![0.0; mat.rows];
23        let mut e = vec![0.0; mat.rows.max(1) - 1];
24        let mut tau = vec![0.0; mat.rows.max(1) - 1];
25        let lwork = 2 * mat.rows;
26        let mut work = vec![0.0; lwork];
27        let mut info = 0;
28
29        unsafe {
30            dsytrd(
31                'L' as u8,
32                n,
33                &mut mat.elems,
34                n,
35                &mut d,
36                &mut e,
37                &mut tau,
38                &mut work,
39                lwork as i32,
40                &mut info,
41            );
42            if info != 0 {
43                return Err(MatrixError::LapackRoutineError {
44                    routine: "dsytrd".to_owned(),
45                    info,
46                });
47            }
48        }
49
50        let t = SymmetricTridiagonalMatrix::from(d, e)?;
51
52        Ok(SYTRD(mat, tau, t))
53    }
54
55    /// # Lanczos algorithm
56    /// for symmetric matrix
57    /// only k iteration
58    pub fn sytrd_k(
59        n: usize,
60        k: usize,
61        vec_mul: &dyn Fn(Vec<f64>) -> Result<Vec<f64>, Box<dyn Error + Send + Sync>>,
62        probe: Option<&[f64]>,
63    ) -> Result<(Matrix, SymmetricTridiagonalMatrix), MatrixError> {
64        let k = k.min(n);
65
66        let mut d = vec![0.0; k];
67        let mut e = vec![0.0; k.max(1) - 1];
68
69        let mut v = vec![vec![0.0; n]; k];
70
71        if 0 < k {
72            match probe {
73                Some(vec) => {
74                    if vec.len() != n {
75                        return Err(MatrixError::DimensionMismatch);
76                    }
77                    let norm = vec.iter().map(|wi| wi.powi(2)).sum::<f64>().sqrt();
78                    v[0] = vec.iter().map(|vi| vi / norm).collect();
79                }
80                None => {
81                    v[0][0] = 1.0;
82                }
83            }
84
85            let a_v = match vec_mul(v[0].clone()) {
86                Ok(v) => Ok(v),
87                Err(e) => Err(MatrixError::Others(e)),
88            }?
89            .col_mat();
90            let v_mat = v[0].clone().col_mat();
91
92            d[0] = a_v.t().dot(&v_mat)[0][0];
93            let mut w_prev = a_v - d[0] * v_mat;
94
95            for i in 1..k {
96                e[i - 1] = w_prev
97                    .elems()
98                    .iter()
99                    .map(|wi| wi.powi(2))
100                    .sum::<f64>()
101                    .sqrt();
102
103                v[i].clone_from_slice((w_prev * (1.0 / e[i - 1])).elems());
104
105                let a_v = match vec_mul(v[i].clone()) {
106                    Ok(v) => Ok(v),
107                    Err(e) => Err(MatrixError::Others(e)),
108                }?
109                .col_mat();
110                let v_mat = v[i].clone().col_mat();
111
112                d[i] = a_v.t().dot(&v_mat)[0][0];
113                w_prev = a_v - d[i] * v_mat - e[i - 1] * v[i - 1].clone().col_mat();
114            }
115        }
116
117        let q = Matrix::from(n, v.concat()).unwrap();
118        let t = SymmetricTridiagonalMatrix::from(d, e)?;
119
120        Ok((q, t))
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use crate::*;
127    #[test]
128    fn it_works() {
129        let a = mat![
130            1.0, 3.0, 6.0, 12.0;
131            2.0, 4.0, 8.0, 16.0;
132            3.0, 6.0, 12.0, 24.0;
133            4.0, 8.0, 16.0, 30.0
134        ];
135        let (q, t) =
136            Matrix::sytrd_k(4, 3, &|v: Vec<f64>| Ok((a.dot(&v.col_mat())).vec()), None).unwrap();
137
138        let aback = q.dot(&t.mat()).dot(&q.t());
139
140        println!("{:#?}", aback);
141        println!("{:#?}", q.dot(&q.t()));
142    }
143}