Skip to main content

cdk_sql_common/mint/
mod.rs

1//! SQL database implementation of the Mint
2//!
3//! This is a generic SQL implementation for the mint storage layer. Any database can be plugged in
4//! as long as standard ANSI SQL is used, as Postgres and SQLite would understand it.
5//!
6//! This implementation also has a rudimentary but standard migration and versioning system.
7//!
8//! The trait expects an asynchronous interaction, but it also provides tools to spawn blocking
9//! clients in a pool and expose them to an asynchronous environment, making them compatible with
10//! Mint.
11use std::fmt::Debug;
12use std::sync::Arc;
13
14use async_trait::async_trait;
15use cdk_common::database::{self, DbTransactionFinalizer, Error, MintDatabase};
16
17use crate::common::migrate;
18use crate::database::{ConnectionWithTransaction, DatabaseExecutor};
19use crate::pool::{DatabasePool, Pool, PooledResource};
20
21mod auth;
22mod completed_operations;
23mod keys;
24mod keyvalue;
25mod proofs;
26mod quotes;
27mod saga;
28mod signatures;
29
30#[rustfmt::skip]
31mod migrations {
32    include!(concat!(env!("OUT_DIR"), "/migrations_mint.rs"));
33}
34
35pub use auth::SQLMintAuthDatabase;
36#[cfg(feature = "prometheus")]
37use cdk_prometheus::MintMetricGuard;
38use migrations::MIGRATIONS;
39
40/// Mint SQL Database
41#[derive(Debug, Clone)]
42pub struct SQLMintDatabase<RM>
43where
44    RM: DatabasePool + 'static,
45{
46    pub(crate) pool: Arc<Pool<RM>>,
47}
48
49/// SQL Transaction Writer
50#[allow(missing_debug_implementations)]
51pub struct SQLTransaction<RM>
52where
53    RM: DatabasePool + 'static,
54{
55    pub(crate) inner: ConnectionWithTransaction<RM::Connection, PooledResource<RM>>,
56}
57
58impl<RM> SQLMintDatabase<RM>
59where
60    RM: DatabasePool + 'static,
61{
62    /// Creates a new instance
63    pub async fn new<X>(db: X) -> Result<Self, Error>
64    where
65        X: Into<RM::Config>,
66    {
67        let pool = Pool::new(db.into());
68
69        Self::migrate(pool.get().await.map_err(|e| Error::Database(Box::new(e)))?).await?;
70
71        Ok(Self { pool })
72    }
73
74    /// Migrate
75    async fn migrate(conn: PooledResource<RM>) -> Result<(), Error> {
76        let tx = ConnectionWithTransaction::new(conn).await?;
77        migrate(&tx, RM::Connection::name(), MIGRATIONS).await?;
78        tx.commit().await?;
79        Ok(())
80    }
81}
82
83#[async_trait]
84impl<RM> database::MintTransaction<Error> for SQLTransaction<RM> where RM: DatabasePool + 'static {}
85
86#[async_trait]
87impl<RM> DbTransactionFinalizer for SQLTransaction<RM>
88where
89    RM: DatabasePool + 'static,
90{
91    type Err = Error;
92
93    async fn commit(self: Box<Self>) -> Result<(), Error> {
94        #[cfg(feature = "prometheus")]
95        let metrics = MintMetricGuard::new("transaction_commit");
96
97        let result = self.inner.commit().await;
98
99        #[cfg(feature = "prometheus")]
100        {
101            metrics.record(result.is_ok());
102        }
103
104        Ok(result?)
105    }
106
107    async fn rollback(self: Box<Self>) -> Result<(), Error> {
108        #[cfg(feature = "prometheus")]
109        let metrics = MintMetricGuard::new("transaction_rollback");
110
111        let result = self.inner.rollback().await;
112
113        #[cfg(feature = "prometheus")]
114        {
115            metrics.record(result.is_ok());
116        }
117        Ok(result?)
118    }
119}
120
121#[async_trait]
122impl<RM> MintDatabase<Error> for SQLMintDatabase<RM>
123where
124    RM: DatabasePool + 'static,
125{
126    async fn begin_transaction(
127        &self,
128    ) -> Result<Box<dyn database::MintTransaction<Error> + Send + Sync>, Error> {
129        let tx = SQLTransaction {
130            inner: ConnectionWithTransaction::new(
131                self.pool
132                    .get()
133                    .await
134                    .map_err(|e| Error::Database(Box::new(e)))?,
135            )
136            .await?,
137        };
138
139        Ok(Box::new(tx))
140    }
141}
142
143#[cfg(all(test, feature = "prometheus"))]
144mod tests {
145    use std::fmt;
146    use std::sync::atomic::AtomicBool;
147    use std::sync::Arc;
148    use std::time::Duration;
149
150    use cdk_common::database::{DbTransactionFinalizer, Error as DatabaseError};
151    use cdk_prometheus::METRICS;
152
153    use super::SQLTransaction;
154    use crate::database::{
155        ConnectionWithTransaction, DatabaseConnector, DatabaseExecutor, DatabaseTransaction,
156    };
157    use crate::pool::{DatabaseConfig, DatabasePool, Error as PoolError, Pool};
158    use crate::stmt::{Column, Statement};
159
160    #[derive(Debug, Clone)]
161    struct TestConfig {
162        fail_commit: bool,
163        fail_rollback: bool,
164    }
165
166    impl DatabaseConfig for TestConfig {
167        fn max_size(&self) -> usize {
168            1
169        }
170
171        fn default_timeout(&self) -> Duration {
172            Duration::from_millis(10)
173        }
174    }
175
176    #[derive(Debug)]
177    struct TestResourceError;
178
179    impl fmt::Display for TestResourceError {
180        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181            f.write_str("test resource error")
182        }
183    }
184
185    impl std::error::Error for TestResourceError {}
186
187    #[derive(Debug)]
188    struct TestConnection {
189        fail_commit: bool,
190        fail_rollback: bool,
191    }
192
193    #[async_trait::async_trait]
194    impl DatabaseExecutor for TestConnection {
195        fn name() -> &'static str {
196            "test"
197        }
198
199        async fn execute(&self, _statement: Statement) -> Result<usize, DatabaseError> {
200            Ok(0)
201        }
202
203        async fn fetch_one(
204            &self,
205            _statement: Statement,
206        ) -> Result<Option<Vec<Column>>, DatabaseError> {
207            Ok(None)
208        }
209
210        async fn fetch_all(
211            &self,
212            _statement: Statement,
213        ) -> Result<Vec<Vec<Column>>, DatabaseError> {
214            Ok(Vec::new())
215        }
216
217        async fn pluck(&self, _statement: Statement) -> Result<Option<Column>, DatabaseError> {
218            Ok(None)
219        }
220
221        async fn batch(&self, _statement: Statement) -> Result<(), DatabaseError> {
222            Ok(())
223        }
224    }
225
226    #[derive(Debug)]
227    struct TestTransaction;
228
229    #[async_trait::async_trait]
230    impl DatabaseTransaction<TestConnection> for TestTransaction {
231        async fn commit(conn: &mut TestConnection) -> Result<(), DatabaseError> {
232            if conn.fail_commit {
233                Err(DatabaseError::Internal("commit failed".to_owned()))
234            } else {
235                Ok(())
236            }
237        }
238
239        async fn begin(_conn: &mut TestConnection) -> Result<(), DatabaseError> {
240            Ok(())
241        }
242
243        async fn rollback(conn: &mut TestConnection) -> Result<(), DatabaseError> {
244            if conn.fail_rollback {
245                Err(DatabaseError::Internal("rollback failed".to_owned()))
246            } else {
247                Ok(())
248            }
249        }
250    }
251
252    impl DatabaseConnector for TestConnection {
253        type Transaction = TestTransaction;
254    }
255
256    #[derive(Debug)]
257    struct TestPool;
258
259    impl DatabasePool for TestPool {
260        type Connection = TestConnection;
261        type Config = TestConfig;
262        type Error = TestResourceError;
263
264        fn new_resource(
265            config: &Self::Config,
266            _stale: Arc<AtomicBool>,
267            _timeout: Duration,
268        ) -> Result<Self::Connection, PoolError<Self::Error>> {
269            Ok(TestConnection {
270                fail_commit: config.fail_commit,
271                fail_rollback: config.fail_rollback,
272            })
273        }
274    }
275
276    async fn new_transaction(fail_commit: bool, fail_rollback: bool) -> SQLTransaction<TestPool> {
277        let pool = Pool::<TestPool>::new(TestConfig {
278            fail_commit,
279            fail_rollback,
280        });
281        let conn = pool
282            .get()
283            .await
284            .expect("test resource should be checked out");
285        let inner = ConnectionWithTransaction::new(conn)
286            .await
287            .expect("test transaction should begin");
288
289        SQLTransaction { inner }
290    }
291
292    fn labels_match(
293        metric: &cdk_prometheus::prometheus::proto::Metric,
294        labels: &[(&str, &str)],
295    ) -> bool {
296        labels.iter().all(|(name, value)| {
297            metric
298                .get_label()
299                .iter()
300                .any(|label| label.get_name() == *name && label.get_value() == *value)
301        })
302    }
303
304    fn counter_value(name: &str, labels: &[(&str, &str)]) -> f64 {
305        for family in METRICS.registry().gather() {
306            if family.get_name() != name {
307                continue;
308            }
309
310            for metric in family.get_metric() {
311                if labels_match(metric, labels) {
312                    return metric.get_counter().get_value();
313                }
314            }
315        }
316
317        0.0
318    }
319
320    fn gauge_value(name: &str, labels: &[(&str, &str)]) -> f64 {
321        for family in METRICS.registry().gather() {
322            if family.get_name() != name {
323                continue;
324            }
325
326            for metric in family.get_metric() {
327                if labels_match(metric, labels) {
328                    return metric.get_gauge().get_value();
329                }
330            }
331        }
332
333        0.0
334    }
335
336    fn histogram_count(name: &str, labels: &[(&str, &str)]) -> f64 {
337        for family in METRICS.registry().gather() {
338            if family.get_name() != name {
339                continue;
340            }
341
342            for metric in family.get_metric() {
343                if labels_match(metric, labels) {
344                    return metric.get_histogram().get_sample_count() as f64;
345                }
346            }
347        }
348
349        0.0
350    }
351
352    #[tokio::test(flavor = "current_thread")]
353    async fn transaction_commit_records_success_duration_and_balances_in_flight() {
354        let _lock = crate::metrics_test_lock::lock().await;
355        let operation = "transaction_commit";
356        let labels = [("operation", operation), ("status", "success")];
357        let in_flight_labels = [("operation", operation)];
358
359        let success_before = counter_value("cdk_mint_operations_total", &labels);
360        let duration_count_before = histogram_count("cdk_mint_operation_duration_seconds", &labels);
361        let in_flight_before = gauge_value("cdk_mint_in_flight_requests", &in_flight_labels);
362
363        let tx = new_transaction(false, false).await;
364        Box::new(tx)
365            .commit()
366            .await
367            .expect("transaction commit should succeed");
368
369        assert_eq!(
370            counter_value("cdk_mint_operations_total", &labels),
371            success_before + 1.0
372        );
373        assert_eq!(
374            histogram_count("cdk_mint_operation_duration_seconds", &labels),
375            duration_count_before + 1.0
376        );
377        assert_eq!(
378            gauge_value("cdk_mint_in_flight_requests", &in_flight_labels),
379            in_flight_before
380        );
381    }
382
383    #[tokio::test(flavor = "current_thread")]
384    async fn transaction_commit_records_error_duration_and_balances_in_flight() {
385        let _lock = crate::metrics_test_lock::lock().await;
386        let operation = "transaction_commit";
387        let labels = [("operation", operation), ("status", "error")];
388        let in_flight_labels = [("operation", operation)];
389
390        let error_before = counter_value("cdk_mint_operations_total", &labels);
391        let duration_count_before = histogram_count("cdk_mint_operation_duration_seconds", &labels);
392        let in_flight_before = gauge_value("cdk_mint_in_flight_requests", &in_flight_labels);
393
394        let tx = new_transaction(true, false).await;
395        Box::new(tx)
396            .commit()
397            .await
398            .expect_err("transaction commit should fail");
399
400        assert_eq!(
401            counter_value("cdk_mint_operations_total", &labels),
402            error_before + 1.0
403        );
404        assert_eq!(
405            histogram_count("cdk_mint_operation_duration_seconds", &labels),
406            duration_count_before + 1.0
407        );
408        assert_eq!(
409            gauge_value("cdk_mint_in_flight_requests", &in_flight_labels),
410            in_flight_before
411        );
412    }
413
414    #[tokio::test(flavor = "current_thread")]
415    async fn transaction_rollback_records_success_duration_and_balances_in_flight() {
416        let _lock = crate::metrics_test_lock::lock().await;
417        let operation = "transaction_rollback";
418        let labels = [("operation", operation), ("status", "success")];
419        let in_flight_labels = [("operation", operation)];
420
421        let success_before = counter_value("cdk_mint_operations_total", &labels);
422        let duration_count_before = histogram_count("cdk_mint_operation_duration_seconds", &labels);
423        let in_flight_before = gauge_value("cdk_mint_in_flight_requests", &in_flight_labels);
424
425        let tx = new_transaction(false, false).await;
426        Box::new(tx)
427            .rollback()
428            .await
429            .expect("transaction rollback should succeed");
430
431        assert_eq!(
432            counter_value("cdk_mint_operations_total", &labels),
433            success_before + 1.0
434        );
435        assert_eq!(
436            histogram_count("cdk_mint_operation_duration_seconds", &labels),
437            duration_count_before + 1.0
438        );
439        assert_eq!(
440            gauge_value("cdk_mint_in_flight_requests", &in_flight_labels),
441            in_flight_before
442        );
443    }
444}