1use super::multicore;
5pub use ff::Field;
6use group::{
7 ff::{BatchInvert, PrimeField},
8 Group as _, GroupOpsOwned, ScalarMulOwned,
9};
10
11pub use pasta_curves::arithmetic::*;
12
13pub trait FftGroup<Scalar: Field>:
17 Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>
18{
19}
20
21impl<T, Scalar> FftGroup<Scalar> for T
22where
23 Scalar: Field,
24 T: Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>,
25{
26}
27
28fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
29 let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
30
31 let c = if bases.len() < 4 {
32 1
33 } else if bases.len() < 32 {
34 3
35 } else {
36 (f64::from(bases.len() as u32)).ln().ceil() as usize
37 };
38
39 fn get_at<F: PrimeField>(segment: usize, c: usize, bytes: &F::Repr) -> usize {
40 let skip_bits = segment * c;
41 let skip_bytes = skip_bits / 8;
42
43 if skip_bytes >= 32 {
44 return 0;
45 }
46
47 let mut v = [0; 8];
48 for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) {
49 *v = *o;
50 }
51
52 let mut tmp = u64::from_le_bytes(v);
53 tmp >>= skip_bits - (skip_bytes * 8);
54 tmp %= 1 << c;
55
56 tmp as usize
57 }
58
59 let segments = (256 / c) + 1;
60
61 for current_segment in (0..segments).rev() {
62 for _ in 0..c {
63 *acc = acc.double();
64 }
65
66 #[derive(Clone, Copy)]
67 enum Bucket<C: CurveAffine> {
68 None,
69 Affine(C),
70 Projective(C::Curve),
71 }
72
73 impl<C: CurveAffine> Bucket<C> {
74 fn add_assign(&mut self, other: &C) {
75 *self = match *self {
76 Bucket::None => Bucket::Affine(*other),
77 Bucket::Affine(a) => Bucket::Projective(a + *other),
78 Bucket::Projective(mut a) => {
79 a += *other;
80 Bucket::Projective(a)
81 }
82 }
83 }
84
85 fn add(self, mut other: C::Curve) -> C::Curve {
86 match self {
87 Bucket::None => other,
88 Bucket::Affine(a) => {
89 other += a;
90 other
91 }
92 Bucket::Projective(a) => other + &a,
93 }
94 }
95 }
96
97 let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; (1 << c) - 1];
98
99 for (coeff, base) in coeffs.iter().zip(bases.iter()) {
100 let coeff = get_at::<C::Scalar>(current_segment, c, coeff);
101 if coeff != 0 {
102 buckets[coeff - 1].add_assign(base);
103 }
104 }
105
106 let mut running_sum = C::Curve::identity();
111 for exp in buckets.into_iter().rev() {
112 running_sum = exp.add(running_sum);
113 *acc += &running_sum;
114 }
115 }
116}
117
118pub fn small_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
121 let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
122 let mut acc = C::Curve::identity();
123
124 for byte_idx in (0..32).rev() {
126 for bit_idx in (0..8).rev() {
128 acc = acc.double();
129 for coeff_idx in 0..coeffs.len() {
131 let byte = coeffs[coeff_idx].as_ref()[byte_idx];
132 if ((byte >> bit_idx) & 1) != 0 {
133 acc += bases[coeff_idx];
134 }
135 }
136 }
137 }
138
139 acc
140}
141
142pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
148 assert_eq!(coeffs.len(), bases.len());
149
150 let num_threads = multicore::current_num_threads();
151 if coeffs.len() > num_threads {
152 let chunk = coeffs.len() / num_threads;
153 let num_chunks = coeffs.chunks(chunk).len();
154 let mut results = vec![C::Curve::identity(); num_chunks];
155 multicore::scope(|scope| {
156 let chunk = coeffs.len() / num_threads;
157
158 for ((coeffs, bases), acc) in coeffs
159 .chunks(chunk)
160 .zip(bases.chunks(chunk))
161 .zip(results.iter_mut())
162 {
163 scope.spawn(move |_| {
164 multiexp_serial(coeffs, bases, acc);
165 });
166 }
167 });
168 results.iter().fold(C::Curve::identity(), |a, b| a + b)
169 } else {
170 let mut acc = C::Curve::identity();
171 multiexp_serial(coeffs, bases, &mut acc);
172 acc
173 }
174}
175
176pub fn best_fft<Scalar: Field, G: FftGroup<Scalar>>(a: &mut [G], omega: Scalar, log_n: u32) {
187 fn bitreverse(mut n: usize, l: usize) -> usize {
188 let mut r = 0;
189 for _ in 0..l {
190 r = (r << 1) | (n & 1);
191 n >>= 1;
192 }
193 r
194 }
195
196 let threads = multicore::current_num_threads();
197 let log_threads = log2_floor(threads);
198 let n = a.len();
199 assert_eq!(n, 1 << log_n);
200
201 for k in 0..n {
202 let rk = bitreverse(k, log_n as usize);
203 if k < rk {
204 a.swap(rk, k);
205 }
206 }
207
208 let twiddles: Vec<_> = (0..(n / 2))
210 .scan(Scalar::ONE, |w, _| {
211 let tw = *w;
212 *w *= ω
213 Some(tw)
214 })
215 .collect();
216
217 if log_n <= log_threads {
218 let mut chunk = 2_usize;
219 let mut twiddle_chunk = n / 2;
220 for _ in 0..log_n {
221 a.chunks_mut(chunk).for_each(|coeffs| {
222 let (left, right) = coeffs.split_at_mut(chunk / 2);
223
224 let (a, left) = left.split_at_mut(1);
226 let (b, right) = right.split_at_mut(1);
227 let t = b[0];
228 b[0] = a[0];
229 a[0] += &t;
230 b[0] -= &t;
231
232 left.iter_mut()
233 .zip(right.iter_mut())
234 .enumerate()
235 .for_each(|(i, (a, b))| {
236 let mut t = *b;
237 t *= &twiddles[(i + 1) * twiddle_chunk];
238 *b = *a;
239 *a += &t;
240 *b -= &t;
241 });
242 });
243 chunk *= 2;
244 twiddle_chunk /= 2;
245 }
246 } else {
247 recursive_butterfly_arithmetic(a, n, 1, &twiddles)
248 }
249}
250
251pub fn recursive_butterfly_arithmetic<Scalar: Field, G: FftGroup<Scalar>>(
253 a: &mut [G],
254 n: usize,
255 twiddle_chunk: usize,
256 twiddles: &[Scalar],
257) {
258 if n == 2 {
259 let t = a[1];
260 a[1] = a[0];
261 a[0] += &t;
262 a[1] -= &t;
263 } else {
264 let (left, right) = a.split_at_mut(n / 2);
265 multicore::join(
266 || recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles),
267 || recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles),
268 );
269
270 let (a, left) = left.split_at_mut(1);
272 let (b, right) = right.split_at_mut(1);
273 let t = b[0];
274 b[0] = a[0];
275 a[0] += &t;
276 b[0] -= &t;
277
278 left.iter_mut()
279 .zip(right.iter_mut())
280 .enumerate()
281 .for_each(|(i, (a, b))| {
282 let mut t = *b;
283 t *= &twiddles[(i + 1) * twiddle_chunk];
284 *b = *a;
285 *a += &t;
286 *b -= &t;
287 });
288 }
289}
290
291pub fn eval_polynomial<F: Field>(poly: &[F], point: F) -> F {
293 poly.iter()
295 .rev()
296 .fold(F::ZERO, |acc, coeff| acc * point + coeff)
297}
298
299pub fn compute_inner_product<F: Field>(a: &[F], b: &[F]) -> F {
303 assert_eq!(a.len(), b.len());
305
306 let mut acc = F::ZERO;
307 for (a, b) in a.iter().zip(b.iter()) {
308 acc += (*a) * (*b);
309 }
310
311 acc
312}
313
314pub fn kate_division<'a, F: Field, I: IntoIterator<Item = &'a F>>(a: I, mut b: F) -> Vec<F>
317where
318 I::IntoIter: DoubleEndedIterator + ExactSizeIterator,
319{
320 b = -b;
321 let a = a.into_iter();
322
323 let mut q = vec![F::ZERO; a.len() - 1];
324
325 let mut tmp = F::ZERO;
326 for (q, r) in q.iter_mut().rev().zip(a.rev()) {
327 let mut lead_coeff = *r;
328 lead_coeff.sub_assign(&tmp);
329 *q = lead_coeff;
330 tmp = lead_coeff;
331 tmp.mul_assign(&b);
332 }
333
334 q
335}
336
337pub fn parallelize<T: Send, F: Fn(&mut [T], usize) + Send + Sync + Clone>(v: &mut [T], f: F) {
340 let n = v.len();
341 let num_threads = multicore::current_num_threads();
342 let mut chunk = n / num_threads;
343 if chunk < num_threads {
344 chunk = n;
345 }
346
347 multicore::scope(|scope| {
348 for (chunk_num, v) in v.chunks_mut(chunk).enumerate() {
349 let f = f.clone();
350 scope.spawn(move |_| {
351 let start = chunk_num * chunk;
352 f(v, start);
353 });
354 }
355 });
356}
357
358fn log2_floor(num: usize) -> u32 {
359 assert!(num > 0);
360
361 let mut pow = 0;
362
363 while (1 << (pow + 1)) <= num {
364 pow += 1;
365 }
366
367 pow
368}
369
370pub fn lagrange_interpolate<F: Field>(points: &[F], evals: &[F]) -> Vec<F> {
374 assert_eq!(points.len(), evals.len());
375 if points.len() == 1 {
376 vec![evals[0]]
378 } else {
379 let mut denoms = Vec::with_capacity(points.len());
380 for (j, x_j) in points.iter().enumerate() {
381 let mut denom = Vec::with_capacity(points.len() - 1);
382 for x_k in points
383 .iter()
384 .enumerate()
385 .filter(|&(k, _)| k != j)
386 .map(|a| a.1)
387 {
388 denom.push(*x_j - x_k);
389 }
390 denoms.push(denom);
391 }
392 denoms.iter_mut().flat_map(|v| v.iter_mut()).batch_invert();
394
395 let mut final_poly = vec![F::ZERO; points.len()];
396 for (j, (denoms, eval)) in denoms.into_iter().zip(evals.iter()).enumerate() {
397 let mut tmp: Vec<F> = Vec::with_capacity(points.len());
398 let mut product = Vec::with_capacity(points.len() - 1);
399 tmp.push(F::ONE);
400 for (x_k, denom) in points
401 .iter()
402 .enumerate()
403 .filter(|&(k, _)| k != j)
404 .map(|a| a.1)
405 .zip(denoms.into_iter())
406 {
407 product.resize(tmp.len() + 1, F::ZERO);
408 for ((a, b), product) in tmp
409 .iter()
410 .chain(std::iter::once(&F::ZERO))
411 .zip(std::iter::once(&F::ZERO).chain(tmp.iter()))
412 .zip(product.iter_mut())
413 {
414 *product = *a * (-denom * x_k) + *b * denom;
415 }
416 std::mem::swap(&mut tmp, &mut product);
417 }
418 assert_eq!(tmp.len(), points.len());
419 assert_eq!(product.len(), points.len() - 1);
420 for (final_coeff, interpolation_coeff) in final_poly.iter_mut().zip(tmp.into_iter()) {
421 *final_coeff += interpolation_coeff * eval;
422 }
423 }
424 final_poly
425 }
426}
427
428#[cfg(test)]
429use rand_core::OsRng;
430
431#[cfg(test)]
432use crate::pasta::Fp;
433
434#[test]
435fn test_lagrange_interpolate() {
436 let rng = OsRng;
437
438 let points = (0..5).map(|_| Fp::random(rng)).collect::<Vec<_>>();
439 let evals = (0..5).map(|_| Fp::random(rng)).collect::<Vec<_>>();
440
441 for coeffs in 0..5 {
442 let points = &points[0..coeffs];
443 let evals = &evals[0..coeffs];
444
445 let poly = lagrange_interpolate(points, evals);
446 assert_eq!(poly.len(), points.len());
447
448 for (point, eval) in points.iter().zip(evals) {
449 assert_eq!(eval_polynomial(&poly, *point), *eval);
450 }
451 }
452}