Skip to main content

p3_dft/
radix_2_small_batch.rs

1//! An FFT implementation optimized for small batch sizes.
2
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use core::iter;
6
7use itertools::Itertools;
8use p3_field::{Field, TwoAdicField, scale_slice_in_place_single_core};
9use p3_matrix::Matrix;
10use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
11use p3_matrix::util::reverse_matrix_index_bits;
12use p3_maybe_rayon::prelude::*;
13use p3_util::{as_base_slice, log2_strict_usize, reverse_slice_index_bits};
14use spin::RwLock;
15
16use crate::{
17    Butterfly, DifButterfly, DifButterflyZeros, DitButterfly, TwiddleFreeButterfly,
18    TwoAdicSubgroupDft,
19};
20
21/// The number of layers to compute in each parallelization.
22const LAYERS_PER_GROUP: usize = 3;
23
24/// Paired twiddle and inverse-twiddle tables, always updated atomically
25/// under a single lock to prevent concurrent observers from seeing a
26/// half-updated state.
27#[derive(Clone, Debug)]
28struct TwiddlePair<F> {
29    twiddles: Arc<[Vec<F>]>,
30    inv_twiddles: Arc<[Vec<F>]>,
31}
32
33impl<F> Default for TwiddlePair<F> {
34    fn default() -> Self {
35        Self {
36            twiddles: Arc::from(Vec::new()),
37            inv_twiddles: Arc::from(Vec::new()),
38        }
39    }
40}
41
42/// An FFT algorithm which divides a butterfly network's layers into two halves.
43///
44/// Unlike other FFT algorithms, this algorithm is optimized for small batch sizes.
45/// It also stores its twiddle factors and only re-computes if it is asked to do a
46/// larger FFT.
47///
48/// Instead of parallelizing across rows, this algorithm parallelizes across groups of rows
49/// with the same twiddle factors. This allows it to make use of field packings far more than
50/// the standard methods even for low width matrices. Once the chunk size is small enough, it
51/// computes a large set of layers fully on a single thread, which avoids the overhead of
52/// passing data between threads.
53#[derive(Default, Clone, Debug)]
54pub struct Radix2DFTSmallBatch<F> {
55    /// Memoized twiddle factors for each length log_n, paired with their inverses.
56    ///
57    /// Both tables are stored behind a single lock so they are always
58    /// updated atomically.  For each `i`, `twiddles[i]` contains a list
59    /// of twiddles stored in bit reversed order.  The final set of
60    /// twiddles `twiddles[-1]` is the one element vectors `[1]` and more
61    /// general `twiddles[-i]` has length `2^i`.
62    cache: Arc<RwLock<TwiddlePair<F>>>,
63}
64
65impl<F: TwoAdicField> Radix2DFTSmallBatch<F> {
66    /// Create a new `Radix2DFTSmallBatch` instance with precomputed twiddles for the given size.
67    ///
68    /// The input `n` should be a power of two, representing the maximal FFT size you expect to handle.
69    pub fn new(n: usize) -> Self {
70        let res = Self::default();
71        res.update_twiddles(n);
72        res
73    }
74
75    /// Given a field element `gen` of order n where `n = 2^lg_n`,
76    /// return a vector of vectors `table` where table[i] is the
77    /// vector of twiddle factors for an fft of length n/2^i. The
78    /// values g_i^k for k >= i/2 are skipped as these are just the
79    /// negatives of the other roots (using g_i^{i/2} = -1). The
80    /// value gen^0 = 1 is included to aid consistency between the
81    /// packed and non-packed variants.
82    fn roots_of_unity_table(&self, n: usize) -> Vec<Vec<F>> {
83        let lg_n = log2_strict_usize(n);
84        let generator = F::two_adic_generator(lg_n);
85        let half_n = 1 << (lg_n - 1);
86        // nth_roots = [1, g, g^2, g^3, ..., g^{n/2 - 1}]
87        let nth_roots = generator.powers().collect_n(half_n);
88
89        (0..lg_n)
90            .map(|i| nth_roots.iter().step_by(1 << i).copied().collect())
91            .collect()
92    }
93
94    /// Compute twiddle and inv_twiddle factors, or take memoized ones if already available.
95    fn update_twiddles(&self, fft_len: usize) {
96        // TODO: This recomputes the entire table from scratch if we
97        // need it to be larger, which is wasteful.
98
99        // Fast path: read lock to check if we already have enough.
100        // roots_of_unity_table(fft_len) returns a vector of twiddles of length log_2(fft_len).
101        let curr_max_fft_len = 1 << self.cache.read().twiddles.len();
102        if fft_len > curr_max_fft_len {
103            let mut new_twiddles = self.roots_of_unity_table(fft_len);
104            let mut new_inv_twiddles: Vec<Vec<F>> = new_twiddles
105                .iter()
106                .map(|ts| {
107                    // The first twiddle is still one, instead of inverting, we can
108                    // just reverse and negate.
109                    iter::once(F::ONE)
110                        .chain(ts[1..].iter().rev().map(|&f| -f))
111                        .collect()
112                })
113                .collect();
114
115            new_twiddles.iter_mut().for_each(|ts| {
116                reverse_slice_index_bits(ts);
117            });
118            new_inv_twiddles.iter_mut().for_each(|ts| {
119                reverse_slice_index_bits(ts);
120            });
121
122            // Slow path: acquire write lock and double-check before updating
123            // both tables atomically.
124            let mut cache = self.cache.write();
125            let cur_have = 1usize << cache.twiddles.len();
126            if fft_len > cur_have {
127                cache.twiddles = Arc::from(new_twiddles);
128                cache.inv_twiddles = Arc::from(new_inv_twiddles);
129            }
130        }
131    }
132}
133
134impl<F> TwoAdicSubgroupDft<F> for Radix2DFTSmallBatch<F>
135where
136    F: TwoAdicField,
137{
138    type Evaluations = RowMajorMatrix<F>;
139
140    fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> Self::Evaluations {
141        let h = mat.height();
142        let w = mat.width();
143        let log_h = log2_strict_usize(h);
144
145        self.update_twiddles(h);
146        let g = self.cache.read().twiddles.clone(); // Lock is dropped immediately
147        let root_table = &g[g.len() - log_h..];
148
149        // The strategy will be to do a standard round-by-round parallelization
150        // until the chunk size is smaller than `num_par_rows * mat.width()` after which we
151        // send `num_par_rows` chunks to each thread and do the remainder of the
152        // fft without transferring any more data between threads.
153        let num_par_rows = estimate_num_rows_in_l1::<F>(h, w);
154        let log_num_par_rows = log2_strict_usize(num_par_rows);
155        let chunk_size = num_par_rows * w;
156
157        // For the layers involving blocks larger than `num_par_rows`, we will
158        // parallelize across the blocks.
159
160        let multi_layer_dit = MultiLayerDitButterfly {};
161
162        // We do `LAYERS_PER_GROUP` layers of the DFT at once, to minimize how much data we need to transfer
163        // between threads.
164        for (dit_0, dit_1, dit_2) in root_table[log_num_par_rows..]
165            .iter()
166            .rev()
167            .map(|slice| unsafe { as_base_slice::<DitButterfly<F>, F>(slice) }) // Safe as DitButterfly is #[repr(transparent)]
168            .tuples()
169        {
170            dft_layer_par_triple(&mut mat.as_view_mut(), dit_0, dit_1, dit_2, multi_layer_dit);
171        }
172
173        // If the total number of layers is not a multiple of `LAYERS_PER_GROUP`,
174        // we need to handle the remaining layers separately.
175        let corr = (log_h - log_num_par_rows) % LAYERS_PER_GROUP;
176        dft_layer_par_extra_layers(
177            &mut mat.as_view_mut(),
178            &root_table[log_num_par_rows..log_num_par_rows + corr],
179            multi_layer_dit,
180        );
181
182        // Once the blocks are small enough, we can split the matrix
183        // into chunks of size `chunk_size` and process them in parallel.
184        // This avoids passing data between threads, which can be expensive.
185        par_remaining_layers(&mut mat.values, chunk_size, &root_table[..log_num_par_rows]);
186
187        // Finally we bit-reverse the matrix to ensure the output is in the correct order.
188        reverse_matrix_index_bits(&mut mat);
189        mat
190    }
191
192    fn idft_batch(&self, mut mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
193        let h = mat.height();
194        let w = mat.width();
195        let log_h = log2_strict_usize(h);
196
197        self.update_twiddles(h);
198        let g = self.cache.read().inv_twiddles.clone(); // Lock is dropped immediately
199        let start = g
200            .len()
201            .checked_sub(log_h)
202            .expect("log_h exceeds inv_twiddles length");
203        let root_table = &g[start..];
204
205        // Find the number of rows which can roughly fit in L1 cache.
206        // The strategy is the same as `dft_batch` but in reverse.
207        // We start by moving `num_par_rows` rows onto each thread and doing
208        // `num_par_rows` layers of the DFT. After this we recombine and do
209        // a standard round-by-round parallelization for the remaining layers.
210        let num_par_rows = estimate_num_rows_in_l1::<F>(h, w);
211        let log_num_par_rows = log2_strict_usize(num_par_rows);
212        let chunk_size = num_par_rows * w;
213
214        // Need to start by bit-reversing the matrix.
215        reverse_matrix_index_bits(&mut mat);
216
217        // For the initial blocks, they are small enough that we can split the matrix
218        // into chunks of size `chunk_size` and process them in parallel.
219        // This avoids passing data between threads, which can be expensive.
220        // We also divide by the height of the matrix while the data is nicely partitioned
221        // on each core.
222        par_initial_layers(
223            &mut mat.values,
224            chunk_size,
225            &root_table[..log_num_par_rows],
226            log_h,
227        );
228
229        // For the layers involving blocks larger than `num_par_rows`, we will
230        // parallelize across the blocks.
231
232        let multi_layer_dif = MultiLayerDifButterfly {};
233
234        // If the total number of layers is not a multiple of `LAYERS_PER_GROUP`,
235        // we need to handle the initial layers separately.
236        let corr = (log_h - log_num_par_rows) % LAYERS_PER_GROUP;
237        dft_layer_par_extra_layers(
238            &mut mat.as_view_mut(),
239            &root_table[log_num_par_rows..log_num_par_rows + corr],
240            multi_layer_dif,
241        );
242
243        // We do `LAYERS_PER_GROUP` layers of the DFT at once, to minimize how much data we need to transfer
244        // between threads.
245        for (dif_0, dif_1, dif_2) in root_table[(log_num_par_rows + corr)..]
246            .iter()
247            .map(|slice| unsafe { as_base_slice::<DifButterfly<F>, F>(slice) }) // Safe as DifButterfly is #[repr(transparent)]
248            .tuples()
249        {
250            dft_layer_par_triple(&mut mat.as_view_mut(), dif_2, dif_1, dif_0, multi_layer_dif);
251        }
252
253        mat
254    }
255
256    fn coset_lde_batch(
257        &self,
258        mut mat: RowMajorMatrix<F>,
259        added_bits: usize,
260        shift: F,
261    ) -> Self::Evaluations {
262        let h = mat.height();
263        let w = mat.width();
264        let log_h = log2_strict_usize(h);
265
266        self.update_twiddles(h << added_bits);
267        let cached = self.cache.read().clone();
268        let g = &cached.twiddles;
269        let start = g
270            .len()
271            .checked_sub(log_h + added_bits)
272            .expect("log_h exceeds twiddles length");
273        let root_table = &g[start..];
274        let ig = &cached.inv_twiddles;
275        let start = ig
276            .len()
277            .checked_sub(log_h)
278            .expect("log_h exceeds inv_twiddles length");
279        let inv_root_table = &ig[start..];
280        let output_height = h << added_bits;
281
282        // The matrix which we will use to store the output.
283        let output_values = F::zero_vec(output_height * w);
284        let mut out = RowMajorMatrix::new(output_values, w);
285
286        // The strategy is reasonably straightforward.
287        // The rough idea is we want to squash together the dft and idft code.
288
289        // This lets us do all of the inner layers on a single thread reducing the amount
290        // of data we need to transfer.
291
292        // For technical reasons, we need to swap the twiddle factors, using the inverse
293        // twiddles for the initial layers and the normal twiddles for the final layers.
294        // This lets us interpret the initial transformation as the idft giving us coefficients
295        // and the final transformation as the dft giving us evaluations.
296
297        // Find the number of rows which can roughly fit in L1 cache.
298        // The strategy will be to do a standard round-by-round parallelization
299        // until the chunk size is smaller than `num_par_rows * mat.width()` after which we
300        // send `num_par_rows` chunks to each thread and do the remainder of the
301        // fft without transferring any more data between threads.
302        let num_par_rows = estimate_num_rows_in_l1::<F>(h, w);
303        let num_inner_dit_layers = log2_strict_usize(num_par_rows);
304        let num_inner_dif_layers = num_inner_dit_layers + added_bits;
305
306        // We will do large DFT/iDFT layers in batches of `LAYERS_PER_GROUP`. We start with
307        // the dit layers.
308        let multi_layer_dit = MultiLayerDitButterfly {};
309        for (dit_0, dit_1, dit_2) in inv_root_table[num_inner_dit_layers..]
310            .iter()
311            .rev()
312            .map(|slice| unsafe { as_base_slice::<DitButterfly<F>, F>(slice) }) // Safe as DitButterfly is #[repr(transparent)]
313            .tuples()
314        {
315            dft_layer_par_triple(&mut mat.as_view_mut(), dit_0, dit_1, dit_2, multi_layer_dit);
316        }
317
318        // If the total number of layers is not a multiple of `LAYERS_PER_GROUP`,
319        // we need to handle the remaining layers separately.
320        let corr = (log_h - num_inner_dit_layers) % LAYERS_PER_GROUP;
321        dft_layer_par_extra_layers(
322            &mut mat.as_view_mut(),
323            &inv_root_table[num_inner_dit_layers..num_inner_dit_layers + corr],
324            multi_layer_dit,
325        );
326
327        // Now do all the inner layers at once. This does the final `log_num_par_rows` of
328        // the initial transformation, then copies the values of mat to output, scales then
329        // and does the first `log_num_par_rows + added_bits` layers of the final transformation.
330        par_middle_layers(
331            &mut mat.as_view_mut(),
332            &mut out.as_view_mut(),
333            num_par_rows,
334            &root_table[..(num_inner_dif_layers)],
335            &inv_root_table[..num_inner_dit_layers],
336            added_bits,
337            shift,
338        );
339
340        // We are left with the final dif layers.
341        let multi_layer_dif = MultiLayerDifButterfly {};
342
343        // If the total number of layers is not a multiple of `LAYERS_PER_GROUP`,
344        // we need to handle the remaining layers separately.
345        dft_layer_par_extra_layers(
346            &mut out.as_view_mut(),
347            &root_table[num_inner_dif_layers..num_inner_dif_layers + corr],
348            multi_layer_dif,
349        );
350
351        // We do `LAYERS_PER_GROUP` layers of the DFT at once, to minimize how much data we need to transfer
352        // between threads.
353        for (dif_0, dif_1, dif_2) in root_table[(num_inner_dif_layers + corr)..]
354            .iter()
355            .map(|slice| unsafe { as_base_slice::<DifButterfly<F>, F>(slice) }) // Safe as DifButterfly is #[repr(transparent)]
356            .tuples()
357        {
358            dft_layer_par_triple(&mut out.as_view_mut(), dif_2, dif_1, dif_0, multi_layer_dif);
359        }
360
361        out
362    }
363}
364
365/// Applies one layer of the Radix-2 DIF FFT butterfly network making use of parallelization.
366///
367/// Splits the matrix into blocks of rows and performs in-place butterfly operations
368/// on each block. Uses a `TwiddleFreeButterfly` for the first pair and `DifButterfly`
369/// with precomputed twiddles for the rest.
370///
371/// Each block is processed in parallel, if the blocks are large enough they themselves
372/// are split into parallel sub-blocks.
373///
374/// # Arguments
375/// - `mat`: Mutable matrix whose height is a power of two.
376/// - `twiddles`: Precomputed twiddle factors for this layer.
377#[inline]
378fn dft_layer_par<F: Field, B: Butterfly<F>>(
379    mat: &mut RowMajorMatrixViewMut<'_, F>,
380    twiddles: &[B],
381) {
382    debug_assert!(
383        mat.height().is_multiple_of(twiddles.len()),
384        "Matrix height must be divisible by the number of twiddles"
385    );
386    let size = mat.values.len();
387    let num_blocks = twiddles.len();
388
389    let outer_block_size = size / num_blocks;
390    let half_outer_block_size = outer_block_size / 2;
391
392    mat.values
393        .par_chunks_exact_mut(outer_block_size)
394        .enumerate()
395        .for_each(|(ind, block)| {
396            // Split each block vertically into top (hi) and bottom (lo) halves
397            let (hi_chunk, lo_chunk) = block.split_at_mut(half_outer_block_size);
398
399            // If num_blocks is small, we probably are not using all available threads.
400            let num_threads = current_num_threads();
401            let inner_block_size = size / (2 * num_blocks).max(num_threads);
402
403            hi_chunk
404                .par_chunks_mut(inner_block_size)
405                .zip(lo_chunk.par_chunks_mut(inner_block_size))
406                .for_each(|(hi_chunk, lo_chunk)| {
407                    if ind == 0 {
408                        // The first pair doesn't require a twiddle factor
409                        TwiddleFreeButterfly.apply_to_rows(hi_chunk, lo_chunk);
410                    } else {
411                        // Apply DIT butterfly using the twiddle factor at index `ind - 1`
412                        twiddles[ind].apply_to_rows(hi_chunk, lo_chunk);
413                    }
414                });
415        });
416}
417
418/// Splits the matrix into chunks of size `chunk_size` and performs
419/// the remaining layers of the FFT in parallel on each chunk.
420///
421/// This avoids passing data between threads, which can be expensive.
422#[inline]
423fn par_remaining_layers<F: Field>(mat: &mut [F], chunk_size: usize, root_table: &[Vec<F>]) {
424    mat.par_chunks_exact_mut(chunk_size)
425        .enumerate()
426        .for_each(|(index, chunk)| {
427            remaining_layers(chunk, root_table, index);
428        });
429}
430
431/// Performs a collection of DIT layers on a chunk of the matrix.
432fn remaining_layers<F: Field>(chunk: &mut [F], root_table: &[Vec<F>], index: usize) {
433    for (layer, twiddles) in root_table.iter().rev().enumerate() {
434        let num_twiddles_per_block = 1 << layer;
435        let start = index * num_twiddles_per_block;
436        let twiddle_range = start..(start + num_twiddles_per_block);
437        // Safe as DitButterfly is #[repr(transparent)]
438        let dit_twiddles: &[DitButterfly<F>] = unsafe { as_base_slice(&twiddles[twiddle_range]) };
439        dft_layer(chunk, dit_twiddles);
440    }
441}
442
443/// Splits the matrix into chunks of size `chunk_size` and performs
444/// the initial layers of the iFFT in parallel on each chunk.
445///
446/// This avoids passing data between threads, which can be expensive.
447///
448/// Basically identical to [par_remaining_layers] but in reverse and we
449/// also divide by the height.
450#[inline]
451fn par_initial_layers<F: Field>(
452    mat: &mut [F],
453    chunk_size: usize,
454    root_table: &[Vec<F>],
455    log_height: usize,
456) {
457    let inv_height = F::ONE.div_2exp_u64(log_height as u64);
458    mat.par_chunks_exact_mut(chunk_size)
459        .enumerate()
460        .for_each(|(index, chunk)| {
461            // Divide all elements by the height of the matrix.
462            scale_slice_in_place_single_core(chunk, inv_height);
463            initial_layers(chunk, root_table, index);
464        });
465}
466
467/// Performs a collection of DIF layers on a chunk of the matrix.
468#[inline]
469fn initial_layers<F: Field>(chunk: &mut [F], root_table: &[Vec<F>], index: usize) {
470    let num_rounds = root_table.len();
471
472    for (layer, twiddles) in root_table.iter().enumerate() {
473        let num_twiddles_per_block = 1 << (num_rounds - layer - 1);
474        let start = index * num_twiddles_per_block;
475        let twiddle_range = start..(start + num_twiddles_per_block);
476        // Safe as DifButterfly is #[repr(transparent)]
477        let dif_twiddles: &[DifButterfly<F>] = unsafe { as_base_slice(&twiddles[twiddle_range]) };
478        dft_layer(chunk, dif_twiddles);
479    }
480}
481
482/// Splits the matrix into chunks of size `chunk_size` and performs
483/// the middle layers of a coset_lde in parallel on each chunk.
484///
485/// Similar to [par_remaining_layers] followed by [par_initial_layers]
486/// with a scaling and copying operation in between.
487fn par_middle_layers<F: Field>(
488    in_mat: &mut RowMajorMatrixViewMut<'_, F>,
489    out_mat: &mut RowMajorMatrixViewMut<'_, F>,
490    num_par_rows: usize,
491    root_table: &[Vec<F>],
492    inv_root_table: &[Vec<F>],
493    added_bits: usize,
494    shift: F,
495) {
496    debug_assert_eq!(in_mat.width(), out_mat.width());
497    debug_assert_eq!(in_mat.height() << added_bits, out_mat.height());
498
499    let width = in_mat.width();
500    let height = in_mat.height();
501    let num_rounds = root_table.len();
502    let in_chunk_size = num_par_rows * width;
503    let out_chunk_size = in_chunk_size << added_bits;
504
505    let log_height = log2_strict_usize(height);
506    let inv_height = F::ONE.div_2exp_u64(log_height as u64);
507
508    let mut scaling = shift.shifted_powers(inv_height).collect_n(height);
509    reverse_slice_index_bits(&mut scaling);
510
511    in_mat
512        .values
513        .par_chunks_exact_mut(in_chunk_size)
514        .zip(out_mat.values.par_chunks_exact_mut(out_chunk_size))
515        .zip(scaling.par_chunks_exact_mut(num_par_rows))
516        .enumerate()
517        .for_each(|(index, ((in_chunk, out_chunk), scaling))| {
518            remaining_layers(in_chunk, inv_root_table, index);
519
520            // Copy the values to the output matrix and scale appropriately.
521            in_chunk
522                .chunks_exact(width)
523                .zip(scaling)
524                .zip(out_chunk.chunks_exact_mut(width << added_bits))
525                .for_each(|((in_row, scale), out_row)| {
526                    out_row
527                        .iter_mut()
528                        .zip(in_row.iter())
529                        .for_each(|(out_val, in_val)| {
530                            *out_val = *in_val * *scale;
531                        });
532                });
533
534            // We can do something cheaper than standard DFT layers for the first `added_bits` layers.
535            // as there are a lot of zeroes in the out_chunk.
536            for (layer, twiddles) in root_table[..added_bits].iter().enumerate() {
537                let num_twiddles_per_block = 1 << (num_rounds - layer - 1);
538                let start = index * num_twiddles_per_block;
539                let twiddle_range = start..(start + num_twiddles_per_block);
540
541                // Safe as DifButterflyZeros is #[repr(transparent)]
542                let dif_twiddles_zeros: &[DifButterflyZeros<F>] =
543                    unsafe { as_base_slice(&twiddles[twiddle_range]) };
544                dft_layer_zeros(out_chunk, dif_twiddles_zeros, added_bits - layer - 1);
545            }
546
547            initial_layers(out_chunk, &root_table[added_bits..], index);
548        });
549}
550
551/// Applies one layer of the Radix-2 FFT butterfly network on a single core.
552///
553/// Splits the matrix into blocks of rows and performs in-place butterfly operations
554/// on each block.
555///
556/// # Arguments
557/// - `vec`: Mutable vector whose height is a power of two.
558/// - `twiddles`: Precomputed twiddle factors for this layer.
559#[inline]
560fn dft_layer<F: Field, B: Butterfly<F>>(vec: &mut [F], twiddles: &[B]) {
561    debug_assert_eq!(
562        vec.len() % twiddles.len(),
563        0,
564        "Vector length must be divisible by the number of twiddles"
565    );
566    let size = vec.len();
567    let num_blocks = twiddles.len();
568
569    let block_size = size / num_blocks;
570    let half_block_size = block_size / 2;
571
572    vec.chunks_exact_mut(block_size)
573        .zip(twiddles)
574        .for_each(|(block, &twiddle)| {
575            // Split each block vertically into top (hi) and bottom (lo) halves
576            let (hi_chunk, lo_chunk) = block.split_at_mut(half_block_size);
577
578            // Apply DIT butterfly
579            twiddle.apply_to_rows(hi_chunk, lo_chunk);
580        });
581}
582
583/// Applies two layers of the Radix-2 FFT butterfly network making use of parallelization.
584///
585/// Splits the matrix into blocks of rows and performs in-place butterfly operations
586/// on each block. Advantage of doing two layers at once is it reduces the amount of
587/// data transferred between threads.
588///
589/// # Arguments
590/// - `mat`: Mutable matrix whose height is a power of two.
591/// - `twiddles_small`: Precomputed twiddle factors for the layer with the smallest block size.
592/// - `twiddles_large`: Precomputed twiddle factors for the layer with the largest block size.
593/// - `multi_butterfly`: Multi-layer butterfly which applies the two layers in the correct order.
594#[inline]
595fn dft_layer_par_double<F: Field, B: Butterfly<F>, M: MultiLayerButterfly<F, B>>(
596    mat: &mut RowMajorMatrixViewMut<'_, F>,
597    twiddles_small: &[B],
598    twiddles_large: &[B],
599    multi_butterfly: M,
600) {
601    debug_assert!(
602        mat.height().is_multiple_of(twiddles_small.len()),
603        "Matrix height must be divisible by the number of twiddles"
604    );
605    let size = mat.values.len();
606    let num_blocks = twiddles_small.len();
607
608    let outer_block_size = size / num_blocks;
609    let quarter_outer_block_size = outer_block_size / 4;
610
611    // Estimate the optimal size of the inner chunks so that all data fits in L1 cache.
612    // Note that 4 inner chunks are processed in each parallel thread so we divide by 4.
613    let inner_chunk_size =
614        (workload_size::<F>().next_power_of_two() / 4).min(quarter_outer_block_size);
615
616    mat.values
617        .par_chunks_exact_mut(outer_block_size)
618        .enumerate()
619        .for_each(|(ind, block)| {
620            // Split each block into four quarters. Each quarter will be further split into
621            // sub-chunks processed in parallel.
622            let chunk_par_iters_0 = block
623                .chunks_exact_mut(quarter_outer_block_size)
624                .map(|chunk| chunk.par_chunks_mut(inner_chunk_size))
625                .collect::<Vec<_>>();
626            let chunk_par_iters_1 = zip_par_iter_vec(chunk_par_iters_0);
627            chunk_par_iters_1.into_iter().tuples().for_each(|(hi, lo)| {
628                hi.zip(lo).for_each(|chunks| {
629                    multi_butterfly.apply_2_layers(chunks, ind, twiddles_small, twiddles_large);
630                });
631            });
632        });
633}
634
635/// Applies three layers of a Radix-2 FFT butterfly network making use of parallelization.
636///
637/// Splits the matrix into blocks of rows and performs in-place butterfly operations
638/// on each block. Advantage of doing three layers at once is it reduces the amount of
639/// data transferred between threads.
640///
641/// # Arguments
642/// - `mat`: Mutable matrix whose height is a power of two.
643/// - `twiddles_small`: Precomputed twiddle factors for the layer with the smallest block size.
644/// - `twiddles_med`: Precomputed twiddle factors for the middle layer.
645/// - `twiddles_large`: Precomputed twiddle factors for the layer with the largest block size.
646/// - `multi_butterfly`: Multi-layer butterfly which applies the three layers in the correct order.
647#[inline]
648fn dft_layer_par_triple<F: Field, B: Butterfly<F>, M: MultiLayerButterfly<F, B>>(
649    mat: &mut RowMajorMatrixViewMut<'_, F>,
650    twiddles_small: &[B],
651    twiddles_med: &[B],
652    twiddles_large: &[B],
653    multi_butterfly: M,
654) {
655    debug_assert!(
656        mat.height().is_multiple_of(twiddles_small.len()),
657        "Matrix height must be divisible by the number of twiddles"
658    );
659    let size = mat.values.len();
660    let num_blocks = twiddles_small.len();
661
662    let outer_block_size = size / num_blocks;
663    let eighth_outer_block_size = outer_block_size / 8;
664
665    // Estimate the optimal size of the inner chunks so that all data fits in L1 cache.
666    // Note that 8 inner chunks are processed in each parallel thread so we divide by 8.
667    let inner_chunk_size =
668        (workload_size::<F>().next_power_of_two() / 8).min(eighth_outer_block_size);
669
670    mat.values
671        .par_chunks_exact_mut(outer_block_size)
672        .enumerate()
673        .for_each(|(ind, block)| {
674            // Split each block into eight equal parts. Each part will be further split into
675            // sub-chunks processed in parallel.
676            let chunk_par_iters_0 = block
677                .chunks_exact_mut(eighth_outer_block_size)
678                .map(|chunk| chunk.par_chunks_mut(inner_chunk_size))
679                .collect::<Vec<_>>();
680            let chunk_par_iters_1 = zip_par_iter_vec(chunk_par_iters_0);
681            let chunk_par_iters_2 = zip_par_iter_vec(chunk_par_iters_1);
682            chunk_par_iters_2.into_iter().tuples().for_each(|(hi, lo)| {
683                hi.zip(lo).for_each(|chunks| {
684                    multi_butterfly.apply_3_layers(
685                        chunks,
686                        ind,
687                        twiddles_small,
688                        twiddles_med,
689                        twiddles_large,
690                    );
691                });
692            });
693        });
694}
695
696/// Applies the remaining layers of the Radix-2 FFT butterfly network in parallel.
697///
698/// This function is used to correct for the fact that the total number of layers
699/// may not be a multiple of `LAYERS_PER_GROUP`.
700fn dft_layer_par_extra_layers<F: Field, B: Butterfly<F>, M: MultiLayerButterfly<F, B>>(
701    mat: &mut RowMajorMatrixViewMut<'_, F>,
702    root_table: &[Vec<F>],
703    multi_layer: M,
704) {
705    match root_table.len() {
706        1 => {
707            // Safe as DitButterfly is #[repr(transparent)]
708            let fft_layer: &[B] = unsafe { as_base_slice(&root_table[0]) };
709            dft_layer_par(&mut mat.as_view_mut(), fft_layer);
710        }
711        2 => {
712            let fft_layer_0: &[B] = unsafe { as_base_slice(&root_table[0]) };
713            let fft_layer_1: &[B] = unsafe { as_base_slice(&root_table[1]) };
714            dft_layer_par_double(
715                &mut mat.as_view_mut(),
716                fft_layer_1,
717                fft_layer_0,
718                multi_layer,
719            );
720        }
721        0 => {}
722        _ => unreachable!("The number of layers must be 0, 1 or 2"),
723    }
724}
725
726/// Applies one layer of the Radix-2 FFT butterfly network on a single core to
727/// a recently zero-padded matrix.
728///
729/// Splits the matrix into blocks of rows and performs in-place butterfly operations
730/// on each block.
731///
732/// Assume `added_bits = 2` and we are doing a decimation in frequency approach.
733/// Then the rows of our matrix look like:
734/// ```text
735/// [R0, 0, 0, 0, R1, 0, 0, 0, ...]
736/// ```
737/// Thus the first two butterfly layers can be implemented more simply as they map the matrix to:
738/// ```text
739/// After Layer 0: [R0, T00 * R0, 0, 0, R1, T01 * R1, 0, 0, ...]
740/// After Layer 1: [R0, T00 * R0, T10 * R0, T10 * T00 * R0, R1, T01 * R1, T11 * R1, T11 * T01 * R1, ...].
741/// ```
742///
743/// # Arguments
744/// - `vec`: Mutable vector whose height is a power of two.
745/// - `twiddles`: Precomputed twiddle factors for this layer.
746/// - `skip`: `(1 << skip) - 1` is the number of entirely zero
747///   blocks between each non zero block.
748#[inline]
749fn dft_layer_zeros<F: Field, B: Butterfly<F>>(vec: &mut [F], twiddles: &[B], skip: usize) {
750    debug_assert_eq!(
751        vec.len() % twiddles.len(),
752        0,
753        "Vector length must be divisible by the number of twiddles"
754    );
755    let size = vec.len();
756    let num_blocks = twiddles.len();
757
758    let block_size = size / num_blocks;
759    let half_block_size = block_size / 2;
760
761    vec.chunks_exact_mut(block_size)
762        .zip(twiddles)
763        .step_by(1 << skip) // Skip the zero blocks
764        .for_each(|(block, &twiddle)| {
765            // Split each block vertically into top (hi) and bottom (lo) halves
766            let (hi_chunk, lo_chunk) = block.split_at_mut(half_block_size);
767
768            // Apply DIF butterfly making use of the fact that `lo_chunk` is zero.
769            twiddle.apply_to_rows(hi_chunk, lo_chunk);
770        });
771}
772
773/// A type representing a decomposition of an FFT block into four sub-blocks.
774type DoubleLayerBlockDecomposition<'a, F> =
775    ((&'a mut [F], &'a mut [F]), (&'a mut [F], &'a mut [F]));
776
777/// Performs an FFT layer on the sub-blocks using a single twiddle factor.
778#[inline]
779fn fft_double_layer_single_twiddle<F: Field, Fly: Butterfly<F>>(
780    block: &mut DoubleLayerBlockDecomposition<'_, F>,
781    butterfly: Fly,
782) {
783    butterfly.apply_to_rows(block.0.0, block.1.0);
784    butterfly.apply_to_rows(block.0.1, block.1.1);
785}
786
787/// Performs an FFT layer on the sub-blocks using a pair of twiddle factors.
788///
789/// The inputs are differentiated in order to allow the first input to potentially
790/// be a `TwiddleFreeButterfly`, which does not require a twiddle factor.
791#[inline]
792fn fft_double_layer_double_twiddle<F: Field, Fly0: Butterfly<F>, Fly1: Butterfly<F>>(
793    block: &mut DoubleLayerBlockDecomposition<'_, F>,
794    fly0: Fly0,
795    fly1: Fly1,
796) {
797    fly0.apply_to_rows(block.0.0, block.0.1);
798    fly1.apply_to_rows(block.1.0, block.1.1);
799}
800
801/// A type representing a decomposition of an FFT block into eight sub-blocks.
802type TripleLayerBlockDecomposition<'a, F> = (
803    ((&'a mut [F], &'a mut [F]), (&'a mut [F], &'a mut [F])),
804    ((&'a mut [F], &'a mut [F]), (&'a mut [F], &'a mut [F])),
805);
806
807/// Performs an FFT layer on the sub-blocks using a single twiddle factor.
808#[inline]
809fn fft_triple_layer_single_twiddle<F: Field, Fly: Butterfly<F>>(
810    block: &mut TripleLayerBlockDecomposition<'_, F>,
811    butterfly: Fly,
812) {
813    butterfly.apply_to_rows(block.0.0.0, block.1.0.0);
814    butterfly.apply_to_rows(block.0.0.1, block.1.0.1);
815    butterfly.apply_to_rows(block.0.1.0, block.1.1.0);
816    butterfly.apply_to_rows(block.0.1.1, block.1.1.1);
817}
818
819/// Performs an FFT layer on the sub-blocks using a pair of twiddle factors.
820///
821/// The inputs are differentiated in order to allow the first input to potentially
822/// be a `TwiddleFreeButterfly`, which does not require a twiddle factor.
823#[inline]
824fn fft_triple_layer_double_twiddle<F: Field, Fly0: Butterfly<F>, Fly1: Butterfly<F>>(
825    block: &mut TripleLayerBlockDecomposition<'_, F>,
826    fly0: Fly0,
827    fly1: Fly1,
828) {
829    fly0.apply_to_rows(block.0.0.0, block.0.1.0);
830    fly0.apply_to_rows(block.0.0.1, block.0.1.1);
831    fly1.apply_to_rows(block.1.0.0, block.1.1.0);
832    fly1.apply_to_rows(block.1.0.1, block.1.1.1);
833}
834
835/// Performs an FFT layer on the sub-blocks using a four twiddle factors.
836///
837/// The inputs are differentiated in order to allow the first input to potentially
838/// be a `TwiddleFreeButterfly`, which does not require a twiddle factor.
839#[inline]
840fn fft_triple_layer_quad_twiddle<F: Field, Fly0: Butterfly<F>, Flies: Butterfly<F>>(
841    block: &mut TripleLayerBlockDecomposition<'_, F>,
842    fly0: Fly0,
843    butterflies: &[Flies],
844) {
845    debug_assert!(butterflies.len() == 3);
846    fly0.apply_to_rows(block.0.0.0, block.0.0.1);
847    butterflies[0].apply_to_rows(block.0.1.0, block.0.1.1);
848    butterflies[1].apply_to_rows(block.1.0.0, block.1.0.1);
849    butterflies[2].apply_to_rows(block.1.1.0, block.1.1.1);
850}
851
852/// Estimates the optimal workload size for `T` to fit in L1 cache.
853///
854/// Approximates the size of the L1 cache by 32 KB. Used to determine the number of
855/// chunks to process in parallel.
856#[must_use]
857const fn workload_size<T: Sized>() -> usize {
858    const L1_CACHE_SIZE: usize = 1 << 15; // 32 KB
859    L1_CACHE_SIZE / size_of::<T>()
860}
861
862/// Estimates the optimal number of rows of a `RowMajorMatrix<T>` to take in each parallel chunk.
863///
864/// Designed to ensure that `<T> * estimate_num_rows_par() * width` is roughly the size of the L1 cache.
865///
866/// Assumes that height is a power of two and always outputs a power of two.
867#[must_use]
868fn estimate_num_rows_in_l1<T: Sized>(height: usize, width: usize) -> usize {
869    (workload_size::<T>() / width)
870        .next_power_of_two()
871        .min(height) // Ensure we don't exceed the height of the matrix.
872}
873
874/// Given a vector of parallel iterators, zip all pairs together.
875///
876/// This lets us simulate the izip!() macro but for our possibly parallel iterators.
877///
878/// This function assumes that the input vector has an even number of elements. If
879/// it is given an odd number of elements, the last element will be ignored.
880#[inline]
881fn zip_par_iter_vec<I: IndexedParallelIterator>(
882    in_vec: Vec<I>,
883) -> Vec<impl IndexedParallelIterator<Item = (I::Item, I::Item)>> {
884    in_vec
885        .into_iter()
886        .tuples()
887        .map(|(hi, lo)| hi.zip(lo))
888        .collect::<Vec<_>>()
889}
890
891trait MultiLayerButterfly<F: Field, B: Butterfly<F>>: Copy + Send + Sync {
892    fn apply_2_layers(
893        &self,
894        chunk_decomposition: DoubleLayerBlockDecomposition<'_, F>,
895        ind: usize,
896        twiddles_small: &[B],
897        twiddles_large: &[B],
898    );
899
900    fn apply_3_layers(
901        &self,
902        chunk_decomposition: TripleLayerBlockDecomposition<'_, F>,
903        ind: usize,
904        twiddles_small: &[B],
905        twiddles_med: &[B],
906        twiddles_large: &[B],
907    );
908}
909
910#[derive(Debug, Clone, Copy)]
911struct MultiLayerDitButterfly;
912
913impl<F: Field> MultiLayerButterfly<F, DitButterfly<F>> for MultiLayerDitButterfly {
914    #[inline]
915    fn apply_2_layers(
916        &self,
917        mut blk_decomp: DoubleLayerBlockDecomposition<'_, F>,
918        ind: usize,
919        twiddles_small: &[DitButterfly<F>],
920        twiddles_large: &[DitButterfly<F>],
921    ) {
922        if ind == 0 {
923            fft_double_layer_single_twiddle(&mut blk_decomp, TwiddleFreeButterfly);
924            fft_double_layer_double_twiddle(
925                &mut blk_decomp,
926                TwiddleFreeButterfly,
927                twiddles_large[1],
928            );
929        } else {
930            fft_double_layer_single_twiddle(&mut blk_decomp, twiddles_small[ind]);
931            fft_double_layer_double_twiddle(
932                &mut blk_decomp,
933                twiddles_large[2 * ind],
934                twiddles_large[2 * ind + 1],
935            );
936        }
937    }
938
939    #[inline]
940    fn apply_3_layers(
941        &self,
942        mut blk_decomp: TripleLayerBlockDecomposition<'_, F>,
943        ind: usize,
944        twiddles_small: &[DitButterfly<F>],
945        twiddles_med: &[DitButterfly<F>],
946        twiddles_large: &[DitButterfly<F>],
947    ) {
948        if ind == 0 {
949            fft_triple_layer_single_twiddle(&mut blk_decomp, TwiddleFreeButterfly);
950            fft_triple_layer_double_twiddle(&mut blk_decomp, TwiddleFreeButterfly, twiddles_med[1]);
951            fft_triple_layer_quad_twiddle(
952                &mut blk_decomp,
953                TwiddleFreeButterfly,
954                &twiddles_large[1..4],
955            );
956        } else {
957            fft_triple_layer_single_twiddle(&mut blk_decomp, twiddles_small[ind]);
958            fft_triple_layer_double_twiddle(
959                &mut blk_decomp,
960                twiddles_med[2 * ind],
961                twiddles_med[2 * ind + 1],
962            );
963            fft_triple_layer_quad_twiddle(
964                &mut blk_decomp,
965                twiddles_large[4 * ind],
966                &twiddles_large[4 * ind + 1..4 * (ind + 1)],
967            );
968        }
969    }
970}
971
972#[derive(Debug, Clone, Copy)]
973struct MultiLayerDifButterfly;
974
975impl<F: Field> MultiLayerButterfly<F, DifButterfly<F>> for MultiLayerDifButterfly {
976    #[inline]
977    fn apply_2_layers(
978        &self,
979        mut blk_decomp: DoubleLayerBlockDecomposition<'_, F>,
980        ind: usize,
981        twiddles_small: &[DifButterfly<F>],
982        twiddles_large: &[DifButterfly<F>],
983    ) {
984        if ind == 0 {
985            fft_double_layer_double_twiddle(
986                &mut blk_decomp,
987                TwiddleFreeButterfly,
988                twiddles_large[1],
989            );
990            fft_double_layer_single_twiddle(&mut blk_decomp, TwiddleFreeButterfly);
991        } else {
992            fft_double_layer_double_twiddle(
993                &mut blk_decomp,
994                twiddles_large[2 * ind],
995                twiddles_large[2 * ind + 1],
996            );
997            fft_double_layer_single_twiddle(&mut blk_decomp, twiddles_small[ind]);
998        }
999    }
1000
1001    #[inline]
1002    fn apply_3_layers(
1003        &self,
1004        mut blk_decomp: TripleLayerBlockDecomposition<'_, F>,
1005        ind: usize,
1006        twiddles_small: &[DifButterfly<F>],
1007        twiddles_med: &[DifButterfly<F>],
1008        twiddles_large: &[DifButterfly<F>],
1009    ) {
1010        if ind == 0 {
1011            fft_triple_layer_quad_twiddle(
1012                &mut blk_decomp,
1013                TwiddleFreeButterfly,
1014                &twiddles_large[1..4],
1015            );
1016            fft_triple_layer_double_twiddle(&mut blk_decomp, TwiddleFreeButterfly, twiddles_med[1]);
1017            fft_triple_layer_single_twiddle(&mut blk_decomp, TwiddleFreeButterfly);
1018        } else {
1019            fft_triple_layer_quad_twiddle(
1020                &mut blk_decomp,
1021                twiddles_large[4 * ind],
1022                &twiddles_large[4 * ind + 1..4 * (ind + 1)],
1023            );
1024            fft_triple_layer_double_twiddle(
1025                &mut blk_decomp,
1026                twiddles_med[2 * ind],
1027                twiddles_med[2 * ind + 1],
1028            );
1029            fft_triple_layer_single_twiddle(&mut blk_decomp, twiddles_small[ind]);
1030        }
1031    }
1032}