ec_gpu_gen/
multiexp_cpu.rs

1#![allow(missing_docs)]
2use std::convert::TryInto;
3use std::io;
4use std::iter;
5use std::ops::AddAssign;
6use std::sync::Arc;
7
8use bitvec::prelude::{BitVec, Lsb0};
9use ff::{Field, PrimeField};
10use group::{prime::PrimeCurveAffine, Group};
11use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
12
13use crate::error::EcError;
14use crate::threadpool::{Waiter, Worker};
15
16/// An object that builds a source of bases.
17pub trait SourceBuilder<G: PrimeCurveAffine>: Send + Sync + 'static + Clone {
18    type Source: Source<G>;
19
20    #[allow(clippy::wrong_self_convention)]
21    fn new(self) -> Self::Source;
22    fn get(self) -> (Arc<Vec<G>>, usize);
23}
24
25/// A source of bases, like an iterator.
26pub trait Source<G: PrimeCurveAffine> {
27    /// Parses the element from the source. Fails if the point is at infinity.
28    fn add_assign_mixed(&mut self, to: &mut <G as PrimeCurveAffine>::Curve) -> Result<(), EcError>;
29
30    /// Skips `amt` elements from the source, avoiding deserialization.
31    fn skip(&mut self, amt: usize) -> Result<(), EcError>;
32}
33
34impl<G: PrimeCurveAffine> SourceBuilder<G> for (Arc<Vec<G>>, usize) {
35    type Source = (Arc<Vec<G>>, usize);
36
37    fn new(self) -> (Arc<Vec<G>>, usize) {
38        (self.0.clone(), self.1)
39    }
40
41    fn get(self) -> (Arc<Vec<G>>, usize) {
42        (self.0.clone(), self.1)
43    }
44}
45
46impl<G: PrimeCurveAffine> Source<G> for (Arc<Vec<G>>, usize) {
47    fn add_assign_mixed(&mut self, to: &mut <G as PrimeCurveAffine>::Curve) -> Result<(), EcError> {
48        if self.0.len() <= self.1 {
49            return Err(io::Error::new(
50                io::ErrorKind::UnexpectedEof,
51                "Expected more bases from source.",
52            )
53            .into());
54        }
55
56        if self.0[self.1].is_identity().into() {
57            return Err(EcError::Simple(
58                "Encountered an identity element in the CRS.",
59            ));
60        }
61
62        to.add_assign(&self.0[self.1]);
63
64        self.1 += 1;
65
66        Ok(())
67    }
68
69    fn skip(&mut self, amt: usize) -> Result<(), EcError> {
70        if self.0.len() <= self.1 {
71            return Err(io::Error::new(
72                io::ErrorKind::UnexpectedEof,
73                "Expected more bases from source.",
74            )
75            .into());
76        }
77
78        self.1 += amt;
79
80        Ok(())
81    }
82}
83
84pub trait QueryDensity: Sized {
85    /// Returns whether the base exists.
86    type Iter: Iterator<Item = bool>;
87
88    fn iter(self) -> Self::Iter;
89    fn get_query_size(self) -> Option<usize>;
90    fn generate_exps<F: PrimeField>(self, exponents: Arc<Vec<F::Repr>>) -> Arc<Vec<F::Repr>>;
91}
92
93#[derive(Clone)]
94pub struct FullDensity;
95
96impl AsRef<FullDensity> for FullDensity {
97    fn as_ref(&self) -> &FullDensity {
98        self
99    }
100}
101
102impl<'a> QueryDensity for &'a FullDensity {
103    type Iter = iter::Repeat<bool>;
104
105    fn iter(self) -> Self::Iter {
106        iter::repeat(true)
107    }
108
109    fn get_query_size(self) -> Option<usize> {
110        None
111    }
112
113    fn generate_exps<F: PrimeField>(self, exponents: Arc<Vec<F::Repr>>) -> Arc<Vec<F::Repr>> {
114        exponents
115    }
116}
117
118#[derive(Clone, PartialEq, Eq, Debug, Default)]
119pub struct DensityTracker {
120    pub bv: BitVec,
121    pub total_density: usize,
122}
123
124impl<'a> QueryDensity for &'a DensityTracker {
125    type Iter = bitvec::slice::BitValIter<'a, usize, Lsb0>;
126
127    fn iter(self) -> Self::Iter {
128        self.bv.iter().by_vals()
129    }
130
131    fn get_query_size(self) -> Option<usize> {
132        Some(self.bv.len())
133    }
134
135    fn generate_exps<F: PrimeField>(self, exponents: Arc<Vec<F::Repr>>) -> Arc<Vec<F::Repr>> {
136        let exps: Vec<_> = exponents
137            .iter()
138            .zip(self.bv.iter())
139            .filter_map(|(&e, d)| if *d { Some(e) } else { None })
140            .collect();
141
142        Arc::new(exps)
143    }
144}
145
146impl DensityTracker {
147    pub fn new() -> DensityTracker {
148        DensityTracker {
149            bv: BitVec::new(),
150            total_density: 0,
151        }
152    }
153
154    pub fn add_element(&mut self) {
155        self.bv.push(false);
156    }
157
158    pub fn inc(&mut self, idx: usize) {
159        if !self.bv.get(idx).unwrap() {
160            self.bv.set(idx, true);
161            self.total_density += 1;
162        }
163    }
164
165    pub fn get_total_density(&self) -> usize {
166        self.total_density
167    }
168
169    /// Extend by concatenating `other`. If `is_input_density` is true, then we are tracking an input density,
170    /// and other may contain a redundant input for the `One` element. Coalesce those as needed and track the result.
171    pub fn extend(&mut self, other: &Self, is_input_density: bool) {
172        if other.bv.is_empty() {
173            // Nothing to do if other is empty.
174            return;
175        }
176
177        if self.bv.is_empty() {
178            // If self is empty, assume other's density.
179            self.total_density = other.total_density;
180            self.bv.resize(other.bv.len(), false);
181            self.bv.copy_from_bitslice(&*other.bv);
182            return;
183        }
184
185        if is_input_density {
186            // Input densities need special handling to coalesce their first inputs.
187
188            if other.bv[0] {
189                // If other's first bit is set,
190                if self.bv[0] {
191                    // And own first bit is set, then decrement total density so the final sum doesn't overcount.
192                    self.total_density -= 1;
193                } else {
194                    // Otherwise, set own first bit.
195                    self.bv.set(0, true);
196                }
197            }
198            // Now discard other's first bit, having accounted for it above, and extend self by remaining bits.
199            self.bv.extend(other.bv.iter().skip(1));
200        } else {
201            // Not an input density, just extend straightforwardly.
202            self.bv.extend(other.bv.iter());
203        }
204
205        // Since any needed adjustments to total densities have been made, just sum the totals and keep the sum.
206        self.total_density += other.total_density;
207    }
208}
209
210// Right shift the repr of a field element by `n` bits.
211fn shr(le_bytes: &mut [u8], mut n: u32) {
212    if n >= 8 * le_bytes.len() as u32 {
213        le_bytes.iter_mut().for_each(|byte| *byte = 0);
214        return;
215    }
216
217    // Shift each full byte towards the least significant end.
218    while n >= 8 {
219        let mut replacement = 0;
220        for byte in le_bytes.iter_mut().rev() {
221            std::mem::swap(&mut replacement, byte);
222        }
223        n -= 8;
224    }
225
226    // Starting at the most significant byte, shift the byte's `n` least significant bits into the
227    // `n` most signficant bits of the next byte.
228    if n > 0 {
229        let mut shift_in = 0;
230        for byte in le_bytes.iter_mut().rev() {
231            // Copy the byte's `n` least significant bits.
232            let shift_out = *byte << (8 - n);
233            // Shift the byte by `n` bits; zeroing its `n` most significant bits.
234            *byte >>= n;
235            // Replace the `n` most significant bits with the bits shifted out of the previous byte.
236            *byte |= shift_in;
237            shift_in = shift_out;
238        }
239    }
240}
241
242fn multiexp_inner<Q, D, G, S>(
243    bases: S,
244    density_map: D,
245    exponents: Arc<Vec<<G::Scalar as PrimeField>::Repr>>,
246    c: u32,
247) -> Result<<G as PrimeCurveAffine>::Curve, EcError>
248where
249    for<'a> &'a Q: QueryDensity,
250    D: Send + Sync + 'static + Clone + AsRef<Q>,
251    G: PrimeCurveAffine,
252    S: SourceBuilder<G>,
253{
254    // Perform this region of the multiexp
255    let this = move |bases: S,
256                     density_map: D,
257                     exponents: Arc<Vec<<G::Scalar as PrimeField>::Repr>>,
258                     skip: u32|
259          -> Result<_, EcError> {
260        // Accumulate the result
261        let mut acc = G::Curve::identity();
262
263        // Build a source for the bases
264        let mut bases = bases.new();
265
266        // Create space for the buckets
267        let mut buckets = vec![<G as PrimeCurveAffine>::Curve::identity(); (1 << c) - 1];
268
269        let zero = G::Scalar::ZERO.to_repr();
270        let one = G::Scalar::ONE.to_repr();
271
272        // only the first round uses this
273        let handle_trivial = skip == 0;
274
275        // Sort the bases into buckets
276        for (&exp, density) in exponents.iter().zip(density_map.as_ref().iter()) {
277            if density {
278                if exp.as_ref() == zero.as_ref() {
279                    bases.skip(1)?;
280                } else if exp.as_ref() == one.as_ref() {
281                    if handle_trivial {
282                        bases.add_assign_mixed(&mut acc)?;
283                    } else {
284                        bases.skip(1)?;
285                    }
286                } else {
287                    let mut exp = exp;
288                    shr(exp.as_mut(), skip);
289                    let exp = u64::from_le_bytes(exp.as_ref()[..8].try_into().unwrap()) % (1 << c);
290
291                    if exp != 0 {
292                        bases.add_assign_mixed(&mut buckets[(exp - 1) as usize])?;
293                    } else {
294                        bases.skip(1)?;
295                    }
296                }
297            }
298        }
299
300        // Summation by parts
301        // e.g. 3a + 2b + 1c = a +
302        //                    (a) + b +
303        //                    ((a) + b) + c
304        let mut running_sum = G::Curve::identity();
305        for exp in buckets.into_iter().rev() {
306            running_sum.add_assign(&exp);
307            acc.add_assign(&running_sum);
308        }
309
310        Ok(acc)
311    };
312
313    let parts = (0..<G::Scalar as PrimeField>::NUM_BITS)
314        .into_par_iter()
315        .step_by(c as usize)
316        .map(|skip| this(bases.clone(), density_map.clone(), exponents.clone(), skip))
317        .collect::<Vec<Result<_, _>>>();
318
319    parts.into_iter().rev().try_fold(
320        <G as PrimeCurveAffine>::Curve::identity(),
321        |mut acc, part| {
322            for _ in 0..c {
323                acc = acc.double();
324            }
325
326            acc.add_assign(&part?);
327            Ok(acc)
328        },
329    )
330}
331
332/// Perform multi-exponentiation. The caller is responsible for ensuring the
333/// query size is the same as the number of exponents.
334pub fn multiexp_cpu<'b, Q, D, G, S>(
335    pool: &Worker,
336    bases: S,
337    density_map: D,
338    exponents: Arc<Vec<<G::Scalar as PrimeField>::Repr>>,
339) -> Waiter<Result<<G as PrimeCurveAffine>::Curve, EcError>>
340where
341    for<'a> &'a Q: QueryDensity,
342    D: Send + Sync + 'static + Clone + AsRef<Q>,
343    G: PrimeCurveAffine,
344    S: SourceBuilder<G>,
345{
346    let c = if exponents.len() < 32 {
347        3u32
348    } else {
349        (f64::from(exponents.len() as u32)).ln().ceil() as u32
350    };
351
352    if let Some(query_size) = density_map.as_ref().get_query_size() {
353        // If the density map has a known query size, it should not be
354        // inconsistent with the number of exponents.
355        assert!(query_size == exponents.len());
356    }
357
358    pool.compute(move || multiexp_inner(bases, density_map, exponents, c))
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    use blstrs::Bls12;
366    use group::Curve;
367    use pairing::Engine;
368    use rand::Rng;
369    use rand_core::SeedableRng;
370    use rand_xorshift::XorShiftRng;
371
372    #[test]
373    fn test_with_bls12() {
374        fn naive_multiexp<G: PrimeCurveAffine>(
375            bases: Arc<Vec<G>>,
376            exponents: &[G::Scalar],
377        ) -> G::Curve {
378            assert_eq!(bases.len(), exponents.len());
379
380            let mut acc = G::Curve::identity();
381
382            for (base, exp) in bases.iter().zip(exponents.iter()) {
383                acc.add_assign(&base.mul(*exp));
384            }
385
386            acc
387        }
388
389        const SAMPLES: usize = 1 << 14;
390
391        let rng = &mut rand::thread_rng();
392        let v: Vec<<Bls12 as Engine>::Fr> = (0..SAMPLES)
393            .map(|_| <Bls12 as Engine>::Fr::random(&mut *rng))
394            .collect();
395        let g = Arc::new(
396            (0..SAMPLES)
397                .map(|_| <Bls12 as Engine>::G1::random(&mut *rng).to_affine())
398                .collect::<Vec<_>>(),
399        );
400
401        let now = std::time::Instant::now();
402        let naive = naive_multiexp(g.clone(), &v);
403        println!("Naive: {}", now.elapsed().as_millis());
404
405        let now = std::time::Instant::now();
406        let pool = Worker::new();
407
408        let v = Arc::new(v.into_iter().map(|fr| fr.to_repr()).collect());
409        let fast = multiexp_cpu(&pool, (g, 0), FullDensity, v).wait().unwrap();
410
411        println!("Fast: {}", now.elapsed().as_millis());
412
413        assert_eq!(naive, fast);
414    }
415
416    #[test]
417    fn test_extend_density_regular() {
418        let mut rng = XorShiftRng::from_seed([
419            0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06,
420            0xbc, 0xe5,
421        ]);
422
423        for k in &[2, 4, 8] {
424            for j in &[10, 20, 50] {
425                let count: usize = k * j;
426
427                let mut tracker_full = DensityTracker::new();
428                let mut partial_trackers: Vec<DensityTracker> = Vec::with_capacity(count / k);
429                for i in 0..count {
430                    if i % k == 0 {
431                        partial_trackers.push(DensityTracker::new());
432                    }
433
434                    let index: usize = i / k;
435                    if rng.gen() {
436                        tracker_full.add_element();
437                        partial_trackers[index].add_element();
438                    }
439
440                    if !partial_trackers[index].bv.is_empty() {
441                        let idx = rng.gen_range(0..partial_trackers[index].bv.len());
442                        let offset: usize = partial_trackers
443                            .iter()
444                            .take(index)
445                            .map(|t| t.bv.len())
446                            .sum();
447                        tracker_full.inc(offset + idx);
448                        partial_trackers[index].inc(idx);
449                    }
450                }
451
452                let mut tracker_combined = DensityTracker::new();
453                for tracker in partial_trackers.into_iter() {
454                    tracker_combined.extend(&tracker, false);
455                }
456                assert_eq!(tracker_combined, tracker_full);
457            }
458        }
459    }
460
461    #[test]
462    fn test_extend_density_input() {
463        let mut rng = XorShiftRng::from_seed([
464            0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06,
465            0xbc, 0xe5,
466        ]);
467        let trials = 10;
468        let max_bits = 10;
469        let max_density = max_bits;
470
471        // Create an empty DensityTracker.
472        let empty = DensityTracker::new;
473
474        // Create a random DensityTracker with first bit unset.
475        let unset = |rng: &mut XorShiftRng| {
476            let mut dt = DensityTracker::new();
477            dt.add_element();
478            let n = rng.gen_range(1..max_bits);
479            let target_density = rng.gen_range(0..max_density);
480            for _ in 1..n {
481                dt.add_element();
482            }
483
484            for _ in 0..target_density {
485                if n > 1 {
486                    let to_inc = rng.gen_range(1..n);
487                    dt.inc(to_inc);
488                }
489            }
490            assert!(!dt.bv[0]);
491            assert_eq!(n, dt.bv.len());
492            dbg!(&target_density, &dt.total_density);
493
494            dt
495        };
496
497        // Create a random DensityTracker with first bit set.
498        let set = |rng: &mut XorShiftRng| {
499            let mut dt = unset(rng);
500            dt.inc(0);
501            dt
502        };
503
504        for _ in 0..trials {
505            {
506                // Both empty.
507                let (mut e1, e2) = (empty(), empty());
508                e1.extend(&e2, true);
509                assert_eq!(empty(), e1);
510            }
511            {
512                // First empty, second unset.
513                let (mut e1, u1) = (empty(), unset(&mut rng));
514                e1.extend(&u1.clone(), true);
515                assert_eq!(u1, e1);
516            }
517            {
518                // First empty, second set.
519                let (mut e1, s1) = (empty(), set(&mut rng));
520                e1.extend(&s1.clone(), true);
521                assert_eq!(s1, e1);
522            }
523            {
524                // First set, second empty.
525                let (mut s1, e1) = (set(&mut rng), empty());
526                let s2 = s1.clone();
527                s1.extend(&e1, true);
528                assert_eq!(s1, s2);
529            }
530            {
531                // First unset, second empty.
532                let (mut u1, e1) = (unset(&mut rng), empty());
533                let u2 = u1.clone();
534                u1.extend(&e1, true);
535                assert_eq!(u1, u2);
536            }
537            {
538                // First unset, second unset.
539                let (mut u1, u2) = (unset(&mut rng), unset(&mut rng));
540                let expected_total = u1.total_density + u2.total_density;
541                u1.extend(&u2, true);
542                assert_eq!(expected_total, u1.total_density);
543                assert!(!u1.bv[0]);
544            }
545            {
546                // First unset, second set.
547                let (mut u1, s1) = (unset(&mut rng), set(&mut rng));
548                let expected_total = u1.total_density + s1.total_density;
549                u1.extend(&s1, true);
550                assert_eq!(expected_total, u1.total_density);
551                assert!(u1.bv[0]);
552            }
553            {
554                // First set, second unset.
555                let (mut s1, u1) = (set(&mut rng), unset(&mut rng));
556                let expected_total = s1.total_density + u1.total_density;
557                s1.extend(&u1, true);
558                assert_eq!(expected_total, s1.total_density);
559                assert!(s1.bv[0]);
560            }
561            {
562                // First set, second set.
563                let (mut s1, s2) = (set(&mut rng), set(&mut rng));
564                let expected_total = s1.total_density + s2.total_density - 1;
565                s1.extend(&s2, true);
566                assert_eq!(expected_total, s1.total_density);
567                assert!(s1.bv[0]);
568            }
569        }
570    }
571}