Skip to main content

g1_msm_ref/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![forbid(unsafe_code)]
3
4//! `alt_bn128_g1_msm` reference implementation
5//!
6//! This crate is the proposed reference implementation of the
7//! `alt_bn128_g1_msm` SIMD: Σᵢ scalarsᵢ · pointsᵢ over BN254 G1, computed
8//! by a Pippenger window-NAF multi-scalar multiplication. It is `no_std`-
9//! compatible so the same code path can run on host (for off-chain
10//! verifiers) and inside agave's syscall bridge.
11//!
12//! Two entrypoints:
13//!
14//! * [`alt_bn128_g1_msm_be`] — proposed syscall surface. Takes the wire
15//!   byte layout `[n: u32 LE | scalar₀ | point₀ | scalar₁ | point₁ | …]`
16//!   and returns 64-byte BE G1Affine.
17//! * [`naive_msm_be`] — same surface, but implemented as `n` sequential
18//!   scalar multiplications + additions. Serves as the *baseline* for
19//!   benchmarks: it is what an on-chain verifier ends up doing today
20//!   when it can only call `alt_bn128_g1_multiplication_be` per point.
21//!
22//! Both functions reject identity points + zero scalars consistently
23//! (skipping their contribution rather than erroring), to match the
24//! existing `alt_bn128_*` syscalls' semantics on the empty/identity
25//! input edge cases.
26
27extern crate alloc;
28
29use alloc::vec::Vec;
30use ark_bn254::{Fr, G1Affine, G1Projective};
31use ark_ec::{AffineRepr, CurveGroup};
32use ark_ff::{AdditiveGroup, BigInteger, PrimeField, Zero};
33use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate};
34
35// ---------------------------------------------------------------------------
36// Public surface — the two implementations under comparison.
37// ---------------------------------------------------------------------------
38
39/// Wire-format error code matching the existing `alt_bn128_*` syscalls.
40#[derive(Clone, Copy, Debug, PartialEq, Eq)]
41pub enum MsmError {
42    /// Input length is not `4 + n * 96` for some n ≥ 0.
43    InvalidInputLayout,
44    /// A scalar is not in canonical Fr form (rejected to match groth16-solana).
45    NonCanonicalScalar,
46    /// A G1 point fails the curve equation check.
47    NotOnCurve,
48}
49
50/// Wire format expected by the proposed `alt_bn128_g1_msm_be` syscall.
51///
52/// ```text
53/// [0..4]            : n (u32 little-endian) — number of (scalar, point) pairs
54/// [4..4+n*32]       : scalars      , each 32-byte BE Fr (one after another)
55/// [4+n*32..4+n*96]  : G1 points    , each 64-byte BE G1Affine (x ‖ y)
56/// ```
57///
58/// **Note on layout choice**: scalars and points are *grouped* (all scalars
59/// then all points) rather than *interleaved*. The grouped layout matches
60/// arkworks's `VariableBaseMSM::msm` API directly and avoids a reorder copy
61/// inside the syscall implementation. EIP-2537 took the interleaved approach;
62/// we deliberately diverge for ergonomics (and to follow the Solana SIMD-0284
63/// little-endian-friendly conventions that downstream programs already use).
64///
65/// Returns the 64-byte BE G1Affine encoding of `Σᵢ scalarsᵢ · pointsᵢ`, or
66/// the zero G1 element (64 zero bytes) if the result is the curve identity.
67pub fn alt_bn128_g1_msm_be(input: &[u8]) -> Result<[u8; 64], MsmError> {
68    let (scalars, points) = parse_msm_input(input)?;
69    let result_proj: G1Projective = pippenger_msm(&scalars, &points);
70    Ok(serialise_g1_be(result_proj.into_affine()))
71}
72
73/// Baseline: sequential `Σᵢ scalarᵢ · pointᵢ` via per-point scalar mul + add.
74///
75/// This is the verifier's *current* on-chain code path when it has only
76/// `alt_bn128_g1_multiplication_be` + `alt_bn128_g1_addition_be` available.
77/// We expose it here so the benchmark grid is apples-to-apples: same wire
78/// format on the same byte input, two different inner algorithms.
79pub fn naive_msm_be(input: &[u8]) -> Result<[u8; 64], MsmError> {
80    let (scalars, points) = parse_msm_input(input)?;
81    let mut acc = G1Projective::zero();
82    for (s, p) in scalars.iter().zip(points.iter()) {
83        if s.is_zero() || p.is_zero() {
84            continue;
85        }
86        // Same arithmetic as the existing alt_bn128_g1_multiplication_be syscall:
87        // produce s·P, then add into the running accumulator.
88        acc += *p * *s;
89    }
90    Ok(serialise_g1_be(acc.into_affine()))
91}
92
93// ---------------------------------------------------------------------------
94// Pippenger MSM — pure-Rust reference. Identical strategy to halo2curves's
95// `multiexp_serial`, but using arkworks types so the surface matches the
96// rest of the verifier crate.
97// ---------------------------------------------------------------------------
98
99/// Pippenger window-NAF MSM over BN254 G1.
100///
101/// **Algorithm** (Pippenger, with window size c chosen as a function of n):
102/// 1. For each window of `c` bits across the 254-bit Fr, build `2^c − 1`
103///    buckets indexed by the partial scalar value.
104/// 2. Add each `pointᵢ` into bucket `(scalarᵢ shifted >> window)` for that
105///    window's contribution.
106/// 3. Sum the buckets weighted by their index → window contribution.
107/// 4. Combine windows by `(c · k)`-bit doublings + additions.
108///
109/// Window `c` is picked from a heuristic that minimises the total operation
110/// count `n + (2^c − 1) + ⌈254/c⌉ · 2^c` for each bench-grid n. This matches
111/// the constant table inside arkworks `VariableBaseMSM::msm`.
112fn pippenger_msm(scalars: &[Fr], points: &[G1Affine]) -> G1Projective {
113    let n = scalars.len();
114    debug_assert_eq!(points.len(), n);
115    if n == 0 {
116        return G1Projective::zero();
117    }
118
119    let c = ln_without_floats(n) + 2;
120
121    // Repr is little-endian limbs; convert each scalar to a u64 array we can
122    // window-slice. BigInt::to_bits_le() returns the bits LSB-first.
123    let scalars_bits: Vec<Vec<bool>> = scalars
124        .iter()
125        .map(|s| s.into_bigint().to_bits_le())
126        .collect();
127
128    let num_bits = Fr::MODULUS_BIT_SIZE as usize;
129    let num_windows = (num_bits + c - 1) / c;
130
131    let mut window_sums = Vec::with_capacity(num_windows);
132    for w in 0..num_windows {
133        let bit_start = w * c;
134        let bit_end = (bit_start + c).min(num_bits);
135
136        let bucket_count = 1usize << c;
137        let mut buckets = alloc::vec![G1Projective::zero(); bucket_count];
138
139        for (s_bits, p) in scalars_bits.iter().zip(points.iter()) {
140            // Read this window's c bits (treated as an unsigned integer).
141            let mut idx: usize = 0;
142            for b in (bit_start..bit_end).rev() {
143                idx <<= 1;
144                if *s_bits.get(b).unwrap_or(&false) {
145                    idx |= 1;
146                }
147            }
148            if idx > 0 && !p.is_zero() {
149                buckets[idx] += p;
150            }
151        }
152
153        // Bucket-sum weighted by index: out = Σ_{k=1..bucket_count-1} k · buckets[k]
154        // Computed by a running prefix sum from k = N-1 down to k = 1.
155        // Iterating to bucket[0] would double-count the lower buckets, so we
156        // skip index 0 (which is always identity anyway since the input loop
157        // never writes to it).
158        let mut running = G1Projective::zero();
159        let mut window_sum = G1Projective::zero();
160        for bucket in buckets[1..].iter().rev() {
161            running += bucket;
162            window_sum += running;
163        }
164        window_sums.push(window_sum);
165    }
166
167    // Combine windows: lower windows contribute first, then we double `c`
168    // times and add the next window. Walking from highest to lowest is
169    // equivalent and matches the standard exposition.
170    let mut total = G1Projective::zero();
171    for &window_sum in window_sums.iter().rev() {
172        for _ in 0..c {
173            total.double_in_place();
174        }
175        total += window_sum;
176    }
177    total
178}
179
180#[inline]
181fn ln_without_floats(n: usize) -> usize {
182    // Same heuristic as arkworks: `log2(max(1, n))` rounded toward zero.
183    if n <= 1 {
184        return 1;
185    }
186    let mut v = n;
187    let mut r = 0;
188    while v > 1 {
189        v >>= 1;
190        r += 1;
191    }
192    r
193}
194
195// ---------------------------------------------------------------------------
196// Wire-format parsing (shared by both entry points).
197// ---------------------------------------------------------------------------
198
199fn parse_msm_input(input: &[u8]) -> Result<(Vec<Fr>, Vec<G1Affine>), MsmError> {
200    if input.len() < 4 {
201        return Err(MsmError::InvalidInputLayout);
202    }
203    let n = u32::from_le_bytes([input[0], input[1], input[2], input[3]]) as usize;
204    let body = &input[4..];
205    let want = n.checked_mul(96).ok_or(MsmError::InvalidInputLayout)?;
206    if body.len() != want {
207        return Err(MsmError::InvalidInputLayout);
208    }
209
210    let mut scalars = Vec::with_capacity(n);
211    let mut points = Vec::with_capacity(n);
212
213    let scalars_end = n * 32;
214    let scalars_raw = &body[..scalars_end];
215    let points_raw  = &body[scalars_end..];
216
217    for i in 0..n {
218        let mut be = [0u8; 32];
219        be.copy_from_slice(&scalars_raw[i * 32..(i + 1) * 32]);
220        scalars.push(parse_scalar_be(&be)?);
221    }
222    for i in 0..n {
223        let mut be = [0u8; 64];
224        be.copy_from_slice(&points_raw[i * 64..(i + 1) * 64]);
225        points.push(parse_g1_be(&be)?);
226    }
227    Ok((scalars, points))
228}
229
230fn parse_scalar_be(bytes: &[u8; 32]) -> Result<Fr, MsmError> {
231    // Big-endian → little-endian, then the canonical-bound check inside
232    // Fr::from_le_bytes_mod_order is reduce-mod-p (lossy). For strict
233    // canonical-form rejection (matching groth16-solana / our verifier),
234    // try CanonicalDeserialize::deserialize_compressed which fails on
235    // non-canonical encodings.
236    let mut le = *bytes;
237    le.reverse();
238    Fr::deserialize_compressed(&le[..]).map_err(|_| MsmError::NonCanonicalScalar)
239}
240
241fn parse_g1_be(bytes: &[u8; 64]) -> Result<G1Affine, MsmError> {
242    if bytes == &[0u8; 64] {
243        return Ok(G1Affine::zero()); // identity
244    }
245    let mut le = [0u8; 64];
246    for i in 0..32 {
247        le[i] = bytes[31 - i];
248        le[32 + i] = bytes[63 - i];
249    }
250    G1Affine::deserialize_with_mode(&le[..], Compress::No, Validate::Yes)
251        .map_err(|_| MsmError::NotOnCurve)
252}
253
254fn serialise_g1_be(p: G1Affine) -> [u8; 64] {
255    if p.is_zero() {
256        return [0u8; 64];
257    }
258    let (x, y) = p.xy().expect("non-identity G1 point must have coordinates");
259    let mut out = [0u8; 64];
260    let mut x_le = [0u8; 32];
261    let mut y_le = [0u8; 32];
262    x.serialize_with_mode(&mut x_le[..], Compress::No).expect("Fq serialisation");
263    y.serialize_with_mode(&mut y_le[..], Compress::No).expect("Fq serialisation");
264    for i in 0..32 {
265        out[i] = x_le[31 - i];
266        out[32 + i] = y_le[31 - i];
267    }
268    out
269}
270
271// ---------------------------------------------------------------------------
272// Tests: round-trip naive == pippenger across n = {0, 1, 2, 4, 8, 16, 32}.
273// ---------------------------------------------------------------------------
274
275#[cfg(all(test, feature = "std"))]
276mod tests {
277    use super::*;
278    use ark_std::UniformRand;
279
280    fn build_input(scalars: &[Fr], points: &[G1Affine]) -> Vec<u8> {
281        let n = scalars.len();
282        assert_eq!(points.len(), n);
283        let mut buf = Vec::with_capacity(4 + n * 96);
284        buf.extend_from_slice(&(n as u32).to_le_bytes());
285        for s in scalars {
286            let mut le = [0u8; 32];
287            s.serialize_with_mode(&mut le[..], Compress::No).unwrap();
288            let mut be = le;
289            be.reverse();
290            buf.extend_from_slice(&be);
291        }
292        for p in points {
293            buf.extend_from_slice(&serialise_g1_be(*p));
294        }
295        buf
296    }
297
298    fn rand_scalar_point_pair(rng: &mut impl ark_std::rand::Rng) -> (Fr, G1Affine) {
299        let g = G1Projective::generator();
300        let r = Fr::rand(rng);
301        let p: G1Affine = (g * r).into_affine();
302        let s = Fr::rand(rng);
303        (s, p)
304    }
305
306    fn cross_check_n(n: usize, seed: u64) {
307        use ark_std::rand::SeedableRng;
308        let mut rng = ark_std::rand::rngs::StdRng::seed_from_u64(seed);
309        let (mut scalars, mut points) = (Vec::new(), Vec::new());
310        for _ in 0..n {
311            let (s, p) = rand_scalar_point_pair(&mut rng);
312            scalars.push(s);
313            points.push(p);
314        }
315        let input = build_input(&scalars, &points);
316
317        let naive = naive_msm_be(&input).unwrap();
318        let pipp  = alt_bn128_g1_msm_be(&input).unwrap();
319        assert_eq!(naive, pipp,
320            "naive vs pippenger disagree at n={n}\nnaive  = 0x{}\npipp   = 0x{}",
321            hex::encode(naive), hex::encode(pipp));
322    }
323
324    #[test] fn n0_returns_identity() {
325        let input = (0u32).to_le_bytes().to_vec();
326        let r = alt_bn128_g1_msm_be(&input).unwrap();
327        assert_eq!(r, [0u8; 64]);
328    }
329
330    #[test] fn n1_matches_scalar_mul()    { cross_check_n(1,  1); }
331    #[test] fn n2()                        { cross_check_n(2,  2); }
332    #[test] fn n4()                        { cross_check_n(4,  4); }
333    #[test] fn n8()                        { cross_check_n(8,  8); }
334    #[test] fn n16()                       { cross_check_n(16, 16); }
335    #[test] fn n32()                       { cross_check_n(32, 32); }
336    #[test] fn n64()                       { cross_check_n(64, 64); }
337
338    #[test] fn rejects_invalid_layout() {
339        let input = (1u32).to_le_bytes().to_vec(); // claims n=1 but no body
340        assert_eq!(alt_bn128_g1_msm_be(&input).unwrap_err(), MsmError::InvalidInputLayout);
341    }
342
343    #[test] fn skips_zero_scalar() {
344        use ark_std::rand::SeedableRng;
345        let mut rng = ark_std::rand::rngs::StdRng::seed_from_u64(99);
346        let (_, p) = rand_scalar_point_pair(&mut rng);
347        let scalars = vec![Fr::ZERO];
348        let points  = vec![p];
349        let input = build_input(&scalars, &points);
350        let r = alt_bn128_g1_msm_be(&input).unwrap();
351        assert_eq!(r, [0u8; 64], "0·P should be identity");
352    }
353
354    #[test] fn skips_identity_point() {
355        let scalars = vec![Fr::from(7u64)];
356        let points  = vec![G1Affine::zero()];
357        let input = build_input(&scalars, &points);
358        let r = alt_bn128_g1_msm_be(&input).unwrap();
359        assert_eq!(r, [0u8; 64], "s·O should be identity");
360    }
361}