use crate::client::GenericClient;
use crate::error::{OrmError, OrmResult};
use std::collections::BTreeMap;
use std::fmt;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio_postgres::Row;
use tokio_postgres::types::ToSql;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueryType {
Select,
Insert,
Update,
Delete,
Other,
}
impl QueryType {
pub fn from_sql(sql: &str) -> Self {
fn strip_sql_prefix(sql: &str) -> &str {
let mut s = sql;
loop {
let before = s;
s = s.trim_start();
if s.starts_with("--") {
if let Some(pos) = s.find('\n') {
s = &s[pos + 1..];
continue;
}
return "";
}
if s.starts_with("/*") {
if let Some(pos) = s.find("*/") {
s = &s[pos + 2..];
continue;
}
return "";
}
if s.starts_with('(') {
s = &s[1..];
continue;
}
if s == before {
break;
}
}
s
}
fn starts_with_keyword(s: &str, keyword: &str) -> bool {
match s.get(0..keyword.len()) {
Some(prefix) => prefix.eq_ignore_ascii_case(keyword),
None => false,
}
}
let trimmed = strip_sql_prefix(sql);
if starts_with_keyword(trimmed, "SELECT") || starts_with_keyword(trimmed, "WITH") {
QueryType::Select
} else if starts_with_keyword(trimmed, "INSERT") {
QueryType::Insert
} else if starts_with_keyword(trimmed, "UPDATE") {
QueryType::Update
} else if starts_with_keyword(trimmed, "DELETE") {
QueryType::Delete
} else {
QueryType::Other
}
}
}
#[derive(Debug, Clone)]
pub struct QueryContext {
pub canonical_sql: String,
pub exec_sql: String,
pub param_count: usize,
pub query_type: QueryType,
pub tag: Option<String>,
pub fields: BTreeMap<String, String>,
}
impl QueryContext {
pub fn new(sql: &str, param_count: usize) -> Self {
Self {
canonical_sql: sql.to_string(),
exec_sql: sql.to_string(),
param_count,
query_type: QueryType::from_sql(sql),
tag: None,
fields: BTreeMap::new(),
}
}
pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
self.tag = Some(tag.into());
self
}
pub fn with_field(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.fields.insert(key.into(), value.into());
self
}
}
#[derive(Debug, Clone)]
pub enum QueryResult {
Rows(usize),
Affected(u64),
OptionalRow(bool),
Error(String),
}
impl fmt::Display for QueryResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
QueryResult::Rows(n) => write!(f, "{n} rows"),
QueryResult::Affected(n) => write!(f, "{n} affected"),
QueryResult::OptionalRow(found) => {
write!(f, "{}", if *found { "1 row" } else { "0 rows" })
}
QueryResult::Error(e) => write!(f, "error: {e}"),
}
}
}
pub trait QueryMonitor: Send + Sync {
fn on_query_start(&self, _ctx: &QueryContext) {}
fn on_query_complete(&self, ctx: &QueryContext, duration: Duration, result: &QueryResult);
fn on_slow_query(&self, _ctx: &QueryContext, _duration: Duration) {}
}
#[derive(Debug, Clone)]
pub enum HookAction {
Continue,
ModifySql {
exec_sql: String,
canonical_sql: Option<String>,
},
Abort(String),
}
pub trait QueryHook: Send + Sync {
fn before_query(&self, ctx: &QueryContext) -> HookAction {
let _ = ctx;
HookAction::Continue
}
fn after_query(&self, _ctx: &QueryContext, _duration: Duration, _result: &QueryResult) {}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct NoopMonitor;
impl QueryMonitor for NoopMonitor {
fn on_query_complete(&self, _ctx: &QueryContext, _duration: Duration, _result: &QueryResult) {}
}
#[derive(Debug, Clone)]
pub struct LoggingMonitor {
pub min_duration: Option<Duration>,
pub max_sql_length: Option<usize>,
pub prefix: String,
}
impl Default for LoggingMonitor {
fn default() -> Self {
Self {
min_duration: None,
max_sql_length: Some(200),
prefix: "[pgorm]".to_string(),
}
}
}
impl LoggingMonitor {
pub fn new() -> Self {
Self::default()
}
pub fn min_duration(mut self, duration: Duration) -> Self {
self.min_duration = Some(duration);
self
}
pub fn max_sql_length(mut self, len: usize) -> Self {
self.max_sql_length = Some(len);
self
}
pub fn prefix(mut self, prefix: impl Into<String>) -> Self {
self.prefix = prefix.into();
self
}
fn truncate_sql(&self, sql: &str) -> String {
match self.max_sql_length {
Some(max) if sql.len() > max => format!("{}...", &sql[..max]),
_ => sql.to_string(),
}
}
}
impl QueryMonitor for LoggingMonitor {
fn on_query_complete(&self, ctx: &QueryContext, duration: Duration, result: &QueryResult) {
if let Some(min) = self.min_duration {
if duration < min {
return;
}
}
let canonical = self.truncate_sql(&ctx.canonical_sql);
let sql = if ctx.exec_sql != ctx.canonical_sql {
format!(
"canonical: {} | exec: {}",
canonical,
self.truncate_sql(&ctx.exec_sql)
)
} else {
canonical
};
let tag = ctx.tag.as_deref().unwrap_or("-");
eprintln!(
"{} [{:?}] [{}] {:?} | {} | {}",
self.prefix, ctx.query_type, tag, duration, result, sql
);
}
fn on_slow_query(&self, ctx: &QueryContext, duration: Duration) {
let canonical = self.truncate_sql(&ctx.canonical_sql);
let sql = if ctx.exec_sql != ctx.canonical_sql {
format!(
"canonical: {} | exec: {}",
canonical,
self.truncate_sql(&ctx.exec_sql)
)
} else {
canonical
};
eprintln!(
"{} SLOW QUERY [{:?}]: {:?} | {}",
self.prefix, ctx.query_type, duration, sql
);
}
}
#[derive(Debug)]
pub struct StatsMonitor {
total_queries: std::sync::atomic::AtomicU64,
failed_queries: std::sync::atomic::AtomicU64,
total_duration_nanos: std::sync::atomic::AtomicU64,
select_count: std::sync::atomic::AtomicU64,
insert_count: std::sync::atomic::AtomicU64,
update_count: std::sync::atomic::AtomicU64,
delete_count: std::sync::atomic::AtomicU64,
max_duration_nanos: std::sync::atomic::AtomicU64,
slowest_query: std::sync::Mutex<Option<String>>,
}
#[derive(Debug, Clone, Default)]
pub struct QueryStats {
pub total_queries: u64,
pub failed_queries: u64,
pub total_duration: Duration,
pub select_count: u64,
pub insert_count: u64,
pub update_count: u64,
pub delete_count: u64,
pub max_duration: Duration,
pub slowest_query: Option<String>,
}
impl StatsMonitor {
pub fn new() -> Self {
Self::default()
}
pub fn stats(&self) -> QueryStats {
use std::sync::atomic::Ordering;
QueryStats {
total_queries: self.total_queries.load(Ordering::Relaxed),
failed_queries: self.failed_queries.load(Ordering::Relaxed),
total_duration: Duration::from_nanos(self.total_duration_nanos.load(Ordering::Relaxed)),
select_count: self.select_count.load(Ordering::Relaxed),
insert_count: self.insert_count.load(Ordering::Relaxed),
update_count: self.update_count.load(Ordering::Relaxed),
delete_count: self.delete_count.load(Ordering::Relaxed),
max_duration: Duration::from_nanos(self.max_duration_nanos.load(Ordering::Relaxed)),
slowest_query: self.slowest_query.lock().unwrap().clone(),
}
}
pub fn reset(&self) {
use std::sync::atomic::Ordering;
self.total_queries.store(0, Ordering::Relaxed);
self.failed_queries.store(0, Ordering::Relaxed);
self.total_duration_nanos.store(0, Ordering::Relaxed);
self.select_count.store(0, Ordering::Relaxed);
self.insert_count.store(0, Ordering::Relaxed);
self.update_count.store(0, Ordering::Relaxed);
self.delete_count.store(0, Ordering::Relaxed);
self.max_duration_nanos.store(0, Ordering::Relaxed);
*self.slowest_query.lock().unwrap() = None;
}
}
impl Default for StatsMonitor {
fn default() -> Self {
Self {
total_queries: std::sync::atomic::AtomicU64::new(0),
failed_queries: std::sync::atomic::AtomicU64::new(0),
total_duration_nanos: std::sync::atomic::AtomicU64::new(0),
select_count: std::sync::atomic::AtomicU64::new(0),
insert_count: std::sync::atomic::AtomicU64::new(0),
update_count: std::sync::atomic::AtomicU64::new(0),
delete_count: std::sync::atomic::AtomicU64::new(0),
max_duration_nanos: std::sync::atomic::AtomicU64::new(0),
slowest_query: std::sync::Mutex::new(None),
}
}
}
impl QueryMonitor for StatsMonitor {
fn on_query_complete(&self, ctx: &QueryContext, duration: Duration, result: &QueryResult) {
use std::sync::atomic::Ordering;
let duration_nanos = u64::try_from(duration.as_nanos()).unwrap_or(u64::MAX);
self.total_queries.fetch_add(1, Ordering::Relaxed);
let prev_total = self
.total_duration_nanos
.fetch_add(duration_nanos, Ordering::Relaxed);
if prev_total.checked_add(duration_nanos).is_none() {
self.total_duration_nanos.store(u64::MAX, Ordering::Relaxed);
}
match ctx.query_type {
QueryType::Select => {
self.select_count.fetch_add(1, Ordering::Relaxed);
}
QueryType::Insert => {
self.insert_count.fetch_add(1, Ordering::Relaxed);
}
QueryType::Update => {
self.update_count.fetch_add(1, Ordering::Relaxed);
}
QueryType::Delete => {
self.delete_count.fetch_add(1, Ordering::Relaxed);
}
QueryType::Other => {}
}
if matches!(result, QueryResult::Error(_)) {
self.failed_queries.fetch_add(1, Ordering::Relaxed);
}
let mut current_max = self.max_duration_nanos.load(Ordering::Relaxed);
while duration_nanos > current_max {
match self.max_duration_nanos.compare_exchange_weak(
current_max,
duration_nanos,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => {
*self.slowest_query.lock().unwrap() = Some(ctx.canonical_sql.clone());
break;
}
Err(updated) => current_max = updated,
}
}
}
}
pub struct CompositeMonitor {
monitors: Vec<Arc<dyn QueryMonitor>>,
}
impl CompositeMonitor {
pub fn new() -> Self {
Self {
monitors: Vec::new(),
}
}
#[allow(clippy::should_implement_trait)]
pub fn add<M: QueryMonitor + 'static>(mut self, monitor: M) -> Self {
self.monitors.push(Arc::new(monitor));
self
}
pub fn add_arc(mut self, monitor: Arc<dyn QueryMonitor>) -> Self {
self.monitors.push(monitor);
self
}
}
impl Default for CompositeMonitor {
fn default() -> Self {
Self::new()
}
}
impl QueryMonitor for CompositeMonitor {
fn on_query_start(&self, ctx: &QueryContext) {
for monitor in &self.monitors {
monitor.on_query_start(ctx);
}
}
fn on_query_complete(&self, ctx: &QueryContext, duration: Duration, result: &QueryResult) {
for monitor in &self.monitors {
monitor.on_query_complete(ctx, duration, result);
}
}
fn on_slow_query(&self, ctx: &QueryContext, duration: Duration) {
for monitor in &self.monitors {
monitor.on_slow_query(ctx, duration);
}
}
}
pub struct CompositeHook {
hooks: Vec<Arc<dyn QueryHook>>,
}
impl CompositeHook {
pub fn new() -> Self {
Self { hooks: Vec::new() }
}
#[allow(clippy::should_implement_trait)]
pub fn add<H: QueryHook + 'static>(mut self, hook: H) -> Self {
self.hooks.push(Arc::new(hook));
self
}
pub fn add_arc(mut self, hook: Arc<dyn QueryHook>) -> Self {
self.hooks.push(hook);
self
}
}
impl Default for CompositeHook {
fn default() -> Self {
Self::new()
}
}
impl QueryHook for CompositeHook {
fn before_query(&self, ctx: &QueryContext) -> HookAction {
let mut current_ctx = ctx.clone();
for hook in &self.hooks {
match hook.before_query(¤t_ctx) {
HookAction::Continue => {}
HookAction::ModifySql {
exec_sql,
canonical_sql,
} => {
current_ctx.exec_sql = exec_sql;
if let Some(canonical_sql) = canonical_sql {
current_ctx.canonical_sql = canonical_sql;
}
current_ctx.query_type = QueryType::from_sql(¤t_ctx.canonical_sql);
}
action @ HookAction::Abort(_) => return action,
}
}
if current_ctx.exec_sql != ctx.exec_sql || current_ctx.canonical_sql != ctx.canonical_sql {
HookAction::ModifySql {
exec_sql: current_ctx.exec_sql,
canonical_sql: (current_ctx.canonical_sql != ctx.canonical_sql)
.then_some(current_ctx.canonical_sql),
}
} else {
HookAction::Continue
}
}
fn after_query(&self, ctx: &QueryContext, duration: Duration, result: &QueryResult) {
for hook in &self.hooks {
hook.after_query(ctx, duration, result);
}
}
}
#[derive(Debug, Clone, Default)]
pub struct MonitorConfig {
pub query_timeout: Option<Duration>,
pub slow_query_threshold: Option<Duration>,
pub monitoring_enabled: bool,
}
impl MonitorConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_query_timeout(mut self, timeout: Duration) -> Self {
self.query_timeout = Some(timeout);
self
}
pub fn with_slow_query_threshold(mut self, threshold: Duration) -> Self {
self.slow_query_threshold = Some(threshold);
self
}
pub fn enable_monitoring(mut self) -> Self {
self.monitoring_enabled = true;
self
}
pub fn disable_monitoring(mut self) -> Self {
self.monitoring_enabled = false;
self
}
}
pub struct InstrumentedClient<C> {
client: C,
monitor: Arc<dyn QueryMonitor>,
hook: Option<Arc<dyn QueryHook>>,
config: MonitorConfig,
}
impl<C: GenericClient> InstrumentedClient<C> {
pub fn new(client: C) -> Self {
Self {
client,
monitor: Arc::new(NoopMonitor),
hook: None,
config: MonitorConfig::default(),
}
}
pub fn with_config(mut self, config: MonitorConfig) -> Self {
self.config = config;
self
}
pub fn with_monitor<M: QueryMonitor + 'static>(mut self, monitor: M) -> Self {
self.monitor = Arc::new(monitor);
self
}
pub fn with_monitor_arc(mut self, monitor: Arc<dyn QueryMonitor>) -> Self {
self.monitor = monitor;
self
}
pub fn with_hook<H: QueryHook + 'static>(mut self, hook: H) -> Self {
self.hook = Some(Arc::new(hook));
self
}
pub fn with_hook_arc(mut self, hook: Arc<dyn QueryHook>) -> Self {
self.hook = Some(hook);
self
}
pub fn add_hook<H: QueryHook + 'static>(self, hook: H) -> Self {
self.add_hook_arc(Arc::new(hook))
}
pub fn add_hook_arc(mut self, hook: Arc<dyn QueryHook>) -> Self {
self.hook = Some(match self.hook.take() {
None => hook,
Some(existing) => Arc::new(CompositeHook::new().add_arc(existing).add_arc(hook)),
});
self
}
#[deprecated(
since = "0.2.0",
note = "Use `with_config(MonitorConfig::new().with_slow_query_threshold(...))` instead"
)]
pub fn with_slow_query_threshold(mut self, threshold: Duration) -> Self {
self.config.slow_query_threshold = Some(threshold);
self
}
pub fn with_query_timeout(mut self, timeout: Duration) -> Self {
self.config.query_timeout = Some(timeout);
self
}
pub fn enable_monitoring(mut self) -> Self {
self.config.monitoring_enabled = true;
self
}
pub fn disable_monitoring(mut self) -> Self {
self.config.monitoring_enabled = false;
self
}
pub fn is_monitoring_enabled(&self) -> bool {
self.config.monitoring_enabled
}
pub fn config(&self) -> &MonitorConfig {
&self.config
}
pub fn config_mut(&mut self) -> &mut MonitorConfig {
&mut self.config
}
pub fn inner(&self) -> &C {
&self.client
}
pub fn into_inner(self) -> C {
self.client
}
fn apply_hook(&self, ctx: &mut QueryContext) -> Result<(), OrmError> {
let Some(hook) = &self.hook else {
return Ok(());
};
match hook.before_query(ctx) {
HookAction::Continue => Ok(()),
HookAction::ModifySql {
exec_sql,
canonical_sql,
} => {
ctx.exec_sql = exec_sql;
if let Some(canonical_sql) = canonical_sql {
ctx.canonical_sql = canonical_sql;
}
ctx.query_type = QueryType::from_sql(&ctx.canonical_sql);
Ok(())
}
HookAction::Abort(reason) => Err(OrmError::validation(format!(
"Query aborted by hook: {reason}"
))),
}
}
fn report_result(&self, ctx: &QueryContext, duration: Duration, result: &QueryResult) {
if !self.config.monitoring_enabled {
return;
}
if let Some(hook) = &self.hook {
hook.after_query(ctx, duration, result);
}
self.monitor.on_query_complete(ctx, duration, result);
if let Some(threshold) = self.config.slow_query_threshold {
if duration > threshold {
self.monitor.on_slow_query(ctx, duration);
}
}
}
async fn execute_with_timeout<T, F>(&self, future: F) -> OrmResult<T>
where
F: std::future::Future<Output = OrmResult<T>> + Send,
{
match self.config.query_timeout {
Some(timeout) => {
tokio::pin!(future);
tokio::select! {
result = &mut future => result,
_ = tokio::time::sleep(timeout) => {
if let Some(cancel_token) = self.client.cancel_token() {
tokio::spawn(async move {
let _ = cancel_token.cancel_query(tokio_postgres::NoTls).await;
});
}
Err(OrmError::Timeout(timeout))
}
}
}
None => future.await,
}
}
async fn query_inner(
&self,
sql: &str,
params: &[&(dyn ToSql + Sync)],
tag: Option<&str>,
) -> OrmResult<Vec<Row>> {
let mut ctx = QueryContext::new(sql, params.len());
if let Some(tag) = tag {
ctx.tag = Some(tag.to_string());
}
self.apply_hook(&mut ctx)?;
if self.config.monitoring_enabled {
self.monitor.on_query_start(&ctx);
}
let start = Instant::now();
let result = self
.execute_with_timeout(self.client.query(&ctx.exec_sql, params))
.await;
let duration = start.elapsed();
let query_result = match &result {
Ok(rows) => QueryResult::Rows(rows.len()),
Err(OrmError::Timeout(d)) => QueryResult::Error(format!("timeout after {d:?}")),
Err(e) => QueryResult::Error(e.to_string()),
};
self.report_result(&ctx, duration, &query_result);
result
}
async fn query_one_inner(
&self,
sql: &str,
params: &[&(dyn ToSql + Sync)],
tag: Option<&str>,
) -> OrmResult<Row> {
let mut ctx = QueryContext::new(sql, params.len());
if let Some(tag) = tag {
ctx.tag = Some(tag.to_string());
}
self.apply_hook(&mut ctx)?;
if self.config.monitoring_enabled {
self.monitor.on_query_start(&ctx);
}
let start = Instant::now();
let result = self
.execute_with_timeout(self.client.query_one(&ctx.exec_sql, params))
.await;
let duration = start.elapsed();
let query_result = match &result {
Ok(_) => QueryResult::OptionalRow(true),
Err(OrmError::NotFound { .. }) => QueryResult::OptionalRow(false),
Err(OrmError::Timeout(d)) => QueryResult::Error(format!("timeout after {d:?}")),
Err(e) => QueryResult::Error(e.to_string()),
};
self.report_result(&ctx, duration, &query_result);
result
}
async fn query_opt_inner(
&self,
sql: &str,
params: &[&(dyn ToSql + Sync)],
tag: Option<&str>,
) -> OrmResult<Option<Row>> {
let mut ctx = QueryContext::new(sql, params.len());
if let Some(tag) = tag {
ctx.tag = Some(tag.to_string());
}
self.apply_hook(&mut ctx)?;
if self.config.monitoring_enabled {
self.monitor.on_query_start(&ctx);
}
let start = Instant::now();
let result = self
.execute_with_timeout(self.client.query_opt(&ctx.exec_sql, params))
.await;
let duration = start.elapsed();
let query_result = match &result {
Ok(Some(_)) => QueryResult::OptionalRow(true),
Ok(None) => QueryResult::OptionalRow(false),
Err(OrmError::Timeout(d)) => QueryResult::Error(format!("timeout after {d:?}")),
Err(e) => QueryResult::Error(e.to_string()),
};
self.report_result(&ctx, duration, &query_result);
result
}
async fn execute_inner(
&self,
sql: &str,
params: &[&(dyn ToSql + Sync)],
tag: Option<&str>,
) -> OrmResult<u64> {
let mut ctx = QueryContext::new(sql, params.len());
if let Some(tag) = tag {
ctx.tag = Some(tag.to_string());
}
self.apply_hook(&mut ctx)?;
if self.config.monitoring_enabled {
self.monitor.on_query_start(&ctx);
}
let start = Instant::now();
let result = self
.execute_with_timeout(self.client.execute(&ctx.exec_sql, params))
.await;
let duration = start.elapsed();
let query_result = match &result {
Ok(n) => QueryResult::Affected(*n),
Err(OrmError::Timeout(d)) => QueryResult::Error(format!("timeout after {d:?}")),
Err(e) => QueryResult::Error(e.to_string()),
};
self.report_result(&ctx, duration, &query_result);
result
}
}
impl<C: GenericClient> GenericClient for InstrumentedClient<C> {
async fn query(&self, sql: &str, params: &[&(dyn ToSql + Sync)]) -> OrmResult<Vec<Row>> {
self.query_inner(sql, params, None).await
}
async fn query_tagged(
&self,
tag: &str,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<Vec<Row>> {
self.query_inner(sql, params, Some(tag)).await
}
async fn query_one(&self, sql: &str, params: &[&(dyn ToSql + Sync)]) -> OrmResult<Row> {
self.query_one_inner(sql, params, None).await
}
async fn query_one_tagged(
&self,
tag: &str,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<Row> {
self.query_one_inner(sql, params, Some(tag)).await
}
async fn query_opt(&self, sql: &str, params: &[&(dyn ToSql + Sync)]) -> OrmResult<Option<Row>> {
self.query_opt_inner(sql, params, None).await
}
async fn query_opt_tagged(
&self,
tag: &str,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<Option<Row>> {
self.query_opt_inner(sql, params, Some(tag)).await
}
async fn execute(&self, sql: &str, params: &[&(dyn ToSql + Sync)]) -> OrmResult<u64> {
self.execute_inner(sql, params, None).await
}
async fn execute_tagged(
&self,
tag: &str,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<u64> {
self.execute_inner(sql, params, Some(tag)).await
}
fn cancel_token(&self) -> Option<tokio_postgres::CancelToken> {
self.client.cancel_token()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_query_type_detection() {
assert_eq!(
QueryType::from_sql("SELECT * FROM users"),
QueryType::Select
);
assert_eq!(
QueryType::from_sql(" select * FROM users"),
QueryType::Select
);
assert_eq!(
QueryType::from_sql("WITH cte AS (SELECT 1) SELECT * FROM cte"),
QueryType::Select
);
assert_eq!(
QueryType::from_sql("INSERT INTO users (name) VALUES ($1)"),
QueryType::Insert
);
assert_eq!(
QueryType::from_sql("UPDATE users SET name = $1"),
QueryType::Update
);
assert_eq!(
QueryType::from_sql("DELETE FROM users WHERE id = $1"),
QueryType::Delete
);
assert_eq!(
QueryType::from_sql("CREATE TABLE users (id INT)"),
QueryType::Other
);
}
#[test]
fn test_logging_monitor_truncation() {
let monitor = LoggingMonitor::new().max_sql_length(10);
assert_eq!(monitor.truncate_sql("SELECT * FROM users"), "SELECT * F...");
assert_eq!(monitor.truncate_sql("SELECT 1"), "SELECT 1");
}
#[test]
fn test_stats_monitor() {
let monitor = StatsMonitor::new();
let ctx = QueryContext::new("SELECT * FROM users", 0);
monitor.on_query_complete(&ctx, Duration::from_millis(10), &QueryResult::Rows(5));
monitor.on_query_complete(&ctx, Duration::from_millis(20), &QueryResult::Rows(3));
let stats = monitor.stats();
assert_eq!(stats.total_queries, 2);
assert_eq!(stats.select_count, 2);
assert_eq!(stats.total_duration, Duration::from_millis(30));
}
#[test]
fn test_composite_hook_modify() {
struct AddCommentHook;
impl QueryHook for AddCommentHook {
fn before_query(&self, ctx: &QueryContext) -> HookAction {
HookAction::ModifySql {
exec_sql: format!("/* instrumented */ {}", ctx.exec_sql),
canonical_sql: None,
}
}
}
let hook = CompositeHook::new().add(AddCommentHook);
let ctx = QueryContext::new("SELECT 1", 0);
match hook.before_query(&ctx) {
HookAction::ModifySql {
exec_sql,
canonical_sql,
} => {
assert_eq!(exec_sql, "/* instrumented */ SELECT 1");
assert!(canonical_sql.is_none());
}
_ => panic!("Expected ModifySql"),
}
}
#[test]
fn test_composite_hook_abort() {
struct BlockDeleteHook;
impl QueryHook for BlockDeleteHook {
fn before_query(&self, ctx: &QueryContext) -> HookAction {
if ctx.query_type == QueryType::Delete {
HookAction::Abort("DELETE not allowed".to_string())
} else {
HookAction::Continue
}
}
}
let hook = CompositeHook::new().add(BlockDeleteHook);
let ctx = QueryContext::new("DELETE FROM users", 0);
match hook.before_query(&ctx) {
HookAction::Abort(reason) => assert_eq!(reason, "DELETE not allowed"),
_ => panic!("Expected Abort"),
}
}
#[tokio::test]
async fn tagged_queries_propagate_to_monitor() {
#[derive(Default)]
struct TagCapture(std::sync::Mutex<Option<String>>);
impl QueryMonitor for TagCapture {
fn on_query_complete(&self, ctx: &QueryContext, _: Duration, _: &QueryResult) {
*self.0.lock().unwrap() = ctx.tag.clone();
}
}
struct DummyClient;
impl GenericClient for DummyClient {
async fn query(&self, _: &str, _: &[&(dyn ToSql + Sync)]) -> OrmResult<Vec<Row>> {
Ok(vec![])
}
async fn query_one(&self, _: &str, _: &[&(dyn ToSql + Sync)]) -> OrmResult<Row> {
Err(OrmError::not_found("no rows"))
}
async fn query_opt(
&self,
_: &str,
_: &[&(dyn ToSql + Sync)],
) -> OrmResult<Option<Row>> {
Ok(None)
}
async fn execute(&self, _: &str, _: &[&(dyn ToSql + Sync)]) -> OrmResult<u64> {
Ok(0)
}
}
let capture = Arc::new(TagCapture::default());
let client = InstrumentedClient::new(DummyClient)
.with_config(MonitorConfig::new().enable_monitoring())
.with_monitor_arc(capture.clone());
client
.query_tagged("test-tag", "SELECT 1", &[])
.await
.unwrap();
assert_eq!(capture.0.lock().unwrap().as_deref(), Some("test-tag"));
}
#[tokio::test]
async fn timeout_returns_error_and_attempts_cancellation() {
struct HangingClient;
impl GenericClient for HangingClient {
async fn query(&self, _: &str, _: &[&(dyn ToSql + Sync)]) -> OrmResult<Vec<Row>> {
tokio::time::sleep(Duration::from_secs(60)).await;
Ok(vec![])
}
async fn query_one(&self, _: &str, _: &[&(dyn ToSql + Sync)]) -> OrmResult<Row> {
Err(OrmError::not_found("unused"))
}
async fn query_opt(
&self,
_: &str,
_: &[&(dyn ToSql + Sync)],
) -> OrmResult<Option<Row>> {
Ok(None)
}
async fn execute(&self, _: &str, _: &[&(dyn ToSql + Sync)]) -> OrmResult<u64> {
Ok(0)
}
}
let client = InstrumentedClient::new(HangingClient).with_config(
MonitorConfig::new()
.with_query_timeout(Duration::from_millis(10))
.enable_monitoring(),
);
let err = client.query("SELECT pg_sleep(60)", &[]).await.unwrap_err();
assert!(matches!(err, OrmError::Timeout(_)));
}
}