use crate::error::{FFTError, FFTResult};
use crate::fft::{fft, ifft};
use scirs2_core::ndarray::{s, Array, Array2, ArrayView, ArrayView2, IxDyn};
use scirs2_core::numeric::Complex64;
use scirs2_core::numeric::{NumCast, Zero};
use std::f64::consts::PI;
use std::fmt::Debug;
#[allow(dead_code)]
pub fn rfft<T>(x: &[T], n: Option<usize>) -> FFTResult<Vec<Complex64>>
where
T: NumCast + Copy + Debug + 'static,
{
let n_input = x.len();
let n_val = n.unwrap_or(n_input);
let full_fft = fft(x, Some(n_val))?;
let n_output = n_val / 2 + 1;
let mut result = Vec::with_capacity(n_output);
for val in full_fft.iter().take(n_output) {
result.push(*val);
}
Ok(result)
}
#[allow(dead_code)]
pub fn irfft<T>(x: &[T], n: Option<usize>) -> FFTResult<Vec<f64>>
where
T: NumCast + Copy + Debug + 'static,
{
let complex_input: Vec<Complex64> = x
.iter()
.map(|&val| -> FFTResult<Complex64> {
if let Some(c) = try_as_complex(val) {
return Ok(c);
}
let val_f64 = NumCast::from(val)
.ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))?;
Ok(Complex64::new(val_f64, 0.0))
})
.collect::<FFTResult<Vec<_>>>()?;
let input_len = complex_input.len();
let n_output = n.unwrap_or_else(|| {
2 * (input_len - 1)
});
let mut full_spectrum = Vec::with_capacity(n_output);
full_spectrum.extend_from_slice(&complex_input);
if n_output > input_len {
let start_idx = if n_output.is_multiple_of(2) {
input_len - 1
} else {
input_len
};
for i in (1..start_idx).rev() {
if full_spectrum.len() >= n_output {
break;
}
full_spectrum.push(complex_input[i].conj());
}
full_spectrum.resize(n_output, Complex64::zero());
}
let complex_output = ifft(&full_spectrum, Some(n_output))?;
let result: Vec<f64> = complex_output.iter().map(|c| c.re).collect();
Ok(result)
}
#[allow(dead_code)]
pub fn rfft2<T>(
x: &ArrayView2<T>,
shape: Option<(usize, usize)>,
_axes: Option<(usize, usize)>,
_norm: Option<&str>,
) -> FFTResult<Array2<Complex64>>
where
T: NumCast + Copy + Debug + 'static,
{
let (n_rows, n_cols) = x.dim();
let (_n_rows_out, n_cols_out) = shape.unwrap_or((n_rows, n_cols));
let full_fft = crate::fft::fft2(&x.to_owned(), shape, None, None)?;
let n_cols_result = n_cols_out / 2 + 1;
let result = full_fft.slice(s![.., 0..n_cols_result]).to_owned();
Ok(result)
}
#[allow(dead_code)]
pub fn irfft2<T>(
x: &ArrayView2<T>,
shape: Option<(usize, usize)>,
_axes: Option<(usize, usize)>,
_norm: Option<&str>,
) -> FFTResult<Array2<f64>>
where
T: NumCast + Copy + Debug + 'static,
{
let (n_rows, n_cols) = x.dim();
let (n_rows_out, n_cols_out) = shape.unwrap_or_else(|| (n_rows, 2 * (n_cols - 1)));
let mut full_spectrum = Array2::zeros((n_rows_out, n_cols_out));
for i in 0..n_rows.min(n_rows_out) {
for j in 0..n_cols.min(n_cols_out) {
let val = if let Some(c) = try_as_complex(x[[i, j]]) {
c
} else {
let element = x[[i, j]];
let val_f64: f64 = NumCast::from(element).ok_or_else(|| {
FFTError::ValueError(format!("Could not convert {element:?} to f64"))
})?;
Complex64::new(val_f64, 0.0)
};
full_spectrum[[i, j]] = val;
}
}
for i in 0..n_rows_out {
for j in n_cols..n_cols_out {
let sym_j = n_cols_out - j;
if sym_j < n_cols {
full_spectrum[[i, j]] = full_spectrum[[i, sym_j]].conj();
}
}
}
let complex_output = crate::fft::ifft2(
&full_spectrum.to_owned(),
Some((n_rows_out, n_cols_out)),
None,
None,
)?;
let result =
Array2::from_shape_fn((n_rows_out, n_cols_out), |(i, j)| complex_output[[i, j]].re);
Ok(result)
}
#[allow(dead_code)]
pub fn rfftn<T>(
x: &ArrayView<T, IxDyn>,
shape: Option<Vec<usize>>,
axes: Option<Vec<usize>>,
norm: Option<&str>,
overwrite_x: Option<bool>,
workers: Option<usize>,
) -> FFTResult<Array<Complex64, IxDyn>>
where
T: NumCast + Copy + Debug + 'static,
{
let full_result = crate::fft::fftn(
&x.to_owned(),
shape.clone(),
axes.clone(),
norm,
overwrite_x,
workers,
)?;
let n_dims = x.ndim();
let axes_to_transform = axes.unwrap_or_else(|| (0..n_dims).collect());
let last_axis = if let Some(last) = axes_to_transform.last() {
*last
} else {
n_dims - 1
};
let mut outshape = full_result.shape().to_vec();
if shape.is_none() {
outshape[last_axis] = outshape[last_axis] / 2 + 1;
}
let result = full_result
.slice_each_axis(|ax| {
if ax.axis.index() == last_axis {
scirs2_core::ndarray::Slice::new(0, Some(outshape[last_axis] as isize), 1)
} else {
scirs2_core::ndarray::Slice::new(0, None, 1)
}
})
.to_owned();
Ok(result)
}
#[allow(dead_code)]
pub fn irfftn<T>(
x: &ArrayView<T, IxDyn>,
shape: Option<Vec<usize>>,
axes: Option<Vec<usize>>,
norm: Option<&str>,
overwrite_x: Option<bool>,
workers: Option<usize>,
) -> FFTResult<Array<f64, IxDyn>>
where
T: NumCast + Copy + Debug + 'static,
{
let _overwrite_x = overwrite_x.unwrap_or(false);
let xshape = x.shape().to_vec();
let n_dims = x.ndim();
let axes_to_transform = match axes {
Some(ax) => {
for &axis in &ax {
if axis >= n_dims {
return Err(FFTError::DimensionError(format!(
"Axis {axis} is out of bounds for array of dimension {n_dims}"
)));
}
}
ax
}
None => (0..n_dims).collect(),
};
let outshape = match shape {
Some(sh) => {
if sh.len() != axes_to_transform.len()
&& !axes_to_transform.is_empty()
&& sh.len() != n_dims
{
return Err(FFTError::DimensionError(format!(
"Shape must have the same number of dimensions as input or match the length of axes, got {} expected {} or {}",
sh.len(),
n_dims,
axes_to_transform.len()
)));
}
if sh.len() == n_dims {
sh
} else if sh.len() == axes_to_transform.len() {
let mut newshape = xshape.clone();
for (i, &axis) in axes_to_transform.iter().enumerate() {
newshape[axis] = sh[i];
}
newshape
} else {
return Err(FFTError::DimensionError(
"Shape has invalid dimensions".to_string(),
));
}
}
None => {
let mut inferredshape = xshape.clone();
let last_axis = if let Some(last) = axes_to_transform.last() {
*last
} else {
n_dims - 1
};
inferredshape[last_axis] = 2 * (inferredshape[last_axis] - 1);
inferredshape
}
};
let full_spectrum = reconstruct_hermitian_symmetry(x, &outshape, axes_to_transform.as_slice())?;
let complex_output = crate::fft::ifftn(
&full_spectrum.to_owned(),
Some(outshape.clone()),
Some(axes_to_transform.clone()),
norm,
Some(_overwrite_x), workers,
)?;
let result = Array::from_shape_fn(IxDyn(&outshape), |idx| complex_output[idx].re);
Ok(result)
}
#[allow(dead_code)]
fn reconstruct_hermitian_symmetry<T>(
x: &ArrayView<T, IxDyn>,
outshape: &[usize],
axes: &[usize],
) -> FFTResult<Array<Complex64, IxDyn>>
where
T: NumCast + Copy + Debug + 'static,
{
let mut result = Array::from_shape_fn(IxDyn(outshape), |_| Complex64::zero());
let mut input_idx = vec![0; outshape.len()];
let xshape = x.shape();
fn fill_known_values<T>(
x: &ArrayView<T, IxDyn>,
result: &mut Array<Complex64, IxDyn>,
curr_idx: &mut Vec<usize>,
dim: usize,
xshape: &[usize],
) -> FFTResult<()>
where
T: NumCast + Copy + Debug + 'static,
{
if dim == curr_idx.len() {
let mut in_bounds = true;
for (i, &_idx) in curr_idx.iter().enumerate() {
if _idx >= xshape[i] {
in_bounds = false;
break;
}
}
if in_bounds {
let val = if let Some(c) = try_as_complex(x[IxDyn(curr_idx)]) {
c
} else {
let val_f64 = NumCast::from(x[IxDyn(curr_idx)]).ok_or_else(|| {
FFTError::ValueError(format!(
"Could not convert {:?} to f64",
x[IxDyn(curr_idx)]
))
})?;
Complex64::new(val_f64, 0.0)
};
result[IxDyn(curr_idx)] = val;
}
return Ok(());
}
for i in 0..xshape[dim] {
curr_idx[dim] = i;
fill_known_values(x, result, curr_idx, dim + 1, xshape)?;
}
Ok(())
}
fill_known_values(x, &mut result, &mut input_idx, 0, xshape)?;
let _first_axis = axes[0];
let mut processed = std::collections::HashSet::new();
let mut idx = vec![0; outshape.len()];
fn mark_processed(
idx: &mut Vec<usize>,
dim: usize,
_shape: &[usize],
xshape: &[usize],
processed: &mut std::collections::HashSet<Vec<usize>>,
) {
if dim == idx.len() {
let mut in_bounds = true;
for (i, &index) in idx.iter().enumerate() {
if index >= xshape[i] {
in_bounds = false;
break;
}
}
if in_bounds {
processed.insert(idx.clone());
}
return;
}
for i in 0..xshape[dim] {
idx[dim] = i;
mark_processed(idx, dim + 1, _shape, xshape, processed);
}
}
mark_processed(&mut idx, 0, outshape, xshape, &mut processed);
fn reflect_index(idx: &[usize], shape: &[usize], axes: &[usize]) -> Vec<usize> {
let mut reflected = idx.to_vec();
for &axis in axes {
if idx[axis] == 0 || (shape[axis].is_multiple_of(2) && idx[axis] == shape[axis] / 2) {
continue;
}
reflected[axis] = shape[axis] - idx[axis];
if reflected[axis] == shape[axis] {
reflected[axis] = 0;
}
}
reflected
}
let mut done = false;
idx.fill(0);
while !done {
if !processed.contains(&idx) {
let reflected = reflect_index(&idx, outshape, axes);
if processed.contains(&reflected) {
result[IxDyn(&idx)] = result[IxDyn(&reflected)].conj();
processed.insert(idx.clone());
}
}
for d in (0..outshape.len()).rev() {
idx[d] += 1;
if idx[d] < outshape[d] {
break;
}
idx[d] = 0;
if d == 0 {
done = true;
}
}
}
Ok(result)
}
#[allow(dead_code)]
fn try_as_complex<T: Copy + Debug + 'static>(val: T) -> Option<Complex64> {
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<Complex64>() {
unsafe {
let ptr = &val as *const T as *const Complex64;
return Some(*ptr);
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::arr2;
#[test]
fn test_rfft_and_irfft() {
let signal = vec![1.0, 2.0, 3.0, 4.0];
let spectrum = rfft(&signal, None).expect("RFFT computation should succeed for test data");
assert_eq!(spectrum.len(), signal.len() / 2 + 1);
assert_relative_eq!(spectrum[0].re, 10.0, epsilon = 1e-10);
let recovered =
irfft(&spectrum, Some(signal.len())).expect("IRFFT computation should succeed");
for i in 0..signal.len() {
assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
}
}
#[test]
fn test_rfft_with_zero_padding() {
let signal = vec![1.0, 2.0, 3.0, 4.0];
let padded_spectrum = rfft(&signal, Some(8)).expect("RFFT with padding should succeed");
assert_eq!(padded_spectrum.len(), 8 / 2 + 1);
assert_relative_eq!(padded_spectrum[0].re, 10.0, epsilon = 1e-10);
let recovered_padded =
irfft(&padded_spectrum, Some(8)).expect("IRFFT recovery should succeed");
for i in 0..4 {
assert_relative_eq!(recovered_padded[i], signal[i], epsilon = 1e-10);
}
for i in 4..8 {
assert_relative_eq!(recovered_padded[i], 0.0, epsilon = 1e-10);
}
}
#[test]
fn test_rfft2_and_irfft2() {
let arr = arr2(&[
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
[13.0, 14.0, 15.0, 16.0],
]);
let spectrum_2d = rfft2(&arr.view(), None, None, None).expect("2D RFFT should succeed");
assert_eq!(spectrum_2d.dim(), (4, 4 / 2 + 1));
let total_sum: f64 = (1..=16).map(|i| i as f64).sum();
assert_relative_eq!(spectrum_2d[[0, 0]].re, total_sum, epsilon = 1e-10);
let recovered_2d =
irfft2(&spectrum_2d.view(), Some((4, 4)), None, None).expect("2D IRFFT should succeed");
for i in 0..4 {
for j in 0..4 {
assert_relative_eq!(recovered_2d[[i, j]], arr[[i, j]], epsilon = 1e-8);
}
}
}
#[test]
fn test_rfft2_small() {
let arr = arr2(&[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]);
let spectrum = rfft2(&arr.view(), None, None, None).expect("Small 2D RFFT should succeed");
assert_eq!(spectrum.dim(), (2, 3));
let sum: f64 = (1..=8).map(|i| i as f64).sum();
assert_relative_eq!(spectrum[[0, 0]].re, sum, epsilon = 1e-10);
let recovered = irfft2(&spectrum.view(), Some((2, 4)), None, None)
.expect("Small 2D IRFFT should succeed");
for i in 0..2 {
for j in 0..4 {
assert_relative_eq!(recovered[[i, j]], arr[[i, j]], epsilon = 1e-8);
}
}
}
#[test]
fn test_sine_wave_rfft() {
let n = 16;
let freq = 2.0; let signal: Vec<f64> = (0..n)
.map(|i| (2.0 * PI * freq * i as f64 / n as f64).sin())
.collect();
let spectrum = rfft(&signal, None).expect("RFFT for sine wave should succeed");
let expected_peak = n as f64 / 2.0;
assert_relative_eq!(
spectrum[freq as usize].im.abs(),
expected_peak,
epsilon = 1e-10
);
let recovered = irfft(&spectrum, Some(n)).expect("IRFFT for sine wave should succeed");
for i in 0..n {
assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-8);
}
}
#[test]
fn test_rfft_hermitian_symmetry() {
let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let n = signal.len();
let spectrum = rfft(&signal, None).expect("RFFT should succeed");
assert_eq!(spectrum.len(), n / 2 + 1);
assert_relative_eq!(spectrum[0].im, 0.0, epsilon = 1e-10);
assert_relative_eq!(spectrum[n / 2].im, 0.0, epsilon = 1e-10);
}
#[test]
fn test_rfft_cosine_wave() {
let n = 32;
let freq = 4; let signal: Vec<f64> = (0..n)
.map(|i| (2.0 * PI * freq as f64 * i as f64 / n as f64).cos())
.collect();
let spectrum = rfft(&signal, None).expect("RFFT cosine should succeed");
for (i, val) in spectrum.iter().enumerate() {
if i == freq {
assert!(val.norm() > 1.0, "Should have energy at frequency {freq}");
} else {
assert!(
val.norm() < 1e-10,
"Should have no energy at frequency {i}, got {}",
val.norm()
);
}
}
}
#[test]
fn test_rfft_energy_conservation() {
let signal = vec![1.0, 3.0, -2.0, 4.0, 0.5, -1.5, 2.5, 3.5];
let n = signal.len();
let spectrum = rfft(&signal, None).expect("RFFT should succeed");
let time_energy: f64 = signal.iter().map(|x| x * x).sum();
let mut freq_energy = spectrum[0].norm_sqr(); freq_energy += spectrum[n / 2].norm_sqr(); for val in spectrum.iter().take(n / 2).skip(1) {
freq_energy += 2.0 * val.norm_sqr(); }
freq_energy /= n as f64;
assert_relative_eq!(time_energy, freq_energy, epsilon = 1e-8);
}
#[test]
fn test_irfft_length_inference() {
let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let spectrum = rfft(&signal, None).expect("RFFT should succeed");
let recovered = irfft(&spectrum, None).expect("IRFFT inference should succeed");
assert_eq!(recovered.len(), 6);
for i in 0..signal.len() {
assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-8);
}
}
}