use crate::error::{FFTError, FFTResult};
use crate::ndim_fft::mixed_radix::{fft_1d, ifft_1d};
use crate::ndim_fft::types::NormMode;
use std::f64::consts::PI;
type FftTransformFn = fn(&[(f64, f64)]) -> Vec<(f64, f64)>;
pub fn compute_strides(shape: &[usize]) -> Vec<usize> {
let ndim = shape.len();
let mut strides = vec![1usize; ndim];
for i in (0..ndim.saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
#[inline]
fn flat_index(idx: &[usize], strides: &[usize]) -> usize {
idx.iter().zip(strides.iter()).map(|(&i, &s)| i * s).sum()
}
pub fn extract_axis(
data: &[(f64, f64)],
shape: &[usize],
axis: usize,
pos: &[usize],
) -> Vec<(f64, f64)> {
let n = shape[axis];
let strides = compute_strides(shape);
let mut result = Vec::with_capacity(n);
let mut midx: Vec<usize> = Vec::with_capacity(shape.len());
let mut pos_iter = pos.iter();
for dim in 0..shape.len() {
if dim == axis {
midx.push(0); } else {
midx.push(*pos_iter.next().unwrap_or(&0));
}
}
let axis_stride = strides[axis];
let base = flat_index(&midx, &strides);
for k in 0..n {
result.push(data[base + k * axis_stride]);
}
result
}
pub fn insert_axis(
data: &mut [(f64, f64)],
shape: &[usize],
axis: usize,
pos: &[usize],
values: &[(f64, f64)],
) {
let n = shape[axis];
let strides = compute_strides(shape);
let mut midx: Vec<usize> = Vec::with_capacity(shape.len());
let mut pos_iter = pos.iter();
for dim in 0..shape.len() {
if dim == axis {
midx.push(0);
} else {
midx.push(*pos_iter.next().unwrap_or(&0));
}
}
let axis_stride = strides[axis];
let base = flat_index(&midx, &strides);
for k in 0..n {
data[base + k * axis_stride] = values[k];
}
}
fn fft_along_axis(
data: &mut [(f64, f64)],
shape: &[usize],
axis: usize,
inverse: bool,
) -> FFTResult<()> {
let ndim = shape.len();
if axis >= ndim {
return Err(FFTError::DimensionError(format!(
"axis {axis} out of range for {ndim}-D array"
)));
}
let n_axis = shape[axis];
let total = data.len();
let n_slices = total / n_axis;
let strides = compute_strides(shape);
let axis_stride = strides[axis];
let mut slice_buf = vec![(0.0f64, 0.0f64); n_axis];
let mut processed = 0usize;
for f in 0..total {
let axis_coord = (f / axis_stride) % n_axis;
if axis_coord != 0 {
continue;
}
for k in 0..n_axis {
slice_buf[k] = data[f + k * axis_stride];
}
let out = if inverse {
crate::ndim_fft::mixed_radix::ifft_1d_raw(&slice_buf)
} else {
fft_1d(&slice_buf)
};
for k in 0..n_axis {
data[f + k * axis_stride] = out[k];
}
processed += 1;
if processed == n_slices {
break;
}
}
Ok(())
}
pub fn in_place_transpose(data: &mut [(f64, f64)], rows: usize, cols: usize) {
if rows == cols {
for r in 0..rows {
for c in (r + 1)..cols {
data.swap(r * cols + c, c * rows + r);
}
}
} else {
let n = data.len();
debug_assert_eq!(n, rows * cols);
let mut tmp = vec![(0.0f64, 0.0f64); n];
for r in 0..rows {
for c in 0..cols {
tmp[c * rows + r] = data[r * cols + c];
}
}
data.copy_from_slice(&tmp);
}
}
pub fn tiled_2d_fft(
data: &mut [(f64, f64)],
rows: usize,
cols: usize,
_tile_size: usize,
inverse: bool,
) {
debug_assert_eq!(data.len(), rows * cols);
let row_transform: FftTransformFn = if inverse {
crate::ndim_fft::mixed_radix::ifft_1d_raw
} else {
fft_1d
};
for r in 0..rows {
let start = r * cols;
let end = start + cols;
let row_out = row_transform(&data[start..end]);
data[start..end].copy_from_slice(&row_out);
}
in_place_transpose(data, rows, cols);
for r in 0..cols {
let start = r * rows;
let end = start + rows;
let row_out = row_transform(&data[start..end]);
data[start..end].copy_from_slice(&row_out);
}
in_place_transpose(data, cols, rows);
}
pub fn apply_normalization(data: &mut [(f64, f64)], n: usize, norm: NormMode, inverse: bool) {
let scale = match norm {
NormMode::None if inverse => 1.0 / n as f64,
NormMode::Ortho => 1.0 / (n as f64).sqrt(),
NormMode::Forward => {
if inverse {
1.0 } else {
1.0 / n as f64
}
}
#[allow(unreachable_patterns)]
_ => 1.0,
};
if (scale - 1.0).abs() < f64::EPSILON {
return; }
for x in data.iter_mut() {
x.0 *= scale;
x.1 *= scale;
}
}
pub fn fftn(input: &[(f64, f64)], shape: &[usize]) -> FFTResult<Vec<(f64, f64)>> {
fftn_norm(input, shape, NormMode::None)
}
pub fn fftn_norm(
input: &[(f64, f64)],
shape: &[usize],
norm: NormMode,
) -> FFTResult<Vec<(f64, f64)>> {
let expected: usize = shape.iter().product();
if input.len() != expected {
return Err(FFTError::DimensionError(format!(
"input length {} does not match shape {:?} (product = {})",
input.len(),
shape,
expected
)));
}
let ndim = shape.len();
if ndim == 0 {
return Ok(input.to_vec());
}
let mut data = input.to_vec();
if ndim == 2 {
tiled_2d_fft(&mut data, shape[0], shape[1], 64, false);
} else {
for axis in 0..ndim {
fft_along_axis(&mut data, shape, axis, false)?;
}
}
apply_normalization(&mut data, expected, norm, false);
Ok(data)
}
pub fn ifftn(input: &[(f64, f64)], shape: &[usize]) -> FFTResult<Vec<(f64, f64)>> {
ifftn_norm(input, shape, NormMode::None)
}
pub fn ifftn_norm(
input: &[(f64, f64)],
shape: &[usize],
norm: NormMode,
) -> FFTResult<Vec<(f64, f64)>> {
let expected: usize = shape.iter().product();
if input.len() != expected {
return Err(FFTError::DimensionError(format!(
"input length {} does not match shape {:?}",
input.len(),
shape
)));
}
let ndim = shape.len();
if ndim == 0 {
return Ok(input.to_vec());
}
let mut data = input.to_vec();
if ndim == 2 {
tiled_2d_fft(&mut data, shape[0], shape[1], 64, true);
} else {
for axis in 0..ndim {
fft_along_axis(&mut data, shape, axis, true)?;
}
}
apply_normalization(&mut data, expected, norm, true);
Ok(data)
}
pub fn rfftn(input: &[f64], shape: &[usize]) -> FFTResult<Vec<(f64, f64)>> {
let expected: usize = shape.iter().product();
if input.len() != expected {
return Err(FFTError::DimensionError(format!(
"rfftn: input length {} does not match shape {:?}",
input.len(),
shape
)));
}
let ndim = shape.len();
if ndim == 0 {
return Ok(input.iter().map(|&r| (r, 0.0)).collect());
}
let complex_input: Vec<(f64, f64)> = input.iter().map(|&r| (r, 0.0)).collect();
let full = fftn(&complex_input, shape)?;
let last_n = shape[ndim - 1];
let half_last = last_n / 2 + 1;
let mut out_shape = shape.to_vec();
out_shape[ndim - 1] = half_last;
let prefix_size: usize = shape[..ndim - 1].iter().product();
let mut result = Vec::with_capacity(prefix_size * half_last);
for i in 0..prefix_size {
let src_start = i * last_n;
result.extend_from_slice(&full[src_start..src_start + half_last]);
}
Ok(result)
}
pub fn irfftn(input: &[(f64, f64)], shape: &[usize]) -> FFTResult<Vec<f64>> {
let ndim = shape.len();
if ndim == 0 {
return Ok(input.iter().map(|&(re, _)| re).collect());
}
let last_n = shape[ndim - 1];
let half_last = last_n / 2 + 1;
let prefix_size: usize = if ndim > 1 {
shape[..ndim - 1].iter().product()
} else {
1
};
if input.len() != prefix_size * half_last {
return Err(FFTError::DimensionError(format!(
"irfftn: input length {} does not match expected {} (shape={:?})",
input.len(),
prefix_size * half_last,
shape
)));
}
let total: usize = shape.iter().product();
let mut full = vec![(0.0f64, 0.0f64); total];
for i in 0..prefix_size {
let src_start = i * half_last;
let dst_start = i * last_n;
full[dst_start..(half_last + dst_start)]
.copy_from_slice(&input[src_start..(half_last + src_start)]);
for k in half_last..last_n {
let conj_k = last_n - k;
let src = input[src_start + conj_k];
full[dst_start + k] = (src.0, -src.1);
}
}
let complex_out = ifftn(&full, shape)?;
Ok(complex_out.into_iter().map(|(re, _)| re).collect())
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_fft2d_shape() {
let rows = 4usize;
let cols = 8usize;
let input: Vec<(f64, f64)> = (0..rows * cols).map(|i| (i as f64, 0.0)).collect();
let out = fftn(&input, &[rows, cols]).expect("fftn failed");
assert_eq!(out.len(), rows * cols);
}
#[test]
fn test_fft2d_roundtrip() {
let rows = 4usize;
let cols = 4usize;
let input: Vec<(f64, f64)> = (0..rows * cols).map(|i| (i as f64 * 0.5, 0.0)).collect();
let freq = fftn(&input, &[rows, cols]).expect("fftn failed");
let recovered = ifftn(&freq, &[rows, cols]).expect("ifftn failed");
for (a, b) in input.iter().zip(recovered.iter()) {
assert_relative_eq!(a.0, b.0, epsilon = 1e-9);
assert_relative_eq!(a.1, b.1, epsilon = 1e-9);
}
}
#[test]
fn test_fft2d_tiled_matches_direct() {
let rows = 8usize;
let cols = 8usize;
let input: Vec<(f64, f64)> = (0..rows * cols)
.map(|i| ((i as f64).sin(), (i as f64).cos()))
.collect();
let tiled = fftn(&input, &[rows, cols]).expect("tiled fftn failed");
let mut direct = input.clone();
fft_along_axis(&mut direct, &[rows, cols], 0, false).expect("axis 0 failed");
fft_along_axis(&mut direct, &[rows, cols], 1, false).expect("axis 1 failed");
for (a, b) in tiled.iter().zip(direct.iter()) {
assert_relative_eq!(a.0, b.0, epsilon = 1e-9);
assert_relative_eq!(a.1, b.1, epsilon = 1e-9);
}
}
#[test]
fn test_fftn_3d_shape() {
let shape = [2usize, 3, 4];
let n: usize = shape.iter().product();
let input: Vec<(f64, f64)> = (0..n).map(|i| (i as f64, 0.0)).collect();
let out = fftn(&input, &shape).expect("fftn 3d failed");
assert_eq!(out.len(), n);
}
#[test]
fn test_fftn_3d_roundtrip() {
let shape = [2usize, 4, 4];
let n: usize = shape.iter().product();
let input: Vec<(f64, f64)> = (0..n).map(|i| (i as f64 * 0.1, 0.0)).collect();
let freq = fftn(&input, &shape).expect("fftn failed");
let recovered = ifftn(&freq, &shape).expect("ifftn failed");
for (a, b) in input.iter().zip(recovered.iter()) {
assert_relative_eq!(a.0, b.0, epsilon = 1e-9);
assert_relative_eq!(a.1, b.1, epsilon = 1e-9);
}
}
#[test]
fn test_rfftn_real_input() {
let n = 16usize;
let input: Vec<f64> = (0..n).map(|i| i as f64).collect();
let out = rfftn(&input, &[n]).expect("rfftn failed");
assert_eq!(out.len(), n / 2 + 1);
}
#[test]
fn test_irfftn_roundtrip() {
let n = 16usize;
let input: Vec<f64> = (0..n)
.map(|i| (i as f64 * 2.0 * PI / n as f64).sin())
.collect();
let spectrum = rfftn(&input, &[n]).expect("rfftn failed");
let recovered = irfftn(&spectrum, &[n]).expect("irfftn failed");
assert_eq!(recovered.len(), n);
for (a, b) in input.iter().zip(recovered.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-9);
}
}
#[test]
fn test_normalization_ortho() {
let n = 8usize;
let input: Vec<(f64, f64)> = (0..n).map(|i| (i as f64, 0.0)).collect();
let out = fftn_norm(&input, &[n], NormMode::Ortho).expect("fftn failed");
let energy_in: f64 = input.iter().map(|&(re, im)| re * re + im * im).sum();
let energy_out: f64 = out.iter().map(|&(re, im)| re * re + im * im).sum();
assert_relative_eq!(energy_in, energy_out, epsilon = 1e-9);
}
#[test]
fn test_normalization_forward() {
let n = 8usize;
let input: Vec<(f64, f64)> = vec![(1.0, 0.0); n];
let out = fftn_norm(&input, &[n], NormMode::Forward).expect("fftn failed");
assert_relative_eq!(out[0].0, 1.0, epsilon = 1e-12);
assert_relative_eq!(out[0].1, 0.0, epsilon = 1e-12);
for &(re, im) in &out[1..] {
assert_relative_eq!(re, 0.0, epsilon = 1e-12);
assert_relative_eq!(im, 0.0, epsilon = 1e-12);
}
}
}