use super::config::{DangerousDmlPolicy, SelectWithoutLimitPolicy, handle_dangerous_dml};
use super::statement_cache::{StmtCacheProbe, is_retryable_prepared_error};
use crate::GenericClient;
use crate::error::{OrmError, OrmResult};
use crate::monitor::{HookAction, QueryContext, QueryMonitor, QueryResult, QueryType};
use crate::row::FromRow;
use std::time::{Duration, Instant};
use tokio_postgres::Row;
use tokio_postgres::types::ToSql;
impl<C: GenericClient> super::PgClient<C> {
#[cfg(not(feature = "tracing"))]
pub(super) fn emit_tracing_sql(&self, _ctx: &QueryContext) {}
#[cfg(feature = "tracing")]
pub(super) fn emit_tracing_sql(&self, ctx: &QueryContext) {
if let Some(hook) = &self.tracing_sql_hook {
let _ = hook.before_query(ctx);
}
}
pub(super) fn apply_sql_policy(&self, ctx: &mut QueryContext) -> OrmResult<()> {
use crate::check::StatementKind;
let policy = &self.config.sql_policy;
if policy.select_without_limit == SelectWithoutLimitPolicy::Allow
&& policy.delete_without_where == DangerousDmlPolicy::Allow
&& policy.update_without_where == DangerousDmlPolicy::Allow
&& policy.truncate == DangerousDmlPolicy::Allow
&& policy.drop_table == DangerousDmlPolicy::Allow
{
return Ok(());
}
let analysis = self.registry.analyze_sql(&ctx.canonical_sql);
if !analysis.parse_result.valid {
return Ok(());
}
match analysis.statement_kind {
Some(StatementKind::Select) => {
if analysis.select_has_limit == Some(false) {
match policy.select_without_limit {
SelectWithoutLimitPolicy::Allow => {}
SelectWithoutLimitPolicy::Warn => {
crate::error::pgorm_warn(&format!(
"[pgorm warn] SQL policy: SELECT without LIMIT/OFFSET: {}",
ctx.canonical_sql
));
}
SelectWithoutLimitPolicy::Error => {
return Err(OrmError::validation(format!(
"SQL policy violation: SELECT without LIMIT/OFFSET: {}",
ctx.canonical_sql
)));
}
SelectWithoutLimitPolicy::AutoLimit(limit) => {
let old_canonical = ctx.canonical_sql.clone();
match pgorm_check::ensure_select_limit(&old_canonical, limit) {
Ok(Some(new_sql)) => {
ctx.canonical_sql = new_sql.clone();
ctx.query_type = QueryType::from_sql(&ctx.canonical_sql);
if ctx.exec_sql == old_canonical {
ctx.exec_sql = new_sql;
} else if let Some(pos) = ctx.exec_sql.rfind(&old_canonical) {
let mut rewritten = String::with_capacity(
ctx.exec_sql.len() - old_canonical.len()
+ ctx.canonical_sql.len(),
);
rewritten.push_str(&ctx.exec_sql[..pos]);
rewritten.push_str(&ctx.canonical_sql);
rewritten
.push_str(&ctx.exec_sql[pos + old_canonical.len()..]);
ctx.exec_sql = rewritten;
} else {
ctx.exec_sql = ctx.canonical_sql.clone();
}
}
Ok(None) => {
return Err(OrmError::validation(format!(
"SQL policy rewrite failed: unable to add LIMIT to: {}",
ctx.canonical_sql
)));
}
Err(e) => return Err(OrmError::validation(e.to_string())),
}
}
}
}
}
Some(StatementKind::Delete) => {
if analysis.delete_has_where == Some(false) {
handle_dangerous_dml(
policy.delete_without_where,
"DELETE without WHERE",
&ctx.canonical_sql,
)?;
}
}
Some(StatementKind::Update) => {
if analysis.update_has_where == Some(false) {
handle_dangerous_dml(
policy.update_without_where,
"UPDATE without WHERE",
&ctx.canonical_sql,
)?;
}
}
Some(StatementKind::Truncate) => {
handle_dangerous_dml(policy.truncate, "TRUNCATE", &ctx.canonical_sql)?;
}
Some(StatementKind::DropTable) => {
handle_dangerous_dml(policy.drop_table, "DROP TABLE", &ctx.canonical_sql)?;
}
_ => {}
}
Ok(())
}
pub(super) fn check_sql(&self, sql: &str) -> OrmResult<()> {
let issues = self.registry.check_sql(sql);
crate::checked_client::handle_check_issues(self.config.check_mode, issues, "SQL check")
}
pub(super) fn apply_hook(&self, ctx: &mut QueryContext) -> Result<(), OrmError> {
if let Some(hook) = &self.hook {
match hook.before_query(ctx) {
HookAction::Continue => Ok(()),
HookAction::ModifySql {
exec_sql,
canonical_sql,
} => {
ctx.exec_sql = exec_sql;
if let Some(canonical_sql) = canonical_sql {
ctx.canonical_sql = canonical_sql;
}
ctx.query_type = QueryType::from_sql(&ctx.canonical_sql);
Ok(())
}
HookAction::Abort(reason) => Err(OrmError::validation(format!(
"Query aborted by hook: {reason}"
))),
}
} else {
Ok(())
}
}
pub(super) fn report_result(
&self,
ctx: &QueryContext,
duration: Duration,
result: &QueryResult,
) {
if self.config.stats_enabled {
self.stats.on_query_complete(ctx, duration, result);
}
if let Some(ref logging) = self.logging_monitor {
logging.on_query_complete(ctx, duration, result);
}
if let Some(ref monitor) = self.custom_monitor {
monitor.on_query_complete(ctx, duration, result);
}
if let Some(threshold) = self.config.slow_query_threshold {
if duration > threshold {
if let Some(ref logging) = self.logging_monitor {
logging.on_slow_query(ctx, duration);
}
if let Some(ref monitor) = self.custom_monitor {
monitor.on_slow_query(ctx, duration);
}
}
}
if let Some(ref hook) = self.hook {
hook.after_query(ctx, duration, result);
}
}
pub(super) async fn execute_with_timeout<T, F>(&self, future: F) -> OrmResult<T>
where
F: std::future::Future<Output = OrmResult<T>> + Send,
{
match self.config.query_timeout {
Some(timeout) => {
tokio::pin!(future);
tokio::select! {
result = &mut future => result,
_ = tokio::time::sleep(timeout) => {
if let Some(cancel_token) = self.client.cancel_token() {
tokio::spawn(async move {
let _ = cancel_token.cancel_query(tokio_postgres::NoTls).await;
});
}
Err(OrmError::Timeout(timeout))
}
}
}
None => future.await,
}
}
pub(super) fn probe_stmt_cache(&self, ctx: &QueryContext) -> StmtCacheProbe {
if !self.config.statement_cache.enabled {
return StmtCacheProbe::Disabled;
}
let Some(cache) = &self.statement_cache else {
return StmtCacheProbe::Disabled;
};
if !self.client.supports_prepared_statements() {
return StmtCacheProbe::Disabled;
}
if ctx.exec_sql != ctx.canonical_sql {
return StmtCacheProbe::Disabled;
}
match cache.get(&ctx.canonical_sql) {
Some(stmt) => StmtCacheProbe::Hit(stmt),
None => StmtCacheProbe::Miss,
}
}
}
impl<C: GenericClient> super::PgClient<C> {
pub async fn sql_query_as<T: FromRow>(
&self,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<Vec<T>> {
let rows = self.query(sql, params).await?;
rows.iter().map(T::from_row).collect()
}
pub async fn sql_query_one_as<T: FromRow>(
&self,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<T> {
let row = self.query_one(sql, params).await?;
T::from_row(&row)
}
pub async fn sql_query_one_strict_as<T: FromRow>(
&self,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<T> {
let row = self.query_one_strict(sql, params).await?;
T::from_row(&row)
}
pub async fn sql_query_opt_as<T: FromRow>(
&self,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<Option<T>> {
let row = self.query_opt(sql, params).await?;
row.as_ref().map(T::from_row).transpose()
}
pub async fn sql_execute(&self, sql: &str, params: &[&(dyn ToSql + Sync)]) -> OrmResult<u64> {
self.execute(sql, params).await
}
pub async fn sql_query(
&self,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<Vec<Row>> {
self.query(sql, params).await
}
pub async fn sql_query_one(&self, sql: &str, params: &[&(dyn ToSql + Sync)]) -> OrmResult<Row> {
self.query_one(sql, params).await
}
pub async fn sql_query_one_strict(
&self,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<Row> {
self.query_one_strict(sql, params).await
}
pub async fn sql_query_opt(
&self,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<Option<Row>> {
self.query_opt(sql, params).await
}
}
impl<C: GenericClient> GenericClient for super::PgClient<C> {
async fn query(&self, sql: &str, params: &[&(dyn ToSql + Sync)]) -> OrmResult<Vec<Row>> {
self.query_impl(None, sql, params).await
}
async fn query_tagged(
&self,
tag: &str,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<Vec<Row>> {
self.query_impl(Some(tag), sql, params).await
}
async fn query_one(&self, sql: &str, params: &[&(dyn ToSql + Sync)]) -> OrmResult<Row> {
self.query_one_impl(None, sql, params).await
}
async fn query_one_tagged(
&self,
tag: &str,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<Row> {
self.query_one_impl(Some(tag), sql, params).await
}
async fn query_opt(&self, sql: &str, params: &[&(dyn ToSql + Sync)]) -> OrmResult<Option<Row>> {
self.query_opt_impl(None, sql, params).await
}
async fn query_opt_tagged(
&self,
tag: &str,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<Option<Row>> {
self.query_opt_impl(Some(tag), sql, params).await
}
async fn execute(&self, sql: &str, params: &[&(dyn ToSql + Sync)]) -> OrmResult<u64> {
self.execute_impl(None, sql, params).await
}
async fn execute_tagged(
&self,
tag: &str,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<u64> {
self.execute_impl(Some(tag), sql, params).await
}
fn cancel_token(&self) -> Option<tokio_postgres::CancelToken> {
self.client.cancel_token()
}
}
macro_rules! stmt_cache_dispatch {
($self:expr, $ctx:expr, $params:expr, $probe:expr,
$unprepared:ident, $prepared:ident) => {
match $probe {
StmtCacheProbe::Disabled => {
$self
.execute_with_timeout($self.client.$unprepared(&$ctx.exec_sql, $params))
.await
}
StmtCacheProbe::Hit(stmt) => {
if $self.config.stats_enabled {
$self.stats.on_stmt_cache_hit();
}
let mut result = $self
.execute_with_timeout($self.client.$prepared(&stmt, $params))
.await;
if let Err(ref err) = result {
if is_retryable_prepared_error(err) {
if let Some(cache) = &$self.statement_cache {
let _ = cache.remove(&$ctx.canonical_sql);
}
if let Some(cache) = &$self.statement_cache {
let prep_start = Instant::now();
let stmt = $self
.execute_with_timeout(
$self.client.prepare_statement(&$ctx.canonical_sql),
)
.await;
let prep_dur = prep_start.elapsed();
if $self.config.stats_enabled {
$self.stats.on_stmt_prepare(prep_dur);
}
let stmt = cache.insert_if_absent($ctx.canonical_sql.clone(), stmt?);
result = $self
.execute_with_timeout($self.client.$prepared(&stmt, $params))
.await;
}
}
}
result
}
StmtCacheProbe::Miss => {
if $self.config.stats_enabled {
$self.stats.on_stmt_cache_miss();
}
match &$self.statement_cache {
Some(cache) => {
let prep_start = Instant::now();
let stmt = $self
.execute_with_timeout(
$self.client.prepare_statement(&$ctx.canonical_sql),
)
.await;
let prep_dur = prep_start.elapsed();
if $self.config.stats_enabled {
$self.stats.on_stmt_prepare(prep_dur);
}
let stmt = cache.insert_if_absent($ctx.canonical_sql.clone(), stmt?);
$self
.execute_with_timeout($self.client.$prepared(&stmt, $params))
.await
}
None => {
$self
.execute_with_timeout($self.client.$unprepared(&$ctx.exec_sql, $params))
.await
}
}
}
}
};
}
impl<C: GenericClient> super::PgClient<C> {
fn prepare_ctx(
&self,
tag: Option<&str>,
sql: &str,
param_count: usize,
) -> OrmResult<(QueryContext, StmtCacheProbe)> {
let mut ctx = QueryContext::new(sql, param_count);
if let Some(tag) = tag {
ctx.tag = Some(tag.to_string());
}
self.apply_hook(&mut ctx)?;
self.apply_sql_policy(&mut ctx)?;
self.check_sql(&ctx.canonical_sql)?;
let probe = self.probe_stmt_cache(&ctx);
probe.populate_context(&mut ctx);
self.emit_tracing_sql(&ctx);
Ok((ctx, probe))
}
pub(super) async fn query_impl(
&self,
tag: Option<&str>,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<Vec<Row>> {
let (ctx, probe) = self.prepare_ctx(tag, sql, params.len())?;
let start = Instant::now();
let result = stmt_cache_dispatch!(self, ctx, params, probe, query, query_prepared);
let duration = start.elapsed();
let query_result = match &result {
Ok(rows) => QueryResult::Rows(rows.len()),
Err(OrmError::Timeout(d)) => QueryResult::error(format!("timeout after {d:?}")),
Err(e) => QueryResult::error(e.to_string()),
};
self.report_result(&ctx, duration, &query_result);
result
}
pub(super) async fn query_one_impl(
&self,
tag: Option<&str>,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<Row> {
let (ctx, probe) = self.prepare_ctx(tag, sql, params.len())?;
let start = Instant::now();
let result = stmt_cache_dispatch!(self, ctx, params, probe, query_one, query_one_prepared);
let duration = start.elapsed();
let query_result = match &result {
Ok(_) => QueryResult::OptionalRow(true),
Err(OrmError::NotFound(_)) => QueryResult::OptionalRow(false),
Err(OrmError::Timeout(d)) => QueryResult::error(format!("timeout after {d:?}")),
Err(e) => QueryResult::error(e.to_string()),
};
self.report_result(&ctx, duration, &query_result);
result
}
pub(super) async fn query_opt_impl(
&self,
tag: Option<&str>,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<Option<Row>> {
let (ctx, probe) = self.prepare_ctx(tag, sql, params.len())?;
let start = Instant::now();
let result = stmt_cache_dispatch!(self, ctx, params, probe, query_opt, query_opt_prepared);
let duration = start.elapsed();
let query_result = match &result {
Ok(Some(_)) => QueryResult::OptionalRow(true),
Ok(None) => QueryResult::OptionalRow(false),
Err(OrmError::Timeout(d)) => QueryResult::error(format!("timeout after {d:?}")),
Err(e) => QueryResult::error(e.to_string()),
};
self.report_result(&ctx, duration, &query_result);
result
}
pub(super) async fn execute_impl(
&self,
tag: Option<&str>,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<u64> {
let (ctx, probe) = self.prepare_ctx(tag, sql, params.len())?;
let start = Instant::now();
let result = stmt_cache_dispatch!(self, ctx, params, probe, execute, execute_prepared);
let duration = start.elapsed();
let query_result = match &result {
Ok(n) => QueryResult::Affected(*n),
Err(OrmError::Timeout(d)) => QueryResult::error(format!("timeout after {d:?}")),
Err(e) => QueryResult::error(e.to_string()),
};
self.report_result(&ctx, duration, &query_result);
result
}
}