#![allow(clippy::items_after_statements)]
use crate::dft::codelets::{
notw_1024_dispatch, notw_128_dispatch, notw_16_dispatch, notw_256_dispatch, notw_2_dispatch,
notw_32_dispatch, notw_4096_dispatch, notw_4_dispatch, notw_512_dispatch, notw_64_dispatch,
notw_8_dispatch,
};
use crate::kernel::{Complex, Float};
use crate::prelude::*;
use super::super::problem::Sign;
use super::simd_butterfly::dit_butterflies_f64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CtVariant {
Dit,
Dif,
DitRadix4,
DitRadix8,
SplitRadix,
}
pub struct CooleyTukeySolver<T: Float> {
pub variant: CtVariant,
_marker: core::marker::PhantomData<T>,
}
impl<T: Float> Default for CooleyTukeySolver<T> {
fn default() -> Self {
Self::new(CtVariant::Dit)
}
}
impl<T: Float> CooleyTukeySolver<T> {
#[must_use]
pub fn new(variant: CtVariant) -> Self {
Self {
variant,
_marker: core::marker::PhantomData,
}
}
#[must_use]
pub fn name(&self) -> &'static str {
match self.variant {
CtVariant::Dit => "dft-ct-dit",
CtVariant::Dif => "dft-ct-dif",
CtVariant::DitRadix4 => "dft-ct-dit-radix4",
CtVariant::DitRadix8 => "dft-ct-dit-radix8",
CtVariant::SplitRadix => "dft-ct-split-radix",
}
}
#[must_use]
pub fn is_power_of_8(n: usize) -> bool {
n > 0 && (n & (n - 1)) == 0 && n.trailing_zeros().is_multiple_of(3)
}
#[must_use]
pub fn is_power_of_4(n: usize) -> bool {
n > 0 && (n & (n - 1)) == 0 && n.trailing_zeros().is_multiple_of(2)
}
#[must_use]
pub fn applicable(n: usize) -> bool {
n > 0 && (n & (n - 1)) == 0
}
pub fn execute(&self, input: &[Complex<T>], output: &mut [Complex<T>], sign: Sign) {
let n = input.len();
debug_assert_eq!(n, output.len());
debug_assert!(Self::applicable(n), "Size must be power of 2");
if n <= 1 {
if n == 1 {
output[0] = input[0];
}
return;
}
let sign_int = sign.value();
match n {
2 => {
output.copy_from_slice(input);
notw_2_dispatch(output);
return;
}
4 => {
output.copy_from_slice(input);
notw_4_dispatch(output, sign_int);
return;
}
8 => {
output.copy_from_slice(input);
notw_8_dispatch(output, sign_int);
return;
}
16 => {
output.copy_from_slice(input);
notw_16_dispatch(output, sign_int);
return;
}
32 => {
output.copy_from_slice(input);
notw_32_dispatch(output, sign_int);
return;
}
64 => {
output.copy_from_slice(input);
notw_64_dispatch(output, sign_int);
return;
}
128 => {
output.copy_from_slice(input);
notw_128_dispatch(output, sign_int);
return;
}
256 => {
output.copy_from_slice(input);
notw_256_dispatch(output, sign_int);
return;
}
512 => {
output.copy_from_slice(input);
notw_512_dispatch(output, sign_int);
return;
}
1024 => {
output.copy_from_slice(input);
notw_1024_dispatch(output, sign_int);
return;
}
4096 => {
output.copy_from_slice(input);
notw_4096_dispatch(output, sign_int);
return;
}
_ => {}
}
match self.variant {
CtVariant::Dit => self.execute_dit(input, output, sign),
CtVariant::Dif => self.execute_dif(input, output, sign),
CtVariant::DitRadix4 => self.execute_dit_radix4(input, output, sign),
CtVariant::DitRadix8 => self.execute_dit_radix8(input, output, sign),
CtVariant::SplitRadix => self.execute_split_radix(input, output, sign),
}
}
pub fn execute_inplace(&self, data: &mut [Complex<T>], sign: Sign) {
let n = data.len();
debug_assert!(Self::applicable(n), "Size must be power of 2");
if n <= 1 {
return;
}
let sign_int = sign.value();
match n {
2 => {
notw_2_dispatch(data);
return;
}
4 => {
notw_4_dispatch(data, sign_int);
return;
}
8 => {
notw_8_dispatch(data, sign_int);
return;
}
16 => {
notw_16_dispatch(data, sign_int);
return;
}
32 => {
notw_32_dispatch(data, sign_int);
return;
}
64 => {
notw_64_dispatch(data, sign_int);
return;
}
128 => {
notw_128_dispatch(data, sign_int);
return;
}
256 => {
notw_256_dispatch(data, sign_int);
return;
}
512 => {
notw_512_dispatch(data, sign_int);
return;
}
1024 => {
notw_1024_dispatch(data, sign_int);
return;
}
4096 => {
notw_4096_dispatch(data, sign_int);
return;
}
_ => {}
}
match self.variant {
CtVariant::Dit => self.execute_dit_inplace(data, sign),
CtVariant::Dif => self.execute_dif_inplace(data, sign),
CtVariant::DitRadix4 => self.execute_dit_radix4_inplace(data, sign),
CtVariant::DitRadix8 => self.execute_dit_radix8_inplace(data, sign),
CtVariant::SplitRadix => self.execute_split_radix_inplace(data, sign),
}
}
fn execute_dit(&self, input: &[Complex<T>], output: &mut [Complex<T>], sign: Sign) {
let n = input.len();
for i in 0..n {
output[bit_reverse(i, n)] = input[i];
}
self.dit_butterflies(output, sign);
}
pub fn execute_dit_inplace(&self, data: &mut [Complex<T>], sign: Sign) {
bit_reverse_permute(data);
self.dit_butterflies(data, sign);
}
fn dit_butterflies(&self, data: &mut [Complex<T>], sign: Sign) {
if core::any::TypeId::of::<T>() == core::any::TypeId::of::<f64>() {
let data_f64: &mut [Complex<f64>] =
unsafe { &mut *(std::ptr::from_mut::<[Complex<T>]>(data) as *mut [Complex<f64>]) };
dit_butterflies_f64(data_f64, sign);
return;
}
self.dit_butterflies_scalar(data, sign);
}
fn dit_butterflies_scalar(&self, data: &mut [Complex<T>], sign: Sign) {
let n = data.len();
let log_n = n.trailing_zeros() as usize;
let sign_val = T::from_isize(sign.value() as isize);
let mut m = 2; for _ in 0..log_n {
let half_m = m / 2;
let angle_step = sign_val * T::TWO_PI / T::from_usize(m);
let w_step = Complex::cis(angle_step);
for k in (0..n).step_by(m) {
let mut w = Complex::new(T::ONE, T::ZERO);
for j in 0..half_m {
let u = data[k + j];
let t = data[k + j + half_m] * w;
data[k + j] = u + t;
data[k + j + half_m] = u - t;
w = w * w_step;
}
}
m *= 2;
}
}
fn execute_dif(&self, input: &[Complex<T>], output: &mut [Complex<T>], sign: Sign) {
output.copy_from_slice(input);
self.dif_butterflies(output, sign);
bit_reverse_permute(output);
}
fn execute_dif_inplace(&self, data: &mut [Complex<T>], sign: Sign) {
self.dif_butterflies(data, sign);
bit_reverse_permute(data);
}
fn dif_butterflies(&self, data: &mut [Complex<T>], sign: Sign) {
let n = data.len();
let log_n = n.trailing_zeros() as usize;
let sign_val = T::from_isize(sign.value() as isize);
let mut m = n;
for _ in 0..log_n {
let half_m = m / 2;
let angle_step = sign_val * T::TWO_PI / T::from_usize(m);
let w_step = Complex::cis(angle_step);
for k in (0..n).step_by(m) {
let mut w = Complex::new(T::ONE, T::ZERO);
for j in 0..half_m {
let u = data[k + j];
let v = data[k + j + half_m];
data[k + j] = u + v;
data[k + j + half_m] = (u - v) * w;
w = w * w_step;
}
}
m /= 2;
}
}
fn execute_dit_radix4(&self, input: &[Complex<T>], output: &mut [Complex<T>], sign: Sign) {
let n = input.len();
for i in 0..n {
output[bit_reverse(i, n)] = input[i];
}
self.dit_radix4_butterflies(output, sign);
}
fn execute_dit_radix4_inplace(&self, data: &mut [Complex<T>], sign: Sign) {
bit_reverse_permute(data);
self.dit_radix4_butterflies(data, sign);
}
fn dit_radix4_butterflies(&self, data: &mut [Complex<T>], sign: Sign) {
let n = data.len();
let log_n = n.trailing_zeros() as usize;
let sign_val = T::from_isize(sign.value() as isize);
let mut s = 0;
if log_n % 2 == 1 {
let m = 2;
for k in (0..n).step_by(m) {
let u = data[k];
let v = data[k + 1];
data[k] = u + v;
data[k + 1] = u - v;
}
s = 1;
}
while s + 1 < log_n {
let m1 = 1 << (s + 1); let m2 = 1 << (s + 2);
let half_m1 = m1 / 2; let half_m2 = m2 / 2;
let angle_step_1 = sign_val * T::TWO_PI / T::from_usize(m1);
let angle_step_2 = sign_val * T::TWO_PI / T::from_usize(m2);
for k in (0..n).step_by(m2) {
for j in 0..half_m1 {
let angle1 = angle_step_1 * T::from_usize(j);
let w1 = Complex::cis(angle1);
let angle2_a = angle_step_2 * T::from_usize(j);
let angle2_b = angle_step_2 * T::from_usize(j + half_m1);
let w2_a = Complex::cis(angle2_a);
let w2_b = Complex::cis(angle2_b);
let i0 = k + j;
let i1 = k + j + half_m1;
let i2 = k + j + half_m2;
let i3 = k + j + half_m2 + half_m1;
let x0 = data[i0];
let x1 = data[i1];
let x2 = data[i2];
let x3 = data[i3];
let a0 = x0 + x1 * w1;
let a1 = x0 - x1 * w1;
let a2 = x2 + x3 * w1;
let a3 = x2 - x3 * w1;
data[i0] = a0 + a2 * w2_a;
data[i2] = a0 - a2 * w2_a;
data[i1] = a1 + a3 * w2_b;
data[i3] = a1 - a3 * w2_b;
}
}
s += 2;
}
if s < log_n {
let m = 1 << (s + 1);
let half_m = m / 2;
let angle_step = sign_val * T::TWO_PI / T::from_usize(m);
let w_step = Complex::cis(angle_step);
for k in (0..n).step_by(m) {
let mut w = Complex::new(T::ONE, T::ZERO);
for j in 0..half_m {
let u = data[k + j];
let t = data[k + j + half_m] * w;
data[k + j] = u + t;
data[k + j + half_m] = u - t;
w = w * w_step;
}
}
}
}
fn execute_dit_radix8(&self, input: &[Complex<T>], output: &mut [Complex<T>], sign: Sign) {
let n = input.len();
for i in 0..n {
output[bit_reverse(i, n)] = input[i];
}
self.dit_radix8_butterflies(output, sign);
}
fn execute_dit_radix8_inplace(&self, data: &mut [Complex<T>], sign: Sign) {
bit_reverse_permute(data);
self.dit_radix8_butterflies(data, sign);
}
fn dit_radix8_butterflies(&self, data: &mut [Complex<T>], sign: Sign) {
let n = data.len();
let log_n = n.trailing_zeros() as usize;
let sign_val = T::from_isize(sign.value() as isize);
let mut s = 0;
let remainder = log_n % 3;
if remainder == 1 {
for k in (0..n).step_by(2) {
let u = data[k];
let v = data[k + 1];
data[k] = u + v;
data[k + 1] = u - v;
}
s = 1;
} else if remainder == 2 {
self.radix4_stage(data, 0, sign_val);
s = 2;
}
while s + 2 < log_n {
self.radix8_stage(data, s, sign_val);
s += 3;
}
if s + 1 < log_n {
self.radix4_stage(data, s, sign_val);
s += 2;
}
if s < log_n {
let m = 1 << (s + 1);
let half_m = m / 2;
let angle_step = sign_val * T::TWO_PI / T::from_usize(m);
let w_step = Complex::cis(angle_step);
for k in (0..n).step_by(m) {
let mut w = Complex::new(T::ONE, T::ZERO);
for j in 0..half_m {
let u = data[k + j];
let t = data[k + j + half_m] * w;
data[k + j] = u + t;
data[k + j + half_m] = u - t;
w = w * w_step;
}
}
}
}
#[inline]
fn radix4_stage(&self, data: &mut [Complex<T>], s: usize, sign_val: T) {
let n = data.len();
let m1 = 1 << (s + 1);
let m2 = 1 << (s + 2);
let half_m1 = m1 / 2;
let half_m2 = m2 / 2;
let angle_step_1 = sign_val * T::TWO_PI / T::from_usize(m1);
let angle_step_2 = sign_val * T::TWO_PI / T::from_usize(m2);
for k in (0..n).step_by(m2) {
for j in 0..half_m1 {
let angle1 = angle_step_1 * T::from_usize(j);
let w1 = Complex::cis(angle1);
let angle2_a = angle_step_2 * T::from_usize(j);
let angle2_b = angle_step_2 * T::from_usize(j + half_m1);
let w2_a = Complex::cis(angle2_a);
let w2_b = Complex::cis(angle2_b);
let i0 = k + j;
let i1 = k + j + half_m1;
let i2 = k + j + half_m2;
let i3 = k + j + half_m2 + half_m1;
let x0 = data[i0];
let x1 = data[i1];
let x2 = data[i2];
let x3 = data[i3];
let a0 = x0 + x1 * w1;
let a1 = x0 - x1 * w1;
let a2 = x2 + x3 * w1;
let a3 = x2 - x3 * w1;
data[i0] = a0 + a2 * w2_a;
data[i2] = a0 - a2 * w2_a;
data[i1] = a1 + a3 * w2_b;
data[i3] = a1 - a3 * w2_b;
}
}
}
#[inline]
fn radix8_stage(&self, data: &mut [Complex<T>], s: usize, sign_val: T) {
let n = data.len();
let m1 = 1 << (s + 1); let m2 = 1 << (s + 2); let m3 = 1 << (s + 3);
let d0 = 1 << s; let d1 = 1 << (s + 1); let d2 = 1 << (s + 2);
let angle_step_1 = sign_val * T::TWO_PI / T::from_usize(m1);
let angle_step_2 = sign_val * T::TWO_PI / T::from_usize(m2);
let angle_step_3 = sign_val * T::TWO_PI / T::from_usize(m3);
let w1_step = Complex::cis(angle_step_1);
let w2_step = Complex::cis(angle_step_2);
let w3_step = Complex::cis(angle_step_3);
let w2_offset_d0 = Complex::cis(angle_step_2 * T::from_usize(d0));
let w3_offset_d0 = Complex::cis(angle_step_3 * T::from_usize(d0));
let w3_offset_d1 = Complex::cis(angle_step_3 * T::from_usize(d1));
let w3_offset_d0_d1 = Complex::cis(angle_step_3 * T::from_usize(d0 + d1));
let mut tw1: Vec<Complex<T>> = Vec::with_capacity(d0);
let mut tw2_0: Vec<Complex<T>> = Vec::with_capacity(d0);
let mut tw2_1: Vec<Complex<T>> = Vec::with_capacity(d0);
let mut tw3_0: Vec<Complex<T>> = Vec::with_capacity(d0);
let mut tw3_1: Vec<Complex<T>> = Vec::with_capacity(d0);
let mut tw3_2: Vec<Complex<T>> = Vec::with_capacity(d0);
let mut tw3_3: Vec<Complex<T>> = Vec::with_capacity(d0);
let mut w1 = Complex::new(T::ONE, T::ZERO);
let mut w2 = Complex::new(T::ONE, T::ZERO);
let mut w3 = Complex::new(T::ONE, T::ZERO);
for _ in 0..d0 {
tw1.push(w1);
tw2_0.push(w2);
tw2_1.push(w2 * w2_offset_d0);
tw3_0.push(w3);
tw3_1.push(w3 * w3_offset_d0);
tw3_2.push(w3 * w3_offset_d1);
tw3_3.push(w3 * w3_offset_d0_d1);
w1 = w1 * w1_step;
w2 = w2 * w2_step;
w3 = w3 * w3_step;
}
for k in (0..n).step_by(m3) {
for j in 0..d0 {
let i0 = k + j;
let i1 = k + j + d0;
let i2 = k + j + d1;
let i3 = k + j + d0 + d1;
let i4 = k + j + d2;
let i5 = k + j + d0 + d2;
let i6 = k + j + d1 + d2;
let i7 = k + j + d0 + d1 + d2;
let x0 = data[i0];
let x1 = data[i1];
let x2 = data[i2];
let x3 = data[i3];
let x4 = data[i4];
let x5 = data[i5];
let x6 = data[i6];
let x7 = data[i7];
let a0 = x0 + x1 * tw1[j];
let a1 = x0 - x1 * tw1[j];
let a2 = x2 + x3 * tw1[j];
let a3 = x2 - x3 * tw1[j];
let a4 = x4 + x5 * tw1[j];
let a5 = x4 - x5 * tw1[j];
let a6 = x6 + x7 * tw1[j];
let a7 = x6 - x7 * tw1[j];
let b0 = a0 + a2 * tw2_0[j];
let b2 = a0 - a2 * tw2_0[j];
let b1 = a1 + a3 * tw2_1[j];
let b3 = a1 - a3 * tw2_1[j];
let b4 = a4 + a6 * tw2_0[j];
let b6 = a4 - a6 * tw2_0[j];
let b5 = a5 + a7 * tw2_1[j];
let b7 = a5 - a7 * tw2_1[j];
data[i0] = b0 + b4 * tw3_0[j];
data[i4] = b0 - b4 * tw3_0[j];
data[i1] = b1 + b5 * tw3_1[j];
data[i5] = b1 - b5 * tw3_1[j];
data[i2] = b2 + b6 * tw3_2[j];
data[i6] = b2 - b6 * tw3_2[j];
data[i3] = b3 + b7 * tw3_3[j];
data[i7] = b3 - b7 * tw3_3[j];
}
}
}
fn execute_split_radix(&self, input: &[Complex<T>], output: &mut [Complex<T>], sign: Sign) {
output.copy_from_slice(input);
let sign_val = T::from_isize(sign.value() as isize);
split_radix_dif_recursive(output, sign_val);
bit_reverse_permute(output);
}
fn execute_split_radix_inplace(&self, data: &mut [Complex<T>], sign: Sign) {
let sign_val = T::from_isize(sign.value() as isize);
split_radix_dif_recursive(data, sign_val);
bit_reverse_permute(data);
}
}
fn split_radix_dif_recursive<T: Float>(data: &mut [Complex<T>], sign_val: T) {
let n = data.len();
if n <= 1 {
return;
}
if n == 2 {
let t0 = data[0];
let t1 = data[1];
data[0] = t0 + t1;
data[1] = t0 - t1;
return;
}
let half_n = n / 2;
let quarter_n = n / 4;
let angle_step = sign_val * T::TWO_PI / T::from_usize(n);
for k in 0..quarter_n {
let angle1 = angle_step * T::from_usize(k);
let angle3 = angle_step * T::from_usize(3 * k);
let w1 = Complex::cis(angle1);
let w3 = Complex::cis(angle3);
let i0 = k;
let i1 = k + quarter_n;
let i2 = k + half_n;
let i3 = k + half_n + quarter_n;
let x0 = data[i0];
let x1 = data[i1];
let x2 = data[i2];
let x3 = data[i3];
let t0 = x0 + x2;
let t1 = x1 + x3;
let t2 = x0 - x2;
let t3 = x1 - x3;
let j_t3 = Complex::new(-sign_val * t3.im, sign_val * t3.re);
let u2 = (t2 + j_t3) * w1;
let u3 = (t2 - j_t3) * w3;
data[i0] = t0;
data[i1] = t1;
data[i2] = u2;
data[i3] = u3;
}
split_radix_dif_recursive(&mut data[0..half_n], sign_val);
split_radix_dif_recursive(&mut data[half_n..half_n + quarter_n], sign_val);
split_radix_dif_recursive(&mut data[half_n + quarter_n..n], sign_val);
}
#[inline]
fn bit_reverse(mut x: usize, n: usize) -> usize {
let log_n = n.trailing_zeros() as usize;
let mut result = 0;
for _ in 0..log_n {
result = (result << 1) | (x & 1);
x >>= 1;
}
result
}
fn bit_reverse_permute<T: Float>(data: &mut [Complex<T>]) {
let n = data.len();
if n <= 1 {
return;
}
let log_n = n.trailing_zeros() as usize;
static BIT_REV_TABLE: [u8; 256] = {
let mut table = [0u8; 256];
let mut i = 0;
while i < 256 {
let mut x = i as u8;
let mut rev = 0u8;
let mut j = 0;
while j < 8 {
rev = (rev << 1) | (x & 1);
x >>= 1;
j += 1;
}
table[i] = rev;
i += 1;
}
table
};
let bit_reverse_fast = |mut x: usize, bits: usize| -> usize {
let mut result: usize = 0;
let mut remaining_bits = bits;
while remaining_bits >= 8 {
result = (result << 8) | (BIT_REV_TABLE[x & 0xFF] as usize);
x >>= 8;
remaining_bits -= 8;
}
if remaining_bits > 0 {
result = (result << remaining_bits)
| ((BIT_REV_TABLE[x & 0xFF] as usize) >> (8 - remaining_bits));
}
result
};
for i in 0..n {
let j = bit_reverse_fast(i, log_n);
if i < j {
data.swap(i, j);
}
}
}
pub fn fft_radix2<T: Float>(input: &[Complex<T>], output: &mut [Complex<T>]) {
CooleyTukeySolver::new(CtVariant::Dit).execute(input, output, Sign::Forward);
}
pub fn ifft_radix2<T: Float>(input: &[Complex<T>], output: &mut [Complex<T>]) {
CooleyTukeySolver::new(CtVariant::Dit).execute(input, output, Sign::Backward);
}
pub fn ifft_radix2_normalized<T: Float>(input: &[Complex<T>], output: &mut [Complex<T>]) {
CooleyTukeySolver::new(CtVariant::Dit).execute(input, output, Sign::Backward);
let n = T::from_usize(output.len());
for x in output.iter_mut() {
*x = *x / n;
}
}
pub fn fft_radix2_inplace<T: Float>(data: &mut [Complex<T>]) {
CooleyTukeySolver::new(CtVariant::Dit).execute_inplace(data, Sign::Forward);
}
pub fn fft_radix4<T: Float>(input: &[Complex<T>], output: &mut [Complex<T>]) {
CooleyTukeySolver::new(CtVariant::DitRadix4).execute(input, output, Sign::Forward);
}
pub fn fft_radix4_inplace<T: Float>(data: &mut [Complex<T>]) {
CooleyTukeySolver::new(CtVariant::DitRadix4).execute_inplace(data, Sign::Forward);
}
pub fn ifft_radix2_inplace<T: Float>(data: &mut [Complex<T>]) {
CooleyTukeySolver::new(CtVariant::Dit).execute_inplace(data, Sign::Backward);
}
pub fn fft_radix8<T: Float>(input: &[Complex<T>], output: &mut [Complex<T>]) {
CooleyTukeySolver::new(CtVariant::DitRadix8).execute(input, output, Sign::Forward);
}
pub fn fft_radix8_inplace<T: Float>(data: &mut [Complex<T>]) {
CooleyTukeySolver::new(CtVariant::DitRadix8).execute_inplace(data, Sign::Forward);
}
pub fn fft_split_radix<T: Float>(input: &[Complex<T>], output: &mut [Complex<T>]) {
CooleyTukeySolver::new(CtVariant::SplitRadix).execute(input, output, Sign::Forward);
}
pub fn fft_split_radix_inplace<T: Float>(data: &mut [Complex<T>]) {
CooleyTukeySolver::new(CtVariant::SplitRadix).execute_inplace(data, Sign::Forward);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dft::solvers::direct::dft_direct;
fn approx_eq(a: f64, b: f64, eps: f64) -> bool {
(a - b).abs() < eps
}
fn complex_approx_eq(a: Complex<f64>, b: Complex<f64>, eps: f64) -> bool {
approx_eq(a.re, b.re, eps) && approx_eq(a.im, b.im, eps)
}
#[test]
fn test_bit_reverse() {
assert_eq!(bit_reverse(0, 8), 0);
assert_eq!(bit_reverse(1, 8), 4);
assert_eq!(bit_reverse(2, 8), 2);
assert_eq!(bit_reverse(3, 8), 6);
assert_eq!(bit_reverse(4, 8), 1);
assert_eq!(bit_reverse(5, 8), 5);
assert_eq!(bit_reverse(6, 8), 3);
assert_eq!(bit_reverse(7, 8), 7);
}
#[test]
fn test_fft_size_2() {
let input = [Complex::new(1.0_f64, 0.0), Complex::new(2.0, 0.0)];
let mut output_fft = [Complex::zero(); 2];
let mut output_direct = [Complex::zero(); 2];
fft_radix2(&input, &mut output_fft);
dft_direct(&input, &mut output_direct);
for (a, b) in output_fft.iter().zip(output_direct.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
#[test]
fn test_fft_size_4() {
let input = [
Complex::new(1.0_f64, 0.0),
Complex::new(2.0, 0.0),
Complex::new(3.0, 0.0),
Complex::new(4.0, 0.0),
];
let mut output_fft = [Complex::zero(); 4];
let mut output_direct = [Complex::zero(); 4];
fft_radix2(&input, &mut output_fft);
dft_direct(&input, &mut output_direct);
for (a, b) in output_fft.iter().zip(output_direct.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
#[test]
fn test_fft_size_8() {
let input: Vec<Complex<f64>> = (0..8)
.map(|i| Complex::new(f64::from(i), f64::from(i) * 0.5))
.collect();
let mut output_fft = vec![Complex::zero(); 8];
let mut output_direct = vec![Complex::zero(); 8];
fft_radix2(&input, &mut output_fft);
dft_direct(&input, &mut output_direct);
for (a, b) in output_fft.iter().zip(output_direct.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
#[test]
fn test_fft_size_16() {
let input: Vec<Complex<f64>> = (0..16)
.map(|i| Complex::new(f64::from(i).sin(), f64::from(i).cos()))
.collect();
let mut output_fft = vec![Complex::zero(); 16];
let mut output_direct = vec![Complex::zero(); 16];
fft_radix2(&input, &mut output_fft);
dft_direct(&input, &mut output_direct);
for (a, b) in output_fft.iter().zip(output_direct.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_fft_inverse_recovers_input() {
let original: Vec<Complex<f64>> = (0..8)
.map(|i| Complex::new(f64::from(i), f64::from(i) * 0.5))
.collect();
let mut transformed = vec![Complex::zero(); 8];
let mut recovered = vec![Complex::zero(); 8];
fft_radix2(&original, &mut transformed);
ifft_radix2_normalized(&transformed, &mut recovered);
for (a, b) in original.iter().zip(recovered.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
#[test]
fn test_fft_inplace_matches_outofplace() {
let input: Vec<Complex<f64>> = (0..8).map(|i| Complex::new(f64::from(i), 0.0)).collect();
let mut out_of_place = vec![Complex::zero(); 8];
fft_radix2(&input, &mut out_of_place);
let mut in_place = input;
fft_radix2_inplace(&mut in_place);
for (a, b) in out_of_place.iter().zip(in_place.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
#[test]
fn test_dif_matches_dit() {
let input: Vec<Complex<f64>> = (0..8)
.map(|i| Complex::new(f64::from(i), f64::from(i) * 0.5))
.collect();
let mut output_dit = vec![Complex::zero(); 8];
let mut output_dif = vec![Complex::zero(); 8];
CooleyTukeySolver::new(CtVariant::Dit).execute(&input, &mut output_dit, Sign::Forward);
CooleyTukeySolver::new(CtVariant::Dif).execute(&input, &mut output_dif, Sign::Forward);
for (a, b) in output_dit.iter().zip(output_dif.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
#[test]
fn test_applicable() {
assert!(CooleyTukeySolver::<f64>::applicable(1));
assert!(CooleyTukeySolver::<f64>::applicable(2));
assert!(CooleyTukeySolver::<f64>::applicable(4));
assert!(CooleyTukeySolver::<f64>::applicable(8));
assert!(CooleyTukeySolver::<f64>::applicable(1024));
assert!(!CooleyTukeySolver::<f64>::applicable(0));
assert!(!CooleyTukeySolver::<f64>::applicable(3));
assert!(!CooleyTukeySolver::<f64>::applicable(5));
assert!(!CooleyTukeySolver::<f64>::applicable(6));
assert!(!CooleyTukeySolver::<f64>::applicable(7));
}
#[test]
fn test_is_power_of_4() {
assert!(CooleyTukeySolver::<f64>::is_power_of_4(1));
assert!(!CooleyTukeySolver::<f64>::is_power_of_4(2));
assert!(CooleyTukeySolver::<f64>::is_power_of_4(4));
assert!(!CooleyTukeySolver::<f64>::is_power_of_4(8));
assert!(CooleyTukeySolver::<f64>::is_power_of_4(16));
assert!(!CooleyTukeySolver::<f64>::is_power_of_4(32));
assert!(CooleyTukeySolver::<f64>::is_power_of_4(64));
assert!(CooleyTukeySolver::<f64>::is_power_of_4(256));
assert!(CooleyTukeySolver::<f64>::is_power_of_4(1024));
assert!(!CooleyTukeySolver::<f64>::is_power_of_4(0));
assert!(!CooleyTukeySolver::<f64>::is_power_of_4(3));
assert!(!CooleyTukeySolver::<f64>::is_power_of_4(5));
}
#[test]
fn test_radix4_matches_radix2_size_4() {
let input = [
Complex::new(1.0_f64, 0.0),
Complex::new(2.0, 0.0),
Complex::new(3.0, 0.0),
Complex::new(4.0, 0.0),
];
let mut output_radix2 = [Complex::zero(); 4];
let mut output_radix4 = [Complex::zero(); 4];
fft_radix2(&input, &mut output_radix2);
fft_radix4(&input, &mut output_radix4);
for (a, b) in output_radix2.iter().zip(output_radix4.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
#[test]
fn test_radix4_matches_radix2_size_16() {
let input: Vec<Complex<f64>> = (0..16)
.map(|i| Complex::new(f64::from(i).sin(), f64::from(i).cos()))
.collect();
let mut output_radix2 = vec![Complex::zero(); 16];
let mut output_radix4 = vec![Complex::zero(); 16];
fft_radix2(&input, &mut output_radix2);
fft_radix4(&input, &mut output_radix4);
for (a, b) in output_radix2.iter().zip(output_radix4.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_radix4_matches_radix2_size_64() {
let input: Vec<Complex<f64>> = (0..64)
.map(|i| Complex::new(f64::from(i).sin(), f64::from(i).cos()))
.collect();
let mut output_radix2 = vec![Complex::zero(); 64];
let mut output_radix4 = vec![Complex::zero(); 64];
fft_radix2(&input, &mut output_radix2);
fft_radix4(&input, &mut output_radix4);
for (a, b) in output_radix2.iter().zip(output_radix4.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_radix4_matches_radix2_size_8() {
let input: Vec<Complex<f64>> = (0..8)
.map(|i| Complex::new(f64::from(i), f64::from(i) * 0.5))
.collect();
let mut output_radix2 = vec![Complex::zero(); 8];
let mut output_radix4 = vec![Complex::zero(); 8];
fft_radix2(&input, &mut output_radix2);
fft_radix4(&input, &mut output_radix4);
for (a, b) in output_radix2.iter().zip(output_radix4.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
#[test]
fn test_radix4_inverse_recovers_input() {
let original: Vec<Complex<f64>> = (0..16)
.map(|i| Complex::new(f64::from(i), f64::from(i) * 0.5))
.collect();
let mut transformed = vec![Complex::zero(); 16];
let mut recovered = vec![Complex::zero(); 16];
CooleyTukeySolver::new(CtVariant::DitRadix4).execute(
&original,
&mut transformed,
Sign::Forward,
);
CooleyTukeySolver::new(CtVariant::DitRadix4).execute(
&transformed,
&mut recovered,
Sign::Backward,
);
let n = 16.0_f64;
for x in &mut recovered {
*x = *x / n;
}
for (a, b) in original.iter().zip(recovered.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_radix4_inplace() {
let input: Vec<Complex<f64>> = (0..16).map(|i| Complex::new(f64::from(i), 0.0)).collect();
let mut out_of_place = vec![Complex::zero(); 16];
fft_radix4(&input, &mut out_of_place);
let mut in_place = input;
fft_radix4_inplace(&mut in_place);
for (a, b) in out_of_place.iter().zip(in_place.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
#[test]
fn test_is_power_of_8() {
assert!(CooleyTukeySolver::<f64>::is_power_of_8(1));
assert!(!CooleyTukeySolver::<f64>::is_power_of_8(2));
assert!(!CooleyTukeySolver::<f64>::is_power_of_8(4));
assert!(CooleyTukeySolver::<f64>::is_power_of_8(8));
assert!(!CooleyTukeySolver::<f64>::is_power_of_8(16));
assert!(!CooleyTukeySolver::<f64>::is_power_of_8(32));
assert!(CooleyTukeySolver::<f64>::is_power_of_8(64));
assert!(!CooleyTukeySolver::<f64>::is_power_of_8(128));
assert!(!CooleyTukeySolver::<f64>::is_power_of_8(256));
assert!(CooleyTukeySolver::<f64>::is_power_of_8(512));
assert!(!CooleyTukeySolver::<f64>::is_power_of_8(0));
assert!(!CooleyTukeySolver::<f64>::is_power_of_8(3));
assert!(!CooleyTukeySolver::<f64>::is_power_of_8(5));
}
#[test]
fn test_radix8_matches_radix2_size_8() {
let input: Vec<Complex<f64>> = (0..8)
.map(|i| Complex::new(f64::from(i), f64::from(i) * 0.5))
.collect();
let mut output_radix2 = vec![Complex::zero(); 8];
let mut output_radix8 = vec![Complex::zero(); 8];
fft_radix2(&input, &mut output_radix2);
fft_radix8(&input, &mut output_radix8);
for (a, b) in output_radix2.iter().zip(output_radix8.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
#[test]
fn test_radix8_matches_radix2_size_16() {
let input: Vec<Complex<f64>> = (0..16)
.map(|i| Complex::new(f64::from(i).sin(), f64::from(i).cos()))
.collect();
let mut output_radix2 = vec![Complex::zero(); 16];
let mut output_radix8 = vec![Complex::zero(); 16];
fft_radix2(&input, &mut output_radix2);
fft_radix8(&input, &mut output_radix8);
for (a, b) in output_radix2.iter().zip(output_radix8.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_radix8_matches_radix2_size_32() {
let input: Vec<Complex<f64>> = (0..32)
.map(|i| Complex::new(f64::from(i).sin(), f64::from(i).cos()))
.collect();
let mut output_radix2 = vec![Complex::zero(); 32];
let mut output_radix8 = vec![Complex::zero(); 32];
fft_radix2(&input, &mut output_radix2);
fft_radix8(&input, &mut output_radix8);
for (a, b) in output_radix2.iter().zip(output_radix8.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_radix8_matches_radix2_size_64() {
let input: Vec<Complex<f64>> = (0..64)
.map(|i| Complex::new(f64::from(i).sin(), f64::from(i).cos()))
.collect();
let mut output_radix2 = vec![Complex::zero(); 64];
let mut output_radix8 = vec![Complex::zero(); 64];
fft_radix2(&input, &mut output_radix2);
fft_radix8(&input, &mut output_radix8);
for (a, b) in output_radix2.iter().zip(output_radix8.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_radix8_matches_radix2_size_128() {
let input: Vec<Complex<f64>> = (0..128)
.map(|i| Complex::new(f64::from(i).sin(), f64::from(i).cos()))
.collect();
let mut output_radix2 = vec![Complex::zero(); 128];
let mut output_radix8 = vec![Complex::zero(); 128];
fft_radix2(&input, &mut output_radix2);
fft_radix8(&input, &mut output_radix8);
for (a, b) in output_radix2.iter().zip(output_radix8.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-8));
}
}
#[test]
fn test_radix8_matches_radix2_size_512() {
let input: Vec<Complex<f64>> = (0..512)
.map(|i| Complex::new(f64::from(i).sin(), f64::from(i).cos()))
.collect();
let mut output_radix2 = vec![Complex::zero(); 512];
let mut output_radix8 = vec![Complex::zero(); 512];
fft_radix2(&input, &mut output_radix2);
fft_radix8(&input, &mut output_radix8);
for (a, b) in output_radix2.iter().zip(output_radix8.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-8));
}
}
#[test]
fn test_radix8_inverse_recovers_input() {
let original: Vec<Complex<f64>> = (0..64)
.map(|i| Complex::new(f64::from(i), f64::from(i) * 0.5))
.collect();
let mut transformed = vec![Complex::zero(); 64];
let mut recovered = vec![Complex::zero(); 64];
CooleyTukeySolver::new(CtVariant::DitRadix8).execute(
&original,
&mut transformed,
Sign::Forward,
);
CooleyTukeySolver::new(CtVariant::DitRadix8).execute(
&transformed,
&mut recovered,
Sign::Backward,
);
let n = 64.0_f64;
for x in &mut recovered {
*x = *x / n;
}
for (a, b) in original.iter().zip(recovered.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_radix8_inplace() {
let input: Vec<Complex<f64>> = (0..64).map(|i| Complex::new(f64::from(i), 0.0)).collect();
let mut out_of_place = vec![Complex::zero(); 64];
fft_radix8(&input, &mut out_of_place);
let mut in_place = input;
fft_radix8_inplace(&mut in_place);
for (a, b) in out_of_place.iter().zip(in_place.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
#[test]
fn test_split_radix_matches_radix2_size_4() {
let input = [
Complex::new(1.0_f64, 0.0),
Complex::new(2.0, 0.0),
Complex::new(3.0, 0.0),
Complex::new(4.0, 0.0),
];
let mut output_radix2 = [Complex::zero(); 4];
let mut output_split = [Complex::zero(); 4];
fft_radix2(&input, &mut output_radix2);
fft_split_radix(&input, &mut output_split);
for (a, b) in output_radix2.iter().zip(output_split.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
#[test]
fn test_split_radix_matches_radix2_size_8() {
let input: Vec<Complex<f64>> = (0..8)
.map(|i| Complex::new(f64::from(i), f64::from(i) * 0.5))
.collect();
let mut output_radix2 = vec![Complex::zero(); 8];
let mut output_split = vec![Complex::zero(); 8];
fft_radix2(&input, &mut output_radix2);
fft_split_radix(&input, &mut output_split);
for (a, b) in output_radix2.iter().zip(output_split.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
#[test]
fn test_split_radix_matches_radix2_size_16() {
let input: Vec<Complex<f64>> = (0..16)
.map(|i| Complex::new(f64::from(i).sin(), f64::from(i).cos()))
.collect();
let mut output_radix2 = vec![Complex::zero(); 16];
let mut output_split = vec![Complex::zero(); 16];
fft_radix2(&input, &mut output_radix2);
fft_split_radix(&input, &mut output_split);
for (a, b) in output_radix2.iter().zip(output_split.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_split_radix_matches_radix2_size_64() {
let input: Vec<Complex<f64>> = (0..64)
.map(|i| Complex::new(f64::from(i).sin(), f64::from(i).cos()))
.collect();
let mut output_radix2 = vec![Complex::zero(); 64];
let mut output_split = vec![Complex::zero(); 64];
fft_radix2(&input, &mut output_radix2);
fft_split_radix(&input, &mut output_split);
for (a, b) in output_radix2.iter().zip(output_split.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_split_radix_matches_radix2_size_256() {
let input: Vec<Complex<f64>> = (0..256)
.map(|i| Complex::new(f64::from(i).sin(), f64::from(i).cos()))
.collect();
let mut output_radix2 = vec![Complex::zero(); 256];
let mut output_split = vec![Complex::zero(); 256];
fft_radix2(&input, &mut output_radix2);
fft_split_radix(&input, &mut output_split);
for (a, b) in output_radix2.iter().zip(output_split.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-8));
}
}
#[test]
fn test_split_radix_inverse_recovers_input() {
let original: Vec<Complex<f64>> = (0..64)
.map(|i| Complex::new(f64::from(i), f64::from(i) * 0.5))
.collect();
let mut transformed = vec![Complex::zero(); 64];
let mut recovered = vec![Complex::zero(); 64];
CooleyTukeySolver::new(CtVariant::SplitRadix).execute(
&original,
&mut transformed,
Sign::Forward,
);
CooleyTukeySolver::new(CtVariant::SplitRadix).execute(
&transformed,
&mut recovered,
Sign::Backward,
);
let n = 64.0_f64;
for x in &mut recovered {
*x = *x / n;
}
for (a, b) in original.iter().zip(recovered.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_split_radix_inplace() {
let input: Vec<Complex<f64>> = (0..64).map(|i| Complex::new(f64::from(i), 0.0)).collect();
let mut out_of_place = vec![Complex::zero(); 64];
fft_split_radix(&input, &mut out_of_place);
let mut in_place = input;
fft_split_radix_inplace(&mut in_place);
for (a, b) in out_of_place.iter().zip(in_place.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
}