1use std::collections::{BTreeMap, BTreeSet, HashMap};
4
5use tokio::sync::mpsc;
6use zip32::DiversifierIndex;
7
8use orchard::tree::MerkleHashOrchard;
9use shardtree::ShardTree;
10use shardtree::store::memory::MemoryShardStore;
11use shardtree::store::{Checkpoint, ShardStore, TreeState};
12use zcash_keys::keys::UnifiedFullViewingKey;
13use zcash_primitives::transaction::TxId;
14use zcash_protocol::consensus::BlockHeight;
15use zcash_protocol::{PoolType, ShieldedProtocol};
16use zip32::AccountId;
17
18use crate::error::{ServerError, SyncError};
19use crate::keys::transparent::TransparentAddressId;
20use crate::sync::{MAX_REORG_ALLOWANCE, ScanRange};
21use crate::wallet::{
22 NullifierMap, OutputId, ShardTrees, SyncState, WalletBlock, WalletTransaction,
23};
24use crate::witness::LocatedTreeData;
25use crate::{Orchard, Sapling, SyncDomain, client, set_transactions_failed};
26
27use super::{FetchRequest, ScanTarget, witness};
28
29pub trait SyncWallet {
31 type Error: std::fmt::Debug + std::fmt::Display + std::error::Error;
33
34 fn get_birthday(&self) -> Result<BlockHeight, Self::Error>;
36
37 fn get_sync_state(&self) -> Result<&SyncState, Self::Error>;
39
40 fn get_sync_state_mut(&mut self) -> Result<&mut SyncState, Self::Error>;
42
43 fn get_unified_full_viewing_keys(
45 &self,
46 ) -> Result<HashMap<AccountId, UnifiedFullViewingKey>, Self::Error>;
47
48 fn add_orchard_address(
50 &mut self,
51 account_id: zip32::AccountId,
52 address: orchard::Address,
53 diversifier_index: DiversifierIndex,
54 ) -> Result<(), Self::Error>;
55
56 fn add_sapling_address(
58 &mut self,
59 account_id: zip32::AccountId,
60 address: sapling_crypto::PaymentAddress,
61 diversifier_index: DiversifierIndex,
62 ) -> Result<(), Self::Error>;
63
64 fn get_transparent_addresses(
66 &self,
67 ) -> Result<&BTreeMap<TransparentAddressId, String>, Self::Error>;
68
69 fn get_transparent_addresses_mut(
71 &mut self,
72 ) -> Result<&mut BTreeMap<TransparentAddressId, String>, Self::Error>;
73
74 fn set_save_flag(&mut self) -> Result<(), Self::Error> {
78 Ok(())
79 }
80}
81
82pub trait SyncBlocks: SyncWallet {
84 fn get_wallet_block(&self, block_height: BlockHeight) -> Result<WalletBlock, Self::Error>;
88
89 fn get_wallet_blocks_mut(
91 &mut self,
92 ) -> Result<&mut BTreeMap<BlockHeight, WalletBlock>, Self::Error>;
93
94 fn append_wallet_blocks(
96 &mut self,
97 mut wallet_blocks: BTreeMap<BlockHeight, WalletBlock>,
98 ) -> Result<(), Self::Error> {
99 self.get_wallet_blocks_mut()?.append(&mut wallet_blocks);
100
101 Ok(())
102 }
103
104 fn truncate_wallet_blocks(&mut self, truncate_height: BlockHeight) -> Result<(), Self::Error> {
106 self.get_wallet_blocks_mut()?
107 .retain(|block_height, _| *block_height <= truncate_height);
108
109 Ok(())
110 }
111}
112
113pub trait SyncTransactions: SyncWallet {
115 fn get_wallet_transactions(&self) -> Result<&HashMap<TxId, WalletTransaction>, Self::Error>;
117
118 fn get_wallet_transactions_mut(
120 &mut self,
121 ) -> Result<&mut HashMap<TxId, WalletTransaction>, Self::Error>;
122
123 fn insert_wallet_transaction(
125 &mut self,
126 wallet_transaction: WalletTransaction,
127 ) -> Result<(), Self::Error> {
128 self.get_wallet_transactions_mut()?
129 .insert(wallet_transaction.txid(), wallet_transaction);
130
131 Ok(())
132 }
133
134 fn extend_wallet_transactions(
136 &mut self,
137 wallet_transactions: HashMap<TxId, WalletTransaction>,
138 ) -> Result<(), Self::Error> {
139 self.get_wallet_transactions_mut()?
140 .extend(wallet_transactions);
141
142 Ok(())
143 }
144
145 fn truncate_wallet_transactions(
149 &mut self,
150 truncate_height: BlockHeight,
151 ) -> Result<(), Self::Error> {
152 let invalid_txids: Vec<TxId> = self
153 .get_wallet_transactions()?
154 .values()
155 .filter(|tx| tx.status().is_confirmed_after(&truncate_height))
156 .map(|tx| tx.transaction().txid())
157 .collect();
158
159 set_transactions_failed(self.get_wallet_transactions_mut()?, invalid_txids);
160
161 Ok(())
162 }
163}
164
165pub trait SyncNullifiers: SyncWallet {
167 fn get_nullifiers(&self) -> Result<&NullifierMap, Self::Error>;
169
170 fn get_nullifiers_mut(&mut self) -> Result<&mut NullifierMap, Self::Error>;
172
173 fn append_nullifiers(&mut self, nullifiers: &mut NullifierMap) -> Result<(), Self::Error> {
175 self.get_nullifiers_mut()?
176 .sapling
177 .append(&mut nullifiers.sapling);
178 self.get_nullifiers_mut()?
179 .orchard
180 .append(&mut nullifiers.orchard);
181
182 Ok(())
183 }
184
185 fn truncate_nullifiers(&mut self, truncate_height: BlockHeight) -> Result<(), Self::Error> {
187 let nullifier_map = self.get_nullifiers_mut()?;
188 nullifier_map
189 .sapling
190 .retain(|_, scan_target| scan_target.block_height <= truncate_height);
191 nullifier_map
192 .orchard
193 .retain(|_, scan_target| scan_target.block_height <= truncate_height);
194
195 Ok(())
196 }
197}
198
199pub trait SyncOutPoints: SyncWallet {
201 fn get_outpoints(&self) -> Result<&BTreeMap<OutputId, ScanTarget>, Self::Error>;
203
204 fn get_outpoints_mut(&mut self) -> Result<&mut BTreeMap<OutputId, ScanTarget>, Self::Error>;
206
207 fn append_outpoints(
209 &mut self,
210 outpoints: &mut BTreeMap<OutputId, ScanTarget>,
211 ) -> Result<(), Self::Error> {
212 self.get_outpoints_mut()?.append(outpoints);
213
214 Ok(())
215 }
216
217 fn truncate_outpoints(&mut self, truncate_height: BlockHeight) -> Result<(), Self::Error> {
219 self.get_outpoints_mut()?
220 .retain(|_, scan_target| scan_target.block_height <= truncate_height);
221
222 Ok(())
223 }
224}
225
226pub trait SyncShardTrees: SyncWallet {
228 fn get_shard_trees(&self) -> Result<&ShardTrees, Self::Error>;
230
231 fn get_shard_trees_mut(&mut self) -> Result<&mut ShardTrees, Self::Error>;
233
234 fn update_shard_trees(
238 &mut self,
239 fetch_request_sender: mpsc::UnboundedSender<FetchRequest>,
240 scan_range: &ScanRange,
241 highest_scanned_height: BlockHeight,
242 sapling_located_trees: Vec<LocatedTreeData<sapling_crypto::Node>>,
243 orchard_located_trees: Vec<LocatedTreeData<MerkleHashOrchard>>,
244 ) -> impl std::future::Future<Output = Result<(), SyncError<Self::Error>>> + Send
245 where
246 Self: std::marker::Send,
247 {
248 async move {
249 let shard_trees = self.get_shard_trees_mut().map_err(SyncError::WalletError)?;
250
251 let checkpoint_range = if scan_range.block_range().start > highest_scanned_height {
255 let verification_window_start = scan_range
256 .block_range()
257 .end
258 .saturating_sub(MAX_REORG_ALLOWANCE);
259
260 std::cmp::max(scan_range.block_range().start, verification_window_start)
261 ..scan_range.block_range().end
262 } else if scan_range.block_range().end
263 > highest_scanned_height.saturating_sub(MAX_REORG_ALLOWANCE) + 1
264 {
265 let verification_window_start =
266 highest_scanned_height.saturating_sub(MAX_REORG_ALLOWANCE) + 1;
267
268 std::cmp::max(scan_range.block_range().start, verification_window_start)
269 ..scan_range.block_range().end
270 } else {
271 BlockHeight::from_u32(0)..BlockHeight::from_u32(0)
272 };
273
274 for checkpoint_height in
278 u32::from(checkpoint_range.start)..u32::from(checkpoint_range.end)
279 {
280 let checkpoint_height = BlockHeight::from_u32(checkpoint_height);
281
282 add_checkpoint::<
283 Sapling,
284 sapling_crypto::Node,
285 { sapling_crypto::NOTE_COMMITMENT_TREE_DEPTH },
286 { witness::SHARD_HEIGHT },
287 >(
288 fetch_request_sender.clone(),
289 checkpoint_height,
290 &sapling_located_trees,
291 &mut shard_trees.sapling,
292 )
293 .await?;
294 add_checkpoint::<
295 Orchard,
296 MerkleHashOrchard,
297 { orchard::NOTE_COMMITMENT_TREE_DEPTH as u8 },
298 { witness::SHARD_HEIGHT },
299 >(
300 fetch_request_sender.clone(),
301 checkpoint_height,
302 &orchard_located_trees,
303 &mut shard_trees.orchard,
304 )
305 .await?;
306 }
307
308 for tree in sapling_located_trees {
309 shard_trees
310 .sapling
311 .insert_tree(tree.subtree, tree.checkpoints)?;
312 }
313 for tree in orchard_located_trees {
314 shard_trees
315 .orchard
316 .insert_tree(tree.subtree, tree.checkpoints)?;
317 }
318
319 Ok(())
320 }
321 }
322
323 fn truncate_shard_trees(
327 &mut self,
328 truncate_height: BlockHeight,
329 ) -> Result<(), SyncError<Self::Error>> {
330 if truncate_height == zcash_protocol::consensus::H0 {
331 let shard_trees = self.get_shard_trees_mut().map_err(SyncError::WalletError)?;
332 tracing::info!("Clearing shard trees.");
333 shard_trees.sapling =
334 ShardTree::new(MemoryShardStore::empty(), MAX_REORG_ALLOWANCE as usize);
335 shard_trees.orchard =
336 ShardTree::new(MemoryShardStore::empty(), MAX_REORG_ALLOWANCE as usize);
337 } else {
338 if !self
339 .get_shard_trees_mut()
340 .map_err(SyncError::WalletError)?
341 .sapling
342 .truncate_to_checkpoint(&truncate_height)?
343 {
344 tracing::error!("Sapling shard tree is broken! Beginning rescan.");
345 return Err(SyncError::TruncationError(
346 truncate_height,
347 PoolType::SAPLING,
348 ));
349 }
350 if !self
351 .get_shard_trees_mut()
352 .map_err(SyncError::WalletError)?
353 .orchard
354 .truncate_to_checkpoint(&truncate_height)?
355 {
356 tracing::error!("Sapling shard tree is broken! Beginning rescan.");
357 return Err(SyncError::TruncationError(
358 truncate_height,
359 PoolType::ORCHARD,
360 ));
361 }
362 }
363
364 Ok(())
365 }
366}
367
368async fn add_checkpoint<D, L, const DEPTH: u8, const SHARD_HEIGHT: u8>(
370 fetch_request_sender: mpsc::UnboundedSender<FetchRequest>,
371 checkpoint_height: BlockHeight,
372 located_trees: &[LocatedTreeData<L>],
373 shard_tree: &mut shardtree::ShardTree<
374 shardtree::store::memory::MemoryShardStore<L, BlockHeight>,
375 DEPTH,
376 SHARD_HEIGHT,
377 >,
378) -> Result<(), ServerError>
379where
380 L: Clone + PartialEq + incrementalmerkletree::Hashable,
381 D: SyncDomain,
382{
383 let checkpoint = if let Some((_, position)) = located_trees
384 .iter()
385 .flat_map(|tree| tree.checkpoints.iter())
386 .find(|(height, _)| **height == checkpoint_height)
387 {
388 Checkpoint::at_position(*position)
389 } else {
390 let mut previous_checkpoint = None;
391 shard_tree
392 .store()
393 .for_each_checkpoint(1_000, |height, checkpoint| {
394 if *height == checkpoint_height - 1 {
395 previous_checkpoint = Some(checkpoint.clone());
396 }
397 Ok(())
398 })
399 .expect("infallible");
400
401 let tree_state = if let Some(checkpoint) = previous_checkpoint {
402 checkpoint.tree_state()
403 } else {
404 let frontiers =
405 client::get_frontiers(fetch_request_sender.clone(), checkpoint_height).await?;
406 let tree_size = match D::SHIELDED_PROTOCOL {
407 ShieldedProtocol::Sapling => frontiers.final_sapling_tree().tree_size(),
408 ShieldedProtocol::Orchard => frontiers.final_orchard_tree().tree_size(),
409 };
410 if tree_size == 0 {
411 TreeState::Empty
412 } else {
413 TreeState::AtPosition(incrementalmerkletree::Position::from(tree_size - 1))
414 }
415 };
416
417 Checkpoint::from_parts(tree_state, BTreeSet::new())
418 };
419
420 shard_tree
421 .store_mut()
422 .add_checkpoint(checkpoint_height, checkpoint)
423 .expect("infallible");
424
425 Ok(())
426}