use std::sync::Arc;
use num_complex::Complex;
use num_integer::Integer;
use num_traits::Zero;
use primal_check::miller_rabin;
use strength_reduce::StrengthReducedU64;
use crate::math_utils;
use crate::{common::FftNum, twiddles, FftDirection};
use crate::{Direction, Fft, Length};
pub struct RadersAlgorithm<T> {
inner_fft: Arc<dyn Fft<T>>,
inner_fft_data: Box<[Complex<T>]>,
primitive_root: u64,
primitive_root_inverse: u64,
len: StrengthReducedU64,
inplace_scratch_len: usize,
outofplace_scratch_len: usize,
immut_scratch_len: usize,
direction: FftDirection,
}
impl<T: FftNum> RadersAlgorithm<T> {
pub fn new(inner_fft: Arc<dyn Fft<T>>) -> Self {
let inner_fft_len = inner_fft.len();
let len = inner_fft_len + 1;
assert!(miller_rabin(len as u64), "For raders algorithm, inner_fft.len() + 1 must be prime. Expected prime number, got {} + 1 = {}", inner_fft_len, len);
let direction = inner_fft.fft_direction();
let reduced_len = StrengthReducedU64::new(len as u64);
let primitive_root = math_utils::primitive_root(len as u64).unwrap();
let gcd_data = i64::extended_gcd(&(primitive_root as i64), &(len as i64));
let primitive_root_inverse = if gcd_data.x >= 0 {
gcd_data.x
} else {
gcd_data.x + len as i64
} as u64;
let inner_fft_scale = T::one() / T::from_usize(inner_fft_len).unwrap();
let mut inner_fft_input = vec![Complex::zero(); inner_fft_len];
let mut twiddle_input = 1;
for input_cell in &mut inner_fft_input {
let twiddle = twiddles::compute_twiddle(twiddle_input, len, direction);
*input_cell = twiddle * inner_fft_scale;
twiddle_input =
((twiddle_input as u64 * primitive_root_inverse) % reduced_len) as usize;
}
let required_inner_scratch = inner_fft.get_inplace_scratch_len();
let extra_inner_scratch = if required_inner_scratch <= inner_fft_len {
0
} else {
required_inner_scratch
};
let inplace_scratch_len = inner_fft_len + extra_inner_scratch;
let immut_scratch_len = inner_fft_len + required_inner_scratch;
let mut inner_fft_scratch = vec![Zero::zero(); required_inner_scratch];
inner_fft.process_with_scratch(&mut inner_fft_input, &mut inner_fft_scratch);
Self {
inner_fft,
inner_fft_data: inner_fft_input.into_boxed_slice(),
primitive_root,
primitive_root_inverse,
len: reduced_len,
inplace_scratch_len,
outofplace_scratch_len: extra_inner_scratch,
immut_scratch_len,
direction,
}
}
fn perform_fft_immut(
&self,
input: &[Complex<T>],
output: &mut [Complex<T>],
scratch: &mut [Complex<T>],
) {
let (output_first, output) = output.split_first_mut().unwrap();
let (input_first, input) = input.split_first().unwrap();
let (scratch, extra_scratch) = scratch.split_at_mut(self.len() - 1);
let mut input_index = 1;
for output_element in scratch.iter_mut() {
input_index = ((input_index as u64 * self.primitive_root) % self.len) as usize;
let input_element = input[input_index - 1];
*output_element = input_element;
}
self.inner_fft.process_with_scratch(scratch, extra_scratch);
*output_first = *input_first + scratch[0];
for (scratch_cell, &twiddle) in scratch.iter_mut().zip(self.inner_fft_data.iter()) {
*scratch_cell = (*scratch_cell * twiddle).conj();
}
scratch[0] = scratch[0] + input_first.conj();
self.inner_fft.process_with_scratch(scratch, extra_scratch);
let mut output_index = 1;
for scratch_element in scratch {
output_index =
((output_index as u64 * self.primitive_root_inverse) % self.len) as usize;
output[output_index - 1] = scratch_element.conj();
}
}
fn perform_fft_out_of_place(
&self,
input: &mut [Complex<T>],
output: &mut [Complex<T>],
scratch: &mut [Complex<T>],
) {
let (output_first, output) = output.split_first_mut().unwrap();
let (input_first, input) = input.split_first_mut().unwrap();
let mut input_index = 1;
for output_element in output.iter_mut() {
input_index = ((input_index as u64 * self.primitive_root) % self.len) as usize;
let input_element = input[input_index - 1];
*output_element = input_element;
}
let inner_scratch = if scratch.len() > 0 {
&mut scratch[..]
} else {
&mut input[..]
};
self.inner_fft.process_with_scratch(output, inner_scratch);
*output_first = *input_first + output[0];
for ((output_cell, input_cell), &multiple) in output
.iter()
.zip(input.iter_mut())
.zip(self.inner_fft_data.iter())
{
*input_cell = (*output_cell * multiple).conj();
}
input[0] = input[0] + input_first.conj();
let inner_scratch = if scratch.len() > 0 {
scratch
} else {
&mut output[..]
};
self.inner_fft.process_with_scratch(input, inner_scratch);
let mut output_index = 1;
for input_element in input {
output_index =
((output_index as u64 * self.primitive_root_inverse) % self.len) as usize;
output[output_index - 1] = input_element.conj();
}
}
fn perform_fft_inplace(&self, buffer: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
let (buffer_first, buffer) = buffer.split_first_mut().unwrap();
let buffer_first_val = *buffer_first;
let (scratch, extra_scratch) = scratch.split_at_mut(self.len() - 1);
let mut input_index = 1;
for scratch_element in scratch.iter_mut() {
input_index = ((input_index as u64 * self.primitive_root) % self.len) as usize;
let buffer_element = buffer[input_index - 1];
*scratch_element = buffer_element;
}
let inner_scratch = if extra_scratch.len() > 0 {
extra_scratch
} else {
&mut buffer[..]
};
self.inner_fft.process_with_scratch(scratch, inner_scratch);
*buffer_first = *buffer_first + scratch[0];
for (scratch_cell, &twiddle) in scratch.iter_mut().zip(self.inner_fft_data.iter()) {
*scratch_cell = (*scratch_cell * twiddle).conj();
}
scratch[0] = scratch[0] + buffer_first_val.conj();
self.inner_fft.process_with_scratch(scratch, inner_scratch);
let mut output_index = 1;
for scratch_element in scratch {
output_index =
((output_index as u64 * self.primitive_root_inverse) % self.len) as usize;
buffer[output_index - 1] = scratch_element.conj();
}
}
}
boilerplate_fft!(
RadersAlgorithm,
|this: &RadersAlgorithm<_>| this.len.get() as usize,
|this: &RadersAlgorithm<_>| this.inplace_scratch_len,
|this: &RadersAlgorithm<_>| this.outofplace_scratch_len,
|this: &RadersAlgorithm<_>| this.immut_scratch_len
);
#[cfg(test)]
mod unit_tests {
use super::*;
use crate::algorithm::Dft;
use crate::test_utils::check_fft_algorithm;
use crate::FftPlanner;
use std::sync::Arc;
#[test]
fn test_raders() {
for len in 3..100 {
if miller_rabin(len as u64) {
test_raders_with_length(len, FftDirection::Forward);
test_raders_with_length(len, FftDirection::Inverse);
}
}
}
#[test]
fn test_raders_32bit_overflow() {
let mut planner = FftPlanner::<f32>::new();
for len in [112501, 216569, 417623] {
let inner_fft = planner.plan_fft_forward(len - 1);
let fft: RadersAlgorithm<f32> = RadersAlgorithm::new(inner_fft);
let mut data = vec![Complex::new(0.0, 0.0); len];
fft.process(&mut data);
}
}
fn test_raders_with_length(len: usize, direction: FftDirection) {
let inner_fft = Arc::new(Dft::new(len - 1, direction));
let fft = RadersAlgorithm::new(inner_fft);
check_fft_algorithm::<f32>(&fft, len, direction);
}
}