Skip to main content

exg_source/
covariance.rs

1//! Noise covariance estimation from sensor data.
2//!
3//! Ported from MNE-Python's `mne.compute_covariance` /
4//! `mne.compute_raw_covariance`.
5//!
6//! ## Methods
7//!
8//! - **Empirical**: `C = (1/N) Σ xᵢ xᵢᵀ` (sample covariance)
9//! - **Shrunk (Ledoit–Wolf)**: `C_shrunk = (1−α) C + α tr(C)/p · I`
10//! - **Diagonal**: keep only diagonal entries
11//!
12//! ## Example
13//!
14//! ```
15//! use exg_source::covariance::{compute_covariance, Regularization};
16//! use ndarray::Array2;
17//!
18//! // 3 channels, 1000 samples
19//! let data = Array2::<f64>::from_shape_fn((3, 1000), |(i, j)| {
20//!     ((i * 1000 + j) as f64).sin() * 1e-6
21//! });
22//! let cov = compute_covariance(&data, Regularization::Empirical);
23//! assert_eq!(cov.n_channels(), 3);
24//! ```
25
26use ndarray::Array2;
27
28use super::NoiseCov;
29
30/// Regularisation strategy for covariance estimation.
31#[derive(Debug, Clone, Copy, PartialEq)]
32pub enum Regularization {
33    /// Raw sample covariance, no regularisation.
34    Empirical,
35    /// Ledoit–Wolf shrinkage towards scaled identity.
36    ///
37    /// If `None`, the optimal shrinkage coefficient is estimated
38    /// automatically. Otherwise, the given `alpha` ∈ [0, 1] is used.
39    ShrunkIdentity(Option<f64>),
40    /// Keep only the diagonal (channel variances).
41    Diagonal,
42}
43
44/// Compute noise covariance from continuous data `[n_channels, n_times]`.
45///
46/// The data is assumed to be baseline or empty-room recording.
47/// The mean is subtracted per channel before computing the covariance.
48///
49/// # Arguments
50///
51/// * `data` — Sensor data, shape `[n_channels, n_times]`.
52/// * `reg`  — Regularisation method.
53///
54/// # Returns
55///
56/// A [`NoiseCov`] suitable for use with [`make_inverse_operator`](super::make_inverse_operator).
57pub fn compute_covariance(data: &Array2<f64>, reg: Regularization) -> NoiseCov {
58    let (n_ch, n_t) = data.dim();
59    assert!(n_t > 1, "Need at least 2 time points for covariance");
60
61    // Subtract per-channel mean
62    let means = data.mean_axis(ndarray::Axis(1)).unwrap();
63    let mut centered = data.clone();
64    for i in 0..n_ch {
65        for j in 0..n_t {
66            centered[[i, j]] -= means[i];
67        }
68    }
69
70    match reg {
71        Regularization::Empirical => {
72            let cov = centered.dot(&centered.t()) / (n_t - 1) as f64;
73            NoiseCov::full(cov)
74        }
75        Regularization::ShrunkIdentity(alpha_opt) => {
76            let cov = centered.dot(&centered.t()) / (n_t - 1) as f64;
77            let alpha = alpha_opt.unwrap_or_else(|| ledoit_wolf_alpha(&centered, &cov));
78            let alpha = alpha.clamp(0.0, 1.0);
79            let trace = cov.diag().sum();
80            let mu = trace / n_ch as f64;
81            let shrunk = cov.mapv(|v| v * (1.0 - alpha)) + Array2::<f64>::eye(n_ch).mapv(|v: f64| v * alpha * mu);
82            NoiseCov::full(shrunk)
83        }
84        Regularization::Diagonal => {
85            let mut vars = Vec::with_capacity(n_ch);
86            for i in 0..n_ch {
87                let mut sum_sq = 0.0;
88                for j in 0..n_t {
89                    sum_sq += centered[[i, j]].powi(2);
90                }
91                vars.push(sum_sq / (n_t - 1) as f64);
92            }
93            NoiseCov::diagonal(vars)
94        }
95    }
96}
97
98/// Compute noise covariance from epoched data `[n_epochs, n_channels, n_times]`.
99///
100/// Concatenates all epochs before computing covariance, subtracting the
101/// per-epoch, per-channel mean (i.e., each epoch is baseline-corrected).
102pub fn compute_covariance_epochs(
103    epochs: &ndarray::Array3<f64>,
104    reg: Regularization,
105) -> NoiseCov {
106    let (n_epochs, n_ch, n_t) = epochs.dim();
107    let total_t = n_epochs * n_t;
108
109    // Concatenate all epochs into [n_ch, total_t], subtracting per-epoch mean
110    let mut concat = Array2::zeros((n_ch, total_t));
111    for e in 0..n_epochs {
112        let epoch = epochs.slice(ndarray::s![e, .., ..]);
113        let mean = epoch.mean_axis(ndarray::Axis(1)).unwrap();
114        for i in 0..n_ch {
115            for j in 0..n_t {
116                concat[[i, e * n_t + j]] = epoch[[i, j]] - mean[i];
117            }
118        }
119    }
120
121    // Now compute covariance on the already-centered data
122    let (_, total) = concat.dim();
123    match reg {
124        Regularization::Empirical => {
125            let cov = concat.dot(&concat.t()) / (total - 1) as f64;
126            NoiseCov::full(cov)
127        }
128        Regularization::ShrunkIdentity(alpha_opt) => {
129            let cov = concat.dot(&concat.t()) / (total - 1) as f64;
130            let alpha = alpha_opt.unwrap_or_else(|| ledoit_wolf_alpha(&concat, &cov));
131            let alpha = alpha.clamp(0.0, 1.0);
132            let trace = cov.diag().sum();
133            let mu = trace / n_ch as f64;
134            let shrunk = cov.mapv(|v| v * (1.0 - alpha)) + Array2::<f64>::eye(n_ch).mapv(|v: f64| v * alpha * mu);
135            NoiseCov::full(shrunk)
136        }
137        Regularization::Diagonal => {
138            let mut vars = Vec::with_capacity(n_ch);
139            for i in 0..n_ch {
140                let mut sum_sq = 0.0;
141                for j in 0..total {
142                    sum_sq += concat[[i, j]].powi(2);
143                }
144                vars.push(sum_sq / (total - 1) as f64);
145            }
146            NoiseCov::diagonal(vars)
147        }
148    }
149}
150
151/// Ledoit–Wolf optimal shrinkage coefficient towards scaled identity.
152///
153/// Implements the analytical formula from Ledoit & Wolf (2004),
154/// "A well-conditioned estimator for large-dimensional covariance matrices."
155fn ledoit_wolf_alpha(x: &Array2<f64>, sample_cov: &Array2<f64>) -> f64 {
156    let (p, n) = x.dim(); // p = channels, n = samples
157
158    if n < 2 {
159        return 1.0;
160    }
161
162    let trace_s = sample_cov.diag().sum();
163    let trace_s2 = sample_cov.iter().map(|v| v * v).sum::<f64>();
164    let mu = trace_s / p as f64;
165
166    // Compute sum of squared norms of x_i x_i^T - S
167    // β̂² = (1/n²) Σ_i ‖x_i x_i^T − S‖²_F
168    let mut beta_sum = 0.0;
169    for t in 0..n {
170        // x_t is column t of x
171        // ‖x_t x_t^T - S‖²_F = (x_t^T x_t)² - 2 x_t^T S x_t + ‖S‖²_F
172        let mut xtx = 0.0;
173        for i in 0..p {
174            xtx += x[[i, t]] * x[[i, t]];
175        }
176        let mut xt_s_xt = 0.0;
177        for i in 0..p {
178            let mut row_dot = 0.0;
179            for j in 0..p {
180                row_dot += sample_cov[[i, j]] * x[[j, t]];
181            }
182            xt_s_xt += x[[i, t]] * row_dot;
183        }
184        beta_sum += xtx * xtx - 2.0 * xt_s_xt + trace_s2;
185    }
186    let beta = beta_sum / (n * n) as f64;
187
188    // δ² = ‖S − μI‖²_F = ‖S‖²_F − p·μ²
189    let delta = trace_s2 - p as f64 * mu * mu;
190
191    if delta <= 0.0 {
192        return 1.0;
193    }
194
195    (beta / delta).clamp(0.0, 1.0)
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use ndarray::Array2;
202
203    #[test]
204    fn test_empirical_covariance_shape() {
205        let data = Array2::<f64>::from_shape_fn((4, 100), |(i, j)| {
206            ((i * 100 + j) as f64 * 0.1).sin()
207        });
208        let cov = compute_covariance(&data, Regularization::Empirical);
209        assert_eq!(cov.n_channels(), 4);
210        let full = cov.to_full();
211        assert_eq!(full.dim(), (4, 4));
212    }
213
214    #[test]
215    fn test_empirical_covariance_symmetric() {
216        let data = Array2::<f64>::from_shape_fn((5, 200), |(i, j)| {
217            ((i * 200 + j) as f64 * 0.3).cos() * 1e-6
218        });
219        let cov = compute_covariance(&data, Regularization::Empirical);
220        let full = cov.to_full();
221        for i in 0..5 {
222            for j in 0..5 {
223                approx::assert_abs_diff_eq!(full[[i, j]], full[[j, i]], epsilon = 1e-15);
224            }
225        }
226    }
227
228    #[test]
229    fn test_empirical_covariance_positive_diagonal() {
230        let data = Array2::<f64>::from_shape_fn((3, 500), |(i, j)| {
231            ((i * 500 + j) as f64 * 0.7).sin() * 1e-6
232        });
233        let cov = compute_covariance(&data, Regularization::Empirical);
234        let diag = cov.diag_elements();
235        for &v in diag.iter() {
236            assert!(v > 0.0, "Diagonal should be positive");
237        }
238    }
239
240    #[test]
241    fn test_diagonal_covariance() {
242        let data = Array2::<f64>::from_shape_fn((3, 500), |(i, j)| {
243            ((i * 500 + j) as f64 * 0.2).sin() * (i as f64 + 1.0) * 1e-6
244        });
245        let cov = compute_covariance(&data, Regularization::Diagonal);
246        assert!(cov.diag);
247        assert_eq!(cov.n_channels(), 3);
248    }
249
250    #[test]
251    fn test_shrunk_covariance_between_empirical_and_identity() {
252        let data = Array2::<f64>::from_shape_fn((4, 200), |(i, j)| {
253            ((i * 200 + j) as f64 * 0.5).sin() * 1e-6
254        });
255        let emp = compute_covariance(&data, Regularization::Empirical).to_full();
256        let shrunk = compute_covariance(&data, Regularization::ShrunkIdentity(None)).to_full();
257
258        // Off-diagonal elements should be smaller in shrunk than empirical
259        let mut emp_offdiag = 0.0;
260        let mut shrunk_offdiag = 0.0;
261        for i in 0..4 {
262            for j in 0..4 {
263                if i != j {
264                    emp_offdiag += emp[[i, j]].abs();
265                    shrunk_offdiag += shrunk[[i, j]].abs();
266                }
267            }
268        }
269        assert!(
270            shrunk_offdiag <= emp_offdiag + 1e-20,
271            "Shrinkage should reduce off-diagonal: shrunk={shrunk_offdiag}, emp={emp_offdiag}"
272        );
273    }
274
275    #[test]
276    fn test_covariance_from_epochs() {
277        let epochs = ndarray::Array3::<f64>::from_shape_fn((10, 3, 50), |(e, i, j)| {
278            ((e * 150 + i * 50 + j) as f64 * 0.4).sin() * 1e-6
279        });
280        let cov = compute_covariance_epochs(&epochs, Regularization::Empirical);
281        assert_eq!(cov.n_channels(), 3);
282        let full = cov.to_full();
283        // Symmetric
284        for i in 0..3 {
285            for j in 0..3 {
286                approx::assert_abs_diff_eq!(full[[i, j]], full[[j, i]], epsilon = 1e-15);
287            }
288        }
289        // Positive diagonal
290        for i in 0..3 {
291            assert!(full[[i, i]] > 0.0);
292        }
293    }
294}