Skip to main content

exg_source/
snr.rs

1//! SNR estimation from sensor data and an inverse operator.
2//!
3//! Ported from MNE-Python's `mne.minimum_norm.estimate_snr`.
4//!
5//! ## Overview
6//!
7//! Two SNR measures are provided:
8//!
9//! - **Whitened GFP** (`snr`): Global field power of the whitened data,
10//!   normalised by the effective channel count. This is a data-driven SNR
11//!   that does not depend on the regularisation parameter.
12//!
13//! - **Regularisation-based** (`snr_est`): Finds the smallest `λ²` for
14//!   which the residual (unregularised − regularised prediction) stays
15//!   within a χ² confidence bound. Returns `1 / √λ²`.
16//!
17//! ## Example
18//!
19//! ```no_run
20//! use exg_source::snr::estimate_snr;
21//! use exg_source::*;
22//! use ndarray::Array2;
23//!
24//! # let n_chan = 16; let n_src = 50;
25//! # let gain = Array2::<f64>::zeros((n_chan, n_src));
26//! # let fwd = ForwardOperator::new_fixed(gain);
27//! # let cov = NoiseCov::diagonal(vec![1e-12; n_chan]);
28//! # let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
29//! let data = Array2::<f64>::zeros((n_chan, 100));
30//! let (snr, snr_est) = estimate_snr(&data, &inv);
31//! println!("SNR (whitened GFP): {:?}", &snr.as_slice().unwrap()[..5]);
32//! println!("SNR (estimated):    {:?}", &snr_est.as_slice().unwrap()[..5]);
33//! ```
34
35use ndarray::Array1;
36use ndarray::Array2;
37
38use super::InverseOperator;
39
40/// Estimate SNR as a function of time.
41///
42/// # Arguments
43///
44/// * `data` — Sensor data, shape `[n_channels, n_times]`.
45/// * `inv`  — Inverse operator.
46///
47/// # Returns
48///
49/// `(snr, snr_est)` — both `Array1<f64>` of length `n_times`.
50///
51/// - `snr`: whitened GFP — `√(‖W x(t)‖² / n_eff)`
52/// - `snr_est`: regularisation-based estimate — `1 / √λ²_opt(t)`
53pub fn estimate_snr(data: &Array2<f64>, inv: &InverseOperator) -> (Array1<f64>, Array1<f64>) {
54    let n_times = data.ncols();
55    let n_eff = inv.n_nzero;
56
57    // Whiten the data: w(t) = W @ x(t)
58    let data_white = inv.whitener.dot(data);
59
60    // Project onto eigen-field basis: w_ef(t) = U^T @ w(t)
61    let data_white_ef = inv.eigen_fields.dot(&data_white);
62
63    // ── SNR from whitened GFP ──────────────────────────────────────────
64    let mut snr = Array1::zeros(n_times);
65    for t in 0..n_times {
66        let mut sum_sq = 0.0;
67        for i in 0..data_white.nrows() {
68            sum_sq += data_white[[i, t]].powi(2);
69        }
70        snr[t] = (sum_sq / n_eff as f64).sqrt();
71    }
72
73    // ── SNR from regularisation mismatch ───────────────────────────────
74    //
75    // For each time point, find the largest λ² for which the residual
76    // between unregularised and regularised solutions exceeds a χ² threshold.
77    //
78    // Π_k(λ²) = s_k² / (s_k² + λ²)
79    // error(t) = Σ_k w_ef_k(t)² × (1 − Π_k(λ²))²
80    //
81    // We sweep λ² downward until error < χ²_{n_eff}(0.001).
82
83    let sing2: Vec<f64> = inv.sing.iter().map(|s| s * s).collect();
84    let n_k = sing2.len();
85
86    // χ² critical value approximation for p=0.001 (Wilson–Hilferty)
87    let chi2_val = chi2_isf(1e-3, n_eff);
88
89    let mut snr_est = Array1::zeros(n_times);
90    let lambda_mult = 0.99_f64;
91
92    for t in 0..n_times {
93        // Check if signal is too weak
94        let sig: f64 = (0..data_white.nrows())
95            .map(|i| data_white[[i, t]].powi(2))
96            .sum();
97        if sig / n_eff as f64 <= 1.0 {
98            snr_est[t] = 0.0;
99            continue;
100        }
101
102        let mut lambda2 = 10.0_f64;
103        let mut converged = false;
104        for _ in 0..1000 {
105            let mut err = 0.0;
106            for k in 0..n_k {
107                if sing2[k] > 0.0 {
108                    let pi_k = sing2[k] / (sing2[k] + lambda2);
109                    let residual = data_white_ef[[k, t]] * (1.0 - pi_k);
110                    err += residual * residual;
111                }
112            }
113            if err < chi2_val {
114                converged = true;
115                break;
116            }
117            lambda2 *= lambda_mult;
118        }
119
120        snr_est[t] = if converged {
121            1.0 / lambda2.sqrt()
122        } else {
123            1.0 / lambda2.sqrt() // best estimate even if not converged
124        };
125    }
126
127    (snr, snr_est)
128}
129
130/// Approximate the inverse survival function of χ²(k) at probability p.
131///
132/// Uses the Wilson–Hilferty normal approximation:
133/// `χ²_p ≈ k × (1 − 2/(9k) + z_p × √(2/(9k)))³`
134fn chi2_isf(p: f64, k: usize) -> f64 {
135    // z_p for the standard normal (approximate for small p)
136    // For p = 0.001, z ≈ 3.09
137    let z = normal_quantile(1.0 - p);
138    let kf = k as f64;
139    let term = 1.0 - 2.0 / (9.0 * kf) + z * (2.0 / (9.0 * kf)).sqrt();
140    kf * term.powi(3)
141}
142
143/// Approximate quantile of the standard normal distribution.
144///
145/// Uses the rational approximation from Abramowitz & Stegun (26.2.23).
146fn normal_quantile(p: f64) -> f64 {
147    if p <= 0.0 {
148        return f64::NEG_INFINITY;
149    }
150    if p >= 1.0 {
151        return f64::INFINITY;
152    }
153    if (p - 0.5).abs() < 1e-15 {
154        return 0.0;
155    }
156
157    let sign;
158    let pp;
159    if p < 0.5 {
160        sign = -1.0;
161        pp = p;
162    } else {
163        sign = 1.0;
164        pp = 1.0 - p;
165    };
166
167    let t = (-2.0 * pp.ln()).sqrt();
168
169    // Rational approximation coefficients
170    let c0 = 2.515517;
171    let c1 = 0.802853;
172    let c2 = 0.010328;
173    let d1 = 1.432788;
174    let d2 = 0.189269;
175    let d3 = 0.001308;
176
177    let z = t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t);
178
179    sign * z
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use crate::{make_inverse_operator, ForwardOperator, NoiseCov};
186    use ndarray::Array2;
187
188    fn make_test_setup() -> (ForwardOperator, NoiseCov) {
189        let n_chan = 16;
190        let n_src = 30;
191        let mut gain = Array2::zeros((n_chan, n_src));
192        for i in 0..n_chan {
193            for j in 0..n_src {
194                let dist =
195                    ((i as f64 - j as f64 * n_chan as f64 / n_src as f64).powi(2) + 1.0).sqrt();
196                gain[[i, j]] = 1e-8 / dist;
197            }
198        }
199        (
200            ForwardOperator::new_fixed(gain),
201            NoiseCov::diagonal(vec![1e-12; n_chan]),
202        )
203    }
204
205    #[test]
206    fn test_estimate_snr_shapes() {
207        let (fwd, cov) = make_test_setup();
208        let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
209        let data = Array2::from_elem((16, 20), 1e-6);
210        let (snr, snr_est) = estimate_snr(&data, &inv);
211        assert_eq!(snr.len(), 20);
212        assert_eq!(snr_est.len(), 20);
213        assert!(snr.iter().all(|v| v.is_finite()));
214        assert!(snr_est.iter().all(|v| v.is_finite()));
215    }
216
217    #[test]
218    fn test_snr_signal_vs_noise() {
219        let (fwd, cov) = make_test_setup();
220        let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
221
222        // High SNR: strong signal
223        let strong = Array2::from_elem((16, 10), 1e-3);
224        let (snr_strong, _) = estimate_snr(&strong, &inv);
225
226        // Low SNR: weak signal
227        let weak = Array2::from_elem((16, 10), 1e-15);
228        let (snr_weak, _) = estimate_snr(&weak, &inv);
229
230        // Strong signal should have higher whitened GFP
231        assert!(
232            snr_strong[0] > snr_weak[0],
233            "Strong signal SNR ({}) should exceed weak ({})",
234            snr_strong[0],
235            snr_weak[0]
236        );
237    }
238
239    #[test]
240    fn test_chi2_isf_sanity() {
241        // For large k, chi2_isf(0.5, k) ≈ k
242        let val = chi2_isf(0.5, 100);
243        assert!(
244            (val - 100.0).abs() < 5.0,
245            "chi2_isf(0.5, 100) = {val}, expected ≈ 100"
246        );
247    }
248
249    #[test]
250    fn test_normal_quantile() {
251        let z = normal_quantile(0.975);
252        assert!(
253            (z - 1.96).abs() < 0.01,
254            "normal_quantile(0.975) = {z}, expected ≈ 1.96"
255        );
256    }
257}