Skip to main content

csv_adapter_core/
mpc.rs

1//! MPC (Multi-Protocol Commitment) Tree
2//!
3//! An MPC tree allows multiple protocols to share a single on-chain witness
4//! transaction. Each leaf is `(protocol_id || commitment_hash)`, and the root
5//! is what gets committed on-chain via Tapret/Opret.
6//!
7//! This follows the same pattern as RGB's MPC tree and Bitcoin's BIP-341
8//! merkle tree construction.
9
10use alloc::vec::Vec;
11
12use crate::hash::Hash;
13use crate::tagged_hash::csv_tagged_hash;
14
15/// Protocol identifier (32 bytes)
16pub type ProtocolId = [u8; 32];
17
18/// A leaf in the MPC tree
19#[derive(Clone, Debug, PartialEq, Eq, Hash)]
20pub struct MpcLeaf {
21    /// Protocol identifier
22    pub protocol_id: ProtocolId,
23    /// Protocol's commitment hash
24    pub commitment: Hash,
25}
26
27impl MpcLeaf {
28    /// Create a new MPC leaf
29    pub fn new(protocol_id: ProtocolId, commitment: Hash) -> Self {
30        Self {
31            protocol_id,
32            commitment,
33        }
34    }
35
36    /// Compute the leaf hash: tagged_hash("mpc-leaf", protocol_id || commitment)
37    pub fn hash(&self) -> Hash {
38        let mut data = Vec::with_capacity(64);
39        data.extend_from_slice(&self.protocol_id);
40        data.extend_from_slice(self.commitment.as_bytes());
41        Hash::new(csv_tagged_hash("mpc-leaf", &data))
42    }
43}
44
45/// Merkle branch proof for a specific protocol's inclusion
46#[derive(Clone, Debug, PartialEq, Eq, Hash)]
47pub struct MpcProof {
48    /// Protocol being proven
49    pub protocol_id: ProtocolId,
50    /// Commitment being proven
51    pub commitment: Hash,
52    /// Merkle branch (sibling hashes from leaf to root)
53    pub branch: Vec<MerkleBranchNode>,
54    /// Position of the leaf (0-indexed)
55    pub leaf_index: usize,
56}
57
58/// A single node in a merkle branch (sibling hash + direction)
59#[derive(Clone, Debug, PartialEq, Eq, Hash)]
60pub struct MerkleBranchNode {
61    /// Sibling hash
62    pub hash: Hash,
63    /// Whether this sibling is on the left (true) or right (false)
64    pub is_left: bool,
65}
66
67impl MpcProof {
68    /// Verify this proof against the claimed root
69    pub fn verify(&self, root: &Hash) -> bool {
70        let mut data = Vec::with_capacity(64);
71        data.extend_from_slice(&self.protocol_id);
72        data.extend_from_slice(self.commitment.as_bytes());
73        let mut current = Hash::new(csv_tagged_hash("mpc-leaf", &data));
74
75        for node in &self.branch {
76            let sibling_data: [u8; 64] = {
77                let mut d = [0u8; 64];
78                if node.is_left {
79                    d[..32].copy_from_slice(node.hash.as_bytes());
80                    d[32..].copy_from_slice(current.as_bytes());
81                } else {
82                    d[..32].copy_from_slice(current.as_bytes());
83                    d[32..].copy_from_slice(node.hash.as_bytes());
84                }
85                d
86            };
87            current = Hash::new(csv_tagged_hash("mpc-internal", &sibling_data));
88        }
89
90        current == *root
91    }
92}
93
94/// Multi-Protocol Commitment tree
95#[derive(Clone, Debug, PartialEq, Eq, Hash)]
96pub struct MpcTree {
97    /// Leaves in deterministic order
98    pub leaves: Vec<MpcLeaf>,
99}
100
101impl MpcTree {
102    /// Create a new MPC tree from leaves
103    pub fn new(leaves: Vec<MpcLeaf>) -> Self {
104        Self { leaves }
105    }
106
107    /// Create from (protocol_id, commitment) pairs
108    pub fn from_pairs(pairs: &[(ProtocolId, Hash)]) -> Self {
109        let leaves = pairs
110            .iter()
111            .map(|(pid, comm)| MpcLeaf::new(*pid, *comm))
112            .collect();
113        Self { leaves }
114    }
115
116    /// Compute the MPC root hash
117    ///
118    /// Uses a deterministic Merkle tree construction. For a single leaf,
119    /// the root is the leaf hash. For multiple leaves, pairs are hashed
120    /// together bottom-up.
121    pub fn root(&self) -> Hash {
122        if self.leaves.is_empty() {
123            return Hash::zero();
124        }
125
126        if self.leaves.len() == 1 {
127            return self.leaves[0].hash();
128        }
129
130        // Collect leaf hashes
131        let mut hashes: Vec<Hash> = self.leaves.iter().map(|l| l.hash()).collect();
132
133        // Build tree bottom-up
134        while hashes.len() > 1 {
135            let mut next_level = Vec::new();
136            for chunk in hashes.chunks(2) {
137                let left = &chunk[0];
138                if chunk.len() == 1 {
139                    // Odd node: promote to next level
140                    next_level.push(*left);
141                } else {
142                    let right = &chunk[1];
143                    next_level.push(hash_pair(left, right));
144                }
145            }
146            hashes = next_level;
147        }
148
149        hashes[0]
150    }
151
152    /// Build a merkle proof for a specific protocol
153    ///
154    /// Returns None if the protocol_id is not in this tree.
155    pub fn prove(&self, protocol_id: ProtocolId) -> Option<MpcProof> {
156        let leaf_index = self
157            .leaves
158            .iter()
159            .position(|l| l.protocol_id == protocol_id)?;
160
161        let leaf = &self.leaves[leaf_index];
162
163        // Build merkle tree with branch tracking
164        let mut levels: Vec<Vec<Hash>> = Vec::new();
165        let current_level: Vec<Hash> = self.leaves.iter().map(|l| l.hash()).collect();
166        levels.push(current_level.clone());
167
168        let mut hashes = current_level;
169        while hashes.len() > 1 {
170            let mut next_level = Vec::new();
171            for chunk in hashes.chunks(2) {
172                let left = &chunk[0];
173                if chunk.len() == 1 {
174                    // Odd node: promote to next level (standard merkle tree behavior)
175                    next_level.push(*left);
176                } else {
177                    next_level.push(hash_pair(left, &chunk[1]));
178                }
179            }
180            hashes = next_level;
181            levels.push(hashes.clone());
182        }
183
184        // Extract branch
185        let mut branch = Vec::new();
186        let mut idx = leaf_index;
187        for level_idx in 0..levels.len() - 1 {
188            let level = &levels[level_idx];
189            let (sibling_idx, is_left) = if idx % 2 == 0 {
190                (idx + 1, false) // Sibling is to the right
191            } else {
192                (idx - 1, true) // Sibling is to the left
193            };
194
195            if sibling_idx < level.len() {
196                branch.push(MerkleBranchNode {
197                    hash: level[sibling_idx],
198                    is_left,
199                });
200            }
201
202            idx /= 2;
203        }
204
205        Some(MpcProof {
206            protocol_id: leaf.protocol_id,
207            commitment: leaf.commitment,
208            branch,
209            leaf_index,
210        })
211    }
212
213    /// Get the number of protocols in this tree
214    pub fn protocol_count(&self) -> usize {
215        self.leaves.len()
216    }
217
218    /// Check if a protocol is present in this tree
219    pub fn contains_protocol(&self, protocol_id: ProtocolId) -> bool {
220        self.leaves.iter().any(|l| l.protocol_id == protocol_id)
221    }
222
223    /// Add a protocol to the tree
224    pub fn push(&mut self, protocol_id: ProtocolId, commitment: Hash) {
225        self.leaves.push(MpcLeaf::new(protocol_id, commitment));
226    }
227}
228
229/// Hash two nodes together (internal helper using tagged hashing)
230fn hash_pair(left: &Hash, right: &Hash) -> Hash {
231    let mut data = [0u8; 64];
232    data[..32].copy_from_slice(left.as_bytes());
233    data[32..].copy_from_slice(right.as_bytes());
234    Hash::new(csv_tagged_hash("mpc-internal", &data))
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    fn test_protocol(id: u8) -> ProtocolId {
242        let mut arr = [0u8; 32];
243        arr[0] = id;
244        arr
245    }
246
247    fn test_commitment(id: u8) -> Hash {
248        let mut arr = [0u8; 32];
249        arr[31] = id;
250        Hash::new(arr)
251    }
252
253    // ─────────────────────────────────────────────
254    // MpcLeaf tests
255    // ─────────────────────────────────────────────
256
257    #[test]
258    fn test_leaf_creation() {
259        let leaf = MpcLeaf::new(test_protocol(1), test_commitment(42));
260        assert_eq!(leaf.protocol_id[0], 1);
261        assert_eq!(leaf.commitment.as_bytes()[31], 42);
262    }
263
264    #[test]
265    fn test_leaf_hash_deterministic() {
266        let leaf1 = MpcLeaf::new(test_protocol(1), test_commitment(42));
267        let leaf2 = MpcLeaf::new(test_protocol(1), test_commitment(42));
268        assert_eq!(leaf1.hash(), leaf2.hash());
269    }
270
271    #[test]
272    fn test_leaf_hash_differs_by_protocol() {
273        let leaf1 = MpcLeaf::new(test_protocol(1), test_commitment(42));
274        let leaf2 = MpcLeaf::new(test_protocol(2), test_commitment(42));
275        assert_ne!(leaf1.hash(), leaf2.hash());
276    }
277
278    #[test]
279    fn test_leaf_hash_differs_by_commitment() {
280        let leaf1 = MpcLeaf::new(test_protocol(1), test_commitment(42));
281        let leaf2 = MpcLeaf::new(test_protocol(1), test_commitment(99));
282        assert_ne!(leaf1.hash(), leaf2.hash());
283    }
284
285    // ─────────────────────────────────────────────
286    // MpcTree root tests
287    // ─────────────────────────────────────────────
288
289    #[test]
290    fn test_empty_tree_root() {
291        let tree = MpcTree::new(vec![]);
292        assert_eq!(tree.root(), Hash::zero());
293    }
294
295    #[test]
296    fn test_single_leaf_tree_root() {
297        let leaf = MpcLeaf::new(test_protocol(1), test_commitment(42));
298        let tree = MpcTree::new(vec![leaf.clone()]);
299        assert_eq!(tree.root(), leaf.hash());
300    }
301
302    #[test]
303    fn test_two_leaf_tree_root() {
304        let leaf_a = MpcLeaf::new(test_protocol(1), test_commitment(1));
305        let leaf_b = MpcLeaf::new(test_protocol(2), test_commitment(2));
306        let tree = MpcTree::new(vec![leaf_a.clone(), leaf_b.clone()]);
307        let expected = hash_pair(&leaf_a.hash(), &leaf_b.hash());
308        assert_eq!(tree.root(), expected);
309    }
310
311    #[test]
312    fn test_three_leaf_tree_root() {
313        let leaf_a = MpcLeaf::new(test_protocol(1), test_commitment(1));
314        let leaf_b = MpcLeaf::new(test_protocol(2), test_commitment(2));
315        let leaf_c = MpcLeaf::new(test_protocol(3), test_commitment(3));
316        let tree = MpcTree::new(vec![leaf_a.clone(), leaf_b.clone(), leaf_c.clone()]);
317
318        // Level 0: [A, B, C]
319        // Level 1: [hash(A,B), C]
320        // Level 2: [hash(hash(A,B), C)]
321        let ab = hash_pair(&leaf_a.hash(), &leaf_b.hash());
322        let expected = hash_pair(&ab, &leaf_c.hash());
323        assert_eq!(tree.root(), expected);
324    }
325
326    #[test]
327    fn test_four_leaf_tree_root() {
328        let leaves: Vec<_> = (1..=4)
329            .map(|i| MpcLeaf::new(test_protocol(i), test_commitment(i)))
330            .collect();
331        let tree = MpcTree::new(leaves.clone());
332
333        let ab = hash_pair(&leaves[0].hash(), &leaves[1].hash());
334        let cd = hash_pair(&leaves[2].hash(), &leaves[3].hash());
335        let expected = hash_pair(&ab, &cd);
336        assert_eq!(tree.root(), expected);
337    }
338
339    #[test]
340    fn test_tree_root_deterministic() {
341        let tree1 = MpcTree::from_pairs(&[
342            (test_protocol(1), test_commitment(1)),
343            (test_protocol(2), test_commitment(2)),
344            (test_protocol(3), test_commitment(3)),
345        ]);
346        let tree2 = MpcTree::from_pairs(&[
347            (test_protocol(1), test_commitment(1)),
348            (test_protocol(2), test_commitment(2)),
349            (test_protocol(3), test_commitment(3)),
350        ]);
351        assert_eq!(tree1.root(), tree2.root());
352    }
353
354    #[test]
355    fn test_large_tree_root() {
356        let pairs: Vec<_> = (1..=100)
357            .map(|i| (test_protocol(i as u8), test_commitment(i as u8)))
358            .collect();
359        let tree = MpcTree::from_pairs(&pairs);
360        let root = tree.root();
361        assert_eq!(root.as_bytes().len(), 32);
362    }
363
364    // ─────────────────────────────────────────────
365    // MpcProof tests
366    // ─────────────────────────────────────────────
367
368    #[test]
369    fn test_proof_single_leaf() {
370        let leaf = MpcLeaf::new(test_protocol(1), test_commitment(42));
371        let tree = MpcTree::new(vec![leaf.clone()]);
372        let proof = tree.prove(test_protocol(1)).unwrap();
373        assert!(proof.verify(&tree.root()));
374    }
375
376    #[test]
377    fn test_proof_two_leaves() {
378        let tree = MpcTree::from_pairs(&[
379            (test_protocol(1), test_commitment(1)),
380            (test_protocol(2), test_commitment(2)),
381        ]);
382        let proof_a = tree.prove(test_protocol(1)).unwrap();
383        let proof_b = tree.prove(test_protocol(2)).unwrap();
384        assert!(proof_a.verify(&tree.root()));
385        assert!(proof_b.verify(&tree.root()));
386    }
387
388    #[test]
389    fn test_proof_three_leaves() {
390        let tree = MpcTree::from_pairs(&[
391            (test_protocol(1), test_commitment(1)),
392            (test_protocol(2), test_commitment(2)),
393            (test_protocol(3), test_commitment(3)),
394        ]);
395        for i in 1..=3 {
396            let proof = tree.prove(test_protocol(i)).unwrap();
397            assert!(proof.verify(&tree.root()));
398        }
399    }
400
401    #[test]
402    fn test_proof_all_leaves_in_large_tree() {
403        let pairs: Vec<_> = (1..=20)
404            .map(|i| (test_protocol(i as u8), test_commitment(i as u8)))
405            .collect();
406        let tree = MpcTree::from_pairs(&pairs);
407        for i in 1..=20 {
408            let proof = tree.prove(test_protocol(i as u8)).unwrap();
409            assert!(
410                proof.verify(&tree.root()),
411                "Proof for protocol {} failed",
412                i
413            );
414        }
415    }
416
417    #[test]
418    fn test_proof_missing_protocol() {
419        let tree = MpcTree::from_pairs(&[
420            (test_protocol(1), test_commitment(1)),
421            (test_protocol(2), test_commitment(2)),
422        ]);
423        assert!(tree.prove(test_protocol(99)).is_none());
424    }
425
426    #[test]
427    fn test_proof_wrong_root() {
428        let tree = MpcTree::from_pairs(&[
429            (test_protocol(1), test_commitment(1)),
430            (test_protocol(2), test_commitment(2)),
431        ]);
432        let proof = tree.prove(test_protocol(1)).unwrap();
433        assert!(!proof.verify(&Hash::new([0xFF; 32])));
434    }
435
436    #[test]
437    fn test_proof_wrong_commitment() {
438        let tree = MpcTree::from_pairs(&[
439            (test_protocol(1), test_commitment(1)),
440            (test_protocol(2), test_commitment(2)),
441        ]);
442        let mut proof = tree.prove(test_protocol(1)).unwrap();
443        // Tamper with the commitment
444        proof.commitment = test_commitment(99);
445        assert!(!proof.verify(&tree.root()));
446    }
447
448    #[test]
449    fn test_proof_wrong_protocol_id() {
450        let tree = MpcTree::from_pairs(&[
451            (test_protocol(1), test_commitment(1)),
452            (test_protocol(2), test_commitment(2)),
453        ]);
454        let mut proof = tree.prove(test_protocol(1)).unwrap();
455        // Tamper with the protocol_id
456        proof.protocol_id = test_protocol(99);
457        assert!(!proof.verify(&tree.root()));
458    }
459
460    #[test]
461    fn test_proof_branch_tampering() {
462        let tree = MpcTree::from_pairs(&[
463            (test_protocol(1), test_commitment(1)),
464            (test_protocol(2), test_commitment(2)),
465            (test_protocol(3), test_commitment(3)),
466        ]);
467        let mut proof = tree.prove(test_protocol(1)).unwrap();
468        // Tamper with a branch node
469        proof.branch[0].hash = Hash::new([0xFF; 32]);
470        assert!(!proof.verify(&tree.root()));
471    }
472
473    // ─────────────────────────────────────────────
474    // MpcTree utility tests
475    // ─────────────────────────────────────────────
476
477    #[test]
478    fn test_from_pairs() {
479        let tree = MpcTree::from_pairs(&[
480            (test_protocol(1), test_commitment(1)),
481            (test_protocol(2), test_commitment(2)),
482        ]);
483        assert_eq!(tree.protocol_count(), 2);
484        assert!(tree.contains_protocol(test_protocol(1)));
485        assert!(tree.contains_protocol(test_protocol(2)));
486        assert!(!tree.contains_protocol(test_protocol(3)));
487    }
488
489    #[test]
490    fn test_push() {
491        let mut tree = MpcTree::from_pairs(&[(test_protocol(1), test_commitment(1))]);
492        assert_eq!(tree.protocol_count(), 1);
493        tree.push(test_protocol(2), test_commitment(2));
494        assert_eq!(tree.protocol_count(), 2);
495        assert!(tree.contains_protocol(test_protocol(2)));
496    }
497
498    #[test]
499    fn test_leaf_index_in_proof() {
500        let tree = MpcTree::from_pairs(&[
501            (test_protocol(1), test_commitment(1)),
502            (test_protocol(2), test_commitment(2)),
503            (test_protocol(3), test_commitment(3)),
504        ]);
505        let proof_0 = tree.prove(test_protocol(1)).unwrap();
506        let proof_2 = tree.prove(test_protocol(3)).unwrap();
507        assert_eq!(proof_0.leaf_index, 0);
508        assert_eq!(proof_2.leaf_index, 2);
509    }
510}