diesel_async/sync_connection_wrapper/
mod.rs1use futures_core::future::BoxFuture;
10use std::error::Error;
11
12#[cfg(feature = "sqlite")]
13mod sqlite;
14
15pub trait SpawnBlocking {
20 fn spawn_blocking<'a, R>(
24 &mut self,
25 task: impl FnOnce() -> R + Send + 'static,
26 ) -> BoxFuture<'a, Result<R, Box<dyn Error + Send + Sync + 'static>>>
27 where
28 R: Send + 'static;
29
30 fn get_runtime() -> Self;
33}
34
35#[cfg(feature = "tokio")]
74pub type SyncConnectionWrapper<C, B = self::implementation::Tokio> =
75 self::implementation::SyncConnectionWrapper<C, B>;
76
77#[cfg(not(feature = "tokio"))]
87pub use self::implementation::SyncConnectionWrapper;
88
89pub use self::implementation::SyncTransactionManagerWrapper;
90
91mod implementation {
92 use crate::{AsyncConnection, AsyncConnectionCore, SimpleAsyncConnection, TransactionManager};
93 use diesel::backend::{Backend, DieselReserveSpecialization};
94 use diesel::connection::{CacheSize, Instrumentation};
95 use diesel::connection::{
96 Connection, LoadConnection, TransactionManagerStatus, WithMetadataLookup,
97 };
98 use diesel::query_builder::{
99 AsQuery, CollectedQuery, MoveableBindCollector, QueryBuilder, QueryFragment, QueryId,
100 };
101 use diesel::row::IntoOwnedRow;
102 use diesel::{ConnectionResult, QueryResult};
103 use futures_core::stream::BoxStream;
104 use futures_util::{FutureExt, StreamExt, TryFutureExt};
105 use std::marker::PhantomData;
106 use std::sync::{Arc, Mutex};
107
108 use super::*;
109
110 fn from_spawn_blocking_error(
111 error: Box<dyn Error + Send + Sync + 'static>,
112 ) -> diesel::result::Error {
113 diesel::result::Error::DatabaseError(
114 diesel::result::DatabaseErrorKind::UnableToSendCommand,
115 Box::new(error.to_string()),
116 )
117 }
118
119 pub struct SyncConnectionWrapper<C, S> {
120 inner: Arc<Mutex<C>>,
121 runtime: S,
122 }
123
124 impl<C, S> SimpleAsyncConnection for SyncConnectionWrapper<C, S>
125 where
126 C: diesel::connection::Connection + 'static,
127 S: SpawnBlocking + Send,
128 {
129 async fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
130 let query = query.to_string();
131 self.spawn_blocking(move |inner| inner.batch_execute(query.as_str()))
132 .await
133 }
134 }
135
136 impl<C, S, MD, O> AsyncConnectionCore for SyncConnectionWrapper<C, S>
137 where
138 <C as Connection>::Backend: std::default::Default + DieselReserveSpecialization,
140 <C::Backend as Backend>::QueryBuilder: std::default::Default,
141 C: Connection + LoadConnection + WithMetadataLookup + 'static,
143 <C as Connection>::TransactionManager: Send,
144 MD: Send + 'static,
146 for<'a> <C::Backend as Backend>::BindCollector<'a>:
147 MoveableBindCollector<C::Backend, BindData = MD> + std::default::Default,
148 O: 'static + Send + for<'conn> diesel::row::Row<'conn, C::Backend>,
150 for<'conn, 'query> <C as LoadConnection>::Row<'conn, 'query>:
151 IntoOwnedRow<'conn, <C as Connection>::Backend, OwnedRow = O>,
152 S: SpawnBlocking + Send,
154 {
155 type LoadFuture<'conn, 'query> =
156 BoxFuture<'query, QueryResult<Self::Stream<'conn, 'query>>>;
157 type ExecuteFuture<'conn, 'query> = BoxFuture<'query, QueryResult<usize>>;
158 type Stream<'conn, 'query> = BoxStream<'static, QueryResult<Self::Row<'conn, 'query>>>;
159 type Row<'conn, 'query> = O;
160 type Backend = <C as Connection>::Backend;
161
162 fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
163 where
164 T: AsQuery + 'query,
165 T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
166 {
167 self.execute_with_prepared_query(source.as_query(), |conn, query| {
168 use diesel::row::IntoOwnedRow;
169 let mut cache = <<<C as LoadConnection>::Row<'_, '_> as IntoOwnedRow<
170 <C as Connection>::Backend,
171 >>::Cache as Default>::default();
172 let cursor = conn.load(&query)?;
173
174 let size_hint = cursor.size_hint();
175 let mut out = Vec::with_capacity(size_hint.1.unwrap_or(size_hint.0));
176 for row in cursor {
179 out.push(Ok(IntoOwnedRow::into_owned(row?, &mut cache)));
180 }
181
182 Ok(out)
183 })
184 .map_ok(|rows| futures_util::stream::iter(rows).boxed())
185 .boxed()
186 }
187
188 fn execute_returning_count<'query, T>(
189 &mut self,
190 source: T,
191 ) -> Self::ExecuteFuture<'_, 'query>
192 where
193 T: QueryFragment<Self::Backend> + QueryId,
194 {
195 self.execute_with_prepared_query(source, |conn, query| {
196 conn.execute_returning_count(&query)
197 })
198 }
199 }
200
201 impl<C, S, MD, O> AsyncConnection for SyncConnectionWrapper<C, S>
202 where
203 <C as Connection>::Backend: std::default::Default + DieselReserveSpecialization,
205 <C::Backend as Backend>::QueryBuilder: std::default::Default,
206 C: Connection + LoadConnection + WithMetadataLookup + 'static,
208 <C as Connection>::TransactionManager: Send,
209 MD: Send + 'static,
211 for<'a> <C::Backend as Backend>::BindCollector<'a>:
212 MoveableBindCollector<C::Backend, BindData = MD> + std::default::Default,
213 O: 'static + Send + for<'conn> diesel::row::Row<'conn, C::Backend>,
215 for<'conn, 'query> <C as LoadConnection>::Row<'conn, 'query>:
216 IntoOwnedRow<'conn, <C as Connection>::Backend, OwnedRow = O>,
217 S: SpawnBlocking + Send,
219 {
220 type TransactionManager =
221 SyncTransactionManagerWrapper<<C as Connection>::TransactionManager>;
222
223 async fn establish(database_url: &str) -> ConnectionResult<Self> {
224 let database_url = database_url.to_string();
225 let mut runtime = S::get_runtime();
226
227 runtime
228 .spawn_blocking(move || C::establish(&database_url))
229 .await
230 .unwrap_or_else(|e| Err(diesel::ConnectionError::BadConnection(e.to_string())))
231 .map(move |c| SyncConnectionWrapper::with_runtime(c, runtime))
232 }
233
234 fn transaction_state(
235 &mut self,
236 ) -> &mut <Self::TransactionManager as TransactionManager<Self>>::TransactionStateData
237 {
238 self.exclusive_connection().transaction_state()
239 }
240
241 fn instrumentation(&mut self) -> &mut dyn Instrumentation {
242 if let Some(inner) = Arc::get_mut(&mut self.inner) {
246 inner
247 .get_mut()
248 .unwrap_or_else(|p| p.into_inner())
249 .instrumentation()
250 } else {
251 panic!("Cannot access shared instrumentation")
252 }
253 }
254
255 fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
256 if let Some(inner) = Arc::get_mut(&mut self.inner) {
260 inner
261 .get_mut()
262 .unwrap_or_else(|p| p.into_inner())
263 .set_instrumentation(instrumentation)
264 } else {
265 panic!("Cannot access shared instrumentation")
266 }
267 }
268
269 fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
270 if let Some(inner) = Arc::get_mut(&mut self.inner) {
274 inner
275 .get_mut()
276 .unwrap_or_else(|p| p.into_inner())
277 .set_prepared_statement_cache_size(size)
278 } else {
279 panic!("Cannot access shared cache")
280 }
281 }
282 }
283
284 pub struct SyncTransactionManagerWrapper<T>(PhantomData<T>);
286
287 impl<T, C, S> TransactionManager<SyncConnectionWrapper<C, S>> for SyncTransactionManagerWrapper<T>
288 where
289 SyncConnectionWrapper<C, S>: AsyncConnection,
290 C: Connection + 'static,
291 S: SpawnBlocking,
292 T: diesel::connection::TransactionManager<C> + Send,
293 {
294 type TransactionStateData = T::TransactionStateData;
295
296 async fn begin_transaction(conn: &mut SyncConnectionWrapper<C, S>) -> QueryResult<()> {
297 conn.spawn_blocking(move |inner| T::begin_transaction(inner))
298 .await
299 }
300
301 async fn commit_transaction(conn: &mut SyncConnectionWrapper<C, S>) -> QueryResult<()> {
302 conn.spawn_blocking(move |inner| T::commit_transaction(inner))
303 .await
304 }
305
306 async fn rollback_transaction(conn: &mut SyncConnectionWrapper<C, S>) -> QueryResult<()> {
307 conn.spawn_blocking(move |inner| T::rollback_transaction(inner))
308 .await
309 }
310
311 fn transaction_manager_status_mut(
312 conn: &mut SyncConnectionWrapper<C, S>,
313 ) -> &mut TransactionManagerStatus {
314 T::transaction_manager_status_mut(conn.exclusive_connection())
315 }
316 }
317
318 impl<C, S> SyncConnectionWrapper<C, S> {
319 pub fn new(connection: C) -> Self
321 where
322 C: Connection,
323 S: SpawnBlocking,
324 {
325 SyncConnectionWrapper {
326 inner: Arc::new(Mutex::new(connection)),
327 runtime: S::get_runtime(),
328 }
329 }
330
331 pub fn with_runtime(connection: C, runtime: S) -> Self
334 where
335 C: Connection,
336 S: SpawnBlocking,
337 {
338 SyncConnectionWrapper {
339 inner: Arc::new(Mutex::new(connection)),
340 runtime,
341 }
342 }
343
344 pub fn spawn_blocking<'a, R>(
371 &mut self,
372 task: impl FnOnce(&mut C) -> QueryResult<R> + Send + 'static,
373 ) -> BoxFuture<'a, QueryResult<R>>
374 where
375 C: Connection + 'static,
376 R: Send + 'static,
377 S: SpawnBlocking,
378 {
379 let inner = self.inner.clone();
380 self.runtime
381 .spawn_blocking(move || {
382 let mut inner = inner.lock().unwrap_or_else(|poison| {
383 inner.clear_poison();
385 poison.into_inner()
386 });
387 task(&mut inner)
388 })
389 .unwrap_or_else(|err| QueryResult::Err(from_spawn_blocking_error(err)))
390 .boxed()
391 }
392
393 fn execute_with_prepared_query<'a, MD, Q, R>(
394 &mut self,
395 query: Q,
396 callback: impl FnOnce(&mut C, &CollectedQuery<MD>) -> QueryResult<R> + Send + 'static,
397 ) -> BoxFuture<'a, QueryResult<R>>
398 where
399 <C as Connection>::Backend: std::default::Default + DieselReserveSpecialization,
401 <C::Backend as Backend>::QueryBuilder: std::default::Default,
402 C: Connection + LoadConnection + WithMetadataLookup + 'static,
404 <C as Connection>::TransactionManager: Send,
405 MD: Send + 'static,
407 for<'b> <C::Backend as Backend>::BindCollector<'b>:
408 MoveableBindCollector<C::Backend, BindData = MD> + std::default::Default,
409 Q: QueryFragment<C::Backend> + QueryId,
411 R: Send + 'static,
412 S: SpawnBlocking,
414 {
415 let backend = C::Backend::default();
416
417 let (collect_bind_result, collector_data) = {
418 let exclusive = self.inner.clone();
419 let mut inner = exclusive.lock().unwrap_or_else(|poison| {
420 exclusive.clear_poison();
422 poison.into_inner()
423 });
424 let mut bind_collector =
425 <<C::Backend as Backend>::BindCollector<'_> as Default>::default();
426 let metadata_lookup = inner.metadata_lookup();
427 let result = query.collect_binds(&mut bind_collector, metadata_lookup, &backend);
428 let collector_data = bind_collector.moveable();
429
430 (result, collector_data)
431 };
432
433 let mut query_builder = <<C::Backend as Backend>::QueryBuilder as Default>::default();
434 let sql = query
435 .to_sql(&mut query_builder, &backend)
436 .map(|_| query_builder.finish());
437 let is_safe_to_cache_prepared = query.is_safe_to_cache_prepared(&backend);
438
439 self.spawn_blocking(|inner| {
440 collect_bind_result?;
441 let query = CollectedQuery::new(sql?, is_safe_to_cache_prepared?, collector_data);
442 callback(inner, &query)
443 })
444 }
445
446 pub(self) fn exclusive_connection(&mut self) -> &mut C
451 where
452 C: Connection,
453 {
454 if let Some(conn_mutex) = Arc::get_mut(&mut self.inner) {
458 conn_mutex
459 .get_mut()
460 .expect("Mutex is poisoned, a thread must have panicked holding it.")
461 } else {
462 panic!("Cannot access shared transaction state")
463 }
464 }
465 }
466
467 #[cfg(any(
468 feature = "deadpool",
469 feature = "bb8",
470 feature = "mobc",
471 feature = "r2d2"
472 ))]
473 impl<C, S> crate::pooled_connection::PoolableConnection for SyncConnectionWrapper<C, S>
474 where
475 Self: AsyncConnection,
476 {
477 fn is_broken(&mut self) -> bool {
478 Self::TransactionManager::is_broken_transaction_manager(self)
479 }
480 }
481
482 #[cfg(feature = "tokio")]
483 pub enum Tokio {
484 Handle(tokio::runtime::Handle),
485 Runtime(tokio::runtime::Runtime),
486 }
487
488 #[cfg(feature = "tokio")]
489 impl SpawnBlocking for Tokio {
490 fn spawn_blocking<'a, R>(
491 &mut self,
492 task: impl FnOnce() -> R + Send + 'static,
493 ) -> BoxFuture<'a, Result<R, Box<dyn Error + Send + Sync + 'static>>>
494 where
495 R: Send + 'static,
496 {
497 let fut = match self {
498 Tokio::Handle(handle) => handle.spawn_blocking(task),
499 Tokio::Runtime(runtime) => runtime.spawn_blocking(task),
500 };
501
502 fut.map_err(Box::from).boxed()
503 }
504
505 fn get_runtime() -> Self {
506 if let Ok(handle) = tokio::runtime::Handle::try_current() {
507 Tokio::Handle(handle)
508 } else {
509 let runtime = tokio::runtime::Builder::new_current_thread()
510 .build()
511 .unwrap();
512
513 Tokio::Runtime(runtime)
514 }
515 }
516 }
517}