use crate::storage::TreeReader;
use crate::SimpleHasher;
use alloc::format;
use alloc::vec::Vec;
use alloc::{boxed::Box, vec};
use anyhow::Context;
use borsh::{BorshDeserialize, BorshSerialize};
use num_derive::{FromPrimitive, ToPrimitive};
#[cfg(any(test))]
use proptest::prelude::*;
#[cfg(any(test))]
use proptest_derive::Arbitrary;
use serde::{Deserialize, Serialize};
use crate::proof::SparseMerkleNode;
use crate::{
types::{
nibble::{nibble_path::NibblePath, Nibble},
proof::{SparseMerkleInternalNode, SparseMerkleLeafNode},
Version,
},
KeyHash, ValueHash, SPARSE_MERKLE_PLACEHOLDER_HASH,
};
#[derive(
Clone,
Debug,
Hash,
Eq,
PartialEq,
Ord,
PartialOrd,
Serialize,
Deserialize,
borsh::BorshSerialize,
borsh::BorshDeserialize,
)]
#[cfg_attr(any(test), derive(Arbitrary))]
pub struct NodeKey {
version: Version,
nibble_path: NibblePath,
}
impl NodeKey {
pub fn new(version: Version, nibble_path: NibblePath) -> Self {
Self {
version,
nibble_path,
}
}
pub(crate) fn new_empty_path(version: Version) -> Self {
Self::new(version, NibblePath::new(vec![]))
}
pub fn version(&self) -> Version {
self.version
}
pub fn nibble_path(&self) -> &NibblePath {
&self.nibble_path
}
pub(crate) fn gen_child_node_key(&self, version: Version, n: Nibble) -> Self {
let mut node_nibble_path = self.nibble_path().clone();
node_nibble_path.push(n);
Self::new(version, node_nibble_path)
}
pub(crate) fn gen_parent_node_key(&self) -> Self {
let mut node_nibble_path = self.nibble_path().clone();
assert!(
node_nibble_path.pop().is_some(),
"Current node key is root.",
);
Self::new(self.version, node_nibble_path)
}
pub(crate) fn set_version(&mut self, version: Version) {
self.version = version;
}
}
#[derive(
Clone,
Debug,
Eq,
PartialEq,
borsh::BorshSerialize,
borsh::BorshDeserialize,
Serialize,
Deserialize,
)]
pub enum NodeType {
Leaf,
Internal { leaf_count: usize },
}
#[cfg(any(test))]
impl Arbitrary for NodeType {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: ()) -> Self::Strategy {
prop_oneof![
Just(NodeType::Leaf),
(2..100usize).prop_map(|leaf_count| NodeType::Internal { leaf_count })
]
.boxed()
}
}
#[derive(
Clone,
Debug,
Eq,
PartialEq,
borsh::BorshSerialize,
borsh::BorshDeserialize,
Serialize,
Deserialize,
)]
#[cfg_attr(any(test), derive(Arbitrary))]
pub struct Child {
pub hash: [u8; 32],
pub version: Version,
pub node_type: NodeType,
}
impl Child {
pub fn new(hash: [u8; 32], version: Version, node_type: NodeType) -> Self {
Self {
hash,
version,
node_type,
}
}
pub fn is_leaf(&self) -> bool {
matches!(self.node_type, NodeType::Leaf)
}
pub fn leaf_count(&self) -> usize {
match self.node_type {
NodeType::Leaf => 1,
NodeType::Internal { leaf_count } => leaf_count,
}
}
}
#[derive(
Debug,
Clone,
PartialEq,
Eq,
Default,
borsh::BorshSerialize,
borsh::BorshDeserialize,
Serialize,
Deserialize,
)]
pub struct Children {
children: Box<[Option<Child>; 16]>,
num_children: usize,
}
#[cfg(any(test))]
impl Arbitrary for Children {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
(any::<Box<[Option<Child>; 16]>>().prop_map(|children| {
let num_children = children.iter().filter(|child| child.is_some()).count();
Self {
children,
num_children,
}
}))
.boxed()
}
}
impl Children {
pub fn new() -> Self {
Default::default()
}
pub fn insert(&mut self, nibble: Nibble, child: Child) {
let idx = nibble.as_usize();
if self.children[idx].is_none() {
self.num_children += 1;
}
self.children[idx] = Some(child);
}
pub fn get(&self, nibble: Nibble) -> &Option<Child> {
&self.children[nibble.as_usize()]
}
pub fn is_empty(&self) -> bool {
self.num_children == 0
}
pub fn remove(&mut self, nibble: Nibble) {
let idx = nibble.as_usize();
if self.children[idx].is_some() {
self.num_children -= 1;
}
self.children[idx] = None;
}
pub fn values(&self) -> impl Iterator<Item = &Child> {
self.children.iter().filter_map(|child| child.as_ref())
}
pub fn iter(&self) -> impl Iterator<Item = (Nibble, &Child)> {
self.iter_sorted()
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = (Nibble, &mut Child)> {
self.children
.iter_mut()
.enumerate()
.filter_map(|(nibble, child)| {
if let Some(child) = child {
Some((Nibble::from(nibble as u8), child))
} else {
None
}
})
}
pub fn num_children(&self) -> usize {
self.num_children
}
pub fn iter_sorted(&self) -> impl Iterator<Item = (Nibble, &Child)> {
self.children
.iter()
.enumerate()
.filter_map(|(nibble, child)| {
if let Some(child) = child {
Some((Nibble::from(nibble as u8), child))
} else {
None
}
})
}
}
#[derive(
Clone,
Debug,
Eq,
PartialEq,
Serialize,
Deserialize,
borsh::BorshSerialize,
borsh::BorshDeserialize,
)]
pub struct InternalNode {
children: Children,
leaf_count: usize,
}
impl SparseMerkleInternalNode {
fn from<H: SimpleHasher>(internal_node: InternalNode) -> Self {
let bitmaps = internal_node.generate_bitmaps();
SparseMerkleInternalNode::new(
internal_node.merkle_hash::<H>(0, 8, bitmaps),
internal_node.merkle_hash::<H>(8, 8, bitmaps),
)
}
}
#[cfg(any(test))]
impl Arbitrary for InternalNode {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: ()) -> Self::Strategy {
(any::<Children>().prop_filter(
"InternalNode constructor panics when its only child is a leaf.",
|children| {
!(children.num_children() == 1
&& children.values().next().expect("Must exist.").is_leaf())
},
))
.prop_map(InternalNode::new)
.boxed()
}
}
fn has_only_child(width: u8, range_existence_bitmap: u16, range_leaf_bitmap: u16) -> bool {
width == 1 || (range_existence_bitmap.count_ones() == 1 && range_leaf_bitmap != 0)
}
fn has_child(
width: u8,
range_existence_bitmap: u16,
n_bitmap: u16,
range_leaf_bitmap: u16,
) -> bool {
width == 1 || (range_existence_bitmap == n_bitmap && range_leaf_bitmap != 0)
}
impl InternalNode {
pub fn new(children: Children) -> Self {
assert!(!children.is_empty(), "Children must not be empty");
if children.num_children() == 1 {
assert!(
!children
.values()
.next()
.expect("Must have 1 element")
.is_leaf(),
"If there's only one child, it must not be a leaf."
);
}
let leaf_count = Self::sum_leaf_count(&children);
Self {
children,
leaf_count,
}
}
fn sum_leaf_count(children: &Children) -> usize {
let mut leaf_count = 0;
for child in children.values() {
let n = child.leaf_count();
leaf_count += n;
}
leaf_count
}
pub fn leaf_count(&self) -> usize {
self.leaf_count
}
pub fn node_type(&self) -> NodeType {
NodeType::Internal {
leaf_count: self.leaf_count,
}
}
pub fn hash<H: SimpleHasher>(&self) -> [u8; 32] {
self.merkle_hash::<H>(
0,
16,
self.generate_bitmaps(),
)
}
pub fn children_sorted(&self) -> impl Iterator<Item = (Nibble, &Child)> {
self.children.iter_sorted()
}
pub fn children_unsorted(&self) -> impl Iterator<Item = (Nibble, &Child)> {
self.children.iter()
}
pub fn child(&self, n: Nibble) -> Option<&Child> {
self.children.get(n).as_ref()
}
pub fn generate_bitmaps(&self) -> (u16, u16) {
let mut existence_bitmap = 0;
let mut leaf_bitmap = 0;
for (nibble, child) in self.children.iter() {
let i = u8::from(nibble);
existence_bitmap |= 1u16 << i;
if child.is_leaf() {
leaf_bitmap |= 1u16 << i;
}
}
assert_eq!(existence_bitmap | leaf_bitmap, existence_bitmap);
(existence_bitmap, leaf_bitmap)
}
fn range_bitmaps(start: u8, width: u8, bitmaps: (u16, u16)) -> (u16, u16) {
assert!(start < 16 && width.count_ones() == 1 && start % width == 0);
assert!(width <= 16 && (start + width) <= 16);
let mask = (((1u32 << width) - 1) << start) as u16;
(bitmaps.0 & mask, bitmaps.1 & mask)
}
fn build_sibling<H: SimpleHasher>(
&self,
tree_reader: &impl TreeReader,
node_key: &NodeKey,
start: u8,
width: u8,
(existence_bitmap, leaf_bitmap): (u16, u16),
) -> SparseMerkleNode {
let (range_existence_bitmap, range_leaf_bitmap) =
Self::range_bitmaps(start, width, (existence_bitmap, leaf_bitmap));
if range_existence_bitmap == 0 {
SparseMerkleNode::Null
} else if has_only_child(width, range_existence_bitmap, range_leaf_bitmap) {
let only_child_index = Nibble::from(range_existence_bitmap.trailing_zeros() as u8);
let child = self
.child(only_child_index)
.with_context(|| {
format!(
"Corrupted internal node: existence_bitmap indicates \
the existence of a non-exist child at index {:x}",
only_child_index
)
})
.unwrap();
let child_node = tree_reader
.get_node(&node_key.gen_child_node_key(child.version, only_child_index))
.with_context(|| {
format!(
"Corruption error: the merkle tree reader supplied cannot find \
the child of version {:?} at index {:x}.",
child.version, only_child_index
)
})
.unwrap();
match child_node {
Node::Internal(node) => {
SparseMerkleNode::Internal(SparseMerkleInternalNode::from::<H>(node))
}
Node::Leaf(node) => SparseMerkleNode::Leaf(SparseMerkleLeafNode::from(node)),
Node::Null => unreachable!("Impossible to get a null node at this location"),
}
} else {
let left_child = self.merkle_hash::<H>(
start,
width / 2,
(range_existence_bitmap, range_leaf_bitmap),
);
let right_child = self.merkle_hash::<H>(
start + width / 2,
width / 2,
(range_existence_bitmap, range_leaf_bitmap),
);
SparseMerkleNode::Internal(SparseMerkleInternalNode::new(left_child, right_child))
}
}
fn merkle_hash<H: SimpleHasher>(
&self,
start: u8,
width: u8,
(existence_bitmap, leaf_bitmap): (u16, u16),
) -> [u8; 32] {
let (range_existence_bitmap, range_leaf_bitmap) =
Self::range_bitmaps(start, width, (existence_bitmap, leaf_bitmap));
if range_existence_bitmap == 0 {
SPARSE_MERKLE_PLACEHOLDER_HASH
} else if has_only_child(width, range_existence_bitmap, range_leaf_bitmap) {
let only_child_index = Nibble::from(range_existence_bitmap.trailing_zeros() as u8);
self.child(only_child_index)
.with_context(|| {
format!(
"Corrupted internal node: existence_bitmap indicates \
the existence of a non-exist child at index {:x}",
only_child_index
)
})
.unwrap()
.hash
} else {
let left_child = self.merkle_hash::<H>(
start,
width / 2,
(range_existence_bitmap, range_leaf_bitmap),
);
let right_child = self.merkle_hash::<H>(
start + width / 2,
width / 2,
(range_existence_bitmap, range_leaf_bitmap),
);
SparseMerkleInternalNode::new(left_child, right_child).hash::<H>()
}
}
pub fn get_only_child_without_siblings(
&self,
node_key: &NodeKey,
n: Nibble,
) -> Option<NodeKey> {
let (existence_bitmap, leaf_bitmap) = self.generate_bitmaps();
for h in (0..4).rev() {
let width = 1 << h;
let child_half_start = get_child_half_start(n, h);
let (range_existence_bitmap, range_leaf_bitmap) =
Self::range_bitmaps(child_half_start, width, (existence_bitmap, leaf_bitmap));
if range_existence_bitmap == 0 {
return None;
} else if has_only_child(width, range_existence_bitmap, range_leaf_bitmap) {
let only_child_index = Nibble::from(range_existence_bitmap.trailing_zeros() as u8);
let only_child_version = self
.child(only_child_index)
.with_context(|| {
format!(
"Corrupted internal node: child_bitmap indicates \
the existence of a non-exist child at index {:x}",
only_child_index
)
})
.unwrap()
.version;
return Some(node_key.gen_child_node_key(only_child_version, only_child_index));
}
}
unreachable!("Impossible to get here without returning even at the lowest level.")
}
fn get_child_with_siblings_helper<H: SimpleHasher>(
&self,
tree_reader: &impl TreeReader,
node_key: &NodeKey,
n: Nibble,
get_only_child: bool,
) -> (Option<NodeKey>, Vec<SparseMerkleNode>) {
let mut siblings: Vec<SparseMerkleNode> = vec![];
let (existence_bitmap, leaf_bitmap) = self.generate_bitmaps();
let n_bitmap = 1 << n.as_usize();
for h in (0..4).rev() {
let width = 1 << h;
let (child_half_start, sibling_half_start) = get_child_and_sibling_half_start(n, h);
siblings.push(self.build_sibling::<H>(
tree_reader,
node_key,
sibling_half_start,
width,
(existence_bitmap, leaf_bitmap),
));
let (range_existence_bitmap, range_leaf_bitmap) =
Self::range_bitmaps(child_half_start, width, (existence_bitmap, leaf_bitmap));
if range_existence_bitmap == 0 {
return (None, siblings);
} else if get_only_child
&& (has_only_child(width, range_existence_bitmap, range_leaf_bitmap))
{
let only_child_index = Nibble::from(range_existence_bitmap.trailing_zeros() as u8);
return (
{
let only_child_version = self
.child(only_child_index)
.with_context(|| {
format!(
"Corrupted internal node: child_bitmap indicates \
the existence of a non-exist child at index {:x}",
only_child_index
)
})
.unwrap()
.version;
Some(node_key.gen_child_node_key(only_child_version, only_child_index))
},
siblings,
);
} else if !get_only_child
&& (has_child(width, range_existence_bitmap, n_bitmap, range_leaf_bitmap))
{
return (
{
let only_child_version = self
.child(n)
.with_context(|| {
format!(
"Corrupted internal node: child_bitmap indicates \
the existence of a non-exist child at index {:x}",
n
)
})
.unwrap()
.version;
Some(node_key.gen_child_node_key(only_child_version, n))
},
siblings,
);
}
}
unreachable!("Impossible to get here without returning even at the lowest level.")
}
pub(crate) fn get_child_with_siblings<H: SimpleHasher>(
&self,
tree_cache: &impl TreeReader,
node_key: &NodeKey,
n: Nibble,
) -> (Option<NodeKey>, Vec<SparseMerkleNode>) {
self.get_child_with_siblings_helper::<H>(tree_cache, node_key, n, false)
}
pub(crate) fn get_only_child_with_siblings<H: SimpleHasher>(
&self,
tree_reader: &impl TreeReader,
node_key: &NodeKey,
n: Nibble,
) -> (Option<NodeKey>, Vec<SparseMerkleNode>) {
self.get_child_with_siblings_helper::<H>(tree_reader, node_key, n, true)
}
#[cfg(test)]
pub(crate) fn children(&self) -> &Children {
&self.children
}
}
pub(crate) fn get_child_and_sibling_half_start(n: Nibble, height: u8) -> (u8, u8) {
let child_half_start = (0xff << height) & u8::from(n);
let sibling_half_start = child_half_start ^ (1 << height);
(child_half_start, sibling_half_start)
}
pub(crate) fn get_child_half_start(n: Nibble, height: u8) -> u8 {
(0xff << height) & u8::from(n)
}
#[derive(
Clone,
Debug,
Eq,
PartialEq,
Serialize,
Deserialize,
borsh::BorshSerialize,
borsh::BorshDeserialize,
)]
pub struct LeafNode {
key_hash: KeyHash,
value_hash: ValueHash,
}
impl LeafNode {
pub fn new(key_hash: KeyHash, value_hash: ValueHash) -> Self {
Self {
key_hash,
value_hash,
}
}
pub fn key_hash(&self) -> KeyHash {
self.key_hash
}
pub(crate) fn value_hash(&self) -> ValueHash {
self.value_hash
}
pub fn hash<H: SimpleHasher>(&self) -> [u8; 32] {
SparseMerkleLeafNode::new(self.key_hash, self.value_hash).hash::<H>()
}
}
impl From<LeafNode> for SparseMerkleLeafNode {
fn from(leaf_node: LeafNode) -> Self {
Self::new(leaf_node.key_hash, leaf_node.value_hash)
}
}
#[repr(u8)]
#[derive(FromPrimitive, ToPrimitive, BorshDeserialize, BorshSerialize)]
#[borsh(use_discriminant = false)]
enum NodeTag {
Null = 0,
Leaf = 1,
Internal = 2,
}
#[derive(Clone, Debug, Eq, PartialEq, BorshSerialize, BorshDeserialize, Serialize, Deserialize)]
pub enum Node {
Null,
Internal(InternalNode),
Leaf(LeafNode),
}
impl From<InternalNode> for Node {
fn from(node: InternalNode) -> Self {
Node::Internal(node)
}
}
impl From<InternalNode> for Children {
fn from(node: InternalNode) -> Self {
node.children
}
}
impl From<LeafNode> for Node {
fn from(node: LeafNode) -> Self {
Node::Leaf(node)
}
}
impl Node {
pub(crate) fn new_null() -> Self {
Node::Null
}
#[cfg(any(test))]
pub(crate) fn new_internal(children: Children) -> Self {
Node::Internal(InternalNode::new(children))
}
pub(crate) fn new_leaf(key_hash: KeyHash, value_hash: ValueHash) -> Self {
Node::Leaf(LeafNode::new(key_hash, value_hash))
}
#[cfg(any(test))]
pub(crate) fn leaf_from_value<H: SimpleHasher>(
key_hash: KeyHash,
value: impl AsRef<[u8]>,
) -> Self {
Node::Leaf(LeafNode::new(key_hash, ValueHash::with::<H>(value)))
}
pub(crate) fn is_leaf(&self) -> bool {
matches!(self, Node::Leaf(_))
}
pub(crate) fn node_type(&self) -> NodeType {
match self {
Self::Null => unreachable!(),
Self::Leaf(_) => NodeType::Leaf,
Self::Internal(n) => n.node_type(),
}
}
pub(crate) fn leaf_count(&self) -> usize {
match self {
Node::Null => 0,
Node::Leaf(_) => 1,
Node::Internal(internal_node) => internal_node.leaf_count,
}
}
pub(crate) fn hash<H: SimpleHasher>(&self) -> [u8; 32] {
match self {
Node::Null => SPARSE_MERKLE_PLACEHOLDER_HASH,
Node::Internal(internal_node) => internal_node.hash::<H>(),
Node::Leaf(leaf_node) => leaf_node.hash::<H>(),
}
}
}