halo2_proofs/
arithmetic.rs

1//! This module provides common utilities, traits and structures for group,
2//! field and polynomial arithmetic.
3
4use super::multicore;
5pub use ff::Field;
6use group::{
7    ff::{BatchInvert, PrimeField},
8    Group as _, GroupOpsOwned, ScalarMulOwned,
9};
10
11pub use pasta_curves::arithmetic::*;
12
13/// This represents an element of a group with basic operations that can be
14/// performed. This allows an FFT implementation (for example) to operate
15/// generically over either a field or elliptic curve group.
16pub trait FftGroup<Scalar: Field>:
17    Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>
18{
19}
20
21impl<T, Scalar> FftGroup<Scalar> for T
22where
23    Scalar: Field,
24    T: Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>,
25{
26}
27
28fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
29    let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
30
31    let c = if bases.len() < 4 {
32        1
33    } else if bases.len() < 32 {
34        3
35    } else {
36        (f64::from(bases.len() as u32)).ln().ceil() as usize
37    };
38
39    fn get_at<F: PrimeField>(segment: usize, c: usize, bytes: &F::Repr) -> usize {
40        let skip_bits = segment * c;
41        let skip_bytes = skip_bits / 8;
42
43        if skip_bytes >= 32 {
44            return 0;
45        }
46
47        let mut v = [0; 8];
48        for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) {
49            *v = *o;
50        }
51
52        let mut tmp = u64::from_le_bytes(v);
53        tmp >>= skip_bits - (skip_bytes * 8);
54        tmp %= 1 << c;
55
56        tmp as usize
57    }
58
59    let segments = (256 / c) + 1;
60
61    for current_segment in (0..segments).rev() {
62        for _ in 0..c {
63            *acc = acc.double();
64        }
65
66        #[derive(Clone, Copy)]
67        enum Bucket<C: CurveAffine> {
68            None,
69            Affine(C),
70            Projective(C::Curve),
71        }
72
73        impl<C: CurveAffine> Bucket<C> {
74            fn add_assign(&mut self, other: &C) {
75                *self = match *self {
76                    Bucket::None => Bucket::Affine(*other),
77                    Bucket::Affine(a) => Bucket::Projective(a + *other),
78                    Bucket::Projective(mut a) => {
79                        a += *other;
80                        Bucket::Projective(a)
81                    }
82                }
83            }
84
85            fn add(self, mut other: C::Curve) -> C::Curve {
86                match self {
87                    Bucket::None => other,
88                    Bucket::Affine(a) => {
89                        other += a;
90                        other
91                    }
92                    Bucket::Projective(a) => other + &a,
93                }
94            }
95        }
96
97        let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; (1 << c) - 1];
98
99        for (coeff, base) in coeffs.iter().zip(bases.iter()) {
100            let coeff = get_at::<C::Scalar>(current_segment, c, coeff);
101            if coeff != 0 {
102                buckets[coeff - 1].add_assign(base);
103            }
104        }
105
106        // Summation by parts
107        // e.g. 3a + 2b + 1c = a +
108        //                    (a) + b +
109        //                    ((a) + b) + c
110        let mut running_sum = C::Curve::identity();
111        for exp in buckets.into_iter().rev() {
112            running_sum = exp.add(running_sum);
113            *acc += &running_sum;
114        }
115    }
116}
117
118/// Performs a small multi-exponentiation operation.
119/// Uses the double-and-add algorithm with doublings shared across points.
120pub fn small_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
121    let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
122    let mut acc = C::Curve::identity();
123
124    // for byte idx
125    for byte_idx in (0..32).rev() {
126        // for bit idx
127        for bit_idx in (0..8).rev() {
128            acc = acc.double();
129            // for each coeff
130            for coeff_idx in 0..coeffs.len() {
131                let byte = coeffs[coeff_idx].as_ref()[byte_idx];
132                if ((byte >> bit_idx) & 1) != 0 {
133                    acc += bases[coeff_idx];
134                }
135            }
136        }
137    }
138
139    acc
140}
141
142/// Performs a multi-exponentiation operation.
143///
144/// This function will panic if coeffs and bases have a different length.
145///
146/// This will use multithreading if beneficial.
147pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
148    assert_eq!(coeffs.len(), bases.len());
149
150    let num_threads = multicore::current_num_threads();
151    if coeffs.len() > num_threads {
152        let chunk = coeffs.len() / num_threads;
153        let num_chunks = coeffs.chunks(chunk).len();
154        let mut results = vec![C::Curve::identity(); num_chunks];
155        multicore::scope(|scope| {
156            let chunk = coeffs.len() / num_threads;
157
158            for ((coeffs, bases), acc) in coeffs
159                .chunks(chunk)
160                .zip(bases.chunks(chunk))
161                .zip(results.iter_mut())
162            {
163                scope.spawn(move |_| {
164                    multiexp_serial(coeffs, bases, acc);
165                });
166            }
167        });
168        results.iter().fold(C::Curve::identity(), |a, b| a + b)
169    } else {
170        let mut acc = C::Curve::identity();
171        multiexp_serial(coeffs, bases, &mut acc);
172        acc
173    }
174}
175
176/// Performs a radix-$2$ Fast-Fourier Transformation (FFT) on a vector of size
177/// $n = 2^k$, when provided `log_n` = $k$ and an element of multiplicative
178/// order $n$ called `omega` ($\omega$). The result is that the vector `a`, when
179/// interpreted as the coefficients of a polynomial of degree $n - 1$, is
180/// transformed into the evaluations of this polynomial at each of the $n$
181/// distinct powers of $\omega$. This transformation is invertible by providing
182/// $\omega^{-1}$ in place of $\omega$ and dividing each resulting field element
183/// by $n$.
184///
185/// This will use multithreading if beneficial.
186pub fn best_fft<Scalar: Field, G: FftGroup<Scalar>>(a: &mut [G], omega: Scalar, log_n: u32) {
187    fn bitreverse(mut n: usize, l: usize) -> usize {
188        let mut r = 0;
189        for _ in 0..l {
190            r = (r << 1) | (n & 1);
191            n >>= 1;
192        }
193        r
194    }
195
196    let threads = multicore::current_num_threads();
197    let log_threads = log2_floor(threads);
198    let n = a.len();
199    assert_eq!(n, 1 << log_n);
200
201    for k in 0..n {
202        let rk = bitreverse(k, log_n as usize);
203        if k < rk {
204            a.swap(rk, k);
205        }
206    }
207
208    // precompute twiddle factors
209    let twiddles: Vec<_> = (0..(n / 2))
210        .scan(Scalar::ONE, |w, _| {
211            let tw = *w;
212            *w *= &omega;
213            Some(tw)
214        })
215        .collect();
216
217    if log_n <= log_threads {
218        let mut chunk = 2_usize;
219        let mut twiddle_chunk = n / 2;
220        for _ in 0..log_n {
221            a.chunks_mut(chunk).for_each(|coeffs| {
222                let (left, right) = coeffs.split_at_mut(chunk / 2);
223
224                // case when twiddle factor is one
225                let (a, left) = left.split_at_mut(1);
226                let (b, right) = right.split_at_mut(1);
227                let t = b[0];
228                b[0] = a[0];
229                a[0] += &t;
230                b[0] -= &t;
231
232                left.iter_mut()
233                    .zip(right.iter_mut())
234                    .enumerate()
235                    .for_each(|(i, (a, b))| {
236                        let mut t = *b;
237                        t *= &twiddles[(i + 1) * twiddle_chunk];
238                        *b = *a;
239                        *a += &t;
240                        *b -= &t;
241                    });
242            });
243            chunk *= 2;
244            twiddle_chunk /= 2;
245        }
246    } else {
247        recursive_butterfly_arithmetic(a, n, 1, &twiddles)
248    }
249}
250
251/// This perform recursive butterfly arithmetic
252pub fn recursive_butterfly_arithmetic<Scalar: Field, G: FftGroup<Scalar>>(
253    a: &mut [G],
254    n: usize,
255    twiddle_chunk: usize,
256    twiddles: &[Scalar],
257) {
258    if n == 2 {
259        let t = a[1];
260        a[1] = a[0];
261        a[0] += &t;
262        a[1] -= &t;
263    } else {
264        let (left, right) = a.split_at_mut(n / 2);
265        multicore::join(
266            || recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles),
267            || recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles),
268        );
269
270        // case when twiddle factor is one
271        let (a, left) = left.split_at_mut(1);
272        let (b, right) = right.split_at_mut(1);
273        let t = b[0];
274        b[0] = a[0];
275        a[0] += &t;
276        b[0] -= &t;
277
278        left.iter_mut()
279            .zip(right.iter_mut())
280            .enumerate()
281            .for_each(|(i, (a, b))| {
282                let mut t = *b;
283                t *= &twiddles[(i + 1) * twiddle_chunk];
284                *b = *a;
285                *a += &t;
286                *b -= &t;
287            });
288    }
289}
290
291/// This evaluates a provided polynomial (in coefficient form) at `point`.
292pub fn eval_polynomial<F: Field>(poly: &[F], point: F) -> F {
293    // TODO: parallelize?
294    poly.iter()
295        .rev()
296        .fold(F::ZERO, |acc, coeff| acc * point + coeff)
297}
298
299/// This computes the inner product of two vectors `a` and `b`.
300///
301/// This function will panic if the two vectors are not the same size.
302pub fn compute_inner_product<F: Field>(a: &[F], b: &[F]) -> F {
303    // TODO: parallelize?
304    assert_eq!(a.len(), b.len());
305
306    let mut acc = F::ZERO;
307    for (a, b) in a.iter().zip(b.iter()) {
308        acc += (*a) * (*b);
309    }
310
311    acc
312}
313
314/// Divides polynomial `a` in `X` by `X - b` with
315/// no remainder.
316pub fn kate_division<'a, F: Field, I: IntoIterator<Item = &'a F>>(a: I, mut b: F) -> Vec<F>
317where
318    I::IntoIter: DoubleEndedIterator + ExactSizeIterator,
319{
320    b = -b;
321    let a = a.into_iter();
322
323    let mut q = vec![F::ZERO; a.len() - 1];
324
325    let mut tmp = F::ZERO;
326    for (q, r) in q.iter_mut().rev().zip(a.rev()) {
327        let mut lead_coeff = *r;
328        lead_coeff.sub_assign(&tmp);
329        *q = lead_coeff;
330        tmp = lead_coeff;
331        tmp.mul_assign(&b);
332    }
333
334    q
335}
336
337/// This simple utility function will parallelize an operation that is to be
338/// performed over a mutable slice.
339pub fn parallelize<T: Send, F: Fn(&mut [T], usize) + Send + Sync + Clone>(v: &mut [T], f: F) {
340    let n = v.len();
341    let num_threads = multicore::current_num_threads();
342    let mut chunk = n / num_threads;
343    if chunk < num_threads {
344        chunk = n;
345    }
346
347    multicore::scope(|scope| {
348        for (chunk_num, v) in v.chunks_mut(chunk).enumerate() {
349            let f = f.clone();
350            scope.spawn(move |_| {
351                let start = chunk_num * chunk;
352                f(v, start);
353            });
354        }
355    });
356}
357
358fn log2_floor(num: usize) -> u32 {
359    assert!(num > 0);
360
361    let mut pow = 0;
362
363    while (1 << (pow + 1)) <= num {
364        pow += 1;
365    }
366
367    pow
368}
369
370/// Returns coefficients of an n - 1 degree polynomial given a set of n points
371/// and their evaluations. This function will panic if two values in `points`
372/// are the same.
373pub fn lagrange_interpolate<F: Field>(points: &[F], evals: &[F]) -> Vec<F> {
374    assert_eq!(points.len(), evals.len());
375    if points.len() == 1 {
376        // Constant polynomial
377        vec![evals[0]]
378    } else {
379        let mut denoms = Vec::with_capacity(points.len());
380        for (j, x_j) in points.iter().enumerate() {
381            let mut denom = Vec::with_capacity(points.len() - 1);
382            for x_k in points
383                .iter()
384                .enumerate()
385                .filter(|&(k, _)| k != j)
386                .map(|a| a.1)
387            {
388                denom.push(*x_j - x_k);
389            }
390            denoms.push(denom);
391        }
392        // Compute (x_j - x_k)^(-1) for each j != i
393        denoms.iter_mut().flat_map(|v| v.iter_mut()).batch_invert();
394
395        let mut final_poly = vec![F::ZERO; points.len()];
396        for (j, (denoms, eval)) in denoms.into_iter().zip(evals.iter()).enumerate() {
397            let mut tmp: Vec<F> = Vec::with_capacity(points.len());
398            let mut product = Vec::with_capacity(points.len() - 1);
399            tmp.push(F::ONE);
400            for (x_k, denom) in points
401                .iter()
402                .enumerate()
403                .filter(|&(k, _)| k != j)
404                .map(|a| a.1)
405                .zip(denoms.into_iter())
406            {
407                product.resize(tmp.len() + 1, F::ZERO);
408                for ((a, b), product) in tmp
409                    .iter()
410                    .chain(std::iter::once(&F::ZERO))
411                    .zip(std::iter::once(&F::ZERO).chain(tmp.iter()))
412                    .zip(product.iter_mut())
413                {
414                    *product = *a * (-denom * x_k) + *b * denom;
415                }
416                std::mem::swap(&mut tmp, &mut product);
417            }
418            assert_eq!(tmp.len(), points.len());
419            assert_eq!(product.len(), points.len() - 1);
420            for (final_coeff, interpolation_coeff) in final_poly.iter_mut().zip(tmp.into_iter()) {
421                *final_coeff += interpolation_coeff * eval;
422            }
423        }
424        final_poly
425    }
426}
427
428#[cfg(test)]
429use rand_core::OsRng;
430
431#[cfg(test)]
432use crate::pasta::Fp;
433
434#[test]
435fn test_lagrange_interpolate() {
436    let rng = OsRng;
437
438    let points = (0..5).map(|_| Fp::random(rng)).collect::<Vec<_>>();
439    let evals = (0..5).map(|_| Fp::random(rng)).collect::<Vec<_>>();
440
441    for coeffs in 0..5 {
442        let points = &points[0..coeffs];
443        let evals = &evals[0..coeffs];
444
445        let poly = lagrange_interpolate(points, evals);
446        assert_eq!(poly.len(), points.len());
447
448        for (point, eval) in points.iter().zip(evals) {
449            assert_eq!(eval_polynomial(&poly, *point), *eval);
450        }
451    }
452}