poly_commit/
msm.rs

1#[cfg(feature = "std")]
2use rayon::prelude::*;
3use zkstd::common::{vec, CurveAffine, CurveGroup, FftField, Vec};
4
5/// Performs a Variable Base Multiscalar Multiplication.
6pub fn msm_curve_addition<C: CurveAffine>(bases: &[C], coeffs: &[C::Scalar]) -> C::Extended {
7    let c = if bases.len() < 4 {
8        1
9    } else if bases.len() < 32 {
10        3
11    } else {
12        let log2 = usize::BITS - bases.len().leading_zeros();
13        (log2 * 69 / 100) as usize + 2
14    };
15
16    let mut buckets = vec![vec![Bucket::None; (1 << c) - 1]; (256 / c) + 1];
17    #[cfg(feature = "std")]
18    let bucket_iteration = buckets.par_iter_mut();
19    #[cfg(not(feature = "std"))]
20    let bucket_iteration = buckets.iter_mut();
21    let filled_buckets = bucket_iteration
22        .enumerate()
23        .rev()
24        .map(|(i, bucket)| {
25            for (coeff, base) in coeffs.iter().zip(bases.iter()) {
26                let seg = get_at(i, c, coeff.to_raw_bytes());
27                if seg != 0 {
28                    bucket[seg - 1].add_assign(base);
29                }
30            }
31            // Summation by parts
32            // e.g. 3a + 2b + 1c = a +
33            //                    (a) + b +
34            //                    ((a) + b) + c
35            let mut acc = C::Extended::ADDITIVE_IDENTITY;
36            let mut sum = C::Extended::ADDITIVE_IDENTITY;
37            bucket.iter().rev().for_each(|b| {
38                sum = b.add(sum);
39                acc += sum;
40            });
41            (0..c * i).for_each(|_| acc = acc.double());
42            acc
43        })
44        .collect::<Vec<_>>();
45    filled_buckets
46        .iter()
47        .fold(C::Extended::ADDITIVE_IDENTITY, |a, b| a + b)
48}
49
50#[derive(Clone, Copy)]
51enum Bucket<C: CurveAffine> {
52    None,
53    Affine(C),
54    Projective(C::Extended),
55}
56
57impl<C: CurveAffine> Bucket<C> {
58    fn add_assign(&mut self, other: &C) {
59        *self = match *self {
60            Bucket::None => Bucket::Affine(*other),
61            Bucket::Affine(a) => Bucket::Projective(a + other),
62            Bucket::Projective(a) => Bucket::Projective(a + other),
63        }
64    }
65
66    fn add(&self, other: C::Extended) -> C::Extended {
67        match self {
68            Bucket::None => other,
69            Bucket::Affine(a) => other + a,
70            Bucket::Projective(a) => other + a,
71        }
72    }
73}
74
75fn get_at(segment: usize, c: usize, bytes: [u8; 32]) -> usize {
76    let skip_bits = segment * c;
77    let skip_bytes = skip_bits / 8;
78
79    if skip_bytes >= 32 {
80        0
81    } else {
82        let mut v = [0; 8];
83        for (v, o) in v.iter_mut().zip(bytes[skip_bytes..].iter()) {
84            *v = *o;
85        }
86
87        let mut tmp = u64::from_le_bytes(v);
88        tmp >>= skip_bits - (skip_bytes * 8);
89        (tmp % (1 << c)) as usize
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::msm_curve_addition;
96    use bls_12_381::{Fr, G1Affine, G1Projective};
97    use rand_core::OsRng;
98    use zkstd::common::{CurveAffine, CurveGroup};
99    use zkstd::common::{Group, WeierstrassProjective};
100
101    fn customized_scalar_point<P: WeierstrassProjective<Extended = P>>(point: P, scalar: &Fr) -> P {
102        let mut res = P::ADDITIVE_IDENTITY;
103        let one = point;
104        let two = one + point;
105        let three = two + point;
106        for &bit in scalar.to_costomized_repr().iter() {
107            res = res.double().double();
108            if bit == 1 {
109                res += one;
110            } else if bit == 2 {
111                res += two;
112            } else if bit == 3 {
113                res += three;
114            }
115        }
116        res
117    }
118
119    #[test]
120    fn multi_scalar_multiplication_test() {
121        let n = 1 << 5;
122        let points = (0..n)
123            .map(|_| G1Affine::from(G1Affine::random(OsRng)))
124            .collect::<Vec<_>>();
125        let scalars = (0..n).map(|_| Fr::random(OsRng)).collect::<Vec<_>>();
126        let msm = msm_curve_addition(&points[..], &scalars[..]);
127        let naive = points
128            .iter()
129            .rev()
130            .zip(scalars.iter().rev())
131            .fold(G1Projective::ADDITIVE_IDENTITY, |acc, (point, coeff)| {
132                acc + customized_scalar_point(point.to_extended(), coeff)
133            });
134        assert_eq!(msm, naive);
135    }
136}