Skip to main content

exg_source/
eloreta.rs

1//! eLORETA (exact Low Resolution Electromagnetic Tomography) solver.
2//!
3//! Ported from MNE-Python's `mne.minimum_norm._eloreta._compute_eloreta`.
4//!
5//! eLORETA iteratively computes optimal source weights that yield exact
6//! localisation (zero localisation bias) for single dipole sources.
7//!
8//! ## References
9//!
10//! Pascual-Marqui, R. D. (2011). Discrete, 3D distributed, linear imaging
11//! methods of electric neuronal activity. Part 1: exact, zero error
12//! localization. *arXiv:0710.3341*.
13
14use anyhow::{bail, Result};
15use ndarray::{Array1, Array2};
16
17use super::linalg;
18use super::{EloretaOptions, InverseOperator, SourceOrientation};
19
20/// Compute the eLORETA inverse kernel.
21///
22/// Returns `(kernel, reginv)` where:
23/// - `kernel` has shape `[n_sources × n_orient, n_channels]`
24/// - `reginv` has length `n_nzero`
25pub fn compute_eloreta(
26    inv: &InverseOperator,
27    lambda2: f64,
28    opts: &EloretaOptions,
29) -> Result<(Array2<f64>, Array1<f64>)> {
30    if inv.eigen_leads_weighted {
31        bail!("eLORETA cannot be computed with weighted eigen leads");
32    }
33
34    let n_orient = match inv.orientation {
35        SourceOrientation::Fixed => 1,
36        SourceOrientation::Free => 3,
37    };
38    let n_src = inv.n_sources;
39    let n_nzero = inv.n_nzero;
40
41    // Reassemble the gain matrix: G = U @ diag(s) @ V^T (in whitened space)
42    // eigen_fields = U^T [k, n_chan_w], sing = s [k], eigen_leads = V [n_cols, k]
43    // G_whitened = (eigen_fields^T @ diag(s) @ eigen_leads^T) ... but this is
44    // whitened gain. Let's reconstruct: G_w = U @ S @ V^T
45    let n_k = inv.sing.len();
46    let n_cols = n_src * n_orient;
47
48    // G = eigen_fields^T * diag(sing) * eigen_leads^T  [n_chan_w, n_cols]
49    // But eigen_fields = U^T [k, n_chan_w], so U = eigen_fields^T [n_chan_w, k]
50    let mut g = Array2::zeros((inv.eigen_fields.ncols(), n_cols));
51    for c in 0..n_cols {
52        for ch in 0..g.nrows() {
53            let mut val = 0.0;
54            for k in 0..n_k {
55                val += inv.eigen_fields[[k, ch]] * inv.sing[k] * inv.eigen_leads[[c, k]];
56            }
57            g[[ch, c]] = val;
58        }
59    }
60
61    // Remove source_cov weighting to get the "raw" whitened gain
62    for c in 0..n_cols {
63        let w = inv.source_cov[c].sqrt();
64        if w > 0.0 {
65            for r in 0..g.nrows() {
66                g[[r, c]] /= w;
67            }
68        }
69    }
70
71    let force_equal = opts.force_equal.unwrap_or(n_orient == 1);
72
73    // Initialise weights R
74    let mut r_diag = Array1::ones(n_cols); // used when force_equal or n_orient==1
75
76    // Main iteration
77    for _iter in 0..opts.max_iter {
78        // Apply R to G: G_R = G @ diag(R) for diagonal R
79        let mut g_r = g.clone();
80        for c in 0..n_cols {
81            let w = r_diag[c];
82            for r in 0..g_r.nrows() {
83                g_r[[r, c]] *= w;
84            }
85        }
86
87        // G_R_Gt = G_R @ G^T = G @ diag(R) @ G^T
88        let g_r_gt = g_r.dot(&g.t());
89
90        // Normalise so trace = n_nzero
91        let trace = g_r_gt.diag().sum();
92        let norm = trace / n_nzero as f64;
93        let g_r_gt_normed = g_r_gt.mapv(|v| v / norm);
94        let _r_norm = norm;
95
96        // Eigendecompose G_R_Gt
97        let (evals, evecs) = linalg::eigh_sorted(&g_r_gt_normed)?;
98
99        // Compute N = (G_R_Gt + lambda2 I)^{-1} using eigendecomposition
100        let mut n_mat = Array2::zeros((g.nrows(), g.nrows()));
101        for k in 0..n_nzero {
102            if evals[k].abs() > 0.0 {
103                let inv_val = 1.0 / (evals[k] + lambda2);
104                for i in 0..n_mat.nrows() {
105                    for j in 0..n_mat.ncols() {
106                        n_mat[[i, j]] += inv_val * evecs[[i, k]] * evecs[[j, k]];
107                    }
108                }
109            }
110        }
111
112        // Update weights
113        let r_diag_old = r_diag.clone();
114
115        if n_orient == 1 || force_equal {
116            // R_k = 1 / sqrt(G_k^T @ N @ G_k)
117            for s in 0..n_src {
118                let mut val = 0.0;
119                for o in 0..n_orient {
120                    let c = s * n_orient + o;
121                    // G_k^T @ N @ G_k for this column
122                    let ng = n_mat.dot(&g.column(c).to_owned());
123                    val += g.column(c).dot(&ng);
124                }
125                let w = if val > 0.0 {
126                    1.0 / (val / n_orient as f64).sqrt()
127                } else {
128                    1.0
129                };
130                for o in 0..n_orient {
131                    r_diag[s * n_orient + o] = w;
132                }
133            }
134        } else {
135            // Free orientation, not force_equal: use per-component weights
136            for c in 0..n_cols {
137                let ng = n_mat.dot(&g.column(c).to_owned());
138                let val = g.column(c).dot(&ng);
139                r_diag[c] = if val > 0.0 { 1.0 / val.sqrt() } else { 1.0 };
140            }
141        }
142
143        // Normalise R to keep things stable
144        let r_trace: f64 = {
145            let mut gr = g.clone();
146            for c in 0..n_cols {
147                for r in 0..gr.nrows() {
148                    gr[[r, c]] *= r_diag[c];
149                }
150            }
151            let grgt = gr.dot(&g.t());
152            grgt.diag().sum() / n_nzero as f64
153        };
154        if r_trace > 0.0 {
155            r_diag.mapv_inplace(|v| v / r_trace.sqrt());
156        }
157
158        // Check convergence
159        let delta_num: f64 = r_diag
160            .iter()
161            .zip(r_diag_old.iter())
162            .map(|(&a, &b)| (a - b).powi(2))
163            .sum::<f64>()
164            .sqrt();
165        let delta_den: f64 = r_diag_old.iter().map(|v| v.powi(2)).sum::<f64>().sqrt();
166        let delta = if delta_den > 0.0 {
167            delta_num / delta_den
168        } else {
169            0.0
170        };
171
172        if delta < opts.eps {
173            break;
174        }
175    }
176
177    // Build final kernel with eLORETA weights
178    // G_weighted = G @ diag(R)
179    let mut g_weighted = g.clone();
180    for c in 0..n_cols {
181        for r in 0..g_weighted.nrows() {
182            g_weighted[[r, c]] *= r_diag[c];
183        }
184    }
185
186    // SVD of weighted gain
187    let (u, sing, vt) = linalg::svd_thin(&g_weighted)?;
188
189    // Compute reginv
190    let mut reginv = Array1::zeros(sing.len());
191    for k in 0..sing.len().min(n_nzero) {
192        let s = sing[k];
193        if s > 0.0 {
194            reginv[k] = s / (s * s + lambda2);
195        }
196    }
197
198    // trans = diag(reginv) @ U^T @ whitener  [k, n_chan]
199    let ut = u.t().to_owned();
200    let trans = {
201        let ut_w = ut.dot(&inv.whitener);
202        let mut t = Array2::zeros(ut_w.dim());
203        for i in 0..t.nrows() {
204            for j in 0..t.ncols() {
205                t[[i, j]] = ut_w[[i, j]] * reginv[i];
206            }
207        }
208        t
209    };
210
211    // kernel = diag(R) @ V @ trans  =  diag(R) @ Vt^T @ trans
212    let v = vt.t().to_owned();
213    let mut kernel = Array2::zeros((n_cols, trans.ncols()));
214    let v_trans = v.dot(&trans);
215    for c in 0..n_cols {
216        let w = r_diag[c];
217        for j in 0..kernel.ncols() {
218            kernel[[c, j]] = w * v_trans[[c, j]];
219        }
220    }
221
222    Ok((kernel, reginv))
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use crate::{
229        inverse::{apply_inverse_with_options, make_inverse_operator},
230        ForwardOperator, InverseMethod, NoiseCov,
231    };
232    use ndarray::Array2;
233
234    #[test]
235    fn test_eloreta_basic() {
236        let n_chan = 16;
237        let n_src = 30;
238        let mut gain = Array2::zeros((n_chan, n_src));
239        for i in 0..n_chan {
240            for j in 0..n_src {
241                let dist = ((i as f64 - j as f64 * n_chan as f64 / n_src as f64).powi(2)
242                    + 1.0)
243                    .sqrt();
244                gain[[i, j]] = 1e-8 / dist;
245            }
246        }
247        let fwd = ForwardOperator::new_fixed(gain);
248        let cov = NoiseCov::diagonal(vec![1e-12; n_chan]);
249        let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
250
251        let data = Array2::from_elem((n_chan, 5), 1e-6);
252        let opts = EloretaOptions {
253            max_iter: 10,
254            eps: 1e-4,
255            force_equal: Some(true),
256        };
257        let stc = apply_inverse_with_options(
258            &data, &inv, 1.0 / 9.0, InverseMethod::ELORETA, Some(&opts),
259        )
260        .unwrap();
261        assert_eq!(stc.data.nrows(), n_src);
262        assert!(stc.data.iter().all(|v: &f64| v.is_finite()));
263    }
264}