use crate::kernel::{Complex, Float};
#[allow(unused_imports)]
use crate::prelude::*;
pub struct Hc2cSolver<T: Float> {
n: usize,
_marker: core::marker::PhantomData<T>,
}
impl<T: Float> Default for Hc2cSolver<T> {
fn default() -> Self {
Self::new(1)
}
}
impl<T: Float> Hc2cSolver<T> {
#[must_use]
pub fn new(n: usize) -> Self {
Self {
n,
_marker: core::marker::PhantomData,
}
}
#[must_use]
pub fn name(&self) -> &'static str {
"rdft-hc2c"
}
#[must_use]
pub fn n(&self) -> usize {
self.n
}
#[must_use]
pub fn output_len(&self) -> usize {
self.n / 2 + 1
}
pub fn execute(&self, halfcomplex: &[T], complex: &mut [Complex<T>]) {
assert_eq!(
halfcomplex.len(),
self.n,
"Half-complex input must have length n"
);
assert_eq!(
complex.len(),
self.output_len(),
"Complex output must have length n/2+1"
);
if self.n == 0 {
return;
}
complex[0] = Complex::new(halfcomplex[0], T::ZERO);
if self.n == 1 {
return;
}
let num_pairs = (self.n - 1) / 2;
for k in 1..=num_pairs {
let re_idx = 2 * k - 1;
let im_idx = 2 * k;
complex[k] = Complex::new(halfcomplex[re_idx], halfcomplex[im_idx]);
}
if self.n.is_multiple_of(2) {
complex[self.n / 2] = Complex::new(halfcomplex[self.n - 1], T::ZERO);
}
}
}
pub struct C2hcSolver<T: Float> {
n: usize,
_marker: core::marker::PhantomData<T>,
}
impl<T: Float> Default for C2hcSolver<T> {
fn default() -> Self {
Self::new(1)
}
}
impl<T: Float> C2hcSolver<T> {
#[must_use]
pub fn new(n: usize) -> Self {
Self {
n,
_marker: core::marker::PhantomData,
}
}
#[must_use]
pub fn name(&self) -> &'static str {
"rdft-c2hc"
}
#[must_use]
pub fn n(&self) -> usize {
self.n
}
#[must_use]
pub fn input_len(&self) -> usize {
self.n / 2 + 1
}
pub fn execute(&self, complex: &[Complex<T>], halfcomplex: &mut [T]) {
assert_eq!(
complex.len(),
self.input_len(),
"Complex input must have length n/2+1"
);
assert_eq!(
halfcomplex.len(),
self.n,
"Half-complex output must have length n"
);
if self.n == 0 {
return;
}
halfcomplex[0] = complex[0].re;
if self.n == 1 {
return;
}
let num_pairs = (self.n - 1) / 2;
for k in 1..=num_pairs {
let re_idx = 2 * k - 1;
let im_idx = 2 * k;
halfcomplex[re_idx] = complex[k].re;
halfcomplex[im_idx] = complex[k].im;
}
if self.n.is_multiple_of(2) {
halfcomplex[self.n - 1] = complex[self.n / 2].re;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64, eps: f64) -> bool {
(a - b).abs() < eps
}
fn complex_approx_eq(a: Complex<f64>, b: Complex<f64>, eps: f64) -> bool {
approx_eq(a.re, b.re, eps) && approx_eq(a.im, b.im, eps)
}
#[test]
fn test_hc2c_size_4_even() {
let solver = Hc2cSolver::<f64>::new(4);
assert_eq!(solver.output_len(), 3);
let hc = [1.0, 2.0, 3.0, 4.0]; let mut complex = vec![Complex::zero(); 3];
solver.execute(&hc, &mut complex);
assert!(complex_approx_eq(complex[0], Complex::new(1.0, 0.0), 1e-10));
assert!(complex_approx_eq(complex[1], Complex::new(2.0, 3.0), 1e-10));
assert!(complex_approx_eq(complex[2], Complex::new(4.0, 0.0), 1e-10));
}
#[test]
fn test_hc2c_size_5_odd() {
let solver = Hc2cSolver::<f64>::new(5);
assert_eq!(solver.output_len(), 3);
let hc = [1.0, 2.0, 3.0, 4.0, 5.0]; let mut complex = vec![Complex::zero(); 3];
solver.execute(&hc, &mut complex);
assert!(complex_approx_eq(complex[0], Complex::new(1.0, 0.0), 1e-10));
assert!(complex_approx_eq(complex[1], Complex::new(2.0, 3.0), 1e-10));
assert!(complex_approx_eq(complex[2], Complex::new(4.0, 5.0), 1e-10));
}
#[test]
fn test_hc2c_c2hc_roundtrip() {
let n = 8;
let hc2c = Hc2cSolver::<f64>::new(n);
let c2hc = C2hcSolver::<f64>::new(n);
let original: Vec<f64> = (1..=n).map(|i| i as f64).collect();
let mut complex = vec![Complex::zero(); hc2c.output_len()];
let mut recovered = vec![0.0_f64; n];
hc2c.execute(&original, &mut complex);
c2hc.execute(&complex, &mut recovered);
for (a, b) in original.iter().zip(recovered.iter()) {
assert!(approx_eq(*a, *b, 1e-10));
}
}
#[test]
fn test_c2hc_size_4() {
let solver = C2hcSolver::<f64>::new(4);
assert_eq!(solver.input_len(), 3);
let complex = [
Complex::new(1.0, 0.0),
Complex::new(2.0, 3.0),
Complex::new(4.0, 0.0),
];
let mut hc = vec![0.0_f64; 4];
solver.execute(&complex, &mut hc);
assert!(approx_eq(hc[0], 1.0, 1e-10)); assert!(approx_eq(hc[1], 2.0, 1e-10)); assert!(approx_eq(hc[2], 3.0, 1e-10)); assert!(approx_eq(hc[3], 4.0, 1e-10)); }
}