diesel_async/mysql/
mod.rs

1use crate::stmt_cache::{CallbackHelper, QueryFragmentHelper};
2use crate::{AnsiTransactionManager, AsyncConnection, AsyncConnectionCore, SimpleAsyncConnection};
3use diesel::connection::statement_cache::{
4    MaybeCached, QueryFragmentForCachedStatement, StatementCache,
5};
6use diesel::connection::StrQueryHelper;
7use diesel::connection::{CacheSize, Instrumentation};
8use diesel::connection::{DynInstrumentation, InstrumentationEvent};
9use diesel::mysql::{Mysql, MysqlQueryBuilder, MysqlType};
10use diesel::query_builder::QueryBuilder;
11use diesel::query_builder::{bind_collector::RawBytesBindCollector, QueryFragment, QueryId};
12use diesel::result::{ConnectionError, ConnectionResult};
13use diesel::QueryResult;
14use futures_core::future::BoxFuture;
15use futures_core::stream::BoxStream;
16use futures_core::Stream;
17use futures_util::{FutureExt, StreamExt, TryStreamExt};
18use mysql_async::prelude::Queryable;
19use mysql_async::{Opts, OptsBuilder, Statement};
20use std::future::Future;
21
22mod cancel_token;
23mod error_helper;
24mod row;
25mod serialize;
26
27pub use self::cancel_token::MysqlCancelToken;
28use self::error_helper::ErrorHelper;
29use self::row::MysqlRow;
30use self::serialize::ToSqlHelper;
31
32/// A connection to a MySQL database. Connection URLs should be in the form
33/// `mysql://[user[:password]@]host/database_name`
34pub struct AsyncMysqlConnection {
35    conn: mysql_async::Conn,
36    stmt_cache: StatementCache<Mysql, Statement>,
37    transaction_manager: AnsiTransactionManager,
38    instrumentation: DynInstrumentation,
39    stmt_to_free: Vec<mysql_async::Statement>,
40}
41
42impl SimpleAsyncConnection for AsyncMysqlConnection {
43    async fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
44        self.instrumentation()
45            .on_connection_event(InstrumentationEvent::start_query(&StrQueryHelper::new(
46                query,
47            )));
48        let result = self
49            .conn
50            .query_drop(query)
51            .await
52            .map_err(ErrorHelper)
53            .map_err(Into::into);
54        self.instrumentation()
55            .on_connection_event(InstrumentationEvent::finish_query(
56                &StrQueryHelper::new(query),
57                result.as_ref().err(),
58            ));
59        result
60    }
61}
62
63const CONNECTION_SETUP_QUERIES: &[&str] = &[
64    "SET time_zone = '+00:00';",
65    "SET character_set_client = 'utf8mb4'",
66    "SET character_set_connection = 'utf8mb4'",
67    "SET character_set_results = 'utf8mb4'",
68];
69
70impl AsyncConnectionCore for AsyncMysqlConnection {
71    type ExecuteFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<usize>>;
72    type LoadFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<Self::Stream<'conn, 'query>>>;
73    type Stream<'conn, 'query> = BoxStream<'conn, QueryResult<Self::Row<'conn, 'query>>>;
74    type Row<'conn, 'query> = MysqlRow;
75    type Backend = Mysql;
76
77    fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
78    where
79        T: diesel::query_builder::AsQuery,
80        T::Query: diesel::query_builder::QueryFragment<Self::Backend>
81            + diesel::query_builder::QueryId
82            + 'query,
83    {
84        self.with_prepared_statement(source.as_query(), |conn, stmt, binds| async move {
85            Ok(Self::poll_result_stream(conn, stmt, binds).await?.boxed())
86        })
87        .boxed()
88    }
89
90    fn execute_returning_count<'conn, 'query, T>(
91        &'conn mut self,
92        source: T,
93    ) -> Self::ExecuteFuture<'conn, 'query>
94    where
95        T: diesel::query_builder::QueryFragment<Self::Backend>
96            + diesel::query_builder::QueryId
97            + 'query,
98    {
99        self.with_prepared_statement(source, |conn, stmt, binds| async move {
100            let params = mysql_async::Params::try_from(binds)?;
101            conn.exec_drop(&*stmt, params).await.map_err(ErrorHelper)?;
102            conn.affected_rows()
103                .try_into()
104                .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))
105        })
106    }
107}
108
109impl AsyncConnection for AsyncMysqlConnection {
110    type TransactionManager = AnsiTransactionManager;
111
112    async fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
113        let mut instrumentation = DynInstrumentation::default_instrumentation();
114        instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
115            database_url,
116        ));
117        let r = Self::establish_connection_inner(database_url).await;
118        instrumentation.on_connection_event(InstrumentationEvent::finish_establish_connection(
119            database_url,
120            r.as_ref().err(),
121        ));
122        let mut conn = r?;
123        conn.instrumentation = instrumentation;
124        Ok(conn)
125    }
126
127    fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
128        &mut self.transaction_manager
129    }
130
131    fn instrumentation(&mut self) -> &mut dyn Instrumentation {
132        &mut *self.instrumentation
133    }
134
135    fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
136        self.instrumentation = instrumentation.into();
137    }
138
139    fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
140        self.stmt_cache.set_cache_size(size);
141    }
142}
143
144#[inline(always)]
145fn update_transaction_manager_status<T>(
146    query_result: QueryResult<T>,
147    transaction_manager: &mut AnsiTransactionManager,
148) -> QueryResult<T> {
149    if let Err(diesel::result::Error::DatabaseError(
150        diesel::result::DatabaseErrorKind::SerializationFailure,
151        _,
152    )) = query_result
153    {
154        transaction_manager
155            .status
156            .set_requires_rollback_maybe_up_to_top_level(true)
157    }
158    query_result
159}
160
161fn prepare_statement_helper<'a>(
162    conn: &'a mut mysql_async::Conn,
163    sql: &str,
164    _is_for_cache: diesel::connection::statement_cache::PrepareForCache,
165    _metadata: &[MysqlType],
166) -> CallbackHelper<impl Future<Output = QueryResult<(Statement, &'a mut mysql_async::Conn)>> + Send>
167{
168    // ideally we wouldn't clone the SQL string here
169    // but as we usually cache statements anyway
170    // this is a fixed one time const
171    //
172    // The probleme with not cloning it is that we then cannot express
173    // the right result lifetime anymore (at least not easily)
174    let sql = sql.to_owned();
175    CallbackHelper(async move {
176        let s = conn.prep(sql).await.map_err(ErrorHelper)?;
177        Ok((s, conn))
178    })
179}
180
181impl AsyncMysqlConnection {
182    /// Wrap an existing [`mysql_async::Conn`] into a async diesel mysql connection
183    ///
184    /// This function constructs a new `AsyncMysqlConnection` based on an existing
185    /// [`mysql_async::Conn]`.
186    pub async fn try_from(conn: mysql_async::Conn) -> ConnectionResult<Self> {
187        use crate::run_query_dsl::RunQueryDsl;
188        let mut conn = AsyncMysqlConnection {
189            conn,
190            stmt_cache: StatementCache::new(),
191            transaction_manager: AnsiTransactionManager::default(),
192            instrumentation: DynInstrumentation::default_instrumentation(),
193            stmt_to_free: Vec::new(),
194        };
195
196        for stmt in CONNECTION_SETUP_QUERIES {
197            diesel::sql_query(*stmt)
198                .execute(&mut conn)
199                .await
200                .map_err(ConnectionError::CouldntSetupConfiguration)?;
201        }
202
203        Ok(conn)
204    }
205
206    /// Constructs a cancellation token that can later be used to request cancellation of a query running on the connection associated with this client.
207    pub fn cancel_token(&self) -> MysqlCancelToken {
208        let kill_id = self.conn.id();
209        let opts = self.conn.opts().clone();
210
211        MysqlCancelToken { kill_id, opts }
212    }
213
214    fn with_prepared_statement<'conn, T, F, R>(
215        &'conn mut self,
216        query: T,
217        callback: impl (FnOnce(&'conn mut mysql_async::Conn, MaybeCached<'conn, Statement>, ToSqlHelper) -> F)
218            + Send
219            + 'conn,
220    ) -> BoxFuture<'conn, QueryResult<R>>
221    where
222        R: Send + 'conn,
223        T: QueryFragment<Mysql> + QueryId,
224        F: Future<Output = QueryResult<R>> + Send,
225    {
226        self.instrumentation()
227            .on_connection_event(InstrumentationEvent::start_query(&diesel::debug_query(
228                &query,
229            )));
230        let mut bind_collector = RawBytesBindCollector::<Mysql>::new();
231        let bind_collector = query
232            .collect_binds(&mut bind_collector, &mut (), &Mysql)
233            .map(|()| bind_collector);
234
235        let AsyncMysqlConnection {
236            ref mut conn,
237            ref mut stmt_cache,
238            ref mut transaction_manager,
239            ref mut instrumentation,
240            ref mut stmt_to_free,
241            ..
242        } = self;
243
244        let is_safe_to_cache_prepared = query.is_safe_to_cache_prepared(&Mysql);
245        let mut qb = MysqlQueryBuilder::new();
246        let sql = query.to_sql(&mut qb, &Mysql).map(|()| qb.finish());
247        let query_id = T::query_id();
248
249        async move {
250            // We need to close any non-cached statement explicitly here as otherwise
251            // we might error out on too many open statements. See https://github.com/weiznich/diesel_async/issues/26
252            // and https://github.com/weiznich/diesel_async/issues/269 for details
253            //
254            // We remember these statements from the last run as there is currenly no relaible way to
255            // run this as destruction step after the execution finished. Users might abort polling the future, etc
256            //
257            // The overhead for this is keeping one additional statement open until the connection is used
258            // next, so you would need to have `max_prepared_stmt_count - 1` other statements open for this to cause issues.
259            // This is hopefully not a problem in practice
260            for stmt in std::mem::take(stmt_to_free) {
261                conn.close(stmt).await.map_err(ErrorHelper)?;
262            }
263            let RawBytesBindCollector {
264                metadata, binds, ..
265            } = bind_collector?;
266            let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
267            let sql = sql?;
268            let helper = QueryFragmentHelper {
269                sql,
270                safe_to_cache: is_safe_to_cache_prepared,
271            };
272            let inner = async {
273                let (stmt, conn) = stmt_cache
274                    .cached_statement_non_generic(
275                        query_id,
276                        &helper,
277                        &Mysql,
278                        &metadata,
279                        conn,
280                        prepare_statement_helper,
281                        &mut **instrumentation,
282                    )
283                    .await?;
284                // for any not cached statement we need to remember to close them on the next connection usage
285                if let MaybeCached::CannotCache(stmt) = &stmt {
286                    stmt_to_free.push(stmt.clone());
287                }
288                callback(conn, stmt, ToSqlHelper { metadata, binds }).await
289            };
290            let r = update_transaction_manager_status(inner.await, transaction_manager);
291            instrumentation.on_connection_event(InstrumentationEvent::finish_query(
292                &StrQueryHelper::new(&helper.sql),
293                r.as_ref().err(),
294            ));
295            r
296        }
297        .boxed()
298    }
299
300    async fn poll_result_stream<'conn>(
301        conn: &'conn mut mysql_async::Conn,
302        stmt: MaybeCached<'_, mysql_async::Statement>,
303        binds: ToSqlHelper,
304    ) -> QueryResult<impl Stream<Item = QueryResult<MysqlRow>> + Send + use<'conn>> {
305        let params = mysql_async::Params::try_from(binds)?;
306        let stmt_for_exec = match stmt {
307                MaybeCached::Cached(ref s) => {
308                    (*s).clone()
309                },
310                MaybeCached::CannotCache(ref s) => {
311                    s.clone()
312                },
313                _ => unreachable!(
314                    "Diesel has only two variants here at the time of writing.\n\
315                     If you ever see this error message please open in issue in the diesel-async issue tracker"
316                ),
317            };
318
319        let res = conn
320            .exec_iter(stmt_for_exec, params)
321            .await
322            .map_err(ErrorHelper)?;
323
324        let stream = res
325            .stream_and_drop::<MysqlRow>()
326            .await
327            .map_err(ErrorHelper)?
328            .ok_or_else(|| {
329                diesel::result::Error::DeserializationError(Box::new(
330                    diesel::result::UnexpectedEndOfRow,
331                ))
332            })?
333            .map_err(|e| diesel::result::Error::from(ErrorHelper(e)));
334
335        Ok(stream)
336    }
337
338    async fn establish_connection_inner(
339        database_url: &str,
340    ) -> Result<AsyncMysqlConnection, ConnectionError> {
341        let opts = Opts::from_url(database_url)
342            .map_err(|e| diesel::result::ConnectionError::InvalidConnectionUrl(e.to_string()))?;
343        let builder = OptsBuilder::from_opts(opts)
344            .init(CONNECTION_SETUP_QUERIES.to_vec())
345            .stmt_cache_size(0) // We have our own cache
346            .client_found_rows(true); // This allows a consistent behavior between MariaDB/MySQL and PostgreSQL (and is already set in `diesel`)
347
348        let conn = mysql_async::Conn::new(builder).await.map_err(ErrorHelper)?;
349
350        Ok(AsyncMysqlConnection {
351            conn,
352            stmt_cache: StatementCache::new(),
353            transaction_manager: AnsiTransactionManager::default(),
354            instrumentation: DynInstrumentation::none(),
355            stmt_to_free: Vec::new(),
356        })
357    }
358}
359
360#[cfg(any(
361    feature = "deadpool",
362    feature = "bb8",
363    feature = "mobc",
364    feature = "r2d2"
365))]
366impl crate::pooled_connection::PoolableConnection for AsyncMysqlConnection {}
367
368#[cfg(test)]
369mod tests {
370    use crate::RunQueryDsl;
371    mod diesel_async {
372        pub use crate::*;
373    }
374    include!("../doctest_setup.rs");
375
376    const STMT_COUNT: usize = 16382 + 1000;
377
378    #[derive(Queryable)]
379    #[expect(dead_code, reason = "used for the test as loading target")]
380    struct User {
381        id: i32,
382        name: String,
383    }
384
385    #[tokio::test]
386    async fn check_cached_statements_are_dropped() {
387        use self::schema::users;
388
389        let mut conn = establish_connection().await;
390
391        for _i in 0..STMT_COUNT {
392            users::table
393                .select(users::id)
394                .load::<i32>(&mut conn)
395                .await
396                .unwrap();
397        }
398    }
399
400    #[tokio::test]
401    async fn check_uncached_statements_are_dropped() {
402        use self::schema::users;
403
404        let mut conn = establish_connection().await;
405
406        for _i in 0..STMT_COUNT {
407            users::table
408                .filter(users::dsl::id.eq_any(&[1, 2]))
409                .load::<User>(&mut conn)
410                .await
411                .unwrap();
412        }
413    }
414
415    #[tokio::test]
416    async fn check_cached_statements_are_dropped_get_result() {
417        use self::schema::users;
418        use diesel::OptionalExtension;
419
420        let mut conn = establish_connection().await;
421
422        for _i in 0..STMT_COUNT {
423            users::table
424                .select(users::id)
425                .get_result::<i32>(&mut conn)
426                .await
427                .optional()
428                .unwrap();
429        }
430    }
431
432    #[tokio::test]
433    async fn check_uncached_statements_are_dropped_get_result() {
434        use self::schema::users;
435        use diesel::OptionalExtension;
436
437        let mut conn = establish_connection().await;
438
439        for _i in 0..STMT_COUNT {
440            users::table
441                .filter(users::dsl::id.eq_any(&[1, 2]))
442                .get_result::<User>(&mut conn)
443                .await
444                .optional()
445                .unwrap();
446        }
447    }
448}
449
450impl QueryFragmentForCachedStatement<Mysql> for QueryFragmentHelper {
451    fn construct_sql(&self, _backend: &Mysql) -> QueryResult<String> {
452        Ok(self.sql.clone())
453    }
454
455    fn is_safe_to_cache_prepared(&self, _backend: &Mysql) -> QueryResult<bool> {
456        Ok(self.safe_to_cache)
457    }
458}