use std::f32::consts::FRAC_PI_4;
use std::ptr;
#[derive(Debug, Clone)]
pub struct Fft4g {
n: usize,
nw: usize,
nc: usize,
bitrv_ip: Vec<usize>,
bitrv_m: usize,
bitrv_long: bool,
w: Vec<f32>,
}
impl Fft4g {
pub fn new(n: usize) -> Self {
assert!(n >= 2, "FFT size must be >= 2, got {n}");
assert!(
n.is_power_of_two(),
"FFT size must be a power of 2, got {n}"
);
let nw = n >> 2;
let ip_len = 2 + (1 << ((n / 2).ilog2() as usize / 2));
let mut ip = vec![0_usize; ip_len.max(4)];
let mut w = vec![0.0_f32; n / 2];
if nw > 0 {
makewt(nw, &mut ip, &mut w);
}
let nc = n >> 2;
if nc > 0 {
makect(nc, &mut ip, &mut w[nw..]);
}
let (bitrv_ip, bitrv_m, bitrv_long) = build_bitrv_table(n);
Self {
n,
nw,
nc,
bitrv_ip,
bitrv_m,
bitrv_long,
w,
}
}
pub fn rdft(&self, a: &mut [f32]) {
assert_eq!(a.len(), self.n, "input length must be {}", self.n);
let n = self.n;
if n > 4 {
apply_bitrv2(&self.bitrv_ip, self.bitrv_m, self.bitrv_long, a);
cftfsub(n, a, &self.w);
rftfsub(n, a, self.nc, &self.w[self.nw..]);
} else if n == 4 {
cftfsub(n, a, &self.w);
}
let xi = a[0] - a[1];
a[0] += a[1];
a[1] = xi;
}
pub fn irdft(&self, a: &mut [f32]) {
assert_eq!(a.len(), self.n, "input length must be {}", self.n);
let n = self.n;
a[1] = 0.5 * (a[0] - a[1]);
a[0] -= a[1];
if n > 4 {
rftbsub(n, a, self.nc, &self.w[self.nw..]);
apply_bitrv2(&self.bitrv_ip, self.bitrv_m, self.bitrv_long, a);
cftbsub(n, a, &self.w);
} else if n == 4 {
cftfsub(n, a, &self.w);
}
}
}
fn makewt(nw: usize, ip: &mut [usize], w: &mut [f32]) {
ip[0] = nw;
ip[1] = 1;
if nw > 2 {
let nwh = nw >> 1;
let delta = FRAC_PI_4 / nwh as f32;
w[0] = 1.0;
w[1] = 0.0;
w[nwh] = (delta * nwh as f32).cos();
w[nwh + 1] = w[nwh];
if nwh > 2 {
for j in (2..nwh).step_by(2) {
let x = (delta * j as f32).cos();
let y = (delta * j as f32).sin();
w[j] = x;
w[j + 1] = y;
w[nw - j] = y;
w[nw - j + 1] = x;
}
bitrv2(nw, &mut ip[2..], w);
}
}
}
fn makect(nc: usize, ip: &mut [usize], c: &mut [f32]) {
ip[1] = nc;
if nc > 1 {
let nch = nc >> 1;
let delta = FRAC_PI_4 / nch as f32;
c[0] = (delta * nch as f32).cos();
c[nch] = 0.5 * c[0];
for j in 1..nch {
c[j] = 0.5 * (delta * j as f32).cos();
c[nc - j] = 0.5 * (delta * j as f32).sin();
}
}
}
fn build_bitrv_table(n: usize) -> (Vec<usize>, usize, bool) {
let mut ip = vec![0_usize; n];
ip[0] = 0;
let mut l = n;
let mut m = 1_usize;
while (m << 3) < l {
l >>= 1;
for j in 0..m {
ip[m + j] = ip[j] + l;
}
m <<= 1;
}
let use_long_swap = (m << 3) == l;
ip.truncate(2 * m); (ip, m, use_long_swap)
}
#[inline(always)]
unsafe fn swap_unchecked(a: &mut [f32], i: usize, j: usize) {
debug_assert!(i < a.len() && j < a.len());
unsafe {
ptr::swap(a.as_mut_ptr().add(i), a.as_mut_ptr().add(j));
}
}
#[inline(always)]
unsafe fn get(a: &[f32], i: usize) -> f32 {
debug_assert!(i < a.len());
unsafe { *a.get_unchecked(i) }
}
#[inline(always)]
unsafe fn set(a: &mut [f32], i: usize, v: f32) {
debug_assert!(i < a.len());
unsafe {
*a.get_unchecked_mut(i) = v;
}
}
fn apply_bitrv2(ip: &[usize], m: usize, use_long_swap: bool, a: &mut [f32]) {
let m2 = 2 * m;
if use_long_swap {
for k in 0..m {
for j in 0..k {
let j1 = 2 * j + ip[k];
let k1 = 2 * k + ip[j];
unsafe {
swap_unchecked(a, j1, k1);
swap_unchecked(a, j1 + 1, k1 + 1);
let j1 = j1 + m2;
let k1 = k1 + 2 * m2;
swap_unchecked(a, j1, k1);
swap_unchecked(a, j1 + 1, k1 + 1);
let j1 = j1 + m2;
let k1 = k1 - m2;
swap_unchecked(a, j1, k1);
swap_unchecked(a, j1 + 1, k1 + 1);
let j1 = j1 + m2;
let k1 = k1 + 2 * m2;
swap_unchecked(a, j1, k1);
swap_unchecked(a, j1 + 1, k1 + 1);
}
}
let j1 = 2 * k + m2 + ip[k];
let k1 = j1 + m2;
unsafe {
swap_unchecked(a, j1, k1);
swap_unchecked(a, j1 + 1, k1 + 1);
}
}
} else {
for k in 1..m {
for j in 0..k {
let j1 = 2 * j + ip[k];
let k1 = 2 * k + ip[j];
unsafe {
swap_unchecked(a, j1, k1);
swap_unchecked(a, j1 + 1, k1 + 1);
let j1 = j1 + m2;
let k1 = k1 + m2;
swap_unchecked(a, j1, k1);
swap_unchecked(a, j1 + 1, k1 + 1);
}
}
}
}
}
fn bitrv2(n: usize, ip: &mut [usize], a: &mut [f32]) {
ip[0] = 0;
let mut l = n;
let mut m = 1_usize;
while (m << 3) < l {
l >>= 1;
for j in 0..m {
ip[m + j] = ip[j] + l;
}
m <<= 1;
}
let m2 = 2 * m;
if (m << 3) == l {
for k in 0..m {
for j in 0..k {
let j1 = 2 * j + ip[k];
let k1 = 2 * k + ip[j];
a.swap(j1, k1);
a.swap(j1 + 1, k1 + 1);
let j1 = j1 + m2;
let k1 = k1 + 2 * m2;
a.swap(j1, k1);
a.swap(j1 + 1, k1 + 1);
let j1 = j1 + m2;
let k1 = k1 - m2;
a.swap(j1, k1);
a.swap(j1 + 1, k1 + 1);
let j1 = j1 + m2;
let k1 = k1 + 2 * m2;
a.swap(j1, k1);
a.swap(j1 + 1, k1 + 1);
}
let j1 = 2 * k + m2 + ip[k];
let k1 = j1 + m2;
a.swap(j1, k1);
a.swap(j1 + 1, k1 + 1);
}
} else {
for k in 1..m {
for j in 0..k {
let j1 = 2 * j + ip[k];
let k1 = 2 * k + ip[j];
a.swap(j1, k1);
a.swap(j1 + 1, k1 + 1);
let j1 = j1 + m2;
let k1 = k1 + m2;
a.swap(j1, k1);
a.swap(j1 + 1, k1 + 1);
}
}
}
}
fn cftfsub(n: usize, a: &mut [f32], w: &[f32]) {
let mut l = 2;
if n > 8 {
cft1st(n, a, w);
l = 8;
while (l << 2) < n {
cftmdl(n, l, a, w);
l <<= 2;
}
}
unsafe {
if (l << 2) == n {
for j in (0..l).step_by(2) {
let j1 = j + l;
let j2 = j1 + l;
let j3 = j2 + l;
let x0r = get(a, j) + get(a, j1);
let x0i = get(a, j + 1) + get(a, j1 + 1);
let x1r = get(a, j) - get(a, j1);
let x1i = get(a, j + 1) - get(a, j1 + 1);
let x2r = get(a, j2) + get(a, j3);
let x2i = get(a, j2 + 1) + get(a, j3 + 1);
let x3r = get(a, j2) - get(a, j3);
let x3i = get(a, j2 + 1) - get(a, j3 + 1);
set(a, j, x0r + x2r);
set(a, j + 1, x0i + x2i);
set(a, j2, x0r - x2r);
set(a, j2 + 1, x0i - x2i);
set(a, j1, x1r - x3i);
set(a, j1 + 1, x1i + x3r);
set(a, j3, x1r + x3i);
set(a, j3 + 1, x1i - x3r);
}
} else {
for j in (0..l).step_by(2) {
let j1 = j + l;
let x0r = get(a, j) - get(a, j1);
let x0i = get(a, j + 1) - get(a, j1 + 1);
set(a, j, get(a, j) + get(a, j1));
set(a, j + 1, get(a, j + 1) + get(a, j1 + 1));
set(a, j1, x0r);
set(a, j1 + 1, x0i);
}
}
}
}
fn cftbsub(n: usize, a: &mut [f32], w: &[f32]) {
let mut l = 2;
if n > 8 {
cft1st(n, a, w);
l = 8;
while (l << 2) < n {
cftmdl(n, l, a, w);
l <<= 2;
}
}
unsafe {
if (l << 2) == n {
for j in (0..l).step_by(2) {
let j1 = j + l;
let j2 = j1 + l;
let j3 = j2 + l;
let x0r = get(a, j) + get(a, j1);
let x0i = -get(a, j + 1) - get(a, j1 + 1);
let x1r = get(a, j) - get(a, j1);
let x1i = -get(a, j + 1) + get(a, j1 + 1);
let x2r = get(a, j2) + get(a, j3);
let x2i = get(a, j2 + 1) + get(a, j3 + 1);
let x3r = get(a, j2) - get(a, j3);
let x3i = get(a, j2 + 1) - get(a, j3 + 1);
set(a, j, x0r + x2r);
set(a, j + 1, x0i - x2i);
set(a, j2, x0r - x2r);
set(a, j2 + 1, x0i + x2i);
set(a, j1, x1r - x3i);
set(a, j1 + 1, x1i - x3r);
set(a, j3, x1r + x3i);
set(a, j3 + 1, x1i + x3r);
}
} else {
for j in (0..l).step_by(2) {
let j1 = j + l;
let x0r = get(a, j) - get(a, j1);
let x0i = -get(a, j + 1) + get(a, j1 + 1);
set(a, j, get(a, j) + get(a, j1));
set(a, j + 1, -get(a, j + 1) - get(a, j1 + 1));
set(a, j1, x0r);
set(a, j1 + 1, x0i);
}
}
}
}
fn cft1st(n: usize, a: &mut [f32], w: &[f32]) {
unsafe {
let x0r = get(a, 0) + get(a, 2);
let x0i = get(a, 1) + get(a, 3);
let x1r = get(a, 0) - get(a, 2);
let x1i = get(a, 1) - get(a, 3);
let x2r = get(a, 4) + get(a, 6);
let x2i = get(a, 5) + get(a, 7);
let x3r = get(a, 4) - get(a, 6);
let x3i = get(a, 5) - get(a, 7);
set(a, 0, x0r + x2r);
set(a, 1, x0i + x2i);
set(a, 4, x0r - x2r);
set(a, 5, x0i - x2i);
set(a, 2, x1r - x3i);
set(a, 3, x1i + x3r);
set(a, 6, x1r + x3i);
set(a, 7, x1i - x3r);
let wk1r = get(w, 2);
let x0r = get(a, 8) + get(a, 10);
let x0i = get(a, 9) + get(a, 11);
let x1r = get(a, 8) - get(a, 10);
let x1i = get(a, 9) - get(a, 11);
let x2r = get(a, 12) + get(a, 14);
let x2i = get(a, 13) + get(a, 15);
let x3r = get(a, 12) - get(a, 14);
let x3i = get(a, 13) - get(a, 15);
set(a, 8, x0r + x2r);
set(a, 9, x0i + x2i);
set(a, 12, x2i - x0i);
set(a, 13, x0r - x2r);
let x0r = x1r - x3i;
let x0i = x1i + x3r;
set(a, 10, wk1r * (x0r - x0i));
set(a, 11, wk1r * (x0r + x0i));
let x0r = x3i + x1r;
let x0i = x3r - x1i;
set(a, 14, wk1r * (x0i - x0r));
set(a, 15, wk1r * (x0i + x0r));
let mut k1 = 0_usize;
let mut j = 16;
while j < n {
k1 += 2;
let k2 = 2 * k1;
let wk2r = get(w, k1);
let wk2i = get(w, k1 + 1);
let wk1r = get(w, k2);
let wk1i = get(w, k2 + 1);
let wk3r = (-2.0 * wk2i).mul_add(wk1i, wk1r);
let wk3i = (2.0 * wk2i).mul_add(wk1r, -wk1i);
let x0r = get(a, j) + get(a, j + 2);
let x0i = get(a, j + 1) + get(a, j + 3);
let x1r = get(a, j) - get(a, j + 2);
let x1i = get(a, j + 1) - get(a, j + 3);
let x2r = get(a, j + 4) + get(a, j + 6);
let x2i = get(a, j + 5) + get(a, j + 7);
let x3r = get(a, j + 4) - get(a, j + 6);
let x3i = get(a, j + 5) - get(a, j + 7);
set(a, j, x0r + x2r);
set(a, j + 1, x0i + x2i);
let x0r = x0r - x2r;
let x0i = x0i - x2i;
set(a, j + 4, wk2r.mul_add(x0r, -wk2i * x0i));
set(a, j + 5, wk2r.mul_add(x0i, wk2i * x0r));
let x0r = x1r - x3i;
let x0i = x1i + x3r;
set(a, j + 2, wk1r.mul_add(x0r, -wk1i * x0i));
set(a, j + 3, wk1r.mul_add(x0i, wk1i * x0r));
let x0r = x1r + x3i;
let x0i = x1i - x3r;
set(a, j + 6, wk3r.mul_add(x0r, -wk3i * x0i));
set(a, j + 7, wk3r.mul_add(x0i, wk3i * x0r));
let wk1r = get(w, k2 + 2);
let wk1i = get(w, k2 + 3);
let wk3r = (-2.0 * wk2r).mul_add(wk1i, wk1r);
let wk3i = (2.0 * wk2r).mul_add(wk1r, -wk1i);
let x0r = get(a, j + 8) + get(a, j + 10);
let x0i = get(a, j + 9) + get(a, j + 11);
let x1r = get(a, j + 8) - get(a, j + 10);
let x1i = get(a, j + 9) - get(a, j + 11);
let x2r = get(a, j + 12) + get(a, j + 14);
let x2i = get(a, j + 13) + get(a, j + 15);
let x3r = get(a, j + 12) - get(a, j + 14);
let x3i = get(a, j + 13) - get(a, j + 15);
set(a, j + 8, x0r + x2r);
set(a, j + 9, x0i + x2i);
let x0r = x0r - x2r;
let x0i = x0i - x2i;
set(a, j + 12, (-wk2i).mul_add(x0r, -wk2r * x0i));
set(a, j + 13, (-wk2i).mul_add(x0i, wk2r * x0r));
let x0r = x1r - x3i;
let x0i = x1i + x3r;
set(a, j + 10, wk1r.mul_add(x0r, -wk1i * x0i));
set(a, j + 11, wk1r.mul_add(x0i, wk1i * x0r));
let x0r = x1r + x3i;
let x0i = x1i - x3r;
set(a, j + 14, wk3r.mul_add(x0r, -wk3i * x0i));
set(a, j + 15, wk3r.mul_add(x0i, wk3i * x0r));
j += 16;
}
}
}
fn cftmdl(n: usize, l: usize, a: &mut [f32], w: &[f32]) {
let m = l << 2;
unsafe {
for j in (0..l).step_by(2) {
let j1 = j + l;
let j2 = j1 + l;
let j3 = j2 + l;
let x0r = get(a, j) + get(a, j1);
let x0i = get(a, j + 1) + get(a, j1 + 1);
let x1r = get(a, j) - get(a, j1);
let x1i = get(a, j + 1) - get(a, j1 + 1);
let x2r = get(a, j2) + get(a, j3);
let x2i = get(a, j2 + 1) + get(a, j3 + 1);
let x3r = get(a, j2) - get(a, j3);
let x3i = get(a, j2 + 1) - get(a, j3 + 1);
set(a, j, x0r + x2r);
set(a, j + 1, x0i + x2i);
set(a, j2, x0r - x2r);
set(a, j2 + 1, x0i - x2i);
set(a, j1, x1r - x3i);
set(a, j1 + 1, x1i + x3r);
set(a, j3, x1r + x3i);
set(a, j3 + 1, x1i - x3r);
}
let wk1r = get(w, 2);
for j in (m..l + m).step_by(2) {
let j1 = j + l;
let j2 = j1 + l;
let j3 = j2 + l;
let x0r = get(a, j) + get(a, j1);
let x0i = get(a, j + 1) + get(a, j1 + 1);
let x1r = get(a, j) - get(a, j1);
let x1i = get(a, j + 1) - get(a, j1 + 1);
let x2r = get(a, j2) + get(a, j3);
let x2i = get(a, j2 + 1) + get(a, j3 + 1);
let x3r = get(a, j2) - get(a, j3);
let x3i = get(a, j2 + 1) - get(a, j3 + 1);
set(a, j, x0r + x2r);
set(a, j + 1, x0i + x2i);
set(a, j2, x2i - x0i);
set(a, j2 + 1, x0r - x2r);
let x0r = x1r - x3i;
let x0i = x1i + x3r;
set(a, j1, wk1r * (x0r - x0i));
set(a, j1 + 1, wk1r * (x0r + x0i));
let x0r = x3i + x1r;
let x0i = x3r - x1i;
set(a, j3, wk1r * (x0i - x0r));
set(a, j3 + 1, wk1r * (x0i + x0r));
}
let mut k1 = 0_usize;
let m2 = 2 * m;
let mut k = m2;
while k < n {
k1 += 2;
let k2 = 2 * k1;
let wk2r = get(w, k1);
let wk2i = get(w, k1 + 1);
let wk1r = get(w, k2);
let wk1i = get(w, k2 + 1);
let wk3r = (-2.0 * wk2i).mul_add(wk1i, wk1r);
let wk3i = (2.0 * wk2i).mul_add(wk1r, -wk1i);
for j in (k..l + k).step_by(2) {
let j1 = j + l;
let j2 = j1 + l;
let j3 = j2 + l;
let x0r = get(a, j) + get(a, j1);
let x0i = get(a, j + 1) + get(a, j1 + 1);
let x1r = get(a, j) - get(a, j1);
let x1i = get(a, j + 1) - get(a, j1 + 1);
let x2r = get(a, j2) + get(a, j3);
let x2i = get(a, j2 + 1) + get(a, j3 + 1);
let x3r = get(a, j2) - get(a, j3);
let x3i = get(a, j2 + 1) - get(a, j3 + 1);
set(a, j, x0r + x2r);
set(a, j + 1, x0i + x2i);
let x0r = x0r - x2r;
let x0i = x0i - x2i;
set(a, j2, wk2r.mul_add(x0r, -wk2i * x0i));
set(a, j2 + 1, wk2r.mul_add(x0i, wk2i * x0r));
let x0r = x1r - x3i;
let x0i = x1i + x3r;
set(a, j1, wk1r.mul_add(x0r, -wk1i * x0i));
set(a, j1 + 1, wk1r.mul_add(x0i, wk1i * x0r));
let x0r = x1r + x3i;
let x0i = x1i - x3r;
set(a, j3, wk3r.mul_add(x0r, -wk3i * x0i));
set(a, j3 + 1, wk3r.mul_add(x0i, wk3i * x0r));
}
let wk1r = get(w, k2 + 2);
let wk1i = get(w, k2 + 3);
let wk3r = (-2.0 * wk2r).mul_add(wk1i, wk1r);
let wk3i = (2.0 * wk2r).mul_add(wk1r, -wk1i);
for j in (k + m..l + (k + m)).step_by(2) {
let j1 = j + l;
let j2 = j1 + l;
let j3 = j2 + l;
let x0r = get(a, j) + get(a, j1);
let x0i = get(a, j + 1) + get(a, j1 + 1);
let x1r = get(a, j) - get(a, j1);
let x1i = get(a, j + 1) - get(a, j1 + 1);
let x2r = get(a, j2) + get(a, j3);
let x2i = get(a, j2 + 1) + get(a, j3 + 1);
let x3r = get(a, j2) - get(a, j3);
let x3i = get(a, j2 + 1) - get(a, j3 + 1);
set(a, j, x0r + x2r);
set(a, j + 1, x0i + x2i);
let x0r = x0r - x2r;
let x0i = x0i - x2i;
set(a, j2, (-wk2i).mul_add(x0r, -wk2r * x0i));
set(a, j2 + 1, (-wk2i).mul_add(x0i, wk2r * x0r));
let x0r = x1r - x3i;
let x0i = x1i + x3r;
set(a, j1, wk1r.mul_add(x0r, -wk1i * x0i));
set(a, j1 + 1, wk1r.mul_add(x0i, wk1i * x0r));
let x0r = x1r + x3i;
let x0i = x1i - x3r;
set(a, j3, wk3r.mul_add(x0r, -wk3i * x0i));
set(a, j3 + 1, wk3r.mul_add(x0i, wk3i * x0r));
}
k += m2;
}
}
}
fn rftfsub(n: usize, a: &mut [f32], nc: usize, c: &[f32]) {
let m = n >> 1;
let ks = 2 * nc / m;
let mut kk = 0;
let mut j = 2;
unsafe {
while j < m {
let k = n - j;
kk += ks;
let wkr = 0.5 - get(c, nc - kk);
let wki = get(c, kk);
let xr = get(a, j) - get(a, k);
let xi = get(a, j + 1) + get(a, k + 1);
let yr = wkr.mul_add(xr, -wki * xi);
let yi = wkr.mul_add(xi, wki * xr);
set(a, j, get(a, j) - yr);
set(a, j + 1, get(a, j + 1) - yi);
set(a, k, get(a, k) + yr);
set(a, k + 1, get(a, k + 1) - yi);
j += 2;
}
}
}
fn rftbsub(n: usize, a: &mut [f32], nc: usize, c: &[f32]) {
let m = n >> 1;
let ks = 2 * nc / m;
let mut kk = 0;
unsafe {
set(a, 1, -get(a, 1));
let mut j = 2;
while j < m {
let k = n - j;
kk += ks;
let wkr = 0.5 - get(c, nc - kk);
let wki = get(c, kk);
let xr = get(a, j) - get(a, k);
let xi = get(a, j + 1) + get(a, k + 1);
let yr = wkr.mul_add(xr, wki * xi);
let yi = wkr.mul_add(xi, -wki * xr);
set(a, j, get(a, j) - yr);
set(a, j + 1, yi - get(a, j + 1));
set(a, k, get(a, k) + yr);
set(a, k + 1, yi - get(a, k + 1));
j += 2;
}
set(a, m + 1, -get(a, m + 1));
}
}
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use test_strategy::proptest;
use super::*;
#[test]
fn roundtrip_256() {
let fft = Fft4g::new(256);
let mut a: Vec<f32> = (0..256).map(|i| (i as f32 * 0.05).sin()).collect();
let original = a.clone();
fft.rdft(&mut a);
fft.irdft(&mut a);
let scale = 2.0 / 256.0;
for (i, (&o, &r)) in original.iter().zip(a.iter()).enumerate() {
let recovered = r * scale;
assert!(
(o - recovered).abs() < 1e-4,
"mismatch at {i}: original={o}, recovered={recovered}"
);
}
}
#[test]
fn roundtrip_multiple_sizes() {
for &n in &[4, 8, 16, 32, 64, 128, 256, 512] {
let fft = Fft4g::new(n);
let mut a: Vec<f32> = (0..n).map(|i| (i as f32 * 0.1).cos()).collect();
let original = a.clone();
fft.rdft(&mut a);
fft.irdft(&mut a);
let scale = 2.0 / n as f32;
for (i, (&o, &r)) in original.iter().zip(a.iter()).enumerate() {
let recovered = r * scale;
assert!(
(o - recovered).abs() < 1e-3,
"size {n}, index {i}: original={o}, recovered={recovered}"
);
}
}
}
#[test]
fn impulse_256() {
let fft = Fft4g::new(256);
let mut a = vec![0.0_f32; 256];
a[0] = 1.0;
fft.rdft(&mut a);
assert!((a[0] - 1.0).abs() < 1e-6, "DC = {}", a[0]);
assert!((a[1] - 1.0).abs() < 1e-6, "Nyquist = {}", a[1]);
for k in 1..128 {
assert!((a[2 * k] - 1.0).abs() < 1e-5, "bin {k} real = {}", a[2 * k]);
assert!(a[2 * k + 1].abs() < 1e-5, "bin {k} imag = {}", a[2 * k + 1]);
}
}
#[test]
fn parseval_energy() {
let n = 256;
let fft = Fft4g::new(n);
let mut a: Vec<f32> = (0..n).map(|i| (i as f32 * 0.2).sin()).collect();
let time_energy: f32 = a.iter().map(|x| x * x).sum();
fft.rdft(&mut a);
let dc_sq = a[0] * a[0];
let nyq_sq = a[1] * a[1];
let mut freq_energy = dc_sq + nyq_sq;
for k in 1..n / 2 {
freq_energy += 2.0 * (a[2 * k] * a[2 * k] + a[2 * k + 1] * a[2 * k + 1]);
}
let expected = n as f32 * time_energy;
assert!(
(freq_energy - expected).abs() / expected < 1e-4,
"Parseval: freq_energy={freq_energy}, expected={expected}"
);
}
#[test]
fn zero_input() {
let fft = Fft4g::new(64);
let mut a = vec![0.0_f32; 64];
fft.rdft(&mut a);
for (i, &v) in a.iter().enumerate() {
assert_eq!(v, 0.0, "expected zero at {i}, got {v}");
}
}
#[test]
#[should_panic(expected = "power of 2")]
fn rejects_non_power_of_two() {
let _ = Fft4g::new(100);
}
#[test]
#[should_panic(expected = ">= 2")]
fn rejects_size_one() {
let _ = Fft4g::new(1);
}
fn fft_size_from_exp(exp: u32) -> usize {
1 << exp
}
#[proptest]
fn roundtrip_recovers_signal(
#[strategy(2..=9u32)] exp: u32,
#[strategy(prop::collection::vec(-1.0f32..1.0, 1 << #exp as usize))] signal: Vec<f32>,
) {
let n = fft_size_from_exp(exp);
let fft = Fft4g::new(n);
let mut a = signal.clone();
fft.rdft(&mut a);
fft.irdft(&mut a);
let scale = 2.0 / n as f32;
for (i, (&o, &r)) in signal.iter().zip(a.iter()).enumerate() {
prop_assert!(
(o - r * scale).abs() < 1e-3,
"size {n}, index {i}: original={o}, recovered={}",
r * scale
);
}
}
#[proptest]
fn linearity_holds(
#[strategy(2..=9u32)] exp: u32,
#[strategy(prop::collection::vec(-1.0f32..1.0, 1 << #exp as usize))] sig_a: Vec<f32>,
#[strategy(prop::collection::vec(-1.0f32..1.0, 1 << #exp as usize))] sig_b: Vec<f32>,
) {
let n = fft_size_from_exp(exp);
let fft = Fft4g::new(n);
let mut a = sig_a.clone();
let mut b = sig_b.clone();
let mut sum: Vec<f32> = sig_a.iter().zip(sig_b.iter()).map(|(x, y)| x + y).collect();
fft.rdft(&mut a);
fft.rdft(&mut b);
fft.rdft(&mut sum);
for (i, ((&fa, &fb), &fs)) in a.iter().zip(b.iter()).zip(sum.iter()).enumerate() {
let expected = fa + fb;
prop_assert!(
(fs - expected).abs() < 1e-3,
"size {n}, index {i}: FFT(a+b)={fs}, FFT(a)+FFT(b)={expected}"
);
}
}
#[proptest]
fn parseval_energy_conservation(
#[strategy(2..=9u32)] exp: u32,
#[strategy(prop::collection::vec(-1.0f32..1.0, 1 << #exp as usize))] signal: Vec<f32>,
) {
let n = fft_size_from_exp(exp);
let fft = Fft4g::new(n);
let time_energy: f32 = signal.iter().map(|x| x * x).sum();
let mut a = signal;
fft.rdft(&mut a);
let dc_sq = a[0] * a[0];
let nyq_sq = a[1] * a[1];
let mut freq_energy = dc_sq + nyq_sq;
for k in 1..n / 2 {
freq_energy += 2.0 * (a[2 * k] * a[2 * k] + a[2 * k + 1] * a[2 * k + 1]);
}
let expected = n as f32 * time_energy;
if expected > 1e-6 {
prop_assert!(
(freq_energy - expected).abs() / expected < 1e-3,
"Parseval: freq_energy={freq_energy}, expected={expected}"
);
}
}
}