ark_msm/
msm.rs

1use std::any::TypeId;
2use crate::{
3    bucket_msm::BucketMSM,
4    glv::{decompose},
5    types::{G1BigInt, BigInt, G1_SCALAR_SIZE_GLV, GROUP_SIZE_IN_BITS},
6};
7use ark_bls12_381::{Fr, g1::Parameters as G1Parameters};
8use ark_ff::{BigInteger, PrimeField};
9use ark_std::log2;
10use ark_ec::{
11    short_weierstrass_jacobian::{GroupAffine, GroupProjective},
12    models::SWModelParameters as Parameters
13};
14
15pub struct VariableBaseMSM;
16
17impl VariableBaseMSM {
18    /// WARNING: this function is derived from benchmark results running
19    /// on a Ubuntu 20.04.2 LTS server with AMD EPYC 7282 16-Core CPU
20    /// and 128G memory, the optimal performance may vary on a different
21    /// configuration.
22    fn get_opt_window_size(k: u32) -> u32 {
23        if k < 10 {
24            return 8;
25        }
26        match k {
27            10 => 10,
28            11 => 10,
29            12 => 10,
30            13 => 12,
31            14 => 12,
32            15 => 13,
33            16 => 13,
34            17 => 13,
35            18 => 13,
36            19 => 13,
37            20 => 15,
38            21 => 15,
39            22 => 15,
40            _ => 16
41        }
42    }
43
44    fn msm_slice<P: Parameters>(scalar: BigInt<P>, slices: &mut Vec<u32>, window_bits: u32) {
45        assert!(window_bits <= 31); // reserve one bit for marking signed slices
46        let mut temp = scalar;
47        for i in 0..slices.len() {
48            slices[i] = (temp.as_ref()[0] % (1 << window_bits)) as u32;
49            temp.divn(window_bits);
50        }
51
52        let mut carry = 0;
53        let total = 1 << window_bits;
54        let half = total >> 1;
55        for i in 0..slices.len() {
56            slices[i] += carry;
57            if slices[i] > half {
58                // slices[i] == half is okay, since (slice[i]-1) will be used for bucket_id
59                slices[i] = total - slices[i];
60                carry = 1;
61                slices[i] |= 1 << 31; // mark the highest bit for later
62            } else {
63                carry = 0;
64            }
65        }
66        assert!(
67            carry == 0,
68            "msm_slice overflows when apply signed-bucket-index"
69        );
70    }
71
72    fn multi_scalar_mul_g1_glv<P: Parameters>(
73        points: &[GroupAffine<P>],
74        scalars: &[BigInt<P>],
75        window_bits: u32,
76        max_batch: u32,
77        max_collisions: u32
78    ) -> GroupProjective<P> {
79        let num_slices: u32 = (G1_SCALAR_SIZE_GLV + window_bits - 1) / window_bits;
80        let mut bucket_msm = BucketMSM::<P>::new(
81            G1_SCALAR_SIZE_GLV,
82            window_bits,
83            max_batch,
84            max_collisions,
85        );
86        // scalar = phi * lambda + normal
87        let mut phi_slices: Vec<u32> = vec![0; num_slices as usize];
88        let mut normal_slices: Vec<u32> = vec![0; num_slices as usize];
89
90        let scalars_and_bases_iter = scalars
91            .iter()
92            .zip(points)
93            .filter(|(s, _)| !s.is_zero());
94        scalars_and_bases_iter.for_each(|(&scalar, point)| {
95            // use unsafe cast for type conversion until we have a better approach
96            let g1_scalar: G1BigInt= unsafe { *(std::ptr::addr_of!(scalar) as *const G1BigInt) };
97            let (phi, normal, is_neg_scalar, is_neg_normal) = decompose(&Fr::from(g1_scalar), window_bits);
98            Self::msm_slice::<G1Parameters>(phi.into(), &mut phi_slices, window_bits);
99            Self::msm_slice::<G1Parameters>(normal.into(), &mut normal_slices, window_bits);
100            bucket_msm.process_point_and_slices_glv(&point, &normal_slices, &phi_slices, is_neg_scalar, is_neg_normal);
101        });
102
103        bucket_msm.process_complete();
104        return bucket_msm.batch_reduce();
105    }
106
107    fn multi_scalar_mul_general<P: Parameters>(
108        points: &[GroupAffine<P>],
109        scalars: &[BigInt<P>],
110        window_bits: u32,
111        max_batch: u32,
112        max_collisions: u32
113    ) -> GroupProjective<P> {
114        let scalar_size = <P::ScalarField as PrimeField>::size_in_bits() as u32;
115        let num_slices: u32 = (scalar_size + window_bits - 1) / window_bits;
116        let mut bucket_msm = BucketMSM::<P>::new(
117            scalar_size,
118            window_bits,
119            max_batch,
120            max_collisions,
121        );
122        let mut slices: Vec<u32> = vec![0; num_slices as usize];
123
124        let scalars_and_bases_iter = scalars
125            .iter()
126            .zip(points)
127            .filter(|(s, _)| !s.is_zero());
128        scalars_and_bases_iter.for_each(|(&scalar, point)| {
129            Self::msm_slice::<P>(scalar, &mut slices, window_bits);
130            bucket_msm.process_point_and_slices(&point, &slices);
131        });
132
133        bucket_msm.process_complete();
134        return bucket_msm.batch_reduce();
135    }
136
137    pub fn multi_scalar_mul_custom<P: Parameters>(
138        points: &[GroupAffine<P>],
139        scalars: &[BigInt<P>],
140        window_bits: u32,
141        max_batch: u32,
142        max_collisions: u32
143    ) -> GroupProjective<P> {
144        assert!(window_bits as usize > GROUP_SIZE_IN_BITS,
145                "Window_bits must be greater than the default log(group size)");
146        if TypeId::of::<P>() == TypeId::of::<G1Parameters>() {
147            Self::multi_scalar_mul_g1_glv(points, scalars, window_bits, max_batch, max_collisions)
148        } else {
149            Self::multi_scalar_mul_general(points, scalars, window_bits, max_batch, max_collisions)
150        }
151    }
152
153    pub fn multi_scalar_mul<P: Parameters>(points: &[GroupAffine<P>], scalars: &[BigInt<P>]) -> GroupProjective<P> {
154        let opt_window_size = Self::get_opt_window_size(log2(points.len()));
155        Self::multi_scalar_mul_custom(&points, &scalars, opt_window_size, 2048, 256)
156    }
157}
158
159
160#[cfg(test)]
161mod collision_method_pippenger_tests {
162    use super::*;
163    use ark_bls12_381::g1::Parameters;
164
165    #[test]
166    fn test_msm_slice_window_size_1() {
167        let scalar = G1BigInt::from(0b101);
168        let mut slices: Vec<u32> = vec![0; 3];
169        VariableBaseMSM::msm_slice::<Parameters>(scalar, &mut slices, 1);
170        // print!("slices {:?}\n", slices);
171        assert_eq!(slices.iter().eq([1, 0, 1].iter()), true);
172    }
173    #[test]
174    fn test_msm_slice_window_size_2() {
175        let scalar = G1BigInt::from(0b000110);
176        let mut slices: Vec<u32> = vec![0; 3];
177        VariableBaseMSM::msm_slice::<Parameters>(scalar, &mut slices, 2);
178        assert_eq!(slices.iter().eq([2, 1, 0].iter()), true);
179    }
180
181    #[test]
182    fn test_msm_slice_window_size_3() {
183        let scalar = G1BigInt::from(0b010111000);
184        let mut slices: Vec<u32> = vec![0; 3];
185        VariableBaseMSM::msm_slice::<Parameters>(scalar, &mut slices, 3);
186        assert_eq!(slices.iter().eq([0, 0x80000001, 3].iter()), true);
187    }
188
189    #[test]
190    fn test_msm_slice_window_size_16() {
191        let scalar = G1BigInt::from(0x123400007FFF);
192        let mut slices: Vec<u32> = vec![0; 3];
193        VariableBaseMSM::msm_slice::<Parameters>(scalar, &mut slices, 16);
194        assert_eq!(slices.iter().eq([0x7FFF, 0, 0x1234].iter()), true);
195    }
196}