1use std::any::TypeId;
2use crate::{
3 bucket_msm::BucketMSM,
4 glv::{decompose},
5 types::{G1BigInt, BigInt, G1_SCALAR_SIZE_GLV, GROUP_SIZE_IN_BITS},
6};
7use ark_bls12_381::{Fr, g1::Parameters as G1Parameters};
8use ark_ff::{BigInteger, PrimeField};
9use ark_std::log2;
10use ark_ec::{
11 short_weierstrass_jacobian::{GroupAffine, GroupProjective},
12 models::SWModelParameters as Parameters
13};
14
15pub struct VariableBaseMSM;
16
17impl VariableBaseMSM {
18 fn get_opt_window_size(k: u32) -> u32 {
23 if k < 10 {
24 return 8;
25 }
26 match k {
27 10 => 10,
28 11 => 10,
29 12 => 10,
30 13 => 12,
31 14 => 12,
32 15 => 13,
33 16 => 13,
34 17 => 13,
35 18 => 13,
36 19 => 13,
37 20 => 15,
38 21 => 15,
39 22 => 15,
40 _ => 16
41 }
42 }
43
44 fn msm_slice<P: Parameters>(scalar: BigInt<P>, slices: &mut Vec<u32>, window_bits: u32) {
45 assert!(window_bits <= 31); let mut temp = scalar;
47 for i in 0..slices.len() {
48 slices[i] = (temp.as_ref()[0] % (1 << window_bits)) as u32;
49 temp.divn(window_bits);
50 }
51
52 let mut carry = 0;
53 let total = 1 << window_bits;
54 let half = total >> 1;
55 for i in 0..slices.len() {
56 slices[i] += carry;
57 if slices[i] > half {
58 slices[i] = total - slices[i];
60 carry = 1;
61 slices[i] |= 1 << 31; } else {
63 carry = 0;
64 }
65 }
66 assert!(
67 carry == 0,
68 "msm_slice overflows when apply signed-bucket-index"
69 );
70 }
71
72 fn multi_scalar_mul_g1_glv<P: Parameters>(
73 points: &[GroupAffine<P>],
74 scalars: &[BigInt<P>],
75 window_bits: u32,
76 max_batch: u32,
77 max_collisions: u32
78 ) -> GroupProjective<P> {
79 let num_slices: u32 = (G1_SCALAR_SIZE_GLV + window_bits - 1) / window_bits;
80 let mut bucket_msm = BucketMSM::<P>::new(
81 G1_SCALAR_SIZE_GLV,
82 window_bits,
83 max_batch,
84 max_collisions,
85 );
86 let mut phi_slices: Vec<u32> = vec![0; num_slices as usize];
88 let mut normal_slices: Vec<u32> = vec![0; num_slices as usize];
89
90 let scalars_and_bases_iter = scalars
91 .iter()
92 .zip(points)
93 .filter(|(s, _)| !s.is_zero());
94 scalars_and_bases_iter.for_each(|(&scalar, point)| {
95 let g1_scalar: G1BigInt= unsafe { *(std::ptr::addr_of!(scalar) as *const G1BigInt) };
97 let (phi, normal, is_neg_scalar, is_neg_normal) = decompose(&Fr::from(g1_scalar), window_bits);
98 Self::msm_slice::<G1Parameters>(phi.into(), &mut phi_slices, window_bits);
99 Self::msm_slice::<G1Parameters>(normal.into(), &mut normal_slices, window_bits);
100 bucket_msm.process_point_and_slices_glv(&point, &normal_slices, &phi_slices, is_neg_scalar, is_neg_normal);
101 });
102
103 bucket_msm.process_complete();
104 return bucket_msm.batch_reduce();
105 }
106
107 fn multi_scalar_mul_general<P: Parameters>(
108 points: &[GroupAffine<P>],
109 scalars: &[BigInt<P>],
110 window_bits: u32,
111 max_batch: u32,
112 max_collisions: u32
113 ) -> GroupProjective<P> {
114 let scalar_size = <P::ScalarField as PrimeField>::size_in_bits() as u32;
115 let num_slices: u32 = (scalar_size + window_bits - 1) / window_bits;
116 let mut bucket_msm = BucketMSM::<P>::new(
117 scalar_size,
118 window_bits,
119 max_batch,
120 max_collisions,
121 );
122 let mut slices: Vec<u32> = vec![0; num_slices as usize];
123
124 let scalars_and_bases_iter = scalars
125 .iter()
126 .zip(points)
127 .filter(|(s, _)| !s.is_zero());
128 scalars_and_bases_iter.for_each(|(&scalar, point)| {
129 Self::msm_slice::<P>(scalar, &mut slices, window_bits);
130 bucket_msm.process_point_and_slices(&point, &slices);
131 });
132
133 bucket_msm.process_complete();
134 return bucket_msm.batch_reduce();
135 }
136
137 pub fn multi_scalar_mul_custom<P: Parameters>(
138 points: &[GroupAffine<P>],
139 scalars: &[BigInt<P>],
140 window_bits: u32,
141 max_batch: u32,
142 max_collisions: u32
143 ) -> GroupProjective<P> {
144 assert!(window_bits as usize > GROUP_SIZE_IN_BITS,
145 "Window_bits must be greater than the default log(group size)");
146 if TypeId::of::<P>() == TypeId::of::<G1Parameters>() {
147 Self::multi_scalar_mul_g1_glv(points, scalars, window_bits, max_batch, max_collisions)
148 } else {
149 Self::multi_scalar_mul_general(points, scalars, window_bits, max_batch, max_collisions)
150 }
151 }
152
153 pub fn multi_scalar_mul<P: Parameters>(points: &[GroupAffine<P>], scalars: &[BigInt<P>]) -> GroupProjective<P> {
154 let opt_window_size = Self::get_opt_window_size(log2(points.len()));
155 Self::multi_scalar_mul_custom(&points, &scalars, opt_window_size, 2048, 256)
156 }
157}
158
159
160#[cfg(test)]
161mod collision_method_pippenger_tests {
162 use super::*;
163 use ark_bls12_381::g1::Parameters;
164
165 #[test]
166 fn test_msm_slice_window_size_1() {
167 let scalar = G1BigInt::from(0b101);
168 let mut slices: Vec<u32> = vec![0; 3];
169 VariableBaseMSM::msm_slice::<Parameters>(scalar, &mut slices, 1);
170 assert_eq!(slices.iter().eq([1, 0, 1].iter()), true);
172 }
173 #[test]
174 fn test_msm_slice_window_size_2() {
175 let scalar = G1BigInt::from(0b000110);
176 let mut slices: Vec<u32> = vec![0; 3];
177 VariableBaseMSM::msm_slice::<Parameters>(scalar, &mut slices, 2);
178 assert_eq!(slices.iter().eq([2, 1, 0].iter()), true);
179 }
180
181 #[test]
182 fn test_msm_slice_window_size_3() {
183 let scalar = G1BigInt::from(0b010111000);
184 let mut slices: Vec<u32> = vec![0; 3];
185 VariableBaseMSM::msm_slice::<Parameters>(scalar, &mut slices, 3);
186 assert_eq!(slices.iter().eq([0, 0x80000001, 3].iter()), true);
187 }
188
189 #[test]
190 fn test_msm_slice_window_size_16() {
191 let scalar = G1BigInt::from(0x123400007FFF);
192 let mut slices: Vec<u32> = vec![0; 3];
193 VariableBaseMSM::msm_slice::<Parameters>(scalar, &mut slices, 16);
194 assert_eq!(slices.iter().eq([0x7FFF, 0, 0x1234].iter()), true);
195 }
196}