use crate::api::{Direction, Flags, Plan};
use crate::kernel::{Complex, Float};
use crate::prelude::*;
pub struct C2rSolver<T: Float> {
n: usize,
twiddles: Vec<Complex<T>>,
}
impl<T: Float> Default for C2rSolver<T> {
fn default() -> Self {
Self::new(0)
}
}
impl<T: Float> C2rSolver<T> {
#[must_use]
pub fn new(n: usize) -> Self {
if n == 0 {
return Self {
n: 0,
twiddles: Vec::new(),
};
}
let mut twiddles = Vec::with_capacity(n / 2);
for k in 0..n / 2 {
let angle = <T as Float>::TWO_PI * T::from_usize(k) / T::from_usize(n);
twiddles.push(Complex::cis(angle));
}
Self { n, twiddles }
}
#[must_use]
pub fn name(&self) -> &'static str {
"rdft-c2r"
}
#[must_use]
pub fn size(&self) -> usize {
self.n
}
#[must_use]
pub fn input_size(&self) -> usize {
self.n / 2 + 1
}
pub fn execute(&self, input: &[Complex<T>], output: &mut [T]) {
let n = self.n;
assert_eq!(input.len(), n / 2 + 1, "Input size must be N/2+1");
assert_eq!(output.len(), n, "Output size must be N");
if n == 0 {
return;
}
if n == 1 {
output[0] = input[0].re;
return;
}
if n == 2 {
output[0] = input[0].re + input[1].re;
output[1] = input[0].re - input[1].re;
return;
}
let half_n = n / 2;
let mut z = vec![Complex::zero(); half_n];
z[0] = Complex::new(
input[0].re + input[half_n].re,
input[0].re - input[half_n].re,
);
for k in 1..half_n {
let xk = input[k];
let xn_k = input[half_n - k].conj();
let sum = xk + xn_k;
let diff = xk - xn_k;
let w = self.twiddles[k];
let i_diff = Complex::new(-diff.im, diff.re);
let term = i_diff * w;
z[k] = sum + term;
}
let mut z_ifft = vec![Complex::zero(); half_n];
if let Some(plan) = Plan::dft_1d(half_n, Direction::Backward, Flags::ESTIMATE) {
plan.execute(&z, &mut z_ifft);
}
for k in 0..half_n {
output[2 * k] = z_ifft[k].re;
output[2 * k + 1] = z_ifft[k].im;
}
}
pub fn execute_normalized(&self, input: &[Complex<T>], output: &mut [T]) {
self.execute(input, output);
let n = T::from_usize(self.n);
for x in output.iter_mut() {
*x = *x / n;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::rdft::solvers::R2cSolver;
fn approx_eq(a: f64, b: f64, eps: f64) -> bool {
(a - b).abs() < eps
}
#[test]
fn test_c2r_size_2() {
let original = [1.0_f64, 2.0];
let mut freq = vec![Complex::zero(); 2];
let mut recovered = [0.0_f64; 2];
R2cSolver::new(2).execute(&original, &mut freq);
C2rSolver::new(2).execute_normalized(&freq, &mut recovered);
for (a, b) in original.iter().zip(recovered.iter()) {
assert!(approx_eq(*a, *b, 1e-10), "got {b}, expected {a}");
}
}
#[test]
fn test_c2r_size_4() {
let original = [1.0_f64, 2.0, 3.0, 4.0];
let mut freq = vec![Complex::zero(); 3];
let mut recovered = [0.0_f64; 4];
R2cSolver::new(4).execute(&original, &mut freq);
C2rSolver::new(4).execute_normalized(&freq, &mut recovered);
for (a, b) in original.iter().zip(recovered.iter()) {
assert!(approx_eq(*a, *b, 1e-10), "got {b}, expected {a}");
}
}
#[test]
fn test_c2r_size_8() {
let original: Vec<f64> = (0..8).map(|i| f64::from(i)).collect();
let mut freq = vec![Complex::zero(); 5];
let mut recovered = vec![0.0_f64; 8];
R2cSolver::new(8).execute(&original, &mut freq);
C2rSolver::new(8).execute_normalized(&freq, &mut recovered);
for (i, (a, b)) in original.iter().zip(recovered.iter()).enumerate() {
assert!(
approx_eq(*a, *b, 1e-9),
"Mismatch at {i}: got {b}, expected {a}"
);
}
}
#[test]
fn test_c2r_size_16() {
let original: Vec<f64> = (0..16).map(|i| f64::from(i).sin()).collect();
let mut freq = vec![Complex::zero(); 9];
let mut recovered = vec![0.0_f64; 16];
R2cSolver::new(16).execute(&original, &mut freq);
C2rSolver::new(16).execute_normalized(&freq, &mut recovered);
for (i, (a, b)) in original.iter().zip(recovered.iter()).enumerate() {
assert!(
approx_eq(*a, *b, 1e-9),
"Mismatch at {i}: got {b}, expected {a}"
);
}
}
#[test]
fn test_c2r_roundtrip_random() {
let original: Vec<f64> = (0..32)
.map(|i| (f64::from(i) * 0.7).sin() + (f64::from(i) * 1.3).cos())
.collect();
let mut freq = vec![Complex::zero(); 17];
let mut recovered = vec![0.0_f64; 32];
R2cSolver::new(32).execute(&original, &mut freq);
C2rSolver::new(32).execute_normalized(&freq, &mut recovered);
for (i, (a, b)) in original.iter().zip(recovered.iter()).enumerate() {
assert!(
approx_eq(*a, *b, 1e-9),
"Mismatch at {i}: got {b}, expected {a}"
);
}
}
}