lambdaworks_math/msm/
pippenger.rs

1use crate::{cyclic_group::IsGroup, unsigned_integer::element::UnsignedInteger};
2
3use super::naive::MSMError;
4
5use alloc::vec;
6
7/// This function computes the multiscalar multiplication (MSM).
8///
9/// Assume a group G of order r is given.
10/// Let `points = [g_1, ..., g_n]` be a tuple of group points in G and
11/// let `cs = [k_1, ..., k_n]` be a tuple of scalars in the Galois field GF(r).
12///
13/// Then, with additive notation, `msm(cs, points)` computes k_1 * g_1 + .... + k_n * g_n.
14///
15/// If `points` and `cs` are empty, then `msm` returns the zero element of the group.
16///
17/// Panics if `cs` and `points` have different lengths.
18pub fn msm<const NUM_LIMBS: usize, G>(
19    cs: &[UnsignedInteger<NUM_LIMBS>],
20    points: &[G],
21) -> Result<G, MSMError>
22where
23    G: IsGroup,
24{
25    if cs.len() != points.len() {
26        return Err(MSMError::LengthMismatch(cs.len(), points.len()));
27    }
28
29    let window_size = optimum_window_size(cs.len());
30
31    Ok(msm_with(cs, points, window_size))
32}
33
34fn optimum_window_size(data_length: usize) -> usize {
35    const SCALE_FACTORS: (usize, usize) = (4, 5);
36
37    // We approximate the optimum window size with: f(n) = k * log2(n), where k is a scaling factor
38    let len_isqrt = data_length.checked_ilog2().unwrap_or(0);
39    (len_isqrt as usize * SCALE_FACTORS.0) / SCALE_FACTORS.1
40}
41
42pub fn msm_with<const NUM_LIMBS: usize, G>(
43    cs: &[UnsignedInteger<NUM_LIMBS>],
44    points: &[G],
45    window_size: usize,
46) -> G
47where
48    G: IsGroup,
49{
50    // When input is small enough, windows of length 2 seem faster than 1.
51    const MIN_WINDOW_SIZE: usize = 2;
52    const MAX_WINDOW_SIZE: usize = 32;
53
54    let window_size = window_size.clamp(MIN_WINDOW_SIZE, MAX_WINDOW_SIZE);
55
56    // The number of windows of size `s` is ceil(lambda/s).
57    let num_windows = (64 * NUM_LIMBS - 1) / window_size + 1;
58
59    // We define `buckets` outside of the loop so we only have to allocate once, and reuse it.
60    //
61    // This line forces a heap allocation which might be undesired. We can define this buckets
62    // variable in the Pippenger struct to only allocate once, but use a bit of extra memory.
63    // If we accept a const window_size, we could make it an array instaed of a vector
64    // avoiding the heap allocation. We should be aware if that might be too agressive for
65    // the stack and cause a potential stack overflow.
66    let n_buckets = (1 << window_size) - 1;
67    let mut buckets = vec![G::neutral_element(); n_buckets];
68
69    (0..num_windows)
70        .rev()
71        .map(|window_idx| {
72            // Put in the right bucket the corresponding ps[i] for the current window.
73            cs.iter().zip(points).for_each(|(k, p)| {
74                // We truncate the number to the least significative limb.
75                // This is ok because window_size < usize::BITS.
76                let window_unmasked = (k >> (window_idx * window_size)).limbs[NUM_LIMBS - 1];
77                let m_ij = window_unmasked & n_buckets as u64;
78                if m_ij != 0 {
79                    let idx = (m_ij - 1) as usize;
80                    buckets[idx] = buckets[idx].operate_with(p);
81                }
82            });
83
84            // Do the reduction step for the buckets.
85            buckets
86                .iter_mut()
87                // This first part iterates buckets in descending order, generating an iterator with the sum of
88                // each bucket and all that came before as its items; i.e: (b_n, b_n + b_n-1, ..., b_n + ... + b_0)
89                .rev()
90                .scan(G::neutral_element(), |m, b| {
91                    *m = m.operate_with(b); // Reduction step.
92                    *b = G::neutral_element(); // Cleanup bucket slot to reuse in the next window.
93                    Some(m.clone())
94                })
95                // This next part sums all elements of the iterator: (b_n) + (b_n + b_n-1) + ...
96                // This results in: (n + 1) * b_n + n * b_n-1 + ... + b_0
97                .reduce(|g, m| g.operate_with(&m))
98                .unwrap_or_else(G::neutral_element)
99        })
100        // NOTE: this operation is non-associative and strictly sequential
101        .reduce(|t, g| t.operate_with_self(1_u64 << window_size).operate_with(&g))
102        .unwrap_or_else(G::neutral_element)
103}
104
105#[cfg(feature = "parallel")]
106// It has the following differences with the sequential one:
107//  1. It uses one vec per thread to store buckets.
108//  2. It reduces all window results via a different method.
109pub fn parallel_msm_with<const NUM_LIMBS: usize, G>(
110    cs: &[UnsignedInteger<NUM_LIMBS>],
111    points: &[G],
112    window_size: usize,
113) -> G
114where
115    G: IsGroup + Send + Sync,
116{
117    use rayon::prelude::*;
118
119    assert!(window_size < usize::BITS as usize); // Program would go OOM anyways
120
121    // The number of windows of size `s` is ceil(lambda/s).
122    let num_windows = (64 * NUM_LIMBS - 1) / window_size + 1;
123    let n_buckets = (1 << window_size) - 1;
124
125    // TODO: limit the number of threads, and reuse vecs
126    (0..num_windows)
127        .into_par_iter()
128        .map(|window_idx| {
129            let mut buckets = vec![G::neutral_element(); n_buckets];
130            // Put in the right bucket the corresponding ps[i] for the current window.
131            let shift = window_idx * window_size;
132            cs.iter().zip(points).for_each(|(k, p)| {
133                // We truncate the number to the least significative limb.
134                // This is ok because window_size < usize::BITS.
135                let window_unmasked = (k >> shift).limbs[NUM_LIMBS - 1];
136                let m_ij = window_unmasked & n_buckets as u64;
137                if m_ij != 0 {
138                    let idx = (m_ij - 1) as usize;
139                    buckets[idx] = buckets[idx].operate_with(p);
140                }
141            });
142
143            let mut m = G::neutral_element();
144
145            // Do the reduction step for the buckets.
146            let window_item = buckets
147                // NOTE: changing this into a parallel iter drops performance, because of the
148                //  need to use multiplication in the `map` step
149                .into_iter()
150                .rev()
151                .map(|b| {
152                    m = m.operate_with(&b); // Reduction step.
153                    m.clone()
154                })
155                .reduce(|g, m| g.operate_with(&m))
156                .unwrap_or_else(G::neutral_element);
157
158            window_item.operate_with_self(UnsignedInteger::<NUM_LIMBS>::from_u64(1) << shift)
159        })
160        .reduce(G::neutral_element, |a, b| a.operate_with(&b))
161}
162
163#[cfg(test)]
164mod tests {
165    use crate::cyclic_group::IsGroup;
166    use crate::msm::{naive, pippenger};
167    use crate::{
168        elliptic_curve::{
169            short_weierstrass::curves::bls12_381::curve::BLS12381Curve, traits::IsEllipticCurve,
170        },
171        unsigned_integer::element::UnsignedInteger,
172    };
173    use alloc::vec::Vec;
174    use proptest::{collection, prelude::*, prop_assert_eq, prop_compose, proptest};
175
176    const _CASES: u32 = 20;
177    const _MAX_WSIZE: usize = 8;
178    const _MAX_LEN: usize = 30;
179
180    prop_compose! {
181        fn unsigned_integer()(limbs: [u64; 6]) -> UnsignedInteger<6> {
182            UnsignedInteger::from_limbs(limbs)
183        }
184    }
185
186    prop_compose! {
187        fn unsigned_integer_vec()(vec in collection::vec(unsigned_integer(), 0.._MAX_LEN)) -> Vec<UnsignedInteger<6>> {
188            vec
189        }
190    }
191
192    prop_compose! {
193        fn point()(power: u128) -> <BLS12381Curve as IsEllipticCurve>::PointRepresentation {
194            BLS12381Curve::generator().operate_with_self(power)
195        }
196    }
197
198    prop_compose! {
199        fn points_vec()(vec in collection::vec(point(), 0.._MAX_LEN)) -> Vec<<BLS12381Curve as IsEllipticCurve>::PointRepresentation> {
200            vec
201        }
202    }
203
204    proptest! {
205        #![proptest_config(ProptestConfig {
206            cases: _CASES, .. ProptestConfig::default()
207          })]
208        // Property-based test that ensures `pippenger::msm` gives same result as `naive::msm`.
209        #[test]
210        fn test_pippenger_matches_naive_msm(window_size in 1.._MAX_WSIZE, cs in unsigned_integer_vec(), points in points_vec()) {
211            let min_len = cs.len().min(points.len());
212            let cs = cs[..min_len].to_vec();
213            let points = points[..min_len].to_vec();
214
215            let pippenger = pippenger::msm_with(&cs, &points, window_size);
216            let naive = naive::msm(&cs, &points).unwrap();
217
218            prop_assert_eq!(naive, pippenger);
219        }
220
221        // Property-based test that ensures `pippenger::msm_with` gives same result as `pippenger::parallel_msm_with`.
222        #[test]
223        #[cfg(feature = "parallel")]
224        fn test_parallel_pippenger_matches_sequential(window_size in 1.._MAX_WSIZE, cs in unsigned_integer_vec(), points in points_vec()) {
225            let min_len = cs.len().min(points.len());
226            let cs = cs[..min_len].to_vec();
227            let points = points[..min_len].to_vec();
228
229            let sequential = pippenger::msm_with(&cs, &points, window_size);
230            let parallel = pippenger::parallel_msm_with(&cs, &points, window_size);
231
232            prop_assert_eq!(parallel, sequential);
233        }
234    }
235}