use crate::parser::query_builder::{
Condition, build_insert_placeholders, build_where_clause, sanitize_identifier,
};
use anyhow::{Context, Result, anyhow};
use futures::future::join_all;
use serde_json::{Map, Value};
use uuid::Uuid;
use sqlx::Error as SqlxError;
use sqlx::Row;
use sqlx::postgres::{PgArguments, PgPool, PgPoolOptions, PgRow};
use sqlx::query::Query;
use sqlx::types::Json;
use std::collections::{HashMap, HashSet};
use std::convert::TryFrom;
use tracing::{error, info};
pub struct PostgresClientRegistry {
pools: HashMap<String, PgPool>,
}
impl PostgresClientRegistry {
pub fn empty() -> Self {
Self {
pools: HashMap::new(),
}
}
pub fn is_empty(&self) -> bool {
self.pools.is_empty()
}
pub async fn from_entries(
entries: Vec<(String, String)>,
) -> Result<(Self, Vec<(String, anyhow::Error)>)> {
let connect_tasks = entries.into_iter().map(|(client_name, uri)| async move {
tracing::info!(client = %client_name, uri = %uri, "connecting to Postgres client");
match PgPoolOptions::new()
.max_connections(50) .min_connections(5) .acquire_timeout(std::time::Duration::from_secs(3))
.idle_timeout(std::time::Duration::from_secs(300))
.max_lifetime(std::time::Duration::from_secs(1800))
.test_before_acquire(false) .connect(&uri)
.await
{
Ok(pool) => {
tracing::info!(client = %client_name, "connected to Postgres client");
Ok((client_name, pool))
}
Err(err) => {
let context_error = anyhow!(
"failed to connect to postgres client {}: {}",
client_name,
err
);
tracing::error!(
client = %client_name,
uri = %uri,
error = %err,
"failed to connect to Postgres client"
);
Err((client_name, context_error))
}
}
});
let mut pools: HashMap<String, PgPool> = HashMap::new();
let mut errors: Vec<(String, anyhow::Error)> = Vec::new();
for result in join_all(connect_tasks).await {
match result {
Ok((client_name, pool)) => {
pools.insert(client_name, pool);
}
Err((client_name, err)) => {
errors.push((client_name, err));
}
}
}
Ok((Self { pools }, errors))
}
pub fn get_pool(&self, key: &str) -> Option<PgPool> {
self.pools.get(key).cloned()
}
pub fn list_clients(&self) -> Vec<String> {
let mut keys: Vec<String> = self.pools.keys().cloned().collect();
keys.sort();
keys
}
}
macro_rules! bind_value {
($query:expr, $value:expr) => {
match $value {
Value::Null => $query.bind(None::<String>),
Value::Bool(b) => $query.bind(*b),
Value::Number(num) => {
if let Some(i) = num.as_i64() {
$query.bind(i)
} else if let Some(f) = num.as_f64() {
$query.bind(f)
} else if let Some(u) = num.as_u64() {
if let Ok(i) = i64::try_from(u) {
$query.bind(i)
} else {
$query.bind(num.to_string())
}
} else {
$query.bind(num.to_string())
}
}
Value::String(s) => $query.bind(s.clone()),
Value::Array(_) | Value::Object(_) => $query.bind(Json($value.clone())),
}
};
}
macro_rules! bind_value_set {
($query:expr, $value:expr) => {
match $value {
Value::Null => $query.bind(None::<String>),
Value::Bool(b) => $query.bind(*b),
Value::Number(num) => {
if let Some(i) = num.as_i64() {
$query.bind(i)
} else if let Some(f) = num.as_f64() {
$query.bind(f)
} else if let Some(u) = num.as_u64() {
if let Ok(i) = i64::try_from(u) {
$query.bind(i)
} else {
$query.bind(num.to_string())
}
} else {
$query.bind(num.to_string())
}
}
Value::String(s) => {
if let Ok(u) = Uuid::parse_str(s) {
$query.bind(u)
} else {
$query.bind(s.clone())
}
}
Value::Array(_) | Value::Object(_) => $query.bind(Json($value.clone())),
}
};
}
#[derive(Debug)]
pub enum PostgresInsertError {
InvalidTableName,
InvalidPayload(String),
NoValidColumns,
MissingReturnColumn,
SqlExecution {
message: String,
sql_state: Option<String>,
},
}
pub async fn insert_row(
pool: &PgPool,
table_name: &str,
payload: &Value,
) -> Result<Value, PostgresInsertError> {
let table: String =
sanitize_identifier(table_name).ok_or(PostgresInsertError::InvalidTableName)?;
let object: &Map<String, Value> = payload.as_object().ok_or_else(|| {
PostgresInsertError::InvalidPayload("insert payload must be an object".to_string())
})?;
let entries: Vec<(String, Value)> = object
.iter()
.filter_map(|(column, value)| {
sanitize_identifier(column).map(|sanitized| (sanitized, value.clone()))
})
.collect::<Vec<_>>();
if entries.is_empty() {
return Err(PostgresInsertError::NoValidColumns);
}
let columns: Vec<&str> = entries
.iter()
.map(|(column, _)| column.as_str())
.collect::<Vec<_>>();
let value_refs: Vec<&Value> = entries.iter().map(|(_, value)| value).collect();
let (placeholders, bind_values) = build_insert_placeholders(&value_refs);
let sql: String = format!(
"INSERT INTO {table} AS t ({columns}) VALUES ({placeholders}) RETURNING to_jsonb(t.*) AS data",
table = table,
columns = columns.join(", "),
placeholders = placeholders.join(", ")
);
let mut query: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
for value in bind_values {
query = bind_value_set!(query, value);
}
let row: PgRow = query.fetch_one(pool).await.map_err(|err| match err {
SqlxError::Database(db_err) => PostgresInsertError::SqlExecution {
message: db_err.message().to_string(),
sql_state: db_err.code().map(|code| code.to_string()),
},
other => PostgresInsertError::SqlExecution {
message: other.to_string(),
sql_state: None,
},
})?;
let data: Json<Value> = row
.try_get("data")
.map_err(|_| PostgresInsertError::MissingReturnColumn)?;
Ok(data.0)
}
pub async fn insert_rows_bulk(
pool: &PgPool,
table_name: &str,
payloads: &[Value],
) -> Result<Vec<Value>, PostgresInsertError> {
if payloads.is_empty() {
return Err(PostgresInsertError::InvalidPayload(
"insert payload array must not be empty".to_string(),
));
}
let mut column_order: Vec<(String, String)> = Vec::new();
let mut seen_columns: HashSet<String> = HashSet::new();
for payload in payloads {
let obj = payload.as_object().ok_or_else(|| {
PostgresInsertError::InvalidPayload(
"each insert payload must be a JSON object".to_string(),
)
})?;
for column in obj.keys() {
if seen_columns.contains(column) {
continue;
}
if let Some(sanitized) = sanitize_identifier(column) {
seen_columns.insert(column.clone());
column_order.push((column.clone(), sanitized));
}
}
}
if column_order.is_empty() {
return Err(PostgresInsertError::NoValidColumns);
}
let sanitized_columns: Vec<String> = column_order
.iter()
.map(|(_, sanitized)| sanitized.clone())
.collect();
let column_names: Vec<String> = column_order.iter().map(|(raw, _)| raw.clone()).collect();
let mut placeholders: Vec<String> = Vec::new();
let mut bind_values: Vec<Value> = Vec::new();
let mut param_index: i32 = 1;
for payload in payloads {
let row_obj = payload.as_object().unwrap();
let mut row_placeholders: Vec<String> = Vec::new();
for column in &column_names {
let value = row_obj.get(column).cloned().unwrap_or(Value::Null);
bind_values.push(value);
row_placeholders.push(format!("${}", param_index));
param_index += 1;
}
placeholders.push(format!("({})", row_placeholders.join(", ")));
}
let table = sanitize_identifier(table_name).ok_or(PostgresInsertError::InvalidTableName)?;
let sql = format!(
"INSERT INTO {table} AS t ({columns}) VALUES {placeholders} RETURNING to_jsonb(t.*) AS data",
table = table,
columns = sanitized_columns.join(", "),
placeholders = placeholders.join(", ")
);
let mut query: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
for value in &bind_values {
query = bind_value_set!(query, value);
}
let rows: Vec<PgRow> =
query
.fetch_all(pool)
.await
.map_err(|err| PostgresInsertError::SqlExecution {
message: err.to_string(),
sql_state: None,
})?;
let mut result: Vec<Value> = Vec::new();
for row in rows {
let data: Json<Value> = row
.try_get("data")
.map_err(|_| PostgresInsertError::MissingReturnColumn)?;
result.push(data.0);
}
Ok(result)
}
pub async fn upsert_row(
pool: &PgPool,
table_name: &str,
payload: &Value,
conflict_column: &str,
) -> Result<Value, PostgresInsertError> {
let table: String =
sanitize_identifier(table_name).ok_or(PostgresInsertError::InvalidTableName)?;
let conflict: String =
sanitize_identifier(conflict_column).ok_or(PostgresInsertError::InvalidTableName)?;
let object: &Map<String, Value> = payload.as_object().ok_or_else(|| {
PostgresInsertError::InvalidPayload("upsert payload must be an object".to_string())
})?;
let entries: Vec<(String, Value)> = object
.iter()
.filter_map(|(column, value)| {
sanitize_identifier(column).map(|sanitized| (sanitized, value.clone()))
})
.collect::<Vec<_>>();
if entries.is_empty() {
return Err(PostgresInsertError::NoValidColumns);
}
let columns: Vec<&str> = entries.iter().map(|(column, _)| column.as_str()).collect();
let values: Vec<&Value> = entries.iter().map(|(_, value)| value).collect();
let (placeholders, bind_values) = build_insert_placeholders(&values);
let set_clause: Vec<String> = entries
.iter()
.map(|(column, _)| format!("{} = EXCLUDED.{}", column, column.trim_matches('"')))
.collect::<Vec<_>>();
let sql: String = format!(
"INSERT INTO {table} AS t ({columns}) VALUES ({placeholders}) ON CONFLICT ({conflict}) DO UPDATE SET {set_clause} RETURNING to_jsonb(t.*) AS data",
table = table,
columns = columns.join(", "),
placeholders = placeholders.join(", "),
conflict = conflict,
set_clause = set_clause.join(", ")
);
let mut query: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
for value in bind_values {
query = bind_value_set!(query, value);
}
let row: PgRow =
query
.fetch_one(pool)
.await
.map_err(|err| PostgresInsertError::SqlExecution {
message: err.to_string(),
sql_state: None,
})?;
let data: Json<Value> = row
.try_get("data")
.map_err(|_| PostgresInsertError::MissingReturnColumn)?;
Ok(data.0)
}
pub async fn update_row(
pool: &PgPool,
table_name: &str,
conditions: &[Condition],
payload: &Value,
) -> Result<Value> {
let table: String =
sanitize_identifier(table_name).ok_or_else(|| anyhow!("invalid table name"))?;
let entries: Vec<(String, Value)> = payload
.as_object()
.context("update payload must be an object")?
.iter()
.filter_map(|(column, value)| {
sanitize_identifier(column).map(|sanitized| (sanitized, value.clone()))
})
.collect::<Vec<_>>();
if entries.is_empty() {
return Err(anyhow!("no valid columns provided for update"));
}
let set_parts: Vec<String> = entries
.iter()
.enumerate()
.map(|(idx, (column, _))| format!("{} = ${}", column, idx + 1))
.collect::<Vec<_>>();
let (where_clause, where_values) = build_where_clause(conditions, entries.len() + 1)?;
if where_clause.is_empty() {
return Err(anyhow!("at least one valid condition is required"));
}
let sql: String = format!(
"UPDATE {table} AS t SET {set_clause}{where_clause} RETURNING to_jsonb(t.*) AS data",
table = table,
set_clause = set_parts.join(", "),
where_clause = where_clause
);
let mut query: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
for (_, value) in &entries {
query = bind_value_set!(query, value);
}
for value in &where_values {
query = bind_value!(query, value);
}
let row: PgRow = query
.fetch_one(pool)
.await
.context("failed to execute update row")?;
let data: Json<Value> = row
.try_get("data")
.context("missing data column after update")?;
Ok(data.0)
}
pub async fn update_rows(
pool: &PgPool,
table_name: &str,
conditions: &[Condition],
payload: &Value,
) -> Result<Vec<Value>> {
let table: String =
sanitize_identifier(table_name).ok_or_else(|| anyhow!("invalid table name"))?;
let entries: Vec<(String, Value)> = payload
.as_object()
.context("update payload must be an object")?
.iter()
.filter_map(|(column, value)| {
sanitize_identifier(column).map(|sanitized| (sanitized, value.clone()))
})
.collect::<Vec<_>>();
if entries.is_empty() {
return Err(anyhow!("no valid columns provided for update"));
}
let set_parts: Vec<String> = entries
.iter()
.enumerate()
.map(|(idx, (column, _))| format!("{} = ${}", column, idx + 1))
.collect::<Vec<_>>();
let (where_clause, where_values) = build_where_clause(conditions, entries.len() + 1)?;
if where_clause.is_empty() {
return Err(anyhow!("at least one valid condition is required"));
}
let sql: String = format!(
"UPDATE {table} AS t SET {set_clause}{where_clause} RETURNING to_jsonb(t.*) AS data",
table = table,
set_clause = set_parts.join(", "),
where_clause = where_clause
);
let mut query: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
for (_, value) in &entries {
query = bind_value_set!(query, value);
}
for value in &where_values {
query = bind_value!(query, value);
}
let rows: Vec<PgRow> = query
.fetch_all(pool)
.await
.map_err(|e| anyhow!("failed to execute update rows: {}", e))?;
let mut result: Vec<Value> = Vec::new();
for row in rows {
let data: Json<Value> = row
.try_get("data")
.context("missing data column after bulk update")?;
result.push(data.0);
}
Ok(result)
}
pub async fn fetch_rows(
pool: &PgPool,
table_name: &str,
conditions: &[Condition],
limit: i64,
offset: i64,
) -> Result<Vec<Value>> {
let table: String =
sanitize_identifier(table_name).ok_or_else(|| anyhow!("invalid table name"))?;
let (where_clause, where_values) = build_where_clause(conditions, 1)?;
let sql: String = format!(
"SELECT row_to_json(t.*) AS data FROM {table} AS t{where_clause} LIMIT {limit} OFFSET {offset}",
table = table,
where_clause = where_clause,
limit = limit,
offset = offset
);
let mut query: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
for value in &where_values {
query = bind_value!(query, value);
}
let rows: Vec<PgRow> = query
.fetch_all(pool)
.await
.context("failed to execute select query")?;
let mut result = Vec::new();
for row in rows {
let data: Json<Value> = row
.try_get("data")
.context("missing data column in select result")?;
result.push(data.0);
}
Ok(result)
}
pub async fn delete_rows(
pool: &PgPool,
table_name: &str,
conditions: &[Condition],
) -> Result<Vec<Value>> {
let table: String =
sanitize_identifier(table_name).ok_or_else(|| anyhow!("invalid table name"))?;
let (where_clause, where_values) = build_where_clause(conditions, 1)?;
if where_clause.is_empty() {
return Err(anyhow!("at least one valid condition is required"));
}
let sql: String = format!(
"DELETE FROM {table} AS t{where_clause} RETURNING to_jsonb(t.*) AS data",
table = table,
where_clause = where_clause
);
let mut query: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
for value in &where_values {
query = bind_value!(query, value);
}
let rows: Vec<PgRow> = query
.fetch_all(pool)
.await
.context("failed to execute delete rows")?;
let mut result: Vec<Value> = Vec::new();
for row in rows {
let data: Json<Value> = row
.try_get("data")
.context("missing data column after delete")?;
result.push(data.0);
}
Ok(result)
}
pub async fn fetch_rows_with_columns(
pool: &PgPool,
table_name: &str,
columns: &[&str],
conditions: &[Condition],
limit: i64,
offset: i64,
) -> Result<Vec<Value>> {
let table: String =
sanitize_identifier(table_name).ok_or_else(|| anyhow!("invalid table name"))?;
let use_all_columns: bool = columns.is_empty() || columns.contains(&"*");
let (where_clause, where_values) = build_where_clause(conditions, 1)?;
let sql: String = if use_all_columns {
format!(
"SELECT row_to_json(t.*) AS data FROM {table} AS t{where_clause} LIMIT {limit} OFFSET {offset}",
table = table,
where_clause = where_clause,
limit = limit,
offset = offset
)
} else {
let resolved_columns =
crate::drivers::postgresql::column_resolver::resolve_columns(pool, table_name, columns)
.await?;
let column_pairs: Vec<String> = columns
.iter()
.zip(resolved_columns.iter())
.filter_map(|(requested, resolved)| {
sanitize_identifier(resolved).map(|sanitized| {
let json_key = requested;
format!("'{}', t.{}", json_key, sanitized)
})
})
.collect();
if column_pairs.is_empty() {
return Err(anyhow!("no valid columns specified"));
}
format!(
"SELECT jsonb_build_object({columns}) AS data FROM {table} AS t{where_clause} LIMIT {limit} OFFSET {offset}",
columns = column_pairs.join(", "),
table = table,
where_clause = where_clause,
limit = limit,
offset = offset
)
};
let mut query: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
let binding_descriptions: Vec<String> = describe_bind_values(&where_values);
info!(
sql = %sql,
bindings = ?binding_descriptions,
"executing select query"
);
for value in &where_values {
query = bind_value!(query, value);
}
let rows: Vec<PgRow> = query
.fetch_all(pool)
.await
.map_err(|err| {
error!(
sql = %sql,
bindings = ?binding_descriptions,
error = ?err,
"failed to execute select query"
);
err
})
.context("failed to execute select query")?;
info!("rows: {:#?}", rows);
let mut result: Vec<Value> = Vec::new();
for row in rows {
let data: Json<Value> = row
.try_get("data")
.context("missing data column in select result")?;
result.push(data.0);
}
Ok(result)
}
#[doc(hidden)]
pub fn describe_bind_values(values: &[Value]) -> Vec<String> {
values.iter().map(describe_bind_value).collect()
}
#[doc(hidden)]
pub fn describe_bind_value(value: &Value) -> String {
match value {
Value::Null => "null (null)".to_string(),
Value::Bool(b) => format!("{} (bool)", b),
Value::Number(num) => {
if let Some(i) = num.as_i64() {
format!("{} (i64)", i)
} else if let Some(u) = num.as_u64() {
format!("{} (u64)", u)
} else if let Some(f) = num.as_f64() {
format!("{} (f64)", f)
} else {
format!("{} (number)", num)
}
}
Value::String(text) => format!("{} (string)", text),
Value::Array(arr) => format!("array(len={})", arr.len()),
Value::Object(map) => format!("object(len={})", map.len()),
}
}