use num_complex::Complex;
use ferray_core::Array;
use ferray_core::dimension::{Dimension, IxDyn};
use ferray_core::dtype::Element;
use ferray_core::error::{FerrayError, FerrayResult};
use crate::axes::{resolve_axes, resolve_axis};
use crate::float::FftFloat;
use crate::nd::{fft_along_axis, irfft_along_axis, rfft_along_axis};
use crate::norm::FftNorm;
pub fn rfft<T: FftFloat, D: Dimension>(
a: &Array<T, D>,
n: Option<usize>,
axis: Option<isize>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<T>, IxDyn>>
where
Complex<T>: Element,
{
let shape = a.shape().to_vec();
let ndim = shape.len();
let ax = resolve_axis(ndim, axis)?;
let real_data: Vec<T> = a.iter().copied().collect();
let (out_shape, out_data) = rfft_along_axis::<T>(&real_data, &shape, ax, n, norm)?;
Array::from_vec(IxDyn::new(&out_shape), out_data)
}
pub fn irfft<T: FftFloat, D: Dimension>(
a: &Array<Complex<T>, D>,
n: Option<usize>,
axis: Option<isize>,
norm: FftNorm,
) -> FerrayResult<Array<T, IxDyn>>
where
Complex<T>: Element,
{
let shape = a.shape().to_vec();
let ndim = shape.len();
let ax = resolve_axis(ndim, axis)?;
let half_len = shape[ax];
let output_len = n.unwrap_or(2 * (half_len - 1));
if output_len == 0 {
return Err(FerrayError::invalid_value(
"irfft output length must be > 0",
));
}
let complex_data: Vec<Complex<T>> = a.iter().copied().collect();
let (out_shape, out_data) = irfft_along_axis::<T>(&complex_data, &shape, ax, output_len, norm)?;
Array::from_vec(IxDyn::new(&out_shape), out_data)
}
pub fn rfft2<T: FftFloat, D: Dimension>(
a: &Array<T, D>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<T>, IxDyn>>
where
Complex<T>: Element,
{
let ndim = a.shape().len();
let axes = match axes {
Some(ax) => resolve_axes(ndim, Some(ax))?,
None => {
if ndim < 2 {
return Err(FerrayError::invalid_value(
"rfft2 requires at least 2 dimensions",
));
}
vec![ndim - 2, ndim - 1]
}
};
rfftn_impl::<T, D>(a, s, &axes, norm)
}
pub fn irfft2<T: FftFloat, D: Dimension>(
a: &Array<Complex<T>, D>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
norm: FftNorm,
) -> FerrayResult<Array<T, IxDyn>>
where
Complex<T>: Element,
{
let ndim = a.shape().len();
let axes = match axes {
Some(ax) => resolve_axes(ndim, Some(ax))?,
None => {
if ndim < 2 {
return Err(FerrayError::invalid_value(
"irfft2 requires at least 2 dimensions",
));
}
vec![ndim - 2, ndim - 1]
}
};
irfftn_impl::<T, D>(a, s, &axes, norm)
}
pub fn rfftn<T: FftFloat, D: Dimension>(
a: &Array<T, D>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<T>, IxDyn>>
where
Complex<T>: Element,
{
let ax = resolve_axes(a.shape().len(), axes)?;
rfftn_impl::<T, D>(a, s, &ax, norm)
}
pub fn irfftn<T: FftFloat, D: Dimension>(
a: &Array<Complex<T>, D>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
norm: FftNorm,
) -> FerrayResult<Array<T, IxDyn>>
where
Complex<T>: Element,
{
let ax = resolve_axes(a.shape().len(), axes)?;
irfftn_impl::<T, D>(a, s, &ax, norm)
}
fn rfftn_impl<T: FftFloat, D: Dimension>(
a: &Array<T, D>,
s: Option<&[usize]>,
axes: &[usize],
norm: FftNorm,
) -> FerrayResult<Array<Complex<T>, IxDyn>>
where
Complex<T>: Element,
{
if axes.is_empty() {
let data: Vec<Complex<T>> = a
.iter()
.map(|&v| Complex::new(v, <T as num_traits::Zero>::zero()))
.collect();
return Array::from_vec(IxDyn::new(a.shape()), data);
}
let input_shape = a.shape().to_vec();
let sizes: Vec<Option<usize>> = match s {
Some(sizes) => {
if sizes.len() != axes.len() {
return Err(FerrayError::invalid_value(format!(
"shape parameter length {} does not match axes length {}",
sizes.len(),
axes.len(),
)));
}
sizes.iter().map(|&sz| Some(sz)).collect()
}
None => axes.iter().map(|&ax| Some(input_shape[ax])).collect(),
};
let last_idx = axes.len() - 1;
let last_ax = axes[last_idx];
let last_n = sizes[last_idx];
let real_data: Vec<T> = a.iter().copied().collect();
let (mut current_shape, mut current_data) =
rfft_along_axis::<T>(&real_data, &input_shape, last_ax, last_n, norm)?;
for i in 0..last_idx {
let ax = axes[i];
let n = sizes[i];
let (new_shape, new_data) =
fft_along_axis::<T>(¤t_data, ¤t_shape, ax, n, false, norm)?;
current_shape = new_shape;
current_data = new_data;
}
Array::from_vec(IxDyn::new(¤t_shape), current_data)
}
fn irfftn_impl<T: FftFloat, D: Dimension>(
a: &Array<Complex<T>, D>,
s: Option<&[usize]>,
axes: &[usize],
norm: FftNorm,
) -> FerrayResult<Array<T, IxDyn>>
where
Complex<T>: Element,
{
if axes.is_empty() {
let data: Vec<T> = a.iter().map(|c| c.re).collect();
return Array::from_vec(IxDyn::new(a.shape()), data);
}
let input_shape = a.shape().to_vec();
let last_idx = axes.len() - 1;
let sizes: Vec<usize> = match s {
Some(sizes) => {
if sizes.len() != axes.len() {
return Err(FerrayError::invalid_value(format!(
"shape parameter length {} does not match axes length {}",
sizes.len(),
axes.len(),
)));
}
sizes.to_vec()
}
None => {
let mut result = Vec::with_capacity(axes.len());
for (i, &ax) in axes.iter().enumerate() {
if i < last_idx {
result.push(input_shape[ax]);
} else {
result.push(2 * (input_shape[ax] - 1));
}
}
result
}
};
let mut current_data: Vec<Complex<T>> = a.iter().copied().collect();
let mut current_shape = input_shape;
for i in 0..last_idx {
let ax = axes[i];
let n = Some(sizes[i]);
let (new_shape, new_data) =
fft_along_axis::<T>(¤t_data, ¤t_shape, ax, n, true, norm)?;
current_shape = new_shape;
current_data = new_data;
}
let last_ax = axes[last_idx];
let output_len = sizes[last_idx];
let (final_shape, final_data) =
irfft_along_axis::<T>(¤t_data, ¤t_shape, last_ax, output_len, norm)?;
Array::from_vec(IxDyn::new(&final_shape), final_data)
}
#[cfg(test)]
mod tests {
use super::*;
use ferray_core::dimension::Ix1;
fn make_real_1d(data: Vec<f64>) -> Array<f64, Ix1> {
let n = data.len();
Array::from_vec(Ix1::new([n]), data).unwrap()
}
#[test]
fn rfft_basic() {
let a = make_real_1d(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
let result = rfft(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[5]);
}
#[test]
fn rfft_impulse() {
let a = make_real_1d(vec![1.0, 0.0, 0.0, 0.0]);
let result = rfft(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[3]);
for val in result.iter() {
assert!((val.re - 1.0).abs() < 1e-12);
assert!(val.im.abs() < 1e-12);
}
}
#[test]
fn rfft_irfft_roundtrip() {
let original = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let a = make_real_1d(original.clone());
let spectrum = rfft(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(spectrum.shape(), &[5]);
let recovered = irfft(&spectrum, Some(8), None, FftNorm::Backward).unwrap();
assert_eq!(recovered.shape(), &[8]);
let rec_data: Vec<f64> = recovered.iter().copied().collect();
for (o, r) in original.iter().zip(rec_data.iter()) {
assert!((o - r).abs() < 1e-10, "{} vs {}", o, r);
}
}
#[test]
fn rfft_with_n() {
let a = make_real_1d(vec![1.0, 2.0]);
let result = rfft(&a, Some(8), None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[5]); }
#[test]
fn rfft2_basic() {
use ferray_core::dimension::Ix2;
let data = vec![1.0, 2.0, 3.0, 4.0];
let a = Array::from_vec(Ix2::new([2, 2]), data).unwrap();
let result = rfft2(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[2, 2]);
}
#[test]
fn rfft_irfft_roundtrip_odd() {
let original = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let a = make_real_1d(original.clone());
let spectrum = rfft(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(spectrum.shape(), &[3]);
let recovered = irfft(&spectrum, Some(5), None, FftNorm::Backward).unwrap();
let rec_data: Vec<f64> = recovered.iter().copied().collect();
for (o, r) in original.iter().zip(rec_data.iter()) {
assert!((o - r).abs() < 1e-10, "{} vs {}", o, r);
}
}
#[test]
fn rfft_along_axis0_2d() {
use ferray_core::dimension::Ix2;
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 4]), data).unwrap();
let result = rfft(&a, None, Some(0), FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[2, 4]);
}
#[test]
fn rfft_irfft_roundtrip_axis0() {
use ferray_core::dimension::Ix2;
let data: Vec<f64> = (0..12).map(|i| i as f64).collect();
let a = Array::<f64, Ix2>::from_vec(Ix2::new([3, 4]), data.clone()).unwrap();
let spectrum = rfft(&a, None, Some(0), FftNorm::Backward).unwrap();
let recovered = irfft(&spectrum, Some(3), Some(0), FftNorm::Backward).unwrap();
let rec_data: Vec<f64> = recovered.iter().copied().collect();
for (o, r) in data.iter().zip(rec_data.iter()) {
assert!((o - r).abs() < 1e-9, "axis0 roundtrip: {} vs {}", o, r);
}
}
#[test]
fn rfft_irfft_roundtrip_axis1() {
use ferray_core::dimension::Ix2;
let data: Vec<f64> = (0..12).map(|i| i as f64).collect();
let a = Array::<f64, Ix2>::from_vec(Ix2::new([3, 4]), data.clone()).unwrap();
let spectrum = rfft(&a, None, Some(1), FftNorm::Backward).unwrap();
assert_eq!(spectrum.shape()[0], 3);
assert_eq!(spectrum.shape()[1], 3); let recovered = irfft(&spectrum, Some(4), Some(1), FftNorm::Backward).unwrap();
let rec_data: Vec<f64> = recovered.iter().copied().collect();
for (o, r) in data.iter().zip(rec_data.iter()) {
assert!((o - r).abs() < 1e-9, "axis1 roundtrip: {} vs {}", o, r);
}
}
#[test]
fn rfft_single_cosine_matches_analytical() {
let n = 16;
let k = 3; let data: Vec<f64> = (0..n)
.map(|i| (2.0 * std::f64::consts::PI * k as f64 * i as f64 / n as f64).cos())
.collect();
let a = make_real_1d(data);
let spectrum = rfft(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(spectrum.shape(), &[n / 2 + 1]);
let bins: Vec<Complex<f64>> = spectrum.iter().copied().collect();
for (i, bin) in bins.iter().enumerate() {
if i == k {
assert!(
(bin.re - (n as f64 / 2.0)).abs() < 1e-10,
"bin {} real part = {}, expected {}",
i,
bin.re,
n as f64 / 2.0
);
assert!(bin.im.abs() < 1e-10);
} else {
assert!(bin.norm() < 1e-10, "bin {} should be ~0, got {:?}", i, bin);
}
}
}
#[test]
fn rfft_parseval_holds_for_multi_lane() {
use ferray_core::dimension::Ix2;
let rows = 4usize;
let cols = 16usize;
let data: Vec<f64> = (0..rows * cols)
.map(|i| ((i as f64).sin() + (i as f64 * 0.3).cos()) * 2.0)
.collect();
let a = Array::<f64, Ix2>::from_vec(Ix2::new([rows, cols]), data.clone()).unwrap();
let spectrum = rfft(&a, None, Some(1), FftNorm::Backward).unwrap();
assert_eq!(spectrum.shape(), &[rows, cols / 2 + 1]);
for row in 0..rows {
let lane_start = row * cols;
let time_energy: f64 = data[lane_start..lane_start + cols]
.iter()
.map(|&v| v * v)
.sum();
let spec_row = row * (cols / 2 + 1);
let half_len = cols / 2 + 1;
let mut freq_energy = 0.0;
for k in 0..half_len {
let bin = spectrum.iter().nth(spec_row + k).copied().unwrap();
let mag_sq = bin.norm_sqr();
if k == 0 || (cols % 2 == 0 && k == cols / 2) {
freq_energy += mag_sq;
} else {
freq_energy += 2.0 * mag_sq;
}
}
let expected = cols as f64 * time_energy;
assert!(
((freq_energy - expected) / expected).abs() < 1e-9,
"lane {}: freq energy {} vs expected {} (time energy = {})",
row,
freq_energy,
expected,
time_energy
);
}
}
#[test]
fn rfft_odd_length_impulse() {
let a = make_real_1d(vec![1.0, 0.0, 0.0, 0.0, 0.0]);
let spectrum = rfft(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(spectrum.shape(), &[3]);
for bin in spectrum.iter() {
assert!((bin.re - 1.0).abs() < 1e-12);
assert!(bin.im.abs() < 1e-12);
}
}
#[test]
fn rfft_multi_lane_with_zero_padding() {
use ferray_core::dimension::Ix2;
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let a = Array::<f64, Ix2>::from_vec(Ix2::new([3, 2]), data).unwrap();
let spectrum = rfft(&a, Some(8), Some(1), FftNorm::Backward).unwrap();
assert_eq!(spectrum.shape(), &[3, 5]);
let bins: Vec<Complex<f64>> = spectrum.iter().copied().collect();
assert!((bins[0].re - (1.0 + 2.0)).abs() < 1e-12);
assert!((bins[5].re - (3.0 + 4.0)).abs() < 1e-12);
assert!((bins[10].re - (5.0 + 6.0)).abs() < 1e-12);
}
#[test]
fn irfft_multi_lane_axis0_roundtrip() {
use ferray_core::dimension::Ix2;
let rows = 6usize;
let cols = 4usize;
let data: Vec<f64> = (0..rows * cols).map(|i| (i as f64).sqrt()).collect();
let a = Array::<f64, Ix2>::from_vec(Ix2::new([rows, cols]), data.clone()).unwrap();
let spectrum = rfft(&a, None, Some(0), FftNorm::Backward).unwrap();
assert_eq!(spectrum.shape(), &[4, 4]);
let recovered = irfft(&spectrum, Some(rows), Some(0), FftNorm::Backward).unwrap();
assert_eq!(recovered.shape(), &[rows, cols]);
let rec_data: Vec<f64> = recovered.iter().copied().collect();
for (i, (o, r)) in data.iter().zip(rec_data.iter()).enumerate() {
assert!(
(o - r).abs() < 1e-10,
"index {}: expected {}, got {}",
i,
o,
r
);
}
}
#[test]
fn rfft_irfft_f32_roundtrip() {
let original: Vec<f32> = (0..16).map(|i| (i as f32 * 0.1).cos()).collect();
let a = Array::<f32, Ix1>::from_vec(Ix1::new([16]), original.clone()).unwrap();
let spectrum = rfft::<f32, Ix1>(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(spectrum.shape(), &[9]);
let recovered = irfft::<f32, IxDyn>(&spectrum, Some(16), None, FftNorm::Backward).unwrap();
assert_eq!(recovered.shape(), &[16]);
for (o, r) in original.iter().zip(recovered.iter()) {
assert!(
(o - r).abs() < 1e-5,
"f32 rfft/irfft mismatch: {} vs {}",
o,
r
);
}
}
#[test]
fn rfft2_f32_roundtrip() {
use ferray_core::dimension::Ix2;
let data: Vec<f32> = (0..16).map(|i| (i as f32) * 0.25).collect();
let a = Array::<f32, Ix2>::from_vec(Ix2::new([4, 4]), data.clone()).unwrap();
let spectrum = rfft2::<f32, Ix2>(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(spectrum.shape(), &[4, 3]);
let recovered =
irfft2::<f32, IxDyn>(&spectrum, Some(&[4, 4]), None, FftNorm::Backward).unwrap();
assert_eq!(recovered.shape(), &[4, 4]);
for (o, r) in data.iter().zip(recovered.iter()) {
assert!((o - r).abs() < 1e-5);
}
}
}