Skip to main content

cdk_sql_common/mint/auth/
mod.rs

1//! SQL Mint Auth
2
3use std::collections::HashMap;
4use std::fmt::Debug;
5use std::str::FromStr;
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use cdk_common::database::{self, MintAuthDatabase, MintAuthTransaction};
10use cdk_common::mint::MintKeySetInfo;
11use cdk_common::nuts::{AuthProof, BlindSignature, Id, PublicKey, State};
12use cdk_common::{AuthRequired, ProtectedEndpoint};
13use migrations::MIGRATIONS;
14use tracing::instrument;
15
16use super::SQLTransaction;
17use crate::column_as_string;
18use crate::common::migrate;
19use crate::database::{ConnectionWithTransaction, DatabaseExecutor};
20use crate::mint::keys::sql_row_to_keyset_info;
21use crate::mint::signatures::sql_row_to_blind_signature;
22use crate::mint::Error;
23use crate::pool::{DatabasePool, Pool, PooledResource};
24use crate::stmt::query;
25
26/// Mint SQL Database
27#[derive(Debug, Clone)]
28pub struct SQLMintAuthDatabase<RM>
29where
30    RM: DatabasePool + 'static,
31{
32    pool: Arc<Pool<RM>>,
33}
34
35impl<RM> SQLMintAuthDatabase<RM>
36where
37    RM: DatabasePool + 'static,
38{
39    /// Creates a new instance
40    pub async fn new<X>(db: X) -> Result<Self, Error>
41    where
42        X: Into<RM::Config>,
43    {
44        let pool = Pool::new(db.into());
45        Self::migrate(pool.get().await.map_err(|e| Error::Database(Box::new(e)))?).await?;
46        Ok(Self { pool })
47    }
48
49    /// Migrate
50    async fn migrate(conn: PooledResource<RM>) -> Result<(), Error> {
51        let tx = ConnectionWithTransaction::new(conn).await?;
52        migrate(&tx, RM::Connection::name(), MIGRATIONS).await?;
53        tx.commit().await?;
54        Ok(())
55    }
56}
57
58#[rustfmt::skip]
59mod migrations {
60    include!(concat!(env!("OUT_DIR"), "/migrations_mint_auth.rs"));
61}
62
63#[async_trait]
64impl<RM> MintAuthTransaction<database::Error> for SQLTransaction<RM>
65where
66    RM: DatabasePool + 'static,
67{
68    #[instrument(skip(self))]
69    async fn set_active_keyset(&mut self, id: Id) -> Result<(), database::Error> {
70        tracing::info!("Setting auth keyset {id} active");
71        query(
72            r#"
73            UPDATE keyset
74            SET active = CASE
75                WHEN id = :id THEN TRUE
76                ELSE FALSE
77            END;
78            "#,
79        )?
80        .bind("id", id.to_string())
81        .execute(&self.inner)
82        .await?;
83
84        Ok(())
85    }
86
87    async fn add_keyset_info(&mut self, keyset: MintKeySetInfo) -> Result<(), database::Error> {
88        query(
89            r#"
90        INSERT INTO
91            keyset (
92                id, unit, active, valid_from, valid_to, derivation_path,
93                amounts, input_fee_ppk, derivation_path_index
94            )
95        VALUES (
96            :id, :unit, :active, :valid_from, :valid_to, :derivation_path,
97            :amounts, :input_fee_ppk, :derivation_path_index
98        )
99        ON CONFLICT(id) DO UPDATE SET
100            unit = excluded.unit,
101            active = excluded.active,
102            valid_from = excluded.valid_from,
103            valid_to = excluded.valid_to,
104            derivation_path = excluded.derivation_path,
105            amounts = excluded.amounts,
106            input_fee_ppk = excluded.input_fee_ppk,
107            derivation_path_index = excluded.derivation_path_index
108        "#,
109        )?
110        .bind("id", keyset.id.to_string())
111        .bind("unit", keyset.unit.to_string())
112        .bind("active", keyset.active)
113        .bind("valid_from", keyset.valid_from as i64)
114        .bind("valid_to", keyset.final_expiry.map(|v| v as i64))
115        .bind("derivation_path", keyset.derivation_path.to_string())
116        .bind("amounts", serde_json::to_string(&keyset.amounts).ok())
117        .bind("input_fee_ppk", keyset.input_fee_ppk as i64)
118        .bind("derivation_path_index", keyset.derivation_path_index)
119        .execute(&self.inner)
120        .await?;
121
122        Ok(())
123    }
124
125    async fn add_proof(&mut self, proof: AuthProof) -> Result<(), database::Error> {
126        let y = proof.y()?;
127        if let Err(err) = query(
128            r#"
129                INSERT INTO proof
130                (y, keyset_id, secret, c, state)
131                VALUES
132                (:y, :keyset_id, :secret, :c, :state)
133                "#,
134        )?
135        .bind("y", y.to_bytes().to_vec())
136        .bind("keyset_id", proof.keyset_id.to_string())
137        .bind("secret", proof.secret.to_string())
138        .bind("c", proof.c.to_bytes().to_vec())
139        .bind("state", "UNSPENT".to_string())
140        .execute(&self.inner)
141        .await
142        {
143            tracing::debug!("Attempting to add known proof. Skipping.... {:?}", err);
144        }
145        Ok(())
146    }
147
148    async fn update_proof_state(
149        &mut self,
150        y: &PublicKey,
151        proofs_state: State,
152    ) -> Result<Option<State>, Self::Err> {
153        let current_state = query(r#"SELECT state FROM proof WHERE y = :y FOR UPDATE"#)?
154            .bind("y", y.to_bytes().to_vec())
155            .pluck(&self.inner)
156            .await?
157            .map(|state| Ok::<_, Error>(column_as_string!(state, State::from_str)))
158            .transpose()?;
159
160        query(r#"UPDATE proof SET state = :new_state WHERE  y = :y"#)?
161            .bind("y", y.to_bytes().to_vec())
162            .bind("new_state", proofs_state.to_string())
163            .execute(&self.inner)
164            .await?;
165
166        Ok(current_state)
167    }
168
169    async fn add_blind_signatures(
170        &mut self,
171        blinded_messages: &[PublicKey],
172        blind_signatures: &[BlindSignature],
173    ) -> Result<(), database::Error> {
174        for (message, signature) in blinded_messages.iter().zip(blind_signatures) {
175            query(
176                r#"
177                       INSERT
178                       INTO blind_signature
179                       (blinded_message, amount, keyset_id, c)
180                       VALUES
181                       (:blinded_message, :amount, :keyset_id, :c)
182                   "#,
183            )?
184            .bind("blinded_message", message.to_bytes().to_vec())
185            .bind("amount", u64::from(signature.amount) as i64)
186            .bind("keyset_id", signature.keyset_id.to_string())
187            .bind("c", signature.c.to_bytes().to_vec())
188            .execute(&self.inner)
189            .await?;
190        }
191
192        Ok(())
193    }
194
195    async fn add_protected_endpoints(
196        &mut self,
197        protected_endpoints: HashMap<ProtectedEndpoint, AuthRequired>,
198    ) -> Result<(), database::Error> {
199        for (endpoint, auth) in protected_endpoints.iter() {
200            if let Err(err) = query(
201                r#"
202                 INSERT INTO protected_endpoints
203                 (endpoint, auth)
204                 VALUES (:endpoint, :auth)
205                 ON CONFLICT (endpoint) DO UPDATE SET
206                 auth = EXCLUDED.auth;
207                 "#,
208            )?
209            .bind("endpoint", serde_json::to_string(endpoint)?)
210            .bind("auth", serde_json::to_string(auth)?)
211            .execute(&self.inner)
212            .await
213            {
214                tracing::debug!(
215                    "Attempting to add protected endpoint. Skipping.... {:?}",
216                    err
217                );
218            }
219        }
220
221        Ok(())
222    }
223    async fn remove_protected_endpoints(
224        &mut self,
225        protected_endpoints: Vec<ProtectedEndpoint>,
226    ) -> Result<(), database::Error> {
227        query(r#"DELETE FROM protected_endpoints WHERE endpoint IN (:endpoints)"#)?
228            .bind_vec(
229                "endpoints",
230                protected_endpoints
231                    .iter()
232                    .map(serde_json::to_string)
233                    .collect::<Result<_, _>>()?,
234            )?
235            .execute(&self.inner)
236            .await?;
237        Ok(())
238    }
239}
240
241#[async_trait]
242impl<RM> MintAuthDatabase for SQLMintAuthDatabase<RM>
243where
244    RM: DatabasePool + 'static,
245{
246    type Err = database::Error;
247
248    async fn begin_transaction<'a>(
249        &'a self,
250    ) -> Result<Box<dyn MintAuthTransaction<database::Error> + Send + Sync + 'a>, database::Error>
251    {
252        Ok(Box::new(SQLTransaction {
253            inner: ConnectionWithTransaction::new(
254                self.pool
255                    .get()
256                    .await
257                    .map_err(|e| Error::Database(Box::new(e)))?,
258            )
259            .await?,
260        }))
261    }
262
263    async fn get_active_keyset_id(&self) -> Result<Option<Id>, Self::Err> {
264        let conn = self
265            .pool
266            .get()
267            .await
268            .map_err(|e| Error::Database(Box::new(e)))?;
269        Ok(query(
270            r#"
271            SELECT
272                id
273            FROM
274                keyset
275            WHERE
276                active = :active;
277            "#,
278        )?
279        .bind("active", true)
280        .pluck(&*conn)
281        .await?
282        .map(|id| Ok::<_, Error>(column_as_string!(id, Id::from_str, Id::from_bytes)))
283        .transpose()?)
284    }
285
286    async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err> {
287        let conn = self
288            .pool
289            .get()
290            .await
291            .map_err(|e| Error::Database(Box::new(e)))?;
292        Ok(query(
293            r#"SELECT
294                id,
295                unit,
296                active,
297                valid_from,
298                valid_to,
299                derivation_path,
300                derivation_path_index,
301                amounts,
302                input_fee_ppk
303            FROM
304                keyset
305                WHERE id=:id"#,
306        )?
307        .bind("id", id.to_string())
308        .fetch_one(&*conn)
309        .await?
310        .map(sql_row_to_keyset_info)
311        .transpose()?)
312    }
313
314    async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err> {
315        let conn = self
316            .pool
317            .get()
318            .await
319            .map_err(|e| Error::Database(Box::new(e)))?;
320        Ok(query(
321            r#"SELECT
322                id,
323                unit,
324                active,
325                valid_from,
326                valid_to,
327                derivation_path,
328                derivation_path_index,
329                amounts,
330                input_fee_ppk
331            FROM
332                keyset
333                WHERE id=:id"#,
334        )?
335        .fetch_all(&*conn)
336        .await?
337        .into_iter()
338        .map(sql_row_to_keyset_info)
339        .collect::<Result<Vec<_>, _>>()?)
340    }
341
342    async fn get_proofs_states(&self, ys: &[PublicKey]) -> Result<Vec<Option<State>>, Self::Err> {
343        let conn = self
344            .pool
345            .get()
346            .await
347            .map_err(|e| Error::Database(Box::new(e)))?;
348        let mut current_states = query(r#"SELECT y, state FROM proof WHERE y IN (:ys)"#)?
349            .bind_vec("ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())?
350            .fetch_all(&*conn)
351            .await?
352            .into_iter()
353            .map(|row| {
354                Ok((
355                    column_as_string!(&row[0], PublicKey::from_hex, PublicKey::from_slice),
356                    column_as_string!(&row[1], State::from_str),
357                ))
358            })
359            .collect::<Result<HashMap<_, _>, Error>>()?;
360
361        Ok(ys.iter().map(|y| current_states.remove(y)).collect())
362    }
363
364    async fn get_blind_signatures(
365        &self,
366        blinded_messages: &[PublicKey],
367    ) -> Result<Vec<Option<BlindSignature>>, Self::Err> {
368        let conn = self
369            .pool
370            .get()
371            .await
372            .map_err(|e| Error::Database(Box::new(e)))?;
373        let mut blinded_signatures = query(
374            r#"SELECT
375                keyset_id,
376                amount,
377                c,
378                dleq_e,
379                dleq_s,
380                blinded_message,
381            FROM
382                blind_signature
383            WHERE blinded_message IN (:blinded_message)
384            "#,
385        )?
386        .bind_vec(
387            "blinded_message",
388            blinded_messages
389                .iter()
390                .map(|bm| bm.to_bytes().to_vec())
391                .collect(),
392        )?
393        .fetch_all(&*conn)
394        .await?
395        .into_iter()
396        .map(|mut row| {
397            Ok((
398                column_as_string!(
399                    &row.pop().ok_or(Error::InvalidDbResponse)?,
400                    PublicKey::from_hex,
401                    PublicKey::from_slice
402                ),
403                sql_row_to_blind_signature(row)?,
404            ))
405        })
406        .collect::<Result<HashMap<_, _>, Error>>()?;
407        Ok(blinded_messages
408            .iter()
409            .map(|bm| blinded_signatures.remove(bm))
410            .collect())
411    }
412
413    async fn get_auth_for_endpoint(
414        &self,
415        protected_endpoint: ProtectedEndpoint,
416    ) -> Result<Option<AuthRequired>, Self::Err> {
417        let conn = self
418            .pool
419            .get()
420            .await
421            .map_err(|e| Error::Database(Box::new(e)))?;
422        Ok(
423            query(r#"SELECT auth FROM protected_endpoints WHERE endpoint = :endpoint"#)?
424                .bind("endpoint", serde_json::to_string(&protected_endpoint)?)
425                .pluck(&*conn)
426                .await?
427                .map(|auth| {
428                    Ok::<_, Error>(column_as_string!(
429                        auth,
430                        serde_json::from_str,
431                        serde_json::from_slice
432                    ))
433                })
434                .transpose()?,
435        )
436    }
437
438    async fn get_auth_for_endpoints(
439        &self,
440    ) -> Result<HashMap<ProtectedEndpoint, Option<AuthRequired>>, Self::Err> {
441        let conn = self
442            .pool
443            .get()
444            .await
445            .map_err(|e| Error::Database(Box::new(e)))?;
446        Ok(query(r#"SELECT endpoint, auth FROM protected_endpoints"#)?
447            .fetch_all(&*conn)
448            .await?
449            .into_iter()
450            .map(|row| {
451                let endpoint =
452                    column_as_string!(&row[0], serde_json::from_str, serde_json::from_slice);
453                let auth = column_as_string!(&row[1], serde_json::from_str, serde_json::from_slice);
454                Ok((endpoint, Some(auth)))
455            })
456            .collect::<Result<HashMap<_, _>, Error>>()?)
457    }
458}