use crate::api::{Direction, Flags, Plan2D};
use crate::kernel::{Complex, Float};
use super::{
compute_kernel_width, next_smooth_number, precompute_deconv_factors, NufftError, NufftOptions,
NufftResult,
};
fn gaussian_weights_1d<T: Float>(x: f64, n_grid: usize, kernel_width: usize) -> Vec<(usize, T)> {
let grid_spacing = 2.0 * core::f64::consts::PI / (n_grid as f64);
let half_width = kernel_width / 2;
let beta = 2.3 * (kernel_width as f64);
let grid_pos = x / grid_spacing;
let center = grid_pos.round() as isize;
let mut coeffs = Vec::with_capacity(kernel_width + 1);
for offset in -(half_width as isize)..=(half_width as isize) {
let grid_idx = (center + offset).rem_euclid(n_grid as isize) as usize;
let grid_x = (grid_idx as f64) * grid_spacing;
let mut dx = x - grid_x;
if dx > core::f64::consts::PI {
dx -= 2.0 * core::f64::consts::PI;
} else if dx < -core::f64::consts::PI {
dx += 2.0 * core::f64::consts::PI;
}
let normalized_dx = dx / (grid_spacing * (half_width as f64));
let weight = (-beta * normalized_dx * normalized_dx).exp();
if weight > 1e-15 {
coeffs.push((grid_idx, T::from_f64(weight)));
}
}
coeffs
}
#[inline]
fn normalize_coord(p: f64) -> Result<f64, NufftError> {
if !(-core::f64::consts::PI..=core::f64::consts::PI).contains(&p) {
return Err(NufftError::PointsOutOfRange);
}
Ok(p + core::f64::consts::PI)
}
pub fn nufft2d_type1<T: Float>(
x: &[f64],
y: &[f64],
c: &[Complex<T>],
n1: usize,
n2: usize,
options: &NufftOptions,
) -> NufftResult<Vec<Complex<T>>> {
if n1 == 0 {
return Err(NufftError::InvalidSize(0));
}
if n2 == 0 {
return Err(NufftError::InvalidSize(0));
}
if options.tolerance <= 0.0 {
return Err(NufftError::InvalidTolerance);
}
let m = c.len();
if x.len() != m || y.len() != m {
return Err(NufftError::ExecutionFailed(format!(
"x ({}) / y ({}) / c ({}) lengths must match",
x.len(),
y.len(),
m
)));
}
let kernel_width = compute_kernel_width(options.tolerance, options.kernel_width);
let n_over1 = next_smooth_number(((n1 as f64) * options.oversampling).ceil() as usize);
let n_over2 = next_smooth_number(((n2 as f64) * options.oversampling).ceil() as usize);
let mut xn = Vec::with_capacity(m);
let mut yn = Vec::with_capacity(m);
for (&xi, &yi) in x.iter().zip(y.iter()) {
xn.push(normalize_coord(xi)?);
yn.push(normalize_coord(yi)?);
}
let wx: Vec<Vec<(usize, T)>> = xn
.iter()
.map(|&xi| gaussian_weights_1d(xi, n_over1, kernel_width))
.collect();
let wy: Vec<Vec<(usize, T)>> = yn
.iter()
.map(|&yi| gaussian_weights_1d(yi, n_over2, kernel_width))
.collect();
let mut grid = vec![Complex::<T>::zero(); n_over1 * n_over2];
for j in 0..m {
let val = c[j];
for &(ix, wx_val) in &wx[j] {
for &(iy, wy_val) in &wy[j] {
let flat = ix * n_over2 + iy;
let w = wx_val * wy_val;
grid[flat] = grid[flat] + Complex::new(val.re * w, val.im * w);
}
}
}
let plan = Plan2D::new(n_over1, n_over2, Direction::Forward, Flags::ESTIMATE)
.ok_or(NufftError::PlanFailed)?;
let mut fft_result = vec![Complex::<T>::zero(); n_over1 * n_over2];
plan.execute(&grid, &mut fft_result);
let deconv1 = precompute_deconv_factors::<T>(n1, n_over1, kernel_width);
let deconv2 = precompute_deconv_factors::<T>(n2, n_over2, kernel_width);
let half1 = n1 / 2;
let half2 = n2 / 2;
let max_deconv = T::from_f64(1.0 / options.tolerance);
let mut result = Vec::with_capacity(n1 * n2);
for k1 in 0..n1 {
let grid_idx1 = if k1 < half1 { k1 } else { n_over1 - (n1 - k1) };
for k2 in 0..n2 {
let grid_idx2 = if k2 < half2 { k2 } else { n_over2 - (n2 - k2) };
let flat_grid = grid_idx1 * n_over2 + grid_idx2;
let d1 = if deconv1[k1].re > max_deconv {
Complex::new(max_deconv, T::ZERO)
} else {
deconv1[k1]
};
let d2 = if deconv2[k2].re > max_deconv {
Complex::new(max_deconv, T::ZERO)
} else {
deconv2[k2]
};
result.push(fft_result[flat_grid] * d1 * d2);
}
}
Ok(result)
}
pub fn nufft2d_type2<T: Float>(
f: &[Complex<T>],
x: &[f64],
y: &[f64],
n1: usize,
n2: usize,
options: &NufftOptions,
) -> NufftResult<Vec<Complex<T>>> {
if n1 == 0 {
return Err(NufftError::InvalidSize(0));
}
if n2 == 0 {
return Err(NufftError::InvalidSize(0));
}
if f.len() != n1 * n2 {
return Err(NufftError::ExecutionFailed(format!(
"f length {} must equal n1*n2 = {}",
f.len(),
n1 * n2
)));
}
if options.tolerance <= 0.0 {
return Err(NufftError::InvalidTolerance);
}
let m = x.len();
if y.len() != m {
return Err(NufftError::ExecutionFailed(format!(
"x ({}) and y ({}) lengths must match",
m,
y.len()
)));
}
let kernel_width = compute_kernel_width(options.tolerance, options.kernel_width);
let n_over1 = next_smooth_number(((n1 as f64) * options.oversampling).ceil() as usize);
let n_over2 = next_smooth_number(((n2 as f64) * options.oversampling).ceil() as usize);
let mut xn = Vec::with_capacity(m);
let mut yn = Vec::with_capacity(m);
for (&xi, &yi) in x.iter().zip(y.iter()) {
xn.push(normalize_coord(xi)?);
yn.push(normalize_coord(yi)?);
}
let deconv1 = precompute_deconv_factors::<T>(n1, n_over1, kernel_width);
let deconv2 = precompute_deconv_factors::<T>(n2, n_over2, kernel_width);
let half1 = n1 / 2;
let half2 = n2 / 2;
let max_deconv = T::from_f64(1.0 / options.tolerance);
let mut grid = vec![Complex::<T>::zero(); n_over1 * n_over2];
for k1 in 0..n1 {
let grid_idx1 = if k1 < half1 { k1 } else { n_over1 - (n1 - k1) };
let d1 = if deconv1[k1].re > max_deconv {
Complex::new(max_deconv, T::ZERO)
} else {
deconv1[k1]
};
for k2 in 0..n2 {
let grid_idx2 = if k2 < half2 { k2 } else { n_over2 - (n2 - k2) };
let flat_in = k1 * n2 + k2;
let flat_grid = grid_idx1 * n_over2 + grid_idx2;
let d2 = if deconv2[k2].re > max_deconv {
Complex::new(max_deconv, T::ZERO)
} else {
deconv2[k2]
};
grid[flat_grid] = f[flat_in] * d1 * d2;
}
}
let plan = Plan2D::new(n_over1, n_over2, Direction::Backward, Flags::ESTIMATE)
.ok_or(NufftError::PlanFailed)?;
let mut ifft_result = vec![Complex::<T>::zero(); n_over1 * n_over2];
plan.execute(&grid, &mut ifft_result);
let scale = T::ONE / T::from_usize(n_over1 * n_over2);
for c_val in &mut ifft_result {
*c_val = Complex::new(c_val.re * scale, c_val.im * scale);
}
let wx: Vec<Vec<(usize, T)>> = xn
.iter()
.map(|&xi| gaussian_weights_1d(xi, n_over1, kernel_width))
.collect();
let wy: Vec<Vec<(usize, T)>> = yn
.iter()
.map(|&yi| gaussian_weights_1d(yi, n_over2, kernel_width))
.collect();
let mut result = Vec::with_capacity(m);
for j in 0..m {
let mut sum = Complex::<T>::zero();
for &(ix, wx_val) in &wx[j] {
for &(iy, wy_val) in &wy[j] {
let flat = ix * n_over2 + iy;
let w = wx_val * wy_val;
let sample = ifft_result[flat];
sum = sum + Complex::new(sample.re * w, sample.im * w);
}
}
result.push(sum);
}
Ok(result)
}
pub fn nufft2d_type1_default<T: Float>(
x: &[f64],
y: &[f64],
c: &[Complex<T>],
n1: usize,
n2: usize,
tolerance: f64,
) -> NufftResult<Vec<Complex<T>>> {
let options = NufftOptions {
tolerance,
..Default::default()
};
nufft2d_type1(x, y, c, n1, n2, &options)
}
pub fn nufft2d_type2_default<T: Float>(
f: &[Complex<T>],
x: &[f64],
y: &[f64],
n1: usize,
n2: usize,
tolerance: f64,
) -> NufftResult<Vec<Complex<T>>> {
let options = NufftOptions {
tolerance,
..Default::default()
};
nufft2d_type2(f, x, y, n1, n2, &options)
}
#[cfg(test)]
mod tests {
use super::*;
fn opts() -> NufftOptions {
NufftOptions::default()
}
#[test]
fn test_2d_type1_single_point_correctness() {
let x = vec![0.0f64];
let y = vec![0.0f64];
let c = vec![Complex::new(1.0f64, 0.0)];
let n1 = 16;
let n2 = 16;
let result = nufft2d_type1(&x, &y, &c, n1, n2, &opts()).expect("2D Type 1 failed");
assert_eq!(result.len(), n1 * n2);
let dc_mag = result[0].norm();
assert!(dc_mag > 0.0, "DC bin must be non-zero");
for (idx, &v) in result.iter().enumerate() {
assert!(
v.re.is_finite() && v.im.is_finite(),
"Bin {idx} is non-finite: {v:?}"
);
}
let mut any_near_dc = false;
for &v in &result {
if v.norm() > 0.0 {
any_near_dc = true;
break;
}
}
assert!(any_near_dc, "At least one non-zero bin is expected");
}
#[test]
fn test_2d_type2_dc_constant() {
let n1 = 8;
let n2 = 8;
let mut f = vec![Complex::<f64>::zero(); n1 * n2];
f[0] = Complex::new(1.0, 0.0);
let x = vec![-1.0, 0.0, 1.0, 2.0];
let y = vec![-1.0, 0.0, 1.0, 2.0];
let result = nufft2d_type2(&f, &x, &y, n1, n2, &opts()).expect("2D Type 2 failed");
assert_eq!(result.len(), x.len());
}
#[test]
fn test_2d_type1_type2_roundtrip() {
let n1 = 16;
let n2 = 16;
let m = 5;
let x: Vec<f64> = (0..m).map(|i| -1.5 + i as f64 * 0.7).collect();
let y: Vec<f64> = (0..m).map(|i| -1.0 + i as f64 * 0.5).collect();
let c: Vec<Complex<f64>> = (0..m)
.map(|i| Complex::new((i as f64 * 0.5).cos(), (i as f64 * 0.5).sin()))
.collect();
let f = nufft2d_type1(&x, &y, &c, n1, n2, &opts()).expect("2D Type 1 failed");
let recovered = nufft2d_type2(&f, &x, &y, n1, n2, &opts()).expect("2D Type 2 failed");
assert_eq!(recovered.len(), m);
for (j, &v) in recovered.iter().enumerate() {
assert!(
v.re.is_finite() && v.im.is_finite(),
"Recovered value {j} is non-finite"
);
}
}
#[test]
fn test_2d_type1_error_invalid_size() {
let x = vec![0.0f64];
let y = vec![0.0f64];
let c = vec![Complex::new(1.0f64, 0.0)];
let result = nufft2d_type1(&x, &y, &c, 0, 16, &opts());
assert!(result.is_err());
let result = nufft2d_type1(&x, &y, &c, 16, 0, &opts());
assert!(result.is_err());
}
#[test]
fn test_2d_type1_error_out_of_range() {
let x = vec![5.0f64]; let y = vec![0.0f64];
let c = vec![Complex::new(1.0f64, 0.0)];
let result = nufft2d_type1(&x, &y, &c, 8, 8, &opts());
assert!(result.is_err());
}
#[test]
fn test_2d_type2_error_mismatched_grid() {
let f = vec![Complex::<f64>::zero(); 15]; let x = vec![0.0f64];
let y = vec![0.0f64];
let result = nufft2d_type2(&f, &x, &y, 4, 4, &opts());
assert!(result.is_err());
}
#[test]
fn test_2d_type1_default_opts_wrapper() {
let x = vec![0.0f64];
let y = vec![0.0f64];
let c = vec![Complex::new(1.0f64, 0.0)];
let result = nufft2d_type1_default(&x, &y, &c, 8, 8, 1e-6);
assert!(result.is_ok());
}
#[test]
fn test_2d_type2_default_opts_wrapper() {
let f = vec![Complex::<f64>::zero(); 8 * 8];
let x = vec![0.0f64];
let y = vec![0.0f64];
let result = nufft2d_type2_default(&f, &x, &y, 8, 8, 1e-6);
assert!(result.is_ok());
}
}