toasty 0.3.0

An async ORM for Rust supporting SQL and NoSQL databases
Documentation
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;

use crate::{Result, db::ConnectionOperation, db::Executor};

use async_trait::async_trait;
use toasty_core::{
    Schema,
    driver::operation::{self, IsolationLevel},
};
use tokio::sync::oneshot;

/// Either a `&mut Db` or `&mut Connection`, used by [`TransactionBuilder`] to
/// defer connection acquisition until [`begin`](TransactionBuilder::begin).
pub(crate) enum TxSource<'a> {
    Db(&'a mut super::Db),
    Connection(&'a mut super::Connection),
}

/// Builder for configuring a transaction before starting it.
///
/// Obtain one via [`Db::transaction_builder`](super::Db::transaction_builder)
/// or [`Connection::transaction_builder`](super::Connection::transaction_builder),
/// configure isolation level and read-only settings, then call
/// [`begin`](Self::begin) to start the transaction.
pub struct TransactionBuilder<'a> {
    source: TxSource<'a>,
    isolation: Option<IsolationLevel>,
    read_only: bool,
}

impl<'a> TransactionBuilder<'a> {
    pub(crate) fn new(source: TxSource<'a>) -> Self {
        TransactionBuilder {
            source,
            isolation: None,
            read_only: false,
        }
    }

    /// Set the isolation level for this transaction.
    pub fn isolation(mut self, level: IsolationLevel) -> Self {
        self.isolation = Some(level);
        self
    }

    /// Set whether this transaction is read-only.
    pub fn read_only(mut self, read_only: bool) -> Self {
        self.read_only = read_only;
        self
    }

    /// Begin the transaction.
    ///
    /// When built from a [`Db`](super::Db), this acquires a connection from
    /// the pool. The connection is owned by the returned [`Transaction`] and
    /// will be returned to the pool when the transaction is dropped.
    pub async fn begin(self) -> Result<Transaction<'a>> {
        let conn = match self.source {
            TxSource::Db(db) => ConnRef::owned(db.connection().await?),
            TxSource::Connection(conn) => ConnRef::Borrowed(conn),
        };
        Transaction::begin_with(conn, self.isolation, self.read_only).await
    }
}

/// An active database transaction.
///
/// All operations executed through a `Transaction` are guaranteed to use the
/// same physical connection.
///
/// If dropped without calling [`commit`](Self::commit) or
/// [`rollback`](Self::rollback), the transaction is automatically rolled back.
pub struct Transaction<'a> {
    /// The connection this transaction operates on.
    conn: ConnRef<'a>,

    /// Whether commit or rollback has been called.
    finalized: bool,

    /// If this is a nested transaction (implemented through savepoints),
    /// this holds the savepoint stack depth to be used as an identifier.
    savepoint: Option<usize>,
}

/// Either a borrowed or owned reference to a [`Connection`](super::Connection).
pub(crate) enum ConnRef<'a> {
    Borrowed(&'a mut super::Connection),
    Owned(super::Connection, PhantomData<&'a ()>),
}

impl<'a> ConnRef<'a> {
    pub(crate) fn owned(conn: super::Connection) -> ConnRef<'a> {
        ConnRef::Owned(conn, PhantomData)
    }
}

impl Deref for ConnRef<'_> {
    type Target = super::Connection;

    fn deref(&self) -> &Self::Target {
        match self {
            ConnRef::Borrowed(c) => c,
            ConnRef::Owned(c, _) => c,
        }
    }
}

impl DerefMut for ConnRef<'_> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        match self {
            ConnRef::Borrowed(c) => c,
            ConnRef::Owned(c, _) => c,
        }
    }
}

impl<'a> Transaction<'a> {
    pub(crate) async fn begin(conn: ConnRef<'a>) -> Result<Transaction<'a>> {
        Self::begin_with(conn, None, false).await
    }

    pub(crate) async fn begin_with(
        conn: ConnRef<'a>,
        isolation: Option<IsolationLevel>,
        read_only: bool,
    ) -> Result<Transaction<'a>> {
        tracing::debug!(
            isolation = ?isolation,
            read_only = read_only,
            "beginning transaction"
        );

        // We're creating the Transaction struct before actually starting the transaction. If the
        // future is cancelled while waiting on the response of the start command, the transaction
        // is still rolled back.
        let tx = Transaction {
            conn,
            finalized: false,
            savepoint: None,
        };

        tx.conn
            .exec_operation(
                operation::Transaction::Start {
                    isolation,
                    read_only,
                }
                .into(),
            )
            .await?;
        Ok(tx)
    }

    /// Create a nested transaction (savepoint).
    pub async fn transaction(&mut self) -> Result<Transaction<'_>> {
        <Self as Executor>::transaction(self).await
    }

    /// Commit the transaction.
    pub async fn commit(mut self) -> Result<()> {
        tracing::debug!("committing transaction");
        // Because driver operations are done in a background task, all the operations aren't
        // cancelled and will continue even if this future is dropped. Setting the finalized flag
        // to true early here makes sure that if the future is dropped we don't queue a rollback
        // command.
        self.finalized = true;
        match self.savepoint {
            Some(_) => self
                .conn
                .exec_operation(operation::Transaction::ReleaseSavepoint(self.savepoint()).into()),
            None => self
                .conn
                .exec_operation(operation::Transaction::Commit.into()),
        }
        .await?;
        Ok(())
    }

    /// Roll back the transaction.
    pub async fn rollback(mut self) -> Result<()> {
        tracing::debug!("rolling back transaction");
        // See `commit` why we're setting the finalized flag to true early.
        self.finalized = true;
        match self.savepoint {
            Some(_) => self.conn.exec_operation(
                operation::Transaction::RollbackToSavepoint(self.savepoint()).into(),
            ),
            None => self
                .conn
                .exec_operation(operation::Transaction::Rollback.into()),
        }
        .await?;
        Ok(())
    }

    fn savepoint(&self) -> String {
        format!("tx_{}", self.savepoint.unwrap())
    }
}

impl Drop for Transaction<'_> {
    fn drop(&mut self) {
        if !self.finalized {
            let op = match self.savepoint {
                Some(_) => operation::Transaction::RollbackToSavepoint(self.savepoint()),
                None => operation::Transaction::Rollback,
            };

            // Fire-and-forget rollback: send the operation to the background
            // connection task without awaiting the response.
            let (tx, _rx) = oneshot::channel();
            let _ = self
                .conn
                .handle()
                .in_tx
                .send(ConnectionOperation::ExecOperation {
                    operation: Box::new(op.into()),
                    tx,
                });
        }
    }
}

#[async_trait]
impl<'a> Executor for Transaction<'a> {
    async fn transaction(&mut self) -> Result<Transaction<'_>> {
        let depth = match self.savepoint {
            Some(savepoint) => savepoint + 1,
            None => 1,
        };
        tracing::debug!(depth = depth, "creating nested transaction (savepoint)");

        let transaction = Transaction {
            conn: ConnRef::Borrowed(&mut self.conn),
            finalized: false,
            savepoint: Some(depth),
        };

        transaction
            .conn
            .exec_operation(operation::Transaction::Savepoint(transaction.savepoint()).into())
            .await?;

        Ok(transaction)
    }

    async fn exec_untyped(
        &mut self,
        stmt: toasty_core::stmt::Statement,
    ) -> Result<toasty_core::driver::ExecResponse> {
        self.conn.exec_stmt(stmt, true).await
    }

    fn schema(&mut self) -> &Arc<Schema> {
        self.conn.schema()
    }
}