use std::borrow::Cow;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;
use futures::Stream;
use futures::stream::BoxStream;
use opentelemetry::trace::{SpanKind, Status, TraceContextExt, Tracer};
use opentelemetry::{Context as OtelContext, KeyValue};
use opentelemetry_semantic_conventions::attribute;
use crate::annotations::QueryAnnotations;
use crate::attributes::{self, ConnectionAttributes, QueryTextMode};
use crate::database::Database;
use crate::metrics::Metrics;
fn build_attributes(
attrs: &ConnectionAttributes,
sql: Option<&str>,
annotations: Option<&QueryAnnotations>,
) -> Vec<KeyValue> {
let mut kv = attrs.base_key_values();
if let Some(ann) = annotations {
if let Some(ref op) = ann.operation {
kv.push(KeyValue::new(attribute::DB_OPERATION_NAME, op.clone()));
}
if let Some(ref coll) = ann.collection {
kv.push(KeyValue::new(attribute::DB_COLLECTION_NAME, coll.clone()));
}
if let Some(ref summary) = ann.query_summary {
kv.push(KeyValue::new(attribute::DB_QUERY_SUMMARY, summary.clone()));
}
if let Some(ref sp) = ann.stored_procedure {
kv.push(KeyValue::new(
attribute::DB_STORED_PROCEDURE_NAME,
sp.clone(),
));
}
}
if let Some(sql) = sql {
match attrs.query_text_mode {
QueryTextMode::Full => {
kv.push(KeyValue::new(attribute::DB_QUERY_TEXT, sql.to_owned()));
}
QueryTextMode::Obfuscated => {
kv.push(KeyValue::new(
attribute::DB_QUERY_TEXT,
crate::obfuscate::obfuscate(sql),
));
}
QueryTextMode::Off => {}
}
}
kv
}
fn start_span(name: &str, span_attrs: Vec<KeyValue>) -> (OtelContext, Instant) {
let tracer = opentelemetry::global::tracer("sqlx-otel");
let span = tracer
.span_builder(name.to_owned())
.with_kind(SpanKind::Client)
.with_attributes(span_attrs)
.start(&tracer);
let cx = OtelContext::current_with_span(span);
(cx, Instant::now())
}
fn begin_query_span(
attrs: &ConnectionAttributes,
sql: Option<&str>,
annotations: Option<&QueryAnnotations>,
) -> (OtelContext, Instant, Vec<KeyValue>) {
let (op, coll, summary) = annotations.map_or((None, None, None), |a| {
(
a.operation.as_deref(),
a.collection.as_deref(),
a.query_summary.as_deref(),
)
});
let name = attributes::span_name(attrs.system, op, coll, summary);
let span_attrs = build_attributes(attrs, sql, annotations);
let metric_attrs = attrs.base_key_values();
let (cx, start) = start_span(&name, span_attrs);
(cx, start, metric_attrs)
}
fn error_type(err: &sqlx::Error) -> &'static str {
match err {
sqlx::Error::Configuration(_) => "Configuration",
sqlx::Error::Database(_) => "Database",
sqlx::Error::Io(_) => "Io",
sqlx::Error::Tls(_) => "Tls",
sqlx::Error::Protocol(_) => "Protocol",
sqlx::Error::RowNotFound => "RowNotFound",
sqlx::Error::TypeNotFound { .. } => "TypeNotFound",
sqlx::Error::ColumnIndexOutOfBounds { .. } => "ColumnIndexOutOfBounds",
sqlx::Error::ColumnNotFound(_) => "ColumnNotFound",
sqlx::Error::ColumnDecode { .. } => "ColumnDecode",
sqlx::Error::Decode(_) => "Decode",
sqlx::Error::AnyDriverError(_) => "AnyDriverError",
sqlx::Error::PoolTimedOut => "PoolTimedOut",
sqlx::Error::PoolClosed => "PoolClosed",
sqlx::Error::WorkerCrashed => "WorkerCrashed",
sqlx::Error::Migrate(_) => "Migrate",
_ => "Unknown",
}
}
fn record_error(cx: &OtelContext, err: &sqlx::Error) {
let span = cx.span();
span.set_status(Status::Error {
description: Cow::Owned(err.to_string()),
});
span.set_attribute(KeyValue::new(attribute::ERROR_TYPE, error_type(err)));
if let sqlx::Error::Database(db_err) = err {
if let Some(code) = db_err.code() {
span.set_attribute(KeyValue::new(
attribute::DB_RESPONSE_STATUS_CODE,
code.into_owned(),
));
}
}
span.add_event(
"exception",
vec![
KeyValue::new("exception.type", error_type(err)),
KeyValue::new("exception.message", err.to_string()),
],
);
}
fn record_rows(cx: &OtelContext, rows: u64) {
cx.span().set_attribute(KeyValue::new(
attribute::DB_RESPONSE_RETURNED_ROWS,
i64::try_from(rows).unwrap_or(i64::MAX),
));
}
fn record_affected_rows(cx: &OtelContext, rows: u64) {
cx.span().set_attribute(KeyValue::new(
"db.response.affected_rows",
i64::try_from(rows).unwrap_or(i64::MAX),
));
}
fn finish(
cx: &OtelContext,
start: Instant,
rows: Option<u64>,
metrics: &Metrics,
attrs: &[KeyValue],
) {
cx.span().end();
metrics.record(start.elapsed(), rows, attrs);
}
async fn execute_instrumented<T>(
fut: futures::future::BoxFuture<'_, Result<T, sqlx::Error>>,
cx: OtelContext,
start: Instant,
metrics: std::sync::Arc<Metrics>,
metric_attrs: Vec<KeyValue>,
) -> Result<T, sqlx::Error> {
let result = fut.await;
if let Err(err) = &result {
record_error(&cx, err);
}
finish(&cx, start, None, &metrics, &metric_attrs);
result
}
trait RowCounter<T> {
fn count(item: &T) -> u64;
}
struct CountAll;
impl<T> RowCounter<T> for CountAll {
fn count(_item: &T) -> u64 {
1
}
}
struct CountRight;
impl<L, R> RowCounter<sqlx::Either<L, R>> for CountRight {
fn count(item: &sqlx::Either<L, R>) -> u64 {
u64::from(item.is_right())
}
}
struct CountNone;
impl<T> RowCounter<T> for CountNone {
fn count(_item: &T) -> u64 {
0
}
}
struct InstrumentedStream<S, C> {
inner: S,
cx: OtelContext,
start: Instant,
rows: u64,
metrics: std::sync::Arc<Metrics>,
metric_attrs: Vec<KeyValue>,
finished: bool,
_counter: std::marker::PhantomData<C>,
}
impl<S, C> InstrumentedStream<S, C> {
fn new(
inner: S,
cx: OtelContext,
start: Instant,
metrics: std::sync::Arc<Metrics>,
metric_attrs: Vec<KeyValue>,
) -> Self {
Self {
inner,
cx,
start,
rows: 0,
metrics,
metric_attrs,
finished: false,
_counter: std::marker::PhantomData,
}
}
fn complete(&mut self) {
if !self.finished {
self.finished = true;
record_rows(&self.cx, self.rows);
finish(
&self.cx,
self.start,
Some(self.rows),
&self.metrics,
&self.metric_attrs,
);
}
}
}
impl<S: Unpin, C> Unpin for InstrumentedStream<S, C> {}
impl<S, T, C> Stream for InstrumentedStream<S, C>
where
S: Stream<Item = Result<T, sqlx::Error>> + Unpin,
C: RowCounter<T>,
{
type Item = Result<T, sqlx::Error>;
fn poll_next(mut self: Pin<&mut Self>, task_cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.inner).poll_next(task_cx) {
Poll::Ready(Some(Ok(item))) => {
self.rows += C::count(&item);
Poll::Ready(Some(Ok(item)))
}
Poll::Ready(Some(Err(err))) => {
record_error(&self.cx, &err);
Poll::Ready(Some(Err(err)))
}
Poll::Ready(None) => {
self.complete();
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
impl<S, C> Drop for InstrumentedStream<S, C> {
fn drop(&mut self) {
self.complete();
}
}
macro_rules! impl_executor {
($ty:ty, $self_:ident => $inner:expr) => {
impl_executor!(@impl $ty, $self_ => $inner, None);
};
($ty:ty, $self_:ident => $inner:expr, annotations: $ann:expr) => {
impl_executor!(@impl $ty, $self_ => $inner, $ann);
};
(@impl $ty:ty, $self_:ident => $inner:expr, $ann:expr) => {
impl<'c, DB> sqlx::Executor<'c> for $ty
where
DB: Database,
for<'a> &'a mut DB::Connection: sqlx::Executor<'a, Database = DB>,
{
type Database = DB;
fn execute<'e, 'q: 'e, E>(
$self_,
query: E,
) -> futures::future::BoxFuture<
'e,
Result<<DB as sqlx::Database>::QueryResult, sqlx::Error>,
>
where
E: 'q + sqlx::Execute<'q, DB>,
'c: 'e,
{
let sql = query.sql().to_owned();
let state = $self_.state.clone();
let (cx, start, metric_attrs) =
begin_query_span(&state.attrs, Some(&sql), $ann);
let fut = ($inner).execute(query);
Box::pin(async move {
let result = fut.await;
match &result {
Ok(qr) => {
record_affected_rows(&cx, DB::rows_affected(qr));
}
Err(err) => {
record_error(&cx, err);
}
}
finish(&cx, start, None, &state.metrics, &metric_attrs);
result
})
}
fn execute_many<'e, 'q: 'e, E>(
$self_,
query: E,
) -> BoxStream<'e, Result<<DB as sqlx::Database>::QueryResult, sqlx::Error>>
where
E: 'q + sqlx::Execute<'q, DB>,
'c: 'e,
{
let sql = query.sql().to_owned();
let state = $self_.state.clone();
let (cx, start, metric_attrs) =
begin_query_span(&state.attrs, Some(&sql), $ann);
let stream = ($inner).execute_many(query);
Box::pin(InstrumentedStream::<_, CountNone>::new(
stream,
cx,
start,
state.metrics,
metric_attrs,
))
}
fn fetch<'e, 'q: 'e, E>(
$self_,
query: E,
) -> BoxStream<'e, Result<<DB as sqlx::Database>::Row, sqlx::Error>>
where
E: 'q + sqlx::Execute<'q, DB>,
'c: 'e,
{
let sql = query.sql().to_owned();
let state = $self_.state.clone();
let (cx, start, metric_attrs) =
begin_query_span(&state.attrs, Some(&sql), $ann);
let stream = ($inner).fetch(query);
Box::pin(InstrumentedStream::<_, CountAll>::new(
stream,
cx,
start,
state.metrics,
metric_attrs,
))
}
fn fetch_many<'e, 'q: 'e, E>(
$self_,
query: E,
) -> BoxStream<
'e,
Result<
sqlx::Either<
<DB as sqlx::Database>::QueryResult,
<DB as sqlx::Database>::Row,
>,
sqlx::Error,
>,
>
where
E: 'q + sqlx::Execute<'q, DB>,
'c: 'e,
{
let sql = query.sql().to_owned();
let state = $self_.state.clone();
let (cx, start, metric_attrs) =
begin_query_span(&state.attrs, Some(&sql), $ann);
let stream = ($inner).fetch_many(query);
Box::pin(InstrumentedStream::<_, CountRight>::new(
stream,
cx,
start,
state.metrics,
metric_attrs,
))
}
fn fetch_all<'e, 'q: 'e, E>(
$self_,
query: E,
) -> futures::future::BoxFuture<
'e,
Result<Vec<<DB as sqlx::Database>::Row>, sqlx::Error>,
>
where
E: 'q + sqlx::Execute<'q, DB>,
'c: 'e,
{
let sql = query.sql().to_owned();
let state = $self_.state.clone();
let (cx, start, metric_attrs) =
begin_query_span(&state.attrs, Some(&sql), $ann);
let fut = ($inner).fetch_all(query);
Box::pin(async move {
let result = fut.await;
match &result {
Ok(rows) => {
let count = rows.len() as u64;
record_rows(&cx, count);
finish(&cx, start, Some(count), &state.metrics, &metric_attrs);
}
Err(err) => {
record_error(&cx, err);
finish(&cx, start, None, &state.metrics, &metric_attrs);
}
}
result
})
}
fn fetch_one<'e, 'q: 'e, E>(
$self_,
query: E,
) -> futures::future::BoxFuture<
'e,
Result<<DB as sqlx::Database>::Row, sqlx::Error>,
>
where
E: 'q + sqlx::Execute<'q, DB>,
'c: 'e,
{
let sql = query.sql().to_owned();
let state = $self_.state.clone();
let (cx, start, metric_attrs) =
begin_query_span(&state.attrs, Some(&sql), $ann);
let fut = ($inner).fetch_one(query);
Box::pin(async move {
let result = fut.await;
match &result {
Ok(_) => {
record_rows(&cx, 1);
finish(&cx, start, Some(1), &state.metrics, &metric_attrs);
}
Err(err) => {
record_error(&cx, err);
finish(&cx, start, None, &state.metrics, &metric_attrs);
}
}
result
})
}
fn fetch_optional<'e, 'q: 'e, E>(
$self_,
query: E,
) -> futures::future::BoxFuture<
'e,
Result<Option<<DB as sqlx::Database>::Row>, sqlx::Error>,
>
where
E: 'q + sqlx::Execute<'q, DB>,
'c: 'e,
{
let sql = query.sql().to_owned();
let state = $self_.state.clone();
let (cx, start, metric_attrs) =
begin_query_span(&state.attrs, Some(&sql), $ann);
let fut = ($inner).fetch_optional(query);
Box::pin(async move {
let result = fut.await;
match &result {
Ok(maybe_row) => {
let count = u64::from(maybe_row.is_some());
record_rows(&cx, count);
finish(&cx, start, Some(count), &state.metrics, &metric_attrs);
}
Err(err) => {
record_error(&cx, err);
finish(&cx, start, None, &state.metrics, &metric_attrs);
}
}
result
})
}
fn prepare<'e, 'q: 'e>(
$self_,
query: &'q str,
) -> futures::future::BoxFuture<
'e,
Result<<DB as sqlx::Database>::Statement<'q>, sqlx::Error>,
>
where
'c: 'e,
{
let state = $self_.state.clone();
let (cx, start, metric_attrs) = begin_query_span(&state.attrs, Some(query), $ann);
let fut = ($inner).prepare(query);
Box::pin(execute_instrumented(
fut, cx, start, state.metrics, metric_attrs,
))
}
fn prepare_with<'e, 'q: 'e>(
$self_,
sql: &'q str,
parameters: &'e [<DB as sqlx::Database>::TypeInfo],
) -> futures::future::BoxFuture<
'e,
Result<<DB as sqlx::Database>::Statement<'q>, sqlx::Error>,
>
where
'c: 'e,
{
let state = $self_.state.clone();
let (cx, start, metric_attrs) = begin_query_span(&state.attrs, Some(sql), $ann);
let fut = ($inner).prepare_with(sql, parameters);
Box::pin(execute_instrumented(
fut, cx, start, state.metrics, metric_attrs,
))
}
#[doc(hidden)]
fn describe<'e, 'q: 'e>(
$self_,
sql: &'q str,
) -> futures::future::BoxFuture<
'e,
Result<sqlx::Describe<DB>, sqlx::Error>,
>
where
'c: 'e,
{
let state = $self_.state.clone();
let (cx, start, metric_attrs) = begin_query_span(&state.attrs, Some(sql), $ann);
let fut = ($inner).describe(sql);
Box::pin(execute_instrumented(
fut, cx, start, state.metrics, metric_attrs,
))
}
}
};
}
impl_executor!(&'_ crate::Pool<DB>, self => &self.inner);
impl_executor!(&'c mut crate::PoolConnection<DB>, self => self.inner.as_mut());
impl_executor!(&'c mut crate::Transaction<'_, DB>, self => &mut *self.inner);
impl_executor!(
crate::annotations::Annotated<'c, crate::Pool<DB>>,
self => &self.inner.inner,
annotations: Some(&self.annotations)
);
impl_executor!(
crate::annotations::AnnotatedMut<'c, crate::PoolConnection<DB>>,
self => self.inner.inner.as_mut(),
annotations: Some(&self.annotations)
);
impl_executor!(
crate::annotations::AnnotatedMut<'c, crate::Transaction<'_, DB>>,
self => &mut *self.inner.inner,
annotations: Some(&self.annotations)
);
#[cfg(test)]
mod tests {
use super::*;
use crate::attributes::ConnectionAttributes;
#[test]
fn error_type_classification() {
assert_eq!(error_type(&sqlx::Error::RowNotFound), "RowNotFound");
assert_eq!(error_type(&sqlx::Error::PoolTimedOut), "PoolTimedOut");
assert_eq!(error_type(&sqlx::Error::PoolClosed), "PoolClosed");
assert_eq!(error_type(&sqlx::Error::WorkerCrashed), "WorkerCrashed");
assert_eq!(
error_type(&sqlx::Error::Configuration("bad".into())),
"Configuration"
);
assert_eq!(
error_type(&sqlx::Error::Io(std::io::Error::other("test"))),
"Io"
);
assert_eq!(error_type(&sqlx::Error::Tls("tls".into())), "Tls");
assert_eq!(
error_type(&sqlx::Error::Protocol("proto".into())),
"Protocol"
);
assert_eq!(error_type(&sqlx::Error::Decode("dec".into())), "Decode");
assert_eq!(
error_type(&sqlx::Error::AnyDriverError("any".into())),
"AnyDriverError"
);
assert_eq!(
error_type(&sqlx::Error::ColumnNotFound("x".into())),
"ColumnNotFound"
);
assert_eq!(
error_type(&sqlx::Error::ColumnIndexOutOfBounds { index: 5, len: 3 }),
"ColumnIndexOutOfBounds"
);
assert_eq!(
error_type(&sqlx::Error::ColumnDecode {
index: "0".into(),
source: "bad".into(),
}),
"ColumnDecode"
);
assert_eq!(
error_type(&sqlx::Error::TypeNotFound {
type_name: "Foo".into(),
}),
"TypeNotFound"
);
assert_eq!(
error_type(&sqlx::Error::Migrate(Box::new(
sqlx::migrate::MigrateError::Execute(sqlx::Error::Protocol("test".into()))
))),
"Migrate"
);
}
fn test_attrs() -> ConnectionAttributes {
ConnectionAttributes {
system: "postgresql",
host: Some("localhost".into()),
port: Some(5432),
namespace: Some("mydb".into()),
network_peer_address: None,
network_peer_port: None,
query_text_mode: QueryTextMode::Full,
}
}
#[test]
fn build_attributes_with_full_query_text() {
let attrs = test_attrs();
let kv = build_attributes(&attrs, Some("SELECT 1"), None);
let keys: Vec<&str> = kv.iter().map(|k| k.key.as_str()).collect();
assert!(keys.contains(&"db.query.text"));
}
#[test]
fn build_attributes_with_off_query_text() {
let mut attrs = test_attrs();
attrs.query_text_mode = QueryTextMode::Off;
let kv = build_attributes(&attrs, Some("SELECT 1"), None);
let keys: Vec<&str> = kv.iter().map(|k| k.key.as_str()).collect();
assert!(!keys.contains(&"db.query.text"));
}
#[test]
fn build_attributes_obfuscated_replaces_literals() {
let mut attrs = test_attrs();
attrs.query_text_mode = QueryTextMode::Obfuscated;
let kv = build_attributes(
&attrs,
Some("INSERT INTO t (id, name) VALUES (1, 'alice')"),
None,
);
let text = kv
.iter()
.find(|k| k.key.as_str() == "db.query.text")
.map(|k| k.value.clone());
assert_eq!(
text,
Some(opentelemetry::Value::String(
"INSERT INTO t (id, name) VALUES (?, ?)".into()
))
);
}
#[test]
fn build_attributes_no_sql_no_annotations() {
let attrs = test_attrs();
let kv = build_attributes(&attrs, None, None);
let keys: Vec<&str> = kv.iter().map(|k| k.key.as_str()).collect();
assert!(!keys.contains(&"db.query.text"));
assert!(!keys.contains(&"db.operation.name"));
assert!(!keys.contains(&"db.collection.name"));
assert!(!keys.contains(&"db.query.summary"));
assert!(!keys.contains(&"db.stored_procedure.name"));
assert!(keys.contains(&"db.system.name"));
}
#[test]
fn build_attributes_with_all_annotation_fields() {
let attrs = test_attrs();
let ann = QueryAnnotations::new()
.operation("SELECT")
.collection("users")
.query_summary("SELECT users")
.stored_procedure("sp_get");
let kv = build_attributes(&attrs, Some("SELECT * FROM users"), Some(&ann));
let find = |key: &str| {
kv.iter()
.find(|k| k.key.as_str() == key)
.map(|k| k.value.clone())
};
assert_eq!(
find("db.operation.name"),
Some(opentelemetry::Value::String("SELECT".into()))
);
assert_eq!(
find("db.collection.name"),
Some(opentelemetry::Value::String("users".into()))
);
assert_eq!(
find("db.query.summary"),
Some(opentelemetry::Value::String("SELECT users".into()))
);
assert_eq!(
find("db.stored_procedure.name"),
Some(opentelemetry::Value::String("sp_get".into()))
);
assert_eq!(
find("db.query.text"),
Some(opentelemetry::Value::String("SELECT * FROM users".into()))
);
}
#[test]
fn build_attributes_annotation_field_permutations() {
type Setter = fn(QueryAnnotations) -> QueryAnnotations;
let attrs = test_attrs();
let fields: &[(&str, Setter)] = &[
("db.operation.name", |a| a.operation("SELECT")),
("db.collection.name", |a| a.collection("users")),
("db.query.summary", |a| a.query_summary("SELECT users")),
("db.stored_procedure.name", |a| a.stored_procedure("sp")),
];
for mask in 0u8..16 {
let mut ann = QueryAnnotations::new();
for (i, &(_, setter)) in fields.iter().enumerate() {
if mask & (1 << i) != 0 {
ann = setter(ann);
}
}
let kv = build_attributes(&attrs, None, Some(&ann));
let keys: Vec<&str> = kv.iter().map(|k| k.key.as_str()).collect();
for (i, &(key, _)) in fields.iter().enumerate() {
println!(
"mask: {:08b}, field: {}, key: {}; contains: {}",
mask,
i,
key,
keys.contains(&key)
);
if mask & (1 << i) != 0 {
assert!(
keys.contains(&key),
"{key} should be present for mask {mask:#06b}"
);
} else {
assert!(
!keys.contains(&key),
"{key} should be absent for mask {mask:#06b}"
);
}
}
}
}
}