cdk-sql-common 0.16.0-rc.0

Generic SQL storage backend for CDK
Documentation
//! Keys database implementation

use std::collections::HashMap;
use std::str::FromStr;

use async_trait::async_trait;
use bitcoin::bip32::DerivationPath;
use cdk_common::common::IssuerVersion;
use cdk_common::database::{Error, MintKeyDatabaseTransaction, MintKeysDatabase};
use cdk_common::mint::MintKeySetInfo;
use cdk_common::{CurrencyUnit, Id};

use super::{SQLMintDatabase, SQLTransaction};
use crate::database::ConnectionWithTransaction;
use crate::pool::DatabasePool;
use crate::stmt::{query, Column};
use crate::{
    column_as_nullable_number, column_as_nullable_string, column_as_number, column_as_string,
    unpack_into,
};

pub(crate) fn sql_row_to_keyset_info(row: Vec<Column>) -> Result<MintKeySetInfo, Error> {
    unpack_into!(
        let (
            id,
            unit,
            active,
            valid_from,
            valid_to,
            derivation_path,
            derivation_path_index,
            amounts,
            row_keyset_ppk,
            issuer_version
        ) = row
    );

    let amounts = column_as_nullable_string!(amounts)
        .and_then(|str| serde_json::from_str(&str).ok())
        .ok_or_else(|| Error::Database("amounts field is required".to_string().into()))?;

    Ok(MintKeySetInfo {
        id: column_as_string!(id, Id::from_str, Id::from_bytes),
        unit: column_as_string!(unit, CurrencyUnit::from_str),
        active: matches!(active, Column::Integer(1)),
        valid_from: column_as_number!(valid_from),
        derivation_path: column_as_string!(derivation_path, DerivationPath::from_str),
        derivation_path_index: column_as_nullable_number!(derivation_path_index),
        amounts,
        input_fee_ppk: column_as_nullable_number!(row_keyset_ppk).unwrap_or(0),
        final_expiry: column_as_nullable_number!(valid_to),
        issuer_version: column_as_nullable_string!(issuer_version).and_then(|v| {
            match IssuerVersion::from_str(&v) {
                Ok(ver) => Some(ver),
                Err(e) => {
                    tracing::warn!(
                        "Failed to parse issuer_version from database: {}. Error: {}",
                        v,
                        e
                    );
                    None
                }
            }
        }),
    })
}

#[async_trait]
impl<RM> MintKeyDatabaseTransaction<'_, Error> for SQLTransaction<RM>
where
    RM: DatabasePool + 'static,
{
    async fn add_keyset_info(&mut self, keyset: MintKeySetInfo) -> Result<(), Error> {
        query(
            r#"
        INSERT INTO
            keyset (
                id, unit, active, valid_from, valid_to, derivation_path,
                amounts, input_fee_ppk, derivation_path_index, issuer_version
            )
        VALUES (
            :id, :unit, :active, :valid_from, :valid_to, :derivation_path,
            :amounts, :input_fee_ppk, :derivation_path_index, :issuer_version
        )
        ON CONFLICT(id) DO UPDATE SET
            unit = excluded.unit,
            active = excluded.active,
            valid_from = excluded.valid_from,
            valid_to = excluded.valid_to,
            derivation_path = excluded.derivation_path,
            amounts = excluded.amounts,
            input_fee_ppk = excluded.input_fee_ppk,
            derivation_path_index = excluded.derivation_path_index,
            issuer_version = excluded.issuer_version
        "#,
        )?
        .bind("id", keyset.id.to_string())
        .bind("unit", keyset.unit.to_string())
        .bind("active", keyset.active)
        .bind("valid_from", keyset.valid_from as i64)
        .bind("valid_to", keyset.final_expiry.map(|v| v as i64))
        .bind("derivation_path", keyset.derivation_path.to_string())
        .bind("amounts", serde_json::to_string(&keyset.amounts).ok())
        .bind("input_fee_ppk", keyset.input_fee_ppk as i64)
        .bind("derivation_path_index", keyset.derivation_path_index)
        .bind(
            "issuer_version",
            keyset.issuer_version.map(|v| v.to_string()),
        )
        .execute(&self.inner)
        .await?;

        Ok(())
    }

    async fn set_active_keyset(&mut self, unit: CurrencyUnit, id: Id) -> Result<(), Error> {
        query(r#"UPDATE keyset SET active=FALSE WHERE unit = :unit"#)?
            .bind("unit", unit.to_string())
            .execute(&self.inner)
            .await?;

        query(r#"UPDATE keyset SET active=TRUE WHERE unit = :unit AND id = :id"#)?
            .bind("unit", unit.to_string())
            .bind("id", id.to_string())
            .execute(&self.inner)
            .await?;

        Ok(())
    }
}

#[async_trait]
impl<RM> MintKeysDatabase for SQLMintDatabase<RM>
where
    RM: DatabasePool + 'static,
{
    type Err = Error;

    async fn begin_transaction<'a>(
        &'a self,
    ) -> Result<Box<dyn MintKeyDatabaseTransaction<'a, Error> + Send + Sync + 'a>, Error> {
        let tx = SQLTransaction {
            inner: ConnectionWithTransaction::new(
                self.pool.get().map_err(|e| Error::Database(Box::new(e)))?,
            )
            .await?,
        };

        Ok(Box::new(tx))
    }

    async fn get_active_keyset_id(&self, unit: &CurrencyUnit) -> Result<Option<Id>, Self::Err> {
        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
        Ok(
            query(r#" SELECT id FROM keyset WHERE active = :active AND unit = :unit"#)?
                .bind("active", true)
                .bind("unit", unit.to_string())
                .pluck(&*conn)
                .await?
                .map(|id| match id {
                    Column::Text(text) => Ok(Id::from_str(&text)?),
                    Column::Blob(id) => Ok(Id::from_bytes(&id)?),
                    _ => Err(Error::InvalidKeysetId),
                })
                .transpose()?,
        )
    }

    async fn get_active_keysets(&self) -> Result<HashMap<CurrencyUnit, Id>, Self::Err> {
        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
        Ok(
            query(r#"SELECT id, unit FROM keyset WHERE active = :active"#)?
                .bind("active", true)
                .fetch_all(&*conn)
                .await?
                .into_iter()
                .map(|row| {
                    Ok((
                        column_as_string!(&row[1], CurrencyUnit::from_str),
                        column_as_string!(&row[0], Id::from_str, Id::from_bytes),
                    ))
                })
                .collect::<Result<HashMap<_, _>, Error>>()?,
        )
    }

    async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err> {
        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
        Ok(query(
            r#"SELECT
                id,
                unit,
                active,
                valid_from,
                valid_to,
                derivation_path,
                derivation_path_index,
                amounts,
                input_fee_ppk,
                issuer_version
            FROM
                keyset
                WHERE id=:id"#,
        )?
        .bind("id", id.to_string())
        .fetch_one(&*conn)
        .await?
        .map(sql_row_to_keyset_info)
        .transpose()?)
    }

    async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err> {
        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
        Ok(query(
            r#"SELECT
                id,
                unit,
                active,
                valid_from,
                valid_to,
                derivation_path,
                derivation_path_index,
                amounts,
                input_fee_ppk,
                issuer_version
            FROM
                keyset
            "#,
        )?
        .fetch_all(&*conn)
        .await?
        .into_iter()
        .map(sql_row_to_keyset_info)
        .collect::<Result<Vec<_>, _>>()?)
    }
}

#[cfg(test)]
mod test {
    use super::*;

    mod keyset_amounts_tests {
        use super::*;

        #[test]
        fn keyset_with_amounts() {
            let amounts = (0..32).map(|x| 2u64.pow(x)).collect::<Vec<_>>();
            let result = sql_row_to_keyset_info(vec![
                Column::Text("0083a60439303340".to_owned()),
                Column::Text("sat".to_owned()),
                Column::Integer(1),
                Column::Integer(1749844864),
                Column::Null,
                Column::Text("0'/0'/0'".to_owned()),
                Column::Integer(0),
                Column::Text(serde_json::to_string(&amounts).expect("valid json")),
                Column::Integer(0),
                Column::Text("cdk/0.1.0".to_owned()),
            ]);
            assert!(result.is_ok());
            let keyset = result.unwrap();
            assert_eq!(keyset.amounts.len(), 32);
            assert_eq!(
                keyset.issuer_version,
                Some(IssuerVersion::from_str("cdk/0.1.0").unwrap())
            );
        }
    }
}