diesel_async/
async_connection_wrapper.rs

1//! This module contains an wrapper type
2//! that provides a [`diesel::Connection`]
3//! implementation for types that implement
4//! [`crate::AsyncConnection`]. Using this type
5//! might be useful for the following usecases:
6//!
7//! * Executing migrations on application startup
8//! * Using a pure rust diesel connection implementation
9//!   as replacement for the existing connection
10//!   implementations provided by diesel
11
12use futures_core::Stream;
13use futures_util::StreamExt;
14use std::future::Future;
15use std::pin::Pin;
16
17/// This is a helper trait that allows to customize the
18/// async runtime used to execute futures as part of the
19/// [`AsyncConnectionWrapper`] type. By default a
20/// tokio runtime is used.
21pub trait BlockOn {
22    /// This function should allow to execute a
23    /// given future to get the result
24    fn block_on<F>(&self, f: F) -> F::Output
25    where
26        F: Future;
27
28    /// This function should be used to construct
29    /// a new runtime instance
30    fn get_runtime() -> Self;
31}
32
33/// A helper type that wraps an [`AsyncConnection`][crate::AsyncConnection] to
34/// provide a sync [`diesel::Connection`] implementation.
35///
36/// Internally this wrapper type will use `block_on` to wait for
37/// the execution of futures from the inner connection. This implies you
38/// cannot use functions of this type in a scope with an already existing
39/// tokio runtime. If you are in a situation where you want to use this
40/// connection wrapper in the scope of an existing tokio runtime (for example
41/// for running migrations via `diesel_migration`) you need to wrap
42/// the relevant code block into a `tokio::task::spawn_blocking` task.
43///
44/// # Examples
45///
46/// ```rust,no_run
47/// # include!("doctest_setup.rs");
48/// use schema::users;
49/// use diesel_async::async_connection_wrapper::AsyncConnectionWrapper;
50/// #
51/// # fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
52/// use diesel::prelude::{RunQueryDsl, Connection};
53/// # let database_url = database_url();
54/// let mut conn = AsyncConnectionWrapper::<DbConnection>::establish(&database_url)?;
55///
56/// let all_users = users::table.load::<(i32, String)>(&mut conn)?;
57/// # assert_eq!(all_users.len(), 0);
58/// # Ok(())
59/// # }
60/// ```
61///
62/// If you are in the scope of an existing tokio runtime you need to use
63/// `tokio::task::spawn_blocking` to encapsulate the blocking tasks
64/// ```rust,no_run
65/// # include!("doctest_setup.rs");
66/// use schema::users;
67/// use diesel_async::async_connection_wrapper::AsyncConnectionWrapper;
68///
69/// async fn some_async_fn() {
70/// # let database_url = database_url();
71///      // need to use `spawn_blocking` to execute
72///      // a blocking task in the scope of an existing runtime
73///      let res = tokio::task::spawn_blocking(move || {
74///          use diesel::prelude::{RunQueryDsl, Connection};
75///          let mut conn = AsyncConnectionWrapper::<DbConnection>::establish(&database_url)?;
76///
77///          let all_users = users::table.load::<(i32, String)>(&mut conn)?;
78/// #         assert_eq!(all_users.len(), 0);
79///          Ok::<_, Box<dyn std::error::Error + Send + Sync>>(())
80///      }).await;
81///
82/// # res.unwrap().unwrap();
83/// }
84///
85/// # #[tokio::main]
86/// # async fn main() {
87/// #    some_async_fn().await;
88/// # }
89/// ```
90#[cfg(feature = "tokio")]
91pub type AsyncConnectionWrapper<C, B = self::implementation::Tokio> =
92    self::implementation::AsyncConnectionWrapper<C, B>;
93
94/// A helper type that wraps an [`crate::AsyncConnectionWrapper`] to
95/// provide a sync [`diesel::Connection`] implementation.
96///
97/// Internally this wrapper type will use `block_on` to wait for
98/// the execution of futures from the inner connection.
99#[cfg(not(feature = "tokio"))]
100pub use self::implementation::AsyncConnectionWrapper;
101
102mod implementation {
103    use diesel::connection::{Instrumentation, SimpleConnection};
104    use std::ops::{Deref, DerefMut};
105
106    use super::*;
107
108    pub struct AsyncConnectionWrapper<C, B> {
109        inner: C,
110        runtime: B,
111    }
112
113    impl<C, B> From<C> for AsyncConnectionWrapper<C, B>
114    where
115        C: crate::AsyncConnection,
116        B: BlockOn + Send,
117    {
118        fn from(inner: C) -> Self {
119            Self {
120                inner,
121                runtime: B::get_runtime(),
122            }
123        }
124    }
125
126    impl<C, B> Deref for AsyncConnectionWrapper<C, B> {
127        type Target = C;
128
129        fn deref(&self) -> &Self::Target {
130            &self.inner
131        }
132    }
133
134    impl<C, B> DerefMut for AsyncConnectionWrapper<C, B> {
135        fn deref_mut(&mut self) -> &mut Self::Target {
136            &mut self.inner
137        }
138    }
139
140    impl<C, B> diesel::connection::SimpleConnection for AsyncConnectionWrapper<C, B>
141    where
142        C: crate::SimpleAsyncConnection,
143        B: BlockOn,
144    {
145        fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
146            let f = self.inner.batch_execute(query);
147            self.runtime.block_on(f)
148        }
149    }
150
151    impl<C, B> diesel::connection::ConnectionSealed for AsyncConnectionWrapper<C, B> {}
152
153    impl<C, B> diesel::connection::Connection for AsyncConnectionWrapper<C, B>
154    where
155        C: crate::AsyncConnection,
156        B: BlockOn + Send,
157    {
158        type Backend = C::Backend;
159
160        type TransactionManager = AsyncConnectionWrapperTransactionManagerWrapper;
161
162        fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
163            let runtime = B::get_runtime();
164            let f = C::establish(database_url);
165            let inner = runtime.block_on(f)?;
166            Ok(Self { inner, runtime })
167        }
168
169        fn execute_returning_count<T>(&mut self, source: &T) -> diesel::QueryResult<usize>
170        where
171            T: diesel::query_builder::QueryFragment<Self::Backend> + diesel::query_builder::QueryId,
172        {
173            let f = self.inner.execute_returning_count(source);
174            self.runtime.block_on(f)
175        }
176
177        fn transaction_state(
178            &mut self,
179        ) -> &mut <Self::TransactionManager as diesel::connection::TransactionManager<Self>>::TransactionStateData{
180            self.inner.transaction_state()
181        }
182
183        fn instrumentation(&mut self) -> &mut dyn Instrumentation {
184            self.inner.instrumentation()
185        }
186
187        fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
188            self.inner.set_instrumentation(instrumentation);
189        }
190    }
191
192    impl<C, B> diesel::connection::LoadConnection for AsyncConnectionWrapper<C, B>
193    where
194        C: crate::AsyncConnection,
195        B: BlockOn + Send,
196    {
197        type Cursor<'conn, 'query>
198            = AsyncCursorWrapper<'conn, C::Stream<'conn, 'query>, B>
199        where
200            Self: 'conn;
201
202        type Row<'conn, 'query>
203            = C::Row<'conn, 'query>
204        where
205            Self: 'conn;
206
207        fn load<'conn, 'query, T>(
208            &'conn mut self,
209            source: T,
210        ) -> diesel::QueryResult<Self::Cursor<'conn, 'query>>
211        where
212            T: diesel::query_builder::Query
213                + diesel::query_builder::QueryFragment<Self::Backend>
214                + diesel::query_builder::QueryId
215                + 'query,
216            Self::Backend: diesel::expression::QueryMetadata<T::SqlType>,
217        {
218            let f = self.inner.load(source);
219            let stream = self.runtime.block_on(f)?;
220
221            Ok(AsyncCursorWrapper {
222                stream: Box::pin(stream),
223                runtime: &self.runtime,
224            })
225        }
226    }
227
228    pub struct AsyncCursorWrapper<'a, S, B> {
229        stream: Pin<Box<S>>,
230        runtime: &'a B,
231    }
232
233    impl<S, B> Iterator for AsyncCursorWrapper<'_, S, B>
234    where
235        S: Stream,
236        B: BlockOn,
237    {
238        type Item = S::Item;
239
240        fn next(&mut self) -> Option<Self::Item> {
241            let f = self.stream.next();
242            self.runtime.block_on(f)
243        }
244    }
245
246    pub struct AsyncConnectionWrapperTransactionManagerWrapper;
247
248    impl<C, B> diesel::connection::TransactionManager<AsyncConnectionWrapper<C, B>>
249        for AsyncConnectionWrapperTransactionManagerWrapper
250    where
251        C: crate::AsyncConnection,
252        B: BlockOn + Send,
253    {
254        type TransactionStateData =
255            <C::TransactionManager as crate::TransactionManager<C>>::TransactionStateData;
256
257        fn begin_transaction(conn: &mut AsyncConnectionWrapper<C, B>) -> diesel::QueryResult<()> {
258            let f = <C::TransactionManager as crate::TransactionManager<_>>::begin_transaction(
259                &mut conn.inner,
260            );
261            conn.runtime.block_on(f)
262        }
263
264        fn rollback_transaction(
265            conn: &mut AsyncConnectionWrapper<C, B>,
266        ) -> diesel::QueryResult<()> {
267            let f = <C::TransactionManager as crate::TransactionManager<_>>::rollback_transaction(
268                &mut conn.inner,
269            );
270            conn.runtime.block_on(f)
271        }
272
273        fn commit_transaction(conn: &mut AsyncConnectionWrapper<C, B>) -> diesel::QueryResult<()> {
274            let f = <C::TransactionManager as crate::TransactionManager<_>>::commit_transaction(
275                &mut conn.inner,
276            );
277            conn.runtime.block_on(f)
278        }
279
280        fn transaction_manager_status_mut(
281            conn: &mut AsyncConnectionWrapper<C, B>,
282        ) -> &mut diesel::connection::TransactionManagerStatus {
283            <C::TransactionManager as crate::TransactionManager<_>>::transaction_manager_status_mut(
284                &mut conn.inner,
285            )
286        }
287
288        fn is_broken_transaction_manager(conn: &mut AsyncConnectionWrapper<C, B>) -> bool {
289            <C::TransactionManager as crate::TransactionManager<_>>::is_broken_transaction_manager(
290                &mut conn.inner,
291            )
292        }
293    }
294
295    #[cfg(feature = "r2d2")]
296    impl<C, B> diesel::r2d2::R2D2Connection for AsyncConnectionWrapper<C, B>
297    where
298        B: BlockOn,
299        Self: diesel::Connection,
300        C: crate::AsyncConnection<Backend = <Self as diesel::Connection>::Backend>
301            + crate::pooled_connection::PoolableConnection
302            + 'static,
303        diesel::dsl::select<diesel::dsl::AsExprOf<i32, diesel::sql_types::Integer>>:
304            crate::methods::ExecuteDsl<C>,
305        diesel::query_builder::SqlQuery: crate::methods::ExecuteDsl<C>,
306    {
307        fn ping(&mut self) -> diesel::QueryResult<()> {
308            let fut = crate::pooled_connection::PoolableConnection::ping(
309                &mut self.inner,
310                &crate::pooled_connection::RecyclingMethod::Verified,
311            );
312            self.runtime.block_on(fut)
313        }
314
315        fn is_broken(&mut self) -> bool {
316            crate::pooled_connection::PoolableConnection::is_broken(&mut self.inner)
317        }
318    }
319
320    impl<C, B> diesel::migration::MigrationConnection for AsyncConnectionWrapper<C, B>
321    where
322        B: BlockOn,
323        Self: diesel::Connection,
324    {
325        fn setup(&mut self) -> diesel::QueryResult<usize> {
326            self.batch_execute(diesel::migration::CREATE_MIGRATIONS_TABLE)
327                .map(|()| 0)
328        }
329    }
330
331    #[cfg(feature = "tokio")]
332    pub struct Tokio {
333        handle: Option<tokio::runtime::Handle>,
334        runtime: Option<tokio::runtime::Runtime>,
335    }
336
337    #[cfg(feature = "tokio")]
338    impl BlockOn for Tokio {
339        fn block_on<F>(&self, f: F) -> F::Output
340        where
341            F: Future,
342        {
343            if let Some(handle) = &self.handle {
344                handle.block_on(f)
345            } else if let Some(runtime) = &self.runtime {
346                runtime.block_on(f)
347            } else {
348                unreachable!()
349            }
350        }
351
352        fn get_runtime() -> Self {
353            if let Ok(handle) = tokio::runtime::Handle::try_current() {
354                Self {
355                    handle: Some(handle),
356                    runtime: None,
357                }
358            } else {
359                let runtime = tokio::runtime::Builder::new_current_thread()
360                    .enable_io()
361                    .build()
362                    .unwrap();
363                Self {
364                    handle: None,
365                    runtime: Some(runtime),
366                }
367            }
368        }
369    }
370}