use scirs2_core::ndarray::{Array, ArrayView, Axis, Dimension};
use scirs2_core::numeric::Complex64;
use scirs2_core::numeric::NumCast;
use scirs2_core::parallel_ops::*;
use std::cmp::min;
use crate::error::{FFTError, FFTResult};
use crate::fft::fft;
use crate::rfft::rfft;
#[allow(dead_code)]
pub fn fftn_optimized<T, D>(
x: &ArrayView<T, D>,
_shape: Option<Vec<usize>>,
axes: Option<Vec<usize>>,
) -> FFTResult<Array<Complex64, D>>
where
T: NumCast + Copy + Send + Sync,
D: Dimension,
{
let ndim = x.ndim();
let mut result = Array::zeros(x.raw_dim());
scirs2_core::ndarray::Zip::from(&mut result)
.and(x)
.for_each(|dst, &src| {
*dst = Complex64::new(
NumCast::from(src)
.ok_or_else(|| {
FFTError::ValueError("Failed to convert input to complex".to_string())
})
.expect("Operation failed"),
0.0,
);
});
let axes_to_transform = if let Some(a) = axes {
validate_axes(&a, ndim)?;
a
} else {
(0..ndim).collect()
};
let optimized_order = optimize_axis_order(&axes_to_transform, result.shape());
for &axis in &optimized_order {
apply_fft_along_axis(&mut result, axis)?;
}
Ok(result)
}
#[allow(dead_code)]
fn apply_fft_along_axis<D>(data: &mut Array<Complex64, D>, axis: usize) -> FFTResult<()>
where
D: Dimension,
{
let axis_len = data.shape()[axis];
let mut buffer = vec![Complex64::new(0.0, 0.0); axis_len];
for mut lane in data.lanes_mut(Axis(axis)) {
buffer
.iter_mut()
.zip(lane.iter())
.for_each(|(b, &x)| *b = x);
let transformed = fft(&buffer, None)?;
lane.iter_mut()
.zip(transformed.iter())
.for_each(|(dst, &src)| *dst = src);
}
Ok(())
}
#[allow(dead_code)]
fn optimize_axis_order(axes: &[usize], shape: &[usize]) -> Vec<usize> {
let mut axis_info: Vec<(usize, usize, usize)> = axes
.iter()
.map(|&axis| {
let size = shape[axis];
let stride = shape.iter().skip(axis + 1).product::<usize>();
(axis, size, stride)
})
.collect();
axis_info.sort_by_key(|&(_, _, stride)| stride);
axis_info.into_iter().map(|(axis, _, _)| axis).collect()
}
#[allow(dead_code)]
fn validate_axes(axes: &[usize], ndim: usize) -> FFTResult<()> {
for &axis in axes {
if axis >= ndim {
return Err(FFTError::ValueError(format!(
"Axis {axis} is out of bounds for array with {ndim} dimensions"
)));
}
}
Ok(())
}
#[allow(dead_code)]
fn should_parallelize(_data_size: usize, axislen: usize) -> bool {
const MIN_PARALLEL_SIZE: usize = 10000;
_data_size > MIN_PARALLEL_SIZE && axislen > 64
}
#[cfg(feature = "parallel")]
#[allow(dead_code)]
fn apply_fft_parallel<D>(data: &mut Array<Complex64, D>, axis: usize) -> FFTResult<()>
where
D: Dimension,
{
let axis_len = data.shape()[axis];
let total_size: usize = data.shape().iter().product();
if should_parallelize(total_size, axis_len) {
let mut lanes: Vec<_> = data.lanes_mut(Axis(axis)).into_iter().collect();
lanes.par_iter_mut().try_for_each(|lane| {
let buffer: Vec<Complex64> = lane.to_vec();
let transformed = fft(&buffer, None)?;
lane.iter_mut()
.zip(transformed.iter())
.for_each(|(dst, &src)| *dst = src);
Ok(())
})
} else {
apply_fft_along_axis(data, axis)
}
}
#[allow(dead_code)]
pub fn fftn_memory_efficient<T, D>(
x: &ArrayView<T, D>,
axes: Option<Vec<usize>>,
_max_memory_gb: f64,
) -> FFTResult<Array<Complex64, D>>
where
T: NumCast + Copy + Send + Sync,
D: Dimension,
{
let ndim = x.ndim();
let axes_to_transform = if let Some(a) = axes {
validate_axes(&a, ndim)?;
a
} else {
(0..ndim).collect()
};
let mut result = Array::zeros(x.raw_dim());
scirs2_core::ndarray::Zip::from(&mut result)
.and(x)
.for_each(|dst, &src| {
*dst = Complex64::new(
NumCast::from(src)
.ok_or_else(|| {
FFTError::ValueError("Failed to convert input to complex".to_string())
})
.expect("Operation failed"),
0.0,
);
});
for &axis in &axes_to_transform {
let axis_len: usize = result.shape()[axis];
if axis_len > 1048576 {
apply_fft_chunked(&mut result, axis)?;
} else {
apply_fft_along_axis(&mut result, axis)?;
}
}
Ok(result)
}
#[allow(dead_code)]
fn apply_fft_chunked<D>(data: &mut Array<Complex64, D>, axis: usize) -> FFTResult<()>
where
D: Dimension,
{
let axis_len = data.shape()[axis];
const CHUNK_SIZE: usize = 65536;
let n_chunks = axis_len.div_ceil(CHUNK_SIZE);
for chunk_idx in 0..n_chunks {
let start = chunk_idx * CHUNK_SIZE;
let end = min(start + CHUNK_SIZE, axis_len);
let chunk_len = end - start;
let mut buffer = vec![Complex64::new(0.0, 0.0); chunk_len];
for mut lane in data.lanes_mut(Axis(axis)) {
buffer
.iter_mut()
.zip(lane.slice_axis(Axis(0), (start..end).into()).iter())
.for_each(|(b, &x)| *b = x);
let transformed = fft(&buffer, None)?;
lane.slice_axis_mut(Axis(0), (start..end).into())
.iter_mut()
.zip(transformed.iter())
.for_each(|(dst, &src)| *dst = src);
}
}
Ok(())
}
#[allow(dead_code)]
pub fn rfftn_optimized<T, D>(
x: &ArrayView<T, D>,
_shape: Option<Vec<usize>>,
axes: Option<Vec<usize>>,
) -> FFTResult<Array<Complex64, D>>
where
T: NumCast + Copy + Send + Sync,
D: Dimension,
{
let ndim = x.ndim();
let mut axes_to_transform = if let Some(a) = axes {
validate_axes(&a, ndim)?;
a
} else {
(0..ndim).collect()
};
let last_axis = axes_to_transform.pop().unwrap_or(ndim - 1);
let mut real_data = Array::zeros(x.raw_dim());
scirs2_core::ndarray::Zip::from(&mut real_data)
.and(x)
.for_each(|dst, &src| {
*dst = NumCast::from(src)
.ok_or_else(|| FFTError::ValueError("Failed to convert input to float".to_string()))
.expect("Operation failed");
});
let mut result: Array<Complex64, D> = Array::zeros(x.raw_dim());
for lane in real_data.lanes(Axis(last_axis)) {
let real_vec: Vec<f64> = lane.to_vec();
let _complex_vec = rfft(&real_vec, None)?;
}
for &axis in &axes_to_transform {
apply_fft_along_axis(&mut result, axis)?;
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_axis_optimization() {
let axes = vec![0, 1, 2];
let shape = vec![10, 100, 1000];
let optimized = optimize_axis_order(&axes, &shape);
assert_eq!(optimized[0], 2);
assert_eq!(optimized[1], 1);
assert_eq!(optimized[2], 0);
}
#[test]
fn test_parallelize_decision() {
assert!(should_parallelize(10001, 100));
assert!(!should_parallelize(10001, 50));
assert!(!should_parallelize(100, 10));
}
#[test]
fn test_validate_axes() {
assert!(validate_axes(&[0, 1, 2], 3).is_ok());
assert!(validate_axes(&[0, 1, 3], 3).is_err());
}
}