use crate::{Autosort, Fft, FftFloat, Transform};
use core::cell::RefCell;
use core::marker::PhantomData;
use num_complex::Complex;
#[cfg(not(feature = "std"))]
use num_traits::Float as _;
fn compute_half_twiddle<T: FftFloat>(index: f64, size: usize) -> Complex<T> {
let theta = index * core::f64::consts::PI / size as f64;
Complex::new(
T::from_f64(theta.cos()).unwrap(),
T::from_f64(-theta.sin()).unwrap(),
)
}
fn initialize_w_twiddles<
T: FftFloat,
E: Extend<Complex<T>> + AsMut<[Complex<T>]>,
F: Fft<Real = T>,
>(
size: usize,
fft: &F,
forward_twiddles: &mut E,
inverse_twiddles: &mut E,
) {
for i in 0..fft.size() {
if let Some(index) = {
if i < size {
Some((i as f64).powi(2))
} else if i > fft.size() - size {
Some(((i as f64) - (fft.size() as f64)).powi(2))
} else {
None
}
} {
let twiddle = compute_half_twiddle(index, size);
forward_twiddles.extend(core::iter::once(twiddle.conj()));
inverse_twiddles.extend(core::iter::once(twiddle));
} else {
forward_twiddles.extend(core::iter::once(Complex::default()));
inverse_twiddles.extend(core::iter::once(Complex::default()));
}
}
fft.fft_in_place(forward_twiddles.as_mut());
fft.fft_in_place(inverse_twiddles.as_mut());
}
fn initialize_x_twiddles<T: FftFloat, E: Extend<Complex<T>>>(
size: usize,
forward_twiddles: &mut E,
inverse_twiddles: &mut E,
) {
for i in 0..size {
let twiddle = compute_half_twiddle(-(i as f64).powi(2), size);
forward_twiddles.extend(core::iter::once(twiddle.conj()));
inverse_twiddles.extend(core::iter::once(twiddle));
}
}
pub struct Bluesteins<T, InnerFft, WTwiddles, XTwiddles, Work> {
size: usize,
inner_fft: InnerFft,
w_forward: WTwiddles,
w_inverse: WTwiddles,
x_forward: XTwiddles,
x_inverse: XTwiddles,
work: RefCell<Work>,
real_type: PhantomData<T>,
}
impl<T, InnerFft, WTwiddles, XTwiddles, Work> Bluesteins<T, InnerFft, WTwiddles, XTwiddles, Work> {
pub unsafe fn new_from_parts(
size: usize,
inner_fft: InnerFft,
w_forward: WTwiddles,
w_inverse: WTwiddles,
x_forward: XTwiddles,
x_inverse: XTwiddles,
work: Work,
) -> Self {
Self {
size,
inner_fft,
w_forward,
w_inverse,
x_forward,
x_inverse,
work: RefCell::new(work),
real_type: PhantomData,
}
}
}
impl<
T: FftFloat,
InnerFft: Fft<Real = T>,
WTwiddles: Default + Extend<Complex<T>> + AsMut<[Complex<T>]>,
XTwiddles: Default + Extend<Complex<T>>,
Work: Default + Extend<Complex<T>>,
> Bluesteins<T, InnerFft, WTwiddles, XTwiddles, Work>
{
pub fn new_with_fft<F: Fn(usize) -> InnerFft>(size: usize, inner_fft_maker: F) -> Self {
let inner_size = (2 * size - 1).checked_next_power_of_two().unwrap();
let inner_fft = inner_fft_maker(inner_size);
let mut w_forward = WTwiddles::default();
let mut w_inverse = WTwiddles::default();
let mut x_forward = XTwiddles::default();
let mut x_inverse = XTwiddles::default();
initialize_w_twiddles(size, &inner_fft, &mut w_forward, &mut w_inverse);
initialize_x_twiddles(size, &mut x_forward, &mut x_inverse);
let mut work = Work::default();
work.extend(core::iter::repeat(Complex::default()).take(inner_fft.size()));
Self {
size,
inner_fft,
w_forward,
w_inverse,
x_forward,
x_inverse,
work: RefCell::new(work),
real_type: PhantomData,
}
}
}
impl<
T: FftFloat,
InnerFft: Fft<Real = T>,
WTwiddles: AsRef<[Complex<T>]>,
XTwiddles: AsRef<[Complex<T>]>,
Work: AsRef<[Complex<T>]>,
> Bluesteins<T, InnerFft, WTwiddles, XTwiddles, Work>
{
pub fn w_twiddles(&self) -> (&[Complex<T>], &[Complex<T>]) {
(self.w_forward.as_ref(), self.w_inverse.as_ref())
}
pub fn x_twiddles(&self) -> (&[Complex<T>], &[Complex<T>]) {
(self.x_forward.as_ref(), self.x_inverse.as_ref())
}
pub fn inner_fft_size(&self) -> usize {
self.inner_fft.size()
}
pub fn work_size(&self) -> usize {
self.work.borrow().as_ref().len()
}
}
macro_rules! implement {
{
$type:ty
} => {
impl<
AutosortTwiddles: Default + Extend<Complex<$type>> + AsRef<[Complex<$type>]>,
AutosortWork: Default + Extend<Complex<$type>> + AsMut<[Complex<$type>]>,
WTwiddles: Default + Extend<Complex<$type>> + AsMut<[Complex<$type>]>,
XTwiddles: Default + Extend<Complex<$type>>,
Work: Default + Extend<Complex<$type>>,
> Bluesteins<$type, Autosort<$type, AutosortTwiddles, AutosortWork>, WTwiddles, XTwiddles, Work>
{
pub fn new(size: usize) -> Self {
Self::new_with_fft(size, |size| Autosort::new(size).unwrap())
}
}
impl<
InnerFft: Fft<Real = $type>,
WTwiddles: AsRef<[Complex<$type>]>,
XTwiddles: AsRef<[Complex<$type>]>,
Work: AsMut<[Complex<$type>]>,
> Fft for Bluesteins<$type, InnerFft, WTwiddles, XTwiddles, Work>
{
type Real = $type;
fn size(&self) -> usize {
self.size
}
fn transform_in_place(&self, input: &mut [Complex<$type>], transform: Transform) {
let mut work = self.work.borrow_mut();
let (x, w) = if transform.is_forward() {
(&self.x_forward, &self.w_forward)
} else {
(&self.x_inverse, &self.w_inverse)
};
apply(
input,
work.as_mut(),
x.as_ref(),
w.as_ref(),
&self.inner_fft,
transform,
);
}
}
}
}
implement! { f32 }
implement! { f64 }
#[multiversion::multiversion]
#[clone(target = "[x86|x86_64]+avx")]
#[inline]
fn apply<T: FftFloat, F: Fft<Real = T>>(
input: &mut [Complex<T>],
work: &mut [Complex<T>],
x: &[Complex<T>],
w: &[Complex<T>],
fft: &F,
transform: Transform,
) {
assert_eq!(x.len(), input.len());
let size = input.len();
for (w, (x, i)) in work.iter_mut().zip(x.iter().zip(input.iter())) {
*w = x * i;
}
for w in work[size..].iter_mut() {
*w = Complex::default();
}
fft.fft_in_place(work);
for (w, wi) in work.iter_mut().zip(w.iter()) {
*w *= wi;
}
fft.ifft_in_place(work);
match transform {
Transform::Fft | Transform::UnscaledIfft => {
for (i, (w, xi)) in input.iter_mut().zip(work.iter().zip(x.iter())) {
*i = w * xi;
}
}
Transform::Ifft => {
let scale = T::one() / T::from_usize(size).unwrap();
for (i, (w, xi)) in input.iter_mut().zip(work.iter().zip(x.iter())) {
*i = w * xi * scale;
}
}
Transform::SqrtScaledFft | Transform::SqrtScaledIfft => {
let scale = T::one() / T::sqrt(T::from_usize(size).unwrap());
for (i, (w, xi)) in input.iter_mut().zip(work.iter().zip(x.iter())) {
*i = w * xi * scale;
}
}
}
}