use crate::{cyclic_group::IsGroup, unsigned_integer::element::UnsignedInteger};
use super::naive::MSMError;
use alloc::vec;
pub fn msm<const NUM_LIMBS: usize, G>(
cs: &[UnsignedInteger<NUM_LIMBS>],
points: &[G],
) -> Result<G, MSMError>
where
G: IsGroup,
{
if cs.len() != points.len() {
return Err(MSMError::LengthMismatch(cs.len(), points.len()));
}
let window_size = optimum_window_size(cs.len());
Ok(msm_with(cs, points, window_size))
}
fn optimum_window_size(data_length: usize) -> usize {
const SCALE_FACTORS: (usize, usize) = (4, 5);
let len_isqrt = data_length.checked_ilog2().unwrap_or(0);
(len_isqrt as usize * SCALE_FACTORS.0) / SCALE_FACTORS.1
}
pub fn msm_with<const NUM_LIMBS: usize, G>(
cs: &[UnsignedInteger<NUM_LIMBS>],
points: &[G],
window_size: usize,
) -> G
where
G: IsGroup,
{
const MIN_WINDOW_SIZE: usize = 2;
const MAX_WINDOW_SIZE: usize = 32;
let window_size = window_size.clamp(MIN_WINDOW_SIZE, MAX_WINDOW_SIZE);
let num_windows = (64 * NUM_LIMBS - 1) / window_size + 1;
let n_buckets = (1 << window_size) - 1;
let mut buckets = vec![G::neutral_element(); n_buckets];
(0..num_windows)
.rev()
.map(|window_idx| {
cs.iter().zip(points).for_each(|(k, p)| {
let window_unmasked = (k >> (window_idx * window_size)).limbs[NUM_LIMBS - 1];
let m_ij = window_unmasked & n_buckets as u64;
if m_ij != 0 {
let idx = (m_ij - 1) as usize;
buckets[idx] = buckets[idx].operate_with(p);
}
});
buckets
.iter_mut()
.rev()
.scan(G::neutral_element(), |m, b| {
*m = m.operate_with(b); *b = G::neutral_element(); Some(m.clone())
})
.reduce(|g, m| g.operate_with(&m))
.unwrap_or_else(G::neutral_element)
})
.reduce(|t, g| t.operate_with_self(1_u64 << window_size).operate_with(&g))
.unwrap_or_else(G::neutral_element)
}
#[cfg(feature = "parallel")]
pub fn parallel_msm_with<const NUM_LIMBS: usize, G>(
cs: &[UnsignedInteger<NUM_LIMBS>],
points: &[G],
window_size: usize,
) -> G
where
G: IsGroup + Send + Sync,
{
use rayon::prelude::*;
assert!(window_size < usize::BITS as usize);
let num_windows = (64 * NUM_LIMBS - 1) / window_size + 1;
let n_buckets = (1 << window_size) - 1;
(0..num_windows)
.into_par_iter()
.map(|window_idx| {
let mut buckets = vec![G::neutral_element(); n_buckets];
let shift = window_idx * window_size;
cs.iter().zip(points).for_each(|(k, p)| {
let window_unmasked = (k >> shift).limbs[NUM_LIMBS - 1];
let m_ij = window_unmasked & n_buckets as u64;
if m_ij != 0 {
let idx = (m_ij - 1) as usize;
buckets[idx] = buckets[idx].operate_with(p);
}
});
let mut m = G::neutral_element();
let window_item = buckets
.into_iter()
.rev()
.map(|b| {
m = m.operate_with(&b); m.clone()
})
.reduce(|g, m| g.operate_with(&m))
.unwrap_or_else(G::neutral_element);
window_item.operate_with_self(UnsignedInteger::<NUM_LIMBS>::from_u64(1) << shift)
})
.reduce(G::neutral_element, |a, b| a.operate_with(&b))
}
#[cfg(test)]
mod tests {
use crate::cyclic_group::IsGroup;
use crate::msm::{naive, pippenger};
use crate::{
elliptic_curve::{
short_weierstrass::curves::bls12_381::curve::BLS12381Curve, traits::IsEllipticCurve,
},
unsigned_integer::element::UnsignedInteger,
};
use alloc::vec::Vec;
use proptest::{collection, prelude::*, prop_assert_eq, prop_compose, proptest};
const _CASES: u32 = 20;
const _MAX_WSIZE: usize = 8;
const _MAX_LEN: usize = 30;
prop_compose! {
fn unsigned_integer()(limbs: [u64; 6]) -> UnsignedInteger<6> {
UnsignedInteger::from_limbs(limbs)
}
}
prop_compose! {
fn unsigned_integer_vec()(vec in collection::vec(unsigned_integer(), 0.._MAX_LEN)) -> Vec<UnsignedInteger<6>> {
vec
}
}
prop_compose! {
fn point()(power: u128) -> <BLS12381Curve as IsEllipticCurve>::PointRepresentation {
BLS12381Curve::generator().operate_with_self(power)
}
}
prop_compose! {
fn points_vec()(vec in collection::vec(point(), 0.._MAX_LEN)) -> Vec<<BLS12381Curve as IsEllipticCurve>::PointRepresentation> {
vec
}
}
proptest! {
#![proptest_config(ProptestConfig {
cases: _CASES, .. ProptestConfig::default()
})]
#[test]
fn test_pippenger_matches_naive_msm(window_size in 1.._MAX_WSIZE, cs in unsigned_integer_vec(), points in points_vec()) {
let min_len = cs.len().min(points.len());
let cs = cs[..min_len].to_vec();
let points = points[..min_len].to_vec();
let pippenger = pippenger::msm_with(&cs, &points, window_size);
let naive = naive::msm(&cs, &points).unwrap();
prop_assert_eq!(naive, pippenger);
}
#[test]
#[cfg(feature = "parallel")]
fn test_parallel_pippenger_matches_sequential(window_size in 1.._MAX_WSIZE, cs in unsigned_integer_vec(), points in points_vec()) {
let min_len = cs.len().min(points.len());
let cs = cs[..min_len].to_vec();
let points = points[..min_len].to_vec();
let sequential = pippenger::msm_with(&cs, &points, window_size);
let parallel = pippenger::parallel_msm_with(&cs, &points, window_size);
prop_assert_eq!(parallel, sequential);
}
}
}