athena_rs 3.4.7

Database driver
Documentation
//! PostgreSQL backend powered by sqlx.
use crate::client::backend::{
    BackendError, BackendResult, BackendType, DatabaseBackend, HealthStatus, QueryLanguage,
    QueryResult, TranslatedQuery,
};
use async_trait::async_trait;
use serde_json::{Map, Number, Value};
use sqlx::{Column, PgPool, Pool, Postgres, Row, postgres::PgRow};
use std::any::Any;

/// ## `PostgresBackend`
/// A PostgreSQL backend powered by sqlx.
pub struct PostgresBackend {
    pool: PgPool,
}

impl PostgresBackend {
    /// ## `new`
    /// Create a new PostgreSQL backend.
    ///
    /// # Arguments
    ///
    /// * `pool` - The PostgreSQL pool.
    ///
    /// # Returns
    ///
    /// A `PostgresBackend` instance.
    pub fn new(pool: PgPool) -> Self {
        Self { pool }
    }

    /// ## `pool`
    /// Get the PostgreSQL pool.
    ///
    /// # Arguments
    ///
    /// * `self` - The PostgreSQL backend.
    ///
    /// # Returns
    ///
    /// A `&PgPool` containing the PostgreSQL pool.
    pub fn pool(&self) -> &PgPool {
        &self.pool
    }

    /// ## `from_connection_string`
    /// Create a new PostgreSQL backend from a connection string.
    ///
    /// # Arguments
    ///
    /// * `connection` - The connection string.
    ///
    /// # Returns
    ///
    /// A `BackendResult` containing the PostgreSQL backend.
    pub async fn from_connection_string(connection: &str) -> BackendResult<Self> {
        let pool: Pool<Postgres> = PgPool::connect(connection)
            .await
            .map_err(|err| BackendError::Generic(err.to_string()))?;
        Ok(Self { pool })
    }
}

/// ## `DatabaseBackend` implementation for PostgresBackend.
#[async_trait]
impl DatabaseBackend for PostgresBackend {
    /// ## `execute_query`
    /// Execute a query on the PostgreSQL backend.
    ///
    /// # Arguments
    ///
    /// * `query` - The translated query to execute.
    ///
    /// # Returns
    ///
    /// A `BackendResult` containing the query result.
    async fn execute_query(&self, query: TranslatedQuery) -> BackendResult<QueryResult> {
        if !matches!(query.language, QueryLanguage::Sql) {
            return Err(BackendError::Generic(
                "Postgres backend only supports SQL".to_string(),
            ));
        }

        let rows: Vec<PgRow> = sqlx::query(&query.sql)
            .fetch_all(&self.pool)
            .await
            .map_err(|err| BackendError::Generic(err.to_string()))?;

        let mut data: Vec<Value> = Vec::new();
        let column_names: Vec<String> = rows
            .first()
            .map(|row| {
                row.columns()
                    .iter()
                    .map(|col| col.name().to_string())
                    .collect()
            })
            .unwrap_or_default();

        for row in &rows {
            data.push(row_to_value(row)?);
        }

        Ok(QueryResult::new(data, column_names, None, None, None))
    }

    /// ## `health_check`
    /// Check the health of the PostgreSQL backend.
    ///
    /// # Arguments
    ///
    /// * `self` - The PostgreSQL backend.
    ///
    /// # Returns
    ///
    /// A `BackendResult` containing the health status.
    async fn health_check(&self) -> BackendResult<HealthStatus> {
        match self.pool.acquire().await {
            Ok(_) => Ok(HealthStatus::Healthy),
            Err(err) => Err(BackendError::Generic(err.to_string())),
        }
    }

    /// ## `backend_type`
    /// Get the backend type.
    ///
    /// # Arguments
    ///
    /// * `self` - The PostgreSQL backend.
    ///
    /// # Returns
    fn backend_type(&self) -> BackendType {
        BackendType::PostgreSQL
    }

    /// ## `supports_sql`
    /// Check if the backend supports SQL.
    ///
    /// # Arguments
    ///
    /// * `self` - The PostgreSQL backend.
    ///
    /// # Returns
    ///
    /// A `bool` indicating if the backend supports SQL.
    fn supports_sql(&self) -> bool {
        true
    }

    /// ## `supports_cql`
    /// Check if the backend supports CQL.
    ///
    /// # Arguments
    ///
    /// * `self` - The PostgreSQL backend.
    ///
    /// # Returns
    ///
    /// A `bool` indicating if the backend supports CQL.
    fn supports_cql(&self) -> bool {
        false
    }

    /// ## `as_any`
    /// Get the backend as any.
    ///
    /// # Arguments
    ///
    /// * `self` - The PostgreSQL backend.
    ///
    /// # Returns
    ///
    /// A `&dyn Any` containing the backend.
    fn as_any(&self) -> &dyn Any {
        self
    }
}

/// ## `row_to_value`
/// Convert a row to a value.
///
/// # Arguments
///
/// * `row` - The row to convert.
///
/// # Returns
///
/// A `BackendResult` containing the value.
fn row_to_value(row: &PgRow) -> BackendResult<Value> {
    let mut object: Map<String, Value> = Map::new();

    for column in row.columns() {
        let value: Value = match read_column_value(row, column) {
            Ok(value) => value,
            Err(err) => {
                return Err(BackendError::Generic(format!(
                    "failed to decode column {}: {err}",
                    column.name()
                )));
            }
        };
        object.insert(column.name().to_string(), value);
    }

    Ok(Value::Object(object))
}

/// ## `read_column_value`
/// Read a column value from a row.
///
/// # Arguments
///
/// * `row` - The row to read the column value from.
/// * `column` - The column to read the value from.
///
/// # Returns
///
/// A `BackendResult` containing the value.
fn read_column_value(row: &PgRow, column: &impl Column) -> BackendResult<Value> {
    let name: &str = column.name();

    if let Ok(json_value) = row.try_get::<Value, _>(name) {
        return Ok(json_value);
    }

    if let Ok(text) = row.try_get::<String, _>(name) {
        return Ok(Value::String(text));
    }

    if let Ok(i) = row.try_get::<i64, _>(name) {
        return Ok(Value::Number(Number::from(i)));
    }

    if let Ok(f) = row.try_get::<f64, _>(name)
        && let Some(number) = Number::from_f64(f)
    {
        return Ok(Value::Number(number));
    }

    if let Ok(b) = row.try_get::<bool, _>(name) {
        return Ok(Value::Bool(b));
    }

    if let Ok(bytes) = row.try_get::<Vec<u8>, _>(name) {
        return Ok(Value::String(String::from_utf8_lossy(&bytes).to_string()));
    }

    Ok(Value::String("<binary>".to_string()))
}