sqlw-backend 0.1.0

Database executor implementations for sqlw
Documentation
//! PostgreSQL executor implementation for sqlw using bb8 connection pool.

use std::future::Future;
use std::sync::Arc;

use crate::DBError;
use bb8_postgres::PostgresConnectionManager;
use sqlw::{FromRow, Query, RowCell, RowError, RowLike, Value};
use tokio_postgres::NoTls;
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::types::ToSql;

/// A borrowed row from a PostgreSQL query result.
pub struct PostgresRowRef<'a> {
    row: &'a tokio_postgres::Row,
}

impl<'a> PostgresRowRef<'a> {
    /// Creates a new row reference from a `tokio_postgres` row.
    pub fn new(row: &'a tokio_postgres::Row) -> Self {
        Self { row }
    }

    fn column_index(&self, name: &str) -> Result<usize, RowError> {
        self.row
            .columns()
            .iter()
            .position(|col| col.name() == name)
            .ok_or_else(|| RowError::ColumnNotFound {
                name: name.to_string(),
            })
    }

    fn get_value_by_index(&self, index: usize) -> Result<Value, RowError> {
        let col = &self.row.columns()[index];
        let ty = col.type_();

        let null_check: Option<i64> = self
            .row
            .try_get(index)
            .map_err(|e| RowError::Any(e.to_string()))?;

        if null_check.is_none() {
            return Ok(Value::Null);
        }

        match ty {
            &tokio_postgres::types::Type::TEXT
            | &tokio_postgres::types::Type::VARCHAR
            | &tokio_postgres::types::Type::BPCHAR
            | &tokio_postgres::types::Type::NAME => {
                let val: String = self
                    .row
                    .try_get(index)
                    .map_err(|e| RowError::Any(e.to_string()))?;
                Ok(Value::Text(val))
            }
            &tokio_postgres::types::Type::INT2 => {
                let val: i16 = self
                    .row
                    .try_get(index)
                    .map_err(|e| RowError::Any(e.to_string()))?;
                Ok(Value::Int(val as i64))
            }
            &tokio_postgres::types::Type::INT4 => {
                let val: i32 = self
                    .row
                    .try_get(index)
                    .map_err(|e| RowError::Any(e.to_string()))?;
                Ok(Value::Int(val as i64))
            }
            &tokio_postgres::types::Type::INT8 => {
                let val: i64 = self
                    .row
                    .try_get(index)
                    .map_err(|e| RowError::Any(e.to_string()))?;
                Ok(Value::Int(val))
            }
            &tokio_postgres::types::Type::FLOAT4 => {
                let val: f32 = self
                    .row
                    .try_get(index)
                    .map_err(|e| RowError::Any(e.to_string()))?;
                Ok(Value::Float(val as f64))
            }
            &tokio_postgres::types::Type::FLOAT8 => {
                let val: f64 = self
                    .row
                    .try_get(index)
                    .map_err(|e| RowError::Any(e.to_string()))?;
                Ok(Value::Float(val))
            }
            &tokio_postgres::types::Type::BOOL => {
                let val: bool = self
                    .row
                    .try_get(index)
                    .map_err(|e| RowError::Any(e.to_string()))?;
                Ok(Value::Bool(val))
            }
            &tokio_postgres::types::Type::BYTEA => {
                let val: Vec<u8> = self
                    .row
                    .try_get(index)
                    .map_err(|e| RowError::Any(e.to_string()))?;
                Ok(Value::Blob(val))
            }
            _ => {
                let val: String = self
                    .row
                    .try_get(index)
                    .map_err(|e| RowError::Any(e.to_string()))?;
                Ok(Value::Text(val))
            }
        }
    }
}

impl<'a> RowLike for PostgresRowRef<'a> {
    fn cell<'b>(&'b self, name: &str) -> Result<RowCell<'b>, RowError> {
        let index = self.column_index(name)?;
        Ok(RowCell::Owned(self.get_value_by_index(index)?))
    }
}

/// An async PostgreSQL executor using `bb8` connection pool with `tokio-postgres`.
///
/// Provides connection-pooled async database operations via the [`QueryExecutor`](sqlw::QueryExecutor) trait.
/// This is the recommended PostgreSQL implementation for production applications.
pub struct PostgresExecutor<Tls = NoTls>
where
    Tls: MakeTlsConnect<tokio_postgres::Socket> + Send + Sync + Clone + 'static,
    Tls::Stream: Send + Sync,
    Tls::TlsConnect: Send,
    <Tls::TlsConnect as tokio_postgres::tls::TlsConnect<tokio_postgres::Socket>>::Future: Send,
{
    /// Connection pool for PostgreSQL
    pool: Arc<bb8::Pool<PostgresConnectionManager<Tls>>>,
}

impl PostgresExecutor<NoTls> {
    /// Creates a new `PostgresExecutor` with the given connection string.
    ///
    /// Uses NoTls as the default TLS configuration (unencrypted connections).
    ///
    /// # Example
    ///
    /// ```ignore
    /// use sqlw_backend::postgres::PostgresExecutor;
    ///
    /// let executor = PostgresExecutor::new(
    ///     "host=localhost user=postgres password=secret dbname=mydb"
    /// ).await?;
    /// ```
    pub async fn new(connection_string: &str) -> Result<Self, DBError> {
        Self::with_config(connection_string, NoTls, |builder| builder).await
    }

    /// Creates a new executor directly from a connection URL with default pool settings.
    ///
    /// # Example
    ///
    /// ```ignore
    /// use sqlw_backend::postgres::PostgresExecutor;
    ///
    /// let executor = PostgresExecutor::from_url(
    ///     "host=localhost user=postgres password=secret dbname=mydb"
    /// ).await?;
    /// ```
    pub async fn from_url(connection_string: &str) -> Result<Self, DBError> {
        Self::new(connection_string).await
    }
}

impl<Tls> PostgresExecutor<Tls>
where
    Tls: MakeTlsConnect<tokio_postgres::Socket> + Send + Sync + Clone + 'static,
    Tls::Stream: Send + Sync,
    Tls::TlsConnect: Send,
    <Tls::TlsConnect as tokio_postgres::tls::TlsConnect<tokio_postgres::Socket>>::Future: Send,
{
    /// Creates a new executor with custom TLS configuration.
    ///
    /// # Example
    ///
    /// ```ignore
    /// use sqlw_backend::postgres::PostgresExecutor;
    /// use tokio_postgres::NoTls;
    ///
    /// let executor = PostgresExecutor::with_tls(
    ///     "host=localhost user=postgres password=secret dbname=mydb",
    ///     NoTls
    /// ).await?;
    /// ```
    pub async fn with_tls(connection_string: &str, tls: Tls) -> Result<Self, DBError> {
        Self::with_config(connection_string, tls, |builder| builder).await
    }

    /// Creates a new executor with custom pool configuration and TLS.
    ///
    /// # Example
    ///
    /// ```ignore
    /// use sqlw_backend::postgres::PostgresExecutor;
    /// use tokio_postgres::NoTls;
    ///
    /// let executor = PostgresExecutor::with_config(
    ///     "host=localhost user=postgres password=secret dbname=mydb",
    ///     NoTls,
    ///     |builder| {
    ///         builder
    ///             .max_size(15)
    ///             .min_idle(Some(5))
    ///             .connection_timeout(std::time::Duration::from_secs(30))
    ///             .idle_timeout(Some(std::time::Duration::from_secs(600)))
    ///             .max_lifetime(Some(std::time::Duration::from_secs(1800)))
    ///     }
    /// ).await?;
    /// ```
    pub async fn with_config<F>(
        connection_string: &str,
        tls: Tls,
        config_fn: F,
    ) -> Result<Self, DBError>
    where
        F: FnOnce(
            bb8::Builder<PostgresConnectionManager<Tls>>,
        ) -> bb8::Builder<PostgresConnectionManager<Tls>>,
    {
        let manager = PostgresConnectionManager::new_from_stringlike(connection_string, tls)
            .map_err(|e| DBError::Connection(e.into()))?;

        let builder = config_fn(bb8::Pool::builder());
        let pool = builder
            .build(manager)
            .await
            .map_err(|e| DBError::Connection(e.into()))?;

        Ok(PostgresExecutor {
            pool: Arc::new(pool),
        })
    }

    /// Returns a reference to the underlying bb8 pool.
    pub fn pool(&self) -> &bb8::Pool<PostgresConnectionManager<Tls>> {
        &self.pool
    }
}

impl<Tls> sqlw::QueryExecutor for PostgresExecutor<Tls>
where
    Tls: MakeTlsConnect<tokio_postgres::Socket> + Send + Sync + Clone + 'static,
    Tls::Stream: Send + Sync,
    Tls::TlsConnect: Send,
    <Tls::TlsConnect as tokio_postgres::tls::TlsConnect<tokio_postgres::Socket>>::Future: Send,
{
    type Error = DBError;

    fn query_void(&self, query: Query) -> impl Future<Output = Result<(), DBError>> {
        let pool = Arc::clone(&self.pool);
        async move {
            let (sql, args) = query.split();

            let conn = pool.get().await.map_err(|e| DBError::Execution(e.into()))?;

            let params_owned = to_postgres_params(args);
            let params: Vec<&(dyn ToSql + Sync)> =
                params_owned.iter().map(|v| v.as_ref()).collect();

            conn.execute(&sql, &params)
                .await
                .map(|_| ())
                .map_err(|e| DBError::Execution(e.into()))
        }
    }

    fn query_one<T: FromRow + Send + 'static>(
        &self,
        query: Query,
    ) -> impl Future<Output = Result<Option<T>, DBError>> {
        let pool = Arc::clone(&self.pool);
        async move {
            let (sql, args) = query.split();

            let conn = pool.get().await.map_err(|e| DBError::Execution(e.into()))?;

            let params_owned = to_postgres_params(args);
            let params: Vec<&(dyn ToSql + Sync)> =
                params_owned.iter().map(|v| v.as_ref()).collect();

            let row = conn
                .query_opt(&sql, &params)
                .await
                .map_err(|e| DBError::Execution(e.into()))?;

            match row {
                Some(row) => {
                    let row_ref = PostgresRowRef::new(&row);
                    T::from_row(&row_ref).map(Some)
                }
                None => Ok(None),
            }
        }
    }

    fn query_list<T: FromRow + Send + 'static>(
        &self,
        query: Query,
    ) -> impl Future<Output = Result<Vec<T>, DBError>> {
        let pool = Arc::clone(&self.pool);
        async move {
            let (sql, args) = query.split();

            let conn = pool.get().await.map_err(|e| DBError::Execution(e.into()))?;

            let params_owned = to_postgres_params(args);
            let params: Vec<&(dyn ToSql + Sync)> =
                params_owned.iter().map(|v| v.as_ref()).collect();

            let rows = conn
                .query(&sql, &params)
                .await
                .map_err(|e| DBError::Execution(e.into()))?;

            let mut results = Vec::new();
            for row in rows {
                let row_ref = PostgresRowRef::new(&row);
                results.push(T::from_row(&row_ref)?);
            }

            Ok(results)
        }
    }
}

fn to_postgres_params(args: Vec<Value>) -> Vec<Box<dyn ToSql + Sync>> {
    args.into_iter()
        .map(|value| match value {
            Value::Text(s) => Box::new(s) as Box<dyn ToSql + Sync>,
            Value::Int(i) => Box::new(i) as Box<dyn ToSql + Sync>,
            Value::Float(f) => Box::new(f) as Box<dyn ToSql + Sync>,
            Value::Bool(b) => Box::new(b) as Box<dyn ToSql + Sync>,
            Value::Blob(b) => Box::new(b) as Box<dyn ToSql + Sync>,
            Value::Null => Box::new(Option::<i64>::None) as Box<dyn ToSql + Sync>,
        })
        .collect()
}