Skip to main content

exg_source/
inverse.rs

1//! Inverse operator construction and application.
2//!
3//! Implements `make_inverse_operator` and `apply_inverse`, ported from
4//! MNE-Python's `mne.minimum_norm.inverse`.
5
6use anyhow::{bail, Result};
7use ndarray::{Array1, Array2};
8
9use super::eloreta::compute_eloreta;
10use super::linalg;
11use super::{
12    EloretaOptions, ForwardOperator, InverseMethod, InverseOperator, NoiseCov, PickOri,
13    SourceEstimate, SourceOrientation,
14};
15
16/// Compute the depth-weighting prior from the gain matrix.
17///
18/// Following MNE-Python's `compute_depth_prior`, the depth weight for source
19/// `k` is `(‖G_k‖ / max ‖G_k‖)^exp` where `G_k` is the column (or 3-column
20/// block for free orientation) of the whitened gain matrix.
21fn compute_depth_prior(
22    whitened_gain: &Array2<f64>,
23    n_sources: usize,
24    n_orient: usize,
25    exp: f64,
26) -> Array1<f64> {
27    let n_cols = n_sources * n_orient;
28    let mut col_norms = Array1::zeros(n_sources);
29
30    for s in 0..n_sources {
31        let mut norm_sq = 0.0;
32        for o in 0..n_orient {
33            let col_idx = s * n_orient + o;
34            if col_idx < n_cols {
35                for r in 0..whitened_gain.nrows() {
36                    norm_sq += whitened_gain[[r, col_idx]].powi(2);
37                }
38            }
39        }
40        col_norms[s] = norm_sq.sqrt();
41    }
42
43    let max_norm = col_norms.iter().copied().fold(0.0_f64, f64::max);
44    if max_norm <= 0.0 {
45        return Array1::ones(n_cols);
46    }
47
48    let mut prior = Array1::zeros(n_cols);
49    for s in 0..n_sources {
50        let w = (col_norms[s] / max_norm).powf(exp);
51        for o in 0..n_orient {
52            prior[s * n_orient + o] = w;
53        }
54    }
55    prior
56}
57
58/// Build an inverse operator from a forward model and noise covariance.
59///
60/// This is the Rust equivalent of `mne.minimum_norm.make_inverse_operator`.
61///
62/// # Arguments
63///
64/// * `fwd` — Forward operator (gain matrix + source info).
65/// * `noise_cov` — Sensor noise covariance.
66/// * `depth_exp` — Optional depth-weighting exponent. If `None`, uses
67///   `fwd.depth_exp` or falls back to no depth weighting.
68///
69/// # Returns
70///
71/// An [`InverseOperator`] containing the SVD decomposition of the whitened
72/// and weighted gain matrix, ready for use with [`apply_inverse`].
73pub fn make_inverse_operator(
74    fwd: &ForwardOperator,
75    noise_cov: &NoiseCov,
76    depth_exp: Option<f64>,
77) -> Result<InverseOperator> {
78    let n_chan = fwd.gain.nrows();
79    let n_orient = fwd.n_orient();
80    let n_cols = fwd.n_sources * n_orient;
81
82    if fwd.gain.ncols() != n_cols {
83        bail!(
84            "Gain matrix has {} columns but expected {} (n_sources={} × n_orient={})",
85            fwd.gain.ncols(),
86            n_cols,
87            fwd.n_sources,
88            n_orient,
89        );
90    }
91    if noise_cov.n_channels() != n_chan {
92        bail!(
93            "Noise covariance has {} channels but gain has {} rows",
94            noise_cov.n_channels(),
95            n_chan,
96        );
97    }
98
99    // 1. Compute whitener from noise covariance
100    let cov_full = noise_cov.to_full();
101    let (whitener, n_nzero) = linalg::compute_whitener(&cov_full)?;
102
103    // 2. Whiten the gain matrix: G_w = W @ G
104    let gain_w = whitener.dot(&fwd.gain);
105
106    // 3. Compute depth prior (optional)
107    let exp = depth_exp.or(fwd.depth_exp);
108    let source_std = if let Some(e) = exp {
109        let prior = compute_depth_prior(&gain_w, fwd.n_sources, n_orient, e);
110        let mut std = Array1::zeros(n_cols);
111        for i in 0..n_cols {
112            std[i] = prior[i].sqrt();
113        }
114        std
115    } else {
116        Array1::ones(n_cols)
117    };
118
119    // 4. Apply source weighting to whitened gain: G_w *= source_std
120    let mut gain_ws = gain_w;
121    for j in 0..n_cols {
122        for i in 0..gain_ws.nrows() {
123            gain_ws[[i, j]] *= source_std[j];
124        }
125    }
126
127    // 5. Scale so that trace(G_ws @ G_ws^T) = n_nzero
128    let trace_grgt = gain_ws.iter().map(|v| v * v).sum::<f64>();
129    let scale = (n_nzero as f64 / trace_grgt).sqrt();
130    gain_ws.mapv_inplace(|v| v * scale);
131    let source_std = source_std.mapv(|v| v * scale);
132
133    // 6. SVD of whitened, weighted gain
134    let (u, sing, vt) = linalg::svd_thin(&gain_ws)?;
135
136    // eigen_fields = U^T  [k, n_chan]
137    let eigen_fields = u.t().to_owned();
138    // eigen_leads = V  [n_cols, k]  (from Vt -> V = Vt^T)
139    let eigen_leads = vt.t().to_owned();
140
141    let source_cov = source_std.mapv(|v| v * v);
142
143    Ok(InverseOperator {
144        eigen_fields,
145        sing,
146        eigen_leads,
147        source_cov,
148        eigen_leads_weighted: false,
149        n_sources: fwd.n_sources,
150        orientation: fwd.orientation,
151        source_nn: fwd.source_nn.clone(),
152        whitener,
153        n_nzero,
154        noise_cov: noise_cov.clone(),
155    })
156}
157
158/// Intermediate prepared state for an inverse operator.
159pub struct PreparedInverse {
160    /// Regularised inverse of singular values: `s / (s² + λ²)`.
161    pub reginv: Array1<f64>,
162    /// Noise-normalisation factors (one per source), or `None` for MNE.
163    pub noisenorm: Option<Array1<f64>>,
164    /// Imaging kernel `K` [n_sources_out, n_channels].
165    pub kernel: Array2<f64>,
166}
167
168/// Prepare an inverse operator for a specific method and regularisation.
169///
170/// Computes the imaging kernel and noise normalisation.
171///
172/// # Arguments
173///
174/// * `inv` — Inverse operator from [`make_inverse_operator`].
175/// * `lambda2` — Regularisation parameter (recommended: 1/SNR²).
176/// * `method` — Inverse method to use.
177/// * `eloreta_opts` — Options for eLORETA (ignored for other methods).
178pub fn prepare_inverse(
179    inv: &InverseOperator,
180    lambda2: f64,
181    method: InverseMethod,
182    eloreta_opts: Option<&EloretaOptions>,
183) -> Result<PreparedInverse> {
184    let n_orient = match inv.orientation {
185        SourceOrientation::Fixed => 1,
186        SourceOrientation::Free => 3,
187    };
188
189    if method == InverseMethod::ELORETA {
190        return prepare_eloreta(inv, lambda2, eloreta_opts);
191    }
192
193    // Compute regularised inverse: reginv_k = s_k / (s_k² + λ²)
194    let reginv = compute_reginv(&inv.sing, lambda2, inv.n_nzero);
195
196    // Noise normalisation
197    let noisenorm = match method {
198        InverseMethod::MNE => None,
199        InverseMethod::DSPM => {
200            let noise_weight = reginv.clone();
201            Some(compute_noise_norm(inv, &noise_weight, n_orient))
202        }
203        InverseMethod::SLORETA => {
204            let noise_weight = Array1::from_iter(reginv.iter().zip(inv.sing.iter()).map(
205                |(&ri, &si)| ri * (1.0 + si * si / lambda2).sqrt(),
206            ));
207            Some(compute_noise_norm(inv, &noise_weight, n_orient))
208        }
209        InverseMethod::ELORETA => unreachable!(),
210    };
211
212    // Assemble kernel: K = sqrt(source_cov) @ V @ diag(reginv) @ U^T @ W
213    // trans = U^T @ W  has shape [k, n_chan] (but eigen_fields is already U^T)
214    // So trans = eigen_fields @ whitener... no, eigen_fields IS U^T [k, n_chan_whitened]
215    // We need: trans = diag(reginv) @ eigen_fields @ whitener
216    // Wait, let me re-read MNE:
217    // trans = eigen_fields @ whitener @ proj
218    // trans *= reginv[:, None]
219    // K = eigen_leads @ trans
220    // K *= sqrt(source_cov)[:, None]
221
222    let n_k = inv.sing.len();
223    let n_chan = inv.whitener.ncols();
224
225    // trans = eigen_fields @ whitener  [k, n_chan]
226    let trans = inv.eigen_fields.dot(&inv.whitener);
227    // trans *= reginv (row-wise)
228    let mut trans_scaled = Array2::zeros((n_k, n_chan));
229    for i in 0..n_k {
230        for j in 0..n_chan {
231            trans_scaled[[i, j]] = trans[[i, j]] * reginv[i];
232        }
233    }
234
235    // K = eigen_leads @ trans_scaled  [n_src*n_orient, n_chan]
236    let mut kernel = inv.eigen_leads.dot(&trans_scaled);
237
238    // K *= sqrt(source_cov)
239    if !inv.eigen_leads_weighted {
240        for i in 0..kernel.nrows() {
241            let w = inv.source_cov[i].sqrt();
242            for j in 0..kernel.ncols() {
243                kernel[[i, j]] *= w;
244            }
245        }
246    }
247
248    Ok(PreparedInverse {
249        reginv,
250        noisenorm,
251        kernel,
252    })
253}
254
255/// Prepare the eLORETA inverse.
256fn prepare_eloreta(
257    inv: &InverseOperator,
258    lambda2: f64,
259    opts: Option<&EloretaOptions>,
260) -> Result<PreparedInverse> {
261    let default_opts = EloretaOptions::default();
262    let opts = opts.unwrap_or(&default_opts);
263
264    let (kernel, reginv) = compute_eloreta(inv, lambda2, opts)?;
265
266    Ok(PreparedInverse {
267        reginv,
268        noisenorm: None, // eLORETA embeds normalisation in the kernel
269        kernel,
270    })
271}
272
273/// Compute `reginv[k] = s[k] / (s[k]² + λ²)` for the first `n_nzero` values.
274fn compute_reginv(sing: &Array1<f64>, lambda2: f64, n_nzero: usize) -> Array1<f64> {
275    let n = sing.len();
276    let mut reginv = Array1::zeros(n);
277    for k in 0..n.min(n_nzero) {
278        let s = sing[k];
279        if s > 0.0 {
280            reginv[k] = s / (s * s + lambda2);
281        }
282    }
283    reginv
284}
285
286/// Compute noise-normalisation factors (dSPM / sLORETA).
287///
288/// For each source, compute `1 / ‖row_k @ diag(noise_weight)‖₂`.
289fn compute_noise_norm(
290    inv: &InverseOperator,
291    noise_weight: &Array1<f64>,
292    n_orient: usize,
293) -> Array1<f64> {
294    let n_rows = inv.eigen_leads.nrows();
295    let n_k = noise_weight.len();
296
297    let mut raw_norm = Array1::zeros(n_rows);
298    for k in 0..n_rows {
299        let mut sq_sum = 0.0;
300        for j in 0..n_k {
301            let lead = if inv.eigen_leads_weighted {
302                inv.eigen_leads[[k, j]]
303            } else {
304                inv.source_cov[k].sqrt() * inv.eigen_leads[[k, j]]
305            };
306            let val = lead * noise_weight[j];
307            sq_sum += val * val;
308        }
309        raw_norm[k] = sq_sum.sqrt();
310    }
311
312    // For free orientation: combine XYZ triplets
313    if n_orient == 3 {
314        let n_src = n_rows / 3;
315        let mut combined = Array1::zeros(n_src);
316        for s in 0..n_src {
317            let mut sum_sq = 0.0;
318            for o in 0..3 {
319                sum_sq += raw_norm[s * 3 + o].powi(2);
320            }
321            combined[s] = sum_sq.sqrt();
322        }
323        combined.mapv(|v| if v.abs() > 0.0 { 1.0 / v } else { 0.0 })
324    } else {
325        raw_norm.mapv(|v| if v.abs() > 0.0 { 1.0 / v } else { 0.0 })
326    }
327}
328
329/// Combine free-orientation XYZ triplets: `√(x² + y² + z²)` per source.
330fn combine_xyz(sol: &Array2<f64>) -> Array2<f64> {
331    let (n_rows, n_times) = sol.dim();
332    assert!(n_rows % 3 == 0, "combine_xyz: rows must be divisible by 3");
333    let n_src = n_rows / 3;
334    let mut out = Array2::zeros((n_src, n_times));
335    for s in 0..n_src {
336        for t in 0..n_times {
337            let x = sol[[s * 3, t]];
338            let y = sol[[s * 3 + 1, t]];
339            let z = sol[[s * 3 + 2, t]];
340            out[[s, t]] = (x * x + y * y + z * z).sqrt();
341        }
342    }
343    out
344}
345
346/// Apply an inverse operator to sensor-space data.
347///
348/// This is the Rust equivalent of `mne.minimum_norm.apply_inverse`.
349///
350/// # Arguments
351///
352/// * `data` — Sensor data, shape `[n_channels, n_times]`.
353/// * `inv` — Inverse operator from [`make_inverse_operator`].
354/// * `lambda2` — Regularisation parameter (recommended: `1.0 / SNR.powi(2)`).
355/// * `method` — Which method to use.
356///
357/// # Returns
358///
359/// A [`SourceEstimate`] with shape `[n_sources, n_times]` (magnitudes for
360/// free orientation are combined across XYZ).
361///
362/// # Example
363///
364/// ```no_run
365/// use exg_source::*;
366/// use ndarray::Array2;
367///
368/// let n_chan = 32;
369/// let n_src  = 500;
370/// let gain = Array2::<f64>::from_elem((n_chan, n_src), 1e-8);
371/// let fwd  = ForwardOperator::new_fixed(gain);
372/// let cov  = NoiseCov::diagonal(vec![1e-12; n_chan]);
373/// let inv  = make_inverse_operator(&fwd, &cov, None).unwrap();
374///
375/// let data = Array2::<f64>::zeros((n_chan, 100));
376/// let stc  = apply_inverse(&data, &inv, 1.0 / 9.0, InverseMethod::SLORETA).unwrap();
377/// assert_eq!(stc.data.nrows(), n_src);
378/// assert_eq!(stc.data.ncols(), 100);
379/// ```
380pub fn apply_inverse(
381    data: &Array2<f64>,
382    inv: &InverseOperator,
383    lambda2: f64,
384    method: InverseMethod,
385) -> Result<SourceEstimate> {
386    apply_inverse_with_options(data, inv, lambda2, method, None)
387}
388
389/// Apply inverse with optional eLORETA parameters.
390pub fn apply_inverse_with_options(
391    data: &Array2<f64>,
392    inv: &InverseOperator,
393    lambda2: f64,
394    method: InverseMethod,
395    eloreta_opts: Option<&EloretaOptions>,
396) -> Result<SourceEstimate> {
397    apply_inverse_full(data, inv, lambda2, method, PickOri::None, eloreta_opts)
398}
399
400/// Apply inverse with full control over orientation picking and eLORETA options.
401///
402/// # Arguments
403///
404/// * `data`          — Sensor data, shape `[n_channels, n_times]`.
405/// * `inv`           — Inverse operator.
406/// * `lambda2`       — Regularisation parameter.
407/// * `method`        — Inverse method.
408/// * `pick_ori`      — How to handle source orientations (see [`PickOri`]).
409/// * `eloreta_opts`  — Options for eLORETA (ignored for other methods).
410pub fn apply_inverse_full(
411    data: &Array2<f64>,
412    inv: &InverseOperator,
413    lambda2: f64,
414    method: InverseMethod,
415    pick_ori: PickOri,
416    eloreta_opts: Option<&EloretaOptions>,
417) -> Result<SourceEstimate> {
418    let n_chan = data.nrows();
419    if n_chan != inv.whitener.ncols() {
420        bail!(
421            "Data has {} channels but inverse expects {}",
422            n_chan,
423            inv.whitener.ncols()
424        );
425    }
426
427    let n_orient = match inv.orientation {
428        SourceOrientation::Fixed => 1,
429        SourceOrientation::Free => 3,
430    };
431
432    if pick_ori == PickOri::Normal && n_orient != 3 {
433        bail!("pick_ori=Normal requires free-orientation inverse");
434    }
435
436    let prepared = prepare_inverse(inv, lambda2, method, eloreta_opts)?;
437
438    // Apply imaging kernel: sol = K @ data
439    let mut sol = prepared.kernel.dot(data);
440
441    let is_free = n_orient == 3;
442
443    match pick_ori {
444        PickOri::None => {
445            // Default: combine XYZ for free orientation
446            if is_free {
447                sol = combine_xyz(&sol);
448            }
449            // Apply noise normalisation
450            apply_noisenorm(&mut sol, &prepared.noisenorm);
451        }
452        PickOri::Normal => {
453            // Pick only the Z (normal) component: every 3rd row starting at index 2
454            let n_src = inv.n_sources;
455            let n_times = sol.ncols();
456            let mut normal_sol = Array2::zeros((n_src, n_times));
457            for s in 0..n_src {
458                for t in 0..n_times {
459                    normal_sol[[s, t]] = sol[[s * 3 + 2, t]];
460                }
461            }
462            sol = normal_sol;
463            // Apply noise normalisation
464            apply_noisenorm(&mut sol, &prepared.noisenorm);
465        }
466        PickOri::Vector => {
467            // Return all 3 components — noise norm must be expanded
468            if let Some(ref nn) = prepared.noisenorm {
469                if is_free {
470                    // noisenorm has n_src entries; repeat for each orientation
471                    for s in 0..inv.n_sources {
472                        let norm = nn[s];
473                        for o in 0..3 {
474                            for t in 0..sol.ncols() {
475                                sol[[s * 3 + o, t]] *= norm;
476                            }
477                        }
478                    }
479                } else {
480                    apply_noisenorm(&mut sol, &prepared.noisenorm);
481                }
482            }
483        }
484    }
485
486    Ok(SourceEstimate {
487        data: sol,
488        n_sources: inv.n_sources,
489        orientation: inv.orientation,
490    })
491}
492
493/// Apply noise normalisation in-place.
494fn apply_noisenorm(sol: &mut Array2<f64>, noisenorm: &Option<Array1<f64>>) {
495    if let Some(ref nn) = noisenorm {
496        let n_src_out = sol.nrows();
497        for s in 0..n_src_out {
498            let norm = nn[s];
499            for t in 0..sol.ncols() {
500                sol[[s, t]] *= norm;
501            }
502        }
503    }
504}
505
506/// Apply inverse operator to each epoch in a batch.
507///
508/// This is the Rust equivalent of `mne.minimum_norm.apply_inverse_epochs`.
509///
510/// # Arguments
511///
512/// * `epochs`  — Epoched data, shape `[n_epochs, n_channels, n_times]`.
513/// * `inv`     — Inverse operator.
514/// * `lambda2` — Regularisation parameter.
515/// * `method`  — Inverse method.
516///
517/// # Returns
518///
519/// A `Vec<SourceEstimate>`, one per epoch.
520pub fn apply_inverse_epochs(
521    epochs: &ndarray::Array3<f64>,
522    inv: &InverseOperator,
523    lambda2: f64,
524    method: InverseMethod,
525) -> Result<Vec<SourceEstimate>> {
526    apply_inverse_epochs_full(epochs, inv, lambda2, method, PickOri::None, None)
527}
528
529/// Apply inverse to epochs with full options.
530pub fn apply_inverse_epochs_full(
531    epochs: &ndarray::Array3<f64>,
532    inv: &InverseOperator,
533    lambda2: f64,
534    method: InverseMethod,
535    pick_ori: PickOri,
536    eloreta_opts: Option<&EloretaOptions>,
537) -> Result<Vec<SourceEstimate>> {
538    let (n_epochs, _n_ch, _n_t) = epochs.dim();
539    let mut results = Vec::with_capacity(n_epochs);
540    for e in 0..n_epochs {
541        let epoch = epochs.slice(ndarray::s![e, .., ..]).to_owned();
542        let stc = apply_inverse_full(&epoch, inv, lambda2, method, pick_ori, eloreta_opts)?;
543        results.push(stc);
544    }
545    Ok(results)
546}
547
548#[cfg(test)]
549mod tests {
550    use super::*;
551    use ndarray::Array2;
552
553    /// Build a simple test forward model and noise cov.
554    fn make_test_setup(n_chan: usize, n_src: usize) -> (ForwardOperator, NoiseCov) {
555        // Simple forward: each source contributes to all channels
556        let mut gain = Array2::zeros((n_chan, n_src));
557        for i in 0..n_chan {
558            for j in 0..n_src {
559                // Distance-like falloff
560                let dist = ((i as f64 - j as f64 * n_chan as f64 / n_src as f64).powi(2)
561                    + 1.0)
562                    .sqrt();
563                gain[[i, j]] = 1e-8 / dist;
564            }
565        }
566        let fwd = ForwardOperator::new_fixed(gain);
567        let cov = NoiseCov::diagonal(vec![1e-12; n_chan]);
568        (fwd, cov)
569    }
570
571    #[test]
572    fn test_make_inverse_operator() {
573        let (fwd, cov) = make_test_setup(16, 50);
574        let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
575        assert_eq!(inv.n_sources, 50);
576        assert_eq!(inv.sing.len(), 16); // min(n_chan, n_src)
577    }
578
579    #[test]
580    fn test_apply_inverse_mne() {
581        let (fwd, cov) = make_test_setup(16, 50);
582        let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
583
584        // Simulate data from a single source
585        let n_times = 10;
586        let source_idx = 25;
587        let mut source_signal = Array2::zeros((50, n_times));
588        for t in 0..n_times {
589            source_signal[[source_idx, t]] = 1e-9;
590        }
591        let data = fwd.gain.dot(&source_signal);
592
593        let stc = apply_inverse(&data, &inv, 1.0 / 9.0, InverseMethod::MNE).unwrap();
594        assert_eq!(stc.data.dim(), (50, n_times));
595
596        // The peak source should be near the simulated source
597        let mut peak_src = 0;
598        let mut peak_val = 0.0_f64;
599        for s in 0..50 {
600            let val = stc.data[[s, 0]].abs();
601            if val > peak_val {
602                peak_val = val;
603                peak_src = s;
604            }
605        }
606        // Should be within a few sources of the true location
607        assert!(
608            (peak_src as i32 - source_idx as i32).unsigned_abs() <= 5,
609            "Peak at {peak_src}, expected near {source_idx}"
610        );
611    }
612
613    #[test]
614    fn test_apply_inverse_dspm() {
615        let (fwd, cov) = make_test_setup(16, 50);
616        let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
617        let data = Array2::from_elem((16, 5), 1e-6);
618        let stc = apply_inverse(&data, &inv, 1.0 / 9.0, InverseMethod::DSPM).unwrap();
619        assert_eq!(stc.data.nrows(), 50);
620        // dSPM values should be finite and non-NaN
621        assert!(stc.data.iter().all(|v| v.is_finite()));
622    }
623
624    #[test]
625    fn test_apply_inverse_sloreta() {
626        let (fwd, cov) = make_test_setup(16, 50);
627        let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
628        let data = Array2::from_elem((16, 5), 1e-6);
629        let stc =
630            apply_inverse(&data, &inv, 1.0 / 9.0, InverseMethod::SLORETA).unwrap();
631        assert_eq!(stc.data.nrows(), 50);
632        assert!(stc.data.iter().all(|v| v.is_finite()));
633    }
634
635    #[test]
636    fn test_free_orientation() {
637        let n_chan = 16;
638        let n_src = 20;
639        let mut gain = Array2::zeros((n_chan, n_src * 3));
640        for i in 0..n_chan {
641            for j in 0..n_src * 3 {
642                let dist = ((i as f64 - j as f64 / 3.0 * n_chan as f64 / n_src as f64)
643                    .powi(2)
644                    + 1.0)
645                    .sqrt();
646                gain[[i, j]] = 1e-8 / dist;
647            }
648        }
649        let fwd = ForwardOperator::new_free(gain);
650        let cov = NoiseCov::diagonal(vec![1e-12; n_chan]);
651        let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
652
653        let data = Array2::from_elem((n_chan, 5), 1e-6);
654        let stc = apply_inverse(&data, &inv, 1.0 / 9.0, InverseMethod::DSPM).unwrap();
655        // Free orientation: XYZ are combined → n_sources rows
656        assert_eq!(stc.data.nrows(), n_src);
657        assert!(stc.data.iter().all(|v| v.is_finite()));
658    }
659
660    #[test]
661    fn test_apply_inverse_epochs() {
662        let (fwd, cov) = make_test_setup(16, 50);
663        let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
664
665        let epochs = ndarray::Array3::from_shape_fn((5, 16, 10), |(_, i, j)| {
666            ((i * 10 + j) as f64).sin() * 1e-6
667        });
668        let stcs = apply_inverse_epochs(&epochs, &inv, 1.0 / 9.0, InverseMethod::DSPM).unwrap();
669        assert_eq!(stcs.len(), 5);
670        for stc in &stcs {
671            assert_eq!(stc.data.dim(), (50, 10));
672            assert!(stc.data.iter().all(|v| v.is_finite()));
673        }
674    }
675
676    #[test]
677    fn test_pick_ori_vector() {
678        let n_chan = 16;
679        let n_src = 20;
680        let mut gain = Array2::zeros((n_chan, n_src * 3));
681        for i in 0..n_chan {
682            for j in 0..n_src * 3 {
683                let dist = ((i as f64 - j as f64 / 3.0 * n_chan as f64 / n_src as f64)
684                    .powi(2)
685                    + 1.0)
686                    .sqrt();
687                gain[[i, j]] = 1e-8 / dist;
688            }
689        }
690        let fwd = ForwardOperator::new_free(gain);
691        let cov = NoiseCov::diagonal(vec![1e-12; n_chan]);
692        let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
693
694        let data = Array2::from_elem((n_chan, 5), 1e-6);
695
696        // Vector: should return n_src*3 rows
697        let stc_vec = apply_inverse_full(
698            &data, &inv, 1.0 / 9.0, InverseMethod::MNE, PickOri::Vector, None,
699        )
700        .unwrap();
701        assert_eq!(stc_vec.data.nrows(), n_src * 3);
702
703        // Normal: should return n_src rows
704        let stc_norm = apply_inverse_full(
705            &data, &inv, 1.0 / 9.0, InverseMethod::MNE, PickOri::Normal, None,
706        )
707        .unwrap();
708        assert_eq!(stc_norm.data.nrows(), n_src);
709
710        // Default (None): should return n_src rows (combined)
711        let stc_comb = apply_inverse_full(
712            &data, &inv, 1.0 / 9.0, InverseMethod::MNE, PickOri::None, None,
713        )
714        .unwrap();
715        assert_eq!(stc_comb.data.nrows(), n_src);
716    }
717
718    #[test]
719    fn test_depth_weighting() {
720        let (fwd, cov) = make_test_setup(16, 50);
721        // With depth weighting
722        let inv_depth = make_inverse_operator(&fwd, &cov, Some(0.8)).unwrap();
723        // Without depth weighting
724        let inv_nodepth = make_inverse_operator(&fwd, &cov, None).unwrap();
725
726        // Source covariances should differ
727        let diff: f64 = (&inv_depth.source_cov - &inv_nodepth.source_cov)
728            .mapv(f64::abs)
729            .sum();
730        assert!(diff > 1e-10, "Depth weighting should change source_cov");
731    }
732}