1#[cfg(feature = "std")]
2use rayon::prelude::*;
3use zkstd::common::{vec, CurveAffine, CurveGroup, FftField, Vec};
4
5pub fn msm_curve_addition<C: CurveAffine>(bases: &[C], coeffs: &[C::Scalar]) -> C::Extended {
7 let c = if bases.len() < 4 {
8 1
9 } else if bases.len() < 32 {
10 3
11 } else {
12 let log2 = usize::BITS - bases.len().leading_zeros();
13 (log2 * 69 / 100) as usize + 2
14 };
15
16 let mut buckets = vec![vec![Bucket::None; (1 << c) - 1]; (256 / c) + 1];
17 #[cfg(feature = "std")]
18 let bucket_iteration = buckets.par_iter_mut();
19 #[cfg(not(feature = "std"))]
20 let bucket_iteration = buckets.iter_mut();
21 let filled_buckets = bucket_iteration
22 .enumerate()
23 .rev()
24 .map(|(i, bucket)| {
25 for (coeff, base) in coeffs.iter().zip(bases.iter()) {
26 let seg = get_at(i, c, coeff.to_raw_bytes());
27 if seg != 0 {
28 bucket[seg - 1].add_assign(base);
29 }
30 }
31 let mut acc = C::Extended::ADDITIVE_IDENTITY;
36 let mut sum = C::Extended::ADDITIVE_IDENTITY;
37 bucket.iter().rev().for_each(|b| {
38 sum = b.add(sum);
39 acc += sum;
40 });
41 (0..c * i).for_each(|_| acc = acc.double());
42 acc
43 })
44 .collect::<Vec<_>>();
45 filled_buckets
46 .iter()
47 .fold(C::Extended::ADDITIVE_IDENTITY, |a, b| a + b)
48}
49
50#[derive(Clone, Copy)]
51enum Bucket<C: CurveAffine> {
52 None,
53 Affine(C),
54 Projective(C::Extended),
55}
56
57impl<C: CurveAffine> Bucket<C> {
58 fn add_assign(&mut self, other: &C) {
59 *self = match *self {
60 Bucket::None => Bucket::Affine(*other),
61 Bucket::Affine(a) => Bucket::Projective(a + other),
62 Bucket::Projective(a) => Bucket::Projective(a + other),
63 }
64 }
65
66 fn add(&self, other: C::Extended) -> C::Extended {
67 match self {
68 Bucket::None => other,
69 Bucket::Affine(a) => other + a,
70 Bucket::Projective(a) => other + a,
71 }
72 }
73}
74
75fn get_at(segment: usize, c: usize, bytes: [u8; 32]) -> usize {
76 let skip_bits = segment * c;
77 let skip_bytes = skip_bits / 8;
78
79 if skip_bytes >= 32 {
80 0
81 } else {
82 let mut v = [0; 8];
83 for (v, o) in v.iter_mut().zip(bytes[skip_bytes..].iter()) {
84 *v = *o;
85 }
86
87 let mut tmp = u64::from_le_bytes(v);
88 tmp >>= skip_bits - (skip_bytes * 8);
89 (tmp % (1 << c)) as usize
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::msm_curve_addition;
96 use bls_12_381::{Fr, G1Affine, G1Projective};
97 use rand_core::OsRng;
98 use zkstd::common::{CurveAffine, CurveGroup};
99 use zkstd::common::{Group, WeierstrassProjective};
100
101 fn customized_scalar_point<P: WeierstrassProjective<Extended = P>>(point: P, scalar: &Fr) -> P {
102 let mut res = P::ADDITIVE_IDENTITY;
103 let one = point;
104 let two = one + point;
105 let three = two + point;
106 for &bit in scalar.to_costomized_repr().iter() {
107 res = res.double().double();
108 if bit == 1 {
109 res += one;
110 } else if bit == 2 {
111 res += two;
112 } else if bit == 3 {
113 res += three;
114 }
115 }
116 res
117 }
118
119 #[test]
120 fn multi_scalar_multiplication_test() {
121 let n = 1 << 5;
122 let points = (0..n)
123 .map(|_| G1Affine::from(G1Affine::random(OsRng)))
124 .collect::<Vec<_>>();
125 let scalars = (0..n).map(|_| Fr::random(OsRng)).collect::<Vec<_>>();
126 let msm = msm_curve_addition(&points[..], &scalars[..]);
127 let naive = points
128 .iter()
129 .rev()
130 .zip(scalars.iter().rev())
131 .fold(G1Projective::ADDITIVE_IDENTITY, |acc, (point, coeff)| {
132 acc + customized_scalar_point(point.to_extended(), coeff)
133 });
134 assert_eq!(msm, naive);
135 }
136}