pepper_sync/wallet/
traits.rs

1//! Traits for interfacing a wallet with the sync engine
2
3use std::collections::{BTreeMap, BTreeSet, HashMap};
4
5use orchard::tree::MerkleHashOrchard;
6use shardtree::store::{Checkpoint, ShardStore, TreeState};
7use tokio::sync::mpsc;
8use zcash_client_backend::data_api::scanning::ScanRange;
9use zcash_client_backend::keys::UnifiedFullViewingKey;
10use zcash_primitives::consensus::BlockHeight;
11use zcash_primitives::transaction::TxId;
12use zcash_primitives::zip32::AccountId;
13use zcash_protocol::ShieldedProtocol;
14
15use crate::error::{ServerError, SyncError};
16use crate::keys::transparent::TransparentAddressId;
17use crate::sync::MAX_VERIFICATION_WINDOW;
18use crate::wallet::{
19    Locator, NullifierMap, OutputId, ShardTrees, SyncState, WalletBlock, WalletTransaction,
20};
21use crate::witness::LocatedTreeData;
22use crate::{Orchard, Sapling, SyncDomain, client};
23
24use super::{FetchRequest, witness};
25
26/// Trait for interfacing wallet with the sync engine.
27pub trait SyncWallet {
28    /// Errors associated with interfacing the sync engine with wallet data
29    type Error: std::fmt::Debug + std::fmt::Display + std::error::Error;
30
31    /// Returns the block height wallet was created.
32    fn get_birthday(&self) -> Result<BlockHeight, Self::Error>;
33
34    /// Returns a reference to wallet sync state.
35    fn get_sync_state(&self) -> Result<&SyncState, Self::Error>;
36
37    /// Returns a mutable reference to wallet sync state.
38    fn get_sync_state_mut(&mut self) -> Result<&mut SyncState, Self::Error>;
39
40    /// Returns all unified full viewing keys known to this wallet.
41    fn get_unified_full_viewing_keys(
42        &self,
43    ) -> Result<HashMap<AccountId, UnifiedFullViewingKey>, Self::Error>;
44
45    /// Returns a reference to all the transparent addresses known to this wallet.
46    fn get_transparent_addresses(
47        &self,
48    ) -> Result<&BTreeMap<TransparentAddressId, String>, Self::Error>;
49
50    /// Returns a mutable reference to all the transparent addresses known to this wallet.
51    fn get_transparent_addresses_mut(
52        &mut self,
53    ) -> Result<&mut BTreeMap<TransparentAddressId, String>, Self::Error>;
54
55    /// Aids in-memory wallets to only save when the wallet state has changed by setting a flag to mark that save is
56    /// required.
57    /// Persitance wallets may use the default implementation.
58    fn set_save_flag(&mut self) -> Result<(), Self::Error> {
59        Ok(())
60    }
61}
62
63/// Trait for interfacing [`crate::wallet::WalletBlock`]s with wallet data
64pub trait SyncBlocks: SyncWallet {
65    /// Get a stored wallet compact block from wallet data by block height
66    ///
67    /// Must return error if block is not found
68    fn get_wallet_block(&self, block_height: BlockHeight) -> Result<WalletBlock, Self::Error>;
69
70    /// Get mutable reference to wallet blocks
71    fn get_wallet_blocks_mut(
72        &mut self,
73    ) -> Result<&mut BTreeMap<BlockHeight, WalletBlock>, Self::Error>;
74
75    /// Append wallet compact blocks to wallet data
76    fn append_wallet_blocks(
77        &mut self,
78        mut wallet_blocks: BTreeMap<BlockHeight, WalletBlock>,
79    ) -> Result<(), Self::Error> {
80        self.get_wallet_blocks_mut()?.append(&mut wallet_blocks);
81
82        Ok(())
83    }
84
85    /// Removes all wallet blocks above the given `block_height`.
86    fn truncate_wallet_blocks(&mut self, truncate_height: BlockHeight) -> Result<(), Self::Error> {
87        self.get_wallet_blocks_mut()?
88            .retain(|block_height, _| *block_height <= truncate_height);
89
90        Ok(())
91    }
92}
93
94/// Trait for interfacing [`crate::wallet::WalletTransaction`]s with wallet data
95pub trait SyncTransactions: SyncWallet {
96    /// Get reference to wallet transactions
97    fn get_wallet_transactions(&self) -> Result<&HashMap<TxId, WalletTransaction>, Self::Error>;
98
99    /// Get mutable reference to wallet transactions
100    fn get_wallet_transactions_mut(
101        &mut self,
102    ) -> Result<&mut HashMap<TxId, WalletTransaction>, Self::Error>;
103
104    /// Insert wallet transaction
105    fn insert_wallet_transaction(
106        &mut self,
107        wallet_transaction: WalletTransaction,
108    ) -> Result<(), Self::Error> {
109        self.get_wallet_transactions_mut()?
110            .insert(wallet_transaction.txid(), wallet_transaction);
111
112        Ok(())
113    }
114
115    /// Extend wallet transaction map with new wallet transactions
116    fn extend_wallet_transactions(
117        &mut self,
118        wallet_transactions: HashMap<TxId, WalletTransaction>,
119    ) -> Result<(), Self::Error> {
120        self.get_wallet_transactions_mut()?
121            .extend(wallet_transactions);
122
123        Ok(())
124    }
125
126    /// Removes all confirmed wallet transactions above the given `block_height`.
127    /// Also sets any output's spending_transaction field to `None` if it's spending transaction was removed.
128    fn truncate_wallet_transactions(
129        &mut self,
130        truncate_height: BlockHeight,
131    ) -> Result<(), Self::Error> {
132        // TODO: Replace with `extract_if()` when it's in stable rust
133        let invalid_txids: Vec<TxId> = self
134            .get_wallet_transactions()?
135            .values()
136            .filter(|tx| tx.status().is_confirmed_after(&truncate_height))
137            .map(|tx| tx.transaction().txid())
138            .collect();
139
140        let wallet_transactions = self.get_wallet_transactions_mut()?;
141        wallet_transactions
142            .values_mut()
143            .flat_map(|tx| tx.sapling_notes_mut())
144            .filter(|note| {
145                note.spending_transaction.map_or_else(
146                    || false,
147                    |spending_txid| invalid_txids.contains(&spending_txid),
148                )
149            })
150            .for_each(|note| {
151                note.spending_transaction = None;
152            });
153        wallet_transactions
154            .values_mut()
155            .flat_map(|tx| tx.orchard_notes_mut())
156            .filter(|note| {
157                note.spending_transaction.map_or_else(
158                    || false,
159                    |spending_txid| invalid_txids.contains(&spending_txid),
160                )
161            })
162            .for_each(|note| {
163                note.spending_transaction = None;
164            });
165
166        invalid_txids.iter().for_each(|invalid_txid| {
167            wallet_transactions.remove(invalid_txid);
168        });
169
170        Ok(())
171    }
172}
173
174/// Trait for interfacing nullifiers with wallet data
175pub trait SyncNullifiers: SyncWallet {
176    /// Get wallet nullifier map
177    fn get_nullifiers(&self) -> Result<&NullifierMap, Self::Error>;
178
179    /// Get mutable reference to wallet nullifier map
180    fn get_nullifiers_mut(&mut self) -> Result<&mut NullifierMap, Self::Error>;
181
182    /// Append nullifiers to wallet nullifier map
183    fn append_nullifiers(&mut self, mut nullifiers: NullifierMap) -> Result<(), Self::Error> {
184        self.get_nullifiers_mut()?
185            .sapling
186            .append(&mut nullifiers.sapling);
187        self.get_nullifiers_mut()?
188            .orchard
189            .append(&mut nullifiers.orchard);
190
191        Ok(())
192    }
193
194    /// Removes all mapped nullifiers above the given `block_height`.
195    fn truncate_nullifiers(&mut self, truncate_height: BlockHeight) -> Result<(), Self::Error> {
196        let nullifier_map = self.get_nullifiers_mut()?;
197        nullifier_map
198            .sapling
199            .retain(|_, (block_height, _)| *block_height <= truncate_height);
200        nullifier_map
201            .orchard
202            .retain(|_, (block_height, _)| *block_height <= truncate_height);
203
204        Ok(())
205    }
206}
207
208/// Trait for interfacing outpoints with wallet data
209pub trait SyncOutPoints: SyncWallet {
210    /// Get wallet outpoint map
211    fn get_outpoints(&self) -> Result<&BTreeMap<OutputId, Locator>, Self::Error>;
212
213    /// Get mutable reference to wallet outpoint map
214    fn get_outpoints_mut(&mut self) -> Result<&mut BTreeMap<OutputId, Locator>, Self::Error>;
215
216    /// Append outpoints to wallet outpoint map
217    fn append_outpoints(
218        &mut self,
219        outpoints: &mut BTreeMap<OutputId, Locator>,
220    ) -> Result<(), Self::Error> {
221        self.get_outpoints_mut()?.append(outpoints);
222
223        Ok(())
224    }
225
226    /// Removes all mapped outpoints above the given `block_height`.
227    fn truncate_outpoints(&mut self, truncate_height: BlockHeight) -> Result<(), Self::Error> {
228        self.get_outpoints_mut()?
229            .retain(|_, (block_height, _)| *block_height <= truncate_height);
230
231        Ok(())
232    }
233}
234
235/// Trait for interfacing shard tree data with wallet data
236pub trait SyncShardTrees: SyncWallet {
237    /// Get reference to shard trees
238    fn get_shard_trees(&self) -> Result<&ShardTrees, Self::Error>;
239
240    /// Get mutable reference to shard trees
241    fn get_shard_trees_mut(&mut self) -> Result<&mut ShardTrees, Self::Error>;
242
243    /// Update wallet shard trees with new shard tree data
244    fn update_shard_trees(
245        &mut self,
246        fetch_request_sender: mpsc::UnboundedSender<FetchRequest>,
247        scan_range: &ScanRange,
248        wallet_height: BlockHeight,
249        sapling_located_trees: Vec<LocatedTreeData<sapling_crypto::Node>>,
250        orchard_located_trees: Vec<LocatedTreeData<MerkleHashOrchard>>,
251    ) -> impl std::future::Future<Output = Result<(), SyncError<Self::Error>>> + Send
252    where
253        Self: std::marker::Send,
254    {
255        async move {
256            let shard_trees = self.get_shard_trees_mut().map_err(SyncError::WalletError)?;
257
258            // limit the range that checkpoints are manually added to the top MAX_VERIFICATION_WINDOW blocks for efficiency.
259            // As we sync the chain tip first and have spend-before-sync, we will always choose anchors very close to chain
260            // height and we will also never need to truncate to checkpoints lower than this height.
261            let checkpoint_range = match (
262                scan_range.block_range().start
263                    > wallet_height.saturating_sub(MAX_VERIFICATION_WINDOW),
264                scan_range.block_range().end - 1
265                    > wallet_height.saturating_sub(MAX_VERIFICATION_WINDOW),
266            ) {
267                (true, _) => scan_range.block_range().clone(),
268                (false, true) => {
269                    (wallet_height - MAX_VERIFICATION_WINDOW)..scan_range.block_range().end
270                }
271                (false, false) => BlockHeight::from_u32(0)..BlockHeight::from_u32(0),
272            };
273
274            // in the case that sapling and/or orchard note commitments are not in an entire block there will be no retention
275            // at that height. Therefore, to prevent anchor and truncate errors, checkpoints are manually added first and
276            // copy the tree state from the previous checkpoint where the commitment tree has not changed as of that block.
277            for checkpoint_height in
278                u32::from(checkpoint_range.start)..u32::from(checkpoint_range.end)
279            {
280                let checkpoint_height = BlockHeight::from_u32(checkpoint_height);
281
282                add_checkpoint::<
283                    Sapling,
284                    sapling_crypto::Node,
285                    { sapling_crypto::NOTE_COMMITMENT_TREE_DEPTH },
286                    { witness::SHARD_HEIGHT },
287                >(
288                    fetch_request_sender.clone(),
289                    checkpoint_height,
290                    &sapling_located_trees,
291                    &mut shard_trees.sapling,
292                )
293                .await?;
294                add_checkpoint::<
295                    Orchard,
296                    MerkleHashOrchard,
297                    { orchard::NOTE_COMMITMENT_TREE_DEPTH as u8 },
298                    { witness::SHARD_HEIGHT },
299                >(
300                    fetch_request_sender.clone(),
301                    checkpoint_height,
302                    &orchard_located_trees,
303                    &mut shard_trees.orchard,
304                )
305                .await?;
306            }
307
308            for tree in sapling_located_trees.into_iter() {
309                shard_trees
310                    .sapling
311                    .insert_tree(tree.subtree, tree.checkpoints)?;
312            }
313            for tree in orchard_located_trees.into_iter() {
314                shard_trees
315                    .orchard
316                    .insert_tree(tree.subtree, tree.checkpoints)?;
317            }
318
319            Ok(())
320        }
321    }
322
323    /// Removes all shard tree data above the given `block_height`.
324    fn truncate_shard_trees(
325        &mut self,
326        truncate_height: BlockHeight,
327    ) -> Result<(), SyncError<Self::Error>> {
328        if !self
329            .get_shard_trees_mut()
330            .map_err(SyncError::WalletError)?
331            .sapling
332            .truncate_to_checkpoint(&truncate_height)?
333        {
334            panic!("max checkpoints should always be higher or equal to max verification window!");
335        }
336        if !self
337            .get_shard_trees_mut()
338            .map_err(SyncError::WalletError)?
339            .orchard
340            .truncate_to_checkpoint(&truncate_height)?
341        {
342            panic!("max checkpoints should always be higher or equal to max verification window!");
343        }
344
345        Ok(())
346    }
347}
348
349async fn add_checkpoint<D, L, const DEPTH: u8, const SHARD_HEIGHT: u8>(
350    fetch_request_sender: mpsc::UnboundedSender<FetchRequest>,
351    checkpoint_height: BlockHeight,
352    located_trees: &[LocatedTreeData<L>],
353    shard_tree: &mut shardtree::ShardTree<
354        shardtree::store::memory::MemoryShardStore<L, BlockHeight>,
355        DEPTH,
356        SHARD_HEIGHT,
357    >,
358) -> Result<(), ServerError>
359where
360    L: Clone + PartialEq + incrementalmerkletree::Hashable,
361    D: SyncDomain,
362{
363    let checkpoint = if let Some((_, position)) = located_trees
364        .iter()
365        .flat_map(|tree| tree.checkpoints.iter())
366        .find(|(height, _)| **height == checkpoint_height)
367    {
368        Checkpoint::at_position(*position)
369    } else {
370        let mut previous_checkpoint = None;
371        shard_tree
372            .store()
373            .for_each_checkpoint(1_000, |height, checkpoint| {
374                if *height == checkpoint_height - 1 {
375                    previous_checkpoint = Some(checkpoint.clone());
376                }
377                Ok(())
378            })
379            .expect("infallible");
380
381        let tree_state = if let Some(checkpoint) = previous_checkpoint {
382            checkpoint.tree_state()
383        } else {
384            let frontiers =
385                client::get_frontiers(fetch_request_sender.clone(), checkpoint_height - 1).await?;
386            let tree_size = match D::SHIELDED_PROTOCOL {
387                ShieldedProtocol::Sapling => frontiers.final_sapling_tree().tree_size(),
388                ShieldedProtocol::Orchard => frontiers.final_orchard_tree().tree_size(),
389            };
390            if tree_size == 0 {
391                TreeState::Empty
392            } else {
393                TreeState::AtPosition(incrementalmerkletree::Position::from(tree_size - 1))
394            }
395        };
396
397        Checkpoint::from_parts(tree_state, BTreeSet::new())
398    };
399
400    shard_tree
401        .store_mut()
402        .add_checkpoint(checkpoint_height, checkpoint)
403        .expect("infallible");
404
405    Ok(())
406}