use std::collections::{BTreeMap, BTreeSet, HashMap};
use tokio::sync::mpsc;
use zip32::DiversifierIndex;
use orchard::tree::MerkleHashOrchard;
use shardtree::ShardTree;
use shardtree::store::memory::MemoryShardStore;
use shardtree::store::{Checkpoint, ShardStore, TreeState};
use zcash_keys::keys::UnifiedFullViewingKey;
use zcash_primitives::transaction::TxId;
use zcash_protocol::consensus::BlockHeight;
use zcash_protocol::{PoolType, ShieldedProtocol};
use zip32::AccountId;
use crate::error::{ServerError, SyncError};
use crate::keys::transparent::TransparentAddressId;
use crate::sync::{MAX_REORG_ALLOWANCE, ScanRange};
use crate::wallet::{
NullifierMap, OutputId, ShardTrees, SyncState, WalletBlock, WalletTransaction,
};
use crate::witness::LocatedTreeData;
use crate::{Orchard, Sapling, SyncDomain, client, set_transactions_failed};
use super::{FetchRequest, ScanTarget, witness};
pub trait SyncWallet {
type Error: std::fmt::Debug + std::fmt::Display + std::error::Error;
fn get_birthday(&self) -> Result<BlockHeight, Self::Error>;
fn get_sync_state(&self) -> Result<&SyncState, Self::Error>;
fn get_sync_state_mut(&mut self) -> Result<&mut SyncState, Self::Error>;
fn get_unified_full_viewing_keys(
&self,
) -> Result<HashMap<AccountId, UnifiedFullViewingKey>, Self::Error>;
fn add_orchard_address(
&mut self,
account_id: zip32::AccountId,
address: orchard::Address,
diversifier_index: DiversifierIndex,
) -> Result<(), Self::Error>;
fn add_sapling_address(
&mut self,
account_id: zip32::AccountId,
address: sapling_crypto::PaymentAddress,
diversifier_index: DiversifierIndex,
) -> Result<(), Self::Error>;
fn get_transparent_addresses(
&self,
) -> Result<&BTreeMap<TransparentAddressId, String>, Self::Error>;
fn get_transparent_addresses_mut(
&mut self,
) -> Result<&mut BTreeMap<TransparentAddressId, String>, Self::Error>;
fn set_save_flag(&mut self) -> Result<(), Self::Error> {
Ok(())
}
}
pub trait SyncBlocks: SyncWallet {
fn get_wallet_block(&self, block_height: BlockHeight) -> Result<WalletBlock, Self::Error>;
fn get_wallet_blocks_mut(
&mut self,
) -> Result<&mut BTreeMap<BlockHeight, WalletBlock>, Self::Error>;
fn append_wallet_blocks(
&mut self,
mut wallet_blocks: BTreeMap<BlockHeight, WalletBlock>,
) -> Result<(), Self::Error> {
self.get_wallet_blocks_mut()?.append(&mut wallet_blocks);
Ok(())
}
fn truncate_wallet_blocks(&mut self, truncate_height: BlockHeight) -> Result<(), Self::Error> {
self.get_wallet_blocks_mut()?
.retain(|block_height, _| *block_height <= truncate_height);
Ok(())
}
}
pub trait SyncTransactions: SyncWallet {
fn get_wallet_transactions(&self) -> Result<&HashMap<TxId, WalletTransaction>, Self::Error>;
fn get_wallet_transactions_mut(
&mut self,
) -> Result<&mut HashMap<TxId, WalletTransaction>, Self::Error>;
fn insert_wallet_transaction(
&mut self,
wallet_transaction: WalletTransaction,
) -> Result<(), Self::Error> {
self.get_wallet_transactions_mut()?
.insert(wallet_transaction.txid(), wallet_transaction);
Ok(())
}
fn extend_wallet_transactions(
&mut self,
wallet_transactions: HashMap<TxId, WalletTransaction>,
) -> Result<(), Self::Error> {
self.get_wallet_transactions_mut()?
.extend(wallet_transactions);
Ok(())
}
fn truncate_wallet_transactions(
&mut self,
truncate_height: BlockHeight,
) -> Result<(), Self::Error> {
let invalid_txids: Vec<TxId> = self
.get_wallet_transactions()?
.values()
.filter(|tx| tx.status().is_confirmed_after(&truncate_height))
.map(|tx| tx.transaction().txid())
.collect();
set_transactions_failed(self.get_wallet_transactions_mut()?, invalid_txids);
Ok(())
}
}
pub trait SyncNullifiers: SyncWallet {
fn get_nullifiers(&self) -> Result<&NullifierMap, Self::Error>;
fn get_nullifiers_mut(&mut self) -> Result<&mut NullifierMap, Self::Error>;
fn append_nullifiers(&mut self, nullifiers: &mut NullifierMap) -> Result<(), Self::Error> {
self.get_nullifiers_mut()?
.sapling
.append(&mut nullifiers.sapling);
self.get_nullifiers_mut()?
.orchard
.append(&mut nullifiers.orchard);
Ok(())
}
fn truncate_nullifiers(&mut self, truncate_height: BlockHeight) -> Result<(), Self::Error> {
let nullifier_map = self.get_nullifiers_mut()?;
nullifier_map
.sapling
.retain(|_, scan_target| scan_target.block_height <= truncate_height);
nullifier_map
.orchard
.retain(|_, scan_target| scan_target.block_height <= truncate_height);
Ok(())
}
}
pub trait SyncOutPoints: SyncWallet {
fn get_outpoints(&self) -> Result<&BTreeMap<OutputId, ScanTarget>, Self::Error>;
fn get_outpoints_mut(&mut self) -> Result<&mut BTreeMap<OutputId, ScanTarget>, Self::Error>;
fn append_outpoints(
&mut self,
outpoints: &mut BTreeMap<OutputId, ScanTarget>,
) -> Result<(), Self::Error> {
self.get_outpoints_mut()?.append(outpoints);
Ok(())
}
fn truncate_outpoints(&mut self, truncate_height: BlockHeight) -> Result<(), Self::Error> {
self.get_outpoints_mut()?
.retain(|_, scan_target| scan_target.block_height <= truncate_height);
Ok(())
}
}
pub trait SyncShardTrees: SyncWallet {
fn get_shard_trees(&self) -> Result<&ShardTrees, Self::Error>;
fn get_shard_trees_mut(&mut self) -> Result<&mut ShardTrees, Self::Error>;
fn update_shard_trees(
&mut self,
fetch_request_sender: mpsc::UnboundedSender<FetchRequest>,
scan_range: &ScanRange,
highest_scanned_height: BlockHeight,
sapling_located_trees: Vec<LocatedTreeData<sapling_crypto::Node>>,
orchard_located_trees: Vec<LocatedTreeData<MerkleHashOrchard>>,
) -> impl std::future::Future<Output = Result<(), SyncError<Self::Error>>> + Send
where
Self: std::marker::Send,
{
async move {
let shard_trees = self.get_shard_trees_mut().map_err(SyncError::WalletError)?;
let checkpoint_range = if scan_range.block_range().start > highest_scanned_height {
let verification_window_start = scan_range
.block_range()
.end
.saturating_sub(MAX_REORG_ALLOWANCE);
std::cmp::max(scan_range.block_range().start, verification_window_start)
..scan_range.block_range().end
} else if scan_range.block_range().end
> highest_scanned_height.saturating_sub(MAX_REORG_ALLOWANCE) + 1
{
let verification_window_start =
highest_scanned_height.saturating_sub(MAX_REORG_ALLOWANCE) + 1;
std::cmp::max(scan_range.block_range().start, verification_window_start)
..scan_range.block_range().end
} else {
BlockHeight::from_u32(0)..BlockHeight::from_u32(0)
};
for checkpoint_height in
u32::from(checkpoint_range.start)..u32::from(checkpoint_range.end)
{
let checkpoint_height = BlockHeight::from_u32(checkpoint_height);
add_checkpoint::<
Sapling,
sapling_crypto::Node,
{ sapling_crypto::NOTE_COMMITMENT_TREE_DEPTH },
{ witness::SHARD_HEIGHT },
>(
fetch_request_sender.clone(),
checkpoint_height,
&sapling_located_trees,
&mut shard_trees.sapling,
)
.await?;
add_checkpoint::<
Orchard,
MerkleHashOrchard,
{ orchard::NOTE_COMMITMENT_TREE_DEPTH as u8 },
{ witness::SHARD_HEIGHT },
>(
fetch_request_sender.clone(),
checkpoint_height,
&orchard_located_trees,
&mut shard_trees.orchard,
)
.await?;
}
for tree in sapling_located_trees {
shard_trees
.sapling
.insert_tree(tree.subtree, tree.checkpoints)?;
}
for tree in orchard_located_trees {
shard_trees
.orchard
.insert_tree(tree.subtree, tree.checkpoints)?;
}
Ok(())
}
}
fn truncate_shard_trees(
&mut self,
truncate_height: BlockHeight,
) -> Result<(), SyncError<Self::Error>> {
if truncate_height == zcash_protocol::consensus::H0 {
let shard_trees = self.get_shard_trees_mut().map_err(SyncError::WalletError)?;
tracing::info!("Clearing shard trees.");
shard_trees.sapling =
ShardTree::new(MemoryShardStore::empty(), MAX_REORG_ALLOWANCE as usize);
shard_trees.orchard =
ShardTree::new(MemoryShardStore::empty(), MAX_REORG_ALLOWANCE as usize);
} else {
if !self
.get_shard_trees_mut()
.map_err(SyncError::WalletError)?
.sapling
.truncate_to_checkpoint(&truncate_height)?
{
tracing::error!("Sapling shard tree is broken! Beginning rescan.");
return Err(SyncError::TruncationError(
truncate_height,
PoolType::SAPLING,
));
}
if !self
.get_shard_trees_mut()
.map_err(SyncError::WalletError)?
.orchard
.truncate_to_checkpoint(&truncate_height)?
{
tracing::error!("Sapling shard tree is broken! Beginning rescan.");
return Err(SyncError::TruncationError(
truncate_height,
PoolType::ORCHARD,
));
}
}
Ok(())
}
}
async fn add_checkpoint<D, L, const DEPTH: u8, const SHARD_HEIGHT: u8>(
fetch_request_sender: mpsc::UnboundedSender<FetchRequest>,
checkpoint_height: BlockHeight,
located_trees: &[LocatedTreeData<L>],
shard_tree: &mut shardtree::ShardTree<
shardtree::store::memory::MemoryShardStore<L, BlockHeight>,
DEPTH,
SHARD_HEIGHT,
>,
) -> Result<(), ServerError>
where
L: Clone + PartialEq + incrementalmerkletree::Hashable,
D: SyncDomain,
{
let checkpoint = if let Some((_, position)) = located_trees
.iter()
.flat_map(|tree| tree.checkpoints.iter())
.find(|(height, _)| **height == checkpoint_height)
{
Checkpoint::at_position(*position)
} else {
let mut previous_checkpoint = None;
shard_tree
.store()
.for_each_checkpoint(1_000, |height, checkpoint| {
if *height == checkpoint_height - 1 {
previous_checkpoint = Some(checkpoint.clone());
}
Ok(())
})
.expect("infallible");
let tree_state = if let Some(checkpoint) = previous_checkpoint {
checkpoint.tree_state()
} else {
let frontiers =
client::get_frontiers(fetch_request_sender.clone(), checkpoint_height).await?;
let tree_size = match D::SHIELDED_PROTOCOL {
ShieldedProtocol::Sapling => frontiers.final_sapling_tree().tree_size(),
ShieldedProtocol::Orchard => frontiers.final_orchard_tree().tree_size(),
};
if tree_size == 0 {
TreeState::Empty
} else {
TreeState::AtPosition(incrementalmerkletree::Position::from(tree_size - 1))
}
};
Checkpoint::from_parts(tree_state, BTreeSet::new())
};
shard_tree
.store_mut()
.add_checkpoint(checkpoint_height, checkpoint)
.expect("infallible");
Ok(())
}