1pub use ff::Field;
5use group::{
6 ff::{BatchInvert, PrimeField},
7 Group as _, GroupOpsOwned, ScalarMulOwned,
8};
9use maybe_rayon::prelude::*;
10pub use pasta_curves::arithmetic::*;
11
12use crate::multicore::{self, TheBestReduce};
13
14pub trait FftGroup<Scalar: Field>:
18 Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>
19{
20}
21
22impl<T, Scalar> FftGroup<Scalar> for T
23where
24 Scalar: Field,
25 T: Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>,
26{
27}
28
29#[derive(Clone, Copy)]
30enum Bucket<C: CurveAffine> {
31 None,
32 Affine(C),
33 Projective(C::Curve),
34}
35
36impl<C: CurveAffine> Bucket<C> {
37 fn add_assign(&mut self, other: &C) {
38 *self = match *self {
39 Bucket::None => Bucket::Affine(*other),
40 Bucket::Affine(a) => Bucket::Projective(a + *other),
41 Bucket::Projective(mut a) => {
42 a += *other;
43 Bucket::Projective(a)
44 }
45 }
46 }
47
48 fn add(self, mut other: C::Curve) -> C::Curve {
49 match self {
50 Bucket::None => other,
51 Bucket::Affine(a) => {
52 other += a;
53 other
54 }
55 Bucket::Projective(a) => other + &a,
56 }
57 }
58}
59
60#[derive(Clone)]
61struct Buckets<C: CurveAffine> {
62 c: usize,
63 coeffs: Vec<Bucket<C>>,
64}
65
66impl<C: CurveAffine> Buckets<C> {
67 fn new(c: usize) -> Self {
68 Self {
69 c,
70 coeffs: vec![Bucket::None; (1 << c) - 1],
71 }
72 }
73
74 fn sum(&mut self, coeffs: &[C::Scalar], bases: &[C], i: usize) -> C::Curve {
75 for (coeff, base) in coeffs.iter().zip(bases.iter()) {
77 let seg = self.get_at::<C::Scalar>(i, &coeff.to_repr());
78 if seg != 0 {
79 self.coeffs[seg - 1].add_assign(base);
80 }
81 }
82 let mut acc = C::Curve::identity();
87 let mut sum = C::Curve::identity();
88 self.coeffs.iter().rev().for_each(|b| {
89 sum = b.add(sum);
90 acc += sum;
91 });
92 acc
93 }
94
95 fn get_at<F: PrimeField>(&self, segment: usize, bytes: &F::Repr) -> usize {
96 let skip_bits = segment * self.c;
97 let skip_bytes = skip_bits / 8;
98
99 if skip_bytes >= 32 {
100 0
101 } else {
102 let mut v = [0; 8];
103 for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) {
104 *v = *o;
105 }
106
107 let mut tmp = u64::from_le_bytes(v);
108 tmp >>= skip_bits - (skip_bytes * 8);
109 (tmp % (1 << self.c)) as usize
110 }
111 }
112}
113
114pub fn small_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
117 let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
118 let mut acc = C::Curve::identity();
119
120 for byte_idx in (0..32).rev() {
122 for bit_idx in (0..8).rev() {
124 acc = acc.double();
125 for coeff_idx in 0..coeffs.len() {
127 let byte = coeffs[coeff_idx].as_ref()[byte_idx];
128 if ((byte >> bit_idx) & 1) != 0 {
129 acc += bases[coeff_idx];
130 }
131 }
132 }
133 }
134
135 acc
136}
137
138pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
144 assert_eq!(coeffs.len(), bases.len());
145
146 let c = if bases.len() < 4 {
147 1
148 } else if bases.len() < 32 {
149 3
150 } else {
151 (f64::from(bases.len() as u32)).ln().ceil() as usize
152 };
153
154 let mut multi_buckets: Vec<Buckets<C>> = vec![Buckets::new(c); (256 / c) + 1];
155 let num_threads = multicore::current_num_threads();
156 if coeffs.len() > num_threads {
157 multi_buckets
158 .par_iter_mut()
159 .enumerate()
160 .rev()
161 .map(|(i, buckets)| {
162 let mut acc = buckets.sum(coeffs, bases, i);
163 (0..c * i).for_each(|_| acc = acc.double());
164 acc
165 })
166 .the_best_reduce(C::Curve::identity, |a, b| a + b)
167 .expect("multi_buckets always contains at least 1 bucket")
168 } else {
169 multi_buckets
170 .iter_mut()
171 .enumerate()
172 .rev()
173 .map(|(i, buckets)| buckets.sum(coeffs, bases, i))
174 .fold(C::Curve::identity(), |mut sum, bucket| {
175 (0..c).for_each(|_| sum = sum.double());
177 sum + bucket
178 })
179 }
180}
181
182pub fn best_fft<Scalar: Field, G: FftGroup<Scalar>>(a: &mut [G], omega: Scalar, log_n: u32) {
193 fn bitreverse(mut n: usize, l: usize) -> usize {
194 let mut r = 0;
195 for _ in 0..l {
196 r = (r << 1) | (n & 1);
197 n >>= 1;
198 }
199 r
200 }
201
202 let threads = multicore::current_num_threads();
203 let log_threads = log2_floor(threads);
204 let n = a.len();
205 assert_eq!(n, 1 << log_n);
206
207 for k in 0..n {
208 let rk = bitreverse(k, log_n as usize);
209 if k < rk {
210 a.swap(rk, k);
211 }
212 }
213
214 let twiddles: Vec<_> = (0..(n / 2))
216 .scan(Scalar::ONE, |w, _| {
217 let tw = *w;
218 *w *= ω
219 Some(tw)
220 })
221 .collect();
222
223 if log_n <= log_threads {
224 let mut chunk = 2_usize;
225 let mut twiddle_chunk = n / 2;
226 for _ in 0..log_n {
227 a.chunks_mut(chunk).for_each(|coeffs| {
228 let (left, right) = coeffs.split_at_mut(chunk / 2);
229
230 let (a, left) = left.split_at_mut(1);
232 let (b, right) = right.split_at_mut(1);
233 let t = b[0];
234 b[0] = a[0];
235 a[0] += &t;
236 b[0] -= &t;
237
238 left.iter_mut()
239 .zip(right.iter_mut())
240 .enumerate()
241 .for_each(|(i, (a, b))| {
242 let mut t = *b;
243 t *= &twiddles[(i + 1) * twiddle_chunk];
244 *b = *a;
245 *a += &t;
246 *b -= &t;
247 });
248 });
249 chunk *= 2;
250 twiddle_chunk /= 2;
251 }
252 } else {
253 recursive_butterfly_arithmetic(a, n, 1, &twiddles)
254 }
255}
256
257pub fn recursive_butterfly_arithmetic<Scalar: Field, G: FftGroup<Scalar>>(
259 a: &mut [G],
260 n: usize,
261 twiddle_chunk: usize,
262 twiddles: &[Scalar],
263) {
264 if n == 2 {
265 let t = a[1];
266 a[1] = a[0];
267 a[0] += &t;
268 a[1] -= &t;
269 } else {
270 let (left, right) = a.split_at_mut(n / 2);
271 multicore::join(
272 || recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles),
273 || recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles),
274 );
275
276 let (a, left) = left.split_at_mut(1);
278 let (b, right) = right.split_at_mut(1);
279 let t = b[0];
280 b[0] = a[0];
281 a[0] += &t;
282 b[0] -= &t;
283
284 left.iter_mut()
285 .zip(right.iter_mut())
286 .enumerate()
287 .for_each(|(i, (a, b))| {
288 let mut t = *b;
289 t *= &twiddles[(i + 1) * twiddle_chunk];
290 *b = *a;
291 *a += &t;
292 *b -= &t;
293 });
294 }
295}
296
297pub fn eval_polynomial<F: Field>(poly: &[F], point: F) -> F {
299 poly.iter()
301 .rev()
302 .fold(F::ZERO, |acc, coeff| acc * point + coeff)
303}
304
305pub fn compute_inner_product<F: Field>(a: &[F], b: &[F]) -> F {
309 assert_eq!(a.len(), b.len());
311
312 let mut acc = F::ZERO;
313 for (a, b) in a.iter().zip(b.iter()) {
314 acc += (*a) * (*b);
315 }
316
317 acc
318}
319
320pub fn kate_division<'a, F: Field, I: IntoIterator<Item = &'a F>>(a: I, mut b: F) -> Vec<F>
323where
324 I::IntoIter: DoubleEndedIterator + ExactSizeIterator,
325{
326 b = -b;
327 let a = a.into_iter();
328
329 let mut q = vec![F::ZERO; a.len() - 1];
330
331 let mut tmp = F::ZERO;
332 for (q, r) in q.iter_mut().rev().zip(a.rev()) {
333 let mut lead_coeff = *r;
334 lead_coeff.sub_assign(&tmp);
335 *q = lead_coeff;
336 tmp = lead_coeff;
337 tmp.mul_assign(&b);
338 }
339
340 q
341}
342
343pub fn parallelize<T: Send, F: Fn(&mut [T], usize) + Send + Sync + Clone>(v: &mut [T], f: F) {
346 let n = v.len();
347 let num_threads = multicore::current_num_threads();
348 let mut chunk = n / num_threads;
349 if chunk < num_threads {
350 chunk = n;
351 }
352
353 multicore::scope(|scope| {
354 for (chunk_num, v) in v.chunks_mut(chunk).enumerate() {
355 let f = f.clone();
356 scope.spawn(move |_| {
357 let start = chunk_num * chunk;
358 f(v, start);
359 });
360 }
361 });
362}
363
364fn log2_floor(num: usize) -> u32 {
365 assert!(num > 0);
366
367 let mut pow = 0;
368
369 while (1 << (pow + 1)) <= num {
370 pow += 1;
371 }
372
373 pow
374}
375
376pub fn lagrange_interpolate<F: Field>(points: &[F], evals: &[F]) -> Vec<F> {
380 assert_eq!(points.len(), evals.len());
381 if points.len() == 1 {
382 vec![evals[0]]
384 } else {
385 let mut denoms = Vec::with_capacity(points.len());
386 for (j, x_j) in points.iter().enumerate() {
387 let mut denom = Vec::with_capacity(points.len() - 1);
388 for x_k in points
389 .iter()
390 .enumerate()
391 .filter(|&(k, _)| k != j)
392 .map(|a| a.1)
393 {
394 denom.push(*x_j - x_k);
395 }
396 denoms.push(denom);
397 }
398 denoms.iter_mut().flat_map(|v| v.iter_mut()).batch_invert();
400
401 let mut final_poly = vec![F::ZERO; points.len()];
402 for (j, (denoms, eval)) in denoms.into_iter().zip(evals.iter()).enumerate() {
403 let mut tmp: Vec<F> = Vec::with_capacity(points.len());
404 let mut product = Vec::with_capacity(points.len() - 1);
405 tmp.push(F::ONE);
406 for (x_k, denom) in points
407 .iter()
408 .enumerate()
409 .filter(|&(k, _)| k != j)
410 .map(|a| a.1)
411 .zip(denoms.into_iter())
412 {
413 product.resize(tmp.len() + 1, F::ZERO);
414 for ((a, b), product) in tmp
415 .iter()
416 .chain(std::iter::once(&F::ZERO))
417 .zip(std::iter::once(&F::ZERO).chain(tmp.iter()))
418 .zip(product.iter_mut())
419 {
420 *product = *a * (-denom * x_k) + *b * denom;
421 }
422 std::mem::swap(&mut tmp, &mut product);
423 }
424 assert_eq!(tmp.len(), points.len());
425 assert_eq!(product.len(), points.len() - 1);
426 for (final_coeff, interpolation_coeff) in final_poly.iter_mut().zip(tmp.into_iter()) {
427 *final_coeff += interpolation_coeff * eval;
428 }
429 }
430 final_poly
431 }
432}
433
434#[cfg(test)]
435use rand_core::OsRng;
436
437#[cfg(test)]
438use crate::pasta::{Eq, EqAffine, Fp};
439
440#[test]
441fn test_multiexp() {
442 let rng = OsRng;
443 let k = 8;
444
445 let coeffs = (0..(1 << k)).map(|_| Fp::random(rng)).collect::<Vec<_>>();
446 let bases = (0..(1 << k))
447 .map(|_| EqAffine::from(Eq::random(rng)))
448 .collect::<Vec<_>>();
449
450 let expected = best_multiexp(&coeffs, &bases);
451 let actual = coeffs
452 .iter()
453 .zip(bases)
454 .map(|(coeff, base)| base * coeff)
455 .fold(Eq::identity(), |acc, val| acc + val);
456
457 assert_eq!(expected, actual);
458}
459
460#[test]
461fn test_lagrange_interpolate() {
462 let rng = OsRng;
463
464 let points = (0..5).map(|_| Fp::random(rng)).collect::<Vec<_>>();
465 let evals = (0..5).map(|_| Fp::random(rng)).collect::<Vec<_>>();
466
467 for coeffs in 0..5 {
468 let points = &points[0..coeffs];
469 let evals = &evals[0..coeffs];
470
471 let poly = lagrange_interpolate(points, evals);
472 assert_eq!(poly.len(), points.len());
473
474 for (point, eval) in points.iter().zip(evals) {
475 assert_eq!(eval_polynomial(&poly, *point), *eval);
476 }
477 }
478}