use crate::error::FFTResult;
use crate::fft::algorithms::{parse_norm_mode, NormMode};
#[cfg(feature = "oxifft")]
use crate::oxifft_plan_cache;
#[cfg(feature = "oxifft")]
use oxifft::{Complex as OxiComplex, Direction};
#[cfg(feature = "rustfft-backend")]
use rustfft::{num_complex::Complex as RustComplex, FftPlanner};
use scirs2_core::ndarray::{Array2, Axis};
use scirs2_core::numeric::Complex64;
use scirs2_core::numeric::NumCast;
use scirs2_core::parallel_ops::*;
#[cfg(feature = "parallel")]
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
pub fn fft2_parallel<T>(
input: &Array2<T>,
shape: Option<(usize, usize)>,
axes: Option<(i32, i32)>,
norm: Option<&str>,
workers: Option<usize>,
) -> FFTResult<Array2<Complex64>>
where
T: NumCast + Copy + std::fmt::Debug + 'static,
{
let inputshape = input.shape();
let outputshape = shape.unwrap_or((inputshape[0], inputshape[1]));
let axes = axes.unwrap_or((0, 1));
if axes.0 < 0 || axes.0 > 1 || axes.1 < 0 || axes.1 > 1 || axes.0 == axes.1 {
return Err(crate::FFTError::ValueError(
"Invalid axes for 2D FFT".to_string(),
));
}
let norm_mode = parse_norm_mode(norm, false);
#[cfg(feature = "parallel")]
let num_workers = workers.unwrap_or_else(|| num_threads().min(8));
let mut complex_input = Array2::<Complex64>::zeros((inputshape[0], inputshape[1]));
for i in 0..inputshape[0] {
for j in 0..inputshape[1] {
let val = input[[i, j]];
if let Some(c) = crate::fft::utility::try_as_complex(val) {
complex_input[[i, j]] = c;
} else {
let real = NumCast::from(val).ok_or_else(|| {
crate::FFTError::ValueError(format!("Could not convert {val:?} to f64"))
})?;
complex_input[[i, j]] = Complex64::new(real, 0.0);
}
}
}
let mut padded_input = if inputshape != [outputshape.0, outputshape.1] {
let mut padded = Array2::<Complex64>::zeros((outputshape.0, outputshape.1));
let copy_rows = std::cmp::min(inputshape[0], outputshape.0);
let copy_cols = std::cmp::min(inputshape[1], outputshape.1);
for i in 0..copy_rows {
for j in 0..copy_cols {
padded[[i, j]] = complex_input[[i, j]];
}
}
padded
} else {
complex_input
};
#[cfg(feature = "oxifft")]
{
if num_workers > 1 {
padded_input
.axis_iter_mut(Axis(0))
.into_par_iter()
.for_each(|mut row| {
let input_oxi: Vec<OxiComplex<f64>> =
row.iter().map(|&c| OxiComplex::new(c.re, c.im)).collect();
let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); outputshape.1];
if let Err(_e) =
oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, Direction::Forward)
{
return;
}
for (i, val) in output.iter().enumerate() {
row[i] = Complex64::new(val.re, val.im);
}
});
} else {
for mut row in padded_input.rows_mut() {
let input_oxi: Vec<OxiComplex<f64>> =
row.iter().map(|&c| OxiComplex::new(c.re, c.im)).collect();
let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); outputshape.1];
oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, Direction::Forward)?;
for (i, val) in output.iter().enumerate() {
row[i] = Complex64::new(val.re, val.im);
}
}
}
if num_workers > 1 {
padded_input
.axis_iter_mut(Axis(1))
.into_par_iter()
.for_each(|mut col| {
let input_oxi: Vec<OxiComplex<f64>> =
col.iter().map(|&c| OxiComplex::new(c.re, c.im)).collect();
let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); outputshape.0];
if let Err(_e) =
oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, Direction::Forward)
{
return;
}
for (i, val) in output.iter().enumerate() {
col[i] = Complex64::new(val.re, val.im);
}
});
} else {
for mut col in padded_input.columns_mut() {
let input_oxi: Vec<OxiComplex<f64>> =
col.iter().map(|&c| OxiComplex::new(c.re, c.im)).collect();
let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); outputshape.0];
oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, Direction::Forward)?;
for (i, val) in output.iter().enumerate() {
col[i] = Complex64::new(val.re, val.im);
}
}
}
}
#[cfg(not(feature = "oxifft"))]
{
#[cfg(feature = "rustfft-backend")]
{
let mut planner = FftPlanner::new();
let row_fft = planner.plan_fft_forward(outputshape.1);
if num_workers > 1 {
padded_input
.axis_iter_mut(Axis(0))
.into_par_iter()
.for_each(|mut row| {
let mut buffer: Vec<RustComplex<f64>> =
row.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
row_fft.process(&mut buffer);
for (i, val) in buffer.iter().enumerate() {
row[i] = Complex64::new(val.re, val.im);
}
});
} else {
for mut row in padded_input.rows_mut() {
let mut buffer: Vec<RustComplex<f64>> =
row.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
row_fft.process(&mut buffer);
for (i, val) in buffer.iter().enumerate() {
row[i] = Complex64::new(val.re, val.im);
}
}
}
let col_fft = planner.plan_fft_forward(outputshape.0);
if num_workers > 1 {
padded_input
.axis_iter_mut(Axis(1))
.into_par_iter()
.for_each(|mut col| {
let mut buffer: Vec<RustComplex<f64>> =
col.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
col_fft.process(&mut buffer);
for (i, val) in buffer.iter().enumerate() {
col[i] = Complex64::new(val.re, val.im);
}
});
} else {
for mut col in padded_input.columns_mut() {
let mut buffer: Vec<RustComplex<f64>> =
col.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
col_fft.process(&mut buffer);
for (i, val) in buffer.iter().enumerate() {
col[i] = Complex64::new(val.re, val.im);
}
}
}
}
#[cfg(not(feature = "rustfft-backend"))]
{
return Err(crate::FFTError::ComputationError(
"No FFT backend available. Enable either 'oxifft' or 'rustfft-backend' feature."
.to_string(),
));
}
}
if norm_mode != NormMode::None {
let total_elements = outputshape.0 * outputshape.1;
let scale = match norm_mode {
NormMode::Backward => 1.0 / (total_elements as f64),
NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
NormMode::Forward => 1.0 / (total_elements as f64),
NormMode::None => 1.0, };
padded_input.mapv_inplace(|x| x * scale);
}
Ok(padded_input)
}
#[cfg(not(feature = "parallel"))]
#[allow(dead_code)]
pub fn fft2_parallel<T>(
input: &Array2<T>,
shape: Option<(usize, usize)>,
_axes: Option<(i32, i32)>,
_norm: Option<&str>,
_workers: Option<usize>,
) -> FFTResult<Array2<Complex64>>
where
T: NumCast + Copy + std::fmt::Debug + 'static,
{
crate::fft::algorithms::fft2(input, shape, None, None)
}
#[cfg(feature = "parallel")]
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
pub fn ifft2_parallel<T>(
input: &Array2<T>,
shape: Option<(usize, usize)>,
axes: Option<(i32, i32)>,
norm: Option<&str>,
workers: Option<usize>,
) -> FFTResult<Array2<Complex64>>
where
T: NumCast + Copy + std::fmt::Debug + 'static,
{
let inputshape = input.shape();
let outputshape = shape.unwrap_or((inputshape[0], inputshape[1]));
let axes = axes.unwrap_or((0, 1));
if axes.0 < 0 || axes.0 > 1 || axes.1 < 0 || axes.1 > 1 || axes.0 == axes.1 {
return Err(crate::FFTError::ValueError(
"Invalid axes for 2D IFFT".to_string(),
));
}
let norm_mode = parse_norm_mode(norm, true);
#[cfg(feature = "parallel")]
let num_workers = workers.unwrap_or_else(|| num_threads().min(8));
let mut complex_input = Array2::<Complex64>::zeros((inputshape[0], inputshape[1]));
for i in 0..inputshape[0] {
for j in 0..inputshape[1] {
let val = input[[i, j]];
if let Some(c) = crate::fft::utility::try_as_complex(val) {
complex_input[[i, j]] = c;
} else {
let real = NumCast::from(val).ok_or_else(|| {
crate::FFTError::ValueError(format!("Could not convert {val:?} to f64"))
})?;
complex_input[[i, j]] = Complex64::new(real, 0.0);
}
}
}
let mut padded_input = if inputshape != [outputshape.0, outputshape.1] {
let mut padded = Array2::<Complex64>::zeros((outputshape.0, outputshape.1));
let copy_rows = std::cmp::min(inputshape[0], outputshape.0);
let copy_cols = std::cmp::min(inputshape[1], outputshape.1);
for i in 0..copy_rows {
for j in 0..copy_cols {
padded[[i, j]] = complex_input[[i, j]];
}
}
padded
} else {
complex_input
};
#[cfg(feature = "oxifft")]
{
if num_workers > 1 {
padded_input
.axis_iter_mut(Axis(0))
.into_par_iter()
.for_each(|mut row| {
let input_oxi: Vec<OxiComplex<f64>> =
row.iter().map(|&c| OxiComplex::new(c.re, c.im)).collect();
let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); outputshape.1];
if let Err(_e) =
oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, Direction::Backward)
{
return;
}
for (i, val) in output.iter().enumerate() {
row[i] = Complex64::new(val.re, val.im);
}
});
} else {
for mut row in padded_input.rows_mut() {
let input_oxi: Vec<OxiComplex<f64>> =
row.iter().map(|&c| OxiComplex::new(c.re, c.im)).collect();
let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); outputshape.1];
oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, Direction::Backward)?;
for (i, val) in output.iter().enumerate() {
row[i] = Complex64::new(val.re, val.im);
}
}
}
if num_workers > 1 {
padded_input
.axis_iter_mut(Axis(1))
.into_par_iter()
.for_each(|mut col| {
let input_oxi: Vec<OxiComplex<f64>> =
col.iter().map(|&c| OxiComplex::new(c.re, c.im)).collect();
let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); outputshape.0];
if let Err(_e) =
oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, Direction::Backward)
{
return;
}
for (i, val) in output.iter().enumerate() {
col[i] = Complex64::new(val.re, val.im);
}
});
} else {
for mut col in padded_input.columns_mut() {
let input_oxi: Vec<OxiComplex<f64>> =
col.iter().map(|&c| OxiComplex::new(c.re, c.im)).collect();
let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); outputshape.0];
oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, Direction::Backward)?;
for (i, val) in output.iter().enumerate() {
col[i] = Complex64::new(val.re, val.im);
}
}
}
}
#[cfg(not(feature = "oxifft"))]
{
#[cfg(feature = "rustfft-backend")]
{
let mut planner = FftPlanner::new();
let row_ifft = planner.plan_fft_inverse(outputshape.1);
if num_workers > 1 {
padded_input
.axis_iter_mut(Axis(0))
.into_par_iter()
.for_each(|mut row| {
let mut buffer: Vec<RustComplex<f64>> =
row.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
row_ifft.process(&mut buffer);
for (i, val) in buffer.iter().enumerate() {
row[i] = Complex64::new(val.re, val.im);
}
});
} else {
for mut row in padded_input.rows_mut() {
let mut buffer: Vec<RustComplex<f64>> =
row.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
row_ifft.process(&mut buffer);
for (i, val) in buffer.iter().enumerate() {
row[i] = Complex64::new(val.re, val.im);
}
}
}
let col_ifft = planner.plan_fft_inverse(outputshape.0);
if num_workers > 1 {
padded_input
.axis_iter_mut(Axis(1))
.into_par_iter()
.for_each(|mut col| {
let mut buffer: Vec<RustComplex<f64>> =
col.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
col_ifft.process(&mut buffer);
for (i, val) in buffer.iter().enumerate() {
col[i] = Complex64::new(val.re, val.im);
}
});
} else {
for mut col in padded_input.columns_mut() {
let mut buffer: Vec<RustComplex<f64>> =
col.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
col_ifft.process(&mut buffer);
for (i, val) in buffer.iter().enumerate() {
col[i] = Complex64::new(val.re, val.im);
}
}
}
}
#[cfg(not(feature = "rustfft-backend"))]
{
return Err(crate::FFTError::ComputationError(
"No FFT backend available. Enable either 'oxifft' or 'rustfft-backend' feature."
.to_string(),
));
}
}
let total_elements = outputshape.0 * outputshape.1;
let scale = match norm_mode {
NormMode::Backward => 1.0 / (total_elements as f64),
NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
NormMode::Forward => 1.0, NormMode::None => 1.0, };
if scale != 1.0 {
padded_input.mapv_inplace(|x| x * scale);
}
Ok(padded_input)
}
#[cfg(not(feature = "parallel"))]
#[allow(dead_code)]
pub fn ifft2_parallel<T>(
input: &Array2<T>,
shape: Option<(usize, usize)>,
_axes: Option<(i32, i32)>,
_norm: Option<&str>,
_workers: Option<usize>,
) -> FFTResult<Array2<Complex64>>
where
T: NumCast + Copy + std::fmt::Debug + 'static,
{
crate::fft::algorithms::ifft2(input, shape, None, None)
}