diesel_async/sync_connection_wrapper/
mod.rs

1//! This module contains a wrapper type
2//! that provides a [`crate::AsyncConnection`]
3//! implementation for types that implement
4//! [`diesel::Connection`]. Using this type
5//! might be useful for the following usecases:
6//!
7//! * using a sync Connection implementation in async context
8//! * using the same code base for async crates needing multiple backends
9use futures_core::future::BoxFuture;
10use std::error::Error;
11
12#[cfg(feature = "sqlite")]
13mod sqlite;
14
15/// This is a helper trait that allows to customize the
16/// spawning blocking tasks as part of the
17/// [`SyncConnectionWrapper`] type. By default a
18/// tokio runtime and its spawn_blocking function is used.
19pub trait SpawnBlocking {
20    /// This function should allow to execute a
21    /// given blocking task without blocking the caller
22    /// to get the result
23    fn spawn_blocking<'a, R>(
24        &mut self,
25        task: impl FnOnce() -> R + Send + 'static,
26    ) -> BoxFuture<'a, Result<R, Box<dyn Error + Send + Sync + 'static>>>
27    where
28        R: Send + 'static;
29
30    /// This function should be used to construct
31    /// a new runtime instance
32    fn get_runtime() -> Self;
33}
34
35/// A wrapper of a [`diesel::connection::Connection`] usable in async context.
36///
37/// It implements AsyncConnection if [`diesel::connection::Connection`] fullfils requirements:
38/// * it's a [`diesel::connection::LoadConnection`]
39/// * its [`diesel::connection::Connection::Backend`] has a [`diesel::query_builder::BindCollector`] implementing [`diesel::query_builder::MoveableBindCollector`]
40/// * its [`diesel::connection::LoadConnection::Row`] implements [`diesel::row::IntoOwnedRow`]
41///
42/// Internally this wrapper type will use `spawn_blocking` on tokio
43/// to execute the request on the inner connection. This implies a
44/// dependency on tokio and that the runtime is running.
45///
46/// Note that only SQLite is supported at the moment.
47///
48/// # Examples
49///
50/// ```rust
51/// # include!("../doctest_setup.rs");
52/// use diesel_async::RunQueryDsl;
53/// use schema::users;
54///
55/// async fn some_async_fn() {
56/// # let database_url = database_url();
57///          use diesel_async::AsyncConnection;
58///          use diesel::sqlite::SqliteConnection;
59///          let mut conn =
60///          SyncConnectionWrapper::<SqliteConnection>::establish(&database_url).await.unwrap();
61/// # create_tables(&mut conn).await;
62///
63///          let all_users = users::table.load::<(i32, String)>(&mut conn).await.unwrap();
64/// #         assert_eq!(all_users.len(), 2);
65/// }
66///
67/// # #[cfg(feature = "sqlite")]
68/// # #[tokio::main]
69/// # async fn main() {
70/// #    some_async_fn().await;
71/// # }
72/// ```
73#[cfg(feature = "tokio")]
74pub type SyncConnectionWrapper<C, B = self::implementation::Tokio> =
75    self::implementation::SyncConnectionWrapper<C, B>;
76
77/// A wrapper of a [`diesel::connection::Connection`] usable in async context.
78///
79/// It implements AsyncConnection if [`diesel::connection::Connection`] fullfils requirements:
80/// * it's a [`diesel::connection::LoadConnection`]
81/// * its [`diesel::connection::Connection::Backend`] has a [`diesel::query_builder::BindCollector`] implementing [`diesel::query_builder::MoveableBindCollector`]
82/// * its [`diesel::connection::LoadConnection::Row`] implements [`diesel::row::IntoOwnedRow`]
83///
84/// Internally this wrapper type will use `spawn_blocking` on given type implementing [`SpawnBlocking`] trait
85/// to execute the request on the inner connection.
86#[cfg(not(feature = "tokio"))]
87pub use self::implementation::SyncConnectionWrapper;
88
89pub use self::implementation::SyncTransactionManagerWrapper;
90
91mod implementation {
92    use crate::{AsyncConnection, AsyncConnectionCore, SimpleAsyncConnection, TransactionManager};
93    use diesel::backend::{Backend, DieselReserveSpecialization};
94    use diesel::connection::{CacheSize, Instrumentation};
95    use diesel::connection::{
96        Connection, LoadConnection, TransactionManagerStatus, WithMetadataLookup,
97    };
98    use diesel::query_builder::{
99        AsQuery, CollectedQuery, MoveableBindCollector, QueryBuilder, QueryFragment, QueryId,
100    };
101    use diesel::row::IntoOwnedRow;
102    use diesel::{ConnectionResult, QueryResult};
103    use futures_core::stream::BoxStream;
104    use futures_util::{FutureExt, StreamExt, TryFutureExt};
105    use std::marker::PhantomData;
106    use std::sync::{Arc, Mutex};
107
108    use super::*;
109
110    fn from_spawn_blocking_error(
111        error: Box<dyn Error + Send + Sync + 'static>,
112    ) -> diesel::result::Error {
113        diesel::result::Error::DatabaseError(
114            diesel::result::DatabaseErrorKind::UnableToSendCommand,
115            Box::new(error.to_string()),
116        )
117    }
118
119    pub struct SyncConnectionWrapper<C, S> {
120        inner: Arc<Mutex<C>>,
121        runtime: S,
122    }
123
124    impl<C, S> SimpleAsyncConnection for SyncConnectionWrapper<C, S>
125    where
126        C: diesel::connection::Connection + 'static,
127        S: SpawnBlocking + Send,
128    {
129        async fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
130            let query = query.to_string();
131            self.spawn_blocking(move |inner| inner.batch_execute(query.as_str()))
132                .await
133        }
134    }
135
136    impl<C, S, MD, O> AsyncConnectionCore for SyncConnectionWrapper<C, S>
137    where
138        // Backend bounds
139        <C as Connection>::Backend: std::default::Default + DieselReserveSpecialization,
140        <C::Backend as Backend>::QueryBuilder: std::default::Default,
141        // Connection bounds
142        C: Connection + LoadConnection + WithMetadataLookup + 'static,
143        <C as Connection>::TransactionManager: Send,
144        // BindCollector bounds
145        MD: Send + 'static,
146        for<'a> <C::Backend as Backend>::BindCollector<'a>:
147            MoveableBindCollector<C::Backend, BindData = MD> + std::default::Default,
148        // Row bounds
149        O: 'static + Send + for<'conn> diesel::row::Row<'conn, C::Backend>,
150        for<'conn, 'query> <C as LoadConnection>::Row<'conn, 'query>:
151            IntoOwnedRow<'conn, <C as Connection>::Backend, OwnedRow = O>,
152        // SpawnBlocking bounds
153        S: SpawnBlocking + Send,
154    {
155        type LoadFuture<'conn, 'query> =
156            BoxFuture<'query, QueryResult<Self::Stream<'conn, 'query>>>;
157        type ExecuteFuture<'conn, 'query> = BoxFuture<'query, QueryResult<usize>>;
158        type Stream<'conn, 'query> = BoxStream<'static, QueryResult<Self::Row<'conn, 'query>>>;
159        type Row<'conn, 'query> = O;
160        type Backend = <C as Connection>::Backend;
161
162        fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
163        where
164            T: AsQuery + 'query,
165            T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
166        {
167            self.execute_with_prepared_query(source.as_query(), |conn, query| {
168                use diesel::row::IntoOwnedRow;
169                let mut cache = <<<C as LoadConnection>::Row<'_, '_> as IntoOwnedRow<
170                    <C as Connection>::Backend,
171                >>::Cache as Default>::default();
172                let cursor = conn.load(&query)?;
173
174                let size_hint = cursor.size_hint();
175                let mut out = Vec::with_capacity(size_hint.1.unwrap_or(size_hint.0));
176                // we use an explicit loop here to easily propagate possible errors
177                // as early as possible
178                for row in cursor {
179                    out.push(Ok(IntoOwnedRow::into_owned(row?, &mut cache)));
180                }
181
182                Ok(out)
183            })
184            .map_ok(|rows| futures_util::stream::iter(rows).boxed())
185            .boxed()
186        }
187
188        fn execute_returning_count<'query, T>(
189            &mut self,
190            source: T,
191        ) -> Self::ExecuteFuture<'_, 'query>
192        where
193            T: QueryFragment<Self::Backend> + QueryId,
194        {
195            self.execute_with_prepared_query(source, |conn, query| {
196                conn.execute_returning_count(&query)
197            })
198        }
199    }
200
201    impl<C, S, MD, O> AsyncConnection for SyncConnectionWrapper<C, S>
202    where
203        // Backend bounds
204        <C as Connection>::Backend: std::default::Default + DieselReserveSpecialization,
205        <C::Backend as Backend>::QueryBuilder: std::default::Default,
206        // Connection bounds
207        C: Connection + LoadConnection + WithMetadataLookup + 'static,
208        <C as Connection>::TransactionManager: Send,
209        // BindCollector bounds
210        MD: Send + 'static,
211        for<'a> <C::Backend as Backend>::BindCollector<'a>:
212            MoveableBindCollector<C::Backend, BindData = MD> + std::default::Default,
213        // Row bounds
214        O: 'static + Send + for<'conn> diesel::row::Row<'conn, C::Backend>,
215        for<'conn, 'query> <C as LoadConnection>::Row<'conn, 'query>:
216            IntoOwnedRow<'conn, <C as Connection>::Backend, OwnedRow = O>,
217        // SpawnBlocking bounds
218        S: SpawnBlocking + Send,
219    {
220        type TransactionManager =
221            SyncTransactionManagerWrapper<<C as Connection>::TransactionManager>;
222
223        async fn establish(database_url: &str) -> ConnectionResult<Self> {
224            let database_url = database_url.to_string();
225            let mut runtime = S::get_runtime();
226
227            runtime
228                .spawn_blocking(move || C::establish(&database_url))
229                .await
230                .unwrap_or_else(|e| Err(diesel::ConnectionError::BadConnection(e.to_string())))
231                .map(move |c| SyncConnectionWrapper::with_runtime(c, runtime))
232        }
233
234        fn transaction_state(
235            &mut self,
236        ) -> &mut <Self::TransactionManager as TransactionManager<Self>>::TransactionStateData
237        {
238            self.exclusive_connection().transaction_state()
239        }
240
241        fn instrumentation(&mut self) -> &mut dyn Instrumentation {
242            // there should be no other pending future when this is called
243            // that means there is only one instance of this arc and
244            // we can simply access the inner data
245            if let Some(inner) = Arc::get_mut(&mut self.inner) {
246                inner
247                    .get_mut()
248                    .unwrap_or_else(|p| p.into_inner())
249                    .instrumentation()
250            } else {
251                panic!("Cannot access shared instrumentation")
252            }
253        }
254
255        fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
256            // there should be no other pending future when this is called
257            // that means there is only one instance of this arc and
258            // we can simply access the inner data
259            if let Some(inner) = Arc::get_mut(&mut self.inner) {
260                inner
261                    .get_mut()
262                    .unwrap_or_else(|p| p.into_inner())
263                    .set_instrumentation(instrumentation)
264            } else {
265                panic!("Cannot access shared instrumentation")
266            }
267        }
268
269        fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
270            // there should be no other pending future when this is called
271            // that means there is only one instance of this arc and
272            // we can simply access the inner data
273            if let Some(inner) = Arc::get_mut(&mut self.inner) {
274                inner
275                    .get_mut()
276                    .unwrap_or_else(|p| p.into_inner())
277                    .set_prepared_statement_cache_size(size)
278            } else {
279                panic!("Cannot access shared cache")
280            }
281        }
282    }
283
284    /// A wrapper of a diesel transaction manager usable in async context.
285    pub struct SyncTransactionManagerWrapper<T>(PhantomData<T>);
286
287    impl<T, C, S> TransactionManager<SyncConnectionWrapper<C, S>> for SyncTransactionManagerWrapper<T>
288    where
289        SyncConnectionWrapper<C, S>: AsyncConnection,
290        C: Connection + 'static,
291        S: SpawnBlocking,
292        T: diesel::connection::TransactionManager<C> + Send,
293    {
294        type TransactionStateData = T::TransactionStateData;
295
296        async fn begin_transaction(conn: &mut SyncConnectionWrapper<C, S>) -> QueryResult<()> {
297            conn.spawn_blocking(move |inner| T::begin_transaction(inner))
298                .await
299        }
300
301        async fn commit_transaction(conn: &mut SyncConnectionWrapper<C, S>) -> QueryResult<()> {
302            conn.spawn_blocking(move |inner| T::commit_transaction(inner))
303                .await
304        }
305
306        async fn rollback_transaction(conn: &mut SyncConnectionWrapper<C, S>) -> QueryResult<()> {
307            conn.spawn_blocking(move |inner| T::rollback_transaction(inner))
308                .await
309        }
310
311        fn transaction_manager_status_mut(
312            conn: &mut SyncConnectionWrapper<C, S>,
313        ) -> &mut TransactionManagerStatus {
314            T::transaction_manager_status_mut(conn.exclusive_connection())
315        }
316    }
317
318    impl<C, S> SyncConnectionWrapper<C, S> {
319        /// Builds a wrapper with this underlying sync connection
320        pub fn new(connection: C) -> Self
321        where
322            C: Connection,
323            S: SpawnBlocking,
324        {
325            SyncConnectionWrapper {
326                inner: Arc::new(Mutex::new(connection)),
327                runtime: S::get_runtime(),
328            }
329        }
330
331        /// Builds a wrapper with this underlying sync connection
332        /// and runtime for spawning blocking tasks
333        pub fn with_runtime(connection: C, runtime: S) -> Self
334        where
335            C: Connection,
336            S: SpawnBlocking,
337        {
338            SyncConnectionWrapper {
339                inner: Arc::new(Mutex::new(connection)),
340                runtime,
341            }
342        }
343
344        /// Run a operation directly with the inner connection
345        ///
346        /// This function is usful to register custom functions
347        /// and collection for Sqlite for example
348        ///
349        /// # Example
350        ///
351        /// ```rust
352        /// # include!("../doctest_setup.rs");
353        /// # #[tokio::main]
354        /// # async fn main() {
355        /// #     run_test().await.unwrap();
356        /// # }
357        /// #
358        /// # async fn run_test() -> QueryResult<()> {
359        /// #     let mut conn = establish_connection().await;
360        /// conn.spawn_blocking(|conn| {
361        ///    // sqlite.rs sqlite NOCASE only works for ASCII characters,
362        ///    // this collation allows handling UTF-8 (barring locale differences)
363        ///    conn.register_collation("RUSTNOCASE", |rhs, lhs| {
364        ///     rhs.to_lowercase().cmp(&lhs.to_lowercase())
365        ///   })
366        /// }).await
367        ///
368        /// # }
369        /// ```
370        pub fn spawn_blocking<'a, R>(
371            &mut self,
372            task: impl FnOnce(&mut C) -> QueryResult<R> + Send + 'static,
373        ) -> BoxFuture<'a, QueryResult<R>>
374        where
375            C: Connection + 'static,
376            R: Send + 'static,
377            S: SpawnBlocking,
378        {
379            let inner = self.inner.clone();
380            self.runtime
381                .spawn_blocking(move || {
382                    let mut inner = inner.lock().unwrap_or_else(|poison| {
383                        // try to be resilient by providing the guard
384                        inner.clear_poison();
385                        poison.into_inner()
386                    });
387                    task(&mut inner)
388                })
389                .unwrap_or_else(|err| QueryResult::Err(from_spawn_blocking_error(err)))
390                .boxed()
391        }
392
393        fn execute_with_prepared_query<'a, MD, Q, R>(
394            &mut self,
395            query: Q,
396            callback: impl FnOnce(&mut C, &CollectedQuery<MD>) -> QueryResult<R> + Send + 'static,
397        ) -> BoxFuture<'a, QueryResult<R>>
398        where
399            // Backend bounds
400            <C as Connection>::Backend: std::default::Default + DieselReserveSpecialization,
401            <C::Backend as Backend>::QueryBuilder: std::default::Default,
402            // Connection bounds
403            C: Connection + LoadConnection + WithMetadataLookup + 'static,
404            <C as Connection>::TransactionManager: Send,
405            // BindCollector bounds
406            MD: Send + 'static,
407            for<'b> <C::Backend as Backend>::BindCollector<'b>:
408                MoveableBindCollector<C::Backend, BindData = MD> + std::default::Default,
409            // Arguments/Return bounds
410            Q: QueryFragment<C::Backend> + QueryId,
411            R: Send + 'static,
412            // SpawnBlocking bounds
413            S: SpawnBlocking,
414        {
415            let backend = C::Backend::default();
416
417            let (collect_bind_result, collector_data) = {
418                let exclusive = self.inner.clone();
419                let mut inner = exclusive.lock().unwrap_or_else(|poison| {
420                    // try to be resilient by providing the guard
421                    exclusive.clear_poison();
422                    poison.into_inner()
423                });
424                let mut bind_collector =
425                    <<C::Backend as Backend>::BindCollector<'_> as Default>::default();
426                let metadata_lookup = inner.metadata_lookup();
427                let result = query.collect_binds(&mut bind_collector, metadata_lookup, &backend);
428                let collector_data = bind_collector.moveable();
429
430                (result, collector_data)
431            };
432
433            let mut query_builder = <<C::Backend as Backend>::QueryBuilder as Default>::default();
434            let sql = query
435                .to_sql(&mut query_builder, &backend)
436                .map(|_| query_builder.finish());
437            let is_safe_to_cache_prepared = query.is_safe_to_cache_prepared(&backend);
438
439            self.spawn_blocking(|inner| {
440                collect_bind_result?;
441                let query = CollectedQuery::new(sql?, is_safe_to_cache_prepared?, collector_data);
442                callback(inner, &query)
443            })
444        }
445
446        /// Gets an exclusive access to the underlying diesel Connection
447        ///
448        /// It panics in case of shared access.
449        /// This is typically used only used during transaction.
450        pub(self) fn exclusive_connection(&mut self) -> &mut C
451        where
452            C: Connection,
453        {
454            // there should be no other pending future when this is called
455            // that means there is only one instance of this Arc and
456            // we can simply access the inner data
457            if let Some(conn_mutex) = Arc::get_mut(&mut self.inner) {
458                conn_mutex
459                    .get_mut()
460                    .expect("Mutex is poisoned, a thread must have panicked holding it.")
461            } else {
462                panic!("Cannot access shared transaction state")
463            }
464        }
465    }
466
467    #[cfg(any(
468        feature = "deadpool",
469        feature = "bb8",
470        feature = "mobc",
471        feature = "r2d2"
472    ))]
473    impl<C, S> crate::pooled_connection::PoolableConnection for SyncConnectionWrapper<C, S>
474    where
475        Self: AsyncConnection,
476    {
477        fn is_broken(&mut self) -> bool {
478            Self::TransactionManager::is_broken_transaction_manager(self)
479        }
480    }
481
482    #[cfg(feature = "tokio")]
483    pub enum Tokio {
484        Handle(tokio::runtime::Handle),
485        Runtime(tokio::runtime::Runtime),
486    }
487
488    #[cfg(feature = "tokio")]
489    impl SpawnBlocking for Tokio {
490        fn spawn_blocking<'a, R>(
491            &mut self,
492            task: impl FnOnce() -> R + Send + 'static,
493        ) -> BoxFuture<'a, Result<R, Box<dyn Error + Send + Sync + 'static>>>
494        where
495            R: Send + 'static,
496        {
497            let fut = match self {
498                Tokio::Handle(handle) => handle.spawn_blocking(task),
499                Tokio::Runtime(runtime) => runtime.spawn_blocking(task),
500            };
501
502            fut.map_err(Box::from).boxed()
503        }
504
505        fn get_runtime() -> Self {
506            if let Ok(handle) = tokio::runtime::Handle::try_current() {
507                Tokio::Handle(handle)
508            } else {
509                let runtime = tokio::runtime::Builder::new_current_thread()
510                    .build()
511                    .unwrap();
512
513                Tokio::Runtime(runtime)
514            }
515        }
516    }
517}