1use crate::{Autosort, Fft, FftFloat, Transform};
2use core::cell::RefCell;
3use core::marker::PhantomData;
4use num_complex::Complex;
5
6#[cfg(not(feature = "std"))]
7use num_traits::Float as _; fn compute_half_twiddle<T: FftFloat>(index: f64, size: usize) -> Complex<T> {
10 let theta = index * core::f64::consts::PI / size as f64;
11 Complex::new(
12 T::from_f64(theta.cos()).unwrap(),
13 T::from_f64(-theta.sin()).unwrap(),
14 )
15}
16
17fn initialize_w_twiddles<
19 T: FftFloat,
20 E: Extend<Complex<T>> + AsMut<[Complex<T>]>,
21 F: Fft<Real = T>,
22>(
23 size: usize,
24 fft: &F,
25 forward_twiddles: &mut E,
26 inverse_twiddles: &mut E,
27) {
28 for i in 0..fft.size() {
29 if let Some(index) = {
30 if i < size {
31 Some((i as f64).powi(2))
32 } else if i > fft.size() - size {
33 Some(((i as f64) - (fft.size() as f64)).powi(2))
34 } else {
35 None
36 }
37 } {
38 let twiddle = compute_half_twiddle(index, size);
39 forward_twiddles.extend(core::iter::once(twiddle.conj()));
40 inverse_twiddles.extend(core::iter::once(twiddle));
41 } else {
42 forward_twiddles.extend(core::iter::once(Complex::default()));
43 inverse_twiddles.extend(core::iter::once(Complex::default()));
44 }
45 }
46 fft.fft_in_place(forward_twiddles.as_mut());
47 fft.fft_in_place(inverse_twiddles.as_mut());
48}
49
50fn initialize_x_twiddles<T: FftFloat, E: Extend<Complex<T>>>(
52 size: usize,
53 forward_twiddles: &mut E,
54 inverse_twiddles: &mut E,
55) {
56 for i in 0..size {
57 let twiddle = compute_half_twiddle(-(i as f64).powi(2), size);
58 forward_twiddles.extend(core::iter::once(twiddle.conj()));
59 inverse_twiddles.extend(core::iter::once(twiddle));
60 }
61}
62
63pub struct Bluesteins<T, InnerFft, WTwiddles, XTwiddles, Work> {
65 size: usize,
66 inner_fft: InnerFft,
67 w_forward: WTwiddles,
68 w_inverse: WTwiddles,
69 x_forward: XTwiddles,
70 x_inverse: XTwiddles,
71 work: RefCell<Work>,
72 real_type: PhantomData<T>,
73}
74
75impl<T, InnerFft, WTwiddles, XTwiddles, Work> Bluesteins<T, InnerFft, WTwiddles, XTwiddles, Work> {
76 pub unsafe fn new_from_parts(
79 size: usize,
80 inner_fft: InnerFft,
81 w_forward: WTwiddles,
82 w_inverse: WTwiddles,
83 x_forward: XTwiddles,
84 x_inverse: XTwiddles,
85 work: Work,
86 ) -> Self {
87 Self {
88 size,
89 inner_fft,
90 w_forward,
91 w_inverse,
92 x_forward,
93 x_inverse,
94 work: RefCell::new(work),
95 real_type: PhantomData,
96 }
97 }
98}
99
100impl<
101 T: FftFloat,
102 InnerFft: Fft<Real = T>,
103 WTwiddles: Default + Extend<Complex<T>> + AsMut<[Complex<T>]>,
104 XTwiddles: Default + Extend<Complex<T>>,
105 Work: Default + Extend<Complex<T>>,
106 > Bluesteins<T, InnerFft, WTwiddles, XTwiddles, Work>
107{
108 pub fn new_with_fft<F: Fn(usize) -> InnerFft>(size: usize, inner_fft_maker: F) -> Self {
110 let inner_size = (2 * size - 1).checked_next_power_of_two().unwrap();
111 let inner_fft = inner_fft_maker(inner_size);
112 let mut w_forward = WTwiddles::default();
113 let mut w_inverse = WTwiddles::default();
114 let mut x_forward = XTwiddles::default();
115 let mut x_inverse = XTwiddles::default();
116 initialize_w_twiddles(size, &inner_fft, &mut w_forward, &mut w_inverse);
117 initialize_x_twiddles(size, &mut x_forward, &mut x_inverse);
118 let mut work = Work::default();
119 work.extend(core::iter::repeat(Complex::default()).take(inner_fft.size()));
120 Self {
121 size,
122 inner_fft,
123 w_forward,
124 w_inverse,
125 x_forward,
126 x_inverse,
127 work: RefCell::new(work),
128 real_type: PhantomData,
129 }
130 }
131}
132
133impl<
134 T: FftFloat,
135 InnerFft: Fft<Real = T>,
136 WTwiddles: AsRef<[Complex<T>]>,
137 XTwiddles: AsRef<[Complex<T>]>,
138 Work: AsRef<[Complex<T>]>,
139 > Bluesteins<T, InnerFft, WTwiddles, XTwiddles, Work>
140{
141 pub fn w_twiddles(&self) -> (&[Complex<T>], &[Complex<T>]) {
143 (self.w_forward.as_ref(), self.w_inverse.as_ref())
144 }
145
146 pub fn x_twiddles(&self) -> (&[Complex<T>], &[Complex<T>]) {
148 (self.x_forward.as_ref(), self.x_inverse.as_ref())
149 }
150
151 pub fn inner_fft_size(&self) -> usize {
153 self.inner_fft.size()
154 }
155
156 pub fn work_size(&self) -> usize {
158 self.work.borrow().as_ref().len()
159 }
160}
161
162macro_rules! implement {
163 {
164 $type:ty
165 } => {
166 impl<
167 AutosortTwiddles: Default + Extend<Complex<$type>> + AsRef<[Complex<$type>]>,
168 AutosortWork: Default + Extend<Complex<$type>> + AsMut<[Complex<$type>]>,
169 WTwiddles: Default + Extend<Complex<$type>> + AsMut<[Complex<$type>]>,
170 XTwiddles: Default + Extend<Complex<$type>>,
171 Work: Default + Extend<Complex<$type>>,
172 > Bluesteins<$type, Autosort<$type, AutosortTwiddles, AutosortWork>, WTwiddles, XTwiddles, Work>
173 {
174 pub fn new(size: usize) -> Self {
176 Self::new_with_fft(size, |size| Autosort::new(size).unwrap())
177 }
178 }
179
180 impl<
181 InnerFft: Fft<Real = $type>,
182 WTwiddles: AsRef<[Complex<$type>]>,
183 XTwiddles: AsRef<[Complex<$type>]>,
184 Work: AsMut<[Complex<$type>]>,
185 > Fft for Bluesteins<$type, InnerFft, WTwiddles, XTwiddles, Work>
186 {
187 type Real = $type;
188
189 fn size(&self) -> usize {
190 self.size
191 }
192
193 fn transform_in_place(&self, input: &mut [Complex<$type>], transform: Transform) {
194 let mut work = self.work.borrow_mut();
195 let (x, w) = if transform.is_forward() {
196 (&self.x_forward, &self.w_forward)
197 } else {
198 (&self.x_inverse, &self.w_inverse)
199 };
200 apply(
201 input,
202 work.as_mut(),
203 x.as_ref(),
204 w.as_ref(),
205 &self.inner_fft,
206 transform,
207 );
208 }
209 }
210 }
211}
212implement! { f32 }
213implement! { f64 }
214
215#[multiversion::multiversion]
216#[clone(target = "[x86|x86_64]+avx")]
217#[inline]
218fn apply<T: FftFloat, F: Fft<Real = T>>(
219 input: &mut [Complex<T>],
220 work: &mut [Complex<T>],
221 x: &[Complex<T>],
222 w: &[Complex<T>],
223 fft: &F,
224 transform: Transform,
225) {
226 assert_eq!(x.len(), input.len());
227
228 let size = input.len();
229 for (w, (x, i)) in work.iter_mut().zip(x.iter().zip(input.iter())) {
230 *w = x * i;
231 }
232 for w in work[size..].iter_mut() {
233 *w = Complex::default();
234 }
235 fft.fft_in_place(work);
236 for (w, wi) in work.iter_mut().zip(w.iter()) {
237 *w *= wi;
238 }
239 fft.ifft_in_place(work);
240 match transform {
241 Transform::Fft | Transform::UnscaledIfft => {
242 for (i, (w, xi)) in input.iter_mut().zip(work.iter().zip(x.iter())) {
243 *i = w * xi;
244 }
245 }
246 Transform::Ifft => {
247 let scale = T::one() / T::from_usize(size).unwrap();
248 for (i, (w, xi)) in input.iter_mut().zip(work.iter().zip(x.iter())) {
249 *i = w * xi * scale;
250 }
251 }
252 Transform::SqrtScaledFft | Transform::SqrtScaledIfft => {
253 let scale = T::one() / T::sqrt(T::from_usize(size).unwrap());
254 for (i, (w, xi)) in input.iter_mut().zip(work.iter().zip(x.iter())) {
255 *i = w * xi * scale;
256 }
257 }
258 }
259}