Skip to main content

cdk_sql_common/mint/auth/
mod.rs

1//! SQL Mint Auth
2
3use std::collections::HashMap;
4use std::fmt::Debug;
5use std::str::FromStr;
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use cdk_common::database::{self, MintAuthDatabase, MintAuthTransaction};
10use cdk_common::mint::MintKeySetInfo;
11use cdk_common::nuts::{AuthProof, BlindSignature, Id, PublicKey, State};
12use cdk_common::{AuthRequired, ProtectedEndpoint};
13use migrations::MIGRATIONS;
14use tracing::instrument;
15
16use super::SQLTransaction;
17use crate::column_as_string;
18use crate::common::migrate;
19use crate::database::{ConnectionWithTransaction, DatabaseExecutor};
20use crate::mint::keys::sql_row_to_keyset_info;
21use crate::mint::signatures::sql_row_to_blind_signature;
22use crate::mint::Error;
23use crate::pool::{DatabasePool, Pool, PooledResource};
24use crate::stmt::query;
25
26/// Mint SQL Database
27#[derive(Debug, Clone)]
28pub struct SQLMintAuthDatabase<RM>
29where
30    RM: DatabasePool + 'static,
31{
32    pool: Arc<Pool<RM>>,
33}
34
35impl<RM> SQLMintAuthDatabase<RM>
36where
37    RM: DatabasePool + 'static,
38{
39    /// Creates a new instance
40    pub async fn new<X>(db: X) -> Result<Self, Error>
41    where
42        X: Into<RM::Config>,
43    {
44        let pool = Pool::new(db.into());
45        Self::migrate(pool.get().map_err(|e| Error::Database(Box::new(e)))?).await?;
46        Ok(Self { pool })
47    }
48
49    /// Migrate
50    async fn migrate(conn: PooledResource<RM>) -> Result<(), Error> {
51        let tx = ConnectionWithTransaction::new(conn).await?;
52        migrate(&tx, RM::Connection::name(), MIGRATIONS).await?;
53        tx.commit().await?;
54        Ok(())
55    }
56}
57
58#[rustfmt::skip]
59mod migrations {
60    include!(concat!(env!("OUT_DIR"), "/migrations_mint_auth.rs"));
61}
62
63#[async_trait]
64impl<RM> MintAuthTransaction<database::Error> for SQLTransaction<RM>
65where
66    RM: DatabasePool + 'static,
67{
68    #[instrument(skip(self))]
69    async fn set_active_keyset(&mut self, id: Id) -> Result<(), database::Error> {
70        tracing::info!("Setting auth keyset {id} active");
71        query(
72            r#"
73            UPDATE keyset
74            SET active = CASE
75                WHEN id = :id THEN TRUE
76                ELSE FALSE
77            END;
78            "#,
79        )?
80        .bind("id", id.to_string())
81        .execute(&self.inner)
82        .await?;
83
84        Ok(())
85    }
86
87    async fn add_keyset_info(&mut self, keyset: MintKeySetInfo) -> Result<(), database::Error> {
88        query(
89            r#"
90        INSERT INTO
91            keyset (
92                id, unit, active, valid_from, valid_to, derivation_path,
93                amounts, input_fee_ppk, derivation_path_index
94            )
95        VALUES (
96            :id, :unit, :active, :valid_from, :valid_to, :derivation_path,
97            :amounts, :input_fee_ppk, :derivation_path_index
98        )
99        ON CONFLICT(id) DO UPDATE SET
100            unit = excluded.unit,
101            active = excluded.active,
102            valid_from = excluded.valid_from,
103            valid_to = excluded.valid_to,
104            derivation_path = excluded.derivation_path,
105            amounts = excluded.amounts,
106            input_fee_ppk = excluded.input_fee_ppk,
107            derivation_path_index = excluded.derivation_path_index
108        "#,
109        )?
110        .bind("id", keyset.id.to_string())
111        .bind("unit", keyset.unit.to_string())
112        .bind("active", keyset.active)
113        .bind("valid_from", keyset.valid_from as i64)
114        .bind("valid_to", keyset.final_expiry.map(|v| v as i64))
115        .bind("derivation_path", keyset.derivation_path.to_string())
116        .bind("amounts", serde_json::to_string(&keyset.amounts).ok())
117        .bind("input_fee_ppk", keyset.input_fee_ppk as i64)
118        .bind("derivation_path_index", keyset.derivation_path_index)
119        .execute(&self.inner)
120        .await?;
121
122        Ok(())
123    }
124
125    async fn add_proof(&mut self, proof: AuthProof) -> Result<(), database::Error> {
126        let y = proof.y()?;
127        if let Err(err) = query(
128            r#"
129                INSERT INTO proof
130                (y, keyset_id, secret, c, state)
131                VALUES
132                (:y, :keyset_id, :secret, :c, :state)
133                "#,
134        )?
135        .bind("y", y.to_bytes().to_vec())
136        .bind("keyset_id", proof.keyset_id.to_string())
137        .bind("secret", proof.secret.to_string())
138        .bind("c", proof.c.to_bytes().to_vec())
139        .bind("state", "UNSPENT".to_string())
140        .execute(&self.inner)
141        .await
142        {
143            tracing::debug!("Attempting to add known proof. Skipping.... {:?}", err);
144        }
145        Ok(())
146    }
147
148    async fn update_proof_state(
149        &mut self,
150        y: &PublicKey,
151        proofs_state: State,
152    ) -> Result<Option<State>, Self::Err> {
153        let current_state = query(r#"SELECT state FROM proof WHERE y = :y FOR UPDATE"#)?
154            .bind("y", y.to_bytes().to_vec())
155            .pluck(&self.inner)
156            .await?
157            .map(|state| Ok::<_, Error>(column_as_string!(state, State::from_str)))
158            .transpose()?;
159
160        query(r#"UPDATE proof SET state = :new_state WHERE  y = :y"#)?
161            .bind("y", y.to_bytes().to_vec())
162            .bind("new_state", proofs_state.to_string())
163            .execute(&self.inner)
164            .await?;
165
166        Ok(current_state)
167    }
168
169    async fn add_blind_signatures(
170        &mut self,
171        blinded_messages: &[PublicKey],
172        blind_signatures: &[BlindSignature],
173    ) -> Result<(), database::Error> {
174        for (message, signature) in blinded_messages.iter().zip(blind_signatures) {
175            query(
176                r#"
177                       INSERT
178                       INTO blind_signature
179                       (blinded_message, amount, keyset_id, c)
180                       VALUES
181                       (:blinded_message, :amount, :keyset_id, :c)
182                   "#,
183            )?
184            .bind("blinded_message", message.to_bytes().to_vec())
185            .bind("amount", u64::from(signature.amount) as i64)
186            .bind("keyset_id", signature.keyset_id.to_string())
187            .bind("c", signature.c.to_bytes().to_vec())
188            .execute(&self.inner)
189            .await?;
190        }
191
192        Ok(())
193    }
194
195    async fn add_protected_endpoints(
196        &mut self,
197        protected_endpoints: HashMap<ProtectedEndpoint, AuthRequired>,
198    ) -> Result<(), database::Error> {
199        for (endpoint, auth) in protected_endpoints.iter() {
200            if let Err(err) = query(
201                r#"
202                 INSERT INTO protected_endpoints
203                 (endpoint, auth)
204                 VALUES (:endpoint, :auth)
205                 ON CONFLICT (endpoint) DO UPDATE SET
206                 auth = EXCLUDED.auth;
207                 "#,
208            )?
209            .bind("endpoint", serde_json::to_string(endpoint)?)
210            .bind("auth", serde_json::to_string(auth)?)
211            .execute(&self.inner)
212            .await
213            {
214                tracing::debug!(
215                    "Attempting to add protected endpoint. Skipping.... {:?}",
216                    err
217                );
218            }
219        }
220
221        Ok(())
222    }
223    async fn remove_protected_endpoints(
224        &mut self,
225        protected_endpoints: Vec<ProtectedEndpoint>,
226    ) -> Result<(), database::Error> {
227        query(r#"DELETE FROM protected_endpoints WHERE endpoint IN (:endpoints)"#)?
228            .bind_vec(
229                "endpoints",
230                protected_endpoints
231                    .iter()
232                    .map(serde_json::to_string)
233                    .collect::<Result<_, _>>()?,
234            )
235            .execute(&self.inner)
236            .await?;
237        Ok(())
238    }
239}
240
241#[async_trait]
242impl<RM> MintAuthDatabase for SQLMintAuthDatabase<RM>
243where
244    RM: DatabasePool + 'static,
245{
246    type Err = database::Error;
247
248    async fn begin_transaction<'a>(
249        &'a self,
250    ) -> Result<Box<dyn MintAuthTransaction<database::Error> + Send + Sync + 'a>, database::Error>
251    {
252        Ok(Box::new(SQLTransaction {
253            inner: ConnectionWithTransaction::new(
254                self.pool.get().map_err(|e| Error::Database(Box::new(e)))?,
255            )
256            .await?,
257        }))
258    }
259
260    async fn get_active_keyset_id(&self) -> Result<Option<Id>, Self::Err> {
261        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
262        Ok(query(
263            r#"
264            SELECT
265                id
266            FROM
267                keyset
268            WHERE
269                active = :active;
270            "#,
271        )?
272        .bind("active", true)
273        .pluck(&*conn)
274        .await?
275        .map(|id| Ok::<_, Error>(column_as_string!(id, Id::from_str, Id::from_bytes)))
276        .transpose()?)
277    }
278
279    async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err> {
280        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
281        Ok(query(
282            r#"SELECT
283                id,
284                unit,
285                active,
286                valid_from,
287                valid_to,
288                derivation_path,
289                derivation_path_index,
290                amounts,
291                input_fee_ppk
292            FROM
293                keyset
294                WHERE id=:id"#,
295        )?
296        .bind("id", id.to_string())
297        .fetch_one(&*conn)
298        .await?
299        .map(sql_row_to_keyset_info)
300        .transpose()?)
301    }
302
303    async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err> {
304        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
305        Ok(query(
306            r#"SELECT
307                id,
308                unit,
309                active,
310                valid_from,
311                valid_to,
312                derivation_path,
313                derivation_path_index,
314                amounts,
315                input_fee_ppk
316            FROM
317                keyset
318                WHERE id=:id"#,
319        )?
320        .fetch_all(&*conn)
321        .await?
322        .into_iter()
323        .map(sql_row_to_keyset_info)
324        .collect::<Result<Vec<_>, _>>()?)
325    }
326
327    async fn get_proofs_states(&self, ys: &[PublicKey]) -> Result<Vec<Option<State>>, Self::Err> {
328        if ys.is_empty() {
329            return Ok(vec![]);
330        }
331        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
332        let mut current_states = query(r#"SELECT y, state FROM proof WHERE y IN (:ys)"#)?
333            .bind_vec("ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
334            .fetch_all(&*conn)
335            .await?
336            .into_iter()
337            .map(|row| {
338                Ok((
339                    column_as_string!(&row[0], PublicKey::from_hex, PublicKey::from_slice),
340                    column_as_string!(&row[1], State::from_str),
341                ))
342            })
343            .collect::<Result<HashMap<_, _>, Error>>()?;
344
345        Ok(ys.iter().map(|y| current_states.remove(y)).collect())
346    }
347
348    async fn get_blind_signatures(
349        &self,
350        blinded_messages: &[PublicKey],
351    ) -> Result<Vec<Option<BlindSignature>>, Self::Err> {
352        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
353        let mut blinded_signatures = query(
354            r#"SELECT
355                keyset_id,
356                amount,
357                c,
358                dleq_e,
359                dleq_s,
360                blinded_message,
361            FROM
362                blind_signature
363            WHERE blinded_message IN (:blinded_message)
364            "#,
365        )?
366        .bind_vec(
367            "blinded_message",
368            blinded_messages
369                .iter()
370                .map(|bm| bm.to_bytes().to_vec())
371                .collect(),
372        )
373        .fetch_all(&*conn)
374        .await?
375        .into_iter()
376        .map(|mut row| {
377            Ok((
378                column_as_string!(
379                    &row.pop().ok_or(Error::InvalidDbResponse)?,
380                    PublicKey::from_hex,
381                    PublicKey::from_slice
382                ),
383                sql_row_to_blind_signature(row)?,
384            ))
385        })
386        .collect::<Result<HashMap<_, _>, Error>>()?;
387        Ok(blinded_messages
388            .iter()
389            .map(|bm| blinded_signatures.remove(bm))
390            .collect())
391    }
392
393    async fn get_auth_for_endpoint(
394        &self,
395        protected_endpoint: ProtectedEndpoint,
396    ) -> Result<Option<AuthRequired>, Self::Err> {
397        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
398        Ok(
399            query(r#"SELECT auth FROM protected_endpoints WHERE endpoint = :endpoint"#)?
400                .bind("endpoint", serde_json::to_string(&protected_endpoint)?)
401                .pluck(&*conn)
402                .await?
403                .map(|auth| {
404                    Ok::<_, Error>(column_as_string!(
405                        auth,
406                        serde_json::from_str,
407                        serde_json::from_slice
408                    ))
409                })
410                .transpose()?,
411        )
412    }
413
414    async fn get_auth_for_endpoints(
415        &self,
416    ) -> Result<HashMap<ProtectedEndpoint, Option<AuthRequired>>, Self::Err> {
417        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
418        Ok(query(r#"SELECT endpoint, auth FROM protected_endpoints"#)?
419            .fetch_all(&*conn)
420            .await?
421            .into_iter()
422            .map(|row| {
423                let endpoint =
424                    column_as_string!(&row[0], serde_json::from_str, serde_json::from_slice);
425                let auth = column_as_string!(&row[1], serde_json::from_str, serde_json::from_slice);
426                Ok((endpoint, Some(auth)))
427            })
428            .collect::<Result<HashMap<_, _>, Error>>()?)
429    }
430}