use num_complex::Complex;
use std::any::TypeId;
use std::sync::Arc;
use crate::array_utils::{bitreversed_transpose, workaround_transmute_mut};
use crate::{common::FftNum, FftDirection};
use crate::{Direction, Fft, Length};
use super::WasmNum;
use super::wasm_simd_vector::{Rotation90, WasmSimdArray, WasmSimdArrayMut, WasmVector};
pub struct WasmSimdRadix4<S: WasmNum, T> {
twiddles: Box<[S::VectorType]>,
rotation: Rotation90<S::VectorType>,
base_fft: Arc<dyn Fft<T>>,
base_len: usize,
len: usize,
direction: FftDirection,
}
impl<S: WasmNum, T: FftNum> WasmSimdRadix4<S, T> {
pub fn new(k: u32, base_fft: Arc<dyn Fft<T>>) -> Self {
let id_a = TypeId::of::<S>();
let id_t = TypeId::of::<T>();
assert_eq!(id_a, id_t);
let direction = base_fft.fft_direction();
let base_len = base_fft.len();
assert!(base_len % (2 * S::VectorType::COMPLEX_PER_VECTOR) == 0 && base_len > 0);
let len = base_len * (1 << (k * 2));
const ROW_COUNT: usize = 4;
let mut cross_fft_len = base_len * ROW_COUNT;
let mut twiddle_factors = Vec::with_capacity(len * 2);
while cross_fft_len <= len {
let num_scalar_columns = cross_fft_len / ROW_COUNT;
let num_vector_columns = num_scalar_columns / S::VectorType::COMPLEX_PER_VECTOR;
for i in 0..num_vector_columns {
for k in 1..ROW_COUNT {
unsafe {
twiddle_factors.push(WasmVector::make_mixedradix_twiddle_chunk(
i * S::VectorType::COMPLEX_PER_VECTOR,
k,
cross_fft_len,
direction,
));
}
}
}
cross_fft_len *= ROW_COUNT;
}
Self {
twiddles: twiddle_factors.into_boxed_slice(),
rotation: unsafe { WasmVector::make_rotate90(direction) },
base_fft,
base_len,
len,
direction,
}
}
#[target_feature(enable = "simd128")]
unsafe fn perform_fft_immut(
&self,
input: &[Complex<T>],
output: &mut [Complex<T>],
_scratch: &mut [Complex<T>],
) {
if self.len() == self.base_len {
output.copy_from_slice(input);
} else {
bitreversed_transpose::<Complex<T>, 4>(self.base_len, input, output);
}
self.base_fft.process_with_scratch(output, &mut []);
const ROW_COUNT: usize = 4;
let mut cross_fft_len = self.base_len * ROW_COUNT;
let mut layer_twiddles: &[S::VectorType] = &self.twiddles;
while cross_fft_len <= input.len() {
let num_rows = input.len() / cross_fft_len;
let num_scalar_columns = cross_fft_len / ROW_COUNT;
let num_vector_columns = num_scalar_columns / S::VectorType::COMPLEX_PER_VECTOR;
for i in 0..num_rows {
butterfly_4::<S, T>(
&mut output[i * cross_fft_len..],
layer_twiddles,
num_scalar_columns,
&self.rotation,
)
}
let twiddle_offset = num_vector_columns * (ROW_COUNT - 1);
layer_twiddles = &layer_twiddles[twiddle_offset..];
cross_fft_len *= ROW_COUNT;
}
}
#[target_feature(enable = "simd128")]
unsafe fn perform_fft_out_of_place(
&self,
input: &[Complex<T>],
output: &mut [Complex<T>],
_scratch: &mut [Complex<T>],
) {
self.perform_fft_immut(input, output, _scratch);
}
}
boilerplate_fft_wasm_simd_oop!(WasmSimdRadix4, |this: &WasmSimdRadix4<_, _>| this.len);
#[target_feature(enable = "simd128")]
unsafe fn butterfly_4<S: WasmNum, T: FftNum>(
data: &mut [Complex<T>],
twiddles: &[S::VectorType],
num_ffts: usize,
rotation: &Rotation90<S::VectorType>,
) {
let unroll_offset = S::VectorType::COMPLEX_PER_VECTOR;
let mut idx = 0usize;
let mut buffer: &mut [Complex<S>] = workaround_transmute_mut(data);
for tw in twiddles
.chunks_exact(6)
.take(num_ffts / (S::VectorType::COMPLEX_PER_VECTOR * 2))
{
let mut scratcha = [
buffer.load_complex(idx + 0 * num_ffts),
buffer.load_complex(idx + 1 * num_ffts),
buffer.load_complex(idx + 2 * num_ffts),
buffer.load_complex(idx + 3 * num_ffts),
];
let mut scratchb = [
buffer.load_complex(idx + 0 * num_ffts + unroll_offset),
buffer.load_complex(idx + 1 * num_ffts + unroll_offset),
buffer.load_complex(idx + 2 * num_ffts + unroll_offset),
buffer.load_complex(idx + 3 * num_ffts + unroll_offset),
];
scratcha[1] = WasmVector::mul_complex(scratcha[1], tw[0]);
scratcha[2] = WasmVector::mul_complex(scratcha[2], tw[1]);
scratcha[3] = WasmVector::mul_complex(scratcha[3], tw[2]);
scratchb[1] = WasmVector::mul_complex(scratchb[1], tw[3]);
scratchb[2] = WasmVector::mul_complex(scratchb[2], tw[4]);
scratchb[3] = WasmVector::mul_complex(scratchb[3], tw[5]);
let scratcha = WasmVector::column_butterfly4(scratcha, *rotation);
let scratchb = WasmVector::column_butterfly4(scratchb, *rotation);
buffer.store_complex(scratcha[0], idx + 0 * num_ffts);
buffer.store_complex(scratchb[0], idx + 0 * num_ffts + unroll_offset);
buffer.store_complex(scratcha[1], idx + 1 * num_ffts);
buffer.store_complex(scratchb[1], idx + 1 * num_ffts + unroll_offset);
buffer.store_complex(scratcha[2], idx + 2 * num_ffts);
buffer.store_complex(scratchb[2], idx + 2 * num_ffts + unroll_offset);
buffer.store_complex(scratcha[3], idx + 3 * num_ffts);
buffer.store_complex(scratchb[3], idx + 3 * num_ffts + unroll_offset);
idx += S::VectorType::COMPLEX_PER_VECTOR * 2;
}
}
#[cfg(test)]
mod unit_tests {
use super::*;
use crate::test_utils::{check_fft_algorithm, construct_base};
use wasm_bindgen_test::wasm_bindgen_test;
#[wasm_bindgen_test]
fn test_wasm_simd_radix4_64() {
for base in [2, 4, 6, 8, 12, 16] {
let base_forward = construct_base(base, FftDirection::Forward);
let base_inverse = construct_base(base, FftDirection::Inverse);
for k in 0..4 {
test_wasm_simd_radix4_64_with_base(k, Arc::clone(&base_forward));
test_wasm_simd_radix4_64_with_base(k, Arc::clone(&base_inverse));
}
}
}
fn test_wasm_simd_radix4_64_with_base(k: u32, base_fft: Arc<dyn Fft<f64>>) {
let len = base_fft.len() * 4usize.pow(k);
let direction = base_fft.fft_direction();
let fft = WasmSimdRadix4::<f64, f64>::new(k, base_fft);
check_fft_algorithm::<f64>(&fft, len, direction);
}
#[wasm_bindgen_test]
fn test_wasm_simd_radix4_32() {
for base in [4, 8, 12, 16] {
let base_forward = construct_base(base, FftDirection::Forward);
let base_inverse = construct_base(base, FftDirection::Inverse);
for k in 0..4 {
test_wasm_simd_radix4_32_with_base(k, Arc::clone(&base_forward));
test_wasm_simd_radix4_32_with_base(k, Arc::clone(&base_inverse));
}
}
}
fn test_wasm_simd_radix4_32_with_base(k: u32, base_fft: Arc<dyn Fft<f32>>) {
let len = base_fft.len() * 4usize.pow(k);
let direction = base_fft.fft_direction();
let fft = WasmSimdRadix4::<f32, f32>::new(k, base_fft);
check_fft_algorithm::<f32>(&fft, len, direction);
}
}