diesel_async/mysql/
mod.rs

1use crate::statement_cache::CacheSize;
2use crate::statement_cache::{MaybeCached, QueryFragmentForCachedStatement, StatementCache};
3use crate::stmt_cache::{CallbackHelper, QueryFragmentHelper};
4use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection};
5use diesel::connection::Instrumentation;
6use diesel::connection::InstrumentationEvent;
7use diesel::connection::StrQueryHelper;
8use diesel::mysql::{Mysql, MysqlQueryBuilder, MysqlType};
9use diesel::query_builder::QueryBuilder;
10use diesel::query_builder::{bind_collector::RawBytesBindCollector, QueryFragment, QueryId};
11use diesel::result::{ConnectionError, ConnectionResult};
12use diesel::QueryResult;
13use futures_core::future::BoxFuture;
14use futures_core::stream::BoxStream;
15use futures_util::stream;
16use futures_util::{FutureExt, StreamExt, TryStreamExt};
17use mysql_async::prelude::Queryable;
18use mysql_async::{Opts, OptsBuilder, Statement};
19use std::future::Future;
20
21mod error_helper;
22mod row;
23mod serialize;
24
25use self::error_helper::ErrorHelper;
26use self::row::MysqlRow;
27use self::serialize::ToSqlHelper;
28
29/// A connection to a MySQL database. Connection URLs should be in the form
30/// `mysql://[user[:password]@]host/database_name`
31pub struct AsyncMysqlConnection {
32    conn: mysql_async::Conn,
33    stmt_cache: StatementCache<Mysql, Statement>,
34    transaction_manager: AnsiTransactionManager,
35    instrumentation: Option<Box<dyn Instrumentation>>,
36}
37
38impl SimpleAsyncConnection for AsyncMysqlConnection {
39    async fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
40        self.instrumentation()
41            .on_connection_event(InstrumentationEvent::start_query(&StrQueryHelper::new(
42                query,
43            )));
44        let result = self
45            .conn
46            .query_drop(query)
47            .await
48            .map_err(ErrorHelper)
49            .map_err(Into::into);
50        self.instrumentation()
51            .on_connection_event(InstrumentationEvent::finish_query(
52                &StrQueryHelper::new(query),
53                result.as_ref().err(),
54            ));
55        result
56    }
57}
58
59const CONNECTION_SETUP_QUERIES: &[&str] = &[
60    "SET time_zone = '+00:00';",
61    "SET character_set_client = 'utf8mb4'",
62    "SET character_set_connection = 'utf8mb4'",
63    "SET character_set_results = 'utf8mb4'",
64];
65
66impl AsyncConnection for AsyncMysqlConnection {
67    type ExecuteFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<usize>>;
68    type LoadFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<Self::Stream<'conn, 'query>>>;
69    type Stream<'conn, 'query> = BoxStream<'conn, QueryResult<Self::Row<'conn, 'query>>>;
70    type Row<'conn, 'query> = MysqlRow;
71    type Backend = Mysql;
72
73    type TransactionManager = AnsiTransactionManager;
74
75    async fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
76        let mut instrumentation = diesel::connection::get_default_instrumentation();
77        instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
78            database_url,
79        ));
80        let r = Self::establish_connection_inner(database_url).await;
81        instrumentation.on_connection_event(InstrumentationEvent::finish_establish_connection(
82            database_url,
83            r.as_ref().err(),
84        ));
85        let mut conn = r?;
86        conn.instrumentation = instrumentation;
87        Ok(conn)
88    }
89
90    fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
91    where
92        T: diesel::query_builder::AsQuery,
93        T::Query: diesel::query_builder::QueryFragment<Self::Backend>
94            + diesel::query_builder::QueryId
95            + 'query,
96    {
97        self.with_prepared_statement(source.as_query(), |conn, stmt, binds| async move {
98            let stmt_for_exec = match stmt {
99                MaybeCached::Cached(ref s) => (*s).clone(),
100                MaybeCached::CannotCache(ref s) => s.clone(),
101            };
102
103            let (tx, rx) = futures_channel::mpsc::channel(0);
104
105            let yielder = async move {
106                let r = Self::poll_result_stream(conn, stmt_for_exec, binds, tx).await;
107                // We need to close any non-cached statement explicitly here as otherwise
108                // we might error out on too many open statements. See https://github.com/weiznich/diesel_async/issues/26
109                // for details
110                //
111                // This might be problematic for cases where the stream is dropped before the end is reached
112                //
113                // Such behaviour might happen if users:
114                // * Just drop the future/stream after polling at least once (timeouts!!)
115                // * Users only fetch a fixed number of elements from the stream
116                //
117                // For now there is not really a good solution to this problem as this would require something like async drop
118                // (and even with async drop that would be really hard to solve due to the involved lifetimes)
119                if let MaybeCached::CannotCache(stmt) = stmt {
120                    conn.close(stmt).await.map_err(ErrorHelper)?;
121                }
122                r
123            };
124
125            let fake_stream = stream::once(yielder).filter_map(|e: QueryResult<()>| async move {
126                if let Err(e) = e {
127                    Some(Err(e))
128                } else {
129                    None
130                }
131            });
132
133            let stream = stream::select(fake_stream, rx).boxed();
134
135            Ok(stream)
136        })
137        .boxed()
138    }
139
140    fn execute_returning_count<'conn, 'query, T>(
141        &'conn mut self,
142        source: T,
143    ) -> Self::ExecuteFuture<'conn, 'query>
144    where
145        T: diesel::query_builder::QueryFragment<Self::Backend>
146            + diesel::query_builder::QueryId
147            + 'query,
148    {
149        self.with_prepared_statement(source, |conn, stmt, binds| async move {
150            let params = mysql_async::Params::try_from(binds)?;
151            conn.exec_drop(&*stmt, params).await.map_err(ErrorHelper)?;
152            // We need to close any non-cached statement explicitly here as otherwise
153            // we might error out on too many open statements. See https://github.com/weiznich/diesel_async/issues/26
154            // for details
155            //
156            // This might be problematic for cases where the stream is dropped before the end is reached
157            //
158            // Such behaviour might happen if users:
159            // * Just drop the future after polling at least once (timeouts!!)
160            //
161            // For now there is not really a good solution to this problem as this would require something like async drop
162            // (and even with async drop that would be really hard to solve due to the involved lifetimes)
163            if let MaybeCached::CannotCache(stmt) = stmt {
164                conn.close(stmt).await.map_err(ErrorHelper)?;
165            }
166            conn.affected_rows()
167                .try_into()
168                .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))
169        })
170    }
171
172    fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
173        &mut self.transaction_manager
174    }
175
176    fn instrumentation(&mut self) -> &mut dyn Instrumentation {
177        &mut self.instrumentation
178    }
179
180    fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
181        self.instrumentation = Some(Box::new(instrumentation));
182    }
183
184    fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
185        self.stmt_cache.set_cache_size(size);
186    }
187}
188
189#[inline(always)]
190fn update_transaction_manager_status<T>(
191    query_result: QueryResult<T>,
192    transaction_manager: &mut AnsiTransactionManager,
193) -> QueryResult<T> {
194    if let Err(diesel::result::Error::DatabaseError(
195        diesel::result::DatabaseErrorKind::SerializationFailure,
196        _,
197    )) = query_result
198    {
199        transaction_manager
200            .status
201            .set_requires_rollback_maybe_up_to_top_level(true)
202    }
203    query_result
204}
205
206fn prepare_statement_helper<'a>(
207    conn: &'a mut mysql_async::Conn,
208    sql: &str,
209    _is_for_cache: crate::statement_cache::PrepareForCache,
210    _metadata: &[MysqlType],
211) -> CallbackHelper<impl Future<Output = QueryResult<(Statement, &'a mut mysql_async::Conn)>> + Send>
212{
213    // ideally we wouldn't clone the SQL string here
214    // but as we usually cache statements anyway
215    // this is a fixed one time const
216    //
217    // The probleme with not cloning it is that we then cannot express
218    // the right result lifetime anymore (at least not easily)
219    let sql = sql.to_owned();
220    CallbackHelper(async move {
221        let s = conn.prep(sql).await.map_err(ErrorHelper)?;
222        Ok((s, conn))
223    })
224}
225
226impl AsyncMysqlConnection {
227    /// Wrap an existing [`mysql_async::Conn`] into a async diesel mysql connection
228    ///
229    /// This function constructs a new `AsyncMysqlConnection` based on an existing
230    /// [`mysql_async::Conn]`.
231    pub async fn try_from(conn: mysql_async::Conn) -> ConnectionResult<Self> {
232        use crate::run_query_dsl::RunQueryDsl;
233        let mut conn = AsyncMysqlConnection {
234            conn,
235            stmt_cache: StatementCache::new(),
236            transaction_manager: AnsiTransactionManager::default(),
237            instrumentation: diesel::connection::get_default_instrumentation(),
238        };
239
240        for stmt in CONNECTION_SETUP_QUERIES {
241            diesel::sql_query(*stmt)
242                .execute(&mut conn)
243                .await
244                .map_err(ConnectionError::CouldntSetupConfiguration)?;
245        }
246
247        Ok(conn)
248    }
249
250    fn with_prepared_statement<'conn, T, F, R>(
251        &'conn mut self,
252        query: T,
253        callback: impl (FnOnce(&'conn mut mysql_async::Conn, MaybeCached<'conn, Statement>, ToSqlHelper) -> F)
254            + Send
255            + 'conn,
256    ) -> BoxFuture<'conn, QueryResult<R>>
257    where
258        R: Send + 'conn,
259        T: QueryFragment<Mysql> + QueryId,
260        F: Future<Output = QueryResult<R>> + Send,
261    {
262        self.instrumentation()
263            .on_connection_event(InstrumentationEvent::start_query(&diesel::debug_query(
264                &query,
265            )));
266        let mut bind_collector = RawBytesBindCollector::<Mysql>::new();
267        let bind_collector = query
268            .collect_binds(&mut bind_collector, &mut (), &Mysql)
269            .map(|()| bind_collector);
270
271        let AsyncMysqlConnection {
272            ref mut conn,
273            ref mut stmt_cache,
274            ref mut transaction_manager,
275            ref mut instrumentation,
276            ..
277        } = self;
278
279        let is_safe_to_cache_prepared = query.is_safe_to_cache_prepared(&Mysql);
280        let mut qb = MysqlQueryBuilder::new();
281        let sql = query.to_sql(&mut qb, &Mysql).map(|()| qb.finish());
282        let query_id = T::query_id();
283
284        async move {
285            let RawBytesBindCollector {
286                metadata, binds, ..
287            } = bind_collector?;
288            let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
289            let sql = sql?;
290            let helper = QueryFragmentHelper {
291                sql,
292                safe_to_cache: is_safe_to_cache_prepared,
293            };
294
295            let inner = async {
296                let (stmt, conn) = {
297                    stmt_cache
298                        .cached_statement_non_generic(
299                            query_id,
300                            &helper,
301                            &Mysql,
302                            &metadata,
303                            conn,
304                            prepare_statement_helper,
305                            &mut *instrumentation,
306                        )
307                        .await?
308                };
309                callback(conn, stmt, ToSqlHelper { metadata, binds }).await
310            };
311            let r = update_transaction_manager_status(inner.await, transaction_manager);
312
313            instrumentation.on_connection_event(InstrumentationEvent::finish_query(
314                &StrQueryHelper::new(&helper.sql),
315                r.as_ref().err(),
316            ));
317            r
318        }
319        .boxed()
320    }
321
322    async fn poll_result_stream(
323        conn: &mut mysql_async::Conn,
324        stmt_for_exec: mysql_async::Statement,
325        binds: ToSqlHelper,
326        mut tx: futures_channel::mpsc::Sender<QueryResult<MysqlRow>>,
327    ) -> QueryResult<()> {
328        use futures_util::sink::SinkExt;
329        let params = mysql_async::Params::try_from(binds)?;
330
331        let res = conn
332            .exec_iter(stmt_for_exec, params)
333            .await
334            .map_err(ErrorHelper)?;
335
336        let mut stream = res
337            .stream_and_drop::<MysqlRow>()
338            .await
339            .map_err(ErrorHelper)?
340            .ok_or_else(|| {
341                diesel::result::Error::DeserializationError(Box::new(
342                    diesel::result::UnexpectedEndOfRow,
343                ))
344            })?
345            .map_err(|e| diesel::result::Error::from(ErrorHelper(e)));
346
347        while let Some(row) = stream.next().await {
348            let row = row?;
349            tx.send(Ok(row))
350                .await
351                .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))?;
352        }
353
354        Ok(())
355    }
356
357    async fn establish_connection_inner(
358        database_url: &str,
359    ) -> Result<AsyncMysqlConnection, ConnectionError> {
360        let opts = Opts::from_url(database_url)
361            .map_err(|e| diesel::result::ConnectionError::InvalidConnectionUrl(e.to_string()))?;
362        let builder = OptsBuilder::from_opts(opts)
363            .init(CONNECTION_SETUP_QUERIES.to_vec())
364            .stmt_cache_size(0) // We have our own cache
365            .client_found_rows(true); // This allows a consistent behavior between MariaDB/MySQL and PostgreSQL (and is already set in `diesel`)
366
367        let conn = mysql_async::Conn::new(builder).await.map_err(ErrorHelper)?;
368
369        Ok(AsyncMysqlConnection {
370            conn,
371            stmt_cache: StatementCache::new(),
372            transaction_manager: AnsiTransactionManager::default(),
373            instrumentation: diesel::connection::get_default_instrumentation(),
374        })
375    }
376}
377
378#[cfg(any(
379    feature = "deadpool",
380    feature = "bb8",
381    feature = "mobc",
382    feature = "r2d2"
383))]
384impl crate::pooled_connection::PoolableConnection for AsyncMysqlConnection {}
385
386#[cfg(test)]
387mod tests {
388    use crate::RunQueryDsl;
389    mod diesel_async {
390        pub use crate::*;
391    }
392    include!("../doctest_setup.rs");
393
394    #[tokio::test]
395    async fn check_statements_are_dropped() {
396        use self::schema::users;
397
398        let mut conn = establish_connection().await;
399        // we cannot set a lower limit here without admin privileges
400        // which makes this test really slow
401        let stmt_count = 16382 + 10;
402
403        for i in 0..stmt_count {
404            diesel::insert_into(users::table)
405                .values(Some(users::name.eq(format!("User{i}"))))
406                .execute(&mut conn)
407                .await
408                .unwrap();
409        }
410
411        #[derive(QueryableByName)]
412        #[diesel(table_name = users)]
413        #[allow(dead_code)]
414        struct User {
415            id: i32,
416            name: String,
417        }
418
419        for i in 0..stmt_count {
420            diesel::sql_query("SELECT id, name FROM users WHERE name = ?")
421                .bind::<diesel::sql_types::Text, _>(format!("User{i}"))
422                .load::<User>(&mut conn)
423                .await
424                .unwrap();
425        }
426    }
427}
428
429impl QueryFragmentForCachedStatement<Mysql> for QueryFragmentHelper {
430    fn construct_sql(&self, _backend: &Mysql) -> QueryResult<String> {
431        Ok(self.sql.clone())
432    }
433
434    fn is_safe_to_cache_prepared(&self, _backend: &Mysql) -> QueryResult<bool> {
435        Ok(self.safe_to_cache)
436    }
437}