1use ndarray::Array1;
36use ndarray::Array2;
37
38use super::InverseOperator;
39
40pub fn estimate_snr(data: &Array2<f64>, inv: &InverseOperator) -> (Array1<f64>, Array1<f64>) {
54 let n_times = data.ncols();
55 let n_eff = inv.n_nzero;
56
57 let data_white = inv.whitener.dot(data);
59
60 let data_white_ef = inv.eigen_fields.dot(&data_white);
62
63 let mut snr = Array1::zeros(n_times);
65 for t in 0..n_times {
66 let mut sum_sq = 0.0;
67 for i in 0..data_white.nrows() {
68 sum_sq += data_white[[i, t]].powi(2);
69 }
70 snr[t] = (sum_sq / n_eff as f64).sqrt();
71 }
72
73 let sing2: Vec<f64> = inv.sing.iter().map(|s| s * s).collect();
84 let n_k = sing2.len();
85
86 let chi2_val = chi2_isf(1e-3, n_eff);
88
89 let mut snr_est = Array1::zeros(n_times);
90 let lambda_mult = 0.99_f64;
91
92 for t in 0..n_times {
93 let sig: f64 = (0..data_white.nrows())
95 .map(|i| data_white[[i, t]].powi(2))
96 .sum();
97 if sig / n_eff as f64 <= 1.0 {
98 snr_est[t] = 0.0;
99 continue;
100 }
101
102 let mut lambda2 = 10.0_f64;
103 let mut converged = false;
104 for _ in 0..1000 {
105 let mut err = 0.0;
106 for k in 0..n_k {
107 if sing2[k] > 0.0 {
108 let pi_k = sing2[k] / (sing2[k] + lambda2);
109 let residual = data_white_ef[[k, t]] * (1.0 - pi_k);
110 err += residual * residual;
111 }
112 }
113 if err < chi2_val {
114 converged = true;
115 break;
116 }
117 lambda2 *= lambda_mult;
118 }
119
120 snr_est[t] = if converged {
121 1.0 / lambda2.sqrt()
122 } else {
123 1.0 / lambda2.sqrt() };
125 }
126
127 (snr, snr_est)
128}
129
130fn chi2_isf(p: f64, k: usize) -> f64 {
135 let z = normal_quantile(1.0 - p);
138 let kf = k as f64;
139 let term = 1.0 - 2.0 / (9.0 * kf) + z * (2.0 / (9.0 * kf)).sqrt();
140 kf * term.powi(3)
141}
142
143fn normal_quantile(p: f64) -> f64 {
147 if p <= 0.0 {
148 return f64::NEG_INFINITY;
149 }
150 if p >= 1.0 {
151 return f64::INFINITY;
152 }
153 if (p - 0.5).abs() < 1e-15 {
154 return 0.0;
155 }
156
157 let sign;
158 let pp;
159 if p < 0.5 {
160 sign = -1.0;
161 pp = p;
162 } else {
163 sign = 1.0;
164 pp = 1.0 - p;
165 };
166
167 let t = (-2.0 * pp.ln()).sqrt();
168
169 let c0 = 2.515517;
171 let c1 = 0.802853;
172 let c2 = 0.010328;
173 let d1 = 1.432788;
174 let d2 = 0.189269;
175 let d3 = 0.001308;
176
177 let z = t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t);
178
179 sign * z
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use crate::{make_inverse_operator, ForwardOperator, NoiseCov};
186 use ndarray::Array2;
187
188 fn make_test_setup() -> (ForwardOperator, NoiseCov) {
189 let n_chan = 16;
190 let n_src = 30;
191 let mut gain = Array2::zeros((n_chan, n_src));
192 for i in 0..n_chan {
193 for j in 0..n_src {
194 let dist =
195 ((i as f64 - j as f64 * n_chan as f64 / n_src as f64).powi(2) + 1.0).sqrt();
196 gain[[i, j]] = 1e-8 / dist;
197 }
198 }
199 (
200 ForwardOperator::new_fixed(gain),
201 NoiseCov::diagonal(vec![1e-12; n_chan]),
202 )
203 }
204
205 #[test]
206 fn test_estimate_snr_shapes() {
207 let (fwd, cov) = make_test_setup();
208 let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
209 let data = Array2::from_elem((16, 20), 1e-6);
210 let (snr, snr_est) = estimate_snr(&data, &inv);
211 assert_eq!(snr.len(), 20);
212 assert_eq!(snr_est.len(), 20);
213 assert!(snr.iter().all(|v| v.is_finite()));
214 assert!(snr_est.iter().all(|v| v.is_finite()));
215 }
216
217 #[test]
218 fn test_snr_signal_vs_noise() {
219 let (fwd, cov) = make_test_setup();
220 let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
221
222 let strong = Array2::from_elem((16, 10), 1e-3);
224 let (snr_strong, _) = estimate_snr(&strong, &inv);
225
226 let weak = Array2::from_elem((16, 10), 1e-15);
228 let (snr_weak, _) = estimate_snr(&weak, &inv);
229
230 assert!(
232 snr_strong[0] > snr_weak[0],
233 "Strong signal SNR ({}) should exceed weak ({})",
234 snr_strong[0],
235 snr_weak[0]
236 );
237 }
238
239 #[test]
240 fn test_chi2_isf_sanity() {
241 let val = chi2_isf(0.5, 100);
243 assert!(
244 (val - 100.0).abs() < 5.0,
245 "chi2_isf(0.5, 100) = {val}, expected ≈ 100"
246 );
247 }
248
249 #[test]
250 fn test_normal_quantile() {
251 let z = normal_quantile(0.975);
252 assert!(
253 (z - 1.96).abs() < 0.01,
254 "normal_quantile(0.975) = {z}, expected ≈ 1.96"
255 );
256 }
257}