1use anyhow::{bail, Result};
7use ndarray::{Array1, Array2};
8
9use super::eloreta::compute_eloreta;
10use super::linalg;
11use super::{
12 EloretaOptions, ForwardOperator, InverseMethod, InverseOperator, NoiseCov, PickOri,
13 SourceEstimate, SourceOrientation,
14};
15
16fn compute_depth_prior(
22 whitened_gain: &Array2<f64>,
23 n_sources: usize,
24 n_orient: usize,
25 exp: f64,
26) -> Array1<f64> {
27 let n_cols = n_sources * n_orient;
28 let mut col_norms = Array1::zeros(n_sources);
29
30 for s in 0..n_sources {
31 let mut norm_sq = 0.0;
32 for o in 0..n_orient {
33 let col_idx = s * n_orient + o;
34 if col_idx < n_cols {
35 for r in 0..whitened_gain.nrows() {
36 norm_sq += whitened_gain[[r, col_idx]].powi(2);
37 }
38 }
39 }
40 col_norms[s] = norm_sq.sqrt();
41 }
42
43 let max_norm = col_norms.iter().copied().fold(0.0_f64, f64::max);
44 if max_norm <= 0.0 {
45 return Array1::ones(n_cols);
46 }
47
48 let mut prior = Array1::zeros(n_cols);
49 for s in 0..n_sources {
50 let w = (col_norms[s] / max_norm).powf(exp);
51 for o in 0..n_orient {
52 prior[s * n_orient + o] = w;
53 }
54 }
55 prior
56}
57
58pub fn make_inverse_operator(
74 fwd: &ForwardOperator,
75 noise_cov: &NoiseCov,
76 depth_exp: Option<f64>,
77) -> Result<InverseOperator> {
78 let n_chan = fwd.gain.nrows();
79 let n_orient = fwd.n_orient();
80 let n_cols = fwd.n_sources * n_orient;
81
82 if fwd.gain.ncols() != n_cols {
83 bail!(
84 "Gain matrix has {} columns but expected {} (n_sources={} × n_orient={})",
85 fwd.gain.ncols(),
86 n_cols,
87 fwd.n_sources,
88 n_orient,
89 );
90 }
91 if noise_cov.n_channels() != n_chan {
92 bail!(
93 "Noise covariance has {} channels but gain has {} rows",
94 noise_cov.n_channels(),
95 n_chan,
96 );
97 }
98
99 let cov_full = noise_cov.to_full();
101 let (whitener, n_nzero) = linalg::compute_whitener(&cov_full)?;
102
103 let gain_w = whitener.dot(&fwd.gain);
105
106 let exp = depth_exp.or(fwd.depth_exp);
108 let source_std = if let Some(e) = exp {
109 let prior = compute_depth_prior(&gain_w, fwd.n_sources, n_orient, e);
110 let mut std = Array1::zeros(n_cols);
111 for i in 0..n_cols {
112 std[i] = prior[i].sqrt();
113 }
114 std
115 } else {
116 Array1::ones(n_cols)
117 };
118
119 let mut gain_ws = gain_w;
121 for j in 0..n_cols {
122 for i in 0..gain_ws.nrows() {
123 gain_ws[[i, j]] *= source_std[j];
124 }
125 }
126
127 let trace_grgt = gain_ws.iter().map(|v| v * v).sum::<f64>();
129 let scale = (n_nzero as f64 / trace_grgt).sqrt();
130 gain_ws.mapv_inplace(|v| v * scale);
131 let source_std = source_std.mapv(|v| v * scale);
132
133 let (u, sing, vt) = linalg::svd_thin(&gain_ws)?;
135
136 let eigen_fields = u.t().to_owned();
138 let eigen_leads = vt.t().to_owned();
140
141 let source_cov = source_std.mapv(|v| v * v);
142
143 Ok(InverseOperator {
144 eigen_fields,
145 sing,
146 eigen_leads,
147 source_cov,
148 eigen_leads_weighted: false,
149 n_sources: fwd.n_sources,
150 orientation: fwd.orientation,
151 source_nn: fwd.source_nn.clone(),
152 whitener,
153 n_nzero,
154 noise_cov: noise_cov.clone(),
155 })
156}
157
158pub struct PreparedInverse {
160 pub reginv: Array1<f64>,
162 pub noisenorm: Option<Array1<f64>>,
164 pub kernel: Array2<f64>,
166}
167
168pub fn prepare_inverse(
179 inv: &InverseOperator,
180 lambda2: f64,
181 method: InverseMethod,
182 eloreta_opts: Option<&EloretaOptions>,
183) -> Result<PreparedInverse> {
184 let n_orient = match inv.orientation {
185 SourceOrientation::Fixed => 1,
186 SourceOrientation::Free => 3,
187 };
188
189 if method == InverseMethod::ELORETA {
190 return prepare_eloreta(inv, lambda2, eloreta_opts);
191 }
192
193 let reginv = compute_reginv(&inv.sing, lambda2, inv.n_nzero);
195
196 let noisenorm = match method {
198 InverseMethod::MNE => None,
199 InverseMethod::DSPM => {
200 let noise_weight = reginv.clone();
201 Some(compute_noise_norm(inv, &noise_weight, n_orient))
202 }
203 InverseMethod::SLORETA => {
204 let noise_weight = Array1::from_iter(reginv.iter().zip(inv.sing.iter()).map(
205 |(&ri, &si)| ri * (1.0 + si * si / lambda2).sqrt(),
206 ));
207 Some(compute_noise_norm(inv, &noise_weight, n_orient))
208 }
209 InverseMethod::ELORETA => unreachable!(),
210 };
211
212 let n_k = inv.sing.len();
223 let n_chan = inv.whitener.ncols();
224
225 let trans = inv.eigen_fields.dot(&inv.whitener);
227 let mut trans_scaled = Array2::zeros((n_k, n_chan));
229 for i in 0..n_k {
230 for j in 0..n_chan {
231 trans_scaled[[i, j]] = trans[[i, j]] * reginv[i];
232 }
233 }
234
235 let mut kernel = inv.eigen_leads.dot(&trans_scaled);
237
238 if !inv.eigen_leads_weighted {
240 for i in 0..kernel.nrows() {
241 let w = inv.source_cov[i].sqrt();
242 for j in 0..kernel.ncols() {
243 kernel[[i, j]] *= w;
244 }
245 }
246 }
247
248 Ok(PreparedInverse {
249 reginv,
250 noisenorm,
251 kernel,
252 })
253}
254
255fn prepare_eloreta(
257 inv: &InverseOperator,
258 lambda2: f64,
259 opts: Option<&EloretaOptions>,
260) -> Result<PreparedInverse> {
261 let default_opts = EloretaOptions::default();
262 let opts = opts.unwrap_or(&default_opts);
263
264 let (kernel, reginv) = compute_eloreta(inv, lambda2, opts)?;
265
266 Ok(PreparedInverse {
267 reginv,
268 noisenorm: None, kernel,
270 })
271}
272
273fn compute_reginv(sing: &Array1<f64>, lambda2: f64, n_nzero: usize) -> Array1<f64> {
275 let n = sing.len();
276 let mut reginv = Array1::zeros(n);
277 for k in 0..n.min(n_nzero) {
278 let s = sing[k];
279 if s > 0.0 {
280 reginv[k] = s / (s * s + lambda2);
281 }
282 }
283 reginv
284}
285
286fn compute_noise_norm(
290 inv: &InverseOperator,
291 noise_weight: &Array1<f64>,
292 n_orient: usize,
293) -> Array1<f64> {
294 let n_rows = inv.eigen_leads.nrows();
295 let n_k = noise_weight.len();
296
297 let mut raw_norm = Array1::zeros(n_rows);
298 for k in 0..n_rows {
299 let mut sq_sum = 0.0;
300 for j in 0..n_k {
301 let lead = if inv.eigen_leads_weighted {
302 inv.eigen_leads[[k, j]]
303 } else {
304 inv.source_cov[k].sqrt() * inv.eigen_leads[[k, j]]
305 };
306 let val = lead * noise_weight[j];
307 sq_sum += val * val;
308 }
309 raw_norm[k] = sq_sum.sqrt();
310 }
311
312 if n_orient == 3 {
314 let n_src = n_rows / 3;
315 let mut combined = Array1::zeros(n_src);
316 for s in 0..n_src {
317 let mut sum_sq = 0.0;
318 for o in 0..3 {
319 sum_sq += raw_norm[s * 3 + o].powi(2);
320 }
321 combined[s] = sum_sq.sqrt();
322 }
323 combined.mapv(|v| if v.abs() > 0.0 { 1.0 / v } else { 0.0 })
324 } else {
325 raw_norm.mapv(|v| if v.abs() > 0.0 { 1.0 / v } else { 0.0 })
326 }
327}
328
329fn combine_xyz(sol: &Array2<f64>) -> Array2<f64> {
331 let (n_rows, n_times) = sol.dim();
332 assert!(n_rows % 3 == 0, "combine_xyz: rows must be divisible by 3");
333 let n_src = n_rows / 3;
334 let mut out = Array2::zeros((n_src, n_times));
335 for s in 0..n_src {
336 for t in 0..n_times {
337 let x = sol[[s * 3, t]];
338 let y = sol[[s * 3 + 1, t]];
339 let z = sol[[s * 3 + 2, t]];
340 out[[s, t]] = (x * x + y * y + z * z).sqrt();
341 }
342 }
343 out
344}
345
346pub fn apply_inverse(
381 data: &Array2<f64>,
382 inv: &InverseOperator,
383 lambda2: f64,
384 method: InverseMethod,
385) -> Result<SourceEstimate> {
386 apply_inverse_with_options(data, inv, lambda2, method, None)
387}
388
389pub fn apply_inverse_with_options(
391 data: &Array2<f64>,
392 inv: &InverseOperator,
393 lambda2: f64,
394 method: InverseMethod,
395 eloreta_opts: Option<&EloretaOptions>,
396) -> Result<SourceEstimate> {
397 apply_inverse_full(data, inv, lambda2, method, PickOri::None, eloreta_opts)
398}
399
400pub fn apply_inverse_full(
411 data: &Array2<f64>,
412 inv: &InverseOperator,
413 lambda2: f64,
414 method: InverseMethod,
415 pick_ori: PickOri,
416 eloreta_opts: Option<&EloretaOptions>,
417) -> Result<SourceEstimate> {
418 let n_chan = data.nrows();
419 if n_chan != inv.whitener.ncols() {
420 bail!(
421 "Data has {} channels but inverse expects {}",
422 n_chan,
423 inv.whitener.ncols()
424 );
425 }
426
427 let n_orient = match inv.orientation {
428 SourceOrientation::Fixed => 1,
429 SourceOrientation::Free => 3,
430 };
431
432 if pick_ori == PickOri::Normal && n_orient != 3 {
433 bail!("pick_ori=Normal requires free-orientation inverse");
434 }
435
436 let prepared = prepare_inverse(inv, lambda2, method, eloreta_opts)?;
437
438 let mut sol = prepared.kernel.dot(data);
440
441 let is_free = n_orient == 3;
442
443 match pick_ori {
444 PickOri::None => {
445 if is_free {
447 sol = combine_xyz(&sol);
448 }
449 apply_noisenorm(&mut sol, &prepared.noisenorm);
451 }
452 PickOri::Normal => {
453 let n_src = inv.n_sources;
455 let n_times = sol.ncols();
456 let mut normal_sol = Array2::zeros((n_src, n_times));
457 for s in 0..n_src {
458 for t in 0..n_times {
459 normal_sol[[s, t]] = sol[[s * 3 + 2, t]];
460 }
461 }
462 sol = normal_sol;
463 apply_noisenorm(&mut sol, &prepared.noisenorm);
465 }
466 PickOri::Vector => {
467 if let Some(ref nn) = prepared.noisenorm {
469 if is_free {
470 for s in 0..inv.n_sources {
472 let norm = nn[s];
473 for o in 0..3 {
474 for t in 0..sol.ncols() {
475 sol[[s * 3 + o, t]] *= norm;
476 }
477 }
478 }
479 } else {
480 apply_noisenorm(&mut sol, &prepared.noisenorm);
481 }
482 }
483 }
484 }
485
486 Ok(SourceEstimate {
487 data: sol,
488 n_sources: inv.n_sources,
489 orientation: inv.orientation,
490 })
491}
492
493fn apply_noisenorm(sol: &mut Array2<f64>, noisenorm: &Option<Array1<f64>>) {
495 if let Some(ref nn) = noisenorm {
496 let n_src_out = sol.nrows();
497 for s in 0..n_src_out {
498 let norm = nn[s];
499 for t in 0..sol.ncols() {
500 sol[[s, t]] *= norm;
501 }
502 }
503 }
504}
505
506pub fn apply_inverse_epochs(
521 epochs: &ndarray::Array3<f64>,
522 inv: &InverseOperator,
523 lambda2: f64,
524 method: InverseMethod,
525) -> Result<Vec<SourceEstimate>> {
526 apply_inverse_epochs_full(epochs, inv, lambda2, method, PickOri::None, None)
527}
528
529pub fn apply_inverse_epochs_full(
531 epochs: &ndarray::Array3<f64>,
532 inv: &InverseOperator,
533 lambda2: f64,
534 method: InverseMethod,
535 pick_ori: PickOri,
536 eloreta_opts: Option<&EloretaOptions>,
537) -> Result<Vec<SourceEstimate>> {
538 let (n_epochs, _n_ch, _n_t) = epochs.dim();
539 let mut results = Vec::with_capacity(n_epochs);
540 for e in 0..n_epochs {
541 let epoch = epochs.slice(ndarray::s![e, .., ..]).to_owned();
542 let stc = apply_inverse_full(&epoch, inv, lambda2, method, pick_ori, eloreta_opts)?;
543 results.push(stc);
544 }
545 Ok(results)
546}
547
548#[cfg(test)]
549mod tests {
550 use super::*;
551 use ndarray::Array2;
552
553 fn make_test_setup(n_chan: usize, n_src: usize) -> (ForwardOperator, NoiseCov) {
555 let mut gain = Array2::zeros((n_chan, n_src));
557 for i in 0..n_chan {
558 for j in 0..n_src {
559 let dist = ((i as f64 - j as f64 * n_chan as f64 / n_src as f64).powi(2)
561 + 1.0)
562 .sqrt();
563 gain[[i, j]] = 1e-8 / dist;
564 }
565 }
566 let fwd = ForwardOperator::new_fixed(gain);
567 let cov = NoiseCov::diagonal(vec![1e-12; n_chan]);
568 (fwd, cov)
569 }
570
571 #[test]
572 fn test_make_inverse_operator() {
573 let (fwd, cov) = make_test_setup(16, 50);
574 let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
575 assert_eq!(inv.n_sources, 50);
576 assert_eq!(inv.sing.len(), 16); }
578
579 #[test]
580 fn test_apply_inverse_mne() {
581 let (fwd, cov) = make_test_setup(16, 50);
582 let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
583
584 let n_times = 10;
586 let source_idx = 25;
587 let mut source_signal = Array2::zeros((50, n_times));
588 for t in 0..n_times {
589 source_signal[[source_idx, t]] = 1e-9;
590 }
591 let data = fwd.gain.dot(&source_signal);
592
593 let stc = apply_inverse(&data, &inv, 1.0 / 9.0, InverseMethod::MNE).unwrap();
594 assert_eq!(stc.data.dim(), (50, n_times));
595
596 let mut peak_src = 0;
598 let mut peak_val = 0.0_f64;
599 for s in 0..50 {
600 let val = stc.data[[s, 0]].abs();
601 if val > peak_val {
602 peak_val = val;
603 peak_src = s;
604 }
605 }
606 assert!(
608 (peak_src as i32 - source_idx as i32).unsigned_abs() <= 5,
609 "Peak at {peak_src}, expected near {source_idx}"
610 );
611 }
612
613 #[test]
614 fn test_apply_inverse_dspm() {
615 let (fwd, cov) = make_test_setup(16, 50);
616 let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
617 let data = Array2::from_elem((16, 5), 1e-6);
618 let stc = apply_inverse(&data, &inv, 1.0 / 9.0, InverseMethod::DSPM).unwrap();
619 assert_eq!(stc.data.nrows(), 50);
620 assert!(stc.data.iter().all(|v| v.is_finite()));
622 }
623
624 #[test]
625 fn test_apply_inverse_sloreta() {
626 let (fwd, cov) = make_test_setup(16, 50);
627 let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
628 let data = Array2::from_elem((16, 5), 1e-6);
629 let stc =
630 apply_inverse(&data, &inv, 1.0 / 9.0, InverseMethod::SLORETA).unwrap();
631 assert_eq!(stc.data.nrows(), 50);
632 assert!(stc.data.iter().all(|v| v.is_finite()));
633 }
634
635 #[test]
636 fn test_free_orientation() {
637 let n_chan = 16;
638 let n_src = 20;
639 let mut gain = Array2::zeros((n_chan, n_src * 3));
640 for i in 0..n_chan {
641 for j in 0..n_src * 3 {
642 let dist = ((i as f64 - j as f64 / 3.0 * n_chan as f64 / n_src as f64)
643 .powi(2)
644 + 1.0)
645 .sqrt();
646 gain[[i, j]] = 1e-8 / dist;
647 }
648 }
649 let fwd = ForwardOperator::new_free(gain);
650 let cov = NoiseCov::diagonal(vec![1e-12; n_chan]);
651 let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
652
653 let data = Array2::from_elem((n_chan, 5), 1e-6);
654 let stc = apply_inverse(&data, &inv, 1.0 / 9.0, InverseMethod::DSPM).unwrap();
655 assert_eq!(stc.data.nrows(), n_src);
657 assert!(stc.data.iter().all(|v| v.is_finite()));
658 }
659
660 #[test]
661 fn test_apply_inverse_epochs() {
662 let (fwd, cov) = make_test_setup(16, 50);
663 let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
664
665 let epochs = ndarray::Array3::from_shape_fn((5, 16, 10), |(_, i, j)| {
666 ((i * 10 + j) as f64).sin() * 1e-6
667 });
668 let stcs = apply_inverse_epochs(&epochs, &inv, 1.0 / 9.0, InverseMethod::DSPM).unwrap();
669 assert_eq!(stcs.len(), 5);
670 for stc in &stcs {
671 assert_eq!(stc.data.dim(), (50, 10));
672 assert!(stc.data.iter().all(|v| v.is_finite()));
673 }
674 }
675
676 #[test]
677 fn test_pick_ori_vector() {
678 let n_chan = 16;
679 let n_src = 20;
680 let mut gain = Array2::zeros((n_chan, n_src * 3));
681 for i in 0..n_chan {
682 for j in 0..n_src * 3 {
683 let dist = ((i as f64 - j as f64 / 3.0 * n_chan as f64 / n_src as f64)
684 .powi(2)
685 + 1.0)
686 .sqrt();
687 gain[[i, j]] = 1e-8 / dist;
688 }
689 }
690 let fwd = ForwardOperator::new_free(gain);
691 let cov = NoiseCov::diagonal(vec![1e-12; n_chan]);
692 let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
693
694 let data = Array2::from_elem((n_chan, 5), 1e-6);
695
696 let stc_vec = apply_inverse_full(
698 &data, &inv, 1.0 / 9.0, InverseMethod::MNE, PickOri::Vector, None,
699 )
700 .unwrap();
701 assert_eq!(stc_vec.data.nrows(), n_src * 3);
702
703 let stc_norm = apply_inverse_full(
705 &data, &inv, 1.0 / 9.0, InverseMethod::MNE, PickOri::Normal, None,
706 )
707 .unwrap();
708 assert_eq!(stc_norm.data.nrows(), n_src);
709
710 let stc_comb = apply_inverse_full(
712 &data, &inv, 1.0 / 9.0, InverseMethod::MNE, PickOri::None, None,
713 )
714 .unwrap();
715 assert_eq!(stc_comb.data.nrows(), n_src);
716 }
717
718 #[test]
719 fn test_depth_weighting() {
720 let (fwd, cov) = make_test_setup(16, 50);
721 let inv_depth = make_inverse_operator(&fwd, &cov, Some(0.8)).unwrap();
723 let inv_nodepth = make_inverse_operator(&fwd, &cov, None).unwrap();
725
726 let diff: f64 = (&inv_depth.source_cov - &inv_nodepth.source_cov)
728 .mapv(f64::abs)
729 .sum();
730 assert!(diff > 1e-10, "Depth weighting should change source_cov");
731 }
732}