use std::sync::Arc;
use num_complex::Complex;
use transpose;
use common::{FFTnum, verify_length, verify_length_divisible};
use ::{Length, IsInverse, FFT};
use algorithm::butterflies::FFTButterfly;
use array_utils;
use twiddles;
pub struct MixedRadix<T> {
width: usize,
width_size_fft: Arc<FFT<T>>,
height: usize,
height_size_fft: Arc<FFT<T>>,
twiddles: Box<[Complex<T>]>,
inverse: bool,
}
impl<T: FFTnum> MixedRadix<T> {
pub fn new(width_fft: Arc<FFT<T>>, height_fft: Arc<FFT<T>>) -> Self {
assert_eq!(
width_fft.is_inverse(), height_fft.is_inverse(),
"width_fft and height_fft must both be inverse, or neither. got width inverse={}, height inverse={}",
width_fft.is_inverse(), height_fft.is_inverse());
let inverse = width_fft.is_inverse();
let width = width_fft.len();
let height = height_fft.len();
let len = width * height;
let mut twiddles = Vec::with_capacity(len);
for x in 0..width {
for y in 0..height {
twiddles.push(twiddles::single_twiddle(x * y, len, inverse));
}
}
MixedRadix {
width: width,
width_size_fft: width_fft,
height: height,
height_size_fft: height_fft,
twiddles: twiddles.into_boxed_slice(),
inverse: inverse,
}
}
fn perform_fft(&self, input: &mut [Complex<T>], output: &mut [Complex<T>]) {
transpose::transpose(input, output, self.width, self.height);
self.height_size_fft.process_multi(output, input);
for (element, &twiddle) in input.iter_mut().zip(self.twiddles.iter()) {
*element = *element * twiddle;
}
transpose::transpose(input, output, self.height, self.width);
self.width_size_fft.process_multi(output, input);
transpose::transpose(input, output, self.width, self.height);
}
}
impl<T: FFTnum> FFT<T> for MixedRadix<T> {
fn process(&self, input: &mut [Complex<T>], output: &mut [Complex<T>]) {
verify_length(input, output, self.len());
self.perform_fft(input, output);
}
fn process_multi(&self, input: &mut [Complex<T>], output: &mut [Complex<T>]) {
verify_length_divisible(input, output, self.len());
for (in_chunk, out_chunk) in input.chunks_mut(self.len()).zip(output.chunks_mut(self.len())) {
self.perform_fft(in_chunk, out_chunk);
}
}
}
impl<T> Length for MixedRadix<T> {
#[inline(always)]
fn len(&self) -> usize {
self.twiddles.len()
}
}
impl<T> IsInverse for MixedRadix<T> {
#[inline(always)]
fn is_inverse(&self) -> bool {
self.inverse
}
}
pub struct MixedRadixDoubleButterfly<T> {
width: usize,
width_size_fft: Arc<FFTButterfly<T>>,
height: usize,
height_size_fft: Arc<FFTButterfly<T>>,
twiddles: Box<[Complex<T>]>,
inverse: bool,
}
impl<T: FFTnum> MixedRadixDoubleButterfly<T> {
pub fn new(width_fft: Arc<FFTButterfly<T>>, height_fft: Arc<FFTButterfly<T>>) -> Self {
assert_eq!(
width_fft.is_inverse(), height_fft.is_inverse(),
"width_fft and height_fft must both be inverse, or neither. got width inverse={}, height inverse={}",
width_fft.is_inverse(), height_fft.is_inverse());
let inverse = width_fft.is_inverse();
let width = width_fft.len();
let height = height_fft.len();
let len = width * height;
let mut twiddles = Vec::with_capacity(len);
for x in 0..width {
for y in 0..height {
twiddles.push(twiddles::single_twiddle(x * y, len, inverse));
}
}
MixedRadixDoubleButterfly {
width: width,
width_size_fft: width_fft,
height: height,
height_size_fft: height_fft,
twiddles: twiddles.into_boxed_slice(),
inverse: inverse
}
}
unsafe fn perform_fft(&self, input: &mut [Complex<T>], output: &mut [Complex<T>]) {
array_utils::transpose_small(self.width, self.height, input, output);
self.height_size_fft.process_multi_inplace(output);
for (element, &twiddle) in output.iter_mut().zip(self.twiddles.iter()) {
*element = *element * twiddle;
}
array_utils::transpose_small(self.height, self.width, output, input);
self.width_size_fft.process_multi_inplace(input);
array_utils::transpose_small(self.width, self.height, input, output);
}
}
impl<T: FFTnum> FFT<T> for MixedRadixDoubleButterfly<T> {
fn process(&self, input: &mut [Complex<T>], output: &mut [Complex<T>]) {
verify_length(input, output, self.len());
unsafe { self.perform_fft(input, output) };
}
fn process_multi(&self, input: &mut [Complex<T>], output: &mut [Complex<T>]) {
verify_length_divisible(input, output, self.len());
for (in_chunk, out_chunk) in input.chunks_mut(self.len()).zip(output.chunks_mut(self.len())) {
unsafe { self.perform_fft(in_chunk, out_chunk) };
}
}
}
impl<T> Length for MixedRadixDoubleButterfly<T> {
#[inline(always)]
fn len(&self) -> usize {
self.twiddles.len()
}
}
impl<T> IsInverse for MixedRadixDoubleButterfly<T> {
#[inline(always)]
fn is_inverse(&self) -> bool {
self.inverse
}
}
#[cfg(test)]
mod unit_tests {
use super::*;
use std::sync::Arc;
use test_utils::{check_fft_algorithm, make_butterfly};
use algorithm::DFT;
#[test]
fn test_mixed_radix() {
for width in 1..7 {
for height in 1..7 {
test_mixed_radix_with_lengths(width, height, false);
test_mixed_radix_with_lengths(width, height, true);
}
}
}
#[test]
fn test_mixed_radix_double_butterfly() {
for width in 2..7 {
for height in 2..7 {
test_mixed_radix_butterfly_with_lengths(width, height, false);
test_mixed_radix_butterfly_with_lengths(width, height, true);
}
}
}
fn test_mixed_radix_with_lengths(width: usize, height: usize, inverse: bool) {
let width_fft = Arc::new(DFT::new(width, inverse)) as Arc<FFT<f32>>;
let height_fft = Arc::new(DFT::new(height, inverse)) as Arc<FFT<f32>>;
let fft = MixedRadix::new(width_fft, height_fft);
check_fft_algorithm(&fft, width * height, inverse);
}
fn test_mixed_radix_butterfly_with_lengths(width: usize, height: usize, inverse: bool) {
let width_fft = make_butterfly(width, inverse);
let height_fft = make_butterfly(height, inverse);
let fft = MixedRadixDoubleButterfly::new(width_fft, height_fft);
check_fft_algorithm(&fft, width * height, inverse);
}
}