use crate::api::{Direction, Flags, Plan};
use crate::kernel::{Complex, Float};
use crate::prelude::*;
pub struct R2cSolver<T: Float> {
n: usize,
twiddles: Vec<Complex<T>>,
}
impl<T: Float> Default for R2cSolver<T> {
fn default() -> Self {
Self::new(0)
}
}
impl<T: Float> R2cSolver<T> {
#[must_use]
pub fn new(n: usize) -> Self {
if n == 0 {
return Self {
n: 0,
twiddles: Vec::new(),
};
}
let mut twiddles = Vec::with_capacity(n / 2);
for k in 0..n / 2 {
let angle = -<T as Float>::TWO_PI * T::from_usize(k) / T::from_usize(n);
twiddles.push(Complex::cis(angle));
}
Self { n, twiddles }
}
#[must_use]
pub fn name(&self) -> &'static str {
"rdft-r2c"
}
#[must_use]
pub fn size(&self) -> usize {
self.n
}
#[must_use]
pub fn output_size(&self) -> usize {
self.n / 2 + 1
}
pub fn execute(&self, input: &[T], output: &mut [Complex<T>]) {
let n = self.n;
assert_eq!(input.len(), n, "Input size must be N");
assert_eq!(output.len(), n / 2 + 1, "Output size must be N/2+1");
if n == 0 {
return;
}
if n == 1 {
output[0] = Complex::new(input[0], T::ZERO);
return;
}
if n == 2 {
output[0] = Complex::new(input[0] + input[1], T::ZERO);
output[1] = Complex::new(input[0] - input[1], T::ZERO);
return;
}
let half_n = n / 2;
let mut z = vec![Complex::zero(); half_n];
for k in 0..half_n {
z[k] = Complex::new(input[2 * k], input[2 * k + 1]);
}
let mut z_fft = vec![Complex::zero(); half_n];
if let Some(plan) = Plan::dft_1d(half_n, Direction::Forward, Flags::ESTIMATE) {
plan.execute(&z, &mut z_fft);
}
output[0] = Complex::new(z_fft[0].re + z_fft[0].im, T::ZERO);
output[half_n] = Complex::new(z_fft[0].re - z_fft[0].im, T::ZERO);
let half = T::from_usize(2).recip();
for k in 1..half_n {
let zk = z_fft[k];
let zn_k = z_fft[half_n - k].conj();
let sum = zk + zn_k;
let diff = zk - zn_k;
let w = self.twiddles[k];
let i_diff = Complex::new(diff.im, -diff.re);
let term = i_diff * w;
output[k] = (sum + term) * half;
}
}
pub fn execute_inplace(&self, data: &mut [T]) {
let n = self.n;
assert!(data.len() >= n, "Buffer too small for input");
if n <= 2 {
if n == 0 {
return;
}
if n == 1 {
return;
}
let x0 = data[0];
let x1 = data[1];
data[0] = x0 + x1; data[1] = T::ZERO; return;
}
let input: Vec<T> = data[..n].to_vec();
let mut output = vec![Complex::zero(); n / 2 + 1];
self.execute(&input, &mut output);
for (i, c) in output.iter().enumerate() {
data[2 * i] = c.re;
data[2 * i + 1] = c.im;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api::fft;
fn approx_eq(a: f64, b: f64, eps: f64) -> bool {
(a - b).abs() < eps
}
fn complex_approx_eq(a: Complex<f64>, b: Complex<f64>, eps: f64) -> bool {
approx_eq(a.re, b.re, eps) && approx_eq(a.im, b.im, eps)
}
#[test]
fn test_r2c_size_2() {
let input = [1.0_f64, 2.0];
let mut output = vec![Complex::zero(); 2];
R2cSolver::new(2).execute(&input, &mut output);
assert!(complex_approx_eq(output[0], Complex::new(3.0, 0.0), 1e-10));
assert!(complex_approx_eq(output[1], Complex::new(-1.0, 0.0), 1e-10));
}
#[test]
fn test_r2c_size_4() {
let input = [1.0_f64, 2.0, 3.0, 4.0];
let mut output = vec![Complex::zero(); 3];
R2cSolver::new(4).execute(&input, &mut output);
let complex_input: Vec<Complex<f64>> =
input.iter().map(|&x| Complex::new(x, 0.0)).collect();
let full_fft = fft(&complex_input);
assert!(complex_approx_eq(output[0], full_fft[0], 1e-10));
assert!(complex_approx_eq(output[1], full_fft[1], 1e-10));
assert!(complex_approx_eq(output[2], full_fft[2], 1e-10));
}
#[test]
fn test_r2c_size_8() {
let input: Vec<f64> = (0..8).map(|i| f64::from(i)).collect();
let mut output = vec![Complex::zero(); 5];
R2cSolver::new(8).execute(&input, &mut output);
let complex_input: Vec<Complex<f64>> =
input.iter().map(|&x| Complex::new(x, 0.0)).collect();
let full_fft = fft(&complex_input);
for k in 0..5 {
assert!(
complex_approx_eq(output[k], full_fft[k], 1e-9),
"Mismatch at k={}: got {:?}, expected {:?}",
k,
output[k],
full_fft[k]
);
}
}
#[test]
fn test_r2c_size_16() {
let input: Vec<f64> = (0..16).map(|i| f64::from(i).sin()).collect();
let mut output = vec![Complex::zero(); 9];
R2cSolver::new(16).execute(&input, &mut output);
let complex_input: Vec<Complex<f64>> =
input.iter().map(|&x| Complex::new(x, 0.0)).collect();
let full_fft = fft(&complex_input);
for k in 0..9 {
assert!(
complex_approx_eq(output[k], full_fft[k], 1e-9),
"Mismatch at k={}: got {:?}, expected {:?}",
k,
output[k],
full_fft[k]
);
}
}
#[test]
fn test_r2c_dc_and_nyquist_are_real() {
let input: Vec<f64> = (0..8).map(|i| f64::from(i).sin() + 0.5).collect();
let mut output = vec![Complex::zero(); 5];
R2cSolver::new(8).execute(&input, &mut output);
assert!(
approx_eq(output[0].im, 0.0, 1e-10),
"DC component should be real"
);
assert!(
approx_eq(output[4].im, 0.0, 1e-10),
"Nyquist component should be real"
);
}
}