Skip to main content

exg_source/
linalg.rs

1//! Pure-Rust linear algebra helpers (SVD, eigendecomposition, whitener).
2//!
3//! Uses [`faer`](https://crates.io/crates/faer) under the hood and converts
4//! to/from `ndarray`.
5
6use anyhow::{bail, Result};
7use faer::Mat;
8use ndarray::{Array1, Array2};
9
10/// Convert `ndarray::Array2<f64>` → `faer::Mat<f64>`.
11pub fn ndarray_to_faer(a: &Array2<f64>) -> Mat<f64> {
12    let (r, c) = a.dim();
13    Mat::<f64>::from_fn(r, c, |i, j| a[[i, j]])
14}
15
16/// Convert `faer::Mat<f64>` → `ndarray::Array2<f64>`.
17pub fn faer_to_ndarray(m: &Mat<f64>) -> Array2<f64> {
18    let (r, c) = (m.nrows(), m.ncols());
19    Array2::from_shape_fn((r, c), |(i, j)| m[(i, j)])
20}
21
22/// Thin SVD: `A = U @ diag(s) @ V^T`.
23///
24/// Returns `(U, s, Vt)` where:
25/// - `U`  has shape `[m, k]`
26/// - `s`  has length `k`
27/// - `Vt` has shape `[k, n]`
28///
29/// with `k = min(m, n)`.
30pub fn svd_thin(a: &Array2<f64>) -> Result<(Array2<f64>, Array1<f64>, Array2<f64>)> {
31    let m = ndarray_to_faer(a);
32    let svd = m
33        .thin_svd()
34        .map_err(|e| anyhow::anyhow!("SVD failed: {e:?}"))?;
35
36    let u = svd.U();
37    let s_diag = svd.S();
38    let v = svd.V();
39    let k = u.ncols();
40
41    let s = Array1::from_iter((0..k).map(|i| s_diag[i]));
42    let u_nd = faer_to_ndarray(&u.to_owned());
43    // faer returns V, we need V^T
44    let v_nd = faer_to_ndarray(&v.to_owned());
45    let vt_nd = v_nd.t().to_owned();
46
47    Ok((u_nd, s, vt_nd))
48}
49
50/// Symmetric eigendecomposition of a real symmetric matrix.
51///
52/// Returns `(eigenvalues, eigenvectors)` sorted in **descending** order of
53/// eigenvalue. `eigenvectors` has shape `[n, n]` with eigenvectors as columns.
54pub fn eigh_sorted(a: &Array2<f64>) -> Result<(Array1<f64>, Array2<f64>)> {
55    let (n, nc) = a.dim();
56    if n != nc {
57        bail!("eigh_sorted: matrix must be square, got [{n}, {nc}]");
58    }
59    let m = ndarray_to_faer(a);
60    let evd = m
61        .self_adjoint_eigen(faer::Side::Lower)
62        .map_err(|e| anyhow::anyhow!("Eigendecomposition failed: {e:?}"))?;
63
64    let s_diag = evd.S();
65    let u = evd.U();
66
67    // faer returns eigenvalues in ascending order — build index for descending
68    let mut indices: Vec<usize> = (0..n).collect();
69    indices.sort_by(|&a, &b| s_diag[b].partial_cmp(&s_diag[a]).unwrap());
70
71    let vals = Array1::from_iter(indices.iter().map(|&i| s_diag[i]));
72    let mut vecs = Array2::zeros((n, n));
73    for (col_out, &col_in) in indices.iter().enumerate() {
74        for row in 0..n {
75            vecs[[row, col_out]] = u[(row, col_in)];
76        }
77    }
78
79    Ok((vals, vecs))
80}
81
82/// Compute a whitening matrix from a noise covariance.
83///
84/// Returns `(whitener, n_nzero)` where:
85/// - `whitener` has shape `[n_nzero, n_channels]`
86/// - `n_nzero` is the number of positive eigenvalues
87///
88/// Whitener satisfies `W @ C @ W^T ≈ I` (restricted to the non-zero subspace).
89pub fn compute_whitener(noise_cov: &Array2<f64>) -> Result<(Array2<f64>, usize)> {
90    let (evals, evecs) = eigh_sorted(noise_cov)?;
91    let n = evals.len();
92
93    // Count positive eigenvalues (numerical rank)
94    let tol = evals[0].abs() * 1e-12;
95    let n_nzero = evals.iter().filter(|&&v| v > tol).count();
96    if n_nzero == 0 {
97        bail!("Noise covariance has no positive eigenvalues");
98    }
99
100    // Whitener: W = diag(1/√λ) @ V^T  for the non-zero subspace
101    let mut whitener = Array2::zeros((n_nzero, n));
102    for k in 0..n_nzero {
103        let inv_sqrt = 1.0 / evals[k].sqrt();
104        for j in 0..n {
105            whitener[[k, j]] = inv_sqrt * evecs[[j, k]];
106        }
107    }
108
109    Ok((whitener, n_nzero))
110}
111
112/// Matrix square root of a symmetric positive (semi-)definite matrix.
113///
114/// Returns `M^{1/2}` such that `M^{1/2} @ M^{1/2} ≈ M`.
115pub fn sqrtm_sym(a: &Array2<f64>) -> Result<Array2<f64>> {
116    let (evals, evecs) = eigh_sorted(a)?;
117    let n = evals.len();
118    let mut result = Array2::zeros((n, n));
119
120    for k in 0..n {
121        let s = if evals[k] > 0.0 { evals[k].sqrt() } else { 0.0 };
122        for i in 0..n {
123            for j in 0..n {
124                result[[i, j]] += s * evecs[[i, k]] * evecs[[j, k]];
125            }
126        }
127    }
128
129    Ok(result)
130}
131
132/// Inverse square root of a symmetric positive definite matrix.
133///
134/// Returns `M^{-1/2}` such that `M^{-1/2} @ M @ M^{-1/2} ≈ I`.
135pub fn inv_sqrtm_sym(a: &Array2<f64>) -> Result<Array2<f64>> {
136    let (evals, evecs) = eigh_sorted(a)?;
137    let n = evals.len();
138    let tol = evals[0].abs() * 1e-12;
139    let mut result = Array2::zeros((n, n));
140
141    for k in 0..n {
142        let s = if evals[k] > tol {
143            1.0 / evals[k].sqrt()
144        } else {
145            0.0
146        };
147        for i in 0..n {
148            for j in 0..n {
149                result[[i, j]] += s * evecs[[i, k]] * evecs[[j, k]];
150            }
151        }
152    }
153
154    Ok(result)
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use ndarray::Array2;
161
162    #[test]
163    fn test_svd_identity() {
164        let eye = Array2::<f64>::eye(4);
165        let (u, s, vt) = svd_thin(&eye).unwrap();
166        for &sv in s.iter() {
167            approx::assert_abs_diff_eq!(sv, 1.0, epsilon = 1e-10);
168        }
169        // U @ diag(s) @ Vt ≈ I
170        let reconstructed = u.dot(&Array2::from_diag(&s)).dot(&vt);
171        for ((i, j), &v) in reconstructed.indexed_iter() {
172            let expected = if i == j { 1.0 } else { 0.0 };
173            approx::assert_abs_diff_eq!(v, expected, epsilon = 1e-10);
174        }
175    }
176
177    #[test]
178    fn test_whitener_diagonal() {
179        let mut cov = Array2::<f64>::zeros((3, 3));
180        cov[[0, 0]] = 4.0;
181        cov[[1, 1]] = 9.0;
182        cov[[2, 2]] = 16.0;
183        let (w, n_nz) = compute_whitener(&cov).unwrap();
184        assert_eq!(n_nz, 3);
185
186        // W @ C @ W^T should be ≈ I
187        let result = w.dot(&cov).dot(&w.t());
188        for ((i, j), &v) in result.indexed_iter() {
189            let expected = if i == j { 1.0 } else { 0.0 };
190            approx::assert_abs_diff_eq!(v, expected, epsilon = 1e-10);
191        }
192    }
193
194    #[test]
195    fn test_eigh_sorted_descending() {
196        let mut m = Array2::<f64>::zeros((3, 3));
197        m[[0, 0]] = 1.0;
198        m[[1, 1]] = 3.0;
199        m[[2, 2]] = 2.0;
200        let (evals, _) = eigh_sorted(&m).unwrap();
201        assert!(evals[0] >= evals[1] && evals[1] >= evals[2]);
202        approx::assert_abs_diff_eq!(evals[0], 3.0, epsilon = 1e-10);
203        approx::assert_abs_diff_eq!(evals[1], 2.0, epsilon = 1e-10);
204        approx::assert_abs_diff_eq!(evals[2], 1.0, epsilon = 1e-10);
205    }
206}