use alloc::vec::Vec;
use crate::hash::Hash;
use crate::tagged_hash::csv_tagged_hash;
pub type ProtocolId = [u8; 32];
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct MpcLeaf {
pub protocol_id: ProtocolId,
pub commitment: Hash,
}
impl MpcLeaf {
pub fn new(protocol_id: ProtocolId, commitment: Hash) -> Self {
Self {
protocol_id,
commitment,
}
}
pub fn hash(&self) -> Hash {
let mut data = Vec::with_capacity(64);
data.extend_from_slice(&self.protocol_id);
data.extend_from_slice(self.commitment.as_bytes());
Hash::new(csv_tagged_hash("mpc-leaf", &data))
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct MpcProof {
pub protocol_id: ProtocolId,
pub commitment: Hash,
pub branch: Vec<MerkleBranchNode>,
pub leaf_index: usize,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct MerkleBranchNode {
pub hash: Hash,
pub is_left: bool,
}
impl MpcProof {
pub fn verify(&self, root: &Hash) -> bool {
let mut data = Vec::with_capacity(64);
data.extend_from_slice(&self.protocol_id);
data.extend_from_slice(self.commitment.as_bytes());
let mut current = Hash::new(csv_tagged_hash("mpc-leaf", &data));
for node in &self.branch {
let sibling_data: [u8; 64] = {
let mut d = [0u8; 64];
if node.is_left {
d[..32].copy_from_slice(node.hash.as_bytes());
d[32..].copy_from_slice(current.as_bytes());
} else {
d[..32].copy_from_slice(current.as_bytes());
d[32..].copy_from_slice(node.hash.as_bytes());
}
d
};
current = Hash::new(csv_tagged_hash("mpc-internal", &sibling_data));
}
current == *root
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct MpcTree {
pub leaves: Vec<MpcLeaf>,
}
impl MpcTree {
pub fn new(leaves: Vec<MpcLeaf>) -> Self {
Self { leaves }
}
pub fn from_pairs(pairs: &[(ProtocolId, Hash)]) -> Self {
let leaves = pairs
.iter()
.map(|(pid, comm)| MpcLeaf::new(*pid, *comm))
.collect();
Self { leaves }
}
pub fn root(&self) -> Hash {
if self.leaves.is_empty() {
return Hash::zero();
}
if self.leaves.len() == 1 {
return self.leaves[0].hash();
}
let mut hashes: Vec<Hash> = self.leaves.iter().map(|l| l.hash()).collect();
while hashes.len() > 1 {
let mut next_level = Vec::new();
for chunk in hashes.chunks(2) {
let left = &chunk[0];
if chunk.len() == 1 {
next_level.push(*left);
} else {
let right = &chunk[1];
next_level.push(hash_pair(left, right));
}
}
hashes = next_level;
}
hashes[0]
}
pub fn prove(&self, protocol_id: ProtocolId) -> Option<MpcProof> {
let leaf_index = self
.leaves
.iter()
.position(|l| l.protocol_id == protocol_id)?;
let leaf = &self.leaves[leaf_index];
let mut levels: Vec<Vec<Hash>> = Vec::new();
let current_level: Vec<Hash> = self.leaves.iter().map(|l| l.hash()).collect();
levels.push(current_level.clone());
let mut hashes = current_level;
while hashes.len() > 1 {
let mut next_level = Vec::new();
for chunk in hashes.chunks(2) {
let left = &chunk[0];
if chunk.len() == 1 {
next_level.push(*left);
} else {
next_level.push(hash_pair(left, &chunk[1]));
}
}
hashes = next_level;
levels.push(hashes.clone());
}
let mut branch = Vec::new();
let mut idx = leaf_index;
for level_idx in 0..levels.len() - 1 {
let level = &levels[level_idx];
let (sibling_idx, is_left) = if idx % 2 == 0 {
(idx + 1, false) } else {
(idx - 1, true) };
if sibling_idx < level.len() {
branch.push(MerkleBranchNode {
hash: level[sibling_idx],
is_left,
});
}
idx /= 2;
}
Some(MpcProof {
protocol_id: leaf.protocol_id,
commitment: leaf.commitment,
branch,
leaf_index,
})
}
pub fn protocol_count(&self) -> usize {
self.leaves.len()
}
pub fn contains_protocol(&self, protocol_id: ProtocolId) -> bool {
self.leaves.iter().any(|l| l.protocol_id == protocol_id)
}
pub fn push(&mut self, protocol_id: ProtocolId, commitment: Hash) {
self.leaves.push(MpcLeaf::new(protocol_id, commitment));
}
}
fn hash_pair(left: &Hash, right: &Hash) -> Hash {
let mut data = [0u8; 64];
data[..32].copy_from_slice(left.as_bytes());
data[32..].copy_from_slice(right.as_bytes());
Hash::new(csv_tagged_hash("mpc-internal", &data))
}
#[cfg(test)]
mod tests {
use super::*;
fn test_protocol(id: u8) -> ProtocolId {
let mut arr = [0u8; 32];
arr[0] = id;
arr
}
fn test_commitment(id: u8) -> Hash {
let mut arr = [0u8; 32];
arr[31] = id;
Hash::new(arr)
}
#[test]
fn test_leaf_creation() {
let leaf = MpcLeaf::new(test_protocol(1), test_commitment(42));
assert_eq!(leaf.protocol_id[0], 1);
assert_eq!(leaf.commitment.as_bytes()[31], 42);
}
#[test]
fn test_leaf_hash_deterministic() {
let leaf1 = MpcLeaf::new(test_protocol(1), test_commitment(42));
let leaf2 = MpcLeaf::new(test_protocol(1), test_commitment(42));
assert_eq!(leaf1.hash(), leaf2.hash());
}
#[test]
fn test_leaf_hash_differs_by_protocol() {
let leaf1 = MpcLeaf::new(test_protocol(1), test_commitment(42));
let leaf2 = MpcLeaf::new(test_protocol(2), test_commitment(42));
assert_ne!(leaf1.hash(), leaf2.hash());
}
#[test]
fn test_leaf_hash_differs_by_commitment() {
let leaf1 = MpcLeaf::new(test_protocol(1), test_commitment(42));
let leaf2 = MpcLeaf::new(test_protocol(1), test_commitment(99));
assert_ne!(leaf1.hash(), leaf2.hash());
}
#[test]
fn test_empty_tree_root() {
let tree = MpcTree::new(vec![]);
assert_eq!(tree.root(), Hash::zero());
}
#[test]
fn test_single_leaf_tree_root() {
let leaf = MpcLeaf::new(test_protocol(1), test_commitment(42));
let tree = MpcTree::new(vec![leaf.clone()]);
assert_eq!(tree.root(), leaf.hash());
}
#[test]
fn test_two_leaf_tree_root() {
let leaf_a = MpcLeaf::new(test_protocol(1), test_commitment(1));
let leaf_b = MpcLeaf::new(test_protocol(2), test_commitment(2));
let tree = MpcTree::new(vec![leaf_a.clone(), leaf_b.clone()]);
let expected = hash_pair(&leaf_a.hash(), &leaf_b.hash());
assert_eq!(tree.root(), expected);
}
#[test]
fn test_three_leaf_tree_root() {
let leaf_a = MpcLeaf::new(test_protocol(1), test_commitment(1));
let leaf_b = MpcLeaf::new(test_protocol(2), test_commitment(2));
let leaf_c = MpcLeaf::new(test_protocol(3), test_commitment(3));
let tree = MpcTree::new(vec![leaf_a.clone(), leaf_b.clone(), leaf_c.clone()]);
let ab = hash_pair(&leaf_a.hash(), &leaf_b.hash());
let expected = hash_pair(&ab, &leaf_c.hash());
assert_eq!(tree.root(), expected);
}
#[test]
fn test_four_leaf_tree_root() {
let leaves: Vec<_> = (1..=4)
.map(|i| MpcLeaf::new(test_protocol(i), test_commitment(i)))
.collect();
let tree = MpcTree::new(leaves.clone());
let ab = hash_pair(&leaves[0].hash(), &leaves[1].hash());
let cd = hash_pair(&leaves[2].hash(), &leaves[3].hash());
let expected = hash_pair(&ab, &cd);
assert_eq!(tree.root(), expected);
}
#[test]
fn test_tree_root_deterministic() {
let tree1 = MpcTree::from_pairs(&[
(test_protocol(1), test_commitment(1)),
(test_protocol(2), test_commitment(2)),
(test_protocol(3), test_commitment(3)),
]);
let tree2 = MpcTree::from_pairs(&[
(test_protocol(1), test_commitment(1)),
(test_protocol(2), test_commitment(2)),
(test_protocol(3), test_commitment(3)),
]);
assert_eq!(tree1.root(), tree2.root());
}
#[test]
fn test_large_tree_root() {
let pairs: Vec<_> = (1..=100)
.map(|i| (test_protocol(i as u8), test_commitment(i as u8)))
.collect();
let tree = MpcTree::from_pairs(&pairs);
let root = tree.root();
assert_eq!(root.as_bytes().len(), 32);
}
#[test]
fn test_proof_single_leaf() {
let leaf = MpcLeaf::new(test_protocol(1), test_commitment(42));
let tree = MpcTree::new(vec![leaf.clone()]);
let proof = tree.prove(test_protocol(1)).unwrap();
assert!(proof.verify(&tree.root()));
}
#[test]
fn test_proof_two_leaves() {
let tree = MpcTree::from_pairs(&[
(test_protocol(1), test_commitment(1)),
(test_protocol(2), test_commitment(2)),
]);
let proof_a = tree.prove(test_protocol(1)).unwrap();
let proof_b = tree.prove(test_protocol(2)).unwrap();
assert!(proof_a.verify(&tree.root()));
assert!(proof_b.verify(&tree.root()));
}
#[test]
fn test_proof_three_leaves() {
let tree = MpcTree::from_pairs(&[
(test_protocol(1), test_commitment(1)),
(test_protocol(2), test_commitment(2)),
(test_protocol(3), test_commitment(3)),
]);
for i in 1..=3 {
let proof = tree.prove(test_protocol(i)).unwrap();
assert!(proof.verify(&tree.root()));
}
}
#[test]
fn test_proof_all_leaves_in_large_tree() {
let pairs: Vec<_> = (1..=20)
.map(|i| (test_protocol(i as u8), test_commitment(i as u8)))
.collect();
let tree = MpcTree::from_pairs(&pairs);
for i in 1..=20 {
let proof = tree.prove(test_protocol(i as u8)).unwrap();
assert!(
proof.verify(&tree.root()),
"Proof for protocol {} failed",
i
);
}
}
#[test]
fn test_proof_missing_protocol() {
let tree = MpcTree::from_pairs(&[
(test_protocol(1), test_commitment(1)),
(test_protocol(2), test_commitment(2)),
]);
assert!(tree.prove(test_protocol(99)).is_none());
}
#[test]
fn test_proof_wrong_root() {
let tree = MpcTree::from_pairs(&[
(test_protocol(1), test_commitment(1)),
(test_protocol(2), test_commitment(2)),
]);
let proof = tree.prove(test_protocol(1)).unwrap();
assert!(!proof.verify(&Hash::new([0xFF; 32])));
}
#[test]
fn test_proof_wrong_commitment() {
let tree = MpcTree::from_pairs(&[
(test_protocol(1), test_commitment(1)),
(test_protocol(2), test_commitment(2)),
]);
let mut proof = tree.prove(test_protocol(1)).unwrap();
proof.commitment = test_commitment(99);
assert!(!proof.verify(&tree.root()));
}
#[test]
fn test_proof_wrong_protocol_id() {
let tree = MpcTree::from_pairs(&[
(test_protocol(1), test_commitment(1)),
(test_protocol(2), test_commitment(2)),
]);
let mut proof = tree.prove(test_protocol(1)).unwrap();
proof.protocol_id = test_protocol(99);
assert!(!proof.verify(&tree.root()));
}
#[test]
fn test_proof_branch_tampering() {
let tree = MpcTree::from_pairs(&[
(test_protocol(1), test_commitment(1)),
(test_protocol(2), test_commitment(2)),
(test_protocol(3), test_commitment(3)),
]);
let mut proof = tree.prove(test_protocol(1)).unwrap();
proof.branch[0].hash = Hash::new([0xFF; 32]);
assert!(!proof.verify(&tree.root()));
}
#[test]
fn test_from_pairs() {
let tree = MpcTree::from_pairs(&[
(test_protocol(1), test_commitment(1)),
(test_protocol(2), test_commitment(2)),
]);
assert_eq!(tree.protocol_count(), 2);
assert!(tree.contains_protocol(test_protocol(1)));
assert!(tree.contains_protocol(test_protocol(2)));
assert!(!tree.contains_protocol(test_protocol(3)));
}
#[test]
fn test_push() {
let mut tree = MpcTree::from_pairs(&[(test_protocol(1), test_commitment(1))]);
assert_eq!(tree.protocol_count(), 1);
tree.push(test_protocol(2), test_commitment(2));
assert_eq!(tree.protocol_count(), 2);
assert!(tree.contains_protocol(test_protocol(2)));
}
#[test]
fn test_leaf_index_in_proof() {
let tree = MpcTree::from_pairs(&[
(test_protocol(1), test_commitment(1)),
(test_protocol(2), test_commitment(2)),
(test_protocol(3), test_commitment(3)),
]);
let proof_0 = tree.prove(test_protocol(1)).unwrap();
let proof_2 = tree.prove(test_protocol(3)).unwrap();
assert_eq!(proof_0.leaf_index, 0);
assert_eq!(proof_2.leaf_index, 2);
}
}