sql-orm-tiberius 0.1.0

Tiberius execution adapter for sql-orm.
Documentation
use crate::{MssqlConnection, MssqlConnectionConfig, MssqlPoolOptions, TokioConnectionStream};
use async_trait::async_trait;
use bb8::{ManageConnection, Pool, PooledConnection};
use core::ops::{Deref, DerefMut};
use sql_orm_core::OrmError;
use std::time::Duration;

#[derive(Debug, Clone)]
pub struct MssqlPool {
    inner: Pool<MssqlConnectionManager>,
    config: MssqlConnectionConfig,
    options: MssqlPoolOptions,
}

impl MssqlPool {
    pub fn builder() -> MssqlPoolBuilder {
        MssqlPoolBuilder::default()
    }

    pub async fn acquire(&self) -> Result<MssqlPooledConnection<'_>, OrmError> {
        let connection = self
            .inner
            .get()
            .await
            .map_err(|_| OrmError::new("failed to acquire SQL Server pooled connection"))?;

        Ok(MssqlPooledConnection { inner: connection })
    }

    pub fn max_size(&self) -> u32 {
        self.options.max_size
    }

    pub fn options(&self) -> &MssqlPoolOptions {
        &self.options
    }

    pub fn config(&self) -> &MssqlConnectionConfig {
        &self.config
    }
}

pub struct MssqlPooledConnection<'a> {
    inner: PooledConnection<'a, MssqlConnectionManager>,
}

impl Deref for MssqlPooledConnection<'_> {
    type Target = MssqlConnection<TokioConnectionStream>;

    fn deref(&self) -> &Self::Target {
        &self.inner
    }
}

impl DerefMut for MssqlPooledConnection<'_> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.inner
    }
}

#[derive(Debug, Clone)]
pub struct MssqlPoolBuilder {
    options: MssqlPoolOptions,
}

impl Default for MssqlPoolBuilder {
    fn default() -> Self {
        Self {
            options: MssqlPoolOptions::bb8(10),
        }
    }
}

impl MssqlPoolBuilder {
    pub fn max_size(mut self, max_size: u32) -> Self {
        self.options.max_size = max_size;
        self.options.enabled = true;
        self
    }

    pub fn min_idle(mut self, min_idle: u32) -> Self {
        self.options.min_idle = Some(min_idle);
        self.options.enabled = true;
        self
    }

    pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
        self.options.acquire_timeout = Some(timeout);
        self.options.enabled = true;
        self
    }

    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
        self.options.idle_timeout = Some(timeout);
        self.options.enabled = true;
        self
    }

    pub fn max_lifetime(mut self, timeout: Duration) -> Self {
        self.options.max_lifetime = Some(timeout);
        self.options.enabled = true;
        self
    }

    pub fn with_pool_options(mut self, options: MssqlPoolOptions) -> Self {
        self.options = options;
        self.options.enabled = true;
        self
    }

    pub fn options(&self) -> MssqlPoolOptions {
        self.options
    }

    pub async fn connect(self, connection_string: &str) -> Result<MssqlPool, OrmError> {
        let config = MssqlConnectionConfig::from_connection_string(connection_string)?;
        self.connect_with_config(config).await
    }

    pub async fn connect_with_config(
        self,
        config: MssqlConnectionConfig,
    ) -> Result<MssqlPool, OrmError> {
        build_pool(config, self.options).await
    }
}

#[derive(Debug, Clone)]
pub struct MssqlConnectionManager {
    config: MssqlConnectionConfig,
}

impl MssqlConnectionManager {
    fn new(config: MssqlConnectionConfig) -> Self {
        Self { config }
    }
}

#[async_trait]
impl ManageConnection for MssqlConnectionManager {
    type Connection = MssqlConnection<TokioConnectionStream>;
    type Error = OrmError;

    async fn connect(&self) -> Result<Self::Connection, Self::Error> {
        MssqlConnection::connect_with_config(self.config.clone()).await
    }

    async fn is_valid(&self, connection: &mut Self::Connection) -> Result<(), Self::Error> {
        connection.health_check().await
    }

    fn has_broken(&self, _connection: &mut Self::Connection) -> bool {
        false
    }
}

async fn build_pool(
    config: MssqlConnectionConfig,
    options: MssqlPoolOptions,
) -> Result<MssqlPool, OrmError> {
    let operational_options = config.options().clone().with_pool(options);
    let config = config.with_options(operational_options);
    let manager = MssqlConnectionManager::new(config.clone());
    let mut builder = Pool::builder().max_size(options.max_size);

    if let Some(min_idle) = options.min_idle {
        builder = builder.min_idle(Some(min_idle));
    }

    if let Some(timeout) = options.acquire_timeout {
        builder = builder.connection_timeout(timeout);
    }

    if let Some(timeout) = options.idle_timeout {
        builder = builder.idle_timeout(Some(timeout));
    }

    if let Some(timeout) = options.max_lifetime {
        builder = builder.max_lifetime(Some(timeout));
    }

    let inner = builder
        .build(manager)
        .await
        .map_err(|_| OrmError::new("failed to create SQL Server connection pool"))?;

    Ok(MssqlPool {
        inner,
        config,
        options,
    })
}

#[cfg(test)]
mod tests {
    use super::MssqlPool;
    use crate::{MssqlPoolBackend, MssqlPoolOptions};
    use std::time::Duration;

    #[test]
    fn builder_starts_with_bb8_defaults() {
        let builder = MssqlPool::builder();

        assert_eq!(builder.options().backend, MssqlPoolBackend::Bb8);
        assert_eq!(builder.options().max_size, 10);
        assert!(builder.options().enabled);
    }

    #[test]
    fn builder_preserves_explicit_pool_options() {
        let builder = MssqlPool::builder()
            .max_size(16)
            .min_idle(4)
            .acquire_timeout(Duration::from_secs(2))
            .idle_timeout(Duration::from_secs(30))
            .max_lifetime(Duration::from_secs(300));

        assert_eq!(builder.options().backend, MssqlPoolBackend::Bb8);
        assert_eq!(builder.options().max_size, 16);
        assert_eq!(builder.options().min_idle, Some(4));
        assert_eq!(
            builder.options().acquire_timeout,
            Some(Duration::from_secs(2))
        );
        assert_eq!(
            builder.options().idle_timeout,
            Some(Duration::from_secs(30))
        );
        assert_eq!(
            builder.options().max_lifetime,
            Some(Duration::from_secs(300))
        );
    }

    #[test]
    fn builder_can_replace_options_wholesale() {
        let options = MssqlPoolOptions::bb8(12).with_min_idle(3);
        let builder = MssqlPool::builder().with_pool_options(options);

        assert_eq!(builder.options(), options);
    }
}