1use std::collections::HashMap;
4use std::fmt::Debug;
5use std::str::FromStr;
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use cdk_common::database::{self, MintAuthDatabase, MintAuthTransaction};
10use cdk_common::mint::MintKeySetInfo;
11use cdk_common::nuts::{AuthProof, BlindSignature, Id, PublicKey, State};
12use cdk_common::{AuthRequired, ProtectedEndpoint};
13use migrations::MIGRATIONS;
14use tracing::instrument;
15
16use super::SQLTransaction;
17use crate::column_as_string;
18use crate::common::migrate;
19use crate::database::{ConnectionWithTransaction, DatabaseExecutor};
20use crate::mint::keys::sql_row_to_keyset_info;
21use crate::mint::signatures::sql_row_to_blind_signature;
22use crate::mint::Error;
23use crate::pool::{DatabasePool, Pool, PooledResource};
24use crate::stmt::query;
25
26#[derive(Debug, Clone)]
28pub struct SQLMintAuthDatabase<RM>
29where
30 RM: DatabasePool + 'static,
31{
32 pool: Arc<Pool<RM>>,
33}
34
35impl<RM> SQLMintAuthDatabase<RM>
36where
37 RM: DatabasePool + 'static,
38{
39 pub async fn new<X>(db: X) -> Result<Self, Error>
41 where
42 X: Into<RM::Config>,
43 {
44 let pool = Pool::new(db.into());
45 Self::migrate(pool.get().map_err(|e| Error::Database(Box::new(e)))?).await?;
46 Ok(Self { pool })
47 }
48
49 async fn migrate(conn: PooledResource<RM>) -> Result<(), Error> {
51 let tx = ConnectionWithTransaction::new(conn).await?;
52 migrate(&tx, RM::Connection::name(), MIGRATIONS).await?;
53 tx.commit().await?;
54 Ok(())
55 }
56}
57
58#[rustfmt::skip]
59mod migrations {
60 include!(concat!(env!("OUT_DIR"), "/migrations_mint_auth.rs"));
61}
62
63#[async_trait]
64impl<RM> MintAuthTransaction<database::Error> for SQLTransaction<RM>
65where
66 RM: DatabasePool + 'static,
67{
68 #[instrument(skip(self))]
69 async fn set_active_keyset(&mut self, id: Id) -> Result<(), database::Error> {
70 tracing::info!("Setting auth keyset {id} active");
71 query(
72 r#"
73 UPDATE keyset
74 SET active = CASE
75 WHEN id = :id THEN TRUE
76 ELSE FALSE
77 END;
78 "#,
79 )?
80 .bind("id", id.to_string())
81 .execute(&self.inner)
82 .await?;
83
84 Ok(())
85 }
86
87 async fn add_keyset_info(&mut self, keyset: MintKeySetInfo) -> Result<(), database::Error> {
88 query(
89 r#"
90 INSERT INTO
91 keyset (
92 id, unit, active, valid_from, valid_to, derivation_path,
93 amounts, input_fee_ppk, derivation_path_index
94 )
95 VALUES (
96 :id, :unit, :active, :valid_from, :valid_to, :derivation_path,
97 :amounts, :input_fee_ppk, :derivation_path_index
98 )
99 ON CONFLICT(id) DO UPDATE SET
100 unit = excluded.unit,
101 active = excluded.active,
102 valid_from = excluded.valid_from,
103 valid_to = excluded.valid_to,
104 derivation_path = excluded.derivation_path,
105 amounts = excluded.amounts,
106 input_fee_ppk = excluded.input_fee_ppk,
107 derivation_path_index = excluded.derivation_path_index
108 "#,
109 )?
110 .bind("id", keyset.id.to_string())
111 .bind("unit", keyset.unit.to_string())
112 .bind("active", keyset.active)
113 .bind("valid_from", keyset.valid_from as i64)
114 .bind("valid_to", keyset.final_expiry.map(|v| v as i64))
115 .bind("derivation_path", keyset.derivation_path.to_string())
116 .bind("amounts", serde_json::to_string(&keyset.amounts).ok())
117 .bind("input_fee_ppk", keyset.input_fee_ppk as i64)
118 .bind("derivation_path_index", keyset.derivation_path_index)
119 .execute(&self.inner)
120 .await?;
121
122 Ok(())
123 }
124
125 async fn add_proof(&mut self, proof: AuthProof) -> Result<(), database::Error> {
126 let y = proof.y()?;
127 if let Err(err) = query(
128 r#"
129 INSERT INTO proof
130 (y, keyset_id, secret, c, state)
131 VALUES
132 (:y, :keyset_id, :secret, :c, :state)
133 "#,
134 )?
135 .bind("y", y.to_bytes().to_vec())
136 .bind("keyset_id", proof.keyset_id.to_string())
137 .bind("secret", proof.secret.to_string())
138 .bind("c", proof.c.to_bytes().to_vec())
139 .bind("state", "UNSPENT".to_string())
140 .execute(&self.inner)
141 .await
142 {
143 tracing::debug!("Attempting to add known proof. Skipping.... {:?}", err);
144 }
145 Ok(())
146 }
147
148 async fn update_proof_state(
149 &mut self,
150 y: &PublicKey,
151 proofs_state: State,
152 ) -> Result<Option<State>, Self::Err> {
153 let current_state = query(r#"SELECT state FROM proof WHERE y = :y FOR UPDATE"#)?
154 .bind("y", y.to_bytes().to_vec())
155 .pluck(&self.inner)
156 .await?
157 .map(|state| Ok::<_, Error>(column_as_string!(state, State::from_str)))
158 .transpose()?;
159
160 query(r#"UPDATE proof SET state = :new_state WHERE y = :y"#)?
161 .bind("y", y.to_bytes().to_vec())
162 .bind("new_state", proofs_state.to_string())
163 .execute(&self.inner)
164 .await?;
165
166 Ok(current_state)
167 }
168
169 async fn add_blind_signatures(
170 &mut self,
171 blinded_messages: &[PublicKey],
172 blind_signatures: &[BlindSignature],
173 ) -> Result<(), database::Error> {
174 for (message, signature) in blinded_messages.iter().zip(blind_signatures) {
175 query(
176 r#"
177 INSERT
178 INTO blind_signature
179 (blinded_message, amount, keyset_id, c)
180 VALUES
181 (:blinded_message, :amount, :keyset_id, :c)
182 "#,
183 )?
184 .bind("blinded_message", message.to_bytes().to_vec())
185 .bind("amount", u64::from(signature.amount) as i64)
186 .bind("keyset_id", signature.keyset_id.to_string())
187 .bind("c", signature.c.to_bytes().to_vec())
188 .execute(&self.inner)
189 .await?;
190 }
191
192 Ok(())
193 }
194
195 async fn add_protected_endpoints(
196 &mut self,
197 protected_endpoints: HashMap<ProtectedEndpoint, AuthRequired>,
198 ) -> Result<(), database::Error> {
199 for (endpoint, auth) in protected_endpoints.iter() {
200 if let Err(err) = query(
201 r#"
202 INSERT INTO protected_endpoints
203 (endpoint, auth)
204 VALUES (:endpoint, :auth)
205 ON CONFLICT (endpoint) DO UPDATE SET
206 auth = EXCLUDED.auth;
207 "#,
208 )?
209 .bind("endpoint", serde_json::to_string(endpoint)?)
210 .bind("auth", serde_json::to_string(auth)?)
211 .execute(&self.inner)
212 .await
213 {
214 tracing::debug!(
215 "Attempting to add protected endpoint. Skipping.... {:?}",
216 err
217 );
218 }
219 }
220
221 Ok(())
222 }
223 async fn remove_protected_endpoints(
224 &mut self,
225 protected_endpoints: Vec<ProtectedEndpoint>,
226 ) -> Result<(), database::Error> {
227 query(r#"DELETE FROM protected_endpoints WHERE endpoint IN (:endpoints)"#)?
228 .bind_vec(
229 "endpoints",
230 protected_endpoints
231 .iter()
232 .map(serde_json::to_string)
233 .collect::<Result<_, _>>()?,
234 )
235 .execute(&self.inner)
236 .await?;
237 Ok(())
238 }
239}
240
241#[async_trait]
242impl<RM> MintAuthDatabase for SQLMintAuthDatabase<RM>
243where
244 RM: DatabasePool + 'static,
245{
246 type Err = database::Error;
247
248 async fn begin_transaction<'a>(
249 &'a self,
250 ) -> Result<Box<dyn MintAuthTransaction<database::Error> + Send + Sync + 'a>, database::Error>
251 {
252 Ok(Box::new(SQLTransaction {
253 inner: ConnectionWithTransaction::new(
254 self.pool.get().map_err(|e| Error::Database(Box::new(e)))?,
255 )
256 .await?,
257 }))
258 }
259
260 async fn get_active_keyset_id(&self) -> Result<Option<Id>, Self::Err> {
261 let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
262 Ok(query(
263 r#"
264 SELECT
265 id
266 FROM
267 keyset
268 WHERE
269 active = :active;
270 "#,
271 )?
272 .bind("active", true)
273 .pluck(&*conn)
274 .await?
275 .map(|id| Ok::<_, Error>(column_as_string!(id, Id::from_str, Id::from_bytes)))
276 .transpose()?)
277 }
278
279 async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err> {
280 let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
281 Ok(query(
282 r#"SELECT
283 id,
284 unit,
285 active,
286 valid_from,
287 valid_to,
288 derivation_path,
289 derivation_path_index,
290 amounts,
291 input_fee_ppk
292 FROM
293 keyset
294 WHERE id=:id"#,
295 )?
296 .bind("id", id.to_string())
297 .fetch_one(&*conn)
298 .await?
299 .map(sql_row_to_keyset_info)
300 .transpose()?)
301 }
302
303 async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err> {
304 let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
305 Ok(query(
306 r#"SELECT
307 id,
308 unit,
309 active,
310 valid_from,
311 valid_to,
312 derivation_path,
313 derivation_path_index,
314 amounts,
315 input_fee_ppk
316 FROM
317 keyset
318 WHERE id=:id"#,
319 )?
320 .fetch_all(&*conn)
321 .await?
322 .into_iter()
323 .map(sql_row_to_keyset_info)
324 .collect::<Result<Vec<_>, _>>()?)
325 }
326
327 async fn get_proofs_states(&self, ys: &[PublicKey]) -> Result<Vec<Option<State>>, Self::Err> {
328 if ys.is_empty() {
329 return Ok(vec![]);
330 }
331 let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
332 let mut current_states = query(r#"SELECT y, state FROM proof WHERE y IN (:ys)"#)?
333 .bind_vec("ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
334 .fetch_all(&*conn)
335 .await?
336 .into_iter()
337 .map(|row| {
338 Ok((
339 column_as_string!(&row[0], PublicKey::from_hex, PublicKey::from_slice),
340 column_as_string!(&row[1], State::from_str),
341 ))
342 })
343 .collect::<Result<HashMap<_, _>, Error>>()?;
344
345 Ok(ys.iter().map(|y| current_states.remove(y)).collect())
346 }
347
348 async fn get_blind_signatures(
349 &self,
350 blinded_messages: &[PublicKey],
351 ) -> Result<Vec<Option<BlindSignature>>, Self::Err> {
352 let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
353 let mut blinded_signatures = query(
354 r#"SELECT
355 keyset_id,
356 amount,
357 c,
358 dleq_e,
359 dleq_s,
360 blinded_message,
361 FROM
362 blind_signature
363 WHERE blinded_message IN (:blinded_message)
364 "#,
365 )?
366 .bind_vec(
367 "blinded_message",
368 blinded_messages
369 .iter()
370 .map(|bm| bm.to_bytes().to_vec())
371 .collect(),
372 )
373 .fetch_all(&*conn)
374 .await?
375 .into_iter()
376 .map(|mut row| {
377 Ok((
378 column_as_string!(
379 &row.pop().ok_or(Error::InvalidDbResponse)?,
380 PublicKey::from_hex,
381 PublicKey::from_slice
382 ),
383 sql_row_to_blind_signature(row)?,
384 ))
385 })
386 .collect::<Result<HashMap<_, _>, Error>>()?;
387 Ok(blinded_messages
388 .iter()
389 .map(|bm| blinded_signatures.remove(bm))
390 .collect())
391 }
392
393 async fn get_auth_for_endpoint(
394 &self,
395 protected_endpoint: ProtectedEndpoint,
396 ) -> Result<Option<AuthRequired>, Self::Err> {
397 let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
398 Ok(
399 query(r#"SELECT auth FROM protected_endpoints WHERE endpoint = :endpoint"#)?
400 .bind("endpoint", serde_json::to_string(&protected_endpoint)?)
401 .pluck(&*conn)
402 .await?
403 .map(|auth| {
404 Ok::<_, Error>(column_as_string!(
405 auth,
406 serde_json::from_str,
407 serde_json::from_slice
408 ))
409 })
410 .transpose()?,
411 )
412 }
413
414 async fn get_auth_for_endpoints(
415 &self,
416 ) -> Result<HashMap<ProtectedEndpoint, Option<AuthRequired>>, Self::Err> {
417 let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
418 Ok(query(r#"SELECT endpoint, auth FROM protected_endpoints"#)?
419 .fetch_all(&*conn)
420 .await?
421 .into_iter()
422 .map(|row| {
423 let endpoint =
424 column_as_string!(&row[0], serde_json::from_str, serde_json::from_slice);
425 let auth = column_as_string!(&row[1], serde_json::from_str, serde_json::from_slice);
426 Ok((endpoint, Some(auth)))
427 })
428 .collect::<Result<HashMap<_, _>, Error>>()?)
429 }
430}