use num_integer::gcd;
use std::collections::HashMap;
use std::sync::Arc;
use crate::common::FFTnum;
use crate::algorithm::butterflies::*;
use crate::algorithm::*;
use crate::FFT;
use crate::math_utils;
const MIN_RADIX4_BITS: u32 = 5; const MAX_RADIX4_BITS: u32 = 16; const BUTTERFLIES: [usize; 16] = [2, 3, 4, 5, 6, 7, 8, 11, 13, 16, 17, 19, 23, 29, 31, 32];
const COMPOSITE_BUTTERFLIES: [usize; 16] =
[2, 3, 4, 5, 6, 7, 8, 11, 13, 16, 17, 19, 23, 29, 31, 32];
const MAX_RADER_PRIME_FACTOR: usize = 23; const MIN_BLUESTEIN_MIXED_RADIX_LEN: usize = 90;
pub struct FFTplanner<T> {
inverse: bool,
algorithm_cache: HashMap<usize, Arc<dyn FFT<T>>>,
butterfly_cache: HashMap<usize, Arc<dyn FFTButterfly<T>>>,
}
impl<T: FFTnum> FFTplanner<T> {
pub fn new(inverse: bool) -> Self {
Self {
inverse,
algorithm_cache: HashMap::new(),
butterfly_cache: HashMap::new(),
}
}
pub fn plan_fft(&mut self, len: usize) -> Arc<dyn FFT<T>> {
if len < 2 {
Arc::new(DFT::new(len, self.inverse)) as Arc<dyn FFT<T>>
} else {
let factors = math_utils::prime_factors(len);
self.plan_fft_with_factors(len, &factors)
}
}
fn plan_butterfly(&mut self, len: usize) -> Arc<dyn FFTButterfly<T>> {
let inverse = self.inverse;
let instance = self
.butterfly_cache
.entry(len)
.or_insert_with(|| match len {
2 => Arc::new(Butterfly2::new(inverse)),
3 => Arc::new(Butterfly3::new(inverse)),
4 => Arc::new(Butterfly4::new(inverse)),
5 => Arc::new(Butterfly5::new(inverse)),
6 => Arc::new(Butterfly6::new(inverse)),
7 => Arc::new(Butterfly7::new(inverse)),
8 => Arc::new(Butterfly8::new(inverse)),
11 => Arc::new(Butterfly11::new(inverse)),
13 => Arc::new(Butterfly13::new(inverse)),
16 => Arc::new(Butterfly16::new(inverse)),
17 => Arc::new(Butterfly17::new(inverse)),
19 => Arc::new(Butterfly19::new(inverse)),
23 => Arc::new(Butterfly23::new(inverse)),
29 => Arc::new(Butterfly29::new(inverse)),
31 => Arc::new(Butterfly31::new(inverse)),
32 => Arc::new(Butterfly32::new(inverse)),
_ => panic!("Invalid butterfly size: {}", len),
});
Arc::clone(instance)
}
fn plan_fft_with_factors(&mut self, len: usize, factors: &[usize]) -> Arc<dyn FFT<T>> {
if self.algorithm_cache.contains_key(&len) {
Arc::clone(self.algorithm_cache.get(&len).unwrap())
} else {
let result = if factors.len() == 1 || COMPOSITE_BUTTERFLIES.contains(&len) {
self.plan_fft_single_factor(len)
} else if len.trailing_zeros() <= MAX_RADIX4_BITS
&& len.trailing_zeros() >= MIN_RADIX4_BITS
{
if len.is_power_of_two() {
Arc::new(Radix4::new(len, self.inverse))
} else {
let left_len = 1 << len.trailing_zeros();
let right_len = len / left_len;
let (left_factors, right_factors) =
factors.split_at(len.trailing_zeros() as usize);
self.plan_mixed_radix(left_len, left_factors, right_len, right_factors)
}
} else {
let sqrt = (len as f32).sqrt() as usize;
if sqrt * sqrt == len {
let mut sqrt_factors = Vec::with_capacity(factors.len() / 2);
for chunk in factors.chunks_exact(2) {
sqrt_factors.push(chunk[0]);
}
self.plan_mixed_radix(sqrt, &sqrt_factors, sqrt, &sqrt_factors)
} else {
let mut product = 1;
let mut second_half_index = 1;
for (i, factor) in factors.iter().enumerate() {
if product * *factor > sqrt {
second_half_index = i;
break;
} else {
product = product * *factor;
}
}
if product * factors[second_half_index] < len / product {
product = product * factors[second_half_index];
second_half_index = second_half_index + 1;
}
let (left_factors, right_factors) = factors.split_at(second_half_index);
self.plan_mixed_radix(product, left_factors, len / product, right_factors)
}
};
self.algorithm_cache.insert(len, Arc::clone(&result));
result
}
}
fn plan_mixed_radix(
&mut self,
left_len: usize,
left_factors: &[usize],
right_len: usize,
right_factors: &[usize],
) -> Arc<dyn FFT<T>> {
let left_is_butterfly = BUTTERFLIES.contains(&left_len);
let right_is_butterfly = BUTTERFLIES.contains(&right_len);
if left_is_butterfly && right_is_butterfly {
let left_fft = self.plan_butterfly(left_len);
let right_fft = self.plan_butterfly(right_len);
if gcd(left_len, right_len) == 1 {
Arc::new(GoodThomasAlgorithmDoubleButterfly::new(left_fft, right_fft))
as Arc<dyn FFT<T>>
} else {
Arc::new(MixedRadixDoubleButterfly::new(left_fft, right_fft)) as Arc<dyn FFT<T>>
}
} else {
let left_fft = self.plan_fft_with_factors(left_len, left_factors);
let right_fft = self.plan_fft_with_factors(right_len, right_factors);
Arc::new(MixedRadix::new(left_fft, right_fft)) as Arc<dyn FFT<T>>
}
}
fn plan_fft_single_factor(&mut self, len: usize) -> Arc<dyn FFT<T>> {
match len {
0 | 1 => Arc::new(DFT::new(len, self.inverse)) as Arc<dyn FFT<T>>,
2 => Arc::new(butterflies::Butterfly2::new(self.inverse)) as Arc<dyn FFT<T>>,
3 => Arc::new(butterflies::Butterfly3::new(self.inverse)) as Arc<dyn FFT<T>>,
4 => Arc::new(butterflies::Butterfly4::new(self.inverse)) as Arc<dyn FFT<T>>,
5 => Arc::new(butterflies::Butterfly5::new(self.inverse)) as Arc<dyn FFT<T>>,
6 => Arc::new(butterflies::Butterfly6::new(self.inverse)) as Arc<dyn FFT<T>>,
7 => Arc::new(butterflies::Butterfly7::new(self.inverse)) as Arc<dyn FFT<T>>,
8 => Arc::new(butterflies::Butterfly8::new(self.inverse)) as Arc<dyn FFT<T>>,
11 => Arc::new(butterflies::Butterfly11::new(self.inverse)) as Arc<dyn FFT<T>>,
13 => Arc::new(butterflies::Butterfly13::new(self.inverse)) as Arc<dyn FFT<T>>,
16 => Arc::new(butterflies::Butterfly16::new(self.inverse)) as Arc<dyn FFT<T>>,
17 => Arc::new(butterflies::Butterfly17::new(self.inverse)) as Arc<dyn FFT<T>>,
19 => Arc::new(butterflies::Butterfly19::new(self.inverse)) as Arc<dyn FFT<T>>,
23 => Arc::new(butterflies::Butterfly23::new(self.inverse)) as Arc<dyn FFT<T>>,
29 => Arc::new(butterflies::Butterfly29::new(self.inverse)) as Arc<dyn FFT<T>>,
31 => Arc::new(butterflies::Butterfly31::new(self.inverse)) as Arc<dyn FFT<T>>,
32 => Arc::new(butterflies::Butterfly32::new(self.inverse)) as Arc<dyn FFT<T>>,
_ => self.plan_prime(len),
}
}
fn plan_prime(&mut self, len: usize) -> Arc<dyn FFT<T>> {
let inner_fft_len_rader = len - 1;
let factors = math_utils::prime_factors(inner_fft_len_rader);
if factors.iter().any(|val| *val > MAX_RADER_PRIME_FACTOR) {
let inner_fft_len_pow2 = (2 * len - 1).checked_next_power_of_two().unwrap();
let min_inner_len = 2 * len - 1;
let mixed_radix_len = 3 * inner_fft_len_pow2 / 4;
let inner_fft =
if mixed_radix_len >= min_inner_len && len >= MIN_BLUESTEIN_MIXED_RADIX_LEN {
let inner_factors = math_utils::prime_factors(mixed_radix_len);
self.plan_fft_with_factors(mixed_radix_len, &inner_factors)
} else {
Arc::new(Radix4::new(inner_fft_len_pow2, self.inverse))
};
Arc::new(Bluesteins::new(len, inner_fft)) as Arc<dyn FFT<T>>
} else {
let inner_fft = self.plan_fft_with_factors(inner_fft_len_rader, &factors);
Arc::new(RadersAlgorithm::new(len, inner_fft)) as Arc<dyn FFT<T>>
}
}
}