1use std::{
2 collections::HashMap,
3 env::current_dir,
4 fs,
5 path::{Path, PathBuf},
6};
7
8use crate::circom::reader::generate_witness_from_bin;
9use circom::circuit::{CircomCircuit, R1CS};
10use ff::Field;
11use nova_snark::{
12 traits::{circuit::TrivialTestCircuit, Group},
13 PublicParams, RecursiveSNARK,
14};
15use num_bigint::BigInt;
16use num_traits::Num;
17use serde::{Deserialize, Serialize};
18use serde_json::Value;
19
20#[cfg(not(target_family = "wasm"))]
21use crate::circom::reader::generate_witness_from_wasm;
22
23#[cfg(target_family = "wasm")]
24use crate::circom::wasm::generate_witness_from_wasm;
25
26pub mod circom;
27
28pub type F<G> = <G as Group>::Scalar;
29pub type EE<G> = nova_snark::provider::ipa_pc::EvaluationEngine<G>;
30pub type S<G> = nova_snark::spartan::snark::RelaxedR1CSSNARK<G, EE<G>>;
31pub type C1<G> = CircomCircuit<<G as Group>::Scalar>;
32pub type C2<G> = TrivialTestCircuit<<G as Group>::Scalar>;
33
34#[derive(Clone)]
35pub enum FileLocation {
36 PathBuf(PathBuf),
37 URL(String),
38}
39
40pub fn create_public_params<G1, G2>(r1cs: R1CS<F<G1>>) -> PublicParams<G1, G2, C1<G1>, C2<G2>>
41where
42 G1: Group<Base = <G2 as Group>::Scalar>,
43 G2: Group<Base = <G1 as Group>::Scalar>,
44{
45 let circuit_primary = CircomCircuit {
46 r1cs,
47 witness: None,
48 };
49 let circuit_secondary = TrivialTestCircuit::default();
50
51 PublicParams::setup(circuit_primary.clone(), circuit_secondary.clone())
52}
53
54#[derive(Serialize, Deserialize)]
55struct CircomInput {
56 step_in: Vec<String>,
57
58 #[serde(flatten)]
59 extra: HashMap<String, Value>,
60}
61
62#[cfg(not(target_family = "wasm"))]
63fn compute_witness<G1, G2>(
64 current_public_input: Vec<String>,
65 private_input: HashMap<String, Value>,
66 witness_generator_file: FileLocation,
67 witness_generator_output: &Path,
68) -> Vec<<G1 as Group>::Scalar>
69where
70 G1: Group<Base = <G2 as Group>::Scalar>,
71 G2: Group<Base = <G1 as Group>::Scalar>,
72{
73 let decimal_stringified_input: Vec<String> = current_public_input
74 .iter()
75 .map(|x| BigInt::from_str_radix(x, 16).unwrap().to_str_radix(10))
76 .collect();
77
78 let input = CircomInput {
79 step_in: decimal_stringified_input.clone(),
80 extra: private_input.clone(),
81 };
82
83 let is_wasm = match &witness_generator_file {
84 FileLocation::PathBuf(path) => path.extension().unwrap_or_default() == "wasm",
85 FileLocation::URL(_) => true,
86 };
87 let input_json = serde_json::to_string(&input).unwrap();
88
89 if is_wasm {
90 generate_witness_from_wasm::<F<G1>>(
91 &witness_generator_file,
92 &input_json,
93 &witness_generator_output,
94 )
95 } else {
96 let witness_generator_file = match &witness_generator_file {
97 FileLocation::PathBuf(path) => path,
98 FileLocation::URL(_) => panic!("unreachable"),
99 };
100 generate_witness_from_bin::<F<G1>>(
101 &witness_generator_file,
102 &input_json,
103 &witness_generator_output,
104 )
105 }
106}
107
108#[cfg(target_family = "wasm")]
109async fn compute_witness<G1, G2>(
110 current_public_input: Vec<String>,
111 private_input: HashMap<String, Value>,
112 witness_generator_file: FileLocation,
113 witness_generator_output: &Path,
114) -> Vec<<G1 as Group>::Scalar>
115where
116 G1: Group<Base = <G2 as Group>::Scalar>,
117 G2: Group<Base = <G1 as Group>::Scalar>,
118{
119 let decimal_stringified_input: Vec<String> = current_public_input
120 .iter()
121 .map(|x| BigInt::from_str_radix(x, 16).unwrap().to_str_radix(10))
122 .collect();
123
124 let input = CircomInput {
125 step_in: decimal_stringified_input.clone(),
126 extra: private_input.clone(),
127 };
128
129 let is_wasm = match &witness_generator_file {
130 FileLocation::PathBuf(path) => path.extension().unwrap_or_default() == "wasm",
131 FileLocation::URL(_) => true,
132 };
133 let input_json = serde_json::to_string(&input).unwrap();
134
135 if is_wasm {
136 generate_witness_from_wasm::<F<G1>>(
137 &witness_generator_file,
138 &input_json,
139 &witness_generator_output,
140 )
141 .await
142 } else {
143 let witness_generator_file = match &witness_generator_file {
144 FileLocation::PathBuf(path) => path,
145 FileLocation::URL(_) => panic!("unreachable"),
146 };
147 generate_witness_from_bin::<F<G1>>(
148 &witness_generator_file,
149 &input_json,
150 &witness_generator_output,
151 )
152 }
153}
154
155#[cfg(not(target_family = "wasm"))]
156pub fn create_recursive_circuit<G1, G2>(
157 witness_generator_file: FileLocation,
158 r1cs: R1CS<F<G1>>,
159 private_inputs: Vec<HashMap<String, Value>>,
160 start_public_input: Vec<F<G1>>,
161 pp: &PublicParams<G1, G2, C1<G1>, C2<G2>>,
162) -> Result<RecursiveSNARK<G1, G2, C1<G1>, C2<G2>>, std::io::Error>
163where
164 G1: Group<Base = <G2 as Group>::Scalar>,
165 G2: Group<Base = <G1 as Group>::Scalar>,
166{
167 let root = current_dir().unwrap();
168 let witness_generator_output = root.join("circom_witness.wtns");
169
170 let iteration_count = private_inputs.len();
171
172 let start_public_input_hex = start_public_input
173 .iter()
174 .map(|&x| format!("{:?}", x).strip_prefix("0x").unwrap().to_string())
175 .collect::<Vec<String>>();
176 let mut current_public_input = start_public_input_hex.clone();
177
178 let witness_0 = compute_witness::<G1, G2>(
179 current_public_input.clone(),
180 private_inputs[0].clone(),
181 witness_generator_file.clone(),
182 &witness_generator_output,
183 );
184
185 let circuit_0 = CircomCircuit {
186 r1cs: r1cs.clone(),
187 witness: Some(witness_0),
188 };
189 let circuit_secondary = TrivialTestCircuit::default();
190 let z0_secondary = vec![G2::Scalar::ZERO];
191
192 let mut recursive_snark = RecursiveSNARK::<G1, G2, C1<G1>, C2<G2>>::new(
193 &pp,
194 &circuit_0,
195 &circuit_secondary,
196 start_public_input.clone(),
197 z0_secondary.clone(),
198 );
199
200 for i in 0..iteration_count {
201 let witness = compute_witness::<G1, G2>(
202 current_public_input.clone(),
203 private_inputs[i].clone(),
204 witness_generator_file.clone(),
205 &witness_generator_output,
206 );
207
208 let circuit = CircomCircuit {
209 r1cs: r1cs.clone(),
210 witness: Some(witness),
211 };
212
213 let current_public_output = circuit.get_public_outputs();
214 current_public_input = current_public_output
215 .iter()
216 .map(|&x| format!("{:?}", x).strip_prefix("0x").unwrap().to_string())
217 .collect();
218
219 let res = recursive_snark.prove_step(
220 &pp,
221 &circuit,
222 &circuit_secondary,
223 start_public_input.clone(),
224 z0_secondary.clone(),
225 );
226 assert!(res.is_ok());
227 }
228 fs::remove_file(witness_generator_output)?;
229
230 Ok(recursive_snark)
231}
232
233#[cfg(target_family = "wasm")]
234pub async fn create_recursive_circuit<G1, G2>(
235 witness_generator_file: FileLocation,
236 r1cs: R1CS<F<G1>>,
237 private_inputs: Vec<HashMap<String, Value>>,
238 start_public_input: Vec<F<G1>>,
239 pp: &PublicParams<G1, G2, C1<G1>, C2<G2>>,
240) -> Result<RecursiveSNARK<G1, G2, C1<G1>, C2<G2>>, std::io::Error>
241where
242 G1: Group<Base = <G2 as Group>::Scalar>,
243 G2: Group<Base = <G1 as Group>::Scalar>,
244{
245 let root = current_dir().unwrap();
246 let witness_generator_output = root.join("circom_witness.wtns");
247
248 let iteration_count = private_inputs.len();
249
250 let start_public_input_hex = start_public_input
251 .iter()
252 .map(|&x| format!("{:?}", x).strip_prefix("0x").unwrap().to_string())
253 .collect::<Vec<String>>();
254 let mut current_public_input = start_public_input_hex.clone();
255
256 let witness_0 = compute_witness::<G1, G2>(
257 current_public_input.clone(),
258 private_inputs[0].clone(),
259 witness_generator_file.clone(),
260 &witness_generator_output,
261 )
262 .await;
263
264 let circuit_0 = CircomCircuit {
265 r1cs: r1cs.clone(),
266 witness: Some(witness_0),
267 };
268 let circuit_secondary = TrivialTestCircuit::default();
269 let z0_secondary = vec![G2::Scalar::ZERO];
270
271 let mut recursive_snark = RecursiveSNARK::<G1, G2, C1<G1>, C2<G2>>::new(
272 &pp,
273 &circuit_0,
274 &circuit_secondary,
275 start_public_input.clone(),
276 z0_secondary.clone(),
277 );
278
279 for i in 0..iteration_count {
280 let witness = compute_witness::<G1, G2>(
281 current_public_input.clone(),
282 private_inputs[i].clone(),
283 witness_generator_file.clone(),
284 &witness_generator_output,
285 )
286 .await;
287
288 let circuit = CircomCircuit {
289 r1cs: r1cs.clone(),
290 witness: Some(witness),
291 };
292
293 let current_public_output = circuit.get_public_outputs();
294 current_public_input = current_public_output
295 .iter()
296 .map(|&x| format!("{:?}", x).strip_prefix("0x").unwrap().to_string())
297 .collect();
298
299 let res = recursive_snark.prove_step(
300 &pp,
301 &circuit,
302 &circuit_secondary,
303 start_public_input.clone(),
304 z0_secondary.clone(),
305 );
306 assert!(res.is_ok());
307 }
308 fs::remove_file(witness_generator_output)?;
309
310 Ok(recursive_snark)
311}