ec_gpu_gen/
fft_cpu.rs

1use ff::PrimeField;
2
3use crate::threadpool::Worker;
4
5/// Calculate the Fast Fourier Transform on the CPU (single-threaded).
6///
7/// The input `a` is mutated and contains the result when this function returns. The length of the
8/// input vector must be `2^log_n`.
9#[allow(clippy::many_single_char_names)]
10pub fn serial_fft<F: PrimeField>(a: &mut [F], omega: &F, log_n: u32) {
11    fn bitreverse(mut n: u32, l: u32) -> u32 {
12        let mut r = 0;
13        for _ in 0..l {
14            r = (r << 1) | (n & 1);
15            n >>= 1;
16        }
17        r
18    }
19
20    let n = a.len() as u32;
21    assert_eq!(n, 1 << log_n);
22
23    for k in 0..n {
24        let rk = bitreverse(k, log_n);
25        if k < rk {
26            a.swap(rk as usize, k as usize);
27        }
28    }
29
30    let mut m = 1;
31    for _ in 0..log_n {
32        let w_m = omega.pow_vartime([u64::from(n / (2 * m))]);
33
34        let mut k = 0;
35        while k < n {
36            let mut w = F::ONE;
37            for j in 0..m {
38                let mut t = a[(k + j + m) as usize];
39                t *= w;
40                let mut tmp = a[(k + j) as usize];
41                tmp -= t;
42                a[(k + j + m) as usize] = tmp;
43                a[(k + j) as usize] += t;
44                w *= w_m;
45            }
46
47            k += 2 * m;
48        }
49
50        m *= 2;
51    }
52}
53
54/// Calculate the Fast Fourier Transform on the CPU (multithreaded).
55///
56/// The result is is written to the input `a`.
57/// The number of threads used will be `2^log_threads`.
58/// There must be more items to process than threads.
59pub fn parallel_fft<F: PrimeField>(
60    a: &mut [F],
61    worker: &Worker,
62    omega: &F,
63    log_n: u32,
64    log_threads: u32,
65) {
66    assert!(log_n >= log_threads);
67
68    let num_threads = 1 << log_threads;
69    let log_new_n = log_n - log_threads;
70    let mut tmp = vec![vec![F::ZERO; 1 << log_new_n]; num_threads];
71    let new_omega = omega.pow_vartime([num_threads as u64]);
72
73    worker.scope(0, |scope, _| {
74        let a = &*a;
75
76        for (j, tmp) in tmp.iter_mut().enumerate() {
77            scope.execute(move || {
78                // Shuffle into a sub-FFT
79                let omega_j = omega.pow_vartime([j as u64]);
80                let omega_step = omega.pow_vartime([(j as u64) << log_new_n]);
81
82                let mut elt = F::ONE;
83                for (i, tmp) in tmp.iter_mut().enumerate() {
84                    for s in 0..num_threads {
85                        let idx = (i + (s << log_new_n)) % (1 << log_n);
86                        let mut t = a[idx];
87                        t *= elt;
88                        *tmp += t;
89                        elt *= omega_step;
90                    }
91                    elt *= omega_j;
92                }
93
94                // Perform sub-FFT
95                serial_fft::<F>(tmp, &new_omega, log_new_n);
96            });
97        }
98    });
99
100    // TODO: does this hurt or help?
101    worker.scope(a.len(), |scope, chunk| {
102        let tmp = &tmp;
103
104        for (idx, a) in a.chunks_mut(chunk).enumerate() {
105            scope.execute(move || {
106                let mut idx = idx * chunk;
107                let mask = (1 << log_threads) - 1;
108                for a in a {
109                    *a = tmp[idx & mask][idx >> log_threads];
110                    idx += 1;
111                }
112            });
113        }
114    });
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    use std::cmp::min;
122
123    use blstrs::Scalar as Fr;
124    use ff::PrimeField;
125    use rand_core::RngCore;
126
127    fn omega<F: PrimeField>(num_coeffs: usize) -> F {
128        // Compute omega, the 2^exp primitive root of unity
129        let exp = (num_coeffs as f32).log2().floor() as u32;
130        let mut omega = F::ROOT_OF_UNITY;
131        for _ in exp..F::S {
132            omega = omega.square();
133        }
134        omega
135    }
136
137    #[test]
138    fn parallel_fft_consistency() {
139        fn test_consistency<F: PrimeField, R: RngCore>(rng: &mut R) {
140            let worker = Worker::new();
141
142            for _ in 0..5 {
143                for log_d in 0..10 {
144                    let d = 1 << log_d;
145
146                    let mut v1_coeffs = (0..d).map(|_| F::random(&mut *rng)).collect::<Vec<_>>();
147                    let mut v2_coeffs = v1_coeffs.clone();
148                    let v1_omega = omega::<F>(v1_coeffs.len());
149                    let v2_omega = v1_omega;
150
151                    for log_threads in log_d..min(log_d + 1, 3) {
152                        parallel_fft::<F>(&mut v1_coeffs, &worker, &v1_omega, log_d, log_threads);
153                        serial_fft::<F>(&mut v2_coeffs, &v2_omega, log_d);
154
155                        assert!(v1_coeffs == v2_coeffs);
156                    }
157                }
158            }
159        }
160
161        let rng = &mut rand::thread_rng();
162
163        test_consistency::<Fr, _>(rng);
164    }
165}