use std::{convert::TryInto, ops::Neg};
use ff::{Field, PrimeField};
use group::Group;
use rayon::iter::{
IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator,
};
use crate::CurveAffine;
const BATCH_SIZE: usize = 64;
fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
let skip_bits = (window_index * window_size).saturating_sub(1);
let skip_bytes = skip_bits / 8;
let mut v: [u8; 4] = [0; 4];
for (dst, src) in v.iter_mut().zip(el.iter().skip(skip_bytes)) {
*dst = *src
}
let mut tmp = u32::from_le_bytes(v);
if window_index == 0 {
tmp <<= 1;
}
tmp >>= skip_bits - (skip_bytes * 8);
tmp &= (1 << (window_size + 1)) - 1;
let sign = tmp & (1 << window_size) == 0;
tmp = (tmp + 1) >> 1;
if sign {
tmp as i32
} else {
((!(tmp - 1) & ((1 << window_size) - 1)) as i32).neg()
}
}
fn batch_add<C: CurveAffine>(
size: usize,
buckets: &mut [BucketAffine<C>],
points: &[SchedulePoint],
bases: &[Affine<C>],
) {
let mut t = vec![C::Base::ZERO; size]; let mut z = vec![C::Base::ZERO; size]; let mut acc = C::Base::ONE;
for (
(
SchedulePoint {
base_idx,
buck_idx,
sign,
},
t,
),
z,
) in points.iter().zip(t.iter_mut()).zip(z.iter_mut())
{
if buckets[*buck_idx].is_inf() {
continue;
}
if buckets[*buck_idx].x() == bases[*base_idx].x {
if (buckets[*buck_idx].y() == bases[*base_idx].y) ^ !*sign {
let x_squared = bases[*base_idx].x.square();
*z = buckets[*buck_idx].y() + buckets[*buck_idx].y(); *t = acc * (x_squared + x_squared + x_squared); acc *= *z;
continue;
}
buckets[*buck_idx].set_inf();
continue;
}
*z = buckets[*buck_idx].x() - bases[*base_idx].x; if *sign {
*t = acc * (buckets[*buck_idx].y() - bases[*base_idx].y);
} else {
*t = acc * (buckets[*buck_idx].y() + bases[*base_idx].y);
} acc *= *z;
}
acc = acc.invert().expect("Some edge case has not been handled properly");
for (
(
SchedulePoint {
base_idx,
buck_idx,
sign,
},
t,
),
z,
) in points.iter().zip(t.iter()).zip(z.iter()).rev()
{
if buckets[*buck_idx].is_inf() {
continue;
}
let lambda = acc * t;
acc *= z; let x = lambda.square() - (buckets[*buck_idx].x() + bases[*base_idx].x); if *sign {
buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) - bases[*base_idx].y));
} else {
buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) + bases[*base_idx].y));
} buckets[*buck_idx].set_x(&x);
}
}
#[derive(Debug, Clone, Copy)]
struct Affine<C: CurveAffine> {
x: C::Base,
y: C::Base,
}
impl<C: CurveAffine> Affine<C> {
fn from(point: &C) -> Self {
let coords = point.coordinates().unwrap();
Self {
x: *coords.x(),
y: *coords.y(),
}
}
fn neg(&self) -> Self {
Self {
x: self.x,
y: -self.y,
}
}
fn eval(&self) -> C {
C::from_xy(self.x, self.y).unwrap()
}
}
#[derive(Debug, Clone)]
enum BucketAffine<C: CurveAffine> {
None,
Point(Affine<C>),
}
#[derive(Debug, Clone)]
enum Bucket<C: CurveAffine> {
None,
Point(C::Curve),
}
impl<C: CurveAffine> Bucket<C> {
fn add_assign(&mut self, point: &C, sign: bool) {
*self = match *self {
Bucket::None => Bucket::Point({
if sign {
point.to_curve()
} else {
point.to_curve().neg()
}
}),
Bucket::Point(a) => {
if sign {
Self::Point(a + point)
} else {
Self::Point(a - point)
}
}
}
}
fn add(&self, other: &BucketAffine<C>) -> C::Curve {
match (self, other) {
(Self::Point(this), BucketAffine::Point(other)) => *this + other.eval(),
(Self::Point(this), BucketAffine::None) => *this,
(Self::None, BucketAffine::Point(other)) => other.eval().to_curve(),
(Self::None, BucketAffine::None) => C::Curve::identity(),
}
}
}
impl<C: CurveAffine> BucketAffine<C> {
fn assign(&mut self, point: &Affine<C>, sign: bool) -> bool {
match *self {
Self::None => {
*self = Self::Point(if sign { *point } else { point.neg() });
true
}
Self::Point(_) => false,
}
}
fn x(&self) -> C::Base {
match self {
Self::None => panic!("::x None"),
Self::Point(a) => a.x,
}
}
fn y(&self) -> C::Base {
match self {
Self::None => panic!("::y None"),
Self::Point(a) => a.y,
}
}
fn is_inf(&self) -> bool {
match self {
Self::None => true,
Self::Point(_) => false,
}
}
fn set_x(&mut self, x: &C::Base) {
match self {
Self::None => panic!("::set_x None"),
Self::Point(ref mut a) => a.x = *x,
}
}
fn set_y(&mut self, y: &C::Base) {
match self {
Self::None => panic!("::set_y None"),
Self::Point(ref mut a) => a.y = *y,
}
}
fn set_inf(&mut self) {
match self {
Self::None => {}
Self::Point(_) => *self = Self::None,
}
}
}
struct Schedule<C: CurveAffine> {
buckets: Vec<BucketAffine<C>>,
set: [SchedulePoint; BATCH_SIZE],
ptr: usize,
}
#[derive(Debug, Clone, Default)]
struct SchedulePoint {
base_idx: usize,
buck_idx: usize,
sign: bool,
}
impl SchedulePoint {
fn new(base_idx: usize, buck_idx: usize, sign: bool) -> Self {
Self {
base_idx,
buck_idx,
sign,
}
}
}
impl<C: CurveAffine> Schedule<C> {
fn new(c: usize) -> Self {
let set = (0..BATCH_SIZE)
.map(|_| SchedulePoint::default())
.collect::<Vec<_>>()
.try_into()
.unwrap();
Self {
buckets: vec![BucketAffine::None; 1 << (c - 1)],
set,
ptr: 0,
}
}
fn contains(&self, buck_idx: usize) -> bool {
self.set.iter().any(|sch| sch.buck_idx == buck_idx)
}
fn execute(&mut self, bases: &[Affine<C>]) {
if self.ptr != 0 {
batch_add(self.ptr, &mut self.buckets, &self.set, bases);
self.ptr = 0;
self.set.iter_mut().for_each(|sch| *sch = SchedulePoint::default());
}
}
fn add(&mut self, bases: &[Affine<C>], base_idx: usize, buck_idx: usize, sign: bool) {
if !self.buckets[buck_idx].assign(&bases[base_idx], sign) {
self.set[self.ptr] = SchedulePoint::new(base_idx, buck_idx, sign);
self.ptr += 1;
}
if self.ptr == self.set.len() {
self.execute(bases);
}
}
}
pub fn msm_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
let c = if bases.len() < 4 {
1
} else if bases.len() < 32 {
3
} else {
(f64::from(bases.len() as u32)).ln().ceil() as usize
};
let field_byte_size = C::Scalar::NUM_BITS.div_ceil(8u32) as usize;
let mut acc_or = vec![0; field_byte_size];
for coeff in &coeffs {
for (acc_limb, limb) in acc_or.iter_mut().zip(coeff.as_ref().iter()) {
*acc_limb |= *limb;
}
}
let max_byte_size =
field_byte_size - acc_or.iter().rev().position(|v| *v != 0).unwrap_or(field_byte_size);
if max_byte_size == 0 {
return;
}
let number_of_windows = max_byte_size * 8_usize / c + 1;
for current_window in (0..number_of_windows).rev() {
for _ in 0..c {
*acc = acc.double();
}
#[derive(Clone, Copy)]
enum Bucket<C: CurveAffine> {
None,
Affine(C),
Projective(C::Curve),
}
impl<C: CurveAffine> Bucket<C> {
fn add_assign(&mut self, other: &C) {
*self = match *self {
Bucket::None => Bucket::Affine(*other),
Bucket::Affine(a) => Bucket::Projective(a + *other),
Bucket::Projective(mut a) => {
a += *other;
Bucket::Projective(a)
}
}
}
fn add(self, mut other: C::Curve) -> C::Curve {
match self {
Bucket::None => other,
Bucket::Affine(a) => {
other += a;
other
}
Bucket::Projective(a) => other + a,
}
}
}
let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; 1 << (c - 1)];
for (coeff, base) in coeffs.iter().zip(bases.iter()) {
let coeff = get_booth_index(current_window, c, coeff.as_ref());
if coeff.is_positive() {
buckets[coeff as usize - 1].add_assign(base);
}
if coeff.is_negative() {
buckets[coeff.unsigned_abs() as usize - 1].add_assign(&base.neg());
}
}
let mut running_sum = C::Curve::identity();
for exp in buckets.into_iter().rev() {
running_sum = exp.add(running_sum);
*acc += &running_sum;
}
}
}
pub fn msm_parallel<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
assert_eq!(coeffs.len(), bases.len());
let num_threads = rayon::current_num_threads();
if coeffs.len() > num_threads {
let chunk = coeffs.len() / num_threads;
let num_chunks = coeffs.chunks(chunk).len();
let mut results = vec![C::Curve::identity(); num_chunks];
rayon::scope(|scope| {
let chunk = coeffs.len() / num_threads;
for ((coeffs, bases), acc) in
coeffs.chunks(chunk).zip(bases.chunks(chunk)).zip(results.iter_mut())
{
scope.spawn(move |_| {
msm_serial(coeffs, bases, acc);
});
}
});
results.iter().fold(C::Curve::identity(), |a, b| a + b)
} else {
let mut acc = C::Curve::identity();
msm_serial(coeffs, bases, &mut acc);
acc
}
}
pub fn msm_best<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
assert_eq!(coeffs.len(), bases.len());
let c = if bases.len() < 4 {
1
} else if bases.len() < 32 {
3
} else {
(f64::from(bases.len() as u32)).ln().ceil() as usize
};
if c < 10 {
return msm_parallel(coeffs, bases);
}
let coeffs: Vec<_> = coeffs.par_iter().map(|a| a.to_repr()).collect();
let bases_local: Vec<_> = bases.par_iter().map(Affine::from).collect();
let number_of_windows = C::Scalar::NUM_BITS as usize / c + 1;
let mut acc = vec![C::Curve::identity(); number_of_windows];
acc.par_iter_mut().enumerate().rev().for_each(|(w, acc)| {
let mut j_bucks = vec![Bucket::<C>::None; 1 << (c - 1)];
let mut sched = Schedule::new(c);
for (base_idx, coeff) in coeffs.iter().enumerate() {
let buck_idx = get_booth_index(w, c, coeff.as_ref());
if buck_idx != 0 {
let sign = buck_idx.is_positive();
let buck_idx = buck_idx.unsigned_abs() as usize - 1;
if sched.contains(buck_idx) {
j_bucks[buck_idx].add_assign(&bases[base_idx], sign);
} else {
sched.add(&bases_local, base_idx, buck_idx, sign);
}
}
}
sched.execute(&bases_local);
let mut running_sum = C::Curve::identity();
for (j_buck, a_buck) in j_bucks.iter().zip(sched.buckets.iter()).rev() {
running_sum += j_buck.add(a_buck);
*acc += running_sum;
}
for _ in 0..c * w {
*acc = acc.double();
}
});
acc.into_iter().sum::<_>()
}
#[cfg(test)]
mod test {
use std::ops::Neg;
use ff::{Field, PrimeField};
use group::{Curve, Group};
use rand_core::OsRng;
use crate::{
bn256::{Fr, G1Affine, G1},
CurveAffine,
};
#[test]
fn test_booth_encoding() {
fn mul(scalar: &Fr, point: &G1Affine, window: usize) -> G1Affine {
let u = scalar.to_repr();
let n = Fr::NUM_BITS as usize / window + 1;
let table =
(0..=1 << (window - 1)).map(|i| point * Fr::from(i as u64)).collect::<Vec<_>>();
let mut acc = G1::identity();
for i in (0..n).rev() {
for _ in 0..window {
acc = acc.double();
}
let idx = super::get_booth_index(i, window, u.as_ref());
if idx.is_negative() {
acc += table[idx.unsigned_abs() as usize].neg();
}
if idx.is_positive() {
acc += table[idx.unsigned_abs() as usize];
}
}
acc.to_affine()
}
let (scalars, points): (Vec<_>, Vec<_>) = (0..10)
.map(|_| {
let scalar = Fr::random(OsRng);
let point = G1Affine::random(OsRng);
(scalar, point)
})
.unzip();
for window in 1..10 {
for (scalar, point) in scalars.iter().zip(points.iter()) {
let c0 = mul(scalar, point, window);
let c1 = point * scalar;
assert_eq!(c0, c1.to_affine());
}
}
}
fn run_msm_cross<C: CurveAffine>(min_k: usize, max_k: usize) {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let points = (0..1 << max_k)
.into_par_iter()
.map(|_| C::Curve::random(OsRng))
.collect::<Vec<_>>();
let mut affine_points = vec![C::identity(); 1 << max_k];
C::Curve::batch_normalize(&points[..], &mut affine_points[..]);
let points = affine_points;
let scalars = (0..1 << max_k)
.into_par_iter()
.map(|_| C::Scalar::random(OsRng))
.collect::<Vec<_>>();
for k in min_k..=max_k {
let points = &points[..1 << k];
let scalars = &scalars[..1 << k];
let e0 = super::msm_best(scalars, points);
let e1 = super::msm_parallel(scalars, points);
assert_eq!(e0, e1);
}
}
#[test]
fn test_msm_cross() {
run_msm_cross::<G1Affine>(14, 18);
}
}