use std::sync::Arc;
use num_complex::Complex;
use num_traits::Zero;
use strength_reduce::StrengthReducedUsize;
use crate::common::{verify_length, verify_length_divisible, FFTnum};
use crate::math_utils;
use crate::twiddles;
use crate::{IsInverse, Length, FFT};
pub struct RadersAlgorithm<T> {
inner_fft: Arc<dyn FFT<T>>,
inner_fft_data: Box<[Complex<T>]>,
primitive_root: usize,
primitive_root_inverse: usize,
len: StrengthReducedUsize,
}
impl<T: FFTnum> RadersAlgorithm<T> {
pub fn new(len: usize, inner_fft: Arc<dyn FFT<T>>) -> Self {
assert_eq!(
len - 1,
inner_fft.len(),
"For raders algorithm, inner_fft.len() must be self.len() - 1. Expected {}, got {}",
len - 1,
inner_fft.len()
);
let inner_fft_len = len - 1;
let reduced_len = StrengthReducedUsize::new(len);
let primitive_root = math_utils::primitive_root(len as u64).unwrap() as usize;
let primitive_root_inverse =
math_utils::multiplicative_inverse(primitive_root as usize, len);
let unity_scale = T::from_f64(1f64 / inner_fft_len as f64).unwrap();
let mut inner_fft_input = vec![Complex::zero(); inner_fft_len];
let mut twiddle_input = 1;
for input_cell in &mut inner_fft_input {
let twiddle = twiddles::single_twiddle(twiddle_input, len, inner_fft.is_inverse());
*input_cell = twiddle * unity_scale;
twiddle_input = (twiddle_input * primitive_root_inverse) % reduced_len;
}
let mut inner_fft_output = vec![Zero::zero(); inner_fft_len];
inner_fft.process(&mut inner_fft_input, &mut inner_fft_output);
Self {
inner_fft,
inner_fft_data: inner_fft_output.into_boxed_slice(),
primitive_root,
primitive_root_inverse,
len: reduced_len,
}
}
fn perform_fft(&self, input: &mut [Complex<T>], output: &mut [Complex<T>]) {
output[0] = input.iter().sum();
let first_input_val = input[0];
let (_, output) = output.split_first_mut().unwrap();
let (_, input) = input.split_first_mut().unwrap();
let mut input_index = 1;
for output_element in output.iter_mut() {
input_index = (input_index * self.primitive_root) % self.len;
*output_element = input[input_index - 1];
}
self.inner_fft.process(output, input);
for ((&input_cell, output_cell), &multiple) in input
.iter()
.zip(output.iter_mut())
.zip(self.inner_fft_data.iter())
{
*output_cell = (input_cell * multiple).conj();
}
self.inner_fft.process(output, input);
let mut output_index = 1;
for input_element in input {
output_index = (output_index * self.primitive_root_inverse) % self.len;
output[output_index - 1] = input_element.conj() + first_input_val;
}
}
}
impl<T: FFTnum> FFT<T> for RadersAlgorithm<T> {
fn process(&self, input: &mut [Complex<T>], output: &mut [Complex<T>]) {
verify_length(input, output, self.len());
self.perform_fft(input, output);
}
fn process_multi(&self, input: &mut [Complex<T>], output: &mut [Complex<T>]) {
verify_length_divisible(input, output, self.len());
for (in_chunk, out_chunk) in input
.chunks_exact_mut(self.len())
.zip(output.chunks_exact_mut(self.len()))
{
self.perform_fft(in_chunk, out_chunk);
}
}
}
impl<T> Length for RadersAlgorithm<T> {
#[inline(always)]
fn len(&self) -> usize {
self.len.get()
}
}
impl<T> IsInverse for RadersAlgorithm<T> {
#[inline(always)]
fn is_inverse(&self) -> bool {
self.inner_fft.is_inverse()
}
}
#[cfg(test)]
mod unit_tests {
use super::*;
use crate::algorithm::DFT;
use crate::test_utils::check_fft_algorithm;
use std::sync::Arc;
#[test]
fn test_raders() {
for &len in &[3, 5, 7, 11, 13] {
test_raders_with_length(len, false);
test_raders_with_length(len, true);
}
}
fn test_raders_with_length(len: usize, inverse: bool) {
let inner_fft = Arc::new(DFT::new(len - 1, inverse));
let fft = RadersAlgorithm::new(len, inner_fft);
check_fft_algorithm(&fft, len, inverse);
}
}