use ant_merkle::Hasher;
use serde::{Deserialize, Serialize};
use std::time::{SystemTime, UNIX_EPOCH};
use thiserror::Error;
use xor_name::XorName;
use super::merkle_payment::sha3_256;
pub use evmlib::merkle_batch_payment::MAX_MERKLE_DEPTH;
pub const MIN_LEAVES: usize = 2;
pub const MAX_LEAVES: usize = 1 << MAX_MERKLE_DEPTH;
pub const MERKLE_PAYMENT_EXPIRATION: u64 = 7 * 24 * 60 * 60;
pub fn expected_reward_pools(depth: u8) -> usize {
1 << midpoint_proof_depth(depth)
}
#[derive(Debug, Error)]
pub enum MerkleTreeError {
#[error("Too few leaves: got {got}, minimum is {MIN_LEAVES}")]
TooFewLeaves { got: usize },
#[error("Too many leaves: got {got}, maximum is {MAX_LEAVES}")]
TooManyLeaves { got: usize },
#[error("Invalid leaf index: {index} (tree has {leaf_count} leaves)")]
InvalidLeafIndex { index: usize, leaf_count: usize },
#[error("Invalid midpoint index: {index} (tree has {midpoint_count} midpoints)")]
InvalidMidpointIndex { index: usize, midpoint_count: usize },
#[error("Invalid proof")]
InvalidProof,
#[error("Internal error: {0}")]
Internal(String),
}
pub type Result<T> = std::result::Result<T, MerkleTreeError>;
pub struct MerkleTree {
inner: ant_merkle::MerkleTree<Sha3Hasher>,
leaf_count: usize,
depth: u8,
root: XorName,
salts: Vec<[u8; 32]>,
}
impl MerkleTree {
pub fn from_xornames(leaves: Vec<XorName>) -> Result<Self> {
let leaf_count = leaves.len();
if leaf_count < MIN_LEAVES {
return Err(MerkleTreeError::TooFewLeaves { got: leaf_count });
}
if leaf_count > MAX_LEAVES {
return Err(MerkleTreeError::TooManyLeaves { got: leaf_count });
}
let mut rng = rand::thread_rng();
let salts: Vec<[u8; 32]> = (0..leaf_count)
.map(|_| {
let mut salt = [0u8; 32];
rand::Rng::fill(&mut rng, &mut salt);
salt
})
.collect();
let depth = tree_depth(leaf_count);
let padded_size = 1 << depth;
let mut salted_leaves: Vec<[u8; 32]> = leaves
.iter()
.zip(&salts)
.map(|(address, salt)| {
let mut data = Vec::with_capacity(64);
data.extend_from_slice(address.as_ref());
data.extend_from_slice(salt);
Sha3Hasher::hash(&data)
})
.collect();
if leaf_count < padded_size {
for _ in leaf_count..padded_size {
let mut dummy = [0u8; 32];
rand::Rng::fill(&mut rng, &mut dummy);
salted_leaves.push(dummy);
}
}
let inner = ant_merkle::MerkleTree::<Sha3Hasher>::from_leaves(&salted_leaves);
let root = inner.root().ok_or(MerkleTreeError::Internal(
"Tree must have root after construction".to_string(),
))?;
Ok(Self {
inner,
root: XorName(root),
leaf_count,
depth,
salts,
})
}
pub fn root(&self) -> XorName {
self.root
}
pub fn depth(&self) -> u8 {
self.depth
}
pub fn leaf_count(&self) -> usize {
self.leaf_count
}
fn midpoints(&self) -> Result<Vec<MerkleMidpoint>> {
let level = midpoint_level(self.depth);
let nodes = self
.inner
.get_nodes_at_level(level)
.ok_or(MerkleTreeError::Internal(
"Midpoint level must exist".to_string(),
))?;
let midpoints: Vec<MerkleMidpoint> = nodes
.into_iter()
.map(|(index, hash)| MerkleMidpoint {
hash: XorName(hash),
index,
})
.collect();
Ok(midpoints)
}
pub fn reward_candidates(&self, merkle_payment_timestamp: u64) -> Result<Vec<MidpointProof>> {
let midpoints = self.midpoints()?;
midpoints
.into_iter()
.map(|midpoint| {
let branch = self.generate_midpoint_proof(midpoint.index, midpoint.hash)?;
Ok(MidpointProof {
branch,
merkle_payment_timestamp,
})
})
.collect()
}
pub fn generate_address_proof(
&self,
address_index: usize,
address_hash: XorName,
) -> Result<MerkleBranch> {
if address_index >= self.leaf_count {
return Err(MerkleTreeError::InvalidLeafIndex {
index: address_index,
leaf_count: self.leaf_count,
});
}
let indices = vec![address_index];
let proof = self.inner.proof(&indices);
let padded_size = 1 << self.depth;
let root = self.root();
let salt = self.salts[address_index];
Ok(MerkleBranch::from_rs_merkle_proof(
proof,
address_index,
padded_size,
address_hash,
root,
Some(salt),
))
}
fn generate_midpoint_proof(
&self,
midpoint_index: usize,
midpoint_hash: XorName,
) -> Result<MerkleBranch> {
let level = midpoint_level(self.depth);
let midpoint_count = expected_reward_pools(self.depth);
if midpoint_index >= midpoint_count {
return Err(MerkleTreeError::InvalidMidpointIndex {
index: midpoint_index,
midpoint_count,
});
}
let proof = self
.inner
.proof_from_node(level, midpoint_index)
.ok_or_else(|| {
MerkleTreeError::Internal("Failed to generate midpoint proof".to_string())
})?;
let effective_leaf_count = midpoint_count;
let root = self.root();
Ok(MerkleBranch::from_rs_merkle_proof(
proof,
midpoint_index,
effective_leaf_count,
midpoint_hash,
root,
None, ))
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
struct MerkleMidpoint {
hash: XorName,
index: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct MidpointProof {
pub branch: MerkleBranch,
pub merkle_payment_timestamp: u64,
}
impl MidpointProof {
pub fn root(&self) -> &XorName {
self.branch.root()
}
pub fn address(&self) -> XorName {
let mut data = Vec::with_capacity(32 + 32 + 8);
data.extend_from_slice(self.branch.leaf_hash().as_ref());
data.extend_from_slice(self.branch.root().as_ref());
data.extend_from_slice(&self.merkle_payment_timestamp.to_le_bytes());
XorName::from_content(&data)
}
pub fn hash(&self) -> [u8; 32] {
let mut bytes = Vec::new();
for proof_hash in &self.branch.proof_hashes {
bytes.extend_from_slice(proof_hash);
}
bytes.extend_from_slice(&(self.branch.leaf_index as u64).to_le_bytes());
bytes.extend_from_slice(&(self.branch.total_leaves_count as u64).to_le_bytes());
bytes.extend_from_slice(self.branch.unsalted_leaf_hash.as_ref());
bytes.extend_from_slice(self.branch.root.as_ref());
if let Some(salt) = &self.branch.salt {
bytes.push(1); bytes.extend_from_slice(salt);
} else {
bytes.push(0); }
bytes.extend_from_slice(&self.merkle_payment_timestamp.to_le_bytes());
sha3_256(&bytes)
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct MerkleBranch {
proof_hashes: Vec<[u8; 32]>,
leaf_index: usize,
total_leaves_count: usize,
unsalted_leaf_hash: XorName,
root: XorName,
salt: Option<[u8; 32]>,
}
impl MerkleBranch {
fn from_rs_merkle_proof(
proof: ant_merkle::MerkleProof<Sha3Hasher>,
leaf_index: usize,
total_leaves_count: usize,
unsalted_leaf_hash: XorName,
root: XorName,
salt: Option<[u8; 32]>,
) -> Self {
let proof_hashes = proof.proof_hashes().to_vec();
Self {
proof_hashes,
leaf_index,
total_leaves_count,
unsalted_leaf_hash,
root,
salt,
}
}
pub fn leaf_hash(&self) -> &XorName {
&self.unsalted_leaf_hash
}
pub fn root(&self) -> &XorName {
&self.root
}
pub fn verify(&self) -> bool {
let hash = if let Some(salt) = &self.salt {
let mut data = Vec::with_capacity(64);
data.extend_from_slice(self.unsalted_leaf_hash.as_ref());
data.extend_from_slice(salt);
Sha3Hasher::hash(&data)
} else {
let leaf_bytes = self.unsalted_leaf_hash.as_ref();
let mut hash = [0u8; 32];
hash.copy_from_slice(leaf_bytes);
hash
};
let root_bytes = self.root.as_ref();
let mut expected_root = [0u8; 32];
expected_root.copy_from_slice(root_bytes);
let proof = ant_merkle::MerkleProof::<Sha3Hasher>::new(self.proof_hashes.clone());
proof.verify(
expected_root,
&[self.leaf_index],
&[hash],
self.total_leaves_count,
)
}
pub fn depth(&self) -> usize {
self.proof_hashes.len()
}
}
pub fn tree_depth(leaf_count: usize) -> u8 {
if leaf_count <= 1 {
return 0;
}
let mut depth = 0;
let mut n = leaf_count - 1;
while n > 0 {
depth += 1;
n >>= 1;
}
depth
}
pub fn midpoint_proof_depth(depth: u8) -> u8 {
depth.div_ceil(2)
}
fn midpoint_level(depth: u8) -> usize {
(depth / 2) as usize
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum BadMerkleProof {
#[error("Address branch proof failed Merkle verification")]
InvalidAddressBranchProof,
#[error("Winner/intersection branch proof failed Merkle verification")]
InvalidWinnerBranchProof,
#[error("Address proof depth mismatch: expected {expected}, got {got}")]
AddressProofDepthMismatch { expected: usize, got: usize },
#[error("Winner proof depth mismatch: expected {expected}, got {got}")]
WinnerProofDepthMismatch { expected: usize, got: usize },
#[error(
"Address branch root doesn't match smart contract root: smart_contract={smart_contract_root}, branch={branch_root}"
)]
AddressBranchRootMismatch {
smart_contract_root: XorName,
branch_root: XorName,
},
#[error(
"Winner branch root doesn't match smart contract root: smart_contract={smart_contract_root}, branch={branch_root}"
)]
WinnerBranchRootMismatch {
smart_contract_root: XorName,
branch_root: XorName,
},
#[error(
"Payment timestamp {payment_timestamp} is in the future (current time: {current_time})"
)]
TimestampInFuture {
payment_timestamp: u64,
current_time: u64,
},
#[error(
"Payment expired: timestamp {payment_timestamp} is {age_seconds}s old (max: {MERKLE_PAYMENT_EXPIRATION}s)"
)]
PaymentExpired {
payment_timestamp: u64,
current_time: u64,
age_seconds: u64,
},
#[error("Failed to get current system time: {0}")]
SystemTimeError(String),
#[error(
"Winner pool timestamp {pool_timestamp} doesn't match smart contract timestamp {contract_timestamp}"
)]
TimestampMismatch {
pool_timestamp: u64,
contract_timestamp: u64,
},
#[error("Address hash not matching branch leaf: leaf={leaf}, address={address}")]
AddressHashNotBranchLeaf { leaf: XorName, address: XorName },
}
fn validate_payment_timestamp(
payment_timestamp: u64,
pool_timestamp: u64,
) -> std::result::Result<(), BadMerkleProof> {
let current_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| BadMerkleProof::SystemTimeError(e.to_string()))?
.as_secs();
if payment_timestamp > current_time {
return Err(BadMerkleProof::TimestampInFuture {
payment_timestamp,
current_time,
});
}
let age = current_time - payment_timestamp;
if age > MERKLE_PAYMENT_EXPIRATION {
return Err(BadMerkleProof::PaymentExpired {
payment_timestamp,
current_time,
age_seconds: age,
});
}
if pool_timestamp != payment_timestamp {
return Err(BadMerkleProof::TimestampMismatch {
pool_timestamp,
contract_timestamp: payment_timestamp,
});
}
Ok(())
}
pub fn verify_merkle_proof(
address_hash: &XorName,
address_branch: &MerkleBranch,
winner_pool_midpoint_proof: &MidpointProof,
smart_contract_depth: u8,
smart_contract_root: &XorName,
smart_contract_timestamp: u64,
) -> std::result::Result<(), BadMerkleProof> {
validate_payment_timestamp(
smart_contract_timestamp,
winner_pool_midpoint_proof.merkle_payment_timestamp,
)?;
let address_depth = address_branch.depth();
let expected_address_depth = smart_contract_depth as usize;
if address_depth != expected_address_depth {
return Err(BadMerkleProof::AddressProofDepthMismatch {
expected: expected_address_depth,
got: address_depth,
});
}
let winner_depth = winner_pool_midpoint_proof.branch.depth();
let expected_winner_depth = midpoint_proof_depth(smart_contract_depth) as usize;
if winner_depth != expected_winner_depth {
return Err(BadMerkleProof::WinnerProofDepthMismatch {
expected: expected_winner_depth,
got: winner_depth,
});
}
if !address_branch.verify() {
return Err(BadMerkleProof::InvalidAddressBranchProof);
}
if !winner_pool_midpoint_proof.branch.verify() {
return Err(BadMerkleProof::InvalidWinnerBranchProof);
}
if address_hash != address_branch.leaf_hash() {
return Err(BadMerkleProof::AddressHashNotBranchLeaf {
leaf: *address_branch.leaf_hash(),
address: *address_hash,
});
}
if address_branch.root() != smart_contract_root {
return Err(BadMerkleProof::AddressBranchRootMismatch {
smart_contract_root: *smart_contract_root,
branch_root: *address_branch.root(),
});
}
if winner_pool_midpoint_proof.branch.root() != smart_contract_root {
return Err(BadMerkleProof::WinnerBranchRootMismatch {
smart_contract_root: *smart_contract_root,
branch_root: *winner_pool_midpoint_proof.branch.root(),
});
}
Ok(())
}
#[derive(Clone)]
struct Sha3Hasher;
impl ant_merkle::Hasher for Sha3Hasher {
type Hash = [u8; 32];
fn hash(data: &[u8]) -> Self::Hash {
sha3_256(data)
}
fn concat_and_hash(left: &Self::Hash, right: Option<&Self::Hash>) -> Self::Hash {
match right {
Some(r) => {
let mut combined = Vec::with_capacity(64);
combined.extend_from_slice(left);
combined.extend_from_slice(r);
sha3_256(&combined)
}
None => sha3_256(left),
}
}
fn hash_size() -> usize {
32
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_leaves(count: usize) -> Vec<XorName> {
(0..count)
.map(|i| XorName::from_content(&i.to_le_bytes()))
.collect()
}
#[test]
fn test_reward_candidate_pool_hash_fixed_width_encoding() {
let leaves = make_test_leaves(16);
let tree = MerkleTree::from_xornames(leaves).unwrap();
let timestamp = 1234567890u64;
let pools = tree.reward_candidates(timestamp).unwrap();
let pool = &pools[0];
let hash1 = pool.hash();
let mut bytes = Vec::new();
for proof_hash in &pool.branch.proof_hashes {
bytes.extend_from_slice(proof_hash);
}
bytes.extend_from_slice(&(pool.branch.leaf_index as u64).to_le_bytes());
bytes.extend_from_slice(&(pool.branch.total_leaves_count as u64).to_le_bytes());
bytes.extend_from_slice(pool.branch.unsalted_leaf_hash.as_ref());
bytes.extend_from_slice(pool.branch.root.as_ref());
if let Some(salt) = &pool.branch.salt {
bytes.push(1);
bytes.extend_from_slice(salt);
} else {
bytes.push(0);
}
bytes.extend_from_slice(&pool.merkle_payment_timestamp.to_le_bytes());
let hash2 = sha3_256(&bytes);
assert_eq!(
hash1, hash2,
"MidpointProof::hash should match manual u64-encoded hash"
);
}
#[test]
fn test_reward_candidate_pool_hash_architecture_independence() {
let leaves = make_test_leaves(4);
let tree = MerkleTree::from_xornames(leaves).unwrap();
let timestamp = u64::MAX;
let pools = tree.reward_candidates(timestamp).unwrap();
let hash1 = pools[0].hash();
let hash2 = pools[0].hash();
assert_eq!(hash1, hash2, "Same pool should produce identical hash");
let pool = &pools[0];
let mut bytes = Vec::new();
for proof_hash in &pool.branch.proof_hashes {
bytes.extend_from_slice(proof_hash);
}
let start_offset = bytes.len();
bytes.extend_from_slice(&(pool.branch.leaf_index as u64).to_le_bytes());
bytes.extend_from_slice(&(pool.branch.total_leaves_count as u64).to_le_bytes());
assert_eq!(
bytes.len() - start_offset,
16, "Should use 8 bytes per usize field regardless of platform"
);
let leaf_index_bytes = &bytes[start_offset..start_offset + 8];
let leaf_index = u64::from_le_bytes(leaf_index_bytes.try_into().unwrap());
assert_eq!(
leaf_index, pool.branch.leaf_index as u64,
"leaf_index should be preserved in u64 encoding"
);
}
#[test]
fn test_expected_reward_pools() {
assert_eq!(expected_reward_pools(1), 2); assert_eq!(expected_reward_pools(2), 2); assert_eq!(expected_reward_pools(3), 4); assert_eq!(expected_reward_pools(4), 4); assert_eq!(expected_reward_pools(5), 8); assert_eq!(expected_reward_pools(6), 8); assert_eq!(expected_reward_pools(7), 16); assert_eq!(expected_reward_pools(8), 16); assert_eq!(expected_reward_pools(16), 256); }
#[test]
fn test_blake2b_output_size() {
let hash1 = Sha3Hasher::hash(b"test data");
let hash2 = Sha3Hasher::concat_and_hash(&hash1, Some(&hash1));
assert_eq!(hash1.len(), 32, "Hash should be 32 bytes (256 bits)");
assert_eq!(
hash2.len(),
32,
"Concatenated hash should be 32 bytes (256 bits)"
);
let hash3 = Sha3Hasher::hash(b"different data");
assert_ne!(
hash1, hash3,
"Different inputs should produce different hashes"
);
println!("Blake2b hash size verified: 32 bytes (256 bits)");
println!("Sample hash: {:02x?}", &hash1[..8]);
}
#[test]
fn test_reward_candidate_pool_hash() {
let leaves = make_test_leaves(16);
let tree = MerkleTree::from_xornames(leaves).unwrap();
let candidates = tree.reward_candidates(12345).unwrap();
let mut seen = std::collections::HashSet::new();
for candidate in &candidates {
assert!(seen.insert(candidate));
}
assert_eq!(seen.len(), candidates.len());
let hash1 = candidates[0].hash();
let hash2 = candidates[0].hash();
assert_eq!(hash1, hash2, "Hash should be deterministic");
let hash3 = candidates[1].hash();
assert_ne!(
hash1, hash3,
"Different candidates should have different hashes"
);
}
#[test]
fn test_min_leaves_validation() {
let leaves = make_test_leaves(1);
let result = MerkleTree::from_xornames(leaves);
assert!(matches!(result, Err(MerkleTreeError::TooFewLeaves { .. })));
}
#[test]
fn test_max_leaves_validation() {
let leaves = make_test_leaves(MAX_LEAVES + 1);
let result = MerkleTree::from_xornames(leaves);
assert!(matches!(result, Err(MerkleTreeError::TooManyLeaves { .. })));
}
#[test]
fn test_basic_tree_construction() {
let leaves = make_test_leaves(100);
let tree = MerkleTree::from_xornames(leaves).unwrap();
assert_eq!(tree.leaf_count(), 100);
assert_eq!(tree.depth(), 7); }
#[test]
fn test_power_of_two_leaves() {
for power in 1..=MAX_MERKLE_DEPTH {
let count = 1 << power; let leaves = make_test_leaves(count);
let tree = MerkleTree::from_xornames(leaves).unwrap();
assert_eq!(tree.depth(), power);
assert_eq!(tree.leaf_count(), count);
}
}
#[test]
fn test_midpoints() {
let leaves = make_test_leaves(256);
let tree = MerkleTree::from_xornames(leaves).unwrap();
let midpoints = tree.midpoints().unwrap();
assert_eq!(midpoints.len(), 16);
for (i, midpoint) in midpoints.iter().enumerate() {
assert_eq!(midpoint.index, i);
}
}
#[test]
fn test_reward_candidates() {
let leaves = make_test_leaves(256);
let tree = MerkleTree::from_xornames(leaves).unwrap();
let merkle_payment_timestamp = 1234567890u64;
let candidates = tree.reward_candidates(merkle_payment_timestamp).unwrap();
assert_eq!(candidates.len(), 16);
let mut addresses = std::collections::HashSet::new();
for candidate in &candidates {
assert!(addresses.insert(candidate.address()));
}
for candidate in &candidates {
assert!(
candidate.branch.verify(),
"Candidate branch should be valid"
);
}
let candidates2 = tree.reward_candidates(merkle_payment_timestamp).unwrap();
assert_eq!(candidates, candidates2);
let candidates3 = tree
.reward_candidates(merkle_payment_timestamp + 1)
.unwrap();
assert_ne!(candidates[0].address(), candidates3[0].address());
for candidate in &candidates3 {
assert!(
candidate.branch.verify(),
"Candidate branch with different timestamp should still be valid"
);
}
let tree_root = tree.root();
let expected_address = candidates[0].address();
let mut data = Vec::with_capacity(32 + 32 + 8);
data.extend_from_slice(candidates[0].branch.leaf_hash().as_ref());
data.extend_from_slice(tree_root.as_ref());
data.extend_from_slice(&merkle_payment_timestamp.to_le_bytes());
let manually_computed = XorName::from_content(&data);
assert_eq!(expected_address, manually_computed);
assert!(candidates[0].branch.verify());
assert_eq!(
candidates[0].merkle_payment_timestamp,
merkle_payment_timestamp
);
assert_eq!(candidates[0].address(), candidates[0].address()); assert_eq!(candidates[0].branch.root(), &tree_root);
assert_eq!(
candidates[0].branch.leaf_hash(),
candidates[0].branch.leaf_hash()
);
}
#[test]
fn test_address_proof_generation_and_verification() {
let leaves = make_test_leaves(100);
let tree = MerkleTree::from_xornames(leaves.clone()).unwrap();
let proof = tree.generate_address_proof(0, leaves[0]).unwrap();
assert!(proof.verify());
let proof = tree.generate_address_proof(99, leaves[99]).unwrap();
assert!(proof.verify());
let proof = tree.generate_address_proof(50, leaves[50]).unwrap();
assert!(proof.verify());
}
#[test]
fn test_invalid_address_index() {
let leaves = make_test_leaves(100);
let tree = MerkleTree::from_xornames(leaves.clone()).unwrap();
let dummy_hash = leaves[0]; let result = tree.generate_address_proof(100, dummy_hash);
assert!(matches!(
result,
Err(MerkleTreeError::InvalidLeafIndex { .. })
));
}
#[test]
fn test_midpoint_proof_generation_and_verification() {
let leaves = make_test_leaves(256);
let tree = MerkleTree::from_xornames(leaves).unwrap();
let midpoints = tree.midpoints().unwrap();
let proof = tree.generate_midpoint_proof(0, midpoints[0].hash).unwrap();
assert!(proof.verify());
let proof = tree
.generate_midpoint_proof(15, midpoints[15].hash)
.unwrap();
assert!(proof.verify());
}
#[test]
fn test_proof_depth() {
let leaves = make_test_leaves(16);
let tree = MerkleTree::from_xornames(leaves.clone()).unwrap();
let address_proof = tree.generate_address_proof(0, leaves[0]).unwrap();
assert_eq!(address_proof.depth(), 4);
let midpoints = tree.midpoints().unwrap();
let midpoint_proof = tree.generate_midpoint_proof(0, midpoints[0].hash).unwrap();
assert_eq!(midpoint_proof.depth(), 2);
}
#[test]
fn test_non_deterministic_root_due_to_salts() {
let leaves = make_test_leaves(100);
let tree1 = MerkleTree::from_xornames(leaves.clone()).unwrap();
let tree2 = MerkleTree::from_xornames(leaves).unwrap();
assert_ne!(tree1.root(), tree2.root());
assert_eq!(tree1.depth(), tree2.depth());
assert_eq!(tree1.leaf_count(), tree2.leaf_count());
}
#[test]
fn test_invalid_proof_rejection() {
let leaves = make_test_leaves(10);
let tree = MerkleTree::from_xornames(leaves.clone()).unwrap();
let wrong_leaf = XorName::from_content(b"wrong");
let wrong_proof = tree.generate_address_proof(0, wrong_leaf).unwrap();
assert!(!wrong_proof.verify());
}
#[test]
fn test_proof_hashes_length_for_depth_4() {
let leaves = make_test_leaves(16); let tree = MerkleTree::from_xornames(leaves.clone()).unwrap();
println!("Tree depth: {}", tree.depth());
println!("Tree leaf count: {}", tree.leaf_count());
let address_proof = tree.generate_address_proof(0, leaves[0]).unwrap();
println!(
"Address proof depth (proof_hashes.len()): {}",
address_proof.depth()
);
let midpoints = tree.midpoints().unwrap();
let midpoint_proof = tree.generate_midpoint_proof(0, midpoints[0].hash).unwrap();
println!(
"Midpoint proof depth (proof_hashes.len()): {}",
midpoint_proof.depth()
);
assert_eq!(tree.depth(), 4);
assert_eq!(
address_proof.depth(),
4,
"Address proof should have 4 siblings (levels 0->1->2->3->4)"
);
assert_eq!(
midpoint_proof.depth(),
2,
"Midpoint proof should have 2 siblings (levels 2->3->4)"
);
}
#[test]
fn test_verify_works_correctly() {
let leaves = make_test_leaves(16); let tree = MerkleTree::from_xornames(leaves.clone()).unwrap();
println!("Testing address proof verification...");
let proof_0 = tree.generate_address_proof(0, leaves[0]).unwrap();
println!("Address 0 proof depth: {}", proof_0.depth());
let valid = proof_0.verify();
println!("Address 0 verification: {valid}");
assert!(valid, "Proof for address 0 should be valid");
let proof_15 = tree.generate_address_proof(15, leaves[15]).unwrap();
println!("Address 15 proof depth: {}", proof_15.depth());
let valid = proof_15.verify();
println!("Address 15 verification: {valid}");
assert!(valid, "Proof for address 15 should be valid");
let proof_7 = tree.generate_address_proof(7, leaves[7]).unwrap();
println!("Address 7 proof depth: {}", proof_7.depth());
let valid = proof_7.verify();
println!("Address 7 verification: {valid}");
assert!(valid, "Proof for address 7 should be valid");
println!("\nTesting midpoint proof verification...");
let midpoints = tree.midpoints().unwrap();
println!("Number of midpoints: {}", midpoints.len());
let int_proof_0 = tree.generate_midpoint_proof(0, midpoints[0].hash).unwrap();
println!("Midpoint 0 proof depth: {}", int_proof_0.depth());
let valid = int_proof_0.verify();
println!("Midpoint 0 verification: {valid}");
assert!(valid, "Proof for midpoint 0 should be valid");
let int_proof_3 = tree.generate_midpoint_proof(3, midpoints[3].hash).unwrap();
println!("Midpoint 3 proof depth: {}", int_proof_3.depth());
let valid = int_proof_3.verify();
println!("Midpoint 3 verification: {valid}");
assert!(valid, "Proof for midpoint 3 should be valid");
println!("\nTesting invalid proofs are rejected...");
let wrong_leaf = XorName::from_content(b"wrong_leaf");
let wrong_proof = tree.generate_address_proof(0, wrong_leaf).unwrap();
let valid = wrong_proof.verify();
println!("Wrong leaf verification: {valid}");
assert!(!valid, "Proof with wrong leaf should fail");
let wrong_index_proof = tree.generate_address_proof(0, leaves[1]).unwrap();
let valid = wrong_index_proof.verify();
println!("Wrong leaf index verification: {valid}");
assert!(!valid, "Proof for leaf 0 with hash from leaf 1 should fail");
println!("\nAll verification tests passed!");
}
#[test]
fn test_complete_batch_payment_flow() {
println!("\n=== SIMULATING COMPLETE MERKLE BATCH PAYMENT FLOW ===\n");
println!("PHASE 1: CLIENT PREPARES DATA");
println!("------------------------------");
let real_address_count = 100;
let addresses = make_test_leaves(real_address_count);
println!("✓ Generated {real_address_count} real addresses from self-encryption");
println!("\nPHASE 2: CLIENT BUILDS MERKLE TREE");
println!("----------------------------------");
let tree = MerkleTree::from_xornames(addresses.clone()).unwrap();
let depth = tree.depth();
let root = tree.root();
let leaf_count = tree.leaf_count();
println!("✓ Tree depth: {depth}");
println!("✓ Real addresses: {leaf_count}");
println!("✓ Padded size: {} (2^{})", 1 << depth, depth);
println!("✓ Dummy addresses added: {}", (1 << depth) - leaf_count);
println!("✓ Merkle root: {root:?}");
assert_eq!(depth, 7); assert_eq!(leaf_count, 100);
println!("\nPHASE 3: CLIENT GETS REWARD CANDIDATES");
println!("---------------------------------------");
let merkle_payment_timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Failed to get current time")
.as_secs();
println!("✓ Transaction timestamp: {merkle_payment_timestamp}");
let candidates = tree.reward_candidates(merkle_payment_timestamp).unwrap();
let midpoint_count = expected_reward_pools(depth);
let level = midpoint_level(depth);
let proof_depth = midpoint_proof_depth(depth);
println!("✓ Midpoint level: {level}");
println!("✓ Midpoint proof depth: {proof_depth}");
println!("✓ Number of midpoint nodes (candidate pools): {midpoint_count}");
println!("✓ Tree depth: {depth}");
println!(
"✓ Total nodes queried: {} × {} = {}",
candidates.len(),
depth,
candidates.len() * depth as usize
);
assert_eq!(candidates.len(), midpoint_count);
println!("✓ Generated {} candidate pools", candidates.len());
let first_candidate = &candidates[0];
let midpoint_hash = first_candidate.branch.leaf_hash();
let candidate_root = first_candidate.branch.root();
println!("\n Example candidate #0:");
println!(" Midpoint hash: {midpoint_hash:?}");
println!(" Root: {candidate_root:?}");
println!(
" Timestamp: {}",
first_candidate.merkle_payment_timestamp
);
println!(" Address: {:?}", first_candidate.address());
println!(" (Address = hash(midpoint || root || timestamp))");
println!("\nPHASE 4: SMART CONTRACT RECEIVES PAYMENT");
println!("-----------------------------------------");
let smart_contract_root = root;
let smart_contract_depth = depth;
let smart_contract_timestamp = merkle_payment_timestamp;
println!("✓ Smart contract received payment");
println!("✓ Stored root: {smart_contract_root:?}");
println!("✓ Stored depth: {smart_contract_depth}");
println!("✓ Stored timestamp: {smart_contract_timestamp}");
println!("✓ Stored {} candidate pools", candidates.len());
let winner_pool_midpoint_proof_index = 0; let winner_candidate = &candidates[winner_pool_midpoint_proof_index];
let smart_contract_winner_pool_midpoint_proof_hash = winner_candidate.hash();
println!("✓ Winner pool selected: index {winner_pool_midpoint_proof_index}");
println!("✓ Winner pool hash stored: {smart_contract_winner_pool_midpoint_proof_hash:?}");
println!("✓ Payment distributed to {depth} nodes (depth)");
println!("\nPHASE 5: CLIENT UPLOADS CHUNKS WITH PROOFS");
println!("-------------------------------------------");
let mut address_proofs = Vec::new();
for (i, address_hash) in addresses.iter().enumerate() {
let proof = tree.generate_address_proof(i, *address_hash).unwrap();
address_proofs.push(proof);
}
println!("✓ Generated {} address proofs", address_proofs.len());
println!("✓ Each proof includes:");
println!(" - Merkle proof (siblings from leaf to root)");
println!(" - Salt (for privacy)");
println!(" - Node hash (address being proven)");
println!(" - Root (expected Merkle root)");
println!("\nPHASE 6: NODES VERIFY AND STORE CHUNKS");
println!("---------------------------------------");
let mut verified_count = 0;
for (i, address_proof) in address_proofs.iter().enumerate() {
let address_hash = &addresses[i];
let result = verify_merkle_proof(
address_hash,
address_proof,
winner_candidate,
smart_contract_depth,
&smart_contract_root,
smart_contract_timestamp,
);
assert!(
result.is_ok(),
"Address {} verification failed: {:?}",
i,
result.err()
);
verified_count += 1;
}
println!("✓ All {verified_count} addresses verified using verify_merkle_proof()");
println!("✓ Core Merkle verification includes:");
println!(" 1. Timestamp not in future");
println!(" 2. Payment not expired (< {MERKLE_PAYMENT_EXPIRATION} seconds old)");
println!(" 3. Winner pool timestamp matches smart contract timestamp");
println!(" 4. Address Merkle proof valid (address ∈ tree)");
println!(" 5. Winner Merkle proof valid (midpoint ∈ tree)");
println!(" 6. Address proof depth matches on-chain depth");
println!(" 7. Winner proof depth matches expected for midpoint");
println!(" 8. Address proof root matches on-chain root");
println!(" 9. Winner proof root matches on-chain root");
println!(" Note: Winner pool hash verification happens in MerklePaymentProof::verify()");
println!("\nPHASE 7: VERIFY PROOF STRUCTURE");
println!("--------------------------------");
let first_proof = &address_proofs[0];
let claimed_depth = depth; let expected_address_depth = claimed_depth as usize;
println!("✓ Address proof depth: {}", first_proof.depth());
println!(
"✓ Expected address proof depth (from claimed depth {claimed_depth}): {expected_address_depth}"
);
println!("✓ Number of sibling hashes: {}", first_proof.depth());
println!("✓ Has salt: {}", first_proof.salt.is_some());
assert_eq!(
first_proof.depth(),
expected_address_depth,
"Proof depth should match expected"
);
let winner_branch = &winner_candidate.branch;
let expected_midpoint_depth = midpoint_proof_depth(claimed_depth) as usize;
let level = midpoint_level(claimed_depth);
println!("\n✓ Winner midpoint proof depth: {}", winner_branch.depth());
println!("✓ Expected midpoint proof depth: {expected_midpoint_depth}");
println!("✓ Midpoint level: {level}");
println!("✓ Tree depth: {claimed_depth}");
println!("✓ No salt (midpoints are intermediate hashes)");
assert_eq!(
winner_branch.depth(),
expected_midpoint_depth,
"Midpoint proof depth should match expected"
);
assert!(
winner_branch.salt.is_none(),
"Midpoint proofs should not have salt"
);
println!("\nPHASE 8: VERIFY PRIVACY PROPERTIES");
println!("-----------------------------------");
let salts: Vec<_> = address_proofs.iter().map(|p| p.salt.unwrap()).collect();
let unique_salts: std::collections::HashSet<_> = salts.iter().collect();
assert_eq!(
unique_salts.len(),
salts.len(),
"All salts should be unique"
);
println!("✓ All {} addresses have unique salts", salts.len());
let tree2 = MerkleTree::from_xornames(addresses.clone()).unwrap();
assert_ne!(tree.root(), tree2.root(), "Different salt → different root");
println!("✓ Random salts ensure non-deterministic roots");
println!("✓ Privacy: address content cannot be inferred from tree structure");
println!("\nPHASE 9: COST COMPARISON");
println!("-------------------------");
let old_payments = real_address_count * 3; let new_payments = depth as usize;
println!("Old system (per-address payment):");
println!(
" {real_address_count} addresses × 3 nodes = {old_payments} payment transactions"
);
println!("\nNew system (Merkle batch payment):");
println!(" 1 batch payment → {new_payments} winner nodes");
println!(
" Nodes queried: {} (only query phase, no storage payment)",
candidates.len() * depth as usize
);
let savings_pct = ((old_payments - new_payments) as f64 / old_payments as f64) * 100.0;
println!("\n✓ Gas savings: {savings_pct:.1}% reduction");
println!(
"✓ Network query overhead: {}% of old system",
(candidates.len() * depth as usize * 100) / old_payments
);
println!("\n=== FLOW COMPLETE ===");
println!("✓ {real_address_count} real addresses uploaded");
println!("✓ {} dummy addresses padded", (1 << depth) - leaf_count);
println!("✓ {} candidate pools formed", candidates.len());
println!("✓ 1 winner pool paid ({depth} nodes)");
println!("✓ All addresses verified and stored");
println!("✓ Privacy preserved with random salts");
println!("✓ {savings_pct:.1}% gas cost reduction achieved\n");
}
#[test]
fn test_get_nodes_at_level_with_padding() {
println!("\n=== TESTING OUR PADDED TREE STRUCTURE ===\n");
let leaves = make_test_leaves(100); let tree = MerkleTree::from_xornames(leaves).unwrap();
let depth = tree.depth();
println!("Tree with 100 leaves:");
println!(" Depth: {depth}");
println!(" Original leaves: {}", tree.leaf_count());
println!(" Padded size: {} (2^{})", 1 << depth, depth);
for level in 0..=depth {
let expected_count = 1 << (depth - level);
if let Some(nodes) = tree.inner.get_nodes_at_level(level as usize) {
let actual_count = nodes.len();
println!("\nLevel {level}:");
println!(" Expected: {} nodes (2^{})", expected_count, depth - level);
println!(" Actual: {actual_count} nodes");
if level as usize == midpoint_level(depth) {
println!(" >>> MIDPOINT LEVEL <<<");
println!(
" Our workaround takes: {} nodes",
std::cmp::min(actual_count, 1 << midpoint_proof_depth(depth))
);
}
if actual_count != expected_count {
println!(" âš Mismatch! This is why we need .take() workaround");
}
}
}
println!("\n=== END TEST ===\n");
}
#[test]
fn test_proof_hashes_length_matches_depth() {
let leaves = make_test_leaves(16);
let tree = MerkleTree::from_xornames(leaves.clone()).unwrap();
assert_eq!(tree.depth(), 4);
let address_proof = tree.generate_address_proof(0, leaves[0]).unwrap();
assert_eq!(address_proof.depth(), 4);
let midpoints = tree.midpoints().unwrap();
let midpoint_proof = tree.generate_midpoint_proof(0, midpoints[0].hash).unwrap();
assert_eq!(midpoint_proof.depth(), 2);
let leaves = make_test_leaves(256);
let tree = MerkleTree::from_xornames(leaves.clone()).unwrap();
assert_eq!(tree.depth(), 8);
let address_proof = tree.generate_address_proof(0, leaves[0]).unwrap();
assert_eq!(address_proof.depth(), 8);
let midpoints = tree.midpoints().unwrap();
let midpoint_proof = tree.generate_midpoint_proof(0, midpoints[0].hash).unwrap();
assert_eq!(midpoint_proof.depth(), 4);
let leaves = make_test_leaves(100);
let tree = MerkleTree::from_xornames(leaves.clone()).unwrap();
assert_eq!(tree.depth(), 7);
let address_proof = tree.generate_address_proof(0, leaves[0]).unwrap();
assert_eq!(address_proof.depth(), 7);
let midpoints = tree.midpoints().unwrap();
let midpoint_proof = tree.generate_midpoint_proof(0, midpoints[0].hash).unwrap();
assert_eq!(midpoint_proof.depth(), 4);
}
#[test]
fn test_verify_merkle_proof_errors() {
use std::time::{SystemTime, UNIX_EPOCH};
let leaves = make_test_leaves(16);
let tree = MerkleTree::from_xornames(leaves.clone()).unwrap();
let merkle_payment_timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let candidates = tree.reward_candidates(merkle_payment_timestamp).unwrap();
let winner_pool_midpoint_proof = &candidates[0];
let address_proof = tree.generate_address_proof(0, leaves[0]).unwrap();
let root = tree.root();
let depth = tree.depth();
let wrong_root = XorName::from_content(b"wrong root");
let result = verify_merkle_proof(
&leaves[0],
&address_proof,
winner_pool_midpoint_proof,
depth,
&wrong_root,
merkle_payment_timestamp,
);
assert!(matches!(
result,
Err(BadMerkleProof::AddressBranchRootMismatch { .. })
));
let result = verify_merkle_proof(
&leaves[0],
&address_proof,
winner_pool_midpoint_proof,
depth + 1, &root,
merkle_payment_timestamp,
);
assert!(matches!(
result,
Err(BadMerkleProof::AddressProofDepthMismatch { .. })
));
let mut wrong_winner = winner_pool_midpoint_proof.clone();
let wrong_tree = MerkleTree::from_xornames(make_test_leaves(16)).unwrap();
let wrong_candidates = wrong_tree
.reward_candidates(merkle_payment_timestamp)
.unwrap();
wrong_winner.branch = wrong_candidates[0].branch.clone();
let result = verify_merkle_proof(
&leaves[0],
&address_proof,
&wrong_winner,
depth,
&root,
merkle_payment_timestamp,
);
assert!(matches!(
result,
Err(BadMerkleProof::WinnerBranchRootMismatch { .. })
));
let future_timestamp = merkle_payment_timestamp + 1000;
let result = verify_merkle_proof(
&leaves[0],
&address_proof,
winner_pool_midpoint_proof,
depth,
&root,
future_timestamp,
);
assert!(matches!(
result,
Err(BadMerkleProof::TimestampInFuture { .. })
));
let old_timestamp = merkle_payment_timestamp - MERKLE_PAYMENT_EXPIRATION - 1;
let old_candidates = tree.reward_candidates(old_timestamp).unwrap();
let result = verify_merkle_proof(
&leaves[0],
&address_proof,
&old_candidates[0],
depth,
&root,
old_timestamp,
);
assert!(matches!(result, Err(BadMerkleProof::PaymentExpired { .. })));
let different_timestamp = merkle_payment_timestamp - 100;
let result = verify_merkle_proof(
&leaves[0],
&address_proof,
winner_pool_midpoint_proof,
depth,
&root,
different_timestamp,
);
assert!(matches!(
result,
Err(BadMerkleProof::TimestampMismatch { .. })
));
}
#[test]
fn test_invalid_midpoint_index() {
let leaves = make_test_leaves(16);
let tree = MerkleTree::from_xornames(leaves).unwrap();
let midpoints = tree.midpoints().unwrap();
let midpoint_count = midpoints.len();
let result = tree.generate_midpoint_proof(midpoint_count, XorName::from_content(b"test"));
assert!(matches!(
result,
Err(MerkleTreeError::InvalidMidpointIndex { .. })
));
}
#[test]
fn test_reward_candidate_pool_address() {
let leaves = make_test_leaves(16);
let tree = MerkleTree::from_xornames(leaves).unwrap();
let timestamp1 = 12345u64;
let timestamp2 = 67890u64;
let candidates1 = tree.reward_candidates(timestamp1).unwrap();
let candidates2 = tree.reward_candidates(timestamp2).unwrap();
assert_ne!(candidates1[0].address(), candidates2[0].address());
assert_eq!(candidates1[0].address(), candidates1[0].address());
let addr = candidates1[0].address();
let mut data = Vec::with_capacity(32 + 32 + 8);
data.extend_from_slice(candidates1[0].branch.leaf_hash().as_ref());
data.extend_from_slice(candidates1[0].branch.root().as_ref());
data.extend_from_slice(×tamp1.to_le_bytes());
let expected = XorName::from_content(&data);
assert_eq!(addr, expected);
}
#[test]
fn test_calculate_depth_edge_cases() {
let test_cases = vec![
(2, 1), (3, 2), (4, 2), (5, 3), (8, 3), (9, 4), (16, 4), (17, 5), (100, 7), (256, 8), ];
for (leaf_count, expected_depth) in test_cases {
let leaves = make_test_leaves(leaf_count);
let tree = MerkleTree::from_xornames(leaves).unwrap();
assert_eq!(
tree.depth(),
expected_depth,
"Depth mismatch for {leaf_count} leaves"
);
}
}
}