pepper_sync/wallet/
traits.rs

1//! Traits for interfacing a wallet with the sync engine
2
3use std::collections::{BTreeMap, BTreeSet, HashMap};
4
5use tokio::sync::mpsc;
6use zip32::DiversifierIndex;
7
8use orchard::tree::MerkleHashOrchard;
9use shardtree::ShardTree;
10use shardtree::store::memory::MemoryShardStore;
11use shardtree::store::{Checkpoint, ShardStore, TreeState};
12use zcash_keys::keys::UnifiedFullViewingKey;
13use zcash_primitives::consensus::BlockHeight;
14use zcash_primitives::transaction::TxId;
15use zcash_primitives::zip32::AccountId;
16use zcash_protocol::{PoolType, ShieldedProtocol};
17
18use crate::error::{ServerError, SyncError};
19use crate::keys::transparent::TransparentAddressId;
20use crate::sync::{MAX_VERIFICATION_WINDOW, ScanRange};
21use crate::wallet::{
22    NullifierMap, OutputId, ShardTrees, SyncState, WalletBlock, WalletTransaction,
23};
24use crate::witness::LocatedTreeData;
25use crate::{Orchard, Sapling, SyncDomain, client, reset_spends};
26
27use super::{FetchRequest, ScanTarget, witness};
28
29/// Trait for interfacing wallet with the sync engine.
30pub trait SyncWallet {
31    /// Errors associated with interfacing the sync engine with wallet data
32    type Error: std::fmt::Debug + std::fmt::Display + std::error::Error;
33
34    /// Returns the block height wallet was created.
35    fn get_birthday(&self) -> Result<BlockHeight, Self::Error>;
36
37    /// Returns a reference to wallet sync state.
38    fn get_sync_state(&self) -> Result<&SyncState, Self::Error>;
39
40    /// Returns a mutable reference to wallet sync state.
41    fn get_sync_state_mut(&mut self) -> Result<&mut SyncState, Self::Error>;
42
43    /// Returns all unified full viewing keys known to this wallet.
44    fn get_unified_full_viewing_keys(
45        &self,
46    ) -> Result<HashMap<AccountId, UnifiedFullViewingKey>, Self::Error>;
47
48    /// Add orchard address to wallet's unified address list.
49    fn add_orchard_address(
50        &mut self,
51        account_id: zip32::AccountId,
52        address: orchard::Address,
53        diversifier_index: DiversifierIndex,
54    ) -> Result<(), Self::Error>;
55
56    /// Add sapling address to wallet's unified address list.
57    fn add_sapling_address(
58        &mut self,
59        account_id: zip32::AccountId,
60        address: sapling_crypto::PaymentAddress,
61        diversifier_index: DiversifierIndex,
62    ) -> Result<(), Self::Error>;
63
64    /// Returns a reference to all transparent addresses known to this wallet.
65    fn get_transparent_addresses(
66        &self,
67    ) -> Result<&BTreeMap<TransparentAddressId, String>, Self::Error>;
68
69    /// Returns a mutable reference to all transparent addresses known to this wallet.
70    fn get_transparent_addresses_mut(
71        &mut self,
72    ) -> Result<&mut BTreeMap<TransparentAddressId, String>, Self::Error>;
73
74    /// Aids in-memory wallets to only save when the wallet state has changed by setting a flag to mark that save is
75    /// required.
76    /// Persitance wallets may use the default implementation.
77    fn set_save_flag(&mut self) -> Result<(), Self::Error> {
78        Ok(())
79    }
80}
81
82/// Trait for interfacing [`crate::wallet::WalletBlock`]s with wallet data
83pub trait SyncBlocks: SyncWallet {
84    /// Get a stored wallet compact block from wallet data by block height
85    ///
86    /// Must return error if block is not found
87    fn get_wallet_block(&self, block_height: BlockHeight) -> Result<WalletBlock, Self::Error>;
88
89    /// Get mutable reference to wallet blocks
90    fn get_wallet_blocks_mut(
91        &mut self,
92    ) -> Result<&mut BTreeMap<BlockHeight, WalletBlock>, Self::Error>;
93
94    /// Append wallet compact blocks to wallet data
95    fn append_wallet_blocks(
96        &mut self,
97        mut wallet_blocks: BTreeMap<BlockHeight, WalletBlock>,
98    ) -> Result<(), Self::Error> {
99        self.get_wallet_blocks_mut()?.append(&mut wallet_blocks);
100
101        Ok(())
102    }
103
104    /// Removes all wallet blocks above the given `block_height`.
105    fn truncate_wallet_blocks(&mut self, truncate_height: BlockHeight) -> Result<(), Self::Error> {
106        self.get_wallet_blocks_mut()?
107            .retain(|block_height, _| *block_height <= truncate_height);
108
109        Ok(())
110    }
111}
112
113/// Trait for interfacing [`crate::wallet::WalletTransaction`]s with wallet data
114pub trait SyncTransactions: SyncWallet {
115    /// Get reference to wallet transactions
116    fn get_wallet_transactions(&self) -> Result<&HashMap<TxId, WalletTransaction>, Self::Error>;
117
118    /// Get mutable reference to wallet transactions
119    fn get_wallet_transactions_mut(
120        &mut self,
121    ) -> Result<&mut HashMap<TxId, WalletTransaction>, Self::Error>;
122
123    /// Insert wallet transaction
124    fn insert_wallet_transaction(
125        &mut self,
126        wallet_transaction: WalletTransaction,
127    ) -> Result<(), Self::Error> {
128        self.get_wallet_transactions_mut()?
129            .insert(wallet_transaction.txid(), wallet_transaction);
130
131        Ok(())
132    }
133
134    /// Extend wallet transaction map with new wallet transactions
135    fn extend_wallet_transactions(
136        &mut self,
137        wallet_transactions: HashMap<TxId, WalletTransaction>,
138    ) -> Result<(), Self::Error> {
139        self.get_wallet_transactions_mut()?
140            .extend(wallet_transactions);
141
142        Ok(())
143    }
144
145    /// Removes all confirmed wallet transactions above the given `block_height`.
146    /// Also sets any output's `spending_transaction` field to `None` if it's spending transaction was removed.
147    fn truncate_wallet_transactions(
148        &mut self,
149        truncate_height: BlockHeight,
150    ) -> Result<(), Self::Error> {
151        let invalid_txids: Vec<TxId> = self
152            .get_wallet_transactions()?
153            .values()
154            .filter(|tx| tx.status().is_confirmed_after(&truncate_height))
155            .map(|tx| tx.transaction().txid())
156            .collect();
157
158        let wallet_transactions = self.get_wallet_transactions_mut()?;
159        reset_spends(wallet_transactions, invalid_txids.clone());
160        for invalid_txid in &invalid_txids {
161            wallet_transactions.remove(invalid_txid);
162        }
163
164        Ok(())
165    }
166}
167
168/// Trait for interfacing nullifiers with wallet data
169pub trait SyncNullifiers: SyncWallet {
170    /// Get wallet nullifier map
171    fn get_nullifiers(&self) -> Result<&NullifierMap, Self::Error>;
172
173    /// Get mutable reference to wallet nullifier map
174    fn get_nullifiers_mut(&mut self) -> Result<&mut NullifierMap, Self::Error>;
175
176    /// Append nullifiers to wallet nullifier map
177    fn append_nullifiers(&mut self, nullifiers: &mut NullifierMap) -> Result<(), Self::Error> {
178        self.get_nullifiers_mut()?
179            .sapling
180            .append(&mut nullifiers.sapling);
181        self.get_nullifiers_mut()?
182            .orchard
183            .append(&mut nullifiers.orchard);
184
185        Ok(())
186    }
187
188    /// Removes all mapped nullifiers above the given `block_height`.
189    fn truncate_nullifiers(&mut self, truncate_height: BlockHeight) -> Result<(), Self::Error> {
190        let nullifier_map = self.get_nullifiers_mut()?;
191        nullifier_map
192            .sapling
193            .retain(|_, scan_target| scan_target.block_height <= truncate_height);
194        nullifier_map
195            .orchard
196            .retain(|_, scan_target| scan_target.block_height <= truncate_height);
197
198        Ok(())
199    }
200}
201
202/// Trait for interfacing outpoints with wallet data
203pub trait SyncOutPoints: SyncWallet {
204    /// Get wallet outpoint map
205    fn get_outpoints(&self) -> Result<&BTreeMap<OutputId, ScanTarget>, Self::Error>;
206
207    /// Get mutable reference to wallet outpoint map
208    fn get_outpoints_mut(&mut self) -> Result<&mut BTreeMap<OutputId, ScanTarget>, Self::Error>;
209
210    /// Append outpoints to wallet outpoint map
211    fn append_outpoints(
212        &mut self,
213        outpoints: &mut BTreeMap<OutputId, ScanTarget>,
214    ) -> Result<(), Self::Error> {
215        self.get_outpoints_mut()?.append(outpoints);
216
217        Ok(())
218    }
219
220    /// Removes all mapped outpoints above the given `block_height`.
221    fn truncate_outpoints(&mut self, truncate_height: BlockHeight) -> Result<(), Self::Error> {
222        self.get_outpoints_mut()?
223            .retain(|_, scan_target| scan_target.block_height <= truncate_height);
224
225        Ok(())
226    }
227}
228
229/// Trait for interfacing shard tree data with wallet data
230pub trait SyncShardTrees: SyncWallet {
231    /// Get reference to shard trees
232    fn get_shard_trees(&self) -> Result<&ShardTrees, Self::Error>;
233
234    /// Get mutable reference to shard trees
235    fn get_shard_trees_mut(&mut self) -> Result<&mut ShardTrees, Self::Error>;
236
237    /// Update wallet shard trees with new shard tree data.
238    ///
239    /// `highest_scanned_height` is the height of the highest scanned block in the wallet not including the `scan_range` we are updating.
240    fn update_shard_trees(
241        &mut self,
242        fetch_request_sender: mpsc::UnboundedSender<FetchRequest>,
243        scan_range: &ScanRange,
244        highest_scanned_height: BlockHeight,
245        sapling_located_trees: Vec<LocatedTreeData<sapling_crypto::Node>>,
246        orchard_located_trees: Vec<LocatedTreeData<MerkleHashOrchard>>,
247    ) -> impl std::future::Future<Output = Result<(), SyncError<Self::Error>>> + Send
248    where
249        Self: std::marker::Send,
250    {
251        async move {
252            let shard_trees = self.get_shard_trees_mut().map_err(SyncError::WalletError)?;
253
254            // limit the range that checkpoints are manually added to the top MAX_VERIFICATION_WINDOW scanned blocks for efficiency.
255            // As we sync the chain tip first and have spend-before-sync, we will always choose anchors very close to chain
256            // height and we will also never need to truncate to checkpoints lower than this height.
257            let checkpoint_range = if scan_range.block_range().start > highest_scanned_height {
258                let verification_window_start = scan_range
259                    .block_range()
260                    .end
261                    .saturating_sub(MAX_VERIFICATION_WINDOW);
262
263                std::cmp::max(scan_range.block_range().start, verification_window_start)
264                    ..scan_range.block_range().end
265            } else if scan_range.block_range().end
266                > highest_scanned_height.saturating_sub(MAX_VERIFICATION_WINDOW) + 1
267            {
268                let verification_window_start =
269                    highest_scanned_height.saturating_sub(MAX_VERIFICATION_WINDOW) + 1;
270
271                std::cmp::max(scan_range.block_range().start, verification_window_start)
272                    ..scan_range.block_range().end
273            } else {
274                BlockHeight::from_u32(0)..BlockHeight::from_u32(0)
275            };
276
277            // in the case that sapling and/or orchard note commitments are not in an entire block there will be no retention
278            // at that height. Therefore, to prevent anchor and truncate errors, checkpoints are manually added first and
279            // copy the tree state from the previous checkpoint where the commitment tree has not changed as of that block.
280            for checkpoint_height in
281                u32::from(checkpoint_range.start)..u32::from(checkpoint_range.end)
282            {
283                let checkpoint_height = BlockHeight::from_u32(checkpoint_height);
284
285                add_checkpoint::<
286                    Sapling,
287                    sapling_crypto::Node,
288                    { sapling_crypto::NOTE_COMMITMENT_TREE_DEPTH },
289                    { witness::SHARD_HEIGHT },
290                >(
291                    fetch_request_sender.clone(),
292                    checkpoint_height,
293                    &sapling_located_trees,
294                    &mut shard_trees.sapling,
295                )
296                .await?;
297                add_checkpoint::<
298                    Orchard,
299                    MerkleHashOrchard,
300                    { orchard::NOTE_COMMITMENT_TREE_DEPTH as u8 },
301                    { witness::SHARD_HEIGHT },
302                >(
303                    fetch_request_sender.clone(),
304                    checkpoint_height,
305                    &orchard_located_trees,
306                    &mut shard_trees.orchard,
307                )
308                .await?;
309            }
310
311            for tree in sapling_located_trees {
312                shard_trees
313                    .sapling
314                    .insert_tree(tree.subtree, tree.checkpoints)?;
315            }
316            for tree in orchard_located_trees {
317                shard_trees
318                    .orchard
319                    .insert_tree(tree.subtree, tree.checkpoints)?;
320            }
321
322            Ok(())
323        }
324    }
325
326    /// Removes all shard tree data above the given `block_height`.
327    ///
328    /// A `truncate_height` of zero should replace the shard trees with empty trees.
329    fn truncate_shard_trees(
330        &mut self,
331        truncate_height: BlockHeight,
332    ) -> Result<(), SyncError<Self::Error>> {
333        if truncate_height == zcash_protocol::consensus::H0 {
334            let shard_trees = self.get_shard_trees_mut().map_err(SyncError::WalletError)?;
335            shard_trees.sapling =
336                ShardTree::new(MemoryShardStore::empty(), MAX_VERIFICATION_WINDOW as usize);
337            shard_trees.orchard =
338                ShardTree::new(MemoryShardStore::empty(), MAX_VERIFICATION_WINDOW as usize);
339        } else {
340            if !self
341                .get_shard_trees_mut()
342                .map_err(SyncError::WalletError)?
343                .sapling
344                .truncate_to_checkpoint(&truncate_height)?
345            {
346                return Err(SyncError::TruncationError(
347                    truncate_height,
348                    PoolType::SAPLING,
349                ));
350            }
351            if !self
352                .get_shard_trees_mut()
353                .map_err(SyncError::WalletError)?
354                .orchard
355                .truncate_to_checkpoint(&truncate_height)?
356            {
357                return Err(SyncError::TruncationError(
358                    truncate_height,
359                    PoolType::ORCHARD,
360                ));
361            }
362        }
363
364        Ok(())
365    }
366}
367
368// TODO: move into `update_shard_trees` trait method
369async fn add_checkpoint<D, L, const DEPTH: u8, const SHARD_HEIGHT: u8>(
370    fetch_request_sender: mpsc::UnboundedSender<FetchRequest>,
371    checkpoint_height: BlockHeight,
372    located_trees: &[LocatedTreeData<L>],
373    shard_tree: &mut shardtree::ShardTree<
374        shardtree::store::memory::MemoryShardStore<L, BlockHeight>,
375        DEPTH,
376        SHARD_HEIGHT,
377    >,
378) -> Result<(), ServerError>
379where
380    L: Clone + PartialEq + incrementalmerkletree::Hashable,
381    D: SyncDomain,
382{
383    let checkpoint = if let Some((_, position)) = located_trees
384        .iter()
385        .flat_map(|tree| tree.checkpoints.iter())
386        .find(|(height, _)| **height == checkpoint_height)
387    {
388        Checkpoint::at_position(*position)
389    } else {
390        let mut previous_checkpoint = None;
391        shard_tree
392            .store()
393            .for_each_checkpoint(1_000, |height, checkpoint| {
394                if *height == checkpoint_height - 1 {
395                    previous_checkpoint = Some(checkpoint.clone());
396                }
397                Ok(())
398            })
399            .expect("infallible");
400
401        let tree_state = if let Some(checkpoint) = previous_checkpoint {
402            checkpoint.tree_state()
403        } else {
404            let frontiers =
405                client::get_frontiers(fetch_request_sender.clone(), checkpoint_height).await?;
406            let tree_size = match D::SHIELDED_PROTOCOL {
407                ShieldedProtocol::Sapling => frontiers.final_sapling_tree().tree_size(),
408                ShieldedProtocol::Orchard => frontiers.final_orchard_tree().tree_size(),
409            };
410            if tree_size == 0 {
411                TreeState::Empty
412            } else {
413                TreeState::AtPosition(incrementalmerkletree::Position::from(tree_size - 1))
414            }
415        };
416
417        Checkpoint::from_parts(tree_state, BTreeSet::new())
418    };
419
420    shard_tree
421        .store_mut()
422        .add_checkpoint(checkpoint_height, checkpoint)
423        .expect("infallible");
424
425    Ok(())
426}