1#![allow(unused_unsafe)]
2#![allow(unused_macros)]
3
4#[macro_use]
5mod butterfly;
6#[macro_use]
7mod avx_optimization;
8
9use crate::fft::{Fft, Transform};
10use crate::float::FftFloat;
11use crate::twiddle::compute_twiddle;
12use core::cell::RefCell;
13use core::marker::PhantomData;
14use num_complex::Complex;
15use num_traits::One as _;
16
17#[cfg(not(feature = "std"))]
18use num_traits::Float as _; const NUM_RADICES: usize = 5;
21const RADICES: [usize; NUM_RADICES] = [4, 8, 4, 3, 2];
22
23fn initialize_twiddles<T: FftFloat, E: Extend<Complex<T>>>(
25 mut size: usize,
26 counts: [usize; NUM_RADICES],
27 forward_twiddles: &mut E,
28 inverse_twiddles: &mut E,
29) {
30 let mut stride = 1;
31 for (radix, count) in RADICES.iter().zip(&counts) {
32 for _ in 0..*count {
33 let m = size / radix;
34 for i in 0..m {
35 forward_twiddles.extend(core::iter::once(Complex::<T>::one()));
36 inverse_twiddles.extend(core::iter::once(Complex::<T>::one()));
37 for j in 1..*radix {
38 forward_twiddles.extend(core::iter::once(compute_twiddle(i * j, size, true)));
39 inverse_twiddles.extend(core::iter::once(compute_twiddle(i * j, size, false)));
40 }
41 }
42 size /= radix;
43 stride *= radix;
44 }
45 }
46}
47
48pub struct Autosort<T, Twiddles, Work> {
50 size: usize,
51 counts: [usize; NUM_RADICES],
52 forward_twiddles: Twiddles,
53 inverse_twiddles: Twiddles,
54 work: RefCell<Work>,
55 real_type: PhantomData<T>,
56}
57
58impl<T, Twiddles, Work> Autosort<T, Twiddles, Work> {
59 pub fn counts(&self) -> [usize; NUM_RADICES] {
61 self.counts
62 }
63
64 pub unsafe fn new_from_parts(
67 size: usize,
68 counts: [usize; NUM_RADICES],
69 forward_twiddles: Twiddles,
70 inverse_twiddles: Twiddles,
71 work: Work,
72 ) -> Self {
73 Self {
74 size,
75 counts,
76 forward_twiddles,
77 inverse_twiddles,
78 work: RefCell::new(work),
79 real_type: PhantomData,
80 }
81 }
82}
83
84impl<T, Twiddles: AsRef<[Complex<T>]>, Work: AsRef<[Complex<T>]>> Autosort<T, Twiddles, Work> {
85 pub fn twiddles(&self) -> (&[Complex<T>], &[Complex<T>]) {
87 (
88 self.forward_twiddles.as_ref(),
89 self.inverse_twiddles.as_ref(),
90 )
91 }
92
93 pub fn work_size(&self) -> usize {
95 self.work.borrow().as_ref().len()
96 }
97}
98
99impl<T: FftFloat, Twiddles: Default + Extend<Complex<T>>, Work: Default + Extend<Complex<T>>>
100 Autosort<T, Twiddles, Work>
101{
102 pub fn new(size: usize) -> Option<Self> {
105 let mut current_size = size;
106 let mut counts = [0usize; NUM_RADICES];
107 if current_size % RADICES[0] == 0 {
108 current_size /= RADICES[0];
109 counts[0] = 1;
110 }
111 for (count, radix) in counts.iter_mut().zip(&RADICES).skip(1) {
112 while current_size % radix == 0 {
113 current_size /= radix;
114 *count += 1;
115 }
116 }
117 if current_size == 1 {
118 let mut forward_twiddles = Twiddles::default();
119 let mut inverse_twiddles = Twiddles::default();
120 initialize_twiddles(size, counts, &mut forward_twiddles, &mut inverse_twiddles);
121 let mut work = Work::default();
122 work.extend(core::iter::repeat(Complex::default()).take(size));
123 Some(Self {
124 size,
125 counts,
126 forward_twiddles,
127 inverse_twiddles,
128 work: RefCell::new(work),
129 real_type: PhantomData,
130 })
131 } else {
132 None
133 }
134 }
135}
136
137macro_rules! implement {
138 {
139 $type:ty, $apply:ident
140 } => {
141 impl<Twiddles: AsRef<[Complex<$type>]>, Work: AsMut<[Complex<$type>]>> Fft
142 for Autosort<$type, Twiddles, Work>
143 {
144 type Real = $type;
145
146 fn size(&self) -> usize {
147 self.size
148 }
149
150 fn transform_in_place(&self, input: &mut [Complex<$type>], transform: Transform) {
151 let mut work = self.work.borrow_mut();
152 let twiddles = if transform.is_forward() {
153 &self.forward_twiddles
154 } else {
155 &self.inverse_twiddles
156 };
157 $apply(
158 input,
159 work.as_mut(),
160 &self.counts,
161 twiddles.as_ref(),
162 self.size,
163 transform,
164 );
165 }
166 }
167 }
168}
169implement! { f32, apply_stages_f32 }
170implement! { f64, apply_stages_f64 }
171
172macro_rules! make_radix_fns {
175 {
176 @impl $type:ident, $wide:literal, $radix:literal, $name:ident, $butterfly:ident
177 } => {
178
179 #[multiversion::multiversion]
180 #[clone(target = "[x86|x86_64]+avx")]
181 #[inline]
182 pub fn $name(
183 input: &[num_complex::Complex<$type>],
184 output: &mut [num_complex::Complex<$type>],
185 _forward: bool,
186 size: usize,
187 stride: usize,
188 cached_twiddles: &[num_complex::Complex<$type>],
189 ) {
190 #[target_cfg(target = "[x86|x86_64]+avx")]
191 crate::avx_vector! { $type };
192
193 #[target_cfg(not(target = "[x86|x86_64]+avx"))]
194 crate::generic_vector! { $type };
195
196 #[target_cfg(target = "[x86|x86_64]+avx")]
197 {
198 if !$wide && crate::avx_optimization!($type, $radix, input, output, _forward, size, stride, cached_twiddles) {
199 return
200 }
201 }
202
203 let m = size / $radix;
204
205 let (full_count, final_offset) = if $wide {
206 (Some(((stride - 1) / width!()) * width!()), Some(stride - width!()))
207 } else {
208 (None, None)
209 };
210
211 for i in 0..m {
212 if $wide {
214 let twiddles = {
215 let mut twiddles = [zeroed!(); $radix];
216 for k in 1..$radix {
217 twiddles[k] = unsafe {
218 broadcast!(cached_twiddles.as_ptr().add(i * $radix + k).read())
219 };
220 }
221 twiddles
222 };
223
224 for j in (0..full_count.unwrap())
226 .step_by(width!())
227 .chain(core::iter::once(final_offset.unwrap()))
228 {
229 let mut scratch = [zeroed!(); $radix];
231 let load = unsafe { input.as_ptr().add(j + stride * i) };
232 for k in 0..$radix {
233 scratch[k] = unsafe { load_wide!(load.add(stride * k * m)) };
234 }
235
236 scratch = $butterfly!($type, scratch, _forward);
238 if size != $radix {
239 for k in 1..$radix {
240 scratch[k] = mul!(scratch[k], twiddles[k]);
241 }
242 }
243
244 let store = unsafe { output.as_mut_ptr().add(j + $radix * stride * i) };
246 for k in 0..$radix {
247 unsafe { store_wide!(scratch[k], store.add(stride * k)) };
248 }
249 }
250 } else {
251 let twiddles = {
252 let mut twiddles = [zeroed!(); $radix];
253 for k in 1..$radix {
254 twiddles[k] = unsafe {
255 load_narrow!(cached_twiddles.as_ptr().add(i * $radix + k))
256 };
257 }
258 twiddles
259 };
260
261 let load = unsafe { input.as_ptr().add(stride * i) };
262 let store = unsafe { output.as_mut_ptr().add($radix * stride * i) };
263 for j in 0..stride {
264 let mut scratch = [zeroed!(); $radix];
266 for k in 0..$radix {
267 scratch[k] = unsafe { load_narrow!(load.add(stride * k * m + j)) };
268 }
269
270 scratch = $butterfly!($type, scratch, _forward);
272 if size != $radix {
273 for k in 1..$radix {
274 scratch[k] = mul!(scratch[k], twiddles[k]);
275 }
276 }
277
278 for k in 0..$radix {
280 unsafe { store_narrow!(scratch[k], store.add(stride * k + j)) };
281 }
282 }
283 }
284 }
285 }
286 };
287 {
288 $([$radix:literal, $wide_name:ident, $narrow_name:ident, $butterfly:ident]),*
289 } => {
290 mod radix_f32 {
291 $(
292 make_radix_fns! { @impl f32, true, $radix, $wide_name, $butterfly }
293 make_radix_fns! { @impl f32, false, $radix, $narrow_name, $butterfly }
294 )*
295 }
296 mod radix_f64 {
297 $(
298 make_radix_fns! { @impl f64, true, $radix, $wide_name, $butterfly }
299 make_radix_fns! { @impl f64, false, $radix, $narrow_name, $butterfly }
300 )*
301 }
302 };
303}
304
305make_radix_fns! {
306 [2, radix_2_wide, radix_2_narrow, butterfly2],
307 [3, radix_3_wide, radix_3_narrow, butterfly3],
308 [4, radix_4_wide, radix_4_narrow, butterfly4],
309 [8, radix_8_wide, radix_8_narrow, butterfly8]
310}
311
312macro_rules! make_stage_fns {
314 { $type:ident, $name:ident, $radix_mod:ident } => {
315 #[multiversion::multiversion]
316 #[clone(target = "[x86|x86_64]+avx")]
317 #[inline]
318 fn $name(
319 input: &mut [Complex<$type>],
320 output: &mut [Complex<$type>],
321 stages: &[usize; NUM_RADICES],
322 mut twiddles: &[Complex<$type>],
323 mut size: usize,
324 transform: Transform,
325 ) {
326 #[target_cfg(target = "[x86|x86_64]+avx")]
327 crate::avx_vector! { $type };
328
329 #[target_cfg(not(target = "[x86|x86_64]+avx"))]
330 crate::generic_vector! { $type };
331
332 assert_eq!(input.len(), output.len());
333 assert_eq!(size, input.len());
334
335 let mut stride = 1;
336
337 let mut data_in_output = false;
338 for (radix, iterations) in RADICES.iter().zip(stages) {
339 let mut iteration = 0;
340
341 while stride < width! {} && iteration < *iterations {
343 let (from, to): (&mut _, &mut _) = if data_in_output {
344 (output, input)
345 } else {
346 (input, output)
347 };
348 match radix {
349 8 => dispatch!($radix_mod::radix_8_narrow(from, to, transform.is_forward(), size, stride, twiddles)),
350 4 => dispatch!($radix_mod::radix_4_narrow(from, to, transform.is_forward(), size, stride, twiddles)),
351 3 => dispatch!($radix_mod::radix_3_narrow(from, to, transform.is_forward(), size, stride, twiddles)),
352 2 => dispatch!($radix_mod::radix_2_narrow(from, to, transform.is_forward(), size, stride, twiddles)),
353 _ => unimplemented!("unsupported radix"),
354 }
355 size /= radix;
356 stride *= radix;
357 twiddles = &twiddles[size * radix..];
358 iteration += 1;
359 data_in_output = !data_in_output;
360 }
361
362 for _ in iteration..*iterations {
363 let (from, to): (&mut _, &mut _) = if data_in_output {
364 (output, input)
365 } else {
366 (input, output)
367 };
368 match radix {
369 8 => dispatch!($radix_mod::radix_8_wide(from, to, transform.is_forward(), size, stride, twiddles)),
370 4 => dispatch!($radix_mod::radix_4_wide(from, to, transform.is_forward(), size, stride, twiddles)),
371 3 => dispatch!($radix_mod::radix_3_wide(from, to, transform.is_forward(), size, stride, twiddles)),
372 2 => dispatch!($radix_mod::radix_2_wide(from, to, transform.is_forward(), size, stride, twiddles)),
373 _ => unimplemented!("unsupported radix"),
374 }
375 size /= radix;
376 stride *= radix;
377 twiddles = &twiddles[size * radix ..];
378 data_in_output = !data_in_output;
379 }
380 }
381 if let Some(scale) = match transform {
382 Transform::Fft | Transform::UnscaledIfft => None,
383 Transform::Ifft => Some(1. / (input.len() as $type)),
384 Transform::SqrtScaledFft | Transform::SqrtScaledIfft => Some(1. / (input.len() as $type).sqrt()),
385 } {
386 if data_in_output {
387 for (x, y) in output.iter().zip(input.iter_mut()) {
388 *y = x * scale;
389 }
390 } else {
391 for x in input.iter_mut() {
392 *x *= scale;
393 }
394 }
395 } else {
396 if data_in_output {
397 input.copy_from_slice(output);
398 }
399 }
400 }
401 };
402}
403make_stage_fns! { f32, apply_stages_f32, radix_f32 }
404make_stage_fns! { f64, apply_stages_f64, radix_f64 }