use crate::error::{FFTError, FFTResult};
use crate::nufft::{fft_internal, gaussian_correction, gaussian_kernel, ifft_internal, OVERSAMPLE};
use scirs2_core::ndarray::Array2;
use scirs2_core::numeric::Complex64;
use scirs2_core::numeric::Zero;
use std::f64::consts::PI;
pub fn nufft2_type1(
xy: &[(f64, f64)],
c: &[Complex64],
n_modes: (usize, usize),
eps: f64,
) -> FFTResult<Array2<Complex64>> {
validate_inputs_2d(xy, c, n_modes, eps)?;
let (n1, n2) = n_modes;
let sigma = OVERSAMPLE;
let ng1 = oversample_grid_size(sigma, n1);
let ng2 = oversample_grid_size(sigma, n2);
let half_w = kernel_half_width_2d(sigma, eps);
let mut grid = vec![Complex64::zero(); ng1 * ng2];
spread_type1_2d(xy, c, &mut grid, ng1, ng2, sigma, half_w);
let grid_fft = fft2d_row_col(&grid, ng1, ng2)?;
let half1 = (n1 / 2) as i64;
let half2 = (n2 / 2) as i64;
let result = Array2::from_shape_fn((n1, n2), |(r, s)| {
let k1 = r as i64 - half1;
let k2 = s as i64 - half2;
let bin1 = ((k1).rem_euclid(ng1 as i64)) as usize;
let bin2 = ((k2).rem_euclid(ng2 as i64)) as usize;
let val = grid_fft[bin1 * ng2 + bin2];
let corr1 = gaussian_correction(k1, ng1, sigma);
let corr2 = gaussian_correction(k2, ng2, sigma);
val * corr1 * corr2
});
Ok(result)
}
pub fn nufft2_type2(
f_hat: &Array2<Complex64>,
xy: &[(f64, f64)],
eps: f64,
) -> FFTResult<Vec<Complex64>> {
let shape = f_hat.shape();
let n1 = shape[0];
let n2 = shape[1];
if n1 == 0 || n2 == 0 {
return Err(FFTError::DimensionError(
"f_hat must have non-zero dimensions".to_string(),
));
}
if eps <= 0.0 {
return Err(FFTError::ValueError("eps must be positive".to_string()));
}
for &(xj, yj) in xy {
if !(-PI..PI).contains(&xj) || !(-PI..PI).contains(&yj) {
return Err(FFTError::ValueError(
"all xy coordinates must lie in [-π, π)".to_string(),
));
}
}
let sigma = OVERSAMPLE;
let ng1 = oversample_grid_size(sigma, n1);
let ng2 = oversample_grid_size(sigma, n2);
let half_w = kernel_half_width_2d(sigma, eps);
let half1 = (n1 / 2) as i64;
let half2 = (n2 / 2) as i64;
let mut grid_freq = vec![Complex64::zero(); ng1 * ng2];
for r in 0..n1 {
for s in 0..n2 {
let k1 = r as i64 - half1;
let k2 = s as i64 - half2;
let corr1 = gaussian_correction(k1, ng1, sigma);
let corr2 = gaussian_correction(k2, ng2, sigma);
let bin1 = ((k1).rem_euclid(ng1 as i64)) as usize;
let bin2 = ((k2).rem_euclid(ng2 as i64)) as usize;
grid_freq[bin1 * ng2 + bin2] = f_hat[(r, s)] * corr1 * corr2;
}
}
let grid_time = ifft2d_row_col(&grid_freq, ng1, ng2)?;
let out = interpolate_type2_2d(xy, &grid_time, ng1, ng2, sigma, half_w);
Ok(out)
}
fn oversample_grid_size(sigma: f64, n: usize) -> usize {
let raw = (sigma * n as f64).ceil() as usize;
if raw % 2 == 0 { raw } else { raw + 1 }
}
fn kernel_half_width_2d(sigma: f64, eps: f64) -> usize {
let w = sigma * ((-eps.ln()) / (PI * PI)).sqrt();
(w.ceil() as usize).max(2)
}
fn validate_inputs_2d(
xy: &[(f64, f64)],
c: &[Complex64],
n_modes: (usize, usize),
eps: f64,
) -> FFTResult<()> {
if xy.len() != c.len() {
return Err(FFTError::DimensionError(
"xy and c must have the same length".to_string(),
));
}
if n_modes.0 == 0 || n_modes.1 == 0 {
return Err(FFTError::ValueError(
"n_modes dimensions must be > 0".to_string(),
));
}
if eps <= 0.0 {
return Err(FFTError::ValueError("eps must be positive".to_string()));
}
for &(xj, yj) in xy {
if !(-PI..PI).contains(&xj) || !(-PI..PI).contains(&yj) {
return Err(FFTError::ValueError(
"all xy coordinates must lie in [-π, π)".to_string(),
));
}
}
Ok(())
}
fn spread_type1_2d(
xy: &[(f64, f64)],
c: &[Complex64],
grid: &mut [Complex64],
ng1: usize,
ng2: usize,
sigma: f64,
half_w: usize,
) {
let h1 = 2.0 * PI / ng1 as f64;
let h2 = 2.0 * PI / ng2 as f64;
let half_w_i = half_w as isize;
for (&(xj, yj), &cj) in xy.iter().zip(c.iter()) {
let xg = (xj + PI) / h1;
let yg = (yj + PI) / h2;
let ix0 = xg.floor() as isize;
let iy0 = yg.floor() as isize;
let wx: Vec<f64> = ((-half_w_i)..=(half_w_i))
.map(|di| gaussian_kernel(xg - (ix0 + di) as f64, sigma))
.collect();
let wy: Vec<f64> = ((-half_w_i)..=(half_w_i))
.map(|dj| gaussian_kernel(yg - (iy0 + dj) as f64, sigma))
.collect();
for (di_idx, di) in ((-half_w_i)..=(half_w_i)).enumerate() {
let ridx = (ix0 + di).rem_euclid(ng1 as isize) as usize;
let wxd = wx[di_idx];
for (dj_idx, dj) in ((-half_w_i)..=(half_w_i)).enumerate() {
let cidx = (iy0 + dj).rem_euclid(ng2 as isize) as usize;
let w = wxd * wy[dj_idx];
grid[ridx * ng2 + cidx] += cj * w;
}
}
}
}
fn interpolate_type2_2d(
xy: &[(f64, f64)],
grid: &[Complex64],
ng1: usize,
ng2: usize,
sigma: f64,
half_w: usize,
) -> Vec<Complex64> {
let h1 = 2.0 * PI / ng1 as f64;
let h2 = 2.0 * PI / ng2 as f64;
let half_w_i = half_w as isize;
let mut out = vec![Complex64::zero(); xy.len()];
for (out_j, &(xj, yj)) in out.iter_mut().zip(xy.iter()) {
let xg = (xj + PI) / h1;
let yg = (yj + PI) / h2;
let ix0 = xg.floor() as isize;
let iy0 = yg.floor() as isize;
let wx: Vec<f64> = ((-half_w_i)..=(half_w_i))
.map(|di| gaussian_kernel(xg - (ix0 + di) as f64, sigma))
.collect();
let wy: Vec<f64> = ((-half_w_i)..=(half_w_i))
.map(|dj| gaussian_kernel(yg - (iy0 + dj) as f64, sigma))
.collect();
let mut acc = Complex64::zero();
for (di_idx, di) in ((-half_w_i)..=(half_w_i)).enumerate() {
let ridx = (ix0 + di).rem_euclid(ng1 as isize) as usize;
let wxd = wx[di_idx];
for (dj_idx, dj) in ((-half_w_i)..=(half_w_i)).enumerate() {
let cidx = (iy0 + dj).rem_euclid(ng2 as isize) as usize;
acc += grid[ridx * ng2 + cidx] * (wxd * wy[dj_idx]);
}
}
*out_j = acc;
}
out
}
fn fft2d_row_col(data: &[Complex64], ng1: usize, ng2: usize) -> FFTResult<Vec<Complex64>> {
let mut buf = data.to_vec();
for r in 0..ng1 {
let row_start = r * ng2;
let row: Vec<Complex64> = buf[row_start..row_start + ng2].to_vec();
let row_fft = fft_internal(&row)?;
buf[row_start..row_start + ng2].copy_from_slice(&row_fft);
}
for s in 0..ng2 {
let col: Vec<Complex64> = (0..ng1).map(|r| buf[r * ng2 + s]).collect();
let col_fft = fft_internal(&col)?;
for (r, val) in col_fft.into_iter().enumerate() {
buf[r * ng2 + s] = val;
}
}
Ok(buf)
}
fn ifft2d_row_col(data: &[Complex64], ng1: usize, ng2: usize) -> FFTResult<Vec<Complex64>> {
let mut buf = data.to_vec();
let scale = 1.0 / (ng1 * ng2) as f64;
for r in 0..ng1 {
let row_start = r * ng2;
let row: Vec<Complex64> = buf[row_start..row_start + ng2]
.iter()
.map(|c| c.conj())
.collect();
let row_fft = fft_internal(&row)?;
for (s, val) in row_fft.into_iter().enumerate() {
buf[row_start + s] = val.conj();
}
}
for s in 0..ng2 {
let col: Vec<Complex64> = (0..ng1).map(|r| buf[r * ng2 + s].conj()).collect();
let col_fft = fft_internal(&col)?;
for (r, val) in col_fft.into_iter().enumerate() {
buf[r * ng2 + s] = val.conj() * scale;
}
}
Ok(buf)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
fn uniform_grid(n1: usize, n2: usize) -> Vec<(f64, f64)> {
let mut pts = Vec::with_capacity(n1 * n2);
for i in 0..n1 {
for j in 0..n2 {
let x = -PI + 2.0 * PI * i as f64 / n1 as f64;
let y = -PI + 2.0 * PI * j as f64 / n2 as f64;
pts.push((x, y));
}
}
pts
}
#[test]
fn test_nufft2_type1_output_shape() {
let pts = uniform_grid(8, 8);
let c: Vec<Complex64> = vec![Complex64::new(1.0, 0.0); pts.len()];
let f_hat = nufft2_type1(&pts, &c, (8, 8), 1e-6).expect("type1");
assert_eq!(f_hat.shape(), &[8, 8]);
}
#[test]
fn test_nufft2_type2_output_length() {
let n1 = 8usize;
let n2 = 8usize;
let f_hat = Array2::from_elem((n1, n2), Complex64::new(1.0, 0.0));
let pts = uniform_grid(4, 4);
let vals = nufft2_type2(&f_hat, &pts, 1e-6).expect("type2");
assert_eq!(vals.len(), pts.len());
}
#[test]
fn test_nufft2_type1_dc_impulse() {
let n1 = 8usize;
let n2 = 8usize;
let pts = uniform_grid(n1, n2);
let c: Vec<Complex64> = vec![Complex64::new(1.0, 0.0); pts.len()];
let f_hat = nufft2_type1(&pts, &c, (n1, n2), 1e-8).expect("type1");
let dc_mag = f_hat[(n1 / 2, n2 / 2)].norm();
assert!(
dc_mag > 0.5 * pts.len() as f64,
"DC={:.3} expected ~{}",
dc_mag,
pts.len()
);
for r in 0..n1 {
for s in 0..n2 {
if r != n1 / 2 || s != n2 / 2 {
assert!(
f_hat[(r, s)].norm() < 0.25 * dc_mag,
"Off-DC mode ({},{}) too large: {:.3}",
r,
s,
f_hat[(r, s)].norm()
);
}
}
}
}
#[test]
fn test_nufft2_dimension_error() {
let pts = vec![(-PI + 0.1, 0.0), (0.0, 0.0)];
let c = vec![Complex64::new(1.0, 0.0)]; let res = nufft2_type1(&pts, &c, (4, 4), 1e-6);
assert!(res.is_err());
}
#[test]
fn test_nufft2_epsilon_error() {
let pts = vec![(-PI + 0.1, 0.0)];
let c = vec![Complex64::new(1.0, 0.0)];
let res = nufft2_type1(&pts, &c, (4, 4), 0.0);
assert!(res.is_err());
}
#[test]
fn test_nufft2_range_error() {
let pts = vec![(PI + 0.5, 0.0)]; let c = vec![Complex64::new(1.0, 0.0)];
let res = nufft2_type1(&pts, &c, (4, 4), 1e-6);
assert!(res.is_err());
}
#[test]
fn test_nufft2_type1_single_tone() {
let n1 = 16usize;
let n2 = 16usize;
let k1_target: i64 = 2;
let k2_target: i64 = 3;
let pts = uniform_grid(n1, n2);
let c: Vec<Complex64> = pts
.iter()
.map(|&(xj, yj)| {
let phase = k1_target as f64 * xj + k2_target as f64 * yj;
Complex64::new(phase.cos(), phase.sin())
})
.collect();
let f_hat = nufft2_type1(&pts, &c, (n1, n2), 1e-8).expect("type1");
let r_peak = (n1 / 2) as i64 + k1_target;
let s_peak = (n2 / 2) as i64 + k2_target;
let peak_mag = f_hat[(r_peak as usize, s_peak as usize)].norm();
let max_other = f_hat
.indexed_iter()
.filter(|&((r, s), _)| r as i64 != r_peak || s as i64 != s_peak)
.map(|(_, v)| v.norm())
.fold(0.0f64, f64::max);
assert!(
peak_mag > 5.0 * max_other,
"peak={:.3} max_other={:.3}",
peak_mag,
max_other
);
}
#[test]
fn test_nufft2_type2_constant_spectrum() {
let n1 = 8usize;
let n2 = 8usize;
let f_hat = Array2::from_elem((n1, n2), Complex64::new(1.0, 0.0));
let xy = vec![(0.0, 0.0)];
let vals = nufft2_type2(&f_hat, &xy, 1e-6).expect("type2");
assert_eq!(vals.len(), 1);
let expected = (n1 * n2) as f64;
assert_relative_eq!(vals[0].re, expected, epsilon = 0.2 * expected);
}
#[test]
fn test_nufft2_type2_empty_spectrum_error() {
let f_hat: Array2<Complex64> = Array2::zeros((0, 4));
let xy = vec![(0.0, 0.0)];
let res = nufft2_type2(&f_hat, &xy, 1e-6);
assert!(res.is_err());
}
#[test]
fn test_fft2d_ifft2d_roundtrip() {
let ng1 = 4usize;
let ng2 = 4usize;
let data: Vec<Complex64> = (0..ng1 * ng2)
.map(|i| Complex64::new(i as f64, -(i as f64) * 0.5))
.collect();
let fft_out = fft2d_row_col(&data, ng1, ng2).expect("fft");
let recovered = ifft2d_row_col(&fft_out, ng1, ng2).expect("ifft");
for (orig, rec) in data.iter().zip(recovered.iter()) {
assert_relative_eq!(orig.re, rec.re, epsilon = 1e-10);
assert_relative_eq!(orig.im, rec.im, epsilon = 1e-10);
}
}
}