p3_dft/
radix_2_dit_parallel.rs

1use alloc::collections::BTreeMap;
2use alloc::slice;
3use alloc::vec::Vec;
4use core::cell::RefCell;
5use core::mem::{MaybeUninit, transmute};
6
7use itertools::{Itertools, izip};
8use p3_field::integers::QuotientMap;
9use p3_field::{Field, Powers, TwoAdicField};
10use p3_matrix::Matrix;
11use p3_matrix::bitrev::{BitReversalPerm, BitReversedMatrixView, BitReversibleMatrix};
12use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView, RowMajorMatrixViewMut};
13use p3_matrix::util::reverse_matrix_index_bits;
14use p3_maybe_rayon::prelude::*;
15use p3_util::{log2_strict_usize, reverse_bits_len, reverse_slice_index_bits};
16use tracing::{debug_span, instrument};
17
18use crate::TwoAdicSubgroupDft;
19use crate::butterflies::{Butterfly, DitButterfly};
20
21/// A parallel FFT algorithm which divides a butterfly network's layers into two halves.
22///
23/// For the first half, we apply a butterfly network with smaller blocks in earlier layers,
24/// i.e. either DIT or Bowers G. Then we bit-reverse, and for the second half, we continue executing
25/// the same network but in bit-reversed order. This way we're always working with small blocks,
26/// so within each half, we can have a certain amount of parallelism with no cross-thread
27/// communication.
28#[derive(Default, Clone, Debug)]
29pub struct Radix2DitParallel<F> {
30    /// Twiddles based on roots of unity, used in the forward DFT.
31    twiddles: RefCell<BTreeMap<usize, VectorPair<F>>>,
32
33    /// A map from `(log_h, shift)` to forward DFT twiddles with that coset shift baked in.
34    #[allow(clippy::type_complexity)]
35    coset_twiddles: RefCell<BTreeMap<(usize, F), Vec<Vec<F>>>>,
36
37    /// Twiddles based on inverse roots of unity, used in the inverse DFT.
38    inverse_twiddles: RefCell<BTreeMap<usize, VectorPair<F>>>,
39}
40
41/// A pair of vectors, one with twiddle factors in their natural order, the other bit-reversed.
42#[derive(Default, Clone, Debug)]
43struct VectorPair<F> {
44    twiddles: Vec<F>,
45    bitrev_twiddles: Vec<F>,
46}
47
48#[instrument(level = "debug", skip_all)]
49fn compute_twiddles<F: TwoAdicField + Ord>(log_h: usize) -> VectorPair<F> {
50    let half_h = (1 << log_h) >> 1;
51    let root = F::two_adic_generator(log_h);
52    let twiddles: Vec<F> = root.powers().take(half_h).collect();
53    let mut bit_reversed_twiddles = twiddles.clone();
54    reverse_slice_index_bits(&mut bit_reversed_twiddles);
55    VectorPair {
56        twiddles,
57        bitrev_twiddles: bit_reversed_twiddles,
58    }
59}
60
61#[instrument(level = "debug", skip_all)]
62fn compute_coset_twiddles<F: TwoAdicField + Ord>(log_h: usize, shift: F) -> Vec<Vec<F>> {
63    // In general either div_floor or div_ceil would work, but here we prefer div_ceil because it
64    // lets us assume below that the "first half" of the network has at least one layer of
65    // butterflies, even in the case of log_h = 1.
66    let mid = log_h.div_ceil(2);
67    let h = 1 << log_h;
68    let root = F::two_adic_generator(log_h);
69
70    (0..log_h)
71        .map(|layer| {
72            let shift_power = shift.exp_power_of_2(layer);
73            let powers = Powers {
74                base: root.exp_power_of_2(layer),
75                current: shift_power,
76            };
77            let mut twiddles: Vec<_> = powers.take(h >> (layer + 1)).collect();
78            let layer_rev = log_h - 1 - layer;
79            if layer_rev >= mid {
80                reverse_slice_index_bits(&mut twiddles);
81            }
82            twiddles
83        })
84        .collect()
85}
86
87#[instrument(level = "debug", skip_all)]
88fn compute_inverse_twiddles<F: TwoAdicField + Ord>(log_h: usize) -> VectorPair<F> {
89    let half_h = (1 << log_h) >> 1;
90    let root_inv = F::two_adic_generator(log_h).inverse();
91    let twiddles: Vec<F> = root_inv.powers().take(half_h).collect();
92    let mut bit_reversed_twiddles = twiddles.clone();
93
94    // In the middle of the coset LDE, we're in bit-reversed order.
95    reverse_slice_index_bits(&mut bit_reversed_twiddles);
96
97    VectorPair {
98        twiddles,
99        bitrev_twiddles: bit_reversed_twiddles,
100    }
101}
102
103impl<F: TwoAdicField + Ord> TwoAdicSubgroupDft<F> for Radix2DitParallel<F> {
104    type Evaluations = BitReversedMatrixView<RowMajorMatrix<F>>;
105
106    fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> Self::Evaluations {
107        let h = mat.height();
108        let log_h = log2_strict_usize(h);
109
110        // Compute twiddle factors, or take memoized ones if already available.
111        let mut twiddles_ref_mut = self.twiddles.borrow_mut();
112        let twiddles = twiddles_ref_mut
113            .entry(log_h)
114            .or_insert_with(|| compute_twiddles(log_h));
115
116        let mid = log_h.div_ceil(2);
117
118        // The first half looks like a normal DIT.
119        reverse_matrix_index_bits(&mut mat);
120        first_half(&mut mat, mid, &twiddles.twiddles);
121
122        // For the second half, we flip the DIT, working in bit-reversed order.
123        reverse_matrix_index_bits(&mut mat);
124        second_half(&mut mat, mid, &twiddles.bitrev_twiddles, None);
125
126        mat.bit_reverse_rows()
127    }
128
129    #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits = added_bits))]
130    fn coset_lde_batch(
131        &self,
132        mut mat: RowMajorMatrix<F>,
133        added_bits: usize,
134        shift: F,
135    ) -> Self::Evaluations {
136        let w = mat.width;
137        let h = mat.height();
138        let log_h = log2_strict_usize(h);
139        let mid = log_h.div_ceil(2);
140
141        let mut inverse_twiddles_ref_mut = self.inverse_twiddles.borrow_mut();
142        let inverse_twiddles = inverse_twiddles_ref_mut
143            .entry(log_h)
144            .or_insert_with(|| compute_inverse_twiddles(log_h));
145
146        // The first half looks like a normal DIT.
147        reverse_matrix_index_bits(&mut mat);
148        first_half(&mut mat, mid, &inverse_twiddles.twiddles);
149
150        // For the second half, we flip the DIT, working in bit-reversed order.
151        reverse_matrix_index_bits(&mut mat);
152        // We'll also scale by 1/h, as per the usual inverse DFT algorithm.
153        // If F isn't a PrimeField, (and is thus an extension field) it's much cheaper to
154        // invert in F::PrimeSubfield.
155        let h_inv_subfield = F::PrimeSubfield::from_int(h).try_inverse();
156        let scale = h_inv_subfield.map(F::from_prime_subfield);
157        second_half(&mut mat, mid, &inverse_twiddles.bitrev_twiddles, scale);
158        // We skip the final bit-reversal, since the next FFT expects bit-reversed input.
159
160        let lde_elems = w * (h << added_bits);
161        let elems_to_add = lde_elems - w * h;
162        debug_span!("reserve_exact").in_scope(|| mat.values.reserve_exact(elems_to_add));
163
164        let g_big = F::two_adic_generator(log_h + added_bits);
165
166        let mat_ptr = mat.values.as_mut_ptr();
167        let rest_ptr = unsafe { (mat_ptr as *mut MaybeUninit<F>).add(w * h) };
168        let first_slice: &mut [F] = unsafe { slice::from_raw_parts_mut(mat_ptr, w * h) };
169        let rest_slice: &mut [MaybeUninit<F>] =
170            unsafe { slice::from_raw_parts_mut(rest_ptr, lde_elems - w * h) };
171        let mut first_coset_mat = RowMajorMatrixViewMut::new(first_slice, w);
172        let mut rest_cosets_mat = rest_slice
173            .chunks_exact_mut(w * h)
174            .map(|slice| RowMajorMatrixViewMut::new(slice, w))
175            .collect_vec();
176
177        for coset_idx in 1..(1 << added_bits) {
178            let total_shift = g_big.exp_u64(coset_idx as u64) * shift;
179            let coset_idx = reverse_bits_len(coset_idx, added_bits);
180            let dest = &mut rest_cosets_mat[coset_idx - 1]; // - 1 because we removed the first matrix.
181            coset_dft_oop(self, &first_coset_mat.as_view(), dest, total_shift);
182        }
183
184        // Now run a forward DFT on the very first coset, this time in-place.
185        coset_dft(self, &mut first_coset_mat.as_view_mut(), shift);
186
187        // SAFETY: We wrote all values above.
188        unsafe {
189            mat.values.set_len(lde_elems);
190        }
191        BitReversalPerm::new_view(mat)
192    }
193}
194
195#[instrument(level = "debug", skip_all)]
196fn coset_dft<F: TwoAdicField + Ord>(
197    dft: &Radix2DitParallel<F>,
198    mat: &mut RowMajorMatrixViewMut<F>,
199    shift: F,
200) {
201    let log_h = log2_strict_usize(mat.height());
202    let mid = log_h.div_ceil(2);
203
204    let mut twiddles_ref_mut = dft.coset_twiddles.borrow_mut();
205    let twiddles = twiddles_ref_mut
206        .entry((log_h, shift))
207        .or_insert_with(|| compute_coset_twiddles(log_h, shift));
208
209    // The first half looks like a normal DIT.
210    first_half_general(mat, mid, twiddles);
211
212    // For the second half, we flip the DIT, working in bit-reversed order.
213    reverse_matrix_index_bits(mat);
214
215    second_half_general(mat, mid, twiddles);
216}
217
218/// Like `coset_dft`, except out-of-place.
219#[instrument(level = "debug", skip_all)]
220fn coset_dft_oop<F: TwoAdicField + Ord>(
221    dft: &Radix2DitParallel<F>,
222    src: &RowMajorMatrixView<F>,
223    dst_maybe: &mut RowMajorMatrixViewMut<MaybeUninit<F>>,
224    shift: F,
225) {
226    assert_eq!(src.dimensions(), dst_maybe.dimensions());
227
228    let log_h = log2_strict_usize(dst_maybe.height());
229
230    if log_h == 0 {
231        // This is an edge case where first_half_general_oop doesn't work, as it expects there to be
232        // at least one layer in the network, so we just copy instead.
233        let src_maybe = unsafe {
234            transmute::<&RowMajorMatrixView<F>, &RowMajorMatrixView<MaybeUninit<F>>>(src)
235        };
236        dst_maybe.copy_from(src_maybe);
237        return;
238    }
239
240    let mid = log_h.div_ceil(2);
241
242    let mut twiddles_ref_mut = dft.coset_twiddles.borrow_mut();
243    let twiddles = twiddles_ref_mut
244        .entry((log_h, shift))
245        .or_insert_with(|| compute_coset_twiddles(log_h, shift));
246
247    // The first half looks like a normal DIT.
248    first_half_general_oop(src, dst_maybe, mid, twiddles);
249
250    // dst is now initialized.
251    let dst = unsafe {
252        transmute::<&mut RowMajorMatrixViewMut<MaybeUninit<F>>, &mut RowMajorMatrixViewMut<F>>(
253            dst_maybe,
254        )
255    };
256
257    // For the second half, we flip the DIT, working in bit-reversed order.
258    reverse_matrix_index_bits(dst);
259
260    second_half_general(dst, mid, twiddles);
261}
262
263/// This can be used as the first half of a DIT butterfly network.
264#[instrument(level = "debug", skip_all)]
265fn first_half<F: Field>(mat: &mut RowMajorMatrix<F>, mid: usize, twiddles: &[F]) {
266    let log_h = log2_strict_usize(mat.height());
267
268    // max block size: 2^mid
269    mat.par_row_chunks_exact_mut(1 << mid)
270        .for_each(|mut submat| {
271            let mut backwards = false;
272            for layer in 0..mid {
273                let layer_rev = log_h - 1 - layer;
274                let layer_pow = 1 << layer_rev;
275                dit_layer(
276                    &mut submat,
277                    layer,
278                    twiddles.iter().copied().step_by(layer_pow),
279                    backwards,
280                );
281                backwards = !backwards;
282            }
283        });
284}
285
286/// Like `first_half`, except supporting different twiddle factors per layer, enabling coset shifts
287/// to be baked into them.
288#[instrument(level = "debug", skip_all)]
289fn first_half_general<F: Field>(
290    mat: &mut RowMajorMatrixViewMut<F>,
291    mid: usize,
292    twiddles: &[Vec<F>],
293) {
294    let log_h = log2_strict_usize(mat.height());
295    mat.par_row_chunks_exact_mut(1 << mid)
296        .for_each(|mut submat| {
297            let mut backwards = false;
298            for layer in 0..mid {
299                let layer_rev = log_h - 1 - layer;
300                dit_layer(
301                    &mut submat,
302                    layer,
303                    twiddles[layer_rev].iter().copied(),
304                    backwards,
305                );
306                backwards = !backwards;
307            }
308        });
309}
310
311/// Like `first_half_general`, except out-of-place.
312///
313/// Assumes there's at least one layer in the network, i.e. `src.height() > 1`.
314/// Undefined behavior otherwise.
315#[instrument(level = "debug", skip_all)]
316fn first_half_general_oop<F: Field>(
317    src: &RowMajorMatrixView<F>,
318    dst_maybe: &mut RowMajorMatrixViewMut<MaybeUninit<F>>,
319    mid: usize,
320    twiddles: &[Vec<F>],
321) {
322    let log_h = log2_strict_usize(src.height());
323    src.par_row_chunks_exact(1 << mid)
324        .zip(dst_maybe.par_row_chunks_exact_mut(1 << mid))
325        .for_each(|(src_submat, mut dst_submat_maybe)| {
326            debug_assert_eq!(src_submat.dimensions(), dst_submat_maybe.dimensions());
327
328            // The first layer is special, done out-of-place.
329            // (Recall from the mid definition that there must be at least one layer here.)
330            let layer_rev = log_h - 1;
331            dit_layer_oop(
332                &src_submat,
333                &mut dst_submat_maybe,
334                0,
335                twiddles[layer_rev].iter().copied(),
336            );
337
338            // submat is now initialized.
339            let mut dst_submat = unsafe {
340                transmute::<RowMajorMatrixViewMut<MaybeUninit<F>>, RowMajorMatrixViewMut<F>>(
341                    dst_submat_maybe,
342                )
343            };
344
345            // Subsequent layers.
346            let mut backwards = true;
347            for layer in 1..mid {
348                let layer_rev = log_h - 1 - layer;
349                dit_layer(
350                    &mut dst_submat,
351                    layer,
352                    twiddles[layer_rev].iter().copied(),
353                    backwards,
354                );
355                backwards = !backwards;
356            }
357        });
358}
359
360/// This can be used as the second half of a DIT butterfly network. It works in bit-reversed order.
361///
362/// The optional `scale` parameter is used to scale the matrix by a constant factor. Normally that
363/// would be a separate step, but it's best to merge it into a butterfly network to avoid a
364/// separate pass through main memory.
365#[instrument(level = "debug", skip_all)]
366#[inline(always)] // To avoid branch on scale
367fn second_half<F: Field>(
368    mat: &mut RowMajorMatrix<F>,
369    mid: usize,
370    twiddles_rev: &[F],
371    scale: Option<F>,
372) {
373    let log_h = log2_strict_usize(mat.height());
374
375    // max block size: 2^(log_h - mid)
376    mat.par_row_chunks_exact_mut(1 << (log_h - mid))
377        .enumerate()
378        .for_each(|(thread, mut submat)| {
379            let mut backwards = false;
380            if let Some(scale) = scale {
381                submat.scale(scale);
382            }
383            for layer in mid..log_h {
384                let first_block = thread << (layer - mid);
385                dit_layer_rev(
386                    &mut submat,
387                    log_h,
388                    layer,
389                    twiddles_rev[first_block..].iter().copied(),
390                    backwards,
391                );
392                backwards = !backwards;
393            }
394        });
395}
396
397/// Like `second_half`, except supporting different twiddle factors per layer, enabling coset shifts
398/// to be baked into them.
399#[instrument(level = "debug", skip_all)]
400fn second_half_general<F: Field>(
401    mat: &mut RowMajorMatrixViewMut<F>,
402    mid: usize,
403    twiddles_rev: &[Vec<F>],
404) {
405    let log_h = log2_strict_usize(mat.height());
406    mat.par_row_chunks_exact_mut(1 << (log_h - mid))
407        .enumerate()
408        .for_each(|(thread, mut submat)| {
409            let mut backwards = false;
410            for layer in mid..log_h {
411                let layer_rev = log_h - 1 - layer;
412                let first_block = thread << (layer - mid);
413                dit_layer_rev(
414                    &mut submat,
415                    log_h,
416                    layer,
417                    twiddles_rev[layer_rev][first_block..].iter().copied(),
418                    backwards,
419                );
420                backwards = !backwards;
421            }
422        });
423}
424
425/// One layer of a DIT butterfly network.
426fn dit_layer<F: Field>(
427    submat: &mut RowMajorMatrixViewMut<'_, F>,
428    layer: usize,
429    twiddles: impl Iterator<Item = F> + Clone,
430    backwards: bool,
431) {
432    let half_block_size = 1 << layer;
433    let block_size = half_block_size * 2;
434    let width = submat.width();
435    debug_assert!(submat.height() >= block_size);
436
437    let process_block = |block: &mut [F]| {
438        let (lows, highs) = block.split_at_mut(half_block_size * width);
439
440        for (lo, hi, twiddle) in izip!(
441            lows.chunks_mut(width),
442            highs.chunks_mut(width),
443            twiddles.clone()
444        ) {
445            DitButterfly(twiddle).apply_to_rows(lo, hi);
446        }
447    };
448
449    let blocks = submat.values.chunks_mut(block_size * width);
450    if backwards {
451        for block in blocks.rev() {
452            process_block(block);
453        }
454    } else {
455        for block in blocks {
456            process_block(block);
457        }
458    }
459}
460
461/// One layer of a DIT butterfly network.
462fn dit_layer_oop<F: Field>(
463    src: &RowMajorMatrixView<F>,
464    dst: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
465    layer: usize,
466    twiddles: impl Iterator<Item = F> + Clone,
467) {
468    debug_assert_eq!(src.dimensions(), dst.dimensions());
469    let half_block_size = 1 << layer;
470    let block_size = half_block_size * 2;
471    let width = dst.width();
472    debug_assert!(dst.height() >= block_size);
473
474    let src_chunks = src.values.chunks(block_size * width);
475    let dst_chunks = dst.values.chunks_mut(block_size * width);
476    for (src_block, dst_block) in src_chunks.zip(dst_chunks) {
477        let (src_lows, src_highs) = src_block.split_at(half_block_size * width);
478        let (dst_lows, dst_highs) = dst_block.split_at_mut(half_block_size * width);
479
480        for (src_lo, dst_lo, src_hi, dst_hi, twiddle) in izip!(
481            src_lows.chunks(width),
482            dst_lows.chunks_mut(width),
483            src_highs.chunks(width),
484            dst_highs.chunks_mut(width),
485            twiddles.clone()
486        ) {
487            DitButterfly(twiddle).apply_to_rows_oop(src_lo, dst_lo, src_hi, dst_hi);
488        }
489    }
490}
491
492/// Like `dit_layer`, except the matrix and twiddles are encoded in bit-reversed order.
493/// This can also be viewed as a layer of the Bowers G^T network.
494fn dit_layer_rev<F: Field>(
495    submat: &mut RowMajorMatrixViewMut<'_, F>,
496    log_h: usize,
497    layer: usize,
498    twiddles_rev: impl DoubleEndedIterator<Item = F> + ExactSizeIterator,
499    backwards: bool,
500) {
501    let layer_rev = log_h - 1 - layer;
502
503    let half_block_size = 1 << layer_rev;
504    let block_size = half_block_size * 2;
505    let width = submat.width();
506    debug_assert!(submat.height() >= block_size);
507
508    let blocks_and_twiddles = submat
509        .values
510        .chunks_mut(block_size * width)
511        .zip(twiddles_rev);
512    if backwards {
513        for (block, twiddle) in blocks_and_twiddles.rev() {
514            let (lo, hi) = block.split_at_mut(half_block_size * width);
515            DitButterfly(twiddle).apply_to_rows(lo, hi)
516        }
517    } else {
518        for (block, twiddle) in blocks_and_twiddles {
519            let (lo, hi) = block.split_at_mut(half_block_size * width);
520            DitButterfly(twiddle).apply_to_rows(lo, hi)
521        }
522    }
523}