#[cfg(feature = "deadpool_experimental")]
use deadpool_postgres::Pool;
#[cfg(feature = "deadpool_experimental")]
use serde_json::{Map, Value};
#[cfg(feature = "deadpool_experimental")]
use std::time::Duration;
#[cfg(feature = "deadpool_experimental")]
use tokio::time::timeout;
#[cfg(feature = "deadpool_experimental")]
use tokio_postgres::types::ToSql;
#[cfg(feature = "deadpool_experimental")]
use crate::parser::query_builder::{
Condition, build_insert_placeholders_for_entries, build_where_clause, sanitize_identifier,
sanitize_qualified_table_identifier,
};
#[cfg(feature = "deadpool_experimental")]
use crate::drivers::postgresql::deadpool_raw_sql::{DeadpoolError, DeadpoolFallbackReason};
#[cfg(feature = "deadpool_experimental")]
fn to_sql_param_set(value: &Value) -> Box<dyn ToSql + Sync> {
match value {
Value::Null => Box::new(Option::<String>::None),
Value::Bool(b) => Box::new(*b),
Value::Number(num) => {
if let Some(i) = num.as_i64() {
Box::new(i)
} else if let Some(f) = num.as_f64() {
Box::new(f)
} else {
Box::new(num.to_string())
}
}
Value::String(s) => {
if let Ok(u) = uuid::Uuid::parse_str(s) {
Box::new(u)
} else {
Box::new(s.clone())
}
}
Value::Array(_) | Value::Object(_) => Box::new(value.clone()),
}
}
#[cfg(feature = "deadpool_experimental")]
fn to_sql_param_where(value: &Value) -> Box<dyn ToSql + Sync> {
match value {
Value::Null => Box::new(Option::<String>::None),
Value::Bool(b) => Box::new(*b),
Value::Number(num) => {
if let Some(i) = num.as_i64() {
Box::new(i)
} else if let Some(f) = num.as_f64() {
Box::new(f)
} else {
Box::new(num.to_string())
}
}
Value::String(s) => Box::new(s.clone()),
Value::Array(_) | Value::Object(_) => Box::new(value.clone()),
}
}
#[cfg(feature = "deadpool_experimental")]
fn order_by_clause(sort: Option<(&str, bool)>) -> String {
sort.and_then(|(col, ascending)| {
sanitize_identifier(col)
.map(|c| format!(" ORDER BY {} {}", c, if ascending { "ASC" } else { "DESC" }))
})
.unwrap_or_default()
}
#[cfg(feature = "deadpool_experimental")]
pub async fn fetch_rows_with_columns_deadpool(
pool: &Pool,
table_name: &str,
columns: &[&str],
conditions: &[Condition],
limit: i64,
offset: i64,
order_by: Option<(&str, bool)>,
checkout_timeout: Duration,
) -> Result<Vec<Value>, DeadpoolError> {
let client = timeout(checkout_timeout, pool.get())
.await
.map_err(|_| DeadpoolError {
reason: DeadpoolFallbackReason::CheckoutTimeout,
is_db_error: false,
message: "deadpool checkout timeout".to_string(),
sql_state: None,
})?
.map_err(|err| DeadpoolError {
reason: DeadpoolFallbackReason::CheckoutFailed,
is_db_error: false,
message: err.to_string(),
sql_state: None,
})?;
let table: String = sanitize_identifier(table_name).ok_or_else(|| DeadpoolError {
reason: DeadpoolFallbackReason::ExecuteFailed,
is_db_error: false,
message: "invalid table name".to_string(),
sql_state: None,
})?;
let (where_clause, where_values) =
build_where_clause(conditions, 1).map_err(|err| DeadpoolError {
reason: DeadpoolFallbackReason::ExecuteFailed,
is_db_error: false,
message: err.to_string(),
sql_state: None,
})?;
let use_all_columns = columns.is_empty() || columns.contains(&"*");
let order_clause = order_by_clause(order_by);
let sql = if use_all_columns {
format!(
"SELECT row_to_json(t.*) AS data FROM {table} AS t{where_clause}{order_clause} LIMIT {limit} OFFSET {offset}",
table = table,
where_clause = where_clause,
order_clause = order_clause,
limit = limit,
offset = offset
)
} else {
let sanitized_cols: Vec<String> = columns
.iter()
.filter_map(|name| sanitize_identifier(name))
.collect();
if sanitized_cols.is_empty() {
return Err(DeadpoolError {
reason: DeadpoolFallbackReason::ExecuteFailed,
is_db_error: false,
message: "no valid columns".to_string(),
sql_state: None,
});
}
let pairs: Vec<String> = columns
.iter()
.zip(sanitized_cols.iter())
.map(|(requested, sanitized)| format!("'{}', t.{}", requested, sanitized))
.collect();
format!(
"SELECT jsonb_build_object({columns}) AS data FROM {table} AS t{where_clause}{order_clause} LIMIT {limit} OFFSET {offset}",
columns = pairs.join(", "),
table = table,
where_clause = where_clause,
order_clause = order_clause,
limit = limit,
offset = offset
)
};
let params: Vec<Box<dyn ToSql + Sync>> = where_values.iter().map(to_sql_param_where).collect();
let param_refs: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| p.as_ref()).collect();
let rows = client
.query(sql.as_str(), ¶m_refs)
.await
.map_err(|err| DeadpoolError {
reason: DeadpoolFallbackReason::ExecuteFailed,
is_db_error: err.as_db_error().is_some(),
message: err.to_string(),
sql_state: err.as_db_error().map(|db| db.code().code().to_string()),
})?;
let mut result = Vec::with_capacity(rows.len());
for row in rows {
let data: Value = row.try_get("data").map_err(|err| DeadpoolError {
reason: DeadpoolFallbackReason::DecodeFailed,
is_db_error: false,
message: err.to_string(),
sql_state: None,
})?;
result.push(data);
}
Ok(result)
}
#[cfg(feature = "deadpool_experimental")]
pub async fn update_rows_deadpool(
pool: &Pool,
table_name: &str,
conditions: &[Condition],
payload: &Value,
checkout_timeout: Duration,
) -> Result<Vec<Value>, DeadpoolError> {
let client = timeout(checkout_timeout, pool.get())
.await
.map_err(|_| DeadpoolError {
reason: DeadpoolFallbackReason::CheckoutTimeout,
is_db_error: false,
message: "deadpool checkout timeout".to_string(),
sql_state: None,
})?
.map_err(|err| DeadpoolError {
reason: DeadpoolFallbackReason::CheckoutFailed,
is_db_error: false,
message: err.to_string(),
sql_state: None,
})?;
let table: String = sanitize_identifier(table_name).ok_or_else(|| DeadpoolError {
reason: DeadpoolFallbackReason::ExecuteFailed,
is_db_error: false,
message: "invalid table name".to_string(),
sql_state: None,
})?;
let entries: Vec<(String, Value)> = payload
.as_object()
.ok_or(DeadpoolError {
reason: DeadpoolFallbackReason::ExecuteFailed,
is_db_error: false,
message: "payload must be object".to_string(),
sql_state: None,
})?
.iter()
.filter_map(|(column, value)| sanitize_identifier(column).map(|s| (s, value.clone())))
.collect();
if entries.is_empty() {
return Err(DeadpoolError {
reason: DeadpoolFallbackReason::ExecuteFailed,
is_db_error: false,
message: "no valid columns".to_string(),
sql_state: None,
});
}
let set_parts: Vec<String> = entries
.iter()
.enumerate()
.map(|(idx, (column, _))| format!("{} = ${}", column, idx + 1))
.collect();
let (where_clause, where_values) =
build_where_clause(conditions, entries.len() + 1).map_err(|err| DeadpoolError {
reason: DeadpoolFallbackReason::ExecuteFailed,
is_db_error: false,
message: err.to_string(),
sql_state: None,
})?;
if where_clause.is_empty() {
return Err(DeadpoolError {
reason: DeadpoolFallbackReason::ExecuteFailed,
is_db_error: false,
message: "missing where clause".to_string(),
sql_state: None,
});
}
let sql = format!(
"UPDATE {table} AS t SET {set_clause}{where_clause} RETURNING to_jsonb(t.*) AS data",
table = table,
set_clause = set_parts.join(", "),
where_clause = where_clause
);
let mut params: Vec<Box<dyn ToSql + Sync>> = Vec::new();
for (_, value) in &entries {
params.push(to_sql_param_set(value));
}
for value in &where_values {
params.push(to_sql_param_where(value));
}
let param_refs: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| p.as_ref()).collect();
let rows = client
.query(sql.as_str(), ¶m_refs)
.await
.map_err(|err| DeadpoolError {
reason: DeadpoolFallbackReason::ExecuteFailed,
is_db_error: err.as_db_error().is_some(),
message: err.to_string(),
sql_state: err.as_db_error().map(|db| db.code().code().to_string()),
})?;
let mut result = Vec::with_capacity(rows.len());
for row in rows {
let data: Value = row.try_get("data").map_err(|err| DeadpoolError {
reason: DeadpoolFallbackReason::DecodeFailed,
is_db_error: false,
message: err.to_string(),
sql_state: None,
})?;
result.push(data);
}
Ok(result)
}
#[cfg(feature = "deadpool_experimental")]
pub async fn insert_row_deadpool(
pool: &Pool,
table_name: &str,
payload: &Value,
checkout_timeout: Duration,
) -> Result<Value, DeadpoolError> {
let client = timeout(checkout_timeout, pool.get())
.await
.map_err(|_| DeadpoolError {
reason: DeadpoolFallbackReason::CheckoutTimeout,
is_db_error: false,
message: "deadpool checkout timeout".to_string(),
sql_state: None,
})?
.map_err(|err| DeadpoolError {
reason: DeadpoolFallbackReason::CheckoutFailed,
is_db_error: false,
message: err.to_string(),
sql_state: None,
})?;
let table: String =
sanitize_qualified_table_identifier(table_name).ok_or_else(|| DeadpoolError {
reason: DeadpoolFallbackReason::ExecuteFailed,
is_db_error: false,
message: "invalid table name".to_string(),
sql_state: None,
})?;
let object: &Map<String, Value> = payload.as_object().ok_or(DeadpoolError {
reason: DeadpoolFallbackReason::ExecuteFailed,
is_db_error: false,
message: "payload must be object".to_string(),
sql_state: None,
})?;
let entries: Vec<(String, Value)> = object
.iter()
.filter_map(|(column, value)| sanitize_identifier(column).map(|s| (s, value.clone())))
.collect();
if entries.is_empty() {
return Err(DeadpoolError {
reason: DeadpoolFallbackReason::ExecuteFailed,
is_db_error: false,
message: "no valid columns".to_string(),
sql_state: None,
});
}
let columns: Vec<&str> = entries.iter().map(|(c, _)| c.as_str()).collect();
let (placeholders, bind_values) = build_insert_placeholders_for_entries(&entries);
let sql = format!(
"INSERT INTO {table} AS t ({columns}) VALUES ({placeholders}) RETURNING to_jsonb(t.*) AS data",
table = table,
columns = columns.join(", "),
placeholders = placeholders.join(", ")
);
let params: Vec<Box<dyn ToSql + Sync>> = bind_values
.iter()
.map(|value| to_sql_param_set(*value))
.collect();
let param_refs: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| &**p).collect();
let row = client
.query_one(sql.as_str(), ¶m_refs)
.await
.map_err(|err| DeadpoolError {
reason: DeadpoolFallbackReason::ExecuteFailed,
is_db_error: err.as_db_error().is_some(),
message: err.to_string(),
sql_state: err.as_db_error().map(|db| db.code().code().to_string()),
})?;
row.try_get("data").map_err(|err| DeadpoolError {
reason: DeadpoolFallbackReason::DecodeFailed,
is_db_error: false,
message: err.to_string(),
sql_state: None,
})
}