use crate::error::{FFTError, FFTResult};
#[cfg(feature = "oxifft")]
use crate::oxifft_plan_cache;
#[cfg(feature = "oxifft")]
use oxifft::{Complex as OxiComplex, Direction};
#[cfg(feature = "rustfft-backend")]
use rustfft::{num_complex::Complex as RustComplex, FftPlanner};
use scirs2_core::ndarray::{Array2, ArrayD, Axis, IxDyn};
use scirs2_core::numeric::Complex64;
use scirs2_core::numeric::NumCast;
use scirs2_core::safe_ops::{safe_divide, safe_sqrt};
use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NormMode {
None,
Backward,
Ortho,
Forward,
}
impl From<&str> for NormMode {
fn from(s: &str) -> Self {
match s {
"backward" => NormMode::Backward,
"ortho" => NormMode::Ortho,
"forward" => NormMode::Forward,
_ => NormMode::None,
}
}
}
#[allow(dead_code)]
pub fn parse_norm_mode(_norm: Option<&str>, isinverse: bool) -> NormMode {
match _norm {
Some(s) => NormMode::from(s),
None if isinverse => NormMode::Backward, None => NormMode::None, }
}
#[allow(dead_code)]
fn apply_normalization(data: &mut [Complex64], n: usize, mode: NormMode) -> FFTResult<()> {
match mode {
NormMode::None => {} NormMode::Backward => {
let n_f64 = n as f64;
let scale = safe_divide(1.0, n_f64).map_err(|_| {
FFTError::ValueError(
"Division by zero in backward normalization: FFT size is zero".to_string(),
)
})?;
data.iter_mut().for_each(|c| *c *= scale);
}
NormMode::Ortho => {
let n_f64 = n as f64;
let sqrt_n = safe_sqrt(n_f64).map_err(|_| {
FFTError::ComputationError(
"Invalid square root in orthogonal normalization".to_string(),
)
})?;
let scale = safe_divide(1.0, sqrt_n).map_err(|_| {
FFTError::ValueError("Division by zero in orthogonal normalization".to_string())
})?;
data.iter_mut().for_each(|c| *c *= scale);
}
NormMode::Forward => {
let n_f64 = n as f64;
let scale = safe_divide(1.0, n_f64).map_err(|_| {
FFTError::ValueError(
"Division by zero in forward normalization: FFT size is zero".to_string(),
)
})?;
data.iter_mut().for_each(|c| *c *= scale);
}
}
Ok(())
}
#[allow(dead_code)]
fn convert_to_complex<T>(val: T) -> FFTResult<Complex64>
where
T: NumCast + Copy + Debug + 'static,
{
if let Some(real) = NumCast::from(val) {
return Ok(Complex64::new(real, 0.0));
}
use std::any::Any;
if let Some(complex) = (&val as &dyn Any).downcast_ref::<Complex64>() {
return Ok(*complex);
}
if let Some(complex32) = (&val as &dyn Any).downcast_ref::<scirs2_core::numeric::Complex<f32>>()
{
return Ok(Complex64::new(complex32.re as f64, complex32.im as f64));
}
Err(FFTError::ValueError(format!(
"Could not convert {val:?} to numeric type"
)))
}
#[allow(dead_code)]
fn to_complex<T>(input: &[T]) -> FFTResult<Vec<Complex64>>
where
T: NumCast + Copy + Debug + 'static,
{
input.iter().map(|&val| convert_to_complex(val)).collect()
}
#[allow(dead_code)]
pub fn fft<T>(input: &[T], n: Option<usize>) -> FFTResult<Vec<Complex64>>
where
T: NumCast + Copy + Debug + 'static,
{
if input.is_empty() {
return Err(FFTError::ValueError("Input cannot be empty".to_string()));
}
let input_len = input.len();
let fft_size = n.unwrap_or_else(|| input_len.next_power_of_two());
let mut data = to_complex(input)?;
if fft_size != input_len {
if fft_size > input_len {
data.resize(fft_size, Complex64::new(0.0, 0.0));
} else {
data.truncate(fft_size);
}
}
#[cfg(feature = "oxifft")]
{
let input_oxi: Vec<OxiComplex<f64>> =
data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); fft_size];
oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, Direction::Forward)?;
let result: Vec<Complex64> = output
.into_iter()
.map(|c| Complex64::new(c.re, c.im))
.collect();
Ok(result)
}
#[cfg(not(feature = "oxifft"))]
{
#[cfg(feature = "rustfft-backend")]
{
let mut planner = FftPlanner::new();
let fft_plan = planner.plan_fft_forward(fft_size);
let mut buffer: Vec<RustComplex<f64>> =
data.iter().map(|c| RustComplex::new(c.re, c.im)).collect();
fft_plan.process(&mut buffer);
let result: Vec<Complex64> = buffer
.into_iter()
.map(|c| Complex64::new(c.re, c.im))
.collect();
Ok(result)
}
#[cfg(not(feature = "rustfft-backend"))]
{
Err(FFTError::ComputationError(
"No FFT backend available. Enable either 'oxifft' or 'rustfft-backend' feature."
.to_string(),
))
}
}
}
#[allow(dead_code)]
#[allow(unreachable_code)]
pub fn ifft<T>(input: &[T], n: Option<usize>) -> FFTResult<Vec<Complex64>>
where
T: NumCast + Copy + Debug + 'static,
{
if input.is_empty() {
return Err(FFTError::ValueError("Input cannot be empty".to_string()));
}
let input_len = input.len();
let fft_size = n.unwrap_or_else(|| input_len.next_power_of_two());
let mut data = to_complex(input)?;
if fft_size != input_len {
if fft_size > input_len {
data.resize(fft_size, Complex64::new(0.0, 0.0));
} else {
data.truncate(fft_size);
}
}
#[cfg(all(not(feature = "oxifft"), not(feature = "rustfft-backend")))]
return Err(FFTError::ComputationError(
"No FFT backend available. Enable either 'oxifft' or 'rustfft-backend' feature."
.to_string(),
));
#[cfg(feature = "oxifft")]
let mut result = {
let input_oxi: Vec<OxiComplex<f64>> =
data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); fft_size];
oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, Direction::Backward)?;
let mut result: Vec<Complex64> = output
.into_iter()
.map(|c| Complex64::new(c.re, c.im))
.collect();
apply_normalization(&mut result, fft_size, NormMode::Backward)?;
result
};
#[cfg(all(not(feature = "oxifft"), feature = "rustfft-backend"))]
let mut result: Vec<Complex64> = {
let mut planner = FftPlanner::new();
let ifft_plan = planner.plan_fft_inverse(fft_size);
let mut buffer: Vec<RustComplex<f64>> =
data.iter().map(|c| RustComplex::new(c.re, c.im)).collect();
ifft_plan.process(&mut buffer);
let mut result: Vec<Complex64> = buffer
.into_iter()
.map(|c| Complex64::new(c.re, c.im))
.collect();
apply_normalization(&mut result, fft_size, NormMode::Backward)?;
result
};
#[cfg(any(feature = "oxifft", feature = "rustfft-backend"))]
{
if n.is_none() && fft_size > input_len {
result.truncate(input_len);
}
Ok(result)
}
#[cfg(all(not(feature = "oxifft"), not(feature = "rustfft-backend")))]
unreachable!()
}
#[allow(dead_code)]
#[allow(unreachable_code)]
pub fn fft2<T>(
input: &Array2<T>,
shape: Option<(usize, usize)>,
axes: Option<(i32, i32)>,
norm: Option<&str>,
) -> FFTResult<Array2<Complex64>>
where
T: NumCast + Copy + Debug + 'static,
{
#[cfg(all(not(feature = "oxifft"), not(feature = "rustfft-backend")))]
return Err(FFTError::ComputationError(
"No FFT backend available. Enable either 'oxifft' or 'rustfft-backend' feature."
.to_string(),
));
let inputshape = input.shape();
let outputshape = shape.unwrap_or((inputshape[0], inputshape[1]));
let axes = axes.unwrap_or((0, 1));
if axes.0 < 0 || axes.0 > 1 || axes.1 < 0 || axes.1 > 1 || axes.0 == axes.1 {
return Err(FFTError::ValueError("Invalid axes for 2D FFT".to_string()));
}
let norm_mode = parse_norm_mode(norm, false);
let mut output = Array2::<Complex64>::zeros(outputshape);
let mut complex_input = Array2::<Complex64>::zeros((inputshape[0], inputshape[1]));
for i in 0..inputshape[0] {
for j in 0..inputshape[1] {
let val = input[[i, j]];
complex_input[[i, j]] = convert_to_complex(val)?;
}
}
let mut padded_input = if inputshape != [outputshape.0, outputshape.1] {
let mut padded = Array2::<Complex64>::zeros((outputshape.0, outputshape.1));
let copy_rows = std::cmp::min(inputshape[0], outputshape.0);
let copy_cols = std::cmp::min(inputshape[1], outputshape.1);
for i in 0..copy_rows {
for j in 0..copy_cols {
padded[[i, j]] = complex_input[[i, j]];
}
}
padded
} else {
complex_input
};
#[cfg(feature = "oxifft")]
{
for mut row in padded_input.rows_mut() {
let input_oxi: Vec<OxiComplex<f64>> =
row.iter().map(|&c| OxiComplex::new(c.re, c.im)).collect();
let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); outputshape.1];
oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, Direction::Forward)?;
for (i, val) in output.iter().enumerate() {
row[i] = Complex64::new(val.re, val.im);
}
}
for mut col in padded_input.columns_mut() {
let input_oxi: Vec<OxiComplex<f64>> =
col.iter().map(|&c| OxiComplex::new(c.re, c.im)).collect();
let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); outputshape.0];
oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, Direction::Forward)?;
for (i, val) in output.iter().enumerate() {
col[i] = Complex64::new(val.re, val.im);
}
}
}
#[cfg(all(not(feature = "oxifft"), feature = "rustfft-backend"))]
{
let mut planner = FftPlanner::new();
let row_fft = planner.plan_fft_forward(outputshape.1);
for mut row in padded_input.rows_mut() {
let mut buffer: Vec<RustComplex<f64>> =
row.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
row_fft.process(&mut buffer);
for (i, val) in buffer.iter().enumerate() {
row[i] = Complex64::new(val.re, val.im);
}
}
let col_fft = planner.plan_fft_forward(outputshape.0);
for mut col in padded_input.columns_mut() {
let mut buffer: Vec<RustComplex<f64>> =
col.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
col_fft.process(&mut buffer);
for (i, val) in buffer.iter().enumerate() {
col[i] = Complex64::new(val.re, val.im);
}
}
}
if norm_mode != NormMode::None {
let total_elements = outputshape.0 * outputshape.1;
let scale = match norm_mode {
NormMode::Backward => 1.0 / (total_elements as f64),
NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
NormMode::Forward => 1.0 / (total_elements as f64),
NormMode::None => 1.0, };
padded_input.mapv_inplace(|x| x * scale);
}
output.assign(&padded_input);
Ok(output)
}
#[allow(dead_code)]
#[allow(unreachable_code)]
pub fn ifft2<T>(
input: &Array2<T>,
shape: Option<(usize, usize)>,
axes: Option<(i32, i32)>,
norm: Option<&str>,
) -> FFTResult<Array2<Complex64>>
where
T: NumCast + Copy + Debug + 'static,
{
#[cfg(all(not(feature = "oxifft"), not(feature = "rustfft-backend")))]
return Err(FFTError::ComputationError(
"No FFT backend available. Enable either 'oxifft' or 'rustfft-backend' feature."
.to_string(),
));
let inputshape = input.shape();
let outputshape = shape.unwrap_or((inputshape[0], inputshape[1]));
let axes = axes.unwrap_or((0, 1));
if axes.0 < 0 || axes.0 > 1 || axes.1 < 0 || axes.1 > 1 || axes.0 == axes.1 {
return Err(FFTError::ValueError("Invalid axes for 2D IFFT".to_string()));
}
let norm_mode = parse_norm_mode(norm, true);
let mut complex_input = Array2::<Complex64>::zeros((inputshape[0], inputshape[1]));
for i in 0..inputshape[0] {
for j in 0..inputshape[1] {
let val = input[[i, j]];
complex_input[[i, j]] = convert_to_complex(val)?;
}
}
let mut padded_input = if inputshape != [outputshape.0, outputshape.1] {
let mut padded = Array2::<Complex64>::zeros((outputshape.0, outputshape.1));
let copy_rows = std::cmp::min(inputshape[0], outputshape.0);
let copy_cols = std::cmp::min(inputshape[1], outputshape.1);
for i in 0..copy_rows {
for j in 0..copy_cols {
padded[[i, j]] = complex_input[[i, j]];
}
}
padded
} else {
complex_input
};
#[cfg(feature = "oxifft")]
{
for mut row in padded_input.rows_mut() {
let input_oxi: Vec<OxiComplex<f64>> =
row.iter().map(|&c| OxiComplex::new(c.re, c.im)).collect();
let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); outputshape.1];
oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, Direction::Backward)?;
for (i, val) in output.iter().enumerate() {
row[i] = Complex64::new(val.re, val.im);
}
}
for mut col in padded_input.columns_mut() {
let input_oxi: Vec<OxiComplex<f64>> =
col.iter().map(|&c| OxiComplex::new(c.re, c.im)).collect();
let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); outputshape.0];
oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, Direction::Backward)?;
for (i, val) in output.iter().enumerate() {
col[i] = Complex64::new(val.re, val.im);
}
}
}
#[cfg(all(not(feature = "oxifft"), feature = "rustfft-backend"))]
{
let mut planner = FftPlanner::new();
let row_ifft = planner.plan_fft_inverse(outputshape.1);
for mut row in padded_input.rows_mut() {
let mut buffer: Vec<RustComplex<f64>> =
row.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
row_ifft.process(&mut buffer);
for (i, val) in buffer.iter().enumerate() {
row[i] = Complex64::new(val.re, val.im);
}
}
let col_ifft = planner.plan_fft_inverse(outputshape.0);
for mut col in padded_input.columns_mut() {
let mut buffer: Vec<RustComplex<f64>> =
col.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
col_ifft.process(&mut buffer);
for (i, val) in buffer.iter().enumerate() {
col[i] = Complex64::new(val.re, val.im);
}
}
}
let total_elements = outputshape.0 * outputshape.1;
let scale = match norm_mode {
NormMode::Backward => 1.0 / (total_elements as f64),
NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
NormMode::Forward => 1.0, NormMode::None => 1.0, };
if scale != 1.0 {
padded_input.mapv_inplace(|x| x * scale);
}
Ok(padded_input)
}
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
#[allow(unreachable_code)]
pub fn fftn<T>(
input: &ArrayD<T>,
shape: Option<Vec<usize>>,
axes: Option<Vec<usize>>,
norm: Option<&str>,
_overwrite_x: Option<bool>,
_workers: Option<usize>,
) -> FFTResult<ArrayD<Complex64>>
where
T: NumCast + Copy + Debug + 'static,
{
let inputshape = input.shape().to_vec();
let input_ndim = inputshape.len();
let outputshape = shape.unwrap_or_else(|| inputshape.clone());
if outputshape.len() != input_ndim {
return Err(FFTError::ValueError(
"Output shape must have the same number of dimensions as input".to_string(),
));
}
let axes = axes.unwrap_or_else(|| (0..input_ndim).collect());
for &axis in &axes {
if axis >= input_ndim {
return Err(FFTError::ValueError(format!(
"Axis {axis} out of bounds for array of dimension {input_ndim}"
)));
}
}
let norm_mode = parse_norm_mode(norm, false);
let mut complex_input = ArrayD::<Complex64>::zeros(IxDyn(&inputshape));
for (idx, &val) in input.iter().enumerate() {
let mut idx_vec = Vec::with_capacity(input_ndim);
let mut remaining = idx;
for &dim in input.shape().iter().rev() {
idx_vec.push(remaining % dim);
remaining /= dim;
}
idx_vec.reverse();
complex_input[IxDyn(&idx_vec)] = convert_to_complex(val)?;
}
let mut result = if inputshape != outputshape {
let mut padded = ArrayD::<Complex64>::zeros(IxDyn(&outputshape));
for (idx, &val) in complex_input.iter().enumerate() {
let mut idx_vec = Vec::with_capacity(input_ndim);
let mut remaining = idx;
for &dim in input.shape().iter().rev() {
idx_vec.push(remaining % dim);
remaining /= dim;
}
idx_vec.reverse();
let mut in_bounds = true;
for (dim, &idx_val) in idx_vec.iter().enumerate() {
if idx_val >= outputshape[dim] {
in_bounds = false;
break;
}
}
if in_bounds {
padded[IxDyn(&idx_vec)] = val;
}
}
padded
} else {
complex_input
};
#[cfg(feature = "oxifft")]
{
for &axis in &axes {
let axis_len = outputshape[axis];
let axis_obj = Axis(axis);
for mut lane in result.lanes_mut(axis_obj) {
let input_oxi: Vec<OxiComplex<f64>> =
lane.iter().map(|&c| OxiComplex::new(c.re, c.im)).collect();
let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); axis_len];
oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, Direction::Forward)?;
for (i, val) in output.iter().enumerate() {
lane[i] = Complex64::new(val.re, val.im);
}
}
}
}
#[cfg(not(feature = "oxifft"))]
{
#[cfg(feature = "rustfft-backend")]
{
let mut planner = FftPlanner::new();
for &axis in &axes {
let axis_len = outputshape[axis];
let fft = planner.plan_fft_forward(axis_len);
let axis_obj = Axis(axis);
for mut lane in result.lanes_mut(axis_obj) {
let mut buffer: Vec<RustComplex<f64>> =
lane.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
fft.process(&mut buffer);
for (i, val) in buffer.iter().enumerate() {
lane[i] = Complex64::new(val.re, val.im);
}
}
}
}
#[cfg(not(feature = "rustfft-backend"))]
{
return Err(FFTError::ComputationError(
"No FFT backend available. Enable either 'oxifft' or 'rustfft-backend' feature."
.to_string(),
));
}
}
if norm_mode != NormMode::None {
let total_elements: usize = outputshape.iter().product();
let scale = match norm_mode {
NormMode::Backward => 1.0 / (total_elements as f64),
NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
NormMode::Forward => 1.0 / (total_elements as f64),
NormMode::None => 1.0, };
result.mapv_inplace(|_x| _x * scale);
}
Ok(result)
}
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
#[allow(unreachable_code)]
pub fn ifftn<T>(
input: &ArrayD<T>,
shape: Option<Vec<usize>>,
axes: Option<Vec<usize>>,
norm: Option<&str>,
_overwrite_x: Option<bool>,
_workers: Option<usize>,
) -> FFTResult<ArrayD<Complex64>>
where
T: NumCast + Copy + Debug + 'static,
{
let inputshape = input.shape().to_vec();
let input_ndim = inputshape.len();
let outputshape = shape.unwrap_or_else(|| inputshape.clone());
if outputshape.len() != input_ndim {
return Err(FFTError::ValueError(
"Output shape must have the same number of dimensions as input".to_string(),
));
}
let axes = axes.unwrap_or_else(|| (0..input_ndim).collect());
for &axis in &axes {
if axis >= input_ndim {
return Err(FFTError::ValueError(format!(
"Axis {axis} out of bounds for array of dimension {input_ndim}"
)));
}
}
let norm_mode = parse_norm_mode(norm, true);
let mut complex_input = ArrayD::<Complex64>::zeros(IxDyn(&inputshape));
for (idx, &val) in input.iter().enumerate() {
let mut idx_vec = Vec::with_capacity(input_ndim);
let mut remaining = idx;
for &dim in input.shape().iter().rev() {
idx_vec.push(remaining % dim);
remaining /= dim;
}
idx_vec.reverse();
complex_input[IxDyn(&idx_vec)] = convert_to_complex(val)?;
}
let mut result = if inputshape != outputshape {
let mut padded = ArrayD::<Complex64>::zeros(IxDyn(&outputshape));
for (idx, &val) in complex_input.iter().enumerate() {
let mut idx_vec = Vec::with_capacity(input_ndim);
let mut remaining = idx;
for &dim in input.shape().iter().rev() {
idx_vec.push(remaining % dim);
remaining /= dim;
}
idx_vec.reverse();
let mut in_bounds = true;
for (dim, &idx_val) in idx_vec.iter().enumerate() {
if idx_val >= outputshape[dim] {
in_bounds = false;
break;
}
}
if in_bounds {
padded[IxDyn(&idx_vec)] = val;
}
}
padded
} else {
complex_input
};
#[cfg(feature = "oxifft")]
{
for &axis in &axes {
let axis_len = outputshape[axis];
let axis_obj = Axis(axis);
for mut lane in result.lanes_mut(axis_obj) {
let input_oxi: Vec<OxiComplex<f64>> =
lane.iter().map(|&c| OxiComplex::new(c.re, c.im)).collect();
let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); axis_len];
oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, Direction::Backward)?;
for (i, val) in output.iter().enumerate() {
lane[i] = Complex64::new(val.re, val.im);
}
}
}
}
#[cfg(not(feature = "oxifft"))]
{
#[cfg(feature = "rustfft-backend")]
{
let mut planner = FftPlanner::new();
for &axis in &axes {
let axis_len = outputshape[axis];
let ifft = planner.plan_fft_inverse(axis_len);
let axis_obj = Axis(axis);
for mut lane in result.lanes_mut(axis_obj) {
let mut buffer: Vec<RustComplex<f64>> =
lane.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
ifft.process(&mut buffer);
for (i, val) in buffer.iter().enumerate() {
lane[i] = Complex64::new(val.re, val.im);
}
}
}
}
#[cfg(not(feature = "rustfft-backend"))]
{
return Err(FFTError::ComputationError(
"No FFT backend available. Enable either 'oxifft' or 'rustfft-backend' feature."
.to_string(),
));
}
}
if norm_mode != NormMode::None {
let total_elements: usize = axes.iter().map(|&a| outputshape[a]).product();
let scale = match norm_mode {
NormMode::Backward => 1.0 / (total_elements as f64),
NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
NormMode::Forward => 1.0, NormMode::None => 1.0, };
if scale != 1.0 {
result.mapv_inplace(|_x| _x * scale);
}
}
Ok(result)
}