use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use crate::{Result, db::ConnectionOperation, db::Executor};
use async_trait::async_trait;
use toasty_core::{
Schema,
driver::operation::{self, IsolationLevel},
stmt::Value,
};
use tokio::sync::oneshot;
pub struct TransactionBuilder {
isolation: Option<IsolationLevel>,
read_only: bool,
}
impl TransactionBuilder {
pub fn new() -> Self {
TransactionBuilder {
isolation: None,
read_only: false,
}
}
pub fn isolation(mut self, level: IsolationLevel) -> Self {
self.isolation = Some(level);
self
}
pub fn read_only(mut self, read_only: bool) -> Self {
self.read_only = read_only;
self
}
pub async fn begin(self, conn: &mut super::Connection) -> Result<Transaction<'_>> {
Transaction::begin_with(ConnRef::Borrowed(conn), self.isolation, self.read_only).await
}
pub async fn begin_on_db(self, db: &mut super::Db) -> Result<Transaction<'_>> {
let conn = db.connection().await?;
Transaction::begin_with(ConnRef::owned(conn), self.isolation, self.read_only).await
}
}
impl Default for TransactionBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct Transaction<'a> {
conn: ConnRef<'a>,
finalized: bool,
savepoint: Option<usize>,
}
pub(crate) enum ConnRef<'a> {
Borrowed(&'a mut super::Connection),
Owned(super::Connection, PhantomData<&'a ()>),
}
impl<'a> ConnRef<'a> {
pub(crate) fn owned(conn: super::Connection) -> ConnRef<'a> {
ConnRef::Owned(conn, PhantomData)
}
}
impl Deref for ConnRef<'_> {
type Target = super::Connection;
fn deref(&self) -> &Self::Target {
match self {
ConnRef::Borrowed(c) => c,
ConnRef::Owned(c, _) => c,
}
}
}
impl DerefMut for ConnRef<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
match self {
ConnRef::Borrowed(c) => c,
ConnRef::Owned(c, _) => c,
}
}
}
impl<'a> Transaction<'a> {
pub(crate) async fn begin(conn: ConnRef<'a>) -> Result<Transaction<'a>> {
Self::begin_with(conn, None, false).await
}
pub(crate) async fn begin_with(
conn: ConnRef<'a>,
isolation: Option<IsolationLevel>,
read_only: bool,
) -> Result<Transaction<'a>> {
tracing::debug!(
isolation = ?isolation,
read_only = read_only,
"beginning transaction"
);
let tx = Transaction {
conn,
finalized: false,
savepoint: None,
};
tx.conn
.exec_operation(
operation::Transaction::Start {
isolation,
read_only,
}
.into(),
)
.await?;
Ok(tx)
}
pub async fn commit(mut self) -> Result<()> {
tracing::debug!("committing transaction");
self.finalized = true;
match self.savepoint {
Some(_) => self
.conn
.exec_operation(operation::Transaction::ReleaseSavepoint(self.savepoint()).into()),
None => self
.conn
.exec_operation(operation::Transaction::Commit.into()),
}
.await?;
Ok(())
}
pub async fn rollback(mut self) -> Result<()> {
tracing::debug!("rolling back transaction");
self.finalized = true;
match self.savepoint {
Some(_) => self.conn.exec_operation(
operation::Transaction::RollbackToSavepoint(self.savepoint()).into(),
),
None => self
.conn
.exec_operation(operation::Transaction::Rollback.into()),
}
.await?;
Ok(())
}
fn savepoint(&self) -> String {
format!("tx_{}", self.savepoint.unwrap())
}
}
impl Drop for Transaction<'_> {
fn drop(&mut self) {
if !self.finalized {
let op = match self.savepoint {
Some(_) => operation::Transaction::RollbackToSavepoint(self.savepoint()),
None => operation::Transaction::Rollback,
};
let (tx, _rx) = oneshot::channel();
let _ = self
.conn
.handle()
.in_tx
.send(ConnectionOperation::ExecOperation {
operation: Box::new(op.into()),
tx,
});
}
}
}
#[async_trait]
impl<'a> Executor for Transaction<'a> {
async fn transaction(&mut self) -> Result<Transaction<'_>> {
let depth = match self.savepoint {
Some(savepoint) => savepoint + 1,
None => 1,
};
tracing::debug!(depth = depth, "creating nested transaction (savepoint)");
let transaction = Transaction {
conn: ConnRef::Borrowed(&mut self.conn),
finalized: false,
savepoint: Some(depth),
};
transaction
.conn
.exec_operation(operation::Transaction::Savepoint(transaction.savepoint()).into())
.await?;
Ok(transaction)
}
async fn exec_untyped(&mut self, stmt: toasty_core::stmt::Statement) -> Result<Value> {
self.conn.exec_stmt(stmt, true).await
}
fn schema(&mut self) -> &Arc<Schema> {
self.conn.schema()
}
}