db_pool/async/backend/mysql/
diesel.rs

1use std::{borrow::Cow, pin::Pin};
2
3use async_trait::async_trait;
4use diesel::{prelude::*, result::Error, sql_query, table};
5use diesel_async::{
6    pooled_connection::{AsyncDieselConnectionManager, ManagerConfig, SetupCallback},
7    AsyncConnection, AsyncMysqlConnection, RunQueryDsl, SimpleAsyncConnection,
8};
9use futures::{future::FutureExt, Future};
10use uuid::Uuid;
11
12use crate::{
13    common::{config::mysql::PrivilegedMySQLConfig, statement::mysql},
14    util::get_db_name,
15};
16
17use super::{
18    super::{
19        common::pool::diesel::r#trait::DieselPoolAssociation, error::Error as BackendError,
20        r#trait::Backend,
21    },
22    r#trait::{MySQLBackend, MySQLBackendWrapper},
23};
24
25type CreateEntities = dyn Fn(AsyncMysqlConnection) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
26    + Send
27    + Sync
28    + 'static;
29
30/// [`Diesel async MySQL`](https://docs.rs/diesel-async/0.5.0/diesel_async/struct.AsyncMysqlConnection.html) backend
31pub struct DieselAsyncMySQLBackend<P: DieselPoolAssociation<AsyncMysqlConnection>> {
32    privileged_config: PrivilegedMySQLConfig,
33    default_pool: P::Pool,
34    create_restricted_pool: Box<dyn Fn() -> P::Builder + Send + Sync + 'static>,
35    create_connection: Box<dyn Fn() -> SetupCallback<AsyncMysqlConnection> + Send + Sync + 'static>,
36    create_entities: Box<CreateEntities>,
37    drop_previous_databases_flag: bool,
38}
39
40impl<P: DieselPoolAssociation<AsyncMysqlConnection>> DieselAsyncMySQLBackend<P> {
41    /// Creates a new [`Diesel async MySQL`](https://docs.rs/diesel-async/0.5.0/diesel_async/struct.AsyncMysqlConnection.html) backend
42    /// # Example
43    /// ```
44    /// use bb8::Pool;
45    /// use db_pool::{
46    ///     r#async::{DieselAsyncMySQLBackend, DieselBb8},
47    ///     PrivilegedMySQLConfig,
48    /// };
49    /// use diesel::sql_query;
50    /// use diesel_async::RunQueryDsl;
51    /// use dotenvy::dotenv;
52    ///
53    /// async fn f() {
54    ///     dotenv().ok();
55    ///
56    ///     let config = PrivilegedMySQLConfig::from_env().unwrap();
57    ///
58    ///     let backend = DieselAsyncMySQLBackend::<DieselBb8>::new(
59    ///         config,
60    ///         || Pool::builder().max_size(10),
61    ///         || Pool::builder().max_size(2),
62    ///         None,
63    ///         move |mut conn| {
64    ///             Box::pin(async move {
65    ///                 sql_query("CREATE TABLE book(id INTEGER PRIMARY KEY AUTO_INCREMENT, title TEXT NOT NULL)")
66    ///                     .execute(&mut conn)
67    ///                     .await
68    ///                     .unwrap();
69    ///             })
70    ///         },
71    ///     )
72    ///     .await
73    ///     .unwrap();
74    /// }
75    ///
76    /// tokio_test::block_on(f());
77    /// ```
78    pub async fn new(
79        privileged_config: PrivilegedMySQLConfig,
80        create_privileged_pool: impl Fn() -> P::Builder,
81        create_restricted_pool: impl Fn() -> P::Builder + Send + Sync + 'static,
82        custom_create_connection: Option<
83            Box<dyn Fn() -> SetupCallback<AsyncMysqlConnection> + Send + Sync + 'static>,
84        >,
85        create_entities: impl Fn(AsyncMysqlConnection) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
86            + Send
87            + Sync
88            + 'static,
89    ) -> Result<Self, P::BuildError> {
90        let create_connection = custom_create_connection.unwrap_or_else(|| {
91            Box::new(|| {
92                Box::new(|connection_url| AsyncMysqlConnection::establish(connection_url).boxed())
93            })
94        });
95
96        let manager_config = {
97            let mut config = ManagerConfig::default();
98            config.custom_setup = Box::new(create_connection());
99            config
100        };
101        let manager = AsyncDieselConnectionManager::new_with_config(
102            privileged_config.default_connection_url(),
103            manager_config,
104        );
105        let builder = create_privileged_pool();
106        let default_pool = P::build_pool(builder, manager).await?;
107
108        Ok(Self {
109            privileged_config,
110            default_pool,
111            create_restricted_pool: Box::new(create_restricted_pool),
112            create_connection: Box::new(create_connection),
113            create_entities: Box::new(create_entities),
114            drop_previous_databases_flag: true,
115        })
116    }
117
118    /// Drop databases created in previous runs upon initialization
119    #[must_use]
120    pub fn drop_previous_databases(self, value: bool) -> Self {
121        Self {
122            drop_previous_databases_flag: value,
123            ..self
124        }
125    }
126}
127
128#[async_trait]
129impl<'pool, P: DieselPoolAssociation<AsyncMysqlConnection>> MySQLBackend<'pool>
130    for DieselAsyncMySQLBackend<P>
131{
132    type Connection = AsyncMysqlConnection;
133    type PooledConnection = P::PooledConnection<'pool>;
134    type Pool = P::Pool;
135
136    type BuildError = P::BuildError;
137    type PoolError = P::PoolError;
138    type ConnectionError = ConnectionError;
139    type QueryError = Error;
140
141    async fn get_connection(&'pool self) -> Result<P::PooledConnection<'pool>, P::PoolError> {
142        P::get_connection(&self.default_pool).await
143    }
144
145    async fn execute_query(&self, query: &str, conn: &mut AsyncMysqlConnection) -> QueryResult<()> {
146        sql_query(query).execute(conn).await?;
147        Ok(())
148    }
149
150    async fn batch_execute_query<'a>(
151        &self,
152        query: impl IntoIterator<Item = Cow<'a, str>> + Send,
153        conn: &mut AsyncMysqlConnection,
154    ) -> QueryResult<()> {
155        let query = query.into_iter().collect::<Vec<_>>();
156        if query.is_empty() {
157            Ok(())
158        } else {
159            conn.batch_execute(query.join(";").as_str()).await
160        }
161    }
162
163    fn get_host(&self) -> &str {
164        self.privileged_config.host.as_str()
165    }
166
167    async fn get_previous_database_names(
168        &self,
169        conn: &mut AsyncMysqlConnection,
170    ) -> QueryResult<Vec<String>> {
171        table! {
172            schemata (schema_name) {
173                schema_name -> Text
174            }
175        }
176
177        schemata::table
178            .select(schemata::schema_name)
179            .filter(schemata::schema_name.like("db_pool_%"))
180            .load::<String>(conn)
181            .await
182    }
183
184    async fn create_entities(&self, db_name: &str) -> Result<(), ConnectionError> {
185        let database_url = self
186            .privileged_config
187            .privileged_database_connection_url(db_name);
188        let conn = (self.create_connection)()(database_url.as_str()).await?;
189        (self.create_entities)(conn).await;
190        Ok(())
191    }
192
193    async fn create_connection_pool(&self, db_id: Uuid) -> Result<P::Pool, P::BuildError> {
194        let db_name = get_db_name(db_id);
195        let db_name = db_name.as_str();
196        let database_url = self.privileged_config.restricted_database_connection_url(
197            db_name,
198            Some(db_name),
199            db_name,
200        );
201        let manager_config = {
202            let mut config = ManagerConfig::default();
203            config.custom_setup = (self.create_connection)();
204            config
205        };
206        let manager = AsyncDieselConnectionManager::<AsyncMysqlConnection>::new_with_config(
207            database_url.as_str(),
208            manager_config,
209        );
210        let builder = (self.create_restricted_pool)();
211        P::build_pool(builder, manager).await
212    }
213
214    async fn get_table_names(
215        &self,
216        db_name: &str,
217        conn: &mut AsyncMysqlConnection,
218    ) -> QueryResult<Vec<String>> {
219        table! {
220            tables (table_name) {
221                table_name -> Text,
222                table_schema -> Text
223            }
224        }
225
226        sql_query(mysql::USE_DEFAULT_DATABASE).execute(conn).await?;
227
228        tables::table
229            .filter(tables::table_schema.eq(db_name))
230            .select(tables::table_name)
231            .load::<String>(conn)
232            .await
233    }
234
235    fn get_drop_previous_databases(&self) -> bool {
236        self.drop_previous_databases_flag
237    }
238}
239
240type BError<BuildError, PoolError> = BackendError<BuildError, PoolError, ConnectionError, Error>;
241
242#[async_trait]
243impl<P: DieselPoolAssociation<AsyncMysqlConnection>> Backend for DieselAsyncMySQLBackend<P> {
244    type Pool = P::Pool;
245
246    type BuildError = P::BuildError;
247    type PoolError = P::PoolError;
248    type ConnectionError = ConnectionError;
249    type QueryError = Error;
250
251    async fn init(&self) -> Result<(), BError<P::BuildError, P::PoolError>> {
252        MySQLBackendWrapper::new(self).init().await
253    }
254
255    async fn create(
256        &self,
257        db_id: uuid::Uuid,
258        restrict_privileges: bool,
259    ) -> Result<P::Pool, BError<P::BuildError, P::PoolError>> {
260        MySQLBackendWrapper::new(self)
261            .create(db_id, restrict_privileges)
262            .await
263    }
264
265    async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError<P::BuildError, P::PoolError>> {
266        MySQLBackendWrapper::new(self).clean(db_id).await
267    }
268
269    async fn drop(
270        &self,
271        db_id: uuid::Uuid,
272        _is_restricted: bool,
273    ) -> Result<(), BError<P::BuildError, P::PoolError>> {
274        MySQLBackendWrapper::new(self).drop(db_id).await
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    #![allow(clippy::unwrap_used, clippy::needless_return)]
281
282    use std::borrow::Cow;
283
284    use bb8::Pool;
285    use diesel::{insert_into, sql_query, table, Insertable, QueryDsl};
286    use diesel_async::{RunQueryDsl, SimpleAsyncConnection};
287    use futures::future::join_all;
288    use tokio_shared_rt::test;
289
290    use crate::{
291        common::statement::mysql::tests::{
292            CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
293        },
294        r#async::{
295            backend::{
296                common::pool::diesel::bb8::DieselBb8,
297                mysql::r#trait::tests::{
298                    test_backend_creates_database_with_unrestricted_privileges,
299                    test_pool_drops_created_unrestricted_database,
300                },
301            },
302            db_pool::DatabasePoolBuilder,
303        },
304        tests::get_privileged_mysql_config,
305    };
306
307    use super::{
308        super::r#trait::tests::{
309            test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables,
310            test_backend_creates_database_with_restricted_privileges, test_backend_drops_database,
311            test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
312            test_pool_drops_previous_databases, MySQLDropLock,
313        },
314        DieselAsyncMySQLBackend,
315    };
316
317    table! {
318        book (id) {
319            id -> Int4,
320            title -> Text
321        }
322    }
323
324    #[derive(Insertable)]
325    #[diesel(table_name = book)]
326    struct NewBook<'a> {
327        title: Cow<'a, str>,
328    }
329
330    async fn create_backend(with_table: bool) -> DieselAsyncMySQLBackend<DieselBb8> {
331        let config = get_privileged_mysql_config().clone();
332        DieselAsyncMySQLBackend::new(config, Pool::builder, Pool::builder, None, {
333            move |mut conn| {
334                if with_table {
335                    Box::pin(async move {
336                        let query = CREATE_ENTITIES_STATEMENTS.join(";");
337                        conn.batch_execute(query.as_str()).await.unwrap();
338                    })
339                } else {
340                    Box::pin(async {})
341                }
342            }
343        })
344        .await
345        .unwrap()
346    }
347
348    #[test(flavor = "multi_thread", shared)]
349    async fn backend_drops_previous_databases() {
350        test_backend_drops_previous_databases(
351            create_backend(false).await,
352            create_backend(false).await.drop_previous_databases(true),
353            create_backend(false).await.drop_previous_databases(false),
354        )
355        .await;
356    }
357
358    #[test(flavor = "multi_thread", shared)]
359    async fn backend_creates_database_with_restricted_privileges() {
360        let backend = create_backend(true).await.drop_previous_databases(false);
361        test_backend_creates_database_with_restricted_privileges(backend).await;
362    }
363
364    #[test(flavor = "multi_thread", shared)]
365    async fn backend_creates_database_with_unrestricted_privileges() {
366        let backend = create_backend(true).await.drop_previous_databases(false);
367        test_backend_creates_database_with_unrestricted_privileges(backend).await;
368    }
369
370    #[test(flavor = "multi_thread", shared)]
371    async fn backend_cleans_database_with_tables() {
372        let backend = create_backend(true).await.drop_previous_databases(false);
373        test_backend_cleans_database_with_tables(backend).await;
374    }
375
376    #[test(flavor = "multi_thread", shared)]
377    async fn backend_cleans_database_without_tables() {
378        let backend = create_backend(false).await.drop_previous_databases(false);
379        test_backend_cleans_database_without_tables(backend).await;
380    }
381
382    #[test(flavor = "multi_thread", shared)]
383    async fn backend_drops_restricted_database() {
384        let backend = create_backend(true).await.drop_previous_databases(false);
385        test_backend_drops_database(backend, true).await;
386    }
387
388    #[test(flavor = "multi_thread", shared)]
389    async fn backend_drops_unrestricted_database() {
390        let backend = create_backend(true).await.drop_previous_databases(false);
391        test_backend_drops_database(backend, false).await;
392    }
393
394    #[test(flavor = "multi_thread", shared)]
395    async fn pool_drops_previous_databases() {
396        test_pool_drops_previous_databases(
397            create_backend(false).await,
398            create_backend(false).await.drop_previous_databases(true),
399            create_backend(false).await.drop_previous_databases(false),
400        )
401        .await;
402    }
403
404    #[test(flavor = "multi_thread", shared)]
405    async fn pool_provides_isolated_databases() {
406        const NUM_DBS: i64 = 3;
407
408        let backend = create_backend(true).await.drop_previous_databases(false);
409
410        async {
411            let db_pool = backend.create_database_pool().await.unwrap();
412            let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
413
414            // insert single row into each database
415            join_all(
416                conn_pools
417                    .iter()
418                    .enumerate()
419                    .map(|(i, conn_pool)| async move {
420                        let conn = &mut conn_pool.get().await.unwrap();
421                        insert_into(book::table)
422                            .values(NewBook {
423                                title: format!("Title {i}").into(),
424                            })
425                            .execute(conn)
426                            .await
427                            .unwrap();
428                    }),
429            )
430            .await;
431
432            // rows fetched must be as inserted
433            join_all(
434                conn_pools
435                    .iter()
436                    .enumerate()
437                    .map(|(i, conn_pool)| async move {
438                        let conn = &mut conn_pool.get().await.unwrap();
439                        assert_eq!(
440                            book::table
441                                .select(book::title)
442                                .load::<String>(conn)
443                                .await
444                                .unwrap(),
445                            vec![format!("Title {i}")]
446                        );
447                    }),
448            )
449            .await;
450        }
451        .lock_read()
452        .await;
453    }
454
455    #[test(flavor = "multi_thread", shared)]
456    async fn pool_provides_restricted_databases() {
457        let backend = create_backend(true).await.drop_previous_databases(false);
458
459        async {
460            let db_pool = backend.create_database_pool().await.unwrap();
461            let conn_pool = db_pool.pull_immutable().await;
462            let conn = &mut conn_pool.get().await.unwrap();
463
464            // DDL statements must fail
465            for stmt in DDL_STATEMENTS {
466                assert!(sql_query(stmt).execute(conn).await.is_err());
467            }
468
469            // DML statements must succeed
470            for stmt in DML_STATEMENTS {
471                assert!(sql_query(stmt).execute(conn).await.is_ok());
472            }
473        }
474        .lock_read()
475        .await;
476    }
477
478    #[test(flavor = "multi_thread", shared)]
479    async fn pool_provides_unrestricted_databases() {
480        let backend = create_backend(true).await.drop_previous_databases(false);
481
482        async {
483            let db_pool = backend.create_database_pool().await.unwrap();
484
485            // DML statements must succeed
486            {
487                let conn_pool = db_pool.create_mutable().await.unwrap();
488                let conn = &mut conn_pool.get().await.unwrap();
489                for stmt in DML_STATEMENTS {
490                    assert!(sql_query(stmt).execute(conn).await.is_ok());
491                }
492            }
493
494            // DDL statements must succeed
495            for stmt in DDL_STATEMENTS {
496                let conn_pool = db_pool.create_mutable().await.unwrap();
497                let conn = &mut conn_pool.get().await.unwrap();
498                assert!(sql_query(stmt).execute(conn).await.is_ok());
499            }
500        }
501        .lock_read()
502        .await;
503    }
504
505    #[test(flavor = "multi_thread", shared)]
506    async fn pool_provides_clean_databases() {
507        const NUM_DBS: i64 = 3;
508
509        let backend = create_backend(true).await.drop_previous_databases(false);
510
511        async {
512            let db_pool = backend.create_database_pool().await.unwrap();
513
514            // fetch connection pools the first time
515            {
516                let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
517
518                // databases must be empty
519                join_all(conn_pools.iter().map(|conn_pool| async move {
520                    let conn = &mut conn_pool.get().await.unwrap();
521                    assert_eq!(
522                        book::table.count().get_result::<i64>(conn).await.unwrap(),
523                        0
524                    );
525                }))
526                .await;
527
528                // insert data into each database
529                join_all(conn_pools.iter().map(|conn_pool| async move {
530                    let conn = &mut conn_pool.get().await.unwrap();
531                    insert_into(book::table)
532                        .values(NewBook {
533                            title: "Title".into(),
534                        })
535                        .execute(conn)
536                        .await
537                        .unwrap();
538                }))
539                .await;
540            }
541
542            // fetch same connection pools a second time
543            {
544                let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
545
546                // databases must be empty
547                join_all(conn_pools.iter().map(|conn_pool| async move {
548                    let conn = &mut conn_pool.get().await.unwrap();
549                    assert_eq!(
550                        book::table.count().get_result::<i64>(conn).await.unwrap(),
551                        0
552                    );
553                }))
554                .await;
555            }
556        }
557        .lock_read()
558        .await;
559    }
560
561    #[test(flavor = "multi_thread", shared)]
562    async fn pool_drops_created_restricted_databases() {
563        let backend = create_backend(false).await;
564        test_pool_drops_created_restricted_databases(backend).await;
565    }
566
567    #[test(flavor = "multi_thread", shared)]
568    async fn pool_drops_created_unrestricted_database() {
569        let backend = create_backend(false).await;
570        test_pool_drops_created_unrestricted_database(backend).await;
571    }
572}