Skip to main content

cdk_sql_common/mint/
keys.rs

1//! Keys database implementation
2
3use std::collections::HashMap;
4use std::str::FromStr;
5
6use async_trait::async_trait;
7use bitcoin::bip32::DerivationPath;
8use cdk_common::common::IssuerVersion;
9use cdk_common::database::{Error, MintKeyDatabaseTransaction, MintKeysDatabase};
10use cdk_common::mint::MintKeySetInfo;
11use cdk_common::{CurrencyUnit, Id};
12
13use super::{SQLMintDatabase, SQLTransaction};
14use crate::database::ConnectionWithTransaction;
15use crate::pool::DatabasePool;
16use crate::stmt::{query, Column};
17use crate::{
18    column_as_nullable_number, column_as_nullable_string, column_as_number, column_as_string,
19    unpack_into,
20};
21
22pub(crate) fn sql_row_to_keyset_info(row: Vec<Column>) -> Result<MintKeySetInfo, Error> {
23    unpack_into!(
24        let (
25            id,
26            unit,
27            active,
28            valid_from,
29            valid_to,
30            derivation_path,
31            derivation_path_index,
32            amounts,
33            row_keyset_ppk,
34            issuer_version
35        ) = row
36    );
37
38    let amounts = column_as_nullable_string!(amounts)
39        .and_then(|str| serde_json::from_str(&str).ok())
40        .ok_or_else(|| Error::Database("amounts field is required".to_string().into()))?;
41
42    Ok(MintKeySetInfo {
43        id: column_as_string!(id, Id::from_str, Id::from_bytes),
44        unit: column_as_string!(unit, CurrencyUnit::from_str),
45        active: matches!(active, Column::Integer(1)),
46        valid_from: column_as_number!(valid_from),
47        derivation_path: column_as_string!(derivation_path, DerivationPath::from_str),
48        derivation_path_index: column_as_nullable_number!(derivation_path_index),
49        amounts,
50        input_fee_ppk: column_as_nullable_number!(row_keyset_ppk).unwrap_or(0),
51        final_expiry: column_as_nullable_number!(valid_to),
52        issuer_version: column_as_nullable_string!(issuer_version).and_then(|v| {
53            match IssuerVersion::from_str(&v) {
54                Ok(ver) => Some(ver),
55                Err(e) => {
56                    tracing::warn!(
57                        "Failed to parse issuer_version from database: {}. Error: {}",
58                        v,
59                        e
60                    );
61                    None
62                }
63            }
64        }),
65    })
66}
67
68#[async_trait]
69impl<RM> MintKeyDatabaseTransaction<'_, Error> for SQLTransaction<RM>
70where
71    RM: DatabasePool + 'static,
72{
73    async fn add_keyset_info(&mut self, keyset: MintKeySetInfo) -> Result<(), Error> {
74        query(
75            r#"
76        INSERT INTO
77            keyset (
78                id, unit, active, valid_from, valid_to, derivation_path,
79                amounts, input_fee_ppk, derivation_path_index, issuer_version
80            )
81        VALUES (
82            :id, :unit, :active, :valid_from, :valid_to, :derivation_path,
83            :amounts, :input_fee_ppk, :derivation_path_index, :issuer_version
84        )
85        ON CONFLICT(id) DO UPDATE SET
86            unit = excluded.unit,
87            active = excluded.active,
88            valid_from = excluded.valid_from,
89            valid_to = excluded.valid_to,
90            derivation_path = excluded.derivation_path,
91            amounts = excluded.amounts,
92            input_fee_ppk = excluded.input_fee_ppk,
93            derivation_path_index = excluded.derivation_path_index,
94            issuer_version = excluded.issuer_version
95        "#,
96        )?
97        .bind("id", keyset.id.to_string())
98        .bind("unit", keyset.unit.to_string())
99        .bind("active", keyset.active)
100        .bind("valid_from", keyset.valid_from as i64)
101        .bind("valid_to", keyset.final_expiry.map(|v| v as i64))
102        .bind("derivation_path", keyset.derivation_path.to_string())
103        .bind("amounts", serde_json::to_string(&keyset.amounts).ok())
104        .bind("input_fee_ppk", keyset.input_fee_ppk as i64)
105        .bind("derivation_path_index", keyset.derivation_path_index)
106        .bind(
107            "issuer_version",
108            keyset.issuer_version.map(|v| v.to_string()),
109        )
110        .execute(&self.inner)
111        .await?;
112
113        Ok(())
114    }
115
116    async fn set_active_keyset(&mut self, unit: CurrencyUnit, id: Id) -> Result<(), Error> {
117        query(r#"UPDATE keyset SET active=FALSE WHERE unit = :unit"#)?
118            .bind("unit", unit.to_string())
119            .execute(&self.inner)
120            .await?;
121
122        query(r#"UPDATE keyset SET active=TRUE WHERE unit = :unit AND id = :id"#)?
123            .bind("unit", unit.to_string())
124            .bind("id", id.to_string())
125            .execute(&self.inner)
126            .await?;
127
128        Ok(())
129    }
130}
131
132#[async_trait]
133impl<RM> MintKeysDatabase for SQLMintDatabase<RM>
134where
135    RM: DatabasePool + 'static,
136{
137    type Err = Error;
138
139    async fn begin_transaction<'a>(
140        &'a self,
141    ) -> Result<Box<dyn MintKeyDatabaseTransaction<'a, Error> + Send + Sync + 'a>, Error> {
142        let tx = SQLTransaction {
143            inner: ConnectionWithTransaction::new(
144                self.pool.get().map_err(|e| Error::Database(Box::new(e)))?,
145            )
146            .await?,
147        };
148
149        Ok(Box::new(tx))
150    }
151
152    async fn get_active_keyset_id(&self, unit: &CurrencyUnit) -> Result<Option<Id>, Self::Err> {
153        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
154        Ok(
155            query(r#" SELECT id FROM keyset WHERE active = :active AND unit = :unit"#)?
156                .bind("active", true)
157                .bind("unit", unit.to_string())
158                .pluck(&*conn)
159                .await?
160                .map(|id| match id {
161                    Column::Text(text) => Ok(Id::from_str(&text)?),
162                    Column::Blob(id) => Ok(Id::from_bytes(&id)?),
163                    _ => Err(Error::InvalidKeysetId),
164                })
165                .transpose()?,
166        )
167    }
168
169    async fn get_active_keysets(&self) -> Result<HashMap<CurrencyUnit, Id>, Self::Err> {
170        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
171        Ok(
172            query(r#"SELECT id, unit FROM keyset WHERE active = :active"#)?
173                .bind("active", true)
174                .fetch_all(&*conn)
175                .await?
176                .into_iter()
177                .map(|row| {
178                    Ok((
179                        column_as_string!(&row[1], CurrencyUnit::from_str),
180                        column_as_string!(&row[0], Id::from_str, Id::from_bytes),
181                    ))
182                })
183                .collect::<Result<HashMap<_, _>, Error>>()?,
184        )
185    }
186
187    async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err> {
188        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
189        Ok(query(
190            r#"SELECT
191                id,
192                unit,
193                active,
194                valid_from,
195                valid_to,
196                derivation_path,
197                derivation_path_index,
198                amounts,
199                input_fee_ppk,
200                issuer_version
201            FROM
202                keyset
203                WHERE id=:id"#,
204        )?
205        .bind("id", id.to_string())
206        .fetch_one(&*conn)
207        .await?
208        .map(sql_row_to_keyset_info)
209        .transpose()?)
210    }
211
212    async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err> {
213        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
214        Ok(query(
215            r#"SELECT
216                id,
217                unit,
218                active,
219                valid_from,
220                valid_to,
221                derivation_path,
222                derivation_path_index,
223                amounts,
224                input_fee_ppk,
225                issuer_version
226            FROM
227                keyset
228            "#,
229        )?
230        .fetch_all(&*conn)
231        .await?
232        .into_iter()
233        .map(sql_row_to_keyset_info)
234        .collect::<Result<Vec<_>, _>>()?)
235    }
236}
237
238#[cfg(test)]
239mod test {
240    use super::*;
241
242    mod keyset_amounts_tests {
243        use super::*;
244
245        #[test]
246        fn keyset_with_amounts() {
247            let amounts = (0..32).map(|x| 2u64.pow(x)).collect::<Vec<_>>();
248            let result = sql_row_to_keyset_info(vec![
249                Column::Text("0083a60439303340".to_owned()),
250                Column::Text("sat".to_owned()),
251                Column::Integer(1),
252                Column::Integer(1749844864),
253                Column::Null,
254                Column::Text("0'/0'/0'".to_owned()),
255                Column::Integer(0),
256                Column::Text(serde_json::to_string(&amounts).expect("valid json")),
257                Column::Integer(0),
258                Column::Text("cdk/0.1.0".to_owned()),
259            ]);
260            assert!(result.is_ok());
261            let keyset = result.unwrap();
262            assert_eq!(keyset.amounts.len(), 32);
263            assert_eq!(
264                keyset.issuer_version,
265                Some(IssuerVersion::from_str("cdk/0.1.0").unwrap())
266            );
267        }
268    }
269}