bellman/
domain.rs

1//! This module contains an [`EvaluationDomain`] abstraction for performing
2//! various kinds of polynomial arithmetic on top of the scalar field.
3//!
4//! In pairing-based SNARKs like [Groth16], we need to calculate a quotient
5//! polynomial over a target polynomial with roots at distinct points associated
6//! with each constraint of the constraint system. In order to be efficient, we
7//! choose these roots to be the powers of a 2<sup>n</sup> root of unity in the
8//! field. This allows us to perform polynomial operations in O(n) by performing
9//! an O(n log n) FFT over such a domain.
10//!
11//! [`EvaluationDomain`]: crate::domain::EvaluationDomain
12//! [Groth16]: https://eprint.iacr.org/2016/260
13
14use ff::PrimeField;
15use group::cofactor::CofactorCurve;
16
17use super::SynthesisError;
18
19use super::multicore::Worker;
20
21pub struct EvaluationDomain<S: PrimeField, G: Group<S>> {
22    coeffs: Vec<G>,
23    exp: u32,
24    omega: S,
25    omegainv: S,
26    geninv: S,
27    minv: S,
28}
29
30impl<S: PrimeField, G: Group<S>> AsRef<[G]> for EvaluationDomain<S, G> {
31    fn as_ref(&self) -> &[G] {
32        &self.coeffs
33    }
34}
35
36impl<S: PrimeField, G: Group<S>> AsMut<[G]> for EvaluationDomain<S, G> {
37    fn as_mut(&mut self) -> &mut [G] {
38        &mut self.coeffs
39    }
40}
41
42impl<S: PrimeField, G: Group<S>> EvaluationDomain<S, G> {
43    pub fn into_coeffs(self) -> Vec<G> {
44        self.coeffs
45    }
46
47    pub fn from_coeffs(mut coeffs: Vec<G>) -> Result<EvaluationDomain<S, G>, SynthesisError> {
48        // Compute the size of our evaluation domain
49        let mut m = 1;
50        let mut exp = 0;
51        while m < coeffs.len() {
52            m *= 2;
53            exp += 1;
54
55            // The pairing-friendly curve may not be able to support
56            // large enough (radix2) evaluation domains.
57            if exp >= S::S {
58                return Err(SynthesisError::PolynomialDegreeTooLarge);
59            }
60        }
61
62        // Compute omega, the 2^exp primitive root of unity
63        let mut omega = S::ROOT_OF_UNITY;
64        for _ in exp..S::S {
65            omega = omega.square();
66        }
67
68        // Extend the coeffs vector with zeroes if necessary
69        coeffs.resize(m, G::group_zero());
70
71        Ok(EvaluationDomain {
72            coeffs,
73            exp,
74            omega,
75            omegainv: omega.invert().unwrap(),
76            geninv: S::MULTIPLICATIVE_GENERATOR.invert().unwrap(),
77            minv: S::from(m as u64).invert().unwrap(),
78        })
79    }
80
81    pub fn fft(&mut self, worker: &Worker) {
82        best_fft(&mut self.coeffs, worker, &self.omega, self.exp);
83    }
84
85    pub fn ifft(&mut self, worker: &Worker) {
86        best_fft(&mut self.coeffs, worker, &self.omegainv, self.exp);
87
88        worker.scope(self.coeffs.len(), |scope, chunk| {
89            let minv = self.minv;
90
91            for v in self.coeffs.chunks_mut(chunk) {
92                scope.spawn(move |_scope| {
93                    for v in v {
94                        v.group_mul_assign(&minv);
95                    }
96                });
97            }
98        });
99    }
100
101    pub fn distribute_powers(&mut self, worker: &Worker, g: S) {
102        worker.scope(self.coeffs.len(), |scope, chunk| {
103            for (i, v) in self.coeffs.chunks_mut(chunk).enumerate() {
104                scope.spawn(move |_scope| {
105                    let mut u = g.pow_vartime(&[(i * chunk) as u64]);
106                    for v in v.iter_mut() {
107                        v.group_mul_assign(&u);
108                        u.mul_assign(&g);
109                    }
110                });
111            }
112        });
113    }
114
115    pub fn coset_fft(&mut self, worker: &Worker) {
116        self.distribute_powers(worker, S::MULTIPLICATIVE_GENERATOR);
117        self.fft(worker);
118    }
119
120    pub fn icoset_fft(&mut self, worker: &Worker) {
121        let geninv = self.geninv;
122
123        self.ifft(worker);
124        self.distribute_powers(worker, geninv);
125    }
126
127    /// This evaluates t(tau) for this domain, which is
128    /// tau^m - 1 for these radix-2 domains.
129    pub fn z(&self, tau: &S) -> S {
130        let mut tmp = tau.pow_vartime(&[self.coeffs.len() as u64]);
131        tmp.sub_assign(&S::ONE);
132
133        tmp
134    }
135
136    /// The target polynomial is the zero polynomial in our
137    /// evaluation domain, so we must perform division over
138    /// a coset.
139    pub fn divide_by_z_on_coset(&mut self, worker: &Worker) {
140        let i = self.z(&S::MULTIPLICATIVE_GENERATOR).invert().unwrap();
141
142        worker.scope(self.coeffs.len(), |scope, chunk| {
143            for v in self.coeffs.chunks_mut(chunk) {
144                scope.spawn(move |_scope| {
145                    for v in v {
146                        v.group_mul_assign(&i);
147                    }
148                });
149            }
150        });
151    }
152
153    /// Perform O(n) multiplication of two polynomials in the domain.
154    pub fn mul_assign(&mut self, worker: &Worker, other: &EvaluationDomain<S, Scalar<S>>) {
155        assert_eq!(self.coeffs.len(), other.coeffs.len());
156
157        worker.scope(self.coeffs.len(), |scope, chunk| {
158            for (a, b) in self
159                .coeffs
160                .chunks_mut(chunk)
161                .zip(other.coeffs.chunks(chunk))
162            {
163                scope.spawn(move |_scope| {
164                    for (a, b) in a.iter_mut().zip(b.iter()) {
165                        a.group_mul_assign(&b.0);
166                    }
167                });
168            }
169        });
170    }
171
172    /// Perform O(n) subtraction of one polynomial from another in the domain.
173    pub fn sub_assign(&mut self, worker: &Worker, other: &EvaluationDomain<S, G>) {
174        assert_eq!(self.coeffs.len(), other.coeffs.len());
175
176        worker.scope(self.coeffs.len(), |scope, chunk| {
177            for (a, b) in self
178                .coeffs
179                .chunks_mut(chunk)
180                .zip(other.coeffs.chunks(chunk))
181            {
182                scope.spawn(move |_scope| {
183                    for (a, b) in a.iter_mut().zip(b.iter()) {
184                        a.group_sub_assign(b);
185                    }
186                });
187            }
188        });
189    }
190}
191
192pub trait Group<Scalar: PrimeField>: Sized + Copy + Clone + Send + Sync {
193    fn group_zero() -> Self;
194    fn group_mul_assign(&mut self, by: &Scalar);
195    fn group_add_assign(&mut self, other: &Self);
196    fn group_sub_assign(&mut self, other: &Self);
197}
198
199pub struct Point<G: CofactorCurve>(pub G);
200
201impl<G: CofactorCurve> PartialEq for Point<G> {
202    fn eq(&self, other: &Point<G>) -> bool {
203        self.0 == other.0
204    }
205}
206
207impl<G: CofactorCurve> Copy for Point<G> {}
208
209impl<G: CofactorCurve> Clone for Point<G> {
210    fn clone(&self) -> Point<G> {
211        *self
212    }
213}
214
215impl<G: CofactorCurve> Group<G::Scalar> for Point<G> {
216    fn group_zero() -> Self {
217        Point(G::identity())
218    }
219    fn group_mul_assign(&mut self, by: &G::Scalar) {
220        self.0.mul_assign(by);
221    }
222    fn group_add_assign(&mut self, other: &Self) {
223        self.0.add_assign(&other.0);
224    }
225    fn group_sub_assign(&mut self, other: &Self) {
226        self.0.sub_assign(&other.0);
227    }
228}
229
230pub struct Scalar<S: PrimeField>(pub S);
231
232impl<S: PrimeField> PartialEq for Scalar<S> {
233    fn eq(&self, other: &Scalar<S>) -> bool {
234        self.0 == other.0
235    }
236}
237
238impl<S: PrimeField> Copy for Scalar<S> {}
239
240impl<S: PrimeField> Clone for Scalar<S> {
241    fn clone(&self) -> Scalar<S> {
242        *self
243    }
244}
245
246impl<S: PrimeField> Group<S> for Scalar<S> {
247    fn group_zero() -> Self {
248        Scalar(S::ZERO)
249    }
250    fn group_mul_assign(&mut self, by: &S) {
251        self.0.mul_assign(by);
252    }
253    fn group_add_assign(&mut self, other: &Self) {
254        self.0.add_assign(&other.0);
255    }
256    fn group_sub_assign(&mut self, other: &Self) {
257        self.0.sub_assign(&other.0);
258    }
259}
260
261fn best_fft<S: PrimeField, T: Group<S>>(a: &mut [T], worker: &Worker, omega: &S, log_n: u32) {
262    let log_cpus = worker.log_num_threads();
263
264    if log_n <= log_cpus {
265        serial_fft(a, omega, log_n);
266    } else {
267        parallel_fft(a, worker, omega, log_n, log_cpus);
268    }
269}
270
271#[allow(clippy::many_single_char_names)]
272fn serial_fft<S: PrimeField, T: Group<S>>(a: &mut [T], omega: &S, log_n: u32) {
273    fn bitreverse(mut n: u32, l: u32) -> u32 {
274        let mut r = 0;
275        for _ in 0..l {
276            r = (r << 1) | (n & 1);
277            n >>= 1;
278        }
279        r
280    }
281
282    let n = a.len() as u32;
283    assert_eq!(n, 1 << log_n);
284
285    for k in 0..n {
286        let rk = bitreverse(k, log_n);
287        if k < rk {
288            a.swap(rk as usize, k as usize);
289        }
290    }
291
292    let mut m = 1;
293    for _ in 0..log_n {
294        let w_m = omega.pow_vartime(&[u64::from(n / (2 * m))]);
295
296        let mut k = 0;
297        while k < n {
298            let mut w = S::ONE;
299            for j in 0..m {
300                let mut t = a[(k + j + m) as usize];
301                t.group_mul_assign(&w);
302                let mut tmp = a[(k + j) as usize];
303                tmp.group_sub_assign(&t);
304                a[(k + j + m) as usize] = tmp;
305                a[(k + j) as usize].group_add_assign(&t);
306                w.mul_assign(&w_m);
307            }
308
309            k += 2 * m;
310        }
311
312        m *= 2;
313    }
314}
315
316fn parallel_fft<S: PrimeField, T: Group<S>>(
317    a: &mut [T],
318    worker: &Worker,
319    omega: &S,
320    log_n: u32,
321    log_cpus: u32,
322) {
323    assert!(log_n >= log_cpus);
324
325    let num_cpus = 1 << log_cpus;
326    let log_new_n = log_n - log_cpus;
327    let mut tmp = vec![vec![T::group_zero(); 1 << log_new_n]; num_cpus];
328    let new_omega = omega.pow_vartime(&[num_cpus as u64]);
329
330    worker.scope(0, |scope, _| {
331        let a = &*a;
332
333        for (j, tmp) in tmp.iter_mut().enumerate() {
334            scope.spawn(move |_scope| {
335                // Shuffle into a sub-FFT
336                let omega_j = omega.pow_vartime(&[j as u64]);
337                let omega_step = omega.pow_vartime(&[(j as u64) << log_new_n]);
338
339                let mut elt = S::ONE;
340                for (i, tmp) in tmp.iter_mut().enumerate() {
341                    for s in 0..num_cpus {
342                        let idx = (i + (s << log_new_n)) % (1 << log_n);
343                        let mut t = a[idx];
344                        t.group_mul_assign(&elt);
345                        tmp.group_add_assign(&t);
346                        elt.mul_assign(&omega_step);
347                    }
348                    elt.mul_assign(&omega_j);
349                }
350
351                // Perform sub-FFT
352                serial_fft(tmp, &new_omega, log_new_n);
353            });
354        }
355    });
356
357    // TODO: does this hurt or help?
358    worker.scope(a.len(), |scope, chunk| {
359        let tmp = &tmp;
360
361        for (idx, a) in a.chunks_mut(chunk).enumerate() {
362            scope.spawn(move |_scope| {
363                let mut idx = idx * chunk;
364                let mask = (1 << log_cpus) - 1;
365                for a in a {
366                    *a = tmp[idx & mask][idx >> log_cpus];
367                    idx += 1;
368                }
369            });
370        }
371    });
372}
373
374// Test multiplying various (low degree) polynomials together and
375// comparing with naive evaluations.
376#[cfg(feature = "pairing")]
377#[test]
378fn polynomial_arith() {
379    use bls12_381::Scalar as Fr;
380    use rand_core::RngCore;
381
382    fn test_mul<S: PrimeField, R: RngCore>(mut rng: &mut R) {
383        let worker = Worker::new();
384
385        for coeffs_a in 0..70 {
386            for coeffs_b in 0..70 {
387                let mut a: Vec<_> = (0..coeffs_a)
388                    .map(|_| Scalar::<S>(S::random(&mut rng)))
389                    .collect();
390                let mut b: Vec<_> = (0..coeffs_b)
391                    .map(|_| Scalar::<S>(S::random(&mut rng)))
392                    .collect();
393
394                // naive evaluation
395                let mut naive = vec![Scalar(S::ZERO); coeffs_a + coeffs_b];
396                for (i1, a) in a.iter().enumerate() {
397                    for (i2, b) in b.iter().enumerate() {
398                        let mut prod = *a;
399                        prod.group_mul_assign(&b.0);
400                        naive[i1 + i2].group_add_assign(&prod);
401                    }
402                }
403
404                a.resize(coeffs_a + coeffs_b, Scalar(S::ZERO));
405                b.resize(coeffs_a + coeffs_b, Scalar(S::ZERO));
406
407                let mut a = EvaluationDomain::from_coeffs(a).unwrap();
408                let mut b = EvaluationDomain::from_coeffs(b).unwrap();
409
410                a.fft(&worker);
411                b.fft(&worker);
412                a.mul_assign(&worker, &b);
413                a.ifft(&worker);
414
415                for (naive, fft) in naive.iter().zip(a.coeffs.iter()) {
416                    assert!(naive == fft);
417                }
418            }
419        }
420    }
421
422    let rng = &mut rand::thread_rng();
423
424    test_mul::<Fr, _>(rng);
425}
426
427#[cfg(feature = "pairing")]
428#[test]
429fn fft_composition() {
430    use bls12_381::Scalar as Fr;
431    use rand_core::RngCore;
432
433    fn test_comp<S: PrimeField, R: RngCore>(mut rng: &mut R) {
434        let worker = Worker::new();
435
436        for coeffs in 0..10 {
437            let coeffs = 1 << coeffs;
438
439            let mut v = vec![];
440            for _ in 0..coeffs {
441                v.push(Scalar::<S>(S::random(&mut rng)));
442            }
443
444            let mut domain = EvaluationDomain::from_coeffs(v.clone()).unwrap();
445            domain.ifft(&worker);
446            domain.fft(&worker);
447            assert!(v == domain.coeffs);
448            domain.fft(&worker);
449            domain.ifft(&worker);
450            assert!(v == domain.coeffs);
451            domain.icoset_fft(&worker);
452            domain.coset_fft(&worker);
453            assert!(v == domain.coeffs);
454            domain.coset_fft(&worker);
455            domain.icoset_fft(&worker);
456            assert!(v == domain.coeffs);
457        }
458    }
459
460    let rng = &mut rand::thread_rng();
461
462    test_comp::<Fr, _>(rng);
463}
464
465#[cfg(feature = "pairing")]
466#[test]
467fn parallel_fft_consistency() {
468    use bls12_381::Scalar as Fr;
469    use rand_core::RngCore;
470    use std::cmp::min;
471
472    fn test_consistency<S: PrimeField, R: RngCore>(mut rng: &mut R) {
473        let worker = Worker::new();
474
475        for _ in 0..5 {
476            for log_d in 0..10 {
477                let d = 1 << log_d;
478
479                let v1 = (0..d)
480                    .map(|_| Scalar::<S>(S::random(&mut rng)))
481                    .collect::<Vec<_>>();
482                let mut v1 = EvaluationDomain::from_coeffs(v1).unwrap();
483                let mut v2 = EvaluationDomain::from_coeffs(v1.coeffs.clone()).unwrap();
484
485                for log_cpus in log_d..min(log_d + 1, 3) {
486                    parallel_fft(&mut v1.coeffs, &worker, &v1.omega, log_d, log_cpus);
487                    serial_fft(&mut v2.coeffs, &v2.omega, log_d);
488
489                    assert!(v1.coeffs == v2.coeffs);
490                }
491            }
492        }
493    }
494
495    let rng = &mut rand::thread_rng();
496
497    test_consistency::<Fr, _>(rng);
498}