use std::collections::BTreeSet;
use std::fmt;
use incrementalmerkletree::{Hashable, Level, Retention};
use pasta_curves::Fp;
use shardtree::{store::memory::MemoryShardStore, ShardTree};
use crate::hash::{MerkleHashVote, MAX_CHECKPOINTS, SHARD_HEIGHT, TREE_DEPTH};
use crate::path::MerklePath;
use crate::sync_api::TreeSyncApi;
#[derive(Debug)]
pub enum SyncError<E: fmt::Debug> {
Api(E),
StartIndexMismatch {
height: u32,
expected: u64,
got: u64,
},
RootMismatch {
height: u32,
local: Option<Fp>,
server: Fp,
},
IncompleteSync {
local_next_index: u64,
server_next_index: u64,
},
InvalidPagination { current: u32, next: u32 },
}
impl<E: fmt::Debug> fmt::Display for SyncError<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SyncError::Api(e) => write!(f, "sync API error: {:?}", e),
SyncError::StartIndexMismatch {
height,
expected,
got,
} => write!(
f,
"start_index mismatch at height {}: expected {}, got {}",
height, expected, got
),
SyncError::RootMismatch {
height,
local,
server,
} => write!(
f,
"root mismatch at height {}: local={:?}, server={:?}",
height, local, server
),
SyncError::IncompleteSync {
local_next_index,
server_next_index,
} => write!(
f,
"incomplete sync: local next_index={}, server next_index={}",
local_next_index, server_next_index
),
SyncError::InvalidPagination { current, next } => write!(
f,
"invalid pagination cursor: current={}, next={}",
current, next
),
}
}
}
impl<E: fmt::Debug> From<E> for SyncError<E> {
fn from(err: E) -> Self {
SyncError::Api(err)
}
}
pub struct TreeClient {
inner: ShardTree<MemoryShardStore<MerkleHashVote, u32>, { TREE_DEPTH as u8 }, { SHARD_HEIGHT }>,
next_position: u64,
last_synced_height: Option<u32>,
marked_positions: BTreeSet<u64>,
}
impl TreeClient {
pub fn empty() -> Self {
Self {
inner: ShardTree::new(MemoryShardStore::empty(), MAX_CHECKPOINTS),
next_position: 0,
last_synced_height: None,
marked_positions: BTreeSet::new(),
}
}
pub fn mark_position(&mut self, position: u64) {
self.marked_positions.insert(position);
}
pub fn sync<A: TreeSyncApi>(&mut self, api: &A) -> Result<(), SyncError<A::Error>> {
let state = api.get_tree_state()?;
if state.next_index == self.next_position {
if state.next_index > 0 {
let local = self.root();
if local != state.root {
return Err(SyncError::RootMismatch {
height: state.height,
local: Some(local),
server: state.root,
});
}
}
return Ok(());
}
let from_height = self.last_synced_height.map(|h| h + 1).unwrap_or(0);
let to_height = state.height;
if from_height > to_height {
return Ok(()); }
let mut page_from = from_height;
loop {
let page = api.get_block_commitments(page_from, to_height)?;
for block in &page.blocks {
if !block.leaves.is_empty() && block.start_index != self.next_position {
return Err(SyncError::StartIndexMismatch {
height: block.height,
expected: self.next_position,
got: block.start_index,
});
}
for leaf in &block.leaves {
let retention = if self.marked_positions.contains(&self.next_position) {
Retention::Marked
} else {
Retention::Ephemeral
};
self.inner
.append(*leaf, retention)
.expect("append must succeed (tree not full)");
self.next_position += 1;
}
self.inner
.checkpoint(block.height)
.expect("checkpoint must succeed");
self.last_synced_height = Some(block.height);
let local = self.root_at_height(block.height);
if local != Some(block.root) {
return Err(SyncError::RootMismatch {
height: block.height,
local,
server: block.root,
});
}
}
if page.next_from_height == 0 {
break;
}
if page.next_from_height <= page_from {
return Err(SyncError::InvalidPagination {
current: page_from,
next: page.next_from_height,
});
}
if page.next_from_height > to_height {
return Err(SyncError::InvalidPagination {
current: page_from,
next: page.next_from_height,
});
}
page_from = page.next_from_height;
}
if self.next_position != state.next_index {
return Err(SyncError::IncompleteSync {
local_next_index: self.next_position,
server_next_index: state.next_index,
});
}
if state.next_index > 0 {
let local = self.root();
if local != state.root {
return Err(SyncError::RootMismatch {
height: state.height,
local: Some(local),
server: state.root,
});
}
}
Ok(())
}
pub fn witness(&self, position: u64, anchor_height: u32) -> Option<MerklePath> {
let pos = incrementalmerkletree::Position::from(position);
self.inner
.witness_at_checkpoint_id(pos, &anchor_height)
.ok()
.flatten()
.map(MerklePath::from)
}
pub fn root_at_height(&self, height: u32) -> Option<Fp> {
self.inner
.root_at_checkpoint_id(&height)
.ok()
.flatten()
.map(|h| h.0)
}
pub fn root(&self) -> Fp {
if let Some(id) = self.last_synced_height {
self.inner
.root_at_checkpoint_id(&id)
.ok()
.flatten()
.map(|h| h.0)
.unwrap_or_else(|| MerkleHashVote::empty_root(Level::from(TREE_DEPTH as u8)).0)
} else {
MerkleHashVote::empty_root(Level::from(TREE_DEPTH as u8)).0
}
}
pub fn size(&self) -> u64 {
self.next_position
}
pub fn last_synced_height(&self) -> Option<u32> {
self.last_synced_height
}
}