use blake3::Hasher;
use saorsa_pqc::api::sig::{
ml_dsa_65, MlDsaPublicKey, MlDsaSecretKey, MlDsaSignature, MlDsaVariant,
};
use serde::{Deserialize, Serialize};
use crate::ant_protocol::XorName;
pub const DOMAIN_COMMITMENT: &[u8] = b"autonomi.ant.replication.storage_commitment.v1";
pub const DOMAIN_COMMITMENT_HASH: &[u8] = b"autonomi.ant.replication.commitment_hash.v1";
pub const DOMAIN_LEAF: &[u8] = b"autonomi.ant.replication.storage_leaf.v1";
pub const DOMAIN_NODE: &[u8] = b"autonomi.ant.replication.storage_node.v1";
pub const MAX_COMMITMENT_KEY_COUNT: u32 = 1_000_000;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct StorageCommitment {
pub root: [u8; 32],
pub key_count: u32,
pub sender_peer_id: [u8; 32],
pub sender_public_key: Vec<u8>,
pub signature: Vec<u8>,
}
#[must_use]
pub fn leaf_hash(key: &XorName, bytes_hash: &[u8; 32]) -> [u8; 32] {
let mut h = Hasher::new();
h.update(DOMAIN_LEAF);
h.update(key);
h.update(bytes_hash);
*h.finalize().as_bytes()
}
#[must_use]
pub fn node_hash(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] {
let mut h = Hasher::new();
h.update(DOMAIN_NODE);
h.update(left);
h.update(right);
*h.finalize().as_bytes()
}
#[must_use]
pub fn commitment_hash(c: &StorageCommitment) -> Option<[u8; 32]> {
let serialized = postcard::to_allocvec(c).ok()?;
let mut h = Hasher::new();
h.update(DOMAIN_COMMITMENT_HASH);
h.update(&serialized);
Some(*h.finalize().as_bytes())
}
fn commitment_signed_payload(
root: &[u8; 32],
key_count: u32,
sender_peer_id: &[u8; 32],
sender_public_key: &[u8],
) -> Vec<u8> {
let mut v = Vec::with_capacity(32 + 4 + 32 + 4 + sender_public_key.len());
v.extend_from_slice(root);
v.extend_from_slice(&key_count.to_le_bytes());
v.extend_from_slice(sender_peer_id);
let pk_len = u32::try_from(sender_public_key.len()).unwrap_or(u32::MAX);
v.extend_from_slice(&pk_len.to_le_bytes());
v.extend_from_slice(sender_public_key);
v
}
pub struct MerkleTree {
leaves: Vec<(XorName, [u8; 32])>,
levels: Vec<Vec<[u8; 32]>>,
}
impl MerkleTree {
pub fn build(mut entries: Vec<(XorName, [u8; 32])>) -> Result<Self, CommitmentError> {
if entries.is_empty() {
return Err(CommitmentError::EmptyKeySet);
}
if entries.len() > MAX_COMMITMENT_KEY_COUNT as usize {
return Err(CommitmentError::TooManyKeys(entries.len()));
}
entries.sort_by_key(|a| a.0);
for w in entries.windows(2) {
if let [a, b] = w {
if a.0 == b.0 {
return Err(CommitmentError::DuplicateKey(a.0));
}
}
}
let leaves: Vec<(XorName, [u8; 32])> = entries
.into_iter()
.map(|(k, bh)| {
let lh = leaf_hash(&k, &bh);
(k, lh)
})
.collect();
let mut level: Vec<[u8; 32]> = leaves.iter().map(|(_, h)| *h).collect();
let mut levels = vec![level.clone()];
while level.len() > 1 {
level = build_next_level(&level);
levels.push(level.clone());
}
Ok(Self { leaves, levels })
}
#[must_use]
pub fn root(&self) -> [u8; 32] {
self.levels
.last()
.and_then(|l| l.first())
.copied()
.unwrap_or([0u8; 32])
}
#[must_use]
pub fn key_count(&self) -> u32 {
u32::try_from(self.leaves.len()).unwrap_or(u32::MAX)
}
#[must_use]
pub fn path_for(&self, key: &XorName) -> Option<Vec<[u8; 32]>> {
let idx = self.leaves.binary_search_by(|(k, _)| k.cmp(key)).ok()?;
let mut path = Vec::with_capacity(self.levels.len());
let mut i = idx;
for level in &self.levels[..self.levels.len().saturating_sub(1)] {
let sibling_idx = if i % 2 == 0 {
if i + 1 < level.len() {
i + 1
} else {
i
}
} else {
i - 1
};
path.push(level[sibling_idx]);
i /= 2;
}
Some(path)
}
#[cfg(test)]
pub(crate) fn iter_leaves(&self) -> impl Iterator<Item = &(XorName, [u8; 32])> {
self.leaves.iter()
}
#[must_use]
pub fn sorted_keys(&self) -> Vec<XorName> {
self.leaves.iter().map(|(k, _)| *k).collect()
}
#[must_use]
pub fn key_at(&self, idx: usize) -> Option<XorName> {
self.leaves.get(idx).map(|(k, _)| *k)
}
#[must_use]
pub fn key_index(&self, key: &XorName) -> Option<usize> {
self.leaves.binary_search_by(|(k, _)| k.cmp(key)).ok()
}
#[must_use]
pub fn contains_key(&self, key: &XorName) -> bool {
self.key_index(key).is_some()
}
#[must_use]
pub fn node_at(&self, level: usize, index: u64) -> Option<[u8; 32]> {
let index = usize::try_from(index).ok()?;
self.levels.get(level).and_then(|l| l.get(index)).copied()
}
#[must_use]
pub fn levels_count(&self) -> usize {
self.levels.len()
}
}
pub(crate) fn build_next_level(cur: &[[u8; 32]]) -> Vec<[u8; 32]> {
let mut next = Vec::with_capacity(cur.len().div_ceil(2));
let mut i = 0;
while i < cur.len() {
let left = &cur[i];
let right = if i + 1 < cur.len() { &cur[i + 1] } else { left };
next.push(node_hash(left, right));
i += 2;
}
next
}
#[must_use]
pub fn verify_path(
leaf: &[u8; 32],
path: &[[u8; 32]],
leaf_index: usize,
key_count: u32,
expected_root: &[u8; 32],
) -> bool {
if key_count == 0
|| key_count > MAX_COMMITMENT_KEY_COUNT
|| (leaf_index as u64) >= u64::from(key_count)
{
return false;
}
let Some(rounded) = key_count.checked_next_power_of_two() else {
return false;
};
let expected_path_len = rounded.trailing_zeros() as usize;
if path.len() != expected_path_len {
return false;
}
let mut cur = *leaf;
let mut i = leaf_index;
for sibling in path {
cur = if i % 2 == 0 {
node_hash(&cur, sibling)
} else {
node_hash(sibling, &cur)
};
i /= 2;
}
cur == *expected_root
}
pub fn sign_commitment(
secret_key: &MlDsaSecretKey,
root: &[u8; 32],
key_count: u32,
sender_peer_id: &[u8; 32],
sender_public_key: &[u8],
) -> Result<Vec<u8>, CommitmentError> {
let payload = commitment_signed_payload(root, key_count, sender_peer_id, sender_public_key);
let dsa = ml_dsa_65();
let sig = dsa
.sign_with_context(secret_key, &payload, DOMAIN_COMMITMENT)
.map_err(|e| CommitmentError::SignatureFailed(e.to_string()))?;
Ok(sig.to_bytes())
}
#[must_use]
pub fn verify_commitment_signature(c: &StorageCommitment) -> bool {
let Ok(public_key) = MlDsaPublicKey::from_bytes(MlDsaVariant::MlDsa65, &c.sender_public_key)
else {
return false;
};
verify_commitment_signature_with_key(c, &public_key)
}
#[must_use]
pub fn verify_commitment_signature_with_key(
c: &StorageCommitment,
public_key: &MlDsaPublicKey,
) -> bool {
let payload = commitment_signed_payload(
&c.root,
c.key_count,
&c.sender_peer_id,
&c.sender_public_key,
);
let Ok(sig) = MlDsaSignature::from_bytes(MlDsaVariant::MlDsa65, &c.signature) else {
return false;
};
let dsa = ml_dsa_65();
dsa.verify_with_context(public_key, &payload, &sig, DOMAIN_COMMITMENT)
.unwrap_or(false)
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum CommitmentError {
#[error("cannot build commitment over empty key set")]
EmptyKeySet,
#[error("commitment key count {0} exceeds MAX_COMMITMENT_KEY_COUNT")]
TooManyKeys(usize),
#[error("duplicate key in commitment: {}", hex::encode(.0))]
DuplicateKey(XorName),
#[error("commitment signing failed: {0}")]
SignatureFailed(String),
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
fn xn(byte: u8) -> XorName {
[byte; 32]
}
fn bh(byte: u8) -> [u8; 32] {
[byte ^ 0x5A; 32]
}
#[test]
fn empty_key_set_rejected() {
let result = MerkleTree::build(vec![]);
assert!(matches!(result, Err(CommitmentError::EmptyKeySet)));
}
#[test]
fn duplicate_keys_rejected() {
let result = MerkleTree::build(vec![(xn(1), bh(1)), (xn(1), bh(2))]);
assert!(matches!(result, Err(CommitmentError::DuplicateKey(_))));
}
#[test]
fn single_leaf_tree_root_is_leaf_hash() {
let key = xn(1);
let bytes_hash = bh(1);
let tree = MerkleTree::build(vec![(key, bytes_hash)]).unwrap();
assert_eq!(tree.root(), leaf_hash(&key, &bytes_hash));
assert_eq!(tree.key_count(), 1);
assert_eq!(tree.path_for(&key), Some(vec![]));
assert!(verify_path(
&leaf_hash(&key, &bytes_hash),
&[],
0,
1,
&tree.root()
));
}
#[test]
fn two_leaf_tree_root_combines_both_leaves() {
let entries = vec![(xn(1), bh(1)), (xn(2), bh(2))];
let tree = MerkleTree::build(entries).unwrap();
let l1 = leaf_hash(&xn(1), &bh(1));
let l2 = leaf_hash(&xn(2), &bh(2));
assert_eq!(tree.root(), node_hash(&l1, &l2));
}
#[test]
fn root_is_deterministic_regardless_of_input_order() {
let mut a = vec![(xn(3), bh(3)), (xn(1), bh(1)), (xn(2), bh(2))];
let mut b = vec![(xn(2), bh(2)), (xn(3), bh(3)), (xn(1), bh(1))];
let tree_a = MerkleTree::build(a.clone()).unwrap();
let tree_b = MerkleTree::build(b.clone()).unwrap();
a.sort_by_key(|x| x.0);
b.sort_by_key(|x| x.0);
assert_eq!(tree_a.root(), tree_b.root());
}
fn xn_u32(i: u32) -> XorName {
let mut k = [0u8; 32];
k[..4].copy_from_slice(&i.to_le_bytes());
k
}
fn bh_u32(i: u32) -> [u8; 32] {
let mut h = [0u8; 32];
h[..4].copy_from_slice(&i.to_le_bytes());
h[4] = 0x5A;
h
}
#[test]
fn paths_verify_for_every_key_at_various_sizes() {
for n in [1u32, 2, 3, 4, 5, 7, 8, 16, 17, 100, 333] {
let entries: Vec<_> = (0..n).map(|i| (xn_u32(i), bh_u32(i))).collect();
let tree = MerkleTree::build(entries.clone()).unwrap();
let root = tree.root();
let key_count = tree.key_count();
for (idx, (k, _)) in tree.iter_leaves().enumerate() {
let path = tree.path_for(k).expect("path for present key");
let bytes_hash = entries.iter().find(|(kk, _)| kk == k).unwrap().1;
let lh = leaf_hash(k, &bytes_hash);
assert!(
verify_path(&lh, &path, idx, key_count, &root),
"path verify failed at n={n} idx={idx}",
);
}
}
}
#[test]
fn path_for_absent_key_is_none() {
let tree = MerkleTree::build(vec![(xn(1), bh(1)), (xn(2), bh(2))]).unwrap();
assert!(tree.path_for(&xn(99)).is_none());
}
#[test]
fn tampered_bytes_hash_breaks_path_verify() {
let entries: Vec<_> = (1..=8u8).map(|i| (xn(i), bh(i))).collect();
let tree = MerkleTree::build(entries.clone()).unwrap();
let root = tree.root();
let (k, _) = &entries[3];
let path = tree.path_for(k).unwrap();
let wrong_bytes_hash = [0xFFu8; 32];
let lh = leaf_hash(k, &wrong_bytes_hash);
assert!(!verify_path(&lh, &path, 3, 8, &root));
}
#[test]
fn tampered_path_node_breaks_verify() {
let entries: Vec<_> = (1..=8u8).map(|i| (xn(i), bh(i))).collect();
let tree = MerkleTree::build(entries.clone()).unwrap();
let root = tree.root();
let (k, _) = &entries[3];
let mut path = tree.path_for(k).unwrap();
path[0][0] ^= 0x01;
let lh = leaf_hash(k, &bh(4));
assert!(!verify_path(&lh, &path, 3, 8, &root));
}
#[test]
fn wrong_leaf_index_breaks_verify() {
let entries: Vec<_> = (1..=8u8).map(|i| (xn(i), bh(i))).collect();
let tree = MerkleTree::build(entries.clone()).unwrap();
let root = tree.root();
let (k, _) = &entries[3];
let path = tree.path_for(k).unwrap();
let lh = leaf_hash(k, &bh(4));
assert!(!verify_path(&lh, &path, 2, 8, &root));
assert!(verify_path(&lh, &path, 3, 8, &root));
}
#[test]
fn out_of_range_leaf_index_rejected() {
let entries: Vec<_> = (1..=8u8).map(|i| (xn(i), bh(i))).collect();
let tree = MerkleTree::build(entries.clone()).unwrap();
let root = tree.root();
let (k, _) = &entries[3];
let path = tree.path_for(k).unwrap();
let lh = leaf_hash(k, &bh(4));
assert!(!verify_path(&lh, &path, 8, 8, &root));
assert!(!verify_path(&lh, &path, 99, 8, &root));
assert!(verify_path(&lh, &path, 3, 8, &root));
}
#[test]
fn wrong_path_length_rejected_pre_hashing() {
let entries: Vec<_> = (1..=8u8).map(|i| (xn(i), bh(i))).collect();
let tree = MerkleTree::build(entries.clone()).unwrap();
let root = tree.root();
let (k, _) = &entries[3];
let path = tree.path_for(k).unwrap();
let lh = leaf_hash(k, &bh(4));
assert_eq!(path.len(), 3);
let short: Vec<_> = path.iter().take(2).copied().collect();
assert!(!verify_path(&lh, &short, 3, 8, &root));
let mut long = path;
long.push([0; 32]);
assert!(!verify_path(&lh, &long, 3, 8, &root));
}
#[test]
fn zero_key_count_rejected() {
let lh = [0u8; 32];
assert!(!verify_path(&lh, &[], 0, 0, &[0u8; 32]));
}
#[test]
fn out_of_protocol_key_count_rejected() {
let lh = [0u8; 32];
assert!(!verify_path(
&lh,
&[],
0,
MAX_COMMITMENT_KEY_COUNT + 1,
&[0u8; 32]
));
assert!(!verify_path(&lh, &[], 0, u32::MAX, &[0u8; 32]));
}
fn pk_bytes(pk: &MlDsaPublicKey) -> Vec<u8> {
pk.to_bytes()
}
#[test]
fn sign_and_verify_roundtrip() {
let dsa = ml_dsa_65();
let (pk, sk) = dsa.generate_keypair().unwrap();
let entries: Vec<_> = (0..5u8).map(|i| (xn(i), bh(i))).collect();
let tree = MerkleTree::build(entries).unwrap();
let root = tree.root();
let key_count = tree.key_count();
let peer_id = [0xAB; 32];
let pk_b = pk_bytes(&pk);
let signature = sign_commitment(&sk, &root, key_count, &peer_id, &pk_b).unwrap();
let c = StorageCommitment {
root,
key_count,
sender_peer_id: peer_id,
sender_public_key: pk_b,
signature,
};
assert!(verify_commitment_signature(&c));
}
#[test]
fn signature_fails_when_root_tampered() {
let dsa = ml_dsa_65();
let (pk, sk) = dsa.generate_keypair().unwrap();
let root = [0u8; 32];
let pk_b = pk_bytes(&pk);
let signature = sign_commitment(&sk, &root, 1, &[0; 32], &pk_b).unwrap();
let c = StorageCommitment {
root: [1u8; 32], key_count: 1,
sender_peer_id: [0; 32],
sender_public_key: pk_b,
signature,
};
assert!(!verify_commitment_signature(&c));
}
#[test]
fn signature_fails_under_swapped_public_key() {
let dsa = ml_dsa_65();
let (pk1, sk1) = dsa.generate_keypair().unwrap();
let (pk2, _sk2) = dsa.generate_keypair().unwrap();
let pk1_b = pk_bytes(&pk1);
let pk2_b = pk_bytes(&pk2);
let signature = sign_commitment(&sk1, &[0u8; 32], 1, &[0; 32], &pk1_b).unwrap();
let c = StorageCommitment {
root: [0u8; 32],
key_count: 1,
sender_peer_id: [0; 32],
sender_public_key: pk2_b,
signature,
};
assert!(!verify_commitment_signature(&c));
}
#[test]
fn signature_fails_with_garbage_bytes() {
let dsa = ml_dsa_65();
let (pk, _sk) = dsa.generate_keypair().unwrap();
let c = StorageCommitment {
root: [0u8; 32],
key_count: 1,
sender_peer_id: [0; 32],
sender_public_key: pk_bytes(&pk),
signature: vec![0u8; 100], };
assert!(!verify_commitment_signature(&c));
}
#[test]
fn signature_fails_with_garbage_public_key() {
let c = StorageCommitment {
root: [0u8; 32],
key_count: 1,
sender_peer_id: [0; 32],
sender_public_key: vec![0u8; 100], signature: vec![0u8; 3293],
};
assert!(!verify_commitment_signature(&c));
}
#[test]
fn commitment_hash_differs_on_any_field_change() {
let dsa = ml_dsa_65();
let (pk, sk) = dsa.generate_keypair().unwrap();
let pk_b = pk_bytes(&pk);
let sig = sign_commitment(&sk, &[0; 32], 1, &[0; 32], &pk_b).unwrap();
let c1 = StorageCommitment {
root: [0; 32],
key_count: 1,
sender_peer_id: [0; 32],
sender_public_key: pk_b,
signature: sig,
};
let h1 = commitment_hash(&c1).unwrap();
let mut c2 = c1.clone();
c2.root = [1; 32];
assert_ne!(h1, commitment_hash(&c2).unwrap());
let mut c3 = c1.clone();
c3.key_count = 2;
assert_ne!(h1, commitment_hash(&c3).unwrap());
let mut c4 = c1.clone();
c4.sender_peer_id = [1; 32];
assert_ne!(h1, commitment_hash(&c4).unwrap());
let mut c5 = c1.clone();
c5.signature[0] ^= 1;
assert_ne!(h1, commitment_hash(&c5).unwrap());
let (pk_other, _) = dsa.generate_keypair().unwrap();
let mut c6 = c1;
c6.sender_public_key = pk_bytes(&pk_other);
assert_ne!(h1, commitment_hash(&c6).unwrap());
}
#[test]
fn commitment_hash_stable_for_identical_input() {
let dsa = ml_dsa_65();
let (pk, sk) = dsa.generate_keypair().unwrap();
let pk_b = pk_bytes(&pk);
let sig = sign_commitment(&sk, &[7; 32], 42, &[3; 32], &pk_b).unwrap();
let c = StorageCommitment {
root: [7; 32],
key_count: 42,
sender_peer_id: [3; 32],
sender_public_key: pk_b,
signature: sig,
};
assert_eq!(commitment_hash(&c), commitment_hash(&c));
}
#[test]
fn commitment_hash_signature_length_change_changes_hash() {
let c1 = StorageCommitment {
root: [0; 32],
key_count: 1,
sender_peer_id: [0; 32],
sender_public_key: vec![0u8; 1952],
signature: vec![0xAB],
};
let c2 = StorageCommitment {
root: [0; 32],
key_count: 1,
sender_peer_id: [0; 32],
sender_public_key: vec![0u8; 1952],
signature: vec![0xAB, 0x00],
};
assert_ne!(commitment_hash(&c1).unwrap(), commitment_hash(&c2).unwrap());
}
#[test]
fn too_many_keys_rejected() {
let mut entries = Vec::with_capacity(MAX_COMMITMENT_KEY_COUNT as usize + 1);
for i in 0..=MAX_COMMITMENT_KEY_COUNT {
let mut k = [0u8; 32];
k[..4].copy_from_slice(&i.to_le_bytes());
entries.push((k, [0; 32]));
}
let result = MerkleTree::build(entries);
assert!(matches!(result, Err(CommitmentError::TooManyKeys(_))));
}
}