1use std::collections::HashMap;
4use std::path::Path;
5use std::str::FromStr;
6use std::time::Duration;
7
8use async_trait::async_trait;
9use cdk_common::database::{self, MintAuthDatabase};
10use cdk_common::mint::MintKeySetInfo;
11use cdk_common::nuts::{AuthProof, BlindSignature, Id, PublicKey, State};
12use cdk_common::{AuthRequired, ProtectedEndpoint};
13use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
14use sqlx::Row;
15use tracing::instrument;
16
17use super::{sqlite_row_to_blind_signature, sqlite_row_to_keyset_info};
18use crate::mint::Error;
19
20#[derive(Debug, Clone)]
22pub struct MintSqliteAuthDatabase {
23 pool: SqlitePool,
24}
25
26impl MintSqliteAuthDatabase {
27 pub async fn new(path: &Path) -> Result<Self, Error> {
29 let path = path.to_str().ok_or(Error::InvalidDbPath)?;
30 let db_options = SqliteConnectOptions::from_str(path)?
31 .busy_timeout(Duration::from_secs(5))
32 .read_only(false)
33 .create_if_missing(true)
34 .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Full);
35
36 let pool = SqlitePoolOptions::new()
37 .max_connections(1)
38 .connect_with(db_options)
39 .await?;
40
41 Ok(Self { pool })
42 }
43
44 pub async fn migrate(&self) {
46 sqlx::migrate!("./src/mint/auth/migrations")
47 .run(&self.pool)
48 .await
49 .expect("Could not run migrations");
50 }
51}
52
53#[async_trait]
54impl MintAuthDatabase for MintSqliteAuthDatabase {
55 type Err = database::Error;
56
57 #[instrument(skip(self))]
58 async fn set_active_keyset(&self, id: Id) -> Result<(), Self::Err> {
59 tracing::info!("Setting auth keyset {id} active");
60 let mut transaction = self.pool.begin().await.map_err(Error::from)?;
61 let update_res = sqlx::query(
62 r#"
63 UPDATE keyset
64 SET active = CASE
65 WHEN id = ? THEN TRUE
66 ELSE FALSE
67 END;
68 "#,
69 )
70 .bind(id.to_string())
71 .execute(&mut *transaction)
72 .await;
73
74 match update_res {
75 Ok(_) => {
76 transaction.commit().await.map_err(Error::from)?;
77 Ok(())
78 }
79 Err(err) => {
80 tracing::error!("SQLite Could not update keyset");
81 if let Err(err) = transaction.rollback().await {
82 tracing::error!("Could not rollback sql transaction: {}", err);
83 }
84 Err(Error::from(err).into())
85 }
86 }
87 }
88
89 async fn get_active_keyset_id(&self) -> Result<Option<Id>, Self::Err> {
90 let mut transaction = self.pool.begin().await.map_err(Error::from)?;
91
92 let rec = sqlx::query(
93 r#"
94SELECT id
95FROM keyset
96WHERE active = 1;
97 "#,
98 )
99 .fetch_one(&mut *transaction)
100 .await;
101
102 let rec = match rec {
103 Ok(rec) => {
104 transaction.commit().await.map_err(Error::from)?;
105 rec
106 }
107 Err(err) => match err {
108 sqlx::Error::RowNotFound => {
109 transaction.commit().await.map_err(Error::from)?;
110 return Ok(None);
111 }
112 _ => {
113 return {
114 if let Err(err) = transaction.rollback().await {
115 tracing::error!("Could not rollback sql transaction: {}", err);
116 }
117 Err(Error::SQLX(err).into())
118 }
119 }
120 },
121 };
122
123 Ok(Some(
124 Id::from_str(rec.try_get("id").map_err(Error::from)?).map_err(Error::from)?,
125 ))
126 }
127
128 async fn add_keyset_info(&self, keyset: MintKeySetInfo) -> Result<(), Self::Err> {
129 let mut transaction = self.pool.begin().await.map_err(Error::from)?;
130 let res = sqlx::query(
131 r#"
132INSERT OR REPLACE INTO keyset
133(id, unit, active, valid_from, valid_to, derivation_path, max_order, derivation_path_index)
134VALUES (?, ?, ?, ?, ?, ?, ?, ?);
135 "#,
136 )
137 .bind(keyset.id.to_string())
138 .bind(keyset.unit.to_string())
139 .bind(keyset.active)
140 .bind(keyset.valid_from as i64)
141 .bind(keyset.valid_to.map(|v| v as i64))
142 .bind(keyset.derivation_path.to_string())
143 .bind(keyset.max_order)
144 .bind(keyset.derivation_path_index)
145 .execute(&mut *transaction)
146 .await;
147
148 match res {
149 Ok(_) => {
150 transaction.commit().await.map_err(Error::from)?;
151 Ok(())
152 }
153 Err(err) => {
154 tracing::error!("SQLite could not add keyset info");
155 if let Err(err) = transaction.rollback().await {
156 tracing::error!("Could not rollback sql transaction: {}", err);
157 }
158
159 Err(Error::from(err).into())
160 }
161 }
162 }
163
164 async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err> {
165 let mut transaction = self.pool.begin().await.map_err(Error::from)?;
166 let rec = sqlx::query(
167 r#"
168SELECT *
169FROM keyset
170WHERE id=?;
171 "#,
172 )
173 .bind(id.to_string())
174 .fetch_one(&mut *transaction)
175 .await;
176
177 match rec {
178 Ok(rec) => {
179 transaction.commit().await.map_err(Error::from)?;
180 Ok(Some(sqlite_row_to_keyset_info(rec)?))
181 }
182 Err(err) => match err {
183 sqlx::Error::RowNotFound => {
184 transaction.commit().await.map_err(Error::from)?;
185 return Ok(None);
186 }
187 _ => {
188 tracing::error!("SQLite could not get keyset info");
189 if let Err(err) = transaction.rollback().await {
190 tracing::error!("Could not rollback sql transaction: {}", err);
191 }
192 return Err(Error::SQLX(err).into());
193 }
194 },
195 }
196 }
197
198 async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err> {
199 let mut transaction = self.pool.begin().await.map_err(Error::from)?;
200 let recs = sqlx::query(
201 r#"
202SELECT *
203FROM keyset;
204 "#,
205 )
206 .fetch_all(&mut *transaction)
207 .await
208 .map_err(Error::from);
209
210 match recs {
211 Ok(recs) => {
212 transaction.commit().await.map_err(Error::from)?;
213 Ok(recs
214 .into_iter()
215 .map(sqlite_row_to_keyset_info)
216 .collect::<Result<_, _>>()?)
217 }
218 Err(err) => {
219 tracing::error!("SQLite could not get keyset info");
220 if let Err(err) = transaction.rollback().await {
221 tracing::error!("Could not rollback sql transaction: {}", err);
222 }
223 Err(err.into())
224 }
225 }
226 }
227
228 async fn add_proof(&self, proof: AuthProof) -> Result<(), Self::Err> {
229 let mut transaction = self.pool.begin().await.map_err(Error::from)?;
230 if let Err(err) = sqlx::query(
231 r#"
232INSERT INTO proof
233(y, keyset_id, secret, c, state)
234VALUES (?, ?, ?, ?, ?);
235 "#,
236 )
237 .bind(proof.y()?.to_bytes().to_vec())
238 .bind(proof.keyset_id.to_string())
239 .bind(proof.secret.to_string())
240 .bind(proof.c.to_bytes().to_vec())
241 .bind("UNSPENT")
242 .execute(&mut *transaction)
243 .await
244 .map_err(Error::from)
245 {
246 tracing::debug!("Attempting to add known proof. Skipping.... {:?}", err);
247 }
248 transaction.commit().await.map_err(Error::from)?;
249
250 Ok(())
251 }
252
253 async fn get_proofs_states(&self, ys: &[PublicKey]) -> Result<Vec<Option<State>>, Self::Err> {
254 let mut transaction = self.pool.begin().await.map_err(Error::from)?;
255
256 let sql = format!(
257 "SELECT y, state FROM proof WHERE y IN ({})",
258 "?,".repeat(ys.len()).trim_end_matches(',')
259 );
260
261 let mut current_states = ys
262 .iter()
263 .fold(sqlx::query(&sql), |query, y| {
264 query.bind(y.to_bytes().to_vec())
265 })
266 .fetch_all(&mut *transaction)
267 .await
268 .map_err(|err| {
269 tracing::error!("SQLite could not get state of proof: {err:?}");
270 Error::SQLX(err)
271 })?
272 .into_iter()
273 .map(|row| {
274 PublicKey::from_slice(row.get("y"))
275 .map_err(Error::from)
276 .and_then(|y| {
277 let state: String = row.get("state");
278 State::from_str(&state)
279 .map_err(Error::from)
280 .map(|state| (y, state))
281 })
282 })
283 .collect::<Result<HashMap<_, _>, _>>()?;
284
285 Ok(ys.iter().map(|y| current_states.remove(y)).collect())
286 }
287
288 async fn update_proof_state(
289 &self,
290 y: &PublicKey,
291 proofs_state: State,
292 ) -> Result<Option<State>, Self::Err> {
293 let mut transaction = self.pool.begin().await.map_err(Error::from)?;
294
295 let current_state = sqlx::query("SELECT state FROM proof WHERE y = ?")
297 .bind(y.to_bytes().to_vec())
298 .fetch_optional(&mut *transaction)
299 .await
300 .map_err(|err| {
301 tracing::error!("SQLite could not get state of proof: {err:?}");
302 Error::SQLX(err)
303 })?
304 .map(|row| {
305 let state: String = row.get("state");
306 State::from_str(&state).map_err(Error::from)
307 })
308 .transpose()?;
309
310 sqlx::query("UPDATE proof SET state = ? WHERE state != ? AND y = ?")
312 .bind(proofs_state.to_string())
313 .bind(State::Spent.to_string())
314 .bind(y.to_bytes().to_vec())
315 .execute(&mut *transaction)
316 .await
317 .map_err(|err| {
318 tracing::error!("SQLite could not update proof state: {err:?}");
319 Error::SQLX(err)
320 })?;
321
322 transaction.commit().await.map_err(Error::from)?;
323 Ok(current_state)
324 }
325
326 async fn add_blind_signatures(
327 &self,
328 blinded_messages: &[PublicKey],
329 blind_signatures: &[BlindSignature],
330 ) -> Result<(), Self::Err> {
331 let mut transaction = self.pool.begin().await.map_err(Error::from)?;
332 for (message, signature) in blinded_messages.iter().zip(blind_signatures) {
333 let res = sqlx::query(
334 r#"
335INSERT INTO blind_signature
336(y, amount, keyset_id, c)
337VALUES (?, ?, ?, ?);
338 "#,
339 )
340 .bind(message.to_bytes().to_vec())
341 .bind(u64::from(signature.amount) as i64)
342 .bind(signature.keyset_id.to_string())
343 .bind(signature.c.to_bytes().to_vec())
344 .execute(&mut *transaction)
345 .await;
346
347 if let Err(err) = res {
348 tracing::error!("SQLite could not add blind signature");
349 if let Err(err) = transaction.rollback().await {
350 tracing::error!("Could not rollback sql transaction: {}", err);
351 }
352 return Err(Error::SQLX(err).into());
353 }
354 }
355
356 transaction.commit().await.map_err(Error::from)?;
357
358 Ok(())
359 }
360
361 async fn get_blind_signatures(
362 &self,
363 blinded_messages: &[PublicKey],
364 ) -> Result<Vec<Option<BlindSignature>>, Self::Err> {
365 let mut transaction = self.pool.begin().await.map_err(Error::from)?;
366
367 let sql = format!(
368 "SELECT * FROM blind_signature WHERE y IN ({})",
369 "?,".repeat(blinded_messages.len()).trim_end_matches(',')
370 );
371
372 let mut blinded_signatures = blinded_messages
373 .iter()
374 .fold(sqlx::query(&sql), |query, y| {
375 query.bind(y.to_bytes().to_vec())
376 })
377 .fetch_all(&mut *transaction)
378 .await
379 .map_err(|err| {
380 tracing::error!("SQLite could not get state of proof: {err:?}");
381 Error::SQLX(err)
382 })?
383 .into_iter()
384 .map(|row| {
385 PublicKey::from_slice(row.get("y"))
386 .map_err(Error::from)
387 .and_then(|y| sqlite_row_to_blind_signature(row).map(|blinded| (y, blinded)))
388 })
389 .collect::<Result<HashMap<_, _>, _>>()?;
390
391 Ok(blinded_messages
392 .iter()
393 .map(|y| blinded_signatures.remove(y))
394 .collect())
395 }
396
397 async fn add_protected_endpoints(
398 &self,
399 protected_endpoints: HashMap<ProtectedEndpoint, AuthRequired>,
400 ) -> Result<(), Self::Err> {
401 let mut transaction = self.pool.begin().await.map_err(Error::from)?;
402
403 for (endpoint, auth) in protected_endpoints.iter() {
404 if let Err(err) = sqlx::query(
405 r#"
406INSERT OR REPLACE INTO protected_endpoints
407(endpoint, auth)
408VALUES (?, ?);
409 "#,
410 )
411 .bind(serde_json::to_string(endpoint)?)
412 .bind(serde_json::to_string(auth)?)
413 .execute(&mut *transaction)
414 .await
415 .map_err(Error::from)
416 {
417 tracing::debug!(
418 "Attempting to add protected endpoint. Skipping.... {:?}",
419 err
420 );
421 }
422 }
423
424 transaction.commit().await.map_err(Error::from)?;
425
426 Ok(())
427 }
428 async fn remove_protected_endpoints(
429 &self,
430 protected_endpoints: Vec<ProtectedEndpoint>,
431 ) -> Result<(), Self::Err> {
432 let mut transaction = self.pool.begin().await.map_err(Error::from)?;
433
434 let sql = format!(
435 "DELETE FROM protected_endpoints WHERE endpoint IN ({})",
436 std::iter::repeat("?")
437 .take(protected_endpoints.len())
438 .collect::<Vec<_>>()
439 .join(",")
440 );
441
442 let endpoints = protected_endpoints
443 .iter()
444 .map(serde_json::to_string)
445 .collect::<Result<Vec<_>, _>>()?;
446
447 endpoints
448 .iter()
449 .fold(sqlx::query(&sql), |query, endpoint| query.bind(endpoint))
450 .execute(&mut *transaction)
451 .await
452 .map_err(Error::from)?;
453
454 transaction.commit().await.map_err(Error::from)?;
455 Ok(())
456 }
457 async fn get_auth_for_endpoint(
458 &self,
459 protected_endpoint: ProtectedEndpoint,
460 ) -> Result<Option<AuthRequired>, Self::Err> {
461 let mut transaction = self.pool.begin().await.map_err(Error::from)?;
462
463 let rec = sqlx::query(
464 r#"
465SELECT *
466FROM protected_endpoints
467WHERE endpoint=?;
468 "#,
469 )
470 .bind(serde_json::to_string(&protected_endpoint)?)
471 .fetch_one(&mut *transaction)
472 .await;
473
474 match rec {
475 Ok(rec) => {
476 transaction.commit().await.map_err(Error::from)?;
477
478 let auth: String = rec.try_get("auth").map_err(Error::from)?;
479
480 Ok(Some(serde_json::from_str(&auth)?))
481 }
482 Err(err) => match err {
483 sqlx::Error::RowNotFound => {
484 transaction.commit().await.map_err(Error::from)?;
485 return Ok(None);
486 }
487 _ => {
488 return {
489 if let Err(err) = transaction.rollback().await {
490 tracing::error!("Could not rollback sql transaction: {}", err);
491 }
492 Err(Error::SQLX(err).into())
493 }
494 }
495 },
496 }
497 }
498 async fn get_auth_for_endpoints(
499 &self,
500 ) -> Result<HashMap<ProtectedEndpoint, Option<AuthRequired>>, Self::Err> {
501 let mut transaction = self.pool.begin().await.map_err(Error::from)?;
502
503 let recs = sqlx::query(
504 r#"
505SELECT *
506FROM protected_endpoints
507 "#,
508 )
509 .fetch_all(&mut *transaction)
510 .await;
511
512 match recs {
513 Ok(recs) => {
514 transaction.commit().await.map_err(Error::from)?;
515
516 let mut endpoints = HashMap::new();
517
518 for rec in recs {
519 let auth: String = rec.try_get("auth").map_err(Error::from)?;
520 let endpoint: String = rec.try_get("endpoint").map_err(Error::from)?;
521
522 let endpoint: ProtectedEndpoint = serde_json::from_str(&endpoint)?;
523 let auth: AuthRequired = serde_json::from_str(&auth)?;
524
525 endpoints.insert(endpoint, Some(auth));
526 }
527
528 Ok(endpoints)
529 }
530 Err(err) => {
531 tracing::error!("SQLite could not get protected endpoints");
532 if let Err(err) = transaction.rollback().await {
533 tracing::error!("Could not rollback sql transaction: {}", err);
534 }
535 Err(Error::from(err).into())
536 }
537 }
538 }
539}