db_pool/async/backend/postgres/
sea_orm.rs

1use std::{borrow::Cow, collections::HashMap, pin::Pin};
2
3use async_trait::async_trait;
4use futures::Future;
5use parking_lot::Mutex;
6use sea_orm::{
7    ActiveModelBehavior, ColumnTrait, ConnectOptions, ConnectionTrait, Database,
8    DatabaseConnection, DbErr, DeriveEntityModel, DerivePrimaryKey, DeriveRelation, EntityTrait,
9    EnumIter, FromQueryResult, PrimaryKeyTrait, QueryFilter, QuerySelect,
10};
11use uuid::Uuid;
12
13use crate::{common::config::PrivilegedPostgresConfig, util::get_db_name};
14
15use super::{
16    super::{
17        common::{
18            conn::sea_orm::PooledConnection,
19            error::sea_orm::{BuildError, ConnectionError, PoolError, QueryError},
20        },
21        error::Error as BackendError,
22        r#trait::Backend,
23    },
24    r#trait::{PostgresBackend, PostgresBackendWrapper},
25};
26
27type CreateEntities = dyn Fn(DatabaseConnection) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
28    + Send
29    + Sync
30    + 'static;
31
32/// [`SeaORM Postgres`](https://docs.rs/sea-orm/1.1.12/sea_orm/type.DbBackend.html#variant.Postgres) backend
33pub struct SeaORMPostgresBackend {
34    privileged_config: PrivilegedPostgresConfig,
35    default_pool: DatabaseConnection,
36    db_conns: Mutex<HashMap<Uuid, DatabaseConnection>>,
37    create_restricted_pool: Box<dyn for<'tmp> Fn(&'tmp mut ConnectOptions) + Send + Sync + 'static>,
38    create_entities: Box<CreateEntities>,
39    drop_previous_databases_flag: bool,
40}
41
42impl SeaORMPostgresBackend {
43    /// Creates a new [`SeaORM Postgres`](https://docs.rs/sea-orm/1.1.12/sea_orm/type.DbBackend.html#variant.Postgres) backend
44    /// # Example
45    /// ```
46    /// use bb8::Pool;
47    /// use db_pool::{r#async::SeaORMPostgresBackend, PrivilegedPostgresConfig};
48    /// use diesel::sql_query;
49    /// use diesel_async::RunQueryDsl;
50    /// use dotenvy::dotenv;
51    /// use sea_orm::ConnectionTrait;
52    ///
53    /// async fn f() {
54    ///     dotenv().ok();
55    ///
56    ///     let config = PrivilegedPostgresConfig::from_env().unwrap();
57    ///
58    ///     let backend = SeaORMPostgresBackend::new(
59    ///         config,
60    ///         |opts| {
61    ///             opts.max_connections(10);
62    ///         },
63    ///         |opts| {
64    ///             opts.max_connections(2);
65    ///         },
66    ///         move |conn| {
67    ///             Box::pin(async move {
68    ///                 conn.execute_unprepared(
69    ///                     "CREATE TABLE book(id SERIAL PRIMARY KEY, title TEXT NOT NULL)",
70    ///                 )
71    ///                 .await
72    ///                 .unwrap();
73    ///             })
74    ///         },
75    ///     )
76    ///     .await
77    ///     .unwrap();
78    /// }
79    ///
80    /// tokio_test::block_on(f());
81    /// ```
82    pub async fn new(
83        privileged_config: PrivilegedPostgresConfig,
84        create_privileged_pool: impl for<'tmp> Fn(&'tmp mut ConnectOptions),
85        create_restricted_pool: impl for<'tmp> Fn(&'tmp mut ConnectOptions) + Send + Sync + 'static,
86        create_entities: impl Fn(
87            DatabaseConnection,
88        ) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
89        + Send
90        + Sync
91        + 'static,
92    ) -> Result<Self, DbErr> {
93        let mut opts = ConnectOptions::new(privileged_config.default_connection_url());
94        create_privileged_pool(&mut opts);
95        let default_pool = Database::connect(opts).await?;
96
97        Ok(Self {
98            privileged_config,
99            default_pool,
100            db_conns: Mutex::new(HashMap::new()),
101            create_restricted_pool: Box::new(create_restricted_pool),
102            create_entities: Box::new(create_entities),
103            drop_previous_databases_flag: true,
104        })
105    }
106
107    /// Drop databases created in previous runs upon initialization
108    #[must_use]
109    pub fn drop_previous_databases(self, value: bool) -> Self {
110        Self {
111            drop_previous_databases_flag: value,
112            ..self
113        }
114    }
115}
116
117#[async_trait]
118impl<'pool> PostgresBackend<'pool> for SeaORMPostgresBackend {
119    type Connection = DatabaseConnection;
120    type PooledConnection = PooledConnection;
121    type Pool = DatabaseConnection;
122
123    type BuildError = BuildError;
124    type PoolError = PoolError;
125    type ConnectionError = ConnectionError;
126    type QueryError = QueryError;
127
128    async fn execute_query(
129        &self,
130        query: &str,
131        conn: &mut DatabaseConnection,
132    ) -> Result<(), QueryError> {
133        conn.execute_unprepared(query).await?;
134        Ok(())
135    }
136
137    async fn batch_execute_query<'a>(
138        &self,
139        query: impl IntoIterator<Item = Cow<'a, str>> + Send,
140        conn: &mut DatabaseConnection,
141    ) -> Result<(), QueryError> {
142        let query = query.into_iter().collect::<Vec<_>>().join(";");
143        self.execute_query(query.as_str(), conn).await
144    }
145
146    async fn get_default_connection(&'pool self) -> Result<PooledConnection, PoolError> {
147        Ok(self.default_pool.clone().into())
148    }
149
150    async fn establish_privileged_database_connection(
151        &self,
152        db_id: Uuid,
153    ) -> Result<DatabaseConnection, ConnectionError> {
154        let db_name = get_db_name(db_id);
155        let database_url = self
156            .privileged_config
157            .privileged_database_connection_url(db_name.as_str());
158        let opts = ConnectOptions::new(database_url);
159        Database::connect(opts).await.map_err(Into::into)
160    }
161
162    async fn establish_restricted_database_connection(
163        &self,
164        db_id: Uuid,
165    ) -> Result<DatabaseConnection, ConnectionError> {
166        let db_name = get_db_name(db_id);
167        let db_name = db_name.as_str();
168        let database_url = self.privileged_config.restricted_database_connection_url(
169            db_name,
170            Some(db_name),
171            db_name,
172        );
173        let opts = ConnectOptions::new(database_url);
174        Database::connect(opts).await.map_err(Into::into)
175    }
176
177    fn put_database_connection(&self, db_id: Uuid, conn: DatabaseConnection) {
178        self.db_conns.lock().insert(db_id, conn);
179    }
180
181    fn get_database_connection(&self, db_id: Uuid) -> DatabaseConnection {
182        self.db_conns
183            .lock()
184            .remove(&db_id)
185            .unwrap_or_else(|| panic!("connection map must have a connection for {db_id}"))
186    }
187
188    async fn get_previous_database_names(
189        &self,
190        conn: &mut DatabaseConnection,
191    ) -> Result<Vec<String>, QueryError> {
192        #[derive(Clone, Debug, DeriveEntityModel)]
193        #[sea_orm(table_name = "pg_database")]
194        pub struct Model {
195            #[sea_orm(primary_key)]
196            oid: i32,
197            datname: String,
198        }
199
200        #[derive(Debug, EnumIter, DeriveRelation)]
201        pub enum Relation {}
202
203        impl ActiveModelBehavior for ActiveModel {}
204
205        #[derive(FromQueryResult)]
206        struct QueryModel {
207            datname: String,
208        }
209
210        Entity::find()
211            .select_only()
212            .column(Column::Datname)
213            .filter(Column::Datname.like("db_pool_%"))
214            .into_model::<QueryModel>()
215            .all(conn)
216            .await
217            .map(|mut models| models.drain(..).map(|model| model.datname).collect())
218            .map_err(Into::into)
219    }
220
221    async fn create_entities(&self, conn: DatabaseConnection) -> Option<DatabaseConnection> {
222        (self.create_entities)(conn.clone()).await;
223        Some(conn)
224    }
225
226    async fn create_connection_pool(&self, db_id: Uuid) -> Result<DatabaseConnection, BuildError> {
227        let db_name = get_db_name(db_id);
228        let db_name = db_name.as_str();
229        let database_url = self.privileged_config.restricted_database_connection_url(
230            db_name,
231            Some(db_name),
232            db_name,
233        );
234        let mut opts = ConnectOptions::new(database_url);
235        (self.create_restricted_pool)(&mut opts);
236        Database::connect(opts).await.map_err(Into::into)
237    }
238
239    async fn get_table_names(
240        &self,
241        conn: &mut DatabaseConnection,
242    ) -> Result<Vec<String>, QueryError> {
243        #[derive(Clone, Debug, DeriveEntityModel)]
244        #[sea_orm(table_name = "pg_tables")]
245        pub struct Model {
246            schemaname: String,
247            #[sea_orm(primary_key)]
248            tablename: String,
249        }
250
251        #[derive(Debug, EnumIter, DeriveRelation)]
252        pub enum Relation {}
253
254        impl ActiveModelBehavior for ActiveModel {}
255
256        #[derive(FromQueryResult)]
257        struct QueryModel {
258            tablename: String,
259        }
260
261        Entity::find()
262            .select_only()
263            .column(Column::Tablename)
264            .filter(Column::Schemaname.is_not_in(["pg_catalog", "information_schema"]))
265            .into_model::<QueryModel>()
266            .all(conn)
267            .await
268            .map(|mut models| models.drain(..).map(|model| model.tablename).collect())
269            .map_err(Into::into)
270    }
271
272    fn get_drop_previous_databases(&self) -> bool {
273        self.drop_previous_databases_flag
274    }
275}
276
277type BError = BackendError<BuildError, PoolError, ConnectionError, QueryError>;
278
279#[async_trait]
280impl Backend for SeaORMPostgresBackend {
281    type Pool = DatabaseConnection;
282
283    type BuildError = BuildError;
284    type PoolError = PoolError;
285    type ConnectionError = ConnectionError;
286    type QueryError = QueryError;
287
288    async fn init(&self) -> Result<(), BError> {
289        PostgresBackendWrapper::new(self).init().await
290    }
291
292    async fn create(
293        &self,
294        db_id: uuid::Uuid,
295        restrict_privileges: bool,
296    ) -> Result<DatabaseConnection, BError> {
297        PostgresBackendWrapper::new(self)
298            .create(db_id, restrict_privileges)
299            .await
300    }
301
302    async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError> {
303        PostgresBackendWrapper::new(self).clean(db_id).await
304    }
305
306    async fn drop(&self, db_id: uuid::Uuid, is_restricted: bool) -> Result<(), BError> {
307        PostgresBackendWrapper::new(self)
308            .drop(db_id, is_restricted)
309            .await
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    #![allow(clippy::unwrap_used, clippy::needless_return)]
316
317    use dotenvy::dotenv;
318    use futures::future::join_all;
319    use sea_orm::{
320        ActiveModelBehavior, ActiveModelTrait, ConnectionTrait, DeriveEntityModel,
321        DerivePrimaryKey, DeriveRelation, EntityTrait, EnumIter, FromQueryResult, PaginatorTrait,
322        PrimaryKeyTrait, QuerySelect, Set,
323    };
324    use tokio_shared_rt::test;
325
326    use crate::{
327        r#async::{
328            backend::postgres::r#trait::tests::{
329                test_backend_drops_database, test_pool_drops_created_unrestricted_database,
330            },
331            db_pool::DatabasePoolBuilder,
332        },
333        common::{
334            config::PrivilegedPostgresConfig,
335            statement::postgres::tests::{
336                CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
337            },
338        },
339    };
340
341    use super::{
342        super::r#trait::tests::{
343            PgDropLock, test_backend_cleans_database_with_tables,
344            test_backend_cleans_database_without_tables,
345            test_backend_creates_database_with_restricted_privileges,
346            test_backend_creates_database_with_unrestricted_privileges,
347            test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
348            test_pool_drops_previous_databases,
349        },
350        SeaORMPostgresBackend,
351    };
352
353    #[derive(Clone, Debug, DeriveEntityModel)]
354    #[sea_orm(table_name = "book")]
355    pub struct Model {
356        #[sea_orm(primary_key)]
357        id: i32,
358        title: String,
359    }
360
361    #[derive(Debug, EnumIter, DeriveRelation)]
362    pub enum Relation {}
363
364    impl ActiveModelBehavior for ActiveModel {}
365
366    async fn create_backend(with_table: bool) -> SeaORMPostgresBackend {
367        dotenv().ok();
368
369        let config = PrivilegedPostgresConfig::from_env().unwrap();
370
371        SeaORMPostgresBackend::new(config, |_| {}, |_| {}, {
372            move |conn| {
373                if with_table {
374                    Box::pin(async move {
375                        conn.execute_unprepared(CREATE_ENTITIES_STATEMENTS.join(";").as_str())
376                            .await
377                            .unwrap();
378                    })
379                } else {
380                    Box::pin(async {})
381                }
382            }
383        })
384        .await
385        .unwrap()
386    }
387
388    #[test(flavor = "multi_thread", shared)]
389    async fn backend_drops_previous_databases() {
390        test_backend_drops_previous_databases(
391            create_backend(false).await,
392            create_backend(false).await.drop_previous_databases(true),
393            create_backend(false).await.drop_previous_databases(false),
394        )
395        .await;
396    }
397
398    #[test(flavor = "multi_thread", shared)]
399    async fn backend_creates_database_with_restricted_privileges() {
400        let backend = create_backend(true).await.drop_previous_databases(false);
401        test_backend_creates_database_with_restricted_privileges(backend).await;
402    }
403
404    #[test(flavor = "multi_thread", shared)]
405    async fn backend_creates_database_with_unrestricted_privileges() {
406        let backend = create_backend(true).await.drop_previous_databases(false);
407        test_backend_creates_database_with_unrestricted_privileges(backend).await;
408    }
409
410    #[test(flavor = "multi_thread", shared)]
411    async fn backend_cleans_database_with_tables() {
412        let backend = create_backend(true).await.drop_previous_databases(false);
413        test_backend_cleans_database_with_tables(backend).await;
414    }
415
416    #[test(flavor = "multi_thread", shared)]
417    async fn backend_cleans_database_without_tables() {
418        let backend = create_backend(false).await.drop_previous_databases(false);
419        test_backend_cleans_database_without_tables(backend).await;
420    }
421
422    #[test(flavor = "multi_thread", shared)]
423    async fn backend_drops_restricted_database() {
424        let backend = create_backend(true).await.drop_previous_databases(false);
425        test_backend_drops_database(backend, true).await;
426    }
427
428    #[test(flavor = "multi_thread", shared)]
429    async fn backend_drops_unrestricted_database() {
430        let backend = create_backend(true).await.drop_previous_databases(false);
431        test_backend_drops_database(backend, false).await;
432    }
433
434    #[test(flavor = "multi_thread", shared)]
435    async fn pool_drops_previous_databases() {
436        test_pool_drops_previous_databases(
437            create_backend(false).await,
438            create_backend(false).await.drop_previous_databases(true),
439            create_backend(false).await.drop_previous_databases(false),
440        )
441        .await;
442    }
443
444    #[test(flavor = "multi_thread", shared)]
445    async fn pool_provides_isolated_databases() {
446        #[derive(FromQueryResult, Eq, PartialEq, Debug)]
447        struct QueryModel {
448            title: String,
449        }
450
451        const NUM_DBS: i64 = 3;
452
453        let backend = create_backend(true).await.drop_previous_databases(false);
454
455        async {
456            let db_pool = backend.create_database_pool().await.unwrap();
457            let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
458
459            // insert single row into each database
460            join_all(conns.iter().enumerate().map(|(i, conn)| async move {
461                let book = ActiveModel {
462                    title: Set(format!("Title {i}")),
463                    ..Default::default()
464                };
465                book.insert(&***conn).await.unwrap();
466            }))
467            .await;
468
469            // rows fetched must be as inserted
470            join_all(conns.iter().enumerate().map(|(i, conn)| async move {
471                assert_eq!(
472                    Entity::find()
473                        .select_only()
474                        .column(Column::Title)
475                        .into_model::<QueryModel>()
476                        .all(&***conn)
477                        .await
478                        .unwrap(),
479                    vec![QueryModel {
480                        title: format!("Title {i}")
481                    }]
482                );
483            }))
484            .await;
485        }
486        .lock_read()
487        .await;
488    }
489
490    #[test(flavor = "multi_thread", shared)]
491    async fn pool_provides_restricted_databases() {
492        let backend = create_backend(true).await.drop_previous_databases(false);
493
494        async {
495            let db_pool = backend.create_database_pool().await.unwrap();
496            let conn = db_pool.pull_immutable().await;
497
498            // DDL statements must fail
499            for stmt in DDL_STATEMENTS {
500                assert!(conn.execute_unprepared(stmt).await.is_err());
501            }
502
503            // DML statements must succeed
504            for stmt in DML_STATEMENTS {
505                assert!(conn.execute_unprepared(stmt).await.is_ok());
506            }
507        }
508        .lock_read()
509        .await;
510    }
511
512    #[test(flavor = "multi_thread", shared)]
513    async fn pool_provides_unrestricted_databases() {
514        let backend = create_backend(true).await.drop_previous_databases(false);
515
516        async {
517            let db_pool = backend.create_database_pool().await.unwrap();
518
519            // DML statements must succeed
520            {
521                let conn = db_pool.create_mutable().await.unwrap();
522                for stmt in DML_STATEMENTS {
523                    assert!(conn.execute_unprepared(stmt).await.is_ok());
524                }
525            }
526
527            // DDL statements must succeed
528            for stmt in DDL_STATEMENTS {
529                let conn = db_pool.create_mutable().await.unwrap();
530                assert!(conn.execute_unprepared(stmt).await.is_ok());
531            }
532        }
533        .lock_read()
534        .await;
535    }
536
537    #[test(flavor = "multi_thread", shared)]
538    async fn pool_provides_clean_databases() {
539        const NUM_DBS: i64 = 3;
540
541        let backend = create_backend(true).await.drop_previous_databases(false);
542
543        async {
544            let db_pool = backend.create_database_pool().await.unwrap();
545
546            // fetch connection pools the first time
547            {
548                let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
549
550                // databases must be empty
551                join_all(conns.iter().map(|conn| async move {
552                    assert_eq!(Entity::find().count(&***conn).await.unwrap(), 0);
553                }))
554                .await;
555
556                // insert data into each database
557                join_all(conns.iter().map(|conn| async move {
558                    let book = ActiveModel {
559                        title: Set("Title".to_owned()),
560                        ..Default::default()
561                    };
562                    book.insert(&***conn).await.unwrap();
563                }))
564                .await;
565            }
566
567            // fetch same connection pools a second time
568            {
569                let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
570
571                // databases must be empty
572                join_all(conns.iter().map(|conn| async move {
573                    assert_eq!(Entity::find().count(&***conn).await.unwrap(), 0);
574                }))
575                .await;
576            }
577        }
578        .lock_read()
579        .await;
580    }
581
582    #[test(flavor = "multi_thread", shared)]
583    async fn pool_drops_created_restricted_databases() {
584        let backend = create_backend(false).await;
585        test_pool_drops_created_restricted_databases(backend).await;
586    }
587
588    #[test(flavor = "multi_thread", shared)]
589    async fn pool_drops_created_unrestricted_database() {
590        let backend = create_backend(false).await;
591        test_pool_drops_created_unrestricted_database(backend).await;
592    }
593}