1use crate::stmt_cache::{CallbackHelper, QueryFragmentHelper};
2use crate::{AnsiTransactionManager, AsyncConnection, AsyncConnectionCore, SimpleAsyncConnection};
3use diesel::connection::statement_cache::{
4 MaybeCached, QueryFragmentForCachedStatement, StatementCache,
5};
6use diesel::connection::StrQueryHelper;
7use diesel::connection::{CacheSize, Instrumentation};
8use diesel::connection::{DynInstrumentation, InstrumentationEvent};
9use diesel::mysql::{Mysql, MysqlQueryBuilder, MysqlType};
10use diesel::query_builder::QueryBuilder;
11use diesel::query_builder::{bind_collector::RawBytesBindCollector, QueryFragment, QueryId};
12use diesel::result::{ConnectionError, ConnectionResult};
13use diesel::QueryResult;
14use futures_core::future::BoxFuture;
15use futures_core::stream::BoxStream;
16use futures_core::Stream;
17use futures_util::{FutureExt, StreamExt, TryStreamExt};
18use mysql_async::prelude::Queryable;
19use mysql_async::{Opts, OptsBuilder, Statement};
20use std::future::Future;
21
22mod cancel_token;
23mod error_helper;
24mod row;
25mod serialize;
26
27pub use self::cancel_token::MysqlCancelToken;
28use self::error_helper::ErrorHelper;
29use self::row::MysqlRow;
30use self::serialize::ToSqlHelper;
31
32pub struct AsyncMysqlConnection {
35 conn: mysql_async::Conn,
36 stmt_cache: StatementCache<Mysql, Statement>,
37 transaction_manager: AnsiTransactionManager,
38 instrumentation: DynInstrumentation,
39 stmt_to_free: Vec<mysql_async::Statement>,
40}
41
42impl SimpleAsyncConnection for AsyncMysqlConnection {
43 async fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
44 self.instrumentation()
45 .on_connection_event(InstrumentationEvent::start_query(&StrQueryHelper::new(
46 query,
47 )));
48 let result = self
49 .conn
50 .query_drop(query)
51 .await
52 .map_err(ErrorHelper)
53 .map_err(Into::into);
54 self.instrumentation()
55 .on_connection_event(InstrumentationEvent::finish_query(
56 &StrQueryHelper::new(query),
57 result.as_ref().err(),
58 ));
59 result
60 }
61}
62
63const CONNECTION_SETUP_QUERIES: &[&str] = &[
64 "SET time_zone = '+00:00';",
65 "SET character_set_client = 'utf8mb4'",
66 "SET character_set_connection = 'utf8mb4'",
67 "SET character_set_results = 'utf8mb4'",
68];
69
70impl AsyncConnectionCore for AsyncMysqlConnection {
71 type ExecuteFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<usize>>;
72 type LoadFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<Self::Stream<'conn, 'query>>>;
73 type Stream<'conn, 'query> = BoxStream<'conn, QueryResult<Self::Row<'conn, 'query>>>;
74 type Row<'conn, 'query> = MysqlRow;
75 type Backend = Mysql;
76
77 fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
78 where
79 T: diesel::query_builder::AsQuery,
80 T::Query: diesel::query_builder::QueryFragment<Self::Backend>
81 + diesel::query_builder::QueryId
82 + 'query,
83 {
84 self.with_prepared_statement(source.as_query(), |conn, stmt, binds| async move {
85 Ok(Self::poll_result_stream(conn, stmt, binds).await?.boxed())
86 })
87 .boxed()
88 }
89
90 fn execute_returning_count<'conn, 'query, T>(
91 &'conn mut self,
92 source: T,
93 ) -> Self::ExecuteFuture<'conn, 'query>
94 where
95 T: diesel::query_builder::QueryFragment<Self::Backend>
96 + diesel::query_builder::QueryId
97 + 'query,
98 {
99 self.with_prepared_statement(source, |conn, stmt, binds| async move {
100 let params = mysql_async::Params::try_from(binds)?;
101 conn.exec_drop(&*stmt, params).await.map_err(ErrorHelper)?;
102 conn.affected_rows()
103 .try_into()
104 .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))
105 })
106 }
107}
108
109impl AsyncConnection for AsyncMysqlConnection {
110 type TransactionManager = AnsiTransactionManager;
111
112 async fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
113 let mut instrumentation = DynInstrumentation::default_instrumentation();
114 instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
115 database_url,
116 ));
117 let r = Self::establish_connection_inner(database_url).await;
118 instrumentation.on_connection_event(InstrumentationEvent::finish_establish_connection(
119 database_url,
120 r.as_ref().err(),
121 ));
122 let mut conn = r?;
123 conn.instrumentation = instrumentation;
124 Ok(conn)
125 }
126
127 fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
128 &mut self.transaction_manager
129 }
130
131 fn instrumentation(&mut self) -> &mut dyn Instrumentation {
132 &mut *self.instrumentation
133 }
134
135 fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
136 self.instrumentation = instrumentation.into();
137 }
138
139 fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
140 self.stmt_cache.set_cache_size(size);
141 }
142}
143
144#[inline(always)]
145fn update_transaction_manager_status<T>(
146 query_result: QueryResult<T>,
147 transaction_manager: &mut AnsiTransactionManager,
148) -> QueryResult<T> {
149 if let Err(diesel::result::Error::DatabaseError(
150 diesel::result::DatabaseErrorKind::SerializationFailure,
151 _,
152 )) = query_result
153 {
154 transaction_manager
155 .status
156 .set_requires_rollback_maybe_up_to_top_level(true)
157 }
158 query_result
159}
160
161fn prepare_statement_helper<'a>(
162 conn: &'a mut mysql_async::Conn,
163 sql: &str,
164 _is_for_cache: diesel::connection::statement_cache::PrepareForCache,
165 _metadata: &[MysqlType],
166) -> CallbackHelper<impl Future<Output = QueryResult<(Statement, &'a mut mysql_async::Conn)>> + Send>
167{
168 let sql = sql.to_owned();
175 CallbackHelper(async move {
176 let s = conn.prep(sql).await.map_err(ErrorHelper)?;
177 Ok((s, conn))
178 })
179}
180
181impl AsyncMysqlConnection {
182 pub async fn try_from(conn: mysql_async::Conn) -> ConnectionResult<Self> {
187 use crate::run_query_dsl::RunQueryDsl;
188 let mut conn = AsyncMysqlConnection {
189 conn,
190 stmt_cache: StatementCache::new(),
191 transaction_manager: AnsiTransactionManager::default(),
192 instrumentation: DynInstrumentation::default_instrumentation(),
193 stmt_to_free: Vec::new(),
194 };
195
196 for stmt in CONNECTION_SETUP_QUERIES {
197 diesel::sql_query(*stmt)
198 .execute(&mut conn)
199 .await
200 .map_err(ConnectionError::CouldntSetupConfiguration)?;
201 }
202
203 Ok(conn)
204 }
205
206 pub fn cancel_token(&self) -> MysqlCancelToken {
208 let kill_id = self.conn.id();
209 let opts = self.conn.opts().clone();
210
211 MysqlCancelToken { kill_id, opts }
212 }
213
214 fn with_prepared_statement<'conn, T, F, R>(
215 &'conn mut self,
216 query: T,
217 callback: impl (FnOnce(&'conn mut mysql_async::Conn, MaybeCached<'conn, Statement>, ToSqlHelper) -> F)
218 + Send
219 + 'conn,
220 ) -> BoxFuture<'conn, QueryResult<R>>
221 where
222 R: Send + 'conn,
223 T: QueryFragment<Mysql> + QueryId,
224 F: Future<Output = QueryResult<R>> + Send,
225 {
226 self.instrumentation()
227 .on_connection_event(InstrumentationEvent::start_query(&diesel::debug_query(
228 &query,
229 )));
230 let mut bind_collector = RawBytesBindCollector::<Mysql>::new();
231 let bind_collector = query
232 .collect_binds(&mut bind_collector, &mut (), &Mysql)
233 .map(|()| bind_collector);
234
235 let AsyncMysqlConnection {
236 ref mut conn,
237 ref mut stmt_cache,
238 ref mut transaction_manager,
239 ref mut instrumentation,
240 ref mut stmt_to_free,
241 ..
242 } = self;
243
244 let is_safe_to_cache_prepared = query.is_safe_to_cache_prepared(&Mysql);
245 let mut qb = MysqlQueryBuilder::new();
246 let sql = query.to_sql(&mut qb, &Mysql).map(|()| qb.finish());
247 let query_id = T::query_id();
248
249 async move {
250 for stmt in std::mem::take(stmt_to_free) {
261 conn.close(stmt).await.map_err(ErrorHelper)?;
262 }
263 let RawBytesBindCollector {
264 metadata, binds, ..
265 } = bind_collector?;
266 let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
267 let sql = sql?;
268 let helper = QueryFragmentHelper {
269 sql,
270 safe_to_cache: is_safe_to_cache_prepared,
271 };
272 let inner = async {
273 let (stmt, conn) = stmt_cache
274 .cached_statement_non_generic(
275 query_id,
276 &helper,
277 &Mysql,
278 &metadata,
279 conn,
280 prepare_statement_helper,
281 &mut **instrumentation,
282 )
283 .await?;
284 if let MaybeCached::CannotCache(stmt) = &stmt {
286 stmt_to_free.push(stmt.clone());
287 }
288 callback(conn, stmt, ToSqlHelper { metadata, binds }).await
289 };
290 let r = update_transaction_manager_status(inner.await, transaction_manager);
291 instrumentation.on_connection_event(InstrumentationEvent::finish_query(
292 &StrQueryHelper::new(&helper.sql),
293 r.as_ref().err(),
294 ));
295 r
296 }
297 .boxed()
298 }
299
300 async fn poll_result_stream<'conn>(
301 conn: &'conn mut mysql_async::Conn,
302 stmt: MaybeCached<'_, mysql_async::Statement>,
303 binds: ToSqlHelper,
304 ) -> QueryResult<impl Stream<Item = QueryResult<MysqlRow>> + Send + use<'conn>> {
305 let params = mysql_async::Params::try_from(binds)?;
306 let stmt_for_exec = match stmt {
307 MaybeCached::Cached(ref s) => {
308 (*s).clone()
309 },
310 MaybeCached::CannotCache(ref s) => {
311 s.clone()
312 },
313 _ => unreachable!(
314 "Diesel has only two variants here at the time of writing.\n\
315 If you ever see this error message please open in issue in the diesel-async issue tracker"
316 ),
317 };
318
319 let res = conn
320 .exec_iter(stmt_for_exec, params)
321 .await
322 .map_err(ErrorHelper)?;
323
324 let stream = res
325 .stream_and_drop::<MysqlRow>()
326 .await
327 .map_err(ErrorHelper)?
328 .ok_or_else(|| {
329 diesel::result::Error::DeserializationError(Box::new(
330 diesel::result::UnexpectedEndOfRow,
331 ))
332 })?
333 .map_err(|e| diesel::result::Error::from(ErrorHelper(e)));
334
335 Ok(stream)
336 }
337
338 async fn establish_connection_inner(
339 database_url: &str,
340 ) -> Result<AsyncMysqlConnection, ConnectionError> {
341 let opts = Opts::from_url(database_url)
342 .map_err(|e| diesel::result::ConnectionError::InvalidConnectionUrl(e.to_string()))?;
343 let builder = OptsBuilder::from_opts(opts)
344 .init(CONNECTION_SETUP_QUERIES.to_vec())
345 .stmt_cache_size(0) .client_found_rows(true); let conn = mysql_async::Conn::new(builder).await.map_err(ErrorHelper)?;
349
350 Ok(AsyncMysqlConnection {
351 conn,
352 stmt_cache: StatementCache::new(),
353 transaction_manager: AnsiTransactionManager::default(),
354 instrumentation: DynInstrumentation::none(),
355 stmt_to_free: Vec::new(),
356 })
357 }
358}
359
360#[cfg(any(
361 feature = "deadpool",
362 feature = "bb8",
363 feature = "mobc",
364 feature = "r2d2"
365))]
366impl crate::pooled_connection::PoolableConnection for AsyncMysqlConnection {}
367
368#[cfg(test)]
369mod tests {
370 use crate::RunQueryDsl;
371 mod diesel_async {
372 pub use crate::*;
373 }
374 include!("../doctest_setup.rs");
375
376 const STMT_COUNT: usize = 16382 + 1000;
377
378 #[derive(Queryable)]
379 #[expect(dead_code, reason = "used for the test as loading target")]
380 struct User {
381 id: i32,
382 name: String,
383 }
384
385 #[tokio::test]
386 async fn check_cached_statements_are_dropped() {
387 use self::schema::users;
388
389 let mut conn = establish_connection().await;
390
391 for _i in 0..STMT_COUNT {
392 users::table
393 .select(users::id)
394 .load::<i32>(&mut conn)
395 .await
396 .unwrap();
397 }
398 }
399
400 #[tokio::test]
401 async fn check_uncached_statements_are_dropped() {
402 use self::schema::users;
403
404 let mut conn = establish_connection().await;
405
406 for _i in 0..STMT_COUNT {
407 users::table
408 .filter(users::dsl::id.eq_any(&[1, 2]))
409 .load::<User>(&mut conn)
410 .await
411 .unwrap();
412 }
413 }
414
415 #[tokio::test]
416 async fn check_cached_statements_are_dropped_get_result() {
417 use self::schema::users;
418 use diesel::OptionalExtension;
419
420 let mut conn = establish_connection().await;
421
422 for _i in 0..STMT_COUNT {
423 users::table
424 .select(users::id)
425 .get_result::<i32>(&mut conn)
426 .await
427 .optional()
428 .unwrap();
429 }
430 }
431
432 #[tokio::test]
433 async fn check_uncached_statements_are_dropped_get_result() {
434 use self::schema::users;
435 use diesel::OptionalExtension;
436
437 let mut conn = establish_connection().await;
438
439 for _i in 0..STMT_COUNT {
440 users::table
441 .filter(users::dsl::id.eq_any(&[1, 2]))
442 .get_result::<User>(&mut conn)
443 .await
444 .optional()
445 .unwrap();
446 }
447 }
448}
449
450impl QueryFragmentForCachedStatement<Mysql> for QueryFragmentHelper {
451 fn construct_sql(&self, _backend: &Mysql) -> QueryResult<String> {
452 Ok(self.sql.clone())
453 }
454
455 fn is_safe_to_cache_prepared(&self, _backend: &Mysql) -> QueryResult<bool> {
456 Ok(self.safe_to_cache)
457 }
458}