1use alloc::collections::BTreeMap;
2use alloc::slice;
3use alloc::vec::Vec;
4use core::cell::RefCell;
5use core::mem::{MaybeUninit, transmute};
6
7use itertools::{Itertools, izip};
8use p3_field::integers::QuotientMap;
9use p3_field::{Field, Powers, TwoAdicField};
10use p3_matrix::Matrix;
11use p3_matrix::bitrev::{BitReversalPerm, BitReversedMatrixView, BitReversibleMatrix};
12use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView, RowMajorMatrixViewMut};
13use p3_matrix::util::reverse_matrix_index_bits;
14use p3_maybe_rayon::prelude::*;
15use p3_util::{log2_strict_usize, reverse_bits_len, reverse_slice_index_bits};
16use tracing::{debug_span, instrument};
17
18use crate::TwoAdicSubgroupDft;
19use crate::butterflies::{Butterfly, DitButterfly};
20
21#[derive(Default, Clone, Debug)]
29pub struct Radix2DitParallel<F> {
30 twiddles: RefCell<BTreeMap<usize, VectorPair<F>>>,
32
33 #[allow(clippy::type_complexity)]
35 coset_twiddles: RefCell<BTreeMap<(usize, F), Vec<Vec<F>>>>,
36
37 inverse_twiddles: RefCell<BTreeMap<usize, VectorPair<F>>>,
39}
40
41#[derive(Default, Clone, Debug)]
43struct VectorPair<F> {
44 twiddles: Vec<F>,
45 bitrev_twiddles: Vec<F>,
46}
47
48#[instrument(level = "debug", skip_all)]
49fn compute_twiddles<F: TwoAdicField + Ord>(log_h: usize) -> VectorPair<F> {
50 let half_h = (1 << log_h) >> 1;
51 let root = F::two_adic_generator(log_h);
52 let twiddles: Vec<F> = root.powers().take(half_h).collect();
53 let mut bit_reversed_twiddles = twiddles.clone();
54 reverse_slice_index_bits(&mut bit_reversed_twiddles);
55 VectorPair {
56 twiddles,
57 bitrev_twiddles: bit_reversed_twiddles,
58 }
59}
60
61#[instrument(level = "debug", skip_all)]
62fn compute_coset_twiddles<F: TwoAdicField + Ord>(log_h: usize, shift: F) -> Vec<Vec<F>> {
63 let mid = log_h.div_ceil(2);
67 let h = 1 << log_h;
68 let root = F::two_adic_generator(log_h);
69
70 (0..log_h)
71 .map(|layer| {
72 let shift_power = shift.exp_power_of_2(layer);
73 let powers = Powers {
74 base: root.exp_power_of_2(layer),
75 current: shift_power,
76 };
77 let mut twiddles: Vec<_> = powers.take(h >> (layer + 1)).collect();
78 let layer_rev = log_h - 1 - layer;
79 if layer_rev >= mid {
80 reverse_slice_index_bits(&mut twiddles);
81 }
82 twiddles
83 })
84 .collect()
85}
86
87#[instrument(level = "debug", skip_all)]
88fn compute_inverse_twiddles<F: TwoAdicField + Ord>(log_h: usize) -> VectorPair<F> {
89 let half_h = (1 << log_h) >> 1;
90 let root_inv = F::two_adic_generator(log_h).inverse();
91 let twiddles: Vec<F> = root_inv.powers().take(half_h).collect();
92 let mut bit_reversed_twiddles = twiddles.clone();
93
94 reverse_slice_index_bits(&mut bit_reversed_twiddles);
96
97 VectorPair {
98 twiddles,
99 bitrev_twiddles: bit_reversed_twiddles,
100 }
101}
102
103impl<F: TwoAdicField + Ord> TwoAdicSubgroupDft<F> for Radix2DitParallel<F> {
104 type Evaluations = BitReversedMatrixView<RowMajorMatrix<F>>;
105
106 fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> Self::Evaluations {
107 let h = mat.height();
108 let log_h = log2_strict_usize(h);
109
110 let mut twiddles_ref_mut = self.twiddles.borrow_mut();
112 let twiddles = twiddles_ref_mut
113 .entry(log_h)
114 .or_insert_with(|| compute_twiddles(log_h));
115
116 let mid = log_h.div_ceil(2);
117
118 reverse_matrix_index_bits(&mut mat);
120 first_half(&mut mat, mid, &twiddles.twiddles);
121
122 reverse_matrix_index_bits(&mut mat);
124 second_half(&mut mat, mid, &twiddles.bitrev_twiddles, None);
125
126 mat.bit_reverse_rows()
127 }
128
129 #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits = added_bits))]
130 fn coset_lde_batch(
131 &self,
132 mut mat: RowMajorMatrix<F>,
133 added_bits: usize,
134 shift: F,
135 ) -> Self::Evaluations {
136 let w = mat.width;
137 let h = mat.height();
138 let log_h = log2_strict_usize(h);
139 let mid = log_h.div_ceil(2);
140
141 let mut inverse_twiddles_ref_mut = self.inverse_twiddles.borrow_mut();
142 let inverse_twiddles = inverse_twiddles_ref_mut
143 .entry(log_h)
144 .or_insert_with(|| compute_inverse_twiddles(log_h));
145
146 reverse_matrix_index_bits(&mut mat);
148 first_half(&mut mat, mid, &inverse_twiddles.twiddles);
149
150 reverse_matrix_index_bits(&mut mat);
152 let h_inv_subfield = F::PrimeSubfield::from_int(h).try_inverse();
156 let scale = h_inv_subfield.map(F::from_prime_subfield);
157 second_half(&mut mat, mid, &inverse_twiddles.bitrev_twiddles, scale);
158 let lde_elems = w * (h << added_bits);
161 let elems_to_add = lde_elems - w * h;
162 debug_span!("reserve_exact").in_scope(|| mat.values.reserve_exact(elems_to_add));
163
164 let g_big = F::two_adic_generator(log_h + added_bits);
165
166 let mat_ptr = mat.values.as_mut_ptr();
167 let rest_ptr = unsafe { (mat_ptr as *mut MaybeUninit<F>).add(w * h) };
168 let first_slice: &mut [F] = unsafe { slice::from_raw_parts_mut(mat_ptr, w * h) };
169 let rest_slice: &mut [MaybeUninit<F>] =
170 unsafe { slice::from_raw_parts_mut(rest_ptr, lde_elems - w * h) };
171 let mut first_coset_mat = RowMajorMatrixViewMut::new(first_slice, w);
172 let mut rest_cosets_mat = rest_slice
173 .chunks_exact_mut(w * h)
174 .map(|slice| RowMajorMatrixViewMut::new(slice, w))
175 .collect_vec();
176
177 for coset_idx in 1..(1 << added_bits) {
178 let total_shift = g_big.exp_u64(coset_idx as u64) * shift;
179 let coset_idx = reverse_bits_len(coset_idx, added_bits);
180 let dest = &mut rest_cosets_mat[coset_idx - 1]; coset_dft_oop(self, &first_coset_mat.as_view(), dest, total_shift);
182 }
183
184 coset_dft(self, &mut first_coset_mat.as_view_mut(), shift);
186
187 unsafe {
189 mat.values.set_len(lde_elems);
190 }
191 BitReversalPerm::new_view(mat)
192 }
193}
194
195#[instrument(level = "debug", skip_all)]
196fn coset_dft<F: TwoAdicField + Ord>(
197 dft: &Radix2DitParallel<F>,
198 mat: &mut RowMajorMatrixViewMut<F>,
199 shift: F,
200) {
201 let log_h = log2_strict_usize(mat.height());
202 let mid = log_h.div_ceil(2);
203
204 let mut twiddles_ref_mut = dft.coset_twiddles.borrow_mut();
205 let twiddles = twiddles_ref_mut
206 .entry((log_h, shift))
207 .or_insert_with(|| compute_coset_twiddles(log_h, shift));
208
209 first_half_general(mat, mid, twiddles);
211
212 reverse_matrix_index_bits(mat);
214
215 second_half_general(mat, mid, twiddles);
216}
217
218#[instrument(level = "debug", skip_all)]
220fn coset_dft_oop<F: TwoAdicField + Ord>(
221 dft: &Radix2DitParallel<F>,
222 src: &RowMajorMatrixView<F>,
223 dst_maybe: &mut RowMajorMatrixViewMut<MaybeUninit<F>>,
224 shift: F,
225) {
226 assert_eq!(src.dimensions(), dst_maybe.dimensions());
227
228 let log_h = log2_strict_usize(dst_maybe.height());
229
230 if log_h == 0 {
231 let src_maybe = unsafe {
234 transmute::<&RowMajorMatrixView<F>, &RowMajorMatrixView<MaybeUninit<F>>>(src)
235 };
236 dst_maybe.copy_from(src_maybe);
237 return;
238 }
239
240 let mid = log_h.div_ceil(2);
241
242 let mut twiddles_ref_mut = dft.coset_twiddles.borrow_mut();
243 let twiddles = twiddles_ref_mut
244 .entry((log_h, shift))
245 .or_insert_with(|| compute_coset_twiddles(log_h, shift));
246
247 first_half_general_oop(src, dst_maybe, mid, twiddles);
249
250 let dst = unsafe {
252 transmute::<&mut RowMajorMatrixViewMut<MaybeUninit<F>>, &mut RowMajorMatrixViewMut<F>>(
253 dst_maybe,
254 )
255 };
256
257 reverse_matrix_index_bits(dst);
259
260 second_half_general(dst, mid, twiddles);
261}
262
263#[instrument(level = "debug", skip_all)]
265fn first_half<F: Field>(mat: &mut RowMajorMatrix<F>, mid: usize, twiddles: &[F]) {
266 let log_h = log2_strict_usize(mat.height());
267
268 mat.par_row_chunks_exact_mut(1 << mid)
270 .for_each(|mut submat| {
271 let mut backwards = false;
272 for layer in 0..mid {
273 let layer_rev = log_h - 1 - layer;
274 let layer_pow = 1 << layer_rev;
275 dit_layer(
276 &mut submat,
277 layer,
278 twiddles.iter().copied().step_by(layer_pow),
279 backwards,
280 );
281 backwards = !backwards;
282 }
283 });
284}
285
286#[instrument(level = "debug", skip_all)]
289fn first_half_general<F: Field>(
290 mat: &mut RowMajorMatrixViewMut<F>,
291 mid: usize,
292 twiddles: &[Vec<F>],
293) {
294 let log_h = log2_strict_usize(mat.height());
295 mat.par_row_chunks_exact_mut(1 << mid)
296 .for_each(|mut submat| {
297 let mut backwards = false;
298 for layer in 0..mid {
299 let layer_rev = log_h - 1 - layer;
300 dit_layer(
301 &mut submat,
302 layer,
303 twiddles[layer_rev].iter().copied(),
304 backwards,
305 );
306 backwards = !backwards;
307 }
308 });
309}
310
311#[instrument(level = "debug", skip_all)]
316fn first_half_general_oop<F: Field>(
317 src: &RowMajorMatrixView<F>,
318 dst_maybe: &mut RowMajorMatrixViewMut<MaybeUninit<F>>,
319 mid: usize,
320 twiddles: &[Vec<F>],
321) {
322 let log_h = log2_strict_usize(src.height());
323 src.par_row_chunks_exact(1 << mid)
324 .zip(dst_maybe.par_row_chunks_exact_mut(1 << mid))
325 .for_each(|(src_submat, mut dst_submat_maybe)| {
326 debug_assert_eq!(src_submat.dimensions(), dst_submat_maybe.dimensions());
327
328 let layer_rev = log_h - 1;
331 dit_layer_oop(
332 &src_submat,
333 &mut dst_submat_maybe,
334 0,
335 twiddles[layer_rev].iter().copied(),
336 );
337
338 let mut dst_submat = unsafe {
340 transmute::<RowMajorMatrixViewMut<MaybeUninit<F>>, RowMajorMatrixViewMut<F>>(
341 dst_submat_maybe,
342 )
343 };
344
345 let mut backwards = true;
347 for layer in 1..mid {
348 let layer_rev = log_h - 1 - layer;
349 dit_layer(
350 &mut dst_submat,
351 layer,
352 twiddles[layer_rev].iter().copied(),
353 backwards,
354 );
355 backwards = !backwards;
356 }
357 });
358}
359
360#[instrument(level = "debug", skip_all)]
366#[inline(always)] fn second_half<F: Field>(
368 mat: &mut RowMajorMatrix<F>,
369 mid: usize,
370 twiddles_rev: &[F],
371 scale: Option<F>,
372) {
373 let log_h = log2_strict_usize(mat.height());
374
375 mat.par_row_chunks_exact_mut(1 << (log_h - mid))
377 .enumerate()
378 .for_each(|(thread, mut submat)| {
379 let mut backwards = false;
380 if let Some(scale) = scale {
381 submat.scale(scale);
382 }
383 for layer in mid..log_h {
384 let first_block = thread << (layer - mid);
385 dit_layer_rev(
386 &mut submat,
387 log_h,
388 layer,
389 twiddles_rev[first_block..].iter().copied(),
390 backwards,
391 );
392 backwards = !backwards;
393 }
394 });
395}
396
397#[instrument(level = "debug", skip_all)]
400fn second_half_general<F: Field>(
401 mat: &mut RowMajorMatrixViewMut<F>,
402 mid: usize,
403 twiddles_rev: &[Vec<F>],
404) {
405 let log_h = log2_strict_usize(mat.height());
406 mat.par_row_chunks_exact_mut(1 << (log_h - mid))
407 .enumerate()
408 .for_each(|(thread, mut submat)| {
409 let mut backwards = false;
410 for layer in mid..log_h {
411 let layer_rev = log_h - 1 - layer;
412 let first_block = thread << (layer - mid);
413 dit_layer_rev(
414 &mut submat,
415 log_h,
416 layer,
417 twiddles_rev[layer_rev][first_block..].iter().copied(),
418 backwards,
419 );
420 backwards = !backwards;
421 }
422 });
423}
424
425fn dit_layer<F: Field>(
427 submat: &mut RowMajorMatrixViewMut<'_, F>,
428 layer: usize,
429 twiddles: impl Iterator<Item = F> + Clone,
430 backwards: bool,
431) {
432 let half_block_size = 1 << layer;
433 let block_size = half_block_size * 2;
434 let width = submat.width();
435 debug_assert!(submat.height() >= block_size);
436
437 let process_block = |block: &mut [F]| {
438 let (lows, highs) = block.split_at_mut(half_block_size * width);
439
440 for (lo, hi, twiddle) in izip!(
441 lows.chunks_mut(width),
442 highs.chunks_mut(width),
443 twiddles.clone()
444 ) {
445 DitButterfly(twiddle).apply_to_rows(lo, hi);
446 }
447 };
448
449 let blocks = submat.values.chunks_mut(block_size * width);
450 if backwards {
451 for block in blocks.rev() {
452 process_block(block);
453 }
454 } else {
455 for block in blocks {
456 process_block(block);
457 }
458 }
459}
460
461fn dit_layer_oop<F: Field>(
463 src: &RowMajorMatrixView<F>,
464 dst: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
465 layer: usize,
466 twiddles: impl Iterator<Item = F> + Clone,
467) {
468 debug_assert_eq!(src.dimensions(), dst.dimensions());
469 let half_block_size = 1 << layer;
470 let block_size = half_block_size * 2;
471 let width = dst.width();
472 debug_assert!(dst.height() >= block_size);
473
474 let src_chunks = src.values.chunks(block_size * width);
475 let dst_chunks = dst.values.chunks_mut(block_size * width);
476 for (src_block, dst_block) in src_chunks.zip(dst_chunks) {
477 let (src_lows, src_highs) = src_block.split_at(half_block_size * width);
478 let (dst_lows, dst_highs) = dst_block.split_at_mut(half_block_size * width);
479
480 for (src_lo, dst_lo, src_hi, dst_hi, twiddle) in izip!(
481 src_lows.chunks(width),
482 dst_lows.chunks_mut(width),
483 src_highs.chunks(width),
484 dst_highs.chunks_mut(width),
485 twiddles.clone()
486 ) {
487 DitButterfly(twiddle).apply_to_rows_oop(src_lo, dst_lo, src_hi, dst_hi);
488 }
489 }
490}
491
492fn dit_layer_rev<F: Field>(
495 submat: &mut RowMajorMatrixViewMut<'_, F>,
496 log_h: usize,
497 layer: usize,
498 twiddles_rev: impl DoubleEndedIterator<Item = F> + ExactSizeIterator,
499 backwards: bool,
500) {
501 let layer_rev = log_h - 1 - layer;
502
503 let half_block_size = 1 << layer_rev;
504 let block_size = half_block_size * 2;
505 let width = submat.width();
506 debug_assert!(submat.height() >= block_size);
507
508 let blocks_and_twiddles = submat
509 .values
510 .chunks_mut(block_size * width)
511 .zip(twiddles_rev);
512 if backwards {
513 for (block, twiddle) in blocks_and_twiddles.rev() {
514 let (lo, hi) = block.split_at_mut(half_block_size * width);
515 DitButterfly(twiddle).apply_to_rows(lo, hi)
516 }
517 } else {
518 for (block, twiddle) in blocks_and_twiddles {
519 let (lo, hi) = block.split_at_mut(half_block_size * width);
520 DitButterfly(twiddle).apply_to_rows(lo, hi)
521 }
522 }
523}