cdk_sql_common/mint/auth/
mod.rs

1//! SQL Mint Auth
2
3use std::collections::HashMap;
4use std::fmt::Debug;
5use std::str::FromStr;
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use cdk_common::database::{self, MintAuthDatabase, MintAuthTransaction};
10use cdk_common::mint::MintKeySetInfo;
11use cdk_common::nuts::{AuthProof, BlindSignature, Id, PublicKey, State};
12use cdk_common::{AuthRequired, ProtectedEndpoint};
13use migrations::MIGRATIONS;
14use tracing::instrument;
15
16use super::{sql_row_to_blind_signature, sql_row_to_keyset_info, SQLTransaction};
17use crate::column_as_string;
18use crate::common::migrate;
19use crate::database::{ConnectionWithTransaction, DatabaseExecutor};
20use crate::mint::Error;
21use crate::pool::{DatabasePool, Pool, PooledResource};
22use crate::stmt::query;
23
24/// Mint SQL Database
25#[derive(Debug, Clone)]
26pub struct SQLMintAuthDatabase<RM>
27where
28    RM: DatabasePool + 'static,
29{
30    pool: Arc<Pool<RM>>,
31}
32
33impl<RM> SQLMintAuthDatabase<RM>
34where
35    RM: DatabasePool + 'static,
36{
37    /// Creates a new instance
38    pub async fn new<X>(db: X) -> Result<Self, Error>
39    where
40        X: Into<RM::Config>,
41    {
42        let pool = Pool::new(db.into());
43        Self::migrate(pool.get().map_err(|e| Error::Database(Box::new(e)))?).await?;
44        Ok(Self { pool })
45    }
46
47    /// Migrate
48    async fn migrate(conn: PooledResource<RM>) -> Result<(), Error> {
49        let tx = ConnectionWithTransaction::new(conn).await?;
50        migrate(&tx, RM::Connection::name(), MIGRATIONS).await?;
51        tx.commit().await?;
52        Ok(())
53    }
54}
55
56#[rustfmt::skip]
57mod migrations {
58    include!(concat!(env!("OUT_DIR"), "/migrations_mint_auth.rs"));
59}
60
61#[async_trait]
62impl<RM> MintAuthTransaction<database::Error> for SQLTransaction<RM>
63where
64    RM: DatabasePool + 'static,
65{
66    #[instrument(skip(self))]
67    async fn set_active_keyset(&mut self, id: Id) -> Result<(), database::Error> {
68        tracing::info!("Setting auth keyset {id} active");
69        query(
70            r#"
71            UPDATE keyset
72            SET active = CASE
73                WHEN id = :id THEN TRUE
74                ELSE FALSE
75            END;
76            "#,
77        )?
78        .bind("id", id.to_string())
79        .execute(&self.inner)
80        .await?;
81
82        Ok(())
83    }
84
85    async fn add_keyset_info(&mut self, keyset: MintKeySetInfo) -> Result<(), database::Error> {
86        query(
87            r#"
88        INSERT INTO
89            keyset (
90                id, unit, active, valid_from, valid_to, derivation_path,
91                max_order, derivation_path_index
92            )
93        VALUES (
94            :id, :unit, :active, :valid_from, :valid_to, :derivation_path,
95            :max_order, :derivation_path_index
96        )
97        ON CONFLICT(id) DO UPDATE SET
98            unit = excluded.unit,
99            active = excluded.active,
100            valid_from = excluded.valid_from,
101            valid_to = excluded.valid_to,
102            derivation_path = excluded.derivation_path,
103            max_order = excluded.max_order,
104            derivation_path_index = excluded.derivation_path_index
105        "#,
106        )?
107        .bind("id", keyset.id.to_string())
108        .bind("unit", keyset.unit.to_string())
109        .bind("active", keyset.active)
110        .bind("valid_from", keyset.valid_from as i64)
111        .bind("valid_to", keyset.final_expiry.map(|v| v as i64))
112        .bind("derivation_path", keyset.derivation_path.to_string())
113        .bind("max_order", keyset.max_order)
114        .bind("derivation_path_index", keyset.derivation_path_index)
115        .execute(&self.inner)
116        .await?;
117
118        Ok(())
119    }
120
121    async fn add_proof(&mut self, proof: AuthProof) -> Result<(), database::Error> {
122        if let Err(err) = query(
123            r#"
124                INSERT INTO proof
125                (y, keyset_id, secret, c, state)
126                VALUES
127                (:y, :keyset_id, :secret, :c, :state)
128                "#,
129        )?
130        .bind("y", proof.y()?.to_bytes().to_vec())
131        .bind("keyset_id", proof.keyset_id.to_string())
132        .bind("secret", proof.secret.to_string())
133        .bind("c", proof.c.to_bytes().to_vec())
134        .bind("state", "UNSPENT".to_string())
135        .execute(&self.inner)
136        .await
137        {
138            tracing::debug!("Attempting to add known proof. Skipping.... {:?}", err);
139        }
140        Ok(())
141    }
142
143    async fn update_proof_state(
144        &mut self,
145        y: &PublicKey,
146        proofs_state: State,
147    ) -> Result<Option<State>, Self::Err> {
148        let current_state = query(r#"SELECT state FROM proof WHERE y = :y FOR UPDATE"#)?
149            .bind("y", y.to_bytes().to_vec())
150            .pluck(&self.inner)
151            .await?
152            .map(|state| Ok::<_, Error>(column_as_string!(state, State::from_str)))
153            .transpose()?;
154
155        query(r#"UPDATE proof SET state = :new_state WHERE state = :state AND y = :y"#)?
156            .bind("y", y.to_bytes().to_vec())
157            .bind(
158                "state",
159                current_state.as_ref().map(|state| state.to_string()),
160            )
161            .bind("new_state", proofs_state.to_string())
162            .execute(&self.inner)
163            .await?;
164
165        Ok(current_state)
166    }
167
168    async fn add_blind_signatures(
169        &mut self,
170        blinded_messages: &[PublicKey],
171        blind_signatures: &[BlindSignature],
172    ) -> Result<(), database::Error> {
173        for (message, signature) in blinded_messages.iter().zip(blind_signatures) {
174            query(
175                r#"
176                       INSERT
177                       INTO blind_signature
178                       (blinded_message, amount, keyset_id, c)
179                       VALUES
180                       (:blinded_message, :amount, :keyset_id, :c)
181                   "#,
182            )?
183            .bind("blinded_message", message.to_bytes().to_vec())
184            .bind("amount", u64::from(signature.amount) as i64)
185            .bind("keyset_id", signature.keyset_id.to_string())
186            .bind("c", signature.c.to_bytes().to_vec())
187            .execute(&self.inner)
188            .await?;
189        }
190
191        Ok(())
192    }
193
194    async fn add_protected_endpoints(
195        &mut self,
196        protected_endpoints: HashMap<ProtectedEndpoint, AuthRequired>,
197    ) -> Result<(), database::Error> {
198        for (endpoint, auth) in protected_endpoints.iter() {
199            if let Err(err) = query(
200                r#"
201                 INSERT INTO protected_endpoints
202                 (endpoint, auth)
203                 VALUES (:endpoint, :auth)
204                 ON CONFLICT (endpoint) DO UPDATE SET
205                 auth = EXCLUDED.auth;
206                 "#,
207            )?
208            .bind("endpoint", serde_json::to_string(endpoint)?)
209            .bind("auth", serde_json::to_string(auth)?)
210            .execute(&self.inner)
211            .await
212            {
213                tracing::debug!(
214                    "Attempting to add protected endpoint. Skipping.... {:?}",
215                    err
216                );
217            }
218        }
219
220        Ok(())
221    }
222    async fn remove_protected_endpoints(
223        &mut self,
224        protected_endpoints: Vec<ProtectedEndpoint>,
225    ) -> Result<(), database::Error> {
226        query(r#"DELETE FROM protected_endpoints WHERE endpoint IN (:endpoints)"#)?
227            .bind_vec(
228                "endpoints",
229                protected_endpoints
230                    .iter()
231                    .map(serde_json::to_string)
232                    .collect::<Result<_, _>>()?,
233            )
234            .execute(&self.inner)
235            .await?;
236        Ok(())
237    }
238}
239
240#[async_trait]
241impl<RM> MintAuthDatabase for SQLMintAuthDatabase<RM>
242where
243    RM: DatabasePool + 'static,
244{
245    type Err = database::Error;
246
247    async fn begin_transaction<'a>(
248        &'a self,
249    ) -> Result<Box<dyn MintAuthTransaction<database::Error> + Send + Sync + 'a>, database::Error>
250    {
251        Ok(Box::new(SQLTransaction {
252            inner: ConnectionWithTransaction::new(
253                self.pool.get().map_err(|e| Error::Database(Box::new(e)))?,
254            )
255            .await?,
256        }))
257    }
258
259    async fn get_active_keyset_id(&self) -> Result<Option<Id>, Self::Err> {
260        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
261        Ok(query(
262            r#"
263            SELECT
264                id
265            FROM
266                keyset
267            WHERE
268                active = :active;
269            "#,
270        )?
271        .bind("active", true)
272        .pluck(&*conn)
273        .await?
274        .map(|id| Ok::<_, Error>(column_as_string!(id, Id::from_str, Id::from_bytes)))
275        .transpose()?)
276    }
277
278    async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err> {
279        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
280        Ok(query(
281            r#"SELECT
282                id,
283                unit,
284                active,
285                valid_from,
286                valid_to,
287                derivation_path,
288                derivation_path_index,
289                max_order,
290                amounts,
291                input_fee_ppk
292            FROM
293                keyset
294                WHERE id=:id"#,
295        )?
296        .bind("id", id.to_string())
297        .fetch_one(&*conn)
298        .await?
299        .map(sql_row_to_keyset_info)
300        .transpose()?)
301    }
302
303    async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err> {
304        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
305        Ok(query(
306            r#"SELECT
307                id,
308                unit,
309                active,
310                valid_from,
311                valid_to,
312                derivation_path,
313                derivation_path_index,
314                max_order,
315                amounts,
316                input_fee_ppk
317            FROM
318                keyset
319                WHERE id=:id"#,
320        )?
321        .fetch_all(&*conn)
322        .await?
323        .into_iter()
324        .map(sql_row_to_keyset_info)
325        .collect::<Result<Vec<_>, _>>()?)
326    }
327
328    async fn get_proofs_states(&self, ys: &[PublicKey]) -> Result<Vec<Option<State>>, Self::Err> {
329        if ys.is_empty() {
330            return Ok(vec![]);
331        }
332        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
333        let mut current_states = query(r#"SELECT y, state FROM proof WHERE y IN (:ys)"#)?
334            .bind_vec("ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
335            .fetch_all(&*conn)
336            .await?
337            .into_iter()
338            .map(|row| {
339                Ok((
340                    column_as_string!(&row[0], PublicKey::from_hex, PublicKey::from_slice),
341                    column_as_string!(&row[1], State::from_str),
342                ))
343            })
344            .collect::<Result<HashMap<_, _>, Error>>()?;
345
346        Ok(ys.iter().map(|y| current_states.remove(y)).collect())
347    }
348
349    async fn get_blind_signatures(
350        &self,
351        blinded_messages: &[PublicKey],
352    ) -> Result<Vec<Option<BlindSignature>>, Self::Err> {
353        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
354        let mut blinded_signatures = query(
355            r#"SELECT
356                keyset_id,
357                amount,
358                c,
359                dleq_e,
360                dleq_s,
361                blinded_message,
362            FROM
363                blind_signature
364            WHERE blinded_message IN (:blinded_message)
365            "#,
366        )?
367        .bind_vec(
368            "blinded_message",
369            blinded_messages
370                .iter()
371                .map(|bm| bm.to_bytes().to_vec())
372                .collect(),
373        )
374        .fetch_all(&*conn)
375        .await?
376        .into_iter()
377        .map(|mut row| {
378            Ok((
379                column_as_string!(
380                    &row.pop().ok_or(Error::InvalidDbResponse)?,
381                    PublicKey::from_hex,
382                    PublicKey::from_slice
383                ),
384                sql_row_to_blind_signature(row)?,
385            ))
386        })
387        .collect::<Result<HashMap<_, _>, Error>>()?;
388        Ok(blinded_messages
389            .iter()
390            .map(|bm| blinded_signatures.remove(bm))
391            .collect())
392    }
393
394    async fn get_auth_for_endpoint(
395        &self,
396        protected_endpoint: ProtectedEndpoint,
397    ) -> Result<Option<AuthRequired>, Self::Err> {
398        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
399        Ok(
400            query(r#"SELECT auth FROM protected_endpoints WHERE endpoint = :endpoint"#)?
401                .bind("endpoint", serde_json::to_string(&protected_endpoint)?)
402                .pluck(&*conn)
403                .await?
404                .map(|auth| {
405                    Ok::<_, Error>(column_as_string!(
406                        auth,
407                        serde_json::from_str,
408                        serde_json::from_slice
409                    ))
410                })
411                .transpose()?,
412        )
413    }
414
415    async fn get_auth_for_endpoints(
416        &self,
417    ) -> Result<HashMap<ProtectedEndpoint, Option<AuthRequired>>, Self::Err> {
418        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
419        Ok(query(r#"SELECT endpoint, auth FROM protected_endpoints"#)?
420            .fetch_all(&*conn)
421            .await?
422            .into_iter()
423            .map(|row| {
424                let endpoint =
425                    column_as_string!(&row[0], serde_json::from_str, serde_json::from_slice);
426                let auth = column_as_string!(&row[1], serde_json::from_str, serde_json::from_slice);
427                Ok((endpoint, Some(auth)))
428            })
429            .collect::<Result<HashMap<_, _>, Error>>()?)
430    }
431}