Skip to main content

lib_q_stark_dft/
radix_2_dit_parallel.rs

1use alloc::collections::BTreeMap;
2use alloc::slice;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use core::mem::{
6    MaybeUninit,
7    transmute,
8};
9
10use itertools::{
11    Itertools,
12    izip,
13};
14use lib_q_stark_field::integers::QuotientMap;
15use lib_q_stark_field::{
16    Field,
17    Powers,
18    TwoAdicField,
19    assert_two_adic_bits,
20    assert_two_adic_fft_height,
21};
22use lib_q_stark_matrix::Matrix;
23use lib_q_stark_matrix::bitrev::{
24    BitReversalPerm,
25    BitReversedMatrixView,
26    BitReversibleMatrix,
27};
28use lib_q_stark_matrix::dense::{
29    RowMajorMatrix,
30    RowMajorMatrixView,
31    RowMajorMatrixViewMut,
32};
33use lib_q_stark_matrix::util::reverse_matrix_index_bits;
34use lib_q_stark_rayon::prelude::*;
35use lib_q_stark_util::{
36    log2_strict_usize,
37    reverse_bits_len,
38    reverse_slice_index_bits,
39};
40use spin::RwLock;
41use tracing::{
42    debug_span,
43    instrument,
44};
45
46use crate::TwoAdicSubgroupDft;
47use crate::butterflies::{
48    Butterfly,
49    DitButterfly,
50};
51
52/// A parallel FFT algorithm which divides a butterfly network's layers into two halves.
53///
54/// For the first half, we apply a butterfly network with smaller blocks in earlier layers,
55/// i.e. either DIT or Bowers G. Then we bit-reverse, and for the second half, we continue executing
56/// the same network but in bit-reversed order. This way we're always working with small blocks,
57/// so within each half, we can have a certain amount of parallelism with no cross-thread
58/// communication.
59#[derive(Default, Clone, Debug)]
60pub struct Radix2DitParallel<F> {
61    /// Twiddles based on roots of unity, used in the forward DFT.
62    twiddles: Arc<RwLock<BTreeMap<usize, Arc<VectorPair<F>>>>>,
63
64    /// A map from `(log_h, shift)` to forward DFT twiddles with that coset shift baked in.
65    #[allow(clippy::type_complexity)]
66    coset_twiddles: Arc<RwLock<BTreeMap<(usize, F), Arc<[Vec<F>]>>>>,
67
68    /// Twiddles based on inverse roots of unity, used in the inverse DFT.
69    inverse_twiddles: Arc<RwLock<BTreeMap<usize, Arc<VectorPair<F>>>>>,
70}
71
72/// A pair of vectors, one with twiddle factors in their natural order, the other bit-reversed.
73#[derive(Default, Clone, Debug)]
74struct VectorPair<F> {
75    twiddles: Vec<F>,
76    bitrev_twiddles: Vec<F>,
77}
78
79impl<F> Radix2DitParallel<F>
80where
81    F: TwoAdicField + Ord,
82{
83    fn get_or_compute_twiddles(&self, log_h: usize) -> Arc<VectorPair<F>> {
84        // Fast path: Check for the value with a cheap read lock.
85        if let Some(pair) = self.twiddles.read().get(&log_h) {
86            return pair.clone();
87        }
88
89        // Slow path: The value doesn't exist. Acquire a write lock.
90        let mut w_lock = self.twiddles.write();
91
92        // Double-check and compute if necessary.
93        w_lock
94            .entry(log_h)
95            .or_insert_with(|| {
96                assert_two_adic_bits::<F>(log_h);
97                let half_h = (1 << log_h) >> 1;
98                let root = F::two_adic_generator(log_h);
99                let twiddles = root.powers().collect_n(half_h);
100                let mut bitrev_twiddles = twiddles.clone();
101                reverse_slice_index_bits(&mut bitrev_twiddles);
102
103                Arc::new(VectorPair {
104                    twiddles,
105                    bitrev_twiddles,
106                })
107            })
108            .clone()
109    }
110
111    fn get_or_compute_coset_twiddles(&self, (log_h, shift): (usize, F)) -> Arc<[Vec<F>]> {
112        let key = (log_h, shift);
113        // Fast path: Try to get the value with a cheap read lock first.
114        if let Some(twiddles) = self.coset_twiddles.read().get(&key) {
115            return twiddles.clone();
116        }
117        // Slow path: The value isn't there, so we need to compute it.
118        // Acquire a write lock to ensure only one thread does the computation.
119        let mut w_lock = self.coset_twiddles.write();
120        // Double-check: Another thread might have inserted it while we waited for the lock.
121        // The `entry` API handles this check and insertion atomically.
122        w_lock
123            .entry(key)
124            .or_insert_with(|| {
125                assert_two_adic_bits::<F>(log_h);
126                let mid = log_h.div_ceil(2);
127                let h = 1 << log_h;
128                let root = F::two_adic_generator(log_h);
129                (0..log_h)
130                    .map(|layer| {
131                        let shift_power = shift.exp_power_of_2(layer);
132                        let powers = Powers {
133                            base: root.exp_power_of_2(layer),
134                            current: shift_power,
135                        };
136                        let mut twiddles = powers.collect_n(h >> (layer + 1));
137                        let layer_rev = log_h - 1 - layer;
138                        if layer_rev >= mid {
139                            reverse_slice_index_bits(&mut twiddles);
140                        }
141                        twiddles
142                    })
143                    .collect::<Vec<_>>()
144                    .into()
145            })
146            .clone()
147    }
148
149    fn get_or_compute_inverse_twiddles(&self, log_h: usize) -> Arc<VectorPair<F>> {
150        // Fast path: First, check for the value using a cheap read lock.
151        if let Some(pair) = self.inverse_twiddles.read().get(&log_h) {
152            return pair.clone();
153        }
154        // Slow path: The value doesn't exist. Acquire a write lock.
155        let mut w_lock = self.inverse_twiddles.write();
156        // Double-check: Another thread might have created the entry while we waited.
157        // The `entry` API handles this check and the insertion atomically.
158        w_lock
159            .entry(log_h)
160            .or_insert_with(|| {
161                assert_two_adic_bits::<F>(log_h);
162                // This computation only runs if the entry is truly vacant.
163                let half_h = (1 << log_h) >> 1;
164                let root_inv = F::two_adic_generator(log_h).inverse();
165                let twiddles = root_inv.powers().collect_n(half_h);
166                let mut bitrev_twiddles = twiddles.clone();
167                reverse_slice_index_bits(&mut bitrev_twiddles);
168
169                Arc::new(VectorPair {
170                    twiddles,
171                    bitrev_twiddles,
172                })
173            })
174            .clone()
175    }
176}
177
178impl<F: TwoAdicField + Ord> TwoAdicSubgroupDft<F> for Radix2DitParallel<F> {
179    type Evaluations = BitReversedMatrixView<RowMajorMatrix<F>>;
180
181    fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> Self::Evaluations {
182        assert_two_adic_fft_height::<F>(mat.height());
183        let h = mat.height();
184        let log_h = log2_strict_usize(h);
185
186        // Compute twiddle factors, or take memoized ones if already available.
187        let twiddles = self.get_or_compute_twiddles(log_h);
188
189        let mid = log_h.div_ceil(2);
190
191        // The first half looks like a normal DIT.
192        reverse_matrix_index_bits(&mut mat);
193        first_half(&mut mat, mid, &twiddles.twiddles);
194
195        // For the second half, we flip the DIT, working in bit-reversed order.
196        reverse_matrix_index_bits(&mut mat);
197        second_half(&mut mat, mid, &twiddles.bitrev_twiddles, None);
198
199        mat.bit_reverse_rows()
200    }
201
202    #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits = added_bits))]
203    fn coset_lde_batch(
204        &self,
205        mut mat: RowMajorMatrix<F>,
206        added_bits: usize,
207        shift: F,
208    ) -> Self::Evaluations {
209        let w = mat.width;
210        let h = mat.height();
211        let log_h = log2_strict_usize(h);
212        let mid = log_h.div_ceil(2);
213
214        let inverse_twiddles = self.get_or_compute_inverse_twiddles(log_h);
215
216        // The first half looks like a normal DIT.
217        reverse_matrix_index_bits(&mut mat);
218        first_half(&mut mat, mid, &inverse_twiddles.twiddles);
219
220        // For the second half, we flip the DIT, working in bit-reversed order.
221        reverse_matrix_index_bits(&mut mat);
222        // We'll also scale by 1/h, as per the usual inverse DFT algorithm.
223        // If F isn't a PrimeField, (and is thus an extension field) it's much cheaper to
224        // invert in F::PrimeSubfield.
225        let h_inv_subfield = F::PrimeSubfield::from_int(h).try_inverse();
226        let scale = h_inv_subfield.map(F::from_prime_subfield);
227        second_half(&mut mat, mid, &inverse_twiddles.bitrev_twiddles, scale);
228        // We skip the final bit-reversal, since the next FFT expects bit-reversed input.
229
230        let lde_elems = w * (h << added_bits);
231        let elems_to_add = lde_elems - w * h;
232        debug_span!("reserve_exact").in_scope(|| mat.values.reserve_exact(elems_to_add));
233
234        let g_big = F::two_adic_generator(log_h + added_bits);
235
236        let mat_ptr = mat.values.as_mut_ptr();
237        let rest_ptr = unsafe { (mat_ptr as *mut MaybeUninit<F>).add(w * h) };
238        let first_slice: &mut [F] = unsafe { slice::from_raw_parts_mut(mat_ptr, w * h) };
239        let rest_slice: &mut [MaybeUninit<F>] =
240            unsafe { slice::from_raw_parts_mut(rest_ptr, lde_elems - w * h) };
241        let mut first_coset_mat = RowMajorMatrixViewMut::new(first_slice, w);
242        let mut rest_cosets_mat = rest_slice
243            .chunks_exact_mut(w * h)
244            .map(|slice| RowMajorMatrixViewMut::new(slice, w))
245            .collect_vec();
246
247        for coset_idx in 1..(1 << added_bits) {
248            let total_shift = g_big.exp_u64(coset_idx as u64) * shift;
249            let coset_idx = reverse_bits_len(coset_idx, added_bits);
250            let dest = &mut rest_cosets_mat[coset_idx - 1]; // - 1 because we removed the first matrix.
251            coset_dft_oop(self, &first_coset_mat.as_view(), dest, total_shift);
252        }
253
254        // Now run a forward DFT on the very first coset, this time in-place.
255        coset_dft(self, &mut first_coset_mat.as_view_mut(), shift);
256
257        // SAFETY: We wrote all values above.
258        unsafe {
259            mat.values.set_len(lde_elems);
260        }
261        BitReversalPerm::new_view(mat)
262    }
263}
264
265#[instrument(level = "debug", skip_all)]
266fn coset_dft<F: TwoAdicField + Ord>(
267    dft: &Radix2DitParallel<F>,
268    mat: &mut RowMajorMatrixViewMut<'_, F>,
269    shift: F,
270) {
271    let log_h = log2_strict_usize(mat.height());
272    let mid = log_h.div_ceil(2);
273
274    // let twiddles = compute_factors((log_h, shift), &dft.coset_twiddles, compute_coset_twiddles);
275    let twiddles = dft.get_or_compute_coset_twiddles((log_h, shift));
276
277    // The first half looks like a normal DIT.
278    first_half_general(mat, mid, &twiddles);
279
280    // For the second half, we flip the DIT, working in bit-reversed order.
281    reverse_matrix_index_bits(mat);
282
283    second_half_general(mat, mid, &twiddles);
284}
285
286/// Like `coset_dft`, except out-of-place.
287#[instrument(level = "debug", skip_all)]
288fn coset_dft_oop<F: TwoAdicField + Ord>(
289    dft: &Radix2DitParallel<F>,
290    src: &RowMajorMatrixView<'_, F>,
291    dst_maybe: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
292    shift: F,
293) {
294    assert_eq!(src.dimensions(), dst_maybe.dimensions());
295
296    let log_h = log2_strict_usize(dst_maybe.height());
297
298    if log_h == 0 {
299        // This is an edge case where first_half_general_oop doesn't work, as it expects there to be
300        // at least one layer in the network, so we just copy instead.
301        let src_maybe = unsafe {
302            transmute::<&RowMajorMatrixView<'_, F>, &RowMajorMatrixView<'_, MaybeUninit<F>>>(src)
303        };
304        dst_maybe.copy_from(src_maybe);
305        return;
306    }
307
308    let mid = log_h.div_ceil(2);
309
310    let twiddles = dft.get_or_compute_coset_twiddles((log_h, shift));
311
312    // The first half looks like a normal DIT.
313    first_half_general_oop(src, dst_maybe, mid, &twiddles);
314
315    // dst is now initialized.
316    let dst = unsafe {
317        transmute::<&mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>, &mut RowMajorMatrixViewMut<'_, F>>(
318            dst_maybe,
319        )
320    };
321
322    // For the second half, we flip the DIT, working in bit-reversed order.
323    reverse_matrix_index_bits(dst);
324
325    second_half_general(dst, mid, &twiddles);
326}
327
328/// This can be used as the first half of a DIT butterfly network.
329#[instrument(level = "debug", skip_all)]
330fn first_half<F: Field>(mat: &mut RowMajorMatrix<F>, mid: usize, twiddles: &[F]) {
331    let log_h = log2_strict_usize(mat.height());
332
333    // max block size: 2^mid
334    mat.par_row_chunks_exact_mut(1 << mid)
335        .for_each(|mut submat| {
336            let mut backwards = false;
337            for layer in 0..mid {
338                let layer_rev = log_h - 1 - layer;
339                let layer_pow = 1 << layer_rev;
340                dit_layer(
341                    &mut submat,
342                    layer,
343                    twiddles.iter().step_by(layer_pow),
344                    backwards,
345                );
346                backwards = !backwards;
347            }
348        });
349}
350
351/// Like `first_half`, except supporting different twiddle factors per layer, enabling coset shifts
352/// to be baked into them.
353#[instrument(level = "debug", skip_all)]
354fn first_half_general<F: Field>(
355    mat: &mut RowMajorMatrixViewMut<'_, F>,
356    mid: usize,
357    twiddles: &[Vec<F>],
358) {
359    let log_h = log2_strict_usize(mat.height());
360    mat.par_row_chunks_exact_mut(1 << mid)
361        .for_each(|mut submat| {
362            let mut backwards = false;
363            for layer in 0..mid {
364                let layer_rev = log_h - 1 - layer;
365                dit_layer(&mut submat, layer, twiddles[layer_rev].iter(), backwards);
366                backwards = !backwards;
367            }
368        });
369}
370
371/// Like `first_half_general`, except out-of-place.
372///
373/// Assumes there's at least one layer in the network, i.e. `src.height() > 1`.
374/// Undefined behavior otherwise.
375#[instrument(level = "debug", skip_all)]
376fn first_half_general_oop<F: Field>(
377    src: &RowMajorMatrixView<'_, F>,
378    dst_maybe: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
379    mid: usize,
380    twiddles: &[Vec<F>],
381) {
382    let log_h = log2_strict_usize(src.height());
383    src.par_row_chunks_exact(1 << mid)
384        .zip(dst_maybe.par_row_chunks_exact_mut(1 << mid))
385        .for_each(|(src_submat, mut dst_submat_maybe)| {
386            debug_assert_eq!(src_submat.dimensions(), dst_submat_maybe.dimensions());
387
388            // The first layer is special, done out-of-place.
389            // (Recall from the mid definition that there must be at least one layer here.)
390            let layer_rev = log_h - 1;
391            dit_layer_oop(
392                &src_submat,
393                &mut dst_submat_maybe,
394                0,
395                twiddles[layer_rev].iter(),
396            );
397
398            // submat is now initialized.
399            let mut dst_submat = unsafe {
400                transmute::<RowMajorMatrixViewMut<'_, MaybeUninit<F>>, RowMajorMatrixViewMut<'_, F>>(
401                    dst_submat_maybe,
402                )
403            };
404
405            // Subsequent layers.
406            let mut backwards = true;
407            for layer in 1..mid {
408                let layer_rev = log_h - 1 - layer;
409                dit_layer(
410                    &mut dst_submat,
411                    layer,
412                    twiddles[layer_rev].iter(),
413                    backwards,
414                );
415                backwards = !backwards;
416            }
417        });
418}
419
420/// This can be used as the second half of a DIT butterfly network. It works in bit-reversed order.
421///
422/// The optional `scale` parameter is used to scale the matrix by a constant factor. Normally that
423/// would be a separate step, but it's best to merge it into a butterfly network to avoid a
424/// separate pass through main memory.
425#[instrument(level = "debug", skip_all)]
426#[inline(always)] // To avoid branch on scale
427fn second_half<F: Field>(
428    mat: &mut RowMajorMatrix<F>,
429    mid: usize,
430    twiddles_rev: &[F],
431    scale: Option<F>,
432) {
433    let log_h = log2_strict_usize(mat.height());
434
435    // max block size: 2^(log_h - mid)
436    mat.par_row_chunks_exact_mut(1 << (log_h - mid))
437        .enumerate()
438        .for_each(|(thread, mut submat)| {
439            let mut backwards = false;
440            if let Some(scale) = scale {
441                submat.scale(scale);
442            }
443            for layer in mid..log_h {
444                let first_block = thread << (layer - mid);
445                dit_layer_rev(
446                    &mut submat,
447                    log_h,
448                    layer,
449                    twiddles_rev[first_block..].iter().copied(),
450                    backwards,
451                );
452                backwards = !backwards;
453            }
454        });
455}
456
457/// Like `second_half`, except supporting different twiddle factors per layer, enabling coset shifts
458/// to be baked into them.
459#[instrument(level = "debug", skip_all)]
460fn second_half_general<F: Field>(
461    mat: &mut RowMajorMatrixViewMut<'_, F>,
462    mid: usize,
463    twiddles_rev: &[Vec<F>],
464) {
465    let log_h = log2_strict_usize(mat.height());
466    mat.par_row_chunks_exact_mut(1 << (log_h - mid))
467        .enumerate()
468        .for_each(|(thread, mut submat)| {
469            let mut backwards = false;
470            for layer in mid..log_h {
471                let layer_rev = log_h - 1 - layer;
472                let first_block = thread << (layer - mid);
473                dit_layer_rev(
474                    &mut submat,
475                    log_h,
476                    layer,
477                    twiddles_rev[layer_rev][first_block..].iter().copied(),
478                    backwards,
479                );
480                backwards = !backwards;
481            }
482        });
483}
484
485/// One layer of a DIT butterfly network.
486fn dit_layer<'a, F: Field>(
487    submat: &mut RowMajorMatrixViewMut<'_, F>,
488    layer: usize,
489    twiddles: impl Iterator<Item = &'a F> + Clone,
490    backwards: bool,
491) {
492    let half_block_size = 1 << layer;
493    let block_size = half_block_size * 2;
494    let width = submat.width();
495    debug_assert!(submat.height() >= block_size);
496
497    let process_block = move |block: &mut [F]| {
498        let (lows, highs) = block.split_at_mut(half_block_size * width);
499        for (lo, hi, twiddle) in izip!(
500            lows.chunks_mut(width),
501            highs.chunks_mut(width),
502            twiddles.clone()
503        ) {
504            DitButterfly(*twiddle).apply_to_rows(lo, hi);
505        }
506    };
507
508    let blocks = submat.values.chunks_mut(block_size * width);
509    if backwards {
510        for block in blocks.rev() {
511            process_block(block);
512        }
513    } else {
514        for block in blocks {
515            process_block(block);
516        }
517    }
518}
519
520/// One layer of a DIT butterfly network.
521fn dit_layer_oop<'a, F: Field>(
522    src: &RowMajorMatrixView<'_, F>,
523    dst: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
524    layer: usize,
525    twiddles: impl Iterator<Item = &'a F> + Clone,
526) {
527    debug_assert_eq!(src.dimensions(), dst.dimensions());
528    let half_block_size = 1 << layer;
529    let block_size = half_block_size * 2;
530    let width = dst.width();
531    debug_assert!(dst.height() >= block_size);
532
533    let process_blocks = move |src_block: &[F], dst_block: &mut [MaybeUninit<F>]| {
534        let (src_lows, src_highs) = src_block.split_at(half_block_size * width);
535        let (dst_lows, dst_highs) = dst_block.split_at_mut(half_block_size * width);
536
537        for (src_lo, dst_lo, src_hi, dst_hi, twiddle) in izip!(
538            src_lows.chunks(width),
539            dst_lows.chunks_mut(width),
540            src_highs.chunks(width),
541            dst_highs.chunks_mut(width),
542            twiddles.clone()
543        ) {
544            DitButterfly(*twiddle).apply_to_rows_oop(src_lo, dst_lo, src_hi, dst_hi);
545        }
546    };
547
548    let src_chunks = src.values.chunks(block_size * width);
549    let dst_chunks = dst.values.chunks_mut(block_size * width);
550
551    for (src_block, dst_block) in src_chunks.zip(dst_chunks) {
552        process_blocks(src_block, dst_block);
553    }
554}
555
556/// Like `dit_layer`, except the matrix and twiddles are encoded in bit-reversed order.
557/// This can also be viewed as a layer of the Bowers G^T network.
558fn dit_layer_rev<F: Field>(
559    submat: &mut RowMajorMatrixViewMut<'_, F>,
560    log_h: usize,
561    layer: usize,
562    twiddles_rev: impl DoubleEndedIterator<Item = F> + ExactSizeIterator,
563    backwards: bool,
564) {
565    let layer_rev = log_h - 1 - layer;
566
567    let half_block_size = 1 << layer_rev;
568    let block_size = half_block_size * 2;
569    let width = submat.width();
570    debug_assert!(submat.height() >= block_size);
571
572    let blocks_and_twiddles = submat
573        .values
574        .chunks_mut(block_size * width)
575        .zip(twiddles_rev);
576    if backwards {
577        for (block, twiddle) in blocks_and_twiddles.rev() {
578            let (lo, hi) = block.split_at_mut(half_block_size * width);
579            DitButterfly(twiddle).apply_to_rows(lo, hi);
580        }
581    } else {
582        for (block, twiddle) in blocks_and_twiddles {
583            let (lo, hi) = block.split_at_mut(half_block_size * width);
584            DitButterfly(twiddle).apply_to_rows(lo, hi);
585        }
586    }
587}