opentalk_database/
metrics.rs

1// SPDX-FileCopyrightText: OpenTalk GmbH <mail@opentalk.eu>
2//
3// SPDX-License-Identifier: EUPL-1.2
4
5use core::{
6    future::Future,
7    pin::Pin,
8    task::{ready, Poll},
9};
10use std::{
11    sync::{Arc, Mutex},
12    time::Instant,
13};
14
15use diesel::{
16    connection::{Instrumentation, InstrumentationEvent},
17    query_builder::{AsQuery, QueryFragment, QueryId},
18    result::{ConnectionResult, QueryResult},
19};
20use diesel_async::{
21    pooled_connection::deadpool::Object, AnsiTransactionManager, AsyncConnection,
22    AsyncPgConnection, SimpleAsyncConnection, TransactionManager,
23};
24use futures_core::{future::BoxFuture, stream::BoxStream};
25use opentelemetry::{
26    metrics::{Counter, Histogram},
27    Key,
28};
29
30type Parent = Object<AsyncPgConnection>;
31
32const ERROR_KEY: Key = Key::from_static_str("error");
33
34pub struct DatabaseMetrics {
35    pub sql_execution_time: Histogram<f64>,
36    pub sql_error: Counter<u64>,
37    pub dbpool_connections: Histogram<u64>,
38    pub dbpool_connections_idle: Histogram<u64>,
39}
40
41pub struct MetricsConnection<Conn> {
42    pub(crate) metrics: Option<Arc<DatabaseMetrics>>,
43    pub(crate) conn: Conn,
44    pub(crate) instrumentation: Arc<Mutex<Option<Box<dyn Instrumentation>>>>,
45}
46
47fn get_metrics_label_for_error(error: &diesel::result::Error) -> &'static str {
48    match error {
49        diesel::result::Error::InvalidCString(_) => "invalid_c_string",
50        diesel::result::Error::DatabaseError(e, _) => match e {
51            diesel::result::DatabaseErrorKind::UniqueViolation => "unique_violation",
52            diesel::result::DatabaseErrorKind::ForeignKeyViolation => "foreign_key_violation",
53            diesel::result::DatabaseErrorKind::UnableToSendCommand => "unable_to_send_command",
54            diesel::result::DatabaseErrorKind::SerializationFailure => "serialization_failure",
55            _ => "unknown",
56        },
57        diesel::result::Error::NotFound => unreachable!(),
58        diesel::result::Error::QueryBuilderError(_) => "query_builder_error",
59        diesel::result::Error::DeserializationError(_) => "deserialization_error",
60        diesel::result::Error::SerializationError(_) => "serialization_error",
61        diesel::result::Error::RollbackTransaction => "rollback_transaction",
62        diesel::result::Error::AlreadyInTransaction => "already_in_transaction",
63        _ => "unknown",
64    }
65}
66
67#[async_trait::async_trait]
68impl<Conn> SimpleAsyncConnection for MetricsConnection<Conn>
69where
70    Conn: SimpleAsyncConnection + Send,
71{
72    async fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
73        Instrument {
74            metrics: self.metrics.clone(),
75            future: self.conn.batch_execute(query),
76            start: None,
77        }
78        .await
79    }
80}
81
82#[async_trait::async_trait]
83impl AsyncConnection for MetricsConnection<Parent> {
84    type LoadFuture<'conn, 'query> =
85        Instrument<BoxFuture<'query, QueryResult<Self::Stream<'conn, 'query>>>>;
86    type ExecuteFuture<'conn, 'query> = Instrument<BoxFuture<'query, QueryResult<usize>>>;
87    type Stream<'conn, 'query> = BoxStream<'static, QueryResult<Self::Row<'conn, 'query>>>;
88    type Row<'conn, 'query> = <Parent as AsyncConnection>::Row<'conn, 'query>;
89    type Backend = <Parent as AsyncConnection>::Backend;
90    type TransactionManager = AnsiTransactionManager;
91
92    async fn establish(database_url: &str) -> ConnectionResult<Self> {
93        let mut instrumentation = diesel::connection::get_default_instrumentation();
94        instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
95            database_url,
96        ));
97
98        Parent::establish(database_url).await.map(|conn| Self {
99            metrics: None,
100            conn,
101            instrumentation: Arc::new(Mutex::new(instrumentation)),
102        })
103    }
104
105    #[doc(hidden)]
106    fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
107    where
108        T: AsQuery + 'query,
109        T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
110    {
111        Instrument {
112            metrics: self.metrics.clone(),
113            future: self.conn.load(source),
114            start: None,
115        }
116    }
117
118    fn execute_returning_count<'conn, 'query, T>(
119        &'conn mut self,
120        source: T,
121    ) -> Self::ExecuteFuture<'conn, 'query>
122    where
123        T: QueryFragment<Self::Backend> + QueryId + 'query,
124    {
125        Instrument {
126            metrics: self.metrics.clone(),
127            future: self.conn.execute_returning_count(source),
128            start: None,
129        }
130    }
131
132    /// Get access to the current transaction state of this connection
133    ///
134    /// Hidden in `diesel` behind the
135    /// `i-implement-a-third-party-backend-and-opt-into-breaking-changes` feature flag,
136    /// therefore not generally visible in the `diesel` generated docs.
137    fn transaction_state(
138        &mut self,
139    ) -> &mut <Self::TransactionManager as TransactionManager<Self>>::TransactionStateData {
140        self.conn.transaction_state()
141    }
142
143    fn instrumentation(&mut self) -> &mut dyn Instrumentation {
144        let Some(instrumentation) = Arc::get_mut(&mut self.instrumentation) else {
145            panic!("Cannot access shared instrumentation")
146        };
147
148        instrumentation.get_mut().unwrap_or_else(|p| p.into_inner())
149    }
150
151    fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
152        self.instrumentation = Arc::new(std::sync::Mutex::new(Some(Box::new(instrumentation))));
153    }
154}
155
156pin_project_lite::pin_project! {
157    pub struct Instrument<F> {
158        metrics: Option<Arc<DatabaseMetrics>>,
159        #[pin]
160        future: F,
161        start: Option<Instant>,
162    }
163}
164
165impl<F, T> Future for Instrument<F>
166where
167    F: Future<Output = diesel::result::QueryResult<T>>,
168{
169    type Output = F::Output;
170
171    fn poll(
172        self: Pin<&mut Self>,
173        cx: &mut std::task::Context<'_>,
174    ) -> std::task::Poll<Self::Output> {
175        let this = self.project();
176
177        if let Some(metrics) = &this.metrics {
178            let start = this.start.get_or_insert_with(Instant::now);
179
180            match ready!(this.future.poll(cx)) {
181                res @ (Ok(_) | Err(diesel::result::Error::NotFound)) => {
182                    metrics
183                        .sql_execution_time
184                        .record(start.elapsed().as_secs_f64(), &[]);
185
186                    Poll::Ready(res)
187                }
188                Err(e) => {
189                    let labels = &[ERROR_KEY.string(get_metrics_label_for_error(&e))];
190                    metrics.sql_error.add(1, labels);
191
192                    Poll::Ready(Err(e))
193                }
194            }
195        } else {
196            this.future.poll(cx)
197        }
198    }
199}