use crate::api::{Direction, Flags, Plan};
use crate::kernel::{Complex, Float};
use crate::prelude::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NufftType {
Type1,
Type2,
Type3,
}
#[derive(Debug, Clone, Copy)]
pub struct NufftOptions {
pub oversampling: f64,
pub kernel_width: usize,
pub tolerance: f64,
pub threaded: bool,
}
impl Default for NufftOptions {
fn default() -> Self {
Self {
oversampling: 2.0,
kernel_width: 6,
tolerance: 1e-6,
threaded: true,
}
}
}
#[derive(Debug, Clone)]
pub enum NufftError {
InvalidSize(usize),
PointsOutOfRange,
PlanFailed,
ExecutionFailed(String),
InvalidTolerance,
}
impl core::fmt::Display for NufftError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::InvalidSize(n) => write!(f, "Invalid NUFFT size: {n}"),
Self::PointsOutOfRange => write!(f, "Non-uniform points must be in [-π, π]"),
Self::PlanFailed => write!(f, "Failed to create FFT plan"),
Self::ExecutionFailed(msg) => write!(f, "NUFFT execution failed: {msg}"),
Self::InvalidTolerance => write!(f, "Tolerance must be positive"),
}
}
}
pub type NufftResult<T> = Result<T, NufftError>;
#[allow(clippy::struct_field_names)]
pub struct Nufft<T: Float> {
nufft_type: NufftType,
n_uniform: usize,
n_nonuniform: usize,
n_oversampled: usize,
points: Vec<f64>,
spread_coeffs: Vec<Vec<(usize, T)>>,
deconv_factors: Vec<Complex<T>>,
fft_plan: Option<Plan<T>>,
options: NufftOptions,
}
impl<T: Float> Nufft<T> {
pub fn new(
nufft_type: NufftType,
n_uniform: usize,
points: &[f64],
tolerance: f64,
) -> NufftResult<Self> {
let options = NufftOptions {
tolerance,
..Default::default()
};
Self::with_options(nufft_type, n_uniform, points, &options)
}
pub fn with_options(
nufft_type: NufftType,
n_uniform: usize,
points: &[f64],
options: &NufftOptions,
) -> NufftResult<Self> {
if n_uniform == 0 {
return Err(NufftError::InvalidSize(0));
}
if options.tolerance <= 0.0 {
return Err(NufftError::InvalidTolerance);
}
let kernel_width = compute_kernel_width(options.tolerance, options.kernel_width);
let n_oversampled = ((n_uniform as f64) * options.oversampling).ceil() as usize;
let n_oversampled = next_smooth_number(n_oversampled);
let mut normalized_points = Vec::with_capacity(points.len());
for &p in points {
if !(-core::f64::consts::PI..=core::f64::consts::PI).contains(&p) {
return Err(NufftError::PointsOutOfRange);
}
normalized_points.push(p + core::f64::consts::PI);
}
let spread_coeffs =
precompute_spreading_coeffs(&normalized_points, n_oversampled, kernel_width);
let deconv_factors = precompute_deconv_factors(n_uniform, n_oversampled, kernel_width);
let fft_plan = Plan::dft_1d(n_oversampled, Direction::Forward, Flags::MEASURE);
Ok(Self {
nufft_type,
n_uniform,
n_nonuniform: points.len(),
n_oversampled,
points: normalized_points,
spread_coeffs,
deconv_factors,
fft_plan,
options: NufftOptions {
kernel_width,
..*options
},
})
}
pub fn type1(&self, values: &[Complex<T>]) -> NufftResult<Vec<Complex<T>>> {
if values.len() != self.n_nonuniform {
return Err(NufftError::ExecutionFailed(format!(
"Expected {} values, got {}",
self.n_nonuniform,
values.len()
)));
}
let mut grid = vec![Complex::<T>::zero(); self.n_oversampled];
self.spread_to_grid(values, &mut grid);
let mut fft_result = vec![Complex::<T>::zero(); self.n_oversampled];
if let Some(ref plan) = self.fft_plan {
plan.execute(&grid, &mut fft_result);
} else {
return Err(NufftError::PlanFailed);
}
let mut result = Vec::with_capacity(self.n_uniform);
let half_n = self.n_uniform / 2;
for k in 0..self.n_uniform {
let grid_idx = if k < half_n {
k
} else {
self.n_oversampled - (self.n_uniform - k)
};
let deconv_idx = k;
result.push(fft_result[grid_idx] * self.deconv_factors[deconv_idx]);
}
Ok(result)
}
pub fn type2(&self, coeffs: &[Complex<T>]) -> NufftResult<Vec<Complex<T>>> {
if coeffs.len() != self.n_uniform {
return Err(NufftError::ExecutionFailed(format!(
"Expected {} coefficients, got {}",
self.n_uniform,
coeffs.len()
)));
}
let mut grid = vec![Complex::<T>::zero(); self.n_oversampled];
let half_n = self.n_uniform / 2;
for (k, &coeff) in coeffs.iter().enumerate() {
let grid_idx = if k < half_n {
k
} else {
self.n_oversampled - (self.n_uniform - k)
};
grid[grid_idx] = coeff * self.deconv_factors[k];
}
let mut ifft_result = vec![Complex::<T>::zero(); self.n_oversampled];
if let Some(inv_plan) =
Plan::dft_1d(self.n_oversampled, Direction::Backward, Flags::ESTIMATE)
{
inv_plan.execute(&grid, &mut ifft_result);
} else {
return Err(NufftError::PlanFailed);
}
let scale = T::ONE / T::from_usize(self.n_oversampled);
for c in &mut ifft_result {
*c = Complex::new(c.re * scale, c.im * scale);
}
let result = self.interpolate_from_grid(&ifft_result);
Ok(result)
}
pub fn execute(&self, input: &[Complex<T>]) -> NufftResult<Vec<Complex<T>>> {
match self.nufft_type {
NufftType::Type1 => self.type1(input),
NufftType::Type2 => self.type2(input),
NufftType::Type3 => {
Err(NufftError::ExecutionFailed(
"Type 3 requires separate execute_type3 call".into(),
))
}
}
}
pub fn execute_type3(
&self,
values: &[Complex<T>],
target_points: &[f64],
) -> NufftResult<Vec<Complex<T>>> {
let uniform_coeffs = self.type1(values)?;
let type2_plan = Self::new(
NufftType::Type2,
self.n_uniform,
target_points,
self.options.tolerance,
)?;
type2_plan.type2(&uniform_coeffs)
}
fn spread_to_grid(&self, values: &[Complex<T>], grid: &mut [Complex<T>]) {
for (j, &val) in values.iter().enumerate() {
for &(idx, weight) in &self.spread_coeffs[j] {
grid[idx] = grid[idx] + Complex::new(val.re * weight, val.im * weight);
}
}
}
fn interpolate_from_grid(&self, grid: &[Complex<T>]) -> Vec<Complex<T>> {
let mut result = Vec::with_capacity(self.n_nonuniform);
for j in 0..self.n_nonuniform {
let mut sum = Complex::<T>::zero();
for &(idx, weight) in &self.spread_coeffs[j] {
sum = sum + Complex::new(grid[idx].re * weight, grid[idx].im * weight);
}
result.push(sum);
}
result
}
pub fn n_uniform(&self) -> usize {
self.n_uniform
}
pub fn n_nonuniform(&self) -> usize {
self.n_nonuniform
}
pub fn nufft_type(&self) -> NufftType {
self.nufft_type
}
pub fn points(&self) -> &[f64] {
&self.points
}
}
fn compute_kernel_width(tolerance: f64, default: usize) -> usize {
let width = (-tolerance.log10() + 2.0).ceil() as usize;
width.max(4).min(default.max(12))
}
fn next_smooth_number(n: usize) -> usize {
let mut candidate = n;
loop {
let mut temp = candidate;
while temp.is_multiple_of(2) {
temp /= 2;
}
while temp.is_multiple_of(3) {
temp /= 3;
}
while temp.is_multiple_of(5) {
temp /= 5;
}
if temp == 1 {
return candidate;
}
candidate += 1;
}
}
fn precompute_spreading_coeffs<T: Float>(
points: &[f64],
n_grid: usize,
kernel_width: usize,
) -> Vec<Vec<(usize, T)>> {
let grid_spacing = 2.0 * core::f64::consts::PI / (n_grid as f64);
let half_width = kernel_width / 2;
let beta = 2.3 * (kernel_width as f64);
points
.iter()
.map(|&x| {
let grid_pos = x / grid_spacing;
let center = grid_pos.round() as isize;
let mut coeffs = Vec::with_capacity(kernel_width);
for offset in -(half_width as isize)..=(half_width as isize) {
let grid_idx = (center + offset).rem_euclid(n_grid as isize) as usize;
let grid_x = (grid_idx as f64) * grid_spacing;
let mut dx = x - grid_x;
if dx > core::f64::consts::PI {
dx -= 2.0 * core::f64::consts::PI;
} else if dx < -core::f64::consts::PI {
dx += 2.0 * core::f64::consts::PI;
}
let normalized_dx = dx / (grid_spacing * (half_width as f64));
let weight = (-beta * normalized_dx * normalized_dx).exp();
if weight > 1e-15 {
coeffs.push((grid_idx, T::from_f64(weight)));
}
}
coeffs
})
.collect()
}
fn precompute_deconv_factors<T: Float>(
n_uniform: usize,
n_oversampled: usize,
kernel_width: usize,
) -> Vec<Complex<T>> {
let beta = 2.3 * (kernel_width as f64);
let ratio = (n_oversampled as f64) / (n_uniform as f64);
(0..n_uniform)
.map(|k| {
let freq = if k < n_uniform / 2 {
k as f64
} else {
(k as f64) - (n_uniform as f64)
};
let arg = core::f64::consts::PI * core::f64::consts::PI * freq * freq
/ (beta * ratio * ratio);
let deconv = (arg).exp();
Complex::new(T::from_f64(deconv), T::ZERO)
})
.collect()
}
pub fn nufft_type1<T: Float>(
points: &[f64],
values: &[Complex<T>],
n_output: usize,
tolerance: f64,
) -> NufftResult<Vec<Complex<T>>> {
let plan = Nufft::new(NufftType::Type1, n_output, points, tolerance)?;
plan.type1(values)
}
pub fn nufft_type2<T: Float>(
coeffs: &[Complex<T>],
points: &[f64],
tolerance: f64,
) -> NufftResult<Vec<Complex<T>>> {
let plan = Nufft::new(NufftType::Type2, coeffs.len(), points, tolerance)?;
plan.type2(coeffs)
}
pub fn nufft_type3<T: Float>(
source_points: &[f64],
values: &[Complex<T>],
target_points: &[f64],
tolerance: f64,
) -> NufftResult<Vec<Complex<T>>> {
let n_uniform = (source_points.len() + target_points.len()).next_power_of_two();
let plan = Nufft::new(NufftType::Type1, n_uniform, source_points, tolerance)?;
plan.execute_type3(values, target_points)
}
#[cfg(test)]
mod tests {
use super::*;
#[allow(dead_code)]
fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() < tol
}
#[test]
fn test_nufft_type1_uniform_points() {
let n = 8;
let points: Vec<f64> = (0..n)
.map(|k| -core::f64::consts::PI + (k as f64) * 2.0 * core::f64::consts::PI / (n as f64))
.collect();
let values: Vec<Complex<f64>> = (0..n)
.map(|k| Complex::new((k as f64).cos(), (k as f64).sin()))
.collect();
let result = nufft_type1(&points, &values, n, 1e-6);
assert!(result.is_ok());
let result = result.expect("NUFFT failed");
assert_eq!(result.len(), n);
}
#[test]
fn test_nufft_type2_single_frequency() {
let n = 16;
let mut coeffs = vec![Complex::<f64>::zero(); n];
coeffs[1] = Complex::new(1.0, 0.0);
let points: Vec<f64> = (0..5)
.map(|k| -core::f64::consts::PI + f64::from(k) * 0.5)
.collect();
let result = nufft_type2(&coeffs, &points, 1e-6);
assert!(result.is_ok());
let result = result.expect("NUFFT failed");
assert_eq!(result.len(), 5);
}
#[test]
fn test_nufft_roundtrip() {
let n = 32;
let points: Vec<f64> = (0..10).map(|k| -2.5 + f64::from(k) * 0.5).collect();
let values: Vec<Complex<f64>> = points
.iter()
.map(|&x| Complex::new(x.cos(), x.sin()))
.collect();
let uniform = nufft_type1(&points, &values, n, 1e-6).expect("Type1 failed");
let recovered = nufft_type2(&uniform, &points, 1e-6).expect("Type2 failed");
assert_eq!(recovered.len(), values.len());
}
#[test]
fn test_nufft_error_handling() {
let points = vec![0.0, 0.5, 1.0];
let result = Nufft::<f64>::new(NufftType::Type1, 0, &points, 1e-6);
assert!(result.is_err());
let bad_points = vec![0.0, 5.0]; let result = Nufft::<f64>::new(NufftType::Type1, 16, &bad_points, 1e-6);
assert!(result.is_err());
let result = Nufft::<f64>::new(NufftType::Type1, 16, &points, -1e-6);
assert!(result.is_err());
}
#[test]
fn test_smooth_number() {
assert_eq!(next_smooth_number(100), 100); assert_eq!(next_smooth_number(101), 108); assert_eq!(next_smooth_number(7), 8); }
}