1use alloc::vec::Vec;
2
3use p3_field::{Field, Powers, TwoAdicField};
4use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
5use p3_matrix::util::reverse_matrix_index_bits;
6use p3_matrix::Matrix;
7use p3_maybe_rayon::prelude::*;
8use p3_util::{log2_strict_usize, reverse_bits, reverse_slice_index_bits};
9
10use crate::butterflies::{Butterfly, DifButterfly, DitButterfly, TwiddleFreeButterfly};
11use crate::util::{bit_reversed_zero_pad, divide_by_height};
12use crate::TwoAdicSubgroupDft;
13
14#[derive(Default, Clone)]
17pub struct Radix2Bowers;
18
19impl<F: TwoAdicField> TwoAdicSubgroupDft<F> for Radix2Bowers {
20 type Evaluations = RowMajorMatrix<F>;
21
22 fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
23 reverse_matrix_index_bits(&mut mat);
24 bowers_g(&mut mat.as_view_mut());
25 mat
26 }
27
28 fn idft_batch(&self, mut mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
30 bowers_g_t(&mut mat.as_view_mut());
31 divide_by_height(&mut mat);
32 reverse_matrix_index_bits(&mut mat);
33 mat
34 }
35
36 fn lde_batch(&self, mut mat: RowMajorMatrix<F>, added_bits: usize) -> RowMajorMatrix<F> {
37 bowers_g_t(&mut mat.as_view_mut());
38 divide_by_height(&mut mat);
39 bit_reversed_zero_pad(&mut mat, added_bits);
40 bowers_g(&mut mat.as_view_mut());
41 mat
42 }
43
44 fn coset_lde_batch(
45 &self,
46 mut mat: RowMajorMatrix<F>,
47 added_bits: usize,
48 shift: F,
49 ) -> RowMajorMatrix<F> {
50 let h = mat.height();
51 let h_inv = F::from_canonical_usize(h).inverse();
52
53 bowers_g_t(&mut mat.as_view_mut());
54
55 let weights = Powers {
59 base: shift,
60 current: h_inv,
61 }
62 .take(h);
63 for (row, weight) in weights.enumerate() {
64 mat.scale_row(reverse_bits(row, h), weight);
66 }
67
68 bit_reversed_zero_pad(&mut mat, added_bits);
69
70 bowers_g(&mut mat.as_view_mut());
71
72 mat
73 }
74}
75
76fn bowers_g<F: TwoAdicField>(mat: &mut RowMajorMatrixViewMut<F>) {
79 let h = mat.height();
80 let log_h = log2_strict_usize(h);
81
82 let root = F::two_adic_generator(log_h);
83 let mut twiddles: Vec<_> = root.powers().take(h / 2).map(DifButterfly).collect();
84 reverse_slice_index_bits(&mut twiddles);
85
86 let log_h = log2_strict_usize(mat.height());
87 for log_half_block_size in 0..log_h {
88 butterfly_layer(mat, 1 << log_half_block_size, &twiddles)
89 }
90}
91
92fn bowers_g_t<F: TwoAdicField>(mat: &mut RowMajorMatrixViewMut<F>) {
95 let h = mat.height();
96 let log_h = log2_strict_usize(h);
97
98 let root_inv = F::two_adic_generator(log_h).inverse();
99 let mut twiddles: Vec<_> = root_inv.powers().take(h / 2).map(DitButterfly).collect();
100 reverse_slice_index_bits(&mut twiddles);
101
102 let log_h = log2_strict_usize(mat.height());
103 for log_half_block_size in (0..log_h).rev() {
104 butterfly_layer(mat, 1 << log_half_block_size, &twiddles)
105 }
106}
107
108fn butterfly_layer<F: Field, B: Butterfly<F>>(
109 mat: &mut RowMajorMatrixViewMut<F>,
110 half_block_size: usize,
111 twiddles: &[B],
112) {
113 mat.par_row_chunks_exact_mut(2 * half_block_size)
114 .enumerate()
115 .for_each(|(block, mut chunks)| {
116 let (mut hi_chunks, mut lo_chunks) = chunks.split_rows_mut(half_block_size);
117 hi_chunks
118 .par_rows_mut()
119 .zip(lo_chunks.par_rows_mut())
120 .for_each(|(hi_chunk, lo_chunk)| {
121 if block == 0 {
122 TwiddleFreeButterfly.apply_to_rows(hi_chunk, lo_chunk)
123 } else {
124 twiddles[block].apply_to_rows(hi_chunk, lo_chunk);
125 }
126 });
127 });
128}
129
130#[cfg(test)]
131mod tests {
132 use p3_baby_bear::BabyBear;
133 use p3_goldilocks::Goldilocks;
134
135 use crate::radix_2_bowers::Radix2Bowers;
136 use crate::testing::*;
137
138 #[test]
139 fn dft_matches_naive() {
140 test_dft_matches_naive::<BabyBear, Radix2Bowers>();
141 }
142
143 #[test]
144 fn coset_dft_matches_naive() {
145 test_coset_dft_matches_naive::<BabyBear, Radix2Bowers>();
146 }
147
148 #[test]
149 fn idft_matches_naive() {
150 test_idft_matches_naive::<Goldilocks, Radix2Bowers>();
151 }
152
153 #[test]
154 fn coset_idft_matches_naive() {
155 test_coset_idft_matches_naive::<BabyBear, Radix2Bowers>();
156 test_coset_idft_matches_naive::<Goldilocks, Radix2Bowers>();
157 }
158
159 #[test]
160 fn lde_matches_naive() {
161 test_lde_matches_naive::<BabyBear, Radix2Bowers>();
162 }
163
164 #[test]
165 fn coset_lde_matches_naive() {
166 test_coset_lde_matches_naive::<BabyBear, Radix2Bowers>();
167 }
168
169 #[test]
170 fn dft_idft_consistency() {
171 test_dft_idft_consistency::<BabyBear, Radix2Bowers>();
172 }
173}