p3_dft/
radix_2_dit_parallel.rs1use alloc::vec::Vec;
2
3use itertools::izip;
4use p3_field::{Field, Powers, TwoAdicField};
5use p3_matrix::bitrev::{BitReversableMatrix, BitReversedMatrixView};
6use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
7use p3_matrix::util::reverse_matrix_index_bits;
8use p3_matrix::Matrix;
9use p3_maybe_rayon::prelude::*;
10use p3_util::{log2_strict_usize, reverse_bits, reverse_slice_index_bits};
11use tracing::instrument;
12
13use crate::butterflies::{Butterfly, DitButterfly};
14use crate::TwoAdicSubgroupDft;
15
16#[derive(Default, Clone, Debug)]
24pub struct Radix2DitParallel;
25
26impl<F: TwoAdicField> TwoAdicSubgroupDft<F> for Radix2DitParallel {
27 type Evaluations = BitReversedMatrixView<RowMajorMatrix<F>>;
28
29 fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> Self::Evaluations {
30 let h = mat.height();
31 let log_h = log2_strict_usize(h);
32
33 let root = F::two_adic_generator(log_h);
34 let mut twiddles: Vec<F> = root.powers().take(h / 2).collect();
35
36 let mid = log_h / 2;
37
38 reverse_matrix_index_bits(&mut mat);
40 par_dit_layer(&mut mat, mid, &twiddles);
41
42 reverse_matrix_index_bits(&mut mat);
44 reverse_slice_index_bits(&mut twiddles);
45 par_dit_layer_rev(&mut mat, mid, &twiddles);
46
47 mat.bit_reverse_rows()
48 }
49
50 #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
51 fn coset_lde_batch(
52 &self,
53 mut mat: RowMajorMatrix<F>,
54 added_bits: usize,
55 shift: F,
56 ) -> Self::Evaluations {
57 let h = mat.height();
58 let log_h = log2_strict_usize(h);
59 let mid = log_h / 2;
60 let h_inv = F::from_canonical_usize(h).inverse();
61
62 let root = F::two_adic_generator(log_h);
63 let root_inv = root.inverse();
64
65 let mut twiddles_inv: Vec<F> = root_inv.powers().take(h / 2).collect();
66
67 reverse_matrix_index_bits(&mut mat);
69 par_dit_layer(&mut mat, mid, &twiddles_inv);
70
71 reverse_matrix_index_bits(&mut mat);
73 reverse_slice_index_bits(&mut twiddles_inv);
74 par_dit_layer_rev(&mut mat, mid, &twiddles_inv);
75 let weights = Powers {
81 base: shift,
82 current: h_inv,
83 }
84 .take(h);
85 for (row, weight) in weights.enumerate() {
86 mat.scale_row(reverse_bits(row, h), weight);
88 }
89
90 mat = mat.bit_reversed_zero_pad(added_bits);
91
92 let h = mat.height();
93 let log_h = log2_strict_usize(h);
94 let mid = log_h / 2;
95
96 let root = F::two_adic_generator(log_h);
97
98 let mut twiddles: Vec<F> = root.powers().take(h / 2).collect();
99
100 par_dit_layer(&mut mat, mid, &twiddles);
102
103 reverse_matrix_index_bits(&mut mat);
105 reverse_slice_index_bits(&mut twiddles);
106 par_dit_layer_rev(&mut mat, mid, &twiddles);
107
108 mat.bit_reverse_rows()
109 }
110}
111
112#[instrument(level = "debug", skip_all)]
114fn par_dit_layer<F: Field>(mat: &mut RowMajorMatrix<F>, mid: usize, twiddles: &[F]) {
115 let log_h = log2_strict_usize(mat.height());
116
117 mat.par_row_chunks_exact_mut(1 << mid)
119 .for_each(|mut submat| {
120 for layer in 0..mid {
121 dit_layer(&mut submat, log_h, layer, twiddles);
122 }
123 });
124}
125
126#[instrument(level = "debug", skip_all)]
128fn par_dit_layer_rev<F: Field>(mat: &mut RowMajorMatrix<F>, mid: usize, twiddles_rev: &[F]) {
129 let log_h = log2_strict_usize(mat.height());
130
131 mat.par_row_chunks_exact_mut(1 << (log_h - mid))
133 .enumerate()
134 .for_each(|(thread, mut submat)| {
135 for layer in mid..log_h {
136 let first_block = thread << (layer - mid);
137 dit_layer_rev(&mut submat, log_h, layer, &twiddles_rev[first_block..]);
138 }
139 });
140}
141
142fn dit_layer<F: Field>(
144 submat: &mut RowMajorMatrixViewMut<'_, F>,
145 log_h: usize,
146 layer: usize,
147 twiddles: &[F],
148) {
149 let layer_rev = log_h - 1 - layer;
150 let layer_pow = 1 << layer_rev;
151
152 let half_block_size = 1 << layer;
153 let block_size = half_block_size * 2;
154 let width = submat.width();
155 debug_assert!(submat.height() >= block_size);
156
157 for block in submat.values.chunks_mut(block_size * width) {
158 let (lows, highs) = block.split_at_mut(half_block_size * width);
159
160 for (lo, hi, &twiddle) in izip!(
161 lows.chunks_mut(width),
162 highs.chunks_mut(width),
163 twiddles.iter().step_by(layer_pow)
164 ) {
165 DitButterfly(twiddle).apply_to_rows(lo, hi);
166 }
167 }
168}
169
170fn dit_layer_rev<F: Field>(
173 submat: &mut RowMajorMatrixViewMut<'_, F>,
174 log_h: usize,
175 layer: usize,
176 twiddles_rev: &[F],
177) {
178 let layer_rev = log_h - 1 - layer;
179
180 let half_block_size = 1 << layer_rev;
181 let block_size = half_block_size * 2;
182 let width = submat.width();
183 debug_assert!(submat.height() >= block_size);
184
185 for (block, &twiddle) in submat
186 .values
187 .chunks_mut(block_size * width)
188 .zip(twiddles_rev)
189 {
190 let (lo, hi) = block.split_at_mut(half_block_size * width);
191 DitButterfly(twiddle).apply_to_rows(lo, hi)
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use p3_baby_bear::BabyBear;
198 use p3_goldilocks::Goldilocks;
199
200 use crate::testing::*;
201 use crate::Radix2DitParallel;
202
203 #[test]
204 fn dft_matches_naive() {
205 test_dft_matches_naive::<BabyBear, Radix2DitParallel>();
206 }
207
208 #[test]
209 fn coset_dft_matches_naive() {
210 test_coset_dft_matches_naive::<BabyBear, Radix2DitParallel>();
211 }
212
213 #[test]
214 fn idft_matches_naive() {
215 test_idft_matches_naive::<Goldilocks, Radix2DitParallel>();
216 }
217
218 #[test]
219 fn coset_idft_matches_naive() {
220 test_coset_idft_matches_naive::<BabyBear, Radix2DitParallel>();
221 test_coset_idft_matches_naive::<Goldilocks, Radix2DitParallel>();
222 }
223
224 #[test]
225 fn lde_matches_naive() {
226 test_lde_matches_naive::<BabyBear, Radix2DitParallel>();
227 }
228
229 #[test]
230 fn coset_lde_matches_naive() {
231 test_coset_lde_matches_naive::<BabyBear, Radix2DitParallel>();
232 }
233
234 #[test]
235 fn dft_idft_consistency() {
236 test_dft_idft_consistency::<BabyBear, Radix2DitParallel>();
237 }
238}