use crate as storage;
use std::cmp::Ordering;
use std::fmt::Debug;
use std::hash::Hash;
use std::io::{Read, Write};
use std::iter::{empty, once};
use std::ops::Deref;
use crate::DefaultDB;
use crate::Storable;
use crate::arena::{ArenaKey, Sp};
use crate::db::DB;
use crate::storable::{Loader, SizeAnn};
use derive_where::derive_where;
use serialize::{self, Deserializable, Serializable, Tagged, tag_enforcement_test};
#[derive_where(Debug, Eq, Clone, PartialEq; V, A)]
#[derive(Storable)]
#[storable(db = D)]
#[storable(db = D, invariant = MerklePatriciaTrie::invariant)]
pub struct MerklePatriciaTrie<
V: Storable<D>,
D: DB = DefaultDB,
A: Storable<D> + Annotation<V> = SizeAnn,
>(
#[cfg(feature = "public-internal-structure")] pub Sp<Node<V, D, A>, D>,
#[cfg(not(feature = "public-internal-structure"))] pub(crate) Sp<Node<V, D, A>, D>,
);
impl<V: Storable<D> + Tagged, D: DB, A: Storable<D> + Annotation<V> + Tagged> Tagged
for MerklePatriciaTrie<V, D, A>
{
fn tag() -> std::borrow::Cow<'static, str> {
format!("mpt({},{})", V::tag(), A::tag()).into()
}
fn tag_unique_factor() -> String {
<Node<V, D, A>>::tag_unique_factor()
}
}
tag_enforcement_test!(MerklePatriciaTrie<()>);
impl<V: Storable<D>, D: DB, A: Storable<D> + Annotation<V>> Default
for MerklePatriciaTrie<V, D, A>
{
fn default() -> Self {
Self::new()
}
}
impl<V: Storable<D>, D: DB, A: Storable<D> + Annotation<V>> MerklePatriciaTrie<V, D, A> {
pub fn new() -> Self {
MerklePatriciaTrie(Sp::new(Node::Empty))
}
fn invariant(&self) -> Result<(), std::io::Error> {
fn err<V: Storable<D>, D: DB, A: Storable<D> + Annotation<V>>(
ann: &A,
true_val: A,
) -> Result<A, std::io::Error> {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"MPT annotation isn't correctly calculated: Annotation {:?}, correct calculation: {:?}",
ann, true_val
),
))
}
fn sum_ann<V: Storable<D>, D: DB, A: Storable<D> + Annotation<V>>(
node: &Node<V, D, A>,
) -> Result<A, std::io::Error> {
match node {
Node::Empty => Ok(A::empty()),
Node::Leaf { ann, value } => {
let true_val = A::from_value(value);
if ann != &A::from_value(value) {
return err(ann, true_val);
};
Ok(true_val)
}
Node::Branch { ann, children } => {
let true_val = children.iter().try_fold(A::empty(), |acc, x| {
Ok::<A, std::io::Error>(acc.append(&sum_ann(&x.deref().clone())?))
})?;
if ann != &true_val {
return err(ann, true_val);
}
Ok(true_val)
}
Node::Extension { ann, child, .. } => {
let true_val = sum_ann(&child.deref().clone())?;
if ann != &true_val {
return err(ann, true_val);
}
Ok(true_val)
}
Node::MidBranchLeaf { ann, value, child } => {
let true_val = sum_ann(&child.deref().clone())?.append(&A::from_value(value));
if ann != &true_val {
return err(ann, true_val);
}
Ok(true_val)
}
}
}
let _ = sum_ann(self.0.deref())?;
Ok(())
}
pub fn insert(&self, path: &[u8], value: V) -> Self {
MerklePatriciaTrie(Node::<V, D, A>::insert(&self.0, path, value).0)
}
pub fn lookup(&self, path: &[u8]) -> Option<&V> {
Node::<V, D, A>::lookup(&self.0, path)
}
pub(crate) fn prune(
&self,
target_path: &[u8],
) -> (Self, Vec<Sp<V, D>>) {
let (node, pruned) = Node::<V, D, A>::prune(&self.0, target_path);
(MerklePatriciaTrie(node), pruned)
}
pub fn lookup_sp(&self, path: &[u8]) -> Option<Sp<V, D>> {
Node::<V, D, A>::lookup_sp(&self.0, path)
}
pub(crate) fn find_predecessor<'a>(&'a self, path: &[u8]) -> Option<(Vec<u8>, &'a V)> {
let mut best_predecessor = None;
Node::<V, D, A>::find_predecessor_recursive(
&self.0,
path,
&mut std::vec::Vec::new(),
&mut best_predecessor,
);
best_predecessor
}
pub fn remove(&self, path: &[u8]) -> Self {
MerklePatriciaTrie(Node::<V, D, A>::remove(&self.0, path).0)
}
pub fn into_inner_for_drop(self) -> impl Iterator<Item = V> {
Sp::into_inner(self.0)
.into_iter()
.flat_map(Node::into_inner_for_drop)
}
pub fn iter(&self) -> MPTIter<'_, V, D> {
MPTIter(Node::<V, D, A>::leaves(&self.0, vec![]))
}
pub fn size(&self) -> usize {
Node::<V, D, A>::size(&self.0)
}
pub fn is_empty(&self) -> bool {
matches!(self.0.deref(), Node::Empty)
}
pub fn ann(&self) -> A {
Node::<V, D, A>::ann(&self.0)
}
}
impl<V: Storable<D>, D: DB, A: Storable<D> + Annotation<V>> Hash for MerklePatriciaTrie<V, D, A> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
<Sp<Node<V, D, A>, D> as Hash>::hash(&self.0, state)
}
fn hash_slice<H: std::hash::Hasher>(data: &[Self], state: &mut H)
where
Self: Sized,
{
Sp::<Node<V, D, A>, D>::hash_slice(
&data
.iter()
.map(|d| d.0.clone())
.collect::<std::vec::Vec<Sp<Node<V, D, A>, D>>>(),
state,
)
}
}
impl<V: Storable<D> + Ord, D: DB, A: Storable<D> + Ord + Annotation<V>> Ord
for MerklePatriciaTrie<V, D, A>
{
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0.cmp(&other.0)
}
}
impl<V: Storable<D> + PartialOrd, D: DB, A: Storable<D> + PartialOrd + Annotation<V>> PartialOrd
for MerklePatriciaTrie<V, D, A>
{
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.0.partial_cmp(&other.0)
}
}
pub struct MPTIter<'a, T: Storable<D> + 'static, D: DB>(
Box<dyn Iterator<Item = (std::vec::Vec<u8>, Sp<T, D>)> + 'a>,
);
impl<T: Storable<D>, D: DB> Iterator for MPTIter<'_, T, D> {
type Item = (std::vec::Vec<u8>, Sp<T, D>);
fn next(&mut self) -> Option<Self::Item> {
self.0.next()
}
}
#[derive(Debug, Default)]
#[derive_where(Clone, Hash, PartialEq, Eq; T, A)]
#[derive_where(PartialOrd; T: PartialOrd, A: PartialOrd)]
#[derive_where(Ord; T: Ord, A: Ord)]
#[allow(clippy::type_complexity)]
#[cfg(feature = "public-internal-structure")]
pub enum Node<T: Storable<D> + 'static, D: DB = DefaultDB, A: Storable<D> + Annotation<T> = SizeAnn>
{
#[default]
Empty,
Leaf {
ann: A,
value: Sp<T, D>,
},
Branch {
ann: A,
children: Box<[Sp<Node<T, D, A>, D>; 16]>,
},
Extension {
ann: A,
compressed_path: std::vec::Vec<u8>, child: Sp<Node<T, D, A>, D>,
},
MidBranchLeaf {
ann: A,
value: Sp<T, D>,
child: Sp<Node<T, D, A>, D>, },
}
#[derive(Debug, Default)]
#[derive_where(Clone, Hash, PartialEq, Eq; T, A)]
#[derive_where(PartialOrd; T: PartialOrd, A: PartialOrd)]
#[derive_where(Ord; T: Ord, A: Ord)]
#[allow(clippy::type_complexity)]
#[cfg(not(feature = "public-internal-structure"))]
pub(crate) enum Node<
T: Storable<D> + 'static,
D: DB = DefaultDB,
A: Storable<D> + Annotation<T> = SizeAnn,
> {
#[default]
Empty,
Leaf {
ann: A,
value: Sp<T, D>,
},
Branch {
ann: A,
children: Box<[Sp<Node<T, D, A>, D>; 16]>,
},
Extension {
ann: A,
compressed_path: std::vec::Vec<u8>, child: Sp<Node<T, D, A>, D>,
},
MidBranchLeaf {
ann: A,
value: Sp<T, D>,
child: Sp<Node<T, D, A>, D>, },
}
impl<T: Storable<D> + Tagged + 'static, D: DB, A: Storable<D> + Annotation<T> + Tagged> Tagged
for Node<T, D, A>
{
fn tag() -> std::borrow::Cow<'static, str> {
format!("mpt-node({},{})", T::tag(), A::tag()).into()
}
fn tag_unique_factor() -> String {
let a = A::tag();
let t = T::tag();
format!(
"[(),({a},{t}),({a},array(mpt-node({a},{t}),16)),({a},vec(u8),mpt-node({a},{t})),({a},{t},mpt-node({a},{t}))]"
)
}
}
tag_enforcement_test!(Node<(), DefaultDB, SizeAnn>);
pub trait Semigroup {
fn append(&self, other: &Self) -> Self;
}
pub trait Monoid: Semigroup {
fn empty() -> Self;
}
impl Semigroup for () {
fn append(&self, _: &Self) -> Self {}
}
impl Monoid for () {
fn empty() -> Self {}
}
impl Semigroup for u64 {
fn append(&self, other: &Self) -> Self {
self.saturating_add(*other)
}
}
impl Monoid for u64 {
fn empty() -> Self {
0
}
}
impl Semigroup for u128 {
fn append(&self, other: &Self) -> Self {
self.saturating_add(*other)
}
}
impl Monoid for u128 {
fn empty() -> Self {
0
}
}
impl Semigroup for i64 {
fn append(&self, other: &Self) -> Self {
self.saturating_add(*other)
}
}
impl Monoid for i64 {
fn empty() -> Self {
0
}
}
impl Semigroup for i128 {
fn append(&self, other: &Self) -> Self {
self.saturating_add(*other)
}
}
impl Monoid for i128 {
fn empty() -> Self {
0
}
}
fn compress_nibbles(nibbles: &[u8]) -> std::vec::Vec<u8> {
let mut compressed: std::vec::Vec<u8> = vec![0; (nibbles.len() / 2) + (nibbles.len() % 2)];
for i in 0..nibbles.len() {
if i % 2 == 0 {
compressed[i / 2] |= nibbles[i] << 4;
} else {
compressed[i / 2] |= nibbles[i];
}
}
compressed
}
fn expand_nibbles(compressed: &[u8], len: usize) -> std::vec::Vec<u8> {
let mut nibbles = std::vec::Vec::new();
for i in 0..len {
if i % 2 == 0 {
nibbles.push((compressed[i / 2] & 0xf0) >> 4);
} else {
nibbles.push(compressed[i / 2] & 0x0f);
}
}
nibbles
}
impl<T: Storable<D>, D: DB, A: Storable<D> + Annotation<T>> Node<T, D, A> {
fn into_inner_for_drop(self) -> impl Iterator<Item = T> {
let res: Box<dyn Iterator<Item = T>> = match self {
Node::Empty => Box::new(empty()),
Node::Leaf { value, .. } => Box::new(Sp::into_inner(value).into_iter()),
Node::Branch { children, .. } => Box::new(
children
.into_iter()
.flat_map(Sp::into_inner)
.flat_map(Node::into_inner_for_drop),
),
Node::Extension { child, .. } => Box::new(
Sp::into_inner(child)
.into_iter()
.flat_map(Node::into_inner_for_drop),
),
Node::MidBranchLeaf { value, child, .. } => Box::new(
Sp::into_inner(value).into_iter().chain(
Sp::into_inner(child)
.into_iter()
.flat_map(Node::into_inner_for_drop),
),
),
};
res
}
}
fn extension<T: Storable<D>, D: DB, A: Storable<D> + Annotation<T>>(
mut path: Vec<u8>,
child: Sp<Node<T, D, A>, D>,
) -> Sp<Node<T, D, A>, D> {
let mut cur = child;
while let Node::Extension {
compressed_path,
child,
..
} = &*cur
&& !path.len().is_multiple_of(255)
{
path.extend(compressed_path);
cur = child.clone();
}
for working_path in path.chunks(255).rev() {
cur = Sp::new(Node::Extension {
ann: Node::<T, D, A>::ann(&cur),
compressed_path: working_path.to_vec(),
child: cur.clone(),
});
}
cur
}
pub trait HasSize
where
Self: Sized + Clone,
{
fn get_size(&self) -> u64;
fn set_size(&self, x: u64) -> Self;
}
impl HasSize for SizeAnn {
fn get_size(&self) -> u64 {
self.0
}
fn set_size(&self, x: u64) -> Self {
SizeAnn(x)
}
}
impl Semigroup for SizeAnn {
fn append(&self, other: &Self) -> Self {
SizeAnn(self.0 + other.0)
}
}
impl Monoid for SizeAnn {
fn empty() -> Self {
SizeAnn(0)
}
}
pub trait Annotation<T>:
Monoid + Serializable + Deserializable + HasSize + Debug + PartialEq
{
fn from_value(value: &T) -> Self;
}
impl<T> Annotation<T> for SizeAnn {
fn from_value(_value: &T) -> Self {
SizeAnn(1)
}
}
impl<T: Storable<D>, D: DB, A: Storable<D> + Annotation<T>> Node<T, D, A> {
fn lookup_sp(sp: &Sp<Node<T, D, A>, D>, path: &[u8]) -> Option<Sp<T, D>> {
Self::lookup_with(sp, path, Clone::clone)
}
fn lookup<'a>(sp: &'a Sp<Node<T, D, A>, D>, path: &[u8]) -> Option<&'a T> {
Self::lookup_with::<&T>(sp, path, |sp| sp.deref())
}
fn lookup_with<'a, S>(
sp: &'a Sp<Node<T, D, A>, D>,
path: &[u8],
f: impl FnOnce(&'a Sp<T, D>) -> S,
) -> Option<S> {
match sp.deref() {
Node::Empty => None,
Node::Leaf { value, .. } if path.is_empty() => Some(f(value)),
Node::Leaf { .. } => None,
Node::Branch { children, .. } => {
if path.is_empty() {
return None;
}
let index: usize = path[0].into();
Self::lookup_with(&children[index], &path[1..], f)
}
Node::Extension {
compressed_path,
child,
..
} => {
if path.len() < compressed_path.len() {
return None;
}
for i in 0..compressed_path.len() {
if compressed_path[i] != path[i] {
return None;
}
}
Self::lookup_with(child, &path[compressed_path.len()..], f)
}
Node::MidBranchLeaf { value, child, .. } => {
if path.is_empty() {
Some(f(value))
} else {
Self::lookup_with(child, path, f)
}
}
}
}
pub(crate) fn with_pushed_nibble<R>(
path: &mut Vec<u8>,
nibble: u8,
f: impl FnOnce(&mut Vec<u8>) -> R,
) -> R {
path.push(nibble);
let res = f(path);
path.pop();
res
}
pub(crate) fn with_extended_suffix<R>(
path: &mut Vec<u8>,
temporary_suffix: &[u8],
f: impl FnOnce(&mut Vec<u8>) -> R,
) -> R {
let old_len = path.len();
path.extend_from_slice(temporary_suffix);
let res = f(path);
path.truncate(old_len);
res
}
pub fn prune(
sp: &Sp<Node<T, D, A>, D>,
target_path: &[u8],
) -> (Sp<Self, D>, Vec<Sp<T, D>>) {
if target_path.is_empty() {
return (sp.clone(), vec![]);
}
match &**sp {
Node::Empty => (Sp::new(Node::Empty), vec![]),
Node::Leaf { value, .. } => {
(Sp::new(Node::Empty), vec![value.clone()])
}
Node::Branch {
children, ann: an, ..
} => {
let path_head = target_path[0];
if path_head >= 16 {
panic!("Invalid path nibble: {}", path_head);
}
let mut pruned = (0..path_head as usize)
.flat_map(|i| Self::iter(&children[i]))
.collect::<Vec<_>>();
let mut children = children.clone();
for child in children.iter_mut().take(path_head as usize) {
*child = Sp::new(Node::Empty);
}
let (child, pruned_2) =
Self::prune(&children[path_head as usize], &target_path[1..]);
children[path_head as usize] = child;
pruned.extend(pruned_2);
let no_filled = children.iter().filter(|c| !Self::is_empty(c)).count();
let an = if pruned.is_empty() {
an.clone()
} else {
children
.iter()
.fold(A::empty(), |acc, x| acc.append(&Self::ann(x)))
};
match no_filled {
0 => (Sp::new(Node::Empty), pruned),
1 => {
let (path_head, child) = children
.into_iter()
.enumerate()
.find(|(_, child)| !Self::is_empty(child))
.expect("Exactly one non-empty child must exist in branch");
match &*child {
Node::Extension {
compressed_path,
child,
..
} => (
extension(
once(path_head as u8)
.chain(compressed_path.iter().copied())
.collect(),
child.clone(),
),
pruned,
),
_ => (extension(vec![path_head as u8], child), pruned),
}
}
_ => (Sp::new(Node::Branch { children, ann: an }), pruned),
}
}
Node::Extension {
compressed_path,
child,
ann: an,
} => {
let relevant_target_path =
&target_path[..usize::min(target_path.len(), compressed_path.len())];
match compressed_path[..].cmp(relevant_target_path) {
Ordering::Less => (Sp::new(Node::Empty), Self::iter(child).collect()),
Ordering::Equal => {
let (child, pruned) =
Self::prune(child, &target_path[compressed_path.len()..]);
match &*child {
Node::Empty => (Sp::new(Node::Empty), pruned),
Node::Extension {
compressed_path: cpath2,
child,
..
} => (
extension(
compressed_path
.iter()
.chain(cpath2.iter())
.copied()
.collect(),
child.clone(),
),
pruned,
),
_ => (
Sp::new(Node::Extension {
ann: if pruned.is_empty() {
an.clone()
} else {
Self::ann(&child)
},
compressed_path: compressed_path.clone(),
child: child.clone(),
}),
pruned,
),
}
}
Ordering::Greater => (sp.clone(), vec![]),
}
}
Node::MidBranchLeaf { value, child, .. } => {
let (child, pruned) = Self::prune(child, target_path);
(child, once(value.clone()).chain(pruned).collect())
}
}
}
pub(crate) fn iter(sp: &Sp<Node<T, D, A>, D>) -> impl Iterator<Item = Sp<T, D>> + '_ {
let res: Box<dyn Iterator<Item = Sp<T, D>>> = match sp.deref() {
Node::Empty => Box::new(empty()),
Node::Leaf { value, .. } => Box::new(once(value.clone())),
Node::Branch { children, .. } => Box::new(children.iter().flat_map(|c| Self::iter(c))),
Node::Extension { child, .. } => Box::new(Self::iter(child)),
Node::MidBranchLeaf { value, child, .. } => {
Box::new(once(value.clone()).chain(Self::iter(child)))
}
};
res
}
pub(crate) fn is_empty(sp: &Sp<Node<T, D, A>, D>) -> bool {
matches!(sp.deref(), Node::Empty)
}
fn find_predecessor_recursive<'a>(
sp: &'a Sp<Node<T, D, A>, D>,
original_target_path: &[u8],
explored_path: &mut Vec<u8>,
best_predecessor: &mut Option<(Vec<u8>, &'a T)>,
) {
let current_depth = explored_path.len();
let path_remaining = if current_depth <= original_target_path.len() {
&original_target_path[current_depth..]
} else {
&[]
};
if path_remaining.is_empty() {
return;
}
match sp.deref() {
Node::Empty => (),
Node::Leaf { value, .. } => {
Self::update_best_pred(best_predecessor, explored_path.to_vec(), value.deref())
}
Node::Branch { children, .. } => {
let path_head = path_remaining[0];
let matching_child = &children[path_head as usize];
if !Self::is_empty(matching_child) {
let mut new_best_pred = None;
Self::with_pushed_nibble(explored_path, path_head, |temp_path| {
Self::find_predecessor_recursive(
matching_child,
original_target_path,
temp_path,
&mut new_best_pred,
)
});
if new_best_pred.is_some() {
*best_predecessor = new_best_pred;
return;
}
}
for i in (0..path_head).rev() {
let largest = Self::with_pushed_nibble(explored_path, i, |temp_path| {
Self::find_largest_key_in_subtree(&children[i as usize], temp_path)
});
if let Some((key, val)) = largest {
Self::update_best_pred(best_predecessor, key, val);
break; }
}
}
Node::Extension {
compressed_path,
child,
..
} => {
let match_len = compressed_path
.iter()
.zip(path_remaining.iter())
.take_while(|(a, b)| a == b)
.count();
if match_len == compressed_path.len() {
Self::with_extended_suffix(explored_path, compressed_path, |temp_path| {
Self::find_predecessor_recursive(
child,
original_target_path,
temp_path,
best_predecessor,
)
});
} else {
let diverging_compressed_nibble = compressed_path[match_len];
let diverging_remaining_nibble = path_remaining[match_len];
if match_len < compressed_path.len() && match_len < path_remaining.len()
&& diverging_compressed_nibble < diverging_remaining_nibble
&& let Some((key, val)) = Self::find_largest_key_in_subtree(sp, explored_path)
{
Self::update_best_pred(best_predecessor, key, val)
}
}
}
Node::MidBranchLeaf { value, child, .. } => {
let mut new_best_pred = None;
Self::find_predecessor_recursive(
child,
original_target_path,
explored_path,
&mut new_best_pred,
);
if new_best_pred.is_some() {
*best_predecessor = new_best_pred;
return;
}
Self::update_best_pred(best_predecessor, explored_path.to_vec(), value.deref())
}
}
}
pub(crate) fn update_best_pred<'a>(
best_predecessor: &mut Option<(Vec<u8>, &'a T)>,
candidate_path: Vec<u8>,
candidate_value: &'a T,
) {
if best_predecessor
.as_ref()
.is_none_or(|(bp_path, _)| candidate_path > *bp_path)
{
*best_predecessor = Some((candidate_path, candidate_value));
}
}
pub(crate) fn find_largest_key_in_subtree<'a>(
sp: &'a Sp<Node<T, D, A>, D>,
current_path_to_node: &mut Vec<u8>,
) -> Option<(Vec<u8>, &'a T)> {
match sp.deref() {
Node::Empty => None,
Node::Leaf { value, .. } => Some((current_path_to_node.to_vec(), value.deref())),
Node::Branch { children, .. } => (0..16).rev().find_map(|i| {
Self::with_pushed_nibble(current_path_to_node, i as u8, |p| {
Self::find_largest_key_in_subtree(&children[i], p)
})
}),
Node::Extension {
compressed_path,
child,
..
} => Self::with_extended_suffix(current_path_to_node, compressed_path, |temp_path| {
Self::find_largest_key_in_subtree(child, temp_path)
}),
Node::MidBranchLeaf { value, child, .. } => {
let largest_in_child =
Self::find_largest_key_in_subtree(child, current_path_to_node);
if largest_in_child.is_some() {
return largest_in_child;
}
Some((current_path_to_node.to_vec(), value.deref()))
}
}
}
pub fn ann(sp: &Sp<Node<T, D, A>, D>) -> A {
match &**sp {
Node::Empty => A::empty(),
Node::Leaf { ann, .. }
| Node::Branch { ann, .. }
| Node::Extension { ann, .. }
| Node::MidBranchLeaf { ann, .. } => (*ann).clone(),
}
}
fn insert(sp: &Sp<Self, D>, path: &[u8], value: T) -> (Sp<Self, D>, Option<Sp<T, D>>) {
if path.is_empty() {
let value_sp = sp.arena.alloc(value.clone());
let (node, existing_val) = match sp.deref() {
Node::Empty => (
Node::Leaf {
ann: Annotation::<T>::from_value(&value),
value: value_sp,
},
None,
),
Node::Leaf { value: old_val, .. } => (
Node::Leaf {
ann: Annotation::<T>::from_value(&value),
value: value_sp,
},
Some(old_val.clone()),
),
Node::Branch { .. } | Node::Extension { .. } => {
let new_ann = Self::ann(sp).append(&Annotation::<T>::from_value(&value));
(
Node::MidBranchLeaf {
ann: new_ann,
value: value_sp,
child: sp.clone(),
},
None,
)
}
Node::MidBranchLeaf {
value: old_val,
child,
..
} => (
Node::MidBranchLeaf {
ann: Self::ann(child).append(&Annotation::<T>::from_value(&value)),
value: value_sp,
child: child.clone(),
},
Some(old_val.clone()),
),
};
return (sp.arena.alloc(node), existing_val);
}
match sp.deref() {
Node::Empty => {
let value_sp = sp.arena.alloc(value.clone());
let child = Sp::new(Node::Leaf {
ann: Annotation::<T>::from_value(&value),
value: value_sp,
});
let res = extension(path.to_vec(), child);
(res, None)
}
Node::Leaf {
ann: existing_ann,
value: self_value,
..
} => {
let child = Sp::new(Node::Leaf {
ann: A::from_value(&value),
value: sp.arena.alloc(value.clone()),
});
let ext = extension(path.to_vec(), child);
let branch_ann = A::from_value(&value).append(existing_ann);
(
sp.arena.alloc(Node::MidBranchLeaf {
ann: branch_ann,
value: self_value.clone(),
child: ext,
}),
None,
)
}
Node::Branch { children, .. } => {
let index: usize = path[0].into();
let mut new_children = children.clone();
let (new_child, existing) = Self::insert(&new_children[index], &path[1..], value);
new_children[index] = new_child;
let new_ann = new_children
.iter()
.fold(A::empty(), |an, c| an.append(&Self::ann(c)));
(
sp.arena.alloc(Node::Branch {
ann: new_ann,
children: new_children,
}),
existing,
)
}
Node::Extension {
compressed_path,
child,
..
} => {
let working_path: std::vec::Vec<u8> =
path.chunks(255).next().expect("path is not empty").to_vec();
let index = compressed_path
.iter()
.zip(working_path)
.take_while(|(a, b)| **a == *b)
.count();
if index == compressed_path.len() {
let (new_child, existing) = Self::insert(child, &path[index..], value);
(
sp.arena.alloc(Node::Extension {
ann: Self::ann(&new_child),
compressed_path: compressed_path.clone(),
child: new_child,
}),
existing,
)
} else {
let remaining = if index == compressed_path.len() - 1 {
child.clone()
} else {
sp.arena.alloc(Node::Extension {
ann: Self::ann(child),
compressed_path: compressed_path[(index + 1)..].to_vec(),
child: child.clone(),
})
};
let compressed_path_index: usize = compressed_path[index].into();
let mut children: [Sp<Node<T, D, A>, D>; 16] =
core::array::from_fn(|_| sp.arena.alloc(Node::Empty));
children[compressed_path_index] = remaining;
let initial_ann = children
.iter()
.map(|c| Self::ann(c))
.fold(A::empty(), |acc, child_ann| acc.append(&child_ann));
let branch = sp.arena.alloc(Node::Branch {
ann: initial_ann,
children: Box::new(children),
});
let (final_branch, existing) = Self::insert(&branch, &path[index..], value);
if index == 0 {
(final_branch, existing)
} else {
(
sp.arena.alloc(Node::Extension {
ann: Self::ann(&final_branch),
compressed_path: compressed_path[0..index].to_vec(),
child: final_branch,
}),
existing,
)
}
}
}
Node::MidBranchLeaf {
child,
value: leaf_value,
..
} => {
let (new_child, existing) = Self::insert(child, path, value);
let new_ann = A::from_value(leaf_value).append(&Self::ann(&new_child));
(
sp.arena.alloc(Node::MidBranchLeaf {
ann: new_ann,
value: leaf_value.clone(),
child: new_child,
}),
existing,
)
}
}
}
fn size(sp: &Sp<Node<T, D, A>, D>) -> usize {
match sp.deref() {
Node::Empty => 0,
Node::Leaf { .. } => 1,
Node::Extension { ann, .. }
| Node::Branch { ann, .. }
| Node::MidBranchLeaf { ann, .. } => ann.clone().get_size() as usize,
}
}
fn leaves<'a>(
sp: &'a Sp<Node<T, D, A>, D>,
current_path: Vec<u8>,
) -> Box<dyn Iterator<Item = (std::vec::Vec<u8>, Sp<T, D>)> + 'a> {
match sp.deref() {
Node::Empty => Box::new(empty()),
Node::Leaf { value, .. } => Box::new([(current_path, value.clone())].into_iter()),
Node::Extension {
compressed_path,
child,
..
} => {
let mut new_path = current_path.to_vec();
new_path.append(&mut compressed_path.clone());
Self::leaves(child, new_path)
}
Node::Branch { children, .. } => {
Box::new(children.iter().enumerate().flat_map(move |(i, child)| {
let mut new_path = current_path.clone();
new_path.push(i as u8);
Self::leaves(child, new_path)
}))
}
Node::MidBranchLeaf { value, child, .. } => Box::new(
Self::leaves(child, current_path.clone())
.chain(once((current_path, value.clone()))),
),
}
}
pub fn remove(sp: &Sp<Self, D>, path: &[u8]) -> (Sp<Self, D>, Option<Sp<T, D>>) {
match sp.deref() {
Node::Empty => (sp.arena.alloc(Node::Empty), None),
Node::Leaf { value, ann } => {
if path.is_empty() {
return (sp.arena.alloc(Node::Empty), Some(value.clone()));
}
(
sp.arena.alloc(Node::Leaf {
ann: ann.clone(),
value: value.clone(),
}),
None,
)
}
Node::Branch { children, .. } => {
let mut new_children = children.clone();
let index: usize = path[0].into();
let (new_child, removed) = Self::remove(&new_children[index], &path[1..]);
new_children[index] = new_child;
if new_children
.iter()
.map(|v| match **v {
Node::Empty => 0,
_ => 1,
})
.sum::<usize>()
== 1
{
let (only_child_index, only_child) = new_children
.iter()
.enumerate()
.find(|(_i, v)| !matches!(***v, Node::Empty))
.unwrap();
match (**only_child).clone() {
Node::Extension {
mut compressed_path,
child,
..
} => {
let mut new_compressed_path = vec![only_child_index as u8];
new_compressed_path.append(&mut compressed_path);
(extension(new_compressed_path, child), removed)
}
_ => (
extension(vec![only_child_index as u8], only_child.clone()),
removed,
),
}
} else {
(
sp.arena.alloc(Node::Branch {
ann: new_children
.iter()
.fold(A::empty(), |acc, x| acc.append(&Self::ann(x))),
children: new_children,
}),
removed,
)
}
}
Node::Extension {
ann: an,
compressed_path,
child,
..
} => {
for i in 0..compressed_path.len() {
if compressed_path[i] != path[i] {
return (
sp.arena.alloc(Node::Extension {
ann: an.clone(),
compressed_path: compressed_path.clone(),
child: child.clone(),
}),
None,
);
}
}
let (new_child, removed) = Self::remove(child, &path[compressed_path.len()..]);
let new_ann = Self::ann(&new_child);
match new_child.deref() {
Node::Empty => (sp.arena.alloc(Node::Empty), removed),
Node::Extension {
compressed_path: p,
child: c,
..
} => {
let mut new_compressed_path = compressed_path.clone();
new_compressed_path.append(&mut p.clone());
let child = extension(new_compressed_path, c.clone());
(child, removed)
}
_ => (
sp.arena.alloc(Node::Extension {
ann: new_ann,
compressed_path: compressed_path.clone(),
child: new_child,
}),
removed,
),
}
}
Node::MidBranchLeaf { child, value, .. } => {
if path.is_empty() {
(child.clone(), Some(value.clone()))
} else {
let (child, removed) = Self::remove(child, path);
let new_ann = Self::ann(&child).append(&Annotation::<T>::from_value(value));
match child.deref() {
Node::Empty => (
sp.arena.alloc(Node::Leaf {
ann: new_ann,
value: value.clone(),
}),
removed,
),
_ => (
sp.arena.alloc(Node::MidBranchLeaf {
ann: new_ann,
value: value.clone(),
child,
}),
removed,
),
}
}
}
}
}
}
impl<T: Storable<D> + 'static, D: DB, A: Storable<D> + Annotation<T>> Storable<D>
for Node<T, D, A>
{
fn children(&self) -> std::vec::Vec<ArenaKey<D::Hasher>> {
match self {
Node::Empty => std::vec::Vec::new(),
Node::Leaf { value, .. } => vec![value.as_child()],
Node::Branch { children, .. } => children.iter().map(Sp::as_child).collect(),
Node::Extension { child, .. } => vec![child.as_child()],
Node::MidBranchLeaf { child, value, .. } => {
vec![value.as_child(), child.as_child()]
}
}
}
fn to_binary_repr<W: Write>(&self, writer: &mut W) -> Result<(), std::io::Error> {
match self {
Node::Empty => {
u8::serialize(&0, writer)?;
}
Node::Leaf { ann, .. } => {
u8::serialize(&1, writer)?;
A::serialize(ann, writer)?;
}
Node::Branch { ann, .. } => {
u8::serialize(&2, writer)?;
A::serialize(ann, writer)?;
}
Node::Extension {
ann,
compressed_path,
..
} => {
u8::serialize(&3, writer)?;
A::serialize(ann, writer)?;
let compressed = compress_nibbles(compressed_path);
u8::serialize(&(compressed_path.len() as u8), writer)?;
std::vec::Vec::<u8>::serialize(&compressed, writer)?;
}
Node::MidBranchLeaf { ann, .. } => {
u8::serialize(&4, writer)?;
A::serialize(ann, writer)?;
}
}
Ok(())
}
fn check_invariant(&self) -> Result<(), std::io::Error> {
match self {
Node::Empty | Node::Leaf { .. } => {}
Node::Branch { ann, children } => {
let non_empty_children = children
.iter()
.filter(|child| !matches!(***child, Node::Empty))
.count();
if non_empty_children < 2 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Fewer than 2 non-empty children in Node::Branch".to_string(),
));
}
if ann.get_size()
!= children
.iter()
.map(|child| Self::size(child) as u64)
.sum::<u64>()
{
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Recorded branch size doesn't match sum of children",
));
}
}
Node::Extension {
ann,
compressed_path,
child,
} => {
if matches!(child.deref(), Node::Extension { .. }) && compressed_path.len() != 255 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Node::Extension path must be of length 255 when having another Node::Extension child",
));
}
if compressed_path.len() > 255 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Node::Extension path may not be longer than 255",
));
}
if ann.get_size() != Self::size(child) as u64 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Recorded extension size doesn't match child size",
));
}
if compressed_path.iter().any(|b| *b > 0x0f) {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Node::Extension path must consist of nibbles",
));
}
}
Node::MidBranchLeaf { ann, child, .. } => {
match child.deref() {
Node::Branch { .. } | Node::Extension { .. } => {}
_ => {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Node::MidBranchLeaf may only have Node::Branch or Node::Extension children",
));
}
}
if ann.get_size() != Self::size(child) as u64 + 1 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Recorded mid-branch-leaf size isn't one greater than child size",
));
}
}
}
Ok(())
}
#[inline(always)]
fn from_binary_repr<R: Read>(
reader: &mut R,
child_hashes: &mut impl Iterator<Item = ArenaKey<D::Hasher>>,
loader: &impl Loader<D>,
) -> Result<Node<T, D, A>, std::io::Error> {
let disc = u8::deserialize(reader, 0)?;
match disc {
0 => Ok(Node::Empty),
1 => {
let ann = A::deserialize(reader, 0)?;
Ok(Node::Leaf {
ann,
value: loader.get_next(child_hashes)?,
})
}
2 => {
let ann = A::deserialize(reader, 0)?;
let Ok(children) = (0..16)
.map(|_| loader.get_next(child_hashes))
.collect::<Result<Vec<_>, _>>()?
.try_into()
else {
unreachable!("iterator must be of expected length")
};
Ok(Node::Branch {
ann,
children: Box::new(children),
})
}
3 => {
let ann = A::deserialize(reader, 0)?;
let len = u8::deserialize(reader, 0)?;
let path =
expand_nibbles(&std::vec::Vec::<u8>::deserialize(reader, 0)?, len as usize);
let child: Sp<Node<T, D, A>, D> = loader.get_next(child_hashes)?;
Ok(Node::Extension {
ann,
compressed_path: path,
child,
})
}
4 => {
let ann = A::deserialize(reader, 0)?;
let value: Sp<T, D> = loader.get_next(child_hashes)?;
let child: Sp<Node<T, D, A>, D> = loader.get_next(child_hashes)?;
Ok(Node::MidBranchLeaf { ann, value, child })
}
_ => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Unrecognised discriminant",
)),
}
}
}
#[cfg(test)]
mod tests {
use sha2::Sha256;
use crate::{
Storage,
db::InMemoryDB,
storable::SMALL_OBJECT_LIMIT,
storage::{WrappedDB, default_storage, set_default_storage},
};
use super::*;
use serialize::{Deserializable, Serializable};
#[test]
fn insert_lookup() {
dbg!("start");
let mut mpt = MerklePatriciaTrie::<u64>::new();
dbg!("new tree");
mpt = mpt.insert(&([1, 2, 3]), 100);
dbg!("inserted 100 at [1, 2, 3]");
mpt = mpt.insert(&([1, 2, 4]), 104);
dbg!("inserted 104 at [1, 2, 4]");
mpt = mpt.insert(&([2, 2, 4]), 105);
dbg!("inserted 105 at [2, 2, 4]");
assert_eq!(mpt.lookup(&([1, 2, 3])), Some(&100));
assert_eq!(mpt.lookup(&([1, 2, 4])), Some(&104));
assert_eq!(mpt.lookup(&([2, 2, 4])), Some(&105));
}
#[test]
fn remove() {
let mut mpt = MerklePatriciaTrie::<u64>::new();
mpt = mpt.insert(&([1, 2, 3]), 100);
mpt = mpt.insert(&([1, 3, 3]), 102);
mpt = mpt.remove(&([1, 2, 3]));
assert_eq!(mpt.size(), 1);
assert_eq!(mpt.lookup(&([1, 3, 3])), Some(&102));
assert_eq!(mpt.lookup(&([1, 2, 3])), None);
}
#[test]
fn deduplicate() {
struct Tag;
type D = WrappedDB<DefaultDB, Tag>;
let _ = set_default_storage::<D>(Storage::default);
let mut mpt = MerklePatriciaTrie::<[u8; SMALL_OBJECT_LIMIT], D>::new();
mpt = mpt.insert(&([1, 2, 3]), [100; SMALL_OBJECT_LIMIT]);
mpt = mpt.insert(&([1, 2, 2]), [100; SMALL_OBJECT_LIMIT]);
assert_eq!(mpt.lookup(&([1, 2, 3])), Some(&[100; SMALL_OBJECT_LIMIT]));
assert_eq!(mpt.lookup(&([1, 2, 2])), Some(&[100; SMALL_OBJECT_LIMIT]));
assert_eq!(mpt.size(), 2);
dbg!(&mpt.0.arena);
assert_eq!(mpt.0.arena.size(), 1);
}
#[test]
fn mpt_arena_serialization() {
let mut mpt = MerklePatriciaTrie::<u8>::new();
mpt = mpt.insert(&([1, 2, 3]), 100);
mpt = mpt.insert(&([1, 2, 4]), 104);
mpt = mpt.insert(&([2, 2, 4]), 105);
let mut bytes = std::vec::Vec::new();
MerklePatriciaTrie::serialize(&mpt, &mut bytes).unwrap();
assert_eq!(bytes.len(), MerklePatriciaTrie::<u8>::serialized_size(&mpt));
let mpt: MerklePatriciaTrie<u8> =
MerklePatriciaTrie::deserialize(&mut bytes.as_slice(), 0).unwrap();
assert_eq!(mpt.lookup(&([1, 2, 3])), Some(&100));
assert_eq!(mpt.lookup(&([1, 2, 4])), Some(&104));
assert_eq!(mpt.lookup(&([2, 2, 4])), Some(&105));
}
#[test]
fn nodes_stored() {
struct Tag;
type D = WrappedDB<DefaultDB, Tag>;
let _ = set_default_storage::<D>(Storage::default);
let arena = &default_storage::<D>().arena;
{
let mut mpt: MerklePatriciaTrie<[u8; SMALL_OBJECT_LIMIT], D> =
MerklePatriciaTrie::new();
mpt = mpt.insert(&([1, 2, 3]), [100; SMALL_OBJECT_LIMIT]);
mpt = mpt.insert(&([1, 2, 4]), [104; SMALL_OBJECT_LIMIT]);
mpt = mpt.insert(&([2, 2, 4]), [105; SMALL_OBJECT_LIMIT]);
assert_eq!(arena.size(), 3);
assert_eq!(mpt.lookup(&([2, 2, 4])), Some(&[105; SMALL_OBJECT_LIMIT]))
}
assert_eq!(arena.size(), 0);
}
#[test]
fn long_extension_paths_serialization() {
let mut mpt: MerklePatriciaTrie<u8, InMemoryDB<Sha256>> = MerklePatriciaTrie::new();
mpt = mpt.insert(&(vec![2; 300]), 100);
let mut bytes = std::vec::Vec::new();
Serializable::serialize(&mpt, &mut bytes).unwrap();
let deserialized_mpt: MerklePatriciaTrie<u8> =
Deserializable::deserialize(&mut bytes.as_slice(), 0).unwrap();
assert_eq!(deserialized_mpt, mpt);
assert_eq!(
mpt.iter()
.map(|(k, _)| k)
.collect::<std::vec::Vec<std::vec::Vec<u8>>>(),
vec![vec![2; 300]]
);
assert_eq!(
deserialized_mpt
.iter()
.map(|(k, _)| k)
.collect::<std::vec::Vec<std::vec::Vec<u8>>>(),
vec![vec![2; 300]]
);
}
#[test]
fn mpt_structure() {
fn validate_long_path(
mpt: &MerklePatriciaTrie<u8, InMemoryDB<Sha256>>,
path_length: u64,
validate_value: u8,
) {
match mpt.0.deref() {
Node::Extension {
compressed_path,
child,
..
} => {
assert_eq!(compressed_path.len() as u64, 255);
match child.deref() {
Node::Extension {
compressed_path,
child,
..
} => {
assert_eq!(compressed_path.len() as u64, path_length - 255);
assert!(
matches!(child.deref(), Node::Leaf { ann: SizeAnn(1), value } if value.deref() == &validate_value)
);
}
_ => unreachable!(),
}
}
_ => unreachable!(),
};
}
let mut mpt = MerklePatriciaTrie::<u8>::new();
mpt = mpt.insert(&(vec![2; 300]), 100);
let mut bytes = std::vec::Vec::new();
Serializable::serialize(&mpt, &mut bytes).unwrap();
let deserialized_mpt: MerklePatriciaTrie<u8> =
Deserializable::deserialize(&mut bytes.as_slice(), 0).unwrap();
assert_eq!(deserialized_mpt, mpt);
validate_long_path(&mpt, 300, 100);
validate_long_path(&deserialized_mpt, 300, 100);
}
#[test]
fn extended_path_insertion() {
let mut mpt = MerklePatriciaTrie::<u32>::new();
mpt = mpt.insert(&([1, 2]), 12);
mpt = mpt.insert(&([1, 2, 3, 4, 5]), 12345);
mpt = mpt.insert(&([1, 2, 3, 4, 6]), 12346);
mpt = mpt.insert(&([1, 2, 3, 5, 6]), 12356);
mpt = mpt.insert(&([1]), 1);
assert_eq!(mpt.lookup(&([1, 2])), Some(&12));
assert_eq!(mpt.lookup(&([1, 2, 3, 4, 5])), Some(&12345));
assert_eq!(mpt.lookup(&([1, 2, 3, 5, 6])), Some(&12356));
assert_eq!(mpt.lookup(&([1])), Some(&1));
assert_eq!(mpt.lookup(&([4])), None);
assert_eq!(mpt.lookup(&([1, 2, 4])), None);
assert_eq!(mpt.lookup(&([1, 2, 3])), None);
assert_eq!(mpt.lookup(&([1, 2, 3, 6])), None);
assert_eq!(mpt.lookup(&([])), None);
println!("{:?}", mpt);
}
#[test]
fn test_canonicity() {
let segment1 = [0u8; 200];
let segment2 = [1u8; 200];
let path1 = segment1
.iter()
.chain(segment1.iter())
.chain(segment1.iter())
.copied()
.collect::<Vec<_>>();
let path2 = segment1
.iter()
.chain(segment2.iter())
.chain(segment2.iter())
.copied()
.collect::<Vec<_>>();
let mpt1 = MerklePatriciaTrie::<()>::new().insert(&path1, ());
let mpt2 = MerklePatriciaTrie::<()>::new()
.insert(&path2, ())
.insert(&path1, ())
.remove(&path2);
dbg!(&mpt1);
dbg!(&mpt2);
assert_eq!(mpt1.0.hash(), mpt2.0.hash());
}
}