halo2_proofs/
arithmetic.rs

1//! This module provides common utilities, traits and structures for group,
2//! field and polynomial arithmetic.
3
4pub use ff::Field;
5use group::{
6    ff::{BatchInvert, PrimeField},
7    Group as _, GroupOpsOwned, ScalarMulOwned,
8};
9use maybe_rayon::prelude::*;
10pub use pasta_curves::arithmetic::*;
11
12use crate::multicore::{self, TheBestReduce};
13
14/// This represents an element of a group with basic operations that can be
15/// performed. This allows an FFT implementation (for example) to operate
16/// generically over either a field or elliptic curve group.
17pub trait FftGroup<Scalar: Field>:
18    Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>
19{
20}
21
22impl<T, Scalar> FftGroup<Scalar> for T
23where
24    Scalar: Field,
25    T: Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>,
26{
27}
28
29#[derive(Clone, Copy)]
30enum Bucket<C: CurveAffine> {
31    None,
32    Affine(C),
33    Projective(C::Curve),
34}
35
36impl<C: CurveAffine> Bucket<C> {
37    fn add_assign(&mut self, other: &C) {
38        *self = match *self {
39            Bucket::None => Bucket::Affine(*other),
40            Bucket::Affine(a) => Bucket::Projective(a + *other),
41            Bucket::Projective(mut a) => {
42                a += *other;
43                Bucket::Projective(a)
44            }
45        }
46    }
47
48    fn add(self, mut other: C::Curve) -> C::Curve {
49        match self {
50            Bucket::None => other,
51            Bucket::Affine(a) => {
52                other += a;
53                other
54            }
55            Bucket::Projective(a) => other + &a,
56        }
57    }
58}
59
60#[derive(Clone)]
61struct Buckets<C: CurveAffine> {
62    c: usize,
63    coeffs: Vec<Bucket<C>>,
64}
65
66impl<C: CurveAffine> Buckets<C> {
67    fn new(c: usize) -> Self {
68        Self {
69            c,
70            coeffs: vec![Bucket::None; (1 << c) - 1],
71        }
72    }
73
74    fn sum(&mut self, coeffs: &[C::Scalar], bases: &[C], i: usize) -> C::Curve {
75        // get segmentation and add coeff to buckets content
76        for (coeff, base) in coeffs.iter().zip(bases.iter()) {
77            let seg = self.get_at::<C::Scalar>(i, &coeff.to_repr());
78            if seg != 0 {
79                self.coeffs[seg - 1].add_assign(base);
80            }
81        }
82        // Summation by parts
83        // e.g. 3a + 2b + 1c = a +
84        //                    (a) + b +
85        //                    ((a) + b) + c
86        let mut acc = C::Curve::identity();
87        let mut sum = C::Curve::identity();
88        self.coeffs.iter().rev().for_each(|b| {
89            sum = b.add(sum);
90            acc += sum;
91        });
92        acc
93    }
94
95    fn get_at<F: PrimeField>(&self, segment: usize, bytes: &F::Repr) -> usize {
96        let skip_bits = segment * self.c;
97        let skip_bytes = skip_bits / 8;
98
99        if skip_bytes >= 32 {
100            0
101        } else {
102            let mut v = [0; 8];
103            for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) {
104                *v = *o;
105            }
106
107            let mut tmp = u64::from_le_bytes(v);
108            tmp >>= skip_bits - (skip_bytes * 8);
109            (tmp % (1 << self.c)) as usize
110        }
111    }
112}
113
114/// Performs a small multi-exponentiation operation.
115/// Uses the double-and-add algorithm with doublings shared across points.
116pub fn small_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
117    let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
118    let mut acc = C::Curve::identity();
119
120    // for byte idx
121    for byte_idx in (0..32).rev() {
122        // for bit idx
123        for bit_idx in (0..8).rev() {
124            acc = acc.double();
125            // for each coeff
126            for coeff_idx in 0..coeffs.len() {
127                let byte = coeffs[coeff_idx].as_ref()[byte_idx];
128                if ((byte >> bit_idx) & 1) != 0 {
129                    acc += bases[coeff_idx];
130                }
131            }
132        }
133    }
134
135    acc
136}
137
138/// Performs a multi-exponentiation operation.
139///
140/// This function will panic if coeffs and bases have a different length.
141///
142/// This will use multithreading if beneficial.
143pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
144    assert_eq!(coeffs.len(), bases.len());
145
146    let c = if bases.len() < 4 {
147        1
148    } else if bases.len() < 32 {
149        3
150    } else {
151        (f64::from(bases.len() as u32)).ln().ceil() as usize
152    };
153
154    let mut multi_buckets: Vec<Buckets<C>> = vec![Buckets::new(c); (256 / c) + 1];
155    let num_threads = multicore::current_num_threads();
156    if coeffs.len() > num_threads {
157        multi_buckets
158            .par_iter_mut()
159            .enumerate()
160            .rev()
161            .map(|(i, buckets)| {
162                let mut acc = buckets.sum(coeffs, bases, i);
163                (0..c * i).for_each(|_| acc = acc.double());
164                acc
165            })
166            .the_best_reduce(C::Curve::identity, |a, b| a + b)
167            .expect("multi_buckets always contains at least 1 bucket")
168    } else {
169        multi_buckets
170            .iter_mut()
171            .enumerate()
172            .rev()
173            .map(|(i, buckets)| buckets.sum(coeffs, bases, i))
174            .fold(C::Curve::identity(), |mut sum, bucket| {
175                // restore original evaluation point
176                (0..c).for_each(|_| sum = sum.double());
177                sum + bucket
178            })
179    }
180}
181
182/// Performs a radix-$2$ Fast-Fourier Transformation (FFT) on a vector of size
183/// $n = 2^k$, when provided `log_n` = $k$ and an element of multiplicative
184/// order $n$ called `omega` ($\omega$). The result is that the vector `a`, when
185/// interpreted as the coefficients of a polynomial of degree $n - 1$, is
186/// transformed into the evaluations of this polynomial at each of the $n$
187/// distinct powers of $\omega$. This transformation is invertible by providing
188/// $\omega^{-1}$ in place of $\omega$ and dividing each resulting field element
189/// by $n$.
190///
191/// This will use multithreading if beneficial.
192pub fn best_fft<Scalar: Field, G: FftGroup<Scalar>>(a: &mut [G], omega: Scalar, log_n: u32) {
193    fn bitreverse(mut n: usize, l: usize) -> usize {
194        let mut r = 0;
195        for _ in 0..l {
196            r = (r << 1) | (n & 1);
197            n >>= 1;
198        }
199        r
200    }
201
202    let threads = multicore::current_num_threads();
203    let log_threads = log2_floor(threads);
204    let n = a.len();
205    assert_eq!(n, 1 << log_n);
206
207    for k in 0..n {
208        let rk = bitreverse(k, log_n as usize);
209        if k < rk {
210            a.swap(rk, k);
211        }
212    }
213
214    // precompute twiddle factors
215    let twiddles: Vec<_> = (0..(n / 2))
216        .scan(Scalar::ONE, |w, _| {
217            let tw = *w;
218            *w *= &omega;
219            Some(tw)
220        })
221        .collect();
222
223    if log_n <= log_threads {
224        let mut chunk = 2_usize;
225        let mut twiddle_chunk = n / 2;
226        for _ in 0..log_n {
227            a.chunks_mut(chunk).for_each(|coeffs| {
228                let (left, right) = coeffs.split_at_mut(chunk / 2);
229
230                // case when twiddle factor is one
231                let (a, left) = left.split_at_mut(1);
232                let (b, right) = right.split_at_mut(1);
233                let t = b[0];
234                b[0] = a[0];
235                a[0] += &t;
236                b[0] -= &t;
237
238                left.iter_mut()
239                    .zip(right.iter_mut())
240                    .enumerate()
241                    .for_each(|(i, (a, b))| {
242                        let mut t = *b;
243                        t *= &twiddles[(i + 1) * twiddle_chunk];
244                        *b = *a;
245                        *a += &t;
246                        *b -= &t;
247                    });
248            });
249            chunk *= 2;
250            twiddle_chunk /= 2;
251        }
252    } else {
253        recursive_butterfly_arithmetic(a, n, 1, &twiddles)
254    }
255}
256
257/// This perform recursive butterfly arithmetic
258pub fn recursive_butterfly_arithmetic<Scalar: Field, G: FftGroup<Scalar>>(
259    a: &mut [G],
260    n: usize,
261    twiddle_chunk: usize,
262    twiddles: &[Scalar],
263) {
264    if n == 2 {
265        let t = a[1];
266        a[1] = a[0];
267        a[0] += &t;
268        a[1] -= &t;
269    } else {
270        let (left, right) = a.split_at_mut(n / 2);
271        multicore::join(
272            || recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles),
273            || recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles),
274        );
275
276        // case when twiddle factor is one
277        let (a, left) = left.split_at_mut(1);
278        let (b, right) = right.split_at_mut(1);
279        let t = b[0];
280        b[0] = a[0];
281        a[0] += &t;
282        b[0] -= &t;
283
284        left.iter_mut()
285            .zip(right.iter_mut())
286            .enumerate()
287            .for_each(|(i, (a, b))| {
288                let mut t = *b;
289                t *= &twiddles[(i + 1) * twiddle_chunk];
290                *b = *a;
291                *a += &t;
292                *b -= &t;
293            });
294    }
295}
296
297/// This evaluates a provided polynomial (in coefficient form) at `point`.
298pub fn eval_polynomial<F: Field>(poly: &[F], point: F) -> F {
299    // TODO: parallelize?
300    poly.iter()
301        .rev()
302        .fold(F::ZERO, |acc, coeff| acc * point + coeff)
303}
304
305/// This computes the inner product of two vectors `a` and `b`.
306///
307/// This function will panic if the two vectors are not the same size.
308pub fn compute_inner_product<F: Field>(a: &[F], b: &[F]) -> F {
309    // TODO: parallelize?
310    assert_eq!(a.len(), b.len());
311
312    let mut acc = F::ZERO;
313    for (a, b) in a.iter().zip(b.iter()) {
314        acc += (*a) * (*b);
315    }
316
317    acc
318}
319
320/// Divides polynomial `a` in `X` by `X - b` with
321/// no remainder.
322pub fn kate_division<'a, F: Field, I: IntoIterator<Item = &'a F>>(a: I, mut b: F) -> Vec<F>
323where
324    I::IntoIter: DoubleEndedIterator + ExactSizeIterator,
325{
326    b = -b;
327    let a = a.into_iter();
328
329    let mut q = vec![F::ZERO; a.len() - 1];
330
331    let mut tmp = F::ZERO;
332    for (q, r) in q.iter_mut().rev().zip(a.rev()) {
333        let mut lead_coeff = *r;
334        lead_coeff.sub_assign(&tmp);
335        *q = lead_coeff;
336        tmp = lead_coeff;
337        tmp.mul_assign(&b);
338    }
339
340    q
341}
342
343/// This simple utility function will parallelize an operation that is to be
344/// performed over a mutable slice.
345pub fn parallelize<T: Send, F: Fn(&mut [T], usize) + Send + Sync + Clone>(v: &mut [T], f: F) {
346    let n = v.len();
347    let num_threads = multicore::current_num_threads();
348    let mut chunk = n / num_threads;
349    if chunk < num_threads {
350        chunk = n;
351    }
352
353    multicore::scope(|scope| {
354        for (chunk_num, v) in v.chunks_mut(chunk).enumerate() {
355            let f = f.clone();
356            scope.spawn(move |_| {
357                let start = chunk_num * chunk;
358                f(v, start);
359            });
360        }
361    });
362}
363
364fn log2_floor(num: usize) -> u32 {
365    assert!(num > 0);
366
367    let mut pow = 0;
368
369    while (1 << (pow + 1)) <= num {
370        pow += 1;
371    }
372
373    pow
374}
375
376/// Returns coefficients of an n - 1 degree polynomial given a set of n points
377/// and their evaluations. This function will panic if two values in `points`
378/// are the same.
379pub fn lagrange_interpolate<F: Field>(points: &[F], evals: &[F]) -> Vec<F> {
380    assert_eq!(points.len(), evals.len());
381    if points.len() == 1 {
382        // Constant polynomial
383        vec![evals[0]]
384    } else {
385        let mut denoms = Vec::with_capacity(points.len());
386        for (j, x_j) in points.iter().enumerate() {
387            let mut denom = Vec::with_capacity(points.len() - 1);
388            for x_k in points
389                .iter()
390                .enumerate()
391                .filter(|&(k, _)| k != j)
392                .map(|a| a.1)
393            {
394                denom.push(*x_j - x_k);
395            }
396            denoms.push(denom);
397        }
398        // Compute (x_j - x_k)^(-1) for each j != i
399        denoms.iter_mut().flat_map(|v| v.iter_mut()).batch_invert();
400
401        let mut final_poly = vec![F::ZERO; points.len()];
402        for (j, (denoms, eval)) in denoms.into_iter().zip(evals.iter()).enumerate() {
403            let mut tmp: Vec<F> = Vec::with_capacity(points.len());
404            let mut product = Vec::with_capacity(points.len() - 1);
405            tmp.push(F::ONE);
406            for (x_k, denom) in points
407                .iter()
408                .enumerate()
409                .filter(|&(k, _)| k != j)
410                .map(|a| a.1)
411                .zip(denoms.into_iter())
412            {
413                product.resize(tmp.len() + 1, F::ZERO);
414                for ((a, b), product) in tmp
415                    .iter()
416                    .chain(std::iter::once(&F::ZERO))
417                    .zip(std::iter::once(&F::ZERO).chain(tmp.iter()))
418                    .zip(product.iter_mut())
419                {
420                    *product = *a * (-denom * x_k) + *b * denom;
421                }
422                std::mem::swap(&mut tmp, &mut product);
423            }
424            assert_eq!(tmp.len(), points.len());
425            assert_eq!(product.len(), points.len() - 1);
426            for (final_coeff, interpolation_coeff) in final_poly.iter_mut().zip(tmp.into_iter()) {
427                *final_coeff += interpolation_coeff * eval;
428            }
429        }
430        final_poly
431    }
432}
433
434#[cfg(test)]
435use rand_core::OsRng;
436
437#[cfg(test)]
438use crate::pasta::{Eq, EqAffine, Fp};
439
440#[test]
441fn test_multiexp() {
442    let rng = OsRng;
443    let k = 8;
444
445    let coeffs = (0..(1 << k)).map(|_| Fp::random(rng)).collect::<Vec<_>>();
446    let bases = (0..(1 << k))
447        .map(|_| EqAffine::from(Eq::random(rng)))
448        .collect::<Vec<_>>();
449
450    let expected = best_multiexp(&coeffs, &bases);
451    let actual = coeffs
452        .iter()
453        .zip(bases)
454        .map(|(coeff, base)| base * coeff)
455        .fold(Eq::identity(), |acc, val| acc + val);
456
457    assert_eq!(expected, actual);
458}
459
460#[test]
461fn test_lagrange_interpolate() {
462    let rng = OsRng;
463
464    let points = (0..5).map(|_| Fp::random(rng)).collect::<Vec<_>>();
465    let evals = (0..5).map(|_| Fp::random(rng)).collect::<Vec<_>>();
466
467    for coeffs in 0..5 {
468        let points = &points[0..coeffs];
469        let evals = &evals[0..coeffs];
470
471        let poly = lagrange_interpolate(points, evals);
472        assert_eq!(poly.len(), points.len());
473
474        for (point, eval) in points.iter().zip(evals) {
475            assert_eq!(eval_polynomial(&poly, *point), *eval);
476        }
477    }
478}