diesel_async/
async_connection_wrapper.rs

1//! This module contains an wrapper type
2//! that provides a [`diesel::Connection`]
3//! implementation for types that implement
4//! [`crate::AsyncConnection`]. Using this type
5//! might be useful for the following usecases:
6//!
7//! * Executing migrations on application startup
8//! * Using a pure rust diesel connection implementation
9//!   as replacement for the existing connection
10//!   implementations provided by diesel
11
12use futures_core::Stream;
13use futures_util::StreamExt;
14use std::future::Future;
15use std::pin::Pin;
16
17/// This is a helper trait that allows to customize the
18/// async runtime used to execute futures as part of the
19/// [`AsyncConnectionWrapper`] type. By default a
20/// tokio runtime is used.
21pub trait BlockOn {
22    /// This function should allow to execute a
23    /// given future to get the result
24    fn block_on<F>(&self, f: F) -> F::Output
25    where
26        F: Future;
27
28    /// This function should be used to construct
29    /// a new runtime instance
30    fn get_runtime() -> Self;
31}
32
33/// A helper type that wraps an [`AsyncConnection`][crate::AsyncConnection] to
34/// provide a sync [`diesel::Connection`] implementation.
35///
36/// Internally this wrapper type will use `block_on` to wait for
37/// the execution of futures from the inner connection. This implies you
38/// cannot use functions of this type in a scope with an already existing
39/// tokio runtime. If you are in a situation where you want to use this
40/// connection wrapper in the scope of an existing tokio runtime (for example
41/// for running migrations via `diesel_migration`) you need to wrap
42/// the relevant code block into a `tokio::task::spawn_blocking` task.
43///
44/// # Examples
45///
46/// ```rust,no_run
47/// # include!("doctest_setup.rs");
48/// use schema::users;
49/// use diesel_async::async_connection_wrapper::AsyncConnectionWrapper;
50/// #
51/// # fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
52/// use diesel::prelude::{RunQueryDsl, Connection};
53/// # let database_url = database_url();
54/// let mut conn = AsyncConnectionWrapper::<DbConnection>::establish(&database_url)?;
55///
56/// let all_users = users::table.load::<(i32, String)>(&mut conn)?;
57/// # assert_eq!(all_users.len(), 0);
58/// # Ok(())
59/// # }
60/// ```
61///
62/// If you are in the scope of an existing tokio runtime you need to use
63/// `tokio::task::spawn_blocking` to encapsulate the blocking tasks
64/// ```rust,no_run
65/// # include!("doctest_setup.rs");
66/// use schema::users;
67/// use diesel_async::async_connection_wrapper::AsyncConnectionWrapper;
68///
69/// async fn some_async_fn() {
70/// # let database_url = database_url();
71///      // need to use `spawn_blocking` to execute
72///      // a blocking task in the scope of an existing runtime
73///      let res = tokio::task::spawn_blocking(move || {
74///          use diesel::prelude::{RunQueryDsl, Connection};
75///          let mut conn = AsyncConnectionWrapper::<DbConnection>::establish(&database_url)?;
76///
77///          let all_users = users::table.load::<(i32, String)>(&mut conn)?;
78/// #         assert_eq!(all_users.len(), 0);
79///          Ok::<_, Box<dyn std::error::Error + Send + Sync>>(())
80///      }).await;
81///
82/// # res.unwrap().unwrap();
83/// }
84///
85/// # #[tokio::main]
86/// # async fn main() {
87/// #    some_async_fn().await;
88/// # }
89/// ```
90#[cfg(feature = "tokio")]
91pub type AsyncConnectionWrapper<C, B = self::implementation::Tokio> =
92    self::implementation::AsyncConnectionWrapper<C, B>;
93
94/// A helper type that wraps an [`crate::AsyncConnectionWrapper`] to
95/// provide a sync [`diesel::Connection`] implementation.
96///
97/// Internally this wrapper type will use `block_on` to wait for
98/// the execution of futures from the inner connection.
99#[cfg(not(feature = "tokio"))]
100pub use self::implementation::AsyncConnectionWrapper;
101
102pub(crate) mod implementation {
103    use diesel::connection::{CacheSize, Instrumentation, SimpleConnection};
104    use std::ops::{Deref, DerefMut};
105
106    use super::*;
107
108    pub struct AsyncConnectionWrapper<C, B> {
109        inner: C,
110        runtime: B,
111    }
112
113    impl<C, B> From<C> for AsyncConnectionWrapper<C, B>
114    where
115        C: crate::AsyncConnection,
116        B: BlockOn + Send,
117    {
118        fn from(inner: C) -> Self {
119            Self {
120                inner,
121                runtime: B::get_runtime(),
122            }
123        }
124    }
125
126    impl<C, B> AsyncConnectionWrapper<C, B>
127    where
128        C: crate::AsyncConnection,
129    {
130        /// Consumes the [`AsyncConnectionWrapper`] returning the wrapped inner
131        /// [`AsyncConnection`].
132        pub fn into_inner(self) -> C {
133            self.inner
134        }
135    }
136
137    impl<C, B> Deref for AsyncConnectionWrapper<C, B> {
138        type Target = C;
139
140        fn deref(&self) -> &Self::Target {
141            &self.inner
142        }
143    }
144
145    impl<C, B> DerefMut for AsyncConnectionWrapper<C, B> {
146        fn deref_mut(&mut self) -> &mut Self::Target {
147            &mut self.inner
148        }
149    }
150
151    impl<C, B> diesel::connection::SimpleConnection for AsyncConnectionWrapper<C, B>
152    where
153        C: crate::SimpleAsyncConnection,
154        B: BlockOn,
155    {
156        fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
157            let f = self.inner.batch_execute(query);
158            self.runtime.block_on(f)
159        }
160    }
161
162    impl<C, B> diesel::connection::ConnectionSealed for AsyncConnectionWrapper<C, B> {}
163
164    impl<C, B> diesel::connection::Connection for AsyncConnectionWrapper<C, B>
165    where
166        C: crate::AsyncConnection,
167        B: BlockOn + Send,
168    {
169        type Backend = C::Backend;
170
171        type TransactionManager = AsyncConnectionWrapperTransactionManagerWrapper;
172
173        fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
174            let runtime = B::get_runtime();
175            let f = C::establish(database_url);
176            let inner = runtime.block_on(f)?;
177            Ok(Self { inner, runtime })
178        }
179
180        fn execute_returning_count<T>(&mut self, source: &T) -> diesel::QueryResult<usize>
181        where
182            T: diesel::query_builder::QueryFragment<Self::Backend> + diesel::query_builder::QueryId,
183        {
184            let f = self.inner.execute_returning_count(source);
185            self.runtime.block_on(f)
186        }
187
188        fn transaction_state(
189            &mut self,
190        ) -> &mut <Self::TransactionManager as diesel::connection::TransactionManager<Self>>::TransactionStateData{
191            self.inner.transaction_state()
192        }
193
194        fn instrumentation(&mut self) -> &mut dyn Instrumentation {
195            self.inner.instrumentation()
196        }
197
198        fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
199            self.inner.set_instrumentation(instrumentation);
200        }
201
202        fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
203            self.inner.set_prepared_statement_cache_size(size)
204        }
205    }
206
207    impl<C, B> diesel::connection::LoadConnection for AsyncConnectionWrapper<C, B>
208    where
209        C: crate::AsyncConnection,
210        B: BlockOn + Send,
211    {
212        type Cursor<'conn, 'query>
213            = AsyncCursorWrapper<'conn, C::Stream<'conn, 'query>, B>
214        where
215            Self: 'conn;
216
217        type Row<'conn, 'query>
218            = C::Row<'conn, 'query>
219        where
220            Self: 'conn;
221
222        fn load<'conn, 'query, T>(
223            &'conn mut self,
224            source: T,
225        ) -> diesel::QueryResult<Self::Cursor<'conn, 'query>>
226        where
227            T: diesel::query_builder::Query
228                + diesel::query_builder::QueryFragment<Self::Backend>
229                + diesel::query_builder::QueryId
230                + 'query,
231            Self::Backend: diesel::expression::QueryMetadata<T::SqlType>,
232        {
233            let f = self.inner.load(source);
234            let stream = self.runtime.block_on(f)?;
235
236            Ok(AsyncCursorWrapper {
237                stream: Box::pin(stream),
238                runtime: &self.runtime,
239            })
240        }
241    }
242
243    pub struct AsyncCursorWrapper<'a, S, B> {
244        stream: Pin<Box<S>>,
245        runtime: &'a B,
246    }
247
248    impl<S, B> Iterator for AsyncCursorWrapper<'_, S, B>
249    where
250        S: Stream,
251        B: BlockOn,
252    {
253        type Item = S::Item;
254
255        fn next(&mut self) -> Option<Self::Item> {
256            let f = self.stream.next();
257            self.runtime.block_on(f)
258        }
259    }
260
261    pub struct AsyncConnectionWrapperTransactionManagerWrapper;
262
263    impl<C, B> diesel::connection::TransactionManager<AsyncConnectionWrapper<C, B>>
264        for AsyncConnectionWrapperTransactionManagerWrapper
265    where
266        C: crate::AsyncConnection,
267        B: BlockOn + Send,
268    {
269        type TransactionStateData =
270            <C::TransactionManager as crate::TransactionManager<C>>::TransactionStateData;
271
272        fn begin_transaction(conn: &mut AsyncConnectionWrapper<C, B>) -> diesel::QueryResult<()> {
273            let f = <C::TransactionManager as crate::TransactionManager<_>>::begin_transaction(
274                &mut conn.inner,
275            );
276            conn.runtime.block_on(f)
277        }
278
279        fn rollback_transaction(
280            conn: &mut AsyncConnectionWrapper<C, B>,
281        ) -> diesel::QueryResult<()> {
282            let f = <C::TransactionManager as crate::TransactionManager<_>>::rollback_transaction(
283                &mut conn.inner,
284            );
285            conn.runtime.block_on(f)
286        }
287
288        fn commit_transaction(conn: &mut AsyncConnectionWrapper<C, B>) -> diesel::QueryResult<()> {
289            let f = <C::TransactionManager as crate::TransactionManager<_>>::commit_transaction(
290                &mut conn.inner,
291            );
292            conn.runtime.block_on(f)
293        }
294
295        fn transaction_manager_status_mut(
296            conn: &mut AsyncConnectionWrapper<C, B>,
297        ) -> &mut diesel::connection::TransactionManagerStatus {
298            <C::TransactionManager as crate::TransactionManager<_>>::transaction_manager_status_mut(
299                &mut conn.inner,
300            )
301        }
302
303        fn is_broken_transaction_manager(conn: &mut AsyncConnectionWrapper<C, B>) -> bool {
304            <C::TransactionManager as crate::TransactionManager<_>>::is_broken_transaction_manager(
305                &mut conn.inner,
306            )
307        }
308    }
309
310    #[cfg(feature = "r2d2")]
311    impl<C, B> diesel::r2d2::R2D2Connection for AsyncConnectionWrapper<C, B>
312    where
313        B: BlockOn,
314        Self: diesel::Connection,
315        C: crate::AsyncConnection<Backend = <Self as diesel::Connection>::Backend>
316            + crate::pooled_connection::PoolableConnection
317            + 'static,
318        diesel::dsl::select<diesel::dsl::AsExprOf<i32, diesel::sql_types::Integer>>:
319            crate::methods::ExecuteDsl<C>,
320        diesel::query_builder::SqlQuery: crate::methods::ExecuteDsl<C>,
321    {
322        fn ping(&mut self) -> diesel::QueryResult<()> {
323            let fut = crate::pooled_connection::PoolableConnection::ping(
324                &mut self.inner,
325                &crate::pooled_connection::RecyclingMethod::Verified,
326            );
327            self.runtime.block_on(fut)
328        }
329
330        fn is_broken(&mut self) -> bool {
331            crate::pooled_connection::PoolableConnection::is_broken(&mut self.inner)
332        }
333    }
334
335    impl<C, B> diesel::migration::MigrationConnection for AsyncConnectionWrapper<C, B>
336    where
337        B: BlockOn,
338        Self: diesel::Connection,
339    {
340        fn setup(&mut self) -> diesel::QueryResult<usize> {
341            self.batch_execute(diesel::migration::CREATE_MIGRATIONS_TABLE)
342                .map(|()| 0)
343        }
344    }
345
346    #[cfg(feature = "tokio")]
347    pub struct Tokio {
348        handle: Option<tokio::runtime::Handle>,
349        runtime: Option<tokio::runtime::Runtime>,
350    }
351
352    #[cfg(feature = "tokio")]
353    impl BlockOn for Tokio {
354        fn block_on<F>(&self, f: F) -> F::Output
355        where
356            F: Future,
357        {
358            if let Some(handle) = &self.handle {
359                handle.block_on(f)
360            } else if let Some(runtime) = &self.runtime {
361                runtime.block_on(f)
362            } else {
363                unreachable!()
364            }
365        }
366
367        fn get_runtime() -> Self {
368            if let Ok(handle) = tokio::runtime::Handle::try_current() {
369                Self {
370                    handle: Some(handle),
371                    runtime: None,
372                }
373            } else {
374                let runtime = tokio::runtime::Builder::new_current_thread()
375                    .enable_io()
376                    .build()
377                    .unwrap();
378                Self {
379                    handle: None,
380                    runtime: Some(runtime),
381                }
382            }
383        }
384    }
385}