db_pool/async/backend/postgres/
sea_orm.rs

1use std::{borrow::Cow, collections::HashMap, pin::Pin};
2
3use async_trait::async_trait;
4use futures::Future;
5use parking_lot::Mutex;
6use sea_orm::{
7    ActiveModelBehavior, ColumnTrait, ConnectOptions, ConnectionTrait, Database,
8    DatabaseConnection, DbErr, DeriveEntityModel, DerivePrimaryKey, DeriveRelation, EntityTrait,
9    EnumIter, FromQueryResult, PrimaryKeyTrait, QueryFilter, QuerySelect,
10};
11use uuid::Uuid;
12
13use crate::{common::config::PrivilegedPostgresConfig, util::get_db_name};
14
15use super::{
16    super::{
17        common::{
18            conn::sea_orm::PooledConnection,
19            error::sea_orm::{BuildError, ConnectionError, PoolError, QueryError},
20        },
21        error::Error as BackendError,
22        r#trait::Backend,
23    },
24    r#trait::{PostgresBackend, PostgresBackendWrapper},
25};
26
27type CreateEntities = dyn Fn(DatabaseConnection) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
28    + Send
29    + Sync
30    + 'static;
31
32/// [`SeaORM Postgres`](https://docs.rs/sea-orm/1.0.1/sea_orm/type.DbBackend.html#variant.Postgres) backend
33pub struct SeaORMPostgresBackend {
34    privileged_config: PrivilegedPostgresConfig,
35    default_pool: DatabaseConnection,
36    db_conns: Mutex<HashMap<Uuid, DatabaseConnection>>,
37    create_restricted_pool: Box<dyn for<'tmp> Fn(&'tmp mut ConnectOptions) + Send + Sync + 'static>,
38    create_entities: Box<CreateEntities>,
39    drop_previous_databases_flag: bool,
40}
41
42impl SeaORMPostgresBackend {
43    /// Creates a new [`SeaORM Postgres`](https://docs.rs/sea-orm/1.0.1/sea_orm/type.DbBackend.html#variant.Postgres) backend
44    /// # Example
45    /// ```
46    /// use bb8::Pool;
47    /// use db_pool::{r#async::SeaORMPostgresBackend, PrivilegedPostgresConfig};
48    /// use diesel::sql_query;
49    /// use diesel_async::RunQueryDsl;
50    /// use dotenvy::dotenv;
51    /// use sea_orm::ConnectionTrait;
52    ///
53    /// async fn f() {
54    ///     dotenv().ok();
55    ///
56    ///     let config = PrivilegedPostgresConfig::from_env().unwrap();
57    ///
58    ///     let backend = SeaORMPostgresBackend::new(
59    ///         config,
60    ///         |opts| {
61    ///             opts.max_connections(10);
62    ///         },
63    ///         |opts| {
64    ///             opts.max_connections(2);
65    ///         },
66    ///         move |conn| {
67    ///             Box::pin(async move {
68    ///                 conn.execute_unprepared(
69    ///                     "CREATE TABLE book(id SERIAL PRIMARY KEY, title TEXT NOT NULL)",
70    ///                 )
71    ///                 .await
72    ///                 .unwrap();
73    ///             })
74    ///         },
75    ///     )
76    ///     .await
77    ///     .unwrap();
78    /// }
79    ///
80    /// tokio_test::block_on(f());
81    /// ```
82    pub async fn new(
83        privileged_config: PrivilegedPostgresConfig,
84        create_privileged_pool: impl for<'tmp> Fn(&'tmp mut ConnectOptions),
85        create_restricted_pool: impl for<'tmp> Fn(&'tmp mut ConnectOptions) + Send + Sync + 'static,
86        create_entities: impl Fn(DatabaseConnection) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
87            + Send
88            + Sync
89            + 'static,
90    ) -> Result<Self, DbErr> {
91        let mut opts = ConnectOptions::new(privileged_config.default_connection_url());
92        create_privileged_pool(&mut opts);
93        let default_pool = Database::connect(opts).await?;
94
95        Ok(Self {
96            privileged_config,
97            default_pool,
98            db_conns: Mutex::new(HashMap::new()),
99            create_restricted_pool: Box::new(create_restricted_pool),
100            create_entities: Box::new(create_entities),
101            drop_previous_databases_flag: true,
102        })
103    }
104
105    /// Drop databases created in previous runs upon initialization
106    #[must_use]
107    pub fn drop_previous_databases(self, value: bool) -> Self {
108        Self {
109            drop_previous_databases_flag: value,
110            ..self
111        }
112    }
113}
114
115#[async_trait]
116impl<'pool> PostgresBackend<'pool> for SeaORMPostgresBackend {
117    type Connection = DatabaseConnection;
118    type PooledConnection = PooledConnection;
119    type Pool = DatabaseConnection;
120
121    type BuildError = BuildError;
122    type PoolError = PoolError;
123    type ConnectionError = ConnectionError;
124    type QueryError = QueryError;
125
126    async fn execute_query(
127        &self,
128        query: &str,
129        conn: &mut DatabaseConnection,
130    ) -> Result<(), QueryError> {
131        conn.execute_unprepared(query).await?;
132        Ok(())
133    }
134
135    async fn batch_execute_query<'a>(
136        &self,
137        query: impl IntoIterator<Item = Cow<'a, str>> + Send,
138        conn: &mut DatabaseConnection,
139    ) -> Result<(), QueryError> {
140        let query = query.into_iter().collect::<Vec<_>>().join(";");
141        self.execute_query(query.as_str(), conn).await
142    }
143
144    async fn get_default_connection(&'pool self) -> Result<PooledConnection, PoolError> {
145        Ok(self.default_pool.clone().into())
146    }
147
148    async fn establish_privileged_database_connection(
149        &self,
150        db_id: Uuid,
151    ) -> Result<DatabaseConnection, ConnectionError> {
152        let db_name = get_db_name(db_id);
153        let database_url = self
154            .privileged_config
155            .privileged_database_connection_url(db_name.as_str());
156        let opts = ConnectOptions::new(database_url);
157        Database::connect(opts).await.map_err(Into::into)
158    }
159
160    async fn establish_restricted_database_connection(
161        &self,
162        db_id: Uuid,
163    ) -> Result<DatabaseConnection, ConnectionError> {
164        let db_name = get_db_name(db_id);
165        let db_name = db_name.as_str();
166        let database_url = self.privileged_config.restricted_database_connection_url(
167            db_name,
168            Some(db_name),
169            db_name,
170        );
171        let opts = ConnectOptions::new(database_url);
172        Database::connect(opts).await.map_err(Into::into)
173    }
174
175    fn put_database_connection(&self, db_id: Uuid, conn: DatabaseConnection) {
176        self.db_conns.lock().insert(db_id, conn);
177    }
178
179    fn get_database_connection(&self, db_id: Uuid) -> DatabaseConnection {
180        self.db_conns
181            .lock()
182            .remove(&db_id)
183            .unwrap_or_else(|| panic!("connection map must have a connection for {db_id}"))
184    }
185
186    async fn get_previous_database_names(
187        &self,
188        conn: &mut DatabaseConnection,
189    ) -> Result<Vec<String>, QueryError> {
190        #[derive(Clone, Debug, DeriveEntityModel)]
191        #[sea_orm(table_name = "pg_database")]
192        pub struct Model {
193            #[sea_orm(primary_key)]
194            oid: i32,
195            datname: String,
196        }
197
198        #[derive(Debug, EnumIter, DeriveRelation)]
199        pub enum Relation {}
200
201        impl ActiveModelBehavior for ActiveModel {}
202
203        #[derive(FromQueryResult)]
204        struct QueryModel {
205            datname: String,
206        }
207
208        Entity::find()
209            .select_only()
210            .column(Column::Datname)
211            .filter(Column::Datname.like("db_pool_%"))
212            .into_model::<QueryModel>()
213            .all(conn)
214            .await
215            .map(|mut models| models.drain(..).map(|model| model.datname).collect())
216            .map_err(Into::into)
217    }
218
219    async fn create_entities(&self, conn: DatabaseConnection) -> DatabaseConnection {
220        (self.create_entities)(conn.clone()).await;
221        conn
222    }
223
224    async fn create_connection_pool(&self, db_id: Uuid) -> Result<DatabaseConnection, BuildError> {
225        let db_name = get_db_name(db_id);
226        let db_name = db_name.as_str();
227        let database_url = self.privileged_config.restricted_database_connection_url(
228            db_name,
229            Some(db_name),
230            db_name,
231        );
232        let mut opts = ConnectOptions::new(database_url);
233        (self.create_restricted_pool)(&mut opts);
234        Database::connect(opts).await.map_err(Into::into)
235    }
236
237    async fn get_table_names(
238        &self,
239        conn: &mut DatabaseConnection,
240    ) -> Result<Vec<String>, QueryError> {
241        #[derive(Clone, Debug, DeriveEntityModel)]
242        #[sea_orm(table_name = "pg_tables")]
243        pub struct Model {
244            schemaname: String,
245            #[sea_orm(primary_key)]
246            tablename: String,
247        }
248
249        #[derive(Debug, EnumIter, DeriveRelation)]
250        pub enum Relation {}
251
252        impl ActiveModelBehavior for ActiveModel {}
253
254        #[derive(FromQueryResult)]
255        struct QueryModel {
256            tablename: String,
257        }
258
259        Entity::find()
260            .select_only()
261            .column(Column::Tablename)
262            .filter(Column::Schemaname.is_not_in(["pg_catalog", "information_schema"]))
263            .into_model::<QueryModel>()
264            .all(conn)
265            .await
266            .map(|mut models| models.drain(..).map(|model| model.tablename).collect())
267            .map_err(Into::into)
268    }
269
270    fn get_drop_previous_databases(&self) -> bool {
271        self.drop_previous_databases_flag
272    }
273}
274
275type BError = BackendError<BuildError, PoolError, ConnectionError, QueryError>;
276
277#[async_trait]
278impl Backend for SeaORMPostgresBackend {
279    type Pool = DatabaseConnection;
280
281    type BuildError = BuildError;
282    type PoolError = PoolError;
283    type ConnectionError = ConnectionError;
284    type QueryError = QueryError;
285
286    async fn init(&self) -> Result<(), BError> {
287        PostgresBackendWrapper::new(self).init().await
288    }
289
290    async fn create(
291        &self,
292        db_id: uuid::Uuid,
293        restrict_privileges: bool,
294    ) -> Result<DatabaseConnection, BError> {
295        PostgresBackendWrapper::new(self)
296            .create(db_id, restrict_privileges)
297            .await
298    }
299
300    async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError> {
301        PostgresBackendWrapper::new(self).clean(db_id).await
302    }
303
304    async fn drop(&self, db_id: uuid::Uuid, is_restricted: bool) -> Result<(), BError> {
305        PostgresBackendWrapper::new(self)
306            .drop(db_id, is_restricted)
307            .await
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    #![allow(clippy::unwrap_used, clippy::needless_return)]
314
315    use dotenvy::dotenv;
316    use futures::future::join_all;
317    use sea_orm::{
318        ActiveModelBehavior, ActiveModelTrait, ConnectionTrait, DeriveEntityModel,
319        DerivePrimaryKey, DeriveRelation, EntityTrait, EnumIter, FromQueryResult, PaginatorTrait,
320        PrimaryKeyTrait, QuerySelect, Set,
321    };
322    use tokio_shared_rt::test;
323
324    use crate::{
325        common::{
326            config::PrivilegedPostgresConfig,
327            statement::postgres::tests::{
328                CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
329            },
330        },
331        r#async::{
332            backend::postgres::r#trait::tests::{
333                test_backend_drops_database, test_pool_drops_created_unrestricted_database,
334            },
335            db_pool::DatabasePoolBuilder,
336        },
337    };
338
339    use super::{
340        super::r#trait::tests::{
341            test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables,
342            test_backend_creates_database_with_restricted_privileges,
343            test_backend_creates_database_with_unrestricted_privileges,
344            test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
345            test_pool_drops_previous_databases, PgDropLock,
346        },
347        SeaORMPostgresBackend,
348    };
349
350    #[derive(Clone, Debug, DeriveEntityModel)]
351    #[sea_orm(table_name = "book")]
352    pub struct Model {
353        #[sea_orm(primary_key)]
354        id: i32,
355        title: String,
356    }
357
358    #[derive(Debug, EnumIter, DeriveRelation)]
359    pub enum Relation {}
360
361    impl ActiveModelBehavior for ActiveModel {}
362
363    async fn create_backend(with_table: bool) -> SeaORMPostgresBackend {
364        dotenv().ok();
365
366        let config = PrivilegedPostgresConfig::from_env().unwrap();
367
368        SeaORMPostgresBackend::new(config, |_| {}, |_| {}, {
369            move |conn| {
370                if with_table {
371                    Box::pin(async move {
372                        conn.execute_unprepared(CREATE_ENTITIES_STATEMENTS.join(";").as_str())
373                            .await
374                            .unwrap();
375                    })
376                } else {
377                    Box::pin(async {})
378                }
379            }
380        })
381        .await
382        .unwrap()
383    }
384
385    #[test(flavor = "multi_thread", shared)]
386    async fn backend_drops_previous_databases() {
387        test_backend_drops_previous_databases(
388            create_backend(false).await,
389            create_backend(false).await.drop_previous_databases(true),
390            create_backend(false).await.drop_previous_databases(false),
391        )
392        .await;
393    }
394
395    #[test(flavor = "multi_thread", shared)]
396    async fn backend_creates_database_with_restricted_privileges() {
397        let backend = create_backend(true).await.drop_previous_databases(false);
398        test_backend_creates_database_with_restricted_privileges(backend).await;
399    }
400
401    #[test(flavor = "multi_thread", shared)]
402    async fn backend_creates_database_with_unrestricted_privileges() {
403        let backend = create_backend(true).await.drop_previous_databases(false);
404        test_backend_creates_database_with_unrestricted_privileges(backend).await;
405    }
406
407    #[test(flavor = "multi_thread", shared)]
408    async fn backend_cleans_database_with_tables() {
409        let backend = create_backend(true).await.drop_previous_databases(false);
410        test_backend_cleans_database_with_tables(backend).await;
411    }
412
413    #[test(flavor = "multi_thread", shared)]
414    async fn backend_cleans_database_without_tables() {
415        let backend = create_backend(false).await.drop_previous_databases(false);
416        test_backend_cleans_database_without_tables(backend).await;
417    }
418
419    #[test(flavor = "multi_thread", shared)]
420    async fn backend_drops_restricted_database() {
421        let backend = create_backend(true).await.drop_previous_databases(false);
422        test_backend_drops_database(backend, true).await;
423    }
424
425    #[test(flavor = "multi_thread", shared)]
426    async fn backend_drops_unrestricted_database() {
427        let backend = create_backend(true).await.drop_previous_databases(false);
428        test_backend_drops_database(backend, false).await;
429    }
430
431    #[test(flavor = "multi_thread", shared)]
432    async fn pool_drops_previous_databases() {
433        test_pool_drops_previous_databases(
434            create_backend(false).await,
435            create_backend(false).await.drop_previous_databases(true),
436            create_backend(false).await.drop_previous_databases(false),
437        )
438        .await;
439    }
440
441    #[test(flavor = "multi_thread", shared)]
442    async fn pool_provides_isolated_databases() {
443        #[derive(FromQueryResult, Eq, PartialEq, Debug)]
444        struct QueryModel {
445            title: String,
446        }
447
448        const NUM_DBS: i64 = 3;
449
450        let backend = create_backend(true).await.drop_previous_databases(false);
451
452        async {
453            let db_pool = backend.create_database_pool().await.unwrap();
454            let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
455
456            // insert single row into each database
457            join_all(conns.iter().enumerate().map(|(i, conn)| async move {
458                let book = ActiveModel {
459                    title: Set(format!("Title {i}")),
460                    ..Default::default()
461                };
462                book.insert(&***conn).await.unwrap();
463            }))
464            .await;
465
466            // rows fetched must be as inserted
467            join_all(conns.iter().enumerate().map(|(i, conn)| async move {
468                assert_eq!(
469                    Entity::find()
470                        .select_only()
471                        .column(Column::Title)
472                        .into_model::<QueryModel>()
473                        .all(&***conn)
474                        .await
475                        .unwrap(),
476                    vec![QueryModel {
477                        title: format!("Title {i}")
478                    }]
479                );
480            }))
481            .await;
482        }
483        .lock_read()
484        .await;
485    }
486
487    #[test(flavor = "multi_thread", shared)]
488    async fn pool_provides_restricted_databases() {
489        let backend = create_backend(true).await.drop_previous_databases(false);
490
491        async {
492            let db_pool = backend.create_database_pool().await.unwrap();
493            let conn = db_pool.pull_immutable().await;
494
495            // DDL statements must fail
496            for stmt in DDL_STATEMENTS {
497                assert!(conn.execute_unprepared(stmt).await.is_err());
498            }
499
500            // DML statements must succeed
501            for stmt in DML_STATEMENTS {
502                assert!(conn.execute_unprepared(stmt).await.is_ok());
503            }
504        }
505        .lock_read()
506        .await;
507    }
508
509    #[test(flavor = "multi_thread", shared)]
510    async fn pool_provides_unrestricted_databases() {
511        let backend = create_backend(true).await.drop_previous_databases(false);
512
513        async {
514            let db_pool = backend.create_database_pool().await.unwrap();
515
516            // DML statements must succeed
517            {
518                let conn = db_pool.create_mutable().await.unwrap();
519                for stmt in DML_STATEMENTS {
520                    assert!(conn.execute_unprepared(stmt).await.is_ok());
521                }
522            }
523
524            // DDL statements must succeed
525            for stmt in DDL_STATEMENTS {
526                let conn = db_pool.create_mutable().await.unwrap();
527                assert!(conn.execute_unprepared(stmt).await.is_ok());
528            }
529        }
530        .lock_read()
531        .await;
532    }
533
534    #[test(flavor = "multi_thread", shared)]
535    async fn pool_provides_clean_databases() {
536        const NUM_DBS: i64 = 3;
537
538        let backend = create_backend(true).await.drop_previous_databases(false);
539
540        async {
541            let db_pool = backend.create_database_pool().await.unwrap();
542
543            // fetch connection pools the first time
544            {
545                let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
546
547                // databases must be empty
548                join_all(conns.iter().map(|conn| async move {
549                    assert_eq!(Entity::find().count(&***conn).await.unwrap(), 0);
550                }))
551                .await;
552
553                // insert data into each database
554                join_all(conns.iter().map(|conn| async move {
555                    let book = ActiveModel {
556                        title: Set("Title".to_owned()),
557                        ..Default::default()
558                    };
559                    book.insert(&***conn).await.unwrap();
560                }))
561                .await;
562            }
563
564            // fetch same connection pools a second time
565            {
566                let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
567
568                // databases must be empty
569                join_all(conns.iter().map(|conn| async move {
570                    assert_eq!(Entity::find().count(&***conn).await.unwrap(), 0);
571                }))
572                .await;
573            }
574        }
575        .lock_read()
576        .await;
577    }
578
579    #[test(flavor = "multi_thread", shared)]
580    async fn pool_drops_created_restricted_databases() {
581        let backend = create_backend(false).await;
582        test_pool_drops_created_restricted_databases(backend).await;
583    }
584
585    #[test(flavor = "multi_thread", shared)]
586    async fn pool_drops_created_unrestricted_database() {
587        let backend = create_backend(false).await;
588        test_pool_drops_created_unrestricted_database(backend).await;
589    }
590}