fourier_algorithms/autosort/
mod.rs

1#![allow(unused_unsafe)]
2#![allow(unused_macros)]
3
4#[macro_use]
5mod butterfly;
6#[macro_use]
7mod avx_optimization;
8
9use crate::fft::{Fft, Transform};
10use crate::float::FftFloat;
11use crate::twiddle::compute_twiddle;
12use core::cell::RefCell;
13use core::marker::PhantomData;
14use num_complex::Complex;
15use num_traits::One as _;
16
17#[cfg(not(feature = "std"))]
18use num_traits::Float as _; // enable sqrt without std
19
20const NUM_RADICES: usize = 5;
21const RADICES: [usize; NUM_RADICES] = [4, 8, 4, 3, 2];
22
23/// Initializes twiddles.
24fn initialize_twiddles<T: FftFloat, E: Extend<Complex<T>>>(
25    mut size: usize,
26    counts: [usize; NUM_RADICES],
27    forward_twiddles: &mut E,
28    inverse_twiddles: &mut E,
29) {
30    let mut stride = 1;
31    for (radix, count) in RADICES.iter().zip(&counts) {
32        for _ in 0..*count {
33            let m = size / radix;
34            for i in 0..m {
35                forward_twiddles.extend(core::iter::once(Complex::<T>::one()));
36                inverse_twiddles.extend(core::iter::once(Complex::<T>::one()));
37                for j in 1..*radix {
38                    forward_twiddles.extend(core::iter::once(compute_twiddle(i * j, size, true)));
39                    inverse_twiddles.extend(core::iter::once(compute_twiddle(i * j, size, false)));
40                }
41            }
42            size /= radix;
43            stride *= radix;
44        }
45    }
46}
47
48/// Implements a mixed-radix Stockham autosort algorithm for multiples of 2 and 3.
49pub struct Autosort<T, Twiddles, Work> {
50    size: usize,
51    counts: [usize; NUM_RADICES],
52    forward_twiddles: Twiddles,
53    inverse_twiddles: Twiddles,
54    work: RefCell<Work>,
55    real_type: PhantomData<T>,
56}
57
58impl<T, Twiddles, Work> Autosort<T, Twiddles, Work> {
59    /// Return the radix counts.
60    pub fn counts(&self) -> [usize; NUM_RADICES] {
61        self.counts
62    }
63
64    /// Create a new transform generator from parts.  Twiddles factors and work must be the correct
65    /// size.
66    pub unsafe fn new_from_parts(
67        size: usize,
68        counts: [usize; NUM_RADICES],
69        forward_twiddles: Twiddles,
70        inverse_twiddles: Twiddles,
71        work: Work,
72    ) -> Self {
73        Self {
74            size,
75            counts,
76            forward_twiddles,
77            inverse_twiddles,
78            work: RefCell::new(work),
79            real_type: PhantomData,
80        }
81    }
82}
83
84impl<T, Twiddles: AsRef<[Complex<T>]>, Work: AsRef<[Complex<T>]>> Autosort<T, Twiddles, Work> {
85    /// Return the forward and inverse twiddle factors.
86    pub fn twiddles(&self) -> (&[Complex<T>], &[Complex<T>]) {
87        (
88            self.forward_twiddles.as_ref(),
89            self.inverse_twiddles.as_ref(),
90        )
91    }
92
93    /// Return the work buffer size.
94    pub fn work_size(&self) -> usize {
95        self.work.borrow().as_ref().len()
96    }
97}
98
99impl<T: FftFloat, Twiddles: Default + Extend<Complex<T>>, Work: Default + Extend<Complex<T>>>
100    Autosort<T, Twiddles, Work>
101{
102    /// Create a new Stockham autosort generator.  Returns `None` if the transform size cannot be
103    /// performed.
104    pub fn new(size: usize) -> Option<Self> {
105        let mut current_size = size;
106        let mut counts = [0usize; NUM_RADICES];
107        if current_size % RADICES[0] == 0 {
108            current_size /= RADICES[0];
109            counts[0] = 1;
110        }
111        for (count, radix) in counts.iter_mut().zip(&RADICES).skip(1) {
112            while current_size % radix == 0 {
113                current_size /= radix;
114                *count += 1;
115            }
116        }
117        if current_size == 1 {
118            let mut forward_twiddles = Twiddles::default();
119            let mut inverse_twiddles = Twiddles::default();
120            initialize_twiddles(size, counts, &mut forward_twiddles, &mut inverse_twiddles);
121            let mut work = Work::default();
122            work.extend(core::iter::repeat(Complex::default()).take(size));
123            Some(Self {
124                size,
125                counts,
126                forward_twiddles,
127                inverse_twiddles,
128                work: RefCell::new(work),
129                real_type: PhantomData,
130            })
131        } else {
132            None
133        }
134    }
135}
136
137macro_rules! implement {
138    {
139        $type:ty, $apply:ident
140    } => {
141        impl<Twiddles: AsRef<[Complex<$type>]>, Work: AsMut<[Complex<$type>]>> Fft
142            for Autosort<$type, Twiddles, Work>
143        {
144            type Real = $type;
145
146            fn size(&self) -> usize {
147                self.size
148            }
149
150            fn transform_in_place(&self, input: &mut [Complex<$type>], transform: Transform) {
151                let mut work = self.work.borrow_mut();
152                let twiddles = if transform.is_forward() {
153                    &self.forward_twiddles
154                } else {
155                    &self.inverse_twiddles
156                };
157                $apply(
158                    input,
159                    work.as_mut(),
160                    &self.counts,
161                    twiddles.as_ref(),
162                    self.size,
163                    transform,
164                );
165            }
166        }
167    }
168}
169implement! { f32, apply_stages_f32 }
170implement! { f64, apply_stages_f64 }
171
172/// This macro creates two modules, `radix_f32` and `radix_f64`, containing the radix application
173/// functions for each radix.
174macro_rules! make_radix_fns {
175    {
176        @impl $type:ident, $wide:literal, $radix:literal, $name:ident, $butterfly:ident
177    } => {
178
179        #[multiversion::multiversion]
180        #[clone(target = "[x86|x86_64]+avx")]
181        #[inline]
182        pub fn $name(
183            input: &[num_complex::Complex<$type>],
184            output: &mut [num_complex::Complex<$type>],
185            _forward: bool,
186            size: usize,
187            stride: usize,
188            cached_twiddles: &[num_complex::Complex<$type>],
189        ) {
190            #[target_cfg(target = "[x86|x86_64]+avx")]
191            crate::avx_vector! { $type };
192
193            #[target_cfg(not(target = "[x86|x86_64]+avx"))]
194            crate::generic_vector! { $type };
195
196            #[target_cfg(target = "[x86|x86_64]+avx")]
197            {
198                if !$wide && crate::avx_optimization!($type, $radix, input, output, _forward, size, stride, cached_twiddles) {
199                    return
200                }
201            }
202
203            let m = size / $radix;
204
205            let (full_count, final_offset) = if $wide {
206                (Some(((stride - 1) / width!()) * width!()), Some(stride - width!()))
207            } else {
208                (None, None)
209            };
210
211            for i in 0..m {
212                // Load twiddle factors
213                if $wide {
214                    let twiddles = {
215                        let mut twiddles = [zeroed!(); $radix];
216                        for k in 1..$radix {
217                            twiddles[k] = unsafe {
218                                broadcast!(cached_twiddles.as_ptr().add(i * $radix + k).read())
219                            };
220                        }
221                        twiddles
222                    };
223
224                    // Loop over full vectors, with a final overlapping vector
225                    for j in (0..full_count.unwrap())
226                        .step_by(width!())
227                        .chain(core::iter::once(final_offset.unwrap()))
228                    {
229                        // Load full vectors
230                        let mut scratch = [zeroed!(); $radix];
231                        let load = unsafe { input.as_ptr().add(j + stride * i) };
232                        for k in 0..$radix {
233                            scratch[k] = unsafe { load_wide!(load.add(stride * k * m)) };
234                        }
235
236                        // Butterfly with optional twiddles
237                        scratch = $butterfly!($type, scratch, _forward);
238                        if size != $radix {
239                            for k in 1..$radix {
240                                scratch[k] = mul!(scratch[k], twiddles[k]);
241                            }
242                        }
243
244                        // Store full vectors
245                        let store = unsafe { output.as_mut_ptr().add(j + $radix * stride * i) };
246                        for k in 0..$radix {
247                            unsafe { store_wide!(scratch[k], store.add(stride * k)) };
248                        }
249                    }
250                } else {
251                    let twiddles = {
252                        let mut twiddles = [zeroed!(); $radix];
253                        for k in 1..$radix {
254                            twiddles[k] = unsafe {
255                                load_narrow!(cached_twiddles.as_ptr().add(i * $radix + k))
256                            };
257                        }
258                        twiddles
259                    };
260
261                    let load = unsafe { input.as_ptr().add(stride * i) };
262                    let store = unsafe { output.as_mut_ptr().add($radix * stride * i) };
263                    for j in 0..stride {
264                        // Load a single value
265                        let mut scratch = [zeroed!(); $radix];
266                        for k in 0..$radix {
267                            scratch[k] = unsafe { load_narrow!(load.add(stride * k * m + j)) };
268                        }
269
270                        // Butterfly with optional twiddles
271                        scratch = $butterfly!($type, scratch, _forward);
272                        if size != $radix {
273                            for k in 1..$radix {
274                                scratch[k] = mul!(scratch[k], twiddles[k]);
275                            }
276                        }
277
278                        // Store a single value
279                        for k in 0..$radix {
280                            unsafe { store_narrow!(scratch[k], store.add(stride * k + j)) };
281                        }
282                    }
283                }
284            }
285        }
286    };
287    {
288        $([$radix:literal, $wide_name:ident, $narrow_name:ident, $butterfly:ident]),*
289    } => {
290        mod radix_f32 {
291        $(
292            make_radix_fns! { @impl f32, true, $radix, $wide_name, $butterfly }
293            make_radix_fns! { @impl f32, false, $radix, $narrow_name, $butterfly }
294        )*
295        }
296        mod radix_f64 {
297        $(
298            make_radix_fns! { @impl f64, true, $radix, $wide_name, $butterfly }
299            make_radix_fns! { @impl f64, false, $radix, $narrow_name, $butterfly }
300        )*
301        }
302    };
303}
304
305make_radix_fns! {
306    [2, radix_2_wide, radix_2_narrow, butterfly2],
307    [3, radix_3_wide, radix_3_narrow, butterfly3],
308    [4, radix_4_wide, radix_4_narrow, butterfly4],
309    [8, radix_8_wide, radix_8_narrow, butterfly8]
310}
311
312/// This macro creates the stage application function.
313macro_rules! make_stage_fns {
314    { $type:ident, $name:ident, $radix_mod:ident } => {
315        #[multiversion::multiversion]
316        #[clone(target = "[x86|x86_64]+avx")]
317        #[inline]
318        fn $name(
319            input: &mut [Complex<$type>],
320            output: &mut [Complex<$type>],
321            stages: &[usize; NUM_RADICES],
322            mut twiddles: &[Complex<$type>],
323            mut size: usize,
324            transform: Transform,
325        ) {
326            #[target_cfg(target = "[x86|x86_64]+avx")]
327            crate::avx_vector! { $type };
328
329            #[target_cfg(not(target = "[x86|x86_64]+avx"))]
330            crate::generic_vector! { $type };
331
332            assert_eq!(input.len(), output.len());
333            assert_eq!(size, input.len());
334
335            let mut stride = 1;
336
337            let mut data_in_output = false;
338            for (radix, iterations) in RADICES.iter().zip(stages) {
339                let mut iteration = 0;
340
341                // Use partial loads until the stride is large enough
342                while stride < width! {} && iteration < *iterations {
343                    let (from, to): (&mut _, &mut _) = if data_in_output {
344                        (output, input)
345                    } else {
346                        (input, output)
347                    };
348                    match radix {
349                        8 => dispatch!($radix_mod::radix_8_narrow(from, to, transform.is_forward(), size, stride, twiddles)),
350                        4 => dispatch!($radix_mod::radix_4_narrow(from, to, transform.is_forward(), size, stride, twiddles)),
351                        3 => dispatch!($radix_mod::radix_3_narrow(from, to, transform.is_forward(), size, stride, twiddles)),
352                        2 => dispatch!($radix_mod::radix_2_narrow(from, to, transform.is_forward(), size, stride, twiddles)),
353                        _ => unimplemented!("unsupported radix"),
354                    }
355                    size /= radix;
356                    stride *= radix;
357                    twiddles = &twiddles[size * radix..];
358                    iteration += 1;
359                    data_in_output = !data_in_output;
360                }
361
362                for _ in iteration..*iterations {
363                    let (from, to): (&mut _, &mut _) = if data_in_output {
364                        (output, input)
365                    } else {
366                        (input, output)
367                    };
368                    match radix {
369                        8 => dispatch!($radix_mod::radix_8_wide(from, to, transform.is_forward(), size, stride, twiddles)),
370                        4 => dispatch!($radix_mod::radix_4_wide(from, to, transform.is_forward(), size, stride, twiddles)),
371                        3 => dispatch!($radix_mod::radix_3_wide(from, to, transform.is_forward(), size, stride, twiddles)),
372                        2 => dispatch!($radix_mod::radix_2_wide(from, to, transform.is_forward(), size, stride, twiddles)),
373                        _ => unimplemented!("unsupported radix"),
374                    }
375                    size /= radix;
376                    stride *= radix;
377                    twiddles = &twiddles[size * radix ..];
378                    data_in_output = !data_in_output;
379                }
380            }
381            if let Some(scale) = match transform {
382                Transform::Fft | Transform::UnscaledIfft => None,
383                Transform::Ifft => Some(1. / (input.len() as $type)),
384                Transform::SqrtScaledFft | Transform::SqrtScaledIfft => Some(1. / (input.len() as $type).sqrt()),
385            } {
386                if data_in_output {
387                    for (x, y) in output.iter().zip(input.iter_mut()) {
388                        *y = x * scale;
389                    }
390                } else {
391                    for x in input.iter_mut() {
392                        *x *= scale;
393                    }
394                }
395            } else {
396                if data_in_output {
397                    input.copy_from_slice(output);
398                }
399            }
400        }
401    };
402}
403make_stage_fns! { f32, apply_stages_f32, radix_f32 }
404make_stage_fns! { f64, apply_stages_f64, radix_f64 }