1use alloc::vec::Vec;
11
12use crate::hash::Hash;
13use crate::tagged_hash::csv_tagged_hash;
14
15pub type ProtocolId = [u8; 32];
17
18#[derive(Clone, Debug, PartialEq, Eq, Hash)]
20pub struct MpcLeaf {
21 pub protocol_id: ProtocolId,
23 pub commitment: Hash,
25}
26
27impl MpcLeaf {
28 pub fn new(protocol_id: ProtocolId, commitment: Hash) -> Self {
30 Self {
31 protocol_id,
32 commitment,
33 }
34 }
35
36 pub fn hash(&self) -> Hash {
38 let mut data = Vec::with_capacity(64);
39 data.extend_from_slice(&self.protocol_id);
40 data.extend_from_slice(self.commitment.as_bytes());
41 Hash::new(csv_tagged_hash("mpc-leaf", &data))
42 }
43}
44
45#[derive(Clone, Debug, PartialEq, Eq, Hash)]
47pub struct MpcProof {
48 pub protocol_id: ProtocolId,
50 pub commitment: Hash,
52 pub branch: Vec<MerkleBranchNode>,
54 pub leaf_index: usize,
56}
57
58#[derive(Clone, Debug, PartialEq, Eq, Hash)]
60pub struct MerkleBranchNode {
61 pub hash: Hash,
63 pub is_left: bool,
65}
66
67impl MpcProof {
68 pub fn verify(&self, root: &Hash) -> bool {
70 let mut data = Vec::with_capacity(64);
71 data.extend_from_slice(&self.protocol_id);
72 data.extend_from_slice(self.commitment.as_bytes());
73 let mut current = Hash::new(csv_tagged_hash("mpc-leaf", &data));
74
75 for node in &self.branch {
76 let sibling_data: [u8; 64] = {
77 let mut d = [0u8; 64];
78 if node.is_left {
79 d[..32].copy_from_slice(node.hash.as_bytes());
80 d[32..].copy_from_slice(current.as_bytes());
81 } else {
82 d[..32].copy_from_slice(current.as_bytes());
83 d[32..].copy_from_slice(node.hash.as_bytes());
84 }
85 d
86 };
87 current = Hash::new(csv_tagged_hash("mpc-internal", &sibling_data));
88 }
89
90 current == *root
91 }
92}
93
94#[derive(Clone, Debug, PartialEq, Eq, Hash)]
96pub struct MpcTree {
97 pub leaves: Vec<MpcLeaf>,
99}
100
101impl MpcTree {
102 pub fn new(leaves: Vec<MpcLeaf>) -> Self {
104 Self { leaves }
105 }
106
107 pub fn from_pairs(pairs: &[(ProtocolId, Hash)]) -> Self {
109 let leaves = pairs
110 .iter()
111 .map(|(pid, comm)| MpcLeaf::new(*pid, *comm))
112 .collect();
113 Self { leaves }
114 }
115
116 pub fn root(&self) -> Hash {
122 if self.leaves.is_empty() {
123 return Hash::zero();
124 }
125
126 if self.leaves.len() == 1 {
127 return self.leaves[0].hash();
128 }
129
130 let mut hashes: Vec<Hash> = self.leaves.iter().map(|l| l.hash()).collect();
132
133 while hashes.len() > 1 {
135 let mut next_level = Vec::new();
136 for chunk in hashes.chunks(2) {
137 let left = &chunk[0];
138 if chunk.len() == 1 {
139 next_level.push(*left);
141 } else {
142 let right = &chunk[1];
143 next_level.push(hash_pair(left, right));
144 }
145 }
146 hashes = next_level;
147 }
148
149 hashes[0]
150 }
151
152 pub fn prove(&self, protocol_id: ProtocolId) -> Option<MpcProof> {
156 let leaf_index = self
157 .leaves
158 .iter()
159 .position(|l| l.protocol_id == protocol_id)?;
160
161 let leaf = &self.leaves[leaf_index];
162
163 let mut levels: Vec<Vec<Hash>> = Vec::new();
165 let current_level: Vec<Hash> = self.leaves.iter().map(|l| l.hash()).collect();
166 levels.push(current_level.clone());
167
168 let mut hashes = current_level;
169 while hashes.len() > 1 {
170 let mut next_level = Vec::new();
171 for chunk in hashes.chunks(2) {
172 let left = &chunk[0];
173 if chunk.len() == 1 {
174 next_level.push(*left);
176 } else {
177 next_level.push(hash_pair(left, &chunk[1]));
178 }
179 }
180 hashes = next_level;
181 levels.push(hashes.clone());
182 }
183
184 let mut branch = Vec::new();
186 let mut idx = leaf_index;
187 for level_idx in 0..levels.len() - 1 {
188 let level = &levels[level_idx];
189 let (sibling_idx, is_left) = if idx % 2 == 0 {
190 (idx + 1, false) } else {
192 (idx - 1, true) };
194
195 if sibling_idx < level.len() {
196 branch.push(MerkleBranchNode {
197 hash: level[sibling_idx],
198 is_left,
199 });
200 }
201
202 idx /= 2;
203 }
204
205 Some(MpcProof {
206 protocol_id: leaf.protocol_id,
207 commitment: leaf.commitment,
208 branch,
209 leaf_index,
210 })
211 }
212
213 pub fn protocol_count(&self) -> usize {
215 self.leaves.len()
216 }
217
218 pub fn contains_protocol(&self, protocol_id: ProtocolId) -> bool {
220 self.leaves.iter().any(|l| l.protocol_id == protocol_id)
221 }
222
223 pub fn push(&mut self, protocol_id: ProtocolId, commitment: Hash) {
225 self.leaves.push(MpcLeaf::new(protocol_id, commitment));
226 }
227}
228
229fn hash_pair(left: &Hash, right: &Hash) -> Hash {
231 let mut data = [0u8; 64];
232 data[..32].copy_from_slice(left.as_bytes());
233 data[32..].copy_from_slice(right.as_bytes());
234 Hash::new(csv_tagged_hash("mpc-internal", &data))
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 fn test_protocol(id: u8) -> ProtocolId {
242 let mut arr = [0u8; 32];
243 arr[0] = id;
244 arr
245 }
246
247 fn test_commitment(id: u8) -> Hash {
248 let mut arr = [0u8; 32];
249 arr[31] = id;
250 Hash::new(arr)
251 }
252
253 #[test]
258 fn test_leaf_creation() {
259 let leaf = MpcLeaf::new(test_protocol(1), test_commitment(42));
260 assert_eq!(leaf.protocol_id[0], 1);
261 assert_eq!(leaf.commitment.as_bytes()[31], 42);
262 }
263
264 #[test]
265 fn test_leaf_hash_deterministic() {
266 let leaf1 = MpcLeaf::new(test_protocol(1), test_commitment(42));
267 let leaf2 = MpcLeaf::new(test_protocol(1), test_commitment(42));
268 assert_eq!(leaf1.hash(), leaf2.hash());
269 }
270
271 #[test]
272 fn test_leaf_hash_differs_by_protocol() {
273 let leaf1 = MpcLeaf::new(test_protocol(1), test_commitment(42));
274 let leaf2 = MpcLeaf::new(test_protocol(2), test_commitment(42));
275 assert_ne!(leaf1.hash(), leaf2.hash());
276 }
277
278 #[test]
279 fn test_leaf_hash_differs_by_commitment() {
280 let leaf1 = MpcLeaf::new(test_protocol(1), test_commitment(42));
281 let leaf2 = MpcLeaf::new(test_protocol(1), test_commitment(99));
282 assert_ne!(leaf1.hash(), leaf2.hash());
283 }
284
285 #[test]
290 fn test_empty_tree_root() {
291 let tree = MpcTree::new(vec![]);
292 assert_eq!(tree.root(), Hash::zero());
293 }
294
295 #[test]
296 fn test_single_leaf_tree_root() {
297 let leaf = MpcLeaf::new(test_protocol(1), test_commitment(42));
298 let tree = MpcTree::new(vec![leaf.clone()]);
299 assert_eq!(tree.root(), leaf.hash());
300 }
301
302 #[test]
303 fn test_two_leaf_tree_root() {
304 let leaf_a = MpcLeaf::new(test_protocol(1), test_commitment(1));
305 let leaf_b = MpcLeaf::new(test_protocol(2), test_commitment(2));
306 let tree = MpcTree::new(vec![leaf_a.clone(), leaf_b.clone()]);
307 let expected = hash_pair(&leaf_a.hash(), &leaf_b.hash());
308 assert_eq!(tree.root(), expected);
309 }
310
311 #[test]
312 fn test_three_leaf_tree_root() {
313 let leaf_a = MpcLeaf::new(test_protocol(1), test_commitment(1));
314 let leaf_b = MpcLeaf::new(test_protocol(2), test_commitment(2));
315 let leaf_c = MpcLeaf::new(test_protocol(3), test_commitment(3));
316 let tree = MpcTree::new(vec![leaf_a.clone(), leaf_b.clone(), leaf_c.clone()]);
317
318 let ab = hash_pair(&leaf_a.hash(), &leaf_b.hash());
322 let expected = hash_pair(&ab, &leaf_c.hash());
323 assert_eq!(tree.root(), expected);
324 }
325
326 #[test]
327 fn test_four_leaf_tree_root() {
328 let leaves: Vec<_> = (1..=4)
329 .map(|i| MpcLeaf::new(test_protocol(i), test_commitment(i)))
330 .collect();
331 let tree = MpcTree::new(leaves.clone());
332
333 let ab = hash_pair(&leaves[0].hash(), &leaves[1].hash());
334 let cd = hash_pair(&leaves[2].hash(), &leaves[3].hash());
335 let expected = hash_pair(&ab, &cd);
336 assert_eq!(tree.root(), expected);
337 }
338
339 #[test]
340 fn test_tree_root_deterministic() {
341 let tree1 = MpcTree::from_pairs(&[
342 (test_protocol(1), test_commitment(1)),
343 (test_protocol(2), test_commitment(2)),
344 (test_protocol(3), test_commitment(3)),
345 ]);
346 let tree2 = MpcTree::from_pairs(&[
347 (test_protocol(1), test_commitment(1)),
348 (test_protocol(2), test_commitment(2)),
349 (test_protocol(3), test_commitment(3)),
350 ]);
351 assert_eq!(tree1.root(), tree2.root());
352 }
353
354 #[test]
355 fn test_large_tree_root() {
356 let pairs: Vec<_> = (1..=100)
357 .map(|i| (test_protocol(i as u8), test_commitment(i as u8)))
358 .collect();
359 let tree = MpcTree::from_pairs(&pairs);
360 let root = tree.root();
361 assert_eq!(root.as_bytes().len(), 32);
362 }
363
364 #[test]
369 fn test_proof_single_leaf() {
370 let leaf = MpcLeaf::new(test_protocol(1), test_commitment(42));
371 let tree = MpcTree::new(vec![leaf.clone()]);
372 let proof = tree.prove(test_protocol(1)).unwrap();
373 assert!(proof.verify(&tree.root()));
374 }
375
376 #[test]
377 fn test_proof_two_leaves() {
378 let tree = MpcTree::from_pairs(&[
379 (test_protocol(1), test_commitment(1)),
380 (test_protocol(2), test_commitment(2)),
381 ]);
382 let proof_a = tree.prove(test_protocol(1)).unwrap();
383 let proof_b = tree.prove(test_protocol(2)).unwrap();
384 assert!(proof_a.verify(&tree.root()));
385 assert!(proof_b.verify(&tree.root()));
386 }
387
388 #[test]
389 fn test_proof_three_leaves() {
390 let tree = MpcTree::from_pairs(&[
391 (test_protocol(1), test_commitment(1)),
392 (test_protocol(2), test_commitment(2)),
393 (test_protocol(3), test_commitment(3)),
394 ]);
395 for i in 1..=3 {
396 let proof = tree.prove(test_protocol(i)).unwrap();
397 assert!(proof.verify(&tree.root()));
398 }
399 }
400
401 #[test]
402 fn test_proof_all_leaves_in_large_tree() {
403 let pairs: Vec<_> = (1..=20)
404 .map(|i| (test_protocol(i as u8), test_commitment(i as u8)))
405 .collect();
406 let tree = MpcTree::from_pairs(&pairs);
407 for i in 1..=20 {
408 let proof = tree.prove(test_protocol(i as u8)).unwrap();
409 assert!(
410 proof.verify(&tree.root()),
411 "Proof for protocol {} failed",
412 i
413 );
414 }
415 }
416
417 #[test]
418 fn test_proof_missing_protocol() {
419 let tree = MpcTree::from_pairs(&[
420 (test_protocol(1), test_commitment(1)),
421 (test_protocol(2), test_commitment(2)),
422 ]);
423 assert!(tree.prove(test_protocol(99)).is_none());
424 }
425
426 #[test]
427 fn test_proof_wrong_root() {
428 let tree = MpcTree::from_pairs(&[
429 (test_protocol(1), test_commitment(1)),
430 (test_protocol(2), test_commitment(2)),
431 ]);
432 let proof = tree.prove(test_protocol(1)).unwrap();
433 assert!(!proof.verify(&Hash::new([0xFF; 32])));
434 }
435
436 #[test]
437 fn test_proof_wrong_commitment() {
438 let tree = MpcTree::from_pairs(&[
439 (test_protocol(1), test_commitment(1)),
440 (test_protocol(2), test_commitment(2)),
441 ]);
442 let mut proof = tree.prove(test_protocol(1)).unwrap();
443 proof.commitment = test_commitment(99);
445 assert!(!proof.verify(&tree.root()));
446 }
447
448 #[test]
449 fn test_proof_wrong_protocol_id() {
450 let tree = MpcTree::from_pairs(&[
451 (test_protocol(1), test_commitment(1)),
452 (test_protocol(2), test_commitment(2)),
453 ]);
454 let mut proof = tree.prove(test_protocol(1)).unwrap();
455 proof.protocol_id = test_protocol(99);
457 assert!(!proof.verify(&tree.root()));
458 }
459
460 #[test]
461 fn test_proof_branch_tampering() {
462 let tree = MpcTree::from_pairs(&[
463 (test_protocol(1), test_commitment(1)),
464 (test_protocol(2), test_commitment(2)),
465 (test_protocol(3), test_commitment(3)),
466 ]);
467 let mut proof = tree.prove(test_protocol(1)).unwrap();
468 proof.branch[0].hash = Hash::new([0xFF; 32]);
470 assert!(!proof.verify(&tree.root()));
471 }
472
473 #[test]
478 fn test_from_pairs() {
479 let tree = MpcTree::from_pairs(&[
480 (test_protocol(1), test_commitment(1)),
481 (test_protocol(2), test_commitment(2)),
482 ]);
483 assert_eq!(tree.protocol_count(), 2);
484 assert!(tree.contains_protocol(test_protocol(1)));
485 assert!(tree.contains_protocol(test_protocol(2)));
486 assert!(!tree.contains_protocol(test_protocol(3)));
487 }
488
489 #[test]
490 fn test_push() {
491 let mut tree = MpcTree::from_pairs(&[(test_protocol(1), test_commitment(1))]);
492 assert_eq!(tree.protocol_count(), 1);
493 tree.push(test_protocol(2), test_commitment(2));
494 assert_eq!(tree.protocol_count(), 2);
495 assert!(tree.contains_protocol(test_protocol(2)));
496 }
497
498 #[test]
499 fn test_leaf_index_in_proof() {
500 let tree = MpcTree::from_pairs(&[
501 (test_protocol(1), test_commitment(1)),
502 (test_protocol(2), test_commitment(2)),
503 (test_protocol(3), test_commitment(3)),
504 ]);
505 let proof_0 = tree.prove(test_protocol(1)).unwrap();
506 let proof_2 = tree.prove(test_protocol(3)).unwrap();
507 assert_eq!(proof_0.leaf_index, 0);
508 assert_eq!(proof_2.leaf_index, 2);
509 }
510}