#[cfg(feature = "deadpool_experimental")]
use crate::drivers::postgresql::raw_sql::{
PostgresSqlExecutionMode, PostgresSqlExecutionResult, PostgresSqlExecutionSummary,
classify_sql_query, normalize_sql_query,
};
#[cfg(feature = "deadpool_experimental")]
use deadpool_postgres::Pool;
#[cfg(feature = "deadpool_experimental")]
use serde_json::{Value, json};
#[cfg(feature = "deadpool_experimental")]
use tokio::time::{Duration, timeout};
#[cfg(feature = "deadpool_experimental")]
use tokio_postgres::SimpleQueryMessage;
#[cfg(feature = "deadpool_experimental")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DeadpoolFallbackReason {
CheckoutTimeout,
CheckoutFailed,
ExecuteFailed,
DecodeFailed,
}
#[cfg(feature = "deadpool_experimental")]
pub fn deadpool_fallback_reason_label(reason: DeadpoolFallbackReason) -> &'static str {
match reason {
DeadpoolFallbackReason::CheckoutTimeout => "checkout_timeout",
DeadpoolFallbackReason::CheckoutFailed => "checkout_failed",
DeadpoolFallbackReason::ExecuteFailed => "execute_failed",
DeadpoolFallbackReason::DecodeFailed => "decode_failed",
}
}
#[cfg(feature = "deadpool_experimental")]
#[derive(Debug)]
pub struct DeadpoolError {
pub reason: DeadpoolFallbackReason,
pub is_db_error: bool,
pub message: String,
pub sql_state: Option<String>,
}
#[cfg(feature = "deadpool_experimental")]
pub async fn execute_postgres_sql_deadpool(
pool: &Pool,
query: &str,
checkout_timeout: Duration,
) -> Result<PostgresSqlExecutionResult, DeadpoolError> {
let normalized_query = normalize_sql_query(query);
let mode = classify_sql_query(&normalized_query);
let client = timeout(checkout_timeout, pool.get())
.await
.map_err(|_| DeadpoolError {
reason: DeadpoolFallbackReason::CheckoutTimeout,
is_db_error: false,
message: "deadpool checkout timeout".to_string(),
sql_state: None,
})?
.map_err(|err| DeadpoolError {
reason: DeadpoolFallbackReason::CheckoutFailed,
is_db_error: false,
message: err.to_string(),
sql_state: None,
})?;
match mode {
PostgresSqlExecutionMode::JsonRows => execute_json_row_query(&client, &normalized_query)
.await
.map_err(|err| DeadpoolError {
reason: DeadpoolFallbackReason::ExecuteFailed,
is_db_error: err.as_db_error().is_some(),
message: err.to_string(),
sql_state: err.as_db_error().map(|db| db.code().code().to_string()),
}),
PostgresSqlExecutionMode::DirectRows => {
execute_direct_row_query(&client, &normalized_query)
.await
.map_err(|err| DeadpoolError {
reason: DeadpoolFallbackReason::ExecuteFailed,
is_db_error: err.as_db_error().is_some(),
message: err.to_string(),
sql_state: err.as_db_error().map(|db| db.code().code().to_string()),
})
}
PostgresSqlExecutionMode::Command => execute_command_query(&client, &normalized_query)
.await
.map_err(|err| DeadpoolError {
reason: DeadpoolFallbackReason::ExecuteFailed,
is_db_error: err.as_db_error().is_some(),
message: err.to_string(),
sql_state: err.as_db_error().map(|db| db.code().code().to_string()),
}),
}
}
#[cfg(feature = "deadpool_experimental")]
async fn execute_json_row_query(
client: &tokio_postgres::Client,
query: &str,
) -> Result<PostgresSqlExecutionResult, tokio_postgres::Error> {
let wrapped_query = format!(
"WITH athena_query_result AS ({query}) SELECT to_jsonb(athena_query_result) AS row FROM athena_query_result"
);
let rows = client.query(wrapped_query.as_str(), &[]).await?;
let data = rows
.into_iter()
.map(|row| row.get::<_, Value>("row"))
.collect::<Vec<_>>();
Ok(PostgresSqlExecutionResult {
summary: PostgresSqlExecutionSummary {
statement_count: 1,
rows_affected: 0,
returned_row_count: data.len(),
},
rows: data,
})
}
#[cfg(feature = "deadpool_experimental")]
async fn execute_direct_row_query(
client: &tokio_postgres::Client,
query: &str,
) -> Result<PostgresSqlExecutionResult, tokio_postgres::Error> {
let rows = client.query(query, &[]).await?;
let data = rows.into_iter().map(row_to_json).collect::<Vec<_>>();
Ok(PostgresSqlExecutionResult {
summary: PostgresSqlExecutionSummary {
statement_count: 1,
rows_affected: 0,
returned_row_count: data.len(),
},
rows: data,
})
}
#[cfg(feature = "deadpool_experimental")]
async fn execute_command_query(
client: &tokio_postgres::Client,
query: &str,
) -> Result<PostgresSqlExecutionResult, tokio_postgres::Error> {
let messages = client.simple_query(query).await?;
let mut statement_count = 0usize;
let mut rows_affected = 0u64;
for msg in messages {
if let SimpleQueryMessage::CommandComplete(count) = msg {
statement_count += 1;
rows_affected += count;
}
}
Ok(PostgresSqlExecutionResult {
rows: Vec::new(),
summary: PostgresSqlExecutionSummary {
statement_count,
rows_affected,
returned_row_count: 0,
},
})
}
#[cfg(feature = "deadpool_experimental")]
fn row_to_json(row: tokio_postgres::Row) -> Value {
let mut object = serde_json::Map::new();
for (idx, column) in row.columns().iter().enumerate() {
let name = column.name().to_string();
let value = read_column_value(&row, idx);
object.insert(name, value);
}
Value::Object(object)
}
#[cfg(feature = "deadpool_experimental")]
fn read_column_value(row: &tokio_postgres::Row, idx: usize) -> Value {
if let Ok(value) = row.try_get::<usize, Option<Value>>(idx) {
return value.unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<usize, Option<String>>(idx) {
return value.map(Value::String).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<usize, Option<bool>>(idx) {
return value.map(Value::Bool).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<usize, Option<i16>>(idx) {
return value.map(|v| json!(v)).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<usize, Option<i32>>(idx) {
return value.map(|v| json!(v)).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<usize, Option<i64>>(idx) {
return value.map(|v| json!(v)).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<usize, Option<f32>>(idx) {
return value.map(|v| json!(v)).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<usize, Option<f64>>(idx) {
return value.map(|v| json!(v)).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<usize, Option<uuid::Uuid>>(idx) {
return value
.map(|v| Value::String(v.to_string()))
.unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<usize, Option<chrono::NaiveDate>>(idx) {
return value
.map(|v| Value::String(v.to_string()))
.unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<usize, Option<chrono::NaiveTime>>(idx) {
return value
.map(|v| Value::String(v.to_string()))
.unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<usize, Option<chrono::NaiveDateTime>>(idx) {
return value
.map(|v| Value::String(v.to_string()))
.unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<usize, Option<chrono::DateTime<chrono::Utc>>>(idx) {
return value
.map(|v| Value::String(v.to_rfc3339()))
.unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<usize, Option<Vec<u8>>>(idx) {
return value
.map(|v| Value::String(String::from_utf8_lossy(&v).to_string()))
.unwrap_or(Value::Null);
}
Value::String("<unsupported>".to_string())
}