use ndarray::{s, Array2, Array3, Array4};
use num_complex::Complex32;
use rustfft::FftPlanner;
use std::sync::Arc;
use tracing::{debug, info};
#[derive(Debug, Clone, Copy)]
pub struct PartialFourierPlan {
pub ny: usize,
pub ky_dc: usize,
pub ky_lo: usize,
pub ky_hi: usize,
pub sym_half: usize,
pub ramp: usize,
}
impl PartialFourierPlan {
#[allow(clippy::needless_range_loop)]
pub fn detect(mask: &Array3<bool>, ky_dc: usize) -> Option<Self> {
let ny = mask.shape()[1];
if ny < 8 {
return None;
}
let mut ky_any = vec![false; ny];
for ky in 0..ny {
if mask.slice(s![.., ky, ..]).iter().any(|&b| b) {
ky_any[ky] = true;
}
}
let ky_lo = ky_any.iter().position(|&b| b)?;
let ky_hi = ny - 1 - ky_any.iter().rev().position(|&b| b)?;
let gaps = (ky_lo..=ky_hi).filter(|&k| !ky_any[k]).count();
if gaps > 0 {
debug!("PF: {} gaps in sampled range -- not partial Fourier", gaps);
return None;
}
if ky_dc < ky_lo || ky_dc > ky_hi {
debug!("PF: DC outside sampled range");
return None;
}
let below = ky_dc - ky_lo;
let above = ky_hi - ky_dc;
let (shorter, longer) = if below <= above {
(below, above)
} else {
(above, below)
};
if longer == 0 {
return None;
}
let asymmetry = longer as f32 / shorter.max(1) as f32;
if asymmetry < 1.10 {
return None; }
if shorter < 4 {
debug!(
"PF: symmetric region too small ({} < 4) for reliable phase",
shorter
);
return None;
}
let sym_half = shorter;
let ramp = (sym_half / 4).max(2).min(sym_half - 1);
Some(Self {
ny,
ky_dc,
ky_lo,
ky_hi,
sym_half,
ramp,
})
}
}
fn hann_window(n: usize) -> Vec<f32> {
if n == 0 {
return vec![];
}
if n == 1 {
return vec![1.0];
}
(0..n)
.map(|i| {
let x = (i as f32) / ((n - 1) as f32);
0.5 - 0.5 * (std::f32::consts::TAU * x).cos()
})
.collect()
}
fn low_pass_weights(plan: &PartialFourierPlan) -> Vec<f32> {
let mut w = vec![0.0f32; plan.ny];
let lo = plan.ky_dc.saturating_sub(plan.sym_half);
let hi = (plan.ky_dc + plan.sym_half).min(plan.ny - 1);
let width = hi - lo + 1;
let hann = hann_window(width);
for (idx, ky) in (lo..=hi).enumerate() {
w[ky] = hann[idx];
}
w
}
#[allow(clippy::needless_range_loop)]
fn homodyne_weights(plan: &PartialFourierPlan) -> Vec<f32> {
let mut w = vec![0.0f32; plan.ny];
let c = plan.ky_dc as isize;
let sym = plan.sym_half as isize;
let r = plan.ramp as isize;
let below_len = c - plan.ky_lo as isize; let above_len = plan.ky_hi as isize - c; let above_is_long = above_len >= below_len;
for ky in plan.ky_lo..=plan.ky_hi {
let d = ky as isize - c; let ad = d.abs();
let on_long_side = (d >= 0 && above_is_long) || (d <= 0 && !above_is_long);
let weight = if ad <= sym - r {
1.0
} else if ad <= sym + r {
let t = ((ad - (sym - r)) as f32) / ((2 * r) as f32);
let hann_val = 0.5 - 0.5 * (std::f32::consts::TAU * t).cos();
if on_long_side {
1.0 + hann_val
} else {
1.0 - hann_val
}
} else if on_long_side {
2.0
} else {
0.0
};
w[ky] = weight;
}
w
}
fn ifft2_plane(
plane: &mut Array2<Complex32>,
ifft_x: &Arc<dyn rustfft::Fft<f32>>,
ifft_y: &Arc<dyn rustfft::Fft<f32>>,
) {
let (ny, nx) = plane.dim();
let half_y = ny / 2;
for x in 0..nx {
let mut col: Vec<Complex32> = (0..ny).map(|y| plane[[y, x]]).collect();
col.rotate_left(half_y);
for y in 0..ny {
plane[[y, x]] = col[y];
}
}
let half_x = nx / 2;
for y in 0..ny {
let mut row: Vec<Complex32> = (0..nx).map(|x| plane[[y, x]]).collect();
row.rotate_left(half_x);
for x in 0..nx {
plane[[y, x]] = row[x];
}
}
let mut row_buf = vec![Complex32::new(0.0, 0.0); nx];
for y in 0..ny {
for x in 0..nx {
row_buf[x] = plane[[y, x]];
}
ifft_x.process(&mut row_buf);
let scale = 1.0 / (nx as f32);
for x in 0..nx {
plane[[y, x]] = row_buf[x] * scale;
}
}
let mut col_buf = vec![Complex32::new(0.0, 0.0); ny];
for x in 0..nx {
for y in 0..ny {
col_buf[y] = plane[[y, x]];
}
ifft_y.process(&mut col_buf);
let scale = 1.0 / (ny as f32);
for y in 0..ny {
plane[[y, x]] = col_buf[y] * scale;
}
}
for x in 0..nx {
let mut col: Vec<Complex32> = (0..ny).map(|y| plane[[y, x]]).collect();
col.rotate_left(ny - half_y);
for y in 0..ny {
plane[[y, x]] = col[y];
}
}
for y in 0..ny {
let mut row: Vec<Complex32> = (0..nx).map(|x| plane[[y, x]]).collect();
row.rotate_left(nx - half_x);
for x in 0..nx {
plane[[y, x]] = row[x];
}
}
}
pub fn homodyne_reconstruct(kspace: &Array4<Complex32>, plan: &PartialFourierPlan) -> Array3<f32> {
let (nc, nz, ny, nx) = (
kspace.shape()[0],
kspace.shape()[1],
kspace.shape()[2],
kspace.shape()[3],
);
assert_eq!(ny, plan.ny, "plan ny does not match kspace shape");
info!(
"Partial Fourier (homodyne): ny={}, ky=[{}, {}] (dc={}), sym_half={}, ramp={}",
ny, plan.ky_lo, plan.ky_hi, plan.ky_dc, plan.sym_half, plan.ramp
);
let h_weights = homodyne_weights(plan);
let l_weights = low_pass_weights(plan);
let mut planner = FftPlanner::<f32>::new();
let ifft_x = planner.plan_fft_inverse(nx);
let ifft_y = planner.plan_fft_inverse(ny);
let mut out = Array3::<f32>::zeros((nz, ny, nx));
for kz in 0..nz {
let mut rss_sq = Array2::<f32>::zeros((ny, nx));
for ch in 0..nc {
let mut homo = Array2::<Complex32>::zeros((ny, nx));
let mut lowp = Array2::<Complex32>::zeros((ny, nx));
for y in 0..ny {
let hw = h_weights[y];
let lw = l_weights[y];
for x in 0..nx {
let k = kspace[[ch, kz, y, x]];
homo[[y, x]] = k * hw;
lowp[[y, x]] = k * lw;
}
}
ifft2_plane(&mut homo, &ifft_x, &ifft_y);
ifft2_plane(&mut lowp, &ifft_x, &ifft_y);
for y in 0..ny {
for x in 0..nx {
let phase = lowp[[y, x]];
let norm = phase.norm();
let demod = if norm > 1e-12 {
Complex32::new(phase.re / norm, -phase.im / norm)
} else {
Complex32::new(1.0, 0.0)
};
let img = homo[[y, x]] * demod;
let m = img.re;
rss_sq[[y, x]] += m * m;
}
}
}
for y in 0..ny {
for x in 0..nx {
out[[kz, y, x]] = rss_sq[[y, x]].sqrt();
}
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array3;
#[test]
fn detect_rejects_symmetric_mask() {
let mask = Array3::<bool>::from_elem((1, 64, 4), true);
assert!(PartialFourierPlan::detect(&mask, 32).is_none());
}
#[test]
fn detect_rejects_regular_undersampling() {
let mut mask = Array3::<bool>::from_elem((1, 64, 4), false);
for ky in (0..64).step_by(2) {
for x in 0..4 {
mask[[0, ky, x]] = true;
}
}
assert!(PartialFourierPlan::detect(&mask, 32).is_none());
}
#[test]
fn detect_flags_6_8_partial_fourier() {
let mut mask = Array3::<bool>::from_elem((1, 64, 4), false);
for ky in 8..64 {
for x in 0..4 {
mask[[0, ky, x]] = true;
}
}
let plan = PartialFourierPlan::detect(&mask, 32).expect("PF detected");
assert_eq!(plan.ky_lo, 8);
assert_eq!(plan.ky_hi, 63);
assert_eq!(plan.ky_dc, 32);
assert_eq!(plan.sym_half, 24);
}
#[test]
fn homodyne_weights_sum_plausible() {
let mut mask = Array3::<bool>::from_elem((1, 128, 2), false);
for ky in 16..128 {
mask[[0, ky, 0]] = true;
mask[[0, ky, 1]] = true;
}
let plan = PartialFourierPlan::detect(&mask, 64).expect("PF");
let w = homodyne_weights(&plan);
assert!((0..16).all(|k| w[k] == 0.0));
assert!((w[64] - 1.0).abs() < 1e-4);
assert!((w[127] - 2.0).abs() < 1e-4);
}
#[test]
fn homodyne_beats_zero_fill_on_real_phantom() {
use rustfft::num_complex::Complex32 as C32;
let ny = 64;
let nx = 32;
let mut truth = Array2::<f32>::zeros((ny, nx));
let mut complex_img = Array2::<C32>::zeros((ny, nx));
for y in 0..ny {
for x in 0..nx {
let yy = y as f32 - ny as f32 / 2.0;
let xx = x as f32 - nx as f32 / 2.0;
let r2 = (yy * yy) / 90.0 + (xx * xx) / 40.0;
let mag = (-r2).exp();
let phase = 0.35 * (yy / ny as f32) + 0.25 * (xx / nx as f32);
truth[[y, x]] = mag;
complex_img[[y, x]] = C32::new(mag * phase.cos(), mag * phase.sin());
}
}
let mut planner = FftPlanner::<f32>::new();
let fft_x = planner.plan_fft_forward(nx);
let fft_y = planner.plan_fft_forward(ny);
let mut full_k = Array2::<C32>::zeros((ny, nx));
for y in 0..ny {
for x in 0..nx {
full_k[[y, x]] = complex_img[[y, x]];
}
}
let half_y = ny / 2;
let half_x = nx / 2;
for x in 0..nx {
let mut col: Vec<C32> = (0..ny).map(|y| full_k[[y, x]]).collect();
col.rotate_left(half_y);
for y in 0..ny {
full_k[[y, x]] = col[y];
}
}
for y in 0..ny {
let mut row: Vec<C32> = (0..nx).map(|x| full_k[[y, x]]).collect();
row.rotate_left(half_x);
for x in 0..nx {
full_k[[y, x]] = row[x];
}
}
let mut row_buf = vec![C32::new(0.0, 0.0); nx];
for y in 0..ny {
for x in 0..nx {
row_buf[x] = full_k[[y, x]];
}
fft_x.process(&mut row_buf);
for x in 0..nx {
full_k[[y, x]] = row_buf[x];
}
}
let mut col_buf = vec![C32::new(0.0, 0.0); ny];
for x in 0..nx {
for y in 0..ny {
col_buf[y] = full_k[[y, x]];
}
fft_y.process(&mut col_buf);
for y in 0..ny {
full_k[[y, x]] = col_buf[y];
}
}
for x in 0..nx {
let mut col: Vec<C32> = (0..ny).map(|y| full_k[[y, x]]).collect();
col.rotate_left(ny - half_y);
for y in 0..ny {
full_k[[y, x]] = col[y];
}
}
for y in 0..ny {
let mut row: Vec<C32> = (0..nx).map(|x| full_k[[y, x]]).collect();
row.rotate_left(nx - half_x);
for x in 0..nx {
full_k[[y, x]] = row[x];
}
}
let ky_lo = 16;
let mut partial_k = Array4::<C32>::zeros((1, 1, ny, nx));
let mut mask = Array3::<bool>::from_elem((1, ny, nx), false);
for y in 0..ny {
for x in 0..nx {
if y >= ky_lo {
partial_k[[0, 0, y, x]] = full_k[[y, x]];
mask[[0, y, x]] = true;
}
}
}
let plan = PartialFourierPlan::detect(&mask, ny / 2).expect("PF plan");
let recon = homodyne_reconstruct(&partial_k, &plan);
let mut zf = Array2::<C32>::zeros((ny, nx));
for y in 0..ny {
for x in 0..nx {
zf[[y, x]] = partial_k[[0, 0, y, x]];
}
}
let ifft_x = planner.plan_fft_inverse(nx);
let ifft_y = planner.plan_fft_inverse(ny);
ifft2_plane(&mut zf, &ifft_x, &ifft_y);
let y0 = ny / 4;
let y1 = 3 * ny / 4;
let x0 = nx / 4;
let x1 = 3 * nx / 4;
let mut t_energy = 0.0f32;
let mut h_energy = 0.0f32;
let mut z_energy = 0.0f32;
for y in y0..y1 {
for x in x0..x1 {
t_energy += truth[[y, x]] * truth[[y, x]];
let h = recon[[0, y, x]];
h_energy += h * h;
let z = zf[[y, x]].norm();
z_energy += z * z;
}
}
let h_scale = (t_energy / h_energy.max(1e-20)).sqrt();
let z_scale = (t_energy / z_energy.max(1e-20)).sqrt();
let mut h_err = 0.0f32;
let mut z_err = 0.0f32;
let mut t_norm = 0.0f32;
for y in y0..y1 {
for x in x0..x1 {
let t = truth[[y, x]];
let h = recon[[0, y, x]] * h_scale;
let z = zf[[y, x]].norm() * z_scale;
h_err += (t - h) * (t - h);
z_err += (t - z) * (t - z);
t_norm += t * t;
}
}
let h_nrmse = (h_err / t_norm.max(1e-20)).sqrt();
let z_nrmse = (z_err / t_norm.max(1e-20)).sqrt();
assert!(
h_nrmse < 0.2,
"homodyne NRMSE {:.4} should be small (zf NRMSE {:.4})",
h_nrmse,
z_nrmse
);
}
}