use std::sync::Arc;
use num_complex::Complex;
use num_traits::Zero;
use strength_reduce::StrengthReducedUsize;
use common::{FFTnum, verify_length, verify_length_divisible};
use math_utils;
use twiddles;
use ::{Length, IsInverse, FFT};
pub struct RadersAlgorithm<T> {
inner_fft: Arc<FFT<T>>,
inner_fft_data: Box<[Complex<T>]>,
primitive_root: usize,
primitive_root_inverse: usize,
len: StrengthReducedUsize,
}
impl<T: FFTnum> RadersAlgorithm<T> {
pub fn new(len: usize, inner_fft: Arc<FFT<T>>) -> Self {
assert_eq!(len - 1, inner_fft.len(), "For raders algorithm, inner_fft.len() must be self.len() - 1. Expected {}, got {}", len - 1, inner_fft.len());
let inner_fft_len = len - 1;
let reduced_len = StrengthReducedUsize::new(len);
let primitive_root = math_utils::primitive_root(len as u64).unwrap() as usize;
let primitive_root_inverse = math_utils::multiplicative_inverse(primitive_root as usize, len);
let unity_scale = T::from_f64(1f64 / inner_fft_len as f64).unwrap();
let mut inner_fft_input = vec![Complex::zero(); inner_fft_len];
let mut twiddle_input = 1;
for input_cell in &mut inner_fft_input {
let twiddle = twiddles::single_twiddle(twiddle_input, len, inner_fft.is_inverse());
*input_cell = twiddle * unity_scale;
twiddle_input = (twiddle_input * primitive_root_inverse) % reduced_len;
}
let mut inner_fft_output = vec![Zero::zero(); inner_fft_len];
inner_fft.process(&mut inner_fft_input, &mut inner_fft_output);
Self {
inner_fft: inner_fft,
inner_fft_data: inner_fft_output.into_boxed_slice(),
primitive_root,
primitive_root_inverse,
len: reduced_len,
}
}
fn perform_fft(&self, input: &mut [Complex<T>], output: &mut [Complex<T>]) {
output[0] = input.iter().sum();
let first_input_val = input[0];
let (_, output) = output.split_first_mut().unwrap();
let (_, input) = input.split_first_mut().unwrap();
let mut input_index = 1;
for output_element in output.iter_mut() {
input_index = (input_index * self.primitive_root) % self.len;
*output_element = input[input_index - 1];
}
self.inner_fft.process(output, input);
for ((&input_cell, output_cell), &multiple) in input.iter().zip(output.iter_mut()).zip(self.inner_fft_data.iter()) {
*output_cell = (input_cell * multiple).conj();
}
self.inner_fft.process(output, input);
let mut output_index = 1;
for input_element in input {
output_index = (output_index * self.primitive_root_inverse) % self.len;
output[output_index - 1] = input_element.conj() + first_input_val;
}
}
}
impl<T: FFTnum> FFT<T> for RadersAlgorithm<T> {
fn process(&self, input: &mut [Complex<T>], output: &mut [Complex<T>]) {
verify_length(input, output, self.len());
self.perform_fft(input, output);
}
fn process_multi(&self, input: &mut [Complex<T>], output: &mut [Complex<T>]) {
verify_length_divisible(input, output, self.len());
for (in_chunk, out_chunk) in input.chunks_mut(self.len()).zip(output.chunks_mut(self.len())) {
self.perform_fft(in_chunk, out_chunk);
}
}
}
impl<T> Length for RadersAlgorithm<T> {
#[inline(always)]
fn len(&self) -> usize {
self.len.get()
}
}
impl<T> IsInverse for RadersAlgorithm<T> {
#[inline(always)]
fn is_inverse(&self) -> bool {
self.inner_fft.is_inverse()
}
}
#[cfg(test)]
mod unit_tests {
use super::*;
use std::sync::Arc;
use test_utils::check_fft_algorithm;
use algorithm::DFT;
#[test]
fn test_raders() {
for &len in &[3,5,7,11,13] {
test_raders_with_length(len, false);
test_raders_with_length(len, true);
}
}
fn test_raders_with_length(len: usize, inverse: bool) {
let inner_fft = Arc::new(DFT::new(len - 1, inverse));
let fft = RadersAlgorithm::new(len, inner_fft);
check_fft_algorithm(&fft, len, inverse);
}
}