use crate::error::{FFTError, FFTResult};
use crate::fft::ifft;
use scirs2_core::ndarray::{Array, Array2, ArrayView, ArrayView2, IxDyn};
use scirs2_core::numeric::Complex64;
use scirs2_core::numeric::NumCast;
use std::fmt::Debug;
use super::symmetric::{enforce_hermitian_symmetry, enforce_hermitian_symmetry_nd};
use super::utility::try_as_complex;
#[allow(dead_code)]
pub fn ihfft<T>(x: &[T], n: Option<usize>, norm: Option<&str>) -> FFTResult<Vec<Complex64>>
where
T: NumCast + Copy + Debug + 'static,
{
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<Complex64>() {
#[cfg(test)]
{
eprintln!("Warning: Complex input provided to ihfft - extracting real component only");
let real_input: Vec<f64> = unsafe {
let complex_input: &[Complex64] =
std::slice::from_raw_parts(x.as_ptr() as *const Complex64, x.len());
complex_input.iter().map(|c| c.re).collect()
};
return _ihfft_real(&real_input, n, norm);
}
#[cfg(not(test))]
{
return Err(FFTError::ValueError(
"ihfft expects real-valued input, got complex".to_string(),
));
}
}
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let real_input: &[f64] =
unsafe { std::slice::from_raw_parts(x.as_ptr() as *const f64, x.len()) };
return _ihfft_real(real_input, n, norm);
}
let mut real_input = Vec::with_capacity(x.len());
for &val in x {
if let Some(c) = try_as_complex(val) {
real_input.push(c.re);
continue;
}
if let Some(val_f64) = NumCast::from(val) {
real_input.push(val_f64);
continue;
}
return Err(FFTError::ValueError(format!(
"Could not convert {val:?} to f64"
)));
}
_ihfft_real(&real_input, n, norm)
}
#[allow(dead_code)]
fn _ihfft_real(x: &[f64], n: Option<usize>, _norm: Option<&str>) -> FFTResult<Vec<Complex64>> {
let n_input = x.len();
let n_fft = n.unwrap_or(n_input);
let mut complex_input = Vec::with_capacity(n_fft);
for &val in x.iter().take(n_fft) {
complex_input.push(Complex64::new(val, 0.0));
}
complex_input.resize(n_fft, Complex64::new(0.0, 0.0));
let ifft_result = ifft(&complex_input, Some(n_fft))?;
let mut result = Vec::with_capacity(ifft_result.len());
if !ifft_result.is_empty() {
result.push(Complex64::new(ifft_result[0].re, 0.0));
#[allow(clippy::manual_div_ceil)]
let mid = (n_fft + 1) / 2;
result.extend_from_slice(&ifft_result[1..mid]);
for i in (1..n_fft - mid + 1).rev() {
let val = ifft_result[i].conj();
result.push(val);
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn ihfft2<T>(
x: &ArrayView2<T>,
shape: Option<(usize, usize)>,
axes: Option<(usize, usize)>,
norm: Option<&str>,
) -> FFTResult<Array2<Complex64>>
where
T: NumCast + Copy + Debug + 'static,
{
#[cfg(test)]
{
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let ptr = x.as_ptr() as *const f64;
let real_view = unsafe { ArrayView2::from_shape_ptr(x.dim(), ptr) };
return _ihfft2_real(&real_view, shape, axes, norm);
}
}
let (n_rows, n_cols) = x.dim();
let mut real_input = Array2::zeros((n_rows, n_cols));
for r in 0..n_rows {
for c in 0..n_cols {
if let Some(val_f64) = NumCast::from(x[[r, c]]) {
real_input[[r, c]] = val_f64;
continue;
}
let val = x[[r, c]];
return Err(FFTError::ValueError(format!(
"Could not convert {val:?} to f64"
)));
}
}
_ihfft2_real(&real_input.view(), shape, axes, norm)
}
#[allow(dead_code)]
fn _ihfft2_real(
x: &ArrayView2<f64>,
shape: Option<(usize, usize)>,
axes: Option<(usize, usize)>,
_norm: Option<&str>,
) -> FFTResult<Array2<Complex64>> {
let (n_rows, n_cols) = x.dim();
let (out_rows, out_cols) = shape.unwrap_or((n_rows, n_cols));
let (axis_0, axis_1) = axes.unwrap_or((0, 1));
if axis_0 >= 2 || axis_1 >= 2 {
return Err(FFTError::ValueError(
"Axes must be 0 or 1 for 2D arrays".to_string(),
));
}
let complex_input = Array2::from_shape_fn((n_rows, n_cols), |idx| Complex64::new(x[idx], 0.0));
let mut temp = Array2::zeros((out_rows, n_cols));
for c in 0..n_cols {
let mut col = Vec::with_capacity(n_rows);
for r in 0..n_rows {
col.push(complex_input[[r, c]]);
}
let ifft_col = ifft(&col, Some(out_rows))?;
for r in 0..out_rows {
temp[[r, c]] = ifft_col[r];
}
}
let mut output = Array2::zeros((out_rows, out_cols));
for r in 0..out_rows {
let mut row = Vec::with_capacity(n_cols);
for c in 0..n_cols {
row.push(temp[[r, c]]);
}
let ifft_row = ifft(&row, Some(out_cols))?;
for c in 0..out_cols {
output[[r, c]] = ifft_row[c];
}
}
enforce_hermitian_symmetry(&mut output);
Ok(output)
}
#[allow(dead_code)]
pub fn ihfftn<T>(
x: &ArrayView<T, IxDyn>,
shape: Option<Vec<usize>>,
axes: Option<Vec<usize>>,
norm: Option<&str>,
overwrite_x: Option<bool>,
workers: Option<usize>,
) -> FFTResult<Array<Complex64, IxDyn>>
where
T: NumCast + Copy + Debug + 'static,
{
#[cfg(test)]
{
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let ptr = x.as_ptr() as *const f64;
let real_view = unsafe { ArrayView::from_shape_ptr(IxDyn(x.shape()), ptr) };
return _ihfftn_real(&real_view, shape, axes, norm, overwrite_x, workers);
}
}
let xshape = x.shape().to_vec();
let real_input = Array::from_shape_fn(IxDyn(&xshape), |idx| {
let val = x[idx.clone()];
if let Some(val_f64) = NumCast::from(val) {
return val_f64;
}
0.0
});
_ihfftn_real(&real_input.view(), shape, axes, norm, overwrite_x, workers)
}
#[allow(dead_code)]
fn _ihfftn_real(
x: &ArrayView<f64, IxDyn>,
shape: Option<Vec<usize>>,
axes: Option<Vec<usize>>,
norm: Option<&str>,
_overwrite_x: Option<bool>,
_workers: Option<usize>,
) -> FFTResult<Array<Complex64, IxDyn>> {
let xshape = x.shape().to_vec();
let ndim = xshape.len();
if ndim == 0 || xshape.contains(&0) {
return Ok(Array::zeros(IxDyn(&[])));
}
let outshape = match shape {
Some(s) => {
if s.len() != ndim {
return Err(FFTError::ValueError(format!(
"Shape must have the same number of dimensions as input, got {} != {}",
s.len(),
ndim
)));
}
s
}
None => xshape.clone(),
};
let transform_axes = match axes {
Some(a) => {
let mut sorted_axes = a.clone();
sorted_axes.sort_unstable();
sorted_axes.dedup();
for &ax in &sorted_axes {
if ax >= ndim {
return Err(FFTError::ValueError(format!(
"Axis {ax} is out of bounds for array of dimension {ndim}"
)));
}
}
sorted_axes
}
None => (0..ndim).collect(),
};
if ndim == 1 {
let mut real_vals = Vec::with_capacity(x.len());
for &val in x.iter() {
real_vals.push(val);
}
let result = _ihfft_real(&real_vals, Some(outshape[0]), norm)?;
let mut complex_result = Array::zeros(IxDyn(&[outshape[0]]));
for i in 0..outshape[0] {
complex_result[i] = result[i];
}
return Ok(complex_result);
}
let complex_input =
Array::from_shape_fn(IxDyn(&xshape), |idx| Complex64::new(x[idx.clone()], 0.0));
let mut array = complex_input;
for &axis in &transform_axes {
let axis_dim = outshape[axis];
let _dim_permutation: Vec<_> = (0..ndim).collect();
let mut workingshape = array.shape().to_vec();
workingshape[axis] = axis_dim;
let mut axis_result = Array::zeros(IxDyn(&workingshape));
let mut indices = vec![0; ndim];
let mut fiber = Vec::with_capacity(axis_dim);
for i in 0..array.shape()[axis] {
indices[axis] = i;
fiber.push(array[IxDyn(&indices)]);
}
let ifft_result = ifft(&fiber, Some(axis_dim))?;
for (i, val) in ifft_result.iter().enumerate().take(axis_dim) {
indices[axis] = i;
axis_result[IxDyn(&indices)] = *val;
}
array = axis_result;
}
enforce_hermitian_symmetry_nd(&mut array);
Ok(array)
}