1use ndarray::{s, Array1, Array2, Array3, Array4, ArrayView3, Axis};
21use num_complex::Complex32;
22use tracing::{debug, info, warn};
23
24use crate::prewhiten::{cholesky_lower, invert_lower_triangular};
25
26#[non_exhaustive]
28#[derive(Debug, Clone)]
29pub struct SamplingPattern {
30 pub r: usize,
32 pub acs_start: usize,
34 pub acs_end: usize,
36 pub ky_lo: usize,
38 pub ky_hi: usize,
40}
41
42impl SamplingPattern {
43 pub fn acs_len(&self) -> usize {
44 self.acs_end.saturating_sub(self.acs_start)
45 }
46
47 pub fn detect(ky_any: &[bool]) -> Option<Self> {
51 let _ = ky_any.len();
52 let sampled: Vec<usize> = ky_any
53 .iter()
54 .enumerate()
55 .filter_map(|(i, &b)| if b { Some(i) } else { None })
56 .collect();
57 if sampled.len() < 4 {
58 return None;
59 }
60
61 let ky_lo = *sampled.first()?;
62 let ky_hi = *sampled.last()?;
63
64 let mut best_start = 0usize;
66 let mut best_len = 0usize;
67 let mut cur_start = sampled[0];
68 let mut cur_len = 1usize;
69 for w in sampled.windows(2) {
70 if w[1] == w[0] + 1 {
71 cur_len += 1;
72 } else {
73 if cur_len > best_len {
74 best_len = cur_len;
75 best_start = cur_start;
76 }
77 cur_start = w[1];
78 cur_len = 1;
79 }
80 }
81 if cur_len > best_len {
82 best_len = cur_len;
83 best_start = cur_start;
84 }
85 let acs_start = best_start;
86 let acs_end = best_start + best_len;
87
88 if acs_start == ky_lo && acs_end == ky_hi + 1 {
90 return None;
91 }
92 if best_len < 8 {
94 debug!("GRAPPA: ACS too small ({} ky), refusing", best_len);
95 return None;
96 }
97
98 let mut outside: Vec<usize> = sampled
100 .iter()
101 .copied()
102 .filter(|&i| i < acs_start || i >= acs_end)
103 .collect();
104 outside.sort_unstable();
105 if outside.len() < 2 {
106 return None;
107 }
108 let mut diffs: Vec<usize> = outside.windows(2).map(|w| w[1] - w[0]).collect();
109 diffs.sort_unstable();
110 let r = diffs[diffs.len() / 2];
111 if !(2..=8).contains(&r) {
112 debug!("GRAPPA: unsupported acceleration R={}", r);
113 return None;
114 }
115 let agree = diffs.iter().filter(|&&d| d == r).count();
117 if agree * 2 < diffs.len() {
118 debug!("GRAPPA: irregular pattern, R spacing not consistent");
119 return None;
120 }
121
122 Some(Self {
123 r,
124 acs_start,
125 acs_end,
126 ky_lo,
127 ky_hi,
128 })
129 }
130}
131
132pub struct GrappaKernel {
137 pub r: usize,
139 pub kernel_ky: usize,
141 pub kernel_kx: usize,
144 pub nc: usize,
146 pub(crate) weights: Vec<Array2<Complex32>>,
148}
149
150impl GrappaKernel {
151 pub fn calibrate(
157 acs: ArrayView3<Complex32>,
158 r: usize,
159 kernel_ky: usize,
160 kernel_kx: usize,
161 ridge: f32,
162 ) -> Result<Self, GrappaError> {
163 let (nc, ny_acs, nx_acs) = (acs.shape()[0], acs.shape()[1], acs.shape()[2]);
164 if r < 2 {
165 return Err(GrappaError::BadConfig("acceleration must be >= 2"));
166 }
167 if kernel_ky < 2 || !kernel_ky.is_multiple_of(2) {
168 return Err(GrappaError::BadConfig("kernel_ky must be even >= 2"));
169 }
170 if kernel_kx == 0 || kernel_kx.is_multiple_of(2) {
171 return Err(GrappaError::BadConfig("kernel_kx must be odd >= 1"));
172 }
173
174 let ky_span = (kernel_ky - 1) * r + 1;
178 let kx_half = kernel_kx / 2;
179
180 if ny_acs < ky_span + r {
181 return Err(GrappaError::AcsTooSmall {
182 need: ky_span + r,
183 got: ny_acs,
184 });
185 }
186 if nx_acs < kernel_kx {
187 return Err(GrappaError::AcsTooSmall {
188 need: kernel_kx,
189 got: nx_acs,
190 });
191 }
192
193 let n_src = nc * kernel_ky * kernel_kx;
194 let kky_center_src_row = kernel_ky / 2 - 1; let max_target_off = kky_center_src_row * r + (r - 1);
200 let n_ky_pos = ((ny_acs as isize) - (ky_span as isize).max(0))
201 .max(0)
202 .min((ny_acs as isize) - (max_target_off as isize) - 1)
203 .max(0) as usize;
204 let n_kx_pos = nx_acs.saturating_sub(kernel_kx - 1);
205 let n_pos = n_ky_pos * n_kx_pos;
206 if n_pos < n_src {
207 warn!(
208 "GRAPPA calibration under-determined: {} positions < {} sources",
209 n_pos, n_src
210 );
211 }
212 if n_pos == 0 {
213 return Err(GrappaError::AcsTooSmall {
214 need: ky_span + kernel_kx,
215 got: ny_acs.min(nx_acs),
216 });
217 }
218
219 info!(
220 "GRAPPA calibrate: nc={}, R={}, kernel={}x{}, ACS={}x{}, positions={}, sources={}",
221 nc, r, kernel_ky, kernel_kx, ny_acs, nx_acs, n_pos, n_src
222 );
223
224 let n_tgt = nc * (r - 1);
227 let mut a = Array2::<Complex32>::zeros((n_pos, n_src));
228 let mut b = Array2::<Complex32>::zeros((n_pos, n_tgt));
229
230 let mut p = 0usize;
231 for ky0 in 0..n_ky_pos {
232 for kx0 in 0..n_kx_pos {
233 for ch in 0..nc {
237 for kky in 0..kernel_ky {
238 let src_y = ky0 + kky * r;
239 for kkx in 0..kernel_kx {
240 let src_x = kx0 + kkx;
241 let col = ch * (kernel_ky * kernel_kx) + kky * kernel_kx + kkx;
242 a[[p, col]] = acs[[ch, src_y, src_x]];
243 }
244 }
245 }
246 let tgt_x = kx0 + kx_half;
247 for d in 1..r {
248 let tgt_y = ky0 + kky_center_src_row * r + d;
249 for ch in 0..nc {
250 let col = (d - 1) * nc + ch;
251 b[[p, col]] = acs[[ch, tgt_y, tgt_x]];
252 }
253 }
254 p += 1;
255 }
256 }
257 debug_assert_eq!(p, n_pos);
258
259 let ata = hermitian_gram(&a); let atb = hermitian_mul(&a, &b); let mut ata_reg = ata;
267 let lam = {
268 let mean_diag =
270 (0..n_src).map(|i| ata_reg[[i, i]].re).sum::<f32>() / (n_src as f32).max(1.0);
271 ridge * mean_diag.max(f32::EPSILON)
272 };
273 for i in 0..n_src {
274 ata_reg[[i, i]] += Complex32::new(lam, 0.0);
275 }
276
277 let l = cholesky_lower(&ata_reg).ok_or(GrappaError::CholeskyFailed)?;
279 let l_inv = invert_lower_triangular(&l).ok_or(GrappaError::CholeskyFailed)?;
280 let tmp = matmul(&l_inv, &atb); let l_inv_h = conjugate_transpose(&l_inv);
284 let x = matmul(&l_inv_h, &tmp); let mut weights = Vec::with_capacity(r - 1);
287 for d in 1..r {
288 let col_start = (d - 1) * nc;
289 let col_end = d * nc;
290 let block = x.slice(s![.., col_start..col_end]);
292 let mut w = Array2::<Complex32>::zeros((nc, n_src));
293 for ch in 0..nc {
294 for j in 0..n_src {
295 w[[ch, j]] = block[[j, ch]];
296 }
297 }
298 weights.push(w);
299 }
300
301 Ok(Self {
302 r,
303 kernel_ky,
304 kernel_kx,
305 nc,
306 weights,
307 })
308 }
309
310 pub fn synthesize(&self, kspace: &mut Array4<Complex32>) -> Result<(), GrappaError> {
319 let (nc, nz, ny, nx) = (
320 kspace.shape()[0],
321 kspace.shape()[1],
322 kspace.shape()[2],
323 kspace.shape()[3],
324 );
325 if nc != self.nc {
326 return Err(GrappaError::BadConfig(
327 "coil count mismatch between kernel and kspace",
328 ));
329 }
330 let r = self.r;
331 let kx_half = self.kernel_kx / 2;
332 let kky_center = self.kernel_ky / 2 - 1;
333
334 for kz in 0..nz {
336 let ky_sampled: Vec<usize> = (0..ny)
339 .filter(|&ky| {
340 kspace[[0, kz, ky, nx / 2]] != Complex32::new(0.0, 0.0)
342 || any_nonzero(kspace.slice(s![0, kz, ky, ..]))
343 })
344 .collect();
345
346 let mut filled = 0usize;
347 if self.kernel_ky < 2 {
352 continue;
353 }
354 let mut sampled_set = vec![false; ny];
356 for &k in &ky_sampled {
357 sampled_set[k] = true;
358 }
359
360 for ky0 in 0..ny {
364 let ky_span = (self.kernel_ky - 1) * r + 1;
365 if ky0 + ky_span > ny {
366 break;
367 }
368 let mut ok = true;
370 for k in 0..self.kernel_ky {
371 if !sampled_set[ky0 + k * r] {
372 ok = false;
373 break;
374 }
375 }
376 if !ok {
377 continue;
378 }
379 let any_missing = (1..r).any(|d| !sampled_set[ky0 + kky_center * r + d]);
382 if !any_missing {
383 continue;
384 }
385
386 for kx0 in 0..=(nx - self.kernel_kx) {
388 let mut src =
390 Vec::<Complex32>::with_capacity(nc * self.kernel_ky * self.kernel_kx);
391 for ch in 0..nc {
392 for kky in 0..self.kernel_ky {
393 let sy = ky0 + kky * r;
394 for kkx in 0..self.kernel_kx {
395 src.push(kspace[[ch, kz, sy, kx0 + kkx]]);
396 }
397 }
398 }
399 let src = Array1::from(src);
400 for d in 1..r {
402 let ty = ky0 + kky_center * r + d;
403 if sampled_set[ty] {
404 continue;
405 }
406 let w = &self.weights[d - 1];
407 let tx = kx0 + kx_half;
408 for ch in 0..nc {
409 let mut acc = Complex32::new(0.0, 0.0);
411 let row = w.row(ch);
412 for (a, b) in row.iter().zip(src.iter()) {
413 acc += a * b;
414 }
415 kspace[[ch, kz, ty, tx]] = acc;
416 }
417 filled += 1;
418 }
419 }
420 }
421 debug!("GRAPPA synth: slice {} filled {} targets", kz, filled);
422 }
423 Ok(())
424 }
425}
426
427fn any_nonzero<'a>(row: ndarray::ArrayView1<'a, Complex32>) -> bool {
428 row.iter().any(|c| c.re != 0.0 || c.im != 0.0)
429}
430
431fn hermitian_gram(a: &Array2<Complex32>) -> Array2<Complex32> {
433 let (m, n) = (a.nrows(), a.ncols());
434 let mut out = Array2::<Complex32>::zeros((n, n));
435 for i in 0..n {
436 for j in i..n {
437 let mut s = Complex32::new(0.0, 0.0);
438 for k in 0..m {
439 s += a[[k, i]].conj() * a[[k, j]];
440 }
441 out[[i, j]] = s;
442 if i != j {
443 out[[j, i]] = s.conj();
444 }
445 }
446 }
447 out
448}
449
450fn hermitian_mul(a: &Array2<Complex32>, b: &Array2<Complex32>) -> Array2<Complex32> {
452 let (m, n) = (a.nrows(), a.ncols());
453 let p = b.ncols();
454 debug_assert_eq!(b.nrows(), m);
455 let mut out = Array2::<Complex32>::zeros((n, p));
456 for i in 0..n {
457 for j in 0..p {
458 let mut s = Complex32::new(0.0, 0.0);
459 for k in 0..m {
460 s += a[[k, i]].conj() * b[[k, j]];
461 }
462 out[[i, j]] = s;
463 }
464 }
465 out
466}
467
468fn matmul(a: &Array2<Complex32>, b: &Array2<Complex32>) -> Array2<Complex32> {
469 let (m, k) = (a.nrows(), a.ncols());
470 let n = b.ncols();
471 debug_assert_eq!(b.nrows(), k);
472 let mut out = Array2::<Complex32>::zeros((m, n));
473 for i in 0..m {
474 for j in 0..n {
475 let mut s = Complex32::new(0.0, 0.0);
476 for kk in 0..k {
477 s += a[[i, kk]] * b[[kk, j]];
478 }
479 out[[i, j]] = s;
480 }
481 }
482 out
483}
484
485fn conjugate_transpose(a: &Array2<Complex32>) -> Array2<Complex32> {
486 let (m, n) = (a.nrows(), a.ncols());
487 let mut out = Array2::<Complex32>::zeros((n, m));
488 for i in 0..m {
489 for j in 0..n {
490 out[[j, i]] = a[[i, j]].conj();
491 }
492 }
493 out
494}
495
496#[allow(clippy::needless_range_loop)]
499pub fn detect_pattern(mask: &Array3<bool>) -> Option<SamplingPattern> {
500 let ny = mask.shape()[1];
501 let mut ky_any = vec![false; ny];
502 for ky in 0..ny {
503 let slab = mask.slice(s![.., ky, ..]);
504 if slab.iter().any(|&b| b) {
505 ky_any[ky] = true;
506 }
507 }
508 SamplingPattern::detect(&ky_any)
509}
510
511pub fn extract_acs_slice(
514 kspace: &Array4<Complex32>,
515 kz: usize,
516 pattern: &SamplingPattern,
517) -> Array3<Complex32> {
518 let (nc, _, _, nx) = (
519 kspace.shape()[0],
520 kspace.shape()[1],
521 kspace.shape()[2],
522 kspace.shape()[3],
523 );
524 let acs = kspace.slice(s![.., kz, pattern.acs_start..pattern.acs_end, ..]);
525 let mut out = Array3::<Complex32>::zeros((nc, pattern.acs_len(), nx));
526 for c in 0..nc {
527 for y in 0..pattern.acs_len() {
528 for x in 0..nx {
529 out[[c, y, x]] = acs[[c, y, x]];
530 }
531 }
532 }
533 let _ = Axis(0); out
535}
536
537#[non_exhaustive]
539#[derive(Debug, thiserror::Error)]
540pub enum GrappaError {
541 #[error("bad GRAPPA config: {0}")]
542 BadConfig(&'static str),
543 #[error("ACS region too small (need {need}, got {got})")]
544 AcsTooSmall { need: usize, got: usize },
545 #[error("Cholesky factorization failed (matrix not positive definite)")]
546 CholeskyFailed,
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552 use ndarray::Array4;
553 use num_complex::Complex32;
554
555 fn make_phantom_kspace(nc: usize, ny: usize, nx: usize) -> Array4<Complex32> {
556 let mut k = Array4::<Complex32>::zeros((nc, 1, ny, nx));
559 for c in 0..nc {
560 for y in 0..ny {
561 for x in 0..nx {
562 let yy = y as f32 - ny as f32 / 2.0;
563 let xx = x as f32 - nx as f32 / 2.0;
564 let r2 = (yy * yy + xx * xx) / (ny * nx) as f32;
565 let base = (-3.0 * r2).exp();
566 let phase = (c as f32) * 0.4 + 0.02 * (yy + xx);
567 let amp = 1.0 + 0.3 * (c as f32 - nc as f32 / 2.0) * yy / ny as f32;
568 let px = base * amp;
569 k[[c, 0, y, x]] = Complex32::new(px * phase.cos(), px * phase.sin());
570 }
571 }
572 }
573 k
574 }
575
576 #[test]
577 #[allow(clippy::needless_range_loop)]
578 fn pattern_detect_r2_with_acs() {
579 let ny = 64;
583 let mut mask = vec![false; ny];
584 for ky in (0..ny).step_by(2) {
585 mask[ky] = true;
586 }
587 for ky in 24..40 {
588 mask[ky] = true;
589 }
590 let p = SamplingPattern::detect(&mask).expect("pattern detected");
591 assert_eq!(p.r, 2);
592 assert!(p.acs_start >= 23 && p.acs_start <= 24);
593 assert!(p.acs_end >= 40 && p.acs_end <= 41);
594 }
595
596 #[test]
597 #[allow(clippy::needless_range_loop)]
598 fn pattern_detect_r3() {
599 let ny = 96;
600 let mut mask = vec![false; ny];
601 for ky in (0..ny).step_by(3) {
602 mask[ky] = true;
603 }
604 for ky in 40..60 {
605 mask[ky] = true;
606 }
607 let p = SamplingPattern::detect(&mask).expect("pattern detected");
608 assert_eq!(p.r, 3);
609 assert!(p.acs_start <= 40);
612 assert!(p.acs_end >= 60);
613 assert!(p.acs_end - p.acs_start <= 22);
614 }
615
616 #[test]
617 fn pattern_detect_rejects_fully_sampled() {
618 let mask = vec![true; 64];
619 assert!(SamplingPattern::detect(&mask).is_none());
620 }
621
622 #[test]
623 fn grappa_reconstructs_r2_from_fully_sampled_phantom() {
624 let nc = 4;
629 let ny = 40;
630 let nx = 32;
631 let truth = make_phantom_kspace(nc, ny, nx);
632
633 let acs_start = 8;
635 let acs_end = 32;
636 let acs = {
637 let mut a = Array3::<Complex32>::zeros((nc, acs_end - acs_start, nx));
638 for c in 0..nc {
639 for y in acs_start..acs_end {
640 for x in 0..nx {
641 a[[c, y - acs_start, x]] = truth[[c, 0, y, x]];
642 }
643 }
644 }
645 a
646 };
647 let kernel = GrappaKernel::calibrate(acs.view(), 2, 4, 5, 1e-3).expect("calibrate ok");
648
649 let mut us = Array4::<Complex32>::zeros((nc, 1, ny, nx));
651 for y in 0..ny {
652 let keep = y % 2 == 0 || (y >= acs_start && y < acs_end);
653 if keep {
654 for c in 0..nc {
655 for x in 0..nx {
656 us[[c, 0, y, x]] = truth[[c, 0, y, x]];
657 }
658 }
659 }
660 }
661 let _pattern = SamplingPattern {
662 r: 2,
663 acs_start,
664 acs_end,
665 ky_lo: 0,
666 ky_hi: ny - 1,
667 };
668 kernel.synthesize(&mut us).expect("synthesize failed");
669
670 let kx_half = 2;
675 let kernel_span = (4 - 1) * 2 + 1; let fill_lo = 3; let fill_hi = ny - (kernel_span - 3) - 1; let mut max_err: f32 = 0.0;
679 let mut sum_truth: f32 = 0.0;
680 let mut sum_err: f32 = 0.0;
681 let mut n_checked = 0usize;
682 for y in 0..ny {
683 if y % 2 == 0 || (y >= acs_start && y < acs_end) {
684 continue;
685 }
686 if y < fill_lo || y > fill_hi {
687 continue; }
689 for c in 0..nc {
690 for x in kx_half..(nx - kx_half) {
691 let t = truth[[c, 0, y, x]];
692 let g = us[[c, 0, y, x]];
693 let e = (t - g).norm();
694 max_err = max_err.max(e);
695 sum_err += e * e;
696 sum_truth += t.norm_sqr();
697 n_checked += 1;
698 }
699 }
700 }
701 assert!(n_checked > 0, "no fillable rows were checked");
702 let nrmse = (sum_err / sum_truth.max(1e-20)).sqrt();
703 assert!(
704 nrmse < 0.1,
705 "GRAPPA NRMSE too high: {:.4} (max |err| = {:.4e}, n={})",
706 nrmse,
707 max_err,
708 n_checked
709 );
710 }
711
712 #[test]
717 fn grappa_nc1_does_not_panic() {
718 let nc = 1;
719 let ny = 20;
720 let nx = 16;
721 let mut acs_raw = Array3::<Complex32>::zeros((nc, 12, nx));
723 for c in 0..nc {
724 for y in 0..12 {
725 for x in 0..nx {
726 acs_raw[[c, y, x]] = Complex32::new(1.0, 0.0);
727 }
728 }
729 }
730 let kernel = GrappaKernel::calibrate(acs_raw.view(), 2, 4, 5, 1e-3);
731 let kernel = match kernel {
732 Ok(k) => k,
733 Err(_) => return,
737 };
738 let mut us = Array4::<Complex32>::zeros((nc, 1, ny, nx));
739 for y in (0..ny).step_by(2) {
740 for x in 0..nx {
741 us[[0, 0, y, x]] = Complex32::new(1.0, 0.0);
742 }
743 }
744 let mut before: Vec<Complex32> = Vec::new();
746 for y in (0..ny).step_by(2) {
747 for x in 0..nx {
748 before.push(us[[0, 0, y, x]]);
749 }
750 }
751 let _ = kernel.synthesize(&mut us); for (i, y) in (0usize..).zip((0..ny).step_by(2)) {
754 for x in 0..nx {
755 let want = before[i * nx + x];
756 let got = us[[0, 0, y, x]];
757 assert!(
758 (got - want).norm() < 1e-6,
759 "sampled row y={} modified by synthesize",
760 y
761 );
762 }
763 }
764 }
765}