db_pool/async/backend/mysql/
sqlx.rs

1use std::{borrow::Cow, pin::Pin};
2
3use async_trait::async_trait;
4use futures::Future;
5use sqlx::{
6    Connection, Executor, MySql, MySqlConnection, MySqlPool, Row,
7    mysql::{MySqlConnectOptions, MySqlPoolOptions},
8    pool::PoolConnection,
9};
10use uuid::Uuid;
11
12use crate::{common::statement::mysql, util::get_db_name};
13
14use super::{
15    super::{
16        common::error::sqlx::{BuildError, ConnectionError, PoolError, QueryError},
17        error::Error as BackendError,
18        r#trait::Backend,
19    },
20    r#trait::{MySQLBackend, MySQLBackendWrapper},
21};
22
23type CreateEntities = dyn Fn(MySqlConnection) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
24    + Send
25    + Sync
26    + 'static;
27
28/// [`sqlx MySQL`](https://docs.rs/sqlx/0.8.6/sqlx/struct.MySql.html) backend
29pub struct SqlxMySQLBackend {
30    privileged_opts: MySqlConnectOptions,
31    default_pool: MySqlPool,
32    create_restricted_pool: Box<dyn Fn() -> MySqlPoolOptions + Send + Sync + 'static>,
33    create_entities: Box<CreateEntities>,
34    drop_previous_databases_flag: bool,
35}
36
37impl SqlxMySQLBackend {
38    /// Creates a new [`sqlx MySQL`](https://docs.rs/sqlx/0.8.6/sqlx/struct.MySql.html) backend
39    /// # Example
40    /// ```
41    /// use db_pool::{r#async::SqlxMySQLBackend, PrivilegedMySQLConfig};
42    /// use dotenvy::dotenv;
43    /// use sqlx::{mysql::MySqlPoolOptions, Executor};
44    ///
45    /// async fn f() {
46    ///     dotenv().ok();
47    ///
48    ///     let config = PrivilegedMySQLConfig::from_env().unwrap();
49    ///
50    ///     let backend = SqlxMySQLBackend::new(
51    ///         config.into(),
52    ///         || MySqlPoolOptions::new().max_connections(10),
53    ///         || MySqlPoolOptions::new().max_connections(2),
54    ///         move |mut conn| {
55    ///             Box::pin(async move {
56    ///                 conn.execute("CREATE TABLE book(id INTEGER PRIMARY KEY AUTO_INCREMENT, title TEXT NOT NULL)")
57    ///                      .await
58    ///                      .unwrap();
59    ///             })
60    ///         },
61    ///     );
62    /// }
63    ///
64    /// tokio_test::block_on(f());
65    /// ```
66    pub fn new(
67        privileged_options: MySqlConnectOptions,
68        create_privileged_pool: impl Fn() -> MySqlPoolOptions,
69        create_restricted_pool: impl Fn() -> MySqlPoolOptions + Send + Sync + 'static,
70        create_entities: impl Fn(MySqlConnection) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
71        + Send
72        + Sync
73        + 'static,
74    ) -> Self {
75        let pool_opts = create_privileged_pool();
76        let default_pool = pool_opts.connect_lazy_with(privileged_options.clone());
77
78        Self {
79            privileged_opts: privileged_options,
80            default_pool,
81            create_restricted_pool: Box::new(create_restricted_pool),
82            create_entities: Box::new(create_entities),
83            drop_previous_databases_flag: true,
84        }
85    }
86
87    /// Drop databases created in previous runs upon initialization
88    #[must_use]
89    pub fn drop_previous_databases(self, value: bool) -> Self {
90        Self {
91            drop_previous_databases_flag: value,
92            ..self
93        }
94    }
95}
96
97#[async_trait]
98impl<'pool> MySQLBackend<'pool> for SqlxMySQLBackend {
99    type Connection = MySqlConnection;
100    type PooledConnection = PoolConnection<MySql>;
101    type Pool = MySqlPool;
102
103    type BuildError = BuildError;
104    type PoolError = PoolError;
105    type ConnectionError = ConnectionError;
106    type QueryError = QueryError;
107
108    async fn get_connection(&'pool self) -> Result<PoolConnection<MySql>, PoolError> {
109        self.default_pool.acquire().await.map_err(Into::into)
110    }
111
112    async fn execute_query(
113        &self,
114        query: &str,
115        conn: &mut MySqlConnection,
116    ) -> Result<(), QueryError> {
117        conn.execute(query).await?;
118        Ok(())
119    }
120
121    async fn batch_execute_query<'a>(
122        &self,
123        query: impl IntoIterator<Item = Cow<'a, str>> + Send,
124        conn: &mut MySqlConnection,
125    ) -> Result<(), QueryError> {
126        let chunks = query.into_iter().collect::<Vec<_>>();
127        if chunks.is_empty() {
128            Ok(())
129        } else {
130            let query = chunks.join(";");
131            self.execute_query(query.as_str(), conn).await
132        }
133    }
134
135    fn get_host(&self) -> &str {
136        self.privileged_opts.get_host()
137    }
138
139    async fn get_previous_database_names(
140        &self,
141        conn: &mut MySqlConnection,
142    ) -> Result<Vec<String>, QueryError> {
143        conn.fetch_all(mysql::GET_DATABASE_NAMES)
144            .await?
145            .iter()
146            .map(|row| row.try_get(0))
147            .collect::<Result<Vec<_>, _>>()
148            .map_err(Into::into)
149    }
150
151    async fn create_entities(&self, db_name: &str) -> Result<(), ConnectionError> {
152        let opts = self.privileged_opts.clone().database(db_name);
153        let conn = MySqlConnection::connect_with(&opts).await?;
154        (self.create_entities)(conn).await;
155        Ok(())
156    }
157
158    async fn create_connection_pool(&self, db_id: Uuid) -> Result<MySqlPool, BuildError> {
159        let db_name = get_db_name(db_id);
160        let db_name = db_name.as_str();
161        let opts = self
162            .privileged_opts
163            .clone()
164            .database(db_name)
165            .username(db_name)
166            .password(db_name);
167        let pool = (self.create_restricted_pool)().connect_lazy_with(opts);
168        Ok(pool)
169    }
170
171    async fn get_table_names(
172        &self,
173        db_name: &str,
174        conn: &mut MySqlConnection,
175    ) -> Result<Vec<String>, QueryError> {
176        conn.fetch_all(mysql::get_table_names(db_name).as_str())
177            .await?
178            .iter()
179            .map(|row| row.try_get(0))
180            .collect::<Result<Vec<_>, _>>()
181            .map_err(Into::into)
182    }
183
184    fn get_drop_previous_databases(&self) -> bool {
185        self.drop_previous_databases_flag
186    }
187}
188
189type BError = BackendError<BuildError, PoolError, ConnectionError, QueryError>;
190
191#[async_trait]
192impl Backend for SqlxMySQLBackend {
193    type Pool = MySqlPool;
194
195    type BuildError = BuildError;
196    type PoolError = PoolError;
197    type ConnectionError = ConnectionError;
198    type QueryError = QueryError;
199
200    async fn init(&self) -> Result<(), BError> {
201        MySQLBackendWrapper::new(self).init().await
202    }
203
204    async fn create(
205        &self,
206        db_id: uuid::Uuid,
207        restrict_privileges: bool,
208    ) -> Result<MySqlPool, BError> {
209        MySQLBackendWrapper::new(self)
210            .create(db_id, restrict_privileges)
211            .await
212    }
213
214    async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError> {
215        MySQLBackendWrapper::new(self).clean(db_id).await
216    }
217
218    async fn drop(&self, db_id: uuid::Uuid, _is_restricted: bool) -> Result<(), BError> {
219        MySQLBackendWrapper::new(self).drop(db_id).await
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    #![allow(clippy::unwrap_used, clippy::needless_return)]
226
227    use futures::{StreamExt, future::join_all};
228    use sqlx::{
229        Executor, FromRow, Row,
230        mysql::{MySqlConnectOptions, MySqlPoolOptions},
231        query, query_as,
232    };
233    use tokio_shared_rt::test;
234
235    use crate::{
236        r#async::{
237            backend::mysql::r#trait::tests::test_backend_creates_database_with_unrestricted_privileges,
238            db_pool::DatabasePoolBuilder,
239        },
240        common::statement::mysql::tests::{
241            CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
242        },
243        tests::get_privileged_mysql_config,
244    };
245
246    use super::{
247        super::r#trait::tests::{
248            MySQLDropLock, test_backend_cleans_database_with_tables,
249            test_backend_cleans_database_without_tables,
250            test_backend_creates_database_with_restricted_privileges, test_backend_drops_database,
251            test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
252            test_pool_drops_created_unrestricted_database, test_pool_drops_previous_databases,
253        },
254        SqlxMySQLBackend,
255    };
256
257    fn create_backend(with_table: bool) -> SqlxMySQLBackend {
258        let config = get_privileged_mysql_config();
259        let opts = MySqlConnectOptions::new().username(config.username.as_str());
260        let opts = if let Some(password) = &config.password {
261            opts.password(password)
262        } else {
263            opts
264        };
265        SqlxMySQLBackend::new(opts, MySqlPoolOptions::new, MySqlPoolOptions::new, {
266            move |mut conn| {
267                if with_table {
268                    Box::pin(async move {
269                        conn.execute_many(CREATE_ENTITIES_STATEMENTS.join(";").as_str())
270                            .collect::<Vec<_>>()
271                            .await
272                            .drain(..)
273                            .collect::<Result<Vec<_>, _>>()
274                            .unwrap();
275                    })
276                } else {
277                    Box::pin(async {})
278                }
279            }
280        })
281    }
282
283    #[test(flavor = "multi_thread", shared)]
284    async fn backend_drops_previous_databases() {
285        test_backend_drops_previous_databases(
286            create_backend(false),
287            create_backend(false).drop_previous_databases(true),
288            create_backend(false).drop_previous_databases(false),
289        )
290        .await;
291    }
292
293    #[test(flavor = "multi_thread", shared)]
294    async fn backend_creates_database_with_restricted_privileges() {
295        let backend = create_backend(true).drop_previous_databases(false);
296        test_backend_creates_database_with_restricted_privileges(backend).await;
297    }
298
299    #[test(flavor = "multi_thread", shared)]
300    async fn backend_creates_database_with_unrestricted_privileges() {
301        let backend = create_backend(true).drop_previous_databases(false);
302        test_backend_creates_database_with_unrestricted_privileges(backend).await;
303    }
304
305    #[test(flavor = "multi_thread", shared)]
306    async fn backend_cleans_database_with_tables() {
307        let backend = create_backend(true).drop_previous_databases(false);
308        test_backend_cleans_database_with_tables(backend).await;
309    }
310
311    #[test(flavor = "multi_thread", shared)]
312    async fn backend_cleans_database_without_tables() {
313        let backend = create_backend(false).drop_previous_databases(false);
314        test_backend_cleans_database_without_tables(backend).await;
315    }
316
317    #[test(flavor = "multi_thread", shared)]
318    async fn backend_drops_restricted_database() {
319        let backend = create_backend(true).drop_previous_databases(false);
320        test_backend_drops_database(backend, true).await;
321    }
322
323    #[test(flavor = "multi_thread", shared)]
324    async fn backend_drops_unrestricted_database() {
325        let backend = create_backend(true).drop_previous_databases(false);
326        test_backend_drops_database(backend, false).await;
327    }
328
329    #[test(flavor = "multi_thread", shared)]
330    async fn pool_drops_previous_databases() {
331        test_pool_drops_previous_databases(
332            create_backend(false),
333            create_backend(false).drop_previous_databases(true),
334            create_backend(false).drop_previous_databases(false),
335        )
336        .await;
337    }
338
339    #[test(flavor = "multi_thread", shared)]
340    async fn pool_provides_isolated_databases() {
341        #[derive(FromRow, Eq, PartialEq, Debug)]
342        struct Book {
343            title: String,
344        }
345
346        const NUM_DBS: i64 = 3;
347
348        let backend = create_backend(true).drop_previous_databases(false);
349
350        async {
351            let db_pool = backend.create_database_pool().await.unwrap();
352            let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
353
354            // insert single row into each database
355            join_all(
356                conn_pools
357                    .iter()
358                    .enumerate()
359                    .map(|(i, conn_pool)| async move {
360                        query("INSERT INTO book (title) VALUES (?)")
361                            .bind(format!("Title {i}"))
362                            .execute(&***conn_pool)
363                            .await
364                            .unwrap();
365                    }),
366            )
367            .await;
368
369            // rows fetched must be as inserted
370            join_all(
371                conn_pools
372                    .iter()
373                    .enumerate()
374                    .map(|(i, conn_pool)| async move {
375                        assert_eq!(
376                            query_as::<_, Book>("SELECT title FROM book")
377                                .fetch_all(&***conn_pool)
378                                .await
379                                .unwrap(),
380                            vec![Book {
381                                title: format!("Title {i}")
382                            }]
383                        );
384                    }),
385            )
386            .await;
387        }
388        .lock_read()
389        .await;
390    }
391
392    #[test(flavor = "multi_thread", shared)]
393    async fn pool_provides_restricted_databases() {
394        let backend = create_backend(true).drop_previous_databases(false);
395
396        async {
397            let db_pool = backend.create_database_pool().await.unwrap();
398
399            let conn_pool = db_pool.pull_immutable().await;
400            let conn = &mut conn_pool.acquire().await.unwrap();
401
402            // DDL statements must fail
403            for stmt in DDL_STATEMENTS {
404                assert!(conn.execute(stmt).await.is_err());
405            }
406
407            // DML statements must succeed
408            for stmt in DML_STATEMENTS {
409                assert!(conn.execute(stmt).await.is_ok());
410            }
411        }
412        .lock_read()
413        .await;
414    }
415
416    #[test(flavor = "multi_thread", shared)]
417    async fn pool_provides_unrestricted_databases() {
418        let backend = create_backend(true).drop_previous_databases(false);
419
420        async {
421            let db_pool = backend.create_database_pool().await.unwrap();
422
423            // DML statements must succeed
424            {
425                let conn_pool = db_pool.create_mutable().await.unwrap();
426                let conn = &mut conn_pool.acquire().await.unwrap();
427                for stmt in DML_STATEMENTS {
428                    assert!(conn.execute(stmt).await.is_ok());
429                }
430            }
431
432            // DDL statements must succeed
433            for stmt in DDL_STATEMENTS {
434                let conn_pool = db_pool.create_mutable().await.unwrap();
435                let conn = &mut conn_pool.acquire().await.unwrap();
436                assert!(conn.execute(stmt).await.is_ok());
437            }
438        }
439        .lock_read()
440        .await;
441    }
442
443    #[test(flavor = "multi_thread", shared)]
444    async fn pool_provides_clean_databases() {
445        const NUM_DBS: i64 = 3;
446
447        let backend = create_backend(true).drop_previous_databases(false);
448
449        async {
450            let db_pool = backend.create_database_pool().await.unwrap();
451
452            // fetch connection pools the first time
453            {
454                let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
455
456                // databases must be empty
457                join_all(conn_pools.iter().map(|conn_pool| async move {
458                    assert_eq!(
459                        query("SELECT COUNT(*) FROM book")
460                            .fetch_one(&***conn_pool)
461                            .await
462                            .unwrap()
463                            .get::<i64, _>(0),
464                        0
465                    );
466                }))
467                .await;
468
469                // insert data into each database
470                join_all(conn_pools.iter().map(|conn_pool| async move {
471                    query("INSERT INTO book (title) VALUES (?)")
472                        .bind("Title")
473                        .execute(&***conn_pool)
474                        .await
475                        .unwrap();
476                }))
477                .await;
478            }
479
480            // fetch same connection pools a second time
481            {
482                let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
483
484                // databases must be empty
485                join_all(conn_pools.iter().map(|conn_pool| async move {
486                    assert_eq!(
487                        query("SELECT COUNT(*) FROM book")
488                            .fetch_one(&***conn_pool)
489                            .await
490                            .unwrap()
491                            .get::<i64, _>(0),
492                        0
493                    );
494                }))
495                .await;
496            }
497        }
498        .lock_read()
499        .await;
500    }
501
502    #[test(flavor = "multi_thread", shared)]
503    async fn pool_drops_created_restricted_databases() {
504        let backend = create_backend(false);
505        test_pool_drops_created_restricted_databases(backend).await;
506    }
507
508    #[test(flavor = "multi_thread", shared)]
509    async fn pool_drops_created_unrestricted_databases() {
510        let backend = create_backend(false);
511        test_pool_drops_created_unrestricted_database(backend).await;
512    }
513}