use crate::error::{FFTError, FFTResult};
use crate::ndim_fft::mixed_radix::{fft_1d, ifft_1d_raw};
use crate::ndim_fft::ndim::{apply_normalization, compute_strides, tiled_2d_fft};
use crate::ndim_fft::types::NormMode;
type FftTransformFn = fn(&[(f64, f64)]) -> Vec<(f64, f64)>;
fn effective_threads(n_threads: usize) -> usize {
if n_threads == 0 {
std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(1)
} else {
n_threads.max(1)
}
}
pub fn parallel_fft_rows(
data: &mut Vec<(f64, f64)>,
rows: usize,
cols: usize,
n_threads: usize,
inverse: bool,
) {
debug_assert_eq!(data.len(), rows * cols);
let threads = effective_threads(n_threads).min(rows);
if threads <= 1 {
let transform: FftTransformFn = if inverse { ifft_1d_raw } else { fft_1d };
for r in 0..rows {
let start = r * cols;
let end = start + cols;
let out = transform(&data[start..end]);
data[start..end].copy_from_slice(&out);
}
return;
}
let chunk_size = rows.div_ceil(threads);
let row_slices: Vec<&mut [(f64, f64)]> = {
let mut slices = Vec::new();
let mut rest = data.as_mut_slice();
let mut remaining_rows = rows;
while remaining_rows > 0 {
let this_chunk = chunk_size.min(remaining_rows);
let (head, tail) = rest.split_at_mut(this_chunk * cols);
slices.push(head);
rest = tail;
remaining_rows -= this_chunk;
}
slices
};
std::thread::scope(|s| {
for chunk in row_slices {
s.spawn(move || {
let transform: FftTransformFn = if inverse { ifft_1d_raw } else { fft_1d };
let row_count = chunk.len() / cols;
for r in 0..row_count {
let start = r * cols;
let end = start + cols;
let out = transform(&chunk[start..end]);
chunk[start..end].copy_from_slice(&out);
}
});
}
});
}
pub fn parallel_fft_2d(
data: &mut Vec<(f64, f64)>,
rows: usize,
cols: usize,
n_threads: usize,
inverse: bool,
) -> FFTResult<()> {
if data.len() != rows * cols {
return Err(FFTError::DimensionError(format!(
"parallel_fft_2d: data length {} != {} × {} = {}",
data.len(),
rows,
cols,
rows * cols
)));
}
parallel_fft_rows(data, rows, cols, n_threads, inverse);
crate::ndim_fft::ndim::in_place_transpose(data, rows, cols);
parallel_fft_rows(data, cols, rows, n_threads, inverse);
crate::ndim_fft::ndim::in_place_transpose(data, cols, rows);
Ok(())
}
pub fn parallel_fftn(
input: &[(f64, f64)],
shape: &[usize],
n_threads: usize,
) -> FFTResult<Vec<(f64, f64)>> {
parallel_fftn_norm(input, shape, n_threads, NormMode::None, false)
}
pub fn parallel_ifftn(
input: &[(f64, f64)],
shape: &[usize],
n_threads: usize,
) -> FFTResult<Vec<(f64, f64)>> {
parallel_fftn_norm(input, shape, n_threads, NormMode::None, true)
}
pub fn parallel_fftn_norm(
input: &[(f64, f64)],
shape: &[usize],
n_threads: usize,
norm: NormMode,
inverse: bool,
) -> FFTResult<Vec<(f64, f64)>> {
let expected: usize = shape.iter().product();
if input.len() != expected {
return Err(FFTError::DimensionError(format!(
"parallel_fftn: input length {} != shape product {}",
input.len(),
expected
)));
}
let ndim = shape.len();
let threads = effective_threads(n_threads);
match ndim {
0 => return Ok(input.to_vec()),
1 => {
let mut data = input.to_vec();
parallel_fft_rows(&mut data, 1, shape[0], threads, inverse);
apply_normalization(&mut data, expected, norm, inverse);
return Ok(data);
}
2 => {
let mut data = input.to_vec();
parallel_fft_2d(&mut data, shape[0], shape[1], threads, inverse)?;
apply_normalization(&mut data, expected, norm, inverse);
return Ok(data);
}
_ => {}
}
let outer = shape[0];
let inner_shape = &shape[1..];
let inner_size: usize = inner_shape.iter().product();
let mut data = input.to_vec();
let chunk_size = outer.div_ceil(threads.min(outer));
{
let slab_slices: Vec<&mut [(f64, f64)]> = {
let mut slices = Vec::new();
let mut rest = data.as_mut_slice();
let mut remaining = outer;
while remaining > 0 {
let this_chunk = chunk_size.min(remaining);
let (head, tail) = rest.split_at_mut(this_chunk * inner_size);
slices.push(head);
rest = tail;
remaining -= this_chunk;
}
slices
};
std::thread::scope(|s| {
for slab in slab_slices {
let inner_shape_ref = inner_shape;
s.spawn(move || {
let slab_outer = slab.len() / inner_size;
for i in 0..slab_outer {
let slice_start = i * inner_size;
let slice_end = slice_start + inner_size;
let mut slice = slab[slice_start..slice_end].to_vec();
if let Ok(out) = if inverse {
crate::ndim_fft::ndim::ifftn(&slice, inner_shape_ref)
} else {
crate::ndim_fft::ndim::fftn(&slice, inner_shape_ref)
} {
let inner_n: usize = inner_shape_ref.iter().product();
let undo_scale = if inverse {
inner_n as f64 } else {
1.0
};
slice = out;
if (undo_scale - 1.0).abs() > f64::EPSILON {
for x in slice.iter_mut() {
x.0 *= undo_scale;
x.1 *= undo_scale;
}
}
}
slab[slice_start..slice_end].copy_from_slice(&slice);
}
});
}
});
}
let strides = compute_strides(shape);
let axis_stride = strides[0];
for inner_idx in 0..inner_size {
let mut slice: Vec<(f64, f64)> = (0..outer)
.map(|k| data[k * axis_stride + inner_idx])
.collect();
let out = if inverse {
ifft_1d_raw(&slice)
} else {
fft_1d(&slice)
};
slice = out;
for k in 0..outer {
data[k * axis_stride + inner_idx] = slice[k];
}
}
apply_normalization(&mut data, expected, norm, inverse);
Ok(data)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ndim_fft::ndim::fftn;
use approx::assert_relative_eq;
#[test]
fn test_parallel_fft_matches_serial() {
let rows = 8usize;
let cols = 8usize;
let input: Vec<(f64, f64)> = (0..rows * cols)
.map(|i| ((i as f64 * 0.3).sin(), (i as f64 * 0.2).cos()))
.collect();
let serial = fftn(&input, &[rows, cols]).expect("serial fftn failed");
let parallel = parallel_fftn(&input, &[rows, cols], 4).expect("parallel fftn failed");
assert_eq!(serial.len(), parallel.len());
for (a, b) in serial.iter().zip(parallel.iter()) {
assert_relative_eq!(a.0, b.0, epsilon = 1e-9);
assert_relative_eq!(a.1, b.1, epsilon = 1e-9);
}
}
}