use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use crate::{Result, db::ConnectionOperation, db::Executor};
use async_trait::async_trait;
use toasty_core::{
Schema,
driver::operation::{self, IsolationLevel},
};
use tokio::sync::oneshot;
pub(crate) enum TxSource<'a> {
Db(&'a mut super::Db),
Connection(&'a mut super::Connection),
}
pub struct TransactionBuilder<'a> {
source: TxSource<'a>,
isolation: Option<IsolationLevel>,
read_only: bool,
}
impl<'a> TransactionBuilder<'a> {
pub(crate) fn new(source: TxSource<'a>) -> Self {
TransactionBuilder {
source,
isolation: None,
read_only: false,
}
}
pub fn isolation(mut self, level: IsolationLevel) -> Self {
self.isolation = Some(level);
self
}
pub fn read_only(mut self, read_only: bool) -> Self {
self.read_only = read_only;
self
}
pub async fn begin(self) -> Result<Transaction<'a>> {
let conn = match self.source {
TxSource::Db(db) => ConnRef::owned(db.connection().await?),
TxSource::Connection(conn) => ConnRef::Borrowed(conn),
};
Transaction::begin_with(conn, self.isolation, self.read_only).await
}
}
pub struct Transaction<'a> {
conn: ConnRef<'a>,
finalized: bool,
savepoint: Option<usize>,
}
pub(crate) enum ConnRef<'a> {
Borrowed(&'a mut super::Connection),
Owned(super::Connection, PhantomData<&'a ()>),
}
impl<'a> ConnRef<'a> {
pub(crate) fn owned(conn: super::Connection) -> ConnRef<'a> {
ConnRef::Owned(conn, PhantomData)
}
}
impl Deref for ConnRef<'_> {
type Target = super::Connection;
fn deref(&self) -> &Self::Target {
match self {
ConnRef::Borrowed(c) => c,
ConnRef::Owned(c, _) => c,
}
}
}
impl DerefMut for ConnRef<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
match self {
ConnRef::Borrowed(c) => c,
ConnRef::Owned(c, _) => c,
}
}
}
impl<'a> Transaction<'a> {
pub(crate) async fn begin(conn: ConnRef<'a>) -> Result<Transaction<'a>> {
Self::begin_with(conn, None, false).await
}
pub(crate) async fn begin_with(
conn: ConnRef<'a>,
isolation: Option<IsolationLevel>,
read_only: bool,
) -> Result<Transaction<'a>> {
tracing::debug!(
isolation = ?isolation,
read_only = read_only,
"beginning transaction"
);
let tx = Transaction {
conn,
finalized: false,
savepoint: None,
};
tx.conn
.exec_operation(
operation::Transaction::Start {
isolation,
read_only,
}
.into(),
)
.await?;
Ok(tx)
}
pub async fn transaction(&mut self) -> Result<Transaction<'_>> {
<Self as Executor>::transaction(self).await
}
pub async fn commit(mut self) -> Result<()> {
tracing::debug!("committing transaction");
self.finalized = true;
match self.savepoint {
Some(_) => self
.conn
.exec_operation(operation::Transaction::ReleaseSavepoint(self.savepoint()).into()),
None => self
.conn
.exec_operation(operation::Transaction::Commit.into()),
}
.await?;
Ok(())
}
pub async fn rollback(mut self) -> Result<()> {
tracing::debug!("rolling back transaction");
self.finalized = true;
match self.savepoint {
Some(_) => self.conn.exec_operation(
operation::Transaction::RollbackToSavepoint(self.savepoint()).into(),
),
None => self
.conn
.exec_operation(operation::Transaction::Rollback.into()),
}
.await?;
Ok(())
}
fn savepoint(&self) -> String {
format!("tx_{}", self.savepoint.unwrap())
}
}
impl Drop for Transaction<'_> {
fn drop(&mut self) {
if !self.finalized {
let op = match self.savepoint {
Some(_) => operation::Transaction::RollbackToSavepoint(self.savepoint()),
None => operation::Transaction::Rollback,
};
let (tx, _rx) = oneshot::channel();
let _ = self
.conn
.handle()
.in_tx
.send(ConnectionOperation::ExecOperation {
operation: Box::new(op.into()),
tx,
});
}
}
}
#[async_trait]
impl<'a> Executor for Transaction<'a> {
async fn transaction(&mut self) -> Result<Transaction<'_>> {
let depth = match self.savepoint {
Some(savepoint) => savepoint + 1,
None => 1,
};
tracing::debug!(depth = depth, "creating nested transaction (savepoint)");
let transaction = Transaction {
conn: ConnRef::Borrowed(&mut self.conn),
finalized: false,
savepoint: Some(depth),
};
transaction
.conn
.exec_operation(operation::Transaction::Savepoint(transaction.savepoint()).into())
.await?;
Ok(transaction)
}
async fn exec_untyped(
&mut self,
stmt: toasty_core::stmt::Statement,
) -> Result<toasty_core::driver::ExecResponse> {
self.conn.exec_stmt(stmt, true).await
}
fn schema(&mut self) -> &Arc<Schema> {
self.conn.schema()
}
}