Skip to main content

pepper_sync/wallet/
traits.rs

1//! Traits for interfacing a wallet with the sync engine
2
3use std::collections::{BTreeMap, BTreeSet, HashMap};
4
5use tokio::sync::mpsc;
6use zip32::DiversifierIndex;
7
8use orchard::tree::MerkleHashOrchard;
9use shardtree::ShardTree;
10use shardtree::store::memory::MemoryShardStore;
11use shardtree::store::{Checkpoint, ShardStore, TreeState};
12use zcash_keys::keys::UnifiedFullViewingKey;
13use zcash_primitives::transaction::TxId;
14use zcash_protocol::consensus::BlockHeight;
15use zcash_protocol::{PoolType, ShieldedProtocol};
16use zip32::AccountId;
17
18use crate::error::{ServerError, SyncError};
19use crate::keys::transparent::TransparentAddressId;
20use crate::sync::{MAX_REORG_ALLOWANCE, ScanRange};
21use crate::wallet::{
22    NullifierMap, OutputId, ShardTrees, SyncState, WalletBlock, WalletTransaction,
23};
24use crate::witness::LocatedTreeData;
25use crate::{Orchard, Sapling, SyncDomain, client, set_transactions_failed};
26
27use super::{FetchRequest, ScanTarget, witness};
28
29/// Trait for interfacing wallet with the sync engine.
30pub trait SyncWallet {
31    /// Errors associated with interfacing the sync engine with wallet data
32    type Error: std::fmt::Debug + std::fmt::Display + std::error::Error;
33
34    /// Returns the block height wallet was created.
35    fn get_birthday(&self) -> Result<BlockHeight, Self::Error>;
36
37    /// Returns a reference to wallet sync state.
38    fn get_sync_state(&self) -> Result<&SyncState, Self::Error>;
39
40    /// Returns a mutable reference to wallet sync state.
41    fn get_sync_state_mut(&mut self) -> Result<&mut SyncState, Self::Error>;
42
43    /// Returns all unified full viewing keys known to this wallet.
44    fn get_unified_full_viewing_keys(
45        &self,
46    ) -> Result<HashMap<AccountId, UnifiedFullViewingKey>, Self::Error>;
47
48    /// Add orchard address to wallet's unified address list.
49    fn add_orchard_address(
50        &mut self,
51        account_id: zip32::AccountId,
52        address: orchard::Address,
53        diversifier_index: DiversifierIndex,
54    ) -> Result<(), Self::Error>;
55
56    /// Add sapling address to wallet's unified address list.
57    fn add_sapling_address(
58        &mut self,
59        account_id: zip32::AccountId,
60        address: sapling_crypto::PaymentAddress,
61        diversifier_index: DiversifierIndex,
62    ) -> Result<(), Self::Error>;
63
64    /// Returns a reference to all transparent addresses known to this wallet.
65    fn get_transparent_addresses(
66        &self,
67    ) -> Result<&BTreeMap<TransparentAddressId, String>, Self::Error>;
68
69    /// Returns a mutable reference to all transparent addresses known to this wallet.
70    fn get_transparent_addresses_mut(
71        &mut self,
72    ) -> Result<&mut BTreeMap<TransparentAddressId, String>, Self::Error>;
73
74    /// Aids in-memory wallets to only save when the wallet state has changed by setting a flag to mark that save is
75    /// required.
76    /// Persitance wallets may use the default implementation.
77    fn set_save_flag(&mut self) -> Result<(), Self::Error> {
78        Ok(())
79    }
80}
81
82/// Trait for interfacing [`crate::wallet::WalletBlock`]s with wallet data
83pub trait SyncBlocks: SyncWallet {
84    /// Get a stored wallet compact block from wallet data by block height
85    ///
86    /// Must return error if block is not found
87    fn get_wallet_block(&self, block_height: BlockHeight) -> Result<WalletBlock, Self::Error>;
88
89    /// Get mutable reference to wallet blocks
90    fn get_wallet_blocks_mut(
91        &mut self,
92    ) -> Result<&mut BTreeMap<BlockHeight, WalletBlock>, Self::Error>;
93
94    /// Append wallet compact blocks to wallet data
95    fn append_wallet_blocks(
96        &mut self,
97        mut wallet_blocks: BTreeMap<BlockHeight, WalletBlock>,
98    ) -> Result<(), Self::Error> {
99        self.get_wallet_blocks_mut()?.append(&mut wallet_blocks);
100
101        Ok(())
102    }
103
104    /// Removes all wallet blocks above the given `block_height`.
105    fn truncate_wallet_blocks(&mut self, truncate_height: BlockHeight) -> Result<(), Self::Error> {
106        self.get_wallet_blocks_mut()?
107            .retain(|block_height, _| *block_height <= truncate_height);
108
109        Ok(())
110    }
111}
112
113/// Trait for interfacing [`crate::wallet::WalletTransaction`]s with wallet data
114pub trait SyncTransactions: SyncWallet {
115    /// Get reference to wallet transactions
116    fn get_wallet_transactions(&self) -> Result<&HashMap<TxId, WalletTransaction>, Self::Error>;
117
118    /// Get mutable reference to wallet transactions
119    fn get_wallet_transactions_mut(
120        &mut self,
121    ) -> Result<&mut HashMap<TxId, WalletTransaction>, Self::Error>;
122
123    /// Insert wallet transaction
124    fn insert_wallet_transaction(
125        &mut self,
126        wallet_transaction: WalletTransaction,
127    ) -> Result<(), Self::Error> {
128        self.get_wallet_transactions_mut()?
129            .insert(wallet_transaction.txid(), wallet_transaction);
130
131        Ok(())
132    }
133
134    /// Extend wallet transaction map with new wallet transactions
135    fn extend_wallet_transactions(
136        &mut self,
137        wallet_transactions: HashMap<TxId, WalletTransaction>,
138    ) -> Result<(), Self::Error> {
139        self.get_wallet_transactions_mut()?
140            .extend(wallet_transactions);
141
142        Ok(())
143    }
144
145    /// Sets all confirmed wallet transactions above the given `block_height` to `Failed` status.
146    /// Also sets any output's `spending_transaction` field to `None` if it's spending transaction was set to `Failed`
147    /// status.
148    fn truncate_wallet_transactions(
149        &mut self,
150        truncate_height: BlockHeight,
151    ) -> Result<(), Self::Error> {
152        let invalid_txids: Vec<TxId> = self
153            .get_wallet_transactions()?
154            .values()
155            .filter(|tx| tx.status().is_confirmed_after(&truncate_height))
156            .map(|tx| tx.transaction().txid())
157            .collect();
158
159        set_transactions_failed(self.get_wallet_transactions_mut()?, invalid_txids);
160
161        Ok(())
162    }
163}
164
165/// Trait for interfacing nullifiers with wallet data
166pub trait SyncNullifiers: SyncWallet {
167    /// Get wallet nullifier map
168    fn get_nullifiers(&self) -> Result<&NullifierMap, Self::Error>;
169
170    /// Get mutable reference to wallet nullifier map
171    fn get_nullifiers_mut(&mut self) -> Result<&mut NullifierMap, Self::Error>;
172
173    /// Append nullifiers to wallet nullifier map
174    fn append_nullifiers(&mut self, nullifiers: &mut NullifierMap) -> Result<(), Self::Error> {
175        self.get_nullifiers_mut()?
176            .sapling
177            .append(&mut nullifiers.sapling);
178        self.get_nullifiers_mut()?
179            .orchard
180            .append(&mut nullifiers.orchard);
181
182        Ok(())
183    }
184
185    /// Removes all mapped nullifiers above the given `block_height`.
186    fn truncate_nullifiers(&mut self, truncate_height: BlockHeight) -> Result<(), Self::Error> {
187        let nullifier_map = self.get_nullifiers_mut()?;
188        nullifier_map
189            .sapling
190            .retain(|_, scan_target| scan_target.block_height <= truncate_height);
191        nullifier_map
192            .orchard
193            .retain(|_, scan_target| scan_target.block_height <= truncate_height);
194
195        Ok(())
196    }
197}
198
199/// Trait for interfacing outpoints with wallet data
200pub trait SyncOutPoints: SyncWallet {
201    /// Get wallet outpoint map
202    fn get_outpoints(&self) -> Result<&BTreeMap<OutputId, ScanTarget>, Self::Error>;
203
204    /// Get mutable reference to wallet outpoint map
205    fn get_outpoints_mut(&mut self) -> Result<&mut BTreeMap<OutputId, ScanTarget>, Self::Error>;
206
207    /// Append outpoints to wallet outpoint map
208    fn append_outpoints(
209        &mut self,
210        outpoints: &mut BTreeMap<OutputId, ScanTarget>,
211    ) -> Result<(), Self::Error> {
212        self.get_outpoints_mut()?.append(outpoints);
213
214        Ok(())
215    }
216
217    /// Removes all mapped outpoints above the given `block_height`.
218    fn truncate_outpoints(&mut self, truncate_height: BlockHeight) -> Result<(), Self::Error> {
219        self.get_outpoints_mut()?
220            .retain(|_, scan_target| scan_target.block_height <= truncate_height);
221
222        Ok(())
223    }
224}
225
226/// Trait for interfacing shard tree data with wallet data
227pub trait SyncShardTrees: SyncWallet {
228    /// Get reference to shard trees
229    fn get_shard_trees(&self) -> Result<&ShardTrees, Self::Error>;
230
231    /// Get mutable reference to shard trees
232    fn get_shard_trees_mut(&mut self) -> Result<&mut ShardTrees, Self::Error>;
233
234    /// Update wallet shard trees with new shard tree data.
235    ///
236    /// `highest_scanned_height` is the height of the highest scanned block in the wallet not including the `scan_range` we are updating.
237    fn update_shard_trees(
238        &mut self,
239        fetch_request_sender: mpsc::UnboundedSender<FetchRequest>,
240        scan_range: &ScanRange,
241        highest_scanned_height: BlockHeight,
242        sapling_located_trees: Vec<LocatedTreeData<sapling_crypto::Node>>,
243        orchard_located_trees: Vec<LocatedTreeData<MerkleHashOrchard>>,
244    ) -> impl std::future::Future<Output = Result<(), SyncError<Self::Error>>> + Send
245    where
246        Self: std::marker::Send,
247    {
248        async move {
249            let shard_trees = self.get_shard_trees_mut().map_err(SyncError::WalletError)?;
250
251            // limit the range that checkpoints are manually added to the top MAX_REORG_ALLOWANCE scanned blocks for efficiency.
252            // As we sync the chain tip first and have spend-before-sync, we will always choose anchors very close to chain
253            // height and we will also never need to truncate to checkpoints lower than this height.
254            let checkpoint_range = if scan_range.block_range().start > highest_scanned_height {
255                let verification_window_start = scan_range
256                    .block_range()
257                    .end
258                    .saturating_sub(MAX_REORG_ALLOWANCE);
259
260                std::cmp::max(scan_range.block_range().start, verification_window_start)
261                    ..scan_range.block_range().end
262            } else if scan_range.block_range().end
263                > highest_scanned_height.saturating_sub(MAX_REORG_ALLOWANCE) + 1
264            {
265                let verification_window_start =
266                    highest_scanned_height.saturating_sub(MAX_REORG_ALLOWANCE) + 1;
267
268                std::cmp::max(scan_range.block_range().start, verification_window_start)
269                    ..scan_range.block_range().end
270            } else {
271                BlockHeight::from_u32(0)..BlockHeight::from_u32(0)
272            };
273
274            // in the case that sapling and/or orchard note commitments are not in an entire block there will be no retention
275            // at that height. Therefore, to prevent anchor and truncate errors, checkpoints are manually added first and
276            // copy the tree state from the previous checkpoint where the commitment tree has not changed as of that block.
277            for checkpoint_height in
278                u32::from(checkpoint_range.start)..u32::from(checkpoint_range.end)
279            {
280                let checkpoint_height = BlockHeight::from_u32(checkpoint_height);
281
282                add_checkpoint::<
283                    Sapling,
284                    sapling_crypto::Node,
285                    { sapling_crypto::NOTE_COMMITMENT_TREE_DEPTH },
286                    { witness::SHARD_HEIGHT },
287                >(
288                    fetch_request_sender.clone(),
289                    checkpoint_height,
290                    &sapling_located_trees,
291                    &mut shard_trees.sapling,
292                )
293                .await?;
294                add_checkpoint::<
295                    Orchard,
296                    MerkleHashOrchard,
297                    { orchard::NOTE_COMMITMENT_TREE_DEPTH as u8 },
298                    { witness::SHARD_HEIGHT },
299                >(
300                    fetch_request_sender.clone(),
301                    checkpoint_height,
302                    &orchard_located_trees,
303                    &mut shard_trees.orchard,
304                )
305                .await?;
306            }
307
308            for tree in sapling_located_trees {
309                shard_trees
310                    .sapling
311                    .insert_tree(tree.subtree, tree.checkpoints)?;
312            }
313            for tree in orchard_located_trees {
314                shard_trees
315                    .orchard
316                    .insert_tree(tree.subtree, tree.checkpoints)?;
317            }
318
319            Ok(())
320        }
321    }
322
323    /// Removes all shard tree data above the given `block_height`.
324    ///
325    /// A `truncate_height` of zero should replace the shard trees with empty trees.
326    fn truncate_shard_trees(
327        &mut self,
328        truncate_height: BlockHeight,
329    ) -> Result<(), SyncError<Self::Error>> {
330        if truncate_height == zcash_protocol::consensus::H0 {
331            let shard_trees = self.get_shard_trees_mut().map_err(SyncError::WalletError)?;
332            tracing::info!("Clearing shard trees.");
333            shard_trees.sapling =
334                ShardTree::new(MemoryShardStore::empty(), MAX_REORG_ALLOWANCE as usize);
335            shard_trees.orchard =
336                ShardTree::new(MemoryShardStore::empty(), MAX_REORG_ALLOWANCE as usize);
337        } else {
338            if !self
339                .get_shard_trees_mut()
340                .map_err(SyncError::WalletError)?
341                .sapling
342                .truncate_to_checkpoint(&truncate_height)?
343            {
344                tracing::error!("Sapling shard tree is broken! Beginning rescan.");
345                return Err(SyncError::TruncationError(
346                    truncate_height,
347                    PoolType::SAPLING,
348                ));
349            }
350            if !self
351                .get_shard_trees_mut()
352                .map_err(SyncError::WalletError)?
353                .orchard
354                .truncate_to_checkpoint(&truncate_height)?
355            {
356                tracing::error!("Sapling shard tree is broken! Beginning rescan.");
357                return Err(SyncError::TruncationError(
358                    truncate_height,
359                    PoolType::ORCHARD,
360                ));
361            }
362        }
363
364        Ok(())
365    }
366}
367
368// TODO: move into `update_shard_trees` trait method
369async fn add_checkpoint<D, L, const DEPTH: u8, const SHARD_HEIGHT: u8>(
370    fetch_request_sender: mpsc::UnboundedSender<FetchRequest>,
371    checkpoint_height: BlockHeight,
372    located_trees: &[LocatedTreeData<L>],
373    shard_tree: &mut shardtree::ShardTree<
374        shardtree::store::memory::MemoryShardStore<L, BlockHeight>,
375        DEPTH,
376        SHARD_HEIGHT,
377    >,
378) -> Result<(), ServerError>
379where
380    L: Clone + PartialEq + incrementalmerkletree::Hashable,
381    D: SyncDomain,
382{
383    let checkpoint = if let Some((_, position)) = located_trees
384        .iter()
385        .flat_map(|tree| tree.checkpoints.iter())
386        .find(|(height, _)| **height == checkpoint_height)
387    {
388        Checkpoint::at_position(*position)
389    } else {
390        let mut previous_checkpoint = None;
391        shard_tree
392            .store()
393            .for_each_checkpoint(1_000, |height, checkpoint| {
394                if *height == checkpoint_height - 1 {
395                    previous_checkpoint = Some(checkpoint.clone());
396                }
397                Ok(())
398            })
399            .expect("infallible");
400
401        let tree_state = if let Some(checkpoint) = previous_checkpoint {
402            checkpoint.tree_state()
403        } else {
404            let frontiers =
405                client::get_frontiers(fetch_request_sender.clone(), checkpoint_height).await?;
406            let tree_size = match D::SHIELDED_PROTOCOL {
407                ShieldedProtocol::Sapling => frontiers.final_sapling_tree().tree_size(),
408                ShieldedProtocol::Orchard => frontiers.final_orchard_tree().tree_size(),
409            };
410            if tree_size == 0 {
411                TreeState::Empty
412            } else {
413                TreeState::AtPosition(incrementalmerkletree::Position::from(tree_size - 1))
414            }
415        };
416
417        Checkpoint::from_parts(tree_state, BTreeSet::new())
418    };
419
420    shard_tree
421        .store_mut()
422        .add_checkpoint(checkpoint_height, checkpoint)
423        .expect("infallible");
424
425    Ok(())
426}