1use anyhow::{bail, Result};
7use faer::Mat;
8use ndarray::{Array1, Array2};
9
10pub fn ndarray_to_faer(a: &Array2<f64>) -> Mat<f64> {
12 let (r, c) = a.dim();
13 Mat::<f64>::from_fn(r, c, |i, j| a[[i, j]])
14}
15
16pub fn faer_to_ndarray(m: &Mat<f64>) -> Array2<f64> {
18 let (r, c) = (m.nrows(), m.ncols());
19 Array2::from_shape_fn((r, c), |(i, j)| m[(i, j)])
20}
21
22pub fn svd_thin(a: &Array2<f64>) -> Result<(Array2<f64>, Array1<f64>, Array2<f64>)> {
31 let m = ndarray_to_faer(a);
32 let svd = m
33 .thin_svd()
34 .map_err(|e| anyhow::anyhow!("SVD failed: {e:?}"))?;
35
36 let u = svd.U();
37 let s_diag = svd.S();
38 let v = svd.V();
39 let k = u.ncols();
40
41 let s = Array1::from_iter((0..k).map(|i| s_diag[i]));
42 let u_nd = faer_to_ndarray(&u.to_owned());
43 let v_nd = faer_to_ndarray(&v.to_owned());
45 let vt_nd = v_nd.t().to_owned();
46
47 Ok((u_nd, s, vt_nd))
48}
49
50pub fn eigh_sorted(a: &Array2<f64>) -> Result<(Array1<f64>, Array2<f64>)> {
55 let (n, nc) = a.dim();
56 if n != nc {
57 bail!("eigh_sorted: matrix must be square, got [{n}, {nc}]");
58 }
59 let m = ndarray_to_faer(a);
60 let evd = m
61 .self_adjoint_eigen(faer::Side::Lower)
62 .map_err(|e| anyhow::anyhow!("Eigendecomposition failed: {e:?}"))?;
63
64 let s_diag = evd.S();
65 let u = evd.U();
66
67 let mut indices: Vec<usize> = (0..n).collect();
69 indices.sort_by(|&a, &b| s_diag[b].partial_cmp(&s_diag[a]).unwrap());
70
71 let vals = Array1::from_iter(indices.iter().map(|&i| s_diag[i]));
72 let mut vecs = Array2::zeros((n, n));
73 for (col_out, &col_in) in indices.iter().enumerate() {
74 for row in 0..n {
75 vecs[[row, col_out]] = u[(row, col_in)];
76 }
77 }
78
79 Ok((vals, vecs))
80}
81
82pub fn compute_whitener(noise_cov: &Array2<f64>) -> Result<(Array2<f64>, usize)> {
90 let (evals, evecs) = eigh_sorted(noise_cov)?;
91 let n = evals.len();
92
93 let tol = evals[0].abs() * 1e-12;
95 let n_nzero = evals.iter().filter(|&&v| v > tol).count();
96 if n_nzero == 0 {
97 bail!("Noise covariance has no positive eigenvalues");
98 }
99
100 let mut whitener = Array2::zeros((n_nzero, n));
102 for k in 0..n_nzero {
103 let inv_sqrt = 1.0 / evals[k].sqrt();
104 for j in 0..n {
105 whitener[[k, j]] = inv_sqrt * evecs[[j, k]];
106 }
107 }
108
109 Ok((whitener, n_nzero))
110}
111
112pub fn sqrtm_sym(a: &Array2<f64>) -> Result<Array2<f64>> {
116 let (evals, evecs) = eigh_sorted(a)?;
117 let n = evals.len();
118 let mut result = Array2::zeros((n, n));
119
120 for k in 0..n {
121 let s = if evals[k] > 0.0 { evals[k].sqrt() } else { 0.0 };
122 for i in 0..n {
123 for j in 0..n {
124 result[[i, j]] += s * evecs[[i, k]] * evecs[[j, k]];
125 }
126 }
127 }
128
129 Ok(result)
130}
131
132pub fn inv_sqrtm_sym(a: &Array2<f64>) -> Result<Array2<f64>> {
136 let (evals, evecs) = eigh_sorted(a)?;
137 let n = evals.len();
138 let tol = evals[0].abs() * 1e-12;
139 let mut result = Array2::zeros((n, n));
140
141 for k in 0..n {
142 let s = if evals[k] > tol {
143 1.0 / evals[k].sqrt()
144 } else {
145 0.0
146 };
147 for i in 0..n {
148 for j in 0..n {
149 result[[i, j]] += s * evecs[[i, k]] * evecs[[j, k]];
150 }
151 }
152 }
153
154 Ok(result)
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use ndarray::Array2;
161
162 #[test]
163 fn test_svd_identity() {
164 let eye = Array2::<f64>::eye(4);
165 let (u, s, vt) = svd_thin(&eye).unwrap();
166 for &sv in s.iter() {
167 approx::assert_abs_diff_eq!(sv, 1.0, epsilon = 1e-10);
168 }
169 let reconstructed = u.dot(&Array2::from_diag(&s)).dot(&vt);
171 for ((i, j), &v) in reconstructed.indexed_iter() {
172 let expected = if i == j { 1.0 } else { 0.0 };
173 approx::assert_abs_diff_eq!(v, expected, epsilon = 1e-10);
174 }
175 }
176
177 #[test]
178 fn test_whitener_diagonal() {
179 let mut cov = Array2::<f64>::zeros((3, 3));
180 cov[[0, 0]] = 4.0;
181 cov[[1, 1]] = 9.0;
182 cov[[2, 2]] = 16.0;
183 let (w, n_nz) = compute_whitener(&cov).unwrap();
184 assert_eq!(n_nz, 3);
185
186 let result = w.dot(&cov).dot(&w.t());
188 for ((i, j), &v) in result.indexed_iter() {
189 let expected = if i == j { 1.0 } else { 0.0 };
190 approx::assert_abs_diff_eq!(v, expected, epsilon = 1e-10);
191 }
192 }
193
194 #[test]
195 fn test_eigh_sorted_descending() {
196 let mut m = Array2::<f64>::zeros((3, 3));
197 m[[0, 0]] = 1.0;
198 m[[1, 1]] = 3.0;
199 m[[2, 2]] = 2.0;
200 let (evals, _) = eigh_sorted(&m).unwrap();
201 assert!(evals[0] >= evals[1] && evals[1] >= evals[2]);
202 approx::assert_abs_diff_eq!(evals[0], 3.0, epsilon = 1e-10);
203 approx::assert_abs_diff_eq!(evals[1], 2.0, epsilon = 1e-10);
204 approx::assert_abs_diff_eq!(evals[2], 1.0, epsilon = 1e-10);
205 }
206}