use crate::error::{FFTError, FFTResult};
use crate::fft::{fft, ifft};
use scirs2_core::ndarray::Array2;
use scirs2_core::numeric::{Complex64, NumCast, Zero};
use std::f64::consts::PI;
use std::fmt::Debug;
fn cast_to_f64<T: NumCast + Copy + Debug>(x: &[T]) -> FFTResult<Vec<f64>> {
x.iter()
.map(|&v| {
NumCast::from(v).ok_or_else(|| {
FFTError::ValueError(format!("Cannot cast {v:?} to f64"))
})
})
.collect()
}
fn twiddle_forward(n: usize, half: usize) -> Vec<Complex64> {
(0..=half)
.map(|k| {
let phase = -2.0 * PI * k as f64 / n as f64;
Complex64::new(phase.cos(), phase.sin())
})
.collect()
}
fn twiddle_inverse(n: usize, half: usize) -> Vec<Complex64> {
(0..=half)
.map(|k| {
let phase = 2.0 * PI * k as f64 / n as f64;
Complex64::new(phase.cos(), phase.sin())
})
.collect()
}
pub fn rfft_optimized<T>(x: &[T], n: Option<usize>) -> FFTResult<Vec<Complex64>>
where
T: NumCast + Copy + Debug + 'static,
{
if x.is_empty() {
return Err(FFTError::ValueError("rfft_optimized: input is empty".into()));
}
let input_f64 = cast_to_f64(x)?;
let n_val = n.unwrap_or(input_f64.len());
if n_val == 0 {
return Err(FFTError::ValueError("rfft_optimized: n must be > 0".into()));
}
let mut padded = vec![0.0_f64; n_val];
let copy_len = input_f64.len().min(n_val);
padded[..copy_len].copy_from_slice(&input_f64[..copy_len]);
if n_val == 1 {
return Ok(vec![Complex64::new(padded[0], 0.0)]);
}
let half = n_val / 2;
let z: Vec<Complex64> = (0..half)
.map(|k| {
let re = padded[2 * k];
let im = if 2 * k + 1 < n_val { padded[2 * k + 1] } else { 0.0 };
Complex64::new(re, im)
})
.collect();
let z_fft = fft(&z, None)?;
let n_out = n_val / 2 + 1;
let mut result = Vec::with_capacity(n_out);
let twiddles = twiddle_forward(n_val, half);
for k in 0..n_out {
let zk = if k < half { z_fft[k] } else { z_fft[0] }; let zm = if k == 0 {
z_fft[0]
} else if k < half {
z_fft[half - k]
} else {
z_fft[0] };
let zk_c = Complex64::new(zm.re, -zm.im);
let a_k = (zk + zk_c) * 0.5;
let diff = zk - zk_c;
let b_k = Complex64::new(diff.im * 0.5, -diff.re * 0.5);
result.push(a_k + twiddles[k] * b_k);
}
if n_val % 2 == 0 && n_val > 1 {
}
Ok(result)
}
pub fn irfft_optimized(x: &[Complex64], n: Option<usize>) -> FFTResult<Vec<f64>> {
if x.is_empty() {
return Err(FFTError::ValueError("irfft_optimized: input is empty".into()));
}
let n_out = n.unwrap_or(2 * (x.len() - 1));
if n_out == 0 {
return Ok(Vec::new());
}
let mut full = vec![Complex64::zero(); n_out];
let n_half = x.len();
for (k, &val) in x.iter().enumerate() {
if k < n_out {
full[k] = val;
}
}
let n_conj = if n_out % 2 == 0 { n_half - 1 } else { n_half };
for k in 1..n_conj {
let neg_k = n_out - k;
if neg_k < n_out && k < x.len() {
full[neg_k] = Complex64::new(x[k].re, -x[k].im);
}
}
let complex_result = ifft(&full, None)?;
Ok(complex_result.iter().map(|c| c.re).collect())
}
pub fn rfft2_optimized(
x: &Array2<f64>,
s: Option<(usize, usize)>,
) -> FFTResult<Array2<Complex64>> {
let (nrows, ncols) = x.dim();
let (out_rows, out_cols) = s.unwrap_or((nrows, ncols));
if out_rows == 0 || out_cols == 0 {
return Err(FFTError::ValueError("rfft2_optimized: output shape must be > 0".into()));
}
let col_out = out_cols / 2 + 1;
let mut row_spectra: Vec<Vec<Complex64>> = Vec::with_capacity(out_rows);
for i in 0..out_rows {
let row: Vec<f64> = (0..out_cols)
.map(|j| if i < nrows && j < ncols { x[[i, j]] } else { 0.0 })
.collect();
let spec = rfft_optimized(&row, Some(out_cols))?;
row_spectra.push(spec);
}
let mut result_data = vec![Complex64::zero(); out_rows * col_out];
for j in 0..col_out {
let col: Vec<Complex64> = row_spectra.iter().map(|r| r[j]).collect();
let col_fft = fft(&col, Some(out_rows))?;
for i in 0..out_rows {
result_data[i * col_out + j] = col_fft[i];
}
}
Array2::from_shape_vec((out_rows, col_out), result_data)
.map_err(|e| FFTError::DimensionError(format!("rfft2_optimized shape error: {e}")))
}
pub fn irfft2_optimized(
x: &Array2<Complex64>,
s: Option<(usize, usize)>,
) -> FFTResult<Array2<f64>> {
let (nrows_in, ncols_half) = x.dim();
let out_cols = s.map(|(_, c)| c).unwrap_or(2 * (ncols_half - 1));
let out_rows = s.map(|(r, _)| r).unwrap_or(nrows_in);
if out_rows == 0 || out_cols == 0 {
return Err(FFTError::ValueError("irfft2_optimized: output shape must be > 0".into()));
}
let mut col_ifft_mat: Vec<Vec<Complex64>> = vec![vec![Complex64::zero(); ncols_half]; out_rows];
for j in 0..ncols_half {
let col: Vec<Complex64> = (0..nrows_in).map(|i| x[[i, j]]).collect();
let col_result = ifft(&col, Some(out_rows))?;
for i in 0..out_rows {
col_ifft_mat[i][j] = col_result[i];
}
}
let mut result_data = vec![0.0_f64; out_rows * out_cols];
for i in 0..out_rows {
let row_real = irfft_optimized(&col_ifft_mat[i], Some(out_cols))?;
for j in 0..out_cols {
result_data[i * out_cols + j] = row_real[j];
}
}
Array2::from_shape_vec((out_rows, out_cols), result_data)
.map_err(|e| FFTError::DimensionError(format!("irfft2_optimized shape error: {e}")))
}
pub fn rfft_norm<T>(x: &[T], n: Option<usize>) -> FFTResult<Vec<Complex64>>
where
T: NumCast + Copy + Debug + 'static,
{
let n_val = n.unwrap_or(x.len());
let mut result = rfft_optimized(x, n)?;
let scale = 1.0 / (n_val as f64).sqrt();
for c in &mut result {
*c = Complex64::new(c.re * scale, c.im * scale);
}
Ok(result)
}
pub fn irfft_norm(x: &[Complex64], n: Option<usize>) -> FFTResult<Vec<f64>> {
let n_val = n.unwrap_or(2 * (x.len() - 1));
let scale = (n_val as f64).sqrt();
let result = irfft_optimized(x, n)?;
Ok(result.iter().map(|&v| v * scale).collect())
}
#[allow(dead_code)]
fn apply_twiddle_correction(
spec: &[Vec<Complex64>],
n_cols: usize,
) -> Vec<Vec<Complex64>> {
let n_rows = spec.len();
let col_half = n_cols / 2 + 1;
let twiddles = twiddle_forward(n_cols, col_half);
let mut out = vec![vec![Complex64::zero(); col_half]; n_rows];
for i in 0..n_rows {
for j in 0..col_half {
out[i][j] = spec[i][j] * twiddles[j];
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
fn brute_dft(x: &[f64]) -> Vec<Complex64> {
let n = x.len();
(0..n / 2 + 1)
.map(|k| {
x.iter().enumerate().fold(Complex64::zero(), |acc, (m, &xm)| {
let phase = -2.0 * PI * k as f64 * m as f64 / n as f64;
acc + Complex64::new(xm * phase.cos(), xm * phase.sin())
})
})
.collect()
}
fn assert_complex_slice_eq(a: &[Complex64], b: &[Complex64], tol: f64) {
assert_eq!(a.len(), b.len(), "length mismatch");
for (i, (ai, bi)) in a.iter().zip(b.iter()).enumerate() {
assert_relative_eq!(ai.re, bi.re, epsilon = tol,
max_relative = tol, var_name = format!("bin {i} re"));
assert_relative_eq!(ai.im, bi.im, epsilon = tol,
max_relative = tol, var_name = format!("bin {i} im"));
}
}
#[test]
fn test_rfft_length_4_vs_brute() {
let signal = vec![1.0_f64, 2.0, 3.0, 4.0];
let got = rfft_optimized(&signal, None).expect("rfft failed");
let expected = brute_dft(&signal);
assert_complex_slice_eq(&got, &expected, 1e-9);
}
#[test]
fn test_rfft_length_8_vs_brute() {
let signal: Vec<f64> = (0..8).map(|k| (k as f64).cos()).collect();
let got = rfft_optimized(&signal, None).expect("rfft failed");
let expected = brute_dft(&signal);
assert_complex_slice_eq(&got, &expected, 1e-9);
}
#[test]
fn test_rfft_odd_length_5_vs_brute() {
let signal = vec![1.0_f64, -1.0, 2.0, -2.0, 0.5];
let got = rfft_optimized(&signal, None).expect("rfft failed");
let expected = brute_dft(&signal);
assert_complex_slice_eq(&got, &expected, 1e-9);
}
#[test]
fn test_rfft_output_length() {
for n in [2, 3, 4, 7, 8, 16, 32, 100] {
let sig: Vec<f64> = (0..n).map(|k| k as f64).collect();
let spec = rfft_optimized(&sig, None).expect("rfft");
assert_eq!(spec.len(), n / 2 + 1, "output length for N={n}");
}
}
#[test]
fn test_rfft_irfft_roundtrip_even() {
let signal: Vec<f64> = (0..8).map(|k| k as f64 * 0.5).collect();
let spectrum = rfft_optimized(&signal, None).expect("rfft");
let recovered = irfft_optimized(&spectrum, Some(signal.len())).expect("irfft");
for (a, b) in signal.iter().zip(recovered.iter()) {
assert_relative_eq!(*a, *b, epsilon = 1e-9);
}
}
#[test]
fn test_rfft_irfft_roundtrip_odd() {
let signal = vec![3.0_f64, 1.0, 4.0, 1.0, 5.0];
let spectrum = rfft_optimized(&signal, None).expect("rfft");
let recovered = irfft_optimized(&spectrum, Some(signal.len())).expect("irfft");
for (a, b) in signal.iter().zip(recovered.iter()) {
assert_relative_eq!(*a, *b, epsilon = 1e-9);
}
}
#[test]
fn test_rfft_norm_roundtrip() {
let signal = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
let spec = rfft_norm(&signal, None).expect("rfft_norm");
let rec = irfft_norm(&spec, Some(signal.len())).expect("irfft_norm");
for (a, b) in signal.iter().zip(rec.iter()) {
assert_relative_eq!(*a, *b, epsilon = 1e-9);
}
}
#[test]
fn test_rfft2_shape() {
let data = Array2::from_shape_vec(
(4, 8),
(0..32).map(|k| k as f64).collect(),
)
.expect("shape");
let spec = rfft2_optimized(&data, None).expect("rfft2");
assert_eq!(spec.shape(), &[4, 5]); }
#[test]
fn test_rfft2_irfft2_roundtrip() {
let n = 4usize;
let data = Array2::from_shape_vec(
(n, n),
(0..n * n).map(|k| k as f64).collect(),
)
.expect("shape");
let spectrum = rfft2_optimized(&data, None).expect("rfft2");
let recovered = irfft2_optimized(&spectrum, Some((n, n))).expect("irfft2");
for i in 0..n {
for j in 0..n {
assert_relative_eq!(data[[i, j]], recovered[[i, j]], epsilon = 1e-8,
var_name = format!("[{i},{j}]"));
}
}
}
#[test]
fn test_rfft_empty_error() {
let empty: Vec<f64> = vec![];
assert!(rfft_optimized(&empty, None).is_err());
}
#[test]
fn test_irfft_empty_error() {
let empty: Vec<Complex64> = vec![];
assert!(irfft_optimized(&empty, None).is_err());
}
#[test]
fn test_rfft_single_element() {
let x = vec![42.0_f64];
let spec = rfft_optimized(&x, None).expect("rfft single");
assert_eq!(spec.len(), 1);
assert_relative_eq!(spec[0].re, 42.0, epsilon = 1e-12);
assert_relative_eq!(spec[0].im, 0.0, epsilon = 1e-12);
}
#[test]
fn test_rfft_dc_component() {
let signal = vec![1.0_f64, 2.0, 3.0, 4.0];
let spec = rfft_optimized(&signal, None).expect("rfft");
let dc_expected = signal.iter().sum::<f64>();
assert_relative_eq!(spec[0].re, dc_expected, epsilon = 1e-9);
assert_relative_eq!(spec[0].im, 0.0, epsilon = 1e-9);
}
#[test]
fn test_rfft_with_n_truncate() {
let signal = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
let spec = rfft_optimized(&signal, Some(4)).expect("rfft truncate");
assert_eq!(spec.len(), 3); }
#[test]
fn test_rfft_with_n_zeropad() {
let signal = vec![1.0_f64, 2.0];
let spec = rfft_optimized(&signal, Some(8)).expect("rfft pad");
assert_eq!(spec.len(), 5); }
}