1use crate::domain::DomainCoeff;
2use ark_ff::{FftField, Field};
3use ark_std::vec::*;
4#[cfg(feature = "parallel")]
5use rayon::prelude::*;
6
7#[allow(unused)]
9#[cfg(feature = "parallel")]
10const MIN_PARALLEL_CHUNK_SIZE: usize = 1 << 7;
11
12#[inline]
13pub(crate) fn bitreverse(mut n: u32, l: u32) -> u32 {
14 let mut r = 0;
15 for _ in 0..l {
16 r = (r << 1) | (n & 1);
17 n >>= 1;
18 }
19 r
20}
21
22#[inline]
23pub fn bitreverse_permutation_in_place<T>(a: &mut [T], width: u32) {
24 let n = a.len();
26 for k in 0..n {
27 let rk = bitreverse(k as u32, width) as usize;
28 if k < rk {
29 a.swap(k, rk);
30 }
31 }
32}
33
34pub(crate) fn compute_powers_serial<F: Field>(size: usize, root: F) -> Vec<F> {
35 compute_powers_and_mul_by_const_serial(size, root, F::one())
36}
37
38pub(crate) fn compute_powers_and_mul_by_const_serial<F: Field>(
39 size: usize,
40 root: F,
41 c: F,
42) -> Vec<F> {
43 let mut value = c;
44 (0..size)
45 .map(|_| {
46 let old_value = value;
47 value *= root;
48 old_value
49 })
50 .collect()
51}
52
53#[allow(unused)]
54#[cfg(feature = "parallel")]
55pub(crate) fn compute_powers<F: Field>(size: usize, g: F) -> Vec<F> {
56 if size < MIN_PARALLEL_CHUNK_SIZE {
57 return compute_powers_serial(size, g);
58 }
59 use ark_std::cmp::{max, min};
61 let num_cpus_available = rayon::current_num_threads();
62 let num_elem_per_thread = max(size / num_cpus_available, MIN_PARALLEL_CHUNK_SIZE);
63 let num_cpus_used = size / num_elem_per_thread;
64
65 let res: Vec<F> = (0..num_cpus_used)
67 .into_par_iter()
68 .flat_map(|i| {
69 let offset = g.pow([(i * num_elem_per_thread) as u64]);
70 let num_elements_to_compute = min(size - i * num_elem_per_thread, num_elem_per_thread);
73 compute_powers_and_mul_by_const_serial(num_elements_to_compute, g, offset)
74 })
75 .collect();
76 res
77}
78
79#[cfg(feature = "parallel")]
80const fn log2_floor(num: usize) -> u32 {
81 if num == 0 {
82 0
83 } else {
84 1usize.leading_zeros() - num.leading_zeros()
85 }
86}
87
88#[cfg(feature = "parallel")]
89pub(crate) fn best_fft<T: DomainCoeff<F>, F: FftField>(
90 a: &mut [T],
91 omega: F,
92 log_n: u32,
93 serial_fft: fn(&mut [T], F, u32),
94) {
95 let num_cpus = rayon::current_num_threads();
96 let log_cpus = log2_floor(num_cpus);
97 if log_n <= log_cpus {
98 serial_fft(a, omega, log_n);
99 } else {
100 parallel_fft(a, omega, log_n, log_cpus, serial_fft);
101 }
102}
103
104#[cfg(not(feature = "parallel"))]
105#[inline]
106pub(crate) fn best_fft<T: DomainCoeff<F>, F: FftField>(
107 a: &mut [T],
108 omega: F,
109 log_n: u32,
110 serial_fft: fn(&mut [T], F, u32),
111) {
112 serial_fft(a, omega, log_n)
113}
114
115#[cfg(feature = "parallel")]
116pub(crate) fn parallel_fft<T: DomainCoeff<F>, F: FftField>(
117 a: &mut [T],
118 omega: F,
119 log_n: u32,
120 log_cpus: u32,
121 serial_fft: fn(&mut [T], F, u32),
122) {
123 assert!(log_n >= log_cpus);
124 let m = a.len();
130 let num_threads = 1 << (log_cpus as usize);
131 let num_cosets = num_threads;
132 assert_eq!(m % num_threads, 0);
133 let coset_size = m / num_threads;
134
135 let mut tmp = vec![vec![T::zero(); coset_size]; num_cosets];
142 let new_omega = omega.pow([num_cosets as u64]);
143 let new_two_adicity = ark_ff::utils::k_adicity(2, coset_size as u64);
144
145 tmp.par_iter_mut()
149 .enumerate()
150 .for_each(|(k, kth_poly_coeffs)| {
151 let omega_k = omega.pow([k as u64]);
153 let omega_step = omega.pow([(k * coset_size) as u64]);
154
155 let mut elt = F::one();
156 kth_poly_coeffs
174 .iter_mut()
175 .enumerate()
176 .take(coset_size)
177 .for_each(|(i, coeff)| {
178 for c in 0..num_threads {
179 let idx = i + (c * coset_size);
180 let mut t = a[idx];
182 t *= elt;
184 *coeff += t;
185 elt *= &omega_step;
186 }
187 elt *= &omega_k;
188 });
189
190 serial_fft(kth_poly_coeffs, new_omega, new_two_adicity);
194 });
195
196 a.iter_mut()
199 .enumerate()
200 .for_each(|(i, a)| *a = tmp[i % num_cosets][i / num_cosets]);
201}
202
203pub struct Elements<F: FftField> {
205 pub(crate) cur_elem: F,
206 pub(crate) cur_pow: u64,
207 pub(crate) size: u64,
208 pub(crate) group_gen: F,
209}
210
211impl<F: FftField> Iterator for Elements<F> {
212 type Item = F;
213 fn next(&mut self) -> Option<F> {
214 if self.cur_pow == self.size {
215 None
216 } else {
217 let cur_elem = self.cur_elem;
218 self.cur_elem *= &self.group_gen;
219 self.cur_pow += 1;
220 Some(cur_elem)
221 }
222 }
223}