1use std::collections::{BTreeMap, BTreeSet, HashMap};
4
5use orchard::tree::MerkleHashOrchard;
6use shardtree::store::{Checkpoint, ShardStore, TreeState};
7use tokio::sync::mpsc;
8use zcash_client_backend::data_api::scanning::ScanRange;
9use zcash_client_backend::keys::UnifiedFullViewingKey;
10use zcash_primitives::consensus::BlockHeight;
11use zcash_primitives::transaction::TxId;
12use zcash_primitives::zip32::AccountId;
13use zcash_protocol::ShieldedProtocol;
14
15use crate::error::{ServerError, SyncError};
16use crate::keys::transparent::TransparentAddressId;
17use crate::sync::MAX_VERIFICATION_WINDOW;
18use crate::wallet::{
19 Locator, NullifierMap, OutputId, ShardTrees, SyncState, WalletBlock, WalletTransaction,
20};
21use crate::witness::LocatedTreeData;
22use crate::{Orchard, Sapling, SyncDomain, client};
23
24use super::{FetchRequest, witness};
25
26pub trait SyncWallet {
28 type Error: std::fmt::Debug + std::fmt::Display + std::error::Error;
30
31 fn get_birthday(&self) -> Result<BlockHeight, Self::Error>;
33
34 fn get_sync_state(&self) -> Result<&SyncState, Self::Error>;
36
37 fn get_sync_state_mut(&mut self) -> Result<&mut SyncState, Self::Error>;
39
40 fn get_unified_full_viewing_keys(
42 &self,
43 ) -> Result<HashMap<AccountId, UnifiedFullViewingKey>, Self::Error>;
44
45 fn get_transparent_addresses(
47 &self,
48 ) -> Result<&BTreeMap<TransparentAddressId, String>, Self::Error>;
49
50 fn get_transparent_addresses_mut(
52 &mut self,
53 ) -> Result<&mut BTreeMap<TransparentAddressId, String>, Self::Error>;
54
55 fn set_save_flag(&mut self) -> Result<(), Self::Error> {
59 Ok(())
60 }
61}
62
63pub trait SyncBlocks: SyncWallet {
65 fn get_wallet_block(&self, block_height: BlockHeight) -> Result<WalletBlock, Self::Error>;
69
70 fn get_wallet_blocks_mut(
72 &mut self,
73 ) -> Result<&mut BTreeMap<BlockHeight, WalletBlock>, Self::Error>;
74
75 fn append_wallet_blocks(
77 &mut self,
78 mut wallet_blocks: BTreeMap<BlockHeight, WalletBlock>,
79 ) -> Result<(), Self::Error> {
80 self.get_wallet_blocks_mut()?.append(&mut wallet_blocks);
81
82 Ok(())
83 }
84
85 fn truncate_wallet_blocks(&mut self, truncate_height: BlockHeight) -> Result<(), Self::Error> {
87 self.get_wallet_blocks_mut()?
88 .retain(|block_height, _| *block_height <= truncate_height);
89
90 Ok(())
91 }
92}
93
94pub trait SyncTransactions: SyncWallet {
96 fn get_wallet_transactions(&self) -> Result<&HashMap<TxId, WalletTransaction>, Self::Error>;
98
99 fn get_wallet_transactions_mut(
101 &mut self,
102 ) -> Result<&mut HashMap<TxId, WalletTransaction>, Self::Error>;
103
104 fn insert_wallet_transaction(
106 &mut self,
107 wallet_transaction: WalletTransaction,
108 ) -> Result<(), Self::Error> {
109 self.get_wallet_transactions_mut()?
110 .insert(wallet_transaction.txid(), wallet_transaction);
111
112 Ok(())
113 }
114
115 fn extend_wallet_transactions(
117 &mut self,
118 wallet_transactions: HashMap<TxId, WalletTransaction>,
119 ) -> Result<(), Self::Error> {
120 self.get_wallet_transactions_mut()?
121 .extend(wallet_transactions);
122
123 Ok(())
124 }
125
126 fn truncate_wallet_transactions(
129 &mut self,
130 truncate_height: BlockHeight,
131 ) -> Result<(), Self::Error> {
132 let invalid_txids: Vec<TxId> = self
134 .get_wallet_transactions()?
135 .values()
136 .filter(|tx| tx.status().is_confirmed_after(&truncate_height))
137 .map(|tx| tx.transaction().txid())
138 .collect();
139
140 let wallet_transactions = self.get_wallet_transactions_mut()?;
141 wallet_transactions
142 .values_mut()
143 .flat_map(|tx| tx.sapling_notes_mut())
144 .filter(|note| {
145 note.spending_transaction.map_or_else(
146 || false,
147 |spending_txid| invalid_txids.contains(&spending_txid),
148 )
149 })
150 .for_each(|note| {
151 note.spending_transaction = None;
152 });
153 wallet_transactions
154 .values_mut()
155 .flat_map(|tx| tx.orchard_notes_mut())
156 .filter(|note| {
157 note.spending_transaction.map_or_else(
158 || false,
159 |spending_txid| invalid_txids.contains(&spending_txid),
160 )
161 })
162 .for_each(|note| {
163 note.spending_transaction = None;
164 });
165
166 invalid_txids.iter().for_each(|invalid_txid| {
167 wallet_transactions.remove(invalid_txid);
168 });
169
170 Ok(())
171 }
172}
173
174pub trait SyncNullifiers: SyncWallet {
176 fn get_nullifiers(&self) -> Result<&NullifierMap, Self::Error>;
178
179 fn get_nullifiers_mut(&mut self) -> Result<&mut NullifierMap, Self::Error>;
181
182 fn append_nullifiers(&mut self, mut nullifiers: NullifierMap) -> Result<(), Self::Error> {
184 self.get_nullifiers_mut()?
185 .sapling
186 .append(&mut nullifiers.sapling);
187 self.get_nullifiers_mut()?
188 .orchard
189 .append(&mut nullifiers.orchard);
190
191 Ok(())
192 }
193
194 fn truncate_nullifiers(&mut self, truncate_height: BlockHeight) -> Result<(), Self::Error> {
196 let nullifier_map = self.get_nullifiers_mut()?;
197 nullifier_map
198 .sapling
199 .retain(|_, (block_height, _)| *block_height <= truncate_height);
200 nullifier_map
201 .orchard
202 .retain(|_, (block_height, _)| *block_height <= truncate_height);
203
204 Ok(())
205 }
206}
207
208pub trait SyncOutPoints: SyncWallet {
210 fn get_outpoints(&self) -> Result<&BTreeMap<OutputId, Locator>, Self::Error>;
212
213 fn get_outpoints_mut(&mut self) -> Result<&mut BTreeMap<OutputId, Locator>, Self::Error>;
215
216 fn append_outpoints(
218 &mut self,
219 outpoints: &mut BTreeMap<OutputId, Locator>,
220 ) -> Result<(), Self::Error> {
221 self.get_outpoints_mut()?.append(outpoints);
222
223 Ok(())
224 }
225
226 fn truncate_outpoints(&mut self, truncate_height: BlockHeight) -> Result<(), Self::Error> {
228 self.get_outpoints_mut()?
229 .retain(|_, (block_height, _)| *block_height <= truncate_height);
230
231 Ok(())
232 }
233}
234
235pub trait SyncShardTrees: SyncWallet {
237 fn get_shard_trees(&self) -> Result<&ShardTrees, Self::Error>;
239
240 fn get_shard_trees_mut(&mut self) -> Result<&mut ShardTrees, Self::Error>;
242
243 fn update_shard_trees(
245 &mut self,
246 fetch_request_sender: mpsc::UnboundedSender<FetchRequest>,
247 scan_range: &ScanRange,
248 wallet_height: BlockHeight,
249 sapling_located_trees: Vec<LocatedTreeData<sapling_crypto::Node>>,
250 orchard_located_trees: Vec<LocatedTreeData<MerkleHashOrchard>>,
251 ) -> impl std::future::Future<Output = Result<(), SyncError<Self::Error>>> + Send
252 where
253 Self: std::marker::Send,
254 {
255 async move {
256 let shard_trees = self.get_shard_trees_mut().map_err(SyncError::WalletError)?;
257
258 let checkpoint_range = match (
262 scan_range.block_range().start
263 > wallet_height.saturating_sub(MAX_VERIFICATION_WINDOW),
264 scan_range.block_range().end - 1
265 > wallet_height.saturating_sub(MAX_VERIFICATION_WINDOW),
266 ) {
267 (true, _) => scan_range.block_range().clone(),
268 (false, true) => {
269 (wallet_height - MAX_VERIFICATION_WINDOW)..scan_range.block_range().end
270 }
271 (false, false) => BlockHeight::from_u32(0)..BlockHeight::from_u32(0),
272 };
273
274 for checkpoint_height in
278 u32::from(checkpoint_range.start)..u32::from(checkpoint_range.end)
279 {
280 let checkpoint_height = BlockHeight::from_u32(checkpoint_height);
281
282 add_checkpoint::<
283 Sapling,
284 sapling_crypto::Node,
285 { sapling_crypto::NOTE_COMMITMENT_TREE_DEPTH },
286 { witness::SHARD_HEIGHT },
287 >(
288 fetch_request_sender.clone(),
289 checkpoint_height,
290 &sapling_located_trees,
291 &mut shard_trees.sapling,
292 )
293 .await?;
294 add_checkpoint::<
295 Orchard,
296 MerkleHashOrchard,
297 { orchard::NOTE_COMMITMENT_TREE_DEPTH as u8 },
298 { witness::SHARD_HEIGHT },
299 >(
300 fetch_request_sender.clone(),
301 checkpoint_height,
302 &orchard_located_trees,
303 &mut shard_trees.orchard,
304 )
305 .await?;
306 }
307
308 for tree in sapling_located_trees.into_iter() {
309 shard_trees
310 .sapling
311 .insert_tree(tree.subtree, tree.checkpoints)?;
312 }
313 for tree in orchard_located_trees.into_iter() {
314 shard_trees
315 .orchard
316 .insert_tree(tree.subtree, tree.checkpoints)?;
317 }
318
319 Ok(())
320 }
321 }
322
323 fn truncate_shard_trees(
325 &mut self,
326 truncate_height: BlockHeight,
327 ) -> Result<(), SyncError<Self::Error>> {
328 if !self
329 .get_shard_trees_mut()
330 .map_err(SyncError::WalletError)?
331 .sapling
332 .truncate_to_checkpoint(&truncate_height)?
333 {
334 panic!("max checkpoints should always be higher or equal to max verification window!");
335 }
336 if !self
337 .get_shard_trees_mut()
338 .map_err(SyncError::WalletError)?
339 .orchard
340 .truncate_to_checkpoint(&truncate_height)?
341 {
342 panic!("max checkpoints should always be higher or equal to max verification window!");
343 }
344
345 Ok(())
346 }
347}
348
349async fn add_checkpoint<D, L, const DEPTH: u8, const SHARD_HEIGHT: u8>(
350 fetch_request_sender: mpsc::UnboundedSender<FetchRequest>,
351 checkpoint_height: BlockHeight,
352 located_trees: &[LocatedTreeData<L>],
353 shard_tree: &mut shardtree::ShardTree<
354 shardtree::store::memory::MemoryShardStore<L, BlockHeight>,
355 DEPTH,
356 SHARD_HEIGHT,
357 >,
358) -> Result<(), ServerError>
359where
360 L: Clone + PartialEq + incrementalmerkletree::Hashable,
361 D: SyncDomain,
362{
363 let checkpoint = if let Some((_, position)) = located_trees
364 .iter()
365 .flat_map(|tree| tree.checkpoints.iter())
366 .find(|(height, _)| **height == checkpoint_height)
367 {
368 Checkpoint::at_position(*position)
369 } else {
370 let mut previous_checkpoint = None;
371 shard_tree
372 .store()
373 .for_each_checkpoint(1_000, |height, checkpoint| {
374 if *height == checkpoint_height - 1 {
375 previous_checkpoint = Some(checkpoint.clone());
376 }
377 Ok(())
378 })
379 .expect("infallible");
380
381 let tree_state = if let Some(checkpoint) = previous_checkpoint {
382 checkpoint.tree_state()
383 } else {
384 let frontiers =
385 client::get_frontiers(fetch_request_sender.clone(), checkpoint_height - 1).await?;
386 let tree_size = match D::SHIELDED_PROTOCOL {
387 ShieldedProtocol::Sapling => frontiers.final_sapling_tree().tree_size(),
388 ShieldedProtocol::Orchard => frontiers.final_orchard_tree().tree_size(),
389 };
390 if tree_size == 0 {
391 TreeState::Empty
392 } else {
393 TreeState::AtPosition(incrementalmerkletree::Position::from(tree_size - 1))
394 }
395 };
396
397 Checkpoint::from_parts(tree_state, BTreeSet::new())
398 };
399
400 shard_tree
401 .store_mut()
402 .add_checkpoint(checkpoint_height, checkpoint)
403 .expect("infallible");
404
405 Ok(())
406}