1use alloc::collections::BTreeMap;
2use alloc::slice;
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use core::mem::{
6 MaybeUninit,
7 transmute,
8};
9
10use itertools::{
11 Itertools,
12 izip,
13};
14use lib_q_stark_field::integers::QuotientMap;
15use lib_q_stark_field::{
16 Field,
17 Powers,
18 TwoAdicField,
19 assert_two_adic_bits,
20 assert_two_adic_fft_height,
21};
22use lib_q_stark_matrix::Matrix;
23use lib_q_stark_matrix::bitrev::{
24 BitReversalPerm,
25 BitReversedMatrixView,
26 BitReversibleMatrix,
27};
28use lib_q_stark_matrix::dense::{
29 RowMajorMatrix,
30 RowMajorMatrixView,
31 RowMajorMatrixViewMut,
32};
33use lib_q_stark_matrix::util::reverse_matrix_index_bits;
34use lib_q_stark_rayon::prelude::*;
35use lib_q_stark_util::{
36 log2_strict_usize,
37 reverse_bits_len,
38 reverse_slice_index_bits,
39};
40use spin::RwLock;
41use tracing::{
42 debug_span,
43 instrument,
44};
45
46use crate::TwoAdicSubgroupDft;
47use crate::butterflies::{
48 Butterfly,
49 DitButterfly,
50};
51
52#[derive(Default, Clone, Debug)]
60pub struct Radix2DitParallel<F> {
61 twiddles: Arc<RwLock<BTreeMap<usize, Arc<VectorPair<F>>>>>,
63
64 #[allow(clippy::type_complexity)]
66 coset_twiddles: Arc<RwLock<BTreeMap<(usize, F), Arc<[Vec<F>]>>>>,
67
68 inverse_twiddles: Arc<RwLock<BTreeMap<usize, Arc<VectorPair<F>>>>>,
70}
71
72#[derive(Default, Clone, Debug)]
74struct VectorPair<F> {
75 twiddles: Vec<F>,
76 bitrev_twiddles: Vec<F>,
77}
78
79impl<F> Radix2DitParallel<F>
80where
81 F: TwoAdicField + Ord,
82{
83 fn get_or_compute_twiddles(&self, log_h: usize) -> Arc<VectorPair<F>> {
84 if let Some(pair) = self.twiddles.read().get(&log_h) {
86 return pair.clone();
87 }
88
89 let mut w_lock = self.twiddles.write();
91
92 w_lock
94 .entry(log_h)
95 .or_insert_with(|| {
96 assert_two_adic_bits::<F>(log_h);
97 let half_h = (1 << log_h) >> 1;
98 let root = F::two_adic_generator(log_h);
99 let twiddles = root.powers().collect_n(half_h);
100 let mut bitrev_twiddles = twiddles.clone();
101 reverse_slice_index_bits(&mut bitrev_twiddles);
102
103 Arc::new(VectorPair {
104 twiddles,
105 bitrev_twiddles,
106 })
107 })
108 .clone()
109 }
110
111 fn get_or_compute_coset_twiddles(&self, (log_h, shift): (usize, F)) -> Arc<[Vec<F>]> {
112 let key = (log_h, shift);
113 if let Some(twiddles) = self.coset_twiddles.read().get(&key) {
115 return twiddles.clone();
116 }
117 let mut w_lock = self.coset_twiddles.write();
120 w_lock
123 .entry(key)
124 .or_insert_with(|| {
125 assert_two_adic_bits::<F>(log_h);
126 let mid = log_h.div_ceil(2);
127 let h = 1 << log_h;
128 let root = F::two_adic_generator(log_h);
129 (0..log_h)
130 .map(|layer| {
131 let shift_power = shift.exp_power_of_2(layer);
132 let powers = Powers {
133 base: root.exp_power_of_2(layer),
134 current: shift_power,
135 };
136 let mut twiddles = powers.collect_n(h >> (layer + 1));
137 let layer_rev = log_h - 1 - layer;
138 if layer_rev >= mid {
139 reverse_slice_index_bits(&mut twiddles);
140 }
141 twiddles
142 })
143 .collect::<Vec<_>>()
144 .into()
145 })
146 .clone()
147 }
148
149 fn get_or_compute_inverse_twiddles(&self, log_h: usize) -> Arc<VectorPair<F>> {
150 if let Some(pair) = self.inverse_twiddles.read().get(&log_h) {
152 return pair.clone();
153 }
154 let mut w_lock = self.inverse_twiddles.write();
156 w_lock
159 .entry(log_h)
160 .or_insert_with(|| {
161 assert_two_adic_bits::<F>(log_h);
162 let half_h = (1 << log_h) >> 1;
164 let root_inv = F::two_adic_generator(log_h).inverse();
165 let twiddles = root_inv.powers().collect_n(half_h);
166 let mut bitrev_twiddles = twiddles.clone();
167 reverse_slice_index_bits(&mut bitrev_twiddles);
168
169 Arc::new(VectorPair {
170 twiddles,
171 bitrev_twiddles,
172 })
173 })
174 .clone()
175 }
176}
177
178impl<F: TwoAdicField + Ord> TwoAdicSubgroupDft<F> for Radix2DitParallel<F> {
179 type Evaluations = BitReversedMatrixView<RowMajorMatrix<F>>;
180
181 fn dft_batch(&self, mut mat: RowMajorMatrix<F>) -> Self::Evaluations {
182 assert_two_adic_fft_height::<F>(mat.height());
183 let h = mat.height();
184 let log_h = log2_strict_usize(h);
185
186 let twiddles = self.get_or_compute_twiddles(log_h);
188
189 let mid = log_h.div_ceil(2);
190
191 reverse_matrix_index_bits(&mut mat);
193 first_half(&mut mat, mid, &twiddles.twiddles);
194
195 reverse_matrix_index_bits(&mut mat);
197 second_half(&mut mat, mid, &twiddles.bitrev_twiddles, None);
198
199 mat.bit_reverse_rows()
200 }
201
202 #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits = added_bits))]
203 fn coset_lde_batch(
204 &self,
205 mut mat: RowMajorMatrix<F>,
206 added_bits: usize,
207 shift: F,
208 ) -> Self::Evaluations {
209 let w = mat.width;
210 let h = mat.height();
211 let log_h = log2_strict_usize(h);
212 let mid = log_h.div_ceil(2);
213
214 let inverse_twiddles = self.get_or_compute_inverse_twiddles(log_h);
215
216 reverse_matrix_index_bits(&mut mat);
218 first_half(&mut mat, mid, &inverse_twiddles.twiddles);
219
220 reverse_matrix_index_bits(&mut mat);
222 let h_inv_subfield = F::PrimeSubfield::from_int(h).try_inverse();
226 let scale = h_inv_subfield.map(F::from_prime_subfield);
227 second_half(&mut mat, mid, &inverse_twiddles.bitrev_twiddles, scale);
228 let lde_elems = w * (h << added_bits);
231 let elems_to_add = lde_elems - w * h;
232 debug_span!("reserve_exact").in_scope(|| mat.values.reserve_exact(elems_to_add));
233
234 let g_big = F::two_adic_generator(log_h + added_bits);
235
236 let mat_ptr = mat.values.as_mut_ptr();
237 let rest_ptr = unsafe { (mat_ptr as *mut MaybeUninit<F>).add(w * h) };
238 let first_slice: &mut [F] = unsafe { slice::from_raw_parts_mut(mat_ptr, w * h) };
239 let rest_slice: &mut [MaybeUninit<F>] =
240 unsafe { slice::from_raw_parts_mut(rest_ptr, lde_elems - w * h) };
241 let mut first_coset_mat = RowMajorMatrixViewMut::new(first_slice, w);
242 let mut rest_cosets_mat = rest_slice
243 .chunks_exact_mut(w * h)
244 .map(|slice| RowMajorMatrixViewMut::new(slice, w))
245 .collect_vec();
246
247 for coset_idx in 1..(1 << added_bits) {
248 let total_shift = g_big.exp_u64(coset_idx as u64) * shift;
249 let coset_idx = reverse_bits_len(coset_idx, added_bits);
250 let dest = &mut rest_cosets_mat[coset_idx - 1]; coset_dft_oop(self, &first_coset_mat.as_view(), dest, total_shift);
252 }
253
254 coset_dft(self, &mut first_coset_mat.as_view_mut(), shift);
256
257 unsafe {
259 mat.values.set_len(lde_elems);
260 }
261 BitReversalPerm::new_view(mat)
262 }
263}
264
265#[instrument(level = "debug", skip_all)]
266fn coset_dft<F: TwoAdicField + Ord>(
267 dft: &Radix2DitParallel<F>,
268 mat: &mut RowMajorMatrixViewMut<'_, F>,
269 shift: F,
270) {
271 let log_h = log2_strict_usize(mat.height());
272 let mid = log_h.div_ceil(2);
273
274 let twiddles = dft.get_or_compute_coset_twiddles((log_h, shift));
276
277 first_half_general(mat, mid, &twiddles);
279
280 reverse_matrix_index_bits(mat);
282
283 second_half_general(mat, mid, &twiddles);
284}
285
286#[instrument(level = "debug", skip_all)]
288fn coset_dft_oop<F: TwoAdicField + Ord>(
289 dft: &Radix2DitParallel<F>,
290 src: &RowMajorMatrixView<'_, F>,
291 dst_maybe: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
292 shift: F,
293) {
294 assert_eq!(src.dimensions(), dst_maybe.dimensions());
295
296 let log_h = log2_strict_usize(dst_maybe.height());
297
298 if log_h == 0 {
299 let src_maybe = unsafe {
302 transmute::<&RowMajorMatrixView<'_, F>, &RowMajorMatrixView<'_, MaybeUninit<F>>>(src)
303 };
304 dst_maybe.copy_from(src_maybe);
305 return;
306 }
307
308 let mid = log_h.div_ceil(2);
309
310 let twiddles = dft.get_or_compute_coset_twiddles((log_h, shift));
311
312 first_half_general_oop(src, dst_maybe, mid, &twiddles);
314
315 let dst = unsafe {
317 transmute::<&mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>, &mut RowMajorMatrixViewMut<'_, F>>(
318 dst_maybe,
319 )
320 };
321
322 reverse_matrix_index_bits(dst);
324
325 second_half_general(dst, mid, &twiddles);
326}
327
328#[instrument(level = "debug", skip_all)]
330fn first_half<F: Field>(mat: &mut RowMajorMatrix<F>, mid: usize, twiddles: &[F]) {
331 let log_h = log2_strict_usize(mat.height());
332
333 mat.par_row_chunks_exact_mut(1 << mid)
335 .for_each(|mut submat| {
336 let mut backwards = false;
337 for layer in 0..mid {
338 let layer_rev = log_h - 1 - layer;
339 let layer_pow = 1 << layer_rev;
340 dit_layer(
341 &mut submat,
342 layer,
343 twiddles.iter().step_by(layer_pow),
344 backwards,
345 );
346 backwards = !backwards;
347 }
348 });
349}
350
351#[instrument(level = "debug", skip_all)]
354fn first_half_general<F: Field>(
355 mat: &mut RowMajorMatrixViewMut<'_, F>,
356 mid: usize,
357 twiddles: &[Vec<F>],
358) {
359 let log_h = log2_strict_usize(mat.height());
360 mat.par_row_chunks_exact_mut(1 << mid)
361 .for_each(|mut submat| {
362 let mut backwards = false;
363 for layer in 0..mid {
364 let layer_rev = log_h - 1 - layer;
365 dit_layer(&mut submat, layer, twiddles[layer_rev].iter(), backwards);
366 backwards = !backwards;
367 }
368 });
369}
370
371#[instrument(level = "debug", skip_all)]
376fn first_half_general_oop<F: Field>(
377 src: &RowMajorMatrixView<'_, F>,
378 dst_maybe: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
379 mid: usize,
380 twiddles: &[Vec<F>],
381) {
382 let log_h = log2_strict_usize(src.height());
383 src.par_row_chunks_exact(1 << mid)
384 .zip(dst_maybe.par_row_chunks_exact_mut(1 << mid))
385 .for_each(|(src_submat, mut dst_submat_maybe)| {
386 debug_assert_eq!(src_submat.dimensions(), dst_submat_maybe.dimensions());
387
388 let layer_rev = log_h - 1;
391 dit_layer_oop(
392 &src_submat,
393 &mut dst_submat_maybe,
394 0,
395 twiddles[layer_rev].iter(),
396 );
397
398 let mut dst_submat = unsafe {
400 transmute::<RowMajorMatrixViewMut<'_, MaybeUninit<F>>, RowMajorMatrixViewMut<'_, F>>(
401 dst_submat_maybe,
402 )
403 };
404
405 let mut backwards = true;
407 for layer in 1..mid {
408 let layer_rev = log_h - 1 - layer;
409 dit_layer(
410 &mut dst_submat,
411 layer,
412 twiddles[layer_rev].iter(),
413 backwards,
414 );
415 backwards = !backwards;
416 }
417 });
418}
419
420#[instrument(level = "debug", skip_all)]
426#[inline(always)] fn second_half<F: Field>(
428 mat: &mut RowMajorMatrix<F>,
429 mid: usize,
430 twiddles_rev: &[F],
431 scale: Option<F>,
432) {
433 let log_h = log2_strict_usize(mat.height());
434
435 mat.par_row_chunks_exact_mut(1 << (log_h - mid))
437 .enumerate()
438 .for_each(|(thread, mut submat)| {
439 let mut backwards = false;
440 if let Some(scale) = scale {
441 submat.scale(scale);
442 }
443 for layer in mid..log_h {
444 let first_block = thread << (layer - mid);
445 dit_layer_rev(
446 &mut submat,
447 log_h,
448 layer,
449 twiddles_rev[first_block..].iter().copied(),
450 backwards,
451 );
452 backwards = !backwards;
453 }
454 });
455}
456
457#[instrument(level = "debug", skip_all)]
460fn second_half_general<F: Field>(
461 mat: &mut RowMajorMatrixViewMut<'_, F>,
462 mid: usize,
463 twiddles_rev: &[Vec<F>],
464) {
465 let log_h = log2_strict_usize(mat.height());
466 mat.par_row_chunks_exact_mut(1 << (log_h - mid))
467 .enumerate()
468 .for_each(|(thread, mut submat)| {
469 let mut backwards = false;
470 for layer in mid..log_h {
471 let layer_rev = log_h - 1 - layer;
472 let first_block = thread << (layer - mid);
473 dit_layer_rev(
474 &mut submat,
475 log_h,
476 layer,
477 twiddles_rev[layer_rev][first_block..].iter().copied(),
478 backwards,
479 );
480 backwards = !backwards;
481 }
482 });
483}
484
485fn dit_layer<'a, F: Field>(
487 submat: &mut RowMajorMatrixViewMut<'_, F>,
488 layer: usize,
489 twiddles: impl Iterator<Item = &'a F> + Clone,
490 backwards: bool,
491) {
492 let half_block_size = 1 << layer;
493 let block_size = half_block_size * 2;
494 let width = submat.width();
495 debug_assert!(submat.height() >= block_size);
496
497 let process_block = move |block: &mut [F]| {
498 let (lows, highs) = block.split_at_mut(half_block_size * width);
499 for (lo, hi, twiddle) in izip!(
500 lows.chunks_mut(width),
501 highs.chunks_mut(width),
502 twiddles.clone()
503 ) {
504 DitButterfly(*twiddle).apply_to_rows(lo, hi);
505 }
506 };
507
508 let blocks = submat.values.chunks_mut(block_size * width);
509 if backwards {
510 for block in blocks.rev() {
511 process_block(block);
512 }
513 } else {
514 for block in blocks {
515 process_block(block);
516 }
517 }
518}
519
520fn dit_layer_oop<'a, F: Field>(
522 src: &RowMajorMatrixView<'_, F>,
523 dst: &mut RowMajorMatrixViewMut<'_, MaybeUninit<F>>,
524 layer: usize,
525 twiddles: impl Iterator<Item = &'a F> + Clone,
526) {
527 debug_assert_eq!(src.dimensions(), dst.dimensions());
528 let half_block_size = 1 << layer;
529 let block_size = half_block_size * 2;
530 let width = dst.width();
531 debug_assert!(dst.height() >= block_size);
532
533 let process_blocks = move |src_block: &[F], dst_block: &mut [MaybeUninit<F>]| {
534 let (src_lows, src_highs) = src_block.split_at(half_block_size * width);
535 let (dst_lows, dst_highs) = dst_block.split_at_mut(half_block_size * width);
536
537 for (src_lo, dst_lo, src_hi, dst_hi, twiddle) in izip!(
538 src_lows.chunks(width),
539 dst_lows.chunks_mut(width),
540 src_highs.chunks(width),
541 dst_highs.chunks_mut(width),
542 twiddles.clone()
543 ) {
544 DitButterfly(*twiddle).apply_to_rows_oop(src_lo, dst_lo, src_hi, dst_hi);
545 }
546 };
547
548 let src_chunks = src.values.chunks(block_size * width);
549 let dst_chunks = dst.values.chunks_mut(block_size * width);
550
551 for (src_block, dst_block) in src_chunks.zip(dst_chunks) {
552 process_blocks(src_block, dst_block);
553 }
554}
555
556fn dit_layer_rev<F: Field>(
559 submat: &mut RowMajorMatrixViewMut<'_, F>,
560 log_h: usize,
561 layer: usize,
562 twiddles_rev: impl DoubleEndedIterator<Item = F> + ExactSizeIterator,
563 backwards: bool,
564) {
565 let layer_rev = log_h - 1 - layer;
566
567 let half_block_size = 1 << layer_rev;
568 let block_size = half_block_size * 2;
569 let width = submat.width();
570 debug_assert!(submat.height() >= block_size);
571
572 let blocks_and_twiddles = submat
573 .values
574 .chunks_mut(block_size * width)
575 .zip(twiddles_rev);
576 if backwards {
577 for (block, twiddle) in blocks_and_twiddles.rev() {
578 let (lo, hi) = block.split_at_mut(half_block_size * width);
579 DitButterfly(twiddle).apply_to_rows(lo, hi);
580 }
581 } else {
582 for (block, twiddle) in blocks_and_twiddles {
583 let (lo, hi) = block.split_at_mut(half_block_size * width);
584 DitButterfly(twiddle).apply_to_rows(lo, hi);
585 }
586 }
587}