use crate::parser::query_builder::{
Condition, build_insert_placeholders, build_where_clause, sanitize_identifier,
};
use ::sqlx::postgres::{PgArguments, PgPool, PgPoolOptions};
use ::sqlx::types::Json;
use ::sqlx::{Postgres, Row};
use anyhow::{Context, Result, anyhow};
use serde_json::Value;
use sqlx::postgres::PgRow;
use sqlx::query::Query;
use std::collections::HashMap;
use std::convert::TryFrom;
pub struct PostgresClientRegistry {
pools: HashMap<String, PgPool>,
}
impl PostgresClientRegistry {
pub fn empty() -> Self {
Self {
pools: HashMap::new(),
}
}
pub async fn from_entries(entries: Vec<(String, String)>) -> Result<Self> {
let mut pools: HashMap<String, PgPool> = HashMap::new();
for (client_name, uri) in entries {
let pool = PgPoolOptions::new()
.max_connections(50) .min_connections(5) .acquire_timeout(std::time::Duration::from_secs(3))
.idle_timeout(std::time::Duration::from_secs(300))
.max_lifetime(std::time::Duration::from_secs(1800))
.test_before_acquire(false) .connect(&uri)
.await
.with_context(|| format!("failed to connect to postgres client {}", client_name))?;
pools.insert(client_name, pool);
}
Ok(Self { pools })
}
pub fn get_pool(&self, key: &str) -> Option<PgPool> {
self.pools.get(key).cloned()
}
}
pub async fn insert_row(pool: &PgPool, table_name: &str, payload: &Value) -> Result<Value> {
let table = sanitize_identifier(table_name).ok_or_else(|| anyhow!("invalid table name"))?;
let entries = payload
.as_object()
.context("insert payload must be an object")?
.iter()
.filter_map(|(column, value)| {
sanitize_identifier(column).map(|sanitized| (sanitized, value.clone()))
})
.collect::<Vec<_>>();
if entries.is_empty() {
return Err(anyhow!("no valid columns provided for insert"));
}
let columns: Vec<&str> = entries
.iter()
.map(|(column, _)| column.as_str())
.collect::<Vec<_>>();
let values: Vec<&Value> = entries.iter().map(|(_, value)| value).collect();
let (placeholders, bind_values) = build_insert_placeholders(&values);
let sql: String = format!(
"INSERT INTO {table} AS t ({columns}) VALUES ({placeholders}) RETURNING to_jsonb(t.*) AS data",
table = table,
columns = columns.join(", "),
placeholders = placeholders.join(", ")
);
let mut query: Query<'_, Postgres, PgArguments> = ::sqlx::query(&sql);
for value in bind_values {
query = bind_value(query, value);
}
let row: PgRow = query
.fetch_one(pool)
.await
.context("failed to execute insert row")?;
let data: Json<Value> = row
.try_get("data")
.context("missing data column after insert")?;
Ok(data.0)
}
pub async fn update_row(
pool: &PgPool,
table_name: &str,
conditions: &[Condition],
payload: &Value,
) -> Result<Value> {
let table: String =
sanitize_identifier(table_name).ok_or_else(|| anyhow!("invalid table name"))?;
let entries: Vec<(String, Value)> = payload
.as_object()
.context("update payload must be an object")?
.iter()
.filter_map(|(column, value)| {
sanitize_identifier(column).map(|sanitized| (sanitized, value.clone()))
})
.collect::<Vec<_>>();
if entries.is_empty() {
return Err(anyhow!("no valid columns provided for update"));
}
let set_parts: Vec<String> = entries
.iter()
.enumerate()
.map(|(idx, (column, _))| format!("{} = ${}", column, idx + 1))
.collect::<Vec<_>>();
let (where_clause, where_values) = build_where_clause(conditions, entries.len() + 1)?;
if where_clause.is_empty() {
return Err(anyhow!("at least one valid condition is required"));
}
let sql: String = format!(
"UPDATE {table} AS t SET {set_clause}{where_clause} RETURNING to_jsonb(t.*) AS data",
table = table,
set_clause = set_parts.join(", "),
where_clause = where_clause
);
let mut query: Query<'_, Postgres, PgArguments> = ::sqlx::query(&sql);
for (_, value) in &entries {
query = bind_value(query, value);
}
for value in &where_values {
query = bind_value(query, value);
}
let row: PgRow = query
.fetch_one(pool)
.await
.context("failed to execute update row")?;
let data: Json<Value> = row
.try_get("data")
.context("missing data column after update")?;
Ok(data.0)
}
pub async fn fetch_rows(
pool: &PgPool,
table_name: &str,
conditions: &[Condition],
limit: i64,
offset: i64,
) -> Result<Vec<Value>> {
let table: String =
sanitize_identifier(table_name).ok_or_else(|| anyhow!("invalid table name"))?;
let (where_clause, where_values) = build_where_clause(conditions, 1)?;
let sql: String = format!(
"SELECT row_to_json(t.*) AS data FROM {table} AS t{where_clause} LIMIT {limit} OFFSET {offset}",
table = table,
where_clause = where_clause,
limit = limit,
offset = offset
);
let mut query: Query<'_, Postgres, PgArguments> = ::sqlx::query(&sql);
for value in &where_values {
query = bind_value(query, value);
}
let rows: Vec<PgRow> = query
.fetch_all(pool)
.await
.with_context(|| format!("failed to execute select query: {}", sql))?;
let mut result: Vec<Value> = Vec::new();
for row in rows {
let data: Json<Value> = row
.try_get("data")
.context("missing data column in select result")?;
result.push(data.0);
}
Ok(result)
}
fn bind_value<'q>(
query: Query<'q, Postgres, PgArguments>,
value: &Value,
) -> Query<'q, Postgres, PgArguments> {
match value {
Value::Null => query.bind(None::<String>),
Value::Bool(b) => query.bind(*b),
Value::Number(num) => {
if let Some(i) = num.as_i64() {
query.bind(i)
} else if let Some(f) = num.as_f64() {
query.bind(f)
} else if let Some(u) = num.as_u64() {
if let Ok(i) = i64::try_from(u) {
query.bind(i)
} else {
query.bind(num.to_string())
}
} else {
query.bind(num.to_string())
}
}
Value::String(s) => query.bind(s.clone()),
Value::Array(_) | Value::Object(_) => query.bind(Json(value.clone())),
}
}