use subtle::ConstantTimeEq;
use thiserror::Error;
use crate::cbor::{decode_canonical_cbor, encode_canonical_cbor, CborValue};
use crate::hash::sha256;
pub const MERKLE_ALG_ID: &str = "rfc9162-sha256";
pub const LEAVES_LIST_FORMAT_V1: &str = "cardano-poe-merkle-leaves-v1";
const DIGEST_LENGTH: usize = 32;
const LEAF_PREFIX: u8 = 0x00;
const NODE_PREFIX: u8 = 0x01;
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum MerkleError {
#[error("empty Merkle tree forbidden (n >= 1)")]
EmptyTree,
#[error("index {index} out of range for tree_size {tree_size}")]
IndexOutOfRange {
index: usize,
tree_size: usize,
},
}
pub fn merkle_root(leaves: &[[u8; DIGEST_LENGTH]]) -> Result<[u8; DIGEST_LENGTH], MerkleError> {
if leaves.is_empty() {
return Err(MerkleError::EmptyTree);
}
Ok(root_unchecked(leaves))
}
pub fn merkle_inclusion_proof(
leaves: &[[u8; DIGEST_LENGTH]],
index: usize,
) -> Result<Vec<[u8; DIGEST_LENGTH]>, MerkleError> {
if leaves.is_empty() {
return Err(MerkleError::EmptyTree);
}
if index >= leaves.len() {
return Err(MerkleError::IndexOutOfRange {
index,
tree_size: leaves.len(),
});
}
Ok(audit_path(leaves, index))
}
#[must_use]
pub fn verify_inclusion(
leaf: &[u8],
index: usize,
tree_size: usize,
proof: &[[u8; DIGEST_LENGTH]],
root: &[u8],
) -> bool {
if leaf.len() != DIGEST_LENGTH || root.len() != DIGEST_LENGTH {
return false;
}
if tree_size < 1 || index >= tree_size {
return false;
}
if tree_size == 1 {
if !proof.is_empty() || index != 0 {
return false;
}
return ct_eq(&hash_leaf(leaf), root);
}
let mut m = index;
let mut last = tree_size - 1;
let mut h = hash_leaf(leaf);
for sibling in proof {
if last == 0 {
return false;
}
if (m & 1) == 1 || m == last {
h = hash_node(sibling, &h);
while (m & 1) == 0 && m != 0 {
m >>= 1;
last >>= 1;
}
} else {
h = hash_node(&h, sibling);
}
m >>= 1;
last >>= 1;
}
if last != 0 {
return false;
}
ct_eq(&h, root)
}
fn largest_pow2_lt(n: usize) -> usize {
debug_assert!(n >= 2, "largest_pow2_lt requires n >= 2");
let mut k = 1;
while k * 2 < n {
k *= 2;
}
k
}
fn hash_leaf(d: &[u8]) -> [u8; DIGEST_LENGTH] {
let mut buf = Vec::with_capacity(1 + d.len());
buf.push(LEAF_PREFIX);
buf.extend_from_slice(d);
sha256(&buf)
}
fn hash_node(left: &[u8], right: &[u8]) -> [u8; DIGEST_LENGTH] {
let mut buf = Vec::with_capacity(1 + left.len() + right.len());
buf.push(NODE_PREFIX);
buf.extend_from_slice(left);
buf.extend_from_slice(right);
sha256(&buf)
}
fn root_unchecked(leaves: &[[u8; DIGEST_LENGTH]]) -> [u8; DIGEST_LENGTH] {
if leaves.len() == 1 {
return hash_leaf(&leaves[0]);
}
let k = largest_pow2_lt(leaves.len());
let left = root_unchecked(&leaves[..k]);
let right = root_unchecked(&leaves[k..]);
hash_node(&left, &right)
}
fn audit_path(leaves: &[[u8; DIGEST_LENGTH]], index: usize) -> Vec<[u8; DIGEST_LENGTH]> {
if leaves.len() == 1 {
return Vec::new();
}
let k = largest_pow2_lt(leaves.len());
if index < k {
let mut path = audit_path(&leaves[..k], index);
path.push(root_unchecked(&leaves[k..]));
path
} else {
let mut path = audit_path(&leaves[k..], index - k);
path.push(root_unchecked(&leaves[..k]));
path
}
}
fn ct_eq(a: &[u8], b: &[u8]) -> bool {
a.ct_eq(b).unwrap_u8() == 1
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MerkleLeavesListErrorCode {
Malformed,
FormatUnsupported,
LeafCountMismatch,
RootMismatch,
}
impl MerkleLeavesListErrorCode {
#[must_use]
pub const fn code(self) -> &'static str {
match self {
MerkleLeavesListErrorCode::Malformed => "SCHEMA_MERKLE_LEAVES_MALFORMED",
MerkleLeavesListErrorCode::FormatUnsupported => {
"SCHEMA_MERKLE_LEAVES_FORMAT_UNSUPPORTED"
}
MerkleLeavesListErrorCode::LeafCountMismatch => "SCHEMA_MERKLE_LEAF_COUNT_MISMATCH",
MerkleLeavesListErrorCode::RootMismatch => "MERKLE_ROOT_MISMATCH",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
#[error("{}: {detail}", code.code())]
pub struct MerkleLeavesListError {
code: MerkleLeavesListErrorCode,
detail: String,
}
impl MerkleLeavesListError {
#[must_use]
pub const fn code(&self) -> MerkleLeavesListErrorCode {
self.code
}
#[must_use]
pub const fn code_str(&self) -> &'static str {
self.code.code()
}
fn new(code: MerkleLeavesListErrorCode, detail: impl Into<String>) -> Self {
Self {
code,
detail: detail.into(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DecodedLeavesList {
pub format: String,
pub tree_alg: String,
pub root: [u8; DIGEST_LENGTH],
pub leaves: Vec<[u8; DIGEST_LENGTH]>,
pub leaf_count: usize,
pub leaf_alg: Option<String>,
}
pub fn encode_leaves_list(
leaves: &[[u8; DIGEST_LENGTH]],
root: &[u8; DIGEST_LENGTH],
leaf_alg: Option<&str>,
) -> Result<Vec<u8>, MerkleLeavesListError> {
if leaves.is_empty() {
return Err(MerkleLeavesListError::new(
MerkleLeavesListErrorCode::Malformed,
"leaves must be a non-empty list",
));
}
let mut pairs = vec![
(
CborValue::text("format"),
CborValue::text(LEAVES_LIST_FORMAT_V1),
),
(CborValue::text("tree_alg"), CborValue::text(MERKLE_ALG_ID)),
(CborValue::text("root"), CborValue::bytes(root.to_vec())),
(
CborValue::text("leaves"),
CborValue::Array(
leaves
.iter()
.map(|l| CborValue::bytes(l.to_vec()))
.collect(),
),
),
(
CborValue::text("leaf_count"),
CborValue::Unsigned(leaves.len() as u64),
),
];
if let Some(alg) = leaf_alg {
pairs.push((CborValue::text("leaf_alg"), CborValue::text(alg)));
}
encode_canonical_cbor(&CborValue::Map(pairs)).map_err(|e| {
MerkleLeavesListError::new(
MerkleLeavesListErrorCode::Malformed,
format!("canonical CBOR encode failed: {e}"),
)
})
}
pub fn decode_leaves_list(bytes: &[u8]) -> Result<DecodedLeavesList, MerkleLeavesListError> {
let decoded = decode_canonical_cbor(bytes).map_err(|e| {
MerkleLeavesListError::new(
MerkleLeavesListErrorCode::Malformed,
format!("CBOR decode failed: {e}"),
)
})?;
let pairs = match &decoded {
CborValue::Map(pairs) => pairs,
_ => {
return Err(MerkleLeavesListError::new(
MerkleLeavesListErrorCode::Malformed,
"top-level must be a CBOR map",
));
}
};
let format = match map_get(pairs, "format") {
Some(CborValue::Text(s)) => s.clone(),
_ => {
return Err(MerkleLeavesListError::new(
MerkleLeavesListErrorCode::Malformed,
"`format` must be a text string",
));
}
};
if format != LEAVES_LIST_FORMAT_V1 {
return Err(MerkleLeavesListError::new(
MerkleLeavesListErrorCode::FormatUnsupported,
format!("unsupported leaves-list format: {format:?}"),
));
}
let tree_alg = match map_get(pairs, "tree_alg") {
Some(CborValue::Text(s)) => s.clone(),
_ => {
return Err(MerkleLeavesListError::new(
MerkleLeavesListErrorCode::Malformed,
"`tree_alg` must be a text string",
));
}
};
if tree_alg != MERKLE_ALG_ID {
return Err(MerkleLeavesListError::new(
MerkleLeavesListErrorCode::Malformed,
format!("unsupported leaves-list tree_alg: {tree_alg:?}"),
));
}
let root = match map_get(pairs, "root") {
Some(CborValue::Bytes(b)) if b.len() == DIGEST_LENGTH => {
let mut out = [0u8; DIGEST_LENGTH];
out.copy_from_slice(b);
out
}
_ => {
return Err(MerkleLeavesListError::new(
MerkleLeavesListErrorCode::Malformed,
"`root` must be a 32-byte byte string",
));
}
};
let leaves_raw = match map_get(pairs, "leaves") {
Some(CborValue::Array(items)) if !items.is_empty() => items,
_ => {
return Err(MerkleLeavesListError::new(
MerkleLeavesListErrorCode::Malformed,
"`leaves` must be a non-empty array",
));
}
};
let mut leaves: Vec<[u8; DIGEST_LENGTH]> = Vec::with_capacity(leaves_raw.len());
for (i, item) in leaves_raw.iter().enumerate() {
match item {
CborValue::Bytes(b) if b.len() == DIGEST_LENGTH => {
let mut out = [0u8; DIGEST_LENGTH];
out.copy_from_slice(b);
leaves.push(out);
}
_ => {
return Err(MerkleLeavesListError::new(
MerkleLeavesListErrorCode::Malformed,
format!("`leaves[{i}]` must be a 32-byte byte string"),
));
}
}
}
let leaf_count = match map_get(pairs, "leaf_count") {
Some(CborValue::Unsigned(n)) => *n,
_ => {
return Err(MerkleLeavesListError::new(
MerkleLeavesListErrorCode::Malformed,
"`leaf_count` must be a non-negative integer",
));
}
};
if leaf_count != leaves.len() as u64 {
return Err(MerkleLeavesListError::new(
MerkleLeavesListErrorCode::LeafCountMismatch,
format!(
"`leaf_count` ({leaf_count}) does not match number of leaves ({})",
leaves.len()
),
));
}
let leaf_alg = match map_get(pairs, "leaf_alg") {
None => None,
Some(CborValue::Text(s)) => Some(s.clone()),
Some(_) => {
return Err(MerkleLeavesListError::new(
MerkleLeavesListErrorCode::Malformed,
"`leaf_alg` (if present) must be a text string",
));
}
};
let recomputed = root_unchecked(&leaves);
if !ct_eq(&recomputed, &root) {
return Err(MerkleLeavesListError::new(
MerkleLeavesListErrorCode::RootMismatch,
"leaves recompute does not match declared root",
));
}
Ok(DecodedLeavesList {
format,
tree_alg,
root,
leaf_count: leaves.len(),
leaves,
leaf_alg,
})
}
fn map_get<'a>(pairs: &'a [(CborValue, CborValue)], key: &str) -> Option<&'a CborValue> {
pairs.iter().find_map(|(k, v)| match k {
CborValue::Text(t) if t == key => Some(v),
_ => None,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn split_point_is_largest_power_of_two_strictly_less_than_n() {
assert_eq!(largest_pow2_lt(2), 1);
assert_eq!(largest_pow2_lt(3), 2);
assert_eq!(largest_pow2_lt(4), 2);
assert_eq!(largest_pow2_lt(5), 4);
assert_eq!(largest_pow2_lt(7), 4);
assert_eq!(largest_pow2_lt(8), 4);
assert_eq!(largest_pow2_lt(9), 8);
assert_eq!(largest_pow2_lt(16), 8);
assert_eq!(largest_pow2_lt(17), 16);
}
#[test]
fn leaf_and_node_prefixes_differ() {
let d = [0x42u8; 32];
assert_ne!(hash_leaf(&d), hash_node(&d, &d));
}
#[test]
fn error_code_strings_are_stable() {
assert_eq!(
MerkleLeavesListErrorCode::Malformed.code(),
"SCHEMA_MERKLE_LEAVES_MALFORMED"
);
assert_eq!(
MerkleLeavesListErrorCode::FormatUnsupported.code(),
"SCHEMA_MERKLE_LEAVES_FORMAT_UNSUPPORTED"
);
assert_eq!(
MerkleLeavesListErrorCode::LeafCountMismatch.code(),
"SCHEMA_MERKLE_LEAF_COUNT_MISMATCH"
);
assert_eq!(
MerkleLeavesListErrorCode::RootMismatch.code(),
"MERKLE_ROOT_MISMATCH"
);
}
#[test]
fn leaves_list_error_display_is_code_colon_detail() {
let err =
MerkleLeavesListError::new(MerkleLeavesListErrorCode::RootMismatch, "boom".to_string());
assert_eq!(err.to_string(), "MERKLE_ROOT_MISMATCH: boom");
assert_eq!(err.code_str(), "MERKLE_ROOT_MISMATCH");
}
#[test]
fn decode_rejects_unsupported_tree_alg() {
let leaves = [[0xa1u8; DIGEST_LENGTH], [0xa2u8; DIGEST_LENGTH]];
let root = root_unchecked(&leaves);
let pairs = vec![
(
CborValue::text("format"),
CborValue::text(LEAVES_LIST_FORMAT_V1),
),
(CborValue::text("root"), CborValue::bytes(root.to_vec())),
(
CborValue::text("leaves"),
CborValue::Array(
leaves
.iter()
.map(|l| CborValue::bytes(l.to_vec()))
.collect(),
),
),
(CborValue::text("tree_alg"), CborValue::text("not-rfc9162")),
(
CborValue::text("leaf_count"),
CborValue::Unsigned(leaves.len() as u64),
),
];
let bytes = encode_canonical_cbor(&CborValue::Map(pairs)).unwrap();
let err = decode_leaves_list(&bytes).expect_err("wrong tree_alg must reject");
assert_eq!(err.code(), MerkleLeavesListErrorCode::Malformed);
assert_eq!(err.code_str(), "SCHEMA_MERKLE_LEAVES_MALFORMED");
}
}