use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use async_trait::async_trait;
use crate::schema::{ColumnInfo, ForeignKeyInfo, IndexInfo, TableInfo};
use crate::{Connection, OxiSqlError, PreparedStatement, Row, ToSqlValue, Transaction};
pub struct LoggingConnection<C> {
inner: C,
prefix: String,
}
impl<C: Connection> LoggingConnection<C> {
pub fn new(inner: C) -> Self {
Self {
inner,
prefix: String::new(),
}
}
pub fn with_prefix(inner: C, prefix: impl Into<String>) -> Self {
Self {
inner,
prefix: prefix.into(),
}
}
pub fn into_inner(self) -> C {
self.inner
}
pub fn prefix(&self) -> &str {
&self.prefix
}
fn fmt_prefix(&self) -> String {
if self.prefix.is_empty() {
String::new()
} else {
format!("{} ", self.prefix)
}
}
}
#[async_trait]
impl<C: Connection + Send + Sync> Connection for LoggingConnection<C> {
async fn execute(&self, sql: &str, params: &[&dyn ToSqlValue]) -> Result<u64, OxiSqlError> {
let t = Instant::now();
let result = self.inner.execute(sql, params).await;
let elapsed = t.elapsed();
match &result {
Ok(n) => log::debug!(
"[{}execute] {} row(s) affected — {:.3}ms{}",
self.fmt_prefix(),
n,
elapsed.as_secs_f64() * 1000.0,
truncate_sql(sql),
),
Err(e) => log::warn!(
"[{}execute] ERROR {} — {:.3}ms{}",
self.fmt_prefix(),
e,
elapsed.as_secs_f64() * 1000.0,
truncate_sql(sql),
),
}
result
}
async fn query(&self, sql: &str, params: &[&dyn ToSqlValue]) -> Result<Vec<Row>, OxiSqlError> {
let t = Instant::now();
let result = self.inner.query(sql, params).await;
let elapsed = t.elapsed();
match &result {
Ok(rows) => log::debug!(
"[{}query] {} row(s) — {:.3}ms{}",
self.fmt_prefix(),
rows.len(),
elapsed.as_secs_f64() * 1000.0,
truncate_sql(sql),
),
Err(e) => log::warn!(
"[{}query] ERROR {} — {:.3}ms{}",
self.fmt_prefix(),
e,
elapsed.as_secs_f64() * 1000.0,
truncate_sql(sql),
),
}
result
}
async fn transaction(&self) -> Result<Box<dyn Transaction + '_>, OxiSqlError> {
log::debug!("[{}transaction] BEGIN", self.fmt_prefix());
self.inner.transaction().await
}
async fn execute_batch(&self, sql: &str) -> Result<u64, OxiSqlError> {
let t = Instant::now();
let result = self.inner.execute_batch(sql).await;
log::debug!(
"[{}execute_batch] {:.3}ms{}",
self.fmt_prefix(),
t.elapsed().as_secs_f64() * 1000.0,
truncate_sql(sql),
);
result
}
async fn ping(&self) -> Result<(), OxiSqlError> {
log::debug!("[{}ping]", self.fmt_prefix());
self.inner.ping().await
}
async fn prepare(&self, sql: &str) -> Result<Box<dyn PreparedStatement + '_>, OxiSqlError> {
log::debug!("[{}prepare]{}", self.fmt_prefix(), truncate_sql(sql));
self.inner.prepare(sql).await
}
async fn tables(&self) -> Result<Vec<TableInfo>, OxiSqlError> {
self.inner.tables().await
}
async fn columns(&self, table: &str) -> Result<Vec<ColumnInfo>, OxiSqlError> {
self.inner.columns(table).await
}
async fn indexes(&self, table: &str) -> Result<Vec<IndexInfo>, OxiSqlError> {
self.inner.indexes(table).await
}
async fn foreign_keys(&self, table: &str) -> Result<Vec<ForeignKeyInfo>, OxiSqlError> {
self.inner.foreign_keys(table).await
}
}
#[cfg(feature = "tracing")]
pub struct TracingConnection<C> {
inner: C,
prefix: String,
}
#[cfg(feature = "tracing")]
impl<C: Connection> TracingConnection<C> {
pub fn new(inner: C) -> Self {
Self {
inner,
prefix: String::new(),
}
}
pub fn with_prefix(inner: C, prefix: impl Into<String>) -> Self {
Self {
inner,
prefix: prefix.into(),
}
}
pub fn into_inner(self) -> C {
self.inner
}
pub fn prefix(&self) -> &str {
&self.prefix
}
fn fmt_prefix(&self) -> String {
if self.prefix.is_empty() {
String::new()
} else {
format!("{} ", self.prefix)
}
}
}
#[cfg(feature = "tracing")]
#[async_trait]
impl<C: Connection + Send + Sync> Connection for TracingConnection<C> {
async fn execute(&self, sql: &str, params: &[&dyn ToSqlValue]) -> Result<u64, OxiSqlError> {
let t = Instant::now();
let result = self.inner.execute(sql, params).await;
let elapsed = t.elapsed();
match &result {
Ok(n) => tracing::debug!(
"[{}execute] {} row(s) affected — {:.3}ms{}",
self.fmt_prefix(),
n,
elapsed.as_secs_f64() * 1000.0,
truncate_sql(sql),
),
Err(e) => tracing::warn!(
"[{}execute] ERROR {} — {:.3}ms{}",
self.fmt_prefix(),
e,
elapsed.as_secs_f64() * 1000.0,
truncate_sql(sql),
),
}
result
}
async fn query(&self, sql: &str, params: &[&dyn ToSqlValue]) -> Result<Vec<Row>, OxiSqlError> {
let t = Instant::now();
let result = self.inner.query(sql, params).await;
let elapsed = t.elapsed();
match &result {
Ok(rows) => tracing::debug!(
"[{}query] {} row(s) — {:.3}ms{}",
self.fmt_prefix(),
rows.len(),
elapsed.as_secs_f64() * 1000.0,
truncate_sql(sql),
),
Err(e) => tracing::warn!(
"[{}query] ERROR {} — {:.3}ms{}",
self.fmt_prefix(),
e,
elapsed.as_secs_f64() * 1000.0,
truncate_sql(sql),
),
}
result
}
async fn transaction(&self) -> Result<Box<dyn Transaction + '_>, OxiSqlError> {
tracing::debug!("[{}transaction] BEGIN", self.fmt_prefix());
self.inner.transaction().await
}
async fn execute_batch(&self, sql: &str) -> Result<u64, OxiSqlError> {
let t = Instant::now();
let result = self.inner.execute_batch(sql).await;
tracing::debug!(
"[{}execute_batch] {:.3}ms{}",
self.fmt_prefix(),
t.elapsed().as_secs_f64() * 1000.0,
truncate_sql(sql),
);
result
}
async fn ping(&self) -> Result<(), OxiSqlError> {
tracing::debug!("[{}ping]", self.fmt_prefix());
self.inner.ping().await
}
async fn prepare(&self, sql: &str) -> Result<Box<dyn PreparedStatement + '_>, OxiSqlError> {
tracing::debug!("[{}prepare]{}", self.fmt_prefix(), truncate_sql(sql));
self.inner.prepare(sql).await
}
async fn tables(&self) -> Result<Vec<TableInfo>, OxiSqlError> {
self.inner.tables().await
}
async fn columns(&self, table: &str) -> Result<Vec<ColumnInfo>, OxiSqlError> {
self.inner.columns(table).await
}
async fn indexes(&self, table: &str) -> Result<Vec<IndexInfo>, OxiSqlError> {
self.inner.indexes(table).await
}
async fn foreign_keys(&self, table: &str) -> Result<Vec<ForeignKeyInfo>, OxiSqlError> {
self.inner.foreign_keys(table).await
}
}
#[cfg(all(test, feature = "tracing"))]
mod tracing_tests {
use std::sync::{Arc, Mutex};
use tracing_subscriber::fmt::MakeWriter;
use crate::{Connection, OxiSqlError, Row, ToSqlValue, Transaction, Value};
use super::TracingConnection;
#[derive(Clone, Default)]
struct MemWriter(Arc<Mutex<Vec<u8>>>);
impl MemWriter {
fn contents(&self) -> String {
let data = self.0.lock().expect("lock");
String::from_utf8_lossy(&data).into_owned()
}
}
impl std::io::Write for MemWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.0.lock().expect("lock").extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
impl<'a> MakeWriter<'a> for MemWriter {
type Writer = MemWriter;
fn make_writer(&'a self) -> Self::Writer {
self.clone()
}
}
struct MockConn;
#[async_trait::async_trait]
impl Connection for MockConn {
async fn execute(
&self,
_sql: &str,
_params: &[&dyn ToSqlValue],
) -> Result<u64, OxiSqlError> {
Ok(1)
}
async fn query(
&self,
_sql: &str,
_params: &[&dyn ToSqlValue],
) -> Result<Vec<Row>, OxiSqlError> {
Ok(vec![Row::new(vec!["n".into()], vec![Value::I64(42)])])
}
async fn transaction(&self) -> Result<Box<dyn Transaction + '_>, OxiSqlError> {
Err(OxiSqlError::Other("no txn".into()))
}
}
#[tokio::test]
async fn tracing_execute_emits_event() {
let writer = MemWriter::default();
let subscriber = tracing_subscriber::fmt()
.with_writer(writer.clone())
.with_ansi(false)
.with_max_level(tracing::Level::DEBUG)
.finish();
let _guard = tracing::subscriber::set_default(subscriber);
let conn = TracingConnection::new(MockConn);
let rows_affected = conn
.execute("INSERT INTO t VALUES ($1)", &[&42i64])
.await
.expect("execute ok");
assert_eq!(rows_affected, 1);
let output = writer.contents();
assert!(
output.contains("execute"),
"expected 'execute' in: {output}"
);
}
#[tokio::test]
async fn tracing_query_emits_event() {
let writer = MemWriter::default();
let subscriber = tracing_subscriber::fmt()
.with_writer(writer.clone())
.with_ansi(false)
.with_max_level(tracing::Level::DEBUG)
.finish();
let _guard = tracing::subscriber::set_default(subscriber);
let conn = TracingConnection::new(MockConn);
let rows = conn.query("SELECT n FROM t", &[]).await.expect("query ok");
assert_eq!(rows.len(), 1);
let output = writer.contents();
assert!(output.contains("query"), "expected 'query' in: {output}");
}
#[test]
fn tracing_conn_prefix_and_accessors() {
let conn = TracingConnection::with_prefix(MockConn, "mydb");
assert_eq!(conn.prefix(), "mydb");
let _inner: MockConn = conn.into_inner();
}
}
#[derive(Debug, Default)]
pub struct ConnectionMetrics {
pub executes: AtomicU64,
pub queries: AtomicU64,
pub errors: AtomicU64,
pub execute_us: AtomicU64,
pub query_us: AtomicU64,
}
impl ConnectionMetrics {
pub fn snapshot(&self) -> MetricsSnapshot {
MetricsSnapshot {
executes: self.executes.load(Ordering::Relaxed),
queries: self.queries.load(Ordering::Relaxed),
errors: self.errors.load(Ordering::Relaxed),
execute_us: self.execute_us.load(Ordering::Relaxed),
query_us: self.query_us.load(Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MetricsSnapshot {
pub executes: u64,
pub queries: u64,
pub errors: u64,
pub execute_us: u64,
pub query_us: u64,
}
pub struct MetricsConnection<C> {
inner: C,
metrics: Arc<ConnectionMetrics>,
}
impl<C: Connection> MetricsConnection<C> {
pub fn new(inner: C, metrics: Arc<ConnectionMetrics>) -> Self {
Self { inner, metrics }
}
pub fn metrics(&self) -> &Arc<ConnectionMetrics> {
&self.metrics
}
pub fn into_inner(self) -> C {
self.inner
}
}
#[async_trait]
impl<C: Connection + Send + Sync> Connection for MetricsConnection<C> {
async fn execute(&self, sql: &str, params: &[&dyn ToSqlValue]) -> Result<u64, OxiSqlError> {
let t = Instant::now();
let result = self.inner.execute(sql, params).await;
let us = t.elapsed().as_micros() as u64;
self.metrics.executes.fetch_add(1, Ordering::Relaxed);
self.metrics.execute_us.fetch_add(us, Ordering::Relaxed);
if result.is_err() {
self.metrics.errors.fetch_add(1, Ordering::Relaxed);
}
result
}
async fn query(&self, sql: &str, params: &[&dyn ToSqlValue]) -> Result<Vec<Row>, OxiSqlError> {
let t = Instant::now();
let result = self.inner.query(sql, params).await;
let us = t.elapsed().as_micros() as u64;
self.metrics.queries.fetch_add(1, Ordering::Relaxed);
self.metrics.query_us.fetch_add(us, Ordering::Relaxed);
if result.is_err() {
self.metrics.errors.fetch_add(1, Ordering::Relaxed);
}
result
}
async fn transaction(&self) -> Result<Box<dyn Transaction + '_>, OxiSqlError> {
self.inner.transaction().await
}
async fn execute_batch(&self, sql: &str) -> Result<u64, OxiSqlError> {
self.inner.execute_batch(sql).await
}
async fn ping(&self) -> Result<(), OxiSqlError> {
self.inner.ping().await
}
async fn prepare(&self, sql: &str) -> Result<Box<dyn PreparedStatement + '_>, OxiSqlError> {
self.inner.prepare(sql).await
}
async fn tables(&self) -> Result<Vec<TableInfo>, OxiSqlError> {
self.inner.tables().await
}
async fn columns(&self, table: &str) -> Result<Vec<ColumnInfo>, OxiSqlError> {
self.inner.columns(table).await
}
async fn indexes(&self, table: &str) -> Result<Vec<IndexInfo>, OxiSqlError> {
self.inner.indexes(table).await
}
async fn foreign_keys(&self, table: &str) -> Result<Vec<ForeignKeyInfo>, OxiSqlError> {
self.inner.foreign_keys(table).await
}
}
pub type RetryPredicate = Arc<dyn Fn(&OxiSqlError) -> bool + Send + Sync>;
#[derive(Clone)]
pub struct RetryPolicy {
pub max_retries: u32,
pub initial_delay_ms: u64,
pub backoff_factor: f64,
pub max_delay_ms: u64,
pub predicate: RetryPredicate,
}
fn default_retry_predicate() -> RetryPredicate {
Arc::new(|e: &OxiSqlError| match e {
OxiSqlError::Timeout(_) => true,
OxiSqlError::Execution(msg) => {
msg.contains("connection reset")
|| msg.contains("broken pipe")
|| msg.contains("connection refused")
|| msg.contains("timed out")
|| msg.contains("temporarily unavailable")
}
_ => false,
})
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_retries: 3,
initial_delay_ms: 100,
backoff_factor: 2.0,
max_delay_ms: 5_000,
predicate: default_retry_predicate(),
}
}
}
impl std::fmt::Debug for RetryPolicy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RetryPolicy")
.field("max_retries", &self.max_retries)
.field("initial_delay_ms", &self.initial_delay_ms)
.field("backoff_factor", &self.backoff_factor)
.field("max_delay_ms", &self.max_delay_ms)
.finish()
}
}
pub struct RetryConnection<C> {
inner: C,
policy: RetryPolicy,
}
impl<C: Connection> RetryConnection<C> {
pub fn new(inner: C, policy: RetryPolicy) -> Self {
Self { inner, policy }
}
pub fn inner(&self) -> &C {
&self.inner
}
pub fn into_inner(self) -> C {
self.inner
}
pub(crate) fn delay_ms(&self, attempt: u32) -> u64 {
let delay =
self.policy.initial_delay_ms as f64 * self.policy.backoff_factor.powi(attempt as i32);
(delay as u64).min(self.policy.max_delay_ms)
}
}
#[async_trait]
impl<C: Connection + Send + Sync> Connection for RetryConnection<C> {
async fn execute(&self, sql: &str, params: &[&dyn ToSqlValue]) -> Result<u64, OxiSqlError> {
let mut last_err: Option<OxiSqlError> = None;
for attempt in 0..=self.policy.max_retries {
match self.inner.execute(sql, params).await {
Ok(n) => return Ok(n),
Err(e) => {
if attempt < self.policy.max_retries && (self.policy.predicate)(&e) {
let delay = self.delay_ms(attempt);
tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
last_err = Some(e);
} else {
return Err(e);
}
}
}
}
Err(last_err.unwrap_or_else(|| OxiSqlError::Other("retry exhausted".into())))
}
async fn query(&self, sql: &str, params: &[&dyn ToSqlValue]) -> Result<Vec<Row>, OxiSqlError> {
let mut last_err: Option<OxiSqlError> = None;
for attempt in 0..=self.policy.max_retries {
match self.inner.query(sql, params).await {
Ok(rows) => return Ok(rows),
Err(e) => {
if attempt < self.policy.max_retries && (self.policy.predicate)(&e) {
let delay = self.delay_ms(attempt);
tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
last_err = Some(e);
} else {
return Err(e);
}
}
}
}
Err(last_err.unwrap_or_else(|| OxiSqlError::Other("retry exhausted".into())))
}
async fn transaction(&self) -> Result<Box<dyn crate::traits::Transaction + '_>, OxiSqlError> {
self.inner.transaction().await
}
async fn execute_batch(&self, sql: &str) -> Result<u64, OxiSqlError> {
self.inner.execute_batch(sql).await
}
async fn ping(&self) -> Result<(), OxiSqlError> {
let mut last_err: Option<OxiSqlError> = None;
for attempt in 0..=self.policy.max_retries {
match self.inner.ping().await {
Ok(()) => return Ok(()),
Err(e) => {
if attempt < self.policy.max_retries && (self.policy.predicate)(&e) {
let delay = self.delay_ms(attempt);
tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
last_err = Some(e);
} else {
return Err(e);
}
}
}
}
Err(last_err.unwrap_or_else(|| OxiSqlError::Other("retry exhausted".into())))
}
async fn prepare(
&self,
sql: &str,
) -> Result<Box<dyn crate::PreparedStatement + '_>, OxiSqlError> {
self.inner.prepare(sql).await
}
async fn tables(&self) -> Result<Vec<TableInfo>, OxiSqlError> {
self.inner.tables().await
}
async fn columns(&self, table: &str) -> Result<Vec<ColumnInfo>, OxiSqlError> {
self.inner.columns(table).await
}
async fn indexes(&self, table: &str) -> Result<Vec<IndexInfo>, OxiSqlError> {
self.inner.indexes(table).await
}
async fn foreign_keys(&self, table: &str) -> Result<Vec<ForeignKeyInfo>, OxiSqlError> {
self.inner.foreign_keys(table).await
}
}
fn truncate_sql(sql: &str) -> String {
const MAX: usize = 80;
let trimmed = sql.trim();
if trimmed.len() <= MAX {
format!(" | {trimmed}")
} else {
let cut = trimmed
.char_indices()
.nth(MAX)
.map(|(i, _)| i)
.unwrap_or(trimmed.len());
format!(" | {}…", &trimmed[..cut])
}
}