use num_complex::Complex;
use ferray_core::Array;
use ferray_core::dimension::{Dimension, IxDyn};
use ferray_core::dtype::Element;
use ferray_core::error::{FerrayError, FerrayResult};
use crate::axes::resolve_axis;
use crate::float::FftFloat;
use crate::norm::FftNorm;
pub fn hfft<T: FftFloat, D: Dimension>(
a: &Array<Complex<T>, D>,
n: Option<usize>,
axis: Option<isize>,
norm: FftNorm,
) -> FerrayResult<Array<T, IxDyn>>
where
Complex<T>: Element,
{
let shape = a.shape().to_vec();
let ndim = shape.len();
let ax = resolve_axis(ndim, axis)?;
let half_len = shape[ax];
let output_len = n.unwrap_or(2 * (half_len - 1));
if output_len == 0 {
return Err(FerrayError::invalid_value("hfft output length must be > 0"));
}
let conj_data: Vec<Complex<T>> = a.iter().map(|c| c.conj()).collect();
let conj_arr = Array::<Complex<T>, IxDyn>::from_vec(IxDyn::new(&shape), conj_data)?;
let hfft_norm = match norm {
FftNorm::Backward => FftNorm::Forward,
FftNorm::Forward => FftNorm::Backward,
FftNorm::Ortho => FftNorm::Ortho,
};
crate::real::irfft::<T, IxDyn>(&conj_arr, Some(output_len), Some(ax as isize), hfft_norm)
}
pub fn ihfft<T: FftFloat, D: Dimension>(
a: &Array<T, D>,
n: Option<usize>,
axis: Option<isize>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<T>, IxDyn>>
where
Complex<T>: Element,
{
let shape = a.shape().to_vec();
let ndim = shape.len();
let ax = resolve_axis(ndim, axis)?;
let ihfft_norm = match norm {
FftNorm::Backward => FftNorm::Forward,
FftNorm::Forward => FftNorm::Backward,
FftNorm::Ortho => FftNorm::Ortho,
};
let result = crate::real::rfft::<T, D>(a, n, Some(ax as isize), ihfft_norm)?;
let conj_data: Vec<Complex<T>> = result.iter().map(|c| c.conj()).collect();
let out_shape = result.shape().to_vec();
Array::from_vec(IxDyn::new(&out_shape), conj_data)
}
#[cfg(test)]
mod tests {
use super::*;
use ferray_core::dimension::Ix1;
fn c(re: f64, im: f64) -> Complex<f64> {
Complex::new(re, im)
}
#[test]
fn hfft_ihfft_roundtrip() {
let original = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let n = original.len();
let a = Array::<f64, Ix1>::from_vec(Ix1::new([n]), original.clone()).unwrap();
let spectrum = ihfft(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(spectrum.shape(), &[n / 2 + 1]);
let recovered = hfft(&spectrum, Some(n), None, FftNorm::Backward).unwrap();
assert_eq!(recovered.shape(), &[n]);
let rec_data: Vec<f64> = recovered.iter().copied().collect();
for (o, r) in original.iter().zip(rec_data.iter()) {
assert!((o - r).abs() < 1e-10, "mismatch: expected {}, got {}", o, r);
}
}
#[test]
fn ihfft_basic() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 0.0, 0.0, 0.0]).unwrap();
let result = ihfft(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[3]);
}
#[test]
fn hfft_hermitian_input() {
let input = vec![c(10.0, 0.0), c(-2.0, 2.0), c(-2.0, 0.0)];
let a = Array::<Complex<f64>, Ix1>::from_vec(Ix1::new([3]), input).unwrap();
let result = hfft(&a, Some(4), None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[4]);
}
}