use alloc::vec::Vec;
use p3_field::{Field, Powers, TwoAdicField};
use p3_matrix::bitrev::{BitReversableMatrix, BitReversedMatrixView};
use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
use p3_matrix::util::reverse_matrix_index_bits;
use p3_matrix::Matrix;
use p3_maybe_rayon::prelude::*;
use p3_util::{log2_strict_usize, reverse_bits, reverse_slice_index_bits};
use tracing::instrument;
use crate::butterflies::{Butterfly, DitButterfly};
use crate::TwoAdicSubgroupDft;
#[derive(Default, Clone, Debug)]
pub struct Radix2DitParallel;
impl<F: TwoAdicField> TwoAdicSubgroupDft<F> for Radix2DitParallel {
type Evaluations = BitReversedMatrixView<RowMajorMatrix<F>>;
fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> Self::Evaluations {
let h = mat.height();
let log_h = log2_strict_usize(h);
let root = F::two_adic_generator(log_h);
let mut twiddles: Vec<F> = root.powers().take(h / 2).collect();
let mid = log_h / 2;
reverse_matrix_index_bits(&mut mat);
par_dit_layer(&mut mat, mid, &twiddles);
reverse_matrix_index_bits(&mut mat);
reverse_slice_index_bits(&mut twiddles);
par_dit_layer_rev(&mut mat, mid, &twiddles);
mat.bit_reverse_rows()
}
#[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))]
fn coset_lde_batch(
&self,
mut mat: RowMajorMatrix<F>,
added_bits: usize,
shift: F,
) -> Self::Evaluations {
let h = mat.height();
let log_h = log2_strict_usize(h);
let mid = log_h / 2;
let h_inv = F::from_canonical_usize(h).inverse();
let root = F::two_adic_generator(log_h);
let root_inv = root.inverse();
let mut twiddles_inv: Vec<F> = root_inv.powers().take(h / 2).collect();
reverse_matrix_index_bits(&mut mat);
par_dit_layer(&mut mat, mid, &twiddles_inv);
reverse_matrix_index_bits(&mut mat);
reverse_slice_index_bits(&mut twiddles_inv);
par_dit_layer_rev(&mut mat, mid, &twiddles_inv);
let weights = Powers {
base: shift,
current: h_inv,
}
.take(h);
for (row, weight) in weights.enumerate() {
mat.scale_row(reverse_bits(row, h), weight);
}
mat = mat.bit_reversed_zero_pad(added_bits);
let h = mat.height();
let log_h = log2_strict_usize(h);
let mid = log_h / 2;
let root = F::two_adic_generator(log_h);
let mut twiddles: Vec<F> = root.powers().take(h / 2).collect();
par_dit_layer(&mut mat, mid, &twiddles);
reverse_matrix_index_bits(&mut mat);
reverse_slice_index_bits(&mut twiddles);
par_dit_layer_rev(&mut mat, mid, &twiddles);
mat.bit_reverse_rows()
}
}
fn par_dit_layer<F: Field>(mat: &mut RowMajorMatrix<F>, mid: usize, twiddles: &[F]) {
let log_h = log2_strict_usize(mat.height());
mat.par_row_chunks_exact_mut(1 << mid)
.for_each(|mut submat| {
for layer in 0..mid {
dit_layer(&mut submat, log_h, layer, twiddles);
}
});
}
fn par_dit_layer_rev<F: Field>(mat: &mut RowMajorMatrix<F>, mid: usize, twiddles_rev: &[F]) {
let log_h = log2_strict_usize(mat.height());
mat.par_row_chunks_exact_mut(1 << (log_h - mid))
.enumerate()
.for_each(|(thread, mut submat)| {
for layer in mid..log_h {
let first_block = thread << (layer - mid);
dit_layer_rev(&mut submat, log_h, layer, &twiddles_rev[first_block..]);
}
});
}
fn dit_layer<F: Field>(
submat: &mut RowMajorMatrixViewMut<'_, F>,
log_h: usize,
layer: usize,
twiddles: &[F],
) {
let layer_rev = log_h - 1 - layer;
let half_block_size = 1 << layer;
let block_size = half_block_size * 2;
debug_assert!(submat.height() >= block_size);
for block_start in (0..submat.height()).step_by(block_size) {
for i in 0..half_block_size {
let hi = block_start + i;
let lo = hi + half_block_size;
let twiddle = twiddles[i << layer_rev];
let (hi_chunk, lo_chunk) = submat.row_pair_mut(hi, lo);
DitButterfly(twiddle).apply_to_rows(hi_chunk, lo_chunk);
}
}
}
fn dit_layer_rev<F: Field>(
submat: &mut RowMajorMatrixViewMut<'_, F>,
log_h: usize,
layer: usize,
twiddles_rev: &[F],
) {
let layer_rev = log_h - 1 - layer;
let half_block_size = 1 << layer_rev;
let block_size = half_block_size * 2;
debug_assert!(submat.height() >= block_size);
for (block, block_start) in (0..submat.height()).step_by(block_size).enumerate() {
let twiddle = twiddles_rev[block];
for i in 0..half_block_size {
let hi = block_start + i;
let lo = hi + half_block_size;
let (hi_chunk, lo_chunk) = submat.row_pair_mut(hi, lo);
DitButterfly(twiddle).apply_to_rows(hi_chunk, lo_chunk);
}
}
}
#[cfg(test)]
mod tests {
use p3_baby_bear::BabyBear;
use p3_goldilocks::Goldilocks;
use crate::testing::*;
use crate::Radix2DitParallel;
#[test]
fn dft_matches_naive() {
test_dft_matches_naive::<BabyBear, Radix2DitParallel>();
}
#[test]
fn coset_dft_matches_naive() {
test_coset_dft_matches_naive::<BabyBear, Radix2DitParallel>();
}
#[test]
fn idft_matches_naive() {
test_idft_matches_naive::<Goldilocks, Radix2DitParallel>();
}
#[test]
fn coset_idft_matches_naive() {
test_coset_idft_matches_naive::<BabyBear, Radix2DitParallel>();
test_coset_idft_matches_naive::<Goldilocks, Radix2DitParallel>();
}
#[test]
fn lde_matches_naive() {
test_lde_matches_naive::<BabyBear, Radix2DitParallel>();
}
#[test]
fn coset_lde_matches_naive() {
test_coset_lde_matches_naive::<BabyBear, Radix2DitParallel>();
}
#[test]
fn dft_idft_consistency() {
test_dft_idft_consistency::<BabyBear, Radix2DitParallel>();
}
}