use num_complex::Complex;
use ferray_core::Array;
use ferray_core::dimension::{Dimension, IxDyn};
use ferray_core::error::{FerrayError, FerrayResult};
use crate::norm::FftNorm;
pub fn hfft<D: Dimension>(
a: &Array<Complex<f64>, D>,
n: Option<usize>,
axis: Option<usize>,
norm: FftNorm,
) -> FerrayResult<Array<f64, IxDyn>> {
let shape = a.shape().to_vec();
let ndim = shape.len();
let ax = resolve_axis(ndim, axis)?;
let half_len = shape[ax];
let output_len = n.unwrap_or(2 * (half_len - 1));
if output_len == 0 {
return Err(FerrayError::invalid_value("hfft output length must be > 0"));
}
let conj_data: Vec<Complex<f64>> = a.iter().map(|c| c.conj()).collect();
let conj_arr = Array::<Complex<f64>, IxDyn>::from_vec(IxDyn::new(&shape), conj_data)?;
let hfft_norm = match norm {
FftNorm::Backward => FftNorm::Forward,
FftNorm::Forward => FftNorm::Backward,
FftNorm::Ortho => FftNorm::Ortho,
};
crate::real::irfft(&conj_arr, Some(output_len), Some(ax), hfft_norm)
}
pub fn ihfft<D: Dimension>(
a: &Array<f64, D>,
n: Option<usize>,
axis: Option<usize>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
let shape = a.shape().to_vec();
let ndim = shape.len();
let ax = resolve_axis(ndim, axis)?;
let ihfft_norm = match norm {
FftNorm::Backward => FftNorm::Forward,
FftNorm::Forward => FftNorm::Backward,
FftNorm::Ortho => FftNorm::Ortho,
};
let result = crate::real::rfft(a, n, Some(ax), ihfft_norm)?;
let conj_data: Vec<Complex<f64>> = result.iter().map(|c| c.conj()).collect();
let out_shape = result.shape().to_vec();
Array::from_vec(IxDyn::new(&out_shape), conj_data)
}
fn resolve_axis(ndim: usize, axis: Option<usize>) -> FerrayResult<usize> {
match axis {
Some(ax) => {
if ax >= ndim {
Err(FerrayError::axis_out_of_bounds(ax, ndim))
} else {
Ok(ax)
}
}
None => {
if ndim == 0 {
Err(FerrayError::invalid_value(
"cannot compute FFT on a 0-dimensional array",
))
} else {
Ok(ndim - 1)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use ferray_core::dimension::Ix1;
fn c(re: f64, im: f64) -> Complex<f64> {
Complex::new(re, im)
}
#[test]
fn hfft_ihfft_roundtrip() {
let original = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let n = original.len();
let a = Array::<f64, Ix1>::from_vec(Ix1::new([n]), original.clone()).unwrap();
let spectrum = ihfft(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(spectrum.shape(), &[n / 2 + 1]);
let recovered = hfft(&spectrum, Some(n), None, FftNorm::Backward).unwrap();
assert_eq!(recovered.shape(), &[n]);
let rec_data: Vec<f64> = recovered.iter().copied().collect();
for (o, r) in original.iter().zip(rec_data.iter()) {
assert!((o - r).abs() < 1e-10, "mismatch: expected {}, got {}", o, r);
}
}
#[test]
fn ihfft_basic() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 0.0, 0.0, 0.0]).unwrap();
let result = ihfft(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[3]);
}
#[test]
fn hfft_hermitian_input() {
let input = vec![c(10.0, 0.0), c(-2.0, 2.0), c(-2.0, 0.0)];
let a = Array::<Complex<f64>, Ix1>::from_vec(Ix1::new([3]), input).unwrap();
let result = hfft(&a, Some(4), None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[4]);
}
}