p3_dft/
radix_2_dit_parallel.rs

1use alloc::vec::Vec;
2
3use itertools::izip;
4use p3_field::{Field, Powers, TwoAdicField};
5use p3_matrix::bitrev::{BitReversableMatrix, BitReversedMatrixView};
6use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
7use p3_matrix::util::reverse_matrix_index_bits;
8use p3_matrix::Matrix;
9use p3_maybe_rayon::prelude::*;
10use p3_util::{log2_strict_usize, reverse_bits, reverse_slice_index_bits};
11use tracing::instrument;
12
13use crate::butterflies::{Butterfly, DitButterfly};
14use crate::TwoAdicSubgroupDft;
15
16/// A parallel FFT algorithm which divides a butterfly network's layers into two halves.
17///
18/// For the first half, we apply a butterfly network with smaller blocks in earlier layers,
19/// i.e. either DIT or Bowers G. Then we bit-reverse, and for the second half, we continue executing
20/// the same network but in bit-reversed order. This way we're always working with small blocks,
21/// so within each half, we can have a certain amount of parallelism with no cross-thread
22/// communication.
23#[derive(Default, Clone, Debug)]
24pub struct Radix2DitParallel;
25
26impl<F: TwoAdicField> TwoAdicSubgroupDft<F> for Radix2DitParallel {
27    type Evaluations = BitReversedMatrixView<RowMajorMatrix<F>>;
28
29    fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> Self::Evaluations {
30        let h = mat.height();
31        let log_h = log2_strict_usize(h);
32
33        let root = F::two_adic_generator(log_h);
34        let mut twiddles: Vec<F> = root.powers().take(h / 2).collect();
35
36        let mid = log_h / 2;
37
38        // The first half looks like a normal DIT.
39        reverse_matrix_index_bits(&mut mat);
40        par_dit_layer(&mut mat, mid, &twiddles);
41
42        // For the second half, we flip the DIT, working in bit-reversed order.
43        reverse_matrix_index_bits(&mut mat);
44        reverse_slice_index_bits(&mut twiddles);
45        par_dit_layer_rev(&mut mat, mid, &twiddles);
46
47        mat.bit_reverse_rows()
48    }
49
50    #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
51    fn coset_lde_batch(
52        &self,
53        mut mat: RowMajorMatrix<F>,
54        added_bits: usize,
55        shift: F,
56    ) -> Self::Evaluations {
57        let h = mat.height();
58        let log_h = log2_strict_usize(h);
59        let mid = log_h / 2;
60        let h_inv = F::from_canonical_usize(h).inverse();
61
62        let root = F::two_adic_generator(log_h);
63        let root_inv = root.inverse();
64
65        let mut twiddles_inv: Vec<F> = root_inv.powers().take(h / 2).collect();
66
67        // The first half looks like a normal DIT.
68        reverse_matrix_index_bits(&mut mat);
69        par_dit_layer(&mut mat, mid, &twiddles_inv);
70
71        // For the second half, we flip the DIT, working in bit-reversed order.
72        reverse_matrix_index_bits(&mut mat);
73        reverse_slice_index_bits(&mut twiddles_inv);
74        par_dit_layer_rev(&mut mat, mid, &twiddles_inv);
75        // We skip the final bit-reversal, since the next FFT expects bit-reversed input.
76
77        // Rescale coefficients in two ways:
78        // - divide by height (since we're doing an inverse DFT)
79        // - multiply by powers of the coset shift (see default coset LDE impl for an explanation)
80        let weights = Powers {
81            base: shift,
82            current: h_inv,
83        }
84        .take(h);
85        for (row, weight) in weights.enumerate() {
86            // reverse_bits because mat is encoded in bit-reversed order
87            mat.scale_row(reverse_bits(row, h), weight);
88        }
89
90        mat = mat.bit_reversed_zero_pad(added_bits);
91
92        let h = mat.height();
93        let log_h = log2_strict_usize(h);
94        let mid = log_h / 2;
95
96        let root = F::two_adic_generator(log_h);
97
98        let mut twiddles: Vec<F> = root.powers().take(h / 2).collect();
99
100        // The first half looks like a normal DIT.
101        par_dit_layer(&mut mat, mid, &twiddles);
102
103        // For the second half, we flip the DIT, working in bit-reversed order.
104        reverse_matrix_index_bits(&mut mat);
105        reverse_slice_index_bits(&mut twiddles);
106        par_dit_layer_rev(&mut mat, mid, &twiddles);
107
108        mat.bit_reverse_rows()
109    }
110}
111
112/// This can be used as the first half of a parallelized butterfly network.
113#[instrument(level = "debug", skip_all)]
114fn par_dit_layer<F: Field>(mat: &mut RowMajorMatrix<F>, mid: usize, twiddles: &[F]) {
115    let log_h = log2_strict_usize(mat.height());
116
117    // max block size: 2^mid
118    mat.par_row_chunks_exact_mut(1 << mid)
119        .for_each(|mut submat| {
120            for layer in 0..mid {
121                dit_layer(&mut submat, log_h, layer, twiddles);
122            }
123        });
124}
125
126/// This can be used as the second half of a parallelized butterfly network.
127#[instrument(level = "debug", skip_all)]
128fn par_dit_layer_rev<F: Field>(mat: &mut RowMajorMatrix<F>, mid: usize, twiddles_rev: &[F]) {
129    let log_h = log2_strict_usize(mat.height());
130
131    // max block size: 2^(log_h - mid)
132    mat.par_row_chunks_exact_mut(1 << (log_h - mid))
133        .enumerate()
134        .for_each(|(thread, mut submat)| {
135            for layer in mid..log_h {
136                let first_block = thread << (layer - mid);
137                dit_layer_rev(&mut submat, log_h, layer, &twiddles_rev[first_block..]);
138            }
139        });
140}
141
142/// One layer of a DIT butterfly network.
143fn dit_layer<F: Field>(
144    submat: &mut RowMajorMatrixViewMut<'_, F>,
145    log_h: usize,
146    layer: usize,
147    twiddles: &[F],
148) {
149    let layer_rev = log_h - 1 - layer;
150    let layer_pow = 1 << layer_rev;
151
152    let half_block_size = 1 << layer;
153    let block_size = half_block_size * 2;
154    let width = submat.width();
155    debug_assert!(submat.height() >= block_size);
156
157    for block in submat.values.chunks_mut(block_size * width) {
158        let (lows, highs) = block.split_at_mut(half_block_size * width);
159
160        for (lo, hi, &twiddle) in izip!(
161            lows.chunks_mut(width),
162            highs.chunks_mut(width),
163            twiddles.iter().step_by(layer_pow)
164        ) {
165            DitButterfly(twiddle).apply_to_rows(lo, hi);
166        }
167    }
168}
169
170/// Like `dit_layer`, except the matrix and twiddles are encoded in bit-reversed order.
171/// This can also be viewed as a layer of the Bowers G^T network.
172fn dit_layer_rev<F: Field>(
173    submat: &mut RowMajorMatrixViewMut<'_, F>,
174    log_h: usize,
175    layer: usize,
176    twiddles_rev: &[F],
177) {
178    let layer_rev = log_h - 1 - layer;
179
180    let half_block_size = 1 << layer_rev;
181    let block_size = half_block_size * 2;
182    let width = submat.width();
183    debug_assert!(submat.height() >= block_size);
184
185    for (block, &twiddle) in submat
186        .values
187        .chunks_mut(block_size * width)
188        .zip(twiddles_rev)
189    {
190        let (lo, hi) = block.split_at_mut(half_block_size * width);
191        DitButterfly(twiddle).apply_to_rows(lo, hi)
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use p3_baby_bear::BabyBear;
198    use p3_goldilocks::Goldilocks;
199
200    use crate::testing::*;
201    use crate::Radix2DitParallel;
202
203    #[test]
204    fn dft_matches_naive() {
205        test_dft_matches_naive::<BabyBear, Radix2DitParallel>();
206    }
207
208    #[test]
209    fn coset_dft_matches_naive() {
210        test_coset_dft_matches_naive::<BabyBear, Radix2DitParallel>();
211    }
212
213    #[test]
214    fn idft_matches_naive() {
215        test_idft_matches_naive::<Goldilocks, Radix2DitParallel>();
216    }
217
218    #[test]
219    fn coset_idft_matches_naive() {
220        test_coset_idft_matches_naive::<BabyBear, Radix2DitParallel>();
221        test_coset_idft_matches_naive::<Goldilocks, Radix2DitParallel>();
222    }
223
224    #[test]
225    fn lde_matches_naive() {
226        test_lde_matches_naive::<BabyBear, Radix2DitParallel>();
227    }
228
229    #[test]
230    fn coset_lde_matches_naive() {
231        test_coset_lde_matches_naive::<BabyBear, Radix2DitParallel>();
232    }
233
234    #[test]
235    fn dft_idft_consistency() {
236        test_dft_idft_consistency::<BabyBear, Radix2DitParallel>();
237    }
238}