p3_dft/
radix_2_bowers.rs

1use alloc::vec::Vec;
2
3use p3_field::{Field, Powers, TwoAdicField};
4use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
5use p3_matrix::util::reverse_matrix_index_bits;
6use p3_matrix::Matrix;
7use p3_maybe_rayon::prelude::*;
8use p3_util::{log2_strict_usize, reverse_bits, reverse_slice_index_bits};
9
10use crate::butterflies::{Butterfly, DifButterfly, DitButterfly, TwiddleFreeButterfly};
11use crate::util::{bit_reversed_zero_pad, divide_by_height};
12use crate::TwoAdicSubgroupDft;
13
14/// The Bowers G FFT algorithm.
15/// See: "Improved Twiddle Access for Fast Fourier Transforms"
16#[derive(Default, Clone)]
17pub struct Radix2Bowers;
18
19impl<F: TwoAdicField> TwoAdicSubgroupDft<F> for Radix2Bowers {
20    type Evaluations = RowMajorMatrix<F>;
21
22    fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
23        reverse_matrix_index_bits(&mut mat);
24        bowers_g(&mut mat.as_view_mut());
25        mat
26    }
27
28    /// Compute the inverse DFT of each column in `mat`.
29    fn idft_batch(&self, mut mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
30        bowers_g_t(&mut mat.as_view_mut());
31        divide_by_height(&mut mat);
32        reverse_matrix_index_bits(&mut mat);
33        mat
34    }
35
36    fn lde_batch(&self, mut mat: RowMajorMatrix<F>, added_bits: usize) -> RowMajorMatrix<F> {
37        bowers_g_t(&mut mat.as_view_mut());
38        divide_by_height(&mut mat);
39        bit_reversed_zero_pad(&mut mat, added_bits);
40        bowers_g(&mut mat.as_view_mut());
41        mat
42    }
43
44    fn coset_lde_batch(
45        &self,
46        mut mat: RowMajorMatrix<F>,
47        added_bits: usize,
48        shift: F,
49    ) -> RowMajorMatrix<F> {
50        let h = mat.height();
51        let h_inv = F::from_canonical_usize(h).inverse();
52
53        bowers_g_t(&mut mat.as_view_mut());
54
55        // Rescale coefficients in two ways:
56        // - divide by height (since we're doing an inverse DFT)
57        // - multiply by powers of the coset shift (see default coset LDE impl for an explanation)
58        let weights = Powers {
59            base: shift,
60            current: h_inv,
61        }
62        .take(h);
63        for (row, weight) in weights.enumerate() {
64            // reverse_bits because mat is encoded in bit-reversed order
65            mat.scale_row(reverse_bits(row, h), weight);
66        }
67
68        bit_reversed_zero_pad(&mut mat, added_bits);
69
70        bowers_g(&mut mat.as_view_mut());
71
72        mat
73    }
74}
75
76/// Executes the Bowers G network. This is like a DFT, except it assumes the input is in
77/// bit-reversed order.
78fn bowers_g<F: TwoAdicField>(mat: &mut RowMajorMatrixViewMut<F>) {
79    let h = mat.height();
80    let log_h = log2_strict_usize(h);
81
82    let root = F::two_adic_generator(log_h);
83    let mut twiddles: Vec<_> = root.powers().take(h / 2).map(DifButterfly).collect();
84    reverse_slice_index_bits(&mut twiddles);
85
86    let log_h = log2_strict_usize(mat.height());
87    for log_half_block_size in 0..log_h {
88        butterfly_layer(mat, 1 << log_half_block_size, &twiddles)
89    }
90}
91
92/// Executes the Bowers G^T network. This is like an inverse DFT, except we skip rescaling by
93/// 1/height, and the output is bit-reversed.
94fn bowers_g_t<F: TwoAdicField>(mat: &mut RowMajorMatrixViewMut<F>) {
95    let h = mat.height();
96    let log_h = log2_strict_usize(h);
97
98    let root_inv = F::two_adic_generator(log_h).inverse();
99    let mut twiddles: Vec<_> = root_inv.powers().take(h / 2).map(DitButterfly).collect();
100    reverse_slice_index_bits(&mut twiddles);
101
102    let log_h = log2_strict_usize(mat.height());
103    for log_half_block_size in (0..log_h).rev() {
104        butterfly_layer(mat, 1 << log_half_block_size, &twiddles)
105    }
106}
107
108fn butterfly_layer<F: Field, B: Butterfly<F>>(
109    mat: &mut RowMajorMatrixViewMut<F>,
110    half_block_size: usize,
111    twiddles: &[B],
112) {
113    mat.par_row_chunks_exact_mut(2 * half_block_size)
114        .enumerate()
115        .for_each(|(block, mut chunks)| {
116            let (mut hi_chunks, mut lo_chunks) = chunks.split_rows_mut(half_block_size);
117            hi_chunks
118                .par_rows_mut()
119                .zip(lo_chunks.par_rows_mut())
120                .for_each(|(hi_chunk, lo_chunk)| {
121                    if block == 0 {
122                        TwiddleFreeButterfly.apply_to_rows(hi_chunk, lo_chunk)
123                    } else {
124                        twiddles[block].apply_to_rows(hi_chunk, lo_chunk);
125                    }
126                });
127        });
128}
129
130#[cfg(test)]
131mod tests {
132    use p3_baby_bear::BabyBear;
133    use p3_goldilocks::Goldilocks;
134
135    use crate::radix_2_bowers::Radix2Bowers;
136    use crate::testing::*;
137
138    #[test]
139    fn dft_matches_naive() {
140        test_dft_matches_naive::<BabyBear, Radix2Bowers>();
141    }
142
143    #[test]
144    fn coset_dft_matches_naive() {
145        test_coset_dft_matches_naive::<BabyBear, Radix2Bowers>();
146    }
147
148    #[test]
149    fn idft_matches_naive() {
150        test_idft_matches_naive::<Goldilocks, Radix2Bowers>();
151    }
152
153    #[test]
154    fn coset_idft_matches_naive() {
155        test_coset_idft_matches_naive::<BabyBear, Radix2Bowers>();
156        test_coset_idft_matches_naive::<Goldilocks, Radix2Bowers>();
157    }
158
159    #[test]
160    fn lde_matches_naive() {
161        test_lde_matches_naive::<BabyBear, Radix2Bowers>();
162    }
163
164    #[test]
165    fn coset_lde_matches_naive() {
166        test_coset_lde_matches_naive::<BabyBear, Radix2Bowers>();
167    }
168
169    #[test]
170    fn dft_idft_consistency() {
171        test_dft_idft_consistency::<BabyBear, Radix2Bowers>();
172    }
173}