use std::collections::BTreeMap;
use incrementalmerkletree::{Hashable, Level, Retention};
use pasta_curves::{group::ff::PrimeField, Fp};
use shardtree::{
error::ShardTreeError,
store::{memory::MemoryShardStore, ShardStore},
ShardTree,
};
use crate::hash::{MerkleHashVote, MAX_CHECKPOINTS, SHARD_HEIGHT, TREE_DEPTH};
use crate::kv_shard_store::{KvCallbacks, KvError, KvShardStore};
use crate::path::MerklePath;
use crate::sync_api::BlockCommitments;
pub struct GenericTreeServer<
S: shardtree::store::ShardStore<H = MerkleHashVote, CheckpointId = u32>,
> {
pub(crate) inner: ShardTree<S, { TREE_DEPTH as u8 }, { SHARD_HEIGHT }>,
pub(crate) latest_checkpoint: Option<u32>,
pub(crate) next_position: u64,
}
pub type TreeServer = GenericTreeServer<KvShardStore>;
pub struct SyncableServer<S: ShardStore<H = MerkleHashVote, CheckpointId = u32>> {
pub(crate) tree: GenericTreeServer<S>,
pub(crate) blocks: BTreeMap<u32, BlockCommitments>,
pending_leaves: Vec<MerkleHashVote>,
pending_start: u64,
}
pub type MemoryTreeServer = SyncableServer<MemoryShardStore<MerkleHashVote, u32>>;
#[derive(Debug)]
pub enum AppendFromKvError {
Kv(KvError),
MissingLeaf(u64),
MalformedLeaf(u64),
Tree(ShardTreeError<KvError>),
}
impl From<KvError> for AppendFromKvError {
fn from(e: KvError) -> Self {
AppendFromKvError::Kv(e)
}
}
impl std::fmt::Display for AppendFromKvError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AppendFromKvError::Kv(e) => write!(f, "KV error reading leaf: {}", e),
AppendFromKvError::MissingLeaf(i) => {
write!(f, "leaf at index {} is missing from KV", i)
}
AppendFromKvError::MalformedLeaf(i) => {
write!(
f,
"leaf at index {} is malformed (wrong length or non-canonical Fp)",
i
)
}
AppendFromKvError::Tree(e) => write!(f, "ShardTree error: {:?}", e),
}
}
}
impl std::error::Error for AppendFromKvError {}
#[derive(Debug)]
pub enum CheckpointError<E> {
NotMonotonic { prev: u32, requested: u32 },
Tree(ShardTreeError<E>),
}
impl<E: std::fmt::Debug> std::fmt::Display for CheckpointError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CheckpointError::NotMonotonic { prev, requested } => write!(
f,
"checkpoint height must be strictly increasing: {} <= {}",
requested, prev
),
CheckpointError::Tree(e) => write!(f, "ShardTree error: {:?}", e),
}
}
}
impl<E: std::fmt::Debug + 'static> std::error::Error for CheckpointError<E> {}
impl<E> From<ShardTreeError<E>> for CheckpointError<E> {
fn from(e: ShardTreeError<E>) -> Self {
CheckpointError::Tree(e)
}
}
impl TreeServer {
pub fn new(cb: KvCallbacks, next_position: u64) -> Self {
let store = KvShardStore::new(cb);
let latest_checkpoint = store.max_checkpoint_id().unwrap_or(None);
Self {
inner: ShardTree::new(store, MAX_CHECKPOINTS),
latest_checkpoint,
next_position,
}
}
}
impl TreeServer {
pub fn append_from_kv(&mut self, cursor: u64, count: u64) -> Result<(), AppendFromKvError> {
for i in cursor..cursor + count {
let key = Self::leaf_key(i);
let blob = self
.inner
.store()
.cb
.get(&key)?
.ok_or(AppendFromKvError::MissingLeaf(i))?;
if blob.len() != 32 {
return Err(AppendFromKvError::MalformedLeaf(i));
}
let mut repr = [0u8; 32];
repr.copy_from_slice(&blob);
let fp: Option<Fp> = Fp::from_repr(repr).into();
let fp = fp.ok_or(AppendFromKvError::MalformedLeaf(i))?;
self.append(fp).map_err(AppendFromKvError::Tree)?;
}
Ok(())
}
fn leaf_key(index: u64) -> [u8; 9] {
let mut k = [0u8; 9];
k[0] = 0x02;
k[1..].copy_from_slice(&index.to_be_bytes());
k
}
}
impl SyncableServer<MemoryShardStore<MerkleHashVote, u32>> {
pub fn empty() -> Self {
Self {
tree: GenericTreeServer {
inner: ShardTree::new(MemoryShardStore::empty(), MAX_CHECKPOINTS),
latest_checkpoint: None,
next_position: 0,
},
blocks: BTreeMap::new(),
pending_leaves: Vec::new(),
pending_start: 0,
}
}
}
impl<S> SyncableServer<S>
where
S: ShardStore<H = MerkleHashVote, CheckpointId = u32>,
S::Error: std::fmt::Debug,
{
pub fn new(tree: GenericTreeServer<S>) -> Self {
Self {
tree,
blocks: BTreeMap::new(),
pending_leaves: Vec::new(),
pending_start: 0,
}
}
pub fn append(&mut self, leaf: Fp) -> Result<u64, ShardTreeError<S::Error>> {
let hash = MerkleHashVote::from_fp(leaf);
let idx = self.tree.append(leaf)?;
self.pending_leaves.push(hash);
Ok(idx)
}
pub fn append_two(&mut self, first: Fp, second: Fp) -> Result<u64, ShardTreeError<S::Error>> {
let idx = self.append(first)?;
self.append(second)?;
Ok(idx)
}
pub fn checkpoint(&mut self, height: u32) -> Result<(), CheckpointError<S::Error>> {
self.tree.checkpoint(height)?;
let root = self
.root_at_height(height)
.unwrap_or_else(|| self.tree.root());
let commitments = BlockCommitments {
height,
start_index: self.pending_start,
leaves: std::mem::take(&mut self.pending_leaves),
root,
};
self.blocks.insert(height, commitments);
self.pending_start = self.tree.next_position;
Ok(())
}
pub fn root(&self) -> Fp {
self.tree.root()
}
pub fn root_at_height(&self, height: u32) -> Option<Fp> {
self.tree.root_at_height(height)
}
pub fn size(&self) -> u64 {
self.tree.size()
}
pub fn path(&self, position: u64, anchor_height: u32) -> Option<MerklePath> {
self.tree.path(position, anchor_height)
}
pub fn latest_checkpoint(&self) -> Option<u32> {
self.tree.latest_checkpoint
}
}
impl<S> GenericTreeServer<S>
where
S: shardtree::store::ShardStore<H = MerkleHashVote, CheckpointId = u32>,
S::Error: std::fmt::Debug,
{
pub fn append(&mut self, leaf: Fp) -> Result<u64, ShardTreeError<S::Error>> {
let index = self.next_position;
let hash = MerkleHashVote::from_fp(leaf);
self.inner.append(hash, Retention::Marked)?;
self.next_position += 1;
Ok(index)
}
pub fn append_two(&mut self, first: Fp, second: Fp) -> Result<u64, ShardTreeError<S::Error>> {
let index = self.append(first)?;
self.append(second)?;
Ok(index)
}
pub fn checkpoint(&mut self, height: u32) -> Result<(), CheckpointError<S::Error>> {
if let Some(prev) = self.latest_checkpoint {
if height <= prev {
return Err(CheckpointError::NotMonotonic {
prev,
requested: height,
});
}
}
self.inner.checkpoint(height)?;
self.latest_checkpoint = Some(height);
Ok(())
}
pub fn root(&self) -> Fp {
if let Some(id) = self.latest_checkpoint {
self.inner
.root_at_checkpoint_id(&id)
.ok()
.flatten()
.map(|h| h.0)
.unwrap_or_else(|| MerkleHashVote::empty_root(Level::from(TREE_DEPTH as u8)).0)
} else {
MerkleHashVote::empty_root(Level::from(TREE_DEPTH as u8)).0
}
}
pub fn root_at_height(&self, height: u32) -> Option<Fp> {
self.inner
.root_at_checkpoint_id(&height)
.ok()
.flatten()
.map(|h| h.0)
}
pub fn size(&self) -> u64 {
self.next_position
}
pub fn set_next_position(&mut self, pos: u64) {
self.next_position = pos;
}
pub fn path(&self, position: u64, anchor_height: u32) -> Option<MerklePath> {
let pos = incrementalmerkletree::Position::from(position);
self.inner
.witness_at_checkpoint_id(pos, &anchor_height)
.ok()
.flatten()
.map(MerklePath::from)
}
}