use alloc::sync::Arc;
use alloc::vec::Vec;
use core::iter;
use itertools::Itertools;
use p3_field::{Field, TwoAdicField, scale_slice_in_place_single_core};
use p3_matrix::Matrix;
use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
use p3_matrix::util::reverse_matrix_index_bits;
use p3_maybe_rayon::prelude::*;
use p3_util::{as_base_slice, log2_strict_usize, reverse_slice_index_bits};
use spin::RwLock;
use crate::{
Butterfly, DifButterfly, DifButterflyZeros, DitButterfly, TwiddleFreeButterfly,
TwoAdicSubgroupDft,
};
const LAYERS_PER_GROUP: usize = 3;
#[derive(Clone, Debug)]
struct TwiddlePair<F> {
twiddles: Arc<[Vec<F>]>,
inv_twiddles: Arc<[Vec<F>]>,
}
impl<F> Default for TwiddlePair<F> {
fn default() -> Self {
Self {
twiddles: Arc::from(Vec::new()),
inv_twiddles: Arc::from(Vec::new()),
}
}
}
#[derive(Default, Clone, Debug)]
pub struct Radix2DFTSmallBatch<F> {
cache: Arc<RwLock<TwiddlePair<F>>>,
}
impl<F: TwoAdicField> Radix2DFTSmallBatch<F> {
pub fn new(n: usize) -> Self {
let res = Self::default();
res.update_twiddles(n);
res
}
fn roots_of_unity_table(&self, n: usize) -> Vec<Vec<F>> {
let lg_n = log2_strict_usize(n);
let generator = F::two_adic_generator(lg_n);
let half_n = 1 << (lg_n - 1);
let nth_roots = generator.powers().collect_n(half_n);
(0..lg_n)
.map(|i| nth_roots.iter().step_by(1 << i).copied().collect())
.collect()
}
fn update_twiddles(&self, fft_len: usize) {
let curr_max_fft_len = 1 << self.cache.read().twiddles.len();
if fft_len > curr_max_fft_len {
let mut new_twiddles = self.roots_of_unity_table(fft_len);
let mut new_inv_twiddles: Vec<Vec<F>> = new_twiddles
.iter()
.map(|ts| {
iter::once(F::ONE)
.chain(ts[1..].iter().rev().map(|&f| -f))
.collect()
})
.collect();
new_twiddles.iter_mut().for_each(|ts| {
reverse_slice_index_bits(ts);
});
new_inv_twiddles.iter_mut().for_each(|ts| {
reverse_slice_index_bits(ts);
});
let mut cache = self.cache.write();
let cur_have = 1usize << cache.twiddles.len();
if fft_len > cur_have {
cache.twiddles = Arc::from(new_twiddles);
cache.inv_twiddles = Arc::from(new_inv_twiddles);
}
}
}
}
impl<F> TwoAdicSubgroupDft<F> for Radix2DFTSmallBatch<F>
where
F: TwoAdicField,
{
type Evaluations = RowMajorMatrix<F>;
fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> Self::Evaluations {
let h = mat.height();
let w = mat.width();
let log_h = log2_strict_usize(h);
self.update_twiddles(h);
let g = self.cache.read().twiddles.clone(); let root_table = &g[g.len() - log_h..];
let num_par_rows = estimate_num_rows_in_l1::<F>(h, w);
let log_num_par_rows = log2_strict_usize(num_par_rows);
let chunk_size = num_par_rows * w;
let multi_layer_dit = MultiLayerDitButterfly {};
for (dit_0, dit_1, dit_2) in root_table[log_num_par_rows..]
.iter()
.rev()
.map(|slice| unsafe { as_base_slice::<DitButterfly<F>, F>(slice) }) .tuples()
{
dft_layer_par_triple(&mut mat.as_view_mut(), dit_0, dit_1, dit_2, multi_layer_dit);
}
let corr = (log_h - log_num_par_rows) % LAYERS_PER_GROUP;
dft_layer_par_extra_layers(
&mut mat.as_view_mut(),
&root_table[log_num_par_rows..log_num_par_rows + corr],
multi_layer_dit,
);
par_remaining_layers(&mut mat.values, chunk_size, &root_table[..log_num_par_rows]);
reverse_matrix_index_bits(&mut mat);
mat
}
fn idft_batch(&self, mut mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
let h = mat.height();
let w = mat.width();
let log_h = log2_strict_usize(h);
self.update_twiddles(h);
let g = self.cache.read().inv_twiddles.clone(); let start = g
.len()
.checked_sub(log_h)
.expect("log_h exceeds inv_twiddles length");
let root_table = &g[start..];
let num_par_rows = estimate_num_rows_in_l1::<F>(h, w);
let log_num_par_rows = log2_strict_usize(num_par_rows);
let chunk_size = num_par_rows * w;
reverse_matrix_index_bits(&mut mat);
par_initial_layers(
&mut mat.values,
chunk_size,
&root_table[..log_num_par_rows],
log_h,
);
let multi_layer_dif = MultiLayerDifButterfly {};
let corr = (log_h - log_num_par_rows) % LAYERS_PER_GROUP;
dft_layer_par_extra_layers(
&mut mat.as_view_mut(),
&root_table[log_num_par_rows..log_num_par_rows + corr],
multi_layer_dif,
);
for (dif_0, dif_1, dif_2) in root_table[(log_num_par_rows + corr)..]
.iter()
.map(|slice| unsafe { as_base_slice::<DifButterfly<F>, F>(slice) }) .tuples()
{
dft_layer_par_triple(&mut mat.as_view_mut(), dif_2, dif_1, dif_0, multi_layer_dif);
}
mat
}
fn coset_lde_batch(
&self,
mut mat: RowMajorMatrix<F>,
added_bits: usize,
shift: F,
) -> Self::Evaluations {
let h = mat.height();
let w = mat.width();
let log_h = log2_strict_usize(h);
self.update_twiddles(h << added_bits);
let cached = self.cache.read().clone();
let g = &cached.twiddles;
let start = g
.len()
.checked_sub(log_h + added_bits)
.expect("log_h exceeds twiddles length");
let root_table = &g[start..];
let ig = &cached.inv_twiddles;
let start = ig
.len()
.checked_sub(log_h)
.expect("log_h exceeds inv_twiddles length");
let inv_root_table = &ig[start..];
let output_height = h << added_bits;
let output_values = F::zero_vec(output_height * w);
let mut out = RowMajorMatrix::new(output_values, w);
let num_par_rows = estimate_num_rows_in_l1::<F>(h, w);
let num_inner_dit_layers = log2_strict_usize(num_par_rows);
let num_inner_dif_layers = num_inner_dit_layers + added_bits;
let multi_layer_dit = MultiLayerDitButterfly {};
for (dit_0, dit_1, dit_2) in inv_root_table[num_inner_dit_layers..]
.iter()
.rev()
.map(|slice| unsafe { as_base_slice::<DitButterfly<F>, F>(slice) }) .tuples()
{
dft_layer_par_triple(&mut mat.as_view_mut(), dit_0, dit_1, dit_2, multi_layer_dit);
}
let corr = (log_h - num_inner_dit_layers) % LAYERS_PER_GROUP;
dft_layer_par_extra_layers(
&mut mat.as_view_mut(),
&inv_root_table[num_inner_dit_layers..num_inner_dit_layers + corr],
multi_layer_dit,
);
par_middle_layers(
&mut mat.as_view_mut(),
&mut out.as_view_mut(),
num_par_rows,
&root_table[..(num_inner_dif_layers)],
&inv_root_table[..num_inner_dit_layers],
added_bits,
shift,
);
let multi_layer_dif = MultiLayerDifButterfly {};
dft_layer_par_extra_layers(
&mut out.as_view_mut(),
&root_table[num_inner_dif_layers..num_inner_dif_layers + corr],
multi_layer_dif,
);
for (dif_0, dif_1, dif_2) in root_table[(num_inner_dif_layers + corr)..]
.iter()
.map(|slice| unsafe { as_base_slice::<DifButterfly<F>, F>(slice) }) .tuples()
{
dft_layer_par_triple(&mut out.as_view_mut(), dif_2, dif_1, dif_0, multi_layer_dif);
}
out
}
}
#[inline]
fn dft_layer_par<F: Field, B: Butterfly<F>>(
mat: &mut RowMajorMatrixViewMut<'_, F>,
twiddles: &[B],
) {
debug_assert!(
mat.height().is_multiple_of(twiddles.len()),
"Matrix height must be divisible by the number of twiddles"
);
let size = mat.values.len();
let num_blocks = twiddles.len();
let outer_block_size = size / num_blocks;
let half_outer_block_size = outer_block_size / 2;
mat.values
.par_chunks_exact_mut(outer_block_size)
.enumerate()
.for_each(|(ind, block)| {
let (hi_chunk, lo_chunk) = block.split_at_mut(half_outer_block_size);
let num_threads = current_num_threads();
let inner_block_size = size / (2 * num_blocks).max(num_threads);
hi_chunk
.par_chunks_mut(inner_block_size)
.zip(lo_chunk.par_chunks_mut(inner_block_size))
.for_each(|(hi_chunk, lo_chunk)| {
if ind == 0 {
TwiddleFreeButterfly.apply_to_rows(hi_chunk, lo_chunk);
} else {
twiddles[ind].apply_to_rows(hi_chunk, lo_chunk);
}
});
});
}
#[inline]
fn par_remaining_layers<F: Field>(mat: &mut [F], chunk_size: usize, root_table: &[Vec<F>]) {
mat.par_chunks_exact_mut(chunk_size)
.enumerate()
.for_each(|(index, chunk)| {
remaining_layers(chunk, root_table, index);
});
}
fn remaining_layers<F: Field>(chunk: &mut [F], root_table: &[Vec<F>], index: usize) {
for (layer, twiddles) in root_table.iter().rev().enumerate() {
let num_twiddles_per_block = 1 << layer;
let start = index * num_twiddles_per_block;
let twiddle_range = start..(start + num_twiddles_per_block);
let dit_twiddles: &[DitButterfly<F>] = unsafe { as_base_slice(&twiddles[twiddle_range]) };
dft_layer(chunk, dit_twiddles);
}
}
#[inline]
fn par_initial_layers<F: Field>(
mat: &mut [F],
chunk_size: usize,
root_table: &[Vec<F>],
log_height: usize,
) {
let inv_height = F::ONE.div_2exp_u64(log_height as u64);
mat.par_chunks_exact_mut(chunk_size)
.enumerate()
.for_each(|(index, chunk)| {
scale_slice_in_place_single_core(chunk, inv_height);
initial_layers(chunk, root_table, index);
});
}
#[inline]
fn initial_layers<F: Field>(chunk: &mut [F], root_table: &[Vec<F>], index: usize) {
let num_rounds = root_table.len();
for (layer, twiddles) in root_table.iter().enumerate() {
let num_twiddles_per_block = 1 << (num_rounds - layer - 1);
let start = index * num_twiddles_per_block;
let twiddle_range = start..(start + num_twiddles_per_block);
let dif_twiddles: &[DifButterfly<F>] = unsafe { as_base_slice(&twiddles[twiddle_range]) };
dft_layer(chunk, dif_twiddles);
}
}
fn par_middle_layers<F: Field>(
in_mat: &mut RowMajorMatrixViewMut<'_, F>,
out_mat: &mut RowMajorMatrixViewMut<'_, F>,
num_par_rows: usize,
root_table: &[Vec<F>],
inv_root_table: &[Vec<F>],
added_bits: usize,
shift: F,
) {
debug_assert_eq!(in_mat.width(), out_mat.width());
debug_assert_eq!(in_mat.height() << added_bits, out_mat.height());
let width = in_mat.width();
let height = in_mat.height();
let num_rounds = root_table.len();
let in_chunk_size = num_par_rows * width;
let out_chunk_size = in_chunk_size << added_bits;
let log_height = log2_strict_usize(height);
let inv_height = F::ONE.div_2exp_u64(log_height as u64);
let mut scaling = shift.shifted_powers(inv_height).collect_n(height);
reverse_slice_index_bits(&mut scaling);
in_mat
.values
.par_chunks_exact_mut(in_chunk_size)
.zip(out_mat.values.par_chunks_exact_mut(out_chunk_size))
.zip(scaling.par_chunks_exact_mut(num_par_rows))
.enumerate()
.for_each(|(index, ((in_chunk, out_chunk), scaling))| {
remaining_layers(in_chunk, inv_root_table, index);
in_chunk
.chunks_exact(width)
.zip(scaling)
.zip(out_chunk.chunks_exact_mut(width << added_bits))
.for_each(|((in_row, scale), out_row)| {
out_row
.iter_mut()
.zip(in_row.iter())
.for_each(|(out_val, in_val)| {
*out_val = *in_val * *scale;
});
});
for (layer, twiddles) in root_table[..added_bits].iter().enumerate() {
let num_twiddles_per_block = 1 << (num_rounds - layer - 1);
let start = index * num_twiddles_per_block;
let twiddle_range = start..(start + num_twiddles_per_block);
let dif_twiddles_zeros: &[DifButterflyZeros<F>] =
unsafe { as_base_slice(&twiddles[twiddle_range]) };
dft_layer_zeros(out_chunk, dif_twiddles_zeros, added_bits - layer - 1);
}
initial_layers(out_chunk, &root_table[added_bits..], index);
});
}
#[inline]
fn dft_layer<F: Field, B: Butterfly<F>>(vec: &mut [F], twiddles: &[B]) {
debug_assert_eq!(
vec.len() % twiddles.len(),
0,
"Vector length must be divisible by the number of twiddles"
);
let size = vec.len();
let num_blocks = twiddles.len();
let block_size = size / num_blocks;
let half_block_size = block_size / 2;
vec.chunks_exact_mut(block_size)
.zip(twiddles)
.for_each(|(block, &twiddle)| {
let (hi_chunk, lo_chunk) = block.split_at_mut(half_block_size);
twiddle.apply_to_rows(hi_chunk, lo_chunk);
});
}
#[inline]
fn dft_layer_par_double<F: Field, B: Butterfly<F>, M: MultiLayerButterfly<F, B>>(
mat: &mut RowMajorMatrixViewMut<'_, F>,
twiddles_small: &[B],
twiddles_large: &[B],
multi_butterfly: M,
) {
debug_assert!(
mat.height().is_multiple_of(twiddles_small.len()),
"Matrix height must be divisible by the number of twiddles"
);
let size = mat.values.len();
let num_blocks = twiddles_small.len();
let outer_block_size = size / num_blocks;
let quarter_outer_block_size = outer_block_size / 4;
let inner_chunk_size =
(workload_size::<F>().next_power_of_two() / 4).min(quarter_outer_block_size);
mat.values
.par_chunks_exact_mut(outer_block_size)
.enumerate()
.for_each(|(ind, block)| {
let chunk_par_iters_0 = block
.chunks_exact_mut(quarter_outer_block_size)
.map(|chunk| chunk.par_chunks_mut(inner_chunk_size))
.collect::<Vec<_>>();
let chunk_par_iters_1 = zip_par_iter_vec(chunk_par_iters_0);
chunk_par_iters_1.into_iter().tuples().for_each(|(hi, lo)| {
hi.zip(lo).for_each(|chunks| {
multi_butterfly.apply_2_layers(chunks, ind, twiddles_small, twiddles_large);
});
});
});
}
#[inline]
fn dft_layer_par_triple<F: Field, B: Butterfly<F>, M: MultiLayerButterfly<F, B>>(
mat: &mut RowMajorMatrixViewMut<'_, F>,
twiddles_small: &[B],
twiddles_med: &[B],
twiddles_large: &[B],
multi_butterfly: M,
) {
debug_assert!(
mat.height().is_multiple_of(twiddles_small.len()),
"Matrix height must be divisible by the number of twiddles"
);
let size = mat.values.len();
let num_blocks = twiddles_small.len();
let outer_block_size = size / num_blocks;
let eighth_outer_block_size = outer_block_size / 8;
let inner_chunk_size =
(workload_size::<F>().next_power_of_two() / 8).min(eighth_outer_block_size);
mat.values
.par_chunks_exact_mut(outer_block_size)
.enumerate()
.for_each(|(ind, block)| {
let chunk_par_iters_0 = block
.chunks_exact_mut(eighth_outer_block_size)
.map(|chunk| chunk.par_chunks_mut(inner_chunk_size))
.collect::<Vec<_>>();
let chunk_par_iters_1 = zip_par_iter_vec(chunk_par_iters_0);
let chunk_par_iters_2 = zip_par_iter_vec(chunk_par_iters_1);
chunk_par_iters_2.into_iter().tuples().for_each(|(hi, lo)| {
hi.zip(lo).for_each(|chunks| {
multi_butterfly.apply_3_layers(
chunks,
ind,
twiddles_small,
twiddles_med,
twiddles_large,
);
});
});
});
}
fn dft_layer_par_extra_layers<F: Field, B: Butterfly<F>, M: MultiLayerButterfly<F, B>>(
mat: &mut RowMajorMatrixViewMut<'_, F>,
root_table: &[Vec<F>],
multi_layer: M,
) {
match root_table.len() {
1 => {
let fft_layer: &[B] = unsafe { as_base_slice(&root_table[0]) };
dft_layer_par(&mut mat.as_view_mut(), fft_layer);
}
2 => {
let fft_layer_0: &[B] = unsafe { as_base_slice(&root_table[0]) };
let fft_layer_1: &[B] = unsafe { as_base_slice(&root_table[1]) };
dft_layer_par_double(
&mut mat.as_view_mut(),
fft_layer_1,
fft_layer_0,
multi_layer,
);
}
0 => {}
_ => unreachable!("The number of layers must be 0, 1 or 2"),
}
}
#[inline]
fn dft_layer_zeros<F: Field, B: Butterfly<F>>(vec: &mut [F], twiddles: &[B], skip: usize) {
debug_assert_eq!(
vec.len() % twiddles.len(),
0,
"Vector length must be divisible by the number of twiddles"
);
let size = vec.len();
let num_blocks = twiddles.len();
let block_size = size / num_blocks;
let half_block_size = block_size / 2;
vec.chunks_exact_mut(block_size)
.zip(twiddles)
.step_by(1 << skip) .for_each(|(block, &twiddle)| {
let (hi_chunk, lo_chunk) = block.split_at_mut(half_block_size);
twiddle.apply_to_rows(hi_chunk, lo_chunk);
});
}
type DoubleLayerBlockDecomposition<'a, F> =
((&'a mut [F], &'a mut [F]), (&'a mut [F], &'a mut [F]));
#[inline]
fn fft_double_layer_single_twiddle<F: Field, Fly: Butterfly<F>>(
block: &mut DoubleLayerBlockDecomposition<'_, F>,
butterfly: Fly,
) {
butterfly.apply_to_rows(block.0.0, block.1.0);
butterfly.apply_to_rows(block.0.1, block.1.1);
}
#[inline]
fn fft_double_layer_double_twiddle<F: Field, Fly0: Butterfly<F>, Fly1: Butterfly<F>>(
block: &mut DoubleLayerBlockDecomposition<'_, F>,
fly0: Fly0,
fly1: Fly1,
) {
fly0.apply_to_rows(block.0.0, block.0.1);
fly1.apply_to_rows(block.1.0, block.1.1);
}
type TripleLayerBlockDecomposition<'a, F> = (
((&'a mut [F], &'a mut [F]), (&'a mut [F], &'a mut [F])),
((&'a mut [F], &'a mut [F]), (&'a mut [F], &'a mut [F])),
);
#[inline]
fn fft_triple_layer_single_twiddle<F: Field, Fly: Butterfly<F>>(
block: &mut TripleLayerBlockDecomposition<'_, F>,
butterfly: Fly,
) {
butterfly.apply_to_rows(block.0.0.0, block.1.0.0);
butterfly.apply_to_rows(block.0.0.1, block.1.0.1);
butterfly.apply_to_rows(block.0.1.0, block.1.1.0);
butterfly.apply_to_rows(block.0.1.1, block.1.1.1);
}
#[inline]
fn fft_triple_layer_double_twiddle<F: Field, Fly0: Butterfly<F>, Fly1: Butterfly<F>>(
block: &mut TripleLayerBlockDecomposition<'_, F>,
fly0: Fly0,
fly1: Fly1,
) {
fly0.apply_to_rows(block.0.0.0, block.0.1.0);
fly0.apply_to_rows(block.0.0.1, block.0.1.1);
fly1.apply_to_rows(block.1.0.0, block.1.1.0);
fly1.apply_to_rows(block.1.0.1, block.1.1.1);
}
#[inline]
fn fft_triple_layer_quad_twiddle<F: Field, Fly0: Butterfly<F>, Flies: Butterfly<F>>(
block: &mut TripleLayerBlockDecomposition<'_, F>,
fly0: Fly0,
butterflies: &[Flies],
) {
debug_assert!(butterflies.len() == 3);
fly0.apply_to_rows(block.0.0.0, block.0.0.1);
butterflies[0].apply_to_rows(block.0.1.0, block.0.1.1);
butterflies[1].apply_to_rows(block.1.0.0, block.1.0.1);
butterflies[2].apply_to_rows(block.1.1.0, block.1.1.1);
}
#[must_use]
const fn workload_size<T: Sized>() -> usize {
const L1_CACHE_SIZE: usize = 1 << 15; L1_CACHE_SIZE / size_of::<T>()
}
#[must_use]
fn estimate_num_rows_in_l1<T: Sized>(height: usize, width: usize) -> usize {
(workload_size::<T>() / width)
.next_power_of_two()
.min(height) }
#[inline]
fn zip_par_iter_vec<I: IndexedParallelIterator>(
in_vec: Vec<I>,
) -> Vec<impl IndexedParallelIterator<Item = (I::Item, I::Item)>> {
in_vec
.into_iter()
.tuples()
.map(|(hi, lo)| hi.zip(lo))
.collect::<Vec<_>>()
}
trait MultiLayerButterfly<F: Field, B: Butterfly<F>>: Copy + Send + Sync {
fn apply_2_layers(
&self,
chunk_decomposition: DoubleLayerBlockDecomposition<'_, F>,
ind: usize,
twiddles_small: &[B],
twiddles_large: &[B],
);
fn apply_3_layers(
&self,
chunk_decomposition: TripleLayerBlockDecomposition<'_, F>,
ind: usize,
twiddles_small: &[B],
twiddles_med: &[B],
twiddles_large: &[B],
);
}
#[derive(Debug, Clone, Copy)]
struct MultiLayerDitButterfly;
impl<F: Field> MultiLayerButterfly<F, DitButterfly<F>> for MultiLayerDitButterfly {
#[inline]
fn apply_2_layers(
&self,
mut blk_decomp: DoubleLayerBlockDecomposition<'_, F>,
ind: usize,
twiddles_small: &[DitButterfly<F>],
twiddles_large: &[DitButterfly<F>],
) {
if ind == 0 {
fft_double_layer_single_twiddle(&mut blk_decomp, TwiddleFreeButterfly);
fft_double_layer_double_twiddle(
&mut blk_decomp,
TwiddleFreeButterfly,
twiddles_large[1],
);
} else {
fft_double_layer_single_twiddle(&mut blk_decomp, twiddles_small[ind]);
fft_double_layer_double_twiddle(
&mut blk_decomp,
twiddles_large[2 * ind],
twiddles_large[2 * ind + 1],
);
}
}
#[inline]
fn apply_3_layers(
&self,
mut blk_decomp: TripleLayerBlockDecomposition<'_, F>,
ind: usize,
twiddles_small: &[DitButterfly<F>],
twiddles_med: &[DitButterfly<F>],
twiddles_large: &[DitButterfly<F>],
) {
if ind == 0 {
fft_triple_layer_single_twiddle(&mut blk_decomp, TwiddleFreeButterfly);
fft_triple_layer_double_twiddle(&mut blk_decomp, TwiddleFreeButterfly, twiddles_med[1]);
fft_triple_layer_quad_twiddle(
&mut blk_decomp,
TwiddleFreeButterfly,
&twiddles_large[1..4],
);
} else {
fft_triple_layer_single_twiddle(&mut blk_decomp, twiddles_small[ind]);
fft_triple_layer_double_twiddle(
&mut blk_decomp,
twiddles_med[2 * ind],
twiddles_med[2 * ind + 1],
);
fft_triple_layer_quad_twiddle(
&mut blk_decomp,
twiddles_large[4 * ind],
&twiddles_large[4 * ind + 1..4 * (ind + 1)],
);
}
}
}
#[derive(Debug, Clone, Copy)]
struct MultiLayerDifButterfly;
impl<F: Field> MultiLayerButterfly<F, DifButterfly<F>> for MultiLayerDifButterfly {
#[inline]
fn apply_2_layers(
&self,
mut blk_decomp: DoubleLayerBlockDecomposition<'_, F>,
ind: usize,
twiddles_small: &[DifButterfly<F>],
twiddles_large: &[DifButterfly<F>],
) {
if ind == 0 {
fft_double_layer_double_twiddle(
&mut blk_decomp,
TwiddleFreeButterfly,
twiddles_large[1],
);
fft_double_layer_single_twiddle(&mut blk_decomp, TwiddleFreeButterfly);
} else {
fft_double_layer_double_twiddle(
&mut blk_decomp,
twiddles_large[2 * ind],
twiddles_large[2 * ind + 1],
);
fft_double_layer_single_twiddle(&mut blk_decomp, twiddles_small[ind]);
}
}
#[inline]
fn apply_3_layers(
&self,
mut blk_decomp: TripleLayerBlockDecomposition<'_, F>,
ind: usize,
twiddles_small: &[DifButterfly<F>],
twiddles_med: &[DifButterfly<F>],
twiddles_large: &[DifButterfly<F>],
) {
if ind == 0 {
fft_triple_layer_quad_twiddle(
&mut blk_decomp,
TwiddleFreeButterfly,
&twiddles_large[1..4],
);
fft_triple_layer_double_twiddle(&mut blk_decomp, TwiddleFreeButterfly, twiddles_med[1]);
fft_triple_layer_single_twiddle(&mut blk_decomp, TwiddleFreeButterfly);
} else {
fft_triple_layer_quad_twiddle(
&mut blk_decomp,
twiddles_large[4 * ind],
&twiddles_large[4 * ind + 1..4 * (ind + 1)],
);
fft_triple_layer_double_twiddle(
&mut blk_decomp,
twiddles_med[2 * ind],
twiddles_med[2 * ind + 1],
);
fft_triple_layer_single_twiddle(&mut blk_decomp, twiddles_small[ind]);
}
}
}