1use abtc_domain::primitives::{Amount, OutPoint, Transaction};
10use abtc_ports::wallet::store::WalletStore;
11use abtc_ports::wallet::UnspentOutput;
12use abtc_ports::{Balance, WalletPort};
13use async_trait::async_trait;
14use std::sync::Arc;
15
16use super::InMemoryWallet;
17
18pub struct PersistentWallet {
30 inner: Arc<InMemoryWallet>,
31 store: Arc<dyn WalletStore>,
32}
33
34impl PersistentWallet {
35 pub async fn new(
40 inner: Arc<InMemoryWallet>,
41 store: Arc<dyn WalletStore>,
42 ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
43 if let Some(snapshot) = store.load().await? {
45 inner.restore_from_snapshot(&snapshot).await?;
46 tracing::info!("Persistent wallet: loaded existing state");
47 } else {
48 tracing::info!("Persistent wallet: no existing state, starting fresh");
49 }
50
51 Ok(PersistentWallet { inner, store })
52 }
53
54 pub async fn save(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
56 let snapshot = self.inner.snapshot().await;
57 self.store.save(&snapshot).await
58 }
59
60 pub async fn add_utxo(
62 &self,
63 utxo: UnspentOutput,
64 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
65 self.inner.add_utxo(utxo).await;
66 self.save().await
67 }
68
69 pub async fn remove_utxos(
71 &self,
72 spent: &[OutPoint],
73 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
74 self.inner.remove_utxos(spent).await;
75 self.save().await
76 }
77
78 pub async fn key_count(&self) -> usize {
80 self.inner.key_count().await
81 }
82
83 pub fn inner(&self) -> &InMemoryWallet {
85 &self.inner
86 }
87}
88
89#[async_trait]
90impl WalletPort for PersistentWallet {
91 async fn get_balance(&self) -> Result<Balance, Box<dyn std::error::Error + Send + Sync>> {
92 self.inner.get_balance().await
93 }
94
95 async fn list_unspent(
96 &self,
97 min_confirmations: u32,
98 max_amount: Option<Amount>,
99 ) -> Result<Vec<UnspentOutput>, Box<dyn std::error::Error + Send + Sync>> {
100 self.inner.list_unspent(min_confirmations, max_amount).await
101 }
102
103 async fn create_transaction(
104 &self,
105 to: Vec<(String, Amount)>,
106 fee_rate: Option<f64>,
107 ) -> Result<Transaction, Box<dyn std::error::Error + Send + Sync>> {
108 self.inner.create_transaction(to, fee_rate).await
109 }
110
111 async fn sign_transaction(
112 &self,
113 tx: &Transaction,
114 ) -> Result<Transaction, Box<dyn std::error::Error + Send + Sync>> {
115 self.inner.sign_transaction(tx).await
116 }
117
118 async fn send_transaction(
119 &self,
120 tx: &Transaction,
121 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
122 let result = self.inner.send_transaction(tx).await?;
123 self.save().await?;
125 Ok(result)
126 }
127
128 async fn get_new_address(
129 &self,
130 label: Option<&str>,
131 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
132 let addr = self.inner.get_new_address(label).await?;
133 self.save().await?;
135 Ok(addr)
136 }
137
138 async fn import_key(
139 &self,
140 privkey_wif: &str,
141 label: Option<&str>,
142 rescan: bool,
143 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
144 self.inner.import_key(privkey_wif, label, rescan).await?;
145 self.save().await?;
147 Ok(())
148 }
149
150 async fn get_transaction_history(
151 &self,
152 count: u32,
153 skip: u32,
154 ) -> Result<Vec<Transaction>, Box<dyn std::error::Error + Send + Sync>> {
155 self.inner.get_transaction_history(count, skip).await
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use abtc_domain::primitives::{Amount, TxOut};
163 use abtc_domain::wallet::address::AddressType;
164 use abtc_ports::wallet::store::WalletSnapshot;
165 use std::sync::Mutex;
166
167 struct MockWalletStore {
169 data: Mutex<Option<WalletSnapshot>>,
170 save_count: Mutex<u32>,
171 }
172
173 impl MockWalletStore {
174 fn new() -> Self {
175 MockWalletStore {
176 data: Mutex::new(None),
177 save_count: Mutex::new(0),
178 }
179 }
180
181 fn save_count(&self) -> u32 {
182 *self.save_count.lock().unwrap()
183 }
184 }
185
186 #[async_trait]
187 impl WalletStore for MockWalletStore {
188 async fn save(
189 &self,
190 snapshot: &WalletSnapshot,
191 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
192 *self.data.lock().unwrap() = Some(snapshot.clone());
193 *self.save_count.lock().unwrap() += 1;
194 Ok(())
195 }
196
197 async fn load(
198 &self,
199 ) -> Result<Option<WalletSnapshot>, Box<dyn std::error::Error + Send + Sync>> {
200 Ok(self.data.lock().unwrap().clone())
201 }
202
203 async fn delete(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
204 *self.data.lock().unwrap() = None;
205 Ok(())
206 }
207 }
208
209 #[tokio::test]
210 async fn test_persistent_wallet_starts_fresh() {
211 let inner = Arc::new(InMemoryWallet::default_testnet());
212 let store = Arc::new(MockWalletStore::new());
213
214 let wallet = PersistentWallet::new(inner, store.clone()).await.unwrap();
215
216 let balance = wallet.get_balance().await.unwrap();
217 assert_eq!(balance.confirmed.as_sat(), 0);
218 assert_eq!(store.save_count(), 0);
219 }
220
221 #[tokio::test]
222 async fn test_get_new_address_triggers_save() {
223 let inner = Arc::new(InMemoryWallet::default_testnet());
224 let store = Arc::new(MockWalletStore::new());
225
226 let wallet = PersistentWallet::new(inner, store.clone()).await.unwrap();
227
228 let addr = wallet.get_new_address(Some("test")).await.unwrap();
229 assert!(addr.starts_with("tb1q"));
230 assert_eq!(store.save_count(), 1);
231
232 let snapshot = store.data.lock().unwrap().clone().unwrap();
234 assert_eq!(snapshot.keys.len(), 1);
235 assert_eq!(snapshot.keys[0].label, Some("test".to_string()));
236 }
237
238 #[tokio::test]
239 async fn test_import_key_triggers_save() {
240 let inner = Arc::new(InMemoryWallet::default_testnet());
241 let store = Arc::new(MockWalletStore::new());
242
243 let wallet = PersistentWallet::new(inner, store.clone()).await.unwrap();
244
245 let key = abtc_domain::wallet::keys::PrivateKey::generate(true, false);
247 let wif = key.to_wif();
248
249 wallet
250 .import_key(&wif, Some("imported"), false)
251 .await
252 .unwrap();
253 assert_eq!(store.save_count(), 1);
254 assert_eq!(wallet.key_count().await, 1);
255 }
256
257 #[tokio::test]
258 async fn test_add_utxo_triggers_save() {
259 let inner = Arc::new(InMemoryWallet::default_testnet());
260 let store = Arc::new(MockWalletStore::new());
261
262 let wallet = PersistentWallet::new(inner, store.clone()).await.unwrap();
263
264 let utxo = UnspentOutput {
265 outpoint: OutPoint::new(abtc_domain::Txid::zero(), 0),
266 output: TxOut::new(Amount::from_sat(50_000), abtc_domain::Script::new()),
267 confirmations: 3,
268 is_coinbase: false,
269 };
270
271 wallet.add_utxo(utxo).await.unwrap();
272 assert_eq!(store.save_count(), 1);
273
274 let balance = wallet.get_balance().await.unwrap();
275 assert_eq!(balance.confirmed.as_sat(), 50_000);
276 }
277
278 #[tokio::test]
279 async fn test_load_restores_state() {
280 let store = Arc::new(MockWalletStore::new());
281
282 {
284 let inner = Arc::new(InMemoryWallet::new(false, AddressType::P2WPKH));
285 let wallet = PersistentWallet::new(inner, store.clone()).await.unwrap();
286
287 wallet.get_new_address(Some("restored")).await.unwrap();
288 assert_eq!(store.save_count(), 1);
289 }
290
291 {
293 let inner2 = Arc::new(InMemoryWallet::new(false, AddressType::P2WPKH));
294 let wallet2 = PersistentWallet::new(inner2, store.clone()).await.unwrap();
295
296 assert_eq!(wallet2.key_count().await, 1);
298 }
299 }
300
301 #[tokio::test]
302 async fn test_multiple_saves_increment() {
303 let inner = Arc::new(InMemoryWallet::default_testnet());
304 let store = Arc::new(MockWalletStore::new());
305
306 let wallet = PersistentWallet::new(inner, store.clone()).await.unwrap();
307
308 wallet.get_new_address(None).await.unwrap();
309 wallet.get_new_address(None).await.unwrap();
310 wallet.get_new_address(None).await.unwrap();
311
312 assert_eq!(store.save_count(), 3);
313 assert_eq!(wallet.key_count().await, 3);
314 }
315}