use alloy_primitives::B256;
use crate::{error::Result, Stem, StemNode, UbtError, STEM_LEN};
pub const EMPTY_NODE_TAG: u8 = 0x00;
pub const INTERNAL_NODE_TAG: u8 = 0x01;
pub const STEM_NODE_TAG: u8 = 0x02;
pub const LEAF_NODE_TAG: u8 = 0x03;
const STEM_HEADER_SIZE: usize = 1 + STEM_LEN + 32;
const INTERNAL_SIZE: usize = 1 + 32 + 32;
const LEAF_SIZE: usize = 1 + 32;
pub fn encode_stem_node(node: &StemNode) -> Vec<u8> {
let mut buf = Vec::with_capacity(STEM_HEADER_SIZE + node.values.len() * 32);
buf.push(STEM_NODE_TAG);
buf.extend_from_slice(node.stem.as_bytes());
let mut bitmap = [0u8; 32];
for &subindex in node.values.keys() {
bitmap[subindex as usize / 8] |= 1 << (subindex % 8);
}
buf.extend_from_slice(&bitmap);
for subindex in 0u8..=255 {
if bitmap[subindex as usize / 8] & (1 << (subindex % 8)) != 0 {
buf.extend_from_slice(node.values[&subindex].as_slice());
}
}
buf
}
pub fn decode_stem_node(bytes: &[u8]) -> Result<StemNode> {
if bytes.len() < STEM_HEADER_SIZE {
return Err(UbtError::InvalidEncoding(format!(
"stem node: expected at least {STEM_HEADER_SIZE} bytes, got {}",
bytes.len()
)));
}
if bytes[0] != STEM_NODE_TAG {
return Err(UbtError::InvalidEncoding(format!(
"expected stem node tag 0x{STEM_NODE_TAG:02x}, got 0x{:02x}",
bytes[0]
)));
}
let mut stem_bytes = [0u8; STEM_LEN];
stem_bytes.copy_from_slice(&bytes[1..=STEM_LEN]);
let stem = Stem::new(stem_bytes);
let bitmap = &bytes[1 + STEM_LEN..STEM_HEADER_SIZE];
let value_count: usize = bitmap.iter().map(|b| b.count_ones() as usize).sum();
let expected_size = STEM_HEADER_SIZE + value_count * 32;
if bytes.len() != expected_size {
return Err(UbtError::InvalidEncoding(format!(
"stem node: expected {expected_size} bytes, got {}",
bytes.len()
)));
}
let mut node = StemNode::new(stem);
node.values.reserve(value_count);
let mut offset = STEM_HEADER_SIZE;
for subindex in 0u8..=255 {
if bitmap[subindex as usize / 8] & (1 << (subindex % 8)) != 0 {
let value = B256::from_slice(&bytes[offset..offset + 32]);
node.set_value(subindex, value);
offset += 32;
}
}
Ok(node)
}
#[must_use]
pub fn encoded_stem_size(node: &StemNode) -> usize {
STEM_HEADER_SIZE + node.values.len() * 32
}
pub fn encode_internal_node(left_hash: &B256, right_hash: &B256) -> Vec<u8> {
let mut buf = Vec::with_capacity(INTERNAL_SIZE);
buf.push(INTERNAL_NODE_TAG);
buf.extend_from_slice(left_hash.as_slice());
buf.extend_from_slice(right_hash.as_slice());
buf
}
pub fn decode_internal_node(bytes: &[u8]) -> Result<(B256, B256)> {
if bytes.len() != INTERNAL_SIZE {
return Err(UbtError::InvalidEncoding(format!(
"internal node: expected {INTERNAL_SIZE} bytes, got {}",
bytes.len()
)));
}
if bytes[0] != INTERNAL_NODE_TAG {
return Err(UbtError::InvalidEncoding(format!(
"expected internal node tag 0x{INTERNAL_NODE_TAG:02x}, got 0x{:02x}",
bytes[0]
)));
}
let left = B256::from_slice(&bytes[1..33]);
let right = B256::from_slice(&bytes[33..65]);
Ok((left, right))
}
pub fn encode_leaf_node(value: &B256) -> Vec<u8> {
let mut buf = Vec::with_capacity(LEAF_SIZE);
buf.push(LEAF_NODE_TAG);
buf.extend_from_slice(value.as_slice());
buf
}
pub fn decode_leaf_node(bytes: &[u8]) -> Result<B256> {
if bytes.len() != LEAF_SIZE {
return Err(UbtError::InvalidEncoding(format!(
"leaf node: expected {LEAF_SIZE} bytes, got {}",
bytes.len()
)));
}
if bytes[0] != LEAF_NODE_TAG {
return Err(UbtError::InvalidEncoding(format!(
"expected leaf node tag 0x{LEAF_NODE_TAG:02x}, got 0x{:02x}",
bytes[0]
)));
}
Ok(B256::from_slice(&bytes[1..33]))
}
#[must_use]
pub fn encode_empty_node() -> Vec<u8> {
vec![EMPTY_NODE_TAG]
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum DecodedNode {
Empty,
Internal {
left_hash: B256,
right_hash: B256,
},
Stem(StemNode),
Leaf {
value: B256,
},
}
pub fn decode_node(bytes: &[u8]) -> Result<DecodedNode> {
if bytes.is_empty() {
return Err(UbtError::InvalidEncoding("empty buffer".to_string()));
}
match bytes[0] {
EMPTY_NODE_TAG if bytes.len() == 1 => Ok(DecodedNode::Empty),
EMPTY_NODE_TAG => Err(UbtError::InvalidEncoding(format!(
"empty node: expected 1 byte, got {}",
bytes.len()
))),
INTERNAL_NODE_TAG => {
let (left, right) = decode_internal_node(bytes)?;
Ok(DecodedNode::Internal {
left_hash: left,
right_hash: right,
})
}
STEM_NODE_TAG => Ok(DecodedNode::Stem(decode_stem_node(bytes)?)),
LEAF_NODE_TAG => Ok(DecodedNode::Leaf {
value: decode_leaf_node(bytes)?,
}),
tag => Err(UbtError::InvalidEncoding(format!(
"unknown node type tag: 0x{tag:02x}"
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_roundtrip_empty() {
let node = StemNode::new(Stem::new([0xAA; 31]));
let bytes = encode_stem_node(&node);
assert_eq!(bytes.len(), STEM_HEADER_SIZE);
let decoded = decode_stem_node(&bytes).unwrap();
assert_eq!(decoded.stem, node.stem);
assert!(decoded.values.is_empty());
}
#[test]
fn test_roundtrip_single_value() {
let mut node = StemNode::new(Stem::new([0x01; 31]));
node.set_value(42, B256::repeat_byte(0xFF));
let bytes = encode_stem_node(&node);
assert_eq!(bytes.len(), STEM_HEADER_SIZE + 32);
let decoded = decode_stem_node(&bytes).unwrap();
assert_eq!(decoded.stem, node.stem);
assert_eq!(decoded.get_value(42), Some(B256::repeat_byte(0xFF)));
assert_eq!(decoded.values.len(), 1);
}
#[test]
fn test_roundtrip_sparse() {
let mut node = StemNode::new(Stem::new([0x00; 31]));
node.set_value(0, B256::repeat_byte(0x01));
node.set_value(127, B256::repeat_byte(0x02));
node.set_value(255, B256::repeat_byte(0x03));
let bytes = encode_stem_node(&node);
assert_eq!(bytes.len(), STEM_HEADER_SIZE + 3 * 32);
let decoded = decode_stem_node(&bytes).unwrap();
assert_eq!(decoded.get_value(0), Some(B256::repeat_byte(0x01)));
assert_eq!(decoded.get_value(127), Some(B256::repeat_byte(0x02)));
assert_eq!(decoded.get_value(255), Some(B256::repeat_byte(0x03)));
assert_eq!(decoded.get_value(1), None);
assert_eq!(decoded.values.len(), 3);
}
#[test]
fn test_roundtrip_full() {
let mut node = StemNode::new(Stem::new([0xBB; 31]));
for i in 0u16..256 {
node.set_value(i as u8, B256::repeat_byte((i as u8).max(1)));
}
let bytes = encode_stem_node(&node);
assert_eq!(bytes.len(), STEM_HEADER_SIZE + 256 * 32);
let decoded = decode_stem_node(&bytes).unwrap();
assert_eq!(decoded.values.len(), 256);
for i in 0u16..256 {
assert_eq!(
decoded.get_value(i as u8),
Some(B256::repeat_byte((i as u8).max(1)))
);
}
}
#[test]
fn test_encoded_size() {
let mut node = StemNode::new(Stem::new([0; 31]));
assert_eq!(encoded_stem_size(&node), STEM_HEADER_SIZE);
node.set_value(0, B256::repeat_byte(1));
node.set_value(100, B256::repeat_byte(2));
assert_eq!(encoded_stem_size(&node), STEM_HEADER_SIZE + 2 * 32);
assert_eq!(encode_stem_node(&node).len(), encoded_stem_size(&node));
}
#[test]
fn test_decode_too_short() {
let err = decode_stem_node(&[0x02; 10]).unwrap_err();
assert!(matches!(err, UbtError::InvalidEncoding(_)));
}
#[test]
fn test_decode_wrong_tag() {
let mut bytes = vec![0x00; STEM_HEADER_SIZE];
bytes[0] = 0xFF;
let err = decode_stem_node(&bytes).unwrap_err();
assert!(matches!(err, UbtError::InvalidEncoding(_)));
}
#[test]
fn test_decode_truncated_values() {
let mut node = StemNode::new(Stem::new([0; 31]));
node.set_value(0, B256::repeat_byte(0x42));
let bytes = encode_stem_node(&node);
let truncated = &bytes[..STEM_HEADER_SIZE + 16];
let err = decode_stem_node(truncated).unwrap_err();
assert!(matches!(err, UbtError::InvalidEncoding(_)));
}
#[test]
fn test_wire_format_layout() {
let mut stem_bytes = [0u8; 31];
stem_bytes[0] = 0xAB;
let mut node = StemNode::new(Stem::new(stem_bytes));
node.set_value(0, B256::repeat_byte(0x11));
node.set_value(8, B256::repeat_byte(0x22));
let bytes = encode_stem_node(&node);
assert_eq!(bytes[0], STEM_NODE_TAG);
assert_eq!(bytes[1], 0xAB);
assert_eq!(bytes[2..32], [0u8; 30]);
assert_eq!(bytes[32], 0x01);
assert_eq!(bytes[33], 0x01);
assert_eq!(bytes[34..64], [0u8; 30]);
assert_eq!(&bytes[64..96], B256::repeat_byte(0x11).as_slice());
assert_eq!(&bytes[96..128], B256::repeat_byte(0x22).as_slice());
}
#[test]
fn test_internal_roundtrip() {
let left = B256::repeat_byte(0xAA);
let right = B256::repeat_byte(0xBB);
let bytes = encode_internal_node(&left, &right);
assert_eq!(bytes.len(), INTERNAL_SIZE);
let (l, r) = decode_internal_node(&bytes).unwrap();
assert_eq!(l, left);
assert_eq!(r, right);
}
#[test]
fn test_leaf_roundtrip() {
let value = B256::repeat_byte(0x42);
let bytes = encode_leaf_node(&value);
assert_eq!(bytes.len(), LEAF_SIZE);
let decoded = decode_leaf_node(&bytes).unwrap();
assert_eq!(decoded, value);
}
#[test]
fn test_empty_roundtrip() {
let bytes = encode_empty_node();
assert_eq!(bytes.len(), 1);
assert_eq!(bytes[0], EMPTY_NODE_TAG);
let decoded = decode_node(&bytes).unwrap();
assert_eq!(decoded, DecodedNode::Empty);
}
#[test]
fn test_decode_node_dispatches() {
let stem_bytes = encode_stem_node(&StemNode::new(Stem::new([0; 31])));
assert!(matches!(
decode_node(&stem_bytes).unwrap(),
DecodedNode::Stem(_)
));
let internal_bytes = encode_internal_node(&B256::ZERO, &B256::ZERO);
assert!(matches!(
decode_node(&internal_bytes).unwrap(),
DecodedNode::Internal { .. }
));
let leaf_bytes = encode_leaf_node(&B256::repeat_byte(1));
assert!(matches!(
decode_node(&leaf_bytes).unwrap(),
DecodedNode::Leaf { .. }
));
}
#[test]
fn test_decode_node_unknown_tag() {
let err = decode_node(&[0xFF]).unwrap_err();
assert!(matches!(err, UbtError::InvalidEncoding(_)));
}
#[test]
fn test_decode_node_empty_buffer() {
let err = decode_node(&[]).unwrap_err();
assert!(matches!(err, UbtError::InvalidEncoding(_)));
}
#[test]
fn test_reject_trailing_bytes() {
let mut leaf = encode_leaf_node(&B256::repeat_byte(0x42));
leaf.push(0xDE);
assert!(decode_leaf_node(&leaf).is_err());
let mut internal = encode_internal_node(&B256::ZERO, &B256::ZERO);
internal.push(0xDE);
assert!(decode_internal_node(&internal).is_err());
let mut stem = encode_stem_node(&StemNode::new(Stem::new([0; 31])));
stem.push(0xDE);
assert!(decode_stem_node(&stem).is_err());
assert!(decode_node(&[EMPTY_NODE_TAG, 0xFF]).is_err());
}
}