use std::sync::Arc;
use p2panda_core::cbor::EncodeError;
use sqlx::migrate::{MigrateDatabase, Migrator};
use sqlx::sqlite::SqlitePoolOptions;
use sqlx::{Sqlite, migrate};
use thiserror::Error;
use tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore};
pub async fn create_database(url: &str) -> Result<(), SqliteError> {
if !Sqlite::database_exists(url).await? {
Sqlite::create_database(url).await?
}
Ok(())
}
pub async fn drop_database(url: &str) -> Result<(), SqliteError> {
if Sqlite::database_exists(url).await? {
Sqlite::drop_database(url).await?
}
Ok(())
}
pub async fn connection_pool(
url: &str,
max_connections: u32,
) -> Result<sqlx::SqlitePool, SqliteError> {
let pool: sqlx::SqlitePool = SqlitePoolOptions::new()
.max_connections(max_connections)
.connect(url)
.await?;
Ok(pool)
}
pub fn migrations() -> Migrator {
migrate!()
}
pub async fn run_pending_migrations(pool: &sqlx::SqlitePool) -> Result<(), SqliteError> {
migrations().run(pool).await?;
Ok(())
}
pub struct SqliteStoreBuilder {
url: String,
max_connections: u32,
run_migrations: bool,
create_database: bool,
}
impl Default for SqliteStoreBuilder {
fn default() -> Self {
Self {
url: "sqlite::memory:".into(),
max_connections: 16,
create_database: true,
run_migrations: true,
}
}
}
impl SqliteStoreBuilder {
pub fn new() -> Self {
Self::default()
}
#[cfg(any(test, feature = "test_utils"))]
pub fn random_memory_url(mut self) -> Self {
self.url = format!(
"sqlite://dbmem{}?mode=memory&cache=private",
rand::random::<u32>()
);
self
}
pub fn database_url(mut self, url: &str) -> Self {
self.url = url.to_string();
self
}
pub fn max_connections(mut self, max_connections: u32) -> Self {
self.max_connections = max_connections;
self
}
pub fn create_database(mut self, create_database: bool) -> Self {
self.create_database = create_database;
self
}
pub fn run_default_migrations(mut self, run_migrations: bool) -> Self {
self.run_migrations = run_migrations;
self
}
pub async fn build(self) -> Result<SqliteStore, SqliteError> {
if self.create_database {
create_database(&self.url).await?;
}
let pool: sqlx::SqlitePool = SqlitePoolOptions::new()
.max_connections(self.max_connections)
.connect(&self.url)
.await?;
if self.run_migrations {
run_pending_migrations(&pool).await?;
}
Ok(SqliteStore::new(pool))
}
}
pub type Transaction<'a> = sqlx::Transaction<'a, Sqlite>;
pub type SqlitePool = sqlx::SqlitePool;
#[derive(Clone, Debug)]
pub struct SqliteStore {
tx: Arc<Mutex<Option<Transaction<'static>>>>,
pub(crate) pool: sqlx::SqlitePool,
semaphore: Arc<Semaphore>,
}
impl SqliteStore {
pub(crate) fn new(pool: sqlx::SqlitePool) -> Self {
Self {
tx: Arc::default(),
pool,
semaphore: Arc::new(Semaphore::new(1)),
}
}
pub fn from_pool(pool: sqlx::SqlitePool) -> Self {
Self::new(pool)
}
pub fn pool(&self) -> &sqlx::SqlitePool {
&self.pool
}
#[cfg(any(test, feature = "test_utils"))]
pub async fn temporary() -> Self {
SqliteStoreBuilder::new()
.random_memory_url()
.max_connections(1)
.build()
.await
.expect("migrations succeeded")
}
pub async fn tx<F, R>(&self, f: F) -> Result<R, SqliteError>
where
F: AsyncFnOnce(&mut Transaction) -> Result<R, SqliteError>,
{
let mut tx_ref = self.tx.lock().await;
let tx = tx_ref.as_mut().ok_or(SqliteError::TransactionMissing)?;
f(tx).await
}
pub async fn execute<F, R>(&self, f: F) -> Result<R, SqliteError>
where
F: AsyncFnOnce(&sqlx::SqlitePool) -> Result<R, SqliteError>,
{
f(&self.pool).await
}
}
impl crate::traits::Transaction for SqliteStore {
type Error = SqliteError;
type Permit = TransactionPermit;
async fn begin(&self) -> Result<TransactionPermit, SqliteError> {
let permit = self
.semaphore
.clone()
.acquire_owned()
.await
.expect("if semaphore is closed then the whole struct is gone as well");
let mut tx_ref = self.tx.lock().await;
assert!(
tx_ref.is_none(),
"can't have an already existing transaction after an just-acquired permit"
);
let tx = self.pool.begin().await?;
tx_ref.replace(tx);
Ok(TransactionPermit::new(permit, self.tx.clone()))
}
async fn rollback(&self, permit: TransactionPermit) -> Result<(), SqliteError> {
let Some(tx) = self.tx.lock().await.take() else {
panic!("can't have no transaction without dropping permit first")
};
let result = tx.rollback().await.map_err(SqliteError::Sqlite);
permit.mark_committed_and_drop();
result
}
async fn commit(&self, permit: TransactionPermit) -> Result<(), SqliteError> {
let Some(tx) = self.tx.lock().await.take() else {
panic!("can't have no transaction without dropping permit first")
};
let result = tx.commit().await.map_err(SqliteError::Sqlite);
permit.mark_committed_and_drop();
result
}
}
pub struct TransactionPermit {
permit: Arc<OwnedSemaphorePermit>,
tx: Arc<Mutex<Option<Transaction<'static>>>>,
committed: bool,
}
impl TransactionPermit {
pub(super) fn new(
permit: OwnedSemaphorePermit,
tx: Arc<Mutex<Option<Transaction<'static>>>>,
) -> Self {
Self {
permit: Arc::new(permit),
tx,
committed: false,
}
}
pub(super) fn mark_committed_and_drop(mut self) {
self.committed = true;
drop(self)
}
}
impl Drop for TransactionPermit {
fn drop(&mut self) {
if !self.committed {
let permit = self.permit.clone();
let tx = self.tx.clone();
tokio::spawn(async move {
if let Some(tx) = tx.lock().await.take() {
let _ = tx.rollback().await;
}
drop(permit); });
}
}
}
#[derive(Debug, Error)]
pub enum SqliteError {
#[error("tried to interact with inexistant transaction")]
TransactionMissing,
#[error(transparent)]
Sqlite(#[from] sqlx::Error),
#[error(transparent)]
Migrate(#[from] sqlx::migrate::MigrateError),
#[error("failed encoding '{0}' value before storing to database: {1}")]
Encode(String, EncodeError),
#[error("could not decode corrupted '{0}' value from database: {1}")]
Decode(String, DecodeError),
}
#[derive(Debug, Error)]
pub enum DecodeError {
#[error(transparent)]
DecodeCbor(#[from] p2panda_core::cbor::DecodeError),
#[error(transparent)]
Hash(#[from] p2panda_core::hash::HashError),
#[error(transparent)]
Topic(#[from] p2panda_core::topic::TopicError),
#[error("parsing from string failed")]
FromStr,
}
#[cfg(test)]
mod tests {
use std::task::Poll;
use futures_test::task::noop_context;
use sqlx::{Executor, query, query_as, query_scalar};
use tokio::pin;
use crate::sqlite::{SqliteError, SqliteStore};
use crate::traits::Transaction;
#[tokio::test]
async fn transaction_provider() {
let pool = SqliteStore::temporary().await;
assert!(matches!(
pool.tx(async |_| Ok(())).await,
Err(SqliteError::TransactionMissing)
));
let permit = pool.begin().await.expect("no error");
assert!(matches!(
{
let fut = pool.begin();
let mut cx = noop_context();
pin!(fut);
fut.poll(&mut cx)
},
Poll::Pending
));
assert!(pool.tx(async |_| Ok(())).await.is_ok());
assert!(pool.commit(permit).await.is_ok());
assert!(matches!(
pool.tx(async |_| Ok(())).await,
Err(SqliteError::TransactionMissing)
));
}
#[tokio::test]
async fn early_permit_drop_causing_rollback() {
let pool = SqliteStore::temporary().await;
pool.execute(async |pool| {
pool.execute("CREATE TABLE test(x INTEGER)").await?;
Ok(())
})
.await
.unwrap();
let permit = pool.begin().await.unwrap();
pool.tx(async |tx| {
query("INSERT INTO test (x) VALUES (10)")
.execute(&mut **tx)
.await?;
Ok(())
})
.await
.unwrap();
drop(permit);
assert!(pool.begin().await.is_ok());
let count: i64 = pool
.execute(async |pool| {
query_scalar("SELECT COUNT(*) FROM test")
.fetch_one(pool)
.await
.map_err(SqliteError::Sqlite)
})
.await
.unwrap();
assert_eq!(count, 0);
}
#[tokio::test]
async fn serialized_transactions() {
let pool_1 = SqliteStore::temporary().await;
let pool_2 = pool_1.clone();
pool_1
.execute(async |pool| {
pool.execute("CREATE TABLE test(x INTEGER)").await?;
Ok(())
})
.await
.unwrap();
let permit_1 = pool_1.begin().await.unwrap();
let handle = tokio::spawn(async move {
let permit_2 = pool_2.begin().await.unwrap();
let result = pool_2
.tx(async |tx| {
let row: (i64,) = query_as("SELECT x FROM test").fetch_one(&mut **tx).await?;
Ok(row.0)
})
.await
.unwrap();
assert_eq!(result, 5);
pool_2
.tx(async |tx| {
query("INSERT INTO test (x) VALUES (10)")
.execute(&mut **tx)
.await?;
Ok(())
})
.await
.unwrap();
pool_2.rollback(permit_2).await.unwrap();
let result = pool_2
.execute(async |pool| {
let row: (i64,) = query_as("SELECT x FROM test").fetch_one(pool).await?;
Ok(row.0)
})
.await
.unwrap();
assert_eq!(result, 5);
});
pool_1
.tx(async |tx| {
query("INSERT INTO test (x) VALUES (5)")
.execute(&mut **tx)
.await?;
Ok(())
})
.await
.unwrap();
let result = pool_1
.tx(async |tx| {
let row: (i64,) = query_as("SELECT x FROM test").fetch_one(&mut **tx).await?;
Ok(row.0)
})
.await
.unwrap();
assert_eq!(result, 5);
pool_1.commit(permit_1).await.unwrap();
let result = pool_1
.execute(async |pool| {
let row: (i64,) = query_as("SELECT x FROM test").fetch_one(pool).await?;
Ok(row.0)
})
.await
.unwrap();
assert_eq!(result, 5);
handle.await.unwrap();
}
}