1use ndarray::Array2;
54
55use super::ForwardOperator;
56
57#[derive(Debug, Clone)]
59pub struct SphereModel {
60 pub radii: Vec<f64>,
62 pub conductivities: Vec<f64>,
64 pub center: [f64; 3],
66}
67
68impl Default for SphereModel {
69 fn default() -> Self {
71 Self {
72 radii: vec![0.067, 0.070, 0.075],
73 conductivities: vec![0.33, 0.0042, 0.33],
74 center: [0.0, 0.0, 0.04],
75 }
76 }
77}
78
79impl SphereModel {
80 pub fn single_shell(radius: f64, conductivity: f64, center: [f64; 3]) -> Self {
82 Self {
83 radii: vec![radius],
84 conductivities: vec![conductivity],
85 center,
86 }
87 }
88
89 pub fn outer_radius(&self) -> f64 {
91 *self.radii.last().unwrap_or(&0.075)
92 }
93}
94
95pub fn make_sphere_forward(
111 electrodes: &Array2<f64>,
112 src_pos: &Array2<f64>,
113 src_normals: &Array2<f64>,
114 sphere: &SphereModel,
115) -> ForwardOperator {
116 let n_elec = electrodes.nrows();
117 let n_src = src_pos.nrows();
118 assert_eq!(src_normals.nrows(), n_src);
119 assert_eq!(electrodes.ncols(), 3);
120 assert_eq!(src_pos.ncols(), 3);
121 assert_eq!(src_normals.ncols(), 3);
122
123 let bs = berg_scherg_params(sphere);
125
126 let mut gain = Array2::zeros((n_elec, n_src));
127
128 for s in 0..n_src {
129 let rd = [
130 src_pos[[s, 0]] - sphere.center[0],
131 src_pos[[s, 1]] - sphere.center[1],
132 src_pos[[s, 2]] - sphere.center[2],
133 ];
134 let q = [src_normals[[s, 0]], src_normals[[s, 1]], src_normals[[s, 2]]];
135
136 for e in 0..n_elec {
137 let re = [
138 electrodes[[e, 0]] - sphere.center[0],
139 electrodes[[e, 1]] - sphere.center[1],
140 electrodes[[e, 2]] - sphere.center[2],
141 ];
142
143 gain[[e, s]] = sphere_potential(&rd, &q, &re, &bs, sphere.outer_radius());
144 }
145 }
146
147 for s in 0..n_src {
149 let mean: f64 = (0..n_elec).map(|e| gain[[e, s]]).sum::<f64>() / n_elec as f64;
150 for e in 0..n_elec {
151 gain[[e, s]] -= mean;
152 }
153 }
154
155 let mut fwd = ForwardOperator::new_fixed(gain);
156 fwd.source_nn = src_normals.clone();
157 fwd
158}
159
160pub fn make_sphere_forward_free(
175 electrodes: &Array2<f64>,
176 src_pos: &Array2<f64>,
177 sphere: &SphereModel,
178) -> ForwardOperator {
179 let n_elec = electrodes.nrows();
180 let n_src = src_pos.nrows();
181
182 let bs = berg_scherg_params(sphere);
183
184 let mut gain = Array2::zeros((n_elec, n_src * 3));
185
186 let unit_dirs = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
187
188 for s in 0..n_src {
189 let rd = [
190 src_pos[[s, 0]] - sphere.center[0],
191 src_pos[[s, 1]] - sphere.center[1],
192 src_pos[[s, 2]] - sphere.center[2],
193 ];
194
195 for (ori, q) in unit_dirs.iter().enumerate() {
196 for e in 0..n_elec {
197 let re = [
198 electrodes[[e, 0]] - sphere.center[0],
199 electrodes[[e, 1]] - sphere.center[1],
200 electrodes[[e, 2]] - sphere.center[2],
201 ];
202
203 gain[[e, s * 3 + ori]] =
204 sphere_potential(&rd, q, &re, &bs, sphere.outer_radius());
205 }
206 }
207 }
208
209 for col in 0..n_src * 3 {
211 let mean: f64 = (0..n_elec).map(|e| gain[[e, col]]).sum::<f64>() / n_elec as f64;
212 for e in 0..n_elec {
213 gain[[e, col]] -= mean;
214 }
215 }
216
217 let mut fwd = ForwardOperator::new_free(gain);
218 for s in 0..n_src {
220 fwd.source_nn[[s * 3, 0]] = 1.0;
221 fwd.source_nn[[s * 3 + 1, 1]] = 1.0;
222 fwd.source_nn[[s * 3 + 2, 2]] = 1.0;
223 }
224 fwd
225}
226
227struct BergSchergParams {
231 mu: Vec<f64>,
233 lam: Vec<f64>,
235}
236
237fn berg_scherg_params(sphere: &SphereModel) -> BergSchergParams {
242 let n_shells = sphere.radii.len();
243
244 if n_shells == 1 {
245 return BergSchergParams {
248 mu: vec![1.0],
249 lam: vec![1.0],
250 };
251 }
252
253 if n_shells == 3 {
257 let ratio = sphere.conductivities[0] / sphere.conductivities[1];
258 let r1 = sphere.radii[0] / sphere.radii[2]; let r2 = sphere.radii[1] / sphere.radii[2]; let (mu, lam) = fit_berg_scherg_3shell(r1, r2, ratio);
264 return BergSchergParams { mu, lam };
265 }
266
267 BergSchergParams {
269 mu: vec![1.0],
270 lam: vec![1.0],
271 }
272}
273
274fn fit_berg_scherg_3shell(r1: f64, r2: f64, ratio: f64) -> (Vec<f64>, Vec<f64>) {
279 let n_max = 50;
284 let mut cn = Vec::with_capacity(n_max);
285
286 for n in 1..=n_max {
287 let nf = n as f64;
288 let c = exact_series_coeff(nf, r1, r2, ratio);
289 cn.push(c);
290 }
291
292 let (mu, lam) = fit_exponential_sum(&cn, 3);
300 (mu, lam)
301}
302
303fn exact_series_coeff(n: f64, r1: f64, r2: f64, ratio: f64) -> f64 {
308 let n1 = n;
318 let p = 2.0 * n1 + 1.0;
319
320 let r1_n = r1.powf(p);
321 let r2_n = r2.powf(p);
322
323 let f12 = (n1 * ratio + n1 + 1.0) * (n1 + (n1 + 1.0) * ratio) / (p * p);
325 let g12 = (ratio - 1.0) * (ratio - 1.0) * n1 * (n1 + 1.0) / (p * p);
326
327 let a = f12 + g12 * (r1_n / r2_n);
329 let b = f12 * r2_n + g12 * r1_n;
330
331 let f23 = ((n1 + 1.0) / ratio + n1) * ((n1 + 1.0) + n1 / ratio) / (p * p);
333 let g23 =
334 (1.0 / ratio - 1.0) * (1.0 / ratio - 1.0) * n1 * (n1 + 1.0) / (p * p);
335
336 let denom = f23 * a + g23 * b / r2_n;
337
338 if denom.abs() < 1e-30 {
339 1.0
340 } else {
341 (f12 * f23) / denom
343 }
344}
345
346fn fit_exponential_sum(cn: &[f64], m: usize) -> (Vec<f64>, Vec<f64>) {
352 let n = cn.len();
353 if n < 2 * m {
354 return (vec![1.0; m], vec![1.0 / m as f64; m]);
356 }
357
358 let mut h_mat = vec![vec![0.0; m]; n - m];
364 let mut h_rhs = vec![0.0; n - m];
365
366 for i in 0..(n - m) {
367 for j in 0..m {
368 h_mat[i][j] = cn[i + j];
369 }
370 h_rhs[i] = -cn[i + m];
371 }
372
373 let mut hth = vec![vec![0.0; m]; m];
375 let mut htb = vec![0.0; m];
376
377 for i in 0..m {
378 for j in 0..m {
379 for k in 0..(n - m) {
380 hth[i][j] += h_mat[k][i] * h_mat[k][j];
381 }
382 }
383 for k in 0..(n - m) {
384 htb[i] += h_mat[k][i] * h_rhs[k];
385 }
386 }
387
388 let a = solve_small_system(&hth, &htb, m);
390
391 let mu = polynomial_roots(&a, m);
394
395 let mut vand = vec![vec![0.0; m]; m.min(n)];
398 let rows = m.min(n);
399 for i in 0..rows {
400 for k in 0..m {
401 vand[i][k] = mu[k].powi(i as i32 + 1);
402 }
403 }
404
405 let cn_sub: Vec<f64> = cn[..rows].to_vec();
406 let lam = solve_small_system_rect(&vand, &cn_sub, rows, m);
407
408 (mu, lam)
409}
410
411fn solve_small_system(a: &[Vec<f64>], b: &[f64], m: usize) -> Vec<f64> {
413 let mut aug = vec![vec![0.0; m + 1]; m];
414 for i in 0..m {
415 for j in 0..m {
416 aug[i][j] = a[i][j];
417 }
418 aug[i][m] = b[i];
419 }
420
421 for col in 0..m {
423 let mut max_row = col;
424 let mut max_val = aug[col][col].abs();
425 for row in (col + 1)..m {
426 if aug[row][col].abs() > max_val {
427 max_val = aug[row][col].abs();
428 max_row = row;
429 }
430 }
431 aug.swap(col, max_row);
432
433 let pivot = aug[col][col];
434 if pivot.abs() < 1e-30 {
435 continue;
436 }
437
438 for row in (col + 1)..m {
439 let factor = aug[row][col] / pivot;
440 for j in col..=m {
441 aug[row][j] -= factor * aug[col][j];
442 }
443 }
444 }
445
446 let mut x = vec![0.0; m];
448 for i in (0..m).rev() {
449 let mut sum = aug[i][m];
450 for j in (i + 1)..m {
451 sum -= aug[i][j] * x[j];
452 }
453 if aug[i][i].abs() > 1e-30 {
454 x[i] = sum / aug[i][i];
455 }
456 }
457 x
458}
459
460fn solve_small_system_rect(a: &[Vec<f64>], b: &[f64], rows: usize, cols: usize) -> Vec<f64> {
462 let mut ata = vec![vec![0.0; cols]; cols];
464 let mut atb = vec![0.0; cols];
465 for i in 0..cols {
466 for j in 0..cols {
467 for k in 0..rows {
468 ata[i][j] += a[k][i] * a[k][j];
469 }
470 }
471 for k in 0..rows {
472 atb[i] += a[k][i] * b[k];
473 }
474 }
475 solve_small_system(&ata, &atb, cols)
476}
477
478fn polynomial_roots(a: &[f64], m: usize) -> Vec<f64> {
484 if m == 0 {
485 return vec![];
486 }
487 if m == 1 {
488 return vec![-a[0]];
489 }
490
491 let mut comp = vec![vec![0.0; m]; m];
493 for i in 1..m {
494 comp[i][i - 1] = 1.0;
495 }
496 for i in 0..m {
497 comp[i][m - 1] = -a[i];
498 }
499
500 eigenvalues_qr(&comp, m)
503}
504
505fn eigenvalues_qr(mat: &[Vec<f64>], m: usize) -> Vec<f64> {
507 let mut a = mat.to_vec();
508
509 for _ in 0..200 {
510 let mut q = vec![vec![0.0; m]; m];
512 let mut r = vec![vec![0.0; m]; m];
513
514 for j in 0..m {
515 let mut v = vec![0.0; m];
517 for i in 0..m {
518 v[i] = a[i][j];
519 }
520
521 for k in 0..j {
523 let mut dot = 0.0;
524 for i in 0..m {
525 dot += q[i][k] * a[i][j];
526 }
527 r[k][j] = dot;
528 for i in 0..m {
529 v[i] -= dot * q[i][k];
530 }
531 }
532
533 let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
534 r[j][j] = norm;
535 if norm > 1e-30 {
536 for i in 0..m {
537 q[i][j] = v[i] / norm;
538 }
539 }
540 }
541
542 let mut new_a = vec![vec![0.0; m]; m];
544 for i in 0..m {
545 for j in 0..m {
546 for k in 0..m {
547 new_a[i][j] += r[i][k] * q[k][j];
548 }
549 }
550 }
551 a = new_a;
552
553 let mut off_diag = 0.0;
555 for i in 1..m {
556 off_diag += a[i][i - 1].abs();
557 }
558 if off_diag < 1e-12 {
559 break;
560 }
561 }
562
563 (0..m).map(|i| a[i][i]).collect()
565}
566
567fn sphere_potential(
574 rd: &[f64; 3],
575 q: &[f64; 3],
576 re: &[f64; 3],
577 bs: &BergSchergParams,
578 outer_radius: f64,
579) -> f64 {
580 let mut total = 0.0;
581
582 for (&mu_k, &lam_k) in bs.mu.iter().zip(bs.lam.iter()) {
583 let rd_k = [rd[0] * mu_k, rd[1] * mu_k, rd[2] * mu_k];
585
586 total += lam_k * homogeneous_sphere_potential(&rd_k, q, re, outer_radius);
587 }
588
589 total
590}
591
592fn homogeneous_sphere_potential(
601 rd: &[f64; 3],
602 q: &[f64; 3],
603 re: &[f64; 3],
604 _radius: f64,
605) -> f64 {
606 let d = [re[0] - rd[0], re[1] - rd[1], re[2] - rd[2]];
608 let d_len = (d[0] * d[0] + d[1] * d[1] + d[2] * d[2]).sqrt();
609
610 if d_len < 1e-15 {
611 return 0.0;
612 }
613
614 let re_len = (re[0] * re[0] + re[1] * re[1] + re[2] * re[2]).sqrt();
615 if re_len < 1e-15 {
616 return 0.0;
617 }
618
619 let d_dot_q = d[0] * q[0] + d[1] * q[1] + d[2] * q[2];
621 let re_dot_d = re[0] * d[0] + re[1] * d[1] + re[2] * d[2];
622 let re_dot_q = re[0] * q[0] + re[1] * q[1] + re[2] * q[2];
623 let d_sq = d_len * d_len;
624
625 let f = d_len * (re_len * d_len + re_dot_d);
627 if f.abs() < 1e-30 {
628 return 0.0;
629 }
630
631 let inv_4pi = 1.0 / (4.0 * std::f64::consts::PI);
632
633 let v = inv_4pi * (2.0 * d_dot_q * re_dot_d / (d_len.powi(3) * re_len)
646 - d_sq * re_dot_q / (d_len.powi(3) * re_len)
647 + d_dot_q / (d_len * f));
648
649 v
650}
651
652#[cfg(test)]
653mod tests {
654 use super::*;
655 use crate::source_space::ico_source_space;
656
657 #[test]
658 fn test_default_sphere_model() {
659 let s = SphereModel::default();
660 assert_eq!(s.radii.len(), 3);
661 assert_eq!(s.conductivities.len(), 3);
662 assert!((s.outer_radius() - 0.075).abs() < 1e-10);
663 }
664
665 #[test]
666 fn test_make_sphere_forward_shape() {
667 let elec = Array2::from_shape_vec(
668 (4, 3),
669 vec![
670 0.07, 0.0, 0.04, -0.07, 0.0, 0.04, 0.0, 0.07, 0.04, 0.0, -0.07, 0.04,
671 ],
672 )
673 .unwrap();
674 let (src_pos, src_nn) = ico_source_space(1, 0.06, [0.0, 0.0, 0.04]);
675 let sphere = SphereModel::default();
676 let fwd = make_sphere_forward(&elec, &src_pos, &src_nn, &sphere);
677
678 assert_eq!(fwd.gain.nrows(), 4);
679 assert_eq!(fwd.gain.ncols(), src_pos.nrows());
680 assert_eq!(fwd.n_sources, src_pos.nrows());
681 assert!(fwd.gain.iter().all(|v| v.is_finite()));
682 }
683
684 #[test]
685 fn test_forward_average_referenced() {
686 let elec = Array2::from_shape_vec(
687 (4, 3),
688 vec![
689 0.07, 0.0, 0.04, -0.07, 0.0, 0.04, 0.0, 0.07, 0.04, 0.0, -0.07, 0.04,
690 ],
691 )
692 .unwrap();
693 let (src_pos, src_nn) = ico_source_space(1, 0.06, [0.0, 0.0, 0.04]);
694 let sphere = SphereModel::default();
695 let fwd = make_sphere_forward(&elec, &src_pos, &src_nn, &sphere);
696
697 for s in 0..fwd.n_sources {
699 let col_sum: f64 = (0..4).map(|e| fwd.gain[[e, s]]).sum();
700 assert!(
701 col_sum.abs() < 1e-12,
702 "Column {s} sum = {col_sum}, expected ≈ 0"
703 );
704 }
705 }
706
707 #[test]
708 fn test_forward_not_all_zeros() {
709 let elec = Array2::from_shape_vec(
710 (4, 3),
711 vec![
712 0.07, 0.0, 0.04, -0.07, 0.0, 0.04, 0.0, 0.07, 0.04, 0.0, -0.07, 0.04,
713 ],
714 )
715 .unwrap();
716 let (src_pos, src_nn) = ico_source_space(1, 0.06, [0.0, 0.0, 0.04]);
717 let sphere = SphereModel::default();
718 let fwd = make_sphere_forward(&elec, &src_pos, &src_nn, &sphere);
719
720 let max_abs = fwd.gain.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
721 assert!(
722 max_abs > 1e-20,
723 "Gain matrix should not be all zeros, max = {max_abs}"
724 );
725 }
726
727 #[test]
728 fn test_forward_symmetry_opposite_dipoles() {
729 let elec = Array2::from_shape_vec(
732 (3, 3),
733 vec![
734 0.075, 0.0, 0.04, -0.075, 0.0, 0.04, 0.0, 0.0, 0.115, ],
738 )
739 .unwrap();
740
741 let src_pos = Array2::from_shape_vec((1, 3), vec![0.0, 0.0, 0.09]).unwrap();
743 let src_nn = Array2::from_shape_vec((1, 3), vec![1.0, 0.0, 0.0]).unwrap(); let sphere = SphereModel::default();
746 let fwd = make_sphere_forward(&elec, &src_pos, &src_nn, &sphere);
747
748 let v_right = fwd.gain[[0, 0]];
750 let v_left = fwd.gain[[1, 0]];
751 assert!(
753 (v_right + v_left).abs() < (v_right - v_left).abs() * 0.5 || v_right.abs() < 1e-20,
754 "Symmetric electrodes should see opposite potentials: right={v_right}, left={v_left}"
755 );
756 }
757
758 #[test]
759 fn test_free_orientation_forward_shape() {
760 let elec = Array2::from_shape_vec(
761 (4, 3),
762 vec![
763 0.07, 0.0, 0.04, -0.07, 0.0, 0.04, 0.0, 0.07, 0.04, 0.0, -0.07, 0.04,
764 ],
765 )
766 .unwrap();
767 let (src_pos, _) = ico_source_space(1, 0.06, [0.0, 0.0, 0.04]);
768 let sphere = SphereModel::default();
769 let fwd = make_sphere_forward_free(&elec, &src_pos, &sphere);
770
771 assert_eq!(fwd.gain.nrows(), 4);
772 assert_eq!(fwd.gain.ncols(), src_pos.nrows() * 3);
773 assert_eq!(fwd.n_sources, src_pos.nrows());
774 assert!(fwd.gain.iter().all(|v| v.is_finite()));
775 }
776
777 #[test]
778 fn test_single_shell_forward() {
779 let elec = Array2::from_shape_vec(
780 (4, 3),
781 vec![
782 0.07, 0.0, 0.04, -0.07, 0.0, 0.04, 0.0, 0.07, 0.04, 0.0, -0.07, 0.04,
783 ],
784 )
785 .unwrap();
786 let (src_pos, src_nn) = ico_source_space(1, 0.06, [0.0, 0.0, 0.04]);
787 let sphere = SphereModel::single_shell(0.075, 0.33, [0.0, 0.0, 0.04]);
788 let fwd = make_sphere_forward(&elec, &src_pos, &src_nn, &sphere);
789
790 assert_eq!(fwd.gain.nrows(), 4);
791 assert!(fwd.gain.iter().all(|v| v.is_finite()));
792 let max_abs = fwd.gain.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
793 assert!(max_abs > 1e-20);
794 }
795
796 #[test]
797 fn test_end_to_end_forward_to_inverse() {
798 use crate::{make_inverse_operator, apply_inverse, InverseMethod, NoiseCov};
800
801 let n_elec = 8;
802 let elec = Array2::from_shape_fn((n_elec, 3), |(i, j)| {
803 let theta = 2.0 * std::f64::consts::PI * i as f64 / n_elec as f64;
804 match j {
805 0 => 0.075 * theta.cos(),
806 1 => 0.075 * theta.sin(),
807 _ => 0.04,
808 }
809 });
810 let (src_pos, src_nn) = ico_source_space(2, 0.06, [0.0, 0.0, 0.04]);
811 let sphere = SphereModel::default();
812 let fwd = make_sphere_forward(&elec, &src_pos, &src_nn, &sphere);
813
814 let cov = NoiseCov::diagonal(vec![1e-12; n_elec]);
815 let inv = make_inverse_operator(&fwd, &cov, None).unwrap();
816
817 let data = Array2::from_elem((n_elec, 10), 1e-6);
818 let stc = apply_inverse(&data, &inv, 1.0 / 9.0, InverseMethod::DSPM).unwrap();
819 assert_eq!(stc.data.nrows(), src_pos.nrows());
820 assert!(stc.data.iter().all(|v| v.is_finite()));
821 }
822}