Skip to main content

cdk_sql_common/mint/
keys.rs

1//! Keys database implementation
2
3use std::collections::HashMap;
4use std::str::FromStr;
5
6use async_trait::async_trait;
7use bitcoin::bip32::DerivationPath;
8use cdk_common::common::IssuerVersion;
9use cdk_common::database::{Error, MintKeyDatabaseTransaction, MintKeysDatabase};
10use cdk_common::mint::MintKeySetInfo;
11use cdk_common::{CurrencyUnit, Id};
12
13use super::{SQLMintDatabase, SQLTransaction};
14use crate::database::ConnectionWithTransaction;
15use crate::pool::DatabasePool;
16use crate::stmt::{query, Column};
17use crate::{
18    column_as_nullable_number, column_as_nullable_string, column_as_number, column_as_string,
19    unpack_into,
20};
21
22pub(crate) fn sql_row_to_keyset_info(row: Vec<Column>) -> Result<MintKeySetInfo, Error> {
23    unpack_into!(
24        let (
25            id,
26            unit,
27            active,
28            valid_from,
29            valid_to,
30            derivation_path,
31            derivation_path_index,
32            amounts,
33            row_keyset_ppk,
34            issuer_version
35        ) = row
36    );
37
38    let amounts = column_as_nullable_string!(amounts)
39        .and_then(|str| serde_json::from_str(&str).ok())
40        .ok_or_else(|| Error::Database("amounts field is required".to_string().into()))?;
41
42    Ok(MintKeySetInfo {
43        id: column_as_string!(id, Id::from_str, Id::from_bytes),
44        unit: column_as_string!(unit, CurrencyUnit::from_str),
45        active: matches!(active, Column::Integer(1)),
46        valid_from: column_as_number!(valid_from),
47        derivation_path: column_as_string!(derivation_path, DerivationPath::from_str),
48        derivation_path_index: column_as_nullable_number!(derivation_path_index),
49        amounts,
50        input_fee_ppk: column_as_nullable_number!(row_keyset_ppk).unwrap_or(0),
51        final_expiry: column_as_nullable_number!(valid_to),
52        issuer_version: column_as_nullable_string!(issuer_version).and_then(|v| {
53            match IssuerVersion::from_str(&v) {
54                Ok(ver) => Some(ver),
55                Err(e) => {
56                    tracing::warn!(
57                        "Failed to parse issuer_version from database: {}. Error: {}",
58                        v,
59                        e
60                    );
61                    None
62                }
63            }
64        }),
65    })
66}
67
68#[async_trait]
69impl<RM> MintKeyDatabaseTransaction<'_, Error> for SQLTransaction<RM>
70where
71    RM: DatabasePool + 'static,
72{
73    async fn add_keyset_info(&mut self, keyset: MintKeySetInfo) -> Result<(), Error> {
74        query(
75            r#"
76        INSERT INTO
77            keyset (
78                id, unit, active, valid_from, valid_to, derivation_path,
79                amounts, input_fee_ppk, derivation_path_index, issuer_version
80            )
81        VALUES (
82            :id, :unit, :active, :valid_from, :valid_to, :derivation_path,
83            :amounts, :input_fee_ppk, :derivation_path_index, :issuer_version
84        )
85        ON CONFLICT(id) DO UPDATE SET
86            unit = excluded.unit,
87            active = excluded.active,
88            valid_from = excluded.valid_from,
89            valid_to = excluded.valid_to,
90            derivation_path = excluded.derivation_path,
91            amounts = excluded.amounts,
92            input_fee_ppk = excluded.input_fee_ppk,
93            derivation_path_index = excluded.derivation_path_index,
94            issuer_version = excluded.issuer_version
95        "#,
96        )?
97        .bind("id", keyset.id.to_string())
98        .bind("unit", keyset.unit.to_string())
99        .bind("active", keyset.active)
100        .bind("valid_from", keyset.valid_from as i64)
101        .bind("valid_to", keyset.final_expiry.map(|v| v as i64))
102        .bind("derivation_path", keyset.derivation_path.to_string())
103        .bind("amounts", serde_json::to_string(&keyset.amounts).ok())
104        .bind("input_fee_ppk", keyset.input_fee_ppk as i64)
105        .bind("derivation_path_index", keyset.derivation_path_index)
106        .bind(
107            "issuer_version",
108            keyset.issuer_version.map(|v| v.to_string()),
109        )
110        .execute(&self.inner)
111        .await?;
112
113        Ok(())
114    }
115
116    async fn set_active_keyset(&mut self, unit: CurrencyUnit, id: Id) -> Result<(), Error> {
117        query(r#"UPDATE keyset SET active=FALSE WHERE unit = :unit"#)?
118            .bind("unit", unit.to_string())
119            .execute(&self.inner)
120            .await?;
121
122        query(r#"UPDATE keyset SET active=TRUE WHERE unit = :unit AND id = :id"#)?
123            .bind("unit", unit.to_string())
124            .bind("id", id.to_string())
125            .execute(&self.inner)
126            .await?;
127
128        Ok(())
129    }
130}
131
132#[async_trait]
133impl<RM> MintKeysDatabase for SQLMintDatabase<RM>
134where
135    RM: DatabasePool + 'static,
136{
137    type Err = Error;
138
139    async fn begin_transaction<'a>(
140        &'a self,
141    ) -> Result<Box<dyn MintKeyDatabaseTransaction<'a, Error> + Send + Sync + 'a>, Error> {
142        let tx = SQLTransaction {
143            inner: ConnectionWithTransaction::new(
144                self.pool
145                    .get()
146                    .await
147                    .map_err(|e| Error::Database(Box::new(e)))?,
148            )
149            .await?,
150        };
151
152        Ok(Box::new(tx))
153    }
154
155    async fn get_active_keyset_id(&self, unit: &CurrencyUnit) -> Result<Option<Id>, Self::Err> {
156        let conn = self
157            .pool
158            .get()
159            .await
160            .map_err(|e| Error::Database(Box::new(e)))?;
161        Ok(
162            query(r#" SELECT id FROM keyset WHERE active = :active AND unit = :unit"#)?
163                .bind("active", true)
164                .bind("unit", unit.to_string())
165                .pluck(&*conn)
166                .await?
167                .map(|id| match id {
168                    Column::Text(text) => Ok(Id::from_str(&text)?),
169                    Column::Blob(id) => Ok(Id::from_bytes(&id)?),
170                    _ => Err(Error::InvalidKeysetId),
171                })
172                .transpose()?,
173        )
174    }
175
176    async fn get_active_keysets(&self) -> Result<HashMap<CurrencyUnit, Id>, Self::Err> {
177        let conn = self
178            .pool
179            .get()
180            .await
181            .map_err(|e| Error::Database(Box::new(e)))?;
182        Ok(
183            query(r#"SELECT id, unit FROM keyset WHERE active = :active"#)?
184                .bind("active", true)
185                .fetch_all(&*conn)
186                .await?
187                .into_iter()
188                .map(|row| {
189                    Ok((
190                        column_as_string!(&row[1], CurrencyUnit::from_str),
191                        column_as_string!(&row[0], Id::from_str, Id::from_bytes),
192                    ))
193                })
194                .collect::<Result<HashMap<_, _>, Error>>()?,
195        )
196    }
197
198    async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err> {
199        let conn = self
200            .pool
201            .get()
202            .await
203            .map_err(|e| Error::Database(Box::new(e)))?;
204        Ok(query(
205            r#"SELECT
206                id,
207                unit,
208                active,
209                valid_from,
210                valid_to,
211                derivation_path,
212                derivation_path_index,
213                amounts,
214                input_fee_ppk,
215                issuer_version
216            FROM
217                keyset
218                WHERE id=:id"#,
219        )?
220        .bind("id", id.to_string())
221        .fetch_one(&*conn)
222        .await?
223        .map(sql_row_to_keyset_info)
224        .transpose()?)
225    }
226
227    async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err> {
228        let conn = self
229            .pool
230            .get()
231            .await
232            .map_err(|e| Error::Database(Box::new(e)))?;
233        Ok(query(
234            r#"SELECT
235                id,
236                unit,
237                active,
238                valid_from,
239                valid_to,
240                derivation_path,
241                derivation_path_index,
242                amounts,
243                input_fee_ppk,
244                issuer_version
245            FROM
246                keyset
247            "#,
248        )?
249        .fetch_all(&*conn)
250        .await?
251        .into_iter()
252        .map(sql_row_to_keyset_info)
253        .collect::<Result<Vec<_>, _>>()?)
254    }
255}
256
257#[cfg(test)]
258mod test {
259    use super::*;
260
261    mod keyset_amounts_tests {
262        use super::*;
263
264        #[test]
265        fn keyset_with_amounts() {
266            let amounts = (0..32).map(|x| 2u64.pow(x)).collect::<Vec<_>>();
267            let result = sql_row_to_keyset_info(vec![
268                Column::Text("0083a60439303340".to_owned()),
269                Column::Text("sat".to_owned()),
270                Column::Integer(1),
271                Column::Integer(1749844864),
272                Column::Null,
273                Column::Text("0'/0'/0'".to_owned()),
274                Column::Integer(0),
275                Column::Text(serde_json::to_string(&amounts).expect("valid json")),
276                Column::Integer(0),
277                Column::Text("cdk/0.1.0".to_owned()),
278            ]);
279            assert!(result.is_ok());
280            let keyset = result.unwrap();
281            assert_eq!(keyset.amounts.len(), 32);
282            assert_eq!(
283                keyset.issuer_version,
284                Some(IssuerVersion::from_str("cdk/0.1.0").unwrap())
285            );
286        }
287    }
288}