db_pool/async/backend/mysql/
sea_orm.rs

1use std::{borrow::Cow, pin::Pin};
2
3use async_trait::async_trait;
4use futures::Future;
5use sea_orm::{
6    ActiveModelBehavior, ColumnTrait, ConnectOptions, ConnectionTrait, Database,
7    DatabaseConnection, DbErr, DeriveEntityModel, DerivePrimaryKey, DeriveRelation, EntityTrait,
8    EnumIter, FromQueryResult, PrimaryKeyTrait, QueryFilter, QuerySelect, TransactionError,
9    TransactionTrait,
10};
11use uuid::Uuid;
12
13use crate::{
14    common::{config::PrivilegedMySQLConfig, statement::mysql},
15    util::get_db_name,
16};
17
18use super::{
19    super::{
20        common::{
21            conn::sea_orm::PooledConnection,
22            error::sea_orm::{BuildError, ConnectionError, PoolError, QueryError},
23        },
24        error::Error as BackendError,
25        r#trait::Backend,
26    },
27    r#trait::{MySQLBackend, MySQLBackendWrapper},
28};
29
30type CreateEntities = dyn Fn(DatabaseConnection) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
31    + Send
32    + Sync
33    + 'static;
34
35/// [`SeaORM MySQL`](https://docs.rs/sea-orm/1.0.1/sea_orm/type.DbBackend.html#variant.MySql) backend
36pub struct SeaORMMySQLBackend {
37    privileged_config: PrivilegedMySQLConfig,
38    default_pool: DatabaseConnection,
39    create_restricted_pool: Box<dyn for<'tmp> Fn(&'tmp mut ConnectOptions) + Send + Sync + 'static>,
40    create_entities: Box<CreateEntities>,
41    drop_previous_databases_flag: bool,
42}
43
44impl SeaORMMySQLBackend {
45    /// Creates a new [`SeaORM MySQL`](https://docs.rs/sea-orm/1.0.1/sea_orm/type.DbBackend.html#variant.MySql) backend
46    /// # Example
47    /// ```
48    /// use db_pool::{r#async::SeaORMMySQLBackend, PrivilegedMySQLConfig};
49    /// use dotenvy::dotenv;
50    /// use sea_orm::ConnectionTrait;
51    ///
52    /// async fn f() {
53    ///     dotenv().ok();
54    ///
55    ///     let config = PrivilegedMySQLConfig::from_env().unwrap();
56    ///
57    ///     let backend = SeaORMMySQLBackend::new(
58    ///             config,
59    ///             |opts| {
60    ///                 opts.max_connections(10);
61    ///             },
62    ///             |opts| {
63    ///                 opts.max_connections(2);
64    ///             },
65    ///             move |conn| {
66    ///                 Box::pin(async move {
67    ///                     conn.execute_unprepared(
68    ///                         "CREATE TABLE book(id INTEGER PRIMARY KEY AUTO_INCREMENT, title TEXT NOT NULL)",
69    ///                     )
70    ///                     .await
71    ///                     .unwrap();
72    ///                 })
73    ///             },
74    ///         )
75    ///         .await
76    ///         .unwrap();
77    /// }
78    ///
79    /// tokio_test::block_on(f());
80    /// ```
81    pub async fn new(
82        privileged_config: PrivilegedMySQLConfig,
83        create_privileged_pool: impl for<'tmp> Fn(&'tmp mut ConnectOptions),
84        create_restricted_pool: impl for<'tmp> Fn(&'tmp mut ConnectOptions) + Send + Sync + 'static,
85        create_entities: impl Fn(DatabaseConnection) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
86            + Send
87            + Sync
88            + 'static,
89    ) -> Result<Self, DbErr> {
90        let mut opts = ConnectOptions::new(privileged_config.default_connection_url());
91        create_privileged_pool(&mut opts);
92        let default_pool = Database::connect(opts).await?;
93
94        Ok(Self {
95            privileged_config,
96            default_pool,
97            create_restricted_pool: Box::new(create_restricted_pool),
98            create_entities: Box::new(create_entities),
99            drop_previous_databases_flag: true,
100        })
101    }
102
103    /// Drop databases created in previous runs upon initialization
104    #[must_use]
105    pub fn drop_previous_databases(self, value: bool) -> Self {
106        Self {
107            drop_previous_databases_flag: value,
108            ..self
109        }
110    }
111}
112
113#[async_trait]
114impl<'pool> MySQLBackend<'pool> for SeaORMMySQLBackend {
115    type Connection = DatabaseConnection;
116    type PooledConnection = PooledConnection;
117    type Pool = DatabaseConnection;
118
119    type BuildError = BuildError;
120    type PoolError = PoolError;
121    type ConnectionError = ConnectionError;
122    type QueryError = QueryError;
123
124    async fn get_connection(&'pool self) -> Result<PooledConnection, PoolError> {
125        Ok(self.default_pool.clone().into())
126    }
127
128    async fn execute_query(
129        &self,
130        query: &str,
131        conn: &mut DatabaseConnection,
132    ) -> Result<(), QueryError> {
133        conn.execute_unprepared(query).await?;
134        Ok(())
135    }
136
137    async fn batch_execute_query<'a>(
138        &self,
139        query: impl IntoIterator<Item = Cow<'a, str>> + Send,
140        conn: &mut DatabaseConnection,
141    ) -> Result<(), QueryError> {
142        let chunks = query.into_iter().collect::<Vec<_>>();
143        if chunks.is_empty() {
144            Ok(())
145        } else {
146            let query = chunks.join(";");
147            self.execute_query(query.as_str(), conn).await
148        }
149    }
150
151    fn get_host(&self) -> &str {
152        self.privileged_config.host.as_str()
153    }
154
155    async fn get_previous_database_names(
156        &self,
157        conn: &mut DatabaseConnection,
158    ) -> Result<Vec<String>, QueryError> {
159        #[derive(Clone, Debug, DeriveEntityModel)]
160        #[sea_orm(table_name = "schemata")]
161        pub struct Model {
162            #[sea_orm(primary_key)]
163            schema_name: String,
164        }
165
166        #[derive(Debug, EnumIter, DeriveRelation)]
167        pub enum Relation {}
168
169        impl ActiveModelBehavior for ActiveModel {}
170
171        conn.transaction(move |txn| {
172            Box::pin(async move {
173                txn.execute_unprepared(mysql::USE_DEFAULT_DATABASE).await?;
174
175                Entity::find()
176                    .filter(Column::SchemaName.like("db_pool_%"))
177                    .all(txn)
178                    .await
179            })
180        })
181        .await
182        .map(|mut models| models.drain(..).map(|model| model.schema_name).collect())
183        .map_err(|err| match err {
184            TransactionError::Connection(err) | TransactionError::Transaction(err) => err.into(),
185        })
186    }
187
188    async fn create_entities(&self, db_name: &str) -> Result<(), ConnectionError> {
189        let database_url = self
190            .privileged_config
191            .privileged_database_connection_url(db_name);
192        let conn = Database::connect(database_url).await?;
193        (self.create_entities)(conn).await;
194        Ok(())
195    }
196
197    async fn create_connection_pool(&self, db_id: Uuid) -> Result<DatabaseConnection, BuildError> {
198        let db_name = get_db_name(db_id);
199        let db_name = db_name.as_str();
200        let database_url = self.privileged_config.restricted_database_connection_url(
201            db_name,
202            Some(db_name),
203            db_name,
204        );
205        let mut opts = ConnectOptions::new(database_url);
206        (self.create_restricted_pool)(&mut opts);
207        Database::connect(opts).await.map_err(Into::into)
208    }
209
210    // TODO: improve error in trait to include both query and connection errors
211    async fn get_table_names(
212        &self,
213        db_name: &str,
214        conn: &mut DatabaseConnection,
215    ) -> Result<Vec<String>, QueryError> {
216        #[derive(Clone, Debug, DeriveEntityModel)]
217        #[sea_orm(table_name = "tables")]
218        pub struct Model {
219            #[sea_orm(primary_key)]
220            table_name: String,
221            table_schema: String,
222        }
223
224        #[derive(Debug, EnumIter, DeriveRelation)]
225        pub enum Relation {}
226
227        impl ActiveModelBehavior for ActiveModel {}
228
229        #[derive(FromQueryResult)]
230        struct QueryModel {
231            table_name: String,
232        }
233
234        conn.transaction(move |txn| {
235            let db_name = db_name.to_owned();
236            Box::pin(async move {
237                txn.execute_unprepared(mysql::USE_DEFAULT_DATABASE).await?;
238
239                Entity::find()
240                    .select_only()
241                    .column(Column::TableName)
242                    .filter(Column::TableSchema.eq(db_name))
243                    .into_model::<QueryModel>()
244                    .all(txn)
245                    .await
246            })
247        })
248        .await
249        .map(|mut models| models.drain(..).map(|model| model.table_name).collect())
250        .map_err(|err| match err {
251            TransactionError::Connection(err) | TransactionError::Transaction(err) => err.into(),
252        })
253    }
254
255    fn get_drop_previous_databases(&self) -> bool {
256        self.drop_previous_databases_flag
257    }
258}
259
260type BError = BackendError<BuildError, PoolError, ConnectionError, QueryError>;
261
262#[async_trait]
263impl Backend for SeaORMMySQLBackend {
264    type Pool = DatabaseConnection;
265
266    type BuildError = BuildError;
267    type PoolError = PoolError;
268    type ConnectionError = ConnectionError;
269    type QueryError = QueryError;
270
271    async fn init(&self) -> Result<(), BError> {
272        MySQLBackendWrapper::new(self).init().await
273    }
274
275    async fn create(
276        &self,
277        db_id: uuid::Uuid,
278        restrict_privileges: bool,
279    ) -> Result<DatabaseConnection, BError> {
280        MySQLBackendWrapper::new(self)
281            .create(db_id, restrict_privileges)
282            .await
283    }
284
285    async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError> {
286        MySQLBackendWrapper::new(self).clean(db_id).await
287    }
288
289    async fn drop(&self, db_id: uuid::Uuid, _is_restricted: bool) -> Result<(), BError> {
290        MySQLBackendWrapper::new(self).drop(db_id).await
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    #![allow(clippy::unwrap_used, clippy::needless_return)]
297
298    use futures::future::join_all;
299    use sea_orm::{
300        ActiveModelBehavior, ActiveModelTrait, ConnectionTrait, DeriveEntityModel,
301        DerivePrimaryKey, DeriveRelation, EntityTrait, EnumIter, FromQueryResult, PaginatorTrait,
302        PrimaryKeyTrait, QuerySelect, Set,
303    };
304    use tokio_shared_rt::test;
305
306    use crate::{
307        common::statement::mysql::tests::{
308            CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
309        },
310        r#async::{
311            backend::mysql::r#trait::tests::{
312                test_backend_creates_database_with_unrestricted_privileges,
313                test_pool_drops_created_restricted_databases,
314                test_pool_drops_created_unrestricted_database,
315            },
316            db_pool::DatabasePoolBuilder,
317        },
318        tests::get_privileged_mysql_config,
319    };
320
321    use super::{
322        super::r#trait::tests::{
323            test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables,
324            test_backend_creates_database_with_restricted_privileges, test_backend_drops_database,
325            test_backend_drops_previous_databases, test_pool_drops_previous_databases,
326            MySQLDropLock,
327        },
328        SeaORMMySQLBackend,
329    };
330
331    #[derive(Clone, Debug, DeriveEntityModel)]
332    #[sea_orm(table_name = "book")]
333    pub struct Model {
334        #[sea_orm(primary_key)]
335        id: i32,
336        title: String,
337    }
338
339    #[derive(Debug, EnumIter, DeriveRelation)]
340    pub enum Relation {}
341
342    impl ActiveModelBehavior for ActiveModel {}
343
344    async fn create_backend(with_table: bool) -> SeaORMMySQLBackend {
345        let config = get_privileged_mysql_config().clone();
346        SeaORMMySQLBackend::new(config, |_| {}, |_| {}, {
347            move |conn| {
348                if with_table {
349                    Box::pin(async move {
350                        conn.execute_unprepared(CREATE_ENTITIES_STATEMENTS.join(";").as_str())
351                            .await
352                            .unwrap();
353                    })
354                } else {
355                    Box::pin(async {})
356                }
357            }
358        })
359        .await
360        .unwrap()
361    }
362
363    #[test(flavor = "multi_thread", shared)]
364    async fn backend_drops_previous_databases() {
365        test_backend_drops_previous_databases(
366            create_backend(false).await,
367            create_backend(false).await.drop_previous_databases(true),
368            create_backend(false).await.drop_previous_databases(false),
369        )
370        .await;
371    }
372
373    #[test(flavor = "multi_thread", shared)]
374    async fn backend_creates_database_with_restricted_privileges() {
375        let backend = create_backend(true).await.drop_previous_databases(false);
376        test_backend_creates_database_with_restricted_privileges(backend).await;
377    }
378
379    #[test(flavor = "multi_thread", shared)]
380    async fn backend_creates_database_with_unrestricted_privileges() {
381        let backend = create_backend(true).await.drop_previous_databases(false);
382        test_backend_creates_database_with_unrestricted_privileges(backend).await;
383    }
384
385    #[test(flavor = "multi_thread", shared)]
386    async fn backend_cleans_database_with_tables() {
387        let backend = create_backend(true).await.drop_previous_databases(false);
388        test_backend_cleans_database_with_tables(backend).await;
389    }
390
391    #[test(flavor = "multi_thread", shared)]
392    async fn backend_cleans_database_without_tables() {
393        let backend = create_backend(false).await.drop_previous_databases(false);
394        test_backend_cleans_database_without_tables(backend).await;
395    }
396
397    #[test(flavor = "multi_thread", shared)]
398    async fn backend_drops_restricted_database() {
399        let backend = create_backend(true).await.drop_previous_databases(false);
400        test_backend_drops_database(backend, true).await;
401    }
402
403    #[test(flavor = "multi_thread", shared)]
404    async fn backend_drops_unrestricted_database() {
405        let backend = create_backend(true).await.drop_previous_databases(false);
406        test_backend_drops_database(backend, false).await;
407    }
408
409    #[test(flavor = "multi_thread", shared)]
410    async fn pool_drops_previous_databases() {
411        test_pool_drops_previous_databases(
412            create_backend(false).await,
413            create_backend(false).await.drop_previous_databases(true),
414            create_backend(false).await.drop_previous_databases(false),
415        )
416        .await;
417    }
418
419    #[test(flavor = "multi_thread", shared)]
420    async fn pool_provides_isolated_databases() {
421        #[derive(FromQueryResult, Eq, PartialEq, Debug)]
422        struct QueryModel {
423            title: String,
424        }
425
426        const NUM_DBS: i64 = 3;
427
428        let backend = create_backend(true).await.drop_previous_databases(false);
429
430        async {
431            let db_pool = backend.create_database_pool().await.unwrap();
432            let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
433
434            // insert single row into each database
435            join_all(conns.iter().enumerate().map(|(i, conn)| async move {
436                let book = ActiveModel {
437                    title: Set(format!("Title {i}")),
438                    ..Default::default()
439                };
440                book.insert(&***conn).await.unwrap();
441            }))
442            .await;
443
444            // rows fetched must be as inserted
445            join_all(conns.iter().enumerate().map(|(i, conn)| async move {
446                assert_eq!(
447                    Entity::find()
448                        .select_only()
449                        .column(Column::Title)
450                        .into_model::<QueryModel>()
451                        .all(&***conn)
452                        .await
453                        .unwrap(),
454                    vec![QueryModel {
455                        title: format!("Title {i}")
456                    }]
457                );
458            }))
459            .await;
460        }
461        .lock_read()
462        .await;
463    }
464
465    #[test(flavor = "multi_thread", shared)]
466    async fn pool_provides_restricted_databases() {
467        let backend = create_backend(true).await.drop_previous_databases(false);
468
469        async {
470            let db_pool = backend.create_database_pool().await.unwrap();
471            let conn = db_pool.pull_immutable().await;
472
473            // DDL statements must fail
474            for stmt in DDL_STATEMENTS {
475                assert!(conn.execute_unprepared(stmt).await.is_err());
476            }
477
478            // DML statements must succeed
479            for stmt in DML_STATEMENTS {
480                assert!(conn.execute_unprepared(stmt).await.is_ok());
481            }
482        }
483        .lock_read()
484        .await;
485    }
486
487    #[test(flavor = "multi_thread", shared)]
488    async fn pool_provides_unrestricted_databases() {
489        let backend = create_backend(true).await.drop_previous_databases(false);
490
491        async {
492            let db_pool = backend.create_database_pool().await.unwrap();
493
494            // DML statements must succeed
495            {
496                let conn = db_pool.create_mutable().await.unwrap();
497                for stmt in DML_STATEMENTS {
498                    assert!(conn.execute_unprepared(stmt).await.is_ok());
499                }
500            }
501
502            // DDL statements must succeed
503            for stmt in DDL_STATEMENTS {
504                let conn = db_pool.create_mutable().await.unwrap();
505                assert!(conn.execute_unprepared(stmt).await.is_ok());
506            }
507        }
508        .lock_read()
509        .await;
510    }
511
512    #[test(flavor = "multi_thread", shared)]
513    async fn pool_provides_clean_databases() {
514        const NUM_DBS: i64 = 3;
515
516        let backend = create_backend(true).await.drop_previous_databases(false);
517
518        async {
519            let db_pool = backend.create_database_pool().await.unwrap();
520
521            // fetch connection pools the first time
522            {
523                let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
524
525                // databases must be empty
526                join_all(conns.iter().map(|conn| async move {
527                    assert_eq!(Entity::find().count(&***conn).await.unwrap(), 0);
528                }))
529                .await;
530
531                // insert data into each database
532                join_all(conns.iter().map(|conn| async move {
533                    let book = ActiveModel {
534                        title: Set("Title".to_owned()),
535                        ..Default::default()
536                    };
537                    book.insert(&***conn).await.unwrap();
538                }))
539                .await;
540            }
541
542            // fetch same connection pools a second time
543            {
544                let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
545
546                // databases must be empty
547                join_all(conns.iter().map(|conn| async move {
548                    assert_eq!(Entity::find().count(&***conn).await.unwrap(), 0);
549                }))
550                .await;
551            }
552        }
553        .lock_read()
554        .await;
555    }
556
557    #[test(flavor = "multi_thread", shared)]
558    async fn pool_drops_created_restricted_databases() {
559        let backend = create_backend(false).await;
560        test_pool_drops_created_restricted_databases(backend).await;
561    }
562
563    #[test(flavor = "multi_thread", shared)]
564    async fn pool_drops_created_unrestricted_database() {
565        let backend = create_backend(false).await;
566        test_pool_drops_created_unrestricted_database(backend).await;
567    }
568}