fourier_algorithms/
bluesteins.rs

1use crate::{Autosort, Fft, FftFloat, Transform};
2use core::cell::RefCell;
3use core::marker::PhantomData;
4use num_complex::Complex;
5
6#[cfg(not(feature = "std"))]
7use num_traits::Float as _; // enable sqrt, powi without std
8
9fn compute_half_twiddle<T: FftFloat>(index: f64, size: usize) -> Complex<T> {
10    let theta = index * core::f64::consts::PI / size as f64;
11    Complex::new(
12        T::from_f64(theta.cos()).unwrap(),
13        T::from_f64(-theta.sin()).unwrap(),
14    )
15}
16
17/// Initialize the "w" twiddles.
18fn initialize_w_twiddles<
19    T: FftFloat,
20    E: Extend<Complex<T>> + AsMut<[Complex<T>]>,
21    F: Fft<Real = T>,
22>(
23    size: usize,
24    fft: &F,
25    forward_twiddles: &mut E,
26    inverse_twiddles: &mut E,
27) {
28    for i in 0..fft.size() {
29        if let Some(index) = {
30            if i < size {
31                Some((i as f64).powi(2))
32            } else if i > fft.size() - size {
33                Some(((i as f64) - (fft.size() as f64)).powi(2))
34            } else {
35                None
36            }
37        } {
38            let twiddle = compute_half_twiddle(index, size);
39            forward_twiddles.extend(core::iter::once(twiddle.conj()));
40            inverse_twiddles.extend(core::iter::once(twiddle));
41        } else {
42            forward_twiddles.extend(core::iter::once(Complex::default()));
43            inverse_twiddles.extend(core::iter::once(Complex::default()));
44        }
45    }
46    fft.fft_in_place(forward_twiddles.as_mut());
47    fft.fft_in_place(inverse_twiddles.as_mut());
48}
49
50/// Initialize the "x" twiddles.
51fn initialize_x_twiddles<T: FftFloat, E: Extend<Complex<T>>>(
52    size: usize,
53    forward_twiddles: &mut E,
54    inverse_twiddles: &mut E,
55) {
56    for i in 0..size {
57        let twiddle = compute_half_twiddle(-(i as f64).powi(2), size);
58        forward_twiddles.extend(core::iter::once(twiddle.conj()));
59        inverse_twiddles.extend(core::iter::once(twiddle));
60    }
61}
62
63/// Implements Bluestein's algorithm for arbitrary FFT sizes.
64pub struct Bluesteins<T, InnerFft, WTwiddles, XTwiddles, Work> {
65    size: usize,
66    inner_fft: InnerFft,
67    w_forward: WTwiddles,
68    w_inverse: WTwiddles,
69    x_forward: XTwiddles,
70    x_inverse: XTwiddles,
71    work: RefCell<Work>,
72    real_type: PhantomData<T>,
73}
74
75impl<T, InnerFft, WTwiddles, XTwiddles, Work> Bluesteins<T, InnerFft, WTwiddles, XTwiddles, Work> {
76    /// Create a new transform generator from parts.  Twiddles factors and work must be the correct
77    /// size.
78    pub unsafe fn new_from_parts(
79        size: usize,
80        inner_fft: InnerFft,
81        w_forward: WTwiddles,
82        w_inverse: WTwiddles,
83        x_forward: XTwiddles,
84        x_inverse: XTwiddles,
85        work: Work,
86    ) -> Self {
87        Self {
88            size,
89            inner_fft,
90            w_forward,
91            w_inverse,
92            x_forward,
93            x_inverse,
94            work: RefCell::new(work),
95            real_type: PhantomData,
96        }
97    }
98}
99
100impl<
101        T: FftFloat,
102        InnerFft: Fft<Real = T>,
103        WTwiddles: Default + Extend<Complex<T>> + AsMut<[Complex<T>]>,
104        XTwiddles: Default + Extend<Complex<T>>,
105        Work: Default + Extend<Complex<T>>,
106    > Bluesteins<T, InnerFft, WTwiddles, XTwiddles, Work>
107{
108    /// Create a new Bluestein's algorithm generator.
109    pub fn new_with_fft<F: Fn(usize) -> InnerFft>(size: usize, inner_fft_maker: F) -> Self {
110        let inner_size = (2 * size - 1).checked_next_power_of_two().unwrap();
111        let inner_fft = inner_fft_maker(inner_size);
112        let mut w_forward = WTwiddles::default();
113        let mut w_inverse = WTwiddles::default();
114        let mut x_forward = XTwiddles::default();
115        let mut x_inverse = XTwiddles::default();
116        initialize_w_twiddles(size, &inner_fft, &mut w_forward, &mut w_inverse);
117        initialize_x_twiddles(size, &mut x_forward, &mut x_inverse);
118        let mut work = Work::default();
119        work.extend(core::iter::repeat(Complex::default()).take(inner_fft.size()));
120        Self {
121            size,
122            inner_fft,
123            w_forward,
124            w_inverse,
125            x_forward,
126            x_inverse,
127            work: RefCell::new(work),
128            real_type: PhantomData,
129        }
130    }
131}
132
133impl<
134        T: FftFloat,
135        InnerFft: Fft<Real = T>,
136        WTwiddles: AsRef<[Complex<T>]>,
137        XTwiddles: AsRef<[Complex<T>]>,
138        Work: AsRef<[Complex<T>]>,
139    > Bluesteins<T, InnerFft, WTwiddles, XTwiddles, Work>
140{
141    /// Return the w-twiddle factors.
142    pub fn w_twiddles(&self) -> (&[Complex<T>], &[Complex<T>]) {
143        (self.w_forward.as_ref(), self.w_inverse.as_ref())
144    }
145
146    /// Return the w-twiddle factors.
147    pub fn x_twiddles(&self) -> (&[Complex<T>], &[Complex<T>]) {
148        (self.x_forward.as_ref(), self.x_inverse.as_ref())
149    }
150
151    /// Return the inner FFT size.
152    pub fn inner_fft_size(&self) -> usize {
153        self.inner_fft.size()
154    }
155
156    /// Return the work buffer size.
157    pub fn work_size(&self) -> usize {
158        self.work.borrow().as_ref().len()
159    }
160}
161
162macro_rules! implement {
163    {
164        $type:ty
165    } => {
166        impl<
167                AutosortTwiddles: Default + Extend<Complex<$type>> + AsRef<[Complex<$type>]>,
168                AutosortWork: Default + Extend<Complex<$type>> + AsMut<[Complex<$type>]>,
169                WTwiddles: Default + Extend<Complex<$type>> + AsMut<[Complex<$type>]>,
170                XTwiddles: Default + Extend<Complex<$type>>,
171                Work: Default + Extend<Complex<$type>>,
172            > Bluesteins<$type, Autosort<$type, AutosortTwiddles, AutosortWork>, WTwiddles, XTwiddles, Work>
173        {
174            /// Create a new Bluestein's algorithm generator.
175            pub fn new(size: usize) -> Self {
176                Self::new_with_fft(size, |size| Autosort::new(size).unwrap())
177            }
178        }
179
180        impl<
181                InnerFft: Fft<Real = $type>,
182                WTwiddles: AsRef<[Complex<$type>]>,
183                XTwiddles: AsRef<[Complex<$type>]>,
184                Work: AsMut<[Complex<$type>]>,
185            > Fft for Bluesteins<$type, InnerFft, WTwiddles, XTwiddles, Work>
186        {
187            type Real = $type;
188
189            fn size(&self) -> usize {
190                self.size
191            }
192
193            fn transform_in_place(&self, input: &mut [Complex<$type>], transform: Transform) {
194                let mut work = self.work.borrow_mut();
195                let (x, w) = if transform.is_forward() {
196                    (&self.x_forward, &self.w_forward)
197                } else {
198                    (&self.x_inverse, &self.w_inverse)
199                };
200                apply(
201                    input,
202                    work.as_mut(),
203                    x.as_ref(),
204                    w.as_ref(),
205                    &self.inner_fft,
206                    transform,
207                );
208            }
209        }
210    }
211}
212implement! { f32 }
213implement! { f64 }
214
215#[multiversion::multiversion]
216#[clone(target = "[x86|x86_64]+avx")]
217#[inline]
218fn apply<T: FftFloat, F: Fft<Real = T>>(
219    input: &mut [Complex<T>],
220    work: &mut [Complex<T>],
221    x: &[Complex<T>],
222    w: &[Complex<T>],
223    fft: &F,
224    transform: Transform,
225) {
226    assert_eq!(x.len(), input.len());
227
228    let size = input.len();
229    for (w, (x, i)) in work.iter_mut().zip(x.iter().zip(input.iter())) {
230        *w = x * i;
231    }
232    for w in work[size..].iter_mut() {
233        *w = Complex::default();
234    }
235    fft.fft_in_place(work);
236    for (w, wi) in work.iter_mut().zip(w.iter()) {
237        *w *= wi;
238    }
239    fft.ifft_in_place(work);
240    match transform {
241        Transform::Fft | Transform::UnscaledIfft => {
242            for (i, (w, xi)) in input.iter_mut().zip(work.iter().zip(x.iter())) {
243                *i = w * xi;
244            }
245        }
246        Transform::Ifft => {
247            let scale = T::one() / T::from_usize(size).unwrap();
248            for (i, (w, xi)) in input.iter_mut().zip(work.iter().zip(x.iter())) {
249                *i = w * xi * scale;
250            }
251        }
252        Transform::SqrtScaledFft | Transform::SqrtScaledIfft => {
253            let scale = T::one() / T::sqrt(T::from_usize(size).unwrap());
254            for (i, (w, xi)) in input.iter_mut().zip(work.iter().zip(x.iter())) {
255                *i = w * xi * scale;
256            }
257        }
258    }
259}