use num_complex::Complex;
use ferray_core::Array;
use ferray_core::dimension::{Dimension, IxDyn};
use ferray_core::error::{FerrayError, FerrayResult};
use crate::nd::fft_along_axis;
use crate::norm::FftNorm;
fn real_to_complex_flat<D: Dimension>(a: &Array<f64, D>) -> Vec<Complex<f64>> {
a.iter().map(|&v| Complex::new(v, 0.0)).collect()
}
fn resolve_axis(ndim: usize, axis: Option<usize>) -> FerrayResult<usize> {
match axis {
Some(ax) => {
if ax >= ndim {
Err(FerrayError::axis_out_of_bounds(ax, ndim))
} else {
Ok(ax)
}
}
None => {
if ndim == 0 {
Err(FerrayError::invalid_value(
"cannot compute FFT on a 0-dimensional array",
))
} else {
Ok(ndim - 1)
}
}
}
}
fn resolve_axes(ndim: usize, axes: Option<&[usize]>) -> FerrayResult<Vec<usize>> {
match axes {
Some(ax) => {
for &a in ax {
if a >= ndim {
return Err(FerrayError::axis_out_of_bounds(a, ndim));
}
}
Ok(ax.to_vec())
}
None => Ok((0..ndim).collect()),
}
}
fn truncate_hermitian(
data: &[Complex<f64>],
shape: &[usize],
axis: usize,
) -> (Vec<usize>, Vec<Complex<f64>>) {
let full_len = shape[axis];
let half_len = full_len / 2 + 1;
let mut new_shape = shape.to_vec();
new_shape[axis] = half_len;
let strides = compute_strides(shape);
let new_strides = compute_strides(&new_shape);
let new_total: usize = new_shape.iter().product();
let output: Vec<Complex<f64>> = (0..new_total)
.map(|flat_idx| {
let multi = flat_to_multi(flat_idx, &new_shape, &new_strides);
let src_idx: usize = multi
.iter()
.zip(strides.iter())
.map(|(&m, &s)| m * s as usize)
.sum();
data[src_idx]
})
.collect();
(new_shape, output)
}
fn extend_hermitian(
data: &[Complex<f64>],
shape: &[usize],
axis: usize,
n: usize,
) -> (Vec<usize>, Vec<Complex<f64>>) {
let half_len = shape[axis];
let mut new_shape = shape.to_vec();
new_shape[axis] = n;
let strides = compute_strides(shape);
let new_strides = compute_strides(&new_shape);
let new_total: usize = new_shape.iter().product();
let output: Vec<Complex<f64>> = (0..new_total)
.map(|flat_idx| {
let multi = flat_to_multi(flat_idx, &new_shape, &new_strides);
let axis_idx = multi[axis];
if axis_idx < half_len {
let src_idx: usize = multi
.iter()
.zip(strides.iter())
.map(|(&m, &s)| m * s as usize)
.sum();
data[src_idx]
} else {
let mirror_axis_idx = n - axis_idx;
if mirror_axis_idx < half_len {
let mut src_multi = multi;
src_multi[axis] = mirror_axis_idx;
let src_idx: usize = src_multi
.iter()
.zip(strides.iter())
.map(|(&m, &s)| m * s as usize)
.sum();
data[src_idx].conj()
} else {
Complex::new(0.0, 0.0)
}
}
})
.collect();
(new_shape, output)
}
fn compute_strides(shape: &[usize]) -> Vec<isize> {
let ndim = shape.len();
let mut strides = vec![0isize; ndim];
if ndim == 0 {
return strides;
}
strides[ndim - 1] = 1;
for i in (0..ndim - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1] as isize;
}
strides
}
fn flat_to_multi(flat_idx: usize, shape: &[usize], strides: &[isize]) -> Vec<usize> {
let ndim = shape.len();
let mut multi = vec![0usize; ndim];
let mut remaining = flat_idx;
for d in 0..ndim {
if strides[d] != 0 {
multi[d] = remaining / strides[d] as usize;
remaining %= strides[d] as usize;
}
}
multi
}
pub fn rfft<D: Dimension>(
a: &Array<f64, D>,
n: Option<usize>,
axis: Option<usize>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
let shape = a.shape().to_vec();
let ndim = shape.len();
let ax = resolve_axis(ndim, axis)?;
let fft_len = n.unwrap_or(shape[ax]);
if fft_len == 0 {
return Err(FerrayError::invalid_value("FFT length must be > 0"));
}
let complex_data = real_to_complex_flat(a);
let (full_shape, full_result) =
fft_along_axis(&complex_data, &shape, ax, Some(fft_len), false, norm)?;
let (out_shape, out_data) = truncate_hermitian(&full_result, &full_shape, ax);
Array::from_vec(IxDyn::new(&out_shape), out_data)
}
pub fn irfft<D: Dimension>(
a: &Array<Complex<f64>, D>,
n: Option<usize>,
axis: Option<usize>,
norm: FftNorm,
) -> FerrayResult<Array<f64, IxDyn>> {
let shape = a.shape().to_vec();
let ndim = shape.len();
let ax = resolve_axis(ndim, axis)?;
let half_len = shape[ax];
let output_len = n.unwrap_or(2 * (half_len - 1));
if output_len == 0 {
return Err(FerrayError::invalid_value(
"irfft output length must be > 0",
));
}
let complex_data: Vec<Complex<f64>> = a.iter().copied().collect();
let (extended_shape, extended_data) = extend_hermitian(&complex_data, &shape, ax, output_len);
let (result_shape, result_data) =
fft_along_axis(&extended_data, &extended_shape, ax, None, true, norm)?;
let real_data: Vec<f64> = result_data.iter().map(|c| c.re).collect();
Array::from_vec(IxDyn::new(&result_shape), real_data)
}
pub fn rfft2<D: Dimension>(
a: &Array<f64, D>,
s: Option<&[usize]>,
axes: Option<&[usize]>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
let ndim = a.shape().len();
let axes = match axes {
Some(ax) => ax.to_vec(),
None => {
if ndim < 2 {
return Err(FerrayError::invalid_value(
"rfft2 requires at least 2 dimensions",
));
}
vec![ndim - 2, ndim - 1]
}
};
rfftn_impl(a, s, &axes, norm)
}
pub fn irfft2<D: Dimension>(
a: &Array<Complex<f64>, D>,
s: Option<&[usize]>,
axes: Option<&[usize]>,
norm: FftNorm,
) -> FerrayResult<Array<f64, IxDyn>> {
let ndim = a.shape().len();
let axes = match axes {
Some(ax) => ax.to_vec(),
None => {
if ndim < 2 {
return Err(FerrayError::invalid_value(
"irfft2 requires at least 2 dimensions",
));
}
vec![ndim - 2, ndim - 1]
}
};
irfftn_impl(a, s, &axes, norm)
}
pub fn rfftn<D: Dimension>(
a: &Array<f64, D>,
s: Option<&[usize]>,
axes: Option<&[usize]>,
norm: FftNorm,
) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
let ax = resolve_axes(a.shape().len(), axes)?;
rfftn_impl(a, s, &ax, norm)
}
pub fn irfftn<D: Dimension>(
a: &Array<Complex<f64>, D>,
s: Option<&[usize]>,
axes: Option<&[usize]>,
norm: FftNorm,
) -> FerrayResult<Array<f64, IxDyn>> {
let ax = resolve_axes(a.shape().len(), axes)?;
irfftn_impl(a, s, &ax, norm)
}
fn rfftn_impl<D: Dimension>(
a: &Array<f64, D>,
s: Option<&[usize]>,
axes: &[usize],
norm: FftNorm,
) -> FerrayResult<Array<Complex<f64>, IxDyn>> {
if axes.is_empty() {
let data: Vec<Complex<f64>> = a.iter().map(|&v| Complex::new(v, 0.0)).collect();
return Array::from_vec(IxDyn::new(a.shape()), data);
}
let input_shape = a.shape().to_vec();
let sizes: Vec<Option<usize>> = match s {
Some(sizes) => {
if sizes.len() != axes.len() {
return Err(FerrayError::invalid_value(format!(
"shape parameter length {} does not match axes length {}",
sizes.len(),
axes.len(),
)));
}
sizes.iter().map(|&sz| Some(sz)).collect()
}
None => axes.iter().map(|&ax| Some(input_shape[ax])).collect(),
};
let complex_data = real_to_complex_flat(a);
let mut current_data = complex_data;
let mut current_shape = input_shape;
for (i, &ax) in axes.iter().enumerate() {
let n = sizes[i];
if i < axes.len() - 1 {
let (new_shape, new_data) =
fft_along_axis(¤t_data, ¤t_shape, ax, n, false, norm)?;
current_shape = new_shape;
current_data = new_data;
} else {
let fft_len = n.unwrap_or(current_shape[ax]);
let (full_shape, full_data) = fft_along_axis(
¤t_data,
¤t_shape,
ax,
Some(fft_len),
false,
norm,
)?;
let (out_shape, out_data) = truncate_hermitian(&full_data, &full_shape, ax);
current_shape = out_shape;
current_data = out_data;
}
}
Array::from_vec(IxDyn::new(¤t_shape), current_data)
}
fn irfftn_impl<D: Dimension>(
a: &Array<Complex<f64>, D>,
s: Option<&[usize]>,
axes: &[usize],
norm: FftNorm,
) -> FerrayResult<Array<f64, IxDyn>> {
if axes.is_empty() {
let data: Vec<f64> = a.iter().map(|c| c.re).collect();
return Array::from_vec(IxDyn::new(a.shape()), data);
}
let input_shape = a.shape().to_vec();
let sizes: Vec<Option<usize>> = match s {
Some(sizes) => {
if sizes.len() != axes.len() {
return Err(FerrayError::invalid_value(format!(
"shape parameter length {} does not match axes length {}",
sizes.len(),
axes.len(),
)));
}
sizes.iter().map(|&sz| Some(sz)).collect()
}
None => {
let mut result = Vec::with_capacity(axes.len());
for (i, &ax) in axes.iter().enumerate() {
if i < axes.len() - 1 {
result.push(Some(input_shape[ax]));
} else {
result.push(Some(2 * (input_shape[ax] - 1)));
}
}
result
}
};
let complex_data: Vec<Complex<f64>> = a.iter().copied().collect();
let mut current_data = complex_data;
let mut current_shape = input_shape;
for (i, &ax) in axes.iter().enumerate() {
let n = sizes[i];
if i < axes.len() - 1 {
let (new_shape, new_data) =
fft_along_axis(¤t_data, ¤t_shape, ax, n, true, norm)?;
current_shape = new_shape;
current_data = new_data;
} else {
let output_len = n.unwrap_or(2 * (current_shape[ax] - 1));
let (ext_shape, ext_data) =
extend_hermitian(¤t_data, ¤t_shape, ax, output_len);
let (result_shape, result_data) =
fft_along_axis(&ext_data, &ext_shape, ax, None, true, norm)?;
current_shape = result_shape;
current_data = result_data;
}
}
let real_data: Vec<f64> = current_data.iter().map(|c| c.re).collect();
Array::from_vec(IxDyn::new(¤t_shape), real_data)
}
#[cfg(test)]
mod tests {
use super::*;
use ferray_core::dimension::Ix1;
fn make_real_1d(data: Vec<f64>) -> Array<f64, Ix1> {
let n = data.len();
Array::from_vec(Ix1::new([n]), data).unwrap()
}
#[test]
fn rfft_basic() {
let a = make_real_1d(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
let result = rfft(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[5]);
}
#[test]
fn rfft_impulse() {
let a = make_real_1d(vec![1.0, 0.0, 0.0, 0.0]);
let result = rfft(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[3]);
for val in result.iter() {
assert!((val.re - 1.0).abs() < 1e-12);
assert!(val.im.abs() < 1e-12);
}
}
#[test]
fn rfft_irfft_roundtrip() {
let original = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let a = make_real_1d(original.clone());
let spectrum = rfft(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(spectrum.shape(), &[5]);
let recovered = irfft(&spectrum, Some(8), None, FftNorm::Backward).unwrap();
assert_eq!(recovered.shape(), &[8]);
let rec_data: Vec<f64> = recovered.iter().copied().collect();
for (o, r) in original.iter().zip(rec_data.iter()) {
assert!((o - r).abs() < 1e-10, "{} vs {}", o, r);
}
}
#[test]
fn rfft_with_n() {
let a = make_real_1d(vec![1.0, 2.0]);
let result = rfft(&a, Some(8), None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[5]); }
#[test]
fn rfft2_basic() {
use ferray_core::dimension::Ix2;
let data = vec![1.0, 2.0, 3.0, 4.0];
let a = Array::from_vec(Ix2::new([2, 2]), data).unwrap();
let result = rfft2(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(result.shape(), &[2, 2]);
}
#[test]
fn rfft_irfft_roundtrip_odd() {
let original = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let a = make_real_1d(original.clone());
let spectrum = rfft(&a, None, None, FftNorm::Backward).unwrap();
assert_eq!(spectrum.shape(), &[3]);
let recovered = irfft(&spectrum, Some(5), None, FftNorm::Backward).unwrap();
let rec_data: Vec<f64> = recovered.iter().copied().collect();
for (o, r) in original.iter().zip(rec_data.iter()) {
assert!((o - r).abs() < 1e-10, "{} vs {}", o, r);
}
}
}