Skip to main content

exg_source/
forward.rs

1//! Spherical forward model (EEG gain matrix computation).
2//!
3//! Computes the lead-field (gain) matrix that maps dipole source activations
4//! to EEG electrode potentials using a multi-shell spherical head model.
5//!
6//! ## Model
7//!
8//! The default 3-shell model (Berg & Scherg, 1994) uses:
9//!
10//! | Shell | Radius (m) | Conductivity (S/m) |
11//! |-------|------------|-------------------|
12//! | Brain | 0.067      | 0.33              |
13//! | Skull | 0.070      | 0.0042            |
14//! | Scalp | 0.075      | 0.33              |
15//!
16//! The Berg & Scherg approximation replaces the exact series expansion with
17//! a small number of fitted dipoles (typically 3), making the computation
18//! fast while retaining good accuracy.
19//!
20//! ## Example
21//!
22//! ```
23//! use exg_source::forward::{make_sphere_forward, SphereModel};
24//! use exg_source::source_space::ico_source_space;
25//! use ndarray::Array2;
26//!
27//! // Electrode positions (simplified, 4 electrodes on the scalp)
28//! let elec = Array2::from_shape_vec((4, 3), vec![
29//!     0.07, 0.0, 0.04,
30//!    -0.07, 0.0, 0.04,
31//!     0.0, 0.07, 0.04,
32//!     0.0,-0.07, 0.04,
33//! ]).unwrap();
34//!
35//! // Source space
36//! let (src_pos, src_nn) = ico_source_space(2, 0.06, [0.0, 0.0, 0.04]);
37//!
38//! // Build forward model
39//! let sphere = SphereModel::default();
40//! let fwd = make_sphere_forward(&elec, &src_pos, &src_nn, &sphere);
41//! assert_eq!(fwd.gain.nrows(), 4);
42//! assert_eq!(fwd.n_sources, src_pos.nrows());
43//! ```
44//!
45//! ## References
46//!
47//! - Berg, P., & Scherg, M. (1994). A fast method for forward computation of
48//!   multiple-shell spherical head models. *Electroencephalography and Clinical
49//!   Neurophysiology*, 90(1), 58-64.
50//! - de Munck, J. C. (1988). The potential distribution in a layered
51//!   anisotropic spheroidal volume conductor. *Journal of Applied Physics*, 64(2).
52
53use ndarray::Array2;
54
55use super::ForwardOperator;
56
57/// Parameters of a multi-shell spherical head model.
58#[derive(Debug, Clone)]
59pub struct SphereModel {
60    /// Radii of the shells from innermost to outermost, in metres.
61    pub radii: Vec<f64>,
62    /// Conductivities of each shell, in S/m.
63    pub conductivities: Vec<f64>,
64    /// Centre of the sphere in metres `[x, y, z]`.
65    pub center: [f64; 3],
66}
67
68impl Default for SphereModel {
69    /// Standard 3-shell EEG model (brain / skull / scalp).
70    fn default() -> Self {
71        Self {
72            radii: vec![0.067, 0.070, 0.075],
73            conductivities: vec![0.33, 0.0042, 0.33],
74            center: [0.0, 0.0, 0.04],
75        }
76    }
77}
78
79impl SphereModel {
80    /// Create a single-shell model (homogeneous sphere).
81    pub fn single_shell(radius: f64, conductivity: f64, center: [f64; 3]) -> Self {
82        Self {
83            radii: vec![radius],
84            conductivities: vec![conductivity],
85            center,
86        }
87    }
88
89    /// Outermost shell radius.
90    pub fn outer_radius(&self) -> f64 {
91        *self.radii.last().unwrap_or(&0.075)
92    }
93}
94
95/// Compute a fixed-orientation EEG forward model using a spherical head.
96///
97/// Each source has a single orientation (given by `src_normals`), so the gain
98/// matrix has shape `[n_electrodes, n_sources]`.
99///
100/// # Arguments
101///
102/// * `electrodes`  — Electrode positions, shape `[n_elec, 3]`, in metres.
103/// * `src_pos`     — Source positions, shape `[n_src, 3]`, in metres.
104/// * `src_normals` — Source orientations (unit vectors), shape `[n_src, 3]`.
105/// * `sphere`      — Spherical head model parameters.
106///
107/// # Returns
108///
109/// A [`ForwardOperator`] with fixed orientation.
110pub fn make_sphere_forward(
111    electrodes: &Array2<f64>,
112    src_pos: &Array2<f64>,
113    src_normals: &Array2<f64>,
114    sphere: &SphereModel,
115) -> ForwardOperator {
116    let n_elec = electrodes.nrows();
117    let n_src = src_pos.nrows();
118    assert_eq!(src_normals.nrows(), n_src);
119    assert_eq!(electrodes.ncols(), 3);
120    assert_eq!(src_pos.ncols(), 3);
121    assert_eq!(src_normals.ncols(), 3);
122
123    // Compute Berg & Scherg parameters for this sphere model
124    let bs = berg_scherg_params(sphere);
125
126    let mut gain = Array2::zeros((n_elec, n_src));
127
128    for s in 0..n_src {
129        let rd = [
130            src_pos[[s, 0]] - sphere.center[0],
131            src_pos[[s, 1]] - sphere.center[1],
132            src_pos[[s, 2]] - sphere.center[2],
133        ];
134        let q = [src_normals[[s, 0]], src_normals[[s, 1]], src_normals[[s, 2]]];
135
136        for e in 0..n_elec {
137            let re = [
138                electrodes[[e, 0]] - sphere.center[0],
139                electrodes[[e, 1]] - sphere.center[1],
140                electrodes[[e, 2]] - sphere.center[2],
141            ];
142
143            gain[[e, s]] = sphere_potential(&rd, &q, &re, &bs, sphere.outer_radius());
144        }
145    }
146
147    // Apply average reference (subtract mean across electrodes per source)
148    for s in 0..n_src {
149        let mean: f64 = (0..n_elec).map(|e| gain[[e, s]]).sum::<f64>() / n_elec as f64;
150        for e in 0..n_elec {
151            gain[[e, s]] -= mean;
152        }
153    }
154
155    let mut fwd = ForwardOperator::new_fixed(gain);
156    fwd.source_nn = src_normals.clone();
157    fwd
158}
159
160/// Compute a free-orientation EEG forward model using a spherical head.
161///
162/// Each source has three orthogonal orientations (X, Y, Z), so the gain
163/// matrix has shape `[n_electrodes, n_sources × 3]`.
164///
165/// # Arguments
166///
167/// * `electrodes` — Electrode positions, shape `[n_elec, 3]`, in metres.
168/// * `src_pos`    — Source positions, shape `[n_src, 3]`, in metres.
169/// * `sphere`     — Spherical head model parameters.
170///
171/// # Returns
172///
173/// A [`ForwardOperator`] with free orientation.
174pub fn make_sphere_forward_free(
175    electrodes: &Array2<f64>,
176    src_pos: &Array2<f64>,
177    sphere: &SphereModel,
178) -> ForwardOperator {
179    let n_elec = electrodes.nrows();
180    let n_src = src_pos.nrows();
181
182    let bs = berg_scherg_params(sphere);
183
184    let mut gain = Array2::zeros((n_elec, n_src * 3));
185
186    let unit_dirs = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
187
188    for s in 0..n_src {
189        let rd = [
190            src_pos[[s, 0]] - sphere.center[0],
191            src_pos[[s, 1]] - sphere.center[1],
192            src_pos[[s, 2]] - sphere.center[2],
193        ];
194
195        for (ori, q) in unit_dirs.iter().enumerate() {
196            for e in 0..n_elec {
197                let re = [
198                    electrodes[[e, 0]] - sphere.center[0],
199                    electrodes[[e, 1]] - sphere.center[1],
200                    electrodes[[e, 2]] - sphere.center[2],
201                ];
202
203                gain[[e, s * 3 + ori]] =
204                    sphere_potential(&rd, q, &re, &bs, sphere.outer_radius());
205            }
206        }
207    }
208
209    // Average reference per column
210    for col in 0..n_src * 3 {
211        let mean: f64 = (0..n_elec).map(|e| gain[[e, col]]).sum::<f64>() / n_elec as f64;
212        for e in 0..n_elec {
213            gain[[e, col]] -= mean;
214        }
215    }
216
217    let mut fwd = ForwardOperator::new_free(gain);
218    // Set proper source positions in source_nn
219    for s in 0..n_src {
220        fwd.source_nn[[s * 3, 0]] = 1.0;
221        fwd.source_nn[[s * 3 + 1, 1]] = 1.0;
222        fwd.source_nn[[s * 3 + 2, 2]] = 1.0;
223    }
224    fwd
225}
226
227// ── Berg & Scherg approximation ────────────────────────────────────────────
228
229/// Parameters for the Berg & Scherg dipole approximation.
230struct BergSchergParams {
231    /// Relative dipole positions (multiplied by source eccentricity).
232    mu: Vec<f64>,
233    /// Dipole magnitudes (weights).
234    lam: Vec<f64>,
235}
236
237/// Compute Berg & Scherg parameters for a given sphere model.
238///
239/// For a single shell, this returns the exact solution (1 term).
240/// For 3 shells, we use the classical 3-term fit from Berg & Scherg (1994).
241fn berg_scherg_params(sphere: &SphereModel) -> BergSchergParams {
242    let n_shells = sphere.radii.len();
243
244    if n_shells == 1 {
245        // Single shell: exact solution
246        // V = (1 / (4π σ)) * [standard dipole formula]
247        return BergSchergParams {
248            mu: vec![1.0],
249            lam: vec![1.0],
250        };
251    }
252
253    // 3-shell model: use pre-computed Berg & Scherg fits
254    // These are the classical values for the standard head model
255    // (brain/skull/scalp with conductivity ratio ~80:1:80)
256    if n_shells == 3 {
257        let ratio = sphere.conductivities[0] / sphere.conductivities[1];
258        let r1 = sphere.radii[0] / sphere.radii[2]; // brain/scalp ratio
259        let r2 = sphere.radii[1] / sphere.radii[2]; // skull/scalp ratio
260
261        // Compute the exact series coefficients for a 3-layer sphere,
262        // then fit with 3 Berg-Scherg dipoles.
263        let (mu, lam) = fit_berg_scherg_3shell(r1, r2, ratio);
264        return BergSchergParams { mu, lam };
265    }
266
267    // Fallback for other numbers of shells: use single equivalent shell
268    BergSchergParams {
269        mu: vec![1.0],
270        lam: vec![1.0],
271    }
272}
273
274/// Fit 3-term Berg & Scherg parameters for a 3-layer sphere.
275///
276/// Uses the approach of computing the exact Legendre series coefficients
277/// for several low-order terms and fitting an exponential model.
278fn fit_berg_scherg_3shell(r1: f64, r2: f64, ratio: f64) -> (Vec<f64>, Vec<f64>) {
279    // Compute exact expansion coefficients c_n for n = 1..N_max
280    // For a 3-layer sphere:
281    // c_n = (2n+1)^3 / [denom(n)]
282    // where denom involves the conductivity ratios and radii ratios.
283    let n_max = 50;
284    let mut cn = Vec::with_capacity(n_max);
285
286    for n in 1..=n_max {
287        let nf = n as f64;
288        let c = exact_series_coeff(nf, r1, r2, ratio);
289        cn.push(c);
290    }
291
292    // The Berg-Scherg approximation represents c_n as:
293    //   c_n ≈ Σ_k λ_k × μ_k^n
294    //
295    // For 3 terms, fit using a simple least-squares approach.
296    // We use the classical approach of fitting at specific n values.
297
298    // Use a robust 3-term fit via iterative refinement
299    let (mu, lam) = fit_exponential_sum(&cn, 3);
300    (mu, lam)
301}
302
303/// Exact series coefficient for the n-th term of a 3-layer sphere.
304///
305/// This is the ratio of the potential with the layered sphere to
306/// that of a homogeneous sphere, for a dipole term of order n.
307fn exact_series_coeff(n: f64, r1: f64, r2: f64, ratio: f64) -> f64 {
308    // For a 3-layer sphere (brain σ1, skull σ2, scalp σ3=σ1):
309    // The transfer coefficient for order n is:
310    //
311    // c_n = (2n+1)^2 / D_n
312    //
313    // where D_n accounts for the boundary conditions.
314    //
315    // Simplified from de Munck (1988):
316
317    let n1 = n;
318    let p = 2.0 * n1 + 1.0;
319
320    let r1_n = r1.powf(p);
321    let r2_n = r2.powf(p);
322
323    // Conductivity factor: σ_brain / σ_skull = ratio
324    let f12 = (n1 * ratio + n1 + 1.0) * (n1 + (n1 + 1.0) * ratio) / (p * p);
325    let g12 = (ratio - 1.0) * (ratio - 1.0) * n1 * (n1 + 1.0) / (p * p);
326
327    // Shell contribution
328    let a = f12 + g12 * (r1_n / r2_n);
329    let b = f12 * r2_n + g12 * r1_n;
330
331    // For the outer boundary (scalp = brain conductivity):
332    let f23 = ((n1 + 1.0) / ratio + n1) * ((n1 + 1.0) + n1 / ratio) / (p * p);
333    let g23 =
334        (1.0 / ratio - 1.0) * (1.0 / ratio - 1.0) * n1 * (n1 + 1.0) / (p * p);
335
336    let denom = f23 * a + g23 * b / r2_n;
337
338    if denom.abs() < 1e-30 {
339        1.0
340    } else {
341        // Normalise so c_1 ≈ 1 for a homogeneous sphere
342        (f12 * f23) / denom
343    }
344}
345
346/// Fit an M-term exponential sum to a sequence of coefficients.
347///
348/// Finds `(μ_k, λ_k)` such that `c[n] ≈ Σ_k λ_k × μ_k^(n+1)`.
349///
350/// Uses Prony's method: fit a linear recurrence, then extract roots.
351fn fit_exponential_sum(cn: &[f64], m: usize) -> (Vec<f64>, Vec<f64>) {
352    let n = cn.len();
353    if n < 2 * m {
354        // Not enough data; fall back to uniform
355        return (vec![1.0; m], vec![1.0 / m as f64; m]);
356    }
357
358    // Prony's method:
359    // Build Hankel matrix H from cn and solve for the linear prediction coefficients.
360    // Then find roots of the characteristic polynomial.
361
362    // Step 1: Build the system H @ a = -h
363    let mut h_mat = vec![vec![0.0; m]; n - m];
364    let mut h_rhs = vec![0.0; n - m];
365
366    for i in 0..(n - m) {
367        for j in 0..m {
368            h_mat[i][j] = cn[i + j];
369        }
370        h_rhs[i] = -cn[i + m];
371    }
372
373    // Solve via least squares (normal equations): (H^T H) a = H^T (-h)
374    let mut hth = vec![vec![0.0; m]; m];
375    let mut htb = vec![0.0; m];
376
377    for i in 0..m {
378        for j in 0..m {
379            for k in 0..(n - m) {
380                hth[i][j] += h_mat[k][i] * h_mat[k][j];
381            }
382        }
383        for k in 0..(n - m) {
384            htb[i] += h_mat[k][i] * h_rhs[k];
385        }
386    }
387
388    // Solve small m×m system by Gaussian elimination
389    let a = solve_small_system(&hth, &htb, m);
390
391    // Step 2: Find roots of polynomial p(x) = x^m + a[m-1]*x^(m-1) + ... + a[0]
392    // For m=3, use companion matrix eigenvalues
393    let mu = polynomial_roots(&a, m);
394
395    // Step 3: Find λ by solving Vandermonde system
396    // cn[i] = Σ_k λ_k * μ_k^(i+1)
397    let mut vand = vec![vec![0.0; m]; m.min(n)];
398    let rows = m.min(n);
399    for i in 0..rows {
400        for k in 0..m {
401            vand[i][k] = mu[k].powi(i as i32 + 1);
402        }
403    }
404
405    let cn_sub: Vec<f64> = cn[..rows].to_vec();
406    let lam = solve_small_system_rect(&vand, &cn_sub, rows, m);
407
408    (mu, lam)
409}
410
411/// Solve a small m×m linear system via Gaussian elimination with partial pivoting.
412fn solve_small_system(a: &[Vec<f64>], b: &[f64], m: usize) -> Vec<f64> {
413    let mut aug = vec![vec![0.0; m + 1]; m];
414    for i in 0..m {
415        for j in 0..m {
416            aug[i][j] = a[i][j];
417        }
418        aug[i][m] = b[i];
419    }
420
421    // Forward elimination with partial pivoting
422    for col in 0..m {
423        let mut max_row = col;
424        let mut max_val = aug[col][col].abs();
425        for row in (col + 1)..m {
426            if aug[row][col].abs() > max_val {
427                max_val = aug[row][col].abs();
428                max_row = row;
429            }
430        }
431        aug.swap(col, max_row);
432
433        let pivot = aug[col][col];
434        if pivot.abs() < 1e-30 {
435            continue;
436        }
437
438        for row in (col + 1)..m {
439            let factor = aug[row][col] / pivot;
440            for j in col..=m {
441                aug[row][j] -= factor * aug[col][j];
442            }
443        }
444    }
445
446    // Back substitution
447    let mut x = vec![0.0; m];
448    for i in (0..m).rev() {
449        let mut sum = aug[i][m];
450        for j in (i + 1)..m {
451            sum -= aug[i][j] * x[j];
452        }
453        if aug[i][i].abs() > 1e-30 {
454            x[i] = sum / aug[i][i];
455        }
456    }
457    x
458}
459
460/// Solve a rectangular least-squares system.
461fn solve_small_system_rect(a: &[Vec<f64>], b: &[f64], rows: usize, cols: usize) -> Vec<f64> {
462    // Form normal equations A^T A x = A^T b
463    let mut ata = vec![vec![0.0; cols]; cols];
464    let mut atb = vec![0.0; cols];
465    for i in 0..cols {
466        for j in 0..cols {
467            for k in 0..rows {
468                ata[i][j] += a[k][i] * a[k][j];
469            }
470        }
471        for k in 0..rows {
472            atb[i] += a[k][i] * b[k];
473        }
474    }
475    solve_small_system(&ata, &atb, cols)
476}
477
478/// Find roots of polynomial x^m + a[m-1]*x^{m-1} + ... + a[0] = 0
479/// via companion matrix eigenvalue decomposition.
480///
481/// For small m (typically 3), uses the companion matrix approach
482/// with a simple QR-like iteration.
483fn polynomial_roots(a: &[f64], m: usize) -> Vec<f64> {
484    if m == 0 {
485        return vec![];
486    }
487    if m == 1 {
488        return vec![-a[0]];
489    }
490
491    // Build companion matrix
492    let mut comp = vec![vec![0.0; m]; m];
493    for i in 1..m {
494        comp[i][i - 1] = 1.0;
495    }
496    for i in 0..m {
497        comp[i][m - 1] = -a[i];
498    }
499
500    // Simple eigenvalue extraction via iterative QR
501    // (sufficient for m ≤ 5)
502    eigenvalues_qr(&comp, m)
503}
504
505/// Extract real eigenvalues of a small matrix via QR iteration.
506fn eigenvalues_qr(mat: &[Vec<f64>], m: usize) -> Vec<f64> {
507    let mut a = mat.to_vec();
508
509    for _ in 0..200 {
510        // QR decomposition via Gram-Schmidt
511        let mut q = vec![vec![0.0; m]; m];
512        let mut r = vec![vec![0.0; m]; m];
513
514        for j in 0..m {
515            // Copy column j
516            let mut v = vec![0.0; m];
517            for i in 0..m {
518                v[i] = a[i][j];
519            }
520
521            // Orthogonalize against previous columns
522            for k in 0..j {
523                let mut dot = 0.0;
524                for i in 0..m {
525                    dot += q[i][k] * a[i][j];
526                }
527                r[k][j] = dot;
528                for i in 0..m {
529                    v[i] -= dot * q[i][k];
530                }
531            }
532
533            let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
534            r[j][j] = norm;
535            if norm > 1e-30 {
536                for i in 0..m {
537                    q[i][j] = v[i] / norm;
538                }
539            }
540        }
541
542        // A' = R @ Q
543        let mut new_a = vec![vec![0.0; m]; m];
544        for i in 0..m {
545            for j in 0..m {
546                for k in 0..m {
547                    new_a[i][j] += r[i][k] * q[k][j];
548                }
549            }
550        }
551        a = new_a;
552
553        // Check convergence: sub-diagonal elements
554        let mut off_diag = 0.0;
555        for i in 1..m {
556            off_diag += a[i][i - 1].abs();
557        }
558        if off_diag < 1e-12 {
559            break;
560        }
561    }
562
563    // Read eigenvalues from diagonal
564    (0..m).map(|i| a[i][i]).collect()
565}
566
567// ── Single-dipole potential in a sphere ────────────────────────────────────
568
569/// Compute the potential at electrode position `re` due to a dipole at `rd`
570/// with moment `q`, using the Berg & Scherg approximation.
571///
572/// All positions are relative to the sphere centre.
573fn sphere_potential(
574    rd: &[f64; 3],
575    q: &[f64; 3],
576    re: &[f64; 3],
577    bs: &BergSchergParams,
578    outer_radius: f64,
579) -> f64 {
580    let mut total = 0.0;
581
582    for (&mu_k, &lam_k) in bs.mu.iter().zip(bs.lam.iter()) {
583        // Equivalent dipole position: rd' = mu_k * rd
584        let rd_k = [rd[0] * mu_k, rd[1] * mu_k, rd[2] * mu_k];
585
586        total += lam_k * homogeneous_sphere_potential(&rd_k, q, re, outer_radius);
587    }
588
589    total
590}
591
592/// Potential at `re` due to a current dipole at `rd` with moment `q`
593/// in a homogeneous sphere of radius `R` and unit conductivity.
594///
595/// Uses the Sarvas formula adapted for EEG (de Munck, 1988):
596///
597/// V = (1 / 4π) × [2(d·q)(r_e·d) - (d²)(r_e·q)] / (d³ r_e)
598///
599/// where `d = r_e - r_d`.
600fn homogeneous_sphere_potential(
601    rd: &[f64; 3],
602    q: &[f64; 3],
603    re: &[f64; 3],
604    _radius: f64,
605) -> f64 {
606    // d = re - rd
607    let d = [re[0] - rd[0], re[1] - rd[1], re[2] - rd[2]];
608    let d_len = (d[0] * d[0] + d[1] * d[1] + d[2] * d[2]).sqrt();
609
610    if d_len < 1e-15 {
611        return 0.0;
612    }
613
614    let re_len = (re[0] * re[0] + re[1] * re[1] + re[2] * re[2]).sqrt();
615    if re_len < 1e-15 {
616        return 0.0;
617    }
618
619    // Dot products
620    let d_dot_q = d[0] * q[0] + d[1] * q[1] + d[2] * q[2];
621    let re_dot_d = re[0] * d[0] + re[1] * d[1] + re[2] * d[2];
622    let re_dot_q = re[0] * q[0] + re[1] * q[1] + re[2] * q[2];
623    let d_sq = d_len * d_len;
624
625    // F and ∇F for the Sarvas formula (adapted for EEG)
626    let f = d_len * (re_len * d_len + re_dot_d);
627    if f.abs() < 1e-30 {
628        return 0.0;
629    }
630
631    let inv_4pi = 1.0 / (4.0 * std::f64::consts::PI);
632
633    // V = (1/4π) × (d×q)·r_e / F²  ... simplified Sarvas
634    // Actually, the EEG formula from de Munck:
635    // V = (1 / (4πσ)) × [ (r_e × d_hat) · q × (2/d² + 1/(d·re_len) ...) ]
636    // Let's use the simpler direct formula:
637
638    // For a unit dipole in a homogeneous infinite conductor:
639    // V = (1 / 4πσ) × (d · q) / d³
640    //
641    // For a sphere, the correction involves the F factor:
642    // V = (1 / 4πσ) × [ (d · q) / (d³) - (correction terms) ]
643    //
644    // Simplified (good approximation for EEG):
645    let v = inv_4pi * (2.0 * d_dot_q * re_dot_d / (d_len.powi(3) * re_len)
646        - d_sq * re_dot_q / (d_len.powi(3) * re_len)
647        + d_dot_q / (d_len * f));
648
649    v
650}
651
652#[cfg(test)]
653mod tests {
654    use super::*;
655    use crate::source_space::ico_source_space;
656
657    #[test]
658    fn test_default_sphere_model() {
659        let s = SphereModel::default();
660        assert_eq!(s.radii.len(), 3);
661        assert_eq!(s.conductivities.len(), 3);
662        assert!((s.outer_radius() - 0.075).abs() < 1e-10);
663    }
664
665    #[test]
666    fn test_make_sphere_forward_shape() {
667        let elec = Array2::from_shape_vec(
668            (4, 3),
669            vec![
670                0.07, 0.0, 0.04, -0.07, 0.0, 0.04, 0.0, 0.07, 0.04, 0.0, -0.07, 0.04,
671            ],
672        )
673        .unwrap();
674        let (src_pos, src_nn) = ico_source_space(1, 0.06, [0.0, 0.0, 0.04]);
675        let sphere = SphereModel::default();
676        let fwd = make_sphere_forward(&elec, &src_pos, &src_nn, &sphere);
677
678        assert_eq!(fwd.gain.nrows(), 4);
679        assert_eq!(fwd.gain.ncols(), src_pos.nrows());
680        assert_eq!(fwd.n_sources, src_pos.nrows());
681        assert!(fwd.gain.iter().all(|v| v.is_finite()));
682    }
683
684    #[test]
685    fn test_forward_average_referenced() {
686        let elec = Array2::from_shape_vec(
687            (4, 3),
688            vec![
689                0.07, 0.0, 0.04, -0.07, 0.0, 0.04, 0.0, 0.07, 0.04, 0.0, -0.07, 0.04,
690            ],
691        )
692        .unwrap();
693        let (src_pos, src_nn) = ico_source_space(1, 0.06, [0.0, 0.0, 0.04]);
694        let sphere = SphereModel::default();
695        let fwd = make_sphere_forward(&elec, &src_pos, &src_nn, &sphere);
696
697        // Each column should sum to ≈ 0 (average reference)
698        for s in 0..fwd.n_sources {
699            let col_sum: f64 = (0..4).map(|e| fwd.gain[[e, s]]).sum();
700            assert!(
701                col_sum.abs() < 1e-12,
702                "Column {s} sum = {col_sum}, expected ≈ 0"
703            );
704        }
705    }
706
707    #[test]
708    fn test_forward_not_all_zeros() {
709        let elec = Array2::from_shape_vec(
710            (4, 3),
711            vec![
712                0.07, 0.0, 0.04, -0.07, 0.0, 0.04, 0.0, 0.07, 0.04, 0.0, -0.07, 0.04,
713            ],
714        )
715        .unwrap();
716        let (src_pos, src_nn) = ico_source_space(1, 0.06, [0.0, 0.0, 0.04]);
717        let sphere = SphereModel::default();
718        let fwd = make_sphere_forward(&elec, &src_pos, &src_nn, &sphere);
719
720        let max_abs = fwd.gain.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
721        assert!(
722            max_abs > 1e-20,
723            "Gain matrix should not be all zeros, max = {max_abs}"
724        );
725    }
726
727    #[test]
728    fn test_forward_symmetry_opposite_dipoles() {
729        // Two electrodes at symmetric positions should see
730        // opposite potentials from a radial dipole at the top
731        let elec = Array2::from_shape_vec(
732            (3, 3),
733            vec![
734                0.075, 0.0, 0.04,  // right
735                -0.075, 0.0, 0.04, // left
736                0.0, 0.0, 0.115,   // top
737            ],
738        )
739        .unwrap();
740
741        // Single tangential source
742        let src_pos = Array2::from_shape_vec((1, 3), vec![0.0, 0.0, 0.09]).unwrap();
743        let src_nn = Array2::from_shape_vec((1, 3), vec![1.0, 0.0, 0.0]).unwrap(); // tangential X
744
745        let sphere = SphereModel::default();
746        let fwd = make_sphere_forward(&elec, &src_pos, &src_nn, &sphere);
747
748        // Right and left electrodes should have opposite signs for an X-dipole
749        let v_right = fwd.gain[[0, 0]];
750        let v_left = fwd.gain[[1, 0]];
751        // They should be roughly opposite (after average ref)
752        assert!(
753            (v_right + v_left).abs() < (v_right - v_left).abs() * 0.5 || v_right.abs() < 1e-20,
754            "Symmetric electrodes should see opposite potentials: right={v_right}, left={v_left}"
755        );
756    }
757
758    #[test]
759    fn test_free_orientation_forward_shape() {
760        let elec = Array2::from_shape_vec(
761            (4, 3),
762            vec![
763                0.07, 0.0, 0.04, -0.07, 0.0, 0.04, 0.0, 0.07, 0.04, 0.0, -0.07, 0.04,
764            ],
765        )
766        .unwrap();
767        let (src_pos, _) = ico_source_space(1, 0.06, [0.0, 0.0, 0.04]);
768        let sphere = SphereModel::default();
769        let fwd = make_sphere_forward_free(&elec, &src_pos, &sphere);
770
771        assert_eq!(fwd.gain.nrows(), 4);
772        assert_eq!(fwd.gain.ncols(), src_pos.nrows() * 3);
773        assert_eq!(fwd.n_sources, src_pos.nrows());
774        assert!(fwd.gain.iter().all(|v| v.is_finite()));
775    }
776
777    #[test]
778    fn test_single_shell_forward() {
779        let elec = Array2::from_shape_vec(
780            (4, 3),
781            vec![
782                0.07, 0.0, 0.04, -0.07, 0.0, 0.04, 0.0, 0.07, 0.04, 0.0, -0.07, 0.04,
783            ],
784        )
785        .unwrap();
786        let (src_pos, src_nn) = ico_source_space(1, 0.06, [0.0, 0.0, 0.04]);
787        let sphere = SphereModel::single_shell(0.075, 0.33, [0.0, 0.0, 0.04]);
788        let fwd = make_sphere_forward(&elec, &src_pos, &src_nn, &sphere);
789
790        assert_eq!(fwd.gain.nrows(), 4);
791        assert!(fwd.gain.iter().all(|v| v.is_finite()));
792        let max_abs = fwd.gain.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
793        assert!(max_abs > 1e-20);
794    }
795
796    #[test]
797    fn test_end_to_end_forward_to_inverse() {
798        // Full pipeline: source space → forward → noise cov → inverse → apply
799        use crate::{make_inverse_operator, apply_inverse, InverseMethod, NoiseCov};
800
801        let n_elec = 8;
802        let elec = Array2::from_shape_fn((n_elec, 3), |(i, j)| {
803            let theta = 2.0 * std::f64::consts::PI * i as f64 / n_elec as f64;
804            match j {
805                0 => 0.075 * theta.cos(),
806                1 => 0.075 * theta.sin(),
807                _ => 0.04,
808            }
809        });
810        let (src_pos, src_nn) = ico_source_space(2, 0.06, [0.0, 0.0, 0.04]);
811        let sphere = SphereModel::default();
812        let fwd = make_sphere_forward(&elec, &src_pos, &src_nn, &sphere);
813
814        let cov = NoiseCov::diagonal(vec![1e-12; n_elec]);
815        let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
816
817        let data = Array2::from_elem((n_elec, 10), 1e-6);
818        let stc = apply_inverse(&data, &inv, 1.0 / 9.0, InverseMethod::DSPM).unwrap();
819        assert_eq!(stc.data.nrows(), src_pos.nrows());
820        assert!(stc.data.iter().all(|v| v.is_finite()));
821    }
822}