1use std::collections::HashMap;
4use std::str::FromStr;
5
6use async_trait::async_trait;
7use bitcoin::bip32::DerivationPath;
8use cdk_common::database::{Error, MintKeyDatabaseTransaction, MintKeysDatabase};
9use cdk_common::mint::MintKeySetInfo;
10use cdk_common::{CurrencyUnit, Id};
11
12use super::{SQLMintDatabase, SQLTransaction};
13use crate::database::ConnectionWithTransaction;
14use crate::pool::DatabasePool;
15use crate::stmt::{query, Column};
16use crate::{
17 column_as_nullable_number, column_as_nullable_string, column_as_number, column_as_string,
18 unpack_into,
19};
20
21pub(crate) fn sql_row_to_keyset_info(row: Vec<Column>) -> Result<MintKeySetInfo, Error> {
22 unpack_into!(
23 let (
24 id,
25 unit,
26 active,
27 valid_from,
28 valid_to,
29 derivation_path,
30 derivation_path_index,
31 amounts,
32 row_keyset_ppk
33 ) = row
34 );
35
36 let amounts = column_as_nullable_string!(amounts)
37 .and_then(|str| serde_json::from_str(&str).ok())
38 .ok_or_else(|| Error::Database("amounts field is required".to_string().into()))?;
39
40 Ok(MintKeySetInfo {
41 id: column_as_string!(id, Id::from_str, Id::from_bytes),
42 unit: column_as_string!(unit, CurrencyUnit::from_str),
43 active: matches!(active, Column::Integer(1)),
44 valid_from: column_as_number!(valid_from),
45 derivation_path: column_as_string!(derivation_path, DerivationPath::from_str),
46 derivation_path_index: column_as_nullable_number!(derivation_path_index),
47 amounts,
48 input_fee_ppk: column_as_nullable_number!(row_keyset_ppk).unwrap_or(0),
49 final_expiry: column_as_nullable_number!(valid_to),
50 })
51}
52
53#[async_trait]
54impl<RM> MintKeyDatabaseTransaction<'_, Error> for SQLTransaction<RM>
55where
56 RM: DatabasePool + 'static,
57{
58 async fn add_keyset_info(&mut self, keyset: MintKeySetInfo) -> Result<(), Error> {
59 query(
60 r#"
61 INSERT INTO
62 keyset (
63 id, unit, active, valid_from, valid_to, derivation_path,
64 amounts, input_fee_ppk, derivation_path_index
65 )
66 VALUES (
67 :id, :unit, :active, :valid_from, :valid_to, :derivation_path,
68 :amounts, :input_fee_ppk, :derivation_path_index
69 )
70 ON CONFLICT(id) DO UPDATE SET
71 unit = excluded.unit,
72 active = excluded.active,
73 valid_from = excluded.valid_from,
74 valid_to = excluded.valid_to,
75 derivation_path = excluded.derivation_path,
76 amounts = excluded.amounts,
77 input_fee_ppk = excluded.input_fee_ppk,
78 derivation_path_index = excluded.derivation_path_index
79 "#,
80 )?
81 .bind("id", keyset.id.to_string())
82 .bind("unit", keyset.unit.to_string())
83 .bind("active", keyset.active)
84 .bind("valid_from", keyset.valid_from as i64)
85 .bind("valid_to", keyset.final_expiry.map(|v| v as i64))
86 .bind("derivation_path", keyset.derivation_path.to_string())
87 .bind("amounts", serde_json::to_string(&keyset.amounts).ok())
88 .bind("input_fee_ppk", keyset.input_fee_ppk as i64)
89 .bind("derivation_path_index", keyset.derivation_path_index)
90 .execute(&self.inner)
91 .await?;
92
93 Ok(())
94 }
95
96 async fn set_active_keyset(&mut self, unit: CurrencyUnit, id: Id) -> Result<(), Error> {
97 query(r#"UPDATE keyset SET active=FALSE WHERE unit = :unit"#)?
98 .bind("unit", unit.to_string())
99 .execute(&self.inner)
100 .await?;
101
102 query(r#"UPDATE keyset SET active=TRUE WHERE unit = :unit AND id = :id"#)?
103 .bind("unit", unit.to_string())
104 .bind("id", id.to_string())
105 .execute(&self.inner)
106 .await?;
107
108 Ok(())
109 }
110}
111
112#[async_trait]
113impl<RM> MintKeysDatabase for SQLMintDatabase<RM>
114where
115 RM: DatabasePool + 'static,
116{
117 type Err = Error;
118
119 async fn begin_transaction<'a>(
120 &'a self,
121 ) -> Result<Box<dyn MintKeyDatabaseTransaction<'a, Error> + Send + Sync + 'a>, Error> {
122 let tx = SQLTransaction {
123 inner: ConnectionWithTransaction::new(
124 self.pool.get().map_err(|e| Error::Database(Box::new(e)))?,
125 )
126 .await?,
127 };
128
129 Ok(Box::new(tx))
130 }
131
132 async fn get_active_keyset_id(&self, unit: &CurrencyUnit) -> Result<Option<Id>, Self::Err> {
133 let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
134 Ok(
135 query(r#" SELECT id FROM keyset WHERE active = :active AND unit = :unit"#)?
136 .bind("active", true)
137 .bind("unit", unit.to_string())
138 .pluck(&*conn)
139 .await?
140 .map(|id| match id {
141 Column::Text(text) => Ok(Id::from_str(&text)?),
142 Column::Blob(id) => Ok(Id::from_bytes(&id)?),
143 _ => Err(Error::InvalidKeysetId),
144 })
145 .transpose()?,
146 )
147 }
148
149 async fn get_active_keysets(&self) -> Result<HashMap<CurrencyUnit, Id>, Self::Err> {
150 let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
151 Ok(
152 query(r#"SELECT id, unit FROM keyset WHERE active = :active"#)?
153 .bind("active", true)
154 .fetch_all(&*conn)
155 .await?
156 .into_iter()
157 .map(|row| {
158 Ok((
159 column_as_string!(&row[1], CurrencyUnit::from_str),
160 column_as_string!(&row[0], Id::from_str, Id::from_bytes),
161 ))
162 })
163 .collect::<Result<HashMap<_, _>, Error>>()?,
164 )
165 }
166
167 async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err> {
168 let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
169 Ok(query(
170 r#"SELECT
171 id,
172 unit,
173 active,
174 valid_from,
175 valid_to,
176 derivation_path,
177 derivation_path_index,
178 amounts,
179 input_fee_ppk
180 FROM
181 keyset
182 WHERE id=:id"#,
183 )?
184 .bind("id", id.to_string())
185 .fetch_one(&*conn)
186 .await?
187 .map(sql_row_to_keyset_info)
188 .transpose()?)
189 }
190
191 async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err> {
192 let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
193 Ok(query(
194 r#"SELECT
195 id,
196 unit,
197 active,
198 valid_from,
199 valid_to,
200 derivation_path,
201 derivation_path_index,
202 amounts,
203 input_fee_ppk
204 FROM
205 keyset
206 "#,
207 )?
208 .fetch_all(&*conn)
209 .await?
210 .into_iter()
211 .map(sql_row_to_keyset_info)
212 .collect::<Result<Vec<_>, _>>()?)
213 }
214}
215
216#[cfg(test)]
217mod test {
218 use super::*;
219
220 mod keyset_amounts_tests {
221 use super::*;
222
223 #[test]
224 fn keyset_with_amounts() {
225 let amounts = (0..32).map(|x| 2u64.pow(x)).collect::<Vec<_>>();
226 let result = sql_row_to_keyset_info(vec![
227 Column::Text("0083a60439303340".to_owned()),
228 Column::Text("sat".to_owned()),
229 Column::Integer(1),
230 Column::Integer(1749844864),
231 Column::Null,
232 Column::Text("0'/0'/0'".to_owned()),
233 Column::Integer(0),
234 Column::Text(serde_json::to_string(&amounts).expect("valid json")),
235 Column::Integer(0),
236 ]);
237 assert!(result.is_ok());
238 let keyset = result.unwrap();
239 assert_eq!(keyset.amounts.len(), 32);
240 }
241 }
242}