1use futures_core::Stream;
13use futures_util::StreamExt;
14use std::future::Future;
15use std::pin::Pin;
16
17pub trait BlockOn {
22 fn block_on<F>(&self, f: F) -> F::Output
25 where
26 F: Future;
27
28 fn get_runtime() -> Self;
31}
32
33#[cfg(feature = "tokio")]
91pub type AsyncConnectionWrapper<C, B = self::implementation::Tokio> =
92 self::implementation::AsyncConnectionWrapper<C, B>;
93
94#[cfg(not(feature = "tokio"))]
100pub use self::implementation::AsyncConnectionWrapper;
101
102pub(crate) mod implementation {
103 use diesel::connection::{CacheSize, Instrumentation, SimpleConnection};
104 use std::ops::{Deref, DerefMut};
105
106 use super::*;
107
108 pub struct AsyncConnectionWrapper<C, B> {
109 inner: C,
110 runtime: B,
111 }
112
113 impl<C, B> From<C> for AsyncConnectionWrapper<C, B>
114 where
115 C: crate::AsyncConnection,
116 B: BlockOn + Send,
117 {
118 fn from(inner: C) -> Self {
119 Self {
120 inner,
121 runtime: B::get_runtime(),
122 }
123 }
124 }
125
126 impl<C, B> AsyncConnectionWrapper<C, B>
127 where
128 C: crate::AsyncConnection,
129 {
130 pub fn into_inner(self) -> C {
133 self.inner
134 }
135 }
136
137 impl<C, B> Deref for AsyncConnectionWrapper<C, B> {
138 type Target = C;
139
140 fn deref(&self) -> &Self::Target {
141 &self.inner
142 }
143 }
144
145 impl<C, B> DerefMut for AsyncConnectionWrapper<C, B> {
146 fn deref_mut(&mut self) -> &mut Self::Target {
147 &mut self.inner
148 }
149 }
150
151 impl<C, B> diesel::connection::SimpleConnection for AsyncConnectionWrapper<C, B>
152 where
153 C: crate::SimpleAsyncConnection,
154 B: BlockOn,
155 {
156 fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
157 let f = self.inner.batch_execute(query);
158 self.runtime.block_on(f)
159 }
160 }
161
162 impl<C, B> diesel::connection::ConnectionSealed for AsyncConnectionWrapper<C, B> {}
163
164 impl<C, B> diesel::connection::Connection for AsyncConnectionWrapper<C, B>
165 where
166 C: crate::AsyncConnection,
167 B: BlockOn + Send,
168 {
169 type Backend = C::Backend;
170
171 type TransactionManager = AsyncConnectionWrapperTransactionManagerWrapper;
172
173 fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
174 let runtime = B::get_runtime();
175 let f = C::establish(database_url);
176 let inner = runtime.block_on(f)?;
177 Ok(Self { inner, runtime })
178 }
179
180 fn execute_returning_count<T>(&mut self, source: &T) -> diesel::QueryResult<usize>
181 where
182 T: diesel::query_builder::QueryFragment<Self::Backend> + diesel::query_builder::QueryId,
183 {
184 let f = self.inner.execute_returning_count(source);
185 self.runtime.block_on(f)
186 }
187
188 fn transaction_state(
189 &mut self,
190 ) -> &mut <Self::TransactionManager as diesel::connection::TransactionManager<Self>>::TransactionStateData{
191 self.inner.transaction_state()
192 }
193
194 fn instrumentation(&mut self) -> &mut dyn Instrumentation {
195 self.inner.instrumentation()
196 }
197
198 fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
199 self.inner.set_instrumentation(instrumentation);
200 }
201
202 fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
203 self.inner.set_prepared_statement_cache_size(size)
204 }
205 }
206
207 impl<C, B> diesel::connection::LoadConnection for AsyncConnectionWrapper<C, B>
208 where
209 C: crate::AsyncConnection,
210 B: BlockOn + Send,
211 {
212 type Cursor<'conn, 'query>
213 = AsyncCursorWrapper<'conn, C::Stream<'conn, 'query>, B>
214 where
215 Self: 'conn;
216
217 type Row<'conn, 'query>
218 = C::Row<'conn, 'query>
219 where
220 Self: 'conn;
221
222 fn load<'conn, 'query, T>(
223 &'conn mut self,
224 source: T,
225 ) -> diesel::QueryResult<Self::Cursor<'conn, 'query>>
226 where
227 T: diesel::query_builder::Query
228 + diesel::query_builder::QueryFragment<Self::Backend>
229 + diesel::query_builder::QueryId
230 + 'query,
231 Self::Backend: diesel::expression::QueryMetadata<T::SqlType>,
232 {
233 let f = self.inner.load(source);
234 let stream = self.runtime.block_on(f)?;
235
236 Ok(AsyncCursorWrapper {
237 stream: Box::pin(stream),
238 runtime: &self.runtime,
239 })
240 }
241 }
242
243 pub struct AsyncCursorWrapper<'a, S, B> {
244 stream: Pin<Box<S>>,
245 runtime: &'a B,
246 }
247
248 impl<S, B> Iterator for AsyncCursorWrapper<'_, S, B>
249 where
250 S: Stream,
251 B: BlockOn,
252 {
253 type Item = S::Item;
254
255 fn next(&mut self) -> Option<Self::Item> {
256 let f = self.stream.next();
257 self.runtime.block_on(f)
258 }
259 }
260
261 pub struct AsyncConnectionWrapperTransactionManagerWrapper;
262
263 impl<C, B> diesel::connection::TransactionManager<AsyncConnectionWrapper<C, B>>
264 for AsyncConnectionWrapperTransactionManagerWrapper
265 where
266 C: crate::AsyncConnection,
267 B: BlockOn + Send,
268 {
269 type TransactionStateData =
270 <C::TransactionManager as crate::TransactionManager<C>>::TransactionStateData;
271
272 fn begin_transaction(conn: &mut AsyncConnectionWrapper<C, B>) -> diesel::QueryResult<()> {
273 let f = <C::TransactionManager as crate::TransactionManager<_>>::begin_transaction(
274 &mut conn.inner,
275 );
276 conn.runtime.block_on(f)
277 }
278
279 fn rollback_transaction(
280 conn: &mut AsyncConnectionWrapper<C, B>,
281 ) -> diesel::QueryResult<()> {
282 let f = <C::TransactionManager as crate::TransactionManager<_>>::rollback_transaction(
283 &mut conn.inner,
284 );
285 conn.runtime.block_on(f)
286 }
287
288 fn commit_transaction(conn: &mut AsyncConnectionWrapper<C, B>) -> diesel::QueryResult<()> {
289 let f = <C::TransactionManager as crate::TransactionManager<_>>::commit_transaction(
290 &mut conn.inner,
291 );
292 conn.runtime.block_on(f)
293 }
294
295 fn transaction_manager_status_mut(
296 conn: &mut AsyncConnectionWrapper<C, B>,
297 ) -> &mut diesel::connection::TransactionManagerStatus {
298 <C::TransactionManager as crate::TransactionManager<_>>::transaction_manager_status_mut(
299 &mut conn.inner,
300 )
301 }
302
303 fn is_broken_transaction_manager(conn: &mut AsyncConnectionWrapper<C, B>) -> bool {
304 <C::TransactionManager as crate::TransactionManager<_>>::is_broken_transaction_manager(
305 &mut conn.inner,
306 )
307 }
308 }
309
310 #[cfg(feature = "r2d2")]
311 impl<C, B> diesel::r2d2::R2D2Connection for AsyncConnectionWrapper<C, B>
312 where
313 B: BlockOn,
314 Self: diesel::Connection,
315 C: crate::AsyncConnection<Backend = <Self as diesel::Connection>::Backend>
316 + crate::pooled_connection::PoolableConnection
317 + 'static,
318 diesel::dsl::select<diesel::dsl::AsExprOf<i32, diesel::sql_types::Integer>>:
319 crate::methods::ExecuteDsl<C>,
320 diesel::query_builder::SqlQuery: crate::methods::ExecuteDsl<C>,
321 {
322 fn ping(&mut self) -> diesel::QueryResult<()> {
323 let fut = crate::pooled_connection::PoolableConnection::ping(
324 &mut self.inner,
325 &crate::pooled_connection::RecyclingMethod::Verified,
326 );
327 self.runtime.block_on(fut)
328 }
329
330 fn is_broken(&mut self) -> bool {
331 crate::pooled_connection::PoolableConnection::is_broken(&mut self.inner)
332 }
333 }
334
335 impl<C, B> diesel::migration::MigrationConnection for AsyncConnectionWrapper<C, B>
336 where
337 B: BlockOn,
338 Self: diesel::Connection,
339 {
340 fn setup(&mut self) -> diesel::QueryResult<usize> {
341 self.batch_execute(diesel::migration::CREATE_MIGRATIONS_TABLE)
342 .map(|()| 0)
343 }
344 }
345
346 #[cfg(feature = "tokio")]
347 pub struct Tokio {
348 handle: Option<tokio::runtime::Handle>,
349 runtime: Option<tokio::runtime::Runtime>,
350 }
351
352 #[cfg(feature = "tokio")]
353 impl BlockOn for Tokio {
354 fn block_on<F>(&self, f: F) -> F::Output
355 where
356 F: Future,
357 {
358 if let Some(handle) = &self.handle {
359 handle.block_on(f)
360 } else if let Some(runtime) = &self.runtime {
361 runtime.block_on(f)
362 } else {
363 unreachable!()
364 }
365 }
366
367 fn get_runtime() -> Self {
368 if let Ok(handle) = tokio::runtime::Handle::try_current() {
369 Self {
370 handle: Some(handle),
371 runtime: None,
372 }
373 } else {
374 let runtime = tokio::runtime::Builder::new_current_thread()
375 .enable_io()
376 .build()
377 .unwrap();
378 Self {
379 handle: None,
380 runtime: Some(runtime),
381 }
382 }
383 }
384 }
385}