1use std::collections::HashMap;
4use std::str::FromStr;
5
6use async_trait::async_trait;
7use bitcoin::bip32::DerivationPath;
8use cdk_common::common::IssuerVersion;
9use cdk_common::database::{Error, MintKeyDatabaseTransaction, MintKeysDatabase};
10use cdk_common::mint::MintKeySetInfo;
11use cdk_common::{CurrencyUnit, Id};
12
13use super::{SQLMintDatabase, SQLTransaction};
14use crate::database::ConnectionWithTransaction;
15use crate::pool::DatabasePool;
16use crate::stmt::{query, Column};
17use crate::{
18 column_as_nullable_number, column_as_nullable_string, column_as_number, column_as_string,
19 unpack_into,
20};
21
22pub(crate) fn sql_row_to_keyset_info(row: Vec<Column>) -> Result<MintKeySetInfo, Error> {
23 unpack_into!(
24 let (
25 id,
26 unit,
27 active,
28 valid_from,
29 valid_to,
30 derivation_path,
31 derivation_path_index,
32 amounts,
33 row_keyset_ppk,
34 issuer_version
35 ) = row
36 );
37
38 let amounts = column_as_nullable_string!(amounts)
39 .and_then(|str| serde_json::from_str(&str).ok())
40 .ok_or_else(|| Error::Database("amounts field is required".to_string().into()))?;
41
42 Ok(MintKeySetInfo {
43 id: column_as_string!(id, Id::from_str, Id::from_bytes),
44 unit: column_as_string!(unit, CurrencyUnit::from_str),
45 active: matches!(active, Column::Integer(1)),
46 valid_from: column_as_number!(valid_from),
47 derivation_path: column_as_string!(derivation_path, DerivationPath::from_str),
48 derivation_path_index: column_as_nullable_number!(derivation_path_index),
49 amounts,
50 input_fee_ppk: column_as_nullable_number!(row_keyset_ppk).unwrap_or(0),
51 final_expiry: column_as_nullable_number!(valid_to),
52 issuer_version: column_as_nullable_string!(issuer_version).and_then(|v| {
53 match IssuerVersion::from_str(&v) {
54 Ok(ver) => Some(ver),
55 Err(e) => {
56 tracing::warn!(
57 "Failed to parse issuer_version from database: {}. Error: {}",
58 v,
59 e
60 );
61 None
62 }
63 }
64 }),
65 })
66}
67
68#[async_trait]
69impl<RM> MintKeyDatabaseTransaction<'_, Error> for SQLTransaction<RM>
70where
71 RM: DatabasePool + 'static,
72{
73 async fn add_keyset_info(&mut self, keyset: MintKeySetInfo) -> Result<(), Error> {
74 query(
75 r#"
76 INSERT INTO
77 keyset (
78 id, unit, active, valid_from, valid_to, derivation_path,
79 amounts, input_fee_ppk, derivation_path_index, issuer_version
80 )
81 VALUES (
82 :id, :unit, :active, :valid_from, :valid_to, :derivation_path,
83 :amounts, :input_fee_ppk, :derivation_path_index, :issuer_version
84 )
85 ON CONFLICT(id) DO UPDATE SET
86 unit = excluded.unit,
87 active = excluded.active,
88 valid_from = excluded.valid_from,
89 valid_to = excluded.valid_to,
90 derivation_path = excluded.derivation_path,
91 amounts = excluded.amounts,
92 input_fee_ppk = excluded.input_fee_ppk,
93 derivation_path_index = excluded.derivation_path_index,
94 issuer_version = excluded.issuer_version
95 "#,
96 )?
97 .bind("id", keyset.id.to_string())
98 .bind("unit", keyset.unit.to_string())
99 .bind("active", keyset.active)
100 .bind("valid_from", keyset.valid_from as i64)
101 .bind("valid_to", keyset.final_expiry.map(|v| v as i64))
102 .bind("derivation_path", keyset.derivation_path.to_string())
103 .bind("amounts", serde_json::to_string(&keyset.amounts).ok())
104 .bind("input_fee_ppk", keyset.input_fee_ppk as i64)
105 .bind("derivation_path_index", keyset.derivation_path_index)
106 .bind(
107 "issuer_version",
108 keyset.issuer_version.map(|v| v.to_string()),
109 )
110 .execute(&self.inner)
111 .await?;
112
113 Ok(())
114 }
115
116 async fn set_active_keyset(&mut self, unit: CurrencyUnit, id: Id) -> Result<(), Error> {
117 query(r#"UPDATE keyset SET active=FALSE WHERE unit = :unit"#)?
118 .bind("unit", unit.to_string())
119 .execute(&self.inner)
120 .await?;
121
122 query(r#"UPDATE keyset SET active=TRUE WHERE unit = :unit AND id = :id"#)?
123 .bind("unit", unit.to_string())
124 .bind("id", id.to_string())
125 .execute(&self.inner)
126 .await?;
127
128 Ok(())
129 }
130}
131
132#[async_trait]
133impl<RM> MintKeysDatabase for SQLMintDatabase<RM>
134where
135 RM: DatabasePool + 'static,
136{
137 type Err = Error;
138
139 async fn begin_transaction<'a>(
140 &'a self,
141 ) -> Result<Box<dyn MintKeyDatabaseTransaction<'a, Error> + Send + Sync + 'a>, Error> {
142 let tx = SQLTransaction {
143 inner: ConnectionWithTransaction::new(
144 self.pool
145 .get()
146 .await
147 .map_err(|e| Error::Database(Box::new(e)))?,
148 )
149 .await?,
150 };
151
152 Ok(Box::new(tx))
153 }
154
155 async fn get_active_keyset_id(&self, unit: &CurrencyUnit) -> Result<Option<Id>, Self::Err> {
156 let conn = self
157 .pool
158 .get()
159 .await
160 .map_err(|e| Error::Database(Box::new(e)))?;
161 Ok(
162 query(r#" SELECT id FROM keyset WHERE active = :active AND unit = :unit"#)?
163 .bind("active", true)
164 .bind("unit", unit.to_string())
165 .pluck(&*conn)
166 .await?
167 .map(|id| match id {
168 Column::Text(text) => Ok(Id::from_str(&text)?),
169 Column::Blob(id) => Ok(Id::from_bytes(&id)?),
170 _ => Err(Error::InvalidKeysetId),
171 })
172 .transpose()?,
173 )
174 }
175
176 async fn get_active_keysets(&self) -> Result<HashMap<CurrencyUnit, Id>, Self::Err> {
177 let conn = self
178 .pool
179 .get()
180 .await
181 .map_err(|e| Error::Database(Box::new(e)))?;
182 Ok(
183 query(r#"SELECT id, unit FROM keyset WHERE active = :active"#)?
184 .bind("active", true)
185 .fetch_all(&*conn)
186 .await?
187 .into_iter()
188 .map(|row| {
189 Ok((
190 column_as_string!(&row[1], CurrencyUnit::from_str),
191 column_as_string!(&row[0], Id::from_str, Id::from_bytes),
192 ))
193 })
194 .collect::<Result<HashMap<_, _>, Error>>()?,
195 )
196 }
197
198 async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err> {
199 let conn = self
200 .pool
201 .get()
202 .await
203 .map_err(|e| Error::Database(Box::new(e)))?;
204 Ok(query(
205 r#"SELECT
206 id,
207 unit,
208 active,
209 valid_from,
210 valid_to,
211 derivation_path,
212 derivation_path_index,
213 amounts,
214 input_fee_ppk,
215 issuer_version
216 FROM
217 keyset
218 WHERE id=:id"#,
219 )?
220 .bind("id", id.to_string())
221 .fetch_one(&*conn)
222 .await?
223 .map(sql_row_to_keyset_info)
224 .transpose()?)
225 }
226
227 async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err> {
228 let conn = self
229 .pool
230 .get()
231 .await
232 .map_err(|e| Error::Database(Box::new(e)))?;
233 Ok(query(
234 r#"SELECT
235 id,
236 unit,
237 active,
238 valid_from,
239 valid_to,
240 derivation_path,
241 derivation_path_index,
242 amounts,
243 input_fee_ppk,
244 issuer_version
245 FROM
246 keyset
247 "#,
248 )?
249 .fetch_all(&*conn)
250 .await?
251 .into_iter()
252 .map(sql_row_to_keyset_info)
253 .collect::<Result<Vec<_>, _>>()?)
254 }
255}
256
257#[cfg(test)]
258mod test {
259 use super::*;
260
261 mod keyset_amounts_tests {
262 use super::*;
263
264 #[test]
265 fn keyset_with_amounts() {
266 let amounts = (0..32).map(|x| 2u64.pow(x)).collect::<Vec<_>>();
267 let result = sql_row_to_keyset_info(vec![
268 Column::Text("0083a60439303340".to_owned()),
269 Column::Text("sat".to_owned()),
270 Column::Integer(1),
271 Column::Integer(1749844864),
272 Column::Null,
273 Column::Text("0'/0'/0'".to_owned()),
274 Column::Integer(0),
275 Column::Text(serde_json::to_string(&amounts).expect("valid json")),
276 Column::Integer(0),
277 Column::Text("cdk/0.1.0".to_owned()),
278 ]);
279 assert!(result.is_ok());
280 let keyset = result.unwrap();
281 assert_eq!(keyset.amounts.len(), 32);
282 assert_eq!(
283 keyset.issuer_version,
284 Some(IssuerVersion::from_str("cdk/0.1.0").unwrap())
285 );
286 }
287 }
288}