use crate::error::{FFTError, FFTResult};
use crate::fft::NormMode;
#[cfg(feature = "oxifft")]
use crate::oxifft_plan_cache;
#[cfg(feature = "oxifft")]
use oxifft::{Complex as OxiComplex, Direction};
use scirs2_core::ndarray::{Array, Array1, Array2, ArrayD, Dimension, IxDyn, ShapeBuilder};
use scirs2_core::numeric::Complex64;
use scirs2_core::numeric::NumCast;
#[cfg(feature = "rustfft-backend")]
use rustfft::{num_complex::Complex as RustComplex, FftPlanner};
use std::fmt::Debug;
use std::sync::Arc;
#[cfg(feature = "oxifft")]
thread_local! {
static BUFFER_CACHE: std::cell::RefCell<Option<Vec<OxiComplex<f64>>>> = std::cell::RefCell::new(None);
}
#[cfg(all(not(feature = "oxifft"), feature = "rustfft-backend"))]
thread_local! {
static BUFFER_CACHE: std::cell::RefCell<Option<Vec<RustComplex<f64>>>> = std::cell::RefCell::new(None);
}
#[cfg(feature = "oxifft")]
#[allow(dead_code)]
fn get_or_create_buffer(size: usize) -> Vec<OxiComplex<f64>> {
BUFFER_CACHE.with(|cache| {
let mut cache_ref = cache.borrow_mut();
if let Some(buffer) = cache_ref.take() {
if buffer.capacity() >= size {
let mut buffer = buffer;
buffer.resize(size, OxiComplex::zero());
return buffer;
}
}
Vec::with_capacity(size)
})
}
#[cfg(all(not(feature = "oxifft"), feature = "rustfft-backend"))]
#[allow(dead_code)]
fn get_or_create_buffer(size: usize) -> Vec<RustComplex<f64>> {
BUFFER_CACHE.with(|cache| {
let mut cache_ref = cache.borrow_mut();
if let Some(buffer) = cache_ref.take() {
if buffer.capacity() >= size {
let mut buffer = buffer;
buffer.resize(size, RustComplex::new(0.0, 0.0));
return buffer;
}
}
Vec::with_capacity(size)
})
}
#[cfg(feature = "oxifft")]
#[allow(dead_code)]
fn return_buffer_to_cache(buffer: Vec<OxiComplex<f64>>) {
BUFFER_CACHE.with(|cache| {
*cache.borrow_mut() = Some(buffer);
});
}
#[cfg(all(not(feature = "oxifft"), feature = "rustfft-backend"))]
#[allow(dead_code)]
fn return_buffer_to_cache(buffer: Vec<RustComplex<f64>>) {
BUFFER_CACHE.with(|cache| {
*cache.borrow_mut() = Some(buffer);
});
}
#[allow(dead_code)]
fn to_complex_value<T>(val: T) -> FFTResult<Complex64>
where
T: NumCast + Copy + Debug + 'static,
{
if let Some(complex) = try_as_complex(&val) {
return Ok(complex);
}
let real = NumCast::from(val)
.ok_or_else(|| FFTError::ValueError(format!("Could not convert {:?} to f64", val)))?;
Ok(Complex64::new(real, 0.0))
}
#[allow(dead_code)]
fn try_as_complex<T: 'static>(val: &T) -> Option<Complex64> {
use std::any::Any;
if let Some(complex) = (val as &dyn Any).downcast_ref::<Complex64>() {
return Some(*complex);
}
if let Some(complex) = (val as &dyn Any).downcast_ref::<scirs2_core::numeric::Complex<f32>>() {
return Some(Complex64::new(complex.re as f64, complex.im as f64));
}
#[cfg(feature = "oxifft")]
{
if let Some(complex) = (val as &dyn Any).downcast_ref::<OxiComplex<f64>>() {
return Some(Complex64::new(complex.re, complex.im));
}
if let Some(complex) = (val as &dyn Any).downcast_ref::<OxiComplex<f32>>() {
return Some(Complex64::new(complex.re as f64, complex.im as f64));
}
}
#[cfg(feature = "rustfft-backend")]
{
if let Some(complex) = (val as &dyn Any).downcast_ref::<RustComplex<f64>>() {
return Some(Complex64::new(complex.re, complex.im));
}
if let Some(complex) = (val as &dyn Any).downcast_ref::<RustComplex<f32>>() {
return Some(Complex64::new(complex.re as f64, complex.im as f64));
}
}
None
}
#[allow(dead_code)]
pub fn fft_optimized<T>(
input: &[T],
n: Option<usize>,
norm: Option<NormMode>,
out: Option<&mut Vec<Complex64>>,
) -> FFTResult<Vec<Complex64>>
where
T: NumCast + Copy + Debug + 'static,
{
if input.is_empty() {
return Err(FFTError::ValueError("Input cannot be empty".to_string()));
}
let input_len = input.len();
let fft_size = n.unwrap_or_else(|| input_len.next_power_of_two());
let norm_mode = norm.unwrap_or(NormMode::None);
let mut buffer = get_or_create_buffer(fft_size);
#[cfg(feature = "oxifft")]
{
buffer.resize(fft_size, OxiComplex::zero());
}
#[cfg(all(not(feature = "oxifft"), feature = "rustfft-backend"))]
{
buffer.resize(fft_size, RustComplex::new(0.0, 0.0));
}
#[cfg(feature = "oxifft")]
{
for (i, val) in input.iter().enumerate() {
if i < fft_size {
let complex = to_complex_value(*val)?;
buffer[i] = OxiComplex::new(complex.re, complex.im);
}
}
}
#[cfg(all(not(feature = "oxifft"), feature = "rustfft-backend"))]
{
for (i, val) in input.iter().enumerate() {
if i < fft_size {
let complex = to_complex_value(*val)?;
buffer[i] = RustComplex::new(complex.re, complex.im);
}
}
}
#[cfg(feature = "oxifft")]
{
let input_oxi: Vec<OxiComplex<f64>> = buffer.iter().map(|&c| *c).collect();
let mut output_oxi: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); fft_size];
oxifft_plan_cache::execute_c2c(&input_oxi, &mut output_oxi, Direction::Forward)?;
buffer.clear();
buffer.extend(output_oxi);
}
#[cfg(all(not(feature = "oxifft"), feature = "rustfft-backend"))]
{
static PLANNER_CACHE: std::sync::OnceLock<std::sync::Mutex<FftPlanner<f64>>> = std::sync::OnceLock::new();
let planner = PLANNER_CACHE.get_or_init(|| std::sync::Mutex::new(FftPlanner::new()));
let fft_plan = {
let mut planner = planner.lock().expect("Operation failed");
planner.plan_fft_forward(fft_size)
};
fft_plan.process(&mut buffer);
}
#[cfg(all(not(feature = "oxifft"), not(feature = "rustfft-backend")))]
{
return Err(FFTError::ComputationError(
"No FFT backend available. Enable either 'oxifft' or 'rustfft-backend' feature.".to_string()
));
}
if norm_mode != NormMode::None {
let scale = match norm_mode {
NormMode::Forward => 1.0 / (fft_size as f64),
NormMode::Backward => 1.0, NormMode::Ortho => 1.0 / (fft_size as f64).sqrt(),
NormMode::None => 1.0, };
if scale != 1.0 {
buffer.iter_mut().for_each(|c| {
c.re *= scale;
c.im *= scale;
});
}
}
let result = if let Some(output_vec) = out {
output_vec.resize(fft_size, Complex64::new(0.0, 0.0));
for (i, c) in buffer.iter().enumerate() {
output_vec[i] = Complex64::new(c.re, c.im);
}
std::mem::take(output_vec)
} else {
let mut result = Vec::with_capacity(fft_size);
unsafe {
result.set_len(fft_size);
}
for (i, c) in buffer.iter().enumerate() {
result[i] = Complex64::new(c.re, c.im);
}
result
};
return_buffer_to_cache(buffer);
Ok(result)
}
#[allow(dead_code)]
pub fn ifft_optimized<T>(
input: &[T],
n: Option<usize>,
norm: Option<NormMode>,
out: Option<&mut Vec<Complex64>>,
) -> FFTResult<Vec<Complex64>>
where
T: NumCast + Copy + Debug + 'static,
{
if input.is_empty() {
return Err(FFTError::ValueError("Input cannot be empty".to_string()));
}
let input_len = input.len();
let fft_size = n.unwrap_or_else(|| input_len.next_power_of_two());
let norm_mode = norm.unwrap_or(NormMode::Backward);
let mut buffer = get_or_create_buffer(fft_size);
#[cfg(feature = "oxifft")]
{
buffer.resize(fft_size, OxiComplex::zero());
}
#[cfg(all(not(feature = "oxifft"), feature = "rustfft-backend"))]
{
buffer.resize(fft_size, RustComplex::new(0.0, 0.0));
}
#[cfg(feature = "oxifft")]
{
for (i, val) in input.iter().enumerate() {
if i < fft_size {
let complex = to_complex_value(*val)?;
buffer[i] = OxiComplex::new(complex.re, complex.im);
}
}
}
#[cfg(all(not(feature = "oxifft"), feature = "rustfft-backend"))]
{
for (i, val) in input.iter().enumerate() {
if i < fft_size {
let complex = to_complex_value(*val)?;
buffer[i] = RustComplex::new(complex.re, complex.im);
}
}
}
#[cfg(feature = "oxifft")]
{
let input_oxi: Vec<OxiComplex<f64>> = buffer.iter().map(|&c| *c).collect();
let mut output_oxi: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); fft_size];
oxifft_plan_cache::execute_c2c(&input_oxi, &mut output_oxi, Direction::Backward)?;
buffer.clear();
buffer.extend(output_oxi);
}
#[cfg(all(not(feature = "oxifft"), feature = "rustfft-backend"))]
{
static PLANNER_CACHE: std::sync::OnceLock<std::sync::Mutex<FftPlanner<f64>>> = std::sync::OnceLock::new();
let planner = PLANNER_CACHE.get_or_init(|| std::sync::Mutex::new(FftPlanner::new()));
let ifft_plan = {
let mut planner = planner.lock().expect("Operation failed");
planner.plan_fft_inverse(fft_size)
};
ifft_plan.process(&mut buffer);
}
#[cfg(all(not(feature = "oxifft"), not(feature = "rustfft-backend")))]
{
return Err(FFTError::ComputationError(
"No FFT backend available. Enable either 'oxifft' or 'rustfft-backend' feature.".to_string()
));
}
if norm_mode != NormMode::None {
let scale = match norm_mode {
NormMode::Forward => 1.0, NormMode::Backward => 1.0 / (fft_size as f64),
NormMode::Ortho => 1.0 / (fft_size as f64).sqrt(),
NormMode::None => 1.0, };
if scale != 1.0 {
buffer.iter_mut().for_each(|c| {
c.re *= scale;
c.im *= scale;
});
}
}
let result = if let Some(output_vec) = out {
output_vec.resize(fft_size, Complex64::new(0.0, 0.0));
for (i, c) in buffer.iter().enumerate() {
output_vec[i] = Complex64::new(c.re, c.im);
}
std::mem::take(output_vec)
} else {
let mut result = Vec::with_capacity(fft_size);
unsafe {
result.set_len(fft_size);
}
for (i, c) in buffer.iter().enumerate() {
result[i] = Complex64::new(c.re, c.im);
}
result
};
return_buffer_to_cache(buffer);
Ok(result)
}
#[allow(dead_code)]
pub fn fft2_optimized<T>(
input: &Array2<T>,
shape: Option<(usize, usize)>,
axes: Option<(i32, i32)>,
norm: Option<&str>,
) -> FFTResult<Array2<Complex64>>
where
T: NumCast + Copy + Debug + 'static,
{
let inputshape = input.shape();
let (n_rows_out, n_cols_out) = shape.unwrap_or((inputshape[0], inputshape[1]));
let (axis1, axis2) = axes.unwrap_or((0, 1));
if axis1 < 0 || axis1 > 1 || axis2 < 0 || axis2 > 1 || axis1 == axis2 {
return Err(FFTError::ValueError("Invalid axes for 2D FFT".to_string()));
}
let norm_mode = match norm {
Some("forward") => NormMode::Forward,
Some("backward") => NormMode::Backward,
Some("ortho") => NormMode::Ortho,
_ => NormMode::Backward, };
let mut output = Array2::<Complex64>::zeros((n_rows_out, n_cols_out));
let mut temp_buffer = Vec::with_capacity(inputshape[0].max(inputshape[1]));
let mut output_buffer = Vec::with_capacity(n_rows_out.max(n_cols_out));
for i in 0..inputshape[0].min(n_rows_out) {
temp_buffer.clear();
for j in 0..inputshape[1] {
let complex = to_complex_value(input[[i, j]])?;
temp_buffer.push(complex);
}
let row_fft = fft_optimized(&temp_buffer, Some(n_cols_out), Some(NormMode::None), Some(&mut output_buffer))?;
for (j, &val) in row_fft.iter().enumerate() {
output[[i, j]] = val;
}
}
for i in inputshape[0].min(n_rows_out)..n_rows_out {
for j in 0..n_cols_out {
output[[i, j]] = Complex64::new(0.0, 0.0);
}
}
temp_buffer.clear();
temp_buffer.resize(n_rows_out, Complex64::new(0.0, 0.0));
for j in 0..n_cols_out {
for i in 0..n_rows_out {
temp_buffer[i] = output[[i, j]];
}
let col_fft = fft_optimized(&temp_buffer, Some(n_rows_out), Some(NormMode::None), Some(&mut output_buffer))?;
for (i, &val) in col_fft.iter().enumerate() {
output[[i, j]] = val;
}
}
if norm_mode != NormMode::None {
let scale = match norm_mode {
NormMode::Forward => 1.0 / (n_rows_out * n_cols_out) as f64,
NormMode::Backward => 1.0, NormMode::Ortho => 1.0 / ((n_rows_out * n_cols_out) as f64).sqrt(),
NormMode::None => 1.0, };
if scale != 1.0 {
output.iter_mut().for_each(|c| *c *= scale);
}
}
Ok(output)
}
#[allow(dead_code)]
pub fn ifft2_optimized<T>(
input: &Array2<T>,
shape: Option<(usize, usize)>,
axes: Option<(i32, i32)>,
norm: Option<&str>,
) -> FFTResult<Array2<Complex64>>
where
T: NumCast + Copy + Debug + 'static,
{
let inputshape = input.shape();
let (n_rows_out, n_cols_out) = shape.unwrap_or((inputshape[0], inputshape[1]));
let (axis1, axis2) = axes.unwrap_or((0, 1));
if axis1 < 0 || axis1 > 1 || axis2 < 0 || axis2 > 1 || axis1 == axis2 {
return Err(FFTError::ValueError("Invalid axes for 2D IFFT".to_string()));
}
let norm_mode = match norm {
Some("forward") => NormMode::Forward,
Some("backward") => NormMode::Backward,
Some("ortho") => NormMode::Ortho,
_ => NormMode::Backward, };
let mut output = Array2::<Complex64>::zeros((n_rows_out, n_cols_out));
let mut temp_buffer = Vec::with_capacity(inputshape[0].max(inputshape[1]));
let mut output_buffer = Vec::with_capacity(n_rows_out.max(n_cols_out));
for i in 0..inputshape[0].min(n_rows_out) {
temp_buffer.clear();
for j in 0..inputshape[1] {
let complex = to_complex_value(input[[i, j]])?;
temp_buffer.push(complex);
}
let row_ifft = ifft_optimized(&temp_buffer, Some(n_cols_out), Some(NormMode::None), Some(&mut output_buffer))?;
for (j, &val) in row_ifft.iter().enumerate() {
output[[i, j]] = val;
}
}
for i in inputshape[0].min(n_rows_out)..n_rows_out {
for j in 0..n_cols_out {
output[[i, j]] = Complex64::new(0.0, 0.0);
}
}
temp_buffer.clear();
temp_buffer.resize(n_rows_out, Complex64::new(0.0, 0.0));
for j in 0..n_cols_out {
for i in 0..n_rows_out {
temp_buffer[i] = output[[i, j]];
}
let col_ifft = ifft_optimized(&temp_buffer, Some(n_rows_out), Some(NormMode::None), Some(&mut output_buffer))?;
for (i, &val) in col_ifft.iter().enumerate() {
output[[i, j]] = val;
}
}
if norm_mode != NormMode::None {
let scale = match norm_mode {
NormMode::Forward => 1.0, NormMode::Backward => 1.0 / (n_rows_out * n_cols_out) as f64,
NormMode::Ortho => 1.0 / ((n_rows_out * n_cols_out) as f64).sqrt(),
NormMode::None => 1.0, };
if scale != 1.0 {
output.iter_mut().for_each(|c| *c *= scale);
}
}
Ok(output)
}