use num_complex::Complex;
use num_traits::{One, Zero};
use rayon::prelude::*;
use ferray_core::error::{FerrayError, FerrayResult};
use crate::axes::compute_strides;
use crate::float::FftFloat;
use crate::norm::FftNorm;
pub(crate) fn fft_along_axis<T: FftFloat>(
data: &[Complex<T>],
shape: &[usize],
axis: usize,
n: Option<usize>,
inverse: bool,
norm: FftNorm,
) -> FerrayResult<(Vec<usize>, Vec<Complex<T>>)>
where
Complex<T>: ferray_core::Element,
{
let ndim = shape.len();
if axis >= ndim {
return Err(FerrayError::axis_out_of_bounds(axis, ndim));
}
let axis_len = shape[axis];
let fft_len = n.unwrap_or(axis_len);
if fft_len == 0 {
return Err(FerrayError::invalid_value("FFT length must be > 0"));
}
let total = shape.iter().product::<usize>();
if total == 0 {
let mut new_shape = shape.to_vec();
new_shape[axis] = fft_len;
let new_total: usize = new_shape.iter().product();
return Ok((new_shape, vec![Complex::zero(); new_total]));
}
if ndim == 1 {
return fft_1d_fast(data, fft_len, axis_len, inverse, norm);
}
let num_lanes = total / axis_len;
let strides = compute_strides(shape);
let mut new_shape = shape.to_vec();
new_shape[axis] = fft_len;
let new_strides = compute_strides(&new_shape);
let new_total: usize = new_shape.iter().product();
let lane_starts = compute_lane_starts(shape, &strides, axis, num_lanes);
let plan = T::cached_plan(fft_len, inverse);
let direction = if inverse {
crate::norm::FftDirection::Inverse
} else {
crate::norm::FftDirection::Forward
};
let scale = T::scale_factor(norm, fft_len, direction);
let one = <T as One>::one();
let scratch_len = plan.get_inplace_scratch_len();
let stride = strides[axis] as usize;
let copy_len = axis_len.min(fft_len);
let mut lane_outputs: Vec<Complex<T>> = vec![Complex::zero(); num_lanes * fft_len];
lane_outputs
.par_chunks_mut(fft_len)
.zip(lane_starts.par_iter())
.for_each_init(
|| vec![Complex::zero(); scratch_len],
|scratch, (out_chunk, &start_offset)| {
for (i, slot) in out_chunk.iter_mut().take(copy_len).enumerate() {
*slot = data[start_offset + i * stride];
}
plan.process_with_scratch(out_chunk, scratch);
if scale != one {
for c in out_chunk.iter_mut() {
*c = *c * scale;
}
}
},
);
let mut output: Vec<Complex<T>> = vec![Complex::zero(); new_total];
let out_stride = new_strides[axis] as usize;
for (lane_idx, lane_chunk) in lane_outputs.chunks(fft_len).enumerate() {
let out_start = compute_lane_output_start(
&new_shape,
&new_strides,
axis,
lane_starts[lane_idx],
&strides,
);
for (i, &val) in lane_chunk.iter().enumerate() {
output[out_start + i * out_stride] = val;
}
}
Ok((new_shape, output))
}
fn fft_1d_fast<T: FftFloat>(
data: &[Complex<T>],
fft_len: usize,
input_len: usize,
inverse: bool,
norm: FftNorm,
) -> FerrayResult<(Vec<usize>, Vec<Complex<T>>)>
where
Complex<T>: ferray_core::Element,
{
let mut buffer: Vec<Complex<T>> = Vec::with_capacity(fft_len);
let copy_len = input_len.min(fft_len);
buffer.extend_from_slice(&data[..copy_len]);
buffer.resize(fft_len, Complex::zero());
let plan = T::cached_plan(fft_len, inverse);
let mut scratch: Vec<Complex<T>> = vec![Complex::zero(); plan.get_inplace_scratch_len()];
plan.process_with_scratch(&mut buffer, &mut scratch);
let direction = if inverse {
crate::norm::FftDirection::Inverse
} else {
crate::norm::FftDirection::Forward
};
let scale = T::scale_factor(norm, fft_len, direction);
let one = <T as One>::one();
if scale != one {
for c in &mut buffer {
*c = *c * scale;
}
}
Ok((vec![fft_len], buffer))
}
pub(crate) fn fft_along_axes<T: FftFloat>(
data: &[Complex<T>],
shape: &[usize],
axes_and_sizes: &[(usize, Option<usize>)],
inverse: bool,
norm: FftNorm,
) -> FerrayResult<(Vec<usize>, Vec<Complex<T>>)>
where
Complex<T>: ferray_core::Element,
{
let mut current_data = data.to_vec();
let mut current_shape = shape.to_vec();
for &(axis, n) in axes_and_sizes {
let (new_shape, new_data) =
fft_along_axis(¤t_data, ¤t_shape, axis, n, inverse, norm)?;
current_shape = new_shape;
current_data = new_data;
}
Ok((current_shape, current_data))
}
pub(crate) fn rfft_along_axis<T: FftFloat>(
data: &[T],
shape: &[usize],
axis: usize,
n: Option<usize>,
norm: FftNorm,
) -> FerrayResult<(Vec<usize>, Vec<Complex<T>>)>
where
Complex<T>: ferray_core::Element,
{
let ndim = shape.len();
if axis >= ndim {
return Err(FerrayError::axis_out_of_bounds(axis, ndim));
}
let axis_len = shape[axis];
let fft_len = n.unwrap_or(axis_len);
if fft_len == 0 {
return Err(FerrayError::invalid_value("FFT length must be > 0"));
}
let half_len = fft_len / 2 + 1;
let mut new_shape = shape.to_vec();
new_shape[axis] = half_len;
let new_total: usize = new_shape.iter().product();
let total = shape.iter().product::<usize>();
if total == 0 {
return Ok((new_shape, vec![Complex::zero(); new_total]));
}
let plan = T::cached_real_forward(fft_len);
let scratch_len = plan.get_scratch_len();
let scale = T::scale_factor(norm, fft_len, crate::norm::FftDirection::Forward);
let one = <T as One>::one();
let t_zero = <T as Zero>::zero();
let strides = compute_strides(shape);
let new_strides = compute_strides(&new_shape);
let stride = strides[axis] as usize;
let copy_len = axis_len.min(fft_len);
if ndim == 1 {
let mut input_buf: Vec<T> = vec![t_zero; fft_len];
input_buf[..copy_len].copy_from_slice(&data[..copy_len]);
let mut output_buf: Vec<Complex<T>> = vec![Complex::zero(); half_len];
let mut scratch = plan.make_scratch_vec();
plan.process_with_scratch(&mut input_buf, &mut output_buf, &mut scratch)
.map_err(|e| FerrayError::invalid_value(format!("real FFT process failed: {e}")))?;
if scale != one {
for c in &mut output_buf {
*c = *c * scale;
}
}
return Ok((new_shape, output_buf));
}
let num_lanes = total / axis_len;
let lane_starts = compute_lane_starts(shape, &strides, axis, num_lanes);
let mut lane_outputs: Vec<Complex<T>> = vec![Complex::zero(); num_lanes * half_len];
lane_outputs
.par_chunks_mut(half_len)
.zip(lane_starts.par_iter())
.for_each_init(
|| (vec![t_zero; fft_len], vec![Complex::zero(); scratch_len]),
|(input_buf, scratch), (out_chunk, &start_offset)| {
for (i, slot) in input_buf.iter_mut().take(copy_len).enumerate() {
*slot = data[start_offset + i * stride];
}
for slot in input_buf.iter_mut().skip(copy_len) {
*slot = t_zero;
}
plan.process_with_scratch(input_buf, out_chunk, scratch)
.expect("real FFT process failed");
if scale != one {
for c in out_chunk.iter_mut() {
*c = *c * scale;
}
}
},
);
let mut output: Vec<Complex<T>> = vec![Complex::zero(); new_total];
let out_stride = new_strides[axis] as usize;
for (lane_idx, lane_chunk) in lane_outputs.chunks(half_len).enumerate() {
let out_start = compute_lane_output_start(
&new_shape,
&new_strides,
axis,
lane_starts[lane_idx],
&strides,
);
for (i, &val) in lane_chunk.iter().enumerate() {
output[out_start + i * out_stride] = val;
}
}
Ok((new_shape, output))
}
pub(crate) fn irfft_along_axis<T: FftFloat>(
data: &[Complex<T>],
shape: &[usize],
axis: usize,
output_len: usize,
norm: FftNorm,
) -> FerrayResult<(Vec<usize>, Vec<T>)>
where
Complex<T>: ferray_core::Element,
{
let ndim = shape.len();
if axis >= ndim {
return Err(FerrayError::axis_out_of_bounds(axis, ndim));
}
if output_len == 0 {
return Err(FerrayError::invalid_value(
"irfft output length must be > 0",
));
}
let half_len = output_len / 2 + 1;
let input_axis_len = shape[axis];
let mut new_shape = shape.to_vec();
new_shape[axis] = output_len;
let new_total: usize = new_shape.iter().product();
let total = shape.iter().product::<usize>();
let t_zero = <T as Zero>::zero();
if total == 0 {
return Ok((new_shape, vec![t_zero; new_total]));
}
let plan = T::cached_real_inverse(output_len);
let scratch_len = plan.get_scratch_len();
let scale = T::scale_factor(norm, output_len, crate::norm::FftDirection::Inverse);
let one = <T as One>::one();
let strides = compute_strides(shape);
let new_strides = compute_strides(&new_shape);
let stride = strides[axis] as usize;
let copy_len = input_axis_len.min(half_len);
if ndim == 1 {
let mut input_buf: Vec<Complex<T>> = vec![Complex::zero(); half_len];
input_buf[..copy_len].copy_from_slice(&data[..copy_len]);
let mut output_buf: Vec<T> = vec![t_zero; output_len];
let mut scratch = plan.make_scratch_vec();
plan.process_with_scratch(&mut input_buf, &mut output_buf, &mut scratch)
.map_err(|e| {
FerrayError::invalid_value(format!("inverse real FFT process failed: {e}"))
})?;
if scale != one {
for v in &mut output_buf {
*v = *v * scale;
}
}
return Ok((new_shape, output_buf));
}
let num_lanes = total / input_axis_len;
let lane_starts = compute_lane_starts(shape, &strides, axis, num_lanes);
let mut lane_outputs: Vec<T> = vec![t_zero; num_lanes * output_len];
lane_outputs
.par_chunks_mut(output_len)
.zip(lane_starts.par_iter())
.for_each_init(
|| {
(
vec![Complex::zero(); half_len],
vec![Complex::zero(); scratch_len],
)
},
|(input_buf, scratch), (out_chunk, &start_offset)| {
for (i, slot) in input_buf.iter_mut().take(copy_len).enumerate() {
*slot = data[start_offset + i * stride];
}
for slot in input_buf.iter_mut().skip(copy_len) {
*slot = Complex::zero();
}
plan.process_with_scratch(input_buf, out_chunk, scratch)
.expect("inverse real FFT process failed");
if scale != one {
for v in out_chunk.iter_mut() {
*v = *v * scale;
}
}
},
);
let mut output: Vec<T> = vec![t_zero; new_total];
let out_stride = new_strides[axis] as usize;
for (lane_idx, lane_chunk) in lane_outputs.chunks(output_len).enumerate() {
let out_start = compute_lane_output_start(
&new_shape,
&new_strides,
axis,
lane_starts[lane_idx],
&strides,
);
for (i, &val) in lane_chunk.iter().enumerate() {
output[out_start + i * out_stride] = val;
}
}
Ok((new_shape, output))
}
fn compute_lane_starts(
shape: &[usize],
strides: &[isize],
axis: usize,
num_lanes: usize,
) -> Vec<usize> {
let ndim = shape.len();
let mut lane_starts = Vec::with_capacity(num_lanes);
let mut outer_dims: Vec<(usize, isize)> = Vec::with_capacity(ndim - 1);
for (d, (&s, &st)) in shape.iter().zip(strides.iter()).enumerate() {
if d != axis {
outer_dims.push((s, st));
}
}
let outer_total = outer_dims.iter().map(|&(s, _)| s).product::<usize>();
debug_assert_eq!(outer_total, num_lanes);
for lane_idx in 0..num_lanes {
let mut offset = 0usize;
let mut remainder = lane_idx;
for &(dim_size, stride) in outer_dims.iter().rev() {
let idx = remainder % dim_size;
remainder /= dim_size;
offset += idx * stride as usize;
}
lane_starts.push(offset);
}
lane_starts
}
fn compute_lane_output_start(
new_shape: &[usize],
new_strides: &[isize],
axis: usize,
input_start: usize,
input_strides: &[isize],
) -> usize {
let ndim = new_shape.len();
let mut remaining = input_start as isize;
let mut multi_idx = vec![0usize; ndim];
for d in 0..ndim {
if d == axis {
continue;
}
if input_strides[d] != 0 {
multi_idx[d] = (remaining / input_strides[d]) as usize;
remaining -= (multi_idx[d] as isize) * input_strides[d];
}
}
let mut offset = 0usize;
for d in 0..ndim {
if d == axis {
continue;
}
offset += multi_idx[d] * new_strides[d] as usize;
}
offset
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fft_1d_simple() {
let data = vec![
Complex::<f64>::new(1.0, 0.0),
Complex::<f64>::new(0.0, 0.0),
Complex::<f64>::new(0.0, 0.0),
Complex::<f64>::new(0.0, 0.0),
];
let (shape, result) =
fft_along_axis(&data, &[4], 0, None, false, FftNorm::Backward).unwrap();
assert_eq!(shape, vec![4]);
for c in &result {
assert!((c.re - 1.0).abs() < 1e-12);
assert!(c.im.abs() < 1e-12);
}
}
#[test]
fn fft_2d_along_axis0() {
let data = vec![
Complex::<f64>::new(1.0, 0.0),
Complex::<f64>::new(0.0, 0.0),
Complex::<f64>::new(0.0, 0.0),
Complex::<f64>::new(1.0, 0.0),
];
let (shape, result) =
fft_along_axis(&data, &[2, 2], 0, None, false, FftNorm::Backward).unwrap();
assert_eq!(shape, vec![2, 2]);
assert!((result[0].re - 1.0).abs() < 1e-12); assert!((result[1].re - 1.0).abs() < 1e-12); assert!((result[2].re - 1.0).abs() < 1e-12); assert!((result[3].re - (-1.0)).abs() < 1e-12); }
#[test]
fn fft_axis_out_of_bounds() {
let data = vec![Complex::<f64>::new(1.0, 0.0)];
assert!(fft_along_axis(&data, &[1], 1, None, false, FftNorm::Backward).is_err());
}
#[test]
fn fft_with_zero_padding() {
let data = vec![Complex::<f64>::new(1.0, 0.0), Complex::<f64>::new(1.0, 0.0)];
let (shape, result) =
fft_along_axis(&data, &[2], 0, Some(4), false, FftNorm::Backward).unwrap();
assert_eq!(shape, vec![4]);
assert_eq!(result.len(), 4);
assert!((result[0].re - 2.0).abs() < 1e-12);
}
#[test]
fn fft_2d_many_lanes_along_axis1() {
let rows = 4usize;
let cols = 8usize;
let mut data = vec![Complex::<f64>::new(0.0, 0.0); rows * cols];
for r in 0..rows {
data[r * cols + r] = Complex::<f64>::new(1.0, 0.0);
}
let (shape, result) =
fft_along_axis(&data, &[rows, cols], 1, None, false, FftNorm::Backward).unwrap();
assert_eq!(shape, vec![rows, cols]);
for c in result.iter().take(cols) {
assert!((c.re - 1.0).abs() < 1e-12);
assert!(c.im.abs() < 1e-12);
}
for r in 0..rows {
assert!((result[r * cols].re - 1.0).abs() < 1e-12);
}
}
#[test]
fn fft_2d_many_lanes_along_axis0() {
let rows = 8usize;
let cols = 4usize;
let mut data = vec![Complex::<f64>::new(0.0, 0.0); rows * cols];
for slot in data.iter_mut().take(cols) {
*slot = Complex::<f64>::new(1.0, 0.0);
}
let (shape, result) =
fft_along_axis(&data, &[rows, cols], 0, None, false, FftNorm::Backward).unwrap();
assert_eq!(shape, vec![rows, cols]);
for r in 0..rows {
for c in 0..cols {
let idx = r * cols + c;
assert!(
(result[idx].re - 1.0).abs() < 1e-12,
"result[{r},{c}].re = {}",
result[idx].re
);
assert!(result[idx].im.abs() < 1e-12);
}
}
}
#[test]
fn fft_2d_zero_padding_multi_lane() {
let data = vec![
Complex::<f64>::new(1.0, 0.0),
Complex::<f64>::new(1.0, 0.0),
Complex::<f64>::new(2.0, 0.0),
Complex::<f64>::new(0.0, 0.0),
Complex::<f64>::new(0.0, 0.0),
Complex::<f64>::new(3.0, 0.0),
];
let (shape, result) =
fft_along_axis(&data, &[3, 2], 1, Some(4), false, FftNorm::Backward).unwrap();
assert_eq!(shape, vec![3, 4]);
assert_eq!(result.len(), 12);
assert!((result[0].re - 2.0).abs() < 1e-12);
assert!((result[4].re - 2.0).abs() < 1e-12);
assert!((result[5].re - 2.0).abs() < 1e-12);
assert!((result[6].re - 2.0).abs() < 1e-12);
assert!((result[7].re - 2.0).abs() < 1e-12);
assert!((result[8].re - 3.0).abs() < 1e-12);
}
}