1use core::{
6 future::Future,
7 pin::Pin,
8 task::{ready, Poll},
9};
10use std::{
11 sync::{Arc, Mutex},
12 time::Instant,
13};
14
15use diesel::{
16 connection::{Instrumentation, InstrumentationEvent},
17 query_builder::{AsQuery, QueryFragment, QueryId},
18 result::{ConnectionResult, QueryResult},
19};
20use diesel_async::{
21 pooled_connection::deadpool::Object, AnsiTransactionManager, AsyncConnection,
22 AsyncPgConnection, SimpleAsyncConnection, TransactionManager,
23};
24use futures_core::{future::BoxFuture, stream::BoxStream};
25use opentelemetry::{
26 metrics::{Counter, Histogram},
27 Key,
28};
29
30type Parent = Object<AsyncPgConnection>;
31
32const ERROR_KEY: Key = Key::from_static_str("error");
33
34pub struct DatabaseMetrics {
35 pub sql_execution_time: Histogram<f64>,
36 pub sql_error: Counter<u64>,
37 pub dbpool_connections: Histogram<u64>,
38 pub dbpool_connections_idle: Histogram<u64>,
39}
40
41pub struct MetricsConnection<Conn> {
42 pub(crate) metrics: Option<Arc<DatabaseMetrics>>,
43 pub(crate) conn: Conn,
44 pub(crate) instrumentation: Arc<Mutex<Option<Box<dyn Instrumentation>>>>,
45}
46
47fn get_metrics_label_for_error(error: &diesel::result::Error) -> &'static str {
48 match error {
49 diesel::result::Error::InvalidCString(_) => "invalid_c_string",
50 diesel::result::Error::DatabaseError(e, _) => match e {
51 diesel::result::DatabaseErrorKind::UniqueViolation => "unique_violation",
52 diesel::result::DatabaseErrorKind::ForeignKeyViolation => "foreign_key_violation",
53 diesel::result::DatabaseErrorKind::UnableToSendCommand => "unable_to_send_command",
54 diesel::result::DatabaseErrorKind::SerializationFailure => "serialization_failure",
55 _ => "unknown",
56 },
57 diesel::result::Error::NotFound => unreachable!(),
58 diesel::result::Error::QueryBuilderError(_) => "query_builder_error",
59 diesel::result::Error::DeserializationError(_) => "deserialization_error",
60 diesel::result::Error::SerializationError(_) => "serialization_error",
61 diesel::result::Error::RollbackTransaction => "rollback_transaction",
62 diesel::result::Error::AlreadyInTransaction => "already_in_transaction",
63 _ => "unknown",
64 }
65}
66
67#[async_trait::async_trait]
68impl<Conn> SimpleAsyncConnection for MetricsConnection<Conn>
69where
70 Conn: SimpleAsyncConnection + Send,
71{
72 async fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
73 Instrument {
74 metrics: self.metrics.clone(),
75 future: self.conn.batch_execute(query),
76 start: None,
77 }
78 .await
79 }
80}
81
82#[async_trait::async_trait]
83impl AsyncConnection for MetricsConnection<Parent> {
84 type LoadFuture<'conn, 'query> =
85 Instrument<BoxFuture<'query, QueryResult<Self::Stream<'conn, 'query>>>>;
86 type ExecuteFuture<'conn, 'query> = Instrument<BoxFuture<'query, QueryResult<usize>>>;
87 type Stream<'conn, 'query> = BoxStream<'static, QueryResult<Self::Row<'conn, 'query>>>;
88 type Row<'conn, 'query> = <Parent as AsyncConnection>::Row<'conn, 'query>;
89 type Backend = <Parent as AsyncConnection>::Backend;
90 type TransactionManager = AnsiTransactionManager;
91
92 async fn establish(database_url: &str) -> ConnectionResult<Self> {
93 let mut instrumentation = diesel::connection::get_default_instrumentation();
94 instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
95 database_url,
96 ));
97
98 Parent::establish(database_url).await.map(|conn| Self {
99 metrics: None,
100 conn,
101 instrumentation: Arc::new(Mutex::new(instrumentation)),
102 })
103 }
104
105 #[doc(hidden)]
106 fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
107 where
108 T: AsQuery + 'query,
109 T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
110 {
111 Instrument {
112 metrics: self.metrics.clone(),
113 future: self.conn.load(source),
114 start: None,
115 }
116 }
117
118 fn execute_returning_count<'conn, 'query, T>(
119 &'conn mut self,
120 source: T,
121 ) -> Self::ExecuteFuture<'conn, 'query>
122 where
123 T: QueryFragment<Self::Backend> + QueryId + 'query,
124 {
125 Instrument {
126 metrics: self.metrics.clone(),
127 future: self.conn.execute_returning_count(source),
128 start: None,
129 }
130 }
131
132 fn transaction_state(
138 &mut self,
139 ) -> &mut <Self::TransactionManager as TransactionManager<Self>>::TransactionStateData {
140 self.conn.transaction_state()
141 }
142
143 fn instrumentation(&mut self) -> &mut dyn Instrumentation {
144 let Some(instrumentation) = Arc::get_mut(&mut self.instrumentation) else {
145 panic!("Cannot access shared instrumentation")
146 };
147
148 instrumentation.get_mut().unwrap_or_else(|p| p.into_inner())
149 }
150
151 fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
152 self.instrumentation = Arc::new(std::sync::Mutex::new(Some(Box::new(instrumentation))));
153 }
154}
155
156pin_project_lite::pin_project! {
157 pub struct Instrument<F> {
158 metrics: Option<Arc<DatabaseMetrics>>,
159 #[pin]
160 future: F,
161 start: Option<Instant>,
162 }
163}
164
165impl<F, T> Future for Instrument<F>
166where
167 F: Future<Output = diesel::result::QueryResult<T>>,
168{
169 type Output = F::Output;
170
171 fn poll(
172 self: Pin<&mut Self>,
173 cx: &mut std::task::Context<'_>,
174 ) -> std::task::Poll<Self::Output> {
175 let this = self.project();
176
177 if let Some(metrics) = &this.metrics {
178 let start = this.start.get_or_insert_with(Instant::now);
179
180 match ready!(this.future.poll(cx)) {
181 res @ (Ok(_) | Err(diesel::result::Error::NotFound)) => {
182 metrics
183 .sql_execution_time
184 .record(start.elapsed().as_secs_f64(), &[]);
185
186 Poll::Ready(res)
187 }
188 Err(e) => {
189 let labels = &[ERROR_KEY.string(get_metrics_label_for_error(&e))];
190 metrics.sql_error.add(1, labels);
191
192 Poll::Ready(Err(e))
193 }
194 }
195 } else {
196 this.future.poll(cx)
197 }
198 }
199}