diesel_async/mysql/
mod.rs

1use crate::stmt_cache::{CallbackHelper, QueryFragmentHelper};
2use crate::{AnsiTransactionManager, AsyncConnection, AsyncConnectionCore, SimpleAsyncConnection};
3use diesel::connection::statement_cache::{
4    MaybeCached, QueryFragmentForCachedStatement, StatementCache,
5};
6use diesel::connection::StrQueryHelper;
7use diesel::connection::{CacheSize, Instrumentation};
8use diesel::connection::{DynInstrumentation, InstrumentationEvent};
9use diesel::mysql::{Mysql, MysqlQueryBuilder, MysqlType};
10use diesel::query_builder::QueryBuilder;
11use diesel::query_builder::{bind_collector::RawBytesBindCollector, QueryFragment, QueryId};
12use diesel::result::{ConnectionError, ConnectionResult};
13use diesel::QueryResult;
14use futures_core::future::BoxFuture;
15use futures_core::stream::BoxStream;
16use futures_util::stream;
17use futures_util::{FutureExt, StreamExt, TryStreamExt};
18use mysql_async::prelude::Queryable;
19use mysql_async::{Opts, OptsBuilder, Statement};
20use std::future::Future;
21
22mod cancel_token;
23mod error_helper;
24mod row;
25mod serialize;
26
27pub use self::cancel_token::MysqlCancelToken;
28use self::error_helper::ErrorHelper;
29use self::row::MysqlRow;
30use self::serialize::ToSqlHelper;
31
32/// A connection to a MySQL database. Connection URLs should be in the form
33/// `mysql://[user[:password]@]host/database_name`
34pub struct AsyncMysqlConnection {
35    conn: mysql_async::Conn,
36    stmt_cache: StatementCache<Mysql, Statement>,
37    transaction_manager: AnsiTransactionManager,
38    instrumentation: DynInstrumentation,
39}
40
41impl SimpleAsyncConnection for AsyncMysqlConnection {
42    async fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
43        self.instrumentation()
44            .on_connection_event(InstrumentationEvent::start_query(&StrQueryHelper::new(
45                query,
46            )));
47        let result = self
48            .conn
49            .query_drop(query)
50            .await
51            .map_err(ErrorHelper)
52            .map_err(Into::into);
53        self.instrumentation()
54            .on_connection_event(InstrumentationEvent::finish_query(
55                &StrQueryHelper::new(query),
56                result.as_ref().err(),
57            ));
58        result
59    }
60}
61
62const CONNECTION_SETUP_QUERIES: &[&str] = &[
63    "SET time_zone = '+00:00';",
64    "SET character_set_client = 'utf8mb4'",
65    "SET character_set_connection = 'utf8mb4'",
66    "SET character_set_results = 'utf8mb4'",
67];
68
69impl AsyncConnectionCore for AsyncMysqlConnection {
70    type ExecuteFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<usize>>;
71    type LoadFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<Self::Stream<'conn, 'query>>>;
72    type Stream<'conn, 'query> = BoxStream<'conn, QueryResult<Self::Row<'conn, 'query>>>;
73    type Row<'conn, 'query> = MysqlRow;
74    type Backend = Mysql;
75
76    fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
77    where
78        T: diesel::query_builder::AsQuery,
79        T::Query: diesel::query_builder::QueryFragment<Self::Backend>
80            + diesel::query_builder::QueryId
81            + 'query,
82    {
83        self.with_prepared_statement(source.as_query(), |conn, stmt, binds| async move {
84            let stmt_for_exec = match stmt {
85                MaybeCached::Cached(ref s) => (*s).clone(),
86                MaybeCached::CannotCache(ref s) => s.clone(),
87                _ => unreachable!(
88                    "Diesel has only two variants here at the time of writing.\n\
89                     If you ever see this error message please open in issue in the diesel-async issue tracker"
90                ),
91            };
92
93            let (tx, rx) = futures_channel::mpsc::channel(0);
94
95            let yielder = async move {
96                let r = Self::poll_result_stream(conn, stmt_for_exec, binds, tx).await;
97                // We need to close any non-cached statement explicitly here as otherwise
98                // we might error out on too many open statements. See https://github.com/weiznich/diesel_async/issues/26
99                // for details
100                //
101                // This might be problematic for cases where the stream is dropped before the end is reached
102                //
103                // Such behaviour might happen if users:
104                // * Just drop the future/stream after polling at least once (timeouts!!)
105                // * Users only fetch a fixed number of elements from the stream
106                //
107                // For now there is not really a good solution to this problem as this would require something like async drop
108                // (and even with async drop that would be really hard to solve due to the involved lifetimes)
109                if let MaybeCached::CannotCache(stmt) = stmt {
110                    conn.close(stmt).await.map_err(ErrorHelper)?;
111                }
112                r
113            };
114
115            let fake_stream = stream::once(yielder).filter_map(|e: QueryResult<()>| async move {
116                if let Err(e) = e {
117                    Some(Err(e))
118                } else {
119                    None
120                }
121            });
122
123            let stream = stream::select(fake_stream, rx).boxed();
124
125            Ok(stream)
126        })
127        .boxed()
128    }
129
130    fn execute_returning_count<'conn, 'query, T>(
131        &'conn mut self,
132        source: T,
133    ) -> Self::ExecuteFuture<'conn, 'query>
134    where
135        T: diesel::query_builder::QueryFragment<Self::Backend>
136            + diesel::query_builder::QueryId
137            + 'query,
138    {
139        self.with_prepared_statement(source, |conn, stmt, binds| async move {
140            let params = mysql_async::Params::try_from(binds)?;
141            conn.exec_drop(&*stmt, params).await.map_err(ErrorHelper)?;
142            // We need to close any non-cached statement explicitly here as otherwise
143            // we might error out on too many open statements. See https://github.com/weiznich/diesel_async/issues/26
144            // for details
145            //
146            // This might be problematic for cases where the stream is dropped before the end is reached
147            //
148            // Such behaviour might happen if users:
149            // * Just drop the future after polling at least once (timeouts!!)
150            //
151            // For now there is not really a good solution to this problem as this would require something like async drop
152            // (and even with async drop that would be really hard to solve due to the involved lifetimes)
153            if let MaybeCached::CannotCache(stmt) = stmt {
154                conn.close(stmt).await.map_err(ErrorHelper)?;
155            }
156            conn.affected_rows()
157                .try_into()
158                .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))
159        })
160    }
161}
162
163impl AsyncConnection for AsyncMysqlConnection {
164    type TransactionManager = AnsiTransactionManager;
165
166    async fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
167        let mut instrumentation = DynInstrumentation::default_instrumentation();
168        instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
169            database_url,
170        ));
171        let r = Self::establish_connection_inner(database_url).await;
172        instrumentation.on_connection_event(InstrumentationEvent::finish_establish_connection(
173            database_url,
174            r.as_ref().err(),
175        ));
176        let mut conn = r?;
177        conn.instrumentation = instrumentation;
178        Ok(conn)
179    }
180
181    fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
182        &mut self.transaction_manager
183    }
184
185    fn instrumentation(&mut self) -> &mut dyn Instrumentation {
186        &mut *self.instrumentation
187    }
188
189    fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
190        self.instrumentation = instrumentation.into();
191    }
192
193    fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
194        self.stmt_cache.set_cache_size(size);
195    }
196}
197
198#[inline(always)]
199fn update_transaction_manager_status<T>(
200    query_result: QueryResult<T>,
201    transaction_manager: &mut AnsiTransactionManager,
202) -> QueryResult<T> {
203    if let Err(diesel::result::Error::DatabaseError(
204        diesel::result::DatabaseErrorKind::SerializationFailure,
205        _,
206    )) = query_result
207    {
208        transaction_manager
209            .status
210            .set_requires_rollback_maybe_up_to_top_level(true)
211    }
212    query_result
213}
214
215fn prepare_statement_helper<'a>(
216    conn: &'a mut mysql_async::Conn,
217    sql: &str,
218    _is_for_cache: diesel::connection::statement_cache::PrepareForCache,
219    _metadata: &[MysqlType],
220) -> CallbackHelper<impl Future<Output = QueryResult<(Statement, &'a mut mysql_async::Conn)>> + Send>
221{
222    // ideally we wouldn't clone the SQL string here
223    // but as we usually cache statements anyway
224    // this is a fixed one time const
225    //
226    // The probleme with not cloning it is that we then cannot express
227    // the right result lifetime anymore (at least not easily)
228    let sql = sql.to_owned();
229    CallbackHelper(async move {
230        let s = conn.prep(sql).await.map_err(ErrorHelper)?;
231        Ok((s, conn))
232    })
233}
234
235impl AsyncMysqlConnection {
236    /// Wrap an existing [`mysql_async::Conn`] into a async diesel mysql connection
237    ///
238    /// This function constructs a new `AsyncMysqlConnection` based on an existing
239    /// [`mysql_async::Conn]`.
240    pub async fn try_from(conn: mysql_async::Conn) -> ConnectionResult<Self> {
241        use crate::run_query_dsl::RunQueryDsl;
242        let mut conn = AsyncMysqlConnection {
243            conn,
244            stmt_cache: StatementCache::new(),
245            transaction_manager: AnsiTransactionManager::default(),
246            instrumentation: DynInstrumentation::default_instrumentation(),
247        };
248
249        for stmt in CONNECTION_SETUP_QUERIES {
250            diesel::sql_query(*stmt)
251                .execute(&mut conn)
252                .await
253                .map_err(ConnectionError::CouldntSetupConfiguration)?;
254        }
255
256        Ok(conn)
257    }
258
259    /// Constructs a cancellation token that can later be used to request cancellation of a query running on the connection associated with this client.
260    pub fn cancel_token(&self) -> MysqlCancelToken {
261        let kill_id = self.conn.id();
262        let opts = self.conn.opts().clone();
263
264        MysqlCancelToken { kill_id, opts }
265    }
266
267    fn with_prepared_statement<'conn, T, F, R>(
268        &'conn mut self,
269        query: T,
270        callback: impl (FnOnce(&'conn mut mysql_async::Conn, MaybeCached<'conn, Statement>, ToSqlHelper) -> F)
271            + Send
272            + 'conn,
273    ) -> BoxFuture<'conn, QueryResult<R>>
274    where
275        R: Send + 'conn,
276        T: QueryFragment<Mysql> + QueryId,
277        F: Future<Output = QueryResult<R>> + Send,
278    {
279        self.instrumentation()
280            .on_connection_event(InstrumentationEvent::start_query(&diesel::debug_query(
281                &query,
282            )));
283        let mut bind_collector = RawBytesBindCollector::<Mysql>::new();
284        let bind_collector = query
285            .collect_binds(&mut bind_collector, &mut (), &Mysql)
286            .map(|()| bind_collector);
287
288        let AsyncMysqlConnection {
289            ref mut conn,
290            ref mut stmt_cache,
291            ref mut transaction_manager,
292            ref mut instrumentation,
293            ..
294        } = self;
295
296        let is_safe_to_cache_prepared = query.is_safe_to_cache_prepared(&Mysql);
297        let mut qb = MysqlQueryBuilder::new();
298        let sql = query.to_sql(&mut qb, &Mysql).map(|()| qb.finish());
299        let query_id = T::query_id();
300
301        async move {
302            let RawBytesBindCollector {
303                metadata, binds, ..
304            } = bind_collector?;
305            let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
306            let sql = sql?;
307            let helper = QueryFragmentHelper {
308                sql,
309                safe_to_cache: is_safe_to_cache_prepared,
310            };
311            let inner = async {
312                let (stmt, conn) = stmt_cache
313                    .cached_statement_non_generic(
314                        query_id,
315                        &helper,
316                        &Mysql,
317                        &metadata,
318                        conn,
319                        prepare_statement_helper,
320                        &mut **instrumentation,
321                    )
322                    .await?;
323                callback(conn, stmt, ToSqlHelper { metadata, binds }).await
324            };
325            let r = update_transaction_manager_status(inner.await, transaction_manager);
326            instrumentation.on_connection_event(InstrumentationEvent::finish_query(
327                &StrQueryHelper::new(&helper.sql),
328                r.as_ref().err(),
329            ));
330            r
331        }
332        .boxed()
333    }
334
335    async fn poll_result_stream(
336        conn: &mut mysql_async::Conn,
337        stmt_for_exec: mysql_async::Statement,
338        binds: ToSqlHelper,
339        mut tx: futures_channel::mpsc::Sender<QueryResult<MysqlRow>>,
340    ) -> QueryResult<()> {
341        use futures_util::sink::SinkExt;
342        let params = mysql_async::Params::try_from(binds)?;
343
344        let res = conn
345            .exec_iter(stmt_for_exec, params)
346            .await
347            .map_err(ErrorHelper)?;
348
349        let mut stream = res
350            .stream_and_drop::<MysqlRow>()
351            .await
352            .map_err(ErrorHelper)?
353            .ok_or_else(|| {
354                diesel::result::Error::DeserializationError(Box::new(
355                    diesel::result::UnexpectedEndOfRow,
356                ))
357            })?
358            .map_err(|e| diesel::result::Error::from(ErrorHelper(e)));
359
360        while let Some(row) = stream.next().await {
361            let row = row?;
362            tx.send(Ok(row))
363                .await
364                .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))?;
365        }
366
367        Ok(())
368    }
369
370    async fn establish_connection_inner(
371        database_url: &str,
372    ) -> Result<AsyncMysqlConnection, ConnectionError> {
373        let opts = Opts::from_url(database_url)
374            .map_err(|e| diesel::result::ConnectionError::InvalidConnectionUrl(e.to_string()))?;
375        let builder = OptsBuilder::from_opts(opts)
376            .init(CONNECTION_SETUP_QUERIES.to_vec())
377            .stmt_cache_size(0) // We have our own cache
378            .client_found_rows(true); // This allows a consistent behavior between MariaDB/MySQL and PostgreSQL (and is already set in `diesel`)
379
380        let conn = mysql_async::Conn::new(builder).await.map_err(ErrorHelper)?;
381
382        Ok(AsyncMysqlConnection {
383            conn,
384            stmt_cache: StatementCache::new(),
385            transaction_manager: AnsiTransactionManager::default(),
386            instrumentation: DynInstrumentation::none(),
387        })
388    }
389}
390
391#[cfg(any(
392    feature = "deadpool",
393    feature = "bb8",
394    feature = "mobc",
395    feature = "r2d2"
396))]
397impl crate::pooled_connection::PoolableConnection for AsyncMysqlConnection {}
398
399#[cfg(test)]
400mod tests {
401    use crate::RunQueryDsl;
402    mod diesel_async {
403        pub use crate::*;
404    }
405    include!("../doctest_setup.rs");
406
407    #[tokio::test]
408    async fn check_statements_are_dropped() {
409        use self::schema::users;
410
411        let mut conn = establish_connection().await;
412        // we cannot set a lower limit here without admin privileges
413        // which makes this test really slow
414        let stmt_count = 16382 + 10;
415
416        for i in 0..stmt_count {
417            diesel::insert_into(users::table)
418                .values(Some(users::name.eq(format!("User{i}"))))
419                .execute(&mut conn)
420                .await
421                .unwrap();
422        }
423
424        #[derive(QueryableByName)]
425        #[diesel(table_name = users)]
426        #[allow(dead_code)]
427        struct User {
428            id: i32,
429            name: String,
430        }
431
432        for i in 0..stmt_count {
433            diesel::sql_query("SELECT id, name FROM users WHERE name = ?")
434                .bind::<diesel::sql_types::Text, _>(format!("User{i}"))
435                .load::<User>(&mut conn)
436                .await
437                .unwrap();
438        }
439    }
440}
441
442impl QueryFragmentForCachedStatement<Mysql> for QueryFragmentHelper {
443    fn construct_sql(&self, _backend: &Mysql) -> QueryResult<String> {
444        Ok(self.sql.clone())
445    }
446
447    fn is_safe_to_cache_prepared(&self, _backend: &Mysql) -> QueryResult<bool> {
448        Ok(self.safe_to_cache)
449    }
450}