1use anyhow::{bail, Result};
15use ndarray::{Array1, Array2};
16
17use super::linalg;
18use super::{EloretaOptions, InverseOperator, SourceOrientation};
19
20pub fn compute_eloreta(
26 inv: &InverseOperator,
27 lambda2: f64,
28 opts: &EloretaOptions,
29) -> Result<(Array2<f64>, Array1<f64>)> {
30 if inv.eigen_leads_weighted {
31 bail!("eLORETA cannot be computed with weighted eigen leads");
32 }
33
34 let n_orient = match inv.orientation {
35 SourceOrientation::Fixed => 1,
36 SourceOrientation::Free => 3,
37 };
38 let n_src = inv.n_sources;
39 let n_nzero = inv.n_nzero;
40
41 let n_k = inv.sing.len();
46 let n_cols = n_src * n_orient;
47
48 let mut g = Array2::zeros((inv.eigen_fields.ncols(), n_cols));
51 for c in 0..n_cols {
52 for ch in 0..g.nrows() {
53 let mut val = 0.0;
54 for k in 0..n_k {
55 val += inv.eigen_fields[[k, ch]] * inv.sing[k] * inv.eigen_leads[[c, k]];
56 }
57 g[[ch, c]] = val;
58 }
59 }
60
61 for c in 0..n_cols {
63 let w = inv.source_cov[c].sqrt();
64 if w > 0.0 {
65 for r in 0..g.nrows() {
66 g[[r, c]] /= w;
67 }
68 }
69 }
70
71 let force_equal = opts.force_equal.unwrap_or(n_orient == 1);
72
73 let mut r_diag = Array1::ones(n_cols); for _iter in 0..opts.max_iter {
78 let mut g_r = g.clone();
80 for c in 0..n_cols {
81 let w = r_diag[c];
82 for r in 0..g_r.nrows() {
83 g_r[[r, c]] *= w;
84 }
85 }
86
87 let g_r_gt = g_r.dot(&g.t());
89
90 let trace = g_r_gt.diag().sum();
92 let norm = trace / n_nzero as f64;
93 let g_r_gt_normed = g_r_gt.mapv(|v| v / norm);
94 let _r_norm = norm;
95
96 let (evals, evecs) = linalg::eigh_sorted(&g_r_gt_normed)?;
98
99 let mut n_mat = Array2::zeros((g.nrows(), g.nrows()));
101 for k in 0..n_nzero {
102 if evals[k].abs() > 0.0 {
103 let inv_val = 1.0 / (evals[k] + lambda2);
104 for i in 0..n_mat.nrows() {
105 for j in 0..n_mat.ncols() {
106 n_mat[[i, j]] += inv_val * evecs[[i, k]] * evecs[[j, k]];
107 }
108 }
109 }
110 }
111
112 let r_diag_old = r_diag.clone();
114
115 if n_orient == 1 || force_equal {
116 for s in 0..n_src {
118 let mut val = 0.0;
119 for o in 0..n_orient {
120 let c = s * n_orient + o;
121 let ng = n_mat.dot(&g.column(c).to_owned());
123 val += g.column(c).dot(&ng);
124 }
125 let w = if val > 0.0 {
126 1.0 / (val / n_orient as f64).sqrt()
127 } else {
128 1.0
129 };
130 for o in 0..n_orient {
131 r_diag[s * n_orient + o] = w;
132 }
133 }
134 } else {
135 for c in 0..n_cols {
137 let ng = n_mat.dot(&g.column(c).to_owned());
138 let val = g.column(c).dot(&ng);
139 r_diag[c] = if val > 0.0 { 1.0 / val.sqrt() } else { 1.0 };
140 }
141 }
142
143 let r_trace: f64 = {
145 let mut gr = g.clone();
146 for c in 0..n_cols {
147 for r in 0..gr.nrows() {
148 gr[[r, c]] *= r_diag[c];
149 }
150 }
151 let grgt = gr.dot(&g.t());
152 grgt.diag().sum() / n_nzero as f64
153 };
154 if r_trace > 0.0 {
155 r_diag.mapv_inplace(|v| v / r_trace.sqrt());
156 }
157
158 let delta_num: f64 = r_diag
160 .iter()
161 .zip(r_diag_old.iter())
162 .map(|(&a, &b)| (a - b).powi(2))
163 .sum::<f64>()
164 .sqrt();
165 let delta_den: f64 = r_diag_old.iter().map(|v| v.powi(2)).sum::<f64>().sqrt();
166 let delta = if delta_den > 0.0 {
167 delta_num / delta_den
168 } else {
169 0.0
170 };
171
172 if delta < opts.eps {
173 break;
174 }
175 }
176
177 let mut g_weighted = g.clone();
180 for c in 0..n_cols {
181 for r in 0..g_weighted.nrows() {
182 g_weighted[[r, c]] *= r_diag[c];
183 }
184 }
185
186 let (u, sing, vt) = linalg::svd_thin(&g_weighted)?;
188
189 let mut reginv = Array1::zeros(sing.len());
191 for k in 0..sing.len().min(n_nzero) {
192 let s = sing[k];
193 if s > 0.0 {
194 reginv[k] = s / (s * s + lambda2);
195 }
196 }
197
198 let ut = u.t().to_owned();
200 let trans = {
201 let ut_w = ut.dot(&inv.whitener);
202 let mut t = Array2::zeros(ut_w.dim());
203 for i in 0..t.nrows() {
204 for j in 0..t.ncols() {
205 t[[i, j]] = ut_w[[i, j]] * reginv[i];
206 }
207 }
208 t
209 };
210
211 let v = vt.t().to_owned();
213 let mut kernel = Array2::zeros((n_cols, trans.ncols()));
214 let v_trans = v.dot(&trans);
215 for c in 0..n_cols {
216 let w = r_diag[c];
217 for j in 0..kernel.ncols() {
218 kernel[[c, j]] = w * v_trans[[c, j]];
219 }
220 }
221
222 Ok((kernel, reginv))
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use crate::{
229 inverse::{apply_inverse_with_options, make_inverse_operator},
230 ForwardOperator, InverseMethod, NoiseCov,
231 };
232 use ndarray::Array2;
233
234 #[test]
235 fn test_eloreta_basic() {
236 let n_chan = 16;
237 let n_src = 30;
238 let mut gain = Array2::zeros((n_chan, n_src));
239 for i in 0..n_chan {
240 for j in 0..n_src {
241 let dist = ((i as f64 - j as f64 * n_chan as f64 / n_src as f64).powi(2)
242 + 1.0)
243 .sqrt();
244 gain[[i, j]] = 1e-8 / dist;
245 }
246 }
247 let fwd = ForwardOperator::new_fixed(gain);
248 let cov = NoiseCov::diagonal(vec![1e-12; n_chan]);
249 let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
250
251 let data = Array2::from_elem((n_chan, 5), 1e-6);
252 let opts = EloretaOptions {
253 max_iter: 10,
254 eps: 1e-4,
255 force_equal: Some(true),
256 };
257 let stc = apply_inverse_with_options(
258 &data, &inv, 1.0 / 9.0, InverseMethod::ELORETA, Some(&opts),
259 )
260 .unwrap();
261 assert_eq!(stc.data.nrows(), n_src);
262 assert!(stc.data.iter().all(|v: &f64| v.is_finite()));
263 }
264}