1use ff::PrimeField;
15use group::cofactor::CofactorCurve;
16
17use super::SynthesisError;
18
19use super::multicore::Worker;
20
21pub struct EvaluationDomain<S: PrimeField, G: Group<S>> {
22 coeffs: Vec<G>,
23 exp: u32,
24 omega: S,
25 omegainv: S,
26 geninv: S,
27 minv: S,
28}
29
30impl<S: PrimeField, G: Group<S>> AsRef<[G]> for EvaluationDomain<S, G> {
31 fn as_ref(&self) -> &[G] {
32 &self.coeffs
33 }
34}
35
36impl<S: PrimeField, G: Group<S>> AsMut<[G]> for EvaluationDomain<S, G> {
37 fn as_mut(&mut self) -> &mut [G] {
38 &mut self.coeffs
39 }
40}
41
42impl<S: PrimeField, G: Group<S>> EvaluationDomain<S, G> {
43 pub fn into_coeffs(self) -> Vec<G> {
44 self.coeffs
45 }
46
47 pub fn from_coeffs(mut coeffs: Vec<G>) -> Result<EvaluationDomain<S, G>, SynthesisError> {
48 let mut m = 1;
50 let mut exp = 0;
51 while m < coeffs.len() {
52 m *= 2;
53 exp += 1;
54
55 if exp >= S::S {
58 return Err(SynthesisError::PolynomialDegreeTooLarge);
59 }
60 }
61
62 let mut omega = S::ROOT_OF_UNITY;
64 for _ in exp..S::S {
65 omega = omega.square();
66 }
67
68 coeffs.resize(m, G::group_zero());
70
71 Ok(EvaluationDomain {
72 coeffs,
73 exp,
74 omega,
75 omegainv: omega.invert().unwrap(),
76 geninv: S::MULTIPLICATIVE_GENERATOR.invert().unwrap(),
77 minv: S::from(m as u64).invert().unwrap(),
78 })
79 }
80
81 pub fn fft(&mut self, worker: &Worker) {
82 best_fft(&mut self.coeffs, worker, &self.omega, self.exp);
83 }
84
85 pub fn ifft(&mut self, worker: &Worker) {
86 best_fft(&mut self.coeffs, worker, &self.omegainv, self.exp);
87
88 worker.scope(self.coeffs.len(), |scope, chunk| {
89 let minv = self.minv;
90
91 for v in self.coeffs.chunks_mut(chunk) {
92 scope.spawn(move |_scope| {
93 for v in v {
94 v.group_mul_assign(&minv);
95 }
96 });
97 }
98 });
99 }
100
101 pub fn distribute_powers(&mut self, worker: &Worker, g: S) {
102 worker.scope(self.coeffs.len(), |scope, chunk| {
103 for (i, v) in self.coeffs.chunks_mut(chunk).enumerate() {
104 scope.spawn(move |_scope| {
105 let mut u = g.pow_vartime(&[(i * chunk) as u64]);
106 for v in v.iter_mut() {
107 v.group_mul_assign(&u);
108 u.mul_assign(&g);
109 }
110 });
111 }
112 });
113 }
114
115 pub fn coset_fft(&mut self, worker: &Worker) {
116 self.distribute_powers(worker, S::MULTIPLICATIVE_GENERATOR);
117 self.fft(worker);
118 }
119
120 pub fn icoset_fft(&mut self, worker: &Worker) {
121 let geninv = self.geninv;
122
123 self.ifft(worker);
124 self.distribute_powers(worker, geninv);
125 }
126
127 pub fn z(&self, tau: &S) -> S {
130 let mut tmp = tau.pow_vartime(&[self.coeffs.len() as u64]);
131 tmp.sub_assign(&S::ONE);
132
133 tmp
134 }
135
136 pub fn divide_by_z_on_coset(&mut self, worker: &Worker) {
140 let i = self.z(&S::MULTIPLICATIVE_GENERATOR).invert().unwrap();
141
142 worker.scope(self.coeffs.len(), |scope, chunk| {
143 for v in self.coeffs.chunks_mut(chunk) {
144 scope.spawn(move |_scope| {
145 for v in v {
146 v.group_mul_assign(&i);
147 }
148 });
149 }
150 });
151 }
152
153 pub fn mul_assign(&mut self, worker: &Worker, other: &EvaluationDomain<S, Scalar<S>>) {
155 assert_eq!(self.coeffs.len(), other.coeffs.len());
156
157 worker.scope(self.coeffs.len(), |scope, chunk| {
158 for (a, b) in self
159 .coeffs
160 .chunks_mut(chunk)
161 .zip(other.coeffs.chunks(chunk))
162 {
163 scope.spawn(move |_scope| {
164 for (a, b) in a.iter_mut().zip(b.iter()) {
165 a.group_mul_assign(&b.0);
166 }
167 });
168 }
169 });
170 }
171
172 pub fn sub_assign(&mut self, worker: &Worker, other: &EvaluationDomain<S, G>) {
174 assert_eq!(self.coeffs.len(), other.coeffs.len());
175
176 worker.scope(self.coeffs.len(), |scope, chunk| {
177 for (a, b) in self
178 .coeffs
179 .chunks_mut(chunk)
180 .zip(other.coeffs.chunks(chunk))
181 {
182 scope.spawn(move |_scope| {
183 for (a, b) in a.iter_mut().zip(b.iter()) {
184 a.group_sub_assign(b);
185 }
186 });
187 }
188 });
189 }
190}
191
192pub trait Group<Scalar: PrimeField>: Sized + Copy + Clone + Send + Sync {
193 fn group_zero() -> Self;
194 fn group_mul_assign(&mut self, by: &Scalar);
195 fn group_add_assign(&mut self, other: &Self);
196 fn group_sub_assign(&mut self, other: &Self);
197}
198
199pub struct Point<G: CofactorCurve>(pub G);
200
201impl<G: CofactorCurve> PartialEq for Point<G> {
202 fn eq(&self, other: &Point<G>) -> bool {
203 self.0 == other.0
204 }
205}
206
207impl<G: CofactorCurve> Copy for Point<G> {}
208
209impl<G: CofactorCurve> Clone for Point<G> {
210 fn clone(&self) -> Point<G> {
211 *self
212 }
213}
214
215impl<G: CofactorCurve> Group<G::Scalar> for Point<G> {
216 fn group_zero() -> Self {
217 Point(G::identity())
218 }
219 fn group_mul_assign(&mut self, by: &G::Scalar) {
220 self.0.mul_assign(by);
221 }
222 fn group_add_assign(&mut self, other: &Self) {
223 self.0.add_assign(&other.0);
224 }
225 fn group_sub_assign(&mut self, other: &Self) {
226 self.0.sub_assign(&other.0);
227 }
228}
229
230pub struct Scalar<S: PrimeField>(pub S);
231
232impl<S: PrimeField> PartialEq for Scalar<S> {
233 fn eq(&self, other: &Scalar<S>) -> bool {
234 self.0 == other.0
235 }
236}
237
238impl<S: PrimeField> Copy for Scalar<S> {}
239
240impl<S: PrimeField> Clone for Scalar<S> {
241 fn clone(&self) -> Scalar<S> {
242 *self
243 }
244}
245
246impl<S: PrimeField> Group<S> for Scalar<S> {
247 fn group_zero() -> Self {
248 Scalar(S::ZERO)
249 }
250 fn group_mul_assign(&mut self, by: &S) {
251 self.0.mul_assign(by);
252 }
253 fn group_add_assign(&mut self, other: &Self) {
254 self.0.add_assign(&other.0);
255 }
256 fn group_sub_assign(&mut self, other: &Self) {
257 self.0.sub_assign(&other.0);
258 }
259}
260
261fn best_fft<S: PrimeField, T: Group<S>>(a: &mut [T], worker: &Worker, omega: &S, log_n: u32) {
262 let log_cpus = worker.log_num_threads();
263
264 if log_n <= log_cpus {
265 serial_fft(a, omega, log_n);
266 } else {
267 parallel_fft(a, worker, omega, log_n, log_cpus);
268 }
269}
270
271#[allow(clippy::many_single_char_names)]
272fn serial_fft<S: PrimeField, T: Group<S>>(a: &mut [T], omega: &S, log_n: u32) {
273 fn bitreverse(mut n: u32, l: u32) -> u32 {
274 let mut r = 0;
275 for _ in 0..l {
276 r = (r << 1) | (n & 1);
277 n >>= 1;
278 }
279 r
280 }
281
282 let n = a.len() as u32;
283 assert_eq!(n, 1 << log_n);
284
285 for k in 0..n {
286 let rk = bitreverse(k, log_n);
287 if k < rk {
288 a.swap(rk as usize, k as usize);
289 }
290 }
291
292 let mut m = 1;
293 for _ in 0..log_n {
294 let w_m = omega.pow_vartime(&[u64::from(n / (2 * m))]);
295
296 let mut k = 0;
297 while k < n {
298 let mut w = S::ONE;
299 for j in 0..m {
300 let mut t = a[(k + j + m) as usize];
301 t.group_mul_assign(&w);
302 let mut tmp = a[(k + j) as usize];
303 tmp.group_sub_assign(&t);
304 a[(k + j + m) as usize] = tmp;
305 a[(k + j) as usize].group_add_assign(&t);
306 w.mul_assign(&w_m);
307 }
308
309 k += 2 * m;
310 }
311
312 m *= 2;
313 }
314}
315
316fn parallel_fft<S: PrimeField, T: Group<S>>(
317 a: &mut [T],
318 worker: &Worker,
319 omega: &S,
320 log_n: u32,
321 log_cpus: u32,
322) {
323 assert!(log_n >= log_cpus);
324
325 let num_cpus = 1 << log_cpus;
326 let log_new_n = log_n - log_cpus;
327 let mut tmp = vec![vec![T::group_zero(); 1 << log_new_n]; num_cpus];
328 let new_omega = omega.pow_vartime(&[num_cpus as u64]);
329
330 worker.scope(0, |scope, _| {
331 let a = &*a;
332
333 for (j, tmp) in tmp.iter_mut().enumerate() {
334 scope.spawn(move |_scope| {
335 let omega_j = omega.pow_vartime(&[j as u64]);
337 let omega_step = omega.pow_vartime(&[(j as u64) << log_new_n]);
338
339 let mut elt = S::ONE;
340 for (i, tmp) in tmp.iter_mut().enumerate() {
341 for s in 0..num_cpus {
342 let idx = (i + (s << log_new_n)) % (1 << log_n);
343 let mut t = a[idx];
344 t.group_mul_assign(&elt);
345 tmp.group_add_assign(&t);
346 elt.mul_assign(&omega_step);
347 }
348 elt.mul_assign(&omega_j);
349 }
350
351 serial_fft(tmp, &new_omega, log_new_n);
353 });
354 }
355 });
356
357 worker.scope(a.len(), |scope, chunk| {
359 let tmp = &tmp;
360
361 for (idx, a) in a.chunks_mut(chunk).enumerate() {
362 scope.spawn(move |_scope| {
363 let mut idx = idx * chunk;
364 let mask = (1 << log_cpus) - 1;
365 for a in a {
366 *a = tmp[idx & mask][idx >> log_cpus];
367 idx += 1;
368 }
369 });
370 }
371 });
372}
373
374#[cfg(feature = "pairing")]
377#[test]
378fn polynomial_arith() {
379 use bls12_381::Scalar as Fr;
380 use rand_core::RngCore;
381
382 fn test_mul<S: PrimeField, R: RngCore>(mut rng: &mut R) {
383 let worker = Worker::new();
384
385 for coeffs_a in 0..70 {
386 for coeffs_b in 0..70 {
387 let mut a: Vec<_> = (0..coeffs_a)
388 .map(|_| Scalar::<S>(S::random(&mut rng)))
389 .collect();
390 let mut b: Vec<_> = (0..coeffs_b)
391 .map(|_| Scalar::<S>(S::random(&mut rng)))
392 .collect();
393
394 let mut naive = vec![Scalar(S::ZERO); coeffs_a + coeffs_b];
396 for (i1, a) in a.iter().enumerate() {
397 for (i2, b) in b.iter().enumerate() {
398 let mut prod = *a;
399 prod.group_mul_assign(&b.0);
400 naive[i1 + i2].group_add_assign(&prod);
401 }
402 }
403
404 a.resize(coeffs_a + coeffs_b, Scalar(S::ZERO));
405 b.resize(coeffs_a + coeffs_b, Scalar(S::ZERO));
406
407 let mut a = EvaluationDomain::from_coeffs(a).unwrap();
408 let mut b = EvaluationDomain::from_coeffs(b).unwrap();
409
410 a.fft(&worker);
411 b.fft(&worker);
412 a.mul_assign(&worker, &b);
413 a.ifft(&worker);
414
415 for (naive, fft) in naive.iter().zip(a.coeffs.iter()) {
416 assert!(naive == fft);
417 }
418 }
419 }
420 }
421
422 let rng = &mut rand::thread_rng();
423
424 test_mul::<Fr, _>(rng);
425}
426
427#[cfg(feature = "pairing")]
428#[test]
429fn fft_composition() {
430 use bls12_381::Scalar as Fr;
431 use rand_core::RngCore;
432
433 fn test_comp<S: PrimeField, R: RngCore>(mut rng: &mut R) {
434 let worker = Worker::new();
435
436 for coeffs in 0..10 {
437 let coeffs = 1 << coeffs;
438
439 let mut v = vec![];
440 for _ in 0..coeffs {
441 v.push(Scalar::<S>(S::random(&mut rng)));
442 }
443
444 let mut domain = EvaluationDomain::from_coeffs(v.clone()).unwrap();
445 domain.ifft(&worker);
446 domain.fft(&worker);
447 assert!(v == domain.coeffs);
448 domain.fft(&worker);
449 domain.ifft(&worker);
450 assert!(v == domain.coeffs);
451 domain.icoset_fft(&worker);
452 domain.coset_fft(&worker);
453 assert!(v == domain.coeffs);
454 domain.coset_fft(&worker);
455 domain.icoset_fft(&worker);
456 assert!(v == domain.coeffs);
457 }
458 }
459
460 let rng = &mut rand::thread_rng();
461
462 test_comp::<Fr, _>(rng);
463}
464
465#[cfg(feature = "pairing")]
466#[test]
467fn parallel_fft_consistency() {
468 use bls12_381::Scalar as Fr;
469 use rand_core::RngCore;
470 use std::cmp::min;
471
472 fn test_consistency<S: PrimeField, R: RngCore>(mut rng: &mut R) {
473 let worker = Worker::new();
474
475 for _ in 0..5 {
476 for log_d in 0..10 {
477 let d = 1 << log_d;
478
479 let v1 = (0..d)
480 .map(|_| Scalar::<S>(S::random(&mut rng)))
481 .collect::<Vec<_>>();
482 let mut v1 = EvaluationDomain::from_coeffs(v1).unwrap();
483 let mut v2 = EvaluationDomain::from_coeffs(v1.coeffs.clone()).unwrap();
484
485 for log_cpus in log_d..min(log_d + 1, 3) {
486 parallel_fft(&mut v1.coeffs, &worker, &v1.omega, log_d, log_cpus);
487 serial_fft(&mut v2.coeffs, &v2.omega, log_d);
488
489 assert!(v1.coeffs == v2.coeffs);
490 }
491 }
492 }
493 }
494
495 let rng = &mut rand::thread_rng();
496
497 test_consistency::<Fr, _>(rng);
498}