1use std::collections::{BTreeMap, BTreeSet, HashMap};
4
5use tokio::sync::mpsc;
6use zip32::DiversifierIndex;
7
8use orchard::tree::MerkleHashOrchard;
9use shardtree::ShardTree;
10use shardtree::store::memory::MemoryShardStore;
11use shardtree::store::{Checkpoint, ShardStore, TreeState};
12use zcash_keys::keys::UnifiedFullViewingKey;
13use zcash_primitives::consensus::BlockHeight;
14use zcash_primitives::transaction::TxId;
15use zcash_primitives::zip32::AccountId;
16use zcash_protocol::{PoolType, ShieldedProtocol};
17
18use crate::error::{ServerError, SyncError};
19use crate::keys::transparent::TransparentAddressId;
20use crate::sync::{MAX_VERIFICATION_WINDOW, ScanRange};
21use crate::wallet::{
22 NullifierMap, OutputId, ShardTrees, SyncState, WalletBlock, WalletTransaction,
23};
24use crate::witness::LocatedTreeData;
25use crate::{Orchard, Sapling, SyncDomain, client, reset_spends};
26
27use super::{FetchRequest, ScanTarget, witness};
28
29pub trait SyncWallet {
31 type Error: std::fmt::Debug + std::fmt::Display + std::error::Error;
33
34 fn get_birthday(&self) -> Result<BlockHeight, Self::Error>;
36
37 fn get_sync_state(&self) -> Result<&SyncState, Self::Error>;
39
40 fn get_sync_state_mut(&mut self) -> Result<&mut SyncState, Self::Error>;
42
43 fn get_unified_full_viewing_keys(
45 &self,
46 ) -> Result<HashMap<AccountId, UnifiedFullViewingKey>, Self::Error>;
47
48 fn add_orchard_address(
50 &mut self,
51 account_id: zip32::AccountId,
52 address: orchard::Address,
53 diversifier_index: DiversifierIndex,
54 ) -> Result<(), Self::Error>;
55
56 fn add_sapling_address(
58 &mut self,
59 account_id: zip32::AccountId,
60 address: sapling_crypto::PaymentAddress,
61 diversifier_index: DiversifierIndex,
62 ) -> Result<(), Self::Error>;
63
64 fn get_transparent_addresses(
66 &self,
67 ) -> Result<&BTreeMap<TransparentAddressId, String>, Self::Error>;
68
69 fn get_transparent_addresses_mut(
71 &mut self,
72 ) -> Result<&mut BTreeMap<TransparentAddressId, String>, Self::Error>;
73
74 fn set_save_flag(&mut self) -> Result<(), Self::Error> {
78 Ok(())
79 }
80}
81
82pub trait SyncBlocks: SyncWallet {
84 fn get_wallet_block(&self, block_height: BlockHeight) -> Result<WalletBlock, Self::Error>;
88
89 fn get_wallet_blocks_mut(
91 &mut self,
92 ) -> Result<&mut BTreeMap<BlockHeight, WalletBlock>, Self::Error>;
93
94 fn append_wallet_blocks(
96 &mut self,
97 mut wallet_blocks: BTreeMap<BlockHeight, WalletBlock>,
98 ) -> Result<(), Self::Error> {
99 self.get_wallet_blocks_mut()?.append(&mut wallet_blocks);
100
101 Ok(())
102 }
103
104 fn truncate_wallet_blocks(&mut self, truncate_height: BlockHeight) -> Result<(), Self::Error> {
106 self.get_wallet_blocks_mut()?
107 .retain(|block_height, _| *block_height <= truncate_height);
108
109 Ok(())
110 }
111}
112
113pub trait SyncTransactions: SyncWallet {
115 fn get_wallet_transactions(&self) -> Result<&HashMap<TxId, WalletTransaction>, Self::Error>;
117
118 fn get_wallet_transactions_mut(
120 &mut self,
121 ) -> Result<&mut HashMap<TxId, WalletTransaction>, Self::Error>;
122
123 fn insert_wallet_transaction(
125 &mut self,
126 wallet_transaction: WalletTransaction,
127 ) -> Result<(), Self::Error> {
128 self.get_wallet_transactions_mut()?
129 .insert(wallet_transaction.txid(), wallet_transaction);
130
131 Ok(())
132 }
133
134 fn extend_wallet_transactions(
136 &mut self,
137 wallet_transactions: HashMap<TxId, WalletTransaction>,
138 ) -> Result<(), Self::Error> {
139 self.get_wallet_transactions_mut()?
140 .extend(wallet_transactions);
141
142 Ok(())
143 }
144
145 fn truncate_wallet_transactions(
148 &mut self,
149 truncate_height: BlockHeight,
150 ) -> Result<(), Self::Error> {
151 let invalid_txids: Vec<TxId> = self
152 .get_wallet_transactions()?
153 .values()
154 .filter(|tx| tx.status().is_confirmed_after(&truncate_height))
155 .map(|tx| tx.transaction().txid())
156 .collect();
157
158 let wallet_transactions = self.get_wallet_transactions_mut()?;
159 reset_spends(wallet_transactions, invalid_txids.clone());
160 for invalid_txid in &invalid_txids {
161 wallet_transactions.remove(invalid_txid);
162 }
163
164 Ok(())
165 }
166}
167
168pub trait SyncNullifiers: SyncWallet {
170 fn get_nullifiers(&self) -> Result<&NullifierMap, Self::Error>;
172
173 fn get_nullifiers_mut(&mut self) -> Result<&mut NullifierMap, Self::Error>;
175
176 fn append_nullifiers(&mut self, nullifiers: &mut NullifierMap) -> Result<(), Self::Error> {
178 self.get_nullifiers_mut()?
179 .sapling
180 .append(&mut nullifiers.sapling);
181 self.get_nullifiers_mut()?
182 .orchard
183 .append(&mut nullifiers.orchard);
184
185 Ok(())
186 }
187
188 fn truncate_nullifiers(&mut self, truncate_height: BlockHeight) -> Result<(), Self::Error> {
190 let nullifier_map = self.get_nullifiers_mut()?;
191 nullifier_map
192 .sapling
193 .retain(|_, scan_target| scan_target.block_height <= truncate_height);
194 nullifier_map
195 .orchard
196 .retain(|_, scan_target| scan_target.block_height <= truncate_height);
197
198 Ok(())
199 }
200}
201
202pub trait SyncOutPoints: SyncWallet {
204 fn get_outpoints(&self) -> Result<&BTreeMap<OutputId, ScanTarget>, Self::Error>;
206
207 fn get_outpoints_mut(&mut self) -> Result<&mut BTreeMap<OutputId, ScanTarget>, Self::Error>;
209
210 fn append_outpoints(
212 &mut self,
213 outpoints: &mut BTreeMap<OutputId, ScanTarget>,
214 ) -> Result<(), Self::Error> {
215 self.get_outpoints_mut()?.append(outpoints);
216
217 Ok(())
218 }
219
220 fn truncate_outpoints(&mut self, truncate_height: BlockHeight) -> Result<(), Self::Error> {
222 self.get_outpoints_mut()?
223 .retain(|_, scan_target| scan_target.block_height <= truncate_height);
224
225 Ok(())
226 }
227}
228
229pub trait SyncShardTrees: SyncWallet {
231 fn get_shard_trees(&self) -> Result<&ShardTrees, Self::Error>;
233
234 fn get_shard_trees_mut(&mut self) -> Result<&mut ShardTrees, Self::Error>;
236
237 fn update_shard_trees(
241 &mut self,
242 fetch_request_sender: mpsc::UnboundedSender<FetchRequest>,
243 scan_range: &ScanRange,
244 highest_scanned_height: BlockHeight,
245 sapling_located_trees: Vec<LocatedTreeData<sapling_crypto::Node>>,
246 orchard_located_trees: Vec<LocatedTreeData<MerkleHashOrchard>>,
247 ) -> impl std::future::Future<Output = Result<(), SyncError<Self::Error>>> + Send
248 where
249 Self: std::marker::Send,
250 {
251 async move {
252 let shard_trees = self.get_shard_trees_mut().map_err(SyncError::WalletError)?;
253
254 let checkpoint_range = if scan_range.block_range().start > highest_scanned_height {
258 let verification_window_start = scan_range
259 .block_range()
260 .end
261 .saturating_sub(MAX_VERIFICATION_WINDOW);
262
263 std::cmp::max(scan_range.block_range().start, verification_window_start)
264 ..scan_range.block_range().end
265 } else if scan_range.block_range().end
266 > highest_scanned_height.saturating_sub(MAX_VERIFICATION_WINDOW) + 1
267 {
268 let verification_window_start =
269 highest_scanned_height.saturating_sub(MAX_VERIFICATION_WINDOW) + 1;
270
271 std::cmp::max(scan_range.block_range().start, verification_window_start)
272 ..scan_range.block_range().end
273 } else {
274 BlockHeight::from_u32(0)..BlockHeight::from_u32(0)
275 };
276
277 for checkpoint_height in
281 u32::from(checkpoint_range.start)..u32::from(checkpoint_range.end)
282 {
283 let checkpoint_height = BlockHeight::from_u32(checkpoint_height);
284
285 add_checkpoint::<
286 Sapling,
287 sapling_crypto::Node,
288 { sapling_crypto::NOTE_COMMITMENT_TREE_DEPTH },
289 { witness::SHARD_HEIGHT },
290 >(
291 fetch_request_sender.clone(),
292 checkpoint_height,
293 &sapling_located_trees,
294 &mut shard_trees.sapling,
295 )
296 .await?;
297 add_checkpoint::<
298 Orchard,
299 MerkleHashOrchard,
300 { orchard::NOTE_COMMITMENT_TREE_DEPTH as u8 },
301 { witness::SHARD_HEIGHT },
302 >(
303 fetch_request_sender.clone(),
304 checkpoint_height,
305 &orchard_located_trees,
306 &mut shard_trees.orchard,
307 )
308 .await?;
309 }
310
311 for tree in sapling_located_trees {
312 shard_trees
313 .sapling
314 .insert_tree(tree.subtree, tree.checkpoints)?;
315 }
316 for tree in orchard_located_trees {
317 shard_trees
318 .orchard
319 .insert_tree(tree.subtree, tree.checkpoints)?;
320 }
321
322 Ok(())
323 }
324 }
325
326 fn truncate_shard_trees(
330 &mut self,
331 truncate_height: BlockHeight,
332 ) -> Result<(), SyncError<Self::Error>> {
333 if truncate_height == zcash_protocol::consensus::H0 {
334 let shard_trees = self.get_shard_trees_mut().map_err(SyncError::WalletError)?;
335 shard_trees.sapling =
336 ShardTree::new(MemoryShardStore::empty(), MAX_VERIFICATION_WINDOW as usize);
337 shard_trees.orchard =
338 ShardTree::new(MemoryShardStore::empty(), MAX_VERIFICATION_WINDOW as usize);
339 } else {
340 if !self
341 .get_shard_trees_mut()
342 .map_err(SyncError::WalletError)?
343 .sapling
344 .truncate_to_checkpoint(&truncate_height)?
345 {
346 return Err(SyncError::TruncationError(
347 truncate_height,
348 PoolType::SAPLING,
349 ));
350 }
351 if !self
352 .get_shard_trees_mut()
353 .map_err(SyncError::WalletError)?
354 .orchard
355 .truncate_to_checkpoint(&truncate_height)?
356 {
357 return Err(SyncError::TruncationError(
358 truncate_height,
359 PoolType::ORCHARD,
360 ));
361 }
362 }
363
364 Ok(())
365 }
366}
367
368async fn add_checkpoint<D, L, const DEPTH: u8, const SHARD_HEIGHT: u8>(
370 fetch_request_sender: mpsc::UnboundedSender<FetchRequest>,
371 checkpoint_height: BlockHeight,
372 located_trees: &[LocatedTreeData<L>],
373 shard_tree: &mut shardtree::ShardTree<
374 shardtree::store::memory::MemoryShardStore<L, BlockHeight>,
375 DEPTH,
376 SHARD_HEIGHT,
377 >,
378) -> Result<(), ServerError>
379where
380 L: Clone + PartialEq + incrementalmerkletree::Hashable,
381 D: SyncDomain,
382{
383 let checkpoint = if let Some((_, position)) = located_trees
384 .iter()
385 .flat_map(|tree| tree.checkpoints.iter())
386 .find(|(height, _)| **height == checkpoint_height)
387 {
388 Checkpoint::at_position(*position)
389 } else {
390 let mut previous_checkpoint = None;
391 shard_tree
392 .store()
393 .for_each_checkpoint(1_000, |height, checkpoint| {
394 if *height == checkpoint_height - 1 {
395 previous_checkpoint = Some(checkpoint.clone());
396 }
397 Ok(())
398 })
399 .expect("infallible");
400
401 let tree_state = if let Some(checkpoint) = previous_checkpoint {
402 checkpoint.tree_state()
403 } else {
404 let frontiers =
405 client::get_frontiers(fetch_request_sender.clone(), checkpoint_height).await?;
406 let tree_size = match D::SHIELDED_PROTOCOL {
407 ShieldedProtocol::Sapling => frontiers.final_sapling_tree().tree_size(),
408 ShieldedProtocol::Orchard => frontiers.final_orchard_tree().tree_size(),
409 };
410 if tree_size == 0 {
411 TreeState::Empty
412 } else {
413 TreeState::AtPosition(incrementalmerkletree::Position::from(tree_size - 1))
414 }
415 };
416
417 Checkpoint::from_parts(tree_state, BTreeSet::new())
418 };
419
420 shard_tree
421 .store_mut()
422 .add_checkpoint(checkpoint_height, checkpoint)
423 .expect("infallible");
424
425 Ok(())
426}