sql-orm-tiberius 0.1.0

Tiberius execution adapter for sql-orm.
Documentation
use crate::config::MssqlConnectionConfig;
use crate::error::{TiberiusErrorContext, map_tiberius_error};
use crate::executor::fetch_one_compiled;
use crate::telemetry::trace_connection;
use crate::transaction::{
    MssqlTransaction, begin_transaction_scope, commit_transaction_scope, rollback_transaction_scope,
};
use futures_io::{AsyncRead, AsyncWrite};
use sql_orm_core::OrmError;
use sql_orm_query::CompiledQuery;
use std::time::Duration;
use tiberius::Client;
use tokio::net::TcpStream;
use tokio::time::timeout;
use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};

pub type TokioConnectionStream = Compat<TcpStream>;

pub struct MssqlConnection<S: AsyncRead + AsyncWrite + Unpin + Send = TokioConnectionStream> {
    client: Client<S>,
    config: MssqlConnectionConfig,
}

impl MssqlConnection<TokioConnectionStream> {
    pub async fn connect(connection_string: &str) -> Result<Self, OrmError> {
        let config = MssqlConnectionConfig::from_connection_string(connection_string)?;
        Self::connect_with_config(config).await
    }

    pub async fn connect_with_config(config: MssqlConnectionConfig) -> Result<Self, OrmError> {
        let tracing_options = config.options().tracing;
        let connect_timeout = config.options().timeouts.connect_timeout;
        let addr = config.addr();
        let trace_addr = addr.clone();
        let tiberius_config = config.tiberius_config().clone();

        let client = trace_connection(tracing_options, &trace_addr, connect_timeout, async {
            run_with_timeout(connect_timeout, "SQL Server connection timed out", async {
                let tcp = TcpStream::connect(addr).await.map_err(|error| {
                    map_tiberius_error(&error.into(), TiberiusErrorContext::ConnectTcp)
                })?;
                tcp.set_nodelay(true).map_err(|error| {
                    map_tiberius_error(&error.into(), TiberiusErrorContext::ConfigureTcp)
                })?;

                Client::connect(tiberius_config, tcp.compat_write())
                    .await
                    .map_err(|error| {
                        map_tiberius_error(&error, TiberiusErrorContext::InitializeClient)
                    })
            })
            .await
        })
        .await?;

        Ok(Self { client, config })
    }
}

impl<S: AsyncRead + AsyncWrite + Unpin + Send> MssqlConnection<S> {
    pub fn new(client: Client<S>, config: MssqlConnectionConfig) -> Self {
        Self { client, config }
    }

    pub fn config(&self) -> &MssqlConnectionConfig {
        &self.config
    }

    pub fn client(&self) -> &Client<S> {
        &self.client
    }

    pub fn client_mut(&mut self) -> &mut Client<S> {
        &mut self.client
    }

    pub(crate) fn query_timeout(&self) -> Option<Duration> {
        self.config.options().timeouts.query_timeout
    }

    pub(crate) fn tracing_options(&self) -> crate::config::MssqlTracingOptions {
        self.config.options().tracing
    }

    pub(crate) fn slow_query_options(&self) -> crate::config::MssqlSlowQueryOptions {
        self.config.options().slow_query
    }

    pub(crate) fn retry_options(&self) -> crate::config::MssqlRetryOptions {
        self.config.options().retry
    }

    pub(crate) fn health_options(&self) -> crate::config::MssqlHealthCheckOptions {
        self.config.options().health
    }

    pub(crate) fn server_addr(&self) -> String {
        self.config.addr()
    }

    pub async fn begin_transaction<'a>(&'a mut self) -> Result<MssqlTransaction<'a, S>, OrmError> {
        let query_timeout = self.query_timeout();
        let tracing_options = self.tracing_options();
        let slow_query_options = self.slow_query_options();
        let server_addr = self.server_addr();
        MssqlTransaction::begin(
            self.client_mut(),
            query_timeout,
            tracing_options,
            slow_query_options,
            server_addr,
        )
        .await
    }

    pub async fn begin_transaction_scope(&mut self) -> Result<(), OrmError> {
        let query_timeout = self.query_timeout();
        let tracing_options = self.tracing_options();
        let server_addr = self.server_addr();
        begin_transaction_scope(
            self.client_mut(),
            query_timeout,
            tracing_options,
            &server_addr,
        )
        .await
    }

    pub async fn commit_transaction(&mut self) -> Result<(), OrmError> {
        let query_timeout = self.query_timeout();
        let tracing_options = self.tracing_options();
        let server_addr = self.server_addr();
        commit_transaction_scope(
            self.client_mut(),
            query_timeout,
            tracing_options,
            &server_addr,
        )
        .await
    }

    pub async fn rollback_transaction(&mut self) -> Result<(), OrmError> {
        let query_timeout = self.query_timeout();
        let tracing_options = self.tracing_options();
        let server_addr = self.server_addr();
        rollback_transaction_scope(
            self.client_mut(),
            query_timeout,
            tracing_options,
            &server_addr,
        )
        .await
    }

    pub async fn health_check(&mut self) -> Result<(), OrmError> {
        let tracing_options = self.tracing_options();
        let slow_query_options = self.slow_query_options();
        let retry_options = self.retry_options();
        let server_addr = self.server_addr();
        let health_options = self.health_options();
        let health_timeout = resolve_health_timeout(health_options, self.query_timeout());
        let query = build_health_check_query(health_options);

        let row = run_with_timeout(health_timeout, "SQL Server health check timed out", async {
            fetch_one_compiled::<_, HealthCheckRow>(
                self.client_mut(),
                query,
                tracing_options,
                slow_query_options,
                retry_options,
                &server_addr,
                health_timeout,
            )
            .await
        })
        .await?;

        match row {
            Some(HealthCheckRow { value: 1 }) => Ok(()),
            Some(_) => Err(OrmError::new(
                "SQL Server health check returned an unexpected value",
            )),
            None => Err(OrmError::new(
                "SQL Server health check did not return a row",
            )),
        }
    }

    pub fn into_inner(self) -> Client<S> {
        self.client
    }
}

struct HealthCheckRow {
    value: i32,
}

impl sql_orm_core::FromRow for HealthCheckRow {
    fn from_row<R: sql_orm_core::Row>(row: &R) -> Result<Self, OrmError> {
        Ok(Self {
            value: row.get_required_typed::<i32>("health_check")?,
        })
    }
}

fn resolve_health_timeout(
    health_options: crate::config::MssqlHealthCheckOptions,
    query_timeout: Option<Duration>,
) -> Option<Duration> {
    health_options.timeout.or(query_timeout)
}

fn build_health_check_query(
    health_options: crate::config::MssqlHealthCheckOptions,
) -> CompiledQuery {
    CompiledQuery::new(health_options.query.sql().to_string(), vec![])
}

pub(crate) async fn run_with_timeout<F, T>(
    duration: Option<Duration>,
    timeout_message: &'static str,
    future: F,
) -> Result<T, OrmError>
where
    F: core::future::Future<Output = Result<T, OrmError>>,
{
    match duration {
        Some(duration) => timeout(duration, future)
            .await
            .map_err(|_| OrmError::new(timeout_message))?,
        None => future.await,
    }
}

#[cfg(test)]
mod tests {
    use super::{build_health_check_query, resolve_health_timeout, run_with_timeout};
    use crate::config::{MssqlHealthCheckOptions, MssqlHealthCheckQuery};
    use std::time::Duration;

    #[test]
    fn health_check_prefers_explicit_health_timeout_over_query_timeout() {
        let health = MssqlHealthCheckOptions::enabled(MssqlHealthCheckQuery::SelectOne)
            .with_timeout(Duration::from_secs(3));

        assert_eq!(
            resolve_health_timeout(health, Some(Duration::from_secs(30))),
            Some(Duration::from_secs(3))
        );
    }

    #[test]
    fn health_check_falls_back_to_query_timeout_when_no_dedicated_timeout_exists() {
        let health = MssqlHealthCheckOptions::enabled(MssqlHealthCheckQuery::SelectOne);

        assert_eq!(
            resolve_health_timeout(health, Some(Duration::from_secs(30))),
            Some(Duration::from_secs(30))
        );
        assert_eq!(resolve_health_timeout(health, None), None);
    }

    #[test]
    fn health_check_builds_expected_compiled_query() {
        let query = build_health_check_query(MssqlHealthCheckOptions::enabled(
            MssqlHealthCheckQuery::SelectOne,
        ));

        assert_eq!(query.sql, "SELECT 1 AS [health_check]");
        assert!(query.params.is_empty());
    }

    #[tokio::test]
    async fn run_with_timeout_returns_future_result_without_timeout() {
        let value = run_with_timeout(None, "timeout", async {
            Ok::<_, sql_orm_core::OrmError>(7)
        })
        .await
        .unwrap();

        assert_eq!(value, 7);
    }

    #[tokio::test]
    async fn run_with_timeout_fails_when_future_exceeds_deadline() {
        let error = run_with_timeout(
            Some(Duration::from_millis(5)),
            "SQL Server connection timed out",
            async {
                tokio::time::sleep(Duration::from_millis(25)).await;
                Ok::<_, sql_orm_core::OrmError>(())
            },
        )
        .await
        .unwrap_err();

        assert_eq!(error.message(), "SQL Server connection timed out");
    }
}