use std::convert::TryInto;
use std::ops::AddAssign;
use ff::PrimeField;
use group::{Curve, Group, prime::PrimeCurveAffine};
use rayon::prelude::*;
pub const WINDOW_SIZE: usize = 8;
pub enum ScalarList<'a, G: PrimeCurveAffine, F: Fn(usize) -> <G::Scalar as PrimeField>::Repr> {
Slice(&'a [<G::Scalar as PrimeField>::Repr]),
Getter(F, usize),
}
impl<G: PrimeCurveAffine, F: Fn(usize) -> <G::Scalar as PrimeField>::Repr> ScalarList<'_, G, F> {
pub fn len(&self) -> usize {
match self {
ScalarList::Slice(s) => s.len(),
ScalarList::Getter(_, len) => *len,
}
}
}
pub type Getter<G> =
dyn Fn(usize) -> <<G as PrimeCurveAffine>::Scalar as PrimeField>::Repr + Sync + Send;
pub trait MultiscalarPrecomp<G: PrimeCurveAffine>: Send + Sync {
fn window_size(&self) -> usize;
fn window_mask(&self) -> u64;
fn tables(&self) -> &[Vec<G>];
fn at_point(&self, idx: usize) -> MultiscalarPrecompRef<'_, G>;
}
#[derive(Clone, Debug)]
pub struct MultiscalarPrecompOwned<G: PrimeCurveAffine> {
num_points: usize,
window_size: usize,
window_mask: u64,
table_entries: usize,
tables: Vec<Vec<G>>,
}
impl<G: PrimeCurveAffine> PartialEq for MultiscalarPrecompOwned<G> {
fn eq(&self, other: &Self) -> bool {
self.num_points == other.num_points
&& self.window_size == other.window_size
&& self.window_mask == other.window_mask
&& self.table_entries == other.table_entries
&& self
.tables
.par_iter()
.zip(other.tables.par_iter())
.all(|(a, b)| a == b)
}
}
impl<G: PrimeCurveAffine> MultiscalarPrecomp<G> for MultiscalarPrecompOwned<G> {
fn window_size(&self) -> usize {
self.window_size
}
fn window_mask(&self) -> u64 {
self.window_mask
}
fn tables(&self) -> &[Vec<G>] {
&self.tables
}
fn at_point(&self, idx: usize) -> MultiscalarPrecompRef<'_, G> {
MultiscalarPrecompRef {
num_points: self.num_points - idx,
window_size: self.window_size,
window_mask: self.window_mask,
table_entries: self.table_entries,
tables: &self.tables[idx..],
}
}
}
#[derive(Debug)]
pub struct MultiscalarPrecompRef<'a, G: PrimeCurveAffine> {
num_points: usize,
window_size: usize,
window_mask: u64,
table_entries: usize,
tables: &'a [Vec<G>],
}
impl<G: PrimeCurveAffine> MultiscalarPrecomp<G> for MultiscalarPrecompRef<'_, G> {
fn window_size(&self) -> usize {
self.window_size
}
fn window_mask(&self) -> u64 {
self.window_mask
}
fn tables(&self) -> &[Vec<G>] {
self.tables
}
fn at_point(&self, idx: usize) -> MultiscalarPrecompRef<'_, G> {
MultiscalarPrecompRef {
num_points: self.num_points - idx,
window_size: self.window_size,
window_mask: self.window_mask,
table_entries: self.table_entries,
tables: &self.tables[idx..],
}
}
}
pub fn precompute_fixed_window<G: PrimeCurveAffine>(
points: &[G],
window_size: usize,
) -> MultiscalarPrecompOwned<G> {
let table_entries = (1 << window_size) - 1;
let num_points = points.len();
let tables = points
.into_par_iter()
.map(|point| {
let mut table = Vec::with_capacity(table_entries);
table.push(*point);
let mut cur_precomp_point = point.to_curve();
for _ in 1..table_entries {
cur_precomp_point.add_assign(point);
table.push(cur_precomp_point.to_affine());
}
table
})
.collect();
MultiscalarPrecompOwned {
num_points,
window_size,
window_mask: (1 << window_size) - 1,
table_entries,
tables,
}
}
pub fn multiscalar<G: PrimeCurveAffine>(
k: &[<G::Scalar as ff::PrimeField>::Repr],
precomp_table: &dyn MultiscalarPrecomp<G>,
nbits: usize,
) -> G::Curve {
debug_assert_eq!(
std::mem::size_of::<<G::Scalar as ff::PrimeField>::Repr>() % 8,
0
);
const BITS_PER_LIMB: usize = std::mem::size_of::<u64>() * 8;
if !nbits.is_multiple_of(precomp_table.window_size())
|| !BITS_PER_LIMB.is_multiple_of(precomp_table.window_size())
{
panic!("Unsupported multiscalar window size!");
}
let mut result = G::Curve::identity();
let num_windows = nbits.div_ceil(precomp_table.window_size());
let mut idx;
for i in (0..num_windows).rev() {
let limb = (i * precomp_table.window_size()) / BITS_PER_LIMB;
let window_in_limb = i % (BITS_PER_LIMB / precomp_table.window_size());
for _ in 0..precomp_table.window_size() {
result = result.double();
}
let mut prev_idx = 0;
let mut prev_table: &Vec<G> = &precomp_table.tables()[0];
let mut table: &Vec<G> = &precomp_table.tables()[0];
for (m, point) in k.iter().enumerate() {
let point_limb =
u64::from_le_bytes(point.as_ref()[limb * 8..(limb + 1) * 8].try_into().unwrap());
idx = point_limb >> (window_in_limb * precomp_table.window_size())
& precomp_table.window_mask();
if idx > 0 {
table = &precomp_table.tables()[m];
prefetch(&table[idx as usize - 1]);
}
if prev_idx > 0 && m > 0 {
result.add_assign(&prev_table[prev_idx as usize - 1]);
}
prev_idx = idx;
prev_table = table;
}
if prev_idx > 0 {
result.add_assign(&prev_table[prev_idx as usize - 1]);
}
}
result
}
pub fn par_multiscalar<F, G: PrimeCurveAffine>(
points: &ScalarList<'_, G, F>,
precomp_table: &dyn MultiscalarPrecomp<G>,
nbits: usize,
) -> G::Curve
where
F: Fn(usize) -> <G::Scalar as PrimeField>::Repr + Sync,
{
let num_points = points.len();
let mut chunk_size = 16; if num_points > 1024 {
chunk_size = 256;
}
if chunk_size > num_points {
chunk_size = 1; }
let num_parts = num_points.div_ceil(chunk_size);
(0..num_parts)
.into_par_iter()
.map(|id| {
let mut scalar_storage: Vec<<G::Scalar as PrimeField>::Repr> = (0..chunk_size)
.map(|_| <G::Scalar as PrimeField>::Repr::default())
.collect();
let start_idx = id * chunk_size;
debug_assert!(start_idx < num_points);
let mut end_idx = start_idx + chunk_size;
if end_idx > num_points {
end_idx = num_points;
}
let subset = precomp_table.at_point(start_idx);
let scalars = match points {
ScalarList::Slice(s) => &s[start_idx..end_idx],
ScalarList::Getter(getter, _) => {
for i in start_idx..end_idx {
scalar_storage[i - start_idx] = getter(i);
}
&scalar_storage
}
};
multiscalar(scalars, &subset, nbits)
}) .reduce(G::Curve::identity, |mut acc, part| {
acc.add_assign(&part);
acc
})
}
#[cfg(target_arch = "x86_64")]
fn prefetch<T>(p: *const T) {
unsafe {
core::arch::x86_64::_mm_prefetch(p as *const _, core::arch::x86_64::_MM_HINT_T0);
}
}
#[cfg(all(nightly, target_arch = "aarch64"))]
fn prefetch<T>(p: *const T) {
unsafe {
use std::arch::aarch64::*;
_prefetch(p as *const _, _PREFETCH_READ, _PREFETCH_LOCALITY3);
}
}
#[cfg(not(any(target_arch = "x86_64", all(target_arch = "aarch64", nightly))))]
fn prefetch<T>(_: *const T) {}
#[cfg(test)]
mod tests {
use super::*;
use std::ops::Mul;
use blstrs::{G1Affine, G1Projective, Scalar as Fr};
use ff::Field;
use rand_core::SeedableRng;
use rand_xorshift::XorShiftRng;
fn multiscalar_naive(
points: &[G1Affine],
scalars: &[<Fr as PrimeField>::Repr],
) -> G1Projective {
let mut acc = G1Projective::identity();
for (scalar, point) in scalars.iter().zip(points.iter()) {
let scalar = <Fr as PrimeField>::from_repr(*scalar).unwrap();
acc.add_assign(&point.mul(scalar));
}
acc
}
#[test]
fn test_multiscalar_single() {
let mut rng = XorShiftRng::from_seed([
0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06,
0xbc, 0xe5,
]);
for _ in 0..50 {
for (num_inputs, window_size) in &[(8, 4), (12, 1), (10, 1), (20, 2)] {
let points: Vec<G1Affine> = (0..*num_inputs)
.map(|_| G1Projective::random(&mut rng).to_affine())
.collect();
let scalars: Vec<<Fr as PrimeField>::Repr> = (0..*num_inputs)
.map(|_| Fr::random(&mut rng).to_repr())
.collect();
let table = precompute_fixed_window::<G1Affine>(&points, *window_size);
let naive_result = multiscalar_naive(&points, &scalars);
let fast_result = multiscalar::<G1Affine>(
&scalars,
&table,
std::mem::size_of::<<Fr as PrimeField>::Repr>() * 8,
);
assert_eq!(naive_result, fast_result);
}
}
}
#[test]
fn test_multiscalar_par() {
let mut rng = XorShiftRng::from_seed([
0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06,
0xbc, 0xe5,
]);
for _ in 0..50 {
for (num_inputs, window_size) in &[(8, 4), (12, 1), (10, 1), (20, 2)] {
let points: Vec<G1Affine> = (0..*num_inputs)
.map(|_| G1Projective::random(&mut rng).to_affine())
.collect();
let scalars: Vec<<Fr as PrimeField>::Repr> = (0..*num_inputs)
.map(|_| Fr::random(&mut rng).to_repr())
.collect();
let table = precompute_fixed_window::<G1Affine>(&points, *window_size);
let naive_result = multiscalar_naive(&points, &scalars);
let fast_result = par_multiscalar::<&Getter<G1Affine>, G1Affine>(
&ScalarList::Slice(&scalars),
&table,
std::mem::size_of::<<Fr as PrimeField>::Repr>() * 8,
);
assert_eq!(naive_result, fast_result);
}
}
}
}