use std::sync::Arc;
use num_complex::Complex;
use num_traits::Zero;
use crate::common::{verify_length, verify_length_divisible, FFTnum};
use crate::{IsInverse, Length, FFT};
pub struct Bluesteins<T> {
len: usize,
inner_fft: Arc<dyn FFT<T>>,
w_twiddles: Box<[Complex<T>]>,
x_twiddles: Box<[Complex<T>]>,
}
fn calculate_twiddle<T: FFTnum>(index: f64, len: usize) -> Complex<T> {
let theta = index * core::f64::consts::PI / len as f64;
Complex::new(
T::from_f64(theta.cos()).unwrap(),
T::from_f64(-theta.sin()).unwrap(),
)
}
fn calculate_w_twiddles<T: FFTnum>(len: usize, fft: &Arc<dyn FFT<T>>, twiddles: &mut [Complex<T>]) {
let mut scratch = vec![Complex::zero(); fft.len()];
let scale = T::one() / T::from_usize(fft.len()).unwrap();
for (i, tw) in scratch.iter_mut().enumerate() {
if let Some(index) = {
if i < len {
Some((i as f64).powi(2))
} else if i > fft.len() - len {
Some(((i as f64) - (fft.len() as f64)).powi(2))
} else {
None
}
} {
*tw = calculate_twiddle(index, len).conj() * scale;
}
}
fft.process(&mut scratch, &mut twiddles[..]);
if fft.is_inverse() {
for tw in twiddles.iter_mut() {
*tw = tw.conj();
}
}
}
fn calculate_x_twiddles<T: FFTnum>(len: usize, twiddles: &mut [Complex<T>], inverse: bool) {
if inverse {
for (i, tw) in twiddles.iter_mut().enumerate() {
*tw = calculate_twiddle(-(i as f64).powi(2), len);
}
} else {
for (i, tw) in twiddles.iter_mut().enumerate() {
*tw = calculate_twiddle(-(i as f64).powi(2), len).conj();
}
}
}
impl<T: FFTnum> Bluesteins<T> {
pub fn new(len: usize, inner_fft: Arc<dyn FFT<T>>) -> Self {
let min_inner_fft_len = 2 * len - 1;
assert!(inner_fft.len() >= min_inner_fft_len, "For Bluesteins algorithm, inner_fft.len() must be equal to or larger than 2*self.len() - 1. Expected at least {}, got {}", min_inner_fft_len, inner_fft.len());
let mut w_twiddles = vec![Complex::zero(); inner_fft.len()];
let mut x_twiddles = vec![Complex::zero(); len];
calculate_w_twiddles(len, &inner_fft, &mut w_twiddles);
calculate_x_twiddles(len, &mut x_twiddles, inner_fft.is_inverse());
Self {
len,
inner_fft,
w_twiddles: w_twiddles.into_boxed_slice(),
x_twiddles: x_twiddles.into_boxed_slice(),
}
}
fn perform_fft(&self, input: &mut [Complex<T>], output: &mut [Complex<T>]) {
assert_eq!(self.len(), input.len());
let mut scratch = vec![Complex::zero(); 2 * self.inner_fft.len()];
let (mut scratch_a, mut scratch_b) = scratch.split_at_mut(self.inner_fft.len());
if self.inner_fft.is_inverse() {
for (w, (x, i)) in scratch_a
.iter_mut()
.zip(self.x_twiddles.iter().zip(input.iter()))
{
*w = (x * i).conj();
}
} else {
for (w, (x, i)) in scratch_a
.iter_mut()
.zip(self.x_twiddles.iter().zip(input.iter()))
{
*w = x * i;
}
}
self.inner_fft.process(&mut scratch_a, &mut scratch_b);
if self.inner_fft.is_inverse() {
for (w, wi) in scratch_b.iter_mut().zip(self.w_twiddles.iter()) {
*w = w.conj() * wi;
}
}
else {
for (w, wi) in scratch_b.iter_mut().zip(self.w_twiddles.iter()) {
*w = (*w * wi).conj();
}
}
self.inner_fft.process(&mut scratch_b, &mut scratch_a);
if self.inner_fft.is_inverse() {
for (i, (w, xi)) in output
.iter_mut()
.zip(scratch_a.iter().zip(self.x_twiddles.iter()))
{
*i = w * xi;
}
}
else {
for (i, (w, xi)) in output
.iter_mut()
.zip(scratch_a.iter().zip(self.x_twiddles.iter()))
{
*i = w.conj() * xi;
}
}
}
}
impl<T: FFTnum> FFT<T> for Bluesteins<T> {
fn process(&self, input: &mut [Complex<T>], output: &mut [Complex<T>]) {
verify_length(input, output, self.len());
self.perform_fft(input, output);
}
fn process_multi(&self, input: &mut [Complex<T>], output: &mut [Complex<T>]) {
verify_length_divisible(input, output, self.len());
for (in_chunk, out_chunk) in input
.chunks_exact_mut(self.len())
.zip(output.chunks_exact_mut(self.len()))
{
self.perform_fft(in_chunk, out_chunk);
}
}
}
impl<T> Length for Bluesteins<T> {
#[inline(always)]
fn len(&self) -> usize {
self.len
}
}
impl<T> IsInverse for Bluesteins<T> {
#[inline(always)]
fn is_inverse(&self) -> bool {
self.inner_fft.is_inverse()
}
}
#[cfg(test)]
mod unit_tests {
use super::*;
use crate::algorithm::DFT;
use crate::test_utils::check_fft_algorithm;
use std::sync::Arc;
#[test]
fn test_bluestein() {
for &len in &[3, 5, 7, 11, 13, 123] {
test_bluestein_with_length(len, false);
test_bluestein_with_length(len, true);
}
}
fn test_bluestein_with_length(len: usize, inverse: bool) {
let inner_fft_len = (2 * len - 1).checked_next_power_of_two().unwrap();
let inner_fft = Arc::new(DFT::new(inner_fft_len, inverse));
let fft = Bluesteins::new(len, inner_fft);
check_fft_algorithm(&fft, len, inverse);
}
}