1use crate::stmt_cache::{CallbackHelper, QueryFragmentHelper};
2use crate::{AnsiTransactionManager, AsyncConnection, AsyncConnectionCore, SimpleAsyncConnection};
3use diesel::connection::statement_cache::{
4 MaybeCached, QueryFragmentForCachedStatement, StatementCache,
5};
6use diesel::connection::StrQueryHelper;
7use diesel::connection::{CacheSize, Instrumentation};
8use diesel::connection::{DynInstrumentation, InstrumentationEvent};
9use diesel::mysql::{Mysql, MysqlQueryBuilder, MysqlType};
10use diesel::query_builder::QueryBuilder;
11use diesel::query_builder::{bind_collector::RawBytesBindCollector, QueryFragment, QueryId};
12use diesel::result::{ConnectionError, ConnectionResult};
13use diesel::QueryResult;
14use futures_core::future::BoxFuture;
15use futures_core::stream::BoxStream;
16use futures_util::stream;
17use futures_util::{FutureExt, StreamExt, TryStreamExt};
18use mysql_async::prelude::Queryable;
19use mysql_async::{Opts, OptsBuilder, Statement};
20use std::future::Future;
21
22mod cancel_token;
23mod error_helper;
24mod row;
25mod serialize;
26
27pub use self::cancel_token::MysqlCancelToken;
28use self::error_helper::ErrorHelper;
29use self::row::MysqlRow;
30use self::serialize::ToSqlHelper;
31
32pub struct AsyncMysqlConnection {
35 conn: mysql_async::Conn,
36 stmt_cache: StatementCache<Mysql, Statement>,
37 transaction_manager: AnsiTransactionManager,
38 instrumentation: DynInstrumentation,
39}
40
41impl SimpleAsyncConnection for AsyncMysqlConnection {
42 async fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
43 self.instrumentation()
44 .on_connection_event(InstrumentationEvent::start_query(&StrQueryHelper::new(
45 query,
46 )));
47 let result = self
48 .conn
49 .query_drop(query)
50 .await
51 .map_err(ErrorHelper)
52 .map_err(Into::into);
53 self.instrumentation()
54 .on_connection_event(InstrumentationEvent::finish_query(
55 &StrQueryHelper::new(query),
56 result.as_ref().err(),
57 ));
58 result
59 }
60}
61
62const CONNECTION_SETUP_QUERIES: &[&str] = &[
63 "SET time_zone = '+00:00';",
64 "SET character_set_client = 'utf8mb4'",
65 "SET character_set_connection = 'utf8mb4'",
66 "SET character_set_results = 'utf8mb4'",
67];
68
69impl AsyncConnectionCore for AsyncMysqlConnection {
70 type ExecuteFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<usize>>;
71 type LoadFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<Self::Stream<'conn, 'query>>>;
72 type Stream<'conn, 'query> = BoxStream<'conn, QueryResult<Self::Row<'conn, 'query>>>;
73 type Row<'conn, 'query> = MysqlRow;
74 type Backend = Mysql;
75
76 fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
77 where
78 T: diesel::query_builder::AsQuery,
79 T::Query: diesel::query_builder::QueryFragment<Self::Backend>
80 + diesel::query_builder::QueryId
81 + 'query,
82 {
83 self.with_prepared_statement(source.as_query(), |conn, stmt, binds| async move {
84 let stmt_for_exec = match stmt {
85 MaybeCached::Cached(ref s) => (*s).clone(),
86 MaybeCached::CannotCache(ref s) => s.clone(),
87 _ => unreachable!(
88 "Diesel has only two variants here at the time of writing.\n\
89 If you ever see this error message please open in issue in the diesel-async issue tracker"
90 ),
91 };
92
93 let (tx, rx) = futures_channel::mpsc::channel(0);
94
95 let yielder = async move {
96 let r = Self::poll_result_stream(conn, stmt_for_exec, binds, tx).await;
97 if let MaybeCached::CannotCache(stmt) = stmt {
110 conn.close(stmt).await.map_err(ErrorHelper)?;
111 }
112 r
113 };
114
115 let fake_stream = stream::once(yielder).filter_map(|e: QueryResult<()>| async move {
116 if let Err(e) = e {
117 Some(Err(e))
118 } else {
119 None
120 }
121 });
122
123 let stream = stream::select(fake_stream, rx).boxed();
124
125 Ok(stream)
126 })
127 .boxed()
128 }
129
130 fn execute_returning_count<'conn, 'query, T>(
131 &'conn mut self,
132 source: T,
133 ) -> Self::ExecuteFuture<'conn, 'query>
134 where
135 T: diesel::query_builder::QueryFragment<Self::Backend>
136 + diesel::query_builder::QueryId
137 + 'query,
138 {
139 self.with_prepared_statement(source, |conn, stmt, binds| async move {
140 let params = mysql_async::Params::try_from(binds)?;
141 conn.exec_drop(&*stmt, params).await.map_err(ErrorHelper)?;
142 if let MaybeCached::CannotCache(stmt) = stmt {
154 conn.close(stmt).await.map_err(ErrorHelper)?;
155 }
156 conn.affected_rows()
157 .try_into()
158 .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))
159 })
160 }
161}
162
163impl AsyncConnection for AsyncMysqlConnection {
164 type TransactionManager = AnsiTransactionManager;
165
166 async fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
167 let mut instrumentation = DynInstrumentation::default_instrumentation();
168 instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
169 database_url,
170 ));
171 let r = Self::establish_connection_inner(database_url).await;
172 instrumentation.on_connection_event(InstrumentationEvent::finish_establish_connection(
173 database_url,
174 r.as_ref().err(),
175 ));
176 let mut conn = r?;
177 conn.instrumentation = instrumentation;
178 Ok(conn)
179 }
180
181 fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
182 &mut self.transaction_manager
183 }
184
185 fn instrumentation(&mut self) -> &mut dyn Instrumentation {
186 &mut *self.instrumentation
187 }
188
189 fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
190 self.instrumentation = instrumentation.into();
191 }
192
193 fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
194 self.stmt_cache.set_cache_size(size);
195 }
196}
197
198#[inline(always)]
199fn update_transaction_manager_status<T>(
200 query_result: QueryResult<T>,
201 transaction_manager: &mut AnsiTransactionManager,
202) -> QueryResult<T> {
203 if let Err(diesel::result::Error::DatabaseError(
204 diesel::result::DatabaseErrorKind::SerializationFailure,
205 _,
206 )) = query_result
207 {
208 transaction_manager
209 .status
210 .set_requires_rollback_maybe_up_to_top_level(true)
211 }
212 query_result
213}
214
215fn prepare_statement_helper<'a>(
216 conn: &'a mut mysql_async::Conn,
217 sql: &str,
218 _is_for_cache: diesel::connection::statement_cache::PrepareForCache,
219 _metadata: &[MysqlType],
220) -> CallbackHelper<impl Future<Output = QueryResult<(Statement, &'a mut mysql_async::Conn)>> + Send>
221{
222 let sql = sql.to_owned();
229 CallbackHelper(async move {
230 let s = conn.prep(sql).await.map_err(ErrorHelper)?;
231 Ok((s, conn))
232 })
233}
234
235impl AsyncMysqlConnection {
236 pub async fn try_from(conn: mysql_async::Conn) -> ConnectionResult<Self> {
241 use crate::run_query_dsl::RunQueryDsl;
242 let mut conn = AsyncMysqlConnection {
243 conn,
244 stmt_cache: StatementCache::new(),
245 transaction_manager: AnsiTransactionManager::default(),
246 instrumentation: DynInstrumentation::default_instrumentation(),
247 };
248
249 for stmt in CONNECTION_SETUP_QUERIES {
250 diesel::sql_query(*stmt)
251 .execute(&mut conn)
252 .await
253 .map_err(ConnectionError::CouldntSetupConfiguration)?;
254 }
255
256 Ok(conn)
257 }
258
259 pub fn cancel_token(&self) -> MysqlCancelToken {
261 let kill_id = self.conn.id();
262 let opts = self.conn.opts().clone();
263
264 MysqlCancelToken { kill_id, opts }
265 }
266
267 fn with_prepared_statement<'conn, T, F, R>(
268 &'conn mut self,
269 query: T,
270 callback: impl (FnOnce(&'conn mut mysql_async::Conn, MaybeCached<'conn, Statement>, ToSqlHelper) -> F)
271 + Send
272 + 'conn,
273 ) -> BoxFuture<'conn, QueryResult<R>>
274 where
275 R: Send + 'conn,
276 T: QueryFragment<Mysql> + QueryId,
277 F: Future<Output = QueryResult<R>> + Send,
278 {
279 self.instrumentation()
280 .on_connection_event(InstrumentationEvent::start_query(&diesel::debug_query(
281 &query,
282 )));
283 let mut bind_collector = RawBytesBindCollector::<Mysql>::new();
284 let bind_collector = query
285 .collect_binds(&mut bind_collector, &mut (), &Mysql)
286 .map(|()| bind_collector);
287
288 let AsyncMysqlConnection {
289 ref mut conn,
290 ref mut stmt_cache,
291 ref mut transaction_manager,
292 ref mut instrumentation,
293 ..
294 } = self;
295
296 let is_safe_to_cache_prepared = query.is_safe_to_cache_prepared(&Mysql);
297 let mut qb = MysqlQueryBuilder::new();
298 let sql = query.to_sql(&mut qb, &Mysql).map(|()| qb.finish());
299 let query_id = T::query_id();
300
301 async move {
302 let RawBytesBindCollector {
303 metadata, binds, ..
304 } = bind_collector?;
305 let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
306 let sql = sql?;
307 let helper = QueryFragmentHelper {
308 sql,
309 safe_to_cache: is_safe_to_cache_prepared,
310 };
311 let inner = async {
312 let (stmt, conn) = stmt_cache
313 .cached_statement_non_generic(
314 query_id,
315 &helper,
316 &Mysql,
317 &metadata,
318 conn,
319 prepare_statement_helper,
320 &mut **instrumentation,
321 )
322 .await?;
323 callback(conn, stmt, ToSqlHelper { metadata, binds }).await
324 };
325 let r = update_transaction_manager_status(inner.await, transaction_manager);
326 instrumentation.on_connection_event(InstrumentationEvent::finish_query(
327 &StrQueryHelper::new(&helper.sql),
328 r.as_ref().err(),
329 ));
330 r
331 }
332 .boxed()
333 }
334
335 async fn poll_result_stream(
336 conn: &mut mysql_async::Conn,
337 stmt_for_exec: mysql_async::Statement,
338 binds: ToSqlHelper,
339 mut tx: futures_channel::mpsc::Sender<QueryResult<MysqlRow>>,
340 ) -> QueryResult<()> {
341 use futures_util::sink::SinkExt;
342 let params = mysql_async::Params::try_from(binds)?;
343
344 let res = conn
345 .exec_iter(stmt_for_exec, params)
346 .await
347 .map_err(ErrorHelper)?;
348
349 let mut stream = res
350 .stream_and_drop::<MysqlRow>()
351 .await
352 .map_err(ErrorHelper)?
353 .ok_or_else(|| {
354 diesel::result::Error::DeserializationError(Box::new(
355 diesel::result::UnexpectedEndOfRow,
356 ))
357 })?
358 .map_err(|e| diesel::result::Error::from(ErrorHelper(e)));
359
360 while let Some(row) = stream.next().await {
361 let row = row?;
362 tx.send(Ok(row))
363 .await
364 .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))?;
365 }
366
367 Ok(())
368 }
369
370 async fn establish_connection_inner(
371 database_url: &str,
372 ) -> Result<AsyncMysqlConnection, ConnectionError> {
373 let opts = Opts::from_url(database_url)
374 .map_err(|e| diesel::result::ConnectionError::InvalidConnectionUrl(e.to_string()))?;
375 let builder = OptsBuilder::from_opts(opts)
376 .init(CONNECTION_SETUP_QUERIES.to_vec())
377 .stmt_cache_size(0) .client_found_rows(true); let conn = mysql_async::Conn::new(builder).await.map_err(ErrorHelper)?;
381
382 Ok(AsyncMysqlConnection {
383 conn,
384 stmt_cache: StatementCache::new(),
385 transaction_manager: AnsiTransactionManager::default(),
386 instrumentation: DynInstrumentation::none(),
387 })
388 }
389}
390
391#[cfg(any(
392 feature = "deadpool",
393 feature = "bb8",
394 feature = "mobc",
395 feature = "r2d2"
396))]
397impl crate::pooled_connection::PoolableConnection for AsyncMysqlConnection {}
398
399#[cfg(test)]
400mod tests {
401 use crate::RunQueryDsl;
402 mod diesel_async {
403 pub use crate::*;
404 }
405 include!("../doctest_setup.rs");
406
407 #[tokio::test]
408 async fn check_statements_are_dropped() {
409 use self::schema::users;
410
411 let mut conn = establish_connection().await;
412 let stmt_count = 16382 + 10;
415
416 for i in 0..stmt_count {
417 diesel::insert_into(users::table)
418 .values(Some(users::name.eq(format!("User{i}"))))
419 .execute(&mut conn)
420 .await
421 .unwrap();
422 }
423
424 #[derive(QueryableByName)]
425 #[diesel(table_name = users)]
426 #[allow(dead_code)]
427 struct User {
428 id: i32,
429 name: String,
430 }
431
432 for i in 0..stmt_count {
433 diesel::sql_query("SELECT id, name FROM users WHERE name = ?")
434 .bind::<diesel::sql_types::Text, _>(format!("User{i}"))
435 .load::<User>(&mut conn)
436 .await
437 .unwrap();
438 }
439 }
440}
441
442impl QueryFragmentForCachedStatement<Mysql> for QueryFragmentHelper {
443 fn construct_sql(&self, _backend: &Mysql) -> QueryResult<String> {
444 Ok(self.sql.clone())
445 }
446
447 fn is_safe_to_cache_prepared(&self, _backend: &Mysql) -> QueryResult<bool> {
448 Ok(self.safe_to_cache)
449 }
450}