use crate::parser::query_builder::{
Condition, build_insert_placeholders, build_where_clause, sanitize_identifier,
};
use anyhow::{Context, Result, anyhow};
use futures::future::join_all;
use serde_json::{Map, Value};
use sqlx::Error as SqlxError;
use sqlx::Row;
use sqlx::postgres::{PgArguments, PgPool, PgPoolOptions, PgRow};
use sqlx::query::Query;
use sqlx::types::Json;
use std::collections::HashMap;
use std::convert::TryFrom;
use tracing::info;
use uuid::Uuid;
pub struct PostgresClientRegistry {
pools: HashMap<String, PgPool>,
}
impl PostgresClientRegistry {
pub fn empty() -> Self {
Self {
pools: HashMap::new(),
}
}
pub fn is_empty(&self) -> bool {
self.pools.is_empty()
}
pub async fn from_entries(
entries: Vec<(String, String)>,
) -> Result<(Self, Vec<(String, anyhow::Error)>)> {
let connect_tasks = entries.into_iter().map(|(client_name, uri)| async move {
tracing::info!(client = %client_name, uri = %uri, "connecting to Postgres client");
match PgPoolOptions::new()
.max_connections(50) .min_connections(5) .acquire_timeout(std::time::Duration::from_secs(3))
.idle_timeout(std::time::Duration::from_secs(300))
.max_lifetime(std::time::Duration::from_secs(1800))
.test_before_acquire(false) .connect(&uri)
.await
{
Ok(pool) => {
tracing::info!(client = %client_name, "connected to Postgres client");
Ok((client_name, pool))
}
Err(err) => {
let context_error = anyhow!(
"failed to connect to postgres client {}: {}",
client_name,
err
);
tracing::error!(
client = %client_name,
uri = %uri,
error = %err,
"failed to connect to Postgres client"
);
Err((client_name, context_error))
}
}
});
let mut pools: HashMap<String, PgPool> = HashMap::new();
let mut errors: Vec<(String, anyhow::Error)> = Vec::new();
for result in join_all(connect_tasks).await {
match result {
Ok((client_name, pool)) => {
pools.insert(client_name, pool);
}
Err((client_name, err)) => {
errors.push((client_name, err));
}
}
}
Ok((Self { pools }, errors))
}
pub fn get_pool(&self, key: &str) -> Option<PgPool> {
self.pools.get(key).cloned()
}
pub fn list_clients(&self) -> Vec<String> {
let mut keys: Vec<String> = self.pools.keys().cloned().collect();
keys.sort();
keys
}
}
macro_rules! bind_value {
($query:expr, $value:expr) => {
match $value {
Value::Null => $query.bind(None::<String>),
Value::Bool(b) => $query.bind(*b),
Value::Number(num) => {
if let Some(i) = num.as_i64() {
$query.bind(i)
} else if let Some(f) = num.as_f64() {
$query.bind(f)
} else if let Some(u) = num.as_u64() {
if let Ok(i) = i64::try_from(u) {
$query.bind(i)
} else {
$query.bind(num.to_string())
}
} else {
$query.bind(num.to_string())
}
}
Value::String(s) => {
if let Ok(parsed) = Uuid::parse_str(s) {
$query.bind(parsed)
} else {
$query.bind(s.clone())
}
}
Value::Array(_) | Value::Object(_) => $query.bind(Json($value.clone())),
}
};
}
#[derive(Debug)]
pub enum PostgresInsertError {
InvalidTableName,
InvalidPayload(String),
NoValidColumns,
MissingReturnColumn,
SqlExecution {
message: String,
sql_state: Option<String>,
},
}
pub async fn insert_row(
pool: &PgPool,
table_name: &str,
payload: &Value,
) -> Result<Value, PostgresInsertError> {
let table: String =
sanitize_identifier(table_name).ok_or(PostgresInsertError::InvalidTableName)?;
let object: &Map<String, Value> = payload.as_object().ok_or_else(|| {
PostgresInsertError::InvalidPayload("insert payload must be an object".to_string())
})?;
let entries: Vec<(String, Value)> = object
.iter()
.filter_map(|(column, value)| {
sanitize_identifier(column).map(|sanitized| (sanitized, value.clone()))
})
.collect::<Vec<_>>();
if entries.is_empty() {
return Err(PostgresInsertError::NoValidColumns);
}
let columns: Vec<&str> = entries
.iter()
.map(|(column, _)| column.as_str())
.collect::<Vec<_>>();
let value_refs: Vec<&Value> = entries.iter().map(|(_, value)| value).collect();
let (placeholders, bind_values) = build_insert_placeholders(&value_refs);
let sql: String = format!(
"INSERT INTO {table} AS t ({columns}) VALUES ({placeholders}) RETURNING to_jsonb(t.*) AS data",
table = table,
columns = columns.join(", "),
placeholders = placeholders.join(", ")
);
let mut query: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
for value in bind_values {
query = bind_value!(query, value);
}
let row: PgRow = query.fetch_one(pool).await.map_err(|err| match err {
SqlxError::Database(db_err) => PostgresInsertError::SqlExecution {
message: db_err.message().to_string(),
sql_state: db_err.code().map(|code| code.to_string()),
},
other => PostgresInsertError::SqlExecution {
message: other.to_string(),
sql_state: None,
},
})?;
let data: Json<Value> = row
.try_get("data")
.map_err(|_| PostgresInsertError::MissingReturnColumn)?;
Ok(data.0)
}
pub async fn update_row(
pool: &PgPool,
table_name: &str,
conditions: &[Condition],
payload: &Value,
) -> Result<Value> {
let table: String =
sanitize_identifier(table_name).ok_or_else(|| anyhow!("invalid table name"))?;
let entries: Vec<(String, Value)> = payload
.as_object()
.context("update payload must be an object")?
.iter()
.filter_map(|(column, value)| {
sanitize_identifier(column).map(|sanitized| (sanitized, value.clone()))
})
.collect::<Vec<_>>();
if entries.is_empty() {
return Err(anyhow!("no valid columns provided for update"));
}
let set_parts: Vec<String> = entries
.iter()
.enumerate()
.map(|(idx, (column, _))| format!("{} = ${}", column, idx + 1))
.collect::<Vec<_>>();
let (where_clause, where_values) = build_where_clause(conditions, entries.len() + 1)?;
if where_clause.is_empty() {
return Err(anyhow!("at least one valid condition is required"));
}
let sql: String = format!(
"UPDATE {table} AS t SET {set_clause}{where_clause} RETURNING to_jsonb(t.*) AS data",
table = table,
set_clause = set_parts.join(", "),
where_clause = where_clause
);
let mut query: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
for (_, value) in &entries {
query = bind_value!(query, value);
}
for value in &where_values {
query = bind_value!(query, value);
}
let row: PgRow = query
.fetch_one(pool)
.await
.context("failed to execute update row")?;
let data: Json<Value> = row
.try_get("data")
.context("missing data column after update")?;
Ok(data.0)
}
pub async fn fetch_rows(
pool: &PgPool,
table_name: &str,
conditions: &[Condition],
limit: i64,
offset: i64,
) -> Result<Vec<Value>> {
let table: String =
sanitize_identifier(table_name).ok_or_else(|| anyhow!("invalid table name"))?;
let (where_clause, where_values) = build_where_clause(conditions, 1)?;
let sql: String = format!(
"SELECT row_to_json(t.*) AS data FROM {table} AS t{where_clause} LIMIT {limit} OFFSET {offset}",
table = table,
where_clause = where_clause,
limit = limit,
offset = offset
);
let mut query: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
for value in &where_values {
query = bind_value!(query, value);
}
let rows: Vec<PgRow> = query
.fetch_all(pool)
.await
.with_context(|| format!("failed to execute select query: {}", sql))?;
let mut result = Vec::new();
for row in rows {
let data: Json<Value> = row
.try_get("data")
.context("missing data column in select result")?;
result.push(data.0);
}
Ok(result)
}
pub async fn fetch_rows_with_columns(
pool: &PgPool,
table_name: &str,
columns: &[&str],
conditions: &[Condition],
limit: i64,
offset: i64,
) -> Result<Vec<Value>> {
let table: String =
sanitize_identifier(table_name).ok_or_else(|| anyhow!("invalid table name"))?;
let use_all_columns: bool = columns.is_empty() || columns.contains(&"*");
let (where_clause, where_values) = build_where_clause(conditions, 1)?;
let sql: String = if use_all_columns {
format!(
"SELECT row_to_json(t.*) AS data FROM {table} AS t{where_clause} LIMIT {limit} OFFSET {offset}",
table = table,
where_clause = where_clause,
limit = limit,
offset = offset
)
} else {
let column_pairs: Vec<String> = columns
.iter()
.filter_map(|col| {
sanitize_identifier(col)
.map(|sanitized| format!("'{}', t.{}", sanitized, sanitized))
})
.collect();
if column_pairs.is_empty() {
return Err(anyhow!("no valid columns specified"));
}
format!(
"SELECT jsonb_build_object({columns}) AS data FROM {table} AS t{where_clause} LIMIT {limit} OFFSET {offset}",
columns = column_pairs.join(", "),
table = table,
where_clause = where_clause,
limit = limit,
offset = offset
)
};
let mut query: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
info!("query SQL: {:#?}", sql);
for value in &where_values {
query = bind_value!(query, value);
}
let rows: Vec<PgRow> = query
.fetch_all(pool)
.await
.with_context(|| format!("failed to execute select query: {}", sql))?;
info!("rows: {:#?}", rows);
let mut result: Vec<Value> = Vec::new();
for row in rows {
let data: Json<Value> = row
.try_get("data")
.context("missing data column in select result")?;
result.push(data.0);
}
Ok(result)
}