athena_rs 3.3.0

Database gateway API
Documentation
#[cfg(feature = "deadpool_experimental")]
use crate::drivers::postgresql::raw_sql::{
    PostgresSqlExecutionMode, PostgresSqlExecutionResult, PostgresSqlExecutionSummary,
    classify_sql_query, normalize_sql_query,
};

#[cfg(feature = "deadpool_experimental")]
use deadpool_postgres::Pool;
#[cfg(feature = "deadpool_experimental")]
use serde_json::{Value, json};
#[cfg(feature = "deadpool_experimental")]
use tokio::time::{Duration, timeout};
#[cfg(feature = "deadpool_experimental")]
use tokio_postgres::SimpleQueryMessage;

#[cfg(feature = "deadpool_experimental")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DeadpoolFallbackReason {
    CheckoutTimeout,
    CheckoutFailed,
    ExecuteFailed,
    DecodeFailed,
}

#[cfg(feature = "deadpool_experimental")]
pub fn deadpool_fallback_reason_label(reason: DeadpoolFallbackReason) -> &'static str {
    match reason {
        DeadpoolFallbackReason::CheckoutTimeout => "checkout_timeout",
        DeadpoolFallbackReason::CheckoutFailed => "checkout_failed",
        DeadpoolFallbackReason::ExecuteFailed => "execute_failed",
        DeadpoolFallbackReason::DecodeFailed => "decode_failed",
    }
}

#[cfg(feature = "deadpool_experimental")]
#[derive(Debug)]
pub struct DeadpoolError {
    pub reason: DeadpoolFallbackReason,
    pub is_db_error: bool,
    pub message: String,
    pub sql_state: Option<String>,
}

#[cfg(feature = "deadpool_experimental")]
pub async fn execute_postgres_sql_deadpool(
    pool: &Pool,
    query: &str,
    checkout_timeout: Duration,
) -> Result<PostgresSqlExecutionResult, DeadpoolError> {
    let normalized_query = normalize_sql_query(query);
    let mode = classify_sql_query(&normalized_query);

    let client = timeout(checkout_timeout, pool.get())
        .await
        .map_err(|_| DeadpoolError {
            reason: DeadpoolFallbackReason::CheckoutTimeout,
            is_db_error: false,
            message: "deadpool checkout timeout".to_string(),
            sql_state: None,
        })?
        .map_err(|err| DeadpoolError {
            reason: DeadpoolFallbackReason::CheckoutFailed,
            is_db_error: false,
            message: err.to_string(),
            sql_state: None,
        })?;

    match mode {
        PostgresSqlExecutionMode::JsonRows => execute_json_row_query(&client, &normalized_query)
            .await
            .map_err(|err| DeadpoolError {
                reason: DeadpoolFallbackReason::ExecuteFailed,
                is_db_error: err.as_db_error().is_some(),
                message: err.to_string(),
                sql_state: err.as_db_error().map(|db| db.code().code().to_string()),
            }),
        PostgresSqlExecutionMode::DirectRows => {
            execute_direct_row_query(&client, &normalized_query)
                .await
                .map_err(|err| DeadpoolError {
                    reason: DeadpoolFallbackReason::ExecuteFailed,
                    is_db_error: err.as_db_error().is_some(),
                    message: err.to_string(),
                    sql_state: err.as_db_error().map(|db| db.code().code().to_string()),
                })
        }
        PostgresSqlExecutionMode::Command => execute_command_query(&client, &normalized_query)
            .await
            .map_err(|err| DeadpoolError {
                reason: DeadpoolFallbackReason::ExecuteFailed,
                is_db_error: err.as_db_error().is_some(),
                message: err.to_string(),
                sql_state: err.as_db_error().map(|db| db.code().code().to_string()),
            }),
    }
}

#[cfg(feature = "deadpool_experimental")]
async fn execute_json_row_query(
    client: &tokio_postgres::Client,
    query: &str,
) -> Result<PostgresSqlExecutionResult, tokio_postgres::Error> {
    let wrapped_query = format!(
        "WITH athena_query_result AS ({query}) SELECT to_jsonb(athena_query_result) AS row FROM athena_query_result"
    );
    let rows = client.query(wrapped_query.as_str(), &[]).await?;
    let data = rows
        .into_iter()
        .map(|row| row.get::<_, Value>("row"))
        .collect::<Vec<_>>();

    Ok(PostgresSqlExecutionResult {
        summary: PostgresSqlExecutionSummary {
            statement_count: 1,
            rows_affected: 0,
            returned_row_count: data.len(),
        },
        rows: data,
    })
}

#[cfg(feature = "deadpool_experimental")]
async fn execute_direct_row_query(
    client: &tokio_postgres::Client,
    query: &str,
) -> Result<PostgresSqlExecutionResult, tokio_postgres::Error> {
    let rows = client.query(query, &[]).await?;
    let data = rows.into_iter().map(row_to_json).collect::<Vec<_>>();

    Ok(PostgresSqlExecutionResult {
        summary: PostgresSqlExecutionSummary {
            statement_count: 1,
            rows_affected: 0,
            returned_row_count: data.len(),
        },
        rows: data,
    })
}

#[cfg(feature = "deadpool_experimental")]
async fn execute_command_query(
    client: &tokio_postgres::Client,
    query: &str,
) -> Result<PostgresSqlExecutionResult, tokio_postgres::Error> {
    let messages = client.simple_query(query).await?;
    let mut statement_count = 0usize;
    let mut rows_affected = 0u64;
    for msg in messages {
        if let SimpleQueryMessage::CommandComplete(count) = msg {
            statement_count += 1;
            rows_affected += count;
        }
    }

    Ok(PostgresSqlExecutionResult {
        rows: Vec::new(),
        summary: PostgresSqlExecutionSummary {
            statement_count,
            rows_affected,
            returned_row_count: 0,
        },
    })
}

#[cfg(feature = "deadpool_experimental")]
fn row_to_json(row: tokio_postgres::Row) -> Value {
    let mut object = serde_json::Map::new();
    for (idx, column) in row.columns().iter().enumerate() {
        let name = column.name().to_string();
        let value = read_column_value(&row, idx);
        object.insert(name, value);
    }
    Value::Object(object)
}

#[cfg(feature = "deadpool_experimental")]
fn read_column_value(row: &tokio_postgres::Row, idx: usize) -> Value {
    if let Ok(value) = row.try_get::<usize, Option<Value>>(idx) {
        return value.unwrap_or(Value::Null);
    }
    if let Ok(value) = row.try_get::<usize, Option<String>>(idx) {
        return value.map(Value::String).unwrap_or(Value::Null);
    }
    if let Ok(value) = row.try_get::<usize, Option<bool>>(idx) {
        return value.map(Value::Bool).unwrap_or(Value::Null);
    }
    if let Ok(value) = row.try_get::<usize, Option<i16>>(idx) {
        return value.map(|v| json!(v)).unwrap_or(Value::Null);
    }
    if let Ok(value) = row.try_get::<usize, Option<i32>>(idx) {
        return value.map(|v| json!(v)).unwrap_or(Value::Null);
    }
    if let Ok(value) = row.try_get::<usize, Option<i64>>(idx) {
        return value.map(|v| json!(v)).unwrap_or(Value::Null);
    }
    if let Ok(value) = row.try_get::<usize, Option<f32>>(idx) {
        return value.map(|v| json!(v)).unwrap_or(Value::Null);
    }
    if let Ok(value) = row.try_get::<usize, Option<f64>>(idx) {
        return value.map(|v| json!(v)).unwrap_or(Value::Null);
    }
    if let Ok(value) = row.try_get::<usize, Option<uuid::Uuid>>(idx) {
        return value
            .map(|v| Value::String(v.to_string()))
            .unwrap_or(Value::Null);
    }
    if let Ok(value) = row.try_get::<usize, Option<chrono::NaiveDate>>(idx) {
        return value
            .map(|v| Value::String(v.to_string()))
            .unwrap_or(Value::Null);
    }
    if let Ok(value) = row.try_get::<usize, Option<chrono::NaiveTime>>(idx) {
        return value
            .map(|v| Value::String(v.to_string()))
            .unwrap_or(Value::Null);
    }
    if let Ok(value) = row.try_get::<usize, Option<chrono::NaiveDateTime>>(idx) {
        return value
            .map(|v| Value::String(v.to_string()))
            .unwrap_or(Value::Null);
    }
    if let Ok(value) = row.try_get::<usize, Option<chrono::DateTime<chrono::Utc>>>(idx) {
        return value
            .map(|v| Value::String(v.to_rfc3339()))
            .unwrap_or(Value::Null);
    }
    if let Ok(value) = row.try_get::<usize, Option<Vec<u8>>>(idx) {
        return value
            .map(|v| Value::String(String::from_utf8_lossy(&v).to_string()))
            .unwrap_or(Value::Null);
    }

    Value::String("<unsupported>".to_string())
}