use async_trait::async_trait;
use base64::Engine as _;
use deadpool_postgres::{Config, Pool, Runtime};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio_postgres::types::{FromSql, ToSql, Type};
use tokio_postgres::NoTls;
use crate::context::ExecutionContext;
use crate::error::ToolError;
use crate::registry::{Tool, ToolConfig};
use crate::result::ToolResult;
use crate::template::TemplateEngine;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PostgresConfig {
#[serde(alias = "command")]
pub query: String,
#[serde(default)]
pub params: Vec<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub connection_string: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub host: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub port: Option<u16>,
#[serde(skip_serializing_if = "Option::is_none")]
pub database: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub password: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub schema: Option<String>,
#[serde(default = "default_as_objects")]
pub as_objects: bool,
}
fn default_as_objects() -> bool {
true
}
fn match_dollar_tag(chars: &[char], i: usize) -> Option<String> {
if chars.get(i) != Some(&'$') {
return None;
}
let mut j = i + 1;
if let Some(&first) = chars.get(j) {
if first.is_ascii_digit() {
return None;
}
}
while let Some(&c) = chars.get(j) {
if c.is_alphanumeric() || c == '_' {
j += 1;
} else {
break;
}
}
if chars.get(j) == Some(&'$') {
Some(chars[i..=j].iter().collect())
} else {
None
}
}
fn split_sql_statements(sql: &str) -> Vec<String> {
let chars: Vec<char> = sql.chars().collect();
let mut statements = Vec::new();
let mut current = String::new();
let mut in_single = false;
let mut dollar_tag: Option<String> = None;
let mut i = 0;
while i < chars.len() {
let c = chars[i];
if let Some(tag) = &dollar_tag {
if c == '$' {
if let Some(close) = match_dollar_tag(&chars, i) {
if &close == tag {
current.push_str(&close);
i += close.chars().count();
dollar_tag = None;
continue;
}
}
}
current.push(c);
i += 1;
continue;
}
match c {
'\'' => {
if in_single && chars.get(i + 1) == Some(&'\'') {
current.push('\'');
current.push('\'');
i += 2;
continue;
}
in_single = !in_single;
current.push(c);
i += 1;
}
'$' if !in_single => {
if let Some(open) = match_dollar_tag(&chars, i) {
current.push_str(&open);
i += open.chars().count();
dollar_tag = Some(open);
} else {
current.push(c);
i += 1;
}
}
'-' if !in_single && chars.get(i + 1) == Some(&'-') => {
while i < chars.len() && chars[i] != '\n' {
current.push(chars[i]);
i += 1;
}
}
';' if !in_single => {
if !current.trim().is_empty() {
statements.push(current.trim().to_string());
}
current.clear();
i += 1;
}
_ => {
current.push(c);
i += 1;
}
}
}
if !current.trim().is_empty() {
statements.push(current.trim().to_string());
}
statements
}
pub struct PostgresTool {
pools: Arc<RwLock<HashMap<String, Pool>>>,
template_engine: TemplateEngine,
}
impl PostgresTool {
pub fn new() -> Self {
Self {
pools: Arc::new(RwLock::new(HashMap::new())),
template_engine: TemplateEngine::new(),
}
}
async fn get_pool(&self, connection_string: &str) -> Result<Pool, ToolError> {
{
let pools = self.pools.read().await;
if let Some(pool) = pools.get(connection_string) {
return Ok(pool.clone());
}
}
let mut config = Config::new();
config.url = Some(connection_string.to_string());
let pool = config
.create_pool(Some(Runtime::Tokio1), NoTls)
.map_err(|e| ToolError::Database(format!("Failed to create pool: {}", e)))?;
{
let mut pools = self.pools.write().await;
pools.insert(connection_string.to_string(), pool.clone());
}
Ok(pool)
}
fn build_connection_string(
&self,
config: &PostgresConfig,
ctx: &ExecutionContext,
) -> Result<String, ToolError> {
if let Some(ref conn_str) = config.connection_string {
return Ok(conn_str.clone());
}
let host = config.host.as_deref().unwrap_or("localhost");
let port = config.port.unwrap_or(5432);
let database = config.database.as_deref().unwrap_or("postgres");
let user = config.user.as_deref().unwrap_or("postgres");
let password = if let Some(ref pw) = config.password {
ctx.get_secret(pw)
.map(|s| s.to_string())
.unwrap_or_else(|| pw.clone())
} else {
String::new()
};
let conn_str = if password.is_empty() {
format!("postgresql://{}@{}:{}/{}", user, host, port, database)
} else {
format!(
"postgresql://{}:{}@{}:{}/{}",
user, password, host, port, database
)
};
Ok(conn_str)
}
pub async fn execute_query(
&self,
query: &str,
params: &[serde_json::Value],
connection_string: &str,
schema: Option<&str>,
as_objects: bool,
) -> Result<ToolResult, ToolError> {
let start = std::time::Instant::now();
let pool = self.get_pool(connection_string).await?;
let client = pool
.get()
.await
.map_err(|e| ToolError::Database(format!("Failed to get connection: {e}")))?;
if let Some(schema) = schema {
client
.execute(&format!("SET search_path TO {}", schema), &[])
.await
.map_err(|e| ToolError::Database(format!("Failed to set schema: {}", e)))?;
}
let pg_params: Vec<Box<dyn ToSql + Sync + Send>> =
params.iter().map(|v| json_to_pg_param(v)).collect();
let param_refs: Vec<&(dyn ToSql + Sync)> = pg_params
.iter()
.map(|p| p.as_ref() as &(dyn ToSql + Sync))
.collect();
let statements = if params.is_empty() {
split_sql_statements(query)
} else {
vec![query.to_string()]
};
let effective_query: String = if statements.len() > 1 {
let (last, leading) = statements.split_last().unwrap();
let leading_sql = format!("{};", leading.join(";\n"));
client
.batch_execute(&leading_sql)
.await
.map_err(|e| ToolError::Database(format_pg_error("Batch execute failed", &e)))?;
last.clone()
} else {
query.to_string()
};
let query = effective_query.as_str();
let is_select = query.trim().to_uppercase().starts_with("SELECT")
|| query.trim().to_uppercase().starts_with("WITH");
let result = if is_select {
let rows = client
.query(query, ¶m_refs)
.await
.map_err(|e| ToolError::Database(format_pg_error("Query failed", &e)))?;
if rows.is_empty() {
serde_json::json!({
"columns": [],
"rows": [],
"row_count": 0
})
} else {
let columns: Vec<String> = rows[0]
.columns()
.iter()
.map(|c| c.name().to_string())
.collect();
let json_rows: Vec<serde_json::Value> = rows
.iter()
.map(|row| {
if as_objects {
let mut obj = serde_json::Map::new();
for (i, col) in row.columns().iter().enumerate() {
let value = pg_value_to_json(row, i);
obj.insert(col.name().to_string(), value);
}
serde_json::Value::Object(obj)
} else {
let values: Vec<serde_json::Value> = (0..row.columns().len())
.map(|i| pg_value_to_json(row, i))
.collect();
serde_json::Value::Array(values)
}
})
.collect();
serde_json::json!({
"columns": columns,
"rows": json_rows,
"row_count": json_rows.len()
})
}
} else {
let affected = client
.execute(query, ¶m_refs)
.await
.map_err(|e| ToolError::Database(format_pg_error("Execute failed", &e)))?;
serde_json::json!({
"affected_rows": affected
})
};
let duration_ms = start.elapsed().as_millis() as u64;
Ok(ToolResult::success(result).with_duration(duration_ms))
}
fn parse_config(
&self,
config: &ToolConfig,
ctx: &ExecutionContext,
) -> Result<PostgresConfig, ToolError> {
let template_ctx = ctx.to_template_context();
let rendered_config = self
.template_engine
.render_value(&config.config, &template_ctx)?;
serde_json::from_value(rendered_config)
.map_err(|e| ToolError::Configuration(format!("Invalid postgres config: {}", e)))
}
}
impl Default for PostgresTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for PostgresTool {
fn name(&self) -> &'static str {
"postgres"
}
async fn execute(
&self,
config: &ToolConfig,
ctx: &ExecutionContext,
) -> Result<ToolResult, ToolError> {
let pg_config = self.parse_config(config, ctx)?;
let connection_string = self.build_connection_string(&pg_config, ctx)?;
tracing::debug!(
query = %pg_config.query,
params_count = pg_config.params.len(),
schema = ?pg_config.schema,
"Executing PostgreSQL query"
);
self.execute_query(
&pg_config.query,
&pg_config.params,
&connection_string,
pg_config.schema.as_deref(),
pg_config.as_objects,
)
.await
}
}
fn format_pg_error(context: &str, e: &tokio_postgres::Error) -> String {
if let Some(db) = e.as_db_error() {
let mut msg = format!(
"{}: {}: {} (SQLSTATE {})",
context,
db.severity(),
db.message(),
db.code().code()
);
if let Some(detail) = db.detail() {
msg.push_str(&format!(" | DETAIL: {detail}"));
}
if let Some(hint) = db.hint() {
msg.push_str(&format!(" | HINT: {hint}"));
}
msg
} else {
let mut msg = format!("{context}: {e}");
let mut src = std::error::Error::source(e);
while let Some(inner) = src {
msg.push_str(&format!(": {inner}"));
src = inner.source();
}
msg
}
}
fn json_to_pg_param(value: &serde_json::Value) -> Box<dyn ToSql + Sync + Send> {
match value {
serde_json::Value::Null => Box::new(Option::<String>::None),
serde_json::Value::Bool(b) => Box::new(*b),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
Box::new(i)
} else if let Some(f) = n.as_f64() {
Box::new(f)
} else {
Box::new(n.to_string())
}
}
serde_json::Value::String(s) => Box::new(s.clone()),
_ => Box::new(value.to_string()),
}
}
fn naive_datetime_to_json(dt: chrono::NaiveDateTime) -> serde_json::Value {
serde_json::json!(dt.format("%Y-%m-%dT%H:%M:%S%.f").to_string())
}
fn naive_date_to_json(d: chrono::NaiveDate) -> serde_json::Value {
serde_json::json!(d.to_string())
}
fn naive_time_to_json(t: chrono::NaiveTime) -> serde_json::Value {
serde_json::json!(t.to_string())
}
fn decode_pg_numeric(buf: &[u8]) -> Option<String> {
if buf.len() < 8 {
return None;
}
let ndigits = i16::from_be_bytes([buf[0], buf[1]]);
let weight = i16::from_be_bytes([buf[2], buf[3]]);
let sign = u16::from_be_bytes([buf[4], buf[5]]);
let dscale = u16::from_be_bytes([buf[6], buf[7]]) as usize;
match sign {
0xC000 => return Some("NaN".to_string()),
0xD000 => return Some("Infinity".to_string()),
0xF000 => return Some("-Infinity".to_string()),
_ => {}
}
let ndigits = ndigits as usize;
if buf.len() < 8 + ndigits * 2 {
return None;
}
let digits: Vec<i16> = (0..ndigits)
.map(|i| {
let off = 8 + i * 2;
i16::from_be_bytes([buf[off], buf[off + 1]])
})
.collect();
let mut int_str = String::new();
if weight < 0 {
int_str.push('0');
} else {
for i in 0..=weight {
let g = digits.get(i as usize).copied().unwrap_or(0);
if i == 0 {
int_str.push_str(&g.to_string());
} else {
int_str.push_str(&format!("{g:04}"));
}
}
}
let mut frac_str = String::new();
let start_idx = if weight >= 0 {
(weight + 1) as usize
} else {
for _ in 0..(-weight - 1) {
frac_str.push_str("0000");
}
0
};
for d in digits.iter().skip(start_idx) {
frac_str.push_str(&format!("{d:04}"));
}
if frac_str.len() < dscale {
frac_str.push_str(&"0".repeat(dscale - frac_str.len()));
}
frac_str.truncate(dscale);
let sign_str = if sign == 0x4000 { "-" } else { "" };
if dscale == 0 {
Some(format!("{sign_str}{int_str}"))
} else {
Some(format!("{sign_str}{int_str}.{frac_str}"))
}
}
struct PgNumericString(String);
impl<'a> FromSql<'a> for PgNumericString {
fn from_sql(
_ty: &Type,
raw: &'a [u8],
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
decode_pg_numeric(raw)
.map(PgNumericString)
.ok_or_else(|| "failed to decode postgres numeric wire format".into())
}
fn accepts(ty: &Type) -> bool {
matches!(*ty, Type::NUMERIC)
}
}
fn pg_value_to_json(row: &tokio_postgres::Row, idx: usize) -> serde_json::Value {
if let Ok(v) = row.try_get::<_, Option<i64>>(idx) {
return v
.map(|n| serde_json::json!(n))
.unwrap_or(serde_json::Value::Null);
}
if let Ok(v) = row.try_get::<_, Option<i32>>(idx) {
return v
.map(|n| serde_json::json!(n))
.unwrap_or(serde_json::Value::Null);
}
if let Ok(v) = row.try_get::<_, Option<f64>>(idx) {
return v
.map(|n| serde_json::json!(n))
.unwrap_or(serde_json::Value::Null);
}
if let Ok(v) = row.try_get::<_, Option<bool>>(idx) {
return v
.map(|b| serde_json::json!(b))
.unwrap_or(serde_json::Value::Null);
}
if let Ok(v) = row.try_get::<_, Option<String>>(idx) {
return v
.map(|s| serde_json::json!(s))
.unwrap_or(serde_json::Value::Null);
}
if let Ok(v) = row.try_get::<_, Option<serde_json::Value>>(idx) {
return v.unwrap_or(serde_json::Value::Null);
}
if let Ok(v) = row.try_get::<_, Option<chrono::DateTime<chrono::Utc>>>(idx) {
return v
.map(|dt| serde_json::json!(dt.to_rfc3339()))
.unwrap_or(serde_json::Value::Null);
}
if let Ok(v) = row.try_get::<_, Option<chrono::NaiveDateTime>>(idx) {
return v.map(naive_datetime_to_json).unwrap_or(serde_json::Value::Null);
}
if let Ok(v) = row.try_get::<_, Option<chrono::NaiveDate>>(idx) {
return v.map(naive_date_to_json).unwrap_or(serde_json::Value::Null);
}
if let Ok(v) = row.try_get::<_, Option<chrono::NaiveTime>>(idx) {
return v.map(naive_time_to_json).unwrap_or(serde_json::Value::Null);
}
if let Ok(v) = row.try_get::<_, Option<uuid::Uuid>>(idx) {
return v
.map(|u| serde_json::json!(u.to_string()))
.unwrap_or(serde_json::Value::Null);
}
if let Ok(v) = row.try_get::<_, Option<PgNumericString>>(idx) {
return v
.map(|n| serde_json::json!(n.0))
.unwrap_or(serde_json::Value::Null);
}
if let Ok(v) = row.try_get::<_, Option<Vec<u8>>>(idx) {
return v
.map(|b| serde_json::json!(base64::engine::general_purpose::STANDARD.encode(b)))
.unwrap_or(serde_json::Value::Null);
}
serde_json::Value::Null
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_postgres_config_deserialization() {
let json = serde_json::json!({
"query": "SELECT * FROM users WHERE id = $1",
"params": [42],
"connection_string": "postgresql://user:pass@localhost/db"
});
let config: PostgresConfig = serde_json::from_value(json).unwrap();
assert_eq!(config.query, "SELECT * FROM users WHERE id = $1");
assert_eq!(config.params.len(), 1);
assert!(config.connection_string.is_some());
}
#[test]
fn test_temporal_value_to_json() {
use chrono::{NaiveDate, NaiveTime};
let dt = NaiveDate::from_ymd_opt(2026, 6, 14)
.unwrap()
.and_hms_opt(12, 30, 45)
.unwrap();
assert_eq!(
naive_datetime_to_json(dt),
serde_json::json!("2026-06-14T12:30:45")
);
let dt_frac = NaiveDate::from_ymd_opt(2026, 6, 14)
.unwrap()
.and_hms_milli_opt(12, 30, 45, 500)
.unwrap();
assert_eq!(
naive_datetime_to_json(dt_frac),
serde_json::json!("2026-06-14T12:30:45.500")
);
assert_eq!(
naive_date_to_json(NaiveDate::from_ymd_opt(2026, 6, 14).unwrap()),
serde_json::json!("2026-06-14")
);
assert_eq!(
naive_time_to_json(NaiveTime::from_hms_opt(23, 59, 1).unwrap()),
serde_json::json!("23:59:01")
);
}
fn pg_numeric_bytes(weight: i16, sign: u16, dscale: u16, digits: &[i16]) -> Vec<u8> {
let mut b = Vec::new();
b.extend_from_slice(&(digits.len() as i16).to_be_bytes());
b.extend_from_slice(&weight.to_be_bytes());
b.extend_from_slice(&sign.to_be_bytes());
b.extend_from_slice(&dscale.to_be_bytes());
for d in digits {
b.extend_from_slice(&d.to_be_bytes());
}
b
}
#[test]
fn test_decode_pg_numeric() {
assert_eq!(decode_pg_numeric(&pg_numeric_bytes(0, 0, 0, &[])).unwrap(), "0");
assert_eq!(decode_pg_numeric(&pg_numeric_bytes(0, 0, 0, &[1])).unwrap(), "1");
assert_eq!(
decode_pg_numeric(&pg_numeric_bytes(0, 0x4000, 0, &[1])).unwrap(),
"-1"
);
assert_eq!(
decode_pg_numeric(&pg_numeric_bytes(0, 0, 4, &[1234, 5678])).unwrap(),
"1234.5678"
);
assert_eq!(
decode_pg_numeric(&pg_numeric_bytes(1, 0, 0, &[10])).unwrap(),
"100000"
);
assert_eq!(
decode_pg_numeric(&pg_numeric_bytes(-1, 0, 1, &[5000])).unwrap(),
"0.5"
);
assert_eq!(
decode_pg_numeric(&pg_numeric_bytes(-2, 0, 5, &[5000])).unwrap(),
"0.00005"
);
assert_eq!(
decode_pg_numeric(&pg_numeric_bytes(0, 0, 2, &[100])).unwrap(),
"100.00"
);
assert_eq!(
decode_pg_numeric(&pg_numeric_bytes(0, 0xC000, 0, &[])).unwrap(),
"NaN"
);
assert!(decode_pg_numeric(&[0, 1, 0, 0]).is_none());
}
#[test]
fn test_postgres_config_command_alias() {
let json = serde_json::json!({
"command": "DROP TABLE IF EXISTS t; CREATE TABLE t (id INT);"
});
let config: PostgresConfig = serde_json::from_value(json).unwrap();
assert_eq!(
config.query,
"DROP TABLE IF EXISTS t; CREATE TABLE t (id INT);"
);
}
#[test]
fn test_split_sql_statements() {
assert_eq!(split_sql_statements("SELECT 1").len(), 1);
assert_eq!(split_sql_statements("SELECT 1;").len(), 1);
let s = split_sql_statements(
"DROP TABLE IF EXISTS t; CREATE TABLE t (id INT); INSERT INTO t VALUES (1);",
);
assert_eq!(s.len(), 3);
assert!(s[0].starts_with("DROP"));
let s = split_sql_statements("INSERT INTO t VALUES ('a;b'); SELECT 1");
assert_eq!(s.len(), 2);
assert!(s[0].contains("'a;b'"));
let s = split_sql_statements(
"CREATE FUNCTION f() RETURNS void AS $$ BEGIN PERFORM 1; PERFORM 2; END; $$ LANGUAGE plpgsql; SELECT f();",
);
assert_eq!(s.len(), 2);
assert!(s[0].contains("$$ BEGIN PERFORM 1; PERFORM 2; END; $$"));
assert!(s[1].starts_with("SELECT f()"));
let s = split_sql_statements("DO $do$ BEGIN; END $do$; SELECT 1");
assert_eq!(s.len(), 2);
let s = split_sql_statements("UPDATE t SET a = $1 WHERE id = 2; SELECT 1");
assert_eq!(s.len(), 2);
let s = split_sql_statements(
"INSERT INTO t VALUES (1);\n-- reset this facility's rows\nDELETE FROM t;\nSELECT count(*) FROM t;",
);
assert_eq!(s.len(), 3, "{s:?}");
assert!(s[0].starts_with("INSERT"));
assert!(s[1].contains("DELETE FROM t"));
assert!(s[2].starts_with("-- reset") || s[2].contains("SELECT count"));
}
#[test]
fn test_postgres_config_with_components() {
let json = serde_json::json!({
"query": "SELECT 1",
"host": "db.example.com",
"port": 5433,
"database": "mydb",
"user": "admin",
"schema": "public"
});
let config: PostgresConfig = serde_json::from_value(json).unwrap();
assert_eq!(config.host, Some("db.example.com".to_string()));
assert_eq!(config.port, Some(5433));
assert_eq!(config.database, Some("mydb".to_string()));
}
#[test]
fn test_postgres_config_defaults() {
let json = serde_json::json!({
"query": "SELECT 1"
});
let config: PostgresConfig = serde_json::from_value(json).unwrap();
assert!(config.params.is_empty());
assert!(config.connection_string.is_none());
assert!(config.as_objects);
}
#[test]
fn test_build_connection_string() {
let tool = PostgresTool::new();
let ctx = ExecutionContext::default();
let config = PostgresConfig {
query: "SELECT 1".to_string(),
params: vec![],
connection_string: None,
host: Some("localhost".to_string()),
port: Some(5432),
database: Some("testdb".to_string()),
user: Some("testuser".to_string()),
password: Some("testpass".to_string()),
schema: None,
as_objects: true,
};
let conn_str = tool.build_connection_string(&config, &ctx).unwrap();
assert!(conn_str.contains("localhost"));
assert!(conn_str.contains("testdb"));
assert!(conn_str.contains("testuser"));
}
#[tokio::test]
async fn test_postgres_tool_interface() {
let tool = PostgresTool::new();
assert_eq!(tool.name(), "postgres");
}
}