Skip to main content

cdk_sql_common/mint/
keys.rs

1//! Keys database implementation
2
3use std::collections::HashMap;
4use std::str::FromStr;
5
6use async_trait::async_trait;
7use bitcoin::bip32::DerivationPath;
8use cdk_common::database::{Error, MintKeyDatabaseTransaction, MintKeysDatabase};
9use cdk_common::mint::MintKeySetInfo;
10use cdk_common::{CurrencyUnit, Id};
11
12use super::{SQLMintDatabase, SQLTransaction};
13use crate::database::ConnectionWithTransaction;
14use crate::pool::DatabasePool;
15use crate::stmt::{query, Column};
16use crate::{
17    column_as_nullable_number, column_as_nullable_string, column_as_number, column_as_string,
18    unpack_into,
19};
20
21pub(crate) fn sql_row_to_keyset_info(row: Vec<Column>) -> Result<MintKeySetInfo, Error> {
22    unpack_into!(
23        let (
24            id,
25            unit,
26            active,
27            valid_from,
28            valid_to,
29            derivation_path,
30            derivation_path_index,
31            amounts,
32            row_keyset_ppk
33        ) = row
34    );
35
36    let amounts = column_as_nullable_string!(amounts)
37        .and_then(|str| serde_json::from_str(&str).ok())
38        .ok_or_else(|| Error::Database("amounts field is required".to_string().into()))?;
39
40    Ok(MintKeySetInfo {
41        id: column_as_string!(id, Id::from_str, Id::from_bytes),
42        unit: column_as_string!(unit, CurrencyUnit::from_str),
43        active: matches!(active, Column::Integer(1)),
44        valid_from: column_as_number!(valid_from),
45        derivation_path: column_as_string!(derivation_path, DerivationPath::from_str),
46        derivation_path_index: column_as_nullable_number!(derivation_path_index),
47        amounts,
48        input_fee_ppk: column_as_nullable_number!(row_keyset_ppk).unwrap_or(0),
49        final_expiry: column_as_nullable_number!(valid_to),
50    })
51}
52
53#[async_trait]
54impl<RM> MintKeyDatabaseTransaction<'_, Error> for SQLTransaction<RM>
55where
56    RM: DatabasePool + 'static,
57{
58    async fn add_keyset_info(&mut self, keyset: MintKeySetInfo) -> Result<(), Error> {
59        query(
60            r#"
61        INSERT INTO
62            keyset (
63                id, unit, active, valid_from, valid_to, derivation_path,
64                amounts, input_fee_ppk, derivation_path_index
65            )
66        VALUES (
67            :id, :unit, :active, :valid_from, :valid_to, :derivation_path,
68            :amounts, :input_fee_ppk, :derivation_path_index
69        )
70        ON CONFLICT(id) DO UPDATE SET
71            unit = excluded.unit,
72            active = excluded.active,
73            valid_from = excluded.valid_from,
74            valid_to = excluded.valid_to,
75            derivation_path = excluded.derivation_path,
76            amounts = excluded.amounts,
77            input_fee_ppk = excluded.input_fee_ppk,
78            derivation_path_index = excluded.derivation_path_index
79        "#,
80        )?
81        .bind("id", keyset.id.to_string())
82        .bind("unit", keyset.unit.to_string())
83        .bind("active", keyset.active)
84        .bind("valid_from", keyset.valid_from as i64)
85        .bind("valid_to", keyset.final_expiry.map(|v| v as i64))
86        .bind("derivation_path", keyset.derivation_path.to_string())
87        .bind("amounts", serde_json::to_string(&keyset.amounts).ok())
88        .bind("input_fee_ppk", keyset.input_fee_ppk as i64)
89        .bind("derivation_path_index", keyset.derivation_path_index)
90        .execute(&self.inner)
91        .await?;
92
93        Ok(())
94    }
95
96    async fn set_active_keyset(&mut self, unit: CurrencyUnit, id: Id) -> Result<(), Error> {
97        query(r#"UPDATE keyset SET active=FALSE WHERE unit = :unit"#)?
98            .bind("unit", unit.to_string())
99            .execute(&self.inner)
100            .await?;
101
102        query(r#"UPDATE keyset SET active=TRUE WHERE unit = :unit AND id = :id"#)?
103            .bind("unit", unit.to_string())
104            .bind("id", id.to_string())
105            .execute(&self.inner)
106            .await?;
107
108        Ok(())
109    }
110}
111
112#[async_trait]
113impl<RM> MintKeysDatabase for SQLMintDatabase<RM>
114where
115    RM: DatabasePool + 'static,
116{
117    type Err = Error;
118
119    async fn begin_transaction<'a>(
120        &'a self,
121    ) -> Result<Box<dyn MintKeyDatabaseTransaction<'a, Error> + Send + Sync + 'a>, Error> {
122        let tx = SQLTransaction {
123            inner: ConnectionWithTransaction::new(
124                self.pool.get().map_err(|e| Error::Database(Box::new(e)))?,
125            )
126            .await?,
127        };
128
129        Ok(Box::new(tx))
130    }
131
132    async fn get_active_keyset_id(&self, unit: &CurrencyUnit) -> Result<Option<Id>, Self::Err> {
133        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
134        Ok(
135            query(r#" SELECT id FROM keyset WHERE active = :active AND unit = :unit"#)?
136                .bind("active", true)
137                .bind("unit", unit.to_string())
138                .pluck(&*conn)
139                .await?
140                .map(|id| match id {
141                    Column::Text(text) => Ok(Id::from_str(&text)?),
142                    Column::Blob(id) => Ok(Id::from_bytes(&id)?),
143                    _ => Err(Error::InvalidKeysetId),
144                })
145                .transpose()?,
146        )
147    }
148
149    async fn get_active_keysets(&self) -> Result<HashMap<CurrencyUnit, Id>, Self::Err> {
150        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
151        Ok(
152            query(r#"SELECT id, unit FROM keyset WHERE active = :active"#)?
153                .bind("active", true)
154                .fetch_all(&*conn)
155                .await?
156                .into_iter()
157                .map(|row| {
158                    Ok((
159                        column_as_string!(&row[1], CurrencyUnit::from_str),
160                        column_as_string!(&row[0], Id::from_str, Id::from_bytes),
161                    ))
162                })
163                .collect::<Result<HashMap<_, _>, Error>>()?,
164        )
165    }
166
167    async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err> {
168        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
169        Ok(query(
170            r#"SELECT
171                id,
172                unit,
173                active,
174                valid_from,
175                valid_to,
176                derivation_path,
177                derivation_path_index,
178                amounts,
179                input_fee_ppk
180            FROM
181                keyset
182                WHERE id=:id"#,
183        )?
184        .bind("id", id.to_string())
185        .fetch_one(&*conn)
186        .await?
187        .map(sql_row_to_keyset_info)
188        .transpose()?)
189    }
190
191    async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err> {
192        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
193        Ok(query(
194            r#"SELECT
195                id,
196                unit,
197                active,
198                valid_from,
199                valid_to,
200                derivation_path,
201                derivation_path_index,
202                amounts,
203                input_fee_ppk
204            FROM
205                keyset
206            "#,
207        )?
208        .fetch_all(&*conn)
209        .await?
210        .into_iter()
211        .map(sql_row_to_keyset_info)
212        .collect::<Result<Vec<_>, _>>()?)
213    }
214}
215
216#[cfg(test)]
217mod test {
218    use super::*;
219
220    mod keyset_amounts_tests {
221        use super::*;
222
223        #[test]
224        fn keyset_with_amounts() {
225            let amounts = (0..32).map(|x| 2u64.pow(x)).collect::<Vec<_>>();
226            let result = sql_row_to_keyset_info(vec![
227                Column::Text("0083a60439303340".to_owned()),
228                Column::Text("sat".to_owned()),
229                Column::Integer(1),
230                Column::Integer(1749844864),
231                Column::Null,
232                Column::Text("0'/0'/0'".to_owned()),
233                Column::Integer(0),
234                Column::Text(serde_json::to_string(&amounts).expect("valid json")),
235                Column::Integer(0),
236            ]);
237            assert!(result.is_ok());
238            let keyset = result.unwrap();
239            assert_eq!(keyset.amounts.len(), 32);
240        }
241    }
242}