1use alloc::sync::Arc;
4use alloc::vec::Vec;
5use core::iter;
6
7use itertools::Itertools;
8use p3_field::{Field, TwoAdicField, scale_slice_in_place_single_core};
9use p3_matrix::Matrix;
10use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut};
11use p3_matrix::util::reverse_matrix_index_bits;
12use p3_maybe_rayon::prelude::*;
13use p3_util::{as_base_slice, log2_strict_usize, reverse_slice_index_bits};
14use spin::RwLock;
15
16use crate::{
17 Butterfly, DifButterfly, DifButterflyZeros, DitButterfly, TwiddleFreeButterfly,
18 TwoAdicSubgroupDft,
19};
20
21const LAYERS_PER_GROUP: usize = 3;
23
24#[derive(Clone, Debug)]
28struct TwiddlePair<F> {
29 twiddles: Arc<[Vec<F>]>,
30 inv_twiddles: Arc<[Vec<F>]>,
31}
32
33impl<F> Default for TwiddlePair<F> {
34 fn default() -> Self {
35 Self {
36 twiddles: Arc::from(Vec::new()),
37 inv_twiddles: Arc::from(Vec::new()),
38 }
39 }
40}
41
42#[derive(Default, Clone, Debug)]
54pub struct Radix2DFTSmallBatch<F> {
55 cache: Arc<RwLock<TwiddlePair<F>>>,
63}
64
65impl<F: TwoAdicField> Radix2DFTSmallBatch<F> {
66 pub fn new(n: usize) -> Self {
70 let res = Self::default();
71 res.update_twiddles(n);
72 res
73 }
74
75 fn roots_of_unity_table(&self, n: usize) -> Vec<Vec<F>> {
83 let lg_n = log2_strict_usize(n);
84 let generator = F::two_adic_generator(lg_n);
85 let half_n = 1 << (lg_n - 1);
86 let nth_roots = generator.powers().collect_n(half_n);
88
89 (0..lg_n)
90 .map(|i| nth_roots.iter().step_by(1 << i).copied().collect())
91 .collect()
92 }
93
94 fn update_twiddles(&self, fft_len: usize) {
96 let curr_max_fft_len = 1 << self.cache.read().twiddles.len();
102 if fft_len > curr_max_fft_len {
103 let mut new_twiddles = self.roots_of_unity_table(fft_len);
104 let mut new_inv_twiddles: Vec<Vec<F>> = new_twiddles
105 .iter()
106 .map(|ts| {
107 iter::once(F::ONE)
110 .chain(ts[1..].iter().rev().map(|&f| -f))
111 .collect()
112 })
113 .collect();
114
115 new_twiddles.iter_mut().for_each(|ts| {
116 reverse_slice_index_bits(ts);
117 });
118 new_inv_twiddles.iter_mut().for_each(|ts| {
119 reverse_slice_index_bits(ts);
120 });
121
122 let mut cache = self.cache.write();
125 let cur_have = 1usize << cache.twiddles.len();
126 if fft_len > cur_have {
127 cache.twiddles = Arc::from(new_twiddles);
128 cache.inv_twiddles = Arc::from(new_inv_twiddles);
129 }
130 }
131 }
132}
133
134impl<F> TwoAdicSubgroupDft<F> for Radix2DFTSmallBatch<F>
135where
136 F: TwoAdicField,
137{
138 type Evaluations = RowMajorMatrix<F>;
139
140 fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> Self::Evaluations {
141 let h = mat.height();
142 let w = mat.width();
143 let log_h = log2_strict_usize(h);
144
145 self.update_twiddles(h);
146 let g = self.cache.read().twiddles.clone(); let root_table = &g[g.len() - log_h..];
148
149 let num_par_rows = estimate_num_rows_in_l1::<F>(h, w);
154 let log_num_par_rows = log2_strict_usize(num_par_rows);
155 let chunk_size = num_par_rows * w;
156
157 let multi_layer_dit = MultiLayerDitButterfly {};
161
162 for (dit_0, dit_1, dit_2) in root_table[log_num_par_rows..]
165 .iter()
166 .rev()
167 .map(|slice| unsafe { as_base_slice::<DitButterfly<F>, F>(slice) }) .tuples()
169 {
170 dft_layer_par_triple(&mut mat.as_view_mut(), dit_0, dit_1, dit_2, multi_layer_dit);
171 }
172
173 let corr = (log_h - log_num_par_rows) % LAYERS_PER_GROUP;
176 dft_layer_par_extra_layers(
177 &mut mat.as_view_mut(),
178 &root_table[log_num_par_rows..log_num_par_rows + corr],
179 multi_layer_dit,
180 );
181
182 par_remaining_layers(&mut mat.values, chunk_size, &root_table[..log_num_par_rows]);
186
187 reverse_matrix_index_bits(&mut mat);
189 mat
190 }
191
192 fn idft_batch(&self, mut mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
193 let h = mat.height();
194 let w = mat.width();
195 let log_h = log2_strict_usize(h);
196
197 self.update_twiddles(h);
198 let g = self.cache.read().inv_twiddles.clone(); let start = g
200 .len()
201 .checked_sub(log_h)
202 .expect("log_h exceeds inv_twiddles length");
203 let root_table = &g[start..];
204
205 let num_par_rows = estimate_num_rows_in_l1::<F>(h, w);
211 let log_num_par_rows = log2_strict_usize(num_par_rows);
212 let chunk_size = num_par_rows * w;
213
214 reverse_matrix_index_bits(&mut mat);
216
217 par_initial_layers(
223 &mut mat.values,
224 chunk_size,
225 &root_table[..log_num_par_rows],
226 log_h,
227 );
228
229 let multi_layer_dif = MultiLayerDifButterfly {};
233
234 let corr = (log_h - log_num_par_rows) % LAYERS_PER_GROUP;
237 dft_layer_par_extra_layers(
238 &mut mat.as_view_mut(),
239 &root_table[log_num_par_rows..log_num_par_rows + corr],
240 multi_layer_dif,
241 );
242
243 for (dif_0, dif_1, dif_2) in root_table[(log_num_par_rows + corr)..]
246 .iter()
247 .map(|slice| unsafe { as_base_slice::<DifButterfly<F>, F>(slice) }) .tuples()
249 {
250 dft_layer_par_triple(&mut mat.as_view_mut(), dif_2, dif_1, dif_0, multi_layer_dif);
251 }
252
253 mat
254 }
255
256 fn coset_lde_batch(
257 &self,
258 mut mat: RowMajorMatrix<F>,
259 added_bits: usize,
260 shift: F,
261 ) -> Self::Evaluations {
262 let h = mat.height();
263 let w = mat.width();
264 let log_h = log2_strict_usize(h);
265
266 self.update_twiddles(h << added_bits);
267 let cached = self.cache.read().clone();
268 let g = &cached.twiddles;
269 let start = g
270 .len()
271 .checked_sub(log_h + added_bits)
272 .expect("log_h exceeds twiddles length");
273 let root_table = &g[start..];
274 let ig = &cached.inv_twiddles;
275 let start = ig
276 .len()
277 .checked_sub(log_h)
278 .expect("log_h exceeds inv_twiddles length");
279 let inv_root_table = &ig[start..];
280 let output_height = h << added_bits;
281
282 let output_values = F::zero_vec(output_height * w);
284 let mut out = RowMajorMatrix::new(output_values, w);
285
286 let num_par_rows = estimate_num_rows_in_l1::<F>(h, w);
303 let num_inner_dit_layers = log2_strict_usize(num_par_rows);
304 let num_inner_dif_layers = num_inner_dit_layers + added_bits;
305
306 let multi_layer_dit = MultiLayerDitButterfly {};
309 for (dit_0, dit_1, dit_2) in inv_root_table[num_inner_dit_layers..]
310 .iter()
311 .rev()
312 .map(|slice| unsafe { as_base_slice::<DitButterfly<F>, F>(slice) }) .tuples()
314 {
315 dft_layer_par_triple(&mut mat.as_view_mut(), dit_0, dit_1, dit_2, multi_layer_dit);
316 }
317
318 let corr = (log_h - num_inner_dit_layers) % LAYERS_PER_GROUP;
321 dft_layer_par_extra_layers(
322 &mut mat.as_view_mut(),
323 &inv_root_table[num_inner_dit_layers..num_inner_dit_layers + corr],
324 multi_layer_dit,
325 );
326
327 par_middle_layers(
331 &mut mat.as_view_mut(),
332 &mut out.as_view_mut(),
333 num_par_rows,
334 &root_table[..(num_inner_dif_layers)],
335 &inv_root_table[..num_inner_dit_layers],
336 added_bits,
337 shift,
338 );
339
340 let multi_layer_dif = MultiLayerDifButterfly {};
342
343 dft_layer_par_extra_layers(
346 &mut out.as_view_mut(),
347 &root_table[num_inner_dif_layers..num_inner_dif_layers + corr],
348 multi_layer_dif,
349 );
350
351 for (dif_0, dif_1, dif_2) in root_table[(num_inner_dif_layers + corr)..]
354 .iter()
355 .map(|slice| unsafe { as_base_slice::<DifButterfly<F>, F>(slice) }) .tuples()
357 {
358 dft_layer_par_triple(&mut out.as_view_mut(), dif_2, dif_1, dif_0, multi_layer_dif);
359 }
360
361 out
362 }
363}
364
365#[inline]
378fn dft_layer_par<F: Field, B: Butterfly<F>>(
379 mat: &mut RowMajorMatrixViewMut<'_, F>,
380 twiddles: &[B],
381) {
382 debug_assert!(
383 mat.height().is_multiple_of(twiddles.len()),
384 "Matrix height must be divisible by the number of twiddles"
385 );
386 let size = mat.values.len();
387 let num_blocks = twiddles.len();
388
389 let outer_block_size = size / num_blocks;
390 let half_outer_block_size = outer_block_size / 2;
391
392 mat.values
393 .par_chunks_exact_mut(outer_block_size)
394 .enumerate()
395 .for_each(|(ind, block)| {
396 let (hi_chunk, lo_chunk) = block.split_at_mut(half_outer_block_size);
398
399 let num_threads = current_num_threads();
401 let inner_block_size = size / (2 * num_blocks).max(num_threads);
402
403 hi_chunk
404 .par_chunks_mut(inner_block_size)
405 .zip(lo_chunk.par_chunks_mut(inner_block_size))
406 .for_each(|(hi_chunk, lo_chunk)| {
407 if ind == 0 {
408 TwiddleFreeButterfly.apply_to_rows(hi_chunk, lo_chunk);
410 } else {
411 twiddles[ind].apply_to_rows(hi_chunk, lo_chunk);
413 }
414 });
415 });
416}
417
418#[inline]
423fn par_remaining_layers<F: Field>(mat: &mut [F], chunk_size: usize, root_table: &[Vec<F>]) {
424 mat.par_chunks_exact_mut(chunk_size)
425 .enumerate()
426 .for_each(|(index, chunk)| {
427 remaining_layers(chunk, root_table, index);
428 });
429}
430
431fn remaining_layers<F: Field>(chunk: &mut [F], root_table: &[Vec<F>], index: usize) {
433 for (layer, twiddles) in root_table.iter().rev().enumerate() {
434 let num_twiddles_per_block = 1 << layer;
435 let start = index * num_twiddles_per_block;
436 let twiddle_range = start..(start + num_twiddles_per_block);
437 let dit_twiddles: &[DitButterfly<F>] = unsafe { as_base_slice(&twiddles[twiddle_range]) };
439 dft_layer(chunk, dit_twiddles);
440 }
441}
442
443#[inline]
451fn par_initial_layers<F: Field>(
452 mat: &mut [F],
453 chunk_size: usize,
454 root_table: &[Vec<F>],
455 log_height: usize,
456) {
457 let inv_height = F::ONE.div_2exp_u64(log_height as u64);
458 mat.par_chunks_exact_mut(chunk_size)
459 .enumerate()
460 .for_each(|(index, chunk)| {
461 scale_slice_in_place_single_core(chunk, inv_height);
463 initial_layers(chunk, root_table, index);
464 });
465}
466
467#[inline]
469fn initial_layers<F: Field>(chunk: &mut [F], root_table: &[Vec<F>], index: usize) {
470 let num_rounds = root_table.len();
471
472 for (layer, twiddles) in root_table.iter().enumerate() {
473 let num_twiddles_per_block = 1 << (num_rounds - layer - 1);
474 let start = index * num_twiddles_per_block;
475 let twiddle_range = start..(start + num_twiddles_per_block);
476 let dif_twiddles: &[DifButterfly<F>] = unsafe { as_base_slice(&twiddles[twiddle_range]) };
478 dft_layer(chunk, dif_twiddles);
479 }
480}
481
482fn par_middle_layers<F: Field>(
488 in_mat: &mut RowMajorMatrixViewMut<'_, F>,
489 out_mat: &mut RowMajorMatrixViewMut<'_, F>,
490 num_par_rows: usize,
491 root_table: &[Vec<F>],
492 inv_root_table: &[Vec<F>],
493 added_bits: usize,
494 shift: F,
495) {
496 debug_assert_eq!(in_mat.width(), out_mat.width());
497 debug_assert_eq!(in_mat.height() << added_bits, out_mat.height());
498
499 let width = in_mat.width();
500 let height = in_mat.height();
501 let num_rounds = root_table.len();
502 let in_chunk_size = num_par_rows * width;
503 let out_chunk_size = in_chunk_size << added_bits;
504
505 let log_height = log2_strict_usize(height);
506 let inv_height = F::ONE.div_2exp_u64(log_height as u64);
507
508 let mut scaling = shift.shifted_powers(inv_height).collect_n(height);
509 reverse_slice_index_bits(&mut scaling);
510
511 in_mat
512 .values
513 .par_chunks_exact_mut(in_chunk_size)
514 .zip(out_mat.values.par_chunks_exact_mut(out_chunk_size))
515 .zip(scaling.par_chunks_exact_mut(num_par_rows))
516 .enumerate()
517 .for_each(|(index, ((in_chunk, out_chunk), scaling))| {
518 remaining_layers(in_chunk, inv_root_table, index);
519
520 in_chunk
522 .chunks_exact(width)
523 .zip(scaling)
524 .zip(out_chunk.chunks_exact_mut(width << added_bits))
525 .for_each(|((in_row, scale), out_row)| {
526 out_row
527 .iter_mut()
528 .zip(in_row.iter())
529 .for_each(|(out_val, in_val)| {
530 *out_val = *in_val * *scale;
531 });
532 });
533
534 for (layer, twiddles) in root_table[..added_bits].iter().enumerate() {
537 let num_twiddles_per_block = 1 << (num_rounds - layer - 1);
538 let start = index * num_twiddles_per_block;
539 let twiddle_range = start..(start + num_twiddles_per_block);
540
541 let dif_twiddles_zeros: &[DifButterflyZeros<F>] =
543 unsafe { as_base_slice(&twiddles[twiddle_range]) };
544 dft_layer_zeros(out_chunk, dif_twiddles_zeros, added_bits - layer - 1);
545 }
546
547 initial_layers(out_chunk, &root_table[added_bits..], index);
548 });
549}
550
551#[inline]
560fn dft_layer<F: Field, B: Butterfly<F>>(vec: &mut [F], twiddles: &[B]) {
561 debug_assert_eq!(
562 vec.len() % twiddles.len(),
563 0,
564 "Vector length must be divisible by the number of twiddles"
565 );
566 let size = vec.len();
567 let num_blocks = twiddles.len();
568
569 let block_size = size / num_blocks;
570 let half_block_size = block_size / 2;
571
572 vec.chunks_exact_mut(block_size)
573 .zip(twiddles)
574 .for_each(|(block, &twiddle)| {
575 let (hi_chunk, lo_chunk) = block.split_at_mut(half_block_size);
577
578 twiddle.apply_to_rows(hi_chunk, lo_chunk);
580 });
581}
582
583#[inline]
595fn dft_layer_par_double<F: Field, B: Butterfly<F>, M: MultiLayerButterfly<F, B>>(
596 mat: &mut RowMajorMatrixViewMut<'_, F>,
597 twiddles_small: &[B],
598 twiddles_large: &[B],
599 multi_butterfly: M,
600) {
601 debug_assert!(
602 mat.height().is_multiple_of(twiddles_small.len()),
603 "Matrix height must be divisible by the number of twiddles"
604 );
605 let size = mat.values.len();
606 let num_blocks = twiddles_small.len();
607
608 let outer_block_size = size / num_blocks;
609 let quarter_outer_block_size = outer_block_size / 4;
610
611 let inner_chunk_size =
614 (workload_size::<F>().next_power_of_two() / 4).min(quarter_outer_block_size);
615
616 mat.values
617 .par_chunks_exact_mut(outer_block_size)
618 .enumerate()
619 .for_each(|(ind, block)| {
620 let chunk_par_iters_0 = block
623 .chunks_exact_mut(quarter_outer_block_size)
624 .map(|chunk| chunk.par_chunks_mut(inner_chunk_size))
625 .collect::<Vec<_>>();
626 let chunk_par_iters_1 = zip_par_iter_vec(chunk_par_iters_0);
627 chunk_par_iters_1.into_iter().tuples().for_each(|(hi, lo)| {
628 hi.zip(lo).for_each(|chunks| {
629 multi_butterfly.apply_2_layers(chunks, ind, twiddles_small, twiddles_large);
630 });
631 });
632 });
633}
634
635#[inline]
648fn dft_layer_par_triple<F: Field, B: Butterfly<F>, M: MultiLayerButterfly<F, B>>(
649 mat: &mut RowMajorMatrixViewMut<'_, F>,
650 twiddles_small: &[B],
651 twiddles_med: &[B],
652 twiddles_large: &[B],
653 multi_butterfly: M,
654) {
655 debug_assert!(
656 mat.height().is_multiple_of(twiddles_small.len()),
657 "Matrix height must be divisible by the number of twiddles"
658 );
659 let size = mat.values.len();
660 let num_blocks = twiddles_small.len();
661
662 let outer_block_size = size / num_blocks;
663 let eighth_outer_block_size = outer_block_size / 8;
664
665 let inner_chunk_size =
668 (workload_size::<F>().next_power_of_two() / 8).min(eighth_outer_block_size);
669
670 mat.values
671 .par_chunks_exact_mut(outer_block_size)
672 .enumerate()
673 .for_each(|(ind, block)| {
674 let chunk_par_iters_0 = block
677 .chunks_exact_mut(eighth_outer_block_size)
678 .map(|chunk| chunk.par_chunks_mut(inner_chunk_size))
679 .collect::<Vec<_>>();
680 let chunk_par_iters_1 = zip_par_iter_vec(chunk_par_iters_0);
681 let chunk_par_iters_2 = zip_par_iter_vec(chunk_par_iters_1);
682 chunk_par_iters_2.into_iter().tuples().for_each(|(hi, lo)| {
683 hi.zip(lo).for_each(|chunks| {
684 multi_butterfly.apply_3_layers(
685 chunks,
686 ind,
687 twiddles_small,
688 twiddles_med,
689 twiddles_large,
690 );
691 });
692 });
693 });
694}
695
696fn dft_layer_par_extra_layers<F: Field, B: Butterfly<F>, M: MultiLayerButterfly<F, B>>(
701 mat: &mut RowMajorMatrixViewMut<'_, F>,
702 root_table: &[Vec<F>],
703 multi_layer: M,
704) {
705 match root_table.len() {
706 1 => {
707 let fft_layer: &[B] = unsafe { as_base_slice(&root_table[0]) };
709 dft_layer_par(&mut mat.as_view_mut(), fft_layer);
710 }
711 2 => {
712 let fft_layer_0: &[B] = unsafe { as_base_slice(&root_table[0]) };
713 let fft_layer_1: &[B] = unsafe { as_base_slice(&root_table[1]) };
714 dft_layer_par_double(
715 &mut mat.as_view_mut(),
716 fft_layer_1,
717 fft_layer_0,
718 multi_layer,
719 );
720 }
721 0 => {}
722 _ => unreachable!("The number of layers must be 0, 1 or 2"),
723 }
724}
725
726#[inline]
749fn dft_layer_zeros<F: Field, B: Butterfly<F>>(vec: &mut [F], twiddles: &[B], skip: usize) {
750 debug_assert_eq!(
751 vec.len() % twiddles.len(),
752 0,
753 "Vector length must be divisible by the number of twiddles"
754 );
755 let size = vec.len();
756 let num_blocks = twiddles.len();
757
758 let block_size = size / num_blocks;
759 let half_block_size = block_size / 2;
760
761 vec.chunks_exact_mut(block_size)
762 .zip(twiddles)
763 .step_by(1 << skip) .for_each(|(block, &twiddle)| {
765 let (hi_chunk, lo_chunk) = block.split_at_mut(half_block_size);
767
768 twiddle.apply_to_rows(hi_chunk, lo_chunk);
770 });
771}
772
773type DoubleLayerBlockDecomposition<'a, F> =
775 ((&'a mut [F], &'a mut [F]), (&'a mut [F], &'a mut [F]));
776
777#[inline]
779fn fft_double_layer_single_twiddle<F: Field, Fly: Butterfly<F>>(
780 block: &mut DoubleLayerBlockDecomposition<'_, F>,
781 butterfly: Fly,
782) {
783 butterfly.apply_to_rows(block.0.0, block.1.0);
784 butterfly.apply_to_rows(block.0.1, block.1.1);
785}
786
787#[inline]
792fn fft_double_layer_double_twiddle<F: Field, Fly0: Butterfly<F>, Fly1: Butterfly<F>>(
793 block: &mut DoubleLayerBlockDecomposition<'_, F>,
794 fly0: Fly0,
795 fly1: Fly1,
796) {
797 fly0.apply_to_rows(block.0.0, block.0.1);
798 fly1.apply_to_rows(block.1.0, block.1.1);
799}
800
801type TripleLayerBlockDecomposition<'a, F> = (
803 ((&'a mut [F], &'a mut [F]), (&'a mut [F], &'a mut [F])),
804 ((&'a mut [F], &'a mut [F]), (&'a mut [F], &'a mut [F])),
805);
806
807#[inline]
809fn fft_triple_layer_single_twiddle<F: Field, Fly: Butterfly<F>>(
810 block: &mut TripleLayerBlockDecomposition<'_, F>,
811 butterfly: Fly,
812) {
813 butterfly.apply_to_rows(block.0.0.0, block.1.0.0);
814 butterfly.apply_to_rows(block.0.0.1, block.1.0.1);
815 butterfly.apply_to_rows(block.0.1.0, block.1.1.0);
816 butterfly.apply_to_rows(block.0.1.1, block.1.1.1);
817}
818
819#[inline]
824fn fft_triple_layer_double_twiddle<F: Field, Fly0: Butterfly<F>, Fly1: Butterfly<F>>(
825 block: &mut TripleLayerBlockDecomposition<'_, F>,
826 fly0: Fly0,
827 fly1: Fly1,
828) {
829 fly0.apply_to_rows(block.0.0.0, block.0.1.0);
830 fly0.apply_to_rows(block.0.0.1, block.0.1.1);
831 fly1.apply_to_rows(block.1.0.0, block.1.1.0);
832 fly1.apply_to_rows(block.1.0.1, block.1.1.1);
833}
834
835#[inline]
840fn fft_triple_layer_quad_twiddle<F: Field, Fly0: Butterfly<F>, Flies: Butterfly<F>>(
841 block: &mut TripleLayerBlockDecomposition<'_, F>,
842 fly0: Fly0,
843 butterflies: &[Flies],
844) {
845 debug_assert!(butterflies.len() == 3);
846 fly0.apply_to_rows(block.0.0.0, block.0.0.1);
847 butterflies[0].apply_to_rows(block.0.1.0, block.0.1.1);
848 butterflies[1].apply_to_rows(block.1.0.0, block.1.0.1);
849 butterflies[2].apply_to_rows(block.1.1.0, block.1.1.1);
850}
851
852#[must_use]
857const fn workload_size<T: Sized>() -> usize {
858 const L1_CACHE_SIZE: usize = 1 << 15; L1_CACHE_SIZE / size_of::<T>()
860}
861
862#[must_use]
868fn estimate_num_rows_in_l1<T: Sized>(height: usize, width: usize) -> usize {
869 (workload_size::<T>() / width)
870 .next_power_of_two()
871 .min(height) }
873
874#[inline]
881fn zip_par_iter_vec<I: IndexedParallelIterator>(
882 in_vec: Vec<I>,
883) -> Vec<impl IndexedParallelIterator<Item = (I::Item, I::Item)>> {
884 in_vec
885 .into_iter()
886 .tuples()
887 .map(|(hi, lo)| hi.zip(lo))
888 .collect::<Vec<_>>()
889}
890
891trait MultiLayerButterfly<F: Field, B: Butterfly<F>>: Copy + Send + Sync {
892 fn apply_2_layers(
893 &self,
894 chunk_decomposition: DoubleLayerBlockDecomposition<'_, F>,
895 ind: usize,
896 twiddles_small: &[B],
897 twiddles_large: &[B],
898 );
899
900 fn apply_3_layers(
901 &self,
902 chunk_decomposition: TripleLayerBlockDecomposition<'_, F>,
903 ind: usize,
904 twiddles_small: &[B],
905 twiddles_med: &[B],
906 twiddles_large: &[B],
907 );
908}
909
910#[derive(Debug, Clone, Copy)]
911struct MultiLayerDitButterfly;
912
913impl<F: Field> MultiLayerButterfly<F, DitButterfly<F>> for MultiLayerDitButterfly {
914 #[inline]
915 fn apply_2_layers(
916 &self,
917 mut blk_decomp: DoubleLayerBlockDecomposition<'_, F>,
918 ind: usize,
919 twiddles_small: &[DitButterfly<F>],
920 twiddles_large: &[DitButterfly<F>],
921 ) {
922 if ind == 0 {
923 fft_double_layer_single_twiddle(&mut blk_decomp, TwiddleFreeButterfly);
924 fft_double_layer_double_twiddle(
925 &mut blk_decomp,
926 TwiddleFreeButterfly,
927 twiddles_large[1],
928 );
929 } else {
930 fft_double_layer_single_twiddle(&mut blk_decomp, twiddles_small[ind]);
931 fft_double_layer_double_twiddle(
932 &mut blk_decomp,
933 twiddles_large[2 * ind],
934 twiddles_large[2 * ind + 1],
935 );
936 }
937 }
938
939 #[inline]
940 fn apply_3_layers(
941 &self,
942 mut blk_decomp: TripleLayerBlockDecomposition<'_, F>,
943 ind: usize,
944 twiddles_small: &[DitButterfly<F>],
945 twiddles_med: &[DitButterfly<F>],
946 twiddles_large: &[DitButterfly<F>],
947 ) {
948 if ind == 0 {
949 fft_triple_layer_single_twiddle(&mut blk_decomp, TwiddleFreeButterfly);
950 fft_triple_layer_double_twiddle(&mut blk_decomp, TwiddleFreeButterfly, twiddles_med[1]);
951 fft_triple_layer_quad_twiddle(
952 &mut blk_decomp,
953 TwiddleFreeButterfly,
954 &twiddles_large[1..4],
955 );
956 } else {
957 fft_triple_layer_single_twiddle(&mut blk_decomp, twiddles_small[ind]);
958 fft_triple_layer_double_twiddle(
959 &mut blk_decomp,
960 twiddles_med[2 * ind],
961 twiddles_med[2 * ind + 1],
962 );
963 fft_triple_layer_quad_twiddle(
964 &mut blk_decomp,
965 twiddles_large[4 * ind],
966 &twiddles_large[4 * ind + 1..4 * (ind + 1)],
967 );
968 }
969 }
970}
971
972#[derive(Debug, Clone, Copy)]
973struct MultiLayerDifButterfly;
974
975impl<F: Field> MultiLayerButterfly<F, DifButterfly<F>> for MultiLayerDifButterfly {
976 #[inline]
977 fn apply_2_layers(
978 &self,
979 mut blk_decomp: DoubleLayerBlockDecomposition<'_, F>,
980 ind: usize,
981 twiddles_small: &[DifButterfly<F>],
982 twiddles_large: &[DifButterfly<F>],
983 ) {
984 if ind == 0 {
985 fft_double_layer_double_twiddle(
986 &mut blk_decomp,
987 TwiddleFreeButterfly,
988 twiddles_large[1],
989 );
990 fft_double_layer_single_twiddle(&mut blk_decomp, TwiddleFreeButterfly);
991 } else {
992 fft_double_layer_double_twiddle(
993 &mut blk_decomp,
994 twiddles_large[2 * ind],
995 twiddles_large[2 * ind + 1],
996 );
997 fft_double_layer_single_twiddle(&mut blk_decomp, twiddles_small[ind]);
998 }
999 }
1000
1001 #[inline]
1002 fn apply_3_layers(
1003 &self,
1004 mut blk_decomp: TripleLayerBlockDecomposition<'_, F>,
1005 ind: usize,
1006 twiddles_small: &[DifButterfly<F>],
1007 twiddles_med: &[DifButterfly<F>],
1008 twiddles_large: &[DifButterfly<F>],
1009 ) {
1010 if ind == 0 {
1011 fft_triple_layer_quad_twiddle(
1012 &mut blk_decomp,
1013 TwiddleFreeButterfly,
1014 &twiddles_large[1..4],
1015 );
1016 fft_triple_layer_double_twiddle(&mut blk_decomp, TwiddleFreeButterfly, twiddles_med[1]);
1017 fft_triple_layer_single_twiddle(&mut blk_decomp, TwiddleFreeButterfly);
1018 } else {
1019 fft_triple_layer_quad_twiddle(
1020 &mut blk_decomp,
1021 twiddles_large[4 * ind],
1022 &twiddles_large[4 * ind + 1..4 * (ind + 1)],
1023 );
1024 fft_triple_layer_double_twiddle(
1025 &mut blk_decomp,
1026 twiddles_med[2 * ind],
1027 twiddles_med[2 * ind + 1],
1028 );
1029 fft_triple_layer_single_twiddle(&mut blk_decomp, twiddles_small[ind]);
1030 }
1031 }
1032}