use core::fmt::{Debug, Formatter};
use sqlx::pool::PoolOptions;
use sqlx::sqlite::SqliteConnectOptions;
use std::ops::Deref;
use std::path::Path;
use ockam_core::errcode::{Kind, Origin};
use sqlx::{ConnectOptions, SqlitePool};
use tokio_retry::strategy::{jitter, FixedInterval};
use tokio_retry::Retry;
use tracing::debug;
use tracing::log::LevelFilter;
use crate::database::migrations::application_migration_set::ApplicationMigrationSet;
use crate::database::migrations::node_migration_set::NodeMigrationSet;
use crate::database::migrations::MigrationSet;
use ockam_core::compat::sync::Arc;
use ockam_core::{Error, Result};
#[derive(Clone)]
pub struct SqlxDatabase {
pub pool: Arc<SqlitePool>,
}
impl Debug for SqlxDatabase {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
f.write_str(format!("database options {:?}", self.pool.connect_options()).as_str())
}
}
impl Deref for SqlxDatabase {
type Target = SqlitePool;
fn deref(&self) -> &Self::Target {
&self.pool
}
}
impl SqlxDatabase {
pub async fn create(path: impl AsRef<Path>) -> Result<Self> {
Self::create_impl(path, Some(NodeMigrationSet)).await
}
pub async fn create_with_migration(
path: impl AsRef<Path>,
migration_set: impl MigrationSet,
) -> Result<Self> {
Self::create_impl(path, Some(migration_set)).await
}
pub async fn create_no_migration(path: impl AsRef<Path>) -> Result<Self> {
Self::create_impl(path, None::<NodeMigrationSet>).await
}
async fn create_impl(
path: impl AsRef<Path>,
migration_set: Option<impl MigrationSet>,
) -> Result<Self> {
path.as_ref()
.parent()
.map(std::fs::create_dir_all)
.transpose()
.map_err(|e| Error::new(Origin::Api, Kind::Io, e.to_string()))?;
let retry_strategy = FixedInterval::from_millis(1000)
.map(jitter) .take(10); let db = Retry::spawn(retry_strategy, || async {
Self::create_at(path.as_ref()).await
})
.await?;
if let Some(migration_set) = migration_set {
let migrator = migration_set.create_migrator()?;
migrator.migrate(&db.pool).await?;
}
Ok(db)
}
pub async fn in_memory(usage: &str) -> Result<Self> {
Self::in_memory_with_migration(usage, NodeMigrationSet).await
}
pub async fn application_in_memory(usage: &str) -> Result<Self> {
Self::in_memory_with_migration(usage, ApplicationMigrationSet).await
}
pub async fn in_memory_with_migration(
usage: &str,
migration_set: impl MigrationSet,
) -> Result<Self> {
debug!("create an in memory database for {usage}");
let pool = Self::create_in_memory_connection_pool().await?;
let migrator = migration_set.create_migrator()?;
migrator.migrate(&pool).await?;
let db = SqlxDatabase {
pool: Arc::new(pool),
};
Ok(db)
}
async fn create_at(path: &Path) -> Result<Self> {
let pool = Self::create_connection_pool(path).await?;
Ok(SqlxDatabase {
pool: Arc::new(pool),
})
}
pub(crate) async fn create_connection_pool(path: &Path) -> Result<SqlitePool> {
let options = SqliteConnectOptions::new()
.filename(path)
.create_if_missing(true)
.log_statements(LevelFilter::Debug);
let pool = SqlitePool::connect_with(options)
.await
.map_err(Self::map_sql_err)?;
Ok(pool)
}
pub(crate) async fn create_in_memory_connection_pool() -> Result<SqlitePool> {
let pool_options = PoolOptions::new().idle_timeout(None).max_lifetime(None);
let pool = pool_options
.connect("sqlite::memory:")
.await
.map_err(Self::map_sql_err)?;
Ok(pool)
}
#[track_caller]
pub fn map_sql_err(err: sqlx::Error) -> Error {
Error::new(Origin::Application, Kind::Io, err)
}
#[track_caller]
pub fn map_decode_err(err: minicbor::decode::Error) -> Error {
Error::new(Origin::Application, Kind::Io, err)
}
}
pub trait FromSqlxError<T> {
fn into_core(self) -> Result<T>;
}
impl<T> FromSqlxError<T> for core::result::Result<T, sqlx::error::Error> {
#[track_caller]
fn into_core(self) -> Result<T> {
match self {
Ok(r) => Ok(r),
Err(err) => {
let err = Error::new(Origin::Api, Kind::Internal, err.to_string());
Err(err)
}
}
}
}
impl<T> FromSqlxError<T> for core::result::Result<T, sqlx::migrate::MigrateError> {
#[track_caller]
fn into_core(self) -> Result<T> {
match self {
Ok(r) => Ok(r),
Err(err) => Err(Error::new(
Origin::Application,
Kind::Io,
format!("migration error {err}"),
)),
}
}
}
pub trait ToVoid<T> {
fn void(self) -> Result<()>;
}
impl<T> ToVoid<T> for core::result::Result<T, sqlx::error::Error> {
#[track_caller]
fn void(self) -> Result<()> {
self.map(|_| ()).into_core()
}
}
#[cfg(test)]
mod tests {
use sqlx::sqlite::SqliteQueryResult;
use sqlx::FromRow;
use tempfile::NamedTempFile;
use crate::database::ToSqlxType;
use super::*;
#[tokio::test]
async fn test_create_identity_table() -> Result<()> {
let db_file = NamedTempFile::new().unwrap();
let db = SqlxDatabase::create(db_file.path()).await?;
let inserted = insert_identity(&db).await.unwrap();
assert_eq!(inserted.rows_affected(), 1);
Ok(())
}
#[tokio::test]
async fn test_query() -> Result<()> {
let db_file = NamedTempFile::new().unwrap();
let db = SqlxDatabase::create(db_file.path()).await?;
insert_identity(&db).await.unwrap();
let result: Option<IdentifierRow> =
sqlx::query_as("SELECT identifier FROM identity WHERE identifier=?1")
.bind("Ifa804b7fca12a19eed206ae180b5b576860ae651")
.fetch_optional(&*db.pool)
.await
.unwrap();
assert_eq!(
result,
Some(IdentifierRow(
"Ifa804b7fca12a19eed206ae180b5b576860ae651".into()
))
);
let result: Option<IdentifierRow> =
sqlx::query_as("SELECT identifier FROM identity WHERE identifier=?1")
.bind("x")
.fetch_optional(&*db.pool)
.await
.unwrap();
assert_eq!(result, None);
Ok(())
}
async fn insert_identity(db: &SqlxDatabase) -> Result<SqliteQueryResult> {
sqlx::query("INSERT INTO identity VALUES (?1, ?2)")
.bind("Ifa804b7fca12a19eed206ae180b5b576860ae651")
.bind("123".to_sql())
.execute(&*db.pool)
.await
.into_core()
}
#[derive(FromRow, PartialEq, Eq, Debug)]
struct IdentifierRow(String);
}