1use alloc::collections::BTreeMap;
2use alloc::slice;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use core::mem::{MaybeUninit, transmute};
6
7use itertools::{Itertools, izip};
8use p3_field::integers::QuotientMap;
9use p3_field::{Field, Powers, TwoAdicField};
10use p3_matrix::Matrix;
11use p3_matrix::bitrev::{BitReversalPerm, BitReversedMatrixView, BitReversibleMatrix};
12use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView, RowMajorMatrixViewMut};
13use p3_matrix::util::reverse_matrix_index_bits;
14use p3_maybe_rayon::prelude::*;
15use p3_util::{log2_strict_usize, reverse_bits_len, reverse_slice_index_bits};
16use spin::RwLock;
17use tracing::{debug_span, instrument};
18
19use crate::TwoAdicSubgroupDft;
20use crate::butterflies::{Butterfly, DitButterfly, ScaledDitButterfly, TwiddleFreeButterfly};
21
22#[derive(Default, Clone, Debug)]
30pub struct Radix2DitParallel<F> {
31 twiddles: Arc<RwLock<BTreeMap<usize, Arc<VectorPair<F>>>>>,
33
34 #[allow(clippy::type_complexity)]
36 coset_twiddles: Arc<RwLock<BTreeMap<(usize, F), Arc<[Vec<F>]>>>>,
37
38 inverse_twiddles: Arc<RwLock<BTreeMap<usize, Arc<VectorPair<F>>>>>,
40}
41
42#[derive(Default, Clone, Debug)]
44struct VectorPair<F> {
45 twiddles: Vec<F>,
46 bitrev_twiddles: Vec<F>,
47}
48
49impl<F> Radix2DitParallel<F>
50where
51 F: TwoAdicField + Ord,
52{
53 fn get_or_compute_twiddles(&self, log_h: usize) -> Arc<VectorPair<F>> {
54 if let Some(pair) = self.twiddles.read().get(&log_h) {
56 return pair.clone();
57 }
58
59 let mut w_lock = self.twiddles.write();
61
62 w_lock
64 .entry(log_h)
65 .or_insert_with(|| {
66 let half_h = (1 << log_h) >> 1;
67 let root = F::two_adic_generator(log_h);
68 let twiddles = root.powers().collect_n(half_h);
69 let mut bitrev_twiddles = twiddles.clone();
70 reverse_slice_index_bits(&mut bitrev_twiddles);
71
72 Arc::new(VectorPair {
73 twiddles,
74 bitrev_twiddles,
75 })
76 })
77 .clone()
78 }
79
80 fn get_or_compute_coset_twiddles(&self, (log_h, shift): (usize, F)) -> Arc<[Vec<F>]> {
81 let key = (log_h, shift);
82 if let Some(twiddles) = self.coset_twiddles.read().get(&key) {
84 return twiddles.clone();
85 }
86 let mut w_lock = self.coset_twiddles.write();
89 w_lock
92 .entry(key)
93 .or_insert_with(|| {
94 let mid = log_h.div_ceil(2);
95 let h = 1 << log_h;
96 let root = F::two_adic_generator(log_h);
97 (0..log_h)
98 .map(|layer| {
99 let shift_power = shift.exp_power_of_2(layer);
100 let powers = Powers {
101 base: root.exp_power_of_2(layer),
102 current: shift_power,
103 };
104 let mut twiddles = powers.collect_n(h >> (layer + 1));
105 let layer_rev = log_h - 1 - layer;
106 if layer_rev >= mid {
107 reverse_slice_index_bits(&mut twiddles);
108 }
109 twiddles
110 })
111 .collect::<Vec<_>>()
112 .into()
113 })
114 .clone()
115 }
116
117 fn get_or_compute_inverse_twiddles(&self, log_h: usize) -> Arc<VectorPair<F>> {
118 if let Some(pair) = self.inverse_twiddles.read().get(&log_h) {
120 return pair.clone();
121 }
122 let mut w_lock = self.inverse_twiddles.write();
124 w_lock
127 .entry(log_h)
128 .or_insert_with(|| {
129 let half_h = (1 << log_h) >> 1;
131 let root_inv = F::two_adic_generator(log_h).inverse();
132 let twiddles = root_inv.powers().collect_n(half_h);
133 let mut bitrev_twiddles = twiddles.clone();
134 reverse_slice_index_bits(&mut bitrev_twiddles);
135
136 Arc::new(VectorPair {
137 twiddles,
138 bitrev_twiddles,
139 })
140 })
141 .clone()
142 }
143}
144
145impl<F: TwoAdicField + Ord> TwoAdicSubgroupDft<F> for Radix2DitParallel<F> {
146 type Evaluations = BitReversedMatrixView<RowMajorMatrix<F>>;
147
148 fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> Self::Evaluations {
149 let h = mat.height();
150 let log_h = log2_strict_usize(h);
151
152 let twiddles = self.get_or_compute_twiddles(log_h);
154
155 let mid = log_h.div_ceil(2);
156
157 reverse_matrix_index_bits(&mut mat);
159 first_half(&mut mat, mid, &twiddles.twiddles);
160
161 reverse_matrix_index_bits(&mut mat);
163 second_half(&mut mat, mid, &twiddles.bitrev_twiddles, None);
164
165 mat.bit_reverse_rows()
166 }
167
168 #[instrument(skip_all, level = "debug", fields(dims = %mat.dimensions(), added_bits = added_bits))]
169 fn coset_lde_batch(
170 &self,
171 mut mat: RowMajorMatrix<F>,
172 added_bits: usize,
173 shift: F,
174 ) -> Self::Evaluations {
175 let w = mat.width;
176 let h = mat.height();
177 let log_h = log2_strict_usize(h);
178 let mid = log_h.div_ceil(2);
179
180 let inverse_twiddles = self.get_or_compute_inverse_twiddles(log_h);
181
182 reverse_matrix_index_bits(&mut mat);
184 first_half(&mut mat, mid, &inverse_twiddles.twiddles);
185
186 reverse_matrix_index_bits(&mut mat);
188 let h_inv_subfield = F::PrimeSubfield::from_int(h).try_inverse();
192 let scale = h_inv_subfield.map(F::from_prime_subfield);
193 second_half(&mut mat, mid, &inverse_twiddles.bitrev_twiddles, scale);
194 let lde_elems = w * (h << added_bits);
197 let elems_to_add = lde_elems - w * h;
198 debug_span!("reserve_exact").in_scope(|| mat.values.reserve_exact(elems_to_add));
199
200 let g_big = F::two_adic_generator(log_h + added_bits);
201
202 let mat_ptr = mat.values.as_mut_ptr();
203 let rest_ptr = unsafe { (mat_ptr as *mut MaybeUninit<F>).add(w * h) };
204 let first_slice: &mut [F] = unsafe { slice::from_raw_parts_mut(mat_ptr, w * h) };
205 let rest_slice: &mut [MaybeUninit<F>] =
206 unsafe { slice::from_raw_parts_mut(rest_ptr, lde_elems - w * h) };
207 let mut first_coset_mat = RowMajorMatrixViewMut::new(first_slice, w);
208 let mut rest_cosets_mat = rest_slice
209 .chunks_exact_mut(w * h)
210 .map(|slice| RowMajorMatrixViewMut::new(slice, w))
211 .collect_vec();
212
213 for coset_idx in 1..(1 << added_bits) {
214 let total_shift = g_big.exp_u64(coset_idx as u64) * shift;
215 let coset_idx = reverse_bits_len(coset_idx, added_bits);
216 let dest = &mut rest_cosets_mat[coset_idx - 1]; coset_dft_oop(self, &first_coset_mat.as_view(), dest, total_shift);
218 }
219
220 coset_dft(self, &mut first_coset_mat.as_view_mut(), shift);
222
223 unsafe {
225 mat.values.set_len(lde_elems);
226 }
227 BitReversalPerm::new_view(mat)
228 }
229}
230
231#[instrument(level = "debug", skip_all)]
232fn coset_dft<F: TwoAdicField + Ord>(
233 dft: &Radix2DitParallel<F>,
234 mat: &mut RowMajorMatrixViewMut<'_, F>,
235 shift: F,
236) {
237 let log_h = log2_strict_usize(mat.height());
238 let mid = log_h.div_ceil(2);
239
240 let twiddles = dft.get_or_compute_coset_twiddles((log_h, shift));
241
242 first_half_general(mat, mid, &twiddles);
244
245 reverse_matrix_index_bits(mat);
247
248 second_half_general(mat, mid, &twiddles);
249}
250
251#[instrument(level = "debug", skip_all)]
253fn coset_dft_oop<F: TwoAdicField + Ord>(
254 dft: &Radix2DitParallel<F>,
255 src: &RowMajorMatrixView<'_, F>,
256 dst_maybe: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
257 shift: F,
258) {
259 assert_eq!(src.dimensions(), dst_maybe.dimensions());
260
261 let log_h = log2_strict_usize(dst_maybe.height());
262
263 if log_h == 0 {
264 let src_maybe = unsafe {
267 transmute::<&RowMajorMatrixView<'_, F>, &RowMajorMatrixView<'_, MaybeUninit<F>>>(src)
268 };
269 dst_maybe.copy_from(src_maybe);
270 return;
271 }
272
273 let mid = log_h.div_ceil(2);
274
275 let twiddles = dft.get_or_compute_coset_twiddles((log_h, shift));
276
277 first_half_general_oop(src, dst_maybe, mid, &twiddles);
279
280 let dst = unsafe {
282 transmute::<&mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>, &mut RowMajorMatrixViewMut<'_, F>>(
283 dst_maybe,
284 )
285 };
286
287 reverse_matrix_index_bits(dst);
289
290 second_half_general(dst, mid, &twiddles);
291}
292
293#[instrument(level = "debug", skip_all)]
301fn first_half<F: Field>(mat: &mut RowMajorMatrix<F>, mid: usize, twiddles: &[F]) {
302 let log_h = log2_strict_usize(mat.height());
303
304 mat.par_row_chunks_exact_mut(1 << mid)
306 .for_each(|mut submat| {
307 let mut backwards = false;
308 for layer in 0..mid {
309 if layer == 0 {
310 dit_layer_twiddle_free(&mut submat, backwards);
314 } else {
315 let layer_rev = log_h - 1 - layer;
316 let layer_pow = 1 << layer_rev;
317 dit_layer_first_one(
321 &mut submat,
322 layer,
323 twiddles.iter().step_by(layer_pow),
324 backwards,
325 );
326 }
327 backwards = !backwards;
328 }
329 });
330}
331
332#[instrument(level = "debug", skip_all)]
335fn first_half_general<F: Field>(
336 mat: &mut RowMajorMatrixViewMut<'_, F>,
337 mid: usize,
338 twiddles: &[Vec<F>],
339) {
340 let log_h = log2_strict_usize(mat.height());
341 mat.par_row_chunks_exact_mut(1 << mid)
342 .for_each(|mut submat| {
343 let mut backwards = false;
344 for layer in 0..mid {
345 let layer_rev = log_h - 1 - layer;
346 dit_layer(&mut submat, layer, twiddles[layer_rev].iter(), backwards);
347 backwards = !backwards;
348 }
349 });
350}
351
352#[instrument(level = "debug", skip_all)]
357fn first_half_general_oop<F: Field>(
358 src: &RowMajorMatrixView<'_, F>,
359 dst_maybe: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
360 mid: usize,
361 twiddles: &[Vec<F>],
362) {
363 let log_h = log2_strict_usize(src.height());
364 src.par_row_chunks_exact(1 << mid)
365 .zip(dst_maybe.par_row_chunks_exact_mut(1 << mid))
366 .for_each(|(src_submat, mut dst_submat_maybe)| {
367 debug_assert_eq!(src_submat.dimensions(), dst_submat_maybe.dimensions());
368
369 let layer_rev = log_h - 1;
372 dit_layer_oop(
373 &src_submat,
374 &mut dst_submat_maybe,
375 0,
376 twiddles[layer_rev].iter(),
377 );
378
379 let mut dst_submat = unsafe {
381 transmute::<RowMajorMatrixViewMut<'_, MaybeUninit<F>>, RowMajorMatrixViewMut<'_, F>>(
382 dst_submat_maybe,
383 )
384 };
385
386 let mut backwards = true;
388 for layer in 1..mid {
389 let layer_rev = log_h - 1 - layer;
390 dit_layer(
391 &mut dst_submat,
392 layer,
393 twiddles[layer_rev].iter(),
394 backwards,
395 );
396 backwards = !backwards;
397 }
398 });
399}
400
401#[instrument(level = "debug", skip_all)]
407#[inline(always)] fn second_half<F: Field>(
409 mat: &mut RowMajorMatrix<F>,
410 mid: usize,
411 twiddles_rev: &[F],
412 scale: Option<F>,
413) {
414 let log_h = log2_strict_usize(mat.height());
415
416 mat.par_row_chunks_exact_mut(1 << (log_h - mid))
418 .enumerate()
419 .for_each(|(thread, mut submat)| {
420 let mut backwards = false;
421 if let Some(scale) = scale {
422 let mut scale_applied = false;
426 for layer in mid..log_h {
427 let first_block = thread << (layer - mid);
428 if !scale_applied {
429 scale_applied = true;
430 dit_layer_rev_scaled(
431 &mut submat,
432 log_h,
433 layer,
434 twiddles_rev[first_block..].iter().copied(),
435 backwards,
436 Some(scale),
437 );
438 } else {
439 dit_layer_rev(
440 &mut submat,
441 log_h,
442 layer,
443 twiddles_rev[first_block..].iter().copied(),
444 backwards,
445 );
446 }
447 backwards = !backwards;
448 }
449 if !scale_applied {
451 submat.scale(scale);
452 }
453 } else {
454 for layer in mid..log_h {
455 let first_block = thread << (layer - mid);
456 dit_layer_rev(
457 &mut submat,
458 log_h,
459 layer,
460 twiddles_rev[first_block..].iter().copied(),
461 backwards,
462 );
463 backwards = !backwards;
464 }
465 }
466 });
467}
468
469#[instrument(level = "debug", skip_all)]
472fn second_half_general<F: Field>(
473 mat: &mut RowMajorMatrixViewMut<'_, F>,
474 mid: usize,
475 twiddles_rev: &[Vec<F>],
476) {
477 let log_h = log2_strict_usize(mat.height());
478 mat.par_row_chunks_exact_mut(1 << (log_h - mid))
479 .enumerate()
480 .for_each(|(thread, mut submat)| {
481 let mut backwards = false;
482 for layer in mid..log_h {
483 let layer_rev = log_h - 1 - layer;
484 let first_block = thread << (layer - mid);
485 dit_layer_rev(
486 &mut submat,
487 log_h,
488 layer,
489 twiddles_rev[layer_rev][first_block..].iter().copied(),
490 backwards,
491 );
492 backwards = !backwards;
493 }
494 });
495}
496
497fn dit_layer_twiddle_free<F: Field>(submat: &mut RowMajorMatrixViewMut<'_, F>, backwards: bool) {
506 let width = submat.width();
508 debug_assert!(submat.height() >= 2);
509
510 let process_block = move |block: &mut [F]| {
511 let (lo, hi) = block.split_at_mut(width);
513 TwiddleFreeButterfly.apply_to_rows(lo, hi);
514 };
515
516 let blocks = submat.values.chunks_mut(2 * width);
517 if backwards {
518 for block in blocks.rev() {
519 process_block(block);
520 }
521 } else {
522 for block in blocks {
523 process_block(block);
524 }
525 }
526}
527
528fn dit_layer_first_one<'a, F: Field>(
537 submat: &mut RowMajorMatrixViewMut<'_, F>,
538 layer: usize,
539 twiddles: impl Iterator<Item = &'a F> + Clone,
540 backwards: bool,
541) {
542 let half_block_size = 1 << layer;
543 let block_size = half_block_size * 2;
544 let width = submat.width();
545 debug_assert!(submat.height() >= block_size);
546 debug_assert!(
547 half_block_size >= 2,
548 "layer must be >= 1 for dit_layer_first_one"
549 );
550
551 let process_block = move |block: &mut [F]| {
552 let (lows, highs) = block.split_at_mut(half_block_size * width);
553 let mut tw_iter = twiddles.clone();
554 let _ = tw_iter.next(); let (lo0, lo_rest) = lows.split_at_mut(width);
557 let (hi0, hi_rest) = highs.split_at_mut(width);
558 TwiddleFreeButterfly.apply_to_rows(lo0, hi0);
559 for (lo, hi, twiddle) in izip!(
561 lo_rest.chunks_mut(width),
562 hi_rest.chunks_mut(width),
563 tw_iter
564 ) {
565 DitButterfly(*twiddle).apply_to_rows(lo, hi);
566 }
567 };
568
569 let blocks = submat.values.chunks_mut(block_size * width);
570 if backwards {
571 for block in blocks.rev() {
572 process_block(block);
573 }
574 } else {
575 for block in blocks {
576 process_block(block);
577 }
578 }
579}
580
581fn dit_layer<'a, F: Field>(
583 submat: &mut RowMajorMatrixViewMut<'_, F>,
584 layer: usize,
585 twiddles: impl Iterator<Item = &'a F> + Clone,
586 backwards: bool,
587) {
588 let half_block_size = 1 << layer;
589 let block_size = half_block_size * 2;
590 let width = submat.width();
591 debug_assert!(submat.height() >= block_size);
592
593 let process_block = move |block: &mut [F]| {
594 let (lows, highs) = block.split_at_mut(half_block_size * width);
595 for (lo, hi, twiddle) in izip!(
596 lows.chunks_mut(width),
597 highs.chunks_mut(width),
598 twiddles.clone()
599 ) {
600 DitButterfly(*twiddle).apply_to_rows(lo, hi);
601 }
602 };
603
604 let blocks = submat.values.chunks_mut(block_size * width);
605 if backwards {
606 for block in blocks.rev() {
607 process_block(block);
608 }
609 } else {
610 for block in blocks {
611 process_block(block);
612 }
613 }
614}
615
616fn dit_layer_oop<'a, F: Field>(
618 src: &RowMajorMatrixView<'_, F>,
619 dst: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
620 layer: usize,
621 twiddles: impl Iterator<Item = &'a F> + Clone,
622) {
623 debug_assert_eq!(src.dimensions(), dst.dimensions());
624 let half_block_size = 1 << layer;
625 let block_size = half_block_size * 2;
626 let width = dst.width();
627 debug_assert!(dst.height() >= block_size);
628
629 let process_blocks = move |src_block: &[F], dst_block: &mut [MaybeUninit<F>]| {
630 let (src_lows, src_highs) = src_block.split_at(half_block_size * width);
631 let (dst_lows, dst_highs) = dst_block.split_at_mut(half_block_size * width);
632
633 for (src_lo, dst_lo, src_hi, dst_hi, twiddle) in izip!(
634 src_lows.chunks(width),
635 dst_lows.chunks_mut(width),
636 src_highs.chunks(width),
637 dst_highs.chunks_mut(width),
638 twiddles.clone()
639 ) {
640 DitButterfly(*twiddle).apply_to_rows_oop(src_lo, dst_lo, src_hi, dst_hi);
641 }
642 };
643
644 let src_chunks = src.values.chunks(block_size * width);
645 let dst_chunks = dst.values.chunks_mut(block_size * width);
646
647 for (src_block, dst_block) in src_chunks.zip(dst_chunks) {
648 process_blocks(src_block, dst_block);
649 }
650}
651
652fn dit_layer_rev_scaled<F: Field>(
660 submat: &mut RowMajorMatrixViewMut<'_, F>,
661 log_h: usize,
662 layer: usize,
663 twiddles_rev: impl DoubleEndedIterator<Item = F> + ExactSizeIterator,
664 backwards: bool,
665 scale: Option<F>,
666) {
667 let layer_rev = log_h - 1 - layer;
668
669 let half_block_size = 1 << layer_rev;
670 let block_size = half_block_size * 2;
671 let width = submat.width();
672 debug_assert!(submat.height() >= block_size);
673
674 match scale {
675 None => {
676 let blocks_and_twiddles = submat
678 .values
679 .chunks_mut(block_size * width)
680 .zip(twiddles_rev);
681 if backwards {
682 for (block, twiddle) in blocks_and_twiddles.rev() {
683 let (lo, hi) = block.split_at_mut(half_block_size * width);
684 DitButterfly(twiddle).apply_to_rows(lo, hi);
685 }
686 } else {
687 for (block, twiddle) in blocks_and_twiddles {
688 let (lo, hi) = block.split_at_mut(half_block_size * width);
689 DitButterfly(twiddle).apply_to_rows(lo, hi);
690 }
691 }
692 }
693 Some(s) => {
694 let blocks_and_twiddles = submat
698 .values
699 .chunks_mut(block_size * width)
700 .zip(twiddles_rev);
701 if backwards {
702 for (block, twiddle) in blocks_and_twiddles.rev() {
703 let (lo, hi) = block.split_at_mut(half_block_size * width);
704 ScaledDitButterfly::new(twiddle, s).apply_to_rows(lo, hi);
705 }
706 } else {
707 for (block, twiddle) in blocks_and_twiddles {
708 let (lo, hi) = block.split_at_mut(half_block_size * width);
709 ScaledDitButterfly::new(twiddle, s).apply_to_rows(lo, hi);
710 }
711 }
712 }
713 }
714}
715
716fn dit_layer_rev<F: Field>(
719 submat: &mut RowMajorMatrixViewMut<'_, F>,
720 log_h: usize,
721 layer: usize,
722 twiddles_rev: impl DoubleEndedIterator<Item = F> + ExactSizeIterator,
723 backwards: bool,
724) {
725 let layer_rev = log_h - 1 - layer;
726
727 let half_block_size = 1 << layer_rev;
728 let block_size = half_block_size * 2;
729 let width = submat.width();
730 debug_assert!(submat.height() >= block_size);
731
732 let blocks_and_twiddles = submat
733 .values
734 .chunks_mut(block_size * width)
735 .zip(twiddles_rev);
736 if backwards {
737 for (block, twiddle) in blocks_and_twiddles.rev() {
738 let (lo, hi) = block.split_at_mut(half_block_size * width);
739 DitButterfly(twiddle).apply_to_rows(lo, hi);
740 }
741 } else {
742 for (block, twiddle) in blocks_and_twiddles {
743 let (lo, hi) = block.split_at_mut(half_block_size * width);
744 DitButterfly(twiddle).apply_to_rows(lo, hi);
745 }
746 }
747}