1use std::collections::HashMap;
4use std::fmt::Debug;
5use std::str::FromStr;
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use cdk_common::database::{self, MintAuthDatabase, MintAuthTransaction};
10use cdk_common::mint::MintKeySetInfo;
11use cdk_common::nuts::{AuthProof, BlindSignature, Id, PublicKey, State};
12use cdk_common::{AuthRequired, ProtectedEndpoint};
13use migrations::MIGRATIONS;
14use tracing::instrument;
15
16use super::{sql_row_to_blind_signature, sql_row_to_keyset_info, SQLTransaction};
17use crate::column_as_string;
18use crate::common::migrate;
19use crate::database::{ConnectionWithTransaction, DatabaseExecutor};
20use crate::mint::Error;
21use crate::pool::{DatabasePool, Pool, PooledResource};
22use crate::stmt::query;
23
24#[derive(Debug, Clone)]
26pub struct SQLMintAuthDatabase<RM>
27where
28 RM: DatabasePool + 'static,
29{
30 pool: Arc<Pool<RM>>,
31}
32
33impl<RM> SQLMintAuthDatabase<RM>
34where
35 RM: DatabasePool + 'static,
36{
37 pub async fn new<X>(db: X) -> Result<Self, Error>
39 where
40 X: Into<RM::Config>,
41 {
42 let pool = Pool::new(db.into());
43 Self::migrate(pool.get().map_err(|e| Error::Database(Box::new(e)))?).await?;
44 Ok(Self { pool })
45 }
46
47 async fn migrate(conn: PooledResource<RM>) -> Result<(), Error> {
49 let tx = ConnectionWithTransaction::new(conn).await?;
50 migrate(&tx, RM::Connection::name(), MIGRATIONS).await?;
51 tx.commit().await?;
52 Ok(())
53 }
54}
55
56#[rustfmt::skip]
57mod migrations {
58 include!(concat!(env!("OUT_DIR"), "/migrations_mint_auth.rs"));
59}
60
61#[async_trait]
62impl<RM> MintAuthTransaction<database::Error> for SQLTransaction<RM>
63where
64 RM: DatabasePool + 'static,
65{
66 #[instrument(skip(self))]
67 async fn set_active_keyset(&mut self, id: Id) -> Result<(), database::Error> {
68 tracing::info!("Setting auth keyset {id} active");
69 query(
70 r#"
71 UPDATE keyset
72 SET active = CASE
73 WHEN id = :id THEN TRUE
74 ELSE FALSE
75 END;
76 "#,
77 )?
78 .bind("id", id.to_string())
79 .execute(&self.inner)
80 .await?;
81
82 Ok(())
83 }
84
85 async fn add_keyset_info(&mut self, keyset: MintKeySetInfo) -> Result<(), database::Error> {
86 query(
87 r#"
88 INSERT INTO
89 keyset (
90 id, unit, active, valid_from, valid_to, derivation_path,
91 max_order, derivation_path_index
92 )
93 VALUES (
94 :id, :unit, :active, :valid_from, :valid_to, :derivation_path,
95 :max_order, :derivation_path_index
96 )
97 ON CONFLICT(id) DO UPDATE SET
98 unit = excluded.unit,
99 active = excluded.active,
100 valid_from = excluded.valid_from,
101 valid_to = excluded.valid_to,
102 derivation_path = excluded.derivation_path,
103 max_order = excluded.max_order,
104 derivation_path_index = excluded.derivation_path_index
105 "#,
106 )?
107 .bind("id", keyset.id.to_string())
108 .bind("unit", keyset.unit.to_string())
109 .bind("active", keyset.active)
110 .bind("valid_from", keyset.valid_from as i64)
111 .bind("valid_to", keyset.final_expiry.map(|v| v as i64))
112 .bind("derivation_path", keyset.derivation_path.to_string())
113 .bind("max_order", keyset.max_order)
114 .bind("derivation_path_index", keyset.derivation_path_index)
115 .execute(&self.inner)
116 .await?;
117
118 Ok(())
119 }
120
121 async fn add_proof(&mut self, proof: AuthProof) -> Result<(), database::Error> {
122 if let Err(err) = query(
123 r#"
124 INSERT INTO proof
125 (y, keyset_id, secret, c, state)
126 VALUES
127 (:y, :keyset_id, :secret, :c, :state)
128 "#,
129 )?
130 .bind("y", proof.y()?.to_bytes().to_vec())
131 .bind("keyset_id", proof.keyset_id.to_string())
132 .bind("secret", proof.secret.to_string())
133 .bind("c", proof.c.to_bytes().to_vec())
134 .bind("state", "UNSPENT".to_string())
135 .execute(&self.inner)
136 .await
137 {
138 tracing::debug!("Attempting to add known proof. Skipping.... {:?}", err);
139 }
140 Ok(())
141 }
142
143 async fn update_proof_state(
144 &mut self,
145 y: &PublicKey,
146 proofs_state: State,
147 ) -> Result<Option<State>, Self::Err> {
148 let current_state = query(r#"SELECT state FROM proof WHERE y = :y FOR UPDATE"#)?
149 .bind("y", y.to_bytes().to_vec())
150 .pluck(&self.inner)
151 .await?
152 .map(|state| Ok::<_, Error>(column_as_string!(state, State::from_str)))
153 .transpose()?;
154
155 query(r#"UPDATE proof SET state = :new_state WHERE state = :state AND y = :y"#)?
156 .bind("y", y.to_bytes().to_vec())
157 .bind(
158 "state",
159 current_state.as_ref().map(|state| state.to_string()),
160 )
161 .bind("new_state", proofs_state.to_string())
162 .execute(&self.inner)
163 .await?;
164
165 Ok(current_state)
166 }
167
168 async fn add_blind_signatures(
169 &mut self,
170 blinded_messages: &[PublicKey],
171 blind_signatures: &[BlindSignature],
172 ) -> Result<(), database::Error> {
173 for (message, signature) in blinded_messages.iter().zip(blind_signatures) {
174 query(
175 r#"
176 INSERT
177 INTO blind_signature
178 (blinded_message, amount, keyset_id, c)
179 VALUES
180 (:blinded_message, :amount, :keyset_id, :c)
181 "#,
182 )?
183 .bind("blinded_message", message.to_bytes().to_vec())
184 .bind("amount", u64::from(signature.amount) as i64)
185 .bind("keyset_id", signature.keyset_id.to_string())
186 .bind("c", signature.c.to_bytes().to_vec())
187 .execute(&self.inner)
188 .await?;
189 }
190
191 Ok(())
192 }
193
194 async fn add_protected_endpoints(
195 &mut self,
196 protected_endpoints: HashMap<ProtectedEndpoint, AuthRequired>,
197 ) -> Result<(), database::Error> {
198 for (endpoint, auth) in protected_endpoints.iter() {
199 if let Err(err) = query(
200 r#"
201 INSERT INTO protected_endpoints
202 (endpoint, auth)
203 VALUES (:endpoint, :auth)
204 ON CONFLICT (endpoint) DO UPDATE SET
205 auth = EXCLUDED.auth;
206 "#,
207 )?
208 .bind("endpoint", serde_json::to_string(endpoint)?)
209 .bind("auth", serde_json::to_string(auth)?)
210 .execute(&self.inner)
211 .await
212 {
213 tracing::debug!(
214 "Attempting to add protected endpoint. Skipping.... {:?}",
215 err
216 );
217 }
218 }
219
220 Ok(())
221 }
222 async fn remove_protected_endpoints(
223 &mut self,
224 protected_endpoints: Vec<ProtectedEndpoint>,
225 ) -> Result<(), database::Error> {
226 query(r#"DELETE FROM protected_endpoints WHERE endpoint IN (:endpoints)"#)?
227 .bind_vec(
228 "endpoints",
229 protected_endpoints
230 .iter()
231 .map(serde_json::to_string)
232 .collect::<Result<_, _>>()?,
233 )
234 .execute(&self.inner)
235 .await?;
236 Ok(())
237 }
238}
239
240#[async_trait]
241impl<RM> MintAuthDatabase for SQLMintAuthDatabase<RM>
242where
243 RM: DatabasePool + 'static,
244{
245 type Err = database::Error;
246
247 async fn begin_transaction<'a>(
248 &'a self,
249 ) -> Result<Box<dyn MintAuthTransaction<database::Error> + Send + Sync + 'a>, database::Error>
250 {
251 Ok(Box::new(SQLTransaction {
252 inner: ConnectionWithTransaction::new(
253 self.pool.get().map_err(|e| Error::Database(Box::new(e)))?,
254 )
255 .await?,
256 }))
257 }
258
259 async fn get_active_keyset_id(&self) -> Result<Option<Id>, Self::Err> {
260 let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
261 Ok(query(
262 r#"
263 SELECT
264 id
265 FROM
266 keyset
267 WHERE
268 active = :active;
269 "#,
270 )?
271 .bind("active", true)
272 .pluck(&*conn)
273 .await?
274 .map(|id| Ok::<_, Error>(column_as_string!(id, Id::from_str, Id::from_bytes)))
275 .transpose()?)
276 }
277
278 async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err> {
279 let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
280 Ok(query(
281 r#"SELECT
282 id,
283 unit,
284 active,
285 valid_from,
286 valid_to,
287 derivation_path,
288 derivation_path_index,
289 max_order,
290 amounts,
291 input_fee_ppk
292 FROM
293 keyset
294 WHERE id=:id"#,
295 )?
296 .bind("id", id.to_string())
297 .fetch_one(&*conn)
298 .await?
299 .map(sql_row_to_keyset_info)
300 .transpose()?)
301 }
302
303 async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err> {
304 let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
305 Ok(query(
306 r#"SELECT
307 id,
308 unit,
309 active,
310 valid_from,
311 valid_to,
312 derivation_path,
313 derivation_path_index,
314 max_order,
315 amounts,
316 input_fee_ppk
317 FROM
318 keyset
319 WHERE id=:id"#,
320 )?
321 .fetch_all(&*conn)
322 .await?
323 .into_iter()
324 .map(sql_row_to_keyset_info)
325 .collect::<Result<Vec<_>, _>>()?)
326 }
327
328 async fn get_proofs_states(&self, ys: &[PublicKey]) -> Result<Vec<Option<State>>, Self::Err> {
329 if ys.is_empty() {
330 return Ok(vec![]);
331 }
332 let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
333 let mut current_states = query(r#"SELECT y, state FROM proof WHERE y IN (:ys)"#)?
334 .bind_vec("ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
335 .fetch_all(&*conn)
336 .await?
337 .into_iter()
338 .map(|row| {
339 Ok((
340 column_as_string!(&row[0], PublicKey::from_hex, PublicKey::from_slice),
341 column_as_string!(&row[1], State::from_str),
342 ))
343 })
344 .collect::<Result<HashMap<_, _>, Error>>()?;
345
346 Ok(ys.iter().map(|y| current_states.remove(y)).collect())
347 }
348
349 async fn get_blind_signatures(
350 &self,
351 blinded_messages: &[PublicKey],
352 ) -> Result<Vec<Option<BlindSignature>>, Self::Err> {
353 let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
354 let mut blinded_signatures = query(
355 r#"SELECT
356 keyset_id,
357 amount,
358 c,
359 dleq_e,
360 dleq_s,
361 blinded_message,
362 FROM
363 blind_signature
364 WHERE blinded_message IN (:blinded_message)
365 "#,
366 )?
367 .bind_vec(
368 "blinded_message",
369 blinded_messages
370 .iter()
371 .map(|bm| bm.to_bytes().to_vec())
372 .collect(),
373 )
374 .fetch_all(&*conn)
375 .await?
376 .into_iter()
377 .map(|mut row| {
378 Ok((
379 column_as_string!(
380 &row.pop().ok_or(Error::InvalidDbResponse)?,
381 PublicKey::from_hex,
382 PublicKey::from_slice
383 ),
384 sql_row_to_blind_signature(row)?,
385 ))
386 })
387 .collect::<Result<HashMap<_, _>, Error>>()?;
388 Ok(blinded_messages
389 .iter()
390 .map(|bm| blinded_signatures.remove(bm))
391 .collect())
392 }
393
394 async fn get_auth_for_endpoint(
395 &self,
396 protected_endpoint: ProtectedEndpoint,
397 ) -> Result<Option<AuthRequired>, Self::Err> {
398 let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
399 Ok(
400 query(r#"SELECT auth FROM protected_endpoints WHERE endpoint = :endpoint"#)?
401 .bind("endpoint", serde_json::to_string(&protected_endpoint)?)
402 .pluck(&*conn)
403 .await?
404 .map(|auth| {
405 Ok::<_, Error>(column_as_string!(
406 auth,
407 serde_json::from_str,
408 serde_json::from_slice
409 ))
410 })
411 .transpose()?,
412 )
413 }
414
415 async fn get_auth_for_endpoints(
416 &self,
417 ) -> Result<HashMap<ProtectedEndpoint, Option<AuthRequired>>, Self::Err> {
418 let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
419 Ok(query(r#"SELECT endpoint, auth FROM protected_endpoints"#)?
420 .fetch_all(&*conn)
421 .await?
422 .into_iter()
423 .map(|row| {
424 let endpoint =
425 column_as_string!(&row[0], serde_json::from_str, serde_json::from_slice);
426 let auth = column_as_string!(&row[1], serde_json::from_str, serde_json::from_slice);
427 Ok((endpoint, Some(auth)))
428 })
429 .collect::<Result<HashMap<_, _>, Error>>()?)
430 }
431}