use std::sync::Arc;
use num_complex::Complex;
use num_traits::Zero;
use crate::{common::FftNum, twiddles, FftDirection};
use crate::{Direction, Fft, Length};
pub struct BluesteinsAlgorithm<T> {
inner_fft: Arc<dyn Fft<T>>,
inner_fft_multiplier: Box<[Complex<T>]>,
twiddles: Box<[Complex<T>]>,
len: usize,
direction: FftDirection,
}
impl<T: FftNum> BluesteinsAlgorithm<T> {
pub fn new(len: usize, inner_fft: Arc<dyn Fft<T>>) -> Self {
let inner_fft_len = inner_fft.len();
assert!(len * 2 - 1 <= inner_fft_len, "Bluestein's algorithm requires inner_fft.len() >= self.len() * 2 - 1. Expected >= {}, got {}", len * 2 - 1, inner_fft_len);
let inner_fft_scale = T::one() / T::from_usize(inner_fft_len).unwrap();
let direction = inner_fft.fft_direction();
let mut inner_fft_input = vec![Complex::zero(); inner_fft_len];
twiddles::fill_bluesteins_twiddles(
&mut inner_fft_input[..len],
direction.opposite_direction(),
);
inner_fft_input[0] = inner_fft_input[0] * inner_fft_scale;
for i in 1..len {
let twiddle = inner_fft_input[i] * inner_fft_scale;
inner_fft_input[i] = twiddle;
inner_fft_input[inner_fft_len - i] = twiddle;
}
let mut inner_fft_scratch = vec![Complex::zero(); inner_fft.get_inplace_scratch_len()];
inner_fft.process_with_scratch(&mut inner_fft_input, &mut inner_fft_scratch);
let mut twiddles = vec![Complex::zero(); len];
twiddles::fill_bluesteins_twiddles(&mut twiddles, direction);
Self {
inner_fft: inner_fft,
inner_fft_multiplier: inner_fft_input.into_boxed_slice(),
twiddles: twiddles.into_boxed_slice(),
len,
direction,
}
}
fn perform_fft_inplace(&self, input: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
let (inner_input, inner_scratch) = scratch.split_at_mut(self.inner_fft_multiplier.len());
for ((buffer_entry, inner_entry), twiddle) in input
.iter()
.zip(inner_input.iter_mut())
.zip(self.twiddles.iter())
{
*inner_entry = *buffer_entry * *twiddle;
}
for inner in (&mut inner_input[input.len()..]).iter_mut() {
*inner = Complex::zero();
}
self.inner_fft
.process_with_scratch(inner_input, inner_scratch);
for (inner, multiplier) in inner_input.iter_mut().zip(self.inner_fft_multiplier.iter()) {
*inner = (*inner * *multiplier).conj();
}
self.inner_fft
.process_with_scratch(inner_input, inner_scratch);
for ((buffer_entry, inner_entry), twiddle) in input
.iter_mut()
.zip(inner_input.iter())
.zip(self.twiddles.iter())
{
*buffer_entry = inner_entry.conj() * twiddle;
}
}
#[inline]
fn perform_fft_immut(
&self,
input: &[Complex<T>],
output: &mut [Complex<T>],
scratch: &mut [Complex<T>],
) {
let (inner_input, inner_scratch) = scratch.split_at_mut(self.inner_fft_multiplier.len());
for ((buffer_entry, inner_entry), twiddle) in input
.iter()
.zip(inner_input.iter_mut())
.zip(self.twiddles.iter())
{
*inner_entry = *buffer_entry * *twiddle;
}
for inner in inner_input.iter_mut().skip(input.len()) {
*inner = Complex::zero();
}
self.inner_fft
.process_with_scratch(inner_input, inner_scratch);
for (inner, multiplier) in inner_input.iter_mut().zip(self.inner_fft_multiplier.iter()) {
*inner = (*inner * *multiplier).conj();
}
self.inner_fft
.process_with_scratch(inner_input, inner_scratch);
for ((buffer_entry, inner_entry), twiddle) in output
.iter_mut()
.zip(inner_input.iter())
.zip(self.twiddles.iter())
{
*buffer_entry = inner_entry.conj() * twiddle;
}
}
fn perform_fft_out_of_place(
&self,
input: &mut [Complex<T>],
output: &mut [Complex<T>],
scratch: &mut [Complex<T>],
) {
self.perform_fft_immut(input, output, scratch);
}
}
boilerplate_fft!(
BluesteinsAlgorithm,
|this: &BluesteinsAlgorithm<_>| this.len, |this: &BluesteinsAlgorithm<_>| this.inner_fft_multiplier.len()
+ this.inner_fft.get_inplace_scratch_len(), |this: &BluesteinsAlgorithm<_>| this.inner_fft_multiplier.len()
+ this.inner_fft.get_inplace_scratch_len(), |this: &BluesteinsAlgorithm<_>| this.inner_fft_multiplier.len()
+ this.inner_fft.get_inplace_scratch_len() );
#[cfg(test)]
mod unit_tests {
use super::*;
use crate::algorithm::Dft;
use crate::test_utils::check_fft_algorithm;
use std::sync::Arc;
#[test]
fn test_bluesteins_scalar() {
for &len in &[3, 5, 7, 11, 13] {
test_bluesteins_with_length(len, FftDirection::Forward);
test_bluesteins_with_length(len, FftDirection::Inverse);
}
}
fn test_bluesteins_with_length(len: usize, direction: FftDirection) {
let inner_fft = Arc::new(Dft::new(
(len * 2 - 1).checked_next_power_of_two().unwrap(),
direction,
));
let fft = BluesteinsAlgorithm::new(len, inner_fft);
check_fft_algorithm::<f32>(&fft, len, direction);
}
}