use futures_core::Stream;
use futures_util::StreamExt;
use std::future::Future;
use std::pin::Pin;
pub trait BlockOn {
fn block_on<F>(&self, f: F) -> F::Output
where
F: Future;
fn get_runtime() -> Self;
}
#[cfg(feature = "tokio")]
pub type AsyncConnectionWrapper<C, B = self::implementation::Tokio> =
self::implementation::AsyncConnectionWrapper<C, B>;
#[cfg(not(feature = "tokio"))]
pub use self::implementation::AsyncConnectionWrapper;
pub(crate) mod implementation {
use diesel::connection::{CacheSize, Instrumentation, SimpleConnection};
use std::ops::{Deref, DerefMut};
use super::*;
pub struct AsyncConnectionWrapper<C, B> {
inner: C,
runtime: B,
}
impl<C, B> From<C> for AsyncConnectionWrapper<C, B>
where
C: crate::AsyncConnection,
B: BlockOn + Send,
{
fn from(inner: C) -> Self {
Self {
inner,
runtime: B::get_runtime(),
}
}
}
impl<C, B> AsyncConnectionWrapper<C, B>
where
C: crate::AsyncConnection,
{
pub fn into_inner(self) -> C {
self.inner
}
}
impl<C, B> Deref for AsyncConnectionWrapper<C, B> {
type Target = C;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<C, B> DerefMut for AsyncConnectionWrapper<C, B> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl<C, B> diesel::connection::SimpleConnection for AsyncConnectionWrapper<C, B>
where
C: crate::SimpleAsyncConnection,
B: BlockOn,
{
fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
let f = self.inner.batch_execute(query);
self.runtime.block_on(f)
}
}
impl<C, B> diesel::connection::ConnectionSealed for AsyncConnectionWrapper<C, B> {}
impl<C, B> diesel::connection::Connection for AsyncConnectionWrapper<C, B>
where
C: crate::AsyncConnection,
B: BlockOn + Send,
{
type Backend = C::Backend;
type TransactionManager = AsyncConnectionWrapperTransactionManagerWrapper;
fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
let runtime = B::get_runtime();
let f = C::establish(database_url);
let inner = runtime.block_on(f)?;
Ok(Self { inner, runtime })
}
fn execute_returning_count<T>(&mut self, source: &T) -> diesel::QueryResult<usize>
where
T: diesel::query_builder::QueryFragment<Self::Backend> + diesel::query_builder::QueryId,
{
let f = self.inner.execute_returning_count(source);
self.runtime.block_on(f)
}
fn transaction_state(
&mut self,
) -> &mut <Self::TransactionManager as diesel::connection::TransactionManager<Self>>::TransactionStateData{
self.inner.transaction_state()
}
fn instrumentation(&mut self) -> &mut dyn Instrumentation {
self.inner.instrumentation()
}
fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
self.inner.set_instrumentation(instrumentation);
}
fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
self.inner.set_prepared_statement_cache_size(size)
}
}
impl<C, B> diesel::connection::LoadConnection for AsyncConnectionWrapper<C, B>
where
C: crate::AsyncConnection,
B: BlockOn + Send,
{
type Cursor<'conn, 'query>
= AsyncCursorWrapper<'conn, C::Stream<'conn, 'query>, B>
where
Self: 'conn;
type Row<'conn, 'query>
= C::Row<'conn, 'query>
where
Self: 'conn;
fn load<'conn, 'query, T>(
&'conn mut self,
source: T,
) -> diesel::QueryResult<Self::Cursor<'conn, 'query>>
where
T: diesel::query_builder::Query
+ diesel::query_builder::QueryFragment<Self::Backend>
+ diesel::query_builder::QueryId
+ 'query,
Self::Backend: diesel::expression::QueryMetadata<T::SqlType>,
{
let f = self.inner.load(source);
let stream = self.runtime.block_on(f)?;
Ok(AsyncCursorWrapper {
stream: Box::pin(stream),
runtime: &self.runtime,
})
}
}
pub struct AsyncCursorWrapper<'a, S, B> {
stream: Pin<Box<S>>,
runtime: &'a B,
}
impl<S, B> Iterator for AsyncCursorWrapper<'_, S, B>
where
S: Stream,
B: BlockOn,
{
type Item = S::Item;
fn next(&mut self) -> Option<Self::Item> {
let f = self.stream.next();
self.runtime.block_on(f)
}
}
pub struct AsyncConnectionWrapperTransactionManagerWrapper;
impl<C, B> diesel::connection::TransactionManager<AsyncConnectionWrapper<C, B>>
for AsyncConnectionWrapperTransactionManagerWrapper
where
C: crate::AsyncConnection,
B: BlockOn + Send,
{
type TransactionStateData =
<C::TransactionManager as crate::TransactionManager<C>>::TransactionStateData;
fn begin_transaction(conn: &mut AsyncConnectionWrapper<C, B>) -> diesel::QueryResult<()> {
let f = <C::TransactionManager as crate::TransactionManager<_>>::begin_transaction(
&mut conn.inner,
);
conn.runtime.block_on(f)
}
fn rollback_transaction(
conn: &mut AsyncConnectionWrapper<C, B>,
) -> diesel::QueryResult<()> {
let f = <C::TransactionManager as crate::TransactionManager<_>>::rollback_transaction(
&mut conn.inner,
);
conn.runtime.block_on(f)
}
fn commit_transaction(conn: &mut AsyncConnectionWrapper<C, B>) -> diesel::QueryResult<()> {
let f = <C::TransactionManager as crate::TransactionManager<_>>::commit_transaction(
&mut conn.inner,
);
conn.runtime.block_on(f)
}
fn transaction_manager_status_mut(
conn: &mut AsyncConnectionWrapper<C, B>,
) -> &mut diesel::connection::TransactionManagerStatus {
<C::TransactionManager as crate::TransactionManager<_>>::transaction_manager_status_mut(
&mut conn.inner,
)
}
fn is_broken_transaction_manager(conn: &mut AsyncConnectionWrapper<C, B>) -> bool {
<C::TransactionManager as crate::TransactionManager<_>>::is_broken_transaction_manager(
&mut conn.inner,
)
}
}
#[cfg(feature = "r2d2")]
impl<C, B> diesel::r2d2::R2D2Connection for AsyncConnectionWrapper<C, B>
where
B: BlockOn,
Self: diesel::Connection,
C: crate::AsyncConnection<Backend = <Self as diesel::Connection>::Backend>
+ crate::pooled_connection::PoolableConnection
+ 'static,
diesel::dsl::select<diesel::dsl::AsExprOf<i32, diesel::sql_types::Integer>>:
crate::methods::ExecuteDsl<C>,
diesel::query_builder::SqlQuery: crate::methods::ExecuteDsl<C>,
{
fn ping(&mut self) -> diesel::QueryResult<()> {
let fut = crate::pooled_connection::PoolableConnection::ping(
&mut self.inner,
&crate::pooled_connection::RecyclingMethod::Verified,
);
self.runtime.block_on(fut)
}
fn is_broken(&mut self) -> bool {
crate::pooled_connection::PoolableConnection::is_broken(&mut self.inner)
}
}
impl<C, B> diesel::migration::MigrationConnection for AsyncConnectionWrapper<C, B>
where
B: BlockOn,
Self: diesel::Connection,
{
fn setup(&mut self) -> diesel::QueryResult<usize> {
self.batch_execute(diesel::migration::CREATE_MIGRATIONS_TABLE)
.map(|()| 0)
}
}
#[cfg(feature = "tokio")]
pub struct Tokio {
handle: Option<tokio::runtime::Handle>,
runtime: Option<tokio::runtime::Runtime>,
}
#[cfg(feature = "tokio")]
impl BlockOn for Tokio {
fn block_on<F>(&self, f: F) -> F::Output
where
F: Future,
{
if let Some(handle) = &self.handle {
handle.block_on(f)
} else if let Some(runtime) = &self.runtime {
runtime.block_on(f)
} else {
unreachable!()
}
}
fn get_runtime() -> Self {
if let Ok(handle) = tokio::runtime::Handle::try_current() {
Self {
handle: Some(handle),
runtime: None,
}
} else {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.build()
.unwrap();
Self {
handle: None,
runtime: Some(runtime),
}
}
}
}
}