Skip to main content

p3_dft/
radix_2_dit_parallel.rs

1use alloc::collections::BTreeMap;
2use alloc::slice;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use core::mem::{MaybeUninit, transmute};
6
7use itertools::{Itertools, izip};
8use p3_field::integers::QuotientMap;
9use p3_field::{Field, Powers, TwoAdicField};
10use p3_matrix::Matrix;
11use p3_matrix::bitrev::{BitReversalPerm, BitReversedMatrixView, BitReversibleMatrix};
12use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView, RowMajorMatrixViewMut};
13use p3_matrix::util::reverse_matrix_index_bits;
14use p3_maybe_rayon::prelude::*;
15use p3_util::{log2_strict_usize, reverse_bits_len, reverse_slice_index_bits};
16use spin::RwLock;
17use tracing::{debug_span, instrument};
18
19use crate::TwoAdicSubgroupDft;
20use crate::butterflies::{Butterfly, DitButterfly, ScaledDitButterfly, TwiddleFreeButterfly};
21
22/// A parallel FFT algorithm which divides a butterfly network's layers into two halves.
23///
24/// For the first half, we apply a butterfly network with smaller blocks in earlier layers,
25/// i.e. either DIT or Bowers G. Then we bit-reverse, and for the second half, we continue executing
26/// the same network but in bit-reversed order. This way we're always working with small blocks,
27/// so within each half, we can have a certain amount of parallelism with no cross-thread
28/// communication.
29#[derive(Default, Clone, Debug)]
30pub struct Radix2DitParallel<F> {
31    /// Twiddles based on roots of unity, used in the forward DFT.
32    twiddles: Arc<RwLock<BTreeMap<usize, Arc<VectorPair<F>>>>>,
33
34    /// A map from `(log_h, shift)` to forward DFT twiddles with that coset shift baked in.
35    #[allow(clippy::type_complexity)]
36    coset_twiddles: Arc<RwLock<BTreeMap<(usize, F), Arc<[Vec<F>]>>>>,
37
38    /// Twiddles based on inverse roots of unity, used in the inverse DFT.
39    inverse_twiddles: Arc<RwLock<BTreeMap<usize, Arc<VectorPair<F>>>>>,
40}
41
42/// A pair of vectors, one with twiddle factors in their natural order, the other bit-reversed.
43#[derive(Default, Clone, Debug)]
44struct VectorPair<F> {
45    twiddles: Vec<F>,
46    bitrev_twiddles: Vec<F>,
47}
48
49impl<F> Radix2DitParallel<F>
50where
51    F: TwoAdicField + Ord,
52{
53    fn get_or_compute_twiddles(&self, log_h: usize) -> Arc<VectorPair<F>> {
54        // Fast path: Check for the value with a cheap read lock.
55        if let Some(pair) = self.twiddles.read().get(&log_h) {
56            return pair.clone();
57        }
58
59        // Slow path: The value doesn't exist. Acquire a write lock.
60        let mut w_lock = self.twiddles.write();
61
62        // Double-check and compute if necessary.
63        w_lock
64            .entry(log_h)
65            .or_insert_with(|| {
66                let half_h = (1 << log_h) >> 1;
67                let root = F::two_adic_generator(log_h);
68                let twiddles = root.powers().collect_n(half_h);
69                let mut bitrev_twiddles = twiddles.clone();
70                reverse_slice_index_bits(&mut bitrev_twiddles);
71
72                Arc::new(VectorPair {
73                    twiddles,
74                    bitrev_twiddles,
75                })
76            })
77            .clone()
78    }
79
80    fn get_or_compute_coset_twiddles(&self, (log_h, shift): (usize, F)) -> Arc<[Vec<F>]> {
81        let key = (log_h, shift);
82        // Fast path: Try to get the value with a cheap read lock first.
83        if let Some(twiddles) = self.coset_twiddles.read().get(&key) {
84            return twiddles.clone();
85        }
86        // Slow path: The value isn't there, so we need to compute it.
87        // Acquire a write lock to ensure only one thread does the computation.
88        let mut w_lock = self.coset_twiddles.write();
89        // Double-check: Another thread might have inserted it while we waited for the lock.
90        // The `entry` API handles this check and insertion atomically.
91        w_lock
92            .entry(key)
93            .or_insert_with(|| {
94                let mid = log_h.div_ceil(2);
95                let h = 1 << log_h;
96                let root = F::two_adic_generator(log_h);
97                (0..log_h)
98                    .map(|layer| {
99                        let shift_power = shift.exp_power_of_2(layer);
100                        let powers = Powers {
101                            base: root.exp_power_of_2(layer),
102                            current: shift_power,
103                        };
104                        let mut twiddles = powers.collect_n(h >> (layer + 1));
105                        let layer_rev = log_h - 1 - layer;
106                        if layer_rev >= mid {
107                            reverse_slice_index_bits(&mut twiddles);
108                        }
109                        twiddles
110                    })
111                    .collect::<Vec<_>>()
112                    .into()
113            })
114            .clone()
115    }
116
117    fn get_or_compute_inverse_twiddles(&self, log_h: usize) -> Arc<VectorPair<F>> {
118        // Fast path: First, check for the value using a cheap read lock.
119        if let Some(pair) = self.inverse_twiddles.read().get(&log_h) {
120            return pair.clone();
121        }
122        // Slow path: The value doesn't exist. Acquire a write lock.
123        let mut w_lock = self.inverse_twiddles.write();
124        // Double-check: Another thread might have created the entry while we waited.
125        // The `entry` API handles this check and the insertion atomically.
126        w_lock
127            .entry(log_h)
128            .or_insert_with(|| {
129                // This computation only runs if the entry is truly vacant.
130                let half_h = (1 << log_h) >> 1;
131                let root_inv = F::two_adic_generator(log_h).inverse();
132                let twiddles = root_inv.powers().collect_n(half_h);
133                let mut bitrev_twiddles = twiddles.clone();
134                reverse_slice_index_bits(&mut bitrev_twiddles);
135
136                Arc::new(VectorPair {
137                    twiddles,
138                    bitrev_twiddles,
139                })
140            })
141            .clone()
142    }
143}
144
145impl<F: TwoAdicField + Ord> TwoAdicSubgroupDft<F> for Radix2DitParallel<F> {
146    type Evaluations = BitReversedMatrixView<RowMajorMatrix<F>>;
147
148    fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> Self::Evaluations {
149        let h = mat.height();
150        let log_h = log2_strict_usize(h);
151
152        // Compute twiddle factors, or take memoized ones if already available.
153        let twiddles = self.get_or_compute_twiddles(log_h);
154
155        let mid = log_h.div_ceil(2);
156
157        // The first half looks like a normal DIT.
158        reverse_matrix_index_bits(&mut mat);
159        first_half(&mut mat, mid, &twiddles.twiddles);
160
161        // For the second half, we flip the DIT, working in bit-reversed order.
162        reverse_matrix_index_bits(&mut mat);
163        second_half(&mut mat, mid, &twiddles.bitrev_twiddles, None);
164
165        mat.bit_reverse_rows()
166    }
167
168    #[instrument(skip_all, level = "debug", fields(dims = %mat.dimensions(), added_bits = added_bits))]
169    fn coset_lde_batch(
170        &self,
171        mut mat: RowMajorMatrix<F>,
172        added_bits: usize,
173        shift: F,
174    ) -> Self::Evaluations {
175        let w = mat.width;
176        let h = mat.height();
177        let log_h = log2_strict_usize(h);
178        let mid = log_h.div_ceil(2);
179
180        let inverse_twiddles = self.get_or_compute_inverse_twiddles(log_h);
181
182        // The first half looks like a normal DIT.
183        reverse_matrix_index_bits(&mut mat);
184        first_half(&mut mat, mid, &inverse_twiddles.twiddles);
185
186        // For the second half, we flip the DIT, working in bit-reversed order.
187        reverse_matrix_index_bits(&mut mat);
188        // We'll also scale by 1/h, as per the usual inverse DFT algorithm.
189        // If F isn't a PrimeField, (and is thus an extension field) it's much cheaper to
190        // invert in F::PrimeSubfield.
191        let h_inv_subfield = F::PrimeSubfield::from_int(h).try_inverse();
192        let scale = h_inv_subfield.map(F::from_prime_subfield);
193        second_half(&mut mat, mid, &inverse_twiddles.bitrev_twiddles, scale);
194        // We skip the final bit-reversal, since the next FFT expects bit-reversed input.
195
196        let lde_elems = w * (h << added_bits);
197        let elems_to_add = lde_elems - w * h;
198        debug_span!("reserve_exact").in_scope(|| mat.values.reserve_exact(elems_to_add));
199
200        let g_big = F::two_adic_generator(log_h + added_bits);
201
202        let mat_ptr = mat.values.as_mut_ptr();
203        let rest_ptr = unsafe { (mat_ptr as *mut MaybeUninit<F>).add(w * h) };
204        let first_slice: &mut [F] = unsafe { slice::from_raw_parts_mut(mat_ptr, w * h) };
205        let rest_slice: &mut [MaybeUninit<F>] =
206            unsafe { slice::from_raw_parts_mut(rest_ptr, lde_elems - w * h) };
207        let mut first_coset_mat = RowMajorMatrixViewMut::new(first_slice, w);
208        let mut rest_cosets_mat = rest_slice
209            .chunks_exact_mut(w * h)
210            .map(|slice| RowMajorMatrixViewMut::new(slice, w))
211            .collect_vec();
212
213        for coset_idx in 1..(1 << added_bits) {
214            let total_shift = g_big.exp_u64(coset_idx as u64) * shift;
215            let coset_idx = reverse_bits_len(coset_idx, added_bits);
216            let dest = &mut rest_cosets_mat[coset_idx - 1]; // - 1 because we removed the first matrix.
217            coset_dft_oop(self, &first_coset_mat.as_view(), dest, total_shift);
218        }
219
220        // Now run a forward DFT on the very first coset, this time in-place.
221        coset_dft(self, &mut first_coset_mat.as_view_mut(), shift);
222
223        // SAFETY: We wrote all values above.
224        unsafe {
225            mat.values.set_len(lde_elems);
226        }
227        BitReversalPerm::new_view(mat)
228    }
229}
230
231#[instrument(level = "debug", skip_all)]
232fn coset_dft<F: TwoAdicField + Ord>(
233    dft: &Radix2DitParallel<F>,
234    mat: &mut RowMajorMatrixViewMut<'_, F>,
235    shift: F,
236) {
237    let log_h = log2_strict_usize(mat.height());
238    let mid = log_h.div_ceil(2);
239
240    let twiddles = dft.get_or_compute_coset_twiddles((log_h, shift));
241
242    // The first half looks like a normal DIT.
243    first_half_general(mat, mid, &twiddles);
244
245    // For the second half, we flip the DIT, working in bit-reversed order.
246    reverse_matrix_index_bits(mat);
247
248    second_half_general(mat, mid, &twiddles);
249}
250
251/// Like `coset_dft`, except out-of-place.
252#[instrument(level = "debug", skip_all)]
253fn coset_dft_oop<F: TwoAdicField + Ord>(
254    dft: &Radix2DitParallel<F>,
255    src: &RowMajorMatrixView<'_, F>,
256    dst_maybe: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
257    shift: F,
258) {
259    assert_eq!(src.dimensions(), dst_maybe.dimensions());
260
261    let log_h = log2_strict_usize(dst_maybe.height());
262
263    if log_h == 0 {
264        // This is an edge case where first_half_general_oop doesn't work, as it expects there to be
265        // at least one layer in the network, so we just copy instead.
266        let src_maybe = unsafe {
267            transmute::<&RowMajorMatrixView<'_, F>, &RowMajorMatrixView<'_, MaybeUninit<F>>>(src)
268        };
269        dst_maybe.copy_from(src_maybe);
270        return;
271    }
272
273    let mid = log_h.div_ceil(2);
274
275    let twiddles = dft.get_or_compute_coset_twiddles((log_h, shift));
276
277    // The first half looks like a normal DIT.
278    first_half_general_oop(src, dst_maybe, mid, &twiddles);
279
280    // dst is now initialized.
281    let dst = unsafe {
282        transmute::<&mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>, &mut RowMajorMatrixViewMut<'_, F>>(
283            dst_maybe,
284        )
285    };
286
287    // For the second half, we flip the DIT, working in bit-reversed order.
288    reverse_matrix_index_bits(dst);
289
290    second_half_general(dst, mid, &twiddles);
291}
292
293/// This can be used as the first half of a DIT butterfly network.
294///
295/// For layer 0, all twiddle factors are 1 (root^0 = 1), so we use `TwiddleFreeButterfly`
296/// to avoid a Montgomery multiply by 1 across the entire matrix.
297///
298/// For layers 1 to mid-1 included, the first twiddle in each block is also always 1 (twiddles[0] = 1),
299/// so we special-case the first row-pair of each block to use `TwiddleFreeButterfly` as well.
300#[instrument(level = "debug", skip_all)]
301fn first_half<F: Field>(mat: &mut RowMajorMatrix<F>, mid: usize, twiddles: &[F]) {
302    let log_h = log2_strict_usize(mat.height());
303
304    // max block size: 2^mid
305    mat.par_row_chunks_exact_mut(1 << mid)
306        .for_each(|mut submat| {
307            let mut backwards = false;
308            for layer in 0..mid {
309                if layer == 0 {
310                    // For layer 0, half_block_size=1 and each block clones the twiddle
311                    // iterator from the start, consuming only twiddles[0] = root^0 = 1.
312                    // Use TwiddleFreeButterfly to skip the multiply entirely.
313                    dit_layer_twiddle_free(&mut submat, backwards);
314                } else {
315                    let layer_rev = log_h - 1 - layer;
316                    let layer_pow = 1 << layer_rev;
317                    // For layers 1..mid-1, twiddles[0] = root^0 = 1 is always the first
318                    // twiddle consumed per block. Use the optimized version that applies
319                    // TwiddleFreeButterfly for the first row-pair of each block.
320                    dit_layer_first_one(
321                        &mut submat,
322                        layer,
323                        twiddles.iter().step_by(layer_pow),
324                        backwards,
325                    );
326                }
327                backwards = !backwards;
328            }
329        });
330}
331
332/// Like `first_half`, except supporting different twiddle factors per layer, enabling coset shifts
333/// to be baked into them.
334#[instrument(level = "debug", skip_all)]
335fn first_half_general<F: Field>(
336    mat: &mut RowMajorMatrixViewMut<'_, F>,
337    mid: usize,
338    twiddles: &[Vec<F>],
339) {
340    let log_h = log2_strict_usize(mat.height());
341    mat.par_row_chunks_exact_mut(1 << mid)
342        .for_each(|mut submat| {
343            let mut backwards = false;
344            for layer in 0..mid {
345                let layer_rev = log_h - 1 - layer;
346                dit_layer(&mut submat, layer, twiddles[layer_rev].iter(), backwards);
347                backwards = !backwards;
348            }
349        });
350}
351
352/// Like `first_half_general`, except out-of-place.
353///
354/// Assumes there's at least one layer in the network, i.e. `src.height() > 1`.
355/// Undefined behavior otherwise.
356#[instrument(level = "debug", skip_all)]
357fn first_half_general_oop<F: Field>(
358    src: &RowMajorMatrixView<'_, F>,
359    dst_maybe: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
360    mid: usize,
361    twiddles: &[Vec<F>],
362) {
363    let log_h = log2_strict_usize(src.height());
364    src.par_row_chunks_exact(1 << mid)
365        .zip(dst_maybe.par_row_chunks_exact_mut(1 << mid))
366        .for_each(|(src_submat, mut dst_submat_maybe)| {
367            debug_assert_eq!(src_submat.dimensions(), dst_submat_maybe.dimensions());
368
369            // The first layer is special, done out-of-place.
370            // (Recall from the mid definition that there must be at least one layer here.)
371            let layer_rev = log_h - 1;
372            dit_layer_oop(
373                &src_submat,
374                &mut dst_submat_maybe,
375                0,
376                twiddles[layer_rev].iter(),
377            );
378
379            // submat is now initialized.
380            let mut dst_submat = unsafe {
381                transmute::<RowMajorMatrixViewMut<'_, MaybeUninit<F>>, RowMajorMatrixViewMut<'_, F>>(
382                    dst_submat_maybe,
383                )
384            };
385
386            // Subsequent layers.
387            let mut backwards = true;
388            for layer in 1..mid {
389                let layer_rev = log_h - 1 - layer;
390                dit_layer(
391                    &mut dst_submat,
392                    layer,
393                    twiddles[layer_rev].iter(),
394                    backwards,
395                );
396                backwards = !backwards;
397            }
398        });
399}
400
401/// This can be used as the second half of a DIT butterfly network. It works in bit-reversed order.
402///
403/// The optional `scale` parameter is used to scale the matrix by a constant factor. Rather than
404/// doing a separate pass over memory, we fold the scaling into the first butterfly layer to
405/// eliminate an extra memory pass.
406#[instrument(level = "debug", skip_all)]
407#[inline(always)] // To avoid branch on scale
408fn second_half<F: Field>(
409    mat: &mut RowMajorMatrix<F>,
410    mid: usize,
411    twiddles_rev: &[F],
412    scale: Option<F>,
413) {
414    let log_h = log2_strict_usize(mat.height());
415
416    // max block size: 2^(log_h - mid)
417    mat.par_row_chunks_exact_mut(1 << (log_h - mid))
418        .enumerate()
419        .for_each(|(thread, mut submat)| {
420            let mut backwards = false;
421            if let Some(scale) = scale {
422                // Fold the scale into the first butterfly layer to avoid a separate
423                // memory pass. This merges the O(N) scaling step into the first O(N)
424                // butterfly pass.
425                let mut scale_applied = false;
426                for layer in mid..log_h {
427                    let first_block = thread << (layer - mid);
428                    if !scale_applied {
429                        scale_applied = true;
430                        dit_layer_rev_scaled(
431                            &mut submat,
432                            log_h,
433                            layer,
434                            twiddles_rev[first_block..].iter().copied(),
435                            backwards,
436                            Some(scale),
437                        );
438                    } else {
439                        dit_layer_rev(
440                            &mut submat,
441                            log_h,
442                            layer,
443                            twiddles_rev[first_block..].iter().copied(),
444                            backwards,
445                        );
446                    }
447                    backwards = !backwards;
448                }
449                // Handle case where there are no layers in the second half (mid == log_h).
450                if !scale_applied {
451                    submat.scale(scale);
452                }
453            } else {
454                for layer in mid..log_h {
455                    let first_block = thread << (layer - mid);
456                    dit_layer_rev(
457                        &mut submat,
458                        log_h,
459                        layer,
460                        twiddles_rev[first_block..].iter().copied(),
461                        backwards,
462                    );
463                    backwards = !backwards;
464                }
465            }
466        });
467}
468
469/// Like `second_half`, except supporting different twiddle factors per layer, enabling coset shifts
470/// to be baked into them.
471#[instrument(level = "debug", skip_all)]
472fn second_half_general<F: Field>(
473    mat: &mut RowMajorMatrixViewMut<'_, F>,
474    mid: usize,
475    twiddles_rev: &[Vec<F>],
476) {
477    let log_h = log2_strict_usize(mat.height());
478    mat.par_row_chunks_exact_mut(1 << (log_h - mid))
479        .enumerate()
480        .for_each(|(thread, mut submat)| {
481            let mut backwards = false;
482            for layer in mid..log_h {
483                let layer_rev = log_h - 1 - layer;
484                let first_block = thread << (layer - mid);
485                dit_layer_rev(
486                    &mut submat,
487                    log_h,
488                    layer,
489                    twiddles_rev[layer_rev][first_block..].iter().copied(),
490                    backwards,
491                );
492                backwards = !backwards;
493            }
494        });
495}
496
497/// One layer of a DIT butterfly network where all twiddle factors are 1 (i.e., layer 0).
498///
499/// This is equivalent to `dit_layer` with `layer=0` and `twiddles[0]=1`, but uses
500/// `TwiddleFreeButterfly` to avoid a Montgomery multiplication by 1 in the hot loop.
501///
502/// Correctness: For layer=0, `half_block_size=1` and each block clones the twiddle
503/// iterator from position 0, consuming only `twiddles[0] = generator^0 = 1`.
504/// Since multiplying by 1 is a no-op, `TwiddleFreeButterfly` gives identical results.
505fn dit_layer_twiddle_free<F: Field>(submat: &mut RowMajorMatrixViewMut<'_, F>, backwards: bool) {
506    // layer=0 means half_block_size=1, block_size=2.
507    let width = submat.width();
508    debug_assert!(submat.height() >= 2);
509
510    let process_block = move |block: &mut [F]| {
511        // Each block is exactly 2 rows: lo = block[0..width], hi = block[width..2*width]
512        let (lo, hi) = block.split_at_mut(width);
513        TwiddleFreeButterfly.apply_to_rows(lo, hi);
514    };
515
516    let blocks = submat.values.chunks_mut(2 * width);
517    if backwards {
518        for block in blocks.rev() {
519            process_block(block);
520        }
521    } else {
522        for block in blocks {
523            process_block(block);
524        }
525    }
526}
527
528/// One layer of a DIT butterfly network where the first twiddle factor per block is always 1.
529///
530/// This is used in `first_half` for layers 1..mid-1 of the standard (non-coset) DFT/inverse DFT,
531/// where `twiddles[0] = root^0 = 1`. The first row-pair of each block uses `TwiddleFreeButterfly`
532/// to avoid one Montgomery multiplication per block, while subsequent row-pairs use `DitButterfly`.
533///
534/// Correctness: The twiddle iterator yields `twiddles[0], twiddles[step], twiddles[2*step], ...`
535/// where `twiddles[0] = root^0 = 1`. Only used when this property holds.
536fn dit_layer_first_one<'a, F: Field>(
537    submat: &mut RowMajorMatrixViewMut<'_, F>,
538    layer: usize,
539    twiddles: impl Iterator<Item = &'a F> + Clone,
540    backwards: bool,
541) {
542    let half_block_size = 1 << layer;
543    let block_size = half_block_size * 2;
544    let width = submat.width();
545    debug_assert!(submat.height() >= block_size);
546    debug_assert!(
547        half_block_size >= 2,
548        "layer must be >= 1 for dit_layer_first_one"
549    );
550
551    let process_block = move |block: &mut [F]| {
552        let (lows, highs) = block.split_at_mut(half_block_size * width);
553        let mut tw_iter = twiddles.clone();
554        // First row-pair: twiddle is always 1, use TwiddleFreeButterfly to skip the multiply.
555        let _ = tw_iter.next(); // consume twiddles[0] = 1
556        let (lo0, lo_rest) = lows.split_at_mut(width);
557        let (hi0, hi_rest) = highs.split_at_mut(width);
558        TwiddleFreeButterfly.apply_to_rows(lo0, hi0);
559        // Remaining row-pairs use DitButterfly with their respective twiddle factors.
560        for (lo, hi, twiddle) in izip!(
561            lo_rest.chunks_mut(width),
562            hi_rest.chunks_mut(width),
563            tw_iter
564        ) {
565            DitButterfly(*twiddle).apply_to_rows(lo, hi);
566        }
567    };
568
569    let blocks = submat.values.chunks_mut(block_size * width);
570    if backwards {
571        for block in blocks.rev() {
572            process_block(block);
573        }
574    } else {
575        for block in blocks {
576            process_block(block);
577        }
578    }
579}
580
581/// One layer of a DIT butterfly network.
582fn dit_layer<'a, F: Field>(
583    submat: &mut RowMajorMatrixViewMut<'_, F>,
584    layer: usize,
585    twiddles: impl Iterator<Item = &'a F> + Clone,
586    backwards: bool,
587) {
588    let half_block_size = 1 << layer;
589    let block_size = half_block_size * 2;
590    let width = submat.width();
591    debug_assert!(submat.height() >= block_size);
592
593    let process_block = move |block: &mut [F]| {
594        let (lows, highs) = block.split_at_mut(half_block_size * width);
595        for (lo, hi, twiddle) in izip!(
596            lows.chunks_mut(width),
597            highs.chunks_mut(width),
598            twiddles.clone()
599        ) {
600            DitButterfly(*twiddle).apply_to_rows(lo, hi);
601        }
602    };
603
604    let blocks = submat.values.chunks_mut(block_size * width);
605    if backwards {
606        for block in blocks.rev() {
607            process_block(block);
608        }
609    } else {
610        for block in blocks {
611            process_block(block);
612        }
613    }
614}
615
616/// One layer of a DIT butterfly network, out-of-place.
617fn dit_layer_oop<'a, F: Field>(
618    src: &RowMajorMatrixView<'_, F>,
619    dst: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
620    layer: usize,
621    twiddles: impl Iterator<Item = &'a F> + Clone,
622) {
623    debug_assert_eq!(src.dimensions(), dst.dimensions());
624    let half_block_size = 1 << layer;
625    let block_size = half_block_size * 2;
626    let width = dst.width();
627    debug_assert!(dst.height() >= block_size);
628
629    let process_blocks = move |src_block: &[F], dst_block: &mut [MaybeUninit<F>]| {
630        let (src_lows, src_highs) = src_block.split_at(half_block_size * width);
631        let (dst_lows, dst_highs) = dst_block.split_at_mut(half_block_size * width);
632
633        for (src_lo, dst_lo, src_hi, dst_hi, twiddle) in izip!(
634            src_lows.chunks(width),
635            dst_lows.chunks_mut(width),
636            src_highs.chunks(width),
637            dst_highs.chunks_mut(width),
638            twiddles.clone()
639        ) {
640            DitButterfly(*twiddle).apply_to_rows_oop(src_lo, dst_lo, src_hi, dst_hi);
641        }
642    };
643
644    let src_chunks = src.values.chunks(block_size * width);
645    let dst_chunks = dst.values.chunks_mut(block_size * width);
646
647    for (src_block, dst_block) in src_chunks.zip(dst_chunks) {
648        process_blocks(src_block, dst_block);
649    }
650}
651
652/// Like `dit_layer_rev`, except with an optional scale factor folded into the butterfly.
653///
654/// This avoids an extra memory pass when scaling is required (e.g., 1/N in inverse DFT).
655/// When `scale` is `None`, this is identical to `dit_layer_rev`.
656///
657/// When `scale` is `Some(s)`, uses `ScaledDitButterfly::new(twiddle, s)` which precomputes
658/// `twiddle * scale` once per block, reducing multiplications in the hot loop from 3 to 2.
659fn dit_layer_rev_scaled<F: Field>(
660    submat: &mut RowMajorMatrixViewMut<'_, F>,
661    log_h: usize,
662    layer: usize,
663    twiddles_rev: impl DoubleEndedIterator<Item = F> + ExactSizeIterator,
664    backwards: bool,
665    scale: Option<F>,
666) {
667    let layer_rev = log_h - 1 - layer;
668
669    let half_block_size = 1 << layer_rev;
670    let block_size = half_block_size * 2;
671    let width = submat.width();
672    debug_assert!(submat.height() >= block_size);
673
674    match scale {
675        None => {
676            // No scaling: same as regular dit_layer_rev
677            let blocks_and_twiddles = submat
678                .values
679                .chunks_mut(block_size * width)
680                .zip(twiddles_rev);
681            if backwards {
682                for (block, twiddle) in blocks_and_twiddles.rev() {
683                    let (lo, hi) = block.split_at_mut(half_block_size * width);
684                    DitButterfly(twiddle).apply_to_rows(lo, hi);
685                }
686            } else {
687                for (block, twiddle) in blocks_and_twiddles {
688                    let (lo, hi) = block.split_at_mut(half_block_size * width);
689                    DitButterfly(twiddle).apply_to_rows(lo, hi);
690                }
691            }
692        }
693        Some(s) => {
694            // Fold scaling into the butterfly to avoid a separate memory pass.
695            // ScaledDitButterfly::new precomputes twiddle * scale once per block,
696            // so the hot loop only needs 2 multiplications instead of 3.
697            let blocks_and_twiddles = submat
698                .values
699                .chunks_mut(block_size * width)
700                .zip(twiddles_rev);
701            if backwards {
702                for (block, twiddle) in blocks_and_twiddles.rev() {
703                    let (lo, hi) = block.split_at_mut(half_block_size * width);
704                    ScaledDitButterfly::new(twiddle, s).apply_to_rows(lo, hi);
705                }
706            } else {
707                for (block, twiddle) in blocks_and_twiddles {
708                    let (lo, hi) = block.split_at_mut(half_block_size * width);
709                    ScaledDitButterfly::new(twiddle, s).apply_to_rows(lo, hi);
710                }
711            }
712        }
713    }
714}
715
716/// Like `dit_layer`, except the matrix and twiddles are encoded in bit-reversed order.
717/// This can also be viewed as a layer of the Bowers G^T network.
718fn dit_layer_rev<F: Field>(
719    submat: &mut RowMajorMatrixViewMut<'_, F>,
720    log_h: usize,
721    layer: usize,
722    twiddles_rev: impl DoubleEndedIterator<Item = F> + ExactSizeIterator,
723    backwards: bool,
724) {
725    let layer_rev = log_h - 1 - layer;
726
727    let half_block_size = 1 << layer_rev;
728    let block_size = half_block_size * 2;
729    let width = submat.width();
730    debug_assert!(submat.height() >= block_size);
731
732    let blocks_and_twiddles = submat
733        .values
734        .chunks_mut(block_size * width)
735        .zip(twiddles_rev);
736    if backwards {
737        for (block, twiddle) in blocks_and_twiddles.rev() {
738            let (lo, hi) = block.split_at_mut(half_block_size * width);
739            DitButterfly(twiddle).apply_to_rows(lo, hi);
740        }
741    } else {
742        for (block, twiddle) in blocks_and_twiddles {
743            let (lo, hi) = block.split_at_mut(half_block_size * width);
744            DitButterfly(twiddle).apply_to_rows(lo, hi);
745        }
746    }
747}