athena_rs 0.75.4

WIP Database API gateway
Documentation
use crate::parser::query_builder::{
    Condition, build_insert_placeholders, build_where_clause, sanitize_identifier,
};
use anyhow::{Context, Result, anyhow};
use futures::future::join_all;
use serde_json::{Map, Value};

use sqlx::Error as SqlxError;
use sqlx::Row;
use sqlx::postgres::{PgArguments, PgPool, PgPoolOptions, PgRow};
use sqlx::query::Query;
use sqlx::types::Json;
use std::collections::HashMap;
use std::convert::TryFrom;
use tracing::info;
use uuid::Uuid;

pub struct PostgresClientRegistry {
    pools: HashMap<String, PgPool>,
}

impl PostgresClientRegistry {
    pub fn empty() -> Self {
        Self {
            pools: HashMap::new(),
        }
    }

    pub fn is_empty(&self) -> bool {
        self.pools.is_empty()
    }

    pub async fn from_entries(
        entries: Vec<(String, String)>,
    ) -> Result<(Self, Vec<(String, anyhow::Error)>)> {
        let connect_tasks = entries.into_iter().map(|(client_name, uri)| async move {
            tracing::info!(client = %client_name, uri = %uri, "connecting to Postgres client");
            match PgPoolOptions::new()
                .max_connections(50) // Increased from 10 to 50
                .min_connections(5) // Maintain minimum pool
                .acquire_timeout(std::time::Duration::from_secs(3))
                .idle_timeout(std::time::Duration::from_secs(300))
                .max_lifetime(std::time::Duration::from_secs(1800))
                .test_before_acquire(false) // Skip test for better performance
                .connect(&uri)
                .await
            {
                Ok(pool) => {
                    tracing::info!(client = %client_name, "connected to Postgres client");
                    Ok((client_name, pool))
                }
                Err(err) => {
                    let context_error = anyhow!(
                        "failed to connect to postgres client {}: {}",
                        client_name,
                        err
                    );
                    tracing::error!(
                        client = %client_name,
                        uri = %uri,
                        error = %err,
                        "failed to connect to Postgres client"
                    );
                    Err((client_name, context_error))
                }
            }
        });

        let mut pools: HashMap<String, PgPool> = HashMap::new();
        let mut errors: Vec<(String, anyhow::Error)> = Vec::new();

        for result in join_all(connect_tasks).await {
            match result {
                Ok((client_name, pool)) => {
                    pools.insert(client_name, pool);
                }
                Err((client_name, err)) => {
                    errors.push((client_name, err));
                }
            }
        }

        Ok((Self { pools }, errors))
    }

    pub fn get_pool(&self, key: &str) -> Option<PgPool> {
        self.pools.get(key).cloned()
    }

    pub fn list_clients(&self) -> Vec<String> {
        let mut keys: Vec<String> = self.pools.keys().cloned().collect();
        keys.sort();
        keys
    }
}

macro_rules! bind_value {
    ($query:expr, $value:expr) => {
        match $value {
            Value::Null => $query.bind(None::<String>),
            Value::Bool(b) => $query.bind(*b),
            Value::Number(num) => {
                if let Some(i) = num.as_i64() {
                    $query.bind(i)
                } else if let Some(f) = num.as_f64() {
                    $query.bind(f)
                } else if let Some(u) = num.as_u64() {
                    if let Ok(i) = i64::try_from(u) {
                        $query.bind(i)
                    } else {
                        $query.bind(num.to_string())
                    }
                } else {
                    $query.bind(num.to_string())
                }
            }
            Value::String(s) => {
                if let Ok(parsed) = Uuid::parse_str(s) {
                    $query.bind(parsed)
                } else {
                    $query.bind(s.clone())
                }
            }
            Value::Array(_) | Value::Object(_) => $query.bind(Json($value.clone())),
        }
    };
}

/// ## `insert_row` -
///
/// ### Arguments
/// - `pool`: `&PgPool`
/// floris; i am most likely going to switch this out for some other solution in the future
/// as it would be miles better if we can have some typed solution that looks at our schema  and
/// can error if a table is wrong or type hint
/// - `table_name`: `&str`
/// Takes literally anything except a possible `[]`
/// - `payload`: `&Value`
#[derive(Debug)]
pub enum PostgresInsertError {
    InvalidTableName,
    InvalidPayload(String),
    NoValidColumns,
    MissingReturnColumn,
    SqlExecution {
        message: String,
        sql_state: Option<String>,
    },
}

pub async fn insert_row(
    pool: &PgPool,
    table_name: &str,
    payload: &Value,
) -> Result<Value, PostgresInsertError> {
    let table: String =
        sanitize_identifier(table_name).ok_or(PostgresInsertError::InvalidTableName)?;
    let object: &Map<String, Value> = payload.as_object().ok_or_else(|| {
        PostgresInsertError::InvalidPayload("insert payload must be an object".to_string())
    })?;
    let entries: Vec<(String, Value)> = object
        .iter()
        .filter_map(|(column, value)| {
            sanitize_identifier(column).map(|sanitized| (sanitized, value.clone()))
        })
        .collect::<Vec<_>>();
    if entries.is_empty() {
        return Err(PostgresInsertError::NoValidColumns);
    }

    let columns: Vec<&str> = entries
        .iter()
        .map(|(column, _)| column.as_str())
        .collect::<Vec<_>>();
    let value_refs: Vec<&Value> = entries.iter().map(|(_, value)| value).collect();
    let (placeholders, bind_values) = build_insert_placeholders(&value_refs);

    let sql: String = format!(
        "INSERT INTO {table} AS t ({columns}) VALUES ({placeholders}) RETURNING to_jsonb(t.*) AS data",
        table = table,
        columns = columns.join(", "),
        placeholders = placeholders.join(", ")
    );

    let mut query: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
    for value in bind_values {
        query = bind_value!(query, value);
    }

    let row: PgRow = query.fetch_one(pool).await.map_err(|err| match err {
        SqlxError::Database(db_err) => PostgresInsertError::SqlExecution {
            message: db_err.message().to_string(),
            sql_state: db_err.code().map(|code| code.to_string()),
        },
        other => PostgresInsertError::SqlExecution {
            message: other.to_string(),
            sql_state: None,
        },
    })?;
    let data: Json<Value> = row
        .try_get("data")
        .map_err(|_| PostgresInsertError::MissingReturnColumn)?;
    Ok(data.0)
}

/// ### `update_row`
pub async fn update_row(
    pool: &PgPool,
    table_name: &str,
    conditions: &[Condition],
    payload: &Value,
) -> Result<Value> {
    let table: String =
        sanitize_identifier(table_name).ok_or_else(|| anyhow!("invalid table name"))?;
    let entries: Vec<(String, Value)> = payload
        .as_object()
        .context("update payload must be an object")?
        .iter()
        .filter_map(|(column, value)| {
            sanitize_identifier(column).map(|sanitized| (sanitized, value.clone()))
        })
        .collect::<Vec<_>>();
    if entries.is_empty() {
        return Err(anyhow!("no valid columns provided for update"));
    }

    let set_parts: Vec<String> = entries
        .iter()
        .enumerate()
        .map(|(idx, (column, _))| format!("{} = ${}", column, idx + 1))
        .collect::<Vec<_>>();

    let (where_clause, where_values) = build_where_clause(conditions, entries.len() + 1)?;
    if where_clause.is_empty() {
        return Err(anyhow!("at least one valid condition is required"));
    }

    let sql: String = format!(
        "UPDATE {table} AS t SET {set_clause}{where_clause} RETURNING to_jsonb(t.*) AS data",
        table = table,
        set_clause = set_parts.join(", "),
        where_clause = where_clause
    );

    let mut query: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
    for (_, value) in &entries {
        query = bind_value!(query, value);
    }
    for value in &where_values {
        query = bind_value!(query, value);
    }

    let row: PgRow = query
        .fetch_one(pool)
        .await
        .context("failed to execute update row")?;
    let data: Json<Value> = row
        .try_get("data")
        .context("missing data column after update")?;
    Ok(data.0)
}

pub async fn fetch_rows(
    pool: &PgPool,
    table_name: &str,
    conditions: &[Condition],
    limit: i64,
    offset: i64,
) -> Result<Vec<Value>> {
    let table: String =
        sanitize_identifier(table_name).ok_or_else(|| anyhow!("invalid table name"))?;
    let (where_clause, where_values) = build_where_clause(conditions, 1)?;
    let sql: String = format!(
        "SELECT row_to_json(t.*) AS data FROM {table} AS t{where_clause} LIMIT {limit} OFFSET {offset}",
        table = table,
        where_clause = where_clause,
        limit = limit,
        offset = offset
    );

    let mut query: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
    for value in &where_values {
        query = bind_value!(query, value);
    }

    let rows: Vec<PgRow> = query
        .fetch_all(pool)
        .await
        .with_context(|| format!("failed to execute select query: {}", sql))?;
    let mut result = Vec::new();
    for row in rows {
        let data: Json<Value> = row
            .try_get("data")
            .context("missing data column in select result")?;
        result.push(data.0);
    }
    Ok(result)
}

pub async fn fetch_rows_with_columns(
    pool: &PgPool,
    table_name: &str,
    columns: &[&str],
    conditions: &[Condition],
    limit: i64,
    offset: i64,
) -> Result<Vec<Value>> {
    let table: String =
        sanitize_identifier(table_name).ok_or_else(|| anyhow!("invalid table name"))?;

    // If columns contains "*" or is empty, select all columns
    let use_all_columns: bool = columns.is_empty() || columns.contains(&"*");

    let (where_clause, where_values) = build_where_clause(conditions, 1)?;

    let sql: String = if use_all_columns {
        format!(
            "SELECT row_to_json(t.*) AS data FROM {table} AS t{where_clause} LIMIT {limit} OFFSET {offset}",
            table = table,
            where_clause = where_clause,
            limit = limit,
            offset = offset
        )
    } else {
        // Build jsonb_build_object with column names and values
        let column_pairs: Vec<String> = columns
            .iter()
            .filter_map(|col| {
                sanitize_identifier(col)
                    .map(|sanitized| format!("'{}', t.{}", sanitized, sanitized))
            })
            .collect();

        if column_pairs.is_empty() {
            return Err(anyhow!("no valid columns specified"));
        }

        format!(
            "SELECT jsonb_build_object({columns}) AS data FROM {table} AS t{where_clause} LIMIT {limit} OFFSET {offset}",
            columns = column_pairs.join(", "),
            table = table,
            where_clause = where_clause,
            limit = limit,
            offset = offset
        )
    };

    let mut query: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
    info!("query SQL: {:#?}", sql);
    for value in &where_values {
        query = bind_value!(query, value);
    }

    let rows: Vec<PgRow> = query
        .fetch_all(pool)
        .await
        .with_context(|| format!("failed to execute select query: {}", sql))?;
    info!("rows: {:#?}", rows);

    let mut result: Vec<Value> = Vec::new();
    for row in rows {
        let data: Json<Value> = row
            .try_get("data")
            .context("missing data column in select result")?;
        result.push(data.0);
    }
    Ok(result)
}