circom_prover/
prover.rs

1use anyhow::Result;
2use circom::Proof;
3use num::BigUint;
4use serde::{Deserialize, Serialize};
5use std::{str::FromStr, thread::JoinHandle};
6
7pub mod ark_circom;
8pub mod circom;
9
10#[cfg(feature = "arkworks")]
11pub mod arkworks;
12#[cfg(feature = "rapidsnark")]
13pub mod rapidsnark;
14
15#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
16pub struct PublicInputs(pub Vec<BigUint>);
17
18#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
19pub struct CircomProof {
20    pub proof: Proof,
21    pub pub_inputs: PublicInputs,
22}
23
24#[cfg(test)]
25mod tests {
26    use super::*;
27    use crate::prover::circom::{Proof, G1, G2};
28    use num::BigUint;
29
30    #[test]
31    fn serde_roundtrip_circom_proof() {
32        let a = G1 {
33            x: BigUint::from(1u32),
34            y: BigUint::from(2u32),
35            z: BigUint::from(1u32),
36        };
37        let b = G2 {
38            x: [BigUint::from(3u32), BigUint::from(4u32)],
39            y: [BigUint::from(5u32), BigUint::from(6u32)],
40            z: [BigUint::from(1u32), BigUint::from(0u32)],
41        };
42        let c = G1 {
43            x: BigUint::from(7u32),
44            y: BigUint::from(8u32),
45            z: BigUint::from(1u32),
46        };
47        let proof = Proof {
48            a,
49            b,
50            c,
51            protocol: "groth16".to_string(),
52            curve: "bn128".to_string(),
53        };
54        let pub_inputs = PublicInputs(vec![BigUint::from(9u32), BigUint::from(10u32)]);
55        let cp = CircomProof { proof, pub_inputs };
56
57        let serialized = serde_json::to_string(&cp).unwrap();
58        let deserialized: CircomProof = serde_json::from_str(&serialized).unwrap();
59        assert_eq!(cp, deserialized);
60    }
61}
62
63#[derive(Debug, Clone, Copy)]
64pub enum ProofLib {
65    Arkworks,
66    Rapidsnark,
67}
68
69pub fn prove(
70    lib: ProofLib,
71    zkey_path: String,
72    witnesses: JoinHandle<Vec<BigUint>>,
73) -> Result<CircomProof> {
74    match lib {
75        #[cfg(feature = "arkworks")]
76        ProofLib::Arkworks => arkworks::generate_circom_proof(zkey_path, witnesses),
77        #[cfg(feature = "rapidsnark")]
78        ProofLib::Rapidsnark => rapidsnark::generate_circom_proof(zkey_path, witnesses),
79        #[allow(unreachable_patterns)]
80        _ => panic!("Unsupported proof library"),
81    }
82}
83
84pub fn verify(lib: ProofLib, zkey_path: String, proof: CircomProof) -> Result<bool> {
85    match lib {
86        #[cfg(feature = "arkworks")]
87        ProofLib::Arkworks => arkworks::verify_circom_proof(zkey_path, proof),
88        #[cfg(feature = "rapidsnark")]
89        ProofLib::Rapidsnark => rapidsnark::verify_circom_proof(zkey_path, proof),
90        #[allow(unreachable_patterns)]
91        _ => panic!("Unsupported proof library"),
92    }
93}
94
95//
96// Helper functions to convert PublicInputs to other types we need
97//
98impl From<Vec<String>> for PublicInputs {
99    fn from(src: Vec<String>) -> Self {
100        let pi = src
101            .iter()
102            .map(|str| BigUint::from_str(str).unwrap())
103            .collect();
104        PublicInputs(pi)
105    }
106}
107
108impl From<PublicInputs> for Vec<String> {
109    fn from(src: PublicInputs) -> Self {
110        src.0.iter().map(|p| p.to_string()).collect()
111    }
112}