1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
//! Accumulate all MSMs into a giant MSM and verify them all at the end to [amortize costs](crate::notes::optimizations#msm-accumulator).
//!
//! <center><img width="70%" src="https://github.com/asn-d6/curdleproofs/raw/backup/doc/images/accumulator.png"></img></center>
//!
//! Code adapted from [jellyfish](https://github.com/EspressoSystems/jellyfish/blob/main/plonk/src/proof_system/structs.rs#L865).

#![allow(non_snake_case)]

use ark_bls12_381::{Fr, G1Affine, G1Projective};
use ark_ec::group::Group;
use ark_std::rand::RngCore;
use ark_std::{UniformRand, Zero};

use hashbrown::HashMap;

use crate::errors::ProofError;
use crate::util::msm;

/// An MSM accumulator object
#[derive(Default, Clone)]
pub struct MsmAccumulator {
    A_c: G1Projective,
    base_scalar_map: HashMap<G1Affine, Fr>,
}

impl MsmAccumulator {
    pub(crate) fn new() -> Self {
        Self {
            A_c: G1Projective::zero(),

            base_scalar_map: HashMap::new(),
        }
    }

    /// Accumulate the check $C = \bm{x} \times \bm{V}$
    pub fn accumulate_check<T: RngCore>(
        &mut self,
        C: &G1Projective,
        vec_x: &[Fr],
        vec_V: &[G1Affine],
        rng: &mut T,
    ) {
        let random_factor = Fr::rand(rng); // `a` in the paper

        self.A_c += C.mul(&random_factor);

        for (scalar, base) in vec_x.iter().zip(vec_V.iter()) {
            let entry_scalar = self.base_scalar_map.entry(*base).or_insert_with(Fr::zero);
            *entry_scalar += random_factor * scalar;
        }
    }

    /// Verify all checks accumulated on this MSM accumulator
    pub fn verify(self) -> Result<(), ProofError> {
        let mut bases = vec![];
        let mut scalars = vec![];
        for (base, scalar) in &self.base_scalar_map {
            bases.push(*base);
            scalars.push(*scalar);
        }

        if (msm(&bases, &scalars) - self.A_c).is_zero() {
            Ok(())
        } else {
            Err(ProofError::VerificationError)
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use ark_ec::ProjectiveCurve;
    use ark_std::rand::{rngs::StdRng, SeedableRng};
    use ark_std::UniformRand;
    use core::iter;

    use crate::util::generate_blinders;

    #[test]
    fn test_msm_accumulator() {
        let mut rng = StdRng::seed_from_u64(0u64);

        let n = 4;

        // Let's check that $B == <vec_B, vec_a> ^ D == <vec_D, vec_c>$
        let vec_B: Vec<_> = iter::repeat_with(|| G1Projective::rand(&mut rng).into_affine())
            .take(n)
            .collect();
        let vec_a = generate_blinders(&mut rng, n);
        let B = msm(&vec_B, &vec_a);

        let vec_D: Vec<_> = iter::repeat_with(|| G1Projective::rand(&mut rng).into_affine())
            .take(n)
            .collect();
        let vec_c = generate_blinders(&mut rng, n);
        let D = msm(&vec_D, &vec_c);

        let mut msm_accumulator = MsmAccumulator::new();

        msm_accumulator.accumulate_check(&B, &vec_a, &vec_B, &mut rng);
        msm_accumulator.accumulate_check(&D, &vec_c, &vec_D, &mut rng);

        assert!(msm_accumulator.verify().is_ok());
    }
}