opensrdk_linear_algebra/matrix/ge/sy_he/
trd.rs1use crate::{
2 matrix::{ge::Vector, MatrixError},
3 st::SymmetricTridiagonalMatrix,
4 Matrix,
5};
6use lapack::dsytrd;
7use serde::{Deserialize, Serialize};
8use std::error::Error;
9
10#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct SYTRD(pub Matrix, pub Vec<f64>, pub SymmetricTridiagonalMatrix);
12
13impl Matrix {
14 pub fn sytrd(self) -> Result<SYTRD, MatrixError> {
17 if self.rows != self.cols {
18 return Err(MatrixError::DimensionMismatch);
19 }
20 let mut mat = self;
21 let n = mat.rows as i32;
22 let mut d = vec![0.0; mat.rows];
23 let mut e = vec![0.0; mat.rows.max(1) - 1];
24 let mut tau = vec![0.0; mat.rows.max(1) - 1];
25 let lwork = 2 * mat.rows;
26 let mut work = vec![0.0; lwork];
27 let mut info = 0;
28
29 unsafe {
30 dsytrd(
31 'L' as u8,
32 n,
33 &mut mat.elems,
34 n,
35 &mut d,
36 &mut e,
37 &mut tau,
38 &mut work,
39 lwork as i32,
40 &mut info,
41 );
42 if info != 0 {
43 return Err(MatrixError::LapackRoutineError {
44 routine: "dsytrd".to_owned(),
45 info,
46 });
47 }
48 }
49
50 let t = SymmetricTridiagonalMatrix::from(d, e)?;
51
52 Ok(SYTRD(mat, tau, t))
53 }
54
55 pub fn sytrd_k(
59 n: usize,
60 k: usize,
61 vec_mul: &dyn Fn(Vec<f64>) -> Result<Vec<f64>, Box<dyn Error + Send + Sync>>,
62 probe: Option<&[f64]>,
63 ) -> Result<(Matrix, SymmetricTridiagonalMatrix), MatrixError> {
64 let k = k.min(n);
65
66 let mut d = vec![0.0; k];
67 let mut e = vec![0.0; k.max(1) - 1];
68
69 let mut v = vec![vec![0.0; n]; k];
70
71 if 0 < k {
72 match probe {
73 Some(vec) => {
74 if vec.len() != n {
75 return Err(MatrixError::DimensionMismatch);
76 }
77 let norm = vec.iter().map(|wi| wi.powi(2)).sum::<f64>().sqrt();
78 v[0] = vec.iter().map(|vi| vi / norm).collect();
79 }
80 None => {
81 v[0][0] = 1.0;
82 }
83 }
84
85 let a_v = match vec_mul(v[0].clone()) {
86 Ok(v) => Ok(v),
87 Err(e) => Err(MatrixError::Others(e)),
88 }?
89 .col_mat();
90 let v_mat = v[0].clone().col_mat();
91
92 d[0] = a_v.t().dot(&v_mat)[0][0];
93 let mut w_prev = a_v - d[0] * v_mat;
94
95 for i in 1..k {
96 e[i - 1] = w_prev
97 .elems()
98 .iter()
99 .map(|wi| wi.powi(2))
100 .sum::<f64>()
101 .sqrt();
102
103 v[i].clone_from_slice((w_prev * (1.0 / e[i - 1])).elems());
104
105 let a_v = match vec_mul(v[i].clone()) {
106 Ok(v) => Ok(v),
107 Err(e) => Err(MatrixError::Others(e)),
108 }?
109 .col_mat();
110 let v_mat = v[i].clone().col_mat();
111
112 d[i] = a_v.t().dot(&v_mat)[0][0];
113 w_prev = a_v - d[i] * v_mat - e[i - 1] * v[i - 1].clone().col_mat();
114 }
115 }
116
117 let q = Matrix::from(n, v.concat()).unwrap();
118 let t = SymmetricTridiagonalMatrix::from(d, e)?;
119
120 Ok((q, t))
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use crate::*;
127 #[test]
128 fn it_works() {
129 let a = mat![
130 1.0, 3.0, 6.0, 12.0;
131 2.0, 4.0, 8.0, 16.0;
132 3.0, 6.0, 12.0, 24.0;
133 4.0, 8.0, 16.0, 30.0
134 ];
135 let (q, t) =
136 Matrix::sytrd_k(4, 3, &|v: Vec<f64>| Ok((a.dot(&v.col_mat())).vec()), None).unwrap();
137
138 let aback = q.dot(&t.mat()).dot(&q.t());
139
140 println!("{:#?}", aback);
141 println!("{:#?}", q.dot(&q.t()));
142 }
143}