use sha2::{Digest, Sha256};
pub type Hash256 = [u8; 32];
#[derive(Debug, Clone, PartialEq)]
pub struct ProofNode {
pub label_bit_len: u32,
pub label_path: Vec<u8>,
pub hash: Hash256,
}
#[derive(Debug, Clone)]
pub struct LookupProof {
pub found: bool,
pub proof: Vec<ProofNode>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Label {
bytes: Vec<u8>,
bit_len: u32,
}
impl Label {
pub fn empty() -> Self {
Self {
bytes: Vec::new(),
bit_len: 0,
}
}
pub fn from_key(key: &Hash256) -> Self {
Self {
bytes: key.to_vec(),
bit_len: 256,
}
}
pub fn bit_len(&self) -> u32 {
self.bit_len
}
pub fn bytes(&self) -> &[u8] {
&self.bytes
}
pub fn bit(&self, idx: u32) -> u8 {
assert!(idx < self.bit_len);
let byte_idx = (idx / 8) as usize;
let bit_idx = 7 - (idx % 8); (self.bytes[byte_idx] >> bit_idx) & 1
}
pub fn prefix(&self, len: u32) -> Self {
if len >= self.bit_len {
return self.clone();
}
if len == 0 {
return Self::empty();
}
let full_bytes = (len / 8) as usize;
let remaining_bits = len % 8;
let mut bytes = self.bytes[..full_bytes].to_vec();
if remaining_bits > 0 && full_bytes < self.bytes.len() {
let mask = 0xFFu8 << (8 - remaining_bits);
bytes.push(self.bytes[full_bytes] & mask);
}
Self {
bytes,
bit_len: len,
}
}
#[allow(dead_code)]
pub fn suffix(&self, start: u32) -> Self {
if start >= self.bit_len {
return Self::empty();
}
let new_len = self.bit_len - start;
let mut result = Label::empty();
for i in 0..new_len {
result = result.append_bit(self.bit(start + i));
}
result
}
pub fn append_bit(&self, bit: u8) -> Self {
let new_bit_len = self.bit_len + 1;
let byte_idx = (self.bit_len / 8) as usize;
let bit_idx = 7 - (self.bit_len % 8);
let mut bytes = self.bytes.clone();
if byte_idx >= bytes.len() {
bytes.push(0);
}
if bit == 1 {
bytes[byte_idx] |= 1 << bit_idx;
}
Self {
bytes,
bit_len: new_bit_len,
}
}
pub fn is_prefix_of(&self, other: &Label) -> bool {
if self.bit_len > other.bit_len {
return false;
}
if self.bit_len == 0 {
return true;
}
let full_bytes = (self.bit_len / 8) as usize;
if self.bytes[..full_bytes] != other.bytes[..full_bytes] {
return false;
}
let remaining_bits = self.bit_len % 8;
if remaining_bits > 0 {
let mask = 0xFFu8 << (8 - remaining_bits);
let self_last = self.bytes.get(full_bytes).copied().unwrap_or(0) & mask;
let other_last = other.bytes.get(full_bytes).copied().unwrap_or(0) & mask;
if self_last != other_last {
return false;
}
}
true
}
pub fn common_prefix_len(&self, other: &Label) -> u32 {
let max_len = self.bit_len.min(other.bit_len);
if max_len == 0 {
return 0;
}
let full_bytes = (max_len / 8) as usize;
for i in 0..full_bytes {
if self.bytes[i] != other.bytes[i] {
let diff = self.bytes[i] ^ other.bytes[i];
let leading_zeros = diff.leading_zeros();
return (i as u32) * 8 + leading_zeros;
}
}
let remaining_bits = max_len % 8;
if remaining_bits > 0 && full_bytes < self.bytes.len() && full_bytes < other.bytes.len() {
let mask = 0xFFu8 << (8 - remaining_bits);
let self_byte = self.bytes[full_bytes] & mask;
let other_byte = other.bytes[full_bytes] & mask;
if self_byte != other_byte {
let diff = self_byte ^ other_byte;
let leading_zeros = diff.leading_zeros();
return (full_bytes as u32) * 8 + leading_zeros;
}
}
max_len
}
}
#[derive(Debug, Clone, Default)]
enum Node {
#[default]
Empty,
Leaf { key: Label, value_hash: Hash256 },
Internal {
prefix: Label,
left: Box<Node>,
right: Box<Node>,
hash: Hash256,
},
}
impl Node {
fn hash(&self) -> Hash256 {
match self {
Node::Empty => [0u8; 32],
Node::Leaf { key, value_hash } => {
let mut hasher = Sha256::new();
hasher.update([0x00]); hasher.update(&key.bytes);
hasher.update(value_hash);
hasher.finalize().into()
}
Node::Internal { hash, .. } => *hash,
}
}
fn is_empty(&self) -> bool {
matches!(self, Node::Empty)
}
}
pub struct PrefixTree {
root: Node,
}
impl Default for PrefixTree {
fn default() -> Self {
Self::new()
}
}
impl PrefixTree {
pub fn new() -> Self {
Self { root: Node::Empty }
}
pub fn root_hash(&self) -> Hash256 {
self.root.hash()
}
pub fn insert(&mut self, key: &Hash256, value_hash: Hash256) {
let key_label = Label::from_key(key);
self.root = Self::insert_rec(std::mem::take(&mut self.root), &key_label, value_hash);
}
fn insert_rec(node: Node, key: &Label, value_hash: Hash256) -> Node {
match node {
Node::Empty => {
Node::Leaf {
key: key.clone(),
value_hash,
}
}
Node::Leaf {
key: existing_key,
value_hash: existing_value,
} => {
if *key == existing_key {
Node::Leaf {
key: key.clone(),
value_hash,
}
} else {
let common_len = key.common_prefix_len(&existing_key);
let prefix = key.prefix(common_len);
let key_bit = key.bit(common_len);
let _existing_bit = existing_key.bit(common_len);
let new_leaf = Node::Leaf {
key: key.clone(),
value_hash,
};
let existing_leaf = Node::Leaf {
key: existing_key,
value_hash: existing_value,
};
let (left, right) = if key_bit == 0 {
(Box::new(new_leaf), Box::new(existing_leaf))
} else {
(Box::new(existing_leaf), Box::new(new_leaf))
};
let hash = Self::compute_internal_hash(&prefix, &left, &right);
Node::Internal {
prefix,
left,
right,
hash,
}
}
}
Node::Internal {
prefix,
left,
right,
..
} => {
let common_len = key.common_prefix_len(&prefix);
if common_len < prefix.bit_len() {
let new_prefix = prefix.prefix(common_len);
let key_bit = key.bit(common_len);
let new_leaf = Node::Leaf {
key: key.clone(),
value_hash,
};
let old_internal_hash = Self::compute_internal_hash(&prefix, &left, &right);
let old_internal = Node::Internal {
prefix: prefix.clone(),
left,
right,
hash: old_internal_hash,
};
let (new_left, new_right) = if key_bit == 0 {
(Box::new(new_leaf), Box::new(old_internal))
} else {
(Box::new(old_internal), Box::new(new_leaf))
};
let hash = Self::compute_internal_hash(&new_prefix, &new_left, &new_right);
Node::Internal {
prefix: new_prefix,
left: new_left,
right: new_right,
hash,
}
} else {
let next_bit = key.bit(prefix.bit_len());
let (new_left, new_right) = if next_bit == 0 {
let new_left = Self::insert_rec(*left, key, value_hash);
(Box::new(new_left), right)
} else {
let new_right = Self::insert_rec(*right, key, value_hash);
(left, Box::new(new_right))
};
let hash = Self::compute_internal_hash(&prefix, &new_left, &new_right);
Node::Internal {
prefix,
left: new_left,
right: new_right,
hash,
}
}
}
}
}
fn compute_internal_hash(prefix: &Label, left: &Node, right: &Node) -> Hash256 {
let mut hasher = Sha256::new();
hasher.update([0x01]); hasher.update(prefix.bit_len().to_be_bytes());
hasher.update(&prefix.bytes);
hasher.update(left.hash());
hasher.update(right.hash());
hasher.finalize().into()
}
pub fn lookup(&self, key: &Hash256) -> LookupProof {
let key_label = Label::from_key(key);
let mut proof = Vec::new();
let found = Self::lookup_rec(&self.root, &key_label, &mut proof);
LookupProof { found, proof }
}
fn lookup_rec(node: &Node, key: &Label, proof: &mut Vec<ProofNode>) -> bool {
match node {
Node::Empty => false,
Node::Leaf {
key: leaf_key,
value_hash: _,
} => *key == *leaf_key,
Node::Internal {
prefix,
left,
right,
..
} => {
if !prefix.is_prefix_of(key) {
return false;
}
let next_bit = key.bit(prefix.bit_len());
if next_bit == 0 {
if !right.is_empty() {
let right_label = prefix.append_bit(1);
proof.push(ProofNode {
label_bit_len: right_label.bit_len(),
label_path: right_label.bytes().to_vec(),
hash: right.hash(),
});
}
Self::lookup_rec(left, key, proof)
} else {
if !left.is_empty() {
let left_label = prefix.append_bit(0);
proof.push(ProofNode {
label_bit_len: left_label.bit_len(),
label_path: left_label.bytes().to_vec(),
hash: left.hash(),
});
}
Self::lookup_rec(right, key, proof)
}
}
}
}
pub fn len(&self) -> usize {
Self::count_keys(&self.root)
}
pub fn is_empty(&self) -> bool {
matches!(self.root, Node::Empty)
}
fn count_keys(node: &Node) -> usize {
match node {
Node::Empty => 0,
Node::Leaf { .. } => 1,
Node::Internal { left, right, .. } => Self::count_keys(left) + Self::count_keys(right),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_key(n: u8) -> Hash256 {
let mut key = [0u8; 32];
key[0] = n;
key
}
#[test]
fn test_label_basics() {
let label = Label::empty();
assert_eq!(label.bit_len(), 0);
let key = test_key(0b10110000);
let label = Label::from_key(&key);
assert_eq!(label.bit_len(), 256);
assert_eq!(label.bit(0), 1);
assert_eq!(label.bit(1), 0);
assert_eq!(label.bit(2), 1);
assert_eq!(label.bit(3), 1);
}
#[test]
fn test_label_prefix() {
let key = test_key(0b11001010);
let label = Label::from_key(&key);
let prefix = label.prefix(4);
assert_eq!(prefix.bit_len(), 4);
assert_eq!(prefix.bit(0), 1);
assert_eq!(prefix.bit(1), 1);
assert_eq!(prefix.bit(2), 0);
assert_eq!(prefix.bit(3), 0);
}
#[test]
fn test_label_is_prefix_of() {
let key1 = test_key(0b11001010);
let label1 = Label::from_key(&key1);
let prefix = label1.prefix(4);
assert!(prefix.is_prefix_of(&label1));
assert!(!label1.is_prefix_of(&prefix));
let empty = Label::empty();
assert!(empty.is_prefix_of(&label1));
assert!(empty.is_prefix_of(&prefix));
}
#[test]
fn test_empty_tree() {
let tree = PrefixTree::new();
assert!(tree.is_empty());
assert_eq!(tree.len(), 0);
assert_eq!(tree.root_hash(), [0u8; 32]);
}
#[test]
fn test_single_insert() {
let mut tree = PrefixTree::new();
let key = test_key(42);
let value = [1u8; 32];
tree.insert(&key, value);
assert!(!tree.is_empty());
assert_eq!(tree.len(), 1);
assert_ne!(tree.root_hash(), [0u8; 32]);
let result = tree.lookup(&key);
assert!(result.found);
}
#[test]
fn test_multiple_inserts() {
let mut tree = PrefixTree::new();
for i in 0..10u8 {
let key = test_key(i);
let value = [i; 32];
tree.insert(&key, value);
}
assert_eq!(tree.len(), 10);
for i in 0..10u8 {
let key = test_key(i);
let result = tree.lookup(&key);
assert!(result.found, "Key {} not found", i);
}
let result = tree.lookup(&test_key(255));
assert!(!result.found);
}
#[test]
fn test_update_existing_key() {
let mut tree = PrefixTree::new();
let key = test_key(42);
tree.insert(&key, [1u8; 32]);
let hash1 = tree.root_hash();
tree.insert(&key, [2u8; 32]);
let hash2 = tree.root_hash();
assert_ne!(hash1, hash2);
assert_eq!(tree.len(), 1); }
#[test]
fn test_lookup_proof_structure() {
let mut tree = PrefixTree::new();
let key1 = test_key(0b00000000); let key2 = test_key(0b10000000);
tree.insert(&key1, [1u8; 32]);
tree.insert(&key2, [2u8; 32]);
let result1 = tree.lookup(&key1);
assert!(result1.found);
assert!(!result1.proof.is_empty());
let result2 = tree.lookup(&key2);
assert!(result2.found);
assert!(!result2.proof.is_empty());
}
#[test]
fn test_deterministic_root() {
let mut tree1 = PrefixTree::new();
let mut tree2 = PrefixTree::new();
for i in 0..5u8 {
let key = test_key(i);
let value = [i; 32];
tree1.insert(&key, value);
tree2.insert(&key, value);
}
assert_eq!(tree1.root_hash(), tree2.root_hash());
}
#[test]
fn test_insert_performance() {
use sha2::{Digest, Sha256};
use std::time::Instant;
let mut tree = PrefixTree::new();
let count = 1000;
let keys: Vec<Hash256> = (0..count)
.map(|i| {
let mut hasher = Sha256::new();
hasher.update(format!("package-{}", i).as_bytes());
hasher.finalize().into()
})
.collect();
let start = Instant::now();
for (i, key) in keys.iter().enumerate() {
let value = [i as u8; 32];
tree.insert(key, value);
}
let elapsed = start.elapsed();
println!(
"Inserted {} keys in {:?} ({:.2} µs/key)",
count,
elapsed,
elapsed.as_micros() as f64 / count as f64
);
assert!(
elapsed.as_millis() < 200,
"Insert performance too slow: {:?}",
elapsed
);
assert_eq!(tree.len(), count);
}
}