db_pool/async/backend/postgres/
diesel.rs

1use std::{borrow::Cow, collections::HashMap, pin::Pin};
2
3use async_trait::async_trait;
4use diesel::{prelude::*, result::Error, sql_query, table, ConnectionError};
5use diesel_async::{
6    pooled_connection::{AsyncDieselConnectionManager, ManagerConfig, SetupCallback},
7    AsyncConnection, AsyncPgConnection, RunQueryDsl, SimpleAsyncConnection,
8};
9use futures::{future::FutureExt, Future};
10use parking_lot::Mutex;
11use uuid::Uuid;
12
13use crate::{common::config::postgres::PrivilegedPostgresConfig, util::get_db_name};
14
15use super::{
16    super::{
17        common::pool::diesel::r#trait::DieselPoolAssociation, error::Error as BackendError,
18        r#trait::Backend,
19    },
20    r#trait::{PostgresBackend, PostgresBackendWrapper},
21};
22
23type CreateEntities = dyn Fn(AsyncPgConnection) -> Pin<Box<dyn Future<Output = AsyncPgConnection> + Send + 'static>>
24    + Send
25    + Sync
26    + 'static;
27
28/// [`Diesel async Postgres`](https://docs.rs/diesel-async/0.5.0/diesel_async/struct.AsyncPgConnection.html) backend
29pub struct DieselAsyncPostgresBackend<P: DieselPoolAssociation<AsyncPgConnection>> {
30    privileged_config: PrivilegedPostgresConfig,
31    default_pool: P::Pool,
32    db_conns: Mutex<HashMap<Uuid, AsyncPgConnection>>,
33    create_restricted_pool: Box<dyn Fn() -> P::Builder + Send + Sync + 'static>,
34    create_connection: Box<dyn Fn() -> SetupCallback<AsyncPgConnection> + Send + Sync + 'static>,
35    create_entities: Box<CreateEntities>,
36    drop_previous_databases_flag: bool,
37}
38
39impl<P: DieselPoolAssociation<AsyncPgConnection>> DieselAsyncPostgresBackend<P> {
40    /// Creates a new [`Diesel async Postgres`](https://docs.rs/diesel-async/0.5.0/diesel_async/struct.AsyncPgConnection.html) backend
41    /// # Example
42    /// ```
43    /// use bb8::Pool;
44    /// use db_pool::{
45    ///     r#async::{DieselAsyncPostgresBackend, DieselBb8},
46    ///     PrivilegedPostgresConfig,
47    /// };
48    /// use diesel::sql_query;
49    /// use diesel_async::RunQueryDsl;
50    /// use dotenvy::dotenv;
51    ///
52    /// async fn f() {
53    ///     dotenv().ok();
54    ///
55    ///     let config = PrivilegedPostgresConfig::from_env().unwrap();
56    ///
57    ///     let backend = DieselAsyncPostgresBackend::<DieselBb8>::new(
58    ///         config,
59    ///         || Pool::builder().max_size(10),
60    ///         || Pool::builder().max_size(2),
61    ///         None,
62    ///         move |mut conn| {
63    ///             Box::pin(async {
64    ///                 sql_query("CREATE TABLE book(id SERIAL PRIMARY KEY, title TEXT NOT NULL)")
65    ///                     .execute(&mut conn)
66    ///                     .await
67    ///                     .unwrap();
68    ///                 conn
69    ///             })
70    ///         },
71    ///     )
72    ///     .await
73    ///     .unwrap();
74    /// }
75    ///
76    /// tokio_test::block_on(f());
77    /// ```
78    pub async fn new(
79        privileged_config: PrivilegedPostgresConfig,
80        create_privileged_pool: impl Fn() -> P::Builder,
81        create_restricted_pool: impl Fn() -> P::Builder + Send + Sync + 'static,
82        custom_create_connection: Option<
83            Box<dyn Fn() -> SetupCallback<AsyncPgConnection> + Send + Sync + 'static>,
84        >,
85        create_entities: impl Fn(
86                AsyncPgConnection,
87            ) -> Pin<Box<dyn Future<Output = AsyncPgConnection> + Send + 'static>>
88            + Send
89            + Sync
90            + 'static,
91    ) -> Result<Self, P::BuildError> {
92        let create_connection = custom_create_connection.unwrap_or_else(|| {
93            Box::new(|| {
94                Box::new(|connection_url| AsyncPgConnection::establish(connection_url).boxed())
95            })
96        });
97
98        let manager_config = {
99            let mut config = ManagerConfig::default();
100            config.custom_setup = Box::new(create_connection());
101            config
102        };
103        let manager = AsyncDieselConnectionManager::new_with_config(
104            privileged_config.default_connection_url(),
105            manager_config,
106        );
107        let builder = create_privileged_pool();
108        let default_pool = P::build_pool(builder, manager).await?;
109
110        Ok(Self {
111            privileged_config,
112            default_pool,
113            db_conns: Mutex::new(HashMap::new()),
114            create_restricted_pool: Box::new(create_restricted_pool),
115            create_connection,
116            create_entities: Box::new(create_entities),
117            drop_previous_databases_flag: true,
118        })
119    }
120
121    /// Drop databases created in previous runs upon initialization
122    #[must_use]
123    pub fn drop_previous_databases(self, value: bool) -> Self {
124        Self {
125            drop_previous_databases_flag: value,
126            ..self
127        }
128    }
129}
130
131#[async_trait]
132impl<'pool, P: DieselPoolAssociation<AsyncPgConnection>> PostgresBackend<'pool>
133    for DieselAsyncPostgresBackend<P>
134{
135    type Connection = AsyncPgConnection;
136    type PooledConnection = P::PooledConnection<'pool>;
137    type Pool = P::Pool;
138
139    type BuildError = P::BuildError;
140    type PoolError = P::PoolError;
141    type ConnectionError = ConnectionError;
142    type QueryError = Error;
143
144    async fn execute_query(&self, query: &str, conn: &mut AsyncPgConnection) -> QueryResult<()> {
145        sql_query(query).execute(conn).await?;
146        Ok(())
147    }
148
149    async fn batch_execute_query<'a>(
150        &self,
151        query: impl IntoIterator<Item = Cow<'a, str>> + Send,
152        conn: &mut AsyncPgConnection,
153    ) -> QueryResult<()> {
154        let query = query.into_iter().collect::<Vec<_>>();
155        if query.is_empty() {
156            Ok(())
157        } else {
158            conn.batch_execute(query.join(";").as_str()).await
159        }
160    }
161
162    async fn get_default_connection(
163        &'pool self,
164    ) -> Result<P::PooledConnection<'pool>, P::PoolError> {
165        P::get_connection(&self.default_pool).await
166    }
167
168    async fn establish_privileged_database_connection(
169        &self,
170        db_id: Uuid,
171    ) -> ConnectionResult<AsyncPgConnection> {
172        let db_name = get_db_name(db_id);
173        let database_url = self
174            .privileged_config
175            .privileged_database_connection_url(db_name.as_str());
176        (self.create_connection)()(database_url.as_str()).await
177    }
178
179    async fn establish_restricted_database_connection(
180        &self,
181        db_id: Uuid,
182    ) -> ConnectionResult<AsyncPgConnection> {
183        let db_name = get_db_name(db_id);
184        let db_name = db_name.as_str();
185        let database_url = self.privileged_config.restricted_database_connection_url(
186            db_name,
187            Some(db_name),
188            db_name,
189        );
190        (self.create_connection)()(database_url.as_str()).await
191    }
192
193    fn put_database_connection(&self, db_id: Uuid, conn: AsyncPgConnection) {
194        self.db_conns.lock().insert(db_id, conn);
195    }
196
197    fn get_database_connection(&self, db_id: Uuid) -> AsyncPgConnection {
198        self.db_conns
199            .lock()
200            .remove(&db_id)
201            .unwrap_or_else(|| panic!("connection map must have a connection for {db_id}"))
202    }
203
204    async fn get_previous_database_names(
205        &self,
206        conn: &mut AsyncPgConnection,
207    ) -> QueryResult<Vec<String>> {
208        table! {
209            pg_database (oid) {
210                oid -> Int4,
211                datname -> Text
212            }
213        }
214
215        pg_database::table
216            .select(pg_database::datname)
217            .filter(pg_database::datname.like("db_pool_%"))
218            .load::<String>(conn)
219            .await
220    }
221
222    async fn create_entities(&self, conn: AsyncPgConnection) -> AsyncPgConnection {
223        (self.create_entities)(conn).await
224    }
225
226    async fn create_connection_pool(&self, db_id: Uuid) -> Result<P::Pool, P::BuildError> {
227        let db_name = get_db_name(db_id);
228        let db_name = db_name.as_str();
229        let database_url = self.privileged_config.restricted_database_connection_url(
230            db_name,
231            Some(db_name),
232            db_name,
233        );
234        let manager_config = {
235            let mut config = ManagerConfig::default();
236            config.custom_setup = Box::new((self.create_connection)());
237            config
238        };
239        let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new_with_config(
240            database_url.as_str(),
241            manager_config,
242        );
243        let builder = (self.create_restricted_pool)();
244        P::build_pool(builder, manager).await
245    }
246
247    async fn get_table_names(
248        &self,
249        privileged_conn: &mut AsyncPgConnection,
250    ) -> QueryResult<Vec<String>> {
251        table! {
252            pg_tables (tablename) {
253                #[sql_name = "schemaname"]
254                schema_name -> Text,
255                tablename -> Text
256            }
257        }
258
259        pg_tables::table
260            .filter(pg_tables::schema_name.ne_all(["pg_catalog", "information_schema"]))
261            .select(pg_tables::tablename)
262            .load(privileged_conn)
263            .await
264    }
265
266    fn get_drop_previous_databases(&self) -> bool {
267        self.drop_previous_databases_flag
268    }
269}
270
271type BError<BuildError, PoolError> = BackendError<BuildError, PoolError, ConnectionError, Error>;
272
273#[async_trait]
274impl<P: DieselPoolAssociation<AsyncPgConnection>> Backend for DieselAsyncPostgresBackend<P> {
275    type Pool = P::Pool;
276
277    type BuildError = P::BuildError;
278    type PoolError = P::PoolError;
279    type ConnectionError = ConnectionError;
280    type QueryError = Error;
281
282    async fn init(&self) -> Result<(), BError<P::BuildError, P::PoolError>> {
283        PostgresBackendWrapper::new(self).init().await
284    }
285
286    async fn create(
287        &self,
288        db_id: uuid::Uuid,
289        restrict_privileges: bool,
290    ) -> Result<P::Pool, BError<P::BuildError, P::PoolError>> {
291        PostgresBackendWrapper::new(self)
292            .create(db_id, restrict_privileges)
293            .await
294    }
295
296    async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError<P::BuildError, P::PoolError>> {
297        PostgresBackendWrapper::new(self).clean(db_id).await
298    }
299
300    async fn drop(
301        &self,
302        db_id: uuid::Uuid,
303        is_restricted: bool,
304    ) -> Result<(), BError<P::BuildError, P::PoolError>> {
305        PostgresBackendWrapper::new(self)
306            .drop(db_id, is_restricted)
307            .await
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    #![allow(clippy::unwrap_used, clippy::needless_return)]
314
315    use std::borrow::Cow;
316
317    use bb8::Pool;
318    use diesel::{insert_into, sql_query, table, Insertable, QueryDsl};
319    use diesel_async::{RunQueryDsl, SimpleAsyncConnection};
320    use dotenvy::dotenv;
321    use futures::future::join_all;
322    use tokio_shared_rt::test;
323
324    use crate::{
325        common::{
326            config::PrivilegedPostgresConfig,
327            statement::postgres::tests::{
328                CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
329            },
330        },
331        r#async::{
332            backend::{
333                common::pool::diesel::bb8::DieselBb8,
334                postgres::r#trait::tests::test_pool_drops_created_unrestricted_database,
335            },
336            db_pool::DatabasePoolBuilder,
337        },
338    };
339
340    use super::{
341        super::r#trait::tests::{
342            test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables,
343            test_backend_creates_database_with_restricted_privileges,
344            test_backend_creates_database_with_unrestricted_privileges,
345            test_backend_drops_database, test_backend_drops_previous_databases,
346            test_pool_drops_created_restricted_databases, test_pool_drops_previous_databases,
347            PgDropLock,
348        },
349        DieselAsyncPostgresBackend,
350    };
351
352    table! {
353        book (id) {
354            id -> Int4,
355            title -> Text
356        }
357    }
358
359    #[derive(Insertable)]
360    #[diesel(table_name = book)]
361    struct NewBook<'a> {
362        title: Cow<'a, str>,
363    }
364
365    async fn create_backend(with_table: bool) -> DieselAsyncPostgresBackend<DieselBb8> {
366        dotenv().ok();
367
368        let config = PrivilegedPostgresConfig::from_env().unwrap();
369
370        DieselAsyncPostgresBackend::new(config, Pool::builder, Pool::builder, None, {
371            move |mut conn| {
372                if with_table {
373                    Box::pin(async move {
374                        let query = CREATE_ENTITIES_STATEMENTS.join(";");
375                        conn.batch_execute(query.as_str()).await.unwrap();
376                        conn
377                    })
378                } else {
379                    Box::pin(async { conn })
380                }
381            }
382        })
383        .await
384        .unwrap()
385    }
386
387    #[test(flavor = "multi_thread", shared)]
388    async fn backend_drops_previous_databases() {
389        test_backend_drops_previous_databases(
390            create_backend(false).await,
391            create_backend(false).await.drop_previous_databases(true),
392            create_backend(false).await.drop_previous_databases(false),
393        )
394        .await;
395    }
396
397    #[test(flavor = "multi_thread", shared)]
398    async fn backend_creates_database_with_restricted_privileges() {
399        let backend = create_backend(true).await.drop_previous_databases(false);
400        test_backend_creates_database_with_restricted_privileges(backend).await;
401    }
402
403    #[test(flavor = "multi_thread", shared)]
404    async fn backend_creates_database_with_unrestricted_privileges() {
405        let backend = create_backend(true).await.drop_previous_databases(false);
406        test_backend_creates_database_with_unrestricted_privileges(backend).await;
407    }
408
409    #[test(flavor = "multi_thread", shared)]
410    async fn backend_cleans_database_with_tables() {
411        let backend = create_backend(true).await.drop_previous_databases(false);
412        test_backend_cleans_database_with_tables(backend).await;
413    }
414
415    #[test(flavor = "multi_thread", shared)]
416    async fn backend_cleans_database_without_tables() {
417        let backend = create_backend(false).await.drop_previous_databases(false);
418        test_backend_cleans_database_without_tables(backend).await;
419    }
420
421    #[test(flavor = "multi_thread", shared)]
422    async fn backend_drops_restricted_database() {
423        let backend = create_backend(true).await.drop_previous_databases(false);
424        test_backend_drops_database(backend, true).await;
425    }
426
427    #[test(flavor = "multi_thread", shared)]
428    async fn backend_drops_unrestricted_database() {
429        let backend = create_backend(true).await.drop_previous_databases(false);
430        test_backend_drops_database(backend, false).await;
431    }
432
433    #[test(flavor = "multi_thread", shared)]
434    async fn pool_drops_previous_databases() {
435        test_pool_drops_previous_databases(
436            create_backend(false).await,
437            create_backend(false).await.drop_previous_databases(true),
438            create_backend(false).await.drop_previous_databases(false),
439        )
440        .await;
441    }
442
443    #[test(flavor = "multi_thread", shared)]
444    async fn pool_provides_isolated_databases() {
445        const NUM_DBS: i64 = 3;
446
447        let backend = create_backend(true).await.drop_previous_databases(false);
448
449        async {
450            let db_pool = backend.create_database_pool().await.unwrap();
451            let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
452
453            // insert single row into each database
454            join_all(
455                conn_pools
456                    .iter()
457                    .enumerate()
458                    .map(|(i, conn_pool)| async move {
459                        let conn = &mut conn_pool.get().await.unwrap();
460                        insert_into(book::table)
461                            .values(NewBook {
462                                title: format!("Title {i}").into(),
463                            })
464                            .execute(conn)
465                            .await
466                            .unwrap();
467                    }),
468            )
469            .await;
470
471            // rows fetched must be as inserted
472            join_all(
473                conn_pools
474                    .iter()
475                    .enumerate()
476                    .map(|(i, conn_pool)| async move {
477                        let conn = &mut conn_pool.get().await.unwrap();
478                        assert_eq!(
479                            book::table
480                                .select(book::title)
481                                .load::<String>(conn)
482                                .await
483                                .unwrap(),
484                            vec![format!("Title {i}")]
485                        );
486                    }),
487            )
488            .await;
489        }
490        .lock_read()
491        .await;
492    }
493
494    #[test(flavor = "multi_thread", shared)]
495    async fn pool_provides_restricted_databases() {
496        let backend = create_backend(true).await.drop_previous_databases(false);
497
498        async {
499            let db_pool = backend.create_database_pool().await.unwrap();
500            let conn_pool = db_pool.pull_immutable().await;
501            let conn = &mut conn_pool.get().await.unwrap();
502
503            // DDL statements must fail
504            for stmt in DDL_STATEMENTS {
505                assert!(sql_query(stmt).execute(conn).await.is_err());
506            }
507
508            // DML statements must succeed
509            for stmt in DML_STATEMENTS {
510                assert!(sql_query(stmt).execute(conn).await.is_ok());
511            }
512        }
513        .lock_read()
514        .await;
515    }
516
517    #[test(flavor = "multi_thread", shared)]
518    async fn pool_provides_unrestricted_databases() {
519        let backend = create_backend(true).await.drop_previous_databases(false);
520
521        async {
522            let db_pool = backend.create_database_pool().await.unwrap();
523
524            // DML statements must succeed
525            {
526                let conn_pool = db_pool.create_mutable().await.unwrap();
527                let conn = &mut conn_pool.get().await.unwrap();
528                for stmt in DML_STATEMENTS {
529                    assert!(sql_query(stmt).execute(conn).await.is_ok());
530                }
531            }
532
533            // DDL statements must succeed
534            for stmt in DDL_STATEMENTS {
535                let conn_pool = db_pool.create_mutable().await.unwrap();
536                let conn = &mut conn_pool.get().await.unwrap();
537                assert!(sql_query(stmt).execute(conn).await.is_ok());
538            }
539        }
540        .lock_read()
541        .await;
542    }
543
544    #[test(flavor = "multi_thread", shared)]
545    async fn pool_provides_clean_databases() {
546        const NUM_DBS: i64 = 3;
547
548        let backend = create_backend(true).await.drop_previous_databases(false);
549
550        async {
551            let db_pool = backend.create_database_pool().await.unwrap();
552
553            // fetch connection pools the first time
554            {
555                let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
556
557                // databases must be empty
558                join_all(conn_pools.iter().map(|conn_pool| async move {
559                    let conn = &mut conn_pool.get().await.unwrap();
560                    assert_eq!(
561                        book::table.count().get_result::<i64>(conn).await.unwrap(),
562                        0
563                    );
564                }))
565                .await;
566
567                // insert data into each database
568                join_all(conn_pools.iter().map(|conn_pool| async move {
569                    let conn = &mut conn_pool.get().await.unwrap();
570                    insert_into(book::table)
571                        .values(NewBook {
572                            title: "Title".into(),
573                        })
574                        .execute(conn)
575                        .await
576                        .unwrap();
577                }))
578                .await;
579            }
580
581            // fetch same connection pools a second time
582            {
583                let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
584
585                // databases must be empty
586                join_all(conn_pools.iter().map(|conn_pool| async move {
587                    let conn = &mut conn_pool.get().await.unwrap();
588                    assert_eq!(
589                        book::table.count().get_result::<i64>(conn).await.unwrap(),
590                        0
591                    );
592                }))
593                .await;
594            }
595        }
596        .lock_read()
597        .await;
598    }
599
600    #[test(flavor = "multi_thread", shared)]
601    async fn pool_drops_created_restricted_databases() {
602        let backend = create_backend(false).await;
603        test_pool_drops_created_restricted_databases(backend).await;
604    }
605
606    #[test(flavor = "multi_thread", shared)]
607    async fn pool_drops_created_unrestricted_database() {
608        let backend = create_backend(false).await;
609        test_pool_drops_created_unrestricted_database(backend).await;
610    }
611}