1use alloc::vec::Vec;
11
12use crate::hash::Hash;
13use crate::tagged_hash::csv_tagged_hash;
14
15pub type ProtocolId = [u8; 32];
17
18#[derive(Clone, Debug, PartialEq, Eq, Hash)]
20pub struct MpcLeaf {
21 pub protocol_id: ProtocolId,
23 pub commitment: Hash,
25}
26
27impl MpcLeaf {
28 pub fn new(protocol_id: ProtocolId, commitment: Hash) -> Self {
30 Self {
31 protocol_id,
32 commitment,
33 }
34 }
35
36 pub fn hash(&self) -> Hash {
38 let mut data = Vec::with_capacity(64);
39 data.extend_from_slice(&self.protocol_id);
40 data.extend_from_slice(self.commitment.as_bytes());
41 Hash::new(csv_tagged_hash("mpc-leaf", &data))
42 }
43}
44
45#[derive(Clone, Debug, PartialEq, Eq, Hash)]
47pub struct MpcProof {
48 pub protocol_id: ProtocolId,
50 pub commitment: Hash,
52 pub branch: Vec<MerkleBranchNode>,
54 pub leaf_index: usize,
56}
57
58#[derive(Clone, Debug, PartialEq, Eq, Hash)]
60pub struct MerkleBranchNode {
61 pub hash: Hash,
63 pub is_left: bool,
65}
66
67impl MpcProof {
68 pub fn verify(&self, root: &Hash) -> bool {
70 let mut data = Vec::with_capacity(64);
71 data.extend_from_slice(&self.protocol_id);
72 data.extend_from_slice(self.commitment.as_bytes());
73 let mut current = Hash::new(csv_tagged_hash("mpc-leaf", &data));
74
75 for node in &self.branch {
76 let sibling_data: [u8; 64] = {
77 let mut d = [0u8; 64];
78 if node.is_left {
79 d[..32].copy_from_slice(node.hash.as_bytes());
80 d[32..].copy_from_slice(current.as_bytes());
81 } else {
82 d[..32].copy_from_slice(current.as_bytes());
83 d[32..].copy_from_slice(node.hash.as_bytes());
84 }
85 d
86 };
87 current = Hash::new(csv_tagged_hash("mpc-internal", &sibling_data));
88 }
89
90 current == *root
91 }
92}
93
94#[derive(Clone, Debug, PartialEq, Eq, Hash)]
96pub struct MpcTree {
97 pub leaves: Vec<MpcLeaf>,
99}
100
101impl MpcTree {
102 pub fn new(leaves: Vec<MpcLeaf>) -> Self {
104 Self { leaves }
105 }
106
107 pub fn from_pairs(pairs: &[(ProtocolId, Hash)]) -> Self {
109 let leaves = pairs
110 .iter()
111 .map(|(pid, comm)| MpcLeaf::new(*pid, *comm))
112 .collect();
113 Self { leaves }
114 }
115
116 pub fn root(&self) -> Hash {
122 if self.leaves.is_empty() {
123 return Hash::zero();
124 }
125
126 if self.leaves.len() == 1 {
127 return self.leaves[0].hash();
128 }
129
130 let mut hashes: Vec<Hash> = self.leaves.iter().map(|l| l.hash()).collect();
132
133 while hashes.len() > 1 {
135 let mut next_level = Vec::new();
136 for chunk in hashes.chunks(2) {
137 let left = &chunk[0];
138 if chunk.len() == 1 {
139 next_level.push(*left);
141 } else {
142 let right = &chunk[1];
143 next_level.push(hash_pair(left, right));
144 }
145 }
146 hashes = next_level;
147 }
148
149 hashes[0]
150 }
151
152 pub fn prove(&self, protocol_id: ProtocolId) -> Option<MpcProof> {
156 let leaf_index = self
157 .leaves
158 .iter()
159 .position(|l| l.protocol_id == protocol_id)?;
160
161 let leaf = &self.leaves[leaf_index];
162
163 let mut levels: Vec<Vec<Hash>> = Vec::new();
165 let current_level: Vec<Hash> = self.leaves.iter().map(|l| l.hash()).collect();
166 levels.push(current_level.clone());
167
168 let mut hashes = current_level;
169 while hashes.len() > 1 {
170 let mut next_level = Vec::new();
171 for chunk in hashes.chunks(2) {
172 let left = &chunk[0];
173 if chunk.len() == 1 {
174 next_level.push(*left);
176 } else {
177 next_level.push(hash_pair(left, &chunk[1]));
178 }
179 }
180 hashes = next_level;
181 levels.push(hashes.clone());
182 }
183
184 let mut branch = Vec::new();
186 let mut idx = leaf_index;
187 for level in levels.iter().take(levels.len() - 1) {
188 let (sibling_idx, is_left) = if idx % 2 == 0 {
189 (idx + 1, false) } else {
191 (idx - 1, true) };
193
194 if sibling_idx < level.len() {
195 branch.push(MerkleBranchNode {
196 hash: level[sibling_idx],
197 is_left,
198 });
199 }
200
201 idx /= 2;
202 }
203
204 Some(MpcProof {
205 protocol_id: leaf.protocol_id,
206 commitment: leaf.commitment,
207 branch,
208 leaf_index,
209 })
210 }
211
212 pub fn protocol_count(&self) -> usize {
214 self.leaves.len()
215 }
216
217 pub fn contains_protocol(&self, protocol_id: ProtocolId) -> bool {
219 self.leaves.iter().any(|l| l.protocol_id == protocol_id)
220 }
221
222 pub fn push(&mut self, protocol_id: ProtocolId, commitment: Hash) {
224 self.leaves.push(MpcLeaf::new(protocol_id, commitment));
225 }
226}
227
228fn hash_pair(left: &Hash, right: &Hash) -> Hash {
230 let mut data = [0u8; 64];
231 data[..32].copy_from_slice(left.as_bytes());
232 data[32..].copy_from_slice(right.as_bytes());
233 Hash::new(csv_tagged_hash("mpc-internal", &data))
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 fn test_protocol(id: u8) -> ProtocolId {
241 let mut arr = [0u8; 32];
242 arr[0] = id;
243 arr
244 }
245
246 fn test_commitment(id: u8) -> Hash {
247 let mut arr = [0u8; 32];
248 arr[31] = id;
249 Hash::new(arr)
250 }
251
252 #[test]
257 fn test_leaf_creation() {
258 let leaf = MpcLeaf::new(test_protocol(1), test_commitment(42));
259 assert_eq!(leaf.protocol_id[0], 1);
260 assert_eq!(leaf.commitment.as_bytes()[31], 42);
261 }
262
263 #[test]
264 fn test_leaf_hash_deterministic() {
265 let leaf1 = MpcLeaf::new(test_protocol(1), test_commitment(42));
266 let leaf2 = MpcLeaf::new(test_protocol(1), test_commitment(42));
267 assert_eq!(leaf1.hash(), leaf2.hash());
268 }
269
270 #[test]
271 fn test_leaf_hash_differs_by_protocol() {
272 let leaf1 = MpcLeaf::new(test_protocol(1), test_commitment(42));
273 let leaf2 = MpcLeaf::new(test_protocol(2), test_commitment(42));
274 assert_ne!(leaf1.hash(), leaf2.hash());
275 }
276
277 #[test]
278 fn test_leaf_hash_differs_by_commitment() {
279 let leaf1 = MpcLeaf::new(test_protocol(1), test_commitment(42));
280 let leaf2 = MpcLeaf::new(test_protocol(1), test_commitment(99));
281 assert_ne!(leaf1.hash(), leaf2.hash());
282 }
283
284 #[test]
289 fn test_empty_tree_root() {
290 let tree = MpcTree::new(vec![]);
291 assert_eq!(tree.root(), Hash::zero());
292 }
293
294 #[test]
295 fn test_single_leaf_tree_root() {
296 let leaf = MpcLeaf::new(test_protocol(1), test_commitment(42));
297 let tree = MpcTree::new(vec![leaf.clone()]);
298 assert_eq!(tree.root(), leaf.hash());
299 }
300
301 #[test]
302 fn test_two_leaf_tree_root() {
303 let leaf_a = MpcLeaf::new(test_protocol(1), test_commitment(1));
304 let leaf_b = MpcLeaf::new(test_protocol(2), test_commitment(2));
305 let tree = MpcTree::new(vec![leaf_a.clone(), leaf_b.clone()]);
306 let expected = hash_pair(&leaf_a.hash(), &leaf_b.hash());
307 assert_eq!(tree.root(), expected);
308 }
309
310 #[test]
311 fn test_three_leaf_tree_root() {
312 let leaf_a = MpcLeaf::new(test_protocol(1), test_commitment(1));
313 let leaf_b = MpcLeaf::new(test_protocol(2), test_commitment(2));
314 let leaf_c = MpcLeaf::new(test_protocol(3), test_commitment(3));
315 let tree = MpcTree::new(vec![leaf_a.clone(), leaf_b.clone(), leaf_c.clone()]);
316
317 let ab = hash_pair(&leaf_a.hash(), &leaf_b.hash());
321 let expected = hash_pair(&ab, &leaf_c.hash());
322 assert_eq!(tree.root(), expected);
323 }
324
325 #[test]
326 fn test_four_leaf_tree_root() {
327 let leaves: Vec<_> = (1..=4)
328 .map(|i| MpcLeaf::new(test_protocol(i), test_commitment(i)))
329 .collect();
330 let tree = MpcTree::new(leaves.clone());
331
332 let ab = hash_pair(&leaves[0].hash(), &leaves[1].hash());
333 let cd = hash_pair(&leaves[2].hash(), &leaves[3].hash());
334 let expected = hash_pair(&ab, &cd);
335 assert_eq!(tree.root(), expected);
336 }
337
338 #[test]
339 fn test_tree_root_deterministic() {
340 let tree1 = MpcTree::from_pairs(&[
341 (test_protocol(1), test_commitment(1)),
342 (test_protocol(2), test_commitment(2)),
343 (test_protocol(3), test_commitment(3)),
344 ]);
345 let tree2 = MpcTree::from_pairs(&[
346 (test_protocol(1), test_commitment(1)),
347 (test_protocol(2), test_commitment(2)),
348 (test_protocol(3), test_commitment(3)),
349 ]);
350 assert_eq!(tree1.root(), tree2.root());
351 }
352
353 #[test]
354 fn test_large_tree_root() {
355 let pairs: Vec<_> = (1..=100)
356 .map(|i| (test_protocol(i as u8), test_commitment(i as u8)))
357 .collect();
358 let tree = MpcTree::from_pairs(&pairs);
359 let root = tree.root();
360 assert_eq!(root.as_bytes().len(), 32);
361 }
362
363 #[test]
368 fn test_proof_single_leaf() {
369 let leaf = MpcLeaf::new(test_protocol(1), test_commitment(42));
370 let tree = MpcTree::new(vec![leaf.clone()]);
371 let proof = tree.prove(test_protocol(1)).unwrap();
372 assert!(proof.verify(&tree.root()));
373 }
374
375 #[test]
376 fn test_proof_two_leaves() {
377 let tree = MpcTree::from_pairs(&[
378 (test_protocol(1), test_commitment(1)),
379 (test_protocol(2), test_commitment(2)),
380 ]);
381 let proof_a = tree.prove(test_protocol(1)).unwrap();
382 let proof_b = tree.prove(test_protocol(2)).unwrap();
383 assert!(proof_a.verify(&tree.root()));
384 assert!(proof_b.verify(&tree.root()));
385 }
386
387 #[test]
388 fn test_proof_three_leaves() {
389 let tree = MpcTree::from_pairs(&[
390 (test_protocol(1), test_commitment(1)),
391 (test_protocol(2), test_commitment(2)),
392 (test_protocol(3), test_commitment(3)),
393 ]);
394 for i in 1..=3 {
395 let proof = tree.prove(test_protocol(i)).unwrap();
396 assert!(proof.verify(&tree.root()));
397 }
398 }
399
400 #[test]
401 fn test_proof_all_leaves_in_large_tree() {
402 let pairs: Vec<_> = (1..=20)
403 .map(|i| (test_protocol(i as u8), test_commitment(i as u8)))
404 .collect();
405 let tree = MpcTree::from_pairs(&pairs);
406 for i in 1..=20 {
407 let proof = tree.prove(test_protocol(i as u8)).unwrap();
408 assert!(
409 proof.verify(&tree.root()),
410 "Proof for protocol {} failed",
411 i
412 );
413 }
414 }
415
416 #[test]
417 fn test_proof_missing_protocol() {
418 let tree = MpcTree::from_pairs(&[
419 (test_protocol(1), test_commitment(1)),
420 (test_protocol(2), test_commitment(2)),
421 ]);
422 assert!(tree.prove(test_protocol(99)).is_none());
423 }
424
425 #[test]
426 fn test_proof_wrong_root() {
427 let tree = MpcTree::from_pairs(&[
428 (test_protocol(1), test_commitment(1)),
429 (test_protocol(2), test_commitment(2)),
430 ]);
431 let proof = tree.prove(test_protocol(1)).unwrap();
432 assert!(!proof.verify(&Hash::new([0xFF; 32])));
433 }
434
435 #[test]
436 fn test_proof_wrong_commitment() {
437 let tree = MpcTree::from_pairs(&[
438 (test_protocol(1), test_commitment(1)),
439 (test_protocol(2), test_commitment(2)),
440 ]);
441 let mut proof = tree.prove(test_protocol(1)).unwrap();
442 proof.commitment = test_commitment(99);
444 assert!(!proof.verify(&tree.root()));
445 }
446
447 #[test]
448 fn test_proof_wrong_protocol_id() {
449 let tree = MpcTree::from_pairs(&[
450 (test_protocol(1), test_commitment(1)),
451 (test_protocol(2), test_commitment(2)),
452 ]);
453 let mut proof = tree.prove(test_protocol(1)).unwrap();
454 proof.protocol_id = test_protocol(99);
456 assert!(!proof.verify(&tree.root()));
457 }
458
459 #[test]
460 fn test_proof_branch_tampering() {
461 let tree = MpcTree::from_pairs(&[
462 (test_protocol(1), test_commitment(1)),
463 (test_protocol(2), test_commitment(2)),
464 (test_protocol(3), test_commitment(3)),
465 ]);
466 let mut proof = tree.prove(test_protocol(1)).unwrap();
467 proof.branch[0].hash = Hash::new([0xFF; 32]);
469 assert!(!proof.verify(&tree.root()));
470 }
471
472 #[test]
477 fn test_from_pairs() {
478 let tree = MpcTree::from_pairs(&[
479 (test_protocol(1), test_commitment(1)),
480 (test_protocol(2), test_commitment(2)),
481 ]);
482 assert_eq!(tree.protocol_count(), 2);
483 assert!(tree.contains_protocol(test_protocol(1)));
484 assert!(tree.contains_protocol(test_protocol(2)));
485 assert!(!tree.contains_protocol(test_protocol(3)));
486 }
487
488 #[test]
489 fn test_push() {
490 let mut tree = MpcTree::from_pairs(&[(test_protocol(1), test_commitment(1))]);
491 assert_eq!(tree.protocol_count(), 1);
492 tree.push(test_protocol(2), test_commitment(2));
493 assert_eq!(tree.protocol_count(), 2);
494 assert!(tree.contains_protocol(test_protocol(2)));
495 }
496
497 #[test]
498 fn test_leaf_index_in_proof() {
499 let tree = MpcTree::from_pairs(&[
500 (test_protocol(1), test_commitment(1)),
501 (test_protocol(2), test_commitment(2)),
502 (test_protocol(3), test_commitment(3)),
503 ]);
504 let proof_0 = tree.prove(test_protocol(1)).unwrap();
505 let proof_2 = tree.prove(test_protocol(3)).unwrap();
506 assert_eq!(proof_0.leaf_index, 0);
507 assert_eq!(proof_2.leaf_index, 2);
508 }
509}