1#![allow(missing_docs)]
2use std::convert::TryInto;
3use std::io;
4use std::iter;
5use std::ops::AddAssign;
6use std::sync::Arc;
7
8use bitvec::prelude::{BitVec, Lsb0};
9use ff::{Field, PrimeField};
10use group::{prime::PrimeCurveAffine, Group};
11use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
12
13use crate::error::EcError;
14use crate::threadpool::{Waiter, Worker};
15
16pub trait SourceBuilder<G: PrimeCurveAffine>: Send + Sync + 'static + Clone {
18 type Source: Source<G>;
19
20 #[allow(clippy::wrong_self_convention)]
21 fn new(self) -> Self::Source;
22 fn get(self) -> (Arc<Vec<G>>, usize);
23}
24
25pub trait Source<G: PrimeCurveAffine> {
27 fn add_assign_mixed(&mut self, to: &mut <G as PrimeCurveAffine>::Curve) -> Result<(), EcError>;
29
30 fn skip(&mut self, amt: usize) -> Result<(), EcError>;
32}
33
34impl<G: PrimeCurveAffine> SourceBuilder<G> for (Arc<Vec<G>>, usize) {
35 type Source = (Arc<Vec<G>>, usize);
36
37 fn new(self) -> (Arc<Vec<G>>, usize) {
38 (self.0.clone(), self.1)
39 }
40
41 fn get(self) -> (Arc<Vec<G>>, usize) {
42 (self.0.clone(), self.1)
43 }
44}
45
46impl<G: PrimeCurveAffine> Source<G> for (Arc<Vec<G>>, usize) {
47 fn add_assign_mixed(&mut self, to: &mut <G as PrimeCurveAffine>::Curve) -> Result<(), EcError> {
48 if self.0.len() <= self.1 {
49 return Err(io::Error::new(
50 io::ErrorKind::UnexpectedEof,
51 "Expected more bases from source.",
52 )
53 .into());
54 }
55
56 if self.0[self.1].is_identity().into() {
57 return Err(EcError::Simple(
58 "Encountered an identity element in the CRS.",
59 ));
60 }
61
62 to.add_assign(&self.0[self.1]);
63
64 self.1 += 1;
65
66 Ok(())
67 }
68
69 fn skip(&mut self, amt: usize) -> Result<(), EcError> {
70 if self.0.len() <= self.1 {
71 return Err(io::Error::new(
72 io::ErrorKind::UnexpectedEof,
73 "Expected more bases from source.",
74 )
75 .into());
76 }
77
78 self.1 += amt;
79
80 Ok(())
81 }
82}
83
84pub trait QueryDensity: Sized {
85 type Iter: Iterator<Item = bool>;
87
88 fn iter(self) -> Self::Iter;
89 fn get_query_size(self) -> Option<usize>;
90 fn generate_exps<F: PrimeField>(self, exponents: Arc<Vec<F::Repr>>) -> Arc<Vec<F::Repr>>;
91}
92
93#[derive(Clone)]
94pub struct FullDensity;
95
96impl AsRef<FullDensity> for FullDensity {
97 fn as_ref(&self) -> &FullDensity {
98 self
99 }
100}
101
102impl<'a> QueryDensity for &'a FullDensity {
103 type Iter = iter::Repeat<bool>;
104
105 fn iter(self) -> Self::Iter {
106 iter::repeat(true)
107 }
108
109 fn get_query_size(self) -> Option<usize> {
110 None
111 }
112
113 fn generate_exps<F: PrimeField>(self, exponents: Arc<Vec<F::Repr>>) -> Arc<Vec<F::Repr>> {
114 exponents
115 }
116}
117
118#[derive(Clone, PartialEq, Eq, Debug, Default)]
119pub struct DensityTracker {
120 pub bv: BitVec,
121 pub total_density: usize,
122}
123
124impl<'a> QueryDensity for &'a DensityTracker {
125 type Iter = bitvec::slice::BitValIter<'a, usize, Lsb0>;
126
127 fn iter(self) -> Self::Iter {
128 self.bv.iter().by_vals()
129 }
130
131 fn get_query_size(self) -> Option<usize> {
132 Some(self.bv.len())
133 }
134
135 fn generate_exps<F: PrimeField>(self, exponents: Arc<Vec<F::Repr>>) -> Arc<Vec<F::Repr>> {
136 let exps: Vec<_> = exponents
137 .iter()
138 .zip(self.bv.iter())
139 .filter_map(|(&e, d)| if *d { Some(e) } else { None })
140 .collect();
141
142 Arc::new(exps)
143 }
144}
145
146impl DensityTracker {
147 pub fn new() -> DensityTracker {
148 DensityTracker {
149 bv: BitVec::new(),
150 total_density: 0,
151 }
152 }
153
154 pub fn add_element(&mut self) {
155 self.bv.push(false);
156 }
157
158 pub fn inc(&mut self, idx: usize) {
159 if !self.bv.get(idx).unwrap() {
160 self.bv.set(idx, true);
161 self.total_density += 1;
162 }
163 }
164
165 pub fn get_total_density(&self) -> usize {
166 self.total_density
167 }
168
169 pub fn extend(&mut self, other: &Self, is_input_density: bool) {
172 if other.bv.is_empty() {
173 return;
175 }
176
177 if self.bv.is_empty() {
178 self.total_density = other.total_density;
180 self.bv.resize(other.bv.len(), false);
181 self.bv.copy_from_bitslice(&*other.bv);
182 return;
183 }
184
185 if is_input_density {
186 if other.bv[0] {
189 if self.bv[0] {
191 self.total_density -= 1;
193 } else {
194 self.bv.set(0, true);
196 }
197 }
198 self.bv.extend(other.bv.iter().skip(1));
200 } else {
201 self.bv.extend(other.bv.iter());
203 }
204
205 self.total_density += other.total_density;
207 }
208}
209
210fn shr(le_bytes: &mut [u8], mut n: u32) {
212 if n >= 8 * le_bytes.len() as u32 {
213 le_bytes.iter_mut().for_each(|byte| *byte = 0);
214 return;
215 }
216
217 while n >= 8 {
219 let mut replacement = 0;
220 for byte in le_bytes.iter_mut().rev() {
221 std::mem::swap(&mut replacement, byte);
222 }
223 n -= 8;
224 }
225
226 if n > 0 {
229 let mut shift_in = 0;
230 for byte in le_bytes.iter_mut().rev() {
231 let shift_out = *byte << (8 - n);
233 *byte >>= n;
235 *byte |= shift_in;
237 shift_in = shift_out;
238 }
239 }
240}
241
242fn multiexp_inner<Q, D, G, S>(
243 bases: S,
244 density_map: D,
245 exponents: Arc<Vec<<G::Scalar as PrimeField>::Repr>>,
246 c: u32,
247) -> Result<<G as PrimeCurveAffine>::Curve, EcError>
248where
249 for<'a> &'a Q: QueryDensity,
250 D: Send + Sync + 'static + Clone + AsRef<Q>,
251 G: PrimeCurveAffine,
252 S: SourceBuilder<G>,
253{
254 let this = move |bases: S,
256 density_map: D,
257 exponents: Arc<Vec<<G::Scalar as PrimeField>::Repr>>,
258 skip: u32|
259 -> Result<_, EcError> {
260 let mut acc = G::Curve::identity();
262
263 let mut bases = bases.new();
265
266 let mut buckets = vec![<G as PrimeCurveAffine>::Curve::identity(); (1 << c) - 1];
268
269 let zero = G::Scalar::ZERO.to_repr();
270 let one = G::Scalar::ONE.to_repr();
271
272 let handle_trivial = skip == 0;
274
275 for (&exp, density) in exponents.iter().zip(density_map.as_ref().iter()) {
277 if density {
278 if exp.as_ref() == zero.as_ref() {
279 bases.skip(1)?;
280 } else if exp.as_ref() == one.as_ref() {
281 if handle_trivial {
282 bases.add_assign_mixed(&mut acc)?;
283 } else {
284 bases.skip(1)?;
285 }
286 } else {
287 let mut exp = exp;
288 shr(exp.as_mut(), skip);
289 let exp = u64::from_le_bytes(exp.as_ref()[..8].try_into().unwrap()) % (1 << c);
290
291 if exp != 0 {
292 bases.add_assign_mixed(&mut buckets[(exp - 1) as usize])?;
293 } else {
294 bases.skip(1)?;
295 }
296 }
297 }
298 }
299
300 let mut running_sum = G::Curve::identity();
305 for exp in buckets.into_iter().rev() {
306 running_sum.add_assign(&exp);
307 acc.add_assign(&running_sum);
308 }
309
310 Ok(acc)
311 };
312
313 let parts = (0..<G::Scalar as PrimeField>::NUM_BITS)
314 .into_par_iter()
315 .step_by(c as usize)
316 .map(|skip| this(bases.clone(), density_map.clone(), exponents.clone(), skip))
317 .collect::<Vec<Result<_, _>>>();
318
319 parts.into_iter().rev().try_fold(
320 <G as PrimeCurveAffine>::Curve::identity(),
321 |mut acc, part| {
322 for _ in 0..c {
323 acc = acc.double();
324 }
325
326 acc.add_assign(&part?);
327 Ok(acc)
328 },
329 )
330}
331
332pub fn multiexp_cpu<'b, Q, D, G, S>(
335 pool: &Worker,
336 bases: S,
337 density_map: D,
338 exponents: Arc<Vec<<G::Scalar as PrimeField>::Repr>>,
339) -> Waiter<Result<<G as PrimeCurveAffine>::Curve, EcError>>
340where
341 for<'a> &'a Q: QueryDensity,
342 D: Send + Sync + 'static + Clone + AsRef<Q>,
343 G: PrimeCurveAffine,
344 S: SourceBuilder<G>,
345{
346 let c = if exponents.len() < 32 {
347 3u32
348 } else {
349 (f64::from(exponents.len() as u32)).ln().ceil() as u32
350 };
351
352 if let Some(query_size) = density_map.as_ref().get_query_size() {
353 assert!(query_size == exponents.len());
356 }
357
358 pool.compute(move || multiexp_inner(bases, density_map, exponents, c))
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364
365 use blstrs::Bls12;
366 use group::Curve;
367 use pairing::Engine;
368 use rand::Rng;
369 use rand_core::SeedableRng;
370 use rand_xorshift::XorShiftRng;
371
372 #[test]
373 fn test_with_bls12() {
374 fn naive_multiexp<G: PrimeCurveAffine>(
375 bases: Arc<Vec<G>>,
376 exponents: &[G::Scalar],
377 ) -> G::Curve {
378 assert_eq!(bases.len(), exponents.len());
379
380 let mut acc = G::Curve::identity();
381
382 for (base, exp) in bases.iter().zip(exponents.iter()) {
383 acc.add_assign(&base.mul(*exp));
384 }
385
386 acc
387 }
388
389 const SAMPLES: usize = 1 << 14;
390
391 let rng = &mut rand::thread_rng();
392 let v: Vec<<Bls12 as Engine>::Fr> = (0..SAMPLES)
393 .map(|_| <Bls12 as Engine>::Fr::random(&mut *rng))
394 .collect();
395 let g = Arc::new(
396 (0..SAMPLES)
397 .map(|_| <Bls12 as Engine>::G1::random(&mut *rng).to_affine())
398 .collect::<Vec<_>>(),
399 );
400
401 let now = std::time::Instant::now();
402 let naive = naive_multiexp(g.clone(), &v);
403 println!("Naive: {}", now.elapsed().as_millis());
404
405 let now = std::time::Instant::now();
406 let pool = Worker::new();
407
408 let v = Arc::new(v.into_iter().map(|fr| fr.to_repr()).collect());
409 let fast = multiexp_cpu(&pool, (g, 0), FullDensity, v).wait().unwrap();
410
411 println!("Fast: {}", now.elapsed().as_millis());
412
413 assert_eq!(naive, fast);
414 }
415
416 #[test]
417 fn test_extend_density_regular() {
418 let mut rng = XorShiftRng::from_seed([
419 0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06,
420 0xbc, 0xe5,
421 ]);
422
423 for k in &[2, 4, 8] {
424 for j in &[10, 20, 50] {
425 let count: usize = k * j;
426
427 let mut tracker_full = DensityTracker::new();
428 let mut partial_trackers: Vec<DensityTracker> = Vec::with_capacity(count / k);
429 for i in 0..count {
430 if i % k == 0 {
431 partial_trackers.push(DensityTracker::new());
432 }
433
434 let index: usize = i / k;
435 if rng.gen() {
436 tracker_full.add_element();
437 partial_trackers[index].add_element();
438 }
439
440 if !partial_trackers[index].bv.is_empty() {
441 let idx = rng.gen_range(0..partial_trackers[index].bv.len());
442 let offset: usize = partial_trackers
443 .iter()
444 .take(index)
445 .map(|t| t.bv.len())
446 .sum();
447 tracker_full.inc(offset + idx);
448 partial_trackers[index].inc(idx);
449 }
450 }
451
452 let mut tracker_combined = DensityTracker::new();
453 for tracker in partial_trackers.into_iter() {
454 tracker_combined.extend(&tracker, false);
455 }
456 assert_eq!(tracker_combined, tracker_full);
457 }
458 }
459 }
460
461 #[test]
462 fn test_extend_density_input() {
463 let mut rng = XorShiftRng::from_seed([
464 0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06,
465 0xbc, 0xe5,
466 ]);
467 let trials = 10;
468 let max_bits = 10;
469 let max_density = max_bits;
470
471 let empty = DensityTracker::new;
473
474 let unset = |rng: &mut XorShiftRng| {
476 let mut dt = DensityTracker::new();
477 dt.add_element();
478 let n = rng.gen_range(1..max_bits);
479 let target_density = rng.gen_range(0..max_density);
480 for _ in 1..n {
481 dt.add_element();
482 }
483
484 for _ in 0..target_density {
485 if n > 1 {
486 let to_inc = rng.gen_range(1..n);
487 dt.inc(to_inc);
488 }
489 }
490 assert!(!dt.bv[0]);
491 assert_eq!(n, dt.bv.len());
492 dbg!(&target_density, &dt.total_density);
493
494 dt
495 };
496
497 let set = |rng: &mut XorShiftRng| {
499 let mut dt = unset(rng);
500 dt.inc(0);
501 dt
502 };
503
504 for _ in 0..trials {
505 {
506 let (mut e1, e2) = (empty(), empty());
508 e1.extend(&e2, true);
509 assert_eq!(empty(), e1);
510 }
511 {
512 let (mut e1, u1) = (empty(), unset(&mut rng));
514 e1.extend(&u1.clone(), true);
515 assert_eq!(u1, e1);
516 }
517 {
518 let (mut e1, s1) = (empty(), set(&mut rng));
520 e1.extend(&s1.clone(), true);
521 assert_eq!(s1, e1);
522 }
523 {
524 let (mut s1, e1) = (set(&mut rng), empty());
526 let s2 = s1.clone();
527 s1.extend(&e1, true);
528 assert_eq!(s1, s2);
529 }
530 {
531 let (mut u1, e1) = (unset(&mut rng), empty());
533 let u2 = u1.clone();
534 u1.extend(&e1, true);
535 assert_eq!(u1, u2);
536 }
537 {
538 let (mut u1, u2) = (unset(&mut rng), unset(&mut rng));
540 let expected_total = u1.total_density + u2.total_density;
541 u1.extend(&u2, true);
542 assert_eq!(expected_total, u1.total_density);
543 assert!(!u1.bv[0]);
544 }
545 {
546 let (mut u1, s1) = (unset(&mut rng), set(&mut rng));
548 let expected_total = u1.total_density + s1.total_density;
549 u1.extend(&s1, true);
550 assert_eq!(expected_total, u1.total_density);
551 assert!(u1.bv[0]);
552 }
553 {
554 let (mut s1, u1) = (set(&mut rng), unset(&mut rng));
556 let expected_total = s1.total_density + u1.total_density;
557 s1.extend(&u1, true);
558 assert_eq!(expected_total, s1.total_density);
559 assert!(s1.bv[0]);
560 }
561 {
562 let (mut s1, s2) = (set(&mut rng), set(&mut rng));
564 let expected_total = s1.total_density + s2.total_density - 1;
565 s1.extend(&s2, true);
566 assert_eq!(expected_total, s1.total_density);
567 assert!(s1.bv[0]);
568 }
569 }
570 }
571}