Skip to main content

ark_poly/domain/
utils.rs

1use crate::domain::DomainCoeff;
2use ark_ff::{FftField, Field};
3use ark_std::vec::*;
4#[cfg(feature = "parallel")]
5use rayon::prelude::*;
6
7// minimum size of a parallelized chunk
8#[allow(unused)]
9#[cfg(feature = "parallel")]
10const MIN_PARALLEL_CHUNK_SIZE: usize = 1 << 7;
11
12#[inline]
13pub(crate) fn bitreverse(mut n: u32, l: u32) -> u32 {
14    let mut r = 0;
15    for _ in 0..l {
16        r = (r << 1) | (n & 1);
17        n >>= 1;
18    }
19    r
20}
21
22#[inline]
23pub fn bitreverse_permutation_in_place<T>(a: &mut [T], width: u32) {
24    // swapping in place (from Storer's book)
25    let n = a.len();
26    for k in 0..n {
27        let rk = bitreverse(k as u32, width) as usize;
28        if k < rk {
29            a.swap(k, rk);
30        }
31    }
32}
33
34pub(crate) fn compute_powers_serial<F: Field>(size: usize, root: F) -> Vec<F> {
35    compute_powers_and_mul_by_const_serial(size, root, F::one())
36}
37
38pub(crate) fn compute_powers_and_mul_by_const_serial<F: Field>(
39    size: usize,
40    root: F,
41    c: F,
42) -> Vec<F> {
43    let mut value = c;
44    (0..size)
45        .map(|_| {
46            let old_value = value;
47            value *= root;
48            old_value
49        })
50        .collect()
51}
52
53#[allow(unused)]
54#[cfg(feature = "parallel")]
55pub(crate) fn compute_powers<F: Field>(size: usize, g: F) -> Vec<F> {
56    if size < MIN_PARALLEL_CHUNK_SIZE {
57        return compute_powers_serial(size, g);
58    }
59    // compute the number of threads we will be using.
60    use ark_std::cmp::{max, min};
61    let num_cpus_available = rayon::current_num_threads();
62    let num_elem_per_thread = max(size / num_cpus_available, MIN_PARALLEL_CHUNK_SIZE);
63    let num_cpus_used = size / num_elem_per_thread;
64
65    // Split up the powers to compute across each thread evenly.
66    let res: Vec<F> = (0..num_cpus_used)
67        .into_par_iter()
68        .flat_map(|i| {
69            let offset = g.pow([(i * num_elem_per_thread) as u64]);
70            // Compute the size that this chunks' output should be
71            // (num_elem_per_thread, unless there are less than num_elem_per_thread elements remaining)
72            let num_elements_to_compute = min(size - i * num_elem_per_thread, num_elem_per_thread);
73            compute_powers_and_mul_by_const_serial(num_elements_to_compute, g, offset)
74        })
75        .collect();
76    res
77}
78
79#[cfg(feature = "parallel")]
80const fn log2_floor(num: usize) -> u32 {
81    if num == 0 {
82        0
83    } else {
84        1usize.leading_zeros() - num.leading_zeros()
85    }
86}
87
88#[cfg(feature = "parallel")]
89pub(crate) fn best_fft<T: DomainCoeff<F>, F: FftField>(
90    a: &mut [T],
91    omega: F,
92    log_n: u32,
93    serial_fft: fn(&mut [T], F, u32),
94) {
95    let num_cpus = rayon::current_num_threads();
96    let log_cpus = log2_floor(num_cpus);
97    if log_n <= log_cpus {
98        serial_fft(a, omega, log_n);
99    } else {
100        parallel_fft(a, omega, log_n, log_cpus, serial_fft);
101    }
102}
103
104#[cfg(not(feature = "parallel"))]
105#[inline]
106pub(crate) fn best_fft<T: DomainCoeff<F>, F: FftField>(
107    a: &mut [T],
108    omega: F,
109    log_n: u32,
110    serial_fft: fn(&mut [T], F, u32),
111) {
112    serial_fft(a, omega, log_n)
113}
114
115#[cfg(feature = "parallel")]
116pub(crate) fn parallel_fft<T: DomainCoeff<F>, F: FftField>(
117    a: &mut [T],
118    omega: F,
119    log_n: u32,
120    log_cpus: u32,
121    serial_fft: fn(&mut [T], F, u32),
122) {
123    assert!(log_n >= log_cpus);
124    // For documentation purposes, comments explain things
125    // as though `a` is a polynomial that we are trying to evaluate.
126
127    // Partition `a` equally into the number of threads.
128    // each partition is then of size m / num_threads.
129    let m = a.len();
130    let num_threads = 1 << (log_cpus as usize);
131    let num_cosets = num_threads;
132    assert_eq!(m % num_threads, 0);
133    let coset_size = m / num_threads;
134
135    // We compute the FFT non-mutatively first in tmp first,
136    // and then shuffle it back into a.
137    // The evaluations are going to be arranged in cosets, each of size |a| / num_threads.
138    // so the first coset is (1, g^{num_cosets}, g^{2*num_cosets}, etc.)
139    // the second coset is (g, g^{1 + num_cosets}, g^{1 + 2*num_cosets}, etc.)
140    // These are cosets with generator g^{num_cosets}, and varying shifts.
141    let mut tmp = vec![vec![T::zero(); coset_size]; num_cosets];
142    let new_omega = omega.pow([num_cosets as u64]);
143    let new_two_adicity = ark_ff::utils::k_adicity(2, coset_size as u64);
144
145    // For each coset, we first build a polynomial of degree |coset size|,
146    // whose evaluations over coset k will agree with the evaluations of a over the coset.
147    // Denote the kth such polynomial as poly_k
148    tmp.par_iter_mut()
149        .enumerate()
150        .for_each(|(k, kth_poly_coeffs)| {
151            // Shuffle into a sub-FFT
152            let omega_k = omega.pow([k as u64]);
153            let omega_step = omega.pow([(k * coset_size) as u64]);
154
155            let mut elt = F::one();
156            // Construct kth_poly_coeffs, which is a polynomial whose evaluations on this coset
157            // should equal the evaluations of a on this coset.
158            // `kth_poly_coeffs[i] = sum_{c in num_cosets} g^{k * (i + c * |coset|)} * a[i + c * |coset|]`
159            // Where c represents the index of the coset being considered.
160            // multiplying by g^{k*i} corresponds to the shift for just being in a different coset.
161            //
162            // TODO: Come back and improve the speed, and make this a more 'normal'
163            // Cooley-Tukey. This appears to be an FFT of the polynomial
164            // `P(x) = sum_{c in |coset|} a[i + c |coset|] * x^c`
165            // onto this coset.
166            // However this is being evaluated in time O(N) instead of time
167            // O(|coset|log(|coset|)). If this understanding is the case, its not
168            // doing standard Cooley-Tukey. At the moment, this has time complexity
169            // of at least 2*N field mul's per thread, so we will be getting
170            // pretty bad parallelism. Exact complexity per thread atm is
171            // `2N + (N/num threads)log(N/num threads)` field muls Compare to the time
172            // complexity of serial is Nlog(N) field muls), with log(N) in [15, 25]
173            kth_poly_coeffs
174                .iter_mut()
175                .enumerate()
176                .take(coset_size)
177                .for_each(|(i, coeff)| {
178                    for c in 0..num_threads {
179                        let idx = i + (c * coset_size);
180                        // Compute the value of `a` corresponding to the `i`th element of the `c`th coset.
181                        let mut t = a[idx];
182                        // Multiply by `g^{k * idx}`
183                        t *= elt;
184                        *coeff += t;
185                        elt *= &omega_step;
186                    }
187                    elt *= &omega_k;
188                });
189
190            // Perform sub-FFT
191            // Since the sub-FFT is mutative, after this point
192            // `kth_poly_coeffs` should be renamed `kth_coset_evals`
193            serial_fft(kth_poly_coeffs, new_omega, new_two_adicity);
194        });
195
196    // shuffle the values computed above into a
197    // The evaluations of a should be ordered as (1, g, g^2, ...)
198    a.iter_mut()
199        .enumerate()
200        .for_each(|(i, a)| *a = tmp[i % num_cosets][i / num_cosets]);
201}
202
203/// An iterator over the elements of a domain.
204pub struct Elements<F: FftField> {
205    pub(crate) cur_elem: F,
206    pub(crate) cur_pow: u64,
207    pub(crate) size: u64,
208    pub(crate) group_gen: F,
209}
210
211impl<F: FftField> Iterator for Elements<F> {
212    type Item = F;
213    fn next(&mut self) -> Option<F> {
214        if self.cur_pow == self.size {
215            None
216        } else {
217            let cur_elem = self.cur_elem;
218            self.cur_elem *= &self.group_gen;
219            self.cur_pow += 1;
220            Some(cur_elem)
221        }
222    }
223}