rustfft/algorithm/
bluesteins_algorithm.rs1use std::sync::Arc;
2
3use num_complex::Complex;
4use num_traits::Zero;
5
6use crate::{common::FftNum, twiddles, FftDirection};
7use crate::{Direction, Fft, Length};
8
9pub struct BluesteinsAlgorithm<T> {
40 inner_fft: Arc<dyn Fft<T>>,
41
42 inner_fft_multiplier: Box<[Complex<T>]>,
43 twiddles: Box<[Complex<T>]>,
44
45 len: usize,
46 direction: FftDirection,
47}
48
49impl<T: FftNum> BluesteinsAlgorithm<T> {
50 pub fn new(len: usize, inner_fft: Arc<dyn Fft<T>>) -> Self {
59 let inner_fft_len = inner_fft.len();
60 assert!(len * 2 - 1 <= inner_fft_len, "Bluestein's algorithm requires inner_fft.len() >= self.len() * 2 - 1. Expected >= {}, got {}", len * 2 - 1, inner_fft_len);
61
62 let inner_fft_scale = T::one() / T::from_usize(inner_fft_len).unwrap();
64 let direction = inner_fft.fft_direction();
65
66 let mut inner_fft_input = vec![Complex::zero(); inner_fft_len];
68 twiddles::fill_bluesteins_twiddles(
69 &mut inner_fft_input[..len],
70 direction.opposite_direction(),
71 );
72
73 inner_fft_input[0] = inner_fft_input[0] * inner_fft_scale;
75 for i in 1..len {
76 let twiddle = inner_fft_input[i] * inner_fft_scale;
77 inner_fft_input[i] = twiddle;
78 inner_fft_input[inner_fft_len - i] = twiddle;
79 }
80
81 let mut inner_fft_scratch = vec![Complex::zero(); inner_fft.get_inplace_scratch_len()];
83 inner_fft.process_with_scratch(&mut inner_fft_input, &mut inner_fft_scratch);
84
85 let mut twiddles = vec![Complex::zero(); len];
87 twiddles::fill_bluesteins_twiddles(&mut twiddles, direction);
88
89 Self {
90 inner_fft: inner_fft,
91
92 inner_fft_multiplier: inner_fft_input.into_boxed_slice(),
93 twiddles: twiddles.into_boxed_slice(),
94
95 len,
96 direction,
97 }
98 }
99
100 fn perform_fft_inplace(&self, input: &mut [Complex<T>], scratch: &mut [Complex<T>]) {
101 let (inner_input, inner_scratch) = scratch.split_at_mut(self.inner_fft_multiplier.len());
102
103 for ((buffer_entry, inner_entry), twiddle) in input
105 .iter()
106 .zip(inner_input.iter_mut())
107 .zip(self.twiddles.iter())
108 {
109 *inner_entry = *buffer_entry * *twiddle;
110 }
111 for inner in (&mut inner_input[input.len()..]).iter_mut() {
112 *inner = Complex::zero();
113 }
114
115 self.inner_fft
117 .process_with_scratch(inner_input, inner_scratch);
118
119 for (inner, multiplier) in inner_input.iter_mut().zip(self.inner_fft_multiplier.iter()) {
121 *inner = (*inner * *multiplier).conj();
122 }
123
124 self.inner_fft
126 .process_with_scratch(inner_input, inner_scratch);
127
128 for ((buffer_entry, inner_entry), twiddle) in input
130 .iter_mut()
131 .zip(inner_input.iter())
132 .zip(self.twiddles.iter())
133 {
134 *buffer_entry = inner_entry.conj() * twiddle;
135 }
136 }
137
138 #[inline]
139 fn perform_fft_immut(
140 &self,
141 input: &[Complex<T>],
142 output: &mut [Complex<T>],
143 scratch: &mut [Complex<T>],
144 ) {
145 let (inner_input, inner_scratch) = scratch.split_at_mut(self.inner_fft_multiplier.len());
146
147 for ((buffer_entry, inner_entry), twiddle) in input
149 .iter()
150 .zip(inner_input.iter_mut())
151 .zip(self.twiddles.iter())
152 {
153 *inner_entry = *buffer_entry * *twiddle;
154 }
155 for inner in inner_input.iter_mut().skip(input.len()) {
156 *inner = Complex::zero();
157 }
158
159 self.inner_fft
161 .process_with_scratch(inner_input, inner_scratch);
162
163 for (inner, multiplier) in inner_input.iter_mut().zip(self.inner_fft_multiplier.iter()) {
165 *inner = (*inner * *multiplier).conj();
166 }
167
168 self.inner_fft
170 .process_with_scratch(inner_input, inner_scratch);
171
172 for ((buffer_entry, inner_entry), twiddle) in output
174 .iter_mut()
175 .zip(inner_input.iter())
176 .zip(self.twiddles.iter())
177 {
178 *buffer_entry = inner_entry.conj() * twiddle;
179 }
180 }
181
182 fn perform_fft_out_of_place(
183 &self,
184 input: &mut [Complex<T>],
185 output: &mut [Complex<T>],
186 scratch: &mut [Complex<T>],
187 ) {
188 self.perform_fft_immut(input, output, scratch);
189 }
190}
191boilerplate_fft!(
192 BluesteinsAlgorithm,
193 |this: &BluesteinsAlgorithm<_>| this.len, |this: &BluesteinsAlgorithm<_>| this.inner_fft_multiplier.len()
195 + this.inner_fft.get_inplace_scratch_len(), |this: &BluesteinsAlgorithm<_>| this.inner_fft_multiplier.len()
197 + this.inner_fft.get_inplace_scratch_len(), |this: &BluesteinsAlgorithm<_>| this.inner_fft_multiplier.len()
199 + this.inner_fft.get_inplace_scratch_len() );
201
202#[cfg(test)]
203mod unit_tests {
204 use super::*;
205 use crate::algorithm::Dft;
206 use crate::test_utils::check_fft_algorithm;
207 use std::sync::Arc;
208
209 #[test]
210 fn test_bluesteins_scalar() {
211 for &len in &[3, 5, 7, 11, 13] {
212 test_bluesteins_with_length(len, FftDirection::Forward);
213 test_bluesteins_with_length(len, FftDirection::Inverse);
214 }
215 }
216
217 fn test_bluesteins_with_length(len: usize, direction: FftDirection) {
218 let inner_fft = Arc::new(Dft::new(
219 (len * 2 - 1).checked_next_power_of_two().unwrap(),
220 direction,
221 ));
222 let fft = BluesteinsAlgorithm::new(len, inner_fft);
223
224 check_fft_algorithm::<f32>(&fft, len, direction);
225 }
226}