use std::fmt::Debug;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use cdk_common::database::Error;
use crate::stmt::{query, Column, Statement};
#[async_trait::async_trait]
pub trait DatabaseExecutor: Debug + Sync + Send {
fn name() -> &'static str;
async fn execute(&self, statement: Statement) -> Result<usize, Error>;
async fn fetch_one(&self, statement: Statement) -> Result<Option<Vec<Column>>, Error>;
async fn fetch_all(&self, statement: Statement) -> Result<Vec<Vec<Column>>, Error>;
async fn pluck(&self, statement: Statement) -> Result<Option<Column>, Error>;
async fn batch(&self, statement: Statement) -> Result<(), Error>;
}
#[async_trait::async_trait]
pub trait DatabaseTransaction<DB>
where
DB: DatabaseExecutor,
{
async fn commit(conn: &mut DB) -> Result<(), Error>;
async fn begin(conn: &mut DB) -> Result<(), Error>;
async fn rollback(conn: &mut DB) -> Result<(), Error>;
}
#[derive(Debug)]
pub struct ConnectionWithTransaction<DB, W>
where
DB: DatabaseConnector + 'static,
W: Debug + Deref<Target = DB> + DerefMut<Target = DB> + Send + Sync + 'static,
{
inner: Option<W>,
}
impl<DB, W> ConnectionWithTransaction<DB, W>
where
DB: DatabaseConnector,
W: Debug + Deref<Target = DB> + DerefMut<Target = DB> + Send + Sync + 'static,
{
pub async fn new(mut inner: W) -> Result<Self, Error> {
DB::Transaction::begin(inner.deref_mut()).await?;
Ok(Self { inner: Some(inner) })
}
pub async fn commit(mut self) -> Result<(), Error> {
let mut conn = self
.inner
.take()
.ok_or(Error::Internal("Missing connection".to_owned()))?;
DB::Transaction::commit(&mut conn).await?;
Ok(())
}
pub async fn rollback(mut self) -> Result<(), Error> {
let mut conn = self
.inner
.take()
.ok_or(Error::Internal("Missing connection".to_owned()))?;
DB::Transaction::rollback(&mut conn).await?;
Ok(())
}
}
impl<DB, W> Drop for ConnectionWithTransaction<DB, W>
where
DB: DatabaseConnector,
W: Debug + Deref<Target = DB> + DerefMut<Target = DB> + Send + Sync + 'static,
{
fn drop(&mut self) {
if let Some(mut conn) = self.inner.take() {
tokio::spawn(async move {
let _ = DB::Transaction::rollback(conn.deref_mut()).await;
});
}
}
}
#[async_trait::async_trait]
impl<DB, W> DatabaseExecutor for ConnectionWithTransaction<DB, W>
where
DB: DatabaseConnector,
W: Debug + Deref<Target = DB> + DerefMut<Target = DB> + Send + Sync + 'static,
{
fn name() -> &'static str {
"Transaction"
}
async fn execute(&self, statement: Statement) -> Result<usize, Error> {
self.inner
.as_ref()
.ok_or(Error::Internal("Missing internal connection".to_owned()))?
.execute(statement)
.await
}
async fn fetch_one(&self, statement: Statement) -> Result<Option<Vec<Column>>, Error> {
self.inner
.as_ref()
.ok_or(Error::Internal("Missing internal connection".to_owned()))?
.fetch_one(statement)
.await
}
async fn fetch_all(&self, statement: Statement) -> Result<Vec<Vec<Column>>, Error> {
self.inner
.as_ref()
.ok_or(Error::Internal("Missing internal connection".to_owned()))?
.fetch_all(statement)
.await
}
async fn pluck(&self, statement: Statement) -> Result<Option<Column>, Error> {
self.inner
.as_ref()
.ok_or(Error::Internal("Missing internal connection".to_owned()))?
.pluck(statement)
.await
}
async fn batch(&self, statement: Statement) -> Result<(), Error> {
self.inner
.as_ref()
.ok_or(Error::Internal("Missing internal connection".to_owned()))?
.batch(statement)
.await
}
}
#[allow(missing_debug_implementations)]
pub struct GenericTransactionHandler<W>(PhantomData<W>);
#[async_trait::async_trait]
impl<W> DatabaseTransaction<W> for GenericTransactionHandler<W>
where
W: DatabaseExecutor,
{
async fn commit(conn: &mut W) -> Result<(), Error> {
query("COMMIT")?.execute(conn).await?;
Ok(())
}
async fn begin(conn: &mut W) -> Result<(), Error> {
query("START TRANSACTION")?.execute(conn).await?;
Ok(())
}
async fn rollback(conn: &mut W) -> Result<(), Error> {
query("ROLLBACK")?.execute(conn).await?;
Ok(())
}
}
#[async_trait::async_trait]
pub trait DatabaseConnector: Debug + DatabaseExecutor + Send + Sync {
type Transaction: DatabaseTransaction<Self>
where
Self: Sized;
}