use super::multicore::{Waiter, Worker};
use bitvec::vec::BitVec;
use ff::{FieldBits, PrimeField, PrimeFieldBits};
use group::prime::{PrimeCurve, PrimeCurveAffine};
use std::io;
use std::iter;
use std::ops::AddAssign;
use std::sync::Arc;
#[cfg(feature = "multicore")]
use rayon::prelude::*;
#[cfg(not(feature = "multicore"))]
use crate::multicore::FakeParallelIterator;
use super::SynthesisError;
pub trait SourceBuilder<G: PrimeCurveAffine>: Send + Sync + 'static + Clone {
type Source: Source<G>;
fn build(self) -> Self::Source;
}
pub trait Source<G: PrimeCurveAffine> {
fn next(&mut self) -> Result<&G, SynthesisError>;
fn skip(&mut self, amt: usize) -> Result<(), SynthesisError>;
}
pub trait AddAssignFromSource: PrimeCurve {
fn add_assign_from_source<S: Source<<Self as PrimeCurve>::Affine>>(
&mut self,
source: &mut S,
) -> Result<(), SynthesisError> {
AddAssign::<&<Self as PrimeCurve>::Affine>::add_assign(self, source.next()?);
Ok(())
}
}
impl<G> AddAssignFromSource for G where G: PrimeCurve {}
impl<G: PrimeCurveAffine> SourceBuilder<G> for (Arc<Vec<G>>, usize) {
type Source = (Arc<Vec<G>>, usize);
fn build(self) -> (Arc<Vec<G>>, usize) {
(self.0.clone(), self.1)
}
}
impl<G: PrimeCurveAffine> Source<G> for (Arc<Vec<G>>, usize) {
fn next(&mut self) -> Result<&G, SynthesisError> {
if self.0.len() <= self.1 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"expected more bases from source",
)
.into());
}
if self.0[self.1].is_identity().into() {
return Err(SynthesisError::UnexpectedIdentity);
}
let ret = &self.0[self.1];
self.1 += 1;
Ok(ret)
}
fn skip(&mut self, amt: usize) -> Result<(), SynthesisError> {
if self.0.len() <= self.1 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"expected more bases from source",
)
.into());
}
self.1 += amt;
Ok(())
}
}
pub trait QueryDensity {
type Iter: Iterator<Item = bool>;
fn iter(self) -> Self::Iter;
fn get_query_size(self) -> Option<usize>;
}
#[derive(Clone)]
pub struct FullDensity;
impl AsRef<FullDensity> for FullDensity {
fn as_ref(&self) -> &FullDensity {
self
}
}
impl<'a> QueryDensity for &'a FullDensity {
type Iter = iter::Repeat<bool>;
fn iter(self) -> Self::Iter {
iter::repeat(true)
}
fn get_query_size(self) -> Option<usize> {
None
}
}
pub struct DensityTracker {
bv: BitVec,
}
impl<'a> QueryDensity for &'a DensityTracker {
type Iter = Box<dyn 'a + Iterator<Item = bool>>;
fn iter(self) -> Self::Iter {
Box::new(self.bv.iter().by_vals())
}
fn get_query_size(self) -> Option<usize> {
Some(self.bv.len())
}
}
impl Default for DensityTracker {
fn default() -> Self {
Self::new()
}
}
impl DensityTracker {
pub fn new() -> DensityTracker {
DensityTracker { bv: BitVec::new() }
}
pub fn add_element(&mut self) {
self.bv.push(false);
}
pub fn inc(&mut self, idx: usize) {
if !self.bv.get(idx).unwrap() {
self.bv.set(idx, true);
}
}
pub fn get_total_density(&self) -> usize {
self.bv.count_ones()
}
}
enum ChunkedExponent {
Zero,
One,
Chunks(Vec<u64>),
}
pub enum Exponent<F: PrimeFieldBits> {
Zero,
One,
Bits(FieldBits<F::ReprBits>),
}
impl<F: PrimeFieldBits> From<&F> for Exponent<F> {
fn from(exp: &F) -> Self {
if exp.is_zero_vartime() {
Exponent::Zero
} else if exp == &F::one() {
Exponent::One
} else {
Exponent::Bits(exp.to_le_bits())
}
}
}
impl<F: PrimeFieldBits> From<F> for Exponent<F> {
fn from(exp: F) -> Self {
(&exp).into()
}
}
impl<F: PrimeFieldBits> Exponent<F> {
fn chunks(&self, c: usize) -> ChunkedExponent {
match self {
Self::Zero => ChunkedExponent::Zero,
Self::One => ChunkedExponent::One,
Self::Bits(exp) => ChunkedExponent::Chunks(
exp.chunks(c)
.map(|chunk| {
chunk
.iter()
.by_vals()
.enumerate()
.fold(0u64, |acc, (i, b)| acc + ((b as u64) << i))
})
.collect(),
),
}
}
}
fn multiexp_inner<Q, D, G, S>(
bases: S,
density_map: D,
exponents: Arc<Vec<Exponent<G::Scalar>>>,
c: u32,
) -> Result<G, SynthesisError>
where
for<'a> &'a Q: QueryDensity,
D: Send + Sync + 'static + Clone + AsRef<Q>,
G: PrimeCurve,
G::Scalar: PrimeFieldBits,
S: SourceBuilder<<G as PrimeCurve>::Affine>,
{
let this = move |bases: S,
density_map: D,
exponents: Arc<Vec<ChunkedExponent>>,
chunk: usize|
-> Result<_, SynthesisError> {
let mut acc = G::identity();
let mut bases = bases.build();
let mut buckets = vec![G::identity(); (1 << c) - 1];
let handle_trivial = chunk == 0;
for (exp, density) in exponents.iter().zip(density_map.as_ref().iter()) {
if density {
match exp {
ChunkedExponent::Zero => bases.skip(1)?,
ChunkedExponent::One => {
if handle_trivial {
acc.add_assign_from_source(&mut bases)?;
} else {
bases.skip(1)?;
}
}
ChunkedExponent::Chunks(chunks) => {
let exp = chunks[chunk];
if exp != 0 {
(&mut buckets[(exp - 1) as usize])
.add_assign_from_source(&mut bases)?;
} else {
bases.skip(1)?;
}
}
}
}
}
let mut running_sum = G::identity();
for exp in buckets.into_iter().rev() {
running_sum.add_assign(&exp);
acc.add_assign(&running_sum);
}
Ok(acc)
};
let exponents = Arc::new(
exponents
.iter()
.map(|exp| exp.chunks(c as usize))
.collect::<Vec<_>>(),
);
let parts = (0..G::Scalar::NUM_BITS)
.into_par_iter()
.step_by(c as usize)
.enumerate()
.map(|(chunk, _)| this(bases.clone(), density_map.clone(), exponents.clone(), chunk))
.collect::<Vec<Result<_, _>>>();
parts
.into_iter()
.rev()
.try_fold(G::identity(), |acc, part| {
part.map(|part| (0..c).fold(acc, |acc, _| acc.double()) + part)
})
}
pub fn multiexp<Q, D, G, S>(
pool: &Worker,
bases: S,
density_map: D,
exponents: Arc<Vec<Exponent<G::Scalar>>>,
) -> Waiter<Result<G, SynthesisError>>
where
for<'a> &'a Q: QueryDensity,
D: Send + Sync + 'static + Clone + AsRef<Q>,
G: PrimeCurve,
G::Scalar: PrimeFieldBits,
S: SourceBuilder<<G as PrimeCurve>::Affine>,
{
let c = if exponents.len() < 32 {
3u32
} else {
(f64::from(exponents.len() as u32)).ln().ceil() as u32
};
if let Some(query_size) = density_map.as_ref().get_query_size() {
assert!(query_size == exponents.len());
}
pool.compute(move || multiexp_inner(bases, density_map, exponents, c))
}
#[cfg(feature = "pairing")]
#[test]
fn test_with_bls12() {
fn naive_multiexp<G: PrimeCurve>(
bases: Arc<Vec<<G as PrimeCurve>::Affine>>,
exponents: Arc<Vec<G::Scalar>>,
) -> G {
assert_eq!(bases.len(), exponents.len());
let mut acc = G::identity();
for (base, exp) in bases.iter().zip(exponents.iter()) {
AddAssign::<&G>::add_assign(&mut acc, &(*base * *exp));
}
acc
}
use bls12_381::{Bls12, Scalar};
use ff::Field;
use group::{Curve, Group};
use pairing::Engine;
const SAMPLES: usize = 1 << 14;
let mut rng = rand::thread_rng();
let v = Arc::new(
(0..SAMPLES)
.map(|_| Scalar::random(&mut rng))
.collect::<Vec<_>>(),
);
let v_bits = Arc::new(v.iter().map(|e| e.into()).collect::<Vec<_>>());
let g = Arc::new(
(0..SAMPLES)
.map(|_| <Bls12 as Engine>::G1::random(&mut rng).to_affine())
.collect::<Vec<_>>(),
);
let naive: <Bls12 as Engine>::G1 = naive_multiexp(g.clone(), v);
let pool = Worker::new();
let fast = multiexp(&pool, (g, 0), FullDensity, v_bits).wait().unwrap();
assert_eq!(naive, fast);
}