1use crate::statement_cache::CacheSize;
2use crate::statement_cache::{MaybeCached, QueryFragmentForCachedStatement, StatementCache};
3use crate::stmt_cache::{CallbackHelper, QueryFragmentHelper};
4use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection};
5use diesel::connection::Instrumentation;
6use diesel::connection::InstrumentationEvent;
7use diesel::connection::StrQueryHelper;
8use diesel::mysql::{Mysql, MysqlQueryBuilder, MysqlType};
9use diesel::query_builder::QueryBuilder;
10use diesel::query_builder::{bind_collector::RawBytesBindCollector, QueryFragment, QueryId};
11use diesel::result::{ConnectionError, ConnectionResult};
12use diesel::QueryResult;
13use futures_core::future::BoxFuture;
14use futures_core::stream::BoxStream;
15use futures_util::stream;
16use futures_util::{FutureExt, StreamExt, TryStreamExt};
17use mysql_async::prelude::Queryable;
18use mysql_async::{Opts, OptsBuilder, Statement};
19use std::future::Future;
20
21mod error_helper;
22mod row;
23mod serialize;
24
25use self::error_helper::ErrorHelper;
26use self::row::MysqlRow;
27use self::serialize::ToSqlHelper;
28
29pub struct AsyncMysqlConnection {
32 conn: mysql_async::Conn,
33 stmt_cache: StatementCache<Mysql, Statement>,
34 transaction_manager: AnsiTransactionManager,
35 instrumentation: Option<Box<dyn Instrumentation>>,
36}
37
38impl SimpleAsyncConnection for AsyncMysqlConnection {
39 async fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
40 self.instrumentation()
41 .on_connection_event(InstrumentationEvent::start_query(&StrQueryHelper::new(
42 query,
43 )));
44 let result = self
45 .conn
46 .query_drop(query)
47 .await
48 .map_err(ErrorHelper)
49 .map_err(Into::into);
50 self.instrumentation()
51 .on_connection_event(InstrumentationEvent::finish_query(
52 &StrQueryHelper::new(query),
53 result.as_ref().err(),
54 ));
55 result
56 }
57}
58
59const CONNECTION_SETUP_QUERIES: &[&str] = &[
60 "SET time_zone = '+00:00';",
61 "SET character_set_client = 'utf8mb4'",
62 "SET character_set_connection = 'utf8mb4'",
63 "SET character_set_results = 'utf8mb4'",
64];
65
66impl AsyncConnection for AsyncMysqlConnection {
67 type ExecuteFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<usize>>;
68 type LoadFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<Self::Stream<'conn, 'query>>>;
69 type Stream<'conn, 'query> = BoxStream<'conn, QueryResult<Self::Row<'conn, 'query>>>;
70 type Row<'conn, 'query> = MysqlRow;
71 type Backend = Mysql;
72
73 type TransactionManager = AnsiTransactionManager;
74
75 async fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
76 let mut instrumentation = diesel::connection::get_default_instrumentation();
77 instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
78 database_url,
79 ));
80 let r = Self::establish_connection_inner(database_url).await;
81 instrumentation.on_connection_event(InstrumentationEvent::finish_establish_connection(
82 database_url,
83 r.as_ref().err(),
84 ));
85 let mut conn = r?;
86 conn.instrumentation = instrumentation;
87 Ok(conn)
88 }
89
90 fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
91 where
92 T: diesel::query_builder::AsQuery,
93 T::Query: diesel::query_builder::QueryFragment<Self::Backend>
94 + diesel::query_builder::QueryId
95 + 'query,
96 {
97 self.with_prepared_statement(source.as_query(), |conn, stmt, binds| async move {
98 let stmt_for_exec = match stmt {
99 MaybeCached::Cached(ref s) => (*s).clone(),
100 MaybeCached::CannotCache(ref s) => s.clone(),
101 };
102
103 let (tx, rx) = futures_channel::mpsc::channel(0);
104
105 let yielder = async move {
106 let r = Self::poll_result_stream(conn, stmt_for_exec, binds, tx).await;
107 if let MaybeCached::CannotCache(stmt) = stmt {
120 conn.close(stmt).await.map_err(ErrorHelper)?;
121 }
122 r
123 };
124
125 let fake_stream = stream::once(yielder).filter_map(|e: QueryResult<()>| async move {
126 if let Err(e) = e {
127 Some(Err(e))
128 } else {
129 None
130 }
131 });
132
133 let stream = stream::select(fake_stream, rx).boxed();
134
135 Ok(stream)
136 })
137 .boxed()
138 }
139
140 fn execute_returning_count<'conn, 'query, T>(
141 &'conn mut self,
142 source: T,
143 ) -> Self::ExecuteFuture<'conn, 'query>
144 where
145 T: diesel::query_builder::QueryFragment<Self::Backend>
146 + diesel::query_builder::QueryId
147 + 'query,
148 {
149 self.with_prepared_statement(source, |conn, stmt, binds| async move {
150 let params = mysql_async::Params::try_from(binds)?;
151 conn.exec_drop(&*stmt, params).await.map_err(ErrorHelper)?;
152 if let MaybeCached::CannotCache(stmt) = stmt {
164 conn.close(stmt).await.map_err(ErrorHelper)?;
165 }
166 conn.affected_rows()
167 .try_into()
168 .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))
169 })
170 }
171
172 fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
173 &mut self.transaction_manager
174 }
175
176 fn instrumentation(&mut self) -> &mut dyn Instrumentation {
177 &mut self.instrumentation
178 }
179
180 fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
181 self.instrumentation = Some(Box::new(instrumentation));
182 }
183
184 fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
185 self.stmt_cache.set_cache_size(size);
186 }
187}
188
189#[inline(always)]
190fn update_transaction_manager_status<T>(
191 query_result: QueryResult<T>,
192 transaction_manager: &mut AnsiTransactionManager,
193) -> QueryResult<T> {
194 if let Err(diesel::result::Error::DatabaseError(
195 diesel::result::DatabaseErrorKind::SerializationFailure,
196 _,
197 )) = query_result
198 {
199 transaction_manager
200 .status
201 .set_requires_rollback_maybe_up_to_top_level(true)
202 }
203 query_result
204}
205
206fn prepare_statement_helper<'a>(
207 conn: &'a mut mysql_async::Conn,
208 sql: &str,
209 _is_for_cache: crate::statement_cache::PrepareForCache,
210 _metadata: &[MysqlType],
211) -> CallbackHelper<impl Future<Output = QueryResult<(Statement, &'a mut mysql_async::Conn)>> + Send>
212{
213 let sql = sql.to_owned();
220 CallbackHelper(async move {
221 let s = conn.prep(sql).await.map_err(ErrorHelper)?;
222 Ok((s, conn))
223 })
224}
225
226impl AsyncMysqlConnection {
227 pub async fn try_from(conn: mysql_async::Conn) -> ConnectionResult<Self> {
232 use crate::run_query_dsl::RunQueryDsl;
233 let mut conn = AsyncMysqlConnection {
234 conn,
235 stmt_cache: StatementCache::new(),
236 transaction_manager: AnsiTransactionManager::default(),
237 instrumentation: diesel::connection::get_default_instrumentation(),
238 };
239
240 for stmt in CONNECTION_SETUP_QUERIES {
241 diesel::sql_query(*stmt)
242 .execute(&mut conn)
243 .await
244 .map_err(ConnectionError::CouldntSetupConfiguration)?;
245 }
246
247 Ok(conn)
248 }
249
250 fn with_prepared_statement<'conn, T, F, R>(
251 &'conn mut self,
252 query: T,
253 callback: impl (FnOnce(&'conn mut mysql_async::Conn, MaybeCached<'conn, Statement>, ToSqlHelper) -> F)
254 + Send
255 + 'conn,
256 ) -> BoxFuture<'conn, QueryResult<R>>
257 where
258 R: Send + 'conn,
259 T: QueryFragment<Mysql> + QueryId,
260 F: Future<Output = QueryResult<R>> + Send,
261 {
262 self.instrumentation()
263 .on_connection_event(InstrumentationEvent::start_query(&diesel::debug_query(
264 &query,
265 )));
266 let mut bind_collector = RawBytesBindCollector::<Mysql>::new();
267 let bind_collector = query
268 .collect_binds(&mut bind_collector, &mut (), &Mysql)
269 .map(|()| bind_collector);
270
271 let AsyncMysqlConnection {
272 ref mut conn,
273 ref mut stmt_cache,
274 ref mut transaction_manager,
275 ref mut instrumentation,
276 ..
277 } = self;
278
279 let is_safe_to_cache_prepared = query.is_safe_to_cache_prepared(&Mysql);
280 let mut qb = MysqlQueryBuilder::new();
281 let sql = query.to_sql(&mut qb, &Mysql).map(|()| qb.finish());
282 let query_id = T::query_id();
283
284 async move {
285 let RawBytesBindCollector {
286 metadata, binds, ..
287 } = bind_collector?;
288 let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
289 let sql = sql?;
290 let helper = QueryFragmentHelper {
291 sql,
292 safe_to_cache: is_safe_to_cache_prepared,
293 };
294
295 let inner = async {
296 let (stmt, conn) = {
297 stmt_cache
298 .cached_statement_non_generic(
299 query_id,
300 &helper,
301 &Mysql,
302 &metadata,
303 conn,
304 prepare_statement_helper,
305 &mut *instrumentation,
306 )
307 .await?
308 };
309 callback(conn, stmt, ToSqlHelper { metadata, binds }).await
310 };
311 let r = update_transaction_manager_status(inner.await, transaction_manager);
312
313 instrumentation.on_connection_event(InstrumentationEvent::finish_query(
314 &StrQueryHelper::new(&helper.sql),
315 r.as_ref().err(),
316 ));
317 r
318 }
319 .boxed()
320 }
321
322 async fn poll_result_stream(
323 conn: &mut mysql_async::Conn,
324 stmt_for_exec: mysql_async::Statement,
325 binds: ToSqlHelper,
326 mut tx: futures_channel::mpsc::Sender<QueryResult<MysqlRow>>,
327 ) -> QueryResult<()> {
328 use futures_util::sink::SinkExt;
329 let params = mysql_async::Params::try_from(binds)?;
330
331 let res = conn
332 .exec_iter(stmt_for_exec, params)
333 .await
334 .map_err(ErrorHelper)?;
335
336 let mut stream = res
337 .stream_and_drop::<MysqlRow>()
338 .await
339 .map_err(ErrorHelper)?
340 .ok_or_else(|| {
341 diesel::result::Error::DeserializationError(Box::new(
342 diesel::result::UnexpectedEndOfRow,
343 ))
344 })?
345 .map_err(|e| diesel::result::Error::from(ErrorHelper(e)));
346
347 while let Some(row) = stream.next().await {
348 let row = row?;
349 tx.send(Ok(row))
350 .await
351 .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))?;
352 }
353
354 Ok(())
355 }
356
357 async fn establish_connection_inner(
358 database_url: &str,
359 ) -> Result<AsyncMysqlConnection, ConnectionError> {
360 let opts = Opts::from_url(database_url)
361 .map_err(|e| diesel::result::ConnectionError::InvalidConnectionUrl(e.to_string()))?;
362 let builder = OptsBuilder::from_opts(opts)
363 .init(CONNECTION_SETUP_QUERIES.to_vec())
364 .stmt_cache_size(0) .client_found_rows(true); let conn = mysql_async::Conn::new(builder).await.map_err(ErrorHelper)?;
368
369 Ok(AsyncMysqlConnection {
370 conn,
371 stmt_cache: StatementCache::new(),
372 transaction_manager: AnsiTransactionManager::default(),
373 instrumentation: diesel::connection::get_default_instrumentation(),
374 })
375 }
376}
377
378#[cfg(any(
379 feature = "deadpool",
380 feature = "bb8",
381 feature = "mobc",
382 feature = "r2d2"
383))]
384impl crate::pooled_connection::PoolableConnection for AsyncMysqlConnection {}
385
386#[cfg(test)]
387mod tests {
388 use crate::RunQueryDsl;
389 mod diesel_async {
390 pub use crate::*;
391 }
392 include!("../doctest_setup.rs");
393
394 #[tokio::test]
395 async fn check_statements_are_dropped() {
396 use self::schema::users;
397
398 let mut conn = establish_connection().await;
399 let stmt_count = 16382 + 10;
402
403 for i in 0..stmt_count {
404 diesel::insert_into(users::table)
405 .values(Some(users::name.eq(format!("User{i}"))))
406 .execute(&mut conn)
407 .await
408 .unwrap();
409 }
410
411 #[derive(QueryableByName)]
412 #[diesel(table_name = users)]
413 #[allow(dead_code)]
414 struct User {
415 id: i32,
416 name: String,
417 }
418
419 for i in 0..stmt_count {
420 diesel::sql_query("SELECT id, name FROM users WHERE name = ?")
421 .bind::<diesel::sql_types::Text, _>(format!("User{i}"))
422 .load::<User>(&mut conn)
423 .await
424 .unwrap();
425 }
426 }
427}
428
429impl QueryFragmentForCachedStatement<Mysql> for QueryFragmentHelper {
430 fn construct_sql(&self, _backend: &Mysql) -> QueryResult<String> {
431 Ok(self.sql.clone())
432 }
433
434 fn is_safe_to_cache_prepared(&self, _backend: &Mysql) -> QueryResult<bool> {
435 Ok(self.safe_to_cache)
436 }
437}