use tracing::field::{Field, Visit};
use tracing::subscriber::Interest;
use tracing::{Event, Metadata, Subscriber};
use tracing_subscriber::layer::{Context, Filter};
use tracing_subscriber::registry::LookupSpan;
use tracing_subscriber::Layer;
use crate::instant::Instant;
use crate::lib_on::sql::{init_sql_state, send_sql_event, SqlEvent};
const SQLX_QUERY_TARGET: &str = "sqlx::query";
pub(crate) struct HotpathSqlLayer;
impl<S> Layer<S> for HotpathSqlLayer
where
S: Subscriber + for<'a> LookupSpan<'a>,
{
fn on_event(&self, event: &Event<'_>, _ctx: Context<'_, S>) {
let mut visitor = QueryVisitor::default();
event.record(&mut visitor);
let Some(sql) = visitor.statement.or(visitor.summary) else {
return;
};
let now = Instant::now();
send_sql_event(SqlEvent::Executed {
sql: sql.into(),
duration_nanos: visitor.elapsed_ns.unwrap_or(0),
elapsed_ns: crate::lib_on::elapsed_since_start_ns(now),
});
}
}
#[derive(Default)]
struct QueryVisitor {
statement: Option<String>,
summary: Option<String>,
elapsed_ns: Option<u64>,
}
impl Visit for QueryVisitor {
fn record_str(&mut self, field: &Field, value: &str) {
match field.name() {
"db.statement" => {
let trimmed = value.trim();
if !trimmed.is_empty() {
self.statement = Some(trimmed.to_string());
}
}
"summary" => self.summary = Some(value.trim().to_string()),
_ => {}
}
}
fn record_f64(&mut self, field: &Field, value: f64) {
if field.name() == "elapsed_secs" {
self.elapsed_ns = Some((value * 1e9) as u64);
}
}
fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) {
if field.name() == "elapsed" && self.elapsed_ns.is_none() {
self.elapsed_ns = parse_duration_debug(&format!("{value:?}"));
}
}
}
fn parse_duration_debug(s: &str) -> Option<u64> {
let s = s.trim();
let (num, scale) = if let Some(rest) = s.strip_suffix("ns") {
(rest, 1.0)
} else if let Some(rest) = s.strip_suffix("µs").or_else(|| s.strip_suffix("us")) {
(rest, 1_000.0)
} else if let Some(rest) = s.strip_suffix("ms") {
(rest, 1_000_000.0)
} else if let Some(rest) = s.strip_suffix('s') {
(rest, 1_000_000_000.0)
} else {
return None;
};
num.trim().parse::<f64>().ok().map(|v| (v * scale) as u64)
}
struct SqlxQueryFilter;
impl<S> Filter<S> for SqlxQueryFilter {
fn enabled(&self, meta: &Metadata<'_>, _ctx: &Context<'_, S>) -> bool {
meta.target() == SQLX_QUERY_TARGET
}
fn callsite_enabled(&self, meta: &Metadata<'_>) -> Interest {
if meta.target() == SQLX_QUERY_TARGET {
Interest::always()
} else {
Interest::never()
}
}
}
pub fn sqlx_tracing_layer<S>() -> impl Layer<S>
where
S: Subscriber + for<'a> LookupSpan<'a>,
{
init_sql_state();
HotpathSqlLayer.with_filter(SqlxQueryFilter)
}