bb8_diesel/
lib.rs

1//! bb8-diesel allows the bb8 asynchronous connection pool
2//! to be used underneath Diesel.
3//!
4//! This is currently implemented against Diesel's synchronous
5//! API, with calls to [`tokio::task::spawn_blocking`] to safely
6//! perform synchronous operations from an asynchronous task.
7
8use async_trait::async_trait;
9use diesel::{
10    backend::UsesAnsiSavepointSyntax,
11    connection::{AnsiTransactionManager, SimpleConnection},
12    deserialize::QueryableByName,
13    query_builder::{AsQuery, QueryFragment, QueryId},
14    query_dsl::UpdateAndFetchResults,
15    r2d2::{self, ManageConnection},
16    sql_types::HasSqlType,
17    ConnectionError, ConnectionResult, QueryResult, Queryable,
18};
19use std::{
20    fmt::Debug,
21    ops::{Deref, DerefMut},
22    sync::{Arc, Mutex},
23};
24use tokio::task;
25
26/// A connection manager which implements [`bb8::ManageConnection`] to
27/// integrate with bb8.
28///
29/// ```no_run
30/// #[macro_use]
31/// extern crate diesel;
32///
33/// use diesel::prelude::*;
34/// use diesel::pg::PgConnection;
35///
36/// table! {
37///     users (id) {
38///         id -> Integer,
39///     }
40/// }
41///
42/// #[tokio::main]
43/// async fn main() {
44///     use users::dsl;
45///
46///     // Creates a Diesel-specific connection manager for bb8.
47///     let mgr = bb8_diesel::DieselConnectionManager::<PgConnection>::new("localhost:1234");
48///     let pool = bb8::Pool::builder().build(mgr).await.unwrap();
49///     let conn = pool.get().await.unwrap();
50///
51///     diesel::insert_into(dsl::users)
52///         .values(dsl::id.eq(1337))
53///         .execute(&*conn)
54///         .unwrap();
55/// }
56/// ```
57#[derive(Clone)]
58pub struct DieselConnectionManager<T> {
59    inner: Arc<Mutex<r2d2::ConnectionManager<T>>>,
60}
61
62impl<T: Send + 'static> DieselConnectionManager<T> {
63    pub fn new<S: Into<String>>(database_url: S) -> Self {
64        Self {
65            inner: Arc::new(Mutex::new(r2d2::ConnectionManager::new(database_url))),
66        }
67    }
68
69    async fn run_blocking<R, F>(&self, f: F) -> R
70    where
71        R: Send + 'static,
72        F: Send + 'static + FnOnce(&r2d2::ConnectionManager<T>) -> R,
73    {
74        let cloned = self.inner.clone();
75        tokio::task::spawn_blocking(move || f(&*cloned.lock().unwrap()))
76            .await
77            // Intentionally panic if the inner closure panics.
78            .unwrap()
79    }
80
81    async fn run_blocking_in_place<R, F>(&self, f: F) -> R
82    where
83        F: FnOnce(&r2d2::ConnectionManager<T>) -> R,
84    {
85        task::block_in_place(|| f(&*self.inner.lock().unwrap()))
86    }
87}
88
89#[async_trait]
90impl<T> bb8::ManageConnection for DieselConnectionManager<T>
91where
92    T: diesel::Connection + Send + 'static,
93{
94    type Connection = DieselConnection<T>;
95    type Error = <r2d2::ConnectionManager<T> as r2d2::ManageConnection>::Error;
96
97    async fn connect(&self) -> Result<Self::Connection, Self::Error> {
98        self.run_blocking(|m| m.connect())
99            .await
100            .map(DieselConnection)
101    }
102
103    async fn is_valid(
104        &self,
105        conn: &mut bb8::PooledConnection<'_, Self>,
106    ) -> Result<(), Self::Error> {
107        self.run_blocking_in_place(|m| {
108            m.is_valid(&mut *conn)?;
109            Ok(())
110        })
111        .await
112    }
113
114    fn has_broken(&self, _: &mut Self::Connection) -> bool {
115        // Diesel returns this value internally. We have no way of calling the
116        // inner method without blocking as this method is not async, but `bb8`
117        // indicates that this method is not mandatory.
118        false
119    }
120}
121
122/// An async-safe analogue of any connection that implements
123/// [`diesel::Connection`].
124///
125/// These connections are created by [`DieselConnectionManager`].
126///
127/// All blocking methods within this type delegate to
128/// [`tokio::task::block_in_place`]. The number of threads is not unbounded,
129/// however, as they are controlled by the truly asynchronous [`bb8::Pool`]
130/// owner.  This type makes it easy to use diesel without fear of blocking the
131/// runtime and without fear of spawning too many child threads.
132///
133/// Note that trying to construct this type via
134/// [`diesel::connection::Connection::establish`] will return an error.
135///
136/// The only correct way to construct this type is by using a bb8 pool.
137pub struct DieselConnection<C>(pub(crate) C);
138
139impl<C> Deref for DieselConnection<C> {
140    type Target = C;
141
142    fn deref(&self) -> &Self::Target {
143        &self.0
144    }
145}
146
147impl<C> DerefMut for DieselConnection<C> {
148    fn deref_mut(&mut self) -> &mut Self::Target {
149        &mut self.0
150    }
151}
152
153impl<C> SimpleConnection for DieselConnection<C>
154where
155    C: SimpleConnection,
156{
157    fn batch_execute(&self, query: &str) -> QueryResult<()> {
158        task::block_in_place(|| self.0.batch_execute(query))
159    }
160}
161
162impl<Conn, Changes, Output> UpdateAndFetchResults<Changes, Output> for DieselConnection<Conn>
163where
164    Conn: UpdateAndFetchResults<Changes, Output>,
165    Conn: diesel::Connection<TransactionManager = AnsiTransactionManager>,
166    Conn::Backend: UsesAnsiSavepointSyntax,
167{
168    fn update_and_fetch(&self, changeset: Changes) -> QueryResult<Output> {
169        task::block_in_place(|| self.0.update_and_fetch(changeset))
170    }
171}
172
173impl<C> diesel::Connection for DieselConnection<C>
174where
175    C: diesel::Connection<TransactionManager = AnsiTransactionManager>,
176    C::Backend: UsesAnsiSavepointSyntax,
177{
178    type Backend = C::Backend;
179
180    // This type is hidden in the docs so we can assume it is only called via
181    // the implemented methods below.
182    type TransactionManager = AnsiTransactionManager;
183
184    fn establish(_database_url: &str) -> ConnectionResult<Self> {
185        // This is taken from `diesel::r2d2`
186        Err(ConnectionError::BadConnection(String::from(
187            "Cannot directly establish a pooled connection",
188        )))
189    }
190
191    fn transaction<T, E, F>(&self, f: F) -> Result<T, E>
192    where
193        F: FnOnce() -> Result<T, E>,
194        E: From<diesel::result::Error>,
195    {
196        task::block_in_place(|| self.0.transaction(f))
197    }
198
199    fn begin_test_transaction(&self) -> QueryResult<()> {
200        task::block_in_place(|| self.0.begin_test_transaction())
201    }
202
203    fn test_transaction<T, E, F>(&self, f: F) -> T
204    where
205        F: FnOnce() -> Result<T, E>,
206        E: Debug,
207    {
208        task::block_in_place(|| self.0.test_transaction(f))
209    }
210
211    fn execute(&self, query: &str) -> QueryResult<usize> {
212        task::block_in_place(|| self.0.execute(query))
213    }
214
215    fn query_by_index<T, U>(&self, source: T) -> QueryResult<Vec<U>>
216    where
217        T: AsQuery,
218        T::Query: QueryFragment<Self::Backend> + QueryId,
219        Self::Backend: HasSqlType<T::SqlType>,
220        U: Queryable<T::SqlType, Self::Backend>,
221    {
222        task::block_in_place(|| self.0.query_by_index(source))
223    }
224
225    fn query_by_name<T, U>(&self, source: &T) -> QueryResult<Vec<U>>
226    where
227        T: QueryFragment<Self::Backend> + QueryId,
228        U: QueryableByName<Self::Backend>,
229    {
230        task::block_in_place(|| self.0.query_by_name(source))
231    }
232
233    fn execute_returning_count<T>(&self, source: &T) -> QueryResult<usize>
234    where
235        T: QueryFragment<Self::Backend> + QueryId,
236    {
237        task::block_in_place(|| self.0.execute_returning_count(source))
238    }
239
240    fn transaction_manager(&self) -> &Self::TransactionManager {
241        &self.0.transaction_manager()
242    }
243}