use crate::error::{FFTError, FFTResult};
use crate::fft::fft;
use scirs2_core::ndarray::{Array, Array2, ArrayView, ArrayView2, IxDyn};
use scirs2_core::numeric::Complex64;
use scirs2_core::numeric::NumCast;
use std::fmt::Debug;
use super::utility::try_as_complex;
#[allow(dead_code)]
pub fn hfft<T>(x: &[T], n: Option<usize>, norm: Option<&str>) -> FFTResult<Vec<f64>>
where
T: NumCast + Copy + Debug + 'static,
{
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<Complex64>() {
let complex_input: &[Complex64] =
unsafe { std::slice::from_raw_parts(x.as_ptr() as *const Complex64, x.len()) };
let mut adjusted_input = Vec::with_capacity(complex_input.len());
if !complex_input.is_empty() {
adjusted_input.push(Complex64::new(complex_input[0].re, 0.0));
adjusted_input.extend_from_slice(&complex_input[1..]);
}
return _hfft_complex(&adjusted_input, n, norm);
}
let mut complex_input = Vec::with_capacity(x.len());
for (i, &val) in x.iter().enumerate() {
if let Some(c) = try_as_complex(val) {
if i == 0 {
complex_input.push(Complex64::new(c.re, 0.0));
} else {
complex_input.push(c);
}
continue;
}
if let Some(val_f64) = NumCast::from(val) {
complex_input.push(Complex64::new(val_f64, 0.0));
continue;
}
return Err(FFTError::ValueError(format!(
"Could not convert {val:?} to Complex64"
)));
}
_hfft_complex(&complex_input, n, norm)
}
#[allow(dead_code)]
fn _hfft_complex(x: &[Complex64], n: Option<usize>, _norm: Option<&str>) -> FFTResult<Vec<f64>> {
let n_fft = n.unwrap_or(x.len());
let n_real = n_fft;
let mut output = Vec::with_capacity(n_real);
let fft_result = fft(x, Some(n_fft))?;
for val in fft_result {
output.push(val.re);
}
Ok(output)
}
#[allow(dead_code)]
pub fn hfft2<T>(
x: &ArrayView2<T>,
shape: Option<(usize, usize)>,
axes: Option<(usize, usize)>,
norm: Option<&str>,
) -> FFTResult<Array2<f64>>
where
T: NumCast + Copy + Debug + 'static,
{
#[cfg(test)]
{
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<Complex64>() {
let ptr = x.as_ptr() as *const Complex64;
let complex_view = unsafe { ArrayView2::from_shape_ptr(x.dim(), ptr) };
return _hfft2_complex(&complex_view, shape, axes, norm);
}
}
let (n_rows, n_cols) = x.dim();
let mut complex_input = Array2::zeros((n_rows, n_cols));
for r in 0..n_rows {
for c in 0..n_cols {
let val = x[[r, c]];
if let Some(complex) = try_as_complex(val) {
complex_input[[r, c]] = complex;
continue;
}
if let Some(val_f64) = NumCast::from(val) {
complex_input[[r, c]] = Complex64::new(val_f64, 0.0);
continue;
}
return Err(FFTError::ValueError(format!(
"Could not convert {val:?} to Complex64"
)));
}
}
_hfft2_complex(&complex_input.view(), shape, axes, norm)
}
#[allow(dead_code)]
fn _hfft2_complex(
x: &ArrayView2<Complex64>,
shape: Option<(usize, usize)>,
axes: Option<(usize, usize)>,
_norm: Option<&str>,
) -> FFTResult<Array2<f64>> {
let (n_rows, n_cols) = x.dim();
let (out_rows, out_cols) = shape.unwrap_or((n_rows, n_cols));
let (axis_0, axis_1) = axes.unwrap_or((0, 1));
if axis_0 >= 2 || axis_1 >= 2 {
return Err(FFTError::ValueError(
"Axes must be 0 or 1 for 2D arrays".to_string(),
));
}
let mut temp = Array2::zeros((out_rows, n_cols));
for c in 0..n_cols {
let mut col = Vec::with_capacity(n_rows);
for r in 0..n_rows {
col.push(x[[r, c]]);
}
let fft_col = fft(&col, Some(out_rows))?;
for r in 0..out_rows {
temp[[r, c]] = fft_col[r];
}
}
let mut output = Array2::zeros((out_rows, out_cols));
for r in 0..out_rows {
let mut row = Vec::with_capacity(n_cols);
for c in 0..n_cols {
row.push(temp[[r, c]]);
}
let fft_row = fft(&row, Some(out_cols))?;
for c in 0..out_cols {
output[[r, c]] = fft_row[c].re;
}
}
Ok(output)
}
#[allow(dead_code)]
pub fn hfftn<T>(
x: &ArrayView<T, IxDyn>,
shape: Option<Vec<usize>>,
axes: Option<Vec<usize>>,
norm: Option<&str>,
overwrite_x: Option<bool>,
workers: Option<usize>,
) -> FFTResult<Array<f64, IxDyn>>
where
T: NumCast + Copy + Debug + 'static,
{
#[cfg(test)]
{
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<Complex64>() {
let ptr = x.as_ptr() as *const Complex64;
let complex_view = unsafe { ArrayView::from_shape_ptr(IxDyn(x.shape()), ptr) };
return _hfftn_complex(&complex_view, shape, axes, norm, overwrite_x, workers);
}
}
let xshape = x.shape().to_vec();
let complex_input = Array::from_shape_fn(IxDyn(&xshape), |idx| {
let val = x[idx.clone()];
if let Some(c) = try_as_complex(val) {
return c;
}
if let Some(val_f64) = NumCast::from(val) {
return Complex64::new(val_f64, 0.0);
}
Complex64::new(0.0, 0.0) });
_hfftn_complex(
&complex_input.view(),
shape,
axes,
norm,
overwrite_x,
workers,
)
}
#[allow(dead_code)]
fn _hfftn_complex(
x: &ArrayView<Complex64, IxDyn>,
shape: Option<Vec<usize>>,
axes: Option<Vec<usize>>,
_norm: Option<&str>,
_overwrite_x: Option<bool>,
_workers: Option<usize>,
) -> FFTResult<Array<f64, IxDyn>> {
let xshape = x.shape().to_vec();
let ndim = xshape.len();
if ndim == 0 || xshape.contains(&0) {
return Ok(Array::zeros(IxDyn(&[])));
}
let outshape = match shape {
Some(s) => {
if s.len() != ndim {
return Err(FFTError::ValueError(format!(
"Shape must have the same number of dimensions as input, got {} != {}",
s.len(),
ndim
)));
}
s
}
None => xshape.clone(),
};
let transform_axes = match axes {
Some(a) => {
let mut sorted_axes = a.clone();
sorted_axes.sort_unstable();
sorted_axes.dedup();
for &ax in &sorted_axes {
if ax >= ndim {
return Err(FFTError::ValueError(format!(
"Axis {ax} is out of bounds for array of dimension {ndim}"
)));
}
}
sorted_axes
}
None => (0..ndim).collect(),
};
if ndim == 1 {
let mut complex_result = Vec::with_capacity(x.len());
for &val in x.iter() {
complex_result.push(val);
}
let fft_result = fft(&complex_result, Some(outshape[0]))?;
let mut real_result = Array::zeros(IxDyn(&[outshape[0]]));
for i in 0..outshape[0] {
real_result[i] = fft_result[i].re;
}
return Ok(real_result);
}
let mut array = Array::from_shape_fn(IxDyn(&xshape), |idx| x[idx.clone()]);
for &axis in &transform_axes {
let axis_dim = outshape[axis];
let _dim_permutation: Vec<_> = (0..ndim).collect();
let mut workingshape = xshape.clone();
workingshape[axis] = axis_dim;
let mut axis_result = Array::zeros(IxDyn(&workingshape));
let mut indices = vec![0; ndim];
let mut fiber = Vec::with_capacity(axis_dim);
for i in 0..axis_dim {
indices[axis] = i;
fiber.push(array[IxDyn(&indices)]);
}
let fft_result = fft(&fiber, Some(axis_dim))?;
for (i, val) in fft_result.iter().enumerate().take(axis_dim) {
indices[axis] = i;
axis_result[IxDyn(&indices)] = *val;
}
array = axis_result;
}
let mut real_result = Array::zeros(IxDyn(&outshape));
for (i, &val) in array.iter().enumerate() {
let mut idx = vec![0; ndim];
for (dim, idx_val) in idx.iter_mut().enumerate().take(ndim) {
let stride = array.strides()[dim] as usize;
if let Some(divided) = i.checked_div(stride) {
*idx_val = divided % array.shape()[dim];
}
}
real_result[IxDyn(&idx)] = val.re;
}
Ok(real_result)
}