use crate::parser::query_builder::{
Condition, build_insert_placeholders, build_where_clause, sanitize_identifier,
sanitize_qualified_table_identifier,
};
use crate::utils::sqlx_postgres_connect_uri::sanitize_sqlx_postgres_connect_uri;
use ::sqlx::postgres::{PgArguments, PgPool, PgPoolOptions};
use ::sqlx::types::Json;
use ::sqlx::{Postgres, Row};
use anyhow::{Context, Result, anyhow};
use serde_json::Value;
use sqlx::FromRow;
use sqlx::postgres::PgRow;
use sqlx::query::Query;
use std::collections::HashMap;
use std::convert::TryFrom;
pub struct PostgresClientRegistry {
pools: HashMap<String, PgPool>,
}
impl PostgresClientRegistry {
pub fn empty() -> Self {
Self {
pools: HashMap::new(),
}
}
pub async fn from_entries(entries: Vec<(String, String)>) -> Result<Self> {
let mut pools: HashMap<String, PgPool> = HashMap::new();
for (client_name, uri) in entries {
let sanitized_uri = sanitize_sqlx_postgres_connect_uri(&uri);
let pool: sqlx::Pool<Postgres> = PgPoolOptions::new()
.max_connections(50) .min_connections(5) .acquire_timeout(std::time::Duration::from_secs(3))
.idle_timeout(std::time::Duration::from_secs(300))
.max_lifetime(std::time::Duration::from_secs(1800))
.test_before_acquire(false) .connect(sanitized_uri.as_ref())
.await
.with_context(|| format!("failed to connect to Athena client {}", client_name))?;
pools.insert(client_name, pool);
}
Ok(Self { pools })
}
pub fn get_pool(&self, key: &str) -> Option<PgPool> {
self.pools.get(key).cloned()
}
}
pub async fn insert_row(pool: &PgPool, table_name: &str, payload: &Value) -> Result<Value> {
let table: String = sanitize_qualified_table_identifier(table_name)
.ok_or_else(|| anyhow!("invalid table name"))?;
let entries: Vec<(String, Value)> = payload
.as_object()
.context("insert payload must be an object")?
.iter()
.filter_map(|(column, value)| {
sanitize_identifier(column).map(|sanitized| (sanitized, value.clone()))
})
.collect::<Vec<_>>();
if entries.is_empty() {
return Err(anyhow!("no valid columns provided for insert"));
}
let columns: Vec<&str> = entries
.iter()
.map(|(column, _)| column.as_str())
.collect::<Vec<_>>();
let values: Vec<&Value> = entries.iter().map(|(_, value)| value).collect();
let (placeholders, bind_values) = build_insert_placeholders(&values);
let sql: String = format!(
"INSERT INTO {table} AS t ({columns}) VALUES ({placeholders}) RETURNING to_jsonb(t.*) AS data",
table = table,
columns = columns.join(", "),
placeholders = placeholders.join(", ")
);
let mut query: Query<'_, Postgres, PgArguments> = ::sqlx::query(&sql);
for value in bind_values {
query = bind_value(query, value);
}
let row: PgRow = query
.fetch_one(pool)
.await
.context("failed to execute insert row")?;
let data: Json<Value> = row
.try_get("data")
.context("missing data column after insert")?;
Ok(data.0)
}
pub async fn insert_row_with_schema_coercion(
pool: &PgPool,
table_name: &str,
payload: &Value,
) -> Result<Value> {
let table: String = sanitize_qualified_table_identifier(table_name)
.ok_or_else(|| anyhow!("invalid table name"))?;
let (schema_name, bare_table_name) = split_qualified_table_name(&table);
let column_catalog = load_table_insert_column_metadata(pool, schema_name, bare_table_name)
.await
.with_context(|| format!("failed to load insert metadata for table {table}"))?;
let entries = payload
.as_object()
.context("insert payload must be an object")?
.iter()
.filter_map(|(column, value)| {
let normalized_column = normalize_identifier_lookup_key(column);
let sanitized = sanitize_identifier(&normalized_column)?;
let metadata = column_catalog.get(&normalized_column);
if value.is_null() && metadata.is_some_and(TableInsertColumnMetadata::omit_null_value) {
return None;
}
Some((
sanitized,
value.clone(),
build_insert_placeholder(metadata, value),
))
})
.collect::<Vec<_>>();
if entries.is_empty() {
return Err(anyhow!("no valid columns provided for insert"));
}
let columns: Vec<&str> = entries
.iter()
.map(|(column, _, _)| column.as_str())
.collect::<Vec<_>>();
let placeholders: Vec<String> = entries
.iter()
.enumerate()
.map(|(index, (_, _, cast_type))| match cast_type {
Some(cast_type) => format!("${}::{cast_type}", index + 1),
None => format!("${}", index + 1),
})
.collect::<Vec<_>>();
let sql: String = format!(
"INSERT INTO {table} AS t ({columns}) VALUES ({placeholders}) RETURNING to_jsonb(t.*) AS data",
table = table,
columns = columns.join(", "),
placeholders = placeholders.join(", ")
);
let cast_plan = describe_insert_cast_plan(&entries);
let mut query: Query<'_, Postgres, PgArguments> = ::sqlx::query(&sql);
for (_, value, _) in &entries {
query = bind_value(query, value);
}
let row: PgRow = query.fetch_one(pool).await.with_context(|| {
format!("failed to execute insert row for table {table} with cast plan [{cast_plan}]")
})?;
let data: Json<Value> = row
.try_get("data")
.context("missing data column after insert")?;
Ok(data.0)
}
pub async fn update_row(
pool: &PgPool,
table_name: &str,
conditions: &[Condition],
payload: &Value,
) -> Result<Value> {
let table: String =
sanitize_identifier(table_name).ok_or_else(|| anyhow!("invalid table name"))?;
let entries: Vec<(String, Value)> = payload
.as_object()
.context("update payload must be an object")?
.iter()
.filter_map(|(column, value)| {
sanitize_identifier(column).map(|sanitized| (sanitized, value.clone()))
})
.collect::<Vec<_>>();
if entries.is_empty() {
return Err(anyhow!("no valid columns provided for update"));
}
let set_parts: Vec<String> = entries
.iter()
.enumerate()
.map(|(idx, (column, _))| format!("{} = ${}", column, idx + 1))
.collect::<Vec<_>>();
let (where_clause, where_values) = build_where_clause(conditions, entries.len() + 1)?;
if where_clause.is_empty() {
return Err(anyhow!("at least one valid condition is required"));
}
let sql: String = format!(
"UPDATE {table} AS t SET {set_clause}{where_clause} RETURNING to_jsonb(t.*) AS data",
table = table,
set_clause = set_parts.join(", "),
where_clause = where_clause
);
let mut query: Query<'_, Postgres, PgArguments> = ::sqlx::query(&sql);
for (_, value) in &entries {
query = bind_value(query, value);
}
for value in &where_values {
query = bind_value(query, value);
}
let row: PgRow = query
.fetch_one(pool)
.await
.context("failed to execute update row")?;
let data: Json<Value> = row
.try_get("data")
.context("missing data column after update")?;
Ok(data.0)
}
pub async fn fetch_rows(
pool: &PgPool,
table_name: &str,
conditions: &[Condition],
limit: i64,
offset: i64,
) -> Result<Vec<Value>> {
let table: String =
sanitize_identifier(table_name).ok_or_else(|| anyhow!("invalid table name"))?;
let (where_clause, where_values) = build_where_clause(conditions, 1)?;
let sql: String = format!(
"SELECT row_to_json(t.*) AS data FROM {table} AS t{where_clause} LIMIT {limit} OFFSET {offset}",
table = table,
where_clause = where_clause,
limit = limit,
offset = offset
);
let mut query: Query<'_, Postgres, PgArguments> = ::sqlx::query(&sql);
for value in &where_values {
query = bind_value(query, value);
}
let rows: Vec<PgRow> = query
.fetch_all(pool)
.await
.with_context(|| format!("failed to execute select query: {}", sql))?;
let mut result: Vec<Value> = Vec::new();
for row in rows {
let data: Json<Value> = row
.try_get("data")
.context("missing data column in select result")?;
result.push(data.0);
}
Ok(result)
}
fn bind_value<'q>(
query: Query<'q, Postgres, PgArguments>,
value: &Value,
) -> Query<'q, Postgres, PgArguments> {
match value {
Value::Null => query.bind(None::<String>),
Value::Bool(b) => query.bind(*b),
Value::Number(num) => {
if let Some(i) = num.as_i64() {
query.bind(i)
} else if let Some(f) = num.as_f64() {
query.bind(f)
} else if let Some(u) = num.as_u64() {
if let Ok(i) = i64::try_from(u) {
query.bind(i)
} else {
query.bind(num.to_string())
}
} else {
query.bind(num.to_string())
}
}
Value::String(s) => query.bind(s.clone()),
Value::Array(_) | Value::Object(_) => query.bind(Json(value.clone())),
}
}
#[derive(Debug, Clone, FromRow)]
struct TableInsertColumnMetadata {
column_name: String,
data_type: String,
udt_name: String,
column_default: Option<String>,
is_identity: String,
}
impl TableInsertColumnMetadata {
fn omit_null_value(&self) -> bool {
self.column_default.is_some() || self.is_identity == "YES"
}
fn cast_type(&self, value: &Value) -> Option<&'static str> {
match value {
Value::Null => None,
Value::String(_) => match (
self.udt_name.as_str(),
self.data_type.to_ascii_lowercase().as_str(),
) {
("uuid", _) | (_, "uuid") => Some("uuid"),
("json", _) | (_, "json") => Some("json"),
("jsonb", _) | (_, "jsonb") => Some("jsonb"),
("bool", _) | (_, "boolean") => Some("boolean"),
("date", _) | (_, "date") => Some("date"),
("time", _) | (_, "time without time zone") => Some("time"),
("timetz", _) | (_, "time with time zone") => Some("timetz"),
("timestamp", _) | (_, "timestamp without time zone") => Some("timestamp"),
("timestamptz", _) | (_, "timestamp with time zone") => Some("timestamptz"),
("int2", _) | (_, "smallint") => Some("smallint"),
("int4", _) | (_, "integer") => Some("integer"),
("int8", _) | (_, "bigint") => Some("bigint"),
("numeric", _) | (_, "numeric") => Some("numeric"),
("float4", _) | (_, "real") => Some("real"),
("float8", _) | (_, "double precision") => Some("double precision"),
_ => None,
},
Value::Array(_) | Value::Object(_) => match (
self.udt_name.as_str(),
self.data_type.to_ascii_lowercase().as_str(),
) {
("json", _) | (_, "json") => Some("json"),
("jsonb", _) | (_, "jsonb") => Some("jsonb"),
_ => None,
},
_ => None,
}
}
}
fn describe_insert_cast_plan(entries: &[(String, Value, Option<&'static str>)]) -> String {
entries
.iter()
.map(|(column, value, cast_type)| {
let value_kind = match value {
Value::Null => "null",
Value::Bool(_) => "bool",
Value::Number(_) => "number",
Value::String(_) => "string",
Value::Array(_) => "array",
Value::Object(_) => "object",
};
match cast_type {
Some(cast_type) => format!("{column}:{value_kind}->{cast_type}"),
None => format!("{column}:{value_kind}->raw"),
}
})
.collect::<Vec<_>>()
.join(", ")
}
fn split_qualified_table_name(table_name: &str) -> (&str, &str) {
match table_name.split_once('.') {
Some((schema_name, bare_table_name)) => (
schema_name.trim().trim_matches('"'),
bare_table_name.trim().trim_matches('"'),
),
None => ("public", table_name.trim().trim_matches('"')),
}
}
fn normalize_identifier_lookup_key(value: &str) -> String {
value.trim().trim_matches('"').to_string()
}
fn build_insert_placeholder(
metadata: Option<&TableInsertColumnMetadata>,
value: &Value,
) -> Option<&'static str> {
metadata.and_then(|metadata| metadata.cast_type(value))
}
async fn load_table_insert_column_metadata(
pool: &PgPool,
schema_name: &str,
table_name: &str,
) -> Result<HashMap<String, TableInsertColumnMetadata>> {
let rows = sqlx::query_as::<_, TableInsertColumnMetadata>(
r#"
SELECT
column_name,
data_type,
udt_name,
column_default,
is_identity
FROM information_schema.columns
WHERE table_schema = $1
AND table_name = $2
"#,
)
.bind(schema_name)
.bind(table_name)
.fetch_all(pool)
.await?;
Ok(rows
.into_iter()
.map(|row| (normalize_identifier_lookup_key(&row.column_name), row))
.collect())
}
#[cfg(test)]
mod tests {
use super::{
TableInsertColumnMetadata, describe_insert_cast_plan, normalize_identifier_lookup_key,
split_qualified_table_name,
};
use serde_json::json;
#[test]
fn split_qualified_table_name_unquotes_segments() {
assert_eq!(
split_qualified_table_name("\"billing\".\"billing_projection_events\""),
("billing", "billing_projection_events")
);
assert_eq!(
split_qualified_table_name("billing.billing_projection_events"),
("billing", "billing_projection_events")
);
}
#[test]
fn normalize_identifier_lookup_key_unquotes_column_names() {
assert_eq!(
normalize_identifier_lookup_key("\"received_at\""),
"received_at"
);
assert_eq!(
normalize_identifier_lookup_key("received_at"),
"received_at"
);
}
#[test]
fn cast_type_uses_data_type_fallback_for_timestamptz() {
let metadata = TableInsertColumnMetadata {
column_name: "received_at".to_string(),
data_type: "timestamp with time zone".to_string(),
udt_name: "timestamp".to_string(),
column_default: None,
is_identity: "NO".to_string(),
};
assert_eq!(
metadata.cast_type(&json!("2026-07-02T10:57:02.000Z")),
Some("timestamptz")
);
}
#[test]
fn describe_insert_cast_plan_marks_raw_and_typed_columns() {
let description = describe_insert_cast_plan(&[
(
"received_at".to_string(),
json!("2026-07-02T10:57:02.000Z"),
Some("timestamptz"),
),
("provider".to_string(), json!("mollie"), None),
]);
assert_eq!(
description,
"received_at:string->timestamptz, provider:string->raw"
);
}
}