liminal_ark_relations/shielder/
merge.rs

1use liminal_ark_relation_macro::snark_relation;
2
3/// It expresses the facts that:
4///  - `first_old_note` is the result of hashing together the `token_id`,
5///    `first_old_token_amount`, `first_old_trapdoor` and `first_old_nullifier`,
6///  - `second_old_note` is the result of hashing together the `token_id`,
7///    `second_old_token_amount`, `second_old_trapdoor` and `second_old_nullifier`,
8///  - `new_note` is the result of hashing together the `token_id`, `new_token_amount`,
9///    `new_trapdoor` and `new_nullifier`,
10///  - `new_token_amount = token_amount + old_token_amount`
11///  - `first_merkle_path` is a valid Merkle proof for `first_old_note` being present
12///    at `first_leaf_index` in some Merkle tree with `merkle_root` hash in the root
13///  - `second_merkle_path` is a valid Merkle proof for `second_old_note` being present
14///    at `second_leaf_index` in some Merkle tree with `merkle_root` hash in the root
15/// Additionally, the relation has one constant input, `max_path_len` which specifies upper bound
16/// for the length of the merkle path (which is ~the height of the tree, ±1).
17#[snark_relation]
18mod relation {
19    #[cfg(feature = "circuit")]
20    use {
21        crate::shielder::{
22            check_merkle_proof, note_var::NoteVarBuilder, path_shape_var::PathShapeVar,
23        },
24        ark_r1cs_std::{
25            alloc::{
26                AllocVar,
27                AllocationMode::{Input, Witness},
28            },
29            eq::EqGadget,
30            fields::fp::FpVar,
31        },
32        ark_relations::ns,
33        core::ops::Add,
34    };
35
36    use crate::shielder::{
37        convert_hash, convert_vec,
38        types::{
39            BackendLeafIndex, BackendMerklePath, BackendMerkleRoot, BackendNote, BackendNullifier,
40            BackendTokenAmount, BackendTokenId, BackendTrapdoor, FrontendLeafIndex,
41            FrontendMerklePath, FrontendMerkleRoot, FrontendNote, FrontendNullifier,
42            FrontendTokenAmount, FrontendTokenId, FrontendTrapdoor,
43        },
44    };
45
46    #[relation_object_definition]
47    #[derive(Clone, Debug)]
48    struct MergeRelation {
49        #[constant]
50        pub max_path_len: u8,
51
52        // Public inputs
53        #[public_input(frontend_type = "FrontendTokenId")]
54        pub token_id: BackendTokenId,
55        #[public_input(frontend_type = "FrontendNullifier", parse_with = "convert_hash")]
56        pub first_old_nullifier: BackendNullifier,
57        #[public_input(frontend_type = "FrontendNullifier", parse_with = "convert_hash")]
58        pub second_old_nullifier: BackendNullifier,
59        #[public_input(frontend_type = "FrontendNote", parse_with = "convert_hash")]
60        pub new_note: BackendNote,
61        #[public_input(frontend_type = "FrontendMerkleRoot", parse_with = "convert_hash")]
62        pub merkle_root: BackendMerkleRoot,
63
64        // Private inputs.
65        #[private_input(frontend_type = "FrontendTrapdoor", parse_with = "convert_hash")]
66        pub first_old_trapdoor: BackendTrapdoor,
67        #[private_input(frontend_type = "FrontendTrapdoor", parse_with = "convert_hash")]
68        pub second_old_trapdoor: BackendTrapdoor,
69        #[private_input(frontend_type = "FrontendTrapdoor", parse_with = "convert_hash")]
70        pub new_trapdoor: BackendTrapdoor,
71        #[private_input(frontend_type = "FrontendNullifier", parse_with = "convert_hash")]
72        pub new_nullifier: BackendNullifier,
73        #[private_input(frontend_type = "FrontendMerklePath", parse_with = "convert_vec")]
74        pub first_merkle_path: BackendMerklePath,
75        #[private_input(frontend_type = "FrontendMerklePath", parse_with = "convert_vec")]
76        pub second_merkle_path: BackendMerklePath,
77        #[private_input(frontend_type = "FrontendLeafIndex")]
78        pub first_leaf_index: BackendLeafIndex,
79        #[private_input(frontend_type = "FrontendLeafIndex")]
80        pub second_leaf_index: BackendLeafIndex,
81        #[private_input(frontend_type = "FrontendNote", parse_with = "convert_hash")]
82        pub first_old_note: BackendNote,
83        #[private_input(frontend_type = "FrontendNote", parse_with = "convert_hash")]
84        pub second_old_note: BackendNote,
85        #[private_input(frontend_type = "FrontendTokenAmount")]
86        pub first_old_token_amount: BackendTokenAmount,
87        #[private_input(frontend_type = "FrontendTokenAmount")]
88        pub second_old_token_amount: BackendTokenAmount,
89        #[private_input(frontend_type = "FrontendTokenAmount")]
90        pub new_token_amount: BackendTokenAmount,
91    }
92
93    #[cfg(feature = "circuit")]
94    #[circuit_definition]
95    fn generate_constraints() {
96        //------------------------------
97        // Check first old note arguments.
98        //------------------------------
99        let first_old_note = NoteVarBuilder::new(cs.clone())
100            .with_token_id(self.token_id(), Input)?
101            .with_token_amount(self.first_old_token_amount(), Witness)?
102            .with_trapdoor(self.first_old_trapdoor(), Witness)?
103            .with_nullifier(self.first_old_nullifier(), Input)?
104            .with_note(self.first_old_note(), Witness)?
105            .build()?;
106
107        //------------------------------
108        // Check second old note arguments.
109        //------------------------------
110        let second_old_note = NoteVarBuilder::new(cs.clone())
111            .with_token_id_var(first_old_note.token_id.clone())
112            .with_token_amount(self.second_old_token_amount(), Witness)?
113            .with_trapdoor(self.second_old_trapdoor(), Witness)?
114            .with_nullifier(self.second_old_nullifier(), Input)?
115            .with_note(self.second_old_note(), Witness)?
116            .build()?;
117
118        //------------------------------
119        // Check new note arguments.
120        //------------------------------
121        let new_note = NoteVarBuilder::new(cs.clone())
122            .with_token_id_var(first_old_note.token_id.clone())
123            .with_token_amount(self.new_token_amount(), Witness)?
124            .with_trapdoor(self.new_trapdoor(), Witness)?
125            .with_nullifier(self.new_nullifier(), Witness)?
126            .with_note(self.new_note(), Input)?
127            .build()?;
128
129        //----------------------------------
130        // Check token value soundness.
131        //----------------------------------
132        let token_sum = first_old_note
133            .token_amount
134            .add(second_old_note.token_amount)?;
135        token_sum.enforce_equal(&new_note.token_amount)?;
136
137        //------------------------
138        // Check first merkle proof.
139        //------------------------
140        let merkle_root = FpVar::new_input(ns!(cs, "merkle root"), || self.merkle_root())?;
141        let first_path_shape = PathShapeVar::new_witness(ns!(cs, "first path shape"), || {
142            Ok((*self.max_path_len(), self.first_leaf_index().cloned()))
143        })?;
144
145        check_merkle_proof(
146            merkle_root.clone(),
147            first_path_shape,
148            first_old_note.note,
149            self.first_merkle_path().cloned().unwrap_or_default(),
150            *self.max_path_len(),
151            cs.clone(),
152        )?;
153
154        //------------------------
155        // Check second merkle proof.
156        //------------------------
157        let second_path_shape = PathShapeVar::new_witness(ns!(cs, "second path shape"), || {
158            Ok((*self.max_path_len(), self.second_leaf_index().cloned()))
159        })?;
160
161        check_merkle_proof(
162            merkle_root,
163            second_path_shape,
164            second_old_note.note,
165            self.second_merkle_path().cloned().unwrap_or_default(),
166            *self.max_path_len(),
167            cs,
168        )
169    }
170}
171
172#[cfg(all(test, feature = "circuit"))]
173mod tests {
174    use ark_bls12_381::Bls12_381;
175    use ark_groth16::Groth16;
176    use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystem};
177    use ark_snark::SNARK;
178
179    use super::*;
180    use crate::shielder::{
181        convert_hash,
182        note::{compute_note, compute_parent_hash},
183        types::FrontendNote,
184    };
185
186    const MAX_PATH_LEN: u8 = 4;
187    const TOKEN_ID: FrontendTokenId = 1;
188
189    const FIRST_OLD_TRAPDOOR: FrontendTrapdoor = [17; 4];
190    const FIRST_OLD_NULLIFIER: FrontendNullifier = [19; 4];
191    const FIRST_OLD_TOKEN_AMOUNT: FrontendTokenAmount = 3;
192
193    const SECOND_OLD_TRAPDOOR: FrontendTrapdoor = [23; 4];
194    const SECOND_OLD_NULLIFIER: FrontendNullifier = [29; 4];
195    const SECOND_OLD_TOKEN_AMOUNT: FrontendTokenAmount = 7;
196
197    const NEW_TRAPDOOR: FrontendTrapdoor = [27; 4];
198    const NEW_NULLIFIER: FrontendNullifier = [87; 4];
199    const NEW_TOKEN_AMOUNT: FrontendTokenAmount = 10;
200
201    const FIRST_LEAF_INDEX: u64 = 5;
202    const SECOND_LEAF_INDEX: u64 = 6;
203
204    fn get_circuit_with_full_input() -> MergeRelationWithFullInput {
205        let first_old_note = compute_note(
206            TOKEN_ID,
207            FIRST_OLD_TOKEN_AMOUNT,
208            FIRST_OLD_TRAPDOOR,
209            FIRST_OLD_NULLIFIER,
210        );
211        let second_old_note = compute_note(
212            TOKEN_ID,
213            SECOND_OLD_TOKEN_AMOUNT,
214            SECOND_OLD_TRAPDOOR,
215            SECOND_OLD_NULLIFIER,
216        );
217        let new_note = compute_note(TOKEN_ID, NEW_TOKEN_AMOUNT, NEW_TRAPDOOR, NEW_NULLIFIER);
218
219        //                                          merkle root
220        //                placeholder                                        x
221        //        1                       x                     x                       x
222        //   2         3              x        x            x       x              x       x
223        // 4  *5*  ^6^   7          x   x    x   x        x   x   x   x          x   x   x   x
224        //
225        // *first_old_note* | ^second_old_note^
226
227        let zero_note = FrontendNote::default(); // x
228
229        // First Merkle path setup.
230        let first_sibling_note = compute_note(0, 1, [2; 4], [3; 4]); // 4
231        let first_parent_note = compute_parent_hash(first_sibling_note, first_old_note); // 2
232
233        // Second Merkle path setup.
234        let second_sibling_note = compute_note(0, 1, [3; 4], [4; 4]); // 7
235        let second_parent_note = compute_parent_hash(second_old_note, second_sibling_note); // 3
236
237        // Merkle paths.
238        let first_merkle_path = vec![first_sibling_note, second_parent_note];
239        let second_merkle_path = vec![second_sibling_note, first_parent_note];
240
241        // Common roots.
242        let grandpa_root = compute_parent_hash(first_parent_note, second_parent_note); // 1
243        let placeholder = compute_parent_hash(grandpa_root, zero_note);
244        let merkle_root = compute_parent_hash(placeholder, zero_note);
245
246        MergeRelationWithFullInput::new(
247            MAX_PATH_LEN,
248            TOKEN_ID,
249            FIRST_OLD_NULLIFIER,
250            SECOND_OLD_NULLIFIER,
251            new_note,
252            merkle_root,
253            FIRST_OLD_TRAPDOOR,
254            SECOND_OLD_TRAPDOOR,
255            NEW_TRAPDOOR,
256            NEW_NULLIFIER,
257            first_merkle_path,
258            second_merkle_path,
259            FIRST_LEAF_INDEX,
260            SECOND_LEAF_INDEX,
261            first_old_note,
262            second_old_note,
263            FIRST_OLD_TOKEN_AMOUNT,
264            SECOND_OLD_TOKEN_AMOUNT,
265            NEW_TOKEN_AMOUNT,
266        )
267    }
268
269    fn get_circuit_with_invalid_first_old_note() -> MergeRelationWithFullInput {
270        let mut circuit = get_circuit_with_full_input();
271
272        let first_old_note = compute_note(
273            TOKEN_ID,
274            FIRST_OLD_TOKEN_AMOUNT + 1,
275            FIRST_OLD_TRAPDOOR,
276            FIRST_OLD_NULLIFIER,
277        );
278        circuit.first_old_note = convert_hash(first_old_note);
279
280        circuit
281    }
282
283    fn get_circuit_with_invalid_second_old_note() -> MergeRelationWithFullInput {
284        let mut circuit = get_circuit_with_full_input();
285
286        let second_old_note = compute_note(
287            TOKEN_ID,
288            SECOND_OLD_TOKEN_AMOUNT + 1,
289            SECOND_OLD_TRAPDOOR,
290            SECOND_OLD_NULLIFIER,
291        );
292        circuit.second_old_note = convert_hash(second_old_note);
293
294        circuit
295    }
296
297    fn get_circuit_with_invalid_new_note() -> MergeRelationWithFullInput {
298        let mut circuit = get_circuit_with_full_input();
299        let new_note = compute_note(
300            TOKEN_ID,
301            NEW_TOKEN_AMOUNT,
302            NEW_TRAPDOOR.map(|t| t + 1),
303            NEW_NULLIFIER,
304        );
305        circuit.new_note = convert_hash(new_note);
306
307        circuit
308    }
309
310    fn get_circuit_with_unsound_value() -> MergeRelationWithFullInput {
311        let mut circuit = get_circuit_with_full_input();
312
313        let new_note = compute_note(TOKEN_ID, NEW_TOKEN_AMOUNT + 1, NEW_TRAPDOOR, NEW_NULLIFIER);
314        circuit.new_note = convert_hash(new_note);
315
316        circuit
317    }
318
319    fn get_circuit_with_invalid_first_leaf_index() -> MergeRelationWithFullInput {
320        let mut circuit = get_circuit_with_full_input();
321        circuit.first_leaf_index = FIRST_LEAF_INDEX + 1;
322        circuit
323    }
324
325    fn get_circuit_with_invalid_second_leaf_index() -> MergeRelationWithFullInput {
326        let mut circuit = get_circuit_with_full_input();
327        circuit.second_leaf_index = SECOND_LEAF_INDEX + 1;
328        circuit
329    }
330
331    fn merge_constraints_correctness(circuit: MergeRelationWithFullInput) -> bool {
332        let cs = ConstraintSystem::new_ref();
333        circuit.generate_constraints(cs.clone()).unwrap();
334
335        let is_satisfied = cs.is_satisfied().unwrap();
336        if !is_satisfied {
337            println!("{:?}", cs.which_is_unsatisfied());
338        }
339
340        is_satisfied
341    }
342
343    fn merge_proving_procedure(circuit_generator: fn() -> MergeRelationWithFullInput) {
344        let circuit_withouth_input = MergeRelationWithoutInput::new(MAX_PATH_LEN);
345
346        let mut rng = ark_std::test_rng();
347        let (pk, vk) =
348            Groth16::<Bls12_381>::circuit_specific_setup(circuit_withouth_input, &mut rng).unwrap();
349
350        let proof = Groth16::prove(&pk, circuit_generator(), &mut rng).unwrap();
351
352        let circuit: MergeRelationWithPublicInput = circuit_generator().into();
353        let input = circuit.serialize_public_input();
354
355        let valid_proof = Groth16::verify(&vk, &input, &proof).unwrap();
356        assert!(valid_proof);
357    }
358
359    #[test]
360    fn merge_constraints_valid_circuit() {
361        let circuit = get_circuit_with_full_input();
362
363        let constraints_correctness = merge_constraints_correctness(circuit);
364        assert!(constraints_correctness);
365    }
366
367    #[test]
368    fn merge_proving_procedure_valid_circuit() {
369        merge_proving_procedure(get_circuit_with_full_input);
370    }
371
372    #[test]
373    fn merge_constraints_invalid_first_old_note() {
374        let invalid_circuit = get_circuit_with_invalid_first_old_note();
375
376        let constraints_correctness = merge_constraints_correctness(invalid_circuit);
377        assert!(!constraints_correctness);
378    }
379
380    #[test]
381    fn merge_constraints_invalid_second_old_note() {
382        let invalid_circuit = get_circuit_with_invalid_second_old_note();
383
384        let constraints_correctness = merge_constraints_correctness(invalid_circuit);
385        assert!(!constraints_correctness);
386    }
387
388    #[test]
389    fn merge_constraints_invalid_new_note() {
390        let invalid_circuit = get_circuit_with_invalid_new_note();
391
392        let constraints_correctness = merge_constraints_correctness(invalid_circuit);
393        assert!(!constraints_correctness);
394    }
395
396    #[test]
397    fn merge_constraints_unsound_value() {
398        let invalid_circuit = get_circuit_with_unsound_value();
399
400        let constraints_correctness = merge_constraints_correctness(invalid_circuit);
401        assert!(!constraints_correctness);
402    }
403
404    #[test]
405    fn merge_constraints_invalid_first_leaf_index() {
406        let invalid_circuit = get_circuit_with_invalid_first_leaf_index();
407
408        let constraints_correctness = merge_constraints_correctness(invalid_circuit);
409        assert!(!constraints_correctness);
410    }
411
412    #[test]
413    fn merge_constraints_invalid_second_leaf_index() {
414        let invalid_circuit = get_circuit_with_invalid_second_leaf_index();
415
416        let constraints_correctness = merge_constraints_correctness(invalid_circuit);
417        assert!(!constraints_correctness);
418    }
419}