sql-orm-tiberius 0.2.0-rc.1

Tiberius execution adapter for sql-orm.
Documentation
use crate::config::{MssqlSlowQueryOptions, MssqlTracingOptions};
use crate::connection::run_with_timeout;
use crate::error::{TiberiusErrorContext, map_tiberius_error};
use crate::executor::{
    ExecuteResult, execute_compiled, fetch_all_compiled, fetch_one_compiled, query_raw_compiled,
};
use crate::telemetry::trace_transaction_command;
use futures_io::{AsyncRead, AsyncWrite};
use sql_orm_core::{FromRow, OrmError};
use sql_orm_query::CompiledQuery;
use std::time::Duration;
use tiberius::{Client, QueryStream};

const BEGIN_TRANSACTION_SQL: &str = "BEGIN TRANSACTION";
const COMMIT_TRANSACTION_SQL: &str = "COMMIT TRANSACTION";
const ROLLBACK_TRANSACTION_SQL: &str = "ROLLBACK TRANSACTION";

pub struct MssqlTransaction<'a, S: AsyncRead + AsyncWrite + Unpin + Send> {
    client: &'a mut Client<S>,
    query_timeout: Option<Duration>,
    tracing_options: MssqlTracingOptions,
    slow_query_options: MssqlSlowQueryOptions,
    server_addr: String,
    completed: bool,
}

impl<'a, S> MssqlTransaction<'a, S>
where
    S: AsyncRead + AsyncWrite + Unpin + Send,
{
    pub(crate) async fn begin(
        client: &'a mut Client<S>,
        query_timeout: Option<Duration>,
        tracing_options: MssqlTracingOptions,
        slow_query_options: MssqlSlowQueryOptions,
        server_addr: String,
    ) -> Result<Self, OrmError> {
        begin_transaction_scope(client, query_timeout, tracing_options, &server_addr).await?;

        Ok(Self {
            client,
            query_timeout,
            tracing_options,
            slow_query_options,
            server_addr,
            completed: false,
        })
    }

    pub fn is_completed(&self) -> bool {
        self.completed
    }

    pub async fn commit(mut self) -> Result<(), OrmError> {
        self.finish(COMMIT_TRANSACTION_SQL).await
    }

    pub async fn rollback(mut self) -> Result<(), OrmError> {
        self.finish(ROLLBACK_TRANSACTION_SQL).await
    }

    pub async fn execute(&mut self, query: CompiledQuery) -> Result<ExecuteResult, OrmError> {
        run_with_timeout(self.query_timeout, "SQL Server query timed out", async {
            execute_compiled(
                self.client,
                query,
                self.tracing_options,
                self.slow_query_options,
                &self.server_addr,
                self.query_timeout,
            )
            .await
        })
        .await
    }

    pub async fn query_raw<'b>(
        &'b mut self,
        query: CompiledQuery,
    ) -> Result<QueryStream<'b>, OrmError> {
        run_with_timeout(self.query_timeout, "SQL Server query timed out", async {
            query_raw_compiled(
                self.client,
                query,
                self.tracing_options,
                self.slow_query_options,
                &self.server_addr,
                self.query_timeout,
            )
            .await
        })
        .await
    }

    pub async fn fetch_one<T>(&mut self, query: CompiledQuery) -> Result<Option<T>, OrmError>
    where
        T: FromRow + Send,
    {
        run_with_timeout(self.query_timeout, "SQL Server query timed out", async {
            fetch_one_compiled(
                self.client,
                query,
                self.tracing_options,
                self.slow_query_options,
                crate::config::MssqlRetryOptions::disabled(),
                &self.server_addr,
                self.query_timeout,
            )
            .await
        })
        .await
    }

    pub async fn fetch_all<T>(&mut self, query: CompiledQuery) -> Result<Vec<T>, OrmError>
    where
        T: FromRow + Send,
    {
        run_with_timeout(self.query_timeout, "SQL Server query timed out", async {
            fetch_all_compiled(
                self.client,
                query,
                self.tracing_options,
                self.slow_query_options,
                crate::config::MssqlRetryOptions::disabled(),
                &self.server_addr,
                self.query_timeout,
            )
            .await
        })
        .await
    }

    async fn finish(&mut self, sql: &'static str) -> Result<(), OrmError> {
        if self.completed {
            return Err(OrmError::new("transaction has already been completed"));
        }

        run_transaction_command(
            self.client,
            sql,
            self.query_timeout,
            self.tracing_options,
            &self.server_addr,
        )
        .await?;
        self.completed = true;

        Ok(())
    }
}

pub(crate) async fn begin_transaction_scope<S>(
    client: &mut Client<S>,
    query_timeout: Option<Duration>,
    tracing_options: MssqlTracingOptions,
    server_addr: &str,
) -> Result<(), OrmError>
where
    S: AsyncRead + AsyncWrite + Unpin + Send,
{
    run_transaction_command(
        client,
        BEGIN_TRANSACTION_SQL,
        query_timeout,
        tracing_options,
        server_addr,
    )
    .await
}

pub(crate) async fn commit_transaction_scope<S>(
    client: &mut Client<S>,
    query_timeout: Option<Duration>,
    tracing_options: MssqlTracingOptions,
    server_addr: &str,
) -> Result<(), OrmError>
where
    S: AsyncRead + AsyncWrite + Unpin + Send,
{
    run_transaction_command(
        client,
        COMMIT_TRANSACTION_SQL,
        query_timeout,
        tracing_options,
        server_addr,
    )
    .await
}

pub(crate) async fn rollback_transaction_scope<S>(
    client: &mut Client<S>,
    query_timeout: Option<Duration>,
    tracing_options: MssqlTracingOptions,
    server_addr: &str,
) -> Result<(), OrmError>
where
    S: AsyncRead + AsyncWrite + Unpin + Send,
{
    run_transaction_command(
        client,
        ROLLBACK_TRANSACTION_SQL,
        query_timeout,
        tracing_options,
        server_addr,
    )
    .await
}

pub(crate) async fn run_transaction_command<S>(
    client: &mut Client<S>,
    sql: &'static str,
    query_timeout: Option<Duration>,
    tracing_options: MssqlTracingOptions,
    server_addr: &str,
) -> Result<(), OrmError>
where
    S: AsyncRead + AsyncWrite + Unpin + Send,
{
    trace_transaction_command(tracing_options, server_addr, query_timeout, sql, async {
        run_with_timeout(query_timeout, "SQL Server query timed out", async {
            client
                .simple_query(sql)
                .await
                .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ExecuteQuery))?
                .into_results()
                .await
                .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ExecuteQuery))?;

            Ok(())
        })
        .await
    })
    .await
}

#[cfg(test)]
mod tests {
    use super::{
        BEGIN_TRANSACTION_SQL, COMMIT_TRANSACTION_SQL, MssqlTransaction, ROLLBACK_TRANSACTION_SQL,
        begin_transaction_scope, commit_transaction_scope, rollback_transaction_scope,
    };
    use std::time::Duration;

    #[test]
    fn transaction_command_constants_match_expected_sql() {
        assert_eq!(BEGIN_TRANSACTION_SQL, "BEGIN TRANSACTION");
        assert_eq!(COMMIT_TRANSACTION_SQL, "COMMIT TRANSACTION");
        assert_eq!(ROLLBACK_TRANSACTION_SQL, "ROLLBACK TRANSACTION");
    }

    #[test]
    fn transaction_wrapper_tracks_completion_state() {
        let wrapper = core::mem::size_of::<
            Option<MssqlTransaction<'static, tokio_util::compat::Compat<tokio::net::TcpStream>>>,
        >();

        assert!(wrapper > 0);
    }

    #[test]
    fn exposes_scope_level_transaction_helpers() {
        let begin = begin_transaction_scope::<tokio_util::compat::Compat<tokio::net::TcpStream>>;
        let commit = commit_transaction_scope::<tokio_util::compat::Compat<tokio::net::TcpStream>>;
        let rollback =
            rollback_transaction_scope::<tokio_util::compat::Compat<tokio::net::TcpStream>>;

        let _ = (begin, commit, rollback);
    }

    #[tokio::test]
    async fn transaction_timeout_shape_is_copyable_for_runtime_use() {
        let timeout = Some(Duration::from_secs(1));

        assert_eq!(timeout, Some(Duration::from_secs(1)));
    }
}