db_pool/async/backend/postgres/
diesel.rs

1use std::{borrow::Cow, collections::HashMap, pin::Pin};
2
3use async_trait::async_trait;
4use diesel::{ConnectionError, prelude::*, result::Error, sql_query, table};
5use diesel_async::{
6    AsyncConnection, AsyncPgConnection, RunQueryDsl, SimpleAsyncConnection,
7    pooled_connection::{AsyncDieselConnectionManager, ManagerConfig, SetupCallback},
8};
9use futures::{Future, future::FutureExt};
10use parking_lot::Mutex;
11use uuid::Uuid;
12
13use crate::{common::config::postgres::PrivilegedPostgresConfig, util::get_db_name};
14
15use super::{
16    super::{
17        common::pool::diesel::r#trait::DieselPoolAssociation, error::Error as BackendError,
18        r#trait::Backend,
19    },
20    r#trait::{PostgresBackend, PostgresBackendWrapper},
21};
22
23type CreateEntities = dyn Fn(
24        AsyncPgConnection,
25    ) -> Pin<Box<dyn Future<Output = Option<AsyncPgConnection>> + Send + 'static>>
26    + Send
27    + Sync
28    + 'static;
29
30/// [`Diesel async Postgres`](https://docs.rs/diesel-async/0.5.2/diesel_async/struct.AsyncPgConnection.html) backend
31pub struct DieselAsyncPostgresBackend<P: DieselPoolAssociation<AsyncPgConnection>> {
32    privileged_config: PrivilegedPostgresConfig,
33    default_pool: P::Pool,
34    db_conns: Mutex<HashMap<Uuid, AsyncPgConnection>>,
35    create_restricted_pool: Box<
36        dyn Fn(AsyncDieselConnectionManager<AsyncPgConnection>) -> P::Builder
37            + Send
38            + Sync
39            + 'static,
40    >,
41    create_connection: Box<dyn Fn() -> SetupCallback<AsyncPgConnection> + Send + Sync + 'static>,
42    create_entities: Box<CreateEntities>,
43    drop_previous_databases_flag: bool,
44}
45
46impl<P: DieselPoolAssociation<AsyncPgConnection>> DieselAsyncPostgresBackend<P> {
47    /// Creates a new [`Diesel async Postgres`](https://docs.rs/diesel-async/0.5.2/diesel_async/struct.AsyncPgConnection.html) backend
48    /// # Example
49    /// ```
50    /// use bb8::Pool;
51    /// use db_pool::{
52    ///     r#async::{DieselAsyncPostgresBackend, DieselBb8},
53    ///     PrivilegedPostgresConfig,
54    /// };
55    /// use diesel::sql_query;
56    /// use diesel_async::RunQueryDsl;
57    /// use dotenvy::dotenv;
58    ///
59    /// async fn f() {
60    ///     dotenv().ok();
61    ///
62    ///     let config = PrivilegedPostgresConfig::from_env().unwrap();
63    ///
64    ///     let backend = DieselAsyncPostgresBackend::<DieselBb8>::new(
65    ///         config,
66    ///         |_| Pool::builder().max_size(10),
67    ///         |_| Pool::builder().max_size(2),
68    ///         None,
69    ///         move |mut conn| {
70    ///             Box::pin(async {
71    ///                 sql_query("CREATE TABLE book(id SERIAL PRIMARY KEY, title TEXT NOT NULL)")
72    ///                     .execute(&mut conn)
73    ///                     .await
74    ///                     .unwrap();
75    ///                 Some(conn)
76    ///             })
77    ///         },
78    ///     )
79    ///     .await
80    ///     .unwrap();
81    /// }
82    ///
83    /// tokio_test::block_on(f());
84    /// ```
85    /// If, after running migrations, the connection has not been moved, it is preferable to return the connection with `Some(conn)`.
86    /// However, if after running migrations, the connection has been moved, `None` can be returned at the expense of the backend
87    /// having to establish the same connection again for some of its internal work.
88    pub async fn new(
89        privileged_config: PrivilegedPostgresConfig,
90        create_privileged_pool: impl Fn(AsyncDieselConnectionManager<AsyncPgConnection>) -> P::Builder,
91        create_restricted_pool: impl Fn(AsyncDieselConnectionManager<AsyncPgConnection>) -> P::Builder
92        + Send
93        + Sync
94        + 'static,
95        custom_create_connection: Option<
96            Box<dyn Fn() -> SetupCallback<AsyncPgConnection> + Send + Sync + 'static>,
97        >,
98        create_entities: impl Fn(
99            AsyncPgConnection,
100        ) -> Pin<
101            Box<dyn Future<Output = Option<AsyncPgConnection>> + Send + 'static>,
102        > + Send
103        + Sync
104        + 'static,
105    ) -> Result<Self, P::BuildError> {
106        let create_connection = custom_create_connection.unwrap_or_else(|| {
107            Box::new(|| {
108                Box::new(|connection_url| AsyncPgConnection::establish(connection_url).boxed())
109            })
110        });
111
112        let manager = || {
113            let manager_config = {
114                let mut config = ManagerConfig::default();
115                config.custom_setup = Box::new(create_connection());
116                config
117            };
118            AsyncDieselConnectionManager::new_with_config(
119                privileged_config.default_connection_url(),
120                manager_config,
121            )
122        };
123
124        // Both create and build functions take manager value as a parameter,
125        // but only one should actually use it (depends on the particular connection pool API)
126        let builder = create_privileged_pool(manager());
127
128        let default_pool = P::build_pool(builder, manager()).await?;
129
130        Ok(Self {
131            privileged_config,
132            default_pool,
133            db_conns: Mutex::new(HashMap::new()),
134            create_restricted_pool: Box::new(create_restricted_pool),
135            create_connection,
136            create_entities: Box::new(create_entities),
137            drop_previous_databases_flag: true,
138        })
139    }
140
141    /// Drop databases created in previous runs upon initialization
142    #[must_use]
143    pub fn drop_previous_databases(self, value: bool) -> Self {
144        Self {
145            drop_previous_databases_flag: value,
146            ..self
147        }
148    }
149}
150
151#[async_trait]
152impl<'pool, P: DieselPoolAssociation<AsyncPgConnection>> PostgresBackend<'pool>
153    for DieselAsyncPostgresBackend<P>
154{
155    type Connection = AsyncPgConnection;
156    type PooledConnection = P::PooledConnection<'pool>;
157    type Pool = P::Pool;
158
159    type BuildError = P::BuildError;
160    type PoolError = P::PoolError;
161    type ConnectionError = ConnectionError;
162    type QueryError = Error;
163
164    async fn execute_query(&self, query: &str, conn: &mut AsyncPgConnection) -> QueryResult<()> {
165        sql_query(query).execute(conn).await?;
166        Ok(())
167    }
168
169    async fn batch_execute_query<'a>(
170        &self,
171        query: impl IntoIterator<Item = Cow<'a, str>> + Send,
172        conn: &mut AsyncPgConnection,
173    ) -> QueryResult<()> {
174        let query = query.into_iter().collect::<Vec<_>>();
175        if query.is_empty() {
176            Ok(())
177        } else {
178            conn.batch_execute(query.join(";").as_str()).await
179        }
180    }
181
182    async fn get_default_connection(
183        &'pool self,
184    ) -> Result<P::PooledConnection<'pool>, P::PoolError> {
185        P::get_connection(&self.default_pool).await
186    }
187
188    async fn establish_privileged_database_connection(
189        &self,
190        db_id: Uuid,
191    ) -> ConnectionResult<AsyncPgConnection> {
192        let db_name = get_db_name(db_id);
193        let database_url = self
194            .privileged_config
195            .privileged_database_connection_url(db_name.as_str());
196        (self.create_connection)()(database_url.as_str()).await
197    }
198
199    async fn establish_restricted_database_connection(
200        &self,
201        db_id: Uuid,
202    ) -> ConnectionResult<AsyncPgConnection> {
203        let db_name = get_db_name(db_id);
204        let db_name = db_name.as_str();
205        let database_url = self.privileged_config.restricted_database_connection_url(
206            db_name,
207            Some(db_name),
208            db_name,
209        );
210        (self.create_connection)()(database_url.as_str()).await
211    }
212
213    fn put_database_connection(&self, db_id: Uuid, conn: AsyncPgConnection) {
214        self.db_conns.lock().insert(db_id, conn);
215    }
216
217    fn get_database_connection(&self, db_id: Uuid) -> AsyncPgConnection {
218        self.db_conns
219            .lock()
220            .remove(&db_id)
221            .unwrap_or_else(|| panic!("connection map must have a connection for {db_id}"))
222    }
223
224    async fn get_previous_database_names(
225        &self,
226        conn: &mut AsyncPgConnection,
227    ) -> QueryResult<Vec<String>> {
228        table! {
229            pg_database (oid) {
230                oid -> Int4,
231                datname -> Text
232            }
233        }
234
235        pg_database::table
236            .select(pg_database::datname)
237            .filter(pg_database::datname.like("db_pool_%"))
238            .load::<String>(conn)
239            .await
240    }
241
242    async fn create_entities(&self, conn: AsyncPgConnection) -> Option<AsyncPgConnection> {
243        (self.create_entities)(conn).await
244    }
245
246    async fn create_connection_pool(&self, db_id: Uuid) -> Result<P::Pool, P::BuildError> {
247        let db_name = get_db_name(db_id);
248        let db_name = db_name.as_str();
249        let database_url = self.privileged_config.restricted_database_connection_url(
250            db_name,
251            Some(db_name),
252            db_name,
253        );
254
255        let manager = {
256            || {
257                let manager_config = {
258                    let mut config = ManagerConfig::default();
259                    config.custom_setup = Box::new((self.create_connection)());
260                    config
261                };
262                AsyncDieselConnectionManager::<AsyncPgConnection>::new_with_config(
263                    database_url.clone(),
264                    manager_config,
265                )
266            }
267        };
268
269        // Both create and build functions take manager value as a parameter,
270        // but only one should actually use it (depends on the particular connection pool API)
271        let builder = (self.create_restricted_pool)(manager());
272
273        P::build_pool(builder, manager()).await
274    }
275
276    async fn get_table_names(
277        &self,
278        privileged_conn: &mut AsyncPgConnection,
279    ) -> QueryResult<Vec<String>> {
280        table! {
281            pg_tables (tablename) {
282                #[sql_name = "schemaname"]
283                schema_name -> Text,
284                tablename -> Text
285            }
286        }
287
288        pg_tables::table
289            .filter(pg_tables::schema_name.ne_all(["pg_catalog", "information_schema"]))
290            .select(pg_tables::tablename)
291            .load(privileged_conn)
292            .await
293    }
294
295    fn get_drop_previous_databases(&self) -> bool {
296        self.drop_previous_databases_flag
297    }
298}
299
300type BError<BuildError, PoolError> = BackendError<BuildError, PoolError, ConnectionError, Error>;
301
302#[async_trait]
303impl<P: DieselPoolAssociation<AsyncPgConnection>> Backend for DieselAsyncPostgresBackend<P> {
304    type Pool = P::Pool;
305
306    type BuildError = P::BuildError;
307    type PoolError = P::PoolError;
308    type ConnectionError = ConnectionError;
309    type QueryError = Error;
310
311    async fn init(&self) -> Result<(), BError<P::BuildError, P::PoolError>> {
312        PostgresBackendWrapper::new(self).init().await
313    }
314
315    async fn create(
316        &self,
317        db_id: uuid::Uuid,
318        restrict_privileges: bool,
319    ) -> Result<P::Pool, BError<P::BuildError, P::PoolError>> {
320        PostgresBackendWrapper::new(self)
321            .create(db_id, restrict_privileges)
322            .await
323    }
324
325    async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError<P::BuildError, P::PoolError>> {
326        PostgresBackendWrapper::new(self).clean(db_id).await
327    }
328
329    async fn drop(
330        &self,
331        db_id: uuid::Uuid,
332        is_restricted: bool,
333    ) -> Result<(), BError<P::BuildError, P::PoolError>> {
334        PostgresBackendWrapper::new(self)
335            .drop(db_id, is_restricted)
336            .await
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    #![allow(clippy::unwrap_used, clippy::needless_return)]
343
344    use std::borrow::Cow;
345
346    use bb8::Pool;
347    use diesel::{Insertable, QueryDsl, insert_into, sql_query, table};
348    use diesel_async::{RunQueryDsl, SimpleAsyncConnection};
349    use dotenvy::dotenv;
350    use futures::future::join_all;
351    use tokio_shared_rt::test;
352
353    use crate::{
354        r#async::{
355            backend::{
356                common::pool::diesel::bb8::DieselBb8,
357                postgres::r#trait::tests::test_pool_drops_created_unrestricted_database,
358            },
359            db_pool::DatabasePoolBuilder,
360        },
361        common::{
362            config::PrivilegedPostgresConfig,
363            statement::postgres::tests::{
364                CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
365            },
366        },
367    };
368
369    use super::{
370        super::r#trait::tests::{
371            PgDropLock, test_backend_cleans_database_with_tables,
372            test_backend_cleans_database_without_tables,
373            test_backend_creates_database_with_restricted_privileges,
374            test_backend_creates_database_with_unrestricted_privileges,
375            test_backend_drops_database, test_backend_drops_previous_databases,
376            test_pool_drops_created_restricted_databases, test_pool_drops_previous_databases,
377        },
378        DieselAsyncPostgresBackend,
379    };
380
381    table! {
382        book (id) {
383            id -> Int4,
384            title -> Text
385        }
386    }
387
388    #[derive(Insertable)]
389    #[diesel(table_name = book)]
390    struct NewBook<'a> {
391        title: Cow<'a, str>,
392    }
393
394    async fn create_backend(with_table: bool) -> DieselAsyncPostgresBackend<DieselBb8> {
395        dotenv().ok();
396
397        let config = PrivilegedPostgresConfig::from_env().unwrap();
398
399        DieselAsyncPostgresBackend::new(config, |_| Pool::builder(), |_| Pool::builder(), None, {
400            move |mut conn| {
401                if with_table {
402                    Box::pin(async move {
403                        let query = CREATE_ENTITIES_STATEMENTS.join(";");
404                        conn.batch_execute(query.as_str()).await.unwrap();
405                        Some(conn)
406                    })
407                } else {
408                    Box::pin(async { Some(conn) })
409                }
410            }
411        })
412        .await
413        .unwrap()
414    }
415
416    #[test(flavor = "multi_thread", shared)]
417    async fn backend_drops_previous_databases() {
418        test_backend_drops_previous_databases(
419            create_backend(false).await,
420            create_backend(false).await.drop_previous_databases(true),
421            create_backend(false).await.drop_previous_databases(false),
422        )
423        .await;
424    }
425
426    #[test(flavor = "multi_thread", shared)]
427    async fn backend_creates_database_with_restricted_privileges() {
428        let backend = create_backend(true).await.drop_previous_databases(false);
429        test_backend_creates_database_with_restricted_privileges(backend).await;
430    }
431
432    #[test(flavor = "multi_thread", shared)]
433    async fn backend_creates_database_with_unrestricted_privileges() {
434        let backend = create_backend(true).await.drop_previous_databases(false);
435        test_backend_creates_database_with_unrestricted_privileges(backend).await;
436    }
437
438    #[test(flavor = "multi_thread", shared)]
439    async fn backend_cleans_database_with_tables() {
440        let backend = create_backend(true).await.drop_previous_databases(false);
441        test_backend_cleans_database_with_tables(backend).await;
442    }
443
444    #[test(flavor = "multi_thread", shared)]
445    async fn backend_cleans_database_without_tables() {
446        let backend = create_backend(false).await.drop_previous_databases(false);
447        test_backend_cleans_database_without_tables(backend).await;
448    }
449
450    #[test(flavor = "multi_thread", shared)]
451    async fn backend_drops_restricted_database() {
452        let backend = create_backend(true).await.drop_previous_databases(false);
453        test_backend_drops_database(backend, true).await;
454    }
455
456    #[test(flavor = "multi_thread", shared)]
457    async fn backend_drops_unrestricted_database() {
458        let backend = create_backend(true).await.drop_previous_databases(false);
459        test_backend_drops_database(backend, false).await;
460    }
461
462    #[test(flavor = "multi_thread", shared)]
463    async fn pool_drops_previous_databases() {
464        test_pool_drops_previous_databases(
465            create_backend(false).await,
466            create_backend(false).await.drop_previous_databases(true),
467            create_backend(false).await.drop_previous_databases(false),
468        )
469        .await;
470    }
471
472    #[test(flavor = "multi_thread", shared)]
473    async fn pool_provides_isolated_databases() {
474        const NUM_DBS: i64 = 3;
475
476        let backend = create_backend(true).await.drop_previous_databases(false);
477
478        async {
479            let db_pool = backend.create_database_pool().await.unwrap();
480            let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
481
482            // insert single row into each database
483            join_all(
484                conn_pools
485                    .iter()
486                    .enumerate()
487                    .map(|(i, conn_pool)| async move {
488                        let conn = &mut conn_pool.get().await.unwrap();
489                        insert_into(book::table)
490                            .values(NewBook {
491                                title: format!("Title {i}").into(),
492                            })
493                            .execute(conn)
494                            .await
495                            .unwrap();
496                    }),
497            )
498            .await;
499
500            // rows fetched must be as inserted
501            join_all(
502                conn_pools
503                    .iter()
504                    .enumerate()
505                    .map(|(i, conn_pool)| async move {
506                        let conn = &mut conn_pool.get().await.unwrap();
507                        assert_eq!(
508                            book::table
509                                .select(book::title)
510                                .load::<String>(conn)
511                                .await
512                                .unwrap(),
513                            vec![format!("Title {i}")]
514                        );
515                    }),
516            )
517            .await;
518        }
519        .lock_read()
520        .await;
521    }
522
523    #[test(flavor = "multi_thread", shared)]
524    async fn pool_provides_restricted_databases() {
525        let backend = create_backend(true).await.drop_previous_databases(false);
526
527        async {
528            let db_pool = backend.create_database_pool().await.unwrap();
529            let conn_pool = db_pool.pull_immutable().await;
530            let conn = &mut conn_pool.get().await.unwrap();
531
532            // DDL statements must fail
533            for stmt in DDL_STATEMENTS {
534                assert!(sql_query(stmt).execute(conn).await.is_err());
535            }
536
537            // DML statements must succeed
538            for stmt in DML_STATEMENTS {
539                assert!(sql_query(stmt).execute(conn).await.is_ok());
540            }
541        }
542        .lock_read()
543        .await;
544    }
545
546    #[test(flavor = "multi_thread", shared)]
547    async fn pool_provides_unrestricted_databases() {
548        let backend = create_backend(true).await.drop_previous_databases(false);
549
550        async {
551            let db_pool = backend.create_database_pool().await.unwrap();
552
553            // DML statements must succeed
554            {
555                let conn_pool = db_pool.create_mutable().await.unwrap();
556                let conn = &mut conn_pool.get().await.unwrap();
557                for stmt in DML_STATEMENTS {
558                    assert!(sql_query(stmt).execute(conn).await.is_ok());
559                }
560            }
561
562            // DDL statements must succeed
563            for stmt in DDL_STATEMENTS {
564                let conn_pool = db_pool.create_mutable().await.unwrap();
565                let conn = &mut conn_pool.get().await.unwrap();
566                assert!(sql_query(stmt).execute(conn).await.is_ok());
567            }
568        }
569        .lock_read()
570        .await;
571    }
572
573    #[test(flavor = "multi_thread", shared)]
574    async fn pool_provides_clean_databases() {
575        const NUM_DBS: i64 = 3;
576
577        let backend = create_backend(true).await.drop_previous_databases(false);
578
579        async {
580            let db_pool = backend.create_database_pool().await.unwrap();
581
582            // fetch connection pools the first time
583            {
584                let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
585
586                // databases must be empty
587                join_all(conn_pools.iter().map(|conn_pool| async move {
588                    let conn = &mut conn_pool.get().await.unwrap();
589                    assert_eq!(
590                        book::table.count().get_result::<i64>(conn).await.unwrap(),
591                        0
592                    );
593                }))
594                .await;
595
596                // insert data into each database
597                join_all(conn_pools.iter().map(|conn_pool| async move {
598                    let conn = &mut conn_pool.get().await.unwrap();
599                    insert_into(book::table)
600                        .values(NewBook {
601                            title: "Title".into(),
602                        })
603                        .execute(conn)
604                        .await
605                        .unwrap();
606                }))
607                .await;
608            }
609
610            // fetch same connection pools a second time
611            {
612                let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
613
614                // databases must be empty
615                join_all(conn_pools.iter().map(|conn_pool| async move {
616                    let conn = &mut conn_pool.get().await.unwrap();
617                    assert_eq!(
618                        book::table.count().get_result::<i64>(conn).await.unwrap(),
619                        0
620                    );
621                }))
622                .await;
623            }
624        }
625        .lock_read()
626        .await;
627    }
628
629    #[test(flavor = "multi_thread", shared)]
630    async fn pool_drops_created_restricted_databases() {
631        let backend = create_backend(false).await;
632        test_pool_drops_created_restricted_databases(backend).await;
633    }
634
635    #[test(flavor = "multi_thread", shared)]
636    async fn pool_drops_created_unrestricted_database() {
637        let backend = create_backend(false).await;
638        test_pool_drops_created_unrestricted_database(backend).await;
639    }
640}