#![allow(clippy::let_and_return)]
#![allow(clippy::while_let_loop)]
mod metrics;
mod node;
mod updater;
mod utils;
#[cfg(test)]
mod sparse_merkle_test;
#[cfg(any(test, feature = "bench", feature = "fuzzing"))]
pub mod test_utils;
use crate::sparse_merkle::{
metrics::{LATEST_GENERATION, OLDEST_GENERATION, TIMER},
node::{NodeInner, SubTree},
updater::SubTreeUpdater,
};
use aptos_crypto::{
hash::{CryptoHash, SPARSE_MERKLE_PLACEHOLDER_HASH},
HashValue,
};
use aptos_infallible::Mutex;
use aptos_types::{nibble::nibble_path::NibblePath, proof::SparseMerkleProof};
use std::sync::MutexGuard;
use std::{
borrow::Borrow,
collections::{BTreeMap, HashMap},
sync::{Arc, Weak},
};
use thiserror::Error;
type NodePosition = bitvec::vec::BitVec<bitvec::order::Msb0, u8>;
#[derive(Debug)]
struct BranchTracker<V> {
head: Weak<Inner<V>>,
next: Weak<Inner<V>>,
parent: Option<Arc<Mutex<BranchTracker<V>>>>,
}
impl<V> BranchTracker<V> {
fn new_head_unknown(
parent: Option<Arc<Mutex<Self>>>,
_locked_family: &MutexGuard<()>,
) -> Arc<Mutex<Self>> {
Arc::new(Mutex::new(Self {
head: Weak::new(),
next: Weak::new(),
parent,
}))
}
fn become_oldest(
&mut self,
head: &Arc<Inner<V>>,
next: Option<&Arc<Inner<V>>>,
_locked_family: &MutexGuard<()>,
) {
self.parent = None;
self.head = Arc::downgrade(head);
self.next = next.map_or_else(Weak::new, Arc::downgrade)
}
fn parent(&self, _locked_family: &MutexGuard<()>) -> Option<Arc<Mutex<Self>>> {
self.parent.clone()
}
fn head(&self, _locked_family: &MutexGuard<()>) -> Option<Arc<Inner<V>>> {
self.head.upgrade().or_else(|| self.next.upgrade())
}
}
#[derive(Debug)]
struct InnerLinks<V> {
children: Vec<Arc<Inner<V>>>,
branch_tracker: Arc<Mutex<BranchTracker<V>>>,
}
impl<V> InnerLinks<V> {
fn new(branch_tracker: Arc<Mutex<BranchTracker<V>>>) -> Mutex<Self> {
Mutex::new(Self {
children: Vec::new(),
branch_tracker,
})
}
}
#[derive(Debug)]
struct Inner<V> {
root: SubTree<V>,
links: Mutex<InnerLinks<V>>,
generation: u64,
family_lock: Arc<Mutex<()>>,
}
impl<V> Drop for Inner<V> {
fn drop(&mut self) {
let mut processed_decendents = Vec::new();
{
let locked_family = self.family_lock.lock();
let mut stack = self.drain_children_for_drop(&locked_family);
while let Some(descendant) = stack.pop() {
if Arc::strong_count(&descendant) == 1 {
stack.extend(descendant.drain_children_for_drop(&locked_family));
}
processed_decendents.push(descendant);
}
};
}
}
impl<V> Inner<V> {
fn new(root: SubTree<V>) -> Arc<Self> {
let family_lock = Arc::new(Mutex::new(()));
let branch_tracker = BranchTracker::new_head_unknown(None, &family_lock.lock());
let me = Arc::new(Self {
root,
links: InnerLinks::new(branch_tracker.clone()),
generation: 0,
family_lock,
});
branch_tracker.lock().head = Arc::downgrade(&me);
me
}
fn become_oldest(self: Arc<Self>, locked_family: &MutexGuard<()>) -> Arc<Self> {
{
let links_locked = self.links.lock();
let mut branch_tracker_locked = links_locked.branch_tracker.lock();
branch_tracker_locked.become_oldest(
&self,
links_locked.children.first(),
locked_family,
);
}
self
}
fn spawn_impl(
&self,
child_root: SubTree<V>,
branch_tracker: Arc<Mutex<BranchTracker<V>>>,
family_lock: Arc<Mutex<()>>,
) -> Arc<Self> {
LATEST_GENERATION.set(self.generation as i64 + 1);
Arc::new(Self {
root: child_root,
links: InnerLinks::new(branch_tracker),
generation: self.generation + 1,
family_lock,
})
}
fn spawn(self: &Arc<Self>, child_root: SubTree<V>) -> Arc<Self> {
let locked_family = self.family_lock.lock();
let mut links_locked = self.links.lock();
let child = if links_locked.children.is_empty() {
let child = self.spawn_impl(
child_root,
links_locked.branch_tracker.clone(),
self.family_lock.clone(),
);
let mut branch_tracker_locked = links_locked.branch_tracker.lock();
if branch_tracker_locked.next.upgrade().is_none() {
branch_tracker_locked.next = Arc::downgrade(&child);
}
child
} else {
let branch_tracker = BranchTracker::new_head_unknown(
Some(links_locked.branch_tracker.clone()),
&locked_family,
);
let child =
self.spawn_impl(child_root, branch_tracker.clone(), self.family_lock.clone());
branch_tracker.lock().head = Arc::downgrade(&child);
child
};
links_locked.children.push(child.clone());
child
}
fn get_oldest_ancestor(self: &Arc<Self>) -> Arc<Self> {
let locked_family = self.family_lock.lock();
let (mut ret, mut parent) = {
let branch_tracker = self.links.lock().branch_tracker.clone();
let branch_tracker_locked = branch_tracker.lock();
(
branch_tracker_locked
.head(&locked_family)
.expect("Leaf must have a head."),
branch_tracker_locked.parent(&locked_family),
)
};
while let Some(branch_tracker) = parent {
let branch_tracker_locked = branch_tracker.lock();
if let Some(head) = branch_tracker_locked.head(&locked_family) {
if head.generation < self.generation {
ret = head;
parent = branch_tracker_locked.parent(&locked_family);
continue;
}
}
break;
}
OLDEST_GENERATION.set(ret.generation as i64);
ret
}
fn drain_children_for_drop(&self, locked_family: &MutexGuard<()>) -> Vec<Arc<Self>> {
self.links
.lock()
.children
.drain(..)
.map(|child| child.become_oldest(locked_family))
.collect()
}
}
#[derive(Clone, Debug)]
pub struct SparseMerkleTree<V> {
inner: Arc<Inner<V>>,
}
impl<V> SparseMerkleTree<V>
where
V: Clone + CryptoHash + Send + Sync,
{
pub fn new(root_hash: HashValue) -> Self {
let root = if root_hash != *SPARSE_MERKLE_PLACEHOLDER_HASH {
SubTree::new_unknown(root_hash)
} else {
SubTree::new_empty()
};
Self {
inner: Inner::new(root),
}
}
pub fn new_empty() -> Self {
Self {
inner: Inner::new(SubTree::new_empty()),
}
}
pub fn has_same_root_hash(&self, other: &Self) -> bool {
self.root_hash() == other.root_hash()
}
fn get_oldest_ancestor(&self) -> Self {
Self {
inner: self.inner.get_oldest_ancestor(),
}
}
pub fn freeze(self) -> FrozenSparseMerkleTree<V> {
let base_smt = self.get_oldest_ancestor();
let base_generation = base_smt.inner.generation;
FrozenSparseMerkleTree {
base_smt,
base_generation,
smt: self,
}
}
#[cfg(test)]
fn new_with_root(root: SubTree<V>) -> Self {
Self {
inner: Inner::new(root),
}
}
fn root_weak(&self) -> SubTree<V> {
self.inner.root.weak()
}
pub fn root_hash(&self) -> HashValue {
self.inner.root.hash()
}
fn generation(&self) -> u64 {
self.inner.generation
}
fn is_the_same(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.inner, &other.inner)
}
}
#[cfg(any(feature = "fuzzing", feature = "bench", test))]
impl<V> SparseMerkleTree<V>
where
V: Clone + CryptoHash + Send + Sync,
{
pub fn batch_update(
&self,
updates: Vec<(HashValue, &V)>,
proof_reader: &impl ProofRead,
) -> Result<Self, UpdateError> {
self.clone()
.freeze()
.batch_update(updates, proof_reader)
.map(FrozenSparseMerkleTree::unfreeze)
}
pub fn get(&self, key: HashValue) -> StateStoreStatus<V> {
self.clone().freeze().get(key)
}
}
impl<V> Default for SparseMerkleTree<V>
where
V: Clone + CryptoHash + Send + Sync,
{
fn default() -> Self {
SparseMerkleTree::new(*SPARSE_MERKLE_PLACEHOLDER_HASH)
}
}
#[derive(Debug, Eq, PartialEq)]
pub enum StateStoreStatus<V> {
ExistsInScratchPad(V),
ExistsInDB,
DoesNotExist,
Unknown,
}
#[derive(Clone, Debug)]
pub struct FrozenSparseMerkleTree<V> {
base_smt: SparseMerkleTree<V>,
base_generation: u64,
smt: SparseMerkleTree<V>,
}
impl<V> FrozenSparseMerkleTree<V>
where
V: Clone + CryptoHash + Send + Sync,
{
fn spawn(&self, child_root: SubTree<V>) -> Self {
Self {
base_smt: self.base_smt.clone(),
base_generation: self.base_generation,
smt: SparseMerkleTree {
inner: self.smt.inner.spawn(child_root),
},
}
}
pub fn unfreeze(self) -> SparseMerkleTree<V> {
self.smt
}
pub fn root_hash(&self) -> HashValue {
self.smt.root_hash()
}
pub fn new_node_hashes_since(&self, since_smt: &Self) -> HashMap<NibblePath, HashValue> {
let _timer = TIMER
.with_label_values(&["new_node_hashes_since"])
.start_timer();
assert!(self.base_smt.is_the_same(&since_smt.base_smt));
let mut node_hashes = HashMap::new();
Self::new_node_hashes_since_impl(
self.smt.root_weak(),
since_smt.smt.generation() + 1,
&mut NodePosition::with_capacity(HashValue::LENGTH_IN_BITS),
&mut node_hashes,
);
node_hashes
}
fn new_node_hashes_since_impl(
subtree: SubTree<V>,
since_generation: u64,
pos: &mut NodePosition,
node_hashes: &mut HashMap<NibblePath, HashValue>,
) {
if let Some(node) = subtree.get_node_if_in_mem(since_generation) {
let is_nibble = if let Some(path) = Self::maybe_to_nibble_path(pos) {
node_hashes.insert(path, subtree.hash());
true
} else {
false
};
match node.inner().borrow() {
NodeInner::Internal(internal_node) => {
let depth = pos.len();
pos.push(false);
Self::new_node_hashes_since_impl(
internal_node.left.weak(),
since_generation,
pos,
node_hashes,
);
*pos.get_mut(depth).unwrap() = true;
Self::new_node_hashes_since_impl(
internal_node.right.weak(),
since_generation,
pos,
node_hashes,
);
pos.pop();
}
NodeInner::Leaf(leaf_node) => {
let mut path = NibblePath::new_even(leaf_node.key.to_vec());
if !is_nibble {
path.truncate(pos.len() as usize / 4 + 1);
}
node_hashes.insert(path, subtree.hash());
}
}
}
}
fn maybe_to_nibble_path(pos: &NodePosition) -> Option<NibblePath> {
assert!(pos.len() <= HashValue::LENGTH_IN_BITS);
const BITS_IN_NIBBLE: usize = 4;
const BITS_IN_BYTE: usize = 8;
if pos.len() % BITS_IN_NIBBLE == 0 {
let mut bytes = pos.clone().into_vec();
if pos.len() % BITS_IN_BYTE == 0 {
Some(NibblePath::new_even(bytes))
} else {
if let Some(b) = bytes.last_mut() {
*b &= 0xf0
}
Some(NibblePath::new_odd(bytes))
}
} else {
None
}
}
pub fn batch_update(
&self,
updates: Vec<(HashValue, &V)>,
proof_reader: &impl ProofRead,
) -> Result<Self, UpdateError> {
let kvs = updates
.into_iter()
.collect::<BTreeMap<_, _>>()
.into_iter()
.collect::<Vec<_>>();
let current_root = self.smt.root_weak();
if kvs.is_empty() {
Ok(self.clone())
} else {
let root = SubTreeUpdater::update(
current_root,
&kvs[..],
proof_reader,
self.smt.inner.generation + 1,
)?;
Ok(self.spawn(root))
}
}
pub fn get(&self, key: HashValue) -> StateStoreStatus<V> {
let mut subtree = self.smt.root_weak();
let mut bits = key.iter_bits();
loop {
match subtree {
SubTree::Empty => return StateStoreStatus::DoesNotExist,
SubTree::NonEmpty { .. } => {
match subtree.get_node_if_in_mem(self.base_generation) {
None => return StateStoreStatus::Unknown,
Some(node) => match node.inner() {
NodeInner::Internal(internal_node) => {
subtree = if bits.next().expect("Tree is too deep.") {
internal_node.right.weak()
} else {
internal_node.left.weak()
};
continue;
} NodeInner::Leaf(leaf_node) => {
return if leaf_node.key == key {
match &leaf_node.value.data.get_if_in_mem() {
Some(value) => StateStoreStatus::ExistsInScratchPad(
value.as_ref().clone(),
),
None => StateStoreStatus::ExistsInDB,
}
} else {
StateStoreStatus::DoesNotExist
};
} }, }
} }
} }
}
pub trait ProofRead: Sync {
fn get_proof(&self, key: HashValue) -> Option<&SparseMerkleProof>;
}
#[derive(Debug, Error, Eq, PartialEq)]
pub enum UpdateError {
#[error("Missing Proof")]
MissingProof,
#[error(
"Short proof: key: {}, num_siblings: {}, depth: {}",
key,
num_siblings,
depth
)]
ShortProof {
key: HashValue,
num_siblings: usize,
depth: usize,
},
}