dinoco_engine 0.0.7

Database adapters, query execution, and migration engine components for Dinoco.
Documentation
use async_trait::async_trait;

use std::future::Future;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Instant;

use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod};

use tokio_postgres::NoTls;
use tokio_postgres::types::{IsNull, Json, ToSql, Type, private::BytesMut, to_sql_checked};

use crate::{
    ConstraintError, DinocoAdapter, DinocoClientConfig, DinocoError, DinocoQueryLog, DinocoQueryLogger, DinocoResult,
    DinocoRow, DinocoTransactionAdapter, DinocoValue, ExecutionResult,
};

mod dialect;
mod handler;
mod migration;
mod row;

pub use dialect::PostgresDialect;

static POSTGRES_DIALECT: PostgresDialect = PostgresDialect;
tokio::task_local! {
    static POSTGRES_TX_CONNECTION: Arc<tokio::sync::Mutex<deadpool_postgres::Object>>;
}

#[derive(Clone)]
pub struct PostgresAdapter {
    pub url: String,
    pub client: Arc<Pool>,
    pub query_logger: DinocoQueryLogger,
}

#[async_trait]
impl DinocoAdapter for PostgresAdapter {
    type Dialect = PostgresDialect;

    fn dialect(&self) -> &Self::Dialect {
        &POSTGRES_DIALECT
    }

    async fn connect(url: String, config: DinocoClientConfig) -> DinocoResult<Self> {
        let pg_config = tokio_postgres::Config::from_str(&url).map_err(|e| DinocoError::from(e))?;

        let mgr = Manager::from_config(pg_config, NoTls, ManagerConfig { recycling_method: RecyclingMethod::Fast });

        let pool = Pool::builder(mgr).max_size(16).build().map_err(|e| DinocoError::from(e))?;

        Ok(Self { url, client: Arc::new(pool), query_logger: config.query_logger })
    }

    async fn execute_result(&self, query: &str, params: &[DinocoValue]) -> DinocoResult<ExecutionResult> {
        if let Ok(tx_connection) = POSTGRES_TX_CONNECTION.try_with(Clone::clone) {
            let connection = tx_connection.lock().await;

            return execute_result_with_connection(&connection, query, params, &self.query_logger).await;
        }

        let connection = self.client.get().await.map_err(|e| DinocoError::from(e))?;

        execute_result_with_connection(&connection, query, params, &self.query_logger).await
    }

    async fn execute_script(&self, sql_content: &str) -> DinocoResult<()> {
        let clean_sql = sql_content.trim();

        if clean_sql.is_empty() {
            return Ok(());
        }

        let client = self.client.get().await.map_err(|e| DinocoError::from(e))?;
        let started_at = Instant::now();

        client.batch_execute(clean_sql).await?;
        self.query_logger.log(DinocoQueryLog {
            adapter: "postgresql",
            duration: started_at.elapsed(),
            params: Vec::new(),
            query: clean_sql.to_string(),
        });

        Ok(())
    }

    async fn query_as<T: DinocoRow>(&self, query: &str, params: &[DinocoValue]) -> DinocoResult<Vec<T>> {
        if let Ok(tx_connection) = POSTGRES_TX_CONNECTION.try_with(Clone::clone) {
            let connection = tx_connection.lock().await;

            return query_as_with_connection::<T>(&connection, query, params, &self.query_logger).await;
        }

        let connection = self.client.get().await.map_err(|e| DinocoError::from(e))?;

        query_as_with_connection::<T>(&connection, query, params, &self.query_logger).await
    }
}

impl DinocoTransactionAdapter for PostgresAdapter {
    fn with_transaction<'a, T, F>(&'a self, operation: F) -> Pin<Box<dyn Future<Output = DinocoResult<T>> + Send + 'a>>
    where
        T: Send + 'a,
        F: FnOnce() -> Pin<Box<dyn Future<Output = DinocoResult<T>> + Send + 'a>> + Send + 'a,
    {
        Box::pin(async move {
            if POSTGRES_TX_CONNECTION.try_with(|_| ()).is_ok() {
                return operation().await;
            }

            let connection = self.client.get().await.map_err(|error| DinocoError::from(error))?;
            let tx_connection = Arc::new(tokio::sync::Mutex::new(connection));

            {
                let connection = tx_connection.lock().await;
                connection.batch_execute("BEGIN").await?;
            }

            let result = POSTGRES_TX_CONNECTION.scope(tx_connection.clone(), async move { operation().await }).await;

            match result {
                Ok(output) => {
                    let connection = tx_connection.lock().await;
                    connection.batch_execute("COMMIT").await?;

                    Ok(output)
                }
                Err(error) => {
                    let connection = tx_connection.lock().await;
                    let _ = connection.batch_execute("ROLLBACK").await;

                    Err(error)
                }
            }
        })
    }
}

async fn execute_result_with_connection(
    connection: &deadpool_postgres::Object,
    query: &str,
    params: &[DinocoValue],
    query_logger: &DinocoQueryLogger,
) -> DinocoResult<ExecutionResult> {
    let pg_params: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| p as _).collect();
    let started_at = Instant::now();
    let affected_rows = connection.execute(query, &pg_params).await?;

    query_logger.log(DinocoQueryLog {
        adapter: "postgresql",
        duration: started_at.elapsed(),
        params: params.to_vec(),
        query: query.to_string(),
    });

    Ok(ExecutionResult { affected_rows, last_insert_id: None })
}

async fn query_as_with_connection<T: DinocoRow>(
    connection: &deadpool_postgres::Object,
    query: &str,
    params: &[DinocoValue],
    query_logger: &DinocoQueryLogger,
) -> DinocoResult<Vec<T>> {
    let pg_params: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| p as _).collect();
    let started_at = Instant::now();
    let db_rows = connection.query(query, &pg_params).await?;
    let mut results = Vec::with_capacity(db_rows.len());

    for db_row in db_rows {
        results.push(T::from_row(&db_row)?);
    }

    query_logger.log(DinocoQueryLog {
        adapter: "postgresql",
        duration: started_at.elapsed(),
        params: params.to_vec(),
        query: query.to_string(),
    });

    Ok(results)
}

impl ToSql for DinocoValue {
    fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
        match self {
            DinocoValue::Null => Ok(IsNull::Yes),
            DinocoValue::Integer(i) => i.to_sql(ty, out),
            DinocoValue::Float(f) => f.to_sql(ty, out),
            DinocoValue::Boolean(b) => b.to_sql(ty, out),
            DinocoValue::String(s) => s.as_str().to_sql(ty, out),
            DinocoValue::Enum(_, s) => s.as_str().to_sql(ty, out),
            DinocoValue::Json(v) => Json(v).to_sql(ty, out),
            DinocoValue::Bytes(v) => v.to_sql(ty, out),
            DinocoValue::DateTime(dt) => dt.to_sql(ty, out),
            DinocoValue::Date(date) => date.to_sql(ty, out),
        }
    }

    fn accepts(_ty: &Type) -> bool {
        true
    }

    to_sql_checked!();
}

impl From<tokio_postgres::Error> for DinocoError {
    fn from(e: tokio_postgres::Error) -> Self {
        if let Some(error) = map_postgres_constraint_error(&e) {
            return Self::Constraint(error);
        }

        Self::Postgres(e)
    }
}

impl From<deadpool_postgres::PoolError> for DinocoError {
    fn from(e: deadpool_postgres::PoolError) -> Self {
        Self::ConnectionError(format!("Failed to get connection from pool: {}", e))
    }
}

impl From<deadpool_postgres::BuildError> for DinocoError {
    fn from(e: deadpool_postgres::BuildError) -> Self {
        Self::ConnectionError(format!("Failed to build connection pool: {}", e))
    }
}

fn map_postgres_constraint_error(error: &tokio_postgres::Error) -> Option<ConstraintError> {
    let db_error = error.as_db_error()?;
    let code = db_error.code().code();
    let table = db_error.table().map(str::to_string);
    let columns = db_error.column().map(|item| vec![item.to_string()]).unwrap_or_default();
    let constraint = db_error.constraint().map(str::to_string);
    let message = db_error.message().to_string();

    match code {
        "23505" => Some(ConstraintError::unique(table, columns, constraint, message)),
        "23503" => Some(ConstraintError::foreign_key(table, columns, constraint, message)),
        "23502" => Some(ConstraintError::not_null(table, columns, constraint, message)),
        "23514" => Some(ConstraintError::check(table, columns, constraint, message)),
        _ => None,
    }
}