cdk_sqlite/mint/auth/
mod.rs

1//! SQLite Mint Auth
2
3use std::collections::HashMap;
4use std::path::Path;
5use std::str::FromStr;
6use std::time::Duration;
7
8use async_trait::async_trait;
9use cdk_common::database::{self, MintAuthDatabase};
10use cdk_common::mint::MintKeySetInfo;
11use cdk_common::nuts::{AuthProof, BlindSignature, Id, PublicKey, State};
12use cdk_common::{AuthRequired, ProtectedEndpoint};
13use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
14use sqlx::Row;
15use tracing::instrument;
16
17use super::{sqlite_row_to_blind_signature, sqlite_row_to_keyset_info};
18use crate::mint::Error;
19
20/// Mint SQLite Database
21#[derive(Debug, Clone)]
22pub struct MintSqliteAuthDatabase {
23    pool: SqlitePool,
24}
25
26impl MintSqliteAuthDatabase {
27    /// Create new [`MintSqliteAuthDatabase`]
28    pub async fn new(path: &Path) -> Result<Self, Error> {
29        let path = path.to_str().ok_or(Error::InvalidDbPath)?;
30        let db_options = SqliteConnectOptions::from_str(path)?
31            .busy_timeout(Duration::from_secs(5))
32            .read_only(false)
33            .create_if_missing(true)
34            .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Full);
35
36        let pool = SqlitePoolOptions::new()
37            .max_connections(1)
38            .connect_with(db_options)
39            .await?;
40
41        Ok(Self { pool })
42    }
43
44    /// Migrate [`MintSqliteAuthDatabase`]
45    pub async fn migrate(&self) {
46        sqlx::migrate!("./src/mint/auth/migrations")
47            .run(&self.pool)
48            .await
49            .expect("Could not run migrations");
50    }
51}
52
53#[async_trait]
54impl MintAuthDatabase for MintSqliteAuthDatabase {
55    type Err = database::Error;
56
57    #[instrument(skip(self))]
58    async fn set_active_keyset(&self, id: Id) -> Result<(), Self::Err> {
59        tracing::info!("Setting auth keyset {id} active");
60        let mut transaction = self.pool.begin().await.map_err(Error::from)?;
61        let update_res = sqlx::query(
62            r#"
63    UPDATE keyset 
64    SET active = CASE 
65        WHEN id = ? THEN TRUE
66        ELSE FALSE
67    END;
68    "#,
69        )
70        .bind(id.to_string())
71        .execute(&mut *transaction)
72        .await;
73
74        match update_res {
75            Ok(_) => {
76                transaction.commit().await.map_err(Error::from)?;
77                Ok(())
78            }
79            Err(err) => {
80                tracing::error!("SQLite Could not update keyset");
81                if let Err(err) = transaction.rollback().await {
82                    tracing::error!("Could not rollback sql transaction: {}", err);
83                }
84                Err(Error::from(err).into())
85            }
86        }
87    }
88
89    async fn get_active_keyset_id(&self) -> Result<Option<Id>, Self::Err> {
90        let mut transaction = self.pool.begin().await.map_err(Error::from)?;
91
92        let rec = sqlx::query(
93            r#"
94SELECT id
95FROM keyset
96WHERE active = 1;
97        "#,
98        )
99        .fetch_one(&mut *transaction)
100        .await;
101
102        let rec = match rec {
103            Ok(rec) => {
104                transaction.commit().await.map_err(Error::from)?;
105                rec
106            }
107            Err(err) => match err {
108                sqlx::Error::RowNotFound => {
109                    transaction.commit().await.map_err(Error::from)?;
110                    return Ok(None);
111                }
112                _ => {
113                    return {
114                        if let Err(err) = transaction.rollback().await {
115                            tracing::error!("Could not rollback sql transaction: {}", err);
116                        }
117                        Err(Error::SQLX(err).into())
118                    }
119                }
120            },
121        };
122
123        Ok(Some(
124            Id::from_str(rec.try_get("id").map_err(Error::from)?).map_err(Error::from)?,
125        ))
126    }
127
128    async fn add_keyset_info(&self, keyset: MintKeySetInfo) -> Result<(), Self::Err> {
129        let mut transaction = self.pool.begin().await.map_err(Error::from)?;
130        let res = sqlx::query(
131            r#"
132INSERT OR REPLACE INTO keyset
133(id, unit, active, valid_from, valid_to, derivation_path, max_order, derivation_path_index)
134VALUES (?, ?, ?, ?, ?, ?, ?, ?);
135        "#,
136        )
137        .bind(keyset.id.to_string())
138        .bind(keyset.unit.to_string())
139        .bind(keyset.active)
140        .bind(keyset.valid_from as i64)
141        .bind(keyset.valid_to.map(|v| v as i64))
142        .bind(keyset.derivation_path.to_string())
143        .bind(keyset.max_order)
144        .bind(keyset.derivation_path_index)
145        .execute(&mut *transaction)
146        .await;
147
148        match res {
149            Ok(_) => {
150                transaction.commit().await.map_err(Error::from)?;
151                Ok(())
152            }
153            Err(err) => {
154                tracing::error!("SQLite could not add keyset info");
155                if let Err(err) = transaction.rollback().await {
156                    tracing::error!("Could not rollback sql transaction: {}", err);
157                }
158
159                Err(Error::from(err).into())
160            }
161        }
162    }
163
164    async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err> {
165        let mut transaction = self.pool.begin().await.map_err(Error::from)?;
166        let rec = sqlx::query(
167            r#"
168SELECT *
169FROM keyset
170WHERE id=?;
171        "#,
172        )
173        .bind(id.to_string())
174        .fetch_one(&mut *transaction)
175        .await;
176
177        match rec {
178            Ok(rec) => {
179                transaction.commit().await.map_err(Error::from)?;
180                Ok(Some(sqlite_row_to_keyset_info(rec)?))
181            }
182            Err(err) => match err {
183                sqlx::Error::RowNotFound => {
184                    transaction.commit().await.map_err(Error::from)?;
185                    return Ok(None);
186                }
187                _ => {
188                    tracing::error!("SQLite could not get keyset info");
189                    if let Err(err) = transaction.rollback().await {
190                        tracing::error!("Could not rollback sql transaction: {}", err);
191                    }
192                    return Err(Error::SQLX(err).into());
193                }
194            },
195        }
196    }
197
198    async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err> {
199        let mut transaction = self.pool.begin().await.map_err(Error::from)?;
200        let recs = sqlx::query(
201            r#"
202SELECT *
203FROM keyset;
204        "#,
205        )
206        .fetch_all(&mut *transaction)
207        .await
208        .map_err(Error::from);
209
210        match recs {
211            Ok(recs) => {
212                transaction.commit().await.map_err(Error::from)?;
213                Ok(recs
214                    .into_iter()
215                    .map(sqlite_row_to_keyset_info)
216                    .collect::<Result<_, _>>()?)
217            }
218            Err(err) => {
219                tracing::error!("SQLite could not get keyset info");
220                if let Err(err) = transaction.rollback().await {
221                    tracing::error!("Could not rollback sql transaction: {}", err);
222                }
223                Err(err.into())
224            }
225        }
226    }
227
228    async fn add_proof(&self, proof: AuthProof) -> Result<(), Self::Err> {
229        let mut transaction = self.pool.begin().await.map_err(Error::from)?;
230        if let Err(err) = sqlx::query(
231            r#"
232INSERT INTO proof
233(y, keyset_id, secret, c, state)
234VALUES (?, ?, ?, ?, ?);
235        "#,
236        )
237        .bind(proof.y()?.to_bytes().to_vec())
238        .bind(proof.keyset_id.to_string())
239        .bind(proof.secret.to_string())
240        .bind(proof.c.to_bytes().to_vec())
241        .bind("UNSPENT")
242        .execute(&mut *transaction)
243        .await
244        .map_err(Error::from)
245        {
246            tracing::debug!("Attempting to add known proof. Skipping.... {:?}", err);
247        }
248        transaction.commit().await.map_err(Error::from)?;
249
250        Ok(())
251    }
252
253    async fn get_proofs_states(&self, ys: &[PublicKey]) -> Result<Vec<Option<State>>, Self::Err> {
254        let mut transaction = self.pool.begin().await.map_err(Error::from)?;
255
256        let sql = format!(
257            "SELECT y, state FROM proof WHERE y IN ({})",
258            "?,".repeat(ys.len()).trim_end_matches(',')
259        );
260
261        let mut current_states = ys
262            .iter()
263            .fold(sqlx::query(&sql), |query, y| {
264                query.bind(y.to_bytes().to_vec())
265            })
266            .fetch_all(&mut *transaction)
267            .await
268            .map_err(|err| {
269                tracing::error!("SQLite could not get state of proof: {err:?}");
270                Error::SQLX(err)
271            })?
272            .into_iter()
273            .map(|row| {
274                PublicKey::from_slice(row.get("y"))
275                    .map_err(Error::from)
276                    .and_then(|y| {
277                        let state: String = row.get("state");
278                        State::from_str(&state)
279                            .map_err(Error::from)
280                            .map(|state| (y, state))
281                    })
282            })
283            .collect::<Result<HashMap<_, _>, _>>()?;
284
285        Ok(ys.iter().map(|y| current_states.remove(y)).collect())
286    }
287
288    async fn update_proof_state(
289        &self,
290        y: &PublicKey,
291        proofs_state: State,
292    ) -> Result<Option<State>, Self::Err> {
293        let mut transaction = self.pool.begin().await.map_err(Error::from)?;
294
295        // Get current state for single y
296        let current_state = sqlx::query("SELECT state FROM proof WHERE y = ?")
297            .bind(y.to_bytes().to_vec())
298            .fetch_optional(&mut *transaction)
299            .await
300            .map_err(|err| {
301                tracing::error!("SQLite could not get state of proof: {err:?}");
302                Error::SQLX(err)
303            })?
304            .map(|row| {
305                let state: String = row.get("state");
306                State::from_str(&state).map_err(Error::from)
307            })
308            .transpose()?;
309
310        // Update state for single y
311        sqlx::query("UPDATE proof SET state = ? WHERE state != ? AND y = ?")
312            .bind(proofs_state.to_string())
313            .bind(State::Spent.to_string())
314            .bind(y.to_bytes().to_vec())
315            .execute(&mut *transaction)
316            .await
317            .map_err(|err| {
318                tracing::error!("SQLite could not update proof state: {err:?}");
319                Error::SQLX(err)
320            })?;
321
322        transaction.commit().await.map_err(Error::from)?;
323        Ok(current_state)
324    }
325
326    async fn add_blind_signatures(
327        &self,
328        blinded_messages: &[PublicKey],
329        blind_signatures: &[BlindSignature],
330    ) -> Result<(), Self::Err> {
331        let mut transaction = self.pool.begin().await.map_err(Error::from)?;
332        for (message, signature) in blinded_messages.iter().zip(blind_signatures) {
333            let res = sqlx::query(
334                r#"
335INSERT INTO blind_signature
336(y, amount, keyset_id, c)
337VALUES (?, ?, ?, ?);
338        "#,
339            )
340            .bind(message.to_bytes().to_vec())
341            .bind(u64::from(signature.amount) as i64)
342            .bind(signature.keyset_id.to_string())
343            .bind(signature.c.to_bytes().to_vec())
344            .execute(&mut *transaction)
345            .await;
346
347            if let Err(err) = res {
348                tracing::error!("SQLite could not add blind signature");
349                if let Err(err) = transaction.rollback().await {
350                    tracing::error!("Could not rollback sql transaction: {}", err);
351                }
352                return Err(Error::SQLX(err).into());
353            }
354        }
355
356        transaction.commit().await.map_err(Error::from)?;
357
358        Ok(())
359    }
360
361    async fn get_blind_signatures(
362        &self,
363        blinded_messages: &[PublicKey],
364    ) -> Result<Vec<Option<BlindSignature>>, Self::Err> {
365        let mut transaction = self.pool.begin().await.map_err(Error::from)?;
366
367        let sql = format!(
368            "SELECT * FROM blind_signature WHERE y IN ({})",
369            "?,".repeat(blinded_messages.len()).trim_end_matches(',')
370        );
371
372        let mut blinded_signatures = blinded_messages
373            .iter()
374            .fold(sqlx::query(&sql), |query, y| {
375                query.bind(y.to_bytes().to_vec())
376            })
377            .fetch_all(&mut *transaction)
378            .await
379            .map_err(|err| {
380                tracing::error!("SQLite could not get state of proof: {err:?}");
381                Error::SQLX(err)
382            })?
383            .into_iter()
384            .map(|row| {
385                PublicKey::from_slice(row.get("y"))
386                    .map_err(Error::from)
387                    .and_then(|y| sqlite_row_to_blind_signature(row).map(|blinded| (y, blinded)))
388            })
389            .collect::<Result<HashMap<_, _>, _>>()?;
390
391        Ok(blinded_messages
392            .iter()
393            .map(|y| blinded_signatures.remove(y))
394            .collect())
395    }
396
397    async fn add_protected_endpoints(
398        &self,
399        protected_endpoints: HashMap<ProtectedEndpoint, AuthRequired>,
400    ) -> Result<(), Self::Err> {
401        let mut transaction = self.pool.begin().await.map_err(Error::from)?;
402
403        for (endpoint, auth) in protected_endpoints.iter() {
404            if let Err(err) = sqlx::query(
405                r#"
406INSERT OR REPLACE INTO protected_endpoints
407(endpoint, auth)
408VALUES (?, ?);
409        "#,
410            )
411            .bind(serde_json::to_string(endpoint)?)
412            .bind(serde_json::to_string(auth)?)
413            .execute(&mut *transaction)
414            .await
415            .map_err(Error::from)
416            {
417                tracing::debug!(
418                    "Attempting to add protected endpoint. Skipping.... {:?}",
419                    err
420                );
421            }
422        }
423
424        transaction.commit().await.map_err(Error::from)?;
425
426        Ok(())
427    }
428    async fn remove_protected_endpoints(
429        &self,
430        protected_endpoints: Vec<ProtectedEndpoint>,
431    ) -> Result<(), Self::Err> {
432        let mut transaction = self.pool.begin().await.map_err(Error::from)?;
433
434        let sql = format!(
435            "DELETE FROM protected_endpoints WHERE endpoint IN ({})",
436            std::iter::repeat("?")
437                .take(protected_endpoints.len())
438                .collect::<Vec<_>>()
439                .join(",")
440        );
441
442        let endpoints = protected_endpoints
443            .iter()
444            .map(serde_json::to_string)
445            .collect::<Result<Vec<_>, _>>()?;
446
447        endpoints
448            .iter()
449            .fold(sqlx::query(&sql), |query, endpoint| query.bind(endpoint))
450            .execute(&mut *transaction)
451            .await
452            .map_err(Error::from)?;
453
454        transaction.commit().await.map_err(Error::from)?;
455        Ok(())
456    }
457    async fn get_auth_for_endpoint(
458        &self,
459        protected_endpoint: ProtectedEndpoint,
460    ) -> Result<Option<AuthRequired>, Self::Err> {
461        let mut transaction = self.pool.begin().await.map_err(Error::from)?;
462
463        let rec = sqlx::query(
464            r#"
465SELECT *
466FROM protected_endpoints
467WHERE endpoint=?;
468        "#,
469        )
470        .bind(serde_json::to_string(&protected_endpoint)?)
471        .fetch_one(&mut *transaction)
472        .await;
473
474        match rec {
475            Ok(rec) => {
476                transaction.commit().await.map_err(Error::from)?;
477
478                let auth: String = rec.try_get("auth").map_err(Error::from)?;
479
480                Ok(Some(serde_json::from_str(&auth)?))
481            }
482            Err(err) => match err {
483                sqlx::Error::RowNotFound => {
484                    transaction.commit().await.map_err(Error::from)?;
485                    return Ok(None);
486                }
487                _ => {
488                    return {
489                        if let Err(err) = transaction.rollback().await {
490                            tracing::error!("Could not rollback sql transaction: {}", err);
491                        }
492                        Err(Error::SQLX(err).into())
493                    }
494                }
495            },
496        }
497    }
498    async fn get_auth_for_endpoints(
499        &self,
500    ) -> Result<HashMap<ProtectedEndpoint, Option<AuthRequired>>, Self::Err> {
501        let mut transaction = self.pool.begin().await.map_err(Error::from)?;
502
503        let recs = sqlx::query(
504            r#"
505SELECT *
506FROM protected_endpoints
507        "#,
508        )
509        .fetch_all(&mut *transaction)
510        .await;
511
512        match recs {
513            Ok(recs) => {
514                transaction.commit().await.map_err(Error::from)?;
515
516                let mut endpoints = HashMap::new();
517
518                for rec in recs {
519                    let auth: String = rec.try_get("auth").map_err(Error::from)?;
520                    let endpoint: String = rec.try_get("endpoint").map_err(Error::from)?;
521
522                    let endpoint: ProtectedEndpoint = serde_json::from_str(&endpoint)?;
523                    let auth: AuthRequired = serde_json::from_str(&auth)?;
524
525                    endpoints.insert(endpoint, Some(auth));
526                }
527
528                Ok(endpoints)
529            }
530            Err(err) => {
531                tracing::error!("SQLite could not get protected endpoints");
532                if let Err(err) = transaction.rollback().await {
533                    tracing::error!("Could not rollback sql transaction: {}", err);
534                }
535                Err(Error::from(err).into())
536            }
537        }
538    }
539}