1use std::collections::HashMap;
4use std::str::FromStr;
5
6use async_trait::async_trait;
7use bitcoin::bip32::DerivationPath;
8use cdk_common::common::IssuerVersion;
9use cdk_common::database::{Error, MintKeyDatabaseTransaction, MintKeysDatabase};
10use cdk_common::mint::MintKeySetInfo;
11use cdk_common::{CurrencyUnit, Id};
12
13use super::{SQLMintDatabase, SQLTransaction};
14use crate::database::ConnectionWithTransaction;
15use crate::pool::DatabasePool;
16use crate::stmt::{query, Column};
17use crate::{
18 column_as_nullable_number, column_as_nullable_string, column_as_number, column_as_string,
19 unpack_into,
20};
21
22pub(crate) fn sql_row_to_keyset_info(row: Vec<Column>) -> Result<MintKeySetInfo, Error> {
23 unpack_into!(
24 let (
25 id,
26 unit,
27 active,
28 valid_from,
29 valid_to,
30 derivation_path,
31 derivation_path_index,
32 amounts,
33 row_keyset_ppk,
34 issuer_version
35 ) = row
36 );
37
38 let amounts = column_as_nullable_string!(amounts)
39 .and_then(|str| serde_json::from_str(&str).ok())
40 .ok_or_else(|| Error::Database("amounts field is required".to_string().into()))?;
41
42 Ok(MintKeySetInfo {
43 id: column_as_string!(id, Id::from_str, Id::from_bytes),
44 unit: column_as_string!(unit, CurrencyUnit::from_str),
45 active: matches!(active, Column::Integer(1)),
46 valid_from: column_as_number!(valid_from),
47 derivation_path: column_as_string!(derivation_path, DerivationPath::from_str),
48 derivation_path_index: column_as_nullable_number!(derivation_path_index),
49 amounts,
50 input_fee_ppk: column_as_nullable_number!(row_keyset_ppk).unwrap_or(0),
51 final_expiry: column_as_nullable_number!(valid_to),
52 issuer_version: column_as_nullable_string!(issuer_version).and_then(|v| {
53 match IssuerVersion::from_str(&v) {
54 Ok(ver) => Some(ver),
55 Err(e) => {
56 tracing::warn!(
57 "Failed to parse issuer_version from database: {}. Error: {}",
58 v,
59 e
60 );
61 None
62 }
63 }
64 }),
65 })
66}
67
68#[async_trait]
69impl<RM> MintKeyDatabaseTransaction<'_, Error> for SQLTransaction<RM>
70where
71 RM: DatabasePool + 'static,
72{
73 async fn add_keyset_info(&mut self, keyset: MintKeySetInfo) -> Result<(), Error> {
74 query(
75 r#"
76 INSERT INTO
77 keyset (
78 id, unit, active, valid_from, valid_to, derivation_path,
79 amounts, input_fee_ppk, derivation_path_index, issuer_version
80 )
81 VALUES (
82 :id, :unit, :active, :valid_from, :valid_to, :derivation_path,
83 :amounts, :input_fee_ppk, :derivation_path_index, :issuer_version
84 )
85 ON CONFLICT(id) DO UPDATE SET
86 unit = excluded.unit,
87 active = excluded.active,
88 valid_from = excluded.valid_from,
89 valid_to = excluded.valid_to,
90 derivation_path = excluded.derivation_path,
91 amounts = excluded.amounts,
92 input_fee_ppk = excluded.input_fee_ppk,
93 derivation_path_index = excluded.derivation_path_index,
94 issuer_version = excluded.issuer_version
95 "#,
96 )?
97 .bind("id", keyset.id.to_string())
98 .bind("unit", keyset.unit.to_string())
99 .bind("active", keyset.active)
100 .bind("valid_from", keyset.valid_from as i64)
101 .bind("valid_to", keyset.final_expiry.map(|v| v as i64))
102 .bind("derivation_path", keyset.derivation_path.to_string())
103 .bind("amounts", serde_json::to_string(&keyset.amounts).ok())
104 .bind("input_fee_ppk", keyset.input_fee_ppk as i64)
105 .bind("derivation_path_index", keyset.derivation_path_index)
106 .bind(
107 "issuer_version",
108 keyset.issuer_version.map(|v| v.to_string()),
109 )
110 .execute(&self.inner)
111 .await?;
112
113 Ok(())
114 }
115
116 async fn set_active_keyset(&mut self, unit: CurrencyUnit, id: Id) -> Result<(), Error> {
117 query(r#"UPDATE keyset SET active=FALSE WHERE unit = :unit"#)?
118 .bind("unit", unit.to_string())
119 .execute(&self.inner)
120 .await?;
121
122 query(r#"UPDATE keyset SET active=TRUE WHERE unit = :unit AND id = :id"#)?
123 .bind("unit", unit.to_string())
124 .bind("id", id.to_string())
125 .execute(&self.inner)
126 .await?;
127
128 Ok(())
129 }
130}
131
132#[async_trait]
133impl<RM> MintKeysDatabase for SQLMintDatabase<RM>
134where
135 RM: DatabasePool + 'static,
136{
137 type Err = Error;
138
139 async fn begin_transaction<'a>(
140 &'a self,
141 ) -> Result<Box<dyn MintKeyDatabaseTransaction<'a, Error> + Send + Sync + 'a>, Error> {
142 let tx = SQLTransaction {
143 inner: ConnectionWithTransaction::new(
144 self.pool.get().map_err(|e| Error::Database(Box::new(e)))?,
145 )
146 .await?,
147 };
148
149 Ok(Box::new(tx))
150 }
151
152 async fn get_active_keyset_id(&self, unit: &CurrencyUnit) -> Result<Option<Id>, Self::Err> {
153 let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
154 Ok(
155 query(r#" SELECT id FROM keyset WHERE active = :active AND unit = :unit"#)?
156 .bind("active", true)
157 .bind("unit", unit.to_string())
158 .pluck(&*conn)
159 .await?
160 .map(|id| match id {
161 Column::Text(text) => Ok(Id::from_str(&text)?),
162 Column::Blob(id) => Ok(Id::from_bytes(&id)?),
163 _ => Err(Error::InvalidKeysetId),
164 })
165 .transpose()?,
166 )
167 }
168
169 async fn get_active_keysets(&self) -> Result<HashMap<CurrencyUnit, Id>, Self::Err> {
170 let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
171 Ok(
172 query(r#"SELECT id, unit FROM keyset WHERE active = :active"#)?
173 .bind("active", true)
174 .fetch_all(&*conn)
175 .await?
176 .into_iter()
177 .map(|row| {
178 Ok((
179 column_as_string!(&row[1], CurrencyUnit::from_str),
180 column_as_string!(&row[0], Id::from_str, Id::from_bytes),
181 ))
182 })
183 .collect::<Result<HashMap<_, _>, Error>>()?,
184 )
185 }
186
187 async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err> {
188 let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
189 Ok(query(
190 r#"SELECT
191 id,
192 unit,
193 active,
194 valid_from,
195 valid_to,
196 derivation_path,
197 derivation_path_index,
198 amounts,
199 input_fee_ppk,
200 issuer_version
201 FROM
202 keyset
203 WHERE id=:id"#,
204 )?
205 .bind("id", id.to_string())
206 .fetch_one(&*conn)
207 .await?
208 .map(sql_row_to_keyset_info)
209 .transpose()?)
210 }
211
212 async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err> {
213 let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
214 Ok(query(
215 r#"SELECT
216 id,
217 unit,
218 active,
219 valid_from,
220 valid_to,
221 derivation_path,
222 derivation_path_index,
223 amounts,
224 input_fee_ppk,
225 issuer_version
226 FROM
227 keyset
228 "#,
229 )?
230 .fetch_all(&*conn)
231 .await?
232 .into_iter()
233 .map(sql_row_to_keyset_info)
234 .collect::<Result<Vec<_>, _>>()?)
235 }
236}
237
238#[cfg(test)]
239mod test {
240 use super::*;
241
242 mod keyset_amounts_tests {
243 use super::*;
244
245 #[test]
246 fn keyset_with_amounts() {
247 let amounts = (0..32).map(|x| 2u64.pow(x)).collect::<Vec<_>>();
248 let result = sql_row_to_keyset_info(vec![
249 Column::Text("0083a60439303340".to_owned()),
250 Column::Text("sat".to_owned()),
251 Column::Integer(1),
252 Column::Integer(1749844864),
253 Column::Null,
254 Column::Text("0'/0'/0'".to_owned()),
255 Column::Integer(0),
256 Column::Text(serde_json::to_string(&amounts).expect("valid json")),
257 Column::Integer(0),
258 Column::Text("cdk/0.1.0".to_owned()),
259 ]);
260 assert!(result.is_ok());
261 let keyset = result.unwrap();
262 assert_eq!(keyset.amounts.len(), 32);
263 assert_eq!(
264 keyset.issuer_version,
265 Some(IssuerVersion::from_str("cdk/0.1.0").unwrap())
266 );
267 }
268 }
269}