use ff::{Field, PrimeField};
use halo2curves::{group::Group, CurveAffine};
use num_integer::Integer;
use num_traits::{ToPrimitive, Zero};
use rayon::{current_num_threads, prelude::*};
#[derive(Copy, Clone)]
struct BucketXYZZ<F: Field> {
x: F,
y: F,
zz: F,
zzz: F,
}
impl<F: Field> BucketXYZZ<F> {
#[inline]
fn zero() -> Self {
Self {
x: F::ONE,
y: F::ONE,
zz: F::ZERO,
zzz: F::ZERO,
}
}
#[inline]
fn is_zero(&self) -> bool {
self.zz == F::ZERO
}
fn double_in_place(&mut self) {
if self.is_zero() {
return;
}
let u = self.y.double();
let v = u.square();
let w = u * v;
let s = self.x * v;
let x_sq = self.x.square();
let m = x_sq.double() + x_sq;
self.x = m.square() - s.double();
self.y = m * (s - self.x) - w * self.y;
self.zz *= v;
self.zzz *= w;
}
fn add_assign_bucket(&mut self, other: &Self) {
if other.is_zero() {
return;
}
if self.is_zero() {
*self = *other;
return;
}
let u1 = self.x * other.zz;
let u2 = other.x * self.zz;
let s1 = self.y * other.zzz;
let s2 = other.y * self.zzz;
if u1 == u2 {
if s1 == s2 {
self.double_in_place();
} else {
*self = Self::zero();
}
return;
}
let p = u2 - u1;
let r = s2 - s1;
let pp = p.square();
let ppp = p * pp;
let q = u1 * pp;
self.x = r.square() - ppp - q.double();
self.y = r * (q - self.x) - s1 * ppp;
self.zz = self.zz * other.zz * pp;
self.zzz = self.zzz * other.zzz * ppp;
}
}
#[inline]
fn bucket_add_affine<C: CurveAffine>(bucket: &mut BucketXYZZ<C::Base>, p: &C) {
if bool::from(p.is_identity()) {
return;
}
let coords = p.coordinates().unwrap();
let px = *coords.x();
let py = *coords.y();
if bucket.is_zero() {
bucket.x = px;
bucket.y = py;
bucket.zz = C::Base::ONE;
bucket.zzz = C::Base::ONE;
return;
}
let u2 = px * bucket.zz;
let s2 = py * bucket.zzz;
if bucket.x == u2 {
if bucket.y == s2 {
bucket.double_in_place();
} else {
*bucket = BucketXYZZ::zero();
}
return;
}
let p_val = u2 - bucket.x;
let r = s2 - bucket.y;
let pp = p_val.square();
let ppp = p_val * pp;
let q = bucket.x * pp;
bucket.x = r.square() - ppp - q.double();
bucket.y = r * (q - bucket.x) - bucket.y * ppp;
bucket.zz *= pp;
bucket.zzz *= ppp;
}
#[inline]
fn bucket_to_curve<C: CurveAffine>(bucket: &BucketXYZZ<C::Base>) -> C::CurveExt {
if bucket.is_zero() {
return C::CurveExt::identity();
}
let zz_inv = bucket.zz.invert().unwrap();
let zzz_inv = bucket.zzz.invert().unwrap();
let x = bucket.x * zz_inv;
let y = bucket.y * zzz_inv;
C::from_xy(x, y)
.expect("XYZZ bucket should produce a valid curve point")
.into()
}
#[inline]
fn scalar_num_bits<F: PrimeField>(s: &F) -> u32 {
let repr = s.to_repr();
let bytes = repr.as_ref();
for i in (0..bytes.len()).rev() {
if bytes[i] != 0 {
return i as u32 * 8 + (8 - bytes[i].leading_zeros());
}
}
0
}
#[inline]
fn repr_low_u64<F: PrimeField>(s: &F) -> u64 {
let repr = s.to_repr();
let bytes = repr.as_ref();
let mut buf = [0u8; 8];
let len = bytes.len().min(8);
buf[..len].copy_from_slice(&bytes[..len]);
u64::from_le_bytes(buf)
}
pub fn msm<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
assert_eq!(coeffs.len(), bases.len());
let n = coeffs.len();
if n == 0 {
return C::Curve::identity();
}
if n <= 16 {
return msm_simple(coeffs, bases);
}
const NUM_GROUPS: usize = 11;
let classified: Vec<u64> = coeffs
.par_iter()
.enumerate()
.filter_map(|(i, s)| {
if bool::from(s.is_zero()) || bool::from(bases[i].is_identity()) {
return None;
}
let neg_s = -(*s);
let bits_s = scalar_num_bits(s);
let bits_neg = scalar_num_bits(&neg_s);
let group = if bits_s <= 1 {
0u8 } else if bits_neg <= 1 {
1u8 } else if bits_s <= 8 {
2u8
} else if bits_neg <= 8 {
3u8
} else if bits_s <= 16 {
4u8
} else if bits_neg <= 16 {
5u8
} else if bits_s <= 32 {
6u8
} else if bits_neg <= 32 {
7u8
} else if bits_s <= 64 {
8u8
} else if bits_neg <= 64 {
9u8
} else {
10u8 };
Some(((i as u64) & 0x0FFF_FFFF_FFFF_FFFF) | ((group as u64) << 60))
})
.collect();
if classified.is_empty() {
return C::Curve::identity();
}
let mut classified = classified;
classified.par_sort_unstable_by_key(|v| (v >> 60) as u8);
let extract_group = |v: u64| (v >> 60) as u8;
let extract_index = |v: u64| (v & 0x0FFF_FFFF_FFFF_FFFF) as usize;
let mut boundaries = [0usize; NUM_GROUPS + 1];
{
let mut pos = 0;
for g in 0..NUM_GROUPS as u8 {
boundaries[g as usize] = pos;
pos += classified[pos..].partition_point(|v| extract_group(*v) <= g);
}
boundaries[NUM_GROUPS] = classified.len();
}
let extract_u64_group = |start: usize, end: usize, negate: bool| -> (Vec<C>, Vec<u64>) {
classified[start..end]
.iter()
.map(|&v| {
let idx = extract_index(v);
let b = bases[idx];
let s = if negate { -coeffs[idx] } else { coeffs[idx] };
(b, repr_low_u64(&s))
})
.unzip()
};
let extract_binary_group = |start: usize, end: usize| -> Vec<C> {
classified[start..end]
.iter()
.map(|&v| bases[extract_index(v)])
.collect()
};
let (g0_start, g0_end) = (boundaries[0], boundaries[1]);
let (g1_start, g1_end) = (boundaries[1], boundaries[2]);
let (g2_start, g2_end) = (boundaries[2], boundaries[3]);
let (g3_start, g3_end) = (boundaries[3], boundaries[4]);
let (g4_start, g4_end) = (boundaries[4], boundaries[5]);
let (g5_start, g5_end) = (boundaries[5], boundaries[6]);
let (g6_start, g6_end) = (boundaries[6], boundaries[7]);
let (g7_start, g7_end) = (boundaries[7], boundaries[8]);
let (g8_start, g8_end) = (boundaries[8], boundaries[9]);
let (g9_start, g9_end) = (boundaries[9], boundaries[10]);
let (g10_start, g10_end) = (boundaries[10], boundaries[11]);
let (binary_result, small_and_large_result) = rayon::join(
|| {
let (pos, neg) = rayon::join(
|| {
let bases_pos = extract_binary_group(g0_start, g0_end);
accumulate_bases::<C>(&bases_pos)
},
|| {
let bases_neg = extract_binary_group(g1_start, g1_end);
accumulate_bases::<C>(&bases_neg)
},
);
pos - neg
},
|| {
let (small_result, large_result) = rayon::join(
|| {
let ((r8, r16), (r32, r64)) = rayon::join(
|| {
rayon::join(
|| {
let (pos_b, pos_s) = extract_u64_group(g2_start, g2_end, false);
let (neg_b, neg_s) = extract_u64_group(g3_start, g3_end, true);
msm_small_with_max_num_bits(&pos_s, &pos_b, 8)
- msm_small_with_max_num_bits(&neg_s, &neg_b, 8)
},
|| {
let (pos_b, pos_s) = extract_u64_group(g4_start, g4_end, false);
let (neg_b, neg_s) = extract_u64_group(g5_start, g5_end, true);
msm_small_with_max_num_bits(&pos_s, &pos_b, 16)
- msm_small_with_max_num_bits(&neg_s, &neg_b, 16)
},
)
},
|| {
rayon::join(
|| {
let (pos_b, pos_s) = extract_u64_group(g6_start, g6_end, false);
let (neg_b, neg_s) = extract_u64_group(g7_start, g7_end, true);
msm_small_with_max_num_bits(&pos_s, &pos_b, 32)
- msm_small_with_max_num_bits(&neg_s, &neg_b, 32)
},
|| {
let (pos_b, pos_s) = extract_u64_group(g8_start, g8_end, false);
let (neg_b, neg_s) = extract_u64_group(g9_start, g9_end, true);
msm_small_with_max_num_bits(&pos_s, &pos_b, 64)
- msm_small_with_max_num_bits(&neg_s, &neg_b, 64)
},
)
},
);
r8 + r16 + r32 + r64
},
|| {
if g10_start >= g10_end {
return C::Curve::identity();
}
let (large_bases, large_coeffs): (Vec<C>, Vec<C::Scalar>) = classified
[g10_start..g10_end]
.iter()
.map(|&v| {
let idx = extract_index(v);
(bases[idx], coeffs[idx])
})
.unzip();
halo2curves::msm::msm_best(&large_coeffs, &large_bases)
},
);
small_result + large_result
},
);
binary_result + small_and_large_result
}
fn msm_simple<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
coeffs
.iter()
.zip(bases.iter())
.fold(C::Curve::identity(), |acc, (coeff, base)| {
acc + *base * coeff
})
}
fn accumulate_bases<C: CurveAffine>(bases: &[C]) -> C::Curve {
let num_threads = current_num_threads();
if bases.is_empty() {
return C::Curve::identity();
}
if bases.len() > num_threads {
let chunk = bases.len().div_ceil(num_threads);
bases
.par_chunks(chunk)
.map(|chunk| {
chunk.iter().fold(C::Curve::identity(), |mut acc, b| {
acc += *b;
acc
})
})
.reduce(C::Curve::identity, |a, b| a + b)
} else {
bases.iter().fold(C::Curve::identity(), |mut acc, b| {
acc += *b;
acc
})
}
}
fn num_bits(n: usize) -> usize {
if n == 0 {
0
} else {
(n.ilog2() + 1) as usize
}
}
pub fn msm_small<C: CurveAffine, T: Integer + Into<u64> + Copy + Sync + ToPrimitive>(
scalars: &[T],
bases: &[C],
) -> C::Curve {
let max_num_bits = num_bits(scalars.iter().max().unwrap().to_usize().unwrap());
msm_small_with_max_num_bits(scalars, bases, max_num_bits)
}
pub fn msm_small_with_max_num_bits<
C: CurveAffine,
T: Integer + Into<u64> + Copy + Sync + ToPrimitive,
>(
scalars: &[T],
bases: &[C],
max_num_bits: usize,
) -> C::Curve {
assert_eq!(bases.len(), scalars.len());
match max_num_bits {
0 => C::identity().into(),
1 => msm_binary(scalars, bases),
2..=10 => msm_10(scalars, bases, max_num_bits),
11..=32 => msm_small_rest(scalars, bases, max_num_bits),
_ => {
let field_scalars: Vec<C::ScalarExt> = scalars
.iter()
.map(|s| C::ScalarExt::from((*s).into()))
.collect();
halo2curves::msm::msm_best(&field_scalars, bases)
}
}
}
fn msm_binary<C: CurveAffine, T: Integer + Sync>(scalars: &[T], bases: &[C]) -> C::Curve {
assert_eq!(scalars.len(), bases.len());
let num_threads = current_num_threads();
let process_chunk = |scalars: &[T], bases: &[C]| {
let mut acc = C::Curve::identity();
scalars
.iter()
.zip(bases.iter())
.filter(|(scalar, _)| !scalar.is_zero())
.for_each(|(_, base)| {
acc += *base;
});
acc
};
if scalars.len() > num_threads {
let chunk = scalars.len() / num_threads;
scalars
.par_chunks(chunk)
.zip(bases.par_chunks(chunk))
.map(|(scalars, bases)| process_chunk(scalars, bases))
.reduce(C::Curve::identity, |sum, evl| sum + evl)
} else {
process_chunk(scalars, bases)
}
}
fn msm_10<C: CurveAffine, T: Into<u64> + Zero + Copy + Sync>(
scalars: &[T],
bases: &[C],
max_num_bits: usize,
) -> C::Curve {
fn msm_10_serial<C: CurveAffine, T: Into<u64> + Zero + Copy>(
scalars: &[T],
bases: &[C],
max_num_bits: usize,
) -> C::Curve {
let num_buckets: usize = 1 << max_num_bits;
let mut buckets: Vec<BucketXYZZ<C::Base>> = vec![BucketXYZZ::zero(); num_buckets];
scalars
.iter()
.zip(bases.iter())
.filter(|(scalar, _base)| !scalar.is_zero())
.for_each(|(scalar, base)| {
let bucket_index: u64 = (*scalar).into();
bucket_add_affine::<C>(&mut buckets[bucket_index as usize], base);
});
let mut result: BucketXYZZ<C::Base> = BucketXYZZ::zero();
let mut running_sum: BucketXYZZ<C::Base> = BucketXYZZ::zero();
for b in buckets.into_iter().skip(1).rev() {
running_sum.add_assign_bucket(&b);
result.add_assign_bucket(&running_sum);
}
bucket_to_curve::<C>(&result)
}
let num_threads = current_num_threads();
if scalars.len() > num_threads {
let chunk_size = scalars.len() / num_threads;
scalars
.par_chunks(chunk_size)
.zip(bases.par_chunks(chunk_size))
.map(|(scalars_chunk, bases_chunk)| msm_10_serial(scalars_chunk, bases_chunk, max_num_bits))
.reduce(C::Curve::identity, |sum, evl| sum + evl)
} else {
msm_10_serial(scalars, bases, max_num_bits)
}
}
fn msm_small_rest<C: CurveAffine, T: Into<u64> + Zero + Copy + Sync>(
scalars: &[T],
bases: &[C],
max_num_bits: usize,
) -> C::Curve {
fn msm_small_rest_serial<C: CurveAffine, T: Into<u64> + Zero + Copy>(
scalars: &[T],
bases: &[C],
max_num_bits: usize,
) -> C::Curve {
let mut c = if bases.len() < 32 {
3
} else {
compute_ln(bases.len()) + 2
};
if max_num_bits == 32 || max_num_bits == 64 {
c = 8;
}
let scalars_and_bases_iter = scalars.iter().zip(bases).filter(|(s, _base)| !s.is_zero());
let window_starts: Vec<usize> = (0..max_num_bits).step_by(c).collect();
let window_sums: Vec<C::CurveExt> = window_starts
.iter()
.map(|&w_start| {
let mut res: BucketXYZZ<C::Base> = BucketXYZZ::zero();
let mut buckets: Vec<BucketXYZZ<C::Base>> = vec![BucketXYZZ::zero(); (1 << c) - 1];
scalars_and_bases_iter.clone().for_each(|(&scalar, base)| {
let scalar: u64 = scalar.into();
if scalar == 1 {
if w_start == 0 {
bucket_add_affine::<C>(&mut res, base);
}
} else {
let mut scalar = scalar;
scalar >>= w_start;
scalar %= 1 << c;
if scalar != 0 {
bucket_add_affine::<C>(&mut buckets[(scalar - 1) as usize], base);
}
}
});
let mut running_sum: BucketXYZZ<C::Base> = BucketXYZZ::zero();
for b in buckets.into_iter().rev() {
running_sum.add_assign_bucket(&b);
res.add_assign_bucket(&running_sum);
}
bucket_to_curve::<C>(&res)
})
.collect();
let lowest = *window_sums.first().unwrap();
lowest
+ window_sums[1..]
.iter()
.rev()
.fold(C::CurveExt::identity(), |mut total, sum_i| {
total += sum_i;
for _ in 0..c {
total = total.double();
}
total
})
}
let num_threads = current_num_threads();
if scalars.len() > num_threads {
let chunk_size = scalars.len() / num_threads;
scalars
.par_chunks(chunk_size)
.zip(bases.par_chunks(chunk_size))
.map(|(scalars_chunk, bases_chunk)| {
msm_small_rest_serial(scalars_chunk, bases_chunk, max_num_bits)
})
.reduce(C::Curve::identity, |sum, evl| sum + evl)
} else {
msm_small_rest_serial(scalars, bases, max_num_bits)
}
}
fn compute_ln(a: usize) -> usize {
if a == 0 {
0 } else {
a.ilog2() as usize * 69 / 100
}
}
#[inline(always)]
pub(crate) fn batch_add<C: CurveAffine>(bases: &[C], one_indices: &[usize]) -> C::Curve {
fn add_chunk<C: CurveAffine>(bases: impl Iterator<Item = C>) -> C::Curve {
let mut acc = C::Curve::identity();
for base in bases {
acc += base;
}
acc
}
let num_chunks = rayon::current_num_threads();
let chunk_size = (one_indices.len() + num_chunks).div_ceil(num_chunks);
let comm = one_indices
.par_chunks(chunk_size)
.into_par_iter()
.map(|chunk| add_chunk(chunk.iter().map(|index| bases[*index])))
.reduce(C::Curve::identity, |sum, evl| sum + evl);
comm
}
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::{
bn256_grumpkin::{bn256, grumpkin},
pasta::{pallas, vesta},
secp_secq::{secp256k1, secq256k1},
};
use ff::Field;
use halo2curves::{group::Group, CurveAffine};
use rand_core::OsRng;
fn test_general_msm_with<F: Field, A: CurveAffine<ScalarExt = F>>() {
let n = 8;
let coeffs = (0..n).map(|_| F::random(OsRng)).collect::<Vec<_>>();
let bases = (0..n)
.map(|_| A::from(A::generator() * F::random(OsRng)))
.collect::<Vec<_>>();
assert_eq!(coeffs.len(), bases.len());
let naive = coeffs
.iter()
.zip(bases.iter())
.fold(A::CurveExt::identity(), |acc, (coeff, base)| {
acc + *base * coeff
});
let msm = msm(&coeffs, &bases);
assert_eq!(naive, msm)
}
#[test]
fn test_general_msm() {
test_general_msm_with::<pallas::Scalar, pallas::Affine>();
test_general_msm_with::<vesta::Scalar, vesta::Affine>();
test_general_msm_with::<bn256::Scalar, bn256::Affine>();
test_general_msm_with::<grumpkin::Scalar, grumpkin::Affine>();
test_general_msm_with::<secp256k1::Scalar, secp256k1::Affine>();
test_general_msm_with::<secq256k1::Scalar, secq256k1::Affine>();
}
fn test_msm_ux_with<F: PrimeField, A: CurveAffine<ScalarExt = F>>() {
let n = 8;
let bases = (0..n)
.map(|_| A::from(A::generator() * F::random(OsRng)))
.collect::<Vec<_>>();
for bit_width in [1, 4, 8, 10, 16, 20, 32, 40, 64] {
println!("bit_width: {bit_width}");
assert!(bit_width <= 64); let mask = if bit_width == 64 {
u64::MAX
} else {
(1u64 << bit_width) - 1
};
let coeffs: Vec<u64> = (0..n)
.map(|_| rand::random::<u64>() & mask)
.collect::<Vec<_>>();
let coeffs_scalar: Vec<F> = coeffs.iter().map(|b| F::from(*b)).collect::<Vec<_>>();
let general = msm(&coeffs_scalar, &bases);
let integer = msm_small(&coeffs, &bases);
assert_eq!(general, integer);
}
}
#[test]
fn test_msm_ux() {
test_msm_ux_with::<pallas::Scalar, pallas::Affine>();
test_msm_ux_with::<vesta::Scalar, vesta::Affine>();
test_msm_ux_with::<bn256::Scalar, bn256::Affine>();
test_msm_ux_with::<grumpkin::Scalar, grumpkin::Affine>();
test_msm_ux_with::<secp256k1::Scalar, secp256k1::Affine>();
test_msm_ux_with::<secq256k1::Scalar, secq256k1::Affine>();
}
fn test_msm_identity_bases_with<F: Field, A: CurveAffine<ScalarExt = F>>() {
let n = 8;
let mut coeffs = (0..n).map(|_| F::random(OsRng)).collect::<Vec<_>>();
let mut bases = (0..n)
.map(|_| A::from(A::generator() * F::random(OsRng)))
.collect::<Vec<_>>();
bases[0] = A::identity();
bases[3] = A::identity();
bases[n - 1] = A::identity();
coeffs[0] = F::ONE;
coeffs[3] = F::random(OsRng);
let naive = coeffs
.iter()
.zip(bases.iter())
.fold(A::CurveExt::identity(), |acc, (coeff, base)| {
acc + *base * coeff
});
let result = msm(&coeffs, &bases);
assert_eq!(naive, result);
}
#[test]
fn test_msm_identity_bases() {
test_msm_identity_bases_with::<pallas::Scalar, pallas::Affine>();
test_msm_identity_bases_with::<vesta::Scalar, vesta::Affine>();
test_msm_identity_bases_with::<bn256::Scalar, bn256::Affine>();
test_msm_identity_bases_with::<grumpkin::Scalar, grumpkin::Affine>();
test_msm_identity_bases_with::<secp256k1::Scalar, secp256k1::Affine>();
test_msm_identity_bases_with::<secq256k1::Scalar, secq256k1::Affine>();
}
}