athena_rs 3.3.0

Database gateway API
Documentation
#[cfg(feature = "deadpool_experimental")]
use deadpool_postgres::Pool;
#[cfg(feature = "deadpool_experimental")]
use serde_json::{Map, Value};
#[cfg(feature = "deadpool_experimental")]
use std::time::Duration;
#[cfg(feature = "deadpool_experimental")]
use tokio::time::timeout;
#[cfg(feature = "deadpool_experimental")]
use tokio_postgres::types::ToSql;

#[cfg(feature = "deadpool_experimental")]
use crate::parser::query_builder::{
    Condition, build_insert_placeholders_for_entries, build_where_clause, sanitize_identifier,
    sanitize_qualified_table_identifier,
};

#[cfg(feature = "deadpool_experimental")]
use crate::drivers::postgresql::deadpool_raw_sql::{DeadpoolError, DeadpoolFallbackReason};

#[cfg(feature = "deadpool_experimental")]
fn to_sql_param_set(value: &Value) -> Box<dyn ToSql + Sync> {
    match value {
        Value::Null => Box::new(Option::<String>::None),
        Value::Bool(b) => Box::new(*b),
        Value::Number(num) => {
            if let Some(i) = num.as_i64() {
                Box::new(i)
            } else if let Some(f) = num.as_f64() {
                Box::new(f)
            } else {
                Box::new(num.to_string())
            }
        }
        Value::String(s) => {
            if let Ok(u) = uuid::Uuid::parse_str(s) {
                Box::new(u)
            } else {
                Box::new(s.clone())
            }
        }
        Value::Array(_) | Value::Object(_) => Box::new(value.clone()),
    }
}

#[cfg(feature = "deadpool_experimental")]
fn to_sql_param_where(value: &Value) -> Box<dyn ToSql + Sync> {
    match value {
        Value::Null => Box::new(Option::<String>::None),
        Value::Bool(b) => Box::new(*b),
        Value::Number(num) => {
            if let Some(i) = num.as_i64() {
                Box::new(i)
            } else if let Some(f) = num.as_f64() {
                Box::new(f)
            } else {
                Box::new(num.to_string())
            }
        }
        Value::String(s) => Box::new(s.clone()),
        Value::Array(_) | Value::Object(_) => Box::new(value.clone()),
    }
}

#[cfg(feature = "deadpool_experimental")]
fn order_by_clause(sort: Option<(&str, bool)>) -> String {
    sort.and_then(|(col, ascending)| {
        sanitize_identifier(col)
            .map(|c| format!(" ORDER BY {} {}", c, if ascending { "ASC" } else { "DESC" }))
    })
    .unwrap_or_default()
}

#[cfg(feature = "deadpool_experimental")]
pub async fn fetch_rows_with_columns_deadpool(
    pool: &Pool,
    table_name: &str,
    columns: &[&str],
    conditions: &[Condition],
    limit: i64,
    offset: i64,
    order_by: Option<(&str, bool)>,
    checkout_timeout: Duration,
) -> Result<Vec<Value>, DeadpoolError> {
    let client = timeout(checkout_timeout, pool.get())
        .await
        .map_err(|_| DeadpoolError {
            reason: DeadpoolFallbackReason::CheckoutTimeout,
            is_db_error: false,
            message: "deadpool checkout timeout".to_string(),
            sql_state: None,
        })?
        .map_err(|err| DeadpoolError {
            reason: DeadpoolFallbackReason::CheckoutFailed,
            is_db_error: false,
            message: err.to_string(),
            sql_state: None,
        })?;

    let table: String = sanitize_identifier(table_name).ok_or_else(|| DeadpoolError {
        reason: DeadpoolFallbackReason::ExecuteFailed,
        is_db_error: false,
        message: "invalid table name".to_string(),
        sql_state: None,
    })?;

    let (where_clause, where_values) =
        build_where_clause(conditions, 1).map_err(|err| DeadpoolError {
            reason: DeadpoolFallbackReason::ExecuteFailed,
            is_db_error: false,
            message: err.to_string(),
            sql_state: None,
        })?;

    let use_all_columns = columns.is_empty() || columns.contains(&"*");
    let order_clause = order_by_clause(order_by);

    let sql = if use_all_columns {
        format!(
            "SELECT row_to_json(t.*) AS data FROM {table} AS t{where_clause}{order_clause} LIMIT {limit} OFFSET {offset}",
            table = table,
            where_clause = where_clause,
            order_clause = order_clause,
            limit = limit,
            offset = offset
        )
    } else {
        let sanitized_cols: Vec<String> = columns
            .iter()
            .filter_map(|name| sanitize_identifier(name))
            .collect();
        if sanitized_cols.is_empty() {
            return Err(DeadpoolError {
                reason: DeadpoolFallbackReason::ExecuteFailed,
                is_db_error: false,
                message: "no valid columns".to_string(),
                sql_state: None,
            });
        }
        let pairs: Vec<String> = columns
            .iter()
            .zip(sanitized_cols.iter())
            .map(|(requested, sanitized)| format!("'{}', t.{}", requested, sanitized))
            .collect();
        format!(
            "SELECT jsonb_build_object({columns}) AS data FROM {table} AS t{where_clause}{order_clause} LIMIT {limit} OFFSET {offset}",
            columns = pairs.join(", "),
            table = table,
            where_clause = where_clause,
            order_clause = order_clause,
            limit = limit,
            offset = offset
        )
    };

    let params: Vec<Box<dyn ToSql + Sync>> = where_values.iter().map(to_sql_param_where).collect();
    let param_refs: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| p.as_ref()).collect();

    let rows = client
        .query(sql.as_str(), &param_refs)
        .await
        .map_err(|err| DeadpoolError {
            reason: DeadpoolFallbackReason::ExecuteFailed,
            is_db_error: err.as_db_error().is_some(),
            message: err.to_string(),
            sql_state: err.as_db_error().map(|db| db.code().code().to_string()),
        })?;

    let mut result = Vec::with_capacity(rows.len());
    for row in rows {
        let data: Value = row.try_get("data").map_err(|err| DeadpoolError {
            reason: DeadpoolFallbackReason::DecodeFailed,
            is_db_error: false,
            message: err.to_string(),
            sql_state: None,
        })?;
        result.push(data);
    }
    Ok(result)
}

#[cfg(feature = "deadpool_experimental")]
pub async fn update_rows_deadpool(
    pool: &Pool,
    table_name: &str,
    conditions: &[Condition],
    payload: &Value,
    checkout_timeout: Duration,
) -> Result<Vec<Value>, DeadpoolError> {
    let client = timeout(checkout_timeout, pool.get())
        .await
        .map_err(|_| DeadpoolError {
            reason: DeadpoolFallbackReason::CheckoutTimeout,
            is_db_error: false,
            message: "deadpool checkout timeout".to_string(),
            sql_state: None,
        })?
        .map_err(|err| DeadpoolError {
            reason: DeadpoolFallbackReason::CheckoutFailed,
            is_db_error: false,
            message: err.to_string(),
            sql_state: None,
        })?;

    let table: String = sanitize_identifier(table_name).ok_or_else(|| DeadpoolError {
        reason: DeadpoolFallbackReason::ExecuteFailed,
        is_db_error: false,
        message: "invalid table name".to_string(),
        sql_state: None,
    })?;

    let entries: Vec<(String, Value)> = payload
        .as_object()
        .ok_or(DeadpoolError {
            reason: DeadpoolFallbackReason::ExecuteFailed,
            is_db_error: false,
            message: "payload must be object".to_string(),
            sql_state: None,
        })?
        .iter()
        .filter_map(|(column, value)| sanitize_identifier(column).map(|s| (s, value.clone())))
        .collect();
    if entries.is_empty() {
        return Err(DeadpoolError {
            reason: DeadpoolFallbackReason::ExecuteFailed,
            is_db_error: false,
            message: "no valid columns".to_string(),
            sql_state: None,
        });
    }

    let set_parts: Vec<String> = entries
        .iter()
        .enumerate()
        .map(|(idx, (column, _))| format!("{} = ${}", column, idx + 1))
        .collect();

    let (where_clause, where_values) =
        build_where_clause(conditions, entries.len() + 1).map_err(|err| DeadpoolError {
            reason: DeadpoolFallbackReason::ExecuteFailed,
            is_db_error: false,
            message: err.to_string(),
            sql_state: None,
        })?;
    if where_clause.is_empty() {
        return Err(DeadpoolError {
            reason: DeadpoolFallbackReason::ExecuteFailed,
            is_db_error: false,
            message: "missing where clause".to_string(),
            sql_state: None,
        });
    }

    let sql = format!(
        "UPDATE {table} AS t SET {set_clause}{where_clause} RETURNING to_jsonb(t.*) AS data",
        table = table,
        set_clause = set_parts.join(", "),
        where_clause = where_clause
    );

    let mut params: Vec<Box<dyn ToSql + Sync>> = Vec::new();
    for (_, value) in &entries {
        params.push(to_sql_param_set(value));
    }
    for value in &where_values {
        params.push(to_sql_param_where(value));
    }
    let param_refs: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| p.as_ref()).collect();

    let rows = client
        .query(sql.as_str(), &param_refs)
        .await
        .map_err(|err| DeadpoolError {
            reason: DeadpoolFallbackReason::ExecuteFailed,
            is_db_error: err.as_db_error().is_some(),
            message: err.to_string(),
            sql_state: err.as_db_error().map(|db| db.code().code().to_string()),
        })?;

    let mut result = Vec::with_capacity(rows.len());
    for row in rows {
        let data: Value = row.try_get("data").map_err(|err| DeadpoolError {
            reason: DeadpoolFallbackReason::DecodeFailed,
            is_db_error: false,
            message: err.to_string(),
            sql_state: None,
        })?;
        result.push(data);
    }
    Ok(result)
}

#[cfg(feature = "deadpool_experimental")]
pub async fn insert_row_deadpool(
    pool: &Pool,
    table_name: &str,
    payload: &Value,
    checkout_timeout: Duration,
) -> Result<Value, DeadpoolError> {
    let client = timeout(checkout_timeout, pool.get())
        .await
        .map_err(|_| DeadpoolError {
            reason: DeadpoolFallbackReason::CheckoutTimeout,
            is_db_error: false,
            message: "deadpool checkout timeout".to_string(),
            sql_state: None,
        })?
        .map_err(|err| DeadpoolError {
            reason: DeadpoolFallbackReason::CheckoutFailed,
            is_db_error: false,
            message: err.to_string(),
            sql_state: None,
        })?;

    let table: String =
        sanitize_qualified_table_identifier(table_name).ok_or_else(|| DeadpoolError {
            reason: DeadpoolFallbackReason::ExecuteFailed,
            is_db_error: false,
            message: "invalid table name".to_string(),
            sql_state: None,
        })?;
    let object: &Map<String, Value> = payload.as_object().ok_or(DeadpoolError {
        reason: DeadpoolFallbackReason::ExecuteFailed,
        is_db_error: false,
        message: "payload must be object".to_string(),
        sql_state: None,
    })?;
    let entries: Vec<(String, Value)> = object
        .iter()
        .filter_map(|(column, value)| sanitize_identifier(column).map(|s| (s, value.clone())))
        .collect();
    if entries.is_empty() {
        return Err(DeadpoolError {
            reason: DeadpoolFallbackReason::ExecuteFailed,
            is_db_error: false,
            message: "no valid columns".to_string(),
            sql_state: None,
        });
    }

    let columns: Vec<&str> = entries.iter().map(|(c, _)| c.as_str()).collect();
    let (placeholders, bind_values) = build_insert_placeholders_for_entries(&entries);

    let sql = format!(
        "INSERT INTO {table} AS t ({columns}) VALUES ({placeholders}) RETURNING to_jsonb(t.*) AS data",
        table = table,
        columns = columns.join(", "),
        placeholders = placeholders.join(", ")
    );

    let params: Vec<Box<dyn ToSql + Sync>> = bind_values
        .iter()
        .map(|value| to_sql_param_set(*value))
        .collect();
    let param_refs: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| &**p).collect();

    let row = client
        .query_one(sql.as_str(), &param_refs)
        .await
        .map_err(|err| DeadpoolError {
            reason: DeadpoolFallbackReason::ExecuteFailed,
            is_db_error: err.as_db_error().is_some(),
            message: err.to_string(),
            sql_state: err.as_db_error().map(|db| db.code().code().to_string()),
        })?;

    row.try_get("data").map_err(|err| DeadpoolError {
        reason: DeadpoolFallbackReason::DecodeFailed,
        is_db_error: false,
        message: err.to_string(),
        sql_state: None,
    })
}