use crate::hasher::NodeHasher;
use crate::trie::{self, InternalData, KeyPath, LeafData, Node, NodeKind, TERMINATOR};
use crate::trie_pos::TriePosition;
use bitvec::prelude::*;
use core::fmt;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(
feature = "borsh",
derive(borsh::BorshDeserialize, borsh::BorshSerialize)
)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum PathProofTerminal {
Leaf(LeafData),
Terminator(TriePosition),
}
impl PathProofTerminal {
pub fn path(&self) -> &BitSlice<u8, Msb0> {
match self {
Self::Leaf(leaf_data) => &leaf_data.key_path.view_bits(),
Self::Terminator(key_path) => key_path.path(),
}
}
pub fn node<H: NodeHasher>(&self) -> Node {
match self {
Self::Leaf(leaf_data) => H::hash_leaf(leaf_data),
Self::Terminator(_key_path) => TERMINATOR,
}
}
pub fn as_leaf_option(&self) -> Option<LeafData> {
match self {
Self::Leaf(leaf_data) => Some(leaf_data.clone()),
Self::Terminator(_) => None,
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(
feature = "borsh",
derive(borsh::BorshDeserialize, borsh::BorshSerialize)
)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct PathProof {
pub terminal: PathProofTerminal,
pub siblings: Vec<Node>,
}
impl PathProof {
pub fn verify<H: NodeHasher>(
&self,
key_path: &BitSlice<u8, Msb0>,
root: Node,
) -> Result<VerifiedPathProof, PathProofVerificationError> {
if self.siblings.len() > core::cmp::min(key_path.len(), 256) {
return Err(PathProofVerificationError::TooManySiblings);
}
let relevant_path = &key_path[..self.siblings.len()];
let cur_node = self.terminal.node::<H>();
let new_root = hash_path::<H>(cur_node, relevant_path, self.siblings.iter().rev().cloned());
if new_root == root {
Ok(VerifiedPathProof {
key_path: relevant_path.into(),
terminal: match &self.terminal {
PathProofTerminal::Leaf(leaf_data) => Some(leaf_data.clone()),
PathProofTerminal::Terminator(_) => None,
},
siblings: self.siblings.clone(),
root,
})
} else {
Err(PathProofVerificationError::RootMismatch)
}
}
}
pub fn hash_path<H: NodeHasher>(
mut node: Node,
path: &BitSlice<u8, Msb0>,
siblings: impl IntoIterator<Item = Node>,
) -> Node {
for (bit, sibling) in path.iter().by_vals().rev().zip(siblings) {
let (left, right) = if bit {
(sibling, node)
} else {
(node, sibling)
};
let next = InternalData {
left: left.clone(),
right: right.clone(),
};
node = H::hash_internal(&next);
}
node
}
#[derive(Debug, Clone, Copy)]
pub struct KeyOutOfScope;
#[derive(Debug, Clone, Copy)]
pub enum PathProofVerificationError {
TooManySiblings,
RootMismatch,
}
#[derive(Clone)]
#[must_use = "VerifiedPathProof only checks the trie path, not whether it actually looks up to your expected value."]
pub struct VerifiedPathProof {
key_path: BitVec<u8, Msb0>,
terminal: Option<LeafData>,
siblings: Vec<Node>,
root: Node,
}
impl VerifiedPathProof {
pub fn terminal(&self) -> Option<&LeafData> {
self.terminal.as_ref()
}
pub fn path(&self) -> &BitSlice<u8, Msb0> {
&self.key_path[..]
}
pub fn root(&self) -> Node {
self.root
}
pub fn confirm_value(&self, expected_leaf: &LeafData) -> Result<bool, KeyOutOfScope> {
self.in_scope(&expected_leaf.key_path)
.map(|_| self.terminal() == Some(expected_leaf))
}
pub fn confirm_nonexistence(&self, key_path: &KeyPath) -> Result<bool, KeyOutOfScope> {
self.in_scope(key_path).map(|_| {
self.terminal()
.as_ref()
.map_or(true, |d| &d.key_path != key_path)
})
}
fn in_scope(&self, key_path: &KeyPath) -> Result<(), KeyOutOfScope> {
let this_path = self.path();
let other_path = &key_path.view_bits::<Msb0>()[..self.key_path.len()];
if this_path == other_path {
Ok(())
} else {
Err(KeyOutOfScope)
}
}
}
impl fmt::Debug for VerifiedPathProof {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("VerifiedPathProof")
.field("path", &self.path())
.field("terminal", &self.terminal())
.field("root", &self.root())
.finish()
}
}
#[derive(Debug, Clone, Copy)]
pub enum VerifyUpdateError {
PathsOutOfOrder,
OpsOutOfOrder,
OpOutOfScope,
PathWithoutOps,
RootMismatch,
}
pub struct PathUpdate {
pub inner: VerifiedPathProof,
pub ops: Vec<(KeyPath, Option<trie::ValueHash>)>,
}
impl fmt::Debug for PathUpdate {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PathUpdate")
.field("inner", &self.inner)
.field("ops", &self.ops)
.finish()
}
}
pub fn verify_update<H: NodeHasher>(
prev_root: Node,
paths: &[PathUpdate],
) -> Result<Node, VerifyUpdateError> {
if paths.is_empty() {
return Ok(prev_root);
}
for (i, path) in paths.iter().enumerate() {
if path.inner.root() != prev_root {
return Err(VerifyUpdateError::RootMismatch);
}
if i != 0 && paths[i - 1].inner.path() >= path.inner.path() {
return Err(VerifyUpdateError::PathsOutOfOrder);
}
if path.ops.is_empty() {
return Err(VerifyUpdateError::PathWithoutOps);
}
for (j, (key, _value)) in path.ops.iter().enumerate() {
if j != 0 && &path.ops[j - 1].0 >= key {
return Err(VerifyUpdateError::OpsOutOfOrder);
}
if !key.view_bits::<Msb0>().starts_with(path.inner.path()) {
return Err(VerifyUpdateError::OpOutOfScope);
}
}
}
let mut pending_siblings: Vec<(Node, usize)> = Vec::new();
for (i, path) in paths.iter().enumerate() {
let leaf = path.inner.terminal().map(|x| x.clone());
let ops = &path.ops;
let skip = path.inner.path().len();
let up_layers = match paths.get(i + 1) {
None => skip, Some(p) => {
let n = shared_bits(p.inner.path(), path.inner.path());
skip - (n + 1)
}
};
let ops = crate::update::leaf_ops_spliced(leaf, ops);
let sub_root = crate::update::build_trie::<H>(skip, ops, |_| {});
let mut cur_node = sub_root;
let mut cur_layer = skip;
let end_layer = skip - up_layers;
for (bit, sibling) in path
.inner
.path()
.iter()
.by_vals()
.rev()
.take(up_layers)
.zip(path.inner.siblings.iter().rev())
{
let sibling = if pending_siblings.last().map_or(false, |p| p.1 == cur_layer) {
pending_siblings.pop().unwrap().0
} else {
*sibling
};
match (NodeKind::of::<H>(&cur_node), NodeKind::of::<H>(&sibling)) {
(NodeKind::Terminator, NodeKind::Terminator) => {}
(NodeKind::Leaf, NodeKind::Terminator) => {}
(NodeKind::Terminator, NodeKind::Leaf) => {
cur_node = sibling;
}
_ => {
let node_data = if bit {
trie::InternalData {
left: sibling,
right: cur_node,
}
} else {
trie::InternalData {
left: cur_node,
right: sibling,
}
};
cur_node = H::hash_internal(&node_data);
}
}
cur_layer -= 1;
}
pending_siblings.push((cur_node, end_layer));
}
Ok(pending_siblings.pop().map(|n| n.0).unwrap())
}
pub fn shared_bits(a: &BitSlice<u8, Msb0>, b: &BitSlice<u8, Msb0>) -> usize {
a.iter().zip(b.iter()).take_while(|(a, b)| a == b).count()
}