use core::{
future::Future,
pin::Pin,
task::{ready, Poll},
};
use std::{
sync::{Arc, Mutex},
time::Instant,
};
use diesel::{
connection::{Instrumentation, InstrumentationEvent},
query_builder::{AsQuery, QueryFragment, QueryId},
result::{ConnectionResult, QueryResult},
};
use diesel_async::{
pooled_connection::deadpool::Object, AnsiTransactionManager, AsyncConnection,
AsyncPgConnection, SimpleAsyncConnection, TransactionManager,
};
use futures_core::{future::BoxFuture, stream::BoxStream};
use opentelemetry::{
metrics::{Counter, Histogram},
Key,
};
type Parent = Object<AsyncPgConnection>;
const ERROR_KEY: Key = Key::from_static_str("error");
pub struct DatabaseMetrics {
pub sql_execution_time: Histogram<f64>,
pub sql_error: Counter<u64>,
pub dbpool_connections: Histogram<u64>,
pub dbpool_connections_idle: Histogram<u64>,
}
pub struct MetricsConnection<Conn> {
pub(crate) metrics: Option<Arc<DatabaseMetrics>>,
pub(crate) conn: Conn,
pub(crate) instrumentation: Arc<Mutex<Option<Box<dyn Instrumentation>>>>,
}
fn get_metrics_label_for_error(error: &diesel::result::Error) -> &'static str {
match error {
diesel::result::Error::InvalidCString(_) => "invalid_c_string",
diesel::result::Error::DatabaseError(e, _) => match e {
diesel::result::DatabaseErrorKind::UniqueViolation => "unique_violation",
diesel::result::DatabaseErrorKind::ForeignKeyViolation => "foreign_key_violation",
diesel::result::DatabaseErrorKind::UnableToSendCommand => "unable_to_send_command",
diesel::result::DatabaseErrorKind::SerializationFailure => "serialization_failure",
_ => "unknown",
},
diesel::result::Error::NotFound => unreachable!(),
diesel::result::Error::QueryBuilderError(_) => "query_builder_error",
diesel::result::Error::DeserializationError(_) => "deserialization_error",
diesel::result::Error::SerializationError(_) => "serialization_error",
diesel::result::Error::RollbackTransaction => "rollback_transaction",
diesel::result::Error::AlreadyInTransaction => "already_in_transaction",
_ => "unknown",
}
}
#[async_trait::async_trait]
impl<Conn> SimpleAsyncConnection for MetricsConnection<Conn>
where
Conn: SimpleAsyncConnection + Send,
{
async fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
Instrument {
metrics: self.metrics.clone(),
future: self.conn.batch_execute(query),
start: None,
}
.await
}
}
#[async_trait::async_trait]
impl AsyncConnection for MetricsConnection<Parent> {
type LoadFuture<'conn, 'query> =
Instrument<BoxFuture<'query, QueryResult<Self::Stream<'conn, 'query>>>>;
type ExecuteFuture<'conn, 'query> = Instrument<BoxFuture<'query, QueryResult<usize>>>;
type Stream<'conn, 'query> = BoxStream<'static, QueryResult<Self::Row<'conn, 'query>>>;
type Row<'conn, 'query> = <Parent as AsyncConnection>::Row<'conn, 'query>;
type Backend = <Parent as AsyncConnection>::Backend;
type TransactionManager = AnsiTransactionManager;
async fn establish(database_url: &str) -> ConnectionResult<Self> {
let mut instrumentation = diesel::connection::get_default_instrumentation();
instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
database_url,
));
Parent::establish(database_url).await.map(|conn| Self {
metrics: None,
conn,
instrumentation: Arc::new(Mutex::new(instrumentation)),
})
}
#[doc(hidden)]
fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
where
T: AsQuery + 'query,
T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
{
Instrument {
metrics: self.metrics.clone(),
future: self.conn.load(source),
start: None,
}
}
fn execute_returning_count<'conn, 'query, T>(
&'conn mut self,
source: T,
) -> Self::ExecuteFuture<'conn, 'query>
where
T: QueryFragment<Self::Backend> + QueryId + 'query,
{
Instrument {
metrics: self.metrics.clone(),
future: self.conn.execute_returning_count(source),
start: None,
}
}
fn transaction_state(
&mut self,
) -> &mut <Self::TransactionManager as TransactionManager<Self>>::TransactionStateData {
self.conn.transaction_state()
}
fn instrumentation(&mut self) -> &mut dyn Instrumentation {
let Some(instrumentation) = Arc::get_mut(&mut self.instrumentation) else {
panic!("Cannot access shared instrumentation")
};
instrumentation.get_mut().unwrap_or_else(|p| p.into_inner())
}
fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
self.instrumentation = Arc::new(std::sync::Mutex::new(Some(Box::new(instrumentation))));
}
}
pin_project_lite::pin_project! {
pub struct Instrument<F> {
metrics: Option<Arc<DatabaseMetrics>>,
#[pin]
future: F,
start: Option<Instant>,
}
}
impl<F, T> Future for Instrument<F>
where
F: Future<Output = diesel::result::QueryResult<T>>,
{
type Output = F::Output;
fn poll(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let this = self.project();
if let Some(metrics) = &this.metrics {
let start = this.start.get_or_insert_with(Instant::now);
match ready!(this.future.poll(cx)) {
res @ (Ok(_) | Err(diesel::result::Error::NotFound)) => {
metrics
.sql_execution_time
.record(start.elapsed().as_secs_f64(), &[]);
Poll::Ready(res)
}
Err(e) => {
let labels = &[ERROR_KEY.string(get_metrics_label_for_error(&e))];
metrics.sql_error.add(1, labels);
Poll::Ready(Err(e))
}
}
} else {
this.future.poll(cx)
}
}
}