1use liminal_ark_relation_macro::snark_relation;
2
3#[snark_relation]
18mod relation {
19 #[cfg(feature = "circuit")]
20 use {
21 crate::shielder::{
22 check_merkle_proof, note_var::NoteVarBuilder, path_shape_var::PathShapeVar,
23 },
24 ark_r1cs_std::{
25 alloc::{
26 AllocVar,
27 AllocationMode::{Input, Witness},
28 },
29 eq::EqGadget,
30 fields::fp::FpVar,
31 },
32 ark_relations::ns,
33 core::ops::Add,
34 };
35
36 use crate::shielder::{
37 convert_hash, convert_vec,
38 types::{
39 BackendLeafIndex, BackendMerklePath, BackendMerkleRoot, BackendNote, BackendNullifier,
40 BackendTokenAmount, BackendTokenId, BackendTrapdoor, FrontendLeafIndex,
41 FrontendMerklePath, FrontendMerkleRoot, FrontendNote, FrontendNullifier,
42 FrontendTokenAmount, FrontendTokenId, FrontendTrapdoor,
43 },
44 };
45
46 #[relation_object_definition]
47 #[derive(Clone, Debug)]
48 struct MergeRelation {
49 #[constant]
50 pub max_path_len: u8,
51
52 #[public_input(frontend_type = "FrontendTokenId")]
54 pub token_id: BackendTokenId,
55 #[public_input(frontend_type = "FrontendNullifier", parse_with = "convert_hash")]
56 pub first_old_nullifier: BackendNullifier,
57 #[public_input(frontend_type = "FrontendNullifier", parse_with = "convert_hash")]
58 pub second_old_nullifier: BackendNullifier,
59 #[public_input(frontend_type = "FrontendNote", parse_with = "convert_hash")]
60 pub new_note: BackendNote,
61 #[public_input(frontend_type = "FrontendMerkleRoot", parse_with = "convert_hash")]
62 pub merkle_root: BackendMerkleRoot,
63
64 #[private_input(frontend_type = "FrontendTrapdoor", parse_with = "convert_hash")]
66 pub first_old_trapdoor: BackendTrapdoor,
67 #[private_input(frontend_type = "FrontendTrapdoor", parse_with = "convert_hash")]
68 pub second_old_trapdoor: BackendTrapdoor,
69 #[private_input(frontend_type = "FrontendTrapdoor", parse_with = "convert_hash")]
70 pub new_trapdoor: BackendTrapdoor,
71 #[private_input(frontend_type = "FrontendNullifier", parse_with = "convert_hash")]
72 pub new_nullifier: BackendNullifier,
73 #[private_input(frontend_type = "FrontendMerklePath", parse_with = "convert_vec")]
74 pub first_merkle_path: BackendMerklePath,
75 #[private_input(frontend_type = "FrontendMerklePath", parse_with = "convert_vec")]
76 pub second_merkle_path: BackendMerklePath,
77 #[private_input(frontend_type = "FrontendLeafIndex")]
78 pub first_leaf_index: BackendLeafIndex,
79 #[private_input(frontend_type = "FrontendLeafIndex")]
80 pub second_leaf_index: BackendLeafIndex,
81 #[private_input(frontend_type = "FrontendNote", parse_with = "convert_hash")]
82 pub first_old_note: BackendNote,
83 #[private_input(frontend_type = "FrontendNote", parse_with = "convert_hash")]
84 pub second_old_note: BackendNote,
85 #[private_input(frontend_type = "FrontendTokenAmount")]
86 pub first_old_token_amount: BackendTokenAmount,
87 #[private_input(frontend_type = "FrontendTokenAmount")]
88 pub second_old_token_amount: BackendTokenAmount,
89 #[private_input(frontend_type = "FrontendTokenAmount")]
90 pub new_token_amount: BackendTokenAmount,
91 }
92
93 #[cfg(feature = "circuit")]
94 #[circuit_definition]
95 fn generate_constraints() {
96 let first_old_note = NoteVarBuilder::new(cs.clone())
100 .with_token_id(self.token_id(), Input)?
101 .with_token_amount(self.first_old_token_amount(), Witness)?
102 .with_trapdoor(self.first_old_trapdoor(), Witness)?
103 .with_nullifier(self.first_old_nullifier(), Input)?
104 .with_note(self.first_old_note(), Witness)?
105 .build()?;
106
107 let second_old_note = NoteVarBuilder::new(cs.clone())
111 .with_token_id_var(first_old_note.token_id.clone())
112 .with_token_amount(self.second_old_token_amount(), Witness)?
113 .with_trapdoor(self.second_old_trapdoor(), Witness)?
114 .with_nullifier(self.second_old_nullifier(), Input)?
115 .with_note(self.second_old_note(), Witness)?
116 .build()?;
117
118 let new_note = NoteVarBuilder::new(cs.clone())
122 .with_token_id_var(first_old_note.token_id.clone())
123 .with_token_amount(self.new_token_amount(), Witness)?
124 .with_trapdoor(self.new_trapdoor(), Witness)?
125 .with_nullifier(self.new_nullifier(), Witness)?
126 .with_note(self.new_note(), Input)?
127 .build()?;
128
129 let token_sum = first_old_note
133 .token_amount
134 .add(second_old_note.token_amount)?;
135 token_sum.enforce_equal(&new_note.token_amount)?;
136
137 let merkle_root = FpVar::new_input(ns!(cs, "merkle root"), || self.merkle_root())?;
141 let first_path_shape = PathShapeVar::new_witness(ns!(cs, "first path shape"), || {
142 Ok((*self.max_path_len(), self.first_leaf_index().cloned()))
143 })?;
144
145 check_merkle_proof(
146 merkle_root.clone(),
147 first_path_shape,
148 first_old_note.note,
149 self.first_merkle_path().cloned().unwrap_or_default(),
150 *self.max_path_len(),
151 cs.clone(),
152 )?;
153
154 let second_path_shape = PathShapeVar::new_witness(ns!(cs, "second path shape"), || {
158 Ok((*self.max_path_len(), self.second_leaf_index().cloned()))
159 })?;
160
161 check_merkle_proof(
162 merkle_root,
163 second_path_shape,
164 second_old_note.note,
165 self.second_merkle_path().cloned().unwrap_or_default(),
166 *self.max_path_len(),
167 cs,
168 )
169 }
170}
171
172#[cfg(all(test, feature = "circuit"))]
173mod tests {
174 use ark_bls12_381::Bls12_381;
175 use ark_groth16::Groth16;
176 use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystem};
177 use ark_snark::SNARK;
178
179 use super::*;
180 use crate::shielder::{
181 convert_hash,
182 note::{compute_note, compute_parent_hash},
183 types::FrontendNote,
184 };
185
186 const MAX_PATH_LEN: u8 = 4;
187 const TOKEN_ID: FrontendTokenId = 1;
188
189 const FIRST_OLD_TRAPDOOR: FrontendTrapdoor = [17; 4];
190 const FIRST_OLD_NULLIFIER: FrontendNullifier = [19; 4];
191 const FIRST_OLD_TOKEN_AMOUNT: FrontendTokenAmount = 3;
192
193 const SECOND_OLD_TRAPDOOR: FrontendTrapdoor = [23; 4];
194 const SECOND_OLD_NULLIFIER: FrontendNullifier = [29; 4];
195 const SECOND_OLD_TOKEN_AMOUNT: FrontendTokenAmount = 7;
196
197 const NEW_TRAPDOOR: FrontendTrapdoor = [27; 4];
198 const NEW_NULLIFIER: FrontendNullifier = [87; 4];
199 const NEW_TOKEN_AMOUNT: FrontendTokenAmount = 10;
200
201 const FIRST_LEAF_INDEX: u64 = 5;
202 const SECOND_LEAF_INDEX: u64 = 6;
203
204 fn get_circuit_with_full_input() -> MergeRelationWithFullInput {
205 let first_old_note = compute_note(
206 TOKEN_ID,
207 FIRST_OLD_TOKEN_AMOUNT,
208 FIRST_OLD_TRAPDOOR,
209 FIRST_OLD_NULLIFIER,
210 );
211 let second_old_note = compute_note(
212 TOKEN_ID,
213 SECOND_OLD_TOKEN_AMOUNT,
214 SECOND_OLD_TRAPDOOR,
215 SECOND_OLD_NULLIFIER,
216 );
217 let new_note = compute_note(TOKEN_ID, NEW_TOKEN_AMOUNT, NEW_TRAPDOOR, NEW_NULLIFIER);
218
219 let zero_note = FrontendNote::default(); let first_sibling_note = compute_note(0, 1, [2; 4], [3; 4]); let first_parent_note = compute_parent_hash(first_sibling_note, first_old_note); let second_sibling_note = compute_note(0, 1, [3; 4], [4; 4]); let second_parent_note = compute_parent_hash(second_old_note, second_sibling_note); let first_merkle_path = vec![first_sibling_note, second_parent_note];
239 let second_merkle_path = vec![second_sibling_note, first_parent_note];
240
241 let grandpa_root = compute_parent_hash(first_parent_note, second_parent_note); let placeholder = compute_parent_hash(grandpa_root, zero_note);
244 let merkle_root = compute_parent_hash(placeholder, zero_note);
245
246 MergeRelationWithFullInput::new(
247 MAX_PATH_LEN,
248 TOKEN_ID,
249 FIRST_OLD_NULLIFIER,
250 SECOND_OLD_NULLIFIER,
251 new_note,
252 merkle_root,
253 FIRST_OLD_TRAPDOOR,
254 SECOND_OLD_TRAPDOOR,
255 NEW_TRAPDOOR,
256 NEW_NULLIFIER,
257 first_merkle_path,
258 second_merkle_path,
259 FIRST_LEAF_INDEX,
260 SECOND_LEAF_INDEX,
261 first_old_note,
262 second_old_note,
263 FIRST_OLD_TOKEN_AMOUNT,
264 SECOND_OLD_TOKEN_AMOUNT,
265 NEW_TOKEN_AMOUNT,
266 )
267 }
268
269 fn get_circuit_with_invalid_first_old_note() -> MergeRelationWithFullInput {
270 let mut circuit = get_circuit_with_full_input();
271
272 let first_old_note = compute_note(
273 TOKEN_ID,
274 FIRST_OLD_TOKEN_AMOUNT + 1,
275 FIRST_OLD_TRAPDOOR,
276 FIRST_OLD_NULLIFIER,
277 );
278 circuit.first_old_note = convert_hash(first_old_note);
279
280 circuit
281 }
282
283 fn get_circuit_with_invalid_second_old_note() -> MergeRelationWithFullInput {
284 let mut circuit = get_circuit_with_full_input();
285
286 let second_old_note = compute_note(
287 TOKEN_ID,
288 SECOND_OLD_TOKEN_AMOUNT + 1,
289 SECOND_OLD_TRAPDOOR,
290 SECOND_OLD_NULLIFIER,
291 );
292 circuit.second_old_note = convert_hash(second_old_note);
293
294 circuit
295 }
296
297 fn get_circuit_with_invalid_new_note() -> MergeRelationWithFullInput {
298 let mut circuit = get_circuit_with_full_input();
299 let new_note = compute_note(
300 TOKEN_ID,
301 NEW_TOKEN_AMOUNT,
302 NEW_TRAPDOOR.map(|t| t + 1),
303 NEW_NULLIFIER,
304 );
305 circuit.new_note = convert_hash(new_note);
306
307 circuit
308 }
309
310 fn get_circuit_with_unsound_value() -> MergeRelationWithFullInput {
311 let mut circuit = get_circuit_with_full_input();
312
313 let new_note = compute_note(TOKEN_ID, NEW_TOKEN_AMOUNT + 1, NEW_TRAPDOOR, NEW_NULLIFIER);
314 circuit.new_note = convert_hash(new_note);
315
316 circuit
317 }
318
319 fn get_circuit_with_invalid_first_leaf_index() -> MergeRelationWithFullInput {
320 let mut circuit = get_circuit_with_full_input();
321 circuit.first_leaf_index = FIRST_LEAF_INDEX + 1;
322 circuit
323 }
324
325 fn get_circuit_with_invalid_second_leaf_index() -> MergeRelationWithFullInput {
326 let mut circuit = get_circuit_with_full_input();
327 circuit.second_leaf_index = SECOND_LEAF_INDEX + 1;
328 circuit
329 }
330
331 fn merge_constraints_correctness(circuit: MergeRelationWithFullInput) -> bool {
332 let cs = ConstraintSystem::new_ref();
333 circuit.generate_constraints(cs.clone()).unwrap();
334
335 let is_satisfied = cs.is_satisfied().unwrap();
336 if !is_satisfied {
337 println!("{:?}", cs.which_is_unsatisfied());
338 }
339
340 is_satisfied
341 }
342
343 fn merge_proving_procedure(circuit_generator: fn() -> MergeRelationWithFullInput) {
344 let circuit_withouth_input = MergeRelationWithoutInput::new(MAX_PATH_LEN);
345
346 let mut rng = ark_std::test_rng();
347 let (pk, vk) =
348 Groth16::<Bls12_381>::circuit_specific_setup(circuit_withouth_input, &mut rng).unwrap();
349
350 let proof = Groth16::prove(&pk, circuit_generator(), &mut rng).unwrap();
351
352 let circuit: MergeRelationWithPublicInput = circuit_generator().into();
353 let input = circuit.serialize_public_input();
354
355 let valid_proof = Groth16::verify(&vk, &input, &proof).unwrap();
356 assert!(valid_proof);
357 }
358
359 #[test]
360 fn merge_constraints_valid_circuit() {
361 let circuit = get_circuit_with_full_input();
362
363 let constraints_correctness = merge_constraints_correctness(circuit);
364 assert!(constraints_correctness);
365 }
366
367 #[test]
368 fn merge_proving_procedure_valid_circuit() {
369 merge_proving_procedure(get_circuit_with_full_input);
370 }
371
372 #[test]
373 fn merge_constraints_invalid_first_old_note() {
374 let invalid_circuit = get_circuit_with_invalid_first_old_note();
375
376 let constraints_correctness = merge_constraints_correctness(invalid_circuit);
377 assert!(!constraints_correctness);
378 }
379
380 #[test]
381 fn merge_constraints_invalid_second_old_note() {
382 let invalid_circuit = get_circuit_with_invalid_second_old_note();
383
384 let constraints_correctness = merge_constraints_correctness(invalid_circuit);
385 assert!(!constraints_correctness);
386 }
387
388 #[test]
389 fn merge_constraints_invalid_new_note() {
390 let invalid_circuit = get_circuit_with_invalid_new_note();
391
392 let constraints_correctness = merge_constraints_correctness(invalid_circuit);
393 assert!(!constraints_correctness);
394 }
395
396 #[test]
397 fn merge_constraints_unsound_value() {
398 let invalid_circuit = get_circuit_with_unsound_value();
399
400 let constraints_correctness = merge_constraints_correctness(invalid_circuit);
401 assert!(!constraints_correctness);
402 }
403
404 #[test]
405 fn merge_constraints_invalid_first_leaf_index() {
406 let invalid_circuit = get_circuit_with_invalid_first_leaf_index();
407
408 let constraints_correctness = merge_constraints_correctness(invalid_circuit);
409 assert!(!constraints_correctness);
410 }
411
412 #[test]
413 fn merge_constraints_invalid_second_leaf_index() {
414 let invalid_circuit = get_circuit_with_invalid_second_leaf_index();
415
416 let constraints_correctness = merge_constraints_correctness(invalid_circuit);
417 assert!(!constraints_correctness);
418 }
419}