1use futures_core::Stream;
13use futures_util::StreamExt;
14use std::future::Future;
15use std::pin::Pin;
16
17pub trait BlockOn {
22 fn block_on<F>(&self, f: F) -> F::Output
25 where
26 F: Future;
27
28 fn get_runtime() -> Self;
31}
32
33#[cfg(feature = "tokio")]
91pub type AsyncConnectionWrapper<C, B = self::implementation::Tokio> =
92 self::implementation::AsyncConnectionWrapper<C, B>;
93
94#[cfg(not(feature = "tokio"))]
100pub use self::implementation::AsyncConnectionWrapper;
101
102mod implementation {
103 use diesel::connection::{Instrumentation, SimpleConnection};
104 use std::ops::{Deref, DerefMut};
105
106 use super::*;
107
108 pub struct AsyncConnectionWrapper<C, B> {
109 inner: C,
110 runtime: B,
111 }
112
113 impl<C, B> From<C> for AsyncConnectionWrapper<C, B>
114 where
115 C: crate::AsyncConnection,
116 B: BlockOn + Send,
117 {
118 fn from(inner: C) -> Self {
119 Self {
120 inner,
121 runtime: B::get_runtime(),
122 }
123 }
124 }
125
126 impl<C, B> Deref for AsyncConnectionWrapper<C, B> {
127 type Target = C;
128
129 fn deref(&self) -> &Self::Target {
130 &self.inner
131 }
132 }
133
134 impl<C, B> DerefMut for AsyncConnectionWrapper<C, B> {
135 fn deref_mut(&mut self) -> &mut Self::Target {
136 &mut self.inner
137 }
138 }
139
140 impl<C, B> diesel::connection::SimpleConnection for AsyncConnectionWrapper<C, B>
141 where
142 C: crate::SimpleAsyncConnection,
143 B: BlockOn,
144 {
145 fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
146 let f = self.inner.batch_execute(query);
147 self.runtime.block_on(f)
148 }
149 }
150
151 impl<C, B> diesel::connection::ConnectionSealed for AsyncConnectionWrapper<C, B> {}
152
153 impl<C, B> diesel::connection::Connection for AsyncConnectionWrapper<C, B>
154 where
155 C: crate::AsyncConnection,
156 B: BlockOn + Send,
157 {
158 type Backend = C::Backend;
159
160 type TransactionManager = AsyncConnectionWrapperTransactionManagerWrapper;
161
162 fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
163 let runtime = B::get_runtime();
164 let f = C::establish(database_url);
165 let inner = runtime.block_on(f)?;
166 Ok(Self { inner, runtime })
167 }
168
169 fn execute_returning_count<T>(&mut self, source: &T) -> diesel::QueryResult<usize>
170 where
171 T: diesel::query_builder::QueryFragment<Self::Backend> + diesel::query_builder::QueryId,
172 {
173 let f = self.inner.execute_returning_count(source);
174 self.runtime.block_on(f)
175 }
176
177 fn transaction_state(
178 &mut self,
179 ) -> &mut <Self::TransactionManager as diesel::connection::TransactionManager<Self>>::TransactionStateData{
180 self.inner.transaction_state()
181 }
182
183 fn instrumentation(&mut self) -> &mut dyn Instrumentation {
184 self.inner.instrumentation()
185 }
186
187 fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
188 self.inner.set_instrumentation(instrumentation);
189 }
190 }
191
192 impl<C, B> diesel::connection::LoadConnection for AsyncConnectionWrapper<C, B>
193 where
194 C: crate::AsyncConnection,
195 B: BlockOn + Send,
196 {
197 type Cursor<'conn, 'query>
198 = AsyncCursorWrapper<'conn, C::Stream<'conn, 'query>, B>
199 where
200 Self: 'conn;
201
202 type Row<'conn, 'query>
203 = C::Row<'conn, 'query>
204 where
205 Self: 'conn;
206
207 fn load<'conn, 'query, T>(
208 &'conn mut self,
209 source: T,
210 ) -> diesel::QueryResult<Self::Cursor<'conn, 'query>>
211 where
212 T: diesel::query_builder::Query
213 + diesel::query_builder::QueryFragment<Self::Backend>
214 + diesel::query_builder::QueryId
215 + 'query,
216 Self::Backend: diesel::expression::QueryMetadata<T::SqlType>,
217 {
218 let f = self.inner.load(source);
219 let stream = self.runtime.block_on(f)?;
220
221 Ok(AsyncCursorWrapper {
222 stream: Box::pin(stream),
223 runtime: &self.runtime,
224 })
225 }
226 }
227
228 pub struct AsyncCursorWrapper<'a, S, B> {
229 stream: Pin<Box<S>>,
230 runtime: &'a B,
231 }
232
233 impl<S, B> Iterator for AsyncCursorWrapper<'_, S, B>
234 where
235 S: Stream,
236 B: BlockOn,
237 {
238 type Item = S::Item;
239
240 fn next(&mut self) -> Option<Self::Item> {
241 let f = self.stream.next();
242 self.runtime.block_on(f)
243 }
244 }
245
246 pub struct AsyncConnectionWrapperTransactionManagerWrapper;
247
248 impl<C, B> diesel::connection::TransactionManager<AsyncConnectionWrapper<C, B>>
249 for AsyncConnectionWrapperTransactionManagerWrapper
250 where
251 C: crate::AsyncConnection,
252 B: BlockOn + Send,
253 {
254 type TransactionStateData =
255 <C::TransactionManager as crate::TransactionManager<C>>::TransactionStateData;
256
257 fn begin_transaction(conn: &mut AsyncConnectionWrapper<C, B>) -> diesel::QueryResult<()> {
258 let f = <C::TransactionManager as crate::TransactionManager<_>>::begin_transaction(
259 &mut conn.inner,
260 );
261 conn.runtime.block_on(f)
262 }
263
264 fn rollback_transaction(
265 conn: &mut AsyncConnectionWrapper<C, B>,
266 ) -> diesel::QueryResult<()> {
267 let f = <C::TransactionManager as crate::TransactionManager<_>>::rollback_transaction(
268 &mut conn.inner,
269 );
270 conn.runtime.block_on(f)
271 }
272
273 fn commit_transaction(conn: &mut AsyncConnectionWrapper<C, B>) -> diesel::QueryResult<()> {
274 let f = <C::TransactionManager as crate::TransactionManager<_>>::commit_transaction(
275 &mut conn.inner,
276 );
277 conn.runtime.block_on(f)
278 }
279
280 fn transaction_manager_status_mut(
281 conn: &mut AsyncConnectionWrapper<C, B>,
282 ) -> &mut diesel::connection::TransactionManagerStatus {
283 <C::TransactionManager as crate::TransactionManager<_>>::transaction_manager_status_mut(
284 &mut conn.inner,
285 )
286 }
287
288 fn is_broken_transaction_manager(conn: &mut AsyncConnectionWrapper<C, B>) -> bool {
289 <C::TransactionManager as crate::TransactionManager<_>>::is_broken_transaction_manager(
290 &mut conn.inner,
291 )
292 }
293 }
294
295 #[cfg(feature = "r2d2")]
296 impl<C, B> diesel::r2d2::R2D2Connection for AsyncConnectionWrapper<C, B>
297 where
298 B: BlockOn,
299 Self: diesel::Connection,
300 C: crate::AsyncConnection<Backend = <Self as diesel::Connection>::Backend>
301 + crate::pooled_connection::PoolableConnection
302 + 'static,
303 diesel::dsl::select<diesel::dsl::AsExprOf<i32, diesel::sql_types::Integer>>:
304 crate::methods::ExecuteDsl<C>,
305 diesel::query_builder::SqlQuery: crate::methods::ExecuteDsl<C>,
306 {
307 fn ping(&mut self) -> diesel::QueryResult<()> {
308 let fut = crate::pooled_connection::PoolableConnection::ping(
309 &mut self.inner,
310 &crate::pooled_connection::RecyclingMethod::Verified,
311 );
312 self.runtime.block_on(fut)
313 }
314
315 fn is_broken(&mut self) -> bool {
316 crate::pooled_connection::PoolableConnection::is_broken(&mut self.inner)
317 }
318 }
319
320 impl<C, B> diesel::migration::MigrationConnection for AsyncConnectionWrapper<C, B>
321 where
322 B: BlockOn,
323 Self: diesel::Connection,
324 {
325 fn setup(&mut self) -> diesel::QueryResult<usize> {
326 self.batch_execute(diesel::migration::CREATE_MIGRATIONS_TABLE)
327 .map(|()| 0)
328 }
329 }
330
331 #[cfg(feature = "tokio")]
332 pub struct Tokio {
333 handle: Option<tokio::runtime::Handle>,
334 runtime: Option<tokio::runtime::Runtime>,
335 }
336
337 #[cfg(feature = "tokio")]
338 impl BlockOn for Tokio {
339 fn block_on<F>(&self, f: F) -> F::Output
340 where
341 F: Future,
342 {
343 if let Some(handle) = &self.handle {
344 handle.block_on(f)
345 } else if let Some(runtime) = &self.runtime {
346 runtime.block_on(f)
347 } else {
348 unreachable!()
349 }
350 }
351
352 fn get_runtime() -> Self {
353 if let Ok(handle) = tokio::runtime::Handle::try_current() {
354 Self {
355 handle: Some(handle),
356 runtime: None,
357 }
358 } else {
359 let runtime = tokio::runtime::Builder::new_current_thread()
360 .enable_io()
361 .build()
362 .unwrap();
363 Self {
364 handle: None,
365 runtime: Some(runtime),
366 }
367 }
368 }
369 }
370}