1use async_trait::async_trait;
9use diesel::{
10 backend::UsesAnsiSavepointSyntax,
11 connection::{AnsiTransactionManager, SimpleConnection},
12 deserialize::QueryableByName,
13 query_builder::{AsQuery, QueryFragment, QueryId},
14 query_dsl::UpdateAndFetchResults,
15 r2d2::{self, ManageConnection},
16 sql_types::HasSqlType,
17 ConnectionError, ConnectionResult, QueryResult, Queryable,
18};
19use std::{
20 fmt::Debug,
21 ops::{Deref, DerefMut},
22 sync::{Arc, Mutex},
23};
24use tokio::task;
25
26#[derive(Clone)]
58pub struct DieselConnectionManager<T> {
59 inner: Arc<Mutex<r2d2::ConnectionManager<T>>>,
60}
61
62impl<T: Send + 'static> DieselConnectionManager<T> {
63 pub fn new<S: Into<String>>(database_url: S) -> Self {
64 Self {
65 inner: Arc::new(Mutex::new(r2d2::ConnectionManager::new(database_url))),
66 }
67 }
68
69 async fn run_blocking<R, F>(&self, f: F) -> R
70 where
71 R: Send + 'static,
72 F: Send + 'static + FnOnce(&r2d2::ConnectionManager<T>) -> R,
73 {
74 let cloned = self.inner.clone();
75 tokio::task::spawn_blocking(move || f(&*cloned.lock().unwrap()))
76 .await
77 .unwrap()
79 }
80
81 async fn run_blocking_in_place<R, F>(&self, f: F) -> R
82 where
83 F: FnOnce(&r2d2::ConnectionManager<T>) -> R,
84 {
85 task::block_in_place(|| f(&*self.inner.lock().unwrap()))
86 }
87}
88
89#[async_trait]
90impl<T> bb8::ManageConnection for DieselConnectionManager<T>
91where
92 T: diesel::Connection + Send + 'static,
93{
94 type Connection = DieselConnection<T>;
95 type Error = <r2d2::ConnectionManager<T> as r2d2::ManageConnection>::Error;
96
97 async fn connect(&self) -> Result<Self::Connection, Self::Error> {
98 self.run_blocking(|m| m.connect())
99 .await
100 .map(DieselConnection)
101 }
102
103 async fn is_valid(
104 &self,
105 conn: &mut bb8::PooledConnection<'_, Self>,
106 ) -> Result<(), Self::Error> {
107 self.run_blocking_in_place(|m| {
108 m.is_valid(&mut *conn)?;
109 Ok(())
110 })
111 .await
112 }
113
114 fn has_broken(&self, _: &mut Self::Connection) -> bool {
115 false
119 }
120}
121
122pub struct DieselConnection<C>(pub(crate) C);
138
139impl<C> Deref for DieselConnection<C> {
140 type Target = C;
141
142 fn deref(&self) -> &Self::Target {
143 &self.0
144 }
145}
146
147impl<C> DerefMut for DieselConnection<C> {
148 fn deref_mut(&mut self) -> &mut Self::Target {
149 &mut self.0
150 }
151}
152
153impl<C> SimpleConnection for DieselConnection<C>
154where
155 C: SimpleConnection,
156{
157 fn batch_execute(&self, query: &str) -> QueryResult<()> {
158 task::block_in_place(|| self.0.batch_execute(query))
159 }
160}
161
162impl<Conn, Changes, Output> UpdateAndFetchResults<Changes, Output> for DieselConnection<Conn>
163where
164 Conn: UpdateAndFetchResults<Changes, Output>,
165 Conn: diesel::Connection<TransactionManager = AnsiTransactionManager>,
166 Conn::Backend: UsesAnsiSavepointSyntax,
167{
168 fn update_and_fetch(&self, changeset: Changes) -> QueryResult<Output> {
169 task::block_in_place(|| self.0.update_and_fetch(changeset))
170 }
171}
172
173impl<C> diesel::Connection for DieselConnection<C>
174where
175 C: diesel::Connection<TransactionManager = AnsiTransactionManager>,
176 C::Backend: UsesAnsiSavepointSyntax,
177{
178 type Backend = C::Backend;
179
180 type TransactionManager = AnsiTransactionManager;
183
184 fn establish(_database_url: &str) -> ConnectionResult<Self> {
185 Err(ConnectionError::BadConnection(String::from(
187 "Cannot directly establish a pooled connection",
188 )))
189 }
190
191 fn transaction<T, E, F>(&self, f: F) -> Result<T, E>
192 where
193 F: FnOnce() -> Result<T, E>,
194 E: From<diesel::result::Error>,
195 {
196 task::block_in_place(|| self.0.transaction(f))
197 }
198
199 fn begin_test_transaction(&self) -> QueryResult<()> {
200 task::block_in_place(|| self.0.begin_test_transaction())
201 }
202
203 fn test_transaction<T, E, F>(&self, f: F) -> T
204 where
205 F: FnOnce() -> Result<T, E>,
206 E: Debug,
207 {
208 task::block_in_place(|| self.0.test_transaction(f))
209 }
210
211 fn execute(&self, query: &str) -> QueryResult<usize> {
212 task::block_in_place(|| self.0.execute(query))
213 }
214
215 fn query_by_index<T, U>(&self, source: T) -> QueryResult<Vec<U>>
216 where
217 T: AsQuery,
218 T::Query: QueryFragment<Self::Backend> + QueryId,
219 Self::Backend: HasSqlType<T::SqlType>,
220 U: Queryable<T::SqlType, Self::Backend>,
221 {
222 task::block_in_place(|| self.0.query_by_index(source))
223 }
224
225 fn query_by_name<T, U>(&self, source: &T) -> QueryResult<Vec<U>>
226 where
227 T: QueryFragment<Self::Backend> + QueryId,
228 U: QueryableByName<Self::Backend>,
229 {
230 task::block_in_place(|| self.0.query_by_name(source))
231 }
232
233 fn execute_returning_count<T>(&self, source: &T) -> QueryResult<usize>
234 where
235 T: QueryFragment<Self::Backend> + QueryId,
236 {
237 task::block_in_place(|| self.0.execute_returning_count(source))
238 }
239
240 fn transaction_manager(&self) -> &Self::TransactionManager {
241 &self.0.transaction_manager()
242 }
243}