1use std::collections::HashMap;
4use std::fmt::Debug;
5use std::str::FromStr;
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use cdk_common::database::{self, MintAuthDatabase, MintAuthTransaction};
10use cdk_common::mint::MintKeySetInfo;
11use cdk_common::nuts::{AuthProof, BlindSignature, Id, PublicKey, State};
12use cdk_common::{AuthRequired, ProtectedEndpoint};
13use migrations::MIGRATIONS;
14use tracing::instrument;
15
16use super::SQLTransaction;
17use crate::column_as_string;
18use crate::common::migrate;
19use crate::database::{ConnectionWithTransaction, DatabaseExecutor};
20use crate::mint::keys::sql_row_to_keyset_info;
21use crate::mint::signatures::sql_row_to_blind_signature;
22use crate::mint::Error;
23use crate::pool::{DatabasePool, Pool, PooledResource};
24use crate::stmt::query;
25
26#[derive(Debug, Clone)]
28pub struct SQLMintAuthDatabase<RM>
29where
30 RM: DatabasePool + 'static,
31{
32 pool: Arc<Pool<RM>>,
33}
34
35impl<RM> SQLMintAuthDatabase<RM>
36where
37 RM: DatabasePool + 'static,
38{
39 pub async fn new<X>(db: X) -> Result<Self, Error>
41 where
42 X: Into<RM::Config>,
43 {
44 let pool = Pool::new(db.into());
45 Self::migrate(pool.get().await.map_err(|e| Error::Database(Box::new(e)))?).await?;
46 Ok(Self { pool })
47 }
48
49 async fn migrate(conn: PooledResource<RM>) -> Result<(), Error> {
51 let tx = ConnectionWithTransaction::new(conn).await?;
52 migrate(&tx, RM::Connection::name(), MIGRATIONS).await?;
53 tx.commit().await?;
54 Ok(())
55 }
56}
57
58#[rustfmt::skip]
59mod migrations {
60 include!(concat!(env!("OUT_DIR"), "/migrations_mint_auth.rs"));
61}
62
63#[async_trait]
64impl<RM> MintAuthTransaction<database::Error> for SQLTransaction<RM>
65where
66 RM: DatabasePool + 'static,
67{
68 #[instrument(skip(self))]
69 async fn set_active_keyset(&mut self, id: Id) -> Result<(), database::Error> {
70 tracing::info!("Setting auth keyset {id} active");
71 query(
72 r#"
73 UPDATE keyset
74 SET active = CASE
75 WHEN id = :id THEN TRUE
76 ELSE FALSE
77 END;
78 "#,
79 )?
80 .bind("id", id.to_string())
81 .execute(&self.inner)
82 .await?;
83
84 Ok(())
85 }
86
87 async fn add_keyset_info(&mut self, keyset: MintKeySetInfo) -> Result<(), database::Error> {
88 query(
89 r#"
90 INSERT INTO
91 keyset (
92 id, unit, active, valid_from, valid_to, derivation_path,
93 amounts, input_fee_ppk, derivation_path_index
94 )
95 VALUES (
96 :id, :unit, :active, :valid_from, :valid_to, :derivation_path,
97 :amounts, :input_fee_ppk, :derivation_path_index
98 )
99 ON CONFLICT(id) DO UPDATE SET
100 unit = excluded.unit,
101 active = excluded.active,
102 valid_from = excluded.valid_from,
103 valid_to = excluded.valid_to,
104 derivation_path = excluded.derivation_path,
105 amounts = excluded.amounts,
106 input_fee_ppk = excluded.input_fee_ppk,
107 derivation_path_index = excluded.derivation_path_index
108 "#,
109 )?
110 .bind("id", keyset.id.to_string())
111 .bind("unit", keyset.unit.to_string())
112 .bind("active", keyset.active)
113 .bind("valid_from", keyset.valid_from as i64)
114 .bind("valid_to", keyset.final_expiry.map(|v| v as i64))
115 .bind("derivation_path", keyset.derivation_path.to_string())
116 .bind("amounts", serde_json::to_string(&keyset.amounts).ok())
117 .bind("input_fee_ppk", keyset.input_fee_ppk as i64)
118 .bind("derivation_path_index", keyset.derivation_path_index)
119 .execute(&self.inner)
120 .await?;
121
122 Ok(())
123 }
124
125 async fn add_proof(&mut self, proof: AuthProof) -> Result<(), database::Error> {
126 let y = proof.y()?;
127 if let Err(err) = query(
128 r#"
129 INSERT INTO proof
130 (y, keyset_id, secret, c, state)
131 VALUES
132 (:y, :keyset_id, :secret, :c, :state)
133 "#,
134 )?
135 .bind("y", y.to_bytes().to_vec())
136 .bind("keyset_id", proof.keyset_id.to_string())
137 .bind("secret", proof.secret.to_string())
138 .bind("c", proof.c.to_bytes().to_vec())
139 .bind("state", "UNSPENT".to_string())
140 .execute(&self.inner)
141 .await
142 {
143 tracing::debug!("Attempting to add known proof. Skipping.... {:?}", err);
144 }
145 Ok(())
146 }
147
148 async fn update_proof_state(
149 &mut self,
150 y: &PublicKey,
151 proofs_state: State,
152 ) -> Result<Option<State>, Self::Err> {
153 let current_state = query(r#"SELECT state FROM proof WHERE y = :y FOR UPDATE"#)?
154 .bind("y", y.to_bytes().to_vec())
155 .pluck(&self.inner)
156 .await?
157 .map(|state| Ok::<_, Error>(column_as_string!(state, State::from_str)))
158 .transpose()?;
159
160 query(r#"UPDATE proof SET state = :new_state WHERE y = :y"#)?
161 .bind("y", y.to_bytes().to_vec())
162 .bind("new_state", proofs_state.to_string())
163 .execute(&self.inner)
164 .await?;
165
166 Ok(current_state)
167 }
168
169 async fn add_blind_signatures(
170 &mut self,
171 blinded_messages: &[PublicKey],
172 blind_signatures: &[BlindSignature],
173 ) -> Result<(), database::Error> {
174 for (message, signature) in blinded_messages.iter().zip(blind_signatures) {
175 query(
176 r#"
177 INSERT
178 INTO blind_signature
179 (blinded_message, amount, keyset_id, c)
180 VALUES
181 (:blinded_message, :amount, :keyset_id, :c)
182 "#,
183 )?
184 .bind("blinded_message", message.to_bytes().to_vec())
185 .bind("amount", u64::from(signature.amount) as i64)
186 .bind("keyset_id", signature.keyset_id.to_string())
187 .bind("c", signature.c.to_bytes().to_vec())
188 .execute(&self.inner)
189 .await?;
190 }
191
192 Ok(())
193 }
194
195 async fn add_protected_endpoints(
196 &mut self,
197 protected_endpoints: HashMap<ProtectedEndpoint, AuthRequired>,
198 ) -> Result<(), database::Error> {
199 for (endpoint, auth) in protected_endpoints.iter() {
200 if let Err(err) = query(
201 r#"
202 INSERT INTO protected_endpoints
203 (endpoint, auth)
204 VALUES (:endpoint, :auth)
205 ON CONFLICT (endpoint) DO UPDATE SET
206 auth = EXCLUDED.auth;
207 "#,
208 )?
209 .bind("endpoint", serde_json::to_string(endpoint)?)
210 .bind("auth", serde_json::to_string(auth)?)
211 .execute(&self.inner)
212 .await
213 {
214 tracing::debug!(
215 "Attempting to add protected endpoint. Skipping.... {:?}",
216 err
217 );
218 }
219 }
220
221 Ok(())
222 }
223 async fn remove_protected_endpoints(
224 &mut self,
225 protected_endpoints: Vec<ProtectedEndpoint>,
226 ) -> Result<(), database::Error> {
227 query(r#"DELETE FROM protected_endpoints WHERE endpoint IN (:endpoints)"#)?
228 .bind_vec(
229 "endpoints",
230 protected_endpoints
231 .iter()
232 .map(serde_json::to_string)
233 .collect::<Result<_, _>>()?,
234 )?
235 .execute(&self.inner)
236 .await?;
237 Ok(())
238 }
239}
240
241#[async_trait]
242impl<RM> MintAuthDatabase for SQLMintAuthDatabase<RM>
243where
244 RM: DatabasePool + 'static,
245{
246 type Err = database::Error;
247
248 async fn begin_transaction<'a>(
249 &'a self,
250 ) -> Result<Box<dyn MintAuthTransaction<database::Error> + Send + Sync + 'a>, database::Error>
251 {
252 Ok(Box::new(SQLTransaction {
253 inner: ConnectionWithTransaction::new(
254 self.pool
255 .get()
256 .await
257 .map_err(|e| Error::Database(Box::new(e)))?,
258 )
259 .await?,
260 }))
261 }
262
263 async fn get_active_keyset_id(&self) -> Result<Option<Id>, Self::Err> {
264 let conn = self
265 .pool
266 .get()
267 .await
268 .map_err(|e| Error::Database(Box::new(e)))?;
269 Ok(query(
270 r#"
271 SELECT
272 id
273 FROM
274 keyset
275 WHERE
276 active = :active;
277 "#,
278 )?
279 .bind("active", true)
280 .pluck(&*conn)
281 .await?
282 .map(|id| Ok::<_, Error>(column_as_string!(id, Id::from_str, Id::from_bytes)))
283 .transpose()?)
284 }
285
286 async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err> {
287 let conn = self
288 .pool
289 .get()
290 .await
291 .map_err(|e| Error::Database(Box::new(e)))?;
292 Ok(query(
293 r#"SELECT
294 id,
295 unit,
296 active,
297 valid_from,
298 valid_to,
299 derivation_path,
300 derivation_path_index,
301 amounts,
302 input_fee_ppk
303 FROM
304 keyset
305 WHERE id=:id"#,
306 )?
307 .bind("id", id.to_string())
308 .fetch_one(&*conn)
309 .await?
310 .map(sql_row_to_keyset_info)
311 .transpose()?)
312 }
313
314 async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err> {
315 let conn = self
316 .pool
317 .get()
318 .await
319 .map_err(|e| Error::Database(Box::new(e)))?;
320 Ok(query(
321 r#"SELECT
322 id,
323 unit,
324 active,
325 valid_from,
326 valid_to,
327 derivation_path,
328 derivation_path_index,
329 amounts,
330 input_fee_ppk
331 FROM
332 keyset
333 WHERE id=:id"#,
334 )?
335 .fetch_all(&*conn)
336 .await?
337 .into_iter()
338 .map(sql_row_to_keyset_info)
339 .collect::<Result<Vec<_>, _>>()?)
340 }
341
342 async fn get_proofs_states(&self, ys: &[PublicKey]) -> Result<Vec<Option<State>>, Self::Err> {
343 let conn = self
344 .pool
345 .get()
346 .await
347 .map_err(|e| Error::Database(Box::new(e)))?;
348 let mut current_states = query(r#"SELECT y, state FROM proof WHERE y IN (:ys)"#)?
349 .bind_vec("ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())?
350 .fetch_all(&*conn)
351 .await?
352 .into_iter()
353 .map(|row| {
354 Ok((
355 column_as_string!(&row[0], PublicKey::from_hex, PublicKey::from_slice),
356 column_as_string!(&row[1], State::from_str),
357 ))
358 })
359 .collect::<Result<HashMap<_, _>, Error>>()?;
360
361 Ok(ys.iter().map(|y| current_states.remove(y)).collect())
362 }
363
364 async fn get_blind_signatures(
365 &self,
366 blinded_messages: &[PublicKey],
367 ) -> Result<Vec<Option<BlindSignature>>, Self::Err> {
368 let conn = self
369 .pool
370 .get()
371 .await
372 .map_err(|e| Error::Database(Box::new(e)))?;
373 let mut blinded_signatures = query(
374 r#"SELECT
375 keyset_id,
376 amount,
377 c,
378 dleq_e,
379 dleq_s,
380 blinded_message,
381 FROM
382 blind_signature
383 WHERE blinded_message IN (:blinded_message)
384 "#,
385 )?
386 .bind_vec(
387 "blinded_message",
388 blinded_messages
389 .iter()
390 .map(|bm| bm.to_bytes().to_vec())
391 .collect(),
392 )?
393 .fetch_all(&*conn)
394 .await?
395 .into_iter()
396 .map(|mut row| {
397 Ok((
398 column_as_string!(
399 &row.pop().ok_or(Error::InvalidDbResponse)?,
400 PublicKey::from_hex,
401 PublicKey::from_slice
402 ),
403 sql_row_to_blind_signature(row)?,
404 ))
405 })
406 .collect::<Result<HashMap<_, _>, Error>>()?;
407 Ok(blinded_messages
408 .iter()
409 .map(|bm| blinded_signatures.remove(bm))
410 .collect())
411 }
412
413 async fn get_auth_for_endpoint(
414 &self,
415 protected_endpoint: ProtectedEndpoint,
416 ) -> Result<Option<AuthRequired>, Self::Err> {
417 let conn = self
418 .pool
419 .get()
420 .await
421 .map_err(|e| Error::Database(Box::new(e)))?;
422 Ok(
423 query(r#"SELECT auth FROM protected_endpoints WHERE endpoint = :endpoint"#)?
424 .bind("endpoint", serde_json::to_string(&protected_endpoint)?)
425 .pluck(&*conn)
426 .await?
427 .map(|auth| {
428 Ok::<_, Error>(column_as_string!(
429 auth,
430 serde_json::from_str,
431 serde_json::from_slice
432 ))
433 })
434 .transpose()?,
435 )
436 }
437
438 async fn get_auth_for_endpoints(
439 &self,
440 ) -> Result<HashMap<ProtectedEndpoint, Option<AuthRequired>>, Self::Err> {
441 let conn = self
442 .pool
443 .get()
444 .await
445 .map_err(|e| Error::Database(Box::new(e)))?;
446 Ok(query(r#"SELECT endpoint, auth FROM protected_endpoints"#)?
447 .fetch_all(&*conn)
448 .await?
449 .into_iter()
450 .map(|row| {
451 let endpoint =
452 column_as_string!(&row[0], serde_json::from_str, serde_json::from_slice);
453 let auth = column_as_string!(&row[1], serde_json::from_str, serde_json::from_slice);
454 Ok((endpoint, Some(auth)))
455 })
456 .collect::<Result<HashMap<_, _>, Error>>()?)
457 }
458}