use num_complex::Complex;
use std::any::TypeId;
use std::sync::Arc;
use crate::array_utils::{bitreversed_transpose, workaround_transmute_mut};
use crate::{common::FftNum, FftDirection};
use crate::{Direction, Fft, Length};
use super::SseNum;
use super::sse_vector::{Rotation90, SseArray, SseArrayMut, SseVector};
pub struct SseRadix4<S: SseNum, T> {
twiddles: Box<[S::VectorType]>,
rotation: Rotation90<S::VectorType>,
base_fft: Arc<dyn Fft<T>>,
base_len: usize,
len: usize,
direction: FftDirection,
}
impl<S: SseNum, T: FftNum> SseRadix4<S, T> {
#[inline]
pub fn new(k: u32, base_fft: Arc<dyn Fft<T>>) -> Result<Self, ()> {
let id_a = TypeId::of::<S>();
let id_t = TypeId::of::<T>();
assert_eq!(id_a, id_t);
let has_sse = is_x86_feature_detected!("sse4.1");
if has_sse {
Ok(unsafe { Self::new_with_sse(k, base_fft) })
} else {
Err(())
}
}
#[target_feature(enable = "sse4.1")]
unsafe fn new_with_sse(k: u32, base_fft: Arc<dyn Fft<T>>) -> Self {
let direction = base_fft.fft_direction();
let base_len = base_fft.len();
assert!(base_len % (2 * S::VectorType::COMPLEX_PER_VECTOR) == 0 && base_len > 0);
let len = base_len * (1 << (k * 2));
const ROW_COUNT: usize = 4;
let mut cross_fft_len = base_len * ROW_COUNT;
let mut twiddle_factors = Vec::with_capacity(len * 2);
while cross_fft_len <= len {
let num_scalar_columns = cross_fft_len / ROW_COUNT;
let num_vector_columns = num_scalar_columns / S::VectorType::COMPLEX_PER_VECTOR;
for i in 0..num_vector_columns {
for k in 1..ROW_COUNT {
twiddle_factors.push(SseVector::make_mixedradix_twiddle_chunk(
i * S::VectorType::COMPLEX_PER_VECTOR,
k,
cross_fft_len,
direction,
));
}
}
cross_fft_len *= ROW_COUNT;
}
Self {
twiddles: twiddle_factors.into_boxed_slice(),
rotation: SseVector::make_rotate90(direction),
base_fft,
base_len,
len,
direction,
}
}
#[target_feature(enable = "sse4.1")]
unsafe fn perform_fft_immut(
&self,
input: &[Complex<T>],
output: &mut [Complex<T>],
_scratch: &mut [Complex<T>],
) {
if self.len() == self.base_len {
output.copy_from_slice(input);
} else {
bitreversed_transpose::<Complex<T>, 4>(self.base_len, input, output);
}
self.base_fft.process_with_scratch(output, &mut []);
const ROW_COUNT: usize = 4;
let mut cross_fft_len = self.base_len * ROW_COUNT;
let mut layer_twiddles: &[S::VectorType] = &self.twiddles;
while cross_fft_len <= input.len() {
let num_rows = input.len() / cross_fft_len;
let num_scalar_columns = cross_fft_len / ROW_COUNT;
let num_vector_columns = num_scalar_columns / S::VectorType::COMPLEX_PER_VECTOR;
for i in 0..num_rows {
butterfly_4::<S, T>(
&mut output[i * cross_fft_len..],
layer_twiddles,
num_scalar_columns,
&self.rotation,
)
}
let twiddle_offset = num_vector_columns * (ROW_COUNT - 1);
layer_twiddles = &layer_twiddles[twiddle_offset..];
cross_fft_len *= ROW_COUNT;
}
}
#[target_feature(enable = "sse4.1")]
unsafe fn perform_fft_out_of_place(
&self,
input: &[Complex<T>],
output: &mut [Complex<T>],
_scratch: &mut [Complex<T>],
) {
self.perform_fft_immut(input, output, _scratch);
}
}
boilerplate_fft_sse_oop!(SseRadix4, |this: &SseRadix4<_, _>| this.len);
#[target_feature(enable = "sse4.1")]
unsafe fn butterfly_4<S: SseNum, T: FftNum>(
data: &mut [Complex<T>],
twiddles: &[S::VectorType],
num_ffts: usize,
rotation: &Rotation90<S::VectorType>,
) {
let unroll_offset = S::VectorType::COMPLEX_PER_VECTOR;
let mut idx = 0usize;
let mut buffer: &mut [Complex<S>] = workaround_transmute_mut(data);
for tw in twiddles
.chunks_exact(6)
.take(num_ffts / (S::VectorType::COMPLEX_PER_VECTOR * 2))
{
let mut scratcha = [
buffer.load_complex(idx + 0 * num_ffts),
buffer.load_complex(idx + 1 * num_ffts),
buffer.load_complex(idx + 2 * num_ffts),
buffer.load_complex(idx + 3 * num_ffts),
];
let mut scratchb = [
buffer.load_complex(idx + 0 * num_ffts + unroll_offset),
buffer.load_complex(idx + 1 * num_ffts + unroll_offset),
buffer.load_complex(idx + 2 * num_ffts + unroll_offset),
buffer.load_complex(idx + 3 * num_ffts + unroll_offset),
];
scratcha[1] = SseVector::mul_complex(scratcha[1], tw[0]);
scratcha[2] = SseVector::mul_complex(scratcha[2], tw[1]);
scratcha[3] = SseVector::mul_complex(scratcha[3], tw[2]);
scratchb[1] = SseVector::mul_complex(scratchb[1], tw[3]);
scratchb[2] = SseVector::mul_complex(scratchb[2], tw[4]);
scratchb[3] = SseVector::mul_complex(scratchb[3], tw[5]);
let scratcha = SseVector::column_butterfly4(scratcha, *rotation);
let scratchb = SseVector::column_butterfly4(scratchb, *rotation);
buffer.store_complex(scratcha[0], idx + 0 * num_ffts);
buffer.store_complex(scratchb[0], idx + 0 * num_ffts + unroll_offset);
buffer.store_complex(scratcha[1], idx + 1 * num_ffts);
buffer.store_complex(scratchb[1], idx + 1 * num_ffts + unroll_offset);
buffer.store_complex(scratcha[2], idx + 2 * num_ffts);
buffer.store_complex(scratchb[2], idx + 2 * num_ffts + unroll_offset);
buffer.store_complex(scratcha[3], idx + 3 * num_ffts);
buffer.store_complex(scratchb[3], idx + 3 * num_ffts + unroll_offset);
idx += S::VectorType::COMPLEX_PER_VECTOR * 2;
}
}
#[cfg(test)]
mod unit_tests {
use super::*;
use crate::test_utils::{check_fft_algorithm, construct_base};
#[test]
fn test_sse_radix4_64() {
for base in [2, 4, 6, 8, 12, 16] {
let base_forward = construct_base(base, FftDirection::Forward);
let base_inverse = construct_base(base, FftDirection::Inverse);
for k in 0..4 {
test_sse_radix4_64_with_base(k, Arc::clone(&base_forward));
test_sse_radix4_64_with_base(k, Arc::clone(&base_inverse));
}
}
}
fn test_sse_radix4_64_with_base(k: u32, base_fft: Arc<dyn Fft<f64>>) {
let len = base_fft.len() * 4usize.pow(k);
let direction = base_fft.fft_direction();
let fft = SseRadix4::<f64, f64>::new(k, base_fft).unwrap();
check_fft_algorithm::<f64>(&fft, len, direction);
}
#[test]
fn test_sse_radix4_32() {
for base in [4, 8, 12, 16] {
let base_forward = construct_base(base, FftDirection::Forward);
let base_inverse = construct_base(base, FftDirection::Inverse);
for k in 0..4 {
test_sse_radix4_32_with_base(k, Arc::clone(&base_forward));
test_sse_radix4_32_with_base(k, Arc::clone(&base_inverse));
}
}
}
fn test_sse_radix4_32_with_base(k: u32, base_fft: Arc<dyn Fft<f32>>) {
let len = base_fft.len() * 4usize.pow(k);
let direction = base_fft.fft_direction();
let fft = SseRadix4::<f32, f32>::new(k, base_fft).unwrap();
check_fft_algorithm::<f32>(&fft, len, direction);
}
}