nova_scotia/
lib.rs

1use std::{
2    collections::HashMap,
3    env::current_dir,
4    fs,
5    path::{Path, PathBuf},
6};
7
8use crate::circom::reader::generate_witness_from_bin;
9use circom::circuit::{CircomCircuit, R1CS};
10use ff::Field;
11use nova_snark::{
12    traits::{circuit::TrivialTestCircuit, Group},
13    PublicParams, RecursiveSNARK,
14};
15use num_bigint::BigInt;
16use num_traits::Num;
17use serde::{Deserialize, Serialize};
18use serde_json::Value;
19
20#[cfg(not(target_family = "wasm"))]
21use crate::circom::reader::generate_witness_from_wasm;
22
23#[cfg(target_family = "wasm")]
24use crate::circom::wasm::generate_witness_from_wasm;
25
26pub mod circom;
27
28pub type F<G> = <G as Group>::Scalar;
29pub type EE<G> = nova_snark::provider::ipa_pc::EvaluationEngine<G>;
30pub type S<G> = nova_snark::spartan::snark::RelaxedR1CSSNARK<G, EE<G>>;
31pub type C1<G> = CircomCircuit<<G as Group>::Scalar>;
32pub type C2<G> = TrivialTestCircuit<<G as Group>::Scalar>;
33
34#[derive(Clone)]
35pub enum FileLocation {
36    PathBuf(PathBuf),
37    URL(String),
38}
39
40pub fn create_public_params<G1, G2>(r1cs: R1CS<F<G1>>) -> PublicParams<G1, G2, C1<G1>, C2<G2>>
41where
42    G1: Group<Base = <G2 as Group>::Scalar>,
43    G2: Group<Base = <G1 as Group>::Scalar>,
44{
45    let circuit_primary = CircomCircuit {
46        r1cs,
47        witness: None,
48    };
49    let circuit_secondary = TrivialTestCircuit::default();
50
51    PublicParams::setup(circuit_primary.clone(), circuit_secondary.clone())
52}
53
54#[derive(Serialize, Deserialize)]
55struct CircomInput {
56    step_in: Vec<String>,
57
58    #[serde(flatten)]
59    extra: HashMap<String, Value>,
60}
61
62#[cfg(not(target_family = "wasm"))]
63fn compute_witness<G1, G2>(
64    current_public_input: Vec<String>,
65    private_input: HashMap<String, Value>,
66    witness_generator_file: FileLocation,
67    witness_generator_output: &Path,
68) -> Vec<<G1 as Group>::Scalar>
69where
70    G1: Group<Base = <G2 as Group>::Scalar>,
71    G2: Group<Base = <G1 as Group>::Scalar>,
72{
73    let decimal_stringified_input: Vec<String> = current_public_input
74        .iter()
75        .map(|x| BigInt::from_str_radix(x, 16).unwrap().to_str_radix(10))
76        .collect();
77
78    let input = CircomInput {
79        step_in: decimal_stringified_input.clone(),
80        extra: private_input.clone(),
81    };
82
83    let is_wasm = match &witness_generator_file {
84        FileLocation::PathBuf(path) => path.extension().unwrap_or_default() == "wasm",
85        FileLocation::URL(_) => true,
86    };
87    let input_json = serde_json::to_string(&input).unwrap();
88
89    if is_wasm {
90        generate_witness_from_wasm::<F<G1>>(
91            &witness_generator_file,
92            &input_json,
93            &witness_generator_output,
94        )
95    } else {
96        let witness_generator_file = match &witness_generator_file {
97            FileLocation::PathBuf(path) => path,
98            FileLocation::URL(_) => panic!("unreachable"),
99        };
100        generate_witness_from_bin::<F<G1>>(
101            &witness_generator_file,
102            &input_json,
103            &witness_generator_output,
104        )
105    }
106}
107
108#[cfg(target_family = "wasm")]
109async fn compute_witness<G1, G2>(
110    current_public_input: Vec<String>,
111    private_input: HashMap<String, Value>,
112    witness_generator_file: FileLocation,
113    witness_generator_output: &Path,
114) -> Vec<<G1 as Group>::Scalar>
115where
116    G1: Group<Base = <G2 as Group>::Scalar>,
117    G2: Group<Base = <G1 as Group>::Scalar>,
118{
119    let decimal_stringified_input: Vec<String> = current_public_input
120        .iter()
121        .map(|x| BigInt::from_str_radix(x, 16).unwrap().to_str_radix(10))
122        .collect();
123
124    let input = CircomInput {
125        step_in: decimal_stringified_input.clone(),
126        extra: private_input.clone(),
127    };
128
129    let is_wasm = match &witness_generator_file {
130        FileLocation::PathBuf(path) => path.extension().unwrap_or_default() == "wasm",
131        FileLocation::URL(_) => true,
132    };
133    let input_json = serde_json::to_string(&input).unwrap();
134
135    if is_wasm {
136        generate_witness_from_wasm::<F<G1>>(
137            &witness_generator_file,
138            &input_json,
139            &witness_generator_output,
140        )
141        .await
142    } else {
143        let witness_generator_file = match &witness_generator_file {
144            FileLocation::PathBuf(path) => path,
145            FileLocation::URL(_) => panic!("unreachable"),
146        };
147        generate_witness_from_bin::<F<G1>>(
148            &witness_generator_file,
149            &input_json,
150            &witness_generator_output,
151        )
152    }
153}
154
155#[cfg(not(target_family = "wasm"))]
156pub fn create_recursive_circuit<G1, G2>(
157    witness_generator_file: FileLocation,
158    r1cs: R1CS<F<G1>>,
159    private_inputs: Vec<HashMap<String, Value>>,
160    start_public_input: Vec<F<G1>>,
161    pp: &PublicParams<G1, G2, C1<G1>, C2<G2>>,
162) -> Result<RecursiveSNARK<G1, G2, C1<G1>, C2<G2>>, std::io::Error>
163where
164    G1: Group<Base = <G2 as Group>::Scalar>,
165    G2: Group<Base = <G1 as Group>::Scalar>,
166{
167    let root = current_dir().unwrap();
168    let witness_generator_output = root.join("circom_witness.wtns");
169
170    let iteration_count = private_inputs.len();
171
172    let start_public_input_hex = start_public_input
173        .iter()
174        .map(|&x| format!("{:?}", x).strip_prefix("0x").unwrap().to_string())
175        .collect::<Vec<String>>();
176    let mut current_public_input = start_public_input_hex.clone();
177
178    let witness_0 = compute_witness::<G1, G2>(
179        current_public_input.clone(),
180        private_inputs[0].clone(),
181        witness_generator_file.clone(),
182        &witness_generator_output,
183    );
184
185    let circuit_0 = CircomCircuit {
186        r1cs: r1cs.clone(),
187        witness: Some(witness_0),
188    };
189    let circuit_secondary = TrivialTestCircuit::default();
190    let z0_secondary = vec![G2::Scalar::ZERO];
191
192    let mut recursive_snark = RecursiveSNARK::<G1, G2, C1<G1>, C2<G2>>::new(
193        &pp,
194        &circuit_0,
195        &circuit_secondary,
196        start_public_input.clone(),
197        z0_secondary.clone(),
198    );
199
200    for i in 0..iteration_count {
201        let witness = compute_witness::<G1, G2>(
202            current_public_input.clone(),
203            private_inputs[i].clone(),
204            witness_generator_file.clone(),
205            &witness_generator_output,
206        );
207
208        let circuit = CircomCircuit {
209            r1cs: r1cs.clone(),
210            witness: Some(witness),
211        };
212
213        let current_public_output = circuit.get_public_outputs();
214        current_public_input = current_public_output
215            .iter()
216            .map(|&x| format!("{:?}", x).strip_prefix("0x").unwrap().to_string())
217            .collect();
218
219        let res = recursive_snark.prove_step(
220            &pp,
221            &circuit,
222            &circuit_secondary,
223            start_public_input.clone(),
224            z0_secondary.clone(),
225        );
226        assert!(res.is_ok());
227    }
228    fs::remove_file(witness_generator_output)?;
229
230    Ok(recursive_snark)
231}
232
233#[cfg(target_family = "wasm")]
234pub async fn create_recursive_circuit<G1, G2>(
235    witness_generator_file: FileLocation,
236    r1cs: R1CS<F<G1>>,
237    private_inputs: Vec<HashMap<String, Value>>,
238    start_public_input: Vec<F<G1>>,
239    pp: &PublicParams<G1, G2, C1<G1>, C2<G2>>,
240) -> Result<RecursiveSNARK<G1, G2, C1<G1>, C2<G2>>, std::io::Error>
241where
242    G1: Group<Base = <G2 as Group>::Scalar>,
243    G2: Group<Base = <G1 as Group>::Scalar>,
244{
245    let root = current_dir().unwrap();
246    let witness_generator_output = root.join("circom_witness.wtns");
247
248    let iteration_count = private_inputs.len();
249
250    let start_public_input_hex = start_public_input
251        .iter()
252        .map(|&x| format!("{:?}", x).strip_prefix("0x").unwrap().to_string())
253        .collect::<Vec<String>>();
254    let mut current_public_input = start_public_input_hex.clone();
255
256    let witness_0 = compute_witness::<G1, G2>(
257        current_public_input.clone(),
258        private_inputs[0].clone(),
259        witness_generator_file.clone(),
260        &witness_generator_output,
261    )
262    .await;
263
264    let circuit_0 = CircomCircuit {
265        r1cs: r1cs.clone(),
266        witness: Some(witness_0),
267    };
268    let circuit_secondary = TrivialTestCircuit::default();
269    let z0_secondary = vec![G2::Scalar::ZERO];
270
271    let mut recursive_snark = RecursiveSNARK::<G1, G2, C1<G1>, C2<G2>>::new(
272        &pp,
273        &circuit_0,
274        &circuit_secondary,
275        start_public_input.clone(),
276        z0_secondary.clone(),
277    );
278
279    for i in 0..iteration_count {
280        let witness = compute_witness::<G1, G2>(
281            current_public_input.clone(),
282            private_inputs[i].clone(),
283            witness_generator_file.clone(),
284            &witness_generator_output,
285        )
286        .await;
287
288        let circuit = CircomCircuit {
289            r1cs: r1cs.clone(),
290            witness: Some(witness),
291        };
292
293        let current_public_output = circuit.get_public_outputs();
294        current_public_input = current_public_output
295            .iter()
296            .map(|&x| format!("{:?}", x).strip_prefix("0x").unwrap().to_string())
297            .collect();
298
299        let res = recursive_snark.prove_step(
300            &pp,
301            &circuit,
302            &circuit_secondary,
303            start_public_input.clone(),
304            z0_secondary.clone(),
305        );
306        assert!(res.is_ok());
307    }
308    fs::remove_file(witness_generator_output)?;
309
310    Ok(recursive_snark)
311}