Skip to main content

csv_adapter_core/
mpc.rs

1//! MPC (Multi-Protocol Commitment) Tree
2//!
3//! An MPC tree allows multiple protocols to share a single on-chain witness
4//! transaction. Each leaf is `(protocol_id || commitment_hash)`, and the root
5//! is what gets committed on-chain via Tapret/Opret.
6//!
7//! This follows the same pattern as RGB's MPC tree and Bitcoin's BIP-341
8//! merkle tree construction.
9
10use alloc::vec::Vec;
11
12use crate::hash::Hash;
13use crate::tagged_hash::csv_tagged_hash;
14
15/// Protocol identifier (32 bytes)
16pub type ProtocolId = [u8; 32];
17
18/// A leaf in the MPC tree
19#[derive(Clone, Debug, PartialEq, Eq, Hash)]
20pub struct MpcLeaf {
21    /// Protocol identifier
22    pub protocol_id: ProtocolId,
23    /// Protocol's commitment hash
24    pub commitment: Hash,
25}
26
27impl MpcLeaf {
28    /// Create a new MPC leaf
29    pub fn new(protocol_id: ProtocolId, commitment: Hash) -> Self {
30        Self {
31            protocol_id,
32            commitment,
33        }
34    }
35
36    /// Compute the leaf hash: tagged_hash("mpc-leaf", protocol_id || commitment)
37    pub fn hash(&self) -> Hash {
38        let mut data = Vec::with_capacity(64);
39        data.extend_from_slice(&self.protocol_id);
40        data.extend_from_slice(self.commitment.as_bytes());
41        Hash::new(csv_tagged_hash("mpc-leaf", &data))
42    }
43}
44
45/// Merkle branch proof for a specific protocol's inclusion
46#[derive(Clone, Debug, PartialEq, Eq, Hash)]
47pub struct MpcProof {
48    /// Protocol being proven
49    pub protocol_id: ProtocolId,
50    /// Commitment being proven
51    pub commitment: Hash,
52    /// Merkle branch (sibling hashes from leaf to root)
53    pub branch: Vec<MerkleBranchNode>,
54    /// Position of the leaf (0-indexed)
55    pub leaf_index: usize,
56}
57
58/// A single node in a merkle branch (sibling hash + direction)
59#[derive(Clone, Debug, PartialEq, Eq, Hash)]
60pub struct MerkleBranchNode {
61    /// Sibling hash
62    pub hash: Hash,
63    /// Whether this sibling is on the left (true) or right (false)
64    pub is_left: bool,
65}
66
67impl MpcProof {
68    /// Verify this proof against the claimed root
69    pub fn verify(&self, root: &Hash) -> bool {
70        let mut data = Vec::with_capacity(64);
71        data.extend_from_slice(&self.protocol_id);
72        data.extend_from_slice(self.commitment.as_bytes());
73        let mut current = Hash::new(csv_tagged_hash("mpc-leaf", &data));
74
75        for node in &self.branch {
76            let sibling_data: [u8; 64] = {
77                let mut d = [0u8; 64];
78                if node.is_left {
79                    d[..32].copy_from_slice(node.hash.as_bytes());
80                    d[32..].copy_from_slice(current.as_bytes());
81                } else {
82                    d[..32].copy_from_slice(current.as_bytes());
83                    d[32..].copy_from_slice(node.hash.as_bytes());
84                }
85                d
86            };
87            current = Hash::new(csv_tagged_hash("mpc-internal", &sibling_data));
88        }
89
90        current == *root
91    }
92}
93
94/// Multi-Protocol Commitment tree
95#[derive(Clone, Debug, PartialEq, Eq, Hash)]
96pub struct MpcTree {
97    /// Leaves in deterministic order
98    pub leaves: Vec<MpcLeaf>,
99}
100
101impl MpcTree {
102    /// Create a new MPC tree from leaves
103    pub fn new(leaves: Vec<MpcLeaf>) -> Self {
104        Self { leaves }
105    }
106
107    /// Create from (protocol_id, commitment) pairs
108    pub fn from_pairs(pairs: &[(ProtocolId, Hash)]) -> Self {
109        let leaves = pairs
110            .iter()
111            .map(|(pid, comm)| MpcLeaf::new(*pid, *comm))
112            .collect();
113        Self { leaves }
114    }
115
116    /// Compute the MPC root hash
117    ///
118    /// Uses a deterministic Merkle tree construction. For a single leaf,
119    /// the root is the leaf hash. For multiple leaves, pairs are hashed
120    /// together bottom-up.
121    pub fn root(&self) -> Hash {
122        if self.leaves.is_empty() {
123            return Hash::zero();
124        }
125
126        if self.leaves.len() == 1 {
127            return self.leaves[0].hash();
128        }
129
130        // Collect leaf hashes
131        let mut hashes: Vec<Hash> = self.leaves.iter().map(|l| l.hash()).collect();
132
133        // Build tree bottom-up
134        while hashes.len() > 1 {
135            let mut next_level = Vec::new();
136            for chunk in hashes.chunks(2) {
137                let left = &chunk[0];
138                if chunk.len() == 1 {
139                    // Odd node: promote to next level
140                    next_level.push(*left);
141                } else {
142                    let right = &chunk[1];
143                    next_level.push(hash_pair(left, right));
144                }
145            }
146            hashes = next_level;
147        }
148
149        hashes[0]
150    }
151
152    /// Build a merkle proof for a specific protocol
153    ///
154    /// Returns None if the protocol_id is not in this tree.
155    pub fn prove(&self, protocol_id: ProtocolId) -> Option<MpcProof> {
156        let leaf_index = self
157            .leaves
158            .iter()
159            .position(|l| l.protocol_id == protocol_id)?;
160
161        let leaf = &self.leaves[leaf_index];
162
163        // Build merkle tree with branch tracking
164        let mut levels: Vec<Vec<Hash>> = Vec::new();
165        let current_level: Vec<Hash> = self.leaves.iter().map(|l| l.hash()).collect();
166        levels.push(current_level.clone());
167
168        let mut hashes = current_level;
169        while hashes.len() > 1 {
170            let mut next_level = Vec::new();
171            for chunk in hashes.chunks(2) {
172                let left = &chunk[0];
173                if chunk.len() == 1 {
174                    // Odd node: promote to next level (standard merkle tree behavior)
175                    next_level.push(*left);
176                } else {
177                    next_level.push(hash_pair(left, &chunk[1]));
178                }
179            }
180            hashes = next_level;
181            levels.push(hashes.clone());
182        }
183
184        // Extract branch
185        let mut branch = Vec::new();
186        let mut idx = leaf_index;
187        for level in levels.iter().take(levels.len() - 1) {
188            let (sibling_idx, is_left) = if idx % 2 == 0 {
189                (idx + 1, false) // Sibling is to the right
190            } else {
191                (idx - 1, true) // Sibling is to the left
192            };
193
194            if sibling_idx < level.len() {
195                branch.push(MerkleBranchNode {
196                    hash: level[sibling_idx],
197                    is_left,
198                });
199            }
200
201            idx /= 2;
202        }
203
204        Some(MpcProof {
205            protocol_id: leaf.protocol_id,
206            commitment: leaf.commitment,
207            branch,
208            leaf_index,
209        })
210    }
211
212    /// Get the number of protocols in this tree
213    pub fn protocol_count(&self) -> usize {
214        self.leaves.len()
215    }
216
217    /// Check if a protocol is present in this tree
218    pub fn contains_protocol(&self, protocol_id: ProtocolId) -> bool {
219        self.leaves.iter().any(|l| l.protocol_id == protocol_id)
220    }
221
222    /// Add a protocol to the tree
223    pub fn push(&mut self, protocol_id: ProtocolId, commitment: Hash) {
224        self.leaves.push(MpcLeaf::new(protocol_id, commitment));
225    }
226}
227
228/// Hash two nodes together (internal helper using tagged hashing)
229fn hash_pair(left: &Hash, right: &Hash) -> Hash {
230    let mut data = [0u8; 64];
231    data[..32].copy_from_slice(left.as_bytes());
232    data[32..].copy_from_slice(right.as_bytes());
233    Hash::new(csv_tagged_hash("mpc-internal", &data))
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    fn test_protocol(id: u8) -> ProtocolId {
241        let mut arr = [0u8; 32];
242        arr[0] = id;
243        arr
244    }
245
246    fn test_commitment(id: u8) -> Hash {
247        let mut arr = [0u8; 32];
248        arr[31] = id;
249        Hash::new(arr)
250    }
251
252    // ─────────────────────────────────────────────
253    // MpcLeaf tests
254    // ─────────────────────────────────────────────
255
256    #[test]
257    fn test_leaf_creation() {
258        let leaf = MpcLeaf::new(test_protocol(1), test_commitment(42));
259        assert_eq!(leaf.protocol_id[0], 1);
260        assert_eq!(leaf.commitment.as_bytes()[31], 42);
261    }
262
263    #[test]
264    fn test_leaf_hash_deterministic() {
265        let leaf1 = MpcLeaf::new(test_protocol(1), test_commitment(42));
266        let leaf2 = MpcLeaf::new(test_protocol(1), test_commitment(42));
267        assert_eq!(leaf1.hash(), leaf2.hash());
268    }
269
270    #[test]
271    fn test_leaf_hash_differs_by_protocol() {
272        let leaf1 = MpcLeaf::new(test_protocol(1), test_commitment(42));
273        let leaf2 = MpcLeaf::new(test_protocol(2), test_commitment(42));
274        assert_ne!(leaf1.hash(), leaf2.hash());
275    }
276
277    #[test]
278    fn test_leaf_hash_differs_by_commitment() {
279        let leaf1 = MpcLeaf::new(test_protocol(1), test_commitment(42));
280        let leaf2 = MpcLeaf::new(test_protocol(1), test_commitment(99));
281        assert_ne!(leaf1.hash(), leaf2.hash());
282    }
283
284    // ─────────────────────────────────────────────
285    // MpcTree root tests
286    // ─────────────────────────────────────────────
287
288    #[test]
289    fn test_empty_tree_root() {
290        let tree = MpcTree::new(vec![]);
291        assert_eq!(tree.root(), Hash::zero());
292    }
293
294    #[test]
295    fn test_single_leaf_tree_root() {
296        let leaf = MpcLeaf::new(test_protocol(1), test_commitment(42));
297        let tree = MpcTree::new(vec![leaf.clone()]);
298        assert_eq!(tree.root(), leaf.hash());
299    }
300
301    #[test]
302    fn test_two_leaf_tree_root() {
303        let leaf_a = MpcLeaf::new(test_protocol(1), test_commitment(1));
304        let leaf_b = MpcLeaf::new(test_protocol(2), test_commitment(2));
305        let tree = MpcTree::new(vec![leaf_a.clone(), leaf_b.clone()]);
306        let expected = hash_pair(&leaf_a.hash(), &leaf_b.hash());
307        assert_eq!(tree.root(), expected);
308    }
309
310    #[test]
311    fn test_three_leaf_tree_root() {
312        let leaf_a = MpcLeaf::new(test_protocol(1), test_commitment(1));
313        let leaf_b = MpcLeaf::new(test_protocol(2), test_commitment(2));
314        let leaf_c = MpcLeaf::new(test_protocol(3), test_commitment(3));
315        let tree = MpcTree::new(vec![leaf_a.clone(), leaf_b.clone(), leaf_c.clone()]);
316
317        // Level 0: [A, B, C]
318        // Level 1: [hash(A,B), C]
319        // Level 2: [hash(hash(A,B), C)]
320        let ab = hash_pair(&leaf_a.hash(), &leaf_b.hash());
321        let expected = hash_pair(&ab, &leaf_c.hash());
322        assert_eq!(tree.root(), expected);
323    }
324
325    #[test]
326    fn test_four_leaf_tree_root() {
327        let leaves: Vec<_> = (1..=4)
328            .map(|i| MpcLeaf::new(test_protocol(i), test_commitment(i)))
329            .collect();
330        let tree = MpcTree::new(leaves.clone());
331
332        let ab = hash_pair(&leaves[0].hash(), &leaves[1].hash());
333        let cd = hash_pair(&leaves[2].hash(), &leaves[3].hash());
334        let expected = hash_pair(&ab, &cd);
335        assert_eq!(tree.root(), expected);
336    }
337
338    #[test]
339    fn test_tree_root_deterministic() {
340        let tree1 = MpcTree::from_pairs(&[
341            (test_protocol(1), test_commitment(1)),
342            (test_protocol(2), test_commitment(2)),
343            (test_protocol(3), test_commitment(3)),
344        ]);
345        let tree2 = MpcTree::from_pairs(&[
346            (test_protocol(1), test_commitment(1)),
347            (test_protocol(2), test_commitment(2)),
348            (test_protocol(3), test_commitment(3)),
349        ]);
350        assert_eq!(tree1.root(), tree2.root());
351    }
352
353    #[test]
354    fn test_large_tree_root() {
355        let pairs: Vec<_> = (1..=100)
356            .map(|i| (test_protocol(i as u8), test_commitment(i as u8)))
357            .collect();
358        let tree = MpcTree::from_pairs(&pairs);
359        let root = tree.root();
360        assert_eq!(root.as_bytes().len(), 32);
361    }
362
363    // ─────────────────────────────────────────────
364    // MpcProof tests
365    // ─────────────────────────────────────────────
366
367    #[test]
368    fn test_proof_single_leaf() {
369        let leaf = MpcLeaf::new(test_protocol(1), test_commitment(42));
370        let tree = MpcTree::new(vec![leaf.clone()]);
371        let proof = tree.prove(test_protocol(1)).unwrap();
372        assert!(proof.verify(&tree.root()));
373    }
374
375    #[test]
376    fn test_proof_two_leaves() {
377        let tree = MpcTree::from_pairs(&[
378            (test_protocol(1), test_commitment(1)),
379            (test_protocol(2), test_commitment(2)),
380        ]);
381        let proof_a = tree.prove(test_protocol(1)).unwrap();
382        let proof_b = tree.prove(test_protocol(2)).unwrap();
383        assert!(proof_a.verify(&tree.root()));
384        assert!(proof_b.verify(&tree.root()));
385    }
386
387    #[test]
388    fn test_proof_three_leaves() {
389        let tree = MpcTree::from_pairs(&[
390            (test_protocol(1), test_commitment(1)),
391            (test_protocol(2), test_commitment(2)),
392            (test_protocol(3), test_commitment(3)),
393        ]);
394        for i in 1..=3 {
395            let proof = tree.prove(test_protocol(i)).unwrap();
396            assert!(proof.verify(&tree.root()));
397        }
398    }
399
400    #[test]
401    fn test_proof_all_leaves_in_large_tree() {
402        let pairs: Vec<_> = (1..=20)
403            .map(|i| (test_protocol(i as u8), test_commitment(i as u8)))
404            .collect();
405        let tree = MpcTree::from_pairs(&pairs);
406        for i in 1..=20 {
407            let proof = tree.prove(test_protocol(i as u8)).unwrap();
408            assert!(
409                proof.verify(&tree.root()),
410                "Proof for protocol {} failed",
411                i
412            );
413        }
414    }
415
416    #[test]
417    fn test_proof_missing_protocol() {
418        let tree = MpcTree::from_pairs(&[
419            (test_protocol(1), test_commitment(1)),
420            (test_protocol(2), test_commitment(2)),
421        ]);
422        assert!(tree.prove(test_protocol(99)).is_none());
423    }
424
425    #[test]
426    fn test_proof_wrong_root() {
427        let tree = MpcTree::from_pairs(&[
428            (test_protocol(1), test_commitment(1)),
429            (test_protocol(2), test_commitment(2)),
430        ]);
431        let proof = tree.prove(test_protocol(1)).unwrap();
432        assert!(!proof.verify(&Hash::new([0xFF; 32])));
433    }
434
435    #[test]
436    fn test_proof_wrong_commitment() {
437        let tree = MpcTree::from_pairs(&[
438            (test_protocol(1), test_commitment(1)),
439            (test_protocol(2), test_commitment(2)),
440        ]);
441        let mut proof = tree.prove(test_protocol(1)).unwrap();
442        // Tamper with the commitment
443        proof.commitment = test_commitment(99);
444        assert!(!proof.verify(&tree.root()));
445    }
446
447    #[test]
448    fn test_proof_wrong_protocol_id() {
449        let tree = MpcTree::from_pairs(&[
450            (test_protocol(1), test_commitment(1)),
451            (test_protocol(2), test_commitment(2)),
452        ]);
453        let mut proof = tree.prove(test_protocol(1)).unwrap();
454        // Tamper with the protocol_id
455        proof.protocol_id = test_protocol(99);
456        assert!(!proof.verify(&tree.root()));
457    }
458
459    #[test]
460    fn test_proof_branch_tampering() {
461        let tree = MpcTree::from_pairs(&[
462            (test_protocol(1), test_commitment(1)),
463            (test_protocol(2), test_commitment(2)),
464            (test_protocol(3), test_commitment(3)),
465        ]);
466        let mut proof = tree.prove(test_protocol(1)).unwrap();
467        // Tamper with a branch node
468        proof.branch[0].hash = Hash::new([0xFF; 32]);
469        assert!(!proof.verify(&tree.root()));
470    }
471
472    // ─────────────────────────────────────────────
473    // MpcTree utility tests
474    // ─────────────────────────────────────────────
475
476    #[test]
477    fn test_from_pairs() {
478        let tree = MpcTree::from_pairs(&[
479            (test_protocol(1), test_commitment(1)),
480            (test_protocol(2), test_commitment(2)),
481        ]);
482        assert_eq!(tree.protocol_count(), 2);
483        assert!(tree.contains_protocol(test_protocol(1)));
484        assert!(tree.contains_protocol(test_protocol(2)));
485        assert!(!tree.contains_protocol(test_protocol(3)));
486    }
487
488    #[test]
489    fn test_push() {
490        let mut tree = MpcTree::from_pairs(&[(test_protocol(1), test_commitment(1))]);
491        assert_eq!(tree.protocol_count(), 1);
492        tree.push(test_protocol(2), test_commitment(2));
493        assert_eq!(tree.protocol_count(), 2);
494        assert!(tree.contains_protocol(test_protocol(2)));
495    }
496
497    #[test]
498    fn test_leaf_index_in_proof() {
499        let tree = MpcTree::from_pairs(&[
500            (test_protocol(1), test_commitment(1)),
501            (test_protocol(2), test_commitment(2)),
502            (test_protocol(3), test_commitment(3)),
503        ]);
504        let proof_0 = tree.prove(test_protocol(1)).unwrap();
505        let proof_2 = tree.prove(test_protocol(3)).unwrap();
506        assert_eq!(proof_0.leaf_index, 0);
507        assert_eq!(proof_2.leaf_index, 2);
508    }
509}