rings_snark/snark/
mod.rs

1//! Implementation of Rings Snark
2//! ==============
3#![allow(clippy::type_complexity)]
4mod impls;
5mod utils;
6use std::ops::Deref;
7use std::ops::DerefMut;
8
9use serde::Deserialize;
10use serde::Serialize;
11use utils::deserialize_forward;
12use utils::serialize_forward;
13
14use crate::circuit::Circuit;
15use crate::error::Result;
16use crate::prelude::nova;
17use crate::prelude::nova::traits::circuit::TrivialCircuit;
18use crate::prelude::nova::traits::snark::RelaxedR1CSSNARKTrait;
19use crate::prelude::nova::traits::Engine;
20use crate::prelude::nova::RecursiveSNARK;
21
22/// Rings Snark implementation, a wrapper of nova's recursion snark and compressed snark
23#[derive(Serialize, Deserialize, Clone, Debug)]
24pub struct SNARK<E1, E2>
25where
26    E1: Engine<Base = <E2 as Engine>::Scalar>,
27    E2: Engine<Base = <E1 as Engine>::Scalar>,
28{
29    /// recursive snark
30    #[serde(flatten)]
31    pub inner: RecursiveSNARK<E1, E2, Circuit<<E1 as Engine>::Scalar>, TrivialCircuit<E2::Scalar>>,
32}
33
34/// Compressed snark
35#[derive(Serialize, Deserialize, Clone)]
36pub struct CompressedSNARK<E1, E2, S1, S2>
37where
38    E1: Engine<Base = <E2 as Engine>::Scalar>,
39    E2: Engine<Base = <E1 as Engine>::Scalar>,
40    S1: RelaxedR1CSSNARKTrait<E1>,
41    S2: RelaxedR1CSSNARKTrait<E2>,
42{
43    #[serde(flatten)]
44    #[serde(
45        serialize_with = "serialize_forward",
46        deserialize_with = "deserialize_forward"
47    )]
48    inner: nova::CompressedSNARK<
49        E1,
50        E2,
51        Circuit<<E1 as Engine>::Scalar>,
52        TrivialCircuit<E2::Scalar>,
53        S1,
54        S2,
55    >,
56}
57
58/// Wrap of nova's public params
59#[derive(Serialize, Deserialize)]
60pub struct PublicParams<E1, E2>
61where
62    E1: Engine<Base = <E2 as Engine>::Scalar>,
63    E2: Engine<Base = <E1 as Engine>::Scalar>,
64{
65    /// public params
66    #[serde(flatten)]
67    pub inner:
68        nova::PublicParams<E1, E2, Circuit<<E1 as Engine>::Scalar>, TrivialCircuit<E2::Scalar>>,
69}
70
71impl<E1, E2> std::fmt::Debug for PublicParams<E1, E2>
72where
73    E1: Engine<Base = <E2 as Engine>::Scalar>,
74    E2: Engine<Base = <E1 as Engine>::Scalar>,
75{
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        f.debug_struct("PublicParams")
78            .field(
79                "inner",
80                &serde_json::to_string(&self.inner).map_err(|_| std::fmt::Error)?,
81            )
82            .finish()
83    }
84}
85
86/// Wrap of nova's prover key
87#[derive(Serialize, Deserialize)]
88pub struct ProverKey<E1, E2, S1, S2>
89where
90    E1: Engine<Base = <E2 as Engine>::Scalar>,
91    E2: Engine<Base = <E1 as Engine>::Scalar>,
92    S1: RelaxedR1CSSNARKTrait<E1>,
93    S2: RelaxedR1CSSNARKTrait<E2>,
94{
95    /// prove key
96    #[serde(flatten)]
97    #[serde(
98        serialize_with = "serialize_forward",
99        deserialize_with = "deserialize_forward"
100    )]
101    pub pk: nova::ProverKey<
102        E1,
103        E2,
104        Circuit<<E1 as Engine>::Scalar>,
105        TrivialCircuit<<E2 as Engine>::Scalar>,
106        S1,
107        S2,
108    >,
109}
110
111/// Wrap of nova's verifier key
112#[derive(Serialize, Deserialize)]
113pub struct VerifierKey<E1, E2, S1, S2>
114where
115    E1: Engine<Base = <E2 as Engine>::Scalar>,
116    E2: Engine<Base = <E1 as Engine>::Scalar>,
117    S1: RelaxedR1CSSNARKTrait<E1>,
118    S2: RelaxedR1CSSNARKTrait<E2>,
119{
120    /// verifier key
121    #[serde(flatten)]
122    #[serde(
123        serialize_with = "serialize_forward",
124        deserialize_with = "deserialize_forward"
125    )]
126    pub vk: nova::VerifierKey<
127        E1,
128        E2,
129        Circuit<<E1 as Engine>::Scalar>,
130        TrivialCircuit<<E2 as Engine>::Scalar>,
131        S1,
132        S2,
133    >,
134}
135
136impl<E1, E2> SNARK<E1, E2>
137where
138    E1: Engine<Base = <E2 as Engine>::Scalar>,
139    E2: Engine<Base = <E1 as Engine>::Scalar>,
140{
141    /// Create public params
142    #[inline]
143    pub fn gen_pp<S1, S2>(circom: Circuit<E1::Scalar>) -> Result<PublicParams<E1, E2>>
144    where
145        S1: RelaxedR1CSSNARKTrait<E1>,
146        S2: RelaxedR1CSSNARKTrait<E2>,
147    {
148        let circuit_primary = circom.clone();
149        let circuit_secondary = TrivialCircuit::<E2::Scalar>::default();
150        let pp = nova::PublicParams::setup(
151            &circuit_primary,
152            &circuit_secondary,
153            S1::ck_floor().deref(),
154            S2::ck_floor().deref(),
155        )?;
156        Ok(pp.into())
157    }
158
159    /// Create public params with circom, and public input
160    pub fn new(
161        circom: impl AsRef<Circuit<E1::Scalar>>,
162        pp: impl AsRef<PublicParams<E1, E2>>,
163        public_inputs: impl AsRef<[E1::Scalar]>,
164        secondary_inputs: impl AsRef<[E2::Scalar]>,
165    ) -> Result<Self> {
166        // flat public input here
167        let circuit_secondary = TrivialCircuit::<E2::Scalar>::default();
168        // default input for secondary on initialize round is [0]
169        let inner = RecursiveSNARK::new(
170            pp.as_ref(),
171            circom.as_ref(),
172            &circuit_secondary,
173            public_inputs.as_ref(),
174            secondary_inputs.as_ref(),
175        )?;
176        Ok(Self { inner })
177    }
178
179    /// Fold next circuit
180    #[inline]
181    pub fn foldr(
182        &mut self,
183        pp: impl AsRef<PublicParams<E1, E2>>,
184        circom: impl AsRef<Circuit<E1::Scalar>>,
185    ) -> Result<()> {
186        let circuit_secondary = TrivialCircuit::<E2::Scalar>::default();
187        let snark = self.deref_mut();
188        snark.prove_step(pp.as_ref(), circom.as_ref(), &circuit_secondary)?;
189        Ok(())
190    }
191
192    /// Fold a set of circuit
193    pub fn fold_all(
194        &mut self,
195        pp: impl AsRef<PublicParams<E1, E2>>,
196        circom: impl AsRef<Vec<Circuit<E1::Scalar>>>,
197    ) -> Result<()> {
198        for c in circom.as_ref() {
199            self.foldr(pp.as_ref(), c)?;
200        }
201        Ok(())
202    }
203
204    /// Verify the correctness of the `RecursiveSNARK`
205    /// Gen compress snark
206    #[inline]
207    pub fn verify(
208        &self,
209        pp: impl AsRef<PublicParams<E1, E2>>,
210        num_steps: usize,
211        z0_primary: impl AsRef<[E1::Scalar]>,
212        z0_secondary: impl AsRef<[E2::Scalar]>,
213    ) -> Result<(Vec<E1::Scalar>, Vec<E2::Scalar>)> {
214        Ok(self.deref().verify(
215            pp.as_ref(),
216            num_steps,
217            z0_primary.as_ref(),
218            z0_secondary.as_ref(),
219        )?)
220    }
221
222    /// Gen compress snark
223    #[inline]
224    pub fn compress_setup<S1, S2>(
225        pp: impl AsRef<PublicParams<E1, E2>>,
226    ) -> Result<(ProverKey<E1, E2, S1, S2>, VerifierKey<E1, E2, S1, S2>)>
227    where
228        S1: RelaxedR1CSSNARKTrait<E1>,
229        S2: RelaxedR1CSSNARKTrait<E2>,
230    {
231        let (pk, vk) = nova::CompressedSNARK::setup(pp.as_ref())?;
232        Ok((ProverKey { pk }, VerifierKey { vk }))
233    }
234
235    /// gen compress_proof
236    #[inline]
237    pub fn compress_prove<S1, S2>(
238        &self,
239        pp: impl AsRef<PublicParams<E1, E2>>,
240        pk: impl AsRef<ProverKey<E1, E2, S1, S2>>,
241    ) -> Result<CompressedSNARK<E1, E2, S1, S2>>
242    where
243        S1: RelaxedR1CSSNARKTrait<E1>,
244        S2: RelaxedR1CSSNARKTrait<E2>,
245    {
246        Ok(nova::CompressedSNARK::<
247            E1,
248            E2,
249            Circuit<<E1 as Engine>::Scalar>,
250            TrivialCircuit<E2::Scalar>,
251            S1,
252            S2,
253        >::prove(pp.as_ref(), pk.as_ref(), self)?
254        .into())
255    }
256
257    /// gen compress_proof
258    #[inline]
259    pub fn compress_verify<S1, S2>(
260        proof: impl AsRef<CompressedSNARK<E1, E2, S1, S2>>,
261        vk: impl AsRef<VerifierKey<E1, E2, S1, S2>>,
262        num_steps: usize,
263        public_inputs: impl AsRef<[E1::Scalar]>,
264    ) -> Result<(Vec<E1::Scalar>, Vec<E2::Scalar>)>
265    where
266        S1: RelaxedR1CSSNARKTrait<E1>,
267        S2: RelaxedR1CSSNARKTrait<E2>,
268    {
269        let z1 = vec![E2::Scalar::from(0)];
270        Ok(proof
271            .as_ref()
272            .verify(vk.as_ref(), num_steps, public_inputs.as_ref(), &z1)?)
273    }
274}