Skip to main content

abtc_adapters/wallet/
persistent.rs

1//! Persistent Wallet Wrapper
2//!
3//! Wraps an InMemoryWallet with a WalletStore to provide automatic
4//! persistence after mutation operations. The wrapper delegates all
5//! WalletPort methods to the inner wallet and saves state to disk
6//! after operations that modify wallet state (key generation, key
7//! import, transaction sending).
8
9use abtc_domain::primitives::{Amount, OutPoint, Transaction};
10use abtc_ports::wallet::store::WalletStore;
11use abtc_ports::wallet::UnspentOutput;
12use abtc_ports::{Balance, WalletPort};
13use async_trait::async_trait;
14use std::sync::Arc;
15
16use super::InMemoryWallet;
17
18/// A wallet that automatically persists state to a WalletStore after mutations.
19///
20/// Wraps an `InMemoryWallet` and a `WalletStore` implementation. All read
21/// operations delegate directly to the inner wallet. Write operations (key
22/// generation, key import, transaction broadcast) trigger an automatic save
23/// after the operation completes.
24///
25/// # Construction
26///
27/// Use `PersistentWallet::new()` to create a persistent wallet. If the store
28/// contains previously saved state, it will be loaded into the inner wallet.
29pub struct PersistentWallet {
30    inner: Arc<InMemoryWallet>,
31    store: Arc<dyn WalletStore>,
32}
33
34impl PersistentWallet {
35    /// Create a persistent wallet, loading any existing state from the store.
36    ///
37    /// If the store contains a saved snapshot, it is restored into the inner
38    /// wallet. If the store is empty (first run), the wallet starts fresh.
39    pub async fn new(
40        inner: Arc<InMemoryWallet>,
41        store: Arc<dyn WalletStore>,
42    ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
43        // Load existing state if available
44        if let Some(snapshot) = store.load().await? {
45            inner.restore_from_snapshot(&snapshot).await?;
46            tracing::info!("Persistent wallet: loaded existing state");
47        } else {
48            tracing::info!("Persistent wallet: no existing state, starting fresh");
49        }
50
51        Ok(PersistentWallet { inner, store })
52    }
53
54    /// Save the current wallet state to the store.
55    pub async fn save(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
56        let snapshot = self.inner.snapshot().await;
57        self.store.save(&snapshot).await
58    }
59
60    /// Add a UTXO and persist.
61    pub async fn add_utxo(
62        &self,
63        utxo: UnspentOutput,
64    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
65        self.inner.add_utxo(utxo).await;
66        self.save().await
67    }
68
69    /// Remove spent UTXOs and persist.
70    pub async fn remove_utxos(
71        &self,
72        spent: &[OutPoint],
73    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
74        self.inner.remove_utxos(spent).await;
75        self.save().await
76    }
77
78    /// Get the number of keys in the wallet.
79    pub async fn key_count(&self) -> usize {
80        self.inner.key_count().await
81    }
82
83    /// Get a reference to the inner InMemoryWallet.
84    pub fn inner(&self) -> &InMemoryWallet {
85        &self.inner
86    }
87}
88
89#[async_trait]
90impl WalletPort for PersistentWallet {
91    async fn get_balance(&self) -> Result<Balance, Box<dyn std::error::Error + Send + Sync>> {
92        self.inner.get_balance().await
93    }
94
95    async fn list_unspent(
96        &self,
97        min_confirmations: u32,
98        max_amount: Option<Amount>,
99    ) -> Result<Vec<UnspentOutput>, Box<dyn std::error::Error + Send + Sync>> {
100        self.inner.list_unspent(min_confirmations, max_amount).await
101    }
102
103    async fn create_transaction(
104        &self,
105        to: Vec<(String, Amount)>,
106        fee_rate: Option<f64>,
107    ) -> Result<Transaction, Box<dyn std::error::Error + Send + Sync>> {
108        self.inner.create_transaction(to, fee_rate).await
109    }
110
111    async fn sign_transaction(
112        &self,
113        tx: &Transaction,
114    ) -> Result<Transaction, Box<dyn std::error::Error + Send + Sync>> {
115        self.inner.sign_transaction(tx).await
116    }
117
118    async fn send_transaction(
119        &self,
120        tx: &Transaction,
121    ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
122        let result = self.inner.send_transaction(tx).await?;
123        // Persist after spending UTXOs
124        self.save().await?;
125        Ok(result)
126    }
127
128    async fn get_new_address(
129        &self,
130        label: Option<&str>,
131    ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
132        let addr = self.inner.get_new_address(label).await?;
133        // Persist after generating new key
134        self.save().await?;
135        Ok(addr)
136    }
137
138    async fn import_key(
139        &self,
140        privkey_wif: &str,
141        label: Option<&str>,
142        rescan: bool,
143    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
144        self.inner.import_key(privkey_wif, label, rescan).await?;
145        // Persist after importing key
146        self.save().await?;
147        Ok(())
148    }
149
150    async fn get_transaction_history(
151        &self,
152        count: u32,
153        skip: u32,
154    ) -> Result<Vec<Transaction>, Box<dyn std::error::Error + Send + Sync>> {
155        self.inner.get_transaction_history(count, skip).await
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use abtc_domain::primitives::{Amount, TxOut};
163    use abtc_domain::wallet::address::AddressType;
164    use abtc_ports::wallet::store::WalletSnapshot;
165    use std::sync::Mutex;
166
167    /// A mock WalletStore that tracks save/load calls in memory.
168    struct MockWalletStore {
169        data: Mutex<Option<WalletSnapshot>>,
170        save_count: Mutex<u32>,
171    }
172
173    impl MockWalletStore {
174        fn new() -> Self {
175            MockWalletStore {
176                data: Mutex::new(None),
177                save_count: Mutex::new(0),
178            }
179        }
180
181        fn save_count(&self) -> u32 {
182            *self.save_count.lock().unwrap()
183        }
184    }
185
186    #[async_trait]
187    impl WalletStore for MockWalletStore {
188        async fn save(
189            &self,
190            snapshot: &WalletSnapshot,
191        ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
192            *self.data.lock().unwrap() = Some(snapshot.clone());
193            *self.save_count.lock().unwrap() += 1;
194            Ok(())
195        }
196
197        async fn load(
198            &self,
199        ) -> Result<Option<WalletSnapshot>, Box<dyn std::error::Error + Send + Sync>> {
200            Ok(self.data.lock().unwrap().clone())
201        }
202
203        async fn delete(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
204            *self.data.lock().unwrap() = None;
205            Ok(())
206        }
207    }
208
209    #[tokio::test]
210    async fn test_persistent_wallet_starts_fresh() {
211        let inner = Arc::new(InMemoryWallet::default_testnet());
212        let store = Arc::new(MockWalletStore::new());
213
214        let wallet = PersistentWallet::new(inner, store.clone()).await.unwrap();
215
216        let balance = wallet.get_balance().await.unwrap();
217        assert_eq!(balance.confirmed.as_sat(), 0);
218        assert_eq!(store.save_count(), 0);
219    }
220
221    #[tokio::test]
222    async fn test_get_new_address_triggers_save() {
223        let inner = Arc::new(InMemoryWallet::default_testnet());
224        let store = Arc::new(MockWalletStore::new());
225
226        let wallet = PersistentWallet::new(inner, store.clone()).await.unwrap();
227
228        let addr = wallet.get_new_address(Some("test")).await.unwrap();
229        assert!(addr.starts_with("tb1q"));
230        assert_eq!(store.save_count(), 1);
231
232        // Verify the saved snapshot has the key
233        let snapshot = store.data.lock().unwrap().clone().unwrap();
234        assert_eq!(snapshot.keys.len(), 1);
235        assert_eq!(snapshot.keys[0].label, Some("test".to_string()));
236    }
237
238    #[tokio::test]
239    async fn test_import_key_triggers_save() {
240        let inner = Arc::new(InMemoryWallet::default_testnet());
241        let store = Arc::new(MockWalletStore::new());
242
243        let wallet = PersistentWallet::new(inner, store.clone()).await.unwrap();
244
245        // Generate a key to get a valid WIF
246        let key = abtc_domain::wallet::keys::PrivateKey::generate(true, false);
247        let wif = key.to_wif();
248
249        wallet
250            .import_key(&wif, Some("imported"), false)
251            .await
252            .unwrap();
253        assert_eq!(store.save_count(), 1);
254        assert_eq!(wallet.key_count().await, 1);
255    }
256
257    #[tokio::test]
258    async fn test_add_utxo_triggers_save() {
259        let inner = Arc::new(InMemoryWallet::default_testnet());
260        let store = Arc::new(MockWalletStore::new());
261
262        let wallet = PersistentWallet::new(inner, store.clone()).await.unwrap();
263
264        let utxo = UnspentOutput {
265            outpoint: OutPoint::new(abtc_domain::Txid::zero(), 0),
266            output: TxOut::new(Amount::from_sat(50_000), abtc_domain::Script::new()),
267            confirmations: 3,
268            is_coinbase: false,
269        };
270
271        wallet.add_utxo(utxo).await.unwrap();
272        assert_eq!(store.save_count(), 1);
273
274        let balance = wallet.get_balance().await.unwrap();
275        assert_eq!(balance.confirmed.as_sat(), 50_000);
276    }
277
278    #[tokio::test]
279    async fn test_load_restores_state() {
280        let store = Arc::new(MockWalletStore::new());
281
282        // Phase 1: Create wallet, add key, save
283        {
284            let inner = Arc::new(InMemoryWallet::new(false, AddressType::P2WPKH));
285            let wallet = PersistentWallet::new(inner, store.clone()).await.unwrap();
286
287            wallet.get_new_address(Some("restored")).await.unwrap();
288            assert_eq!(store.save_count(), 1);
289        }
290
291        // Phase 2: Create new inner wallet, load from store
292        {
293            let inner2 = Arc::new(InMemoryWallet::new(false, AddressType::P2WPKH));
294            let wallet2 = PersistentWallet::new(inner2, store.clone()).await.unwrap();
295
296            // Key should be restored
297            assert_eq!(wallet2.key_count().await, 1);
298        }
299    }
300
301    #[tokio::test]
302    async fn test_multiple_saves_increment() {
303        let inner = Arc::new(InMemoryWallet::default_testnet());
304        let store = Arc::new(MockWalletStore::new());
305
306        let wallet = PersistentWallet::new(inner, store.clone()).await.unwrap();
307
308        wallet.get_new_address(None).await.unwrap();
309        wallet.get_new_address(None).await.unwrap();
310        wallet.get_new_address(None).await.unwrap();
311
312        assert_eq!(store.save_count(), 3);
313        assert_eq!(wallet.key_count().await, 3);
314    }
315}