use num_complex::Complex32;
use openkspace_io::ismrmrd::Acquisition;
use rustfft::{Fft, FftPlanner};
use std::sync::Arc;
use tracing::info;
#[derive(Clone)]
pub struct OversamplingRemover {
encoded_ns: usize,
recon_ns: usize,
ratio: usize,
ifft_plan: Arc<dyn Fft<f32>>,
fft_plan: Arc<dyn Fft<f32>>,
}
impl std::fmt::Debug for OversamplingRemover {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OversamplingRemover")
.field("encoded_ns", &self.encoded_ns)
.field("recon_ns", &self.recon_ns)
.field("ratio", &self.ratio)
.finish()
}
}
impl OversamplingRemover {
pub fn new(encoded_x: usize, recon_x: usize) -> Option<Self> {
if recon_x == 0 || encoded_x <= recon_x {
return None;
}
if !encoded_x.is_multiple_of(recon_x) {
return None;
}
let ratio = encoded_x / recon_x;
if ratio > 8 {
return None;
}
let mut planner = FftPlanner::<f32>::new();
Some(Self {
encoded_ns: encoded_x,
recon_ns: recon_x,
ratio,
ifft_plan: planner.plan_fft_inverse(encoded_x),
fft_plan: planner.plan_fft_forward(recon_x),
})
}
#[inline]
pub fn ratio(&self) -> usize {
self.ratio
}
#[inline]
pub fn output_samples(&self) -> usize {
self.recon_ns
}
#[inline]
pub fn input_samples(&self) -> usize {
self.encoded_ns
}
pub fn log_summary(&self) {
info!(
"Readout oversampling removal: {} -> {} samples (ratio {}x)",
self.encoded_ns, self.recon_ns, self.ratio
);
}
pub fn apply(&self, acq: &mut Acquisition) {
let ns = acq.num_samples();
if ns != self.encoded_ns {
return;
}
let nc = acq.num_channels();
let ifft_scratch_len = self.ifft_plan.get_inplace_scratch_len();
let fft_scratch_len = self.fft_plan.get_inplace_scratch_len();
let mut scratch: Vec<Complex32> =
vec![Complex32::new(0.0, 0.0); ifft_scratch_len.max(fft_scratch_len)];
let ifft_scale = 1.0f32 / (self.encoded_ns as f32);
let crop_start = (self.encoded_ns - self.recon_ns) / 2;
let mut new_data: Vec<Complex32> = Vec::with_capacity(nc * self.recon_ns);
let mut lane_buf: Vec<Complex32> = vec![Complex32::new(0.0, 0.0); self.encoded_ns];
let mut cropped: Vec<Complex32> = vec![Complex32::new(0.0, 0.0); self.recon_ns];
for ch in 0..nc {
let src = acq.channel(ch);
lane_buf.copy_from_slice(src);
rotate_left(&mut lane_buf, self.encoded_ns.div_ceil(2));
self.ifft_plan
.process_with_scratch(&mut lane_buf, &mut scratch[..ifft_scratch_len]);
for v in lane_buf.iter_mut() {
*v *= ifft_scale;
}
rotate_left(&mut lane_buf, self.encoded_ns / 2);
cropped.copy_from_slice(&lane_buf[crop_start..crop_start + self.recon_ns]);
rotate_left(&mut cropped, self.recon_ns.div_ceil(2));
self.fft_plan
.process_with_scratch(&mut cropped, &mut scratch[..fft_scratch_len]);
rotate_left(&mut cropped, self.recon_ns / 2);
new_data.extend_from_slice(&cropped);
}
acq.data = new_data;
acq.header.number_of_samples = self.recon_ns as u16;
acq.header.center_sample = (self.recon_ns / 2) as u16;
acq.header.discard_pre = 0;
acq.header.discard_post = 0;
}
}
fn rotate_left(buf: &mut [Complex32], k: usize) {
let n = buf.len();
if n < 2 || k == 0 || k >= n {
return;
}
let mut tmp: Vec<Complex32> = Vec::with_capacity(n);
for i in 0..n {
tmp.push(buf[(i + k) % n]);
}
buf.copy_from_slice(&tmp);
}
#[cfg(test)]
mod tests {
use super::*;
use openkspace_io::ismrmrd::AcquisitionHeader;
fn zeroed_header(ns: usize, nc: usize) -> AcquisitionHeader {
AcquisitionHeader {
version: 1,
flags: 0,
measurement_uid: 0,
scan_counter: 0,
acquisition_time_stamp: 0,
physiology_time_stamp: [0; 3],
number_of_samples: ns as u16,
available_channels: nc as u16,
active_channels: nc as u16,
channel_mask: [0; 16],
discard_pre: 0,
discard_post: 0,
center_sample: (ns / 2) as u16,
encoding_space_ref: 0,
trajectory_dimensions: 0,
sample_time_us: 0.0,
position: [0.0; 3],
read_dir: [0.0; 3],
phase_dir: [0.0; 3],
slice_dir: [0.0; 3],
patient_table_position: [0.0; 3],
idx: openkspace_io::ismrmrd::EncodingCounters {
kspace_encode_step_1: 0,
kspace_encode_step_2: 0,
average: 0,
slice: 0,
contrast: 0,
phase: 0,
repetition: 0,
set: 0,
segment: 0,
user: [0; 8],
},
user_int: [0; 8],
user_float: [0.0; 8],
}
}
#[test]
fn rejects_non_oversampled() {
assert!(OversamplingRemover::new(256, 256).is_none());
assert!(OversamplingRemover::new(256, 512).is_none());
assert!(OversamplingRemover::new(256, 0).is_none());
}
#[test]
fn rejects_non_integer_ratio() {
assert!(OversamplingRemover::new(300, 256).is_none());
}
#[test]
fn accepts_integer_ratios() {
let r = OversamplingRemover::new(512, 256).unwrap();
assert_eq!(r.ratio(), 2);
assert_eq!(r.input_samples(), 512);
assert_eq!(r.output_samples(), 256);
}
#[test]
fn preserves_dc_only_signal() {
let encoded = 16;
let recon = 8;
let nc = 1;
let header = zeroed_header(encoded, nc);
let mut data = vec![Complex32::new(0.0, 0.0); encoded];
data[encoded / 2] = Complex32::new(10.0, 0.0);
let mut acq = Acquisition { header, data };
let remover = OversamplingRemover::new(encoded, recon).unwrap();
remover.apply(&mut acq);
assert_eq!(acq.num_samples(), recon);
let expected = Complex32::new(10.0 * (recon as f32 / encoded as f32), 0.0);
for i in 0..recon {
let v = acq.data[i];
if i == recon / 2 {
assert!(
(v - expected).norm() < 1e-4,
"centre: got {:?}, expected {:?}",
v,
expected
);
} else {
assert!(
v.norm() < 1e-4,
"off-centre {}: expected ~0, got {:?}",
i,
v
);
}
}
}
#[test]
fn leaves_mismatched_acquisitions_untouched() {
let header = zeroed_header(64, 2);
let data = vec![Complex32::new(1.0, 2.0); 64 * 2];
let mut acq = Acquisition {
header,
data: data.clone(),
};
let remover = OversamplingRemover::new(512, 256).unwrap();
remover.apply(&mut acq);
assert_eq!(acq.num_samples(), 64);
assert_eq!(acq.data, data);
}
}