use std::ops::{Deref, DerefMut};
use futures_core::future::BoxFuture;
use crate::connection::Connection;
use crate::cursor::HasCursor;
use crate::database::Database;
use crate::describe::Describe;
use crate::executor::{Execute, Executor, RefExecutor};
use crate::runtime::spawn;
#[must_use = "transaction rolls back if not explicitly `.commit()`ed"]
pub struct Transaction<C>
where
C: Connection,
{
inner: Option<C>,
depth: u32,
}
impl<C> Transaction<C>
where
C: Connection,
{
pub(crate) async fn new(depth: u32, mut inner: C) -> crate::Result<Self> {
if depth == 0 {
inner.execute("BEGIN").await?;
} else {
let stmt = format!("SAVEPOINT _sqlx_savepoint_{}", depth);
inner.execute(&*stmt).await?;
}
Ok(Self {
inner: Some(inner),
depth: depth + 1,
})
}
pub async fn begin(self) -> crate::Result<Transaction<Transaction<C>>> {
Transaction::new(self.depth, self).await
}
pub async fn commit(mut self) -> crate::Result<C> {
let mut inner = self.inner.take().expect(ERR_FINALIZED);
let depth = self.depth;
if depth == 1 {
inner.execute("COMMIT").await?;
} else {
let stmt = format!("RELEASE SAVEPOINT _sqlx_savepoint_{}", depth - 1);
inner.execute(&*stmt).await?;
}
Ok(inner)
}
pub async fn rollback(mut self) -> crate::Result<C> {
let mut inner = self.inner.take().expect(ERR_FINALIZED);
let depth = self.depth;
if depth == 1 {
inner.execute("ROLLBACK").await?;
} else {
let stmt = format!("ROLLBACK TO SAVEPOINT _sqlx_savepoint_{}", depth - 1);
inner.execute(&*stmt).await?;
}
Ok(inner)
}
}
const ERR_FINALIZED: &str = "(bug) transaction already finalized";
impl<C> Deref for Transaction<C>
where
C: Connection,
{
type Target = C;
fn deref(&self) -> &Self::Target {
self.inner.as_ref().expect(ERR_FINALIZED)
}
}
impl<C> DerefMut for Transaction<C>
where
C: Connection,
{
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner.as_mut().expect(ERR_FINALIZED)
}
}
impl<C> Connection for Transaction<C>
where
C: Connection,
{
fn close(mut self) -> BoxFuture<'static, crate::Result<()>> {
Box::pin(async move {
let mut inner = self.inner.take().expect(ERR_FINALIZED);
if self.depth == 1 {
let res = inner.execute("ROLLBACK").await;
let _ = inner.close().await;
res?;
} else {
inner.close().await?
}
Ok(())
})
}
#[inline]
fn ping(&mut self) -> BoxFuture<'_, crate::Result<()>> {
self.deref_mut().ping()
}
}
impl<DB, C> Executor for Transaction<C>
where
DB: Database,
C: Connection<Database = DB>,
{
type Database = C::Database;
fn execute<'e, 'q: 'e, 'c: 'e, E: 'e>(
&'c mut self,
query: E,
) -> BoxFuture<'e, crate::Result<u64>>
where
E: Execute<'q, Self::Database>,
{
(**self).execute(query)
}
fn fetch<'e, 'q, E>(&'e mut self, query: E) -> <Self::Database as HasCursor<'e, 'q>>::Cursor
where
E: Execute<'q, Self::Database>,
{
(**self).fetch(query)
}
#[doc(hidden)]
fn describe<'e, 'q, E: 'e>(
&'e mut self,
query: E,
) -> BoxFuture<'e, crate::Result<Describe<Self::Database>>>
where
E: Execute<'q, Self::Database>,
{
(**self).describe(query)
}
}
impl<'e, DB, C> RefExecutor<'e> for &'e mut Transaction<C>
where
DB: Database,
C: Connection<Database = DB>,
{
type Database = DB;
fn fetch_by_ref<'q, E>(self, query: E) -> <Self::Database as HasCursor<'e, 'q>>::Cursor
where
E: Execute<'q, Self::Database>,
{
(**self).fetch(query)
}
}
impl<C> Drop for Transaction<C>
where
C: Connection,
{
fn drop(&mut self) {
if self.depth > 0 {
if let Some(inner) = self.inner.take() {
spawn(async move {
let _ = inner.close().await;
});
}
}
}
}