use ndarray::Array1;
use num_complex::Complex32;
use openkspace_io::ismrmrd::Acquisition;
use std::collections::HashMap;
use tracing::{debug, info};
#[derive(Debug, Clone, Default)]
pub struct PhaseCorrector {
corrections: HashMap<(u16, u16), Array1<Complex32>>,
}
impl PhaseCorrector {
pub fn from_phasecorr_acqs(pc: &[Acquisition]) -> Self {
if pc.is_empty() {
return Self::default();
}
let mut bins: HashMap<(u16, u16), (Array1<Complex32>, u32)> = HashMap::new();
for a in pc {
let ns = a.num_samples();
let nc = a.num_channels();
if ns == 0 || nc == 0 {
continue;
}
let view = a.as_array_view(); let mut coil_sum = Array1::<Complex32>::zeros(ns);
for c in 0..nc {
let row = view.row(c);
for k in 0..ns {
coil_sum[k] += row[k];
}
}
let key = (a.header.idx.slice, a.header.idx.segment);
let entry = bins
.entry(key)
.or_insert_with(|| (Array1::<Complex32>::zeros(ns), 0));
if entry.0.len() != ns {
debug!(
"phasecorr: dropping scan with ns={ns} (bin expects {})",
entry.0.len()
);
continue;
}
for k in 0..ns {
entry.0[k] += coil_sum[k];
}
entry.1 += 1;
}
let mut corrections = HashMap::with_capacity(bins.len());
for (key, (sum, count)) in bins {
let mut corr = sum;
let inv_n = 1.0 / count as f32;
corr.mapv_inplace(|v| {
let avg = v * Complex32::new(inv_n, 0.0);
let mag = avg.norm();
if mag < 1e-20 {
Complex32::new(1.0, 0.0)
} else {
avg.conj() / Complex32::new(mag, 0.0)
}
});
corrections.insert(key, corr);
}
info!(
"Phase correction: built {} navigator vectors",
corrections.len()
);
Self { corrections }
}
pub fn is_empty(&self) -> bool {
self.corrections.is_empty()
}
pub fn apply(&self, acq: &mut Acquisition) {
if self.corrections.is_empty() {
return;
}
let key = (acq.header.idx.slice, acq.header.idx.segment);
let corr = match self.corrections.get(&key) {
Some(c) => c,
None => match self.corrections.get(&(u16::MAX, acq.header.idx.segment)) {
Some(c) => c,
None => return,
},
};
let ns = acq.num_samples();
let copy = ns.min(corr.len());
if copy == 0 {
return;
}
let nc = acq.num_channels();
let mut view = acq.as_array_view_mut(); for c in 0..nc {
let mut row = view.row_mut(c);
for k in 0..copy {
row[k] *= corr[k];
}
}
}
}
#[cfg(test)]
#[allow(clippy::field_reassign_with_default)]
mod tests {
use super::*;
use openkspace_io::ismrmrd::AcquisitionHeader;
fn mk_acq(slice: u16, segment: u16, data: &[Complex32], nc: usize) -> Acquisition {
let ns = data.len() / nc;
let mut h = AcquisitionHeader::default();
h.number_of_samples = ns as u16;
h.active_channels = nc as u16;
h.idx.slice = slice;
h.idx.segment = segment;
let flat: Vec<f32> = data.iter().flat_map(|c| [c.re, c.im]).collect();
Acquisition::from_raw_f32(h, &flat)
}
#[test]
fn empty_corrector_is_noop() {
let corr = PhaseCorrector::default();
assert!(corr.is_empty());
let mut a = mk_acq(0, 0, &[Complex32::new(1.0, 2.0)], 1);
corr.apply(&mut a);
assert_eq!(a.data[0], Complex32::new(1.0, 2.0));
}
#[test]
fn removes_constant_phase_offset() {
let ns = 8;
let nc = 1;
let phi = std::f32::consts::PI / 3.0;
let (c, s) = phi.sin_cos();
let nav: Vec<Complex32> = (0..ns).map(|_| Complex32::new(s, c)).collect();
let nav_acq = mk_acq(0, 0, &nav, nc);
let corr = PhaseCorrector::from_phasecorr_acqs(&[nav_acq]);
assert!(!corr.is_empty());
let img: Vec<Complex32> = (1..=ns)
.map(|k| Complex32::new(s, c) * Complex32::new(k as f32, 0.0))
.collect();
let mut img_acq = mk_acq(0, 0, &img, nc);
corr.apply(&mut img_acq);
for k in 0..ns {
let got = img_acq.data[k];
let expected_re = (k + 1) as f32;
assert!(
(got.re - expected_re).abs() < 1e-4 && got.im.abs() < 1e-4,
"k={k}: got {got:?}, expected real {}",
expected_re
);
}
}
}