1#![cfg_attr(not(feature = "std"), no_std)]
2#![forbid(unsafe_code)]
3
4extern crate alloc;
28
29use alloc::vec::Vec;
30use ark_bn254::{Fr, G1Affine, G1Projective};
31use ark_ec::{AffineRepr, CurveGroup};
32use ark_ff::{AdditiveGroup, BigInteger, PrimeField, Zero};
33use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate};
34
35#[derive(Clone, Copy, Debug, PartialEq, Eq)]
41pub enum MsmError {
42 InvalidInputLayout,
44 NonCanonicalScalar,
46 NotOnCurve,
48}
49
50pub fn alt_bn128_g1_msm_be(input: &[u8]) -> Result<[u8; 64], MsmError> {
68 let (scalars, points) = parse_msm_input(input)?;
69 let result_proj: G1Projective = pippenger_msm(&scalars, &points);
70 Ok(serialise_g1_be(result_proj.into_affine()))
71}
72
73pub fn naive_msm_be(input: &[u8]) -> Result<[u8; 64], MsmError> {
80 let (scalars, points) = parse_msm_input(input)?;
81 let mut acc = G1Projective::zero();
82 for (s, p) in scalars.iter().zip(points.iter()) {
83 if s.is_zero() || p.is_zero() {
84 continue;
85 }
86 acc += *p * *s;
89 }
90 Ok(serialise_g1_be(acc.into_affine()))
91}
92
93fn pippenger_msm(scalars: &[Fr], points: &[G1Affine]) -> G1Projective {
113 let n = scalars.len();
114 debug_assert_eq!(points.len(), n);
115 if n == 0 {
116 return G1Projective::zero();
117 }
118
119 let c = ln_without_floats(n) + 2;
120
121 let scalars_bits: Vec<Vec<bool>> = scalars
124 .iter()
125 .map(|s| s.into_bigint().to_bits_le())
126 .collect();
127
128 let num_bits = Fr::MODULUS_BIT_SIZE as usize;
129 let num_windows = (num_bits + c - 1) / c;
130
131 let mut window_sums = Vec::with_capacity(num_windows);
132 for w in 0..num_windows {
133 let bit_start = w * c;
134 let bit_end = (bit_start + c).min(num_bits);
135
136 let bucket_count = 1usize << c;
137 let mut buckets = alloc::vec![G1Projective::zero(); bucket_count];
138
139 for (s_bits, p) in scalars_bits.iter().zip(points.iter()) {
140 let mut idx: usize = 0;
142 for b in (bit_start..bit_end).rev() {
143 idx <<= 1;
144 if *s_bits.get(b).unwrap_or(&false) {
145 idx |= 1;
146 }
147 }
148 if idx > 0 && !p.is_zero() {
149 buckets[idx] += p;
150 }
151 }
152
153 let mut running = G1Projective::zero();
159 let mut window_sum = G1Projective::zero();
160 for bucket in buckets[1..].iter().rev() {
161 running += bucket;
162 window_sum += running;
163 }
164 window_sums.push(window_sum);
165 }
166
167 let mut total = G1Projective::zero();
171 for &window_sum in window_sums.iter().rev() {
172 for _ in 0..c {
173 total.double_in_place();
174 }
175 total += window_sum;
176 }
177 total
178}
179
180#[inline]
181fn ln_without_floats(n: usize) -> usize {
182 if n <= 1 {
184 return 1;
185 }
186 let mut v = n;
187 let mut r = 0;
188 while v > 1 {
189 v >>= 1;
190 r += 1;
191 }
192 r
193}
194
195fn parse_msm_input(input: &[u8]) -> Result<(Vec<Fr>, Vec<G1Affine>), MsmError> {
200 if input.len() < 4 {
201 return Err(MsmError::InvalidInputLayout);
202 }
203 let n = u32::from_le_bytes([input[0], input[1], input[2], input[3]]) as usize;
204 let body = &input[4..];
205 let want = n.checked_mul(96).ok_or(MsmError::InvalidInputLayout)?;
206 if body.len() != want {
207 return Err(MsmError::InvalidInputLayout);
208 }
209
210 let mut scalars = Vec::with_capacity(n);
211 let mut points = Vec::with_capacity(n);
212
213 let scalars_end = n * 32;
214 let scalars_raw = &body[..scalars_end];
215 let points_raw = &body[scalars_end..];
216
217 for i in 0..n {
218 let mut be = [0u8; 32];
219 be.copy_from_slice(&scalars_raw[i * 32..(i + 1) * 32]);
220 scalars.push(parse_scalar_be(&be)?);
221 }
222 for i in 0..n {
223 let mut be = [0u8; 64];
224 be.copy_from_slice(&points_raw[i * 64..(i + 1) * 64]);
225 points.push(parse_g1_be(&be)?);
226 }
227 Ok((scalars, points))
228}
229
230fn parse_scalar_be(bytes: &[u8; 32]) -> Result<Fr, MsmError> {
231 let mut le = *bytes;
237 le.reverse();
238 Fr::deserialize_compressed(&le[..]).map_err(|_| MsmError::NonCanonicalScalar)
239}
240
241fn parse_g1_be(bytes: &[u8; 64]) -> Result<G1Affine, MsmError> {
242 if bytes == &[0u8; 64] {
243 return Ok(G1Affine::zero()); }
245 let mut le = [0u8; 64];
246 for i in 0..32 {
247 le[i] = bytes[31 - i];
248 le[32 + i] = bytes[63 - i];
249 }
250 G1Affine::deserialize_with_mode(&le[..], Compress::No, Validate::Yes)
251 .map_err(|_| MsmError::NotOnCurve)
252}
253
254fn serialise_g1_be(p: G1Affine) -> [u8; 64] {
255 if p.is_zero() {
256 return [0u8; 64];
257 }
258 let (x, y) = p.xy().expect("non-identity G1 point must have coordinates");
259 let mut out = [0u8; 64];
260 let mut x_le = [0u8; 32];
261 let mut y_le = [0u8; 32];
262 x.serialize_with_mode(&mut x_le[..], Compress::No).expect("Fq serialisation");
263 y.serialize_with_mode(&mut y_le[..], Compress::No).expect("Fq serialisation");
264 for i in 0..32 {
265 out[i] = x_le[31 - i];
266 out[32 + i] = y_le[31 - i];
267 }
268 out
269}
270
271#[cfg(all(test, feature = "std"))]
276mod tests {
277 use super::*;
278 use ark_std::UniformRand;
279
280 fn build_input(scalars: &[Fr], points: &[G1Affine]) -> Vec<u8> {
281 let n = scalars.len();
282 assert_eq!(points.len(), n);
283 let mut buf = Vec::with_capacity(4 + n * 96);
284 buf.extend_from_slice(&(n as u32).to_le_bytes());
285 for s in scalars {
286 let mut le = [0u8; 32];
287 s.serialize_with_mode(&mut le[..], Compress::No).unwrap();
288 let mut be = le;
289 be.reverse();
290 buf.extend_from_slice(&be);
291 }
292 for p in points {
293 buf.extend_from_slice(&serialise_g1_be(*p));
294 }
295 buf
296 }
297
298 fn rand_scalar_point_pair(rng: &mut impl ark_std::rand::Rng) -> (Fr, G1Affine) {
299 let g = G1Projective::generator();
300 let r = Fr::rand(rng);
301 let p: G1Affine = (g * r).into_affine();
302 let s = Fr::rand(rng);
303 (s, p)
304 }
305
306 fn cross_check_n(n: usize, seed: u64) {
307 use ark_std::rand::SeedableRng;
308 let mut rng = ark_std::rand::rngs::StdRng::seed_from_u64(seed);
309 let (mut scalars, mut points) = (Vec::new(), Vec::new());
310 for _ in 0..n {
311 let (s, p) = rand_scalar_point_pair(&mut rng);
312 scalars.push(s);
313 points.push(p);
314 }
315 let input = build_input(&scalars, &points);
316
317 let naive = naive_msm_be(&input).unwrap();
318 let pipp = alt_bn128_g1_msm_be(&input).unwrap();
319 assert_eq!(naive, pipp,
320 "naive vs pippenger disagree at n={n}\nnaive = 0x{}\npipp = 0x{}",
321 hex::encode(naive), hex::encode(pipp));
322 }
323
324 #[test] fn n0_returns_identity() {
325 let input = (0u32).to_le_bytes().to_vec();
326 let r = alt_bn128_g1_msm_be(&input).unwrap();
327 assert_eq!(r, [0u8; 64]);
328 }
329
330 #[test] fn n1_matches_scalar_mul() { cross_check_n(1, 1); }
331 #[test] fn n2() { cross_check_n(2, 2); }
332 #[test] fn n4() { cross_check_n(4, 4); }
333 #[test] fn n8() { cross_check_n(8, 8); }
334 #[test] fn n16() { cross_check_n(16, 16); }
335 #[test] fn n32() { cross_check_n(32, 32); }
336 #[test] fn n64() { cross_check_n(64, 64); }
337
338 #[test] fn rejects_invalid_layout() {
339 let input = (1u32).to_le_bytes().to_vec(); assert_eq!(alt_bn128_g1_msm_be(&input).unwrap_err(), MsmError::InvalidInputLayout);
341 }
342
343 #[test] fn skips_zero_scalar() {
344 use ark_std::rand::SeedableRng;
345 let mut rng = ark_std::rand::rngs::StdRng::seed_from_u64(99);
346 let (_, p) = rand_scalar_point_pair(&mut rng);
347 let scalars = vec![Fr::ZERO];
348 let points = vec![p];
349 let input = build_input(&scalars, &points);
350 let r = alt_bn128_g1_msm_be(&input).unwrap();
351 assert_eq!(r, [0u8; 64], "0·P should be identity");
352 }
353
354 #[test] fn skips_identity_point() {
355 let scalars = vec![Fr::from(7u64)];
356 let points = vec![G1Affine::zero()];
357 let input = build_input(&scalars, &points);
358 let r = alt_bn128_g1_msm_be(&input).unwrap();
359 assert_eq!(r, [0u8; 64], "s·O should be identity");
360 }
361}