1use std::fmt::Debug;
12use std::sync::Arc;
13
14use async_trait::async_trait;
15use cdk_common::database::{self, DbTransactionFinalizer, Error, MintDatabase};
16
17use crate::common::migrate;
18use crate::database::{ConnectionWithTransaction, DatabaseExecutor};
19use crate::pool::{DatabasePool, Pool, PooledResource};
20
21mod auth;
22mod completed_operations;
23mod keys;
24mod keyvalue;
25mod proofs;
26mod quotes;
27mod saga;
28mod signatures;
29
30#[rustfmt::skip]
31mod migrations {
32 include!(concat!(env!("OUT_DIR"), "/migrations_mint.rs"));
33}
34
35pub use auth::SQLMintAuthDatabase;
36#[cfg(feature = "prometheus")]
37use cdk_prometheus::MintMetricGuard;
38use migrations::MIGRATIONS;
39
40#[derive(Debug, Clone)]
42pub struct SQLMintDatabase<RM>
43where
44 RM: DatabasePool + 'static,
45{
46 pub(crate) pool: Arc<Pool<RM>>,
47}
48
49#[allow(missing_debug_implementations)]
51pub struct SQLTransaction<RM>
52where
53 RM: DatabasePool + 'static,
54{
55 pub(crate) inner: ConnectionWithTransaction<RM::Connection, PooledResource<RM>>,
56}
57
58impl<RM> SQLMintDatabase<RM>
59where
60 RM: DatabasePool + 'static,
61{
62 pub async fn new<X>(db: X) -> Result<Self, Error>
64 where
65 X: Into<RM::Config>,
66 {
67 let pool = Pool::new(db.into());
68
69 Self::migrate(pool.get().await.map_err(|e| Error::Database(Box::new(e)))?).await?;
70
71 Ok(Self { pool })
72 }
73
74 async fn migrate(conn: PooledResource<RM>) -> Result<(), Error> {
76 let tx = ConnectionWithTransaction::new(conn).await?;
77 migrate(&tx, RM::Connection::name(), MIGRATIONS).await?;
78 tx.commit().await?;
79 Ok(())
80 }
81}
82
83#[async_trait]
84impl<RM> database::MintTransaction<Error> for SQLTransaction<RM> where RM: DatabasePool + 'static {}
85
86#[async_trait]
87impl<RM> DbTransactionFinalizer for SQLTransaction<RM>
88where
89 RM: DatabasePool + 'static,
90{
91 type Err = Error;
92
93 async fn commit(self: Box<Self>) -> Result<(), Error> {
94 #[cfg(feature = "prometheus")]
95 let metrics = MintMetricGuard::new("transaction_commit");
96
97 let result = self.inner.commit().await;
98
99 #[cfg(feature = "prometheus")]
100 {
101 metrics.record(result.is_ok());
102 }
103
104 Ok(result?)
105 }
106
107 async fn rollback(self: Box<Self>) -> Result<(), Error> {
108 #[cfg(feature = "prometheus")]
109 let metrics = MintMetricGuard::new("transaction_rollback");
110
111 let result = self.inner.rollback().await;
112
113 #[cfg(feature = "prometheus")]
114 {
115 metrics.record(result.is_ok());
116 }
117 Ok(result?)
118 }
119}
120
121#[async_trait]
122impl<RM> MintDatabase<Error> for SQLMintDatabase<RM>
123where
124 RM: DatabasePool + 'static,
125{
126 async fn begin_transaction(
127 &self,
128 ) -> Result<Box<dyn database::MintTransaction<Error> + Send + Sync>, Error> {
129 let tx = SQLTransaction {
130 inner: ConnectionWithTransaction::new(
131 self.pool
132 .get()
133 .await
134 .map_err(|e| Error::Database(Box::new(e)))?,
135 )
136 .await?,
137 };
138
139 Ok(Box::new(tx))
140 }
141}
142
143#[cfg(all(test, feature = "prometheus"))]
144mod tests {
145 use std::fmt;
146 use std::sync::atomic::AtomicBool;
147 use std::sync::Arc;
148 use std::time::Duration;
149
150 use cdk_common::database::{DbTransactionFinalizer, Error as DatabaseError};
151 use cdk_prometheus::METRICS;
152
153 use super::SQLTransaction;
154 use crate::database::{
155 ConnectionWithTransaction, DatabaseConnector, DatabaseExecutor, DatabaseTransaction,
156 };
157 use crate::pool::{DatabaseConfig, DatabasePool, Error as PoolError, Pool};
158 use crate::stmt::{Column, Statement};
159
160 #[derive(Debug, Clone)]
161 struct TestConfig {
162 fail_commit: bool,
163 fail_rollback: bool,
164 }
165
166 impl DatabaseConfig for TestConfig {
167 fn max_size(&self) -> usize {
168 1
169 }
170
171 fn default_timeout(&self) -> Duration {
172 Duration::from_millis(10)
173 }
174 }
175
176 #[derive(Debug)]
177 struct TestResourceError;
178
179 impl fmt::Display for TestResourceError {
180 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181 f.write_str("test resource error")
182 }
183 }
184
185 impl std::error::Error for TestResourceError {}
186
187 #[derive(Debug)]
188 struct TestConnection {
189 fail_commit: bool,
190 fail_rollback: bool,
191 }
192
193 #[async_trait::async_trait]
194 impl DatabaseExecutor for TestConnection {
195 fn name() -> &'static str {
196 "test"
197 }
198
199 async fn execute(&self, _statement: Statement) -> Result<usize, DatabaseError> {
200 Ok(0)
201 }
202
203 async fn fetch_one(
204 &self,
205 _statement: Statement,
206 ) -> Result<Option<Vec<Column>>, DatabaseError> {
207 Ok(None)
208 }
209
210 async fn fetch_all(
211 &self,
212 _statement: Statement,
213 ) -> Result<Vec<Vec<Column>>, DatabaseError> {
214 Ok(Vec::new())
215 }
216
217 async fn pluck(&self, _statement: Statement) -> Result<Option<Column>, DatabaseError> {
218 Ok(None)
219 }
220
221 async fn batch(&self, _statement: Statement) -> Result<(), DatabaseError> {
222 Ok(())
223 }
224 }
225
226 #[derive(Debug)]
227 struct TestTransaction;
228
229 #[async_trait::async_trait]
230 impl DatabaseTransaction<TestConnection> for TestTransaction {
231 async fn commit(conn: &mut TestConnection) -> Result<(), DatabaseError> {
232 if conn.fail_commit {
233 Err(DatabaseError::Internal("commit failed".to_owned()))
234 } else {
235 Ok(())
236 }
237 }
238
239 async fn begin(_conn: &mut TestConnection) -> Result<(), DatabaseError> {
240 Ok(())
241 }
242
243 async fn rollback(conn: &mut TestConnection) -> Result<(), DatabaseError> {
244 if conn.fail_rollback {
245 Err(DatabaseError::Internal("rollback failed".to_owned()))
246 } else {
247 Ok(())
248 }
249 }
250 }
251
252 impl DatabaseConnector for TestConnection {
253 type Transaction = TestTransaction;
254 }
255
256 #[derive(Debug)]
257 struct TestPool;
258
259 impl DatabasePool for TestPool {
260 type Connection = TestConnection;
261 type Config = TestConfig;
262 type Error = TestResourceError;
263
264 fn new_resource(
265 config: &Self::Config,
266 _stale: Arc<AtomicBool>,
267 _timeout: Duration,
268 ) -> Result<Self::Connection, PoolError<Self::Error>> {
269 Ok(TestConnection {
270 fail_commit: config.fail_commit,
271 fail_rollback: config.fail_rollback,
272 })
273 }
274 }
275
276 async fn new_transaction(fail_commit: bool, fail_rollback: bool) -> SQLTransaction<TestPool> {
277 let pool = Pool::<TestPool>::new(TestConfig {
278 fail_commit,
279 fail_rollback,
280 });
281 let conn = pool
282 .get()
283 .await
284 .expect("test resource should be checked out");
285 let inner = ConnectionWithTransaction::new(conn)
286 .await
287 .expect("test transaction should begin");
288
289 SQLTransaction { inner }
290 }
291
292 fn labels_match(
293 metric: &cdk_prometheus::prometheus::proto::Metric,
294 labels: &[(&str, &str)],
295 ) -> bool {
296 labels.iter().all(|(name, value)| {
297 metric
298 .get_label()
299 .iter()
300 .any(|label| label.get_name() == *name && label.get_value() == *value)
301 })
302 }
303
304 fn counter_value(name: &str, labels: &[(&str, &str)]) -> f64 {
305 for family in METRICS.registry().gather() {
306 if family.get_name() != name {
307 continue;
308 }
309
310 for metric in family.get_metric() {
311 if labels_match(metric, labels) {
312 return metric.get_counter().get_value();
313 }
314 }
315 }
316
317 0.0
318 }
319
320 fn gauge_value(name: &str, labels: &[(&str, &str)]) -> f64 {
321 for family in METRICS.registry().gather() {
322 if family.get_name() != name {
323 continue;
324 }
325
326 for metric in family.get_metric() {
327 if labels_match(metric, labels) {
328 return metric.get_gauge().get_value();
329 }
330 }
331 }
332
333 0.0
334 }
335
336 fn histogram_count(name: &str, labels: &[(&str, &str)]) -> f64 {
337 for family in METRICS.registry().gather() {
338 if family.get_name() != name {
339 continue;
340 }
341
342 for metric in family.get_metric() {
343 if labels_match(metric, labels) {
344 return metric.get_histogram().get_sample_count() as f64;
345 }
346 }
347 }
348
349 0.0
350 }
351
352 #[tokio::test(flavor = "current_thread")]
353 async fn transaction_commit_records_success_duration_and_balances_in_flight() {
354 let _lock = crate::metrics_test_lock::lock().await;
355 let operation = "transaction_commit";
356 let labels = [("operation", operation), ("status", "success")];
357 let in_flight_labels = [("operation", operation)];
358
359 let success_before = counter_value("cdk_mint_operations_total", &labels);
360 let duration_count_before = histogram_count("cdk_mint_operation_duration_seconds", &labels);
361 let in_flight_before = gauge_value("cdk_mint_in_flight_requests", &in_flight_labels);
362
363 let tx = new_transaction(false, false).await;
364 Box::new(tx)
365 .commit()
366 .await
367 .expect("transaction commit should succeed");
368
369 assert_eq!(
370 counter_value("cdk_mint_operations_total", &labels),
371 success_before + 1.0
372 );
373 assert_eq!(
374 histogram_count("cdk_mint_operation_duration_seconds", &labels),
375 duration_count_before + 1.0
376 );
377 assert_eq!(
378 gauge_value("cdk_mint_in_flight_requests", &in_flight_labels),
379 in_flight_before
380 );
381 }
382
383 #[tokio::test(flavor = "current_thread")]
384 async fn transaction_commit_records_error_duration_and_balances_in_flight() {
385 let _lock = crate::metrics_test_lock::lock().await;
386 let operation = "transaction_commit";
387 let labels = [("operation", operation), ("status", "error")];
388 let in_flight_labels = [("operation", operation)];
389
390 let error_before = counter_value("cdk_mint_operations_total", &labels);
391 let duration_count_before = histogram_count("cdk_mint_operation_duration_seconds", &labels);
392 let in_flight_before = gauge_value("cdk_mint_in_flight_requests", &in_flight_labels);
393
394 let tx = new_transaction(true, false).await;
395 Box::new(tx)
396 .commit()
397 .await
398 .expect_err("transaction commit should fail");
399
400 assert_eq!(
401 counter_value("cdk_mint_operations_total", &labels),
402 error_before + 1.0
403 );
404 assert_eq!(
405 histogram_count("cdk_mint_operation_duration_seconds", &labels),
406 duration_count_before + 1.0
407 );
408 assert_eq!(
409 gauge_value("cdk_mint_in_flight_requests", &in_flight_labels),
410 in_flight_before
411 );
412 }
413
414 #[tokio::test(flavor = "current_thread")]
415 async fn transaction_rollback_records_success_duration_and_balances_in_flight() {
416 let _lock = crate::metrics_test_lock::lock().await;
417 let operation = "transaction_rollback";
418 let labels = [("operation", operation), ("status", "success")];
419 let in_flight_labels = [("operation", operation)];
420
421 let success_before = counter_value("cdk_mint_operations_total", &labels);
422 let duration_count_before = histogram_count("cdk_mint_operation_duration_seconds", &labels);
423 let in_flight_before = gauge_value("cdk_mint_in_flight_requests", &in_flight_labels);
424
425 let tx = new_transaction(false, false).await;
426 Box::new(tx)
427 .rollback()
428 .await
429 .expect("transaction rollback should succeed");
430
431 assert_eq!(
432 counter_value("cdk_mint_operations_total", &labels),
433 success_before + 1.0
434 );
435 assert_eq!(
436 histogram_count("cdk_mint_operation_duration_seconds", &labels),
437 duration_count_before + 1.0
438 );
439 assert_eq!(
440 gauge_value("cdk_mint_in_flight_requests", &in_flight_labels),
441 in_flight_before
442 );
443 }
444}