1use alloc::collections::BTreeMap;
2use alloc::vec::Vec;
3use core::cell::RefCell;
4
5use p3_field::{Field, TwoAdicField};
6use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
7use p3_matrix::util::reverse_matrix_index_bits;
8use p3_matrix::Matrix;
9use p3_maybe_rayon::prelude::*;
10use p3_util::log2_strict_usize;
11
12use crate::butterflies::{Butterfly, DitButterfly, TwiddleFreeButterfly};
13use crate::TwoAdicSubgroupDft;
14
15#[derive(Default, Clone, Debug)]
17pub struct Radix2Dit<F: TwoAdicField> {
18 twiddles: RefCell<BTreeMap<usize, Vec<F>>>,
20}
21
22impl<F: TwoAdicField> TwoAdicSubgroupDft<F> for Radix2Dit<F> {
23 type Evaluations = RowMajorMatrix<F>;
24
25 fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
26 let h = mat.height();
27 let log_h = log2_strict_usize(h);
28
29 let mut twiddles_ref_mut = self.twiddles.borrow_mut();
31 let twiddles = twiddles_ref_mut.entry(log_h).or_insert_with(|| {
32 let root = F::two_adic_generator(log_h);
33 root.powers().take(1 << log_h).collect()
34 });
35
36 reverse_matrix_index_bits(&mut mat);
38 for layer in 0..log_h {
39 dit_layer(&mut mat.as_view_mut(), layer, twiddles);
40 }
41 mat
42 }
43}
44
45fn dit_layer<F: Field>(mat: &mut RowMajorMatrixViewMut<'_, F>, layer: usize, twiddles: &[F]) {
47 let h = mat.height();
48 let log_h = log2_strict_usize(h);
49 let layer_rev = log_h - 1 - layer;
50
51 let half_block_size = 1 << layer;
52 let block_size = half_block_size * 2;
53
54 mat.par_row_chunks_exact_mut(block_size)
55 .for_each(|mut block_chunks| {
56 let (mut hi_chunks, mut lo_chunks) = block_chunks.split_rows_mut(half_block_size);
57 hi_chunks
58 .par_rows_mut()
59 .zip(lo_chunks.par_rows_mut())
60 .enumerate()
61 .for_each(|(ind, (hi_chunk, lo_chunk))| {
62 if ind == 0 {
63 TwiddleFreeButterfly.apply_to_rows(hi_chunk, lo_chunk)
64 } else {
65 DitButterfly(twiddles[ind << layer_rev]).apply_to_rows(hi_chunk, lo_chunk)
66 }
67 });
68 });
69}
70
71#[cfg(test)]
72mod tests {
73 use p3_baby_bear::BabyBear;
74 use p3_goldilocks::Goldilocks;
75
76 use crate::testing::*;
77 use crate::Radix2Dit;
78
79 #[test]
80 fn dft_matches_naive() {
81 test_dft_matches_naive::<BabyBear, Radix2Dit<_>>();
82 }
83
84 #[test]
85 fn coset_dft_matches_naive() {
86 test_coset_dft_matches_naive::<BabyBear, Radix2Dit<_>>();
87 }
88
89 #[test]
90 fn idft_matches_naive() {
91 test_idft_matches_naive::<Goldilocks, Radix2Dit<_>>();
92 }
93
94 #[test]
95 fn coset_idft_matches_naive() {
96 test_coset_idft_matches_naive::<BabyBear, Radix2Dit<_>>();
97 test_coset_idft_matches_naive::<Goldilocks, Radix2Dit<_>>();
98 }
99
100 #[test]
101 fn lde_matches_naive() {
102 test_lde_matches_naive::<BabyBear, Radix2Dit<_>>();
103 }
104
105 #[test]
106 fn coset_lde_matches_naive() {
107 test_coset_lde_matches_naive::<BabyBear, Radix2Dit<_>>();
108 }
109
110 #[test]
111 fn dft_idft_consistency() {
112 test_dft_idft_consistency::<BabyBear, Radix2Dit<_>>();
113 }
114}