athena_rs 2.8.0

Database gateway API
Documentation
//! PostgreSQL backend powered by sqlx.
use crate::client::backend::{
    BackendError, BackendResult, BackendType, DatabaseBackend, HealthStatus, QueryLanguage,
    QueryResult, TranslatedQuery,
};
use async_trait::async_trait;
use serde_json::{Map, Number, Value};
use sqlx::{Column, PgPool, Pool, Postgres, Row, postgres::PgRow};

pub struct PostgresBackend {
    pool: PgPool,
}

impl PostgresBackend {
    pub fn new(pool: PgPool) -> Self {
        Self { pool }
    }

    pub fn pool(&self) -> &PgPool {
        &self.pool
    }

    pub async fn from_connection_string(connection: &str) -> BackendResult<Self> {
        let pool: Pool<Postgres> = PgPool::connect(connection)
            .await
            .map_err(|err| BackendError::Generic(err.to_string()))?;
        Ok(Self { pool })
    }
}

#[async_trait]
impl DatabaseBackend for PostgresBackend {
    async fn execute_query(&self, query: TranslatedQuery) -> BackendResult<QueryResult> {
        if !matches!(query.language, QueryLanguage::Sql) {
            return Err(BackendError::Generic(
                "Postgres backend only supports SQL".to_string(),
            ));
        }

        let rows: Vec<PgRow> = sqlx::query(&query.sql)
            .fetch_all(&self.pool)
            .await
            .map_err(|err| BackendError::Generic(err.to_string()))?;

        let mut data: Vec<Value> = Vec::new();
        let column_names: Vec<String> = rows
            .first()
            .map(|row| {
                row.columns()
                    .iter()
                    .map(|col| col.name().to_string())
                    .collect()
            })
            .unwrap_or_default();

        for row in &rows {
            data.push(row_to_value(row)?);
        }

        Ok(QueryResult::new(data, column_names))
    }

    async fn health_check(&self) -> BackendResult<HealthStatus> {
        match self.pool.acquire().await {
            Ok(_) => Ok(HealthStatus::Healthy),
            Err(err) => Err(BackendError::Generic(err.to_string())),
        }
    }

    fn backend_type(&self) -> BackendType {
        BackendType::PostgreSQL
    }

    fn supports_sql(&self) -> bool {
        true
    }

    fn supports_cql(&self) -> bool {
        false
    }

    fn as_any(&self) -> &dyn std::any::Any {
        self
    }
}

fn row_to_value(row: &PgRow) -> BackendResult<Value> {
    let mut object = Map::new();

    for column in row.columns() {
        let value: Value = match read_column_value(row, column) {
            Ok(value) => value,
            Err(err) => {
                return Err(BackendError::Generic(format!(
                    "failed to decode column {}: {err}",
                    column.name()
                )));
            }
        };
        object.insert(column.name().to_string(), value);
    }

    Ok(Value::Object(object))
}

fn read_column_value(row: &PgRow, column: &impl Column) -> BackendResult<Value> {
    let name = column.name();

    if let Ok(json_value) = row.try_get::<serde_json::Value, _>(name) {
        return Ok(json_value);
    }

    if let Ok(text) = row.try_get::<String, _>(name) {
        return Ok(Value::String(text));
    }

    if let Ok(i) = row.try_get::<i64, _>(name) {
        return Ok(Value::Number(Number::from(i)));
    }

    if let Ok(f) = row.try_get::<f64, _>(name)
        && let Some(number) = Number::from_f64(f)
    {
        return Ok(Value::Number(number));
    }

    if let Ok(b) = row.try_get::<bool, _>(name) {
        return Ok(Value::Bool(b));
    }

    if let Ok(bytes) = row.try_get::<Vec<u8>, _>(name) {
        return Ok(Value::String(String::from_utf8_lossy(&bytes).to_string()));
    }

    Ok(Value::String("<binary>".to_string()))
}