Skip to main content

agent_first_pay/store/
postgres_store.rs

1use crate::provider::PayError;
2use crate::store::wallet::{self, WalletMetadata};
3use crate::store::{MigrationLog, PayStore};
4use crate::types::{HistoryRecord, Network};
5use sqlx::PgPool;
6use std::collections::BTreeMap;
7use std::path::PathBuf;
8
9const SCHEMA_SQL: &str = r#"
10CREATE TABLE IF NOT EXISTS wallets (
11    id TEXT PRIMARY KEY,
12    network TEXT NOT NULL,
13    metadata JSONB NOT NULL,
14    created_at_epoch_s BIGINT NOT NULL
15);
16CREATE INDEX IF NOT EXISTS idx_wallets_network ON wallets(network);
17
18CREATE TABLE IF NOT EXISTS transactions (
19    sequence BIGSERIAL PRIMARY KEY,
20    transaction_id TEXT NOT NULL UNIQUE,
21    wallet TEXT NOT NULL,
22    record JSONB NOT NULL
23);
24CREATE INDEX IF NOT EXISTS idx_transactions_wallet ON transactions(wallet);
25
26CREATE TABLE IF NOT EXISTS spend_rules (
27    rule_id TEXT PRIMARY KEY,
28    rule JSONB NOT NULL
29);
30
31CREATE TABLE IF NOT EXISTS spend_reservations (
32    reservation_id BIGSERIAL PRIMARY KEY,
33    op_id TEXT NOT NULL UNIQUE,
34    reservation JSONB NOT NULL
35);
36
37CREATE TABLE IF NOT EXISTS spend_events (
38    event_id BIGSERIAL PRIMARY KEY,
39    reservation_id BIGINT NOT NULL,
40    event JSONB NOT NULL
41);
42
43CREATE TABLE IF NOT EXISTS exchange_rate_cache (
44    pair TEXT PRIMARY KEY,
45    quote JSONB NOT NULL
46);
47"#;
48
49/// Advisory lock key for spend operations (prevents concurrent spend check-then-write).
50/// Hex of "afpay\0" = 0x616670617900.
51pub const SPEND_ADVISORY_LOCK_KEY: i64 = 0x0061_6670_6179;
52
53/// PostgreSQL-backed storage.
54#[derive(Clone)]
55pub struct PostgresStore {
56    pool: PgPool,
57    data_dir: String,
58}
59
60impl PostgresStore {
61    /// Get a reference to the connection pool (used by spend module).
62    pub fn pool(&self) -> &PgPool {
63        &self.pool
64    }
65
66    pub async fn connect(database_url: &str, data_dir: &str) -> Result<Self, PayError> {
67        let pool = PgPool::connect(database_url)
68            .await
69            .map_err(|e| PayError::InternalError(format!("postgres connect: {e}")))?;
70
71        sqlx::raw_sql(SCHEMA_SQL)
72            .execute(&pool)
73            .await
74            .map_err(|e| PayError::InternalError(format!("postgres schema init: {e}")))?;
75
76        Ok(Self {
77            pool,
78            data_dir: data_dir.to_string(),
79        })
80    }
81}
82
83impl PayStore for PostgresStore {
84    fn save_wallet_metadata(&self, meta: &WalletMetadata) -> Result<(), PayError> {
85        let pool = self.pool.clone();
86        let meta_json = serde_json::to_value(meta)
87            .map_err(|e| PayError::InternalError(format!("serialize wallet metadata: {e}")))?;
88        let network_str = meta.network.to_string();
89        let id = meta.id.clone();
90        let created = meta.created_at_epoch_s as i64;
91
92        tokio::task::block_in_place(|| {
93            tokio::runtime::Handle::current().block_on(async {
94                sqlx::query(
95                    "INSERT INTO wallets (id, network, metadata, created_at_epoch_s) \
96                     VALUES ($1, $2, $3, $4) \
97                     ON CONFLICT (id) DO UPDATE SET metadata = $3",
98                )
99                .bind(&id)
100                .bind(&network_str)
101                .bind(&meta_json)
102                .bind(created)
103                .execute(&pool)
104                .await
105                .map_err(|e| PayError::InternalError(format!("postgres save wallet: {e}")))?;
106                Ok(())
107            })
108        })
109    }
110
111    fn load_wallet_metadata(&self, wallet_id: &str) -> Result<WalletMetadata, PayError> {
112        let pool = self.pool.clone();
113        let wallet_id = wallet_id.to_string();
114
115        tokio::task::block_in_place(|| {
116            tokio::runtime::Handle::current().block_on(async {
117                let row: Option<(serde_json::Value,)> =
118                    sqlx::query_as("SELECT metadata FROM wallets WHERE id = $1")
119                        .bind(&wallet_id)
120                        .fetch_optional(&pool)
121                        .await
122                        .map_err(|e| {
123                            PayError::InternalError(format!("postgres load wallet: {e}"))
124                        })?;
125
126                match row {
127                    Some((meta_json,)) => serde_json::from_value(meta_json).map_err(|e| {
128                        PayError::InternalError(format!("postgres parse wallet metadata: {e}"))
129                    }),
130                    None => {
131                        // Label fallback
132                        if !wallet_id.starts_with("w_") {
133                            let row: Option<(serde_json::Value,)> = sqlx::query_as(
134                                "SELECT metadata FROM wallets WHERE metadata->>'label' = $1",
135                            )
136                            .bind(&wallet_id)
137                            .fetch_optional(&pool)
138                            .await
139                            .map_err(|e| {
140                                PayError::InternalError(format!(
141                                    "postgres load wallet by label: {e}"
142                                ))
143                            })?;
144                            if let Some((meta_json,)) = row {
145                                return serde_json::from_value(meta_json).map_err(|e| {
146                                    PayError::InternalError(format!(
147                                        "postgres parse wallet metadata: {e}"
148                                    ))
149                                });
150                            }
151                        }
152                        Err(PayError::WalletNotFound(format!(
153                            "wallet {wallet_id} not found"
154                        )))
155                    }
156                }
157            })
158        })
159    }
160
161    fn list_wallet_metadata(
162        &self,
163        network: Option<Network>,
164    ) -> Result<Vec<WalletMetadata>, PayError> {
165        let pool = self.pool.clone();
166
167        tokio::task::block_in_place(|| {
168            tokio::runtime::Handle::current().block_on(async {
169                let rows: Vec<(serde_json::Value,)> = match network {
170                    Some(n) => {
171                        sqlx::query_as(
172                            "SELECT metadata FROM wallets WHERE network = $1 ORDER BY id",
173                        )
174                        .bind(n.to_string())
175                        .fetch_all(&pool)
176                        .await
177                    }
178                    None => {
179                        sqlx::query_as("SELECT metadata FROM wallets ORDER BY id")
180                            .fetch_all(&pool)
181                            .await
182                    }
183                }
184                .map_err(|e| PayError::InternalError(format!("postgres list wallets: {e}")))?;
185
186                rows.into_iter()
187                    .map(|(meta_json,)| {
188                        serde_json::from_value(meta_json).map_err(|e| {
189                            PayError::InternalError(format!("postgres parse wallet metadata: {e}"))
190                        })
191                    })
192                    .collect()
193            })
194        })
195    }
196
197    fn delete_wallet_metadata(&self, wallet_id: &str) -> Result<(), PayError> {
198        let pool = self.pool.clone();
199        let wallet_id = wallet_id.to_string();
200
201        tokio::task::block_in_place(|| {
202            tokio::runtime::Handle::current().block_on(async {
203                let result = sqlx::query("DELETE FROM wallets WHERE id = $1")
204                    .bind(&wallet_id)
205                    .execute(&pool)
206                    .await
207                    .map_err(|e| PayError::InternalError(format!("postgres delete wallet: {e}")))?;
208
209                if result.rows_affected() == 0 {
210                    return Err(PayError::WalletNotFound(format!(
211                        "wallet {wallet_id} not found"
212                    )));
213                }
214                Ok(())
215            })
216        })
217    }
218
219    fn wallet_directory_path(&self, wallet_id: &str) -> Result<PathBuf, PayError> {
220        wallet::wallet_directory_path(&self.data_dir, wallet_id)
221    }
222
223    fn wallet_data_directory_path(&self, wallet_id: &str) -> Result<PathBuf, PayError> {
224        wallet::wallet_data_directory_path(&self.data_dir, wallet_id)
225    }
226
227    fn wallet_data_directory_path_for_meta(&self, meta: &WalletMetadata) -> PathBuf {
228        wallet::wallet_data_directory_path_for_wallet_metadata(&self.data_dir, meta)
229    }
230
231    fn resolve_wallet_id(&self, id_or_label: &str) -> Result<String, PayError> {
232        if id_or_label.starts_with("w_") {
233            return Ok(id_or_label.to_string());
234        }
235        let all = self.list_wallet_metadata(None)?;
236        let mut matches: Vec<&WalletMetadata> = all
237            .iter()
238            .filter(|w| w.label.as_deref() == Some(id_or_label))
239            .collect();
240        match matches.len() {
241            0 => Err(PayError::WalletNotFound(format!(
242                "no wallet found with ID or label '{id_or_label}'"
243            ))),
244            1 => Ok(matches.remove(0).id.clone()),
245            n => Err(PayError::InvalidAmount(format!(
246                "label '{id_or_label}' matches {n} wallets — use wallet ID instead"
247            ))),
248        }
249    }
250
251    fn append_transaction_record(&self, record: &HistoryRecord) -> Result<(), PayError> {
252        let pool = self.pool.clone();
253        let record_json = serde_json::to_value(record)
254            .map_err(|e| PayError::InternalError(format!("serialize transaction record: {e}")))?;
255        let tx_id = record.transaction_id.clone();
256        let wallet = record.wallet.clone();
257
258        tokio::task::block_in_place(|| {
259            tokio::runtime::Handle::current().block_on(async {
260                sqlx::query(
261                    "INSERT INTO transactions (transaction_id, wallet, record) \
262                     VALUES ($1, $2, $3) \
263                     ON CONFLICT (transaction_id) DO NOTHING",
264                )
265                .bind(&tx_id)
266                .bind(&wallet)
267                .bind(&record_json)
268                .execute(&pool)
269                .await
270                .map_err(|e| {
271                    PayError::InternalError(format!("postgres append transaction: {e}"))
272                })?;
273                Ok(())
274            })
275        })
276    }
277
278    fn load_wallet_transaction_records(
279        &self,
280        wallet_id: &str,
281    ) -> Result<Vec<HistoryRecord>, PayError> {
282        let pool = self.pool.clone();
283        let wallet_id = wallet_id.to_string();
284
285        tokio::task::block_in_place(|| {
286            tokio::runtime::Handle::current().block_on(async {
287                let rows: Vec<(serde_json::Value,)> = sqlx::query_as(
288                    "SELECT record FROM transactions WHERE wallet = $1 ORDER BY sequence",
289                )
290                .bind(&wallet_id)
291                .fetch_all(&pool)
292                .await
293                .map_err(|e| PayError::InternalError(format!("postgres load transactions: {e}")))?;
294
295                rows.into_iter()
296                    .map(|(record_json,)| {
297                        serde_json::from_value(record_json).map_err(|e| {
298                            PayError::InternalError(format!(
299                                "postgres parse transaction record: {e}"
300                            ))
301                        })
302                    })
303                    .collect()
304            })
305        })
306    }
307
308    fn find_transaction_record_by_id(
309        &self,
310        tx_id: &str,
311    ) -> Result<Option<HistoryRecord>, PayError> {
312        let pool = self.pool.clone();
313        let tx_id = tx_id.to_string();
314
315        tokio::task::block_in_place(|| {
316            tokio::runtime::Handle::current().block_on(async {
317                let row: Option<(serde_json::Value,)> =
318                    sqlx::query_as("SELECT record FROM transactions WHERE transaction_id = $1")
319                        .bind(&tx_id)
320                        .fetch_optional(&pool)
321                        .await
322                        .map_err(|e| {
323                            PayError::InternalError(format!("postgres find transaction: {e}"))
324                        })?;
325
326                match row {
327                    Some((record_json,)) => {
328                        let record: HistoryRecord =
329                            serde_json::from_value(record_json).map_err(|e| {
330                                PayError::InternalError(format!(
331                                    "postgres parse transaction record: {e}"
332                                ))
333                            })?;
334                        Ok(Some(record))
335                    }
336                    None => Ok(None),
337                }
338            })
339        })
340    }
341
342    fn update_transaction_record_memo(
343        &self,
344        tx_id: &str,
345        memo: Option<&BTreeMap<String, String>>,
346    ) -> Result<(), PayError> {
347        let pool = self.pool.clone();
348        let tx_id = tx_id.to_string();
349        let memo_json = serde_json::to_value(memo)
350            .map_err(|e| PayError::InternalError(format!("serialize memo: {e}")))?;
351
352        tokio::task::block_in_place(|| {
353            tokio::runtime::Handle::current().block_on(async {
354                // Read existing record, update memo, write back
355                let row: Option<(serde_json::Value,)> =
356                    sqlx::query_as("SELECT record FROM transactions WHERE transaction_id = $1")
357                        .bind(&tx_id)
358                        .fetch_optional(&pool)
359                        .await
360                        .map_err(|e| {
361                            PayError::InternalError(format!("postgres read transaction: {e}"))
362                        })?;
363
364                let Some((record_json,)) = row else {
365                    return Err(PayError::WalletNotFound(format!(
366                        "transaction {tx_id} not found"
367                    )));
368                };
369
370                let mut record: HistoryRecord = serde_json::from_value(record_json)
371                    .map_err(|e| PayError::InternalError(format!("postgres parse record: {e}")))?;
372                record.local_memo = serde_json::from_value(memo_json)
373                    .map_err(|e| PayError::InternalError(format!("postgres parse memo: {e}")))?;
374                let updated_json = serde_json::to_value(&record).map_err(|e| {
375                    PayError::InternalError(format!("serialize updated record: {e}"))
376                })?;
377
378                sqlx::query("UPDATE transactions SET record = $1 WHERE transaction_id = $2")
379                    .bind(&updated_json)
380                    .bind(&tx_id)
381                    .execute(&pool)
382                    .await
383                    .map_err(|e| {
384                        PayError::InternalError(format!("postgres update transaction memo: {e}"))
385                    })?;
386                Ok(())
387            })
388        })
389    }
390
391    fn update_transaction_record_fee(
392        &self,
393        tx_id: &str,
394        fee_value: u64,
395        fee_unit: &str,
396    ) -> Result<(), PayError> {
397        let pool = self.pool.clone();
398        let tx_id = tx_id.to_string();
399        let fee_unit = fee_unit.to_string();
400
401        tokio::task::block_in_place(|| {
402            tokio::runtime::Handle::current().block_on(async {
403                let row: Option<(serde_json::Value,)> =
404                    sqlx::query_as("SELECT record FROM transactions WHERE transaction_id = $1")
405                        .bind(&tx_id)
406                        .fetch_optional(&pool)
407                        .await
408                        .map_err(|e| {
409                            PayError::InternalError(format!("postgres read transaction: {e}"))
410                        })?;
411
412                let Some((record_json,)) = row else {
413                    return Err(PayError::WalletNotFound(format!(
414                        "transaction {tx_id} not found"
415                    )));
416                };
417
418                let mut record: HistoryRecord = serde_json::from_value(record_json)
419                    .map_err(|e| PayError::InternalError(format!("postgres parse record: {e}")))?;
420                record.fee = Some(crate::types::Amount {
421                    value: fee_value,
422                    token: fee_unit,
423                });
424                let updated_json = serde_json::to_value(&record).map_err(|e| {
425                    PayError::InternalError(format!("serialize updated record: {e}"))
426                })?;
427
428                sqlx::query("UPDATE transactions SET record = $1 WHERE transaction_id = $2")
429                    .bind(&updated_json)
430                    .bind(&tx_id)
431                    .execute(&pool)
432                    .await
433                    .map_err(|e| {
434                        PayError::InternalError(format!("postgres update transaction fee: {e}"))
435                    })?;
436                Ok(())
437            })
438        })
439    }
440
441    fn drain_migration_log(&self) -> Vec<MigrationLog> {
442        Vec::new()
443    }
444}