1use ff::PrimeField;
2
3use crate::threadpool::Worker;
4
5#[allow(clippy::many_single_char_names)]
10pub fn serial_fft<F: PrimeField>(a: &mut [F], omega: &F, log_n: u32) {
11 fn bitreverse(mut n: u32, l: u32) -> u32 {
12 let mut r = 0;
13 for _ in 0..l {
14 r = (r << 1) | (n & 1);
15 n >>= 1;
16 }
17 r
18 }
19
20 let n = a.len() as u32;
21 assert_eq!(n, 1 << log_n);
22
23 for k in 0..n {
24 let rk = bitreverse(k, log_n);
25 if k < rk {
26 a.swap(rk as usize, k as usize);
27 }
28 }
29
30 let mut m = 1;
31 for _ in 0..log_n {
32 let w_m = omega.pow_vartime([u64::from(n / (2 * m))]);
33
34 let mut k = 0;
35 while k < n {
36 let mut w = F::ONE;
37 for j in 0..m {
38 let mut t = a[(k + j + m) as usize];
39 t *= w;
40 let mut tmp = a[(k + j) as usize];
41 tmp -= t;
42 a[(k + j + m) as usize] = tmp;
43 a[(k + j) as usize] += t;
44 w *= w_m;
45 }
46
47 k += 2 * m;
48 }
49
50 m *= 2;
51 }
52}
53
54pub fn parallel_fft<F: PrimeField>(
60 a: &mut [F],
61 worker: &Worker,
62 omega: &F,
63 log_n: u32,
64 log_threads: u32,
65) {
66 assert!(log_n >= log_threads);
67
68 let num_threads = 1 << log_threads;
69 let log_new_n = log_n - log_threads;
70 let mut tmp = vec![vec![F::ZERO; 1 << log_new_n]; num_threads];
71 let new_omega = omega.pow_vartime([num_threads as u64]);
72
73 worker.scope(0, |scope, _| {
74 let a = &*a;
75
76 for (j, tmp) in tmp.iter_mut().enumerate() {
77 scope.execute(move || {
78 let omega_j = omega.pow_vartime([j as u64]);
80 let omega_step = omega.pow_vartime([(j as u64) << log_new_n]);
81
82 let mut elt = F::ONE;
83 for (i, tmp) in tmp.iter_mut().enumerate() {
84 for s in 0..num_threads {
85 let idx = (i + (s << log_new_n)) % (1 << log_n);
86 let mut t = a[idx];
87 t *= elt;
88 *tmp += t;
89 elt *= omega_step;
90 }
91 elt *= omega_j;
92 }
93
94 serial_fft::<F>(tmp, &new_omega, log_new_n);
96 });
97 }
98 });
99
100 worker.scope(a.len(), |scope, chunk| {
102 let tmp = &tmp;
103
104 for (idx, a) in a.chunks_mut(chunk).enumerate() {
105 scope.execute(move || {
106 let mut idx = idx * chunk;
107 let mask = (1 << log_threads) - 1;
108 for a in a {
109 *a = tmp[idx & mask][idx >> log_threads];
110 idx += 1;
111 }
112 });
113 }
114 });
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 use std::cmp::min;
122
123 use blstrs::Scalar as Fr;
124 use ff::PrimeField;
125 use rand_core::RngCore;
126
127 fn omega<F: PrimeField>(num_coeffs: usize) -> F {
128 let exp = (num_coeffs as f32).log2().floor() as u32;
130 let mut omega = F::ROOT_OF_UNITY;
131 for _ in exp..F::S {
132 omega = omega.square();
133 }
134 omega
135 }
136
137 #[test]
138 fn parallel_fft_consistency() {
139 fn test_consistency<F: PrimeField, R: RngCore>(rng: &mut R) {
140 let worker = Worker::new();
141
142 for _ in 0..5 {
143 for log_d in 0..10 {
144 let d = 1 << log_d;
145
146 let mut v1_coeffs = (0..d).map(|_| F::random(&mut *rng)).collect::<Vec<_>>();
147 let mut v2_coeffs = v1_coeffs.clone();
148 let v1_omega = omega::<F>(v1_coeffs.len());
149 let v2_omega = v1_omega;
150
151 for log_threads in log_d..min(log_d + 1, 3) {
152 parallel_fft::<F>(&mut v1_coeffs, &worker, &v1_omega, log_d, log_threads);
153 serial_fft::<F>(&mut v2_coeffs, &v2_omega, log_d);
154
155 assert!(v1_coeffs == v2_coeffs);
156 }
157 }
158 }
159 }
160
161 let rng = &mut rand::thread_rng();
162
163 test_consistency::<Fr, _>(rng);
164 }
165}