db_pool/async/backend/postgres/
tokio_postgres.rs

1use std::{borrow::Cow, collections::HashMap, convert::Into, pin::Pin};
2
3use async_trait::async_trait;
4use deadpool_postgres::Manager;
5use futures::Future;
6use parking_lot::Mutex;
7use tokio_postgres::{Client, Config, NoTls};
8use uuid::Uuid;
9
10use crate::{common::statement::postgres, util::get_db_name};
11
12use super::{
13    super::{
14        common::{
15            error::tokio_postgres::{ConnectionError, QueryError},
16            pool::tokio_postgres::r#trait::TokioPostgresPoolAssociation,
17        },
18        error::Error as BackendError,
19        r#trait::Backend,
20    },
21    r#trait::{PostgresBackend, PostgresBackendWrapper},
22};
23
24type CreateEntities = dyn Fn(Client) -> Pin<Box<dyn Future<Output = Client> + Send + 'static>>
25    + Send
26    + Sync
27    + 'static;
28
29/// [`tokio-postgres`](https://docs.rs/tokio-postgres/0.7.13/tokio_postgres/) backend
30pub struct TokioPostgresBackend<P: TokioPostgresPoolAssociation> {
31    privileged_config: Config,
32    default_pool: P::Pool,
33    db_conns: Mutex<HashMap<Uuid, Client>>,
34    create_restricted_pool: Box<dyn Fn(Manager) -> P::Builder + Send + Sync + 'static>,
35    create_entities: Box<CreateEntities>,
36    drop_previous_databases_flag: bool,
37}
38
39impl<P: TokioPostgresPoolAssociation> TokioPostgresBackend<P> {
40    /// Creates a new [`tokio-postgres`](https://docs.rs/tokio-postgres/0.7.13/tokio_postgres/) backend
41    /// # Example
42    /// ```
43    /// use bb8::Pool;
44    /// use db_pool::{
45    ///     r#async::{TokioPostgresBackend, TokioPostgresBb8},
46    ///     PrivilegedPostgresConfig,
47    /// };
48    /// use dotenvy::dotenv;
49    ///
50    /// async fn f() {
51    ///     dotenv().ok();
52    ///
53    ///     let config = PrivilegedPostgresConfig::from_env().unwrap();
54    ///     
55    ///     let backend = TokioPostgresBackend::<TokioPostgresBb8>::new(
56    ///         config.into(),
57    ///         |_| Pool::builder().max_size(10),
58    ///         |_| Pool::builder().max_size(2),
59    ///         move |conn| {
60    ///             Box::pin(async move {
61    ///                 conn.execute(
62    ///                     "CREATE TABLE book(id SERIAL PRIMARY KEY, title TEXT NOT NULL)",
63    ///                     &[],
64    ///                 )
65    ///                 .await
66    ///                 .unwrap();
67    ///                 conn
68    ///             })
69    ///         },
70    ///     )
71    ///     .await
72    ///     .unwrap();
73    /// }
74    ///
75    /// tokio_test::block_on(f());
76    /// ```
77    pub async fn new(
78        privileged_config: Config,
79        create_privileged_pool: impl Fn(Manager) -> P::Builder,
80        create_restricted_pool: impl Fn(Manager) -> P::Builder + Send + Sync + 'static,
81        create_entities: impl Fn(Client) -> Pin<Box<dyn Future<Output = Client> + Send + 'static>>
82        + Send
83        + Sync
84        + 'static,
85    ) -> Result<Self, P::BuildError> {
86        let manager = Manager::new(privileged_config.clone(), NoTls);
87        let builder = create_privileged_pool(manager);
88        let default_pool = P::build_pool(builder, privileged_config.clone()).await?;
89
90        Ok(Self {
91            privileged_config,
92            default_pool,
93            db_conns: Mutex::new(HashMap::new()),
94            create_entities: Box::new(create_entities),
95            create_restricted_pool: Box::new(create_restricted_pool),
96            drop_previous_databases_flag: true,
97        })
98    }
99
100    /// Drop databases created in previous runs upon initialization
101    #[must_use]
102    pub fn drop_previous_databases(self, value: bool) -> Self {
103        Self {
104            drop_previous_databases_flag: value,
105            ..self
106        }
107    }
108}
109
110#[async_trait]
111impl<'pool, P: TokioPostgresPoolAssociation> PostgresBackend<'pool> for TokioPostgresBackend<P> {
112    type Connection = Client;
113    type PooledConnection = P::PooledConnection<'pool>;
114    type Pool = P::Pool;
115
116    type BuildError = P::BuildError;
117    type PoolError = P::PoolError;
118    type ConnectionError = ConnectionError;
119    type QueryError = QueryError;
120
121    async fn execute_query(&self, query: &str, conn: &mut Client) -> Result<(), QueryError> {
122        conn.execute(query, &[]).await?;
123        Ok(())
124    }
125
126    async fn batch_execute_query<'a>(
127        &self,
128        query: impl IntoIterator<Item = Cow<'a, str>> + Send,
129        conn: &mut Client,
130    ) -> Result<(), QueryError> {
131        let query = query.into_iter().collect::<Vec<_>>().join(";");
132        conn.batch_execute(query.as_str()).await?;
133        Ok(())
134    }
135
136    async fn get_default_connection(
137        &'pool self,
138    ) -> Result<P::PooledConnection<'pool>, P::PoolError> {
139        P::get_connection(&self.default_pool).await
140    }
141
142    async fn establish_privileged_database_connection(
143        &self,
144        db_id: Uuid,
145    ) -> Result<Client, ConnectionError> {
146        let mut config = self.privileged_config.clone();
147        let db_name = get_db_name(db_id);
148        config.dbname(db_name.as_str());
149        let (client, connection) = config.connect(NoTls).await?;
150        tokio::spawn(connection);
151        Ok(client)
152    }
153
154    async fn establish_restricted_database_connection(
155        &self,
156        db_id: Uuid,
157    ) -> Result<Client, ConnectionError> {
158        let mut config = self.privileged_config.clone();
159        let db_name = get_db_name(db_id);
160        let db_name = db_name.as_str();
161        config.user(db_name).password(db_name).dbname(db_name);
162        let (client, connection) = config.connect(NoTls).await?;
163        tokio::spawn(connection);
164        Ok(client)
165    }
166
167    fn put_database_connection(&self, db_id: Uuid, conn: Client) {
168        self.db_conns.lock().insert(db_id, conn);
169    }
170
171    fn get_database_connection(&self, db_id: Uuid) -> Client {
172        self.db_conns
173            .lock()
174            .remove(&db_id)
175            .unwrap_or_else(|| panic!("connection map must have a connection for {db_id}"))
176    }
177
178    async fn get_previous_database_names(
179        &self,
180        conn: &mut Client,
181    ) -> Result<Vec<String>, QueryError> {
182        conn.query(postgres::GET_DATABASE_NAMES, &[])
183            .await
184            .map(|rows| rows.iter().map(|row| row.get(0)).collect())
185            .map_err(Into::into)
186    }
187
188    async fn create_entities(&self, conn: Client) -> Option<Client> {
189        Some((self.create_entities)(conn).await)
190    }
191
192    async fn create_connection_pool(&self, db_id: Uuid) -> Result<P::Pool, P::BuildError> {
193        let db_name = get_db_name(db_id);
194        let db_name = db_name.as_str();
195        let mut config = self.privileged_config.clone();
196        config.dbname(db_name);
197        config.user(db_name);
198        config.password(db_name);
199        let manager = Manager::new(config.clone(), NoTls);
200        let builder = (self.create_restricted_pool)(manager);
201        P::build_pool(builder, config).await
202    }
203
204    async fn get_table_names(
205        &self,
206        privileged_conn: &mut Client,
207    ) -> Result<Vec<String>, QueryError> {
208        privileged_conn
209            .query(postgres::GET_TABLE_NAMES, &[])
210            .await
211            .map(|rows| rows.iter().map(|row| row.get(0)).collect())
212            .map_err(Into::into)
213    }
214
215    fn get_drop_previous_databases(&self) -> bool {
216        self.drop_previous_databases_flag
217    }
218}
219
220type BError<BuildError, PoolError> =
221    BackendError<BuildError, PoolError, ConnectionError, QueryError>;
222
223#[async_trait]
224impl<P: TokioPostgresPoolAssociation> Backend for TokioPostgresBackend<P> {
225    type Pool = P::Pool;
226
227    type BuildError = P::BuildError;
228    type PoolError = P::PoolError;
229    type ConnectionError = ConnectionError;
230    type QueryError = QueryError;
231
232    async fn init(&self) -> Result<(), BError<P::BuildError, P::PoolError>> {
233        PostgresBackendWrapper::new(self).init().await
234    }
235
236    async fn create(
237        &self,
238        db_id: uuid::Uuid,
239        restrict_privileges: bool,
240    ) -> Result<P::Pool, BError<P::BuildError, P::PoolError>> {
241        PostgresBackendWrapper::new(self)
242            .create(db_id, restrict_privileges)
243            .await
244    }
245
246    async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError<P::BuildError, P::PoolError>> {
247        PostgresBackendWrapper::new(self).clean(db_id).await
248    }
249
250    async fn drop(
251        &self,
252        db_id: uuid::Uuid,
253        is_restricted: bool,
254    ) -> Result<(), BError<P::BuildError, P::PoolError>> {
255        PostgresBackendWrapper::new(self)
256            .drop(db_id, is_restricted)
257            .await
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    #![allow(clippy::unwrap_used, clippy::needless_return)]
264
265    use bb8::Pool;
266    use futures::future::join_all;
267    use tokio_postgres::Config;
268    use tokio_shared_rt::test;
269
270    use crate::{
271        r#async::{
272            backend::{
273                common::pool::tokio_postgres::bb8::TokioPostgresBb8,
274                postgres::r#trait::tests::{
275                    test_backend_creates_database_with_unrestricted_privileges,
276                    test_backend_drops_database, test_pool_drops_created_unrestricted_database,
277                },
278            },
279            db_pool::DatabasePoolBuilder,
280        },
281        common::statement::postgres::tests::{
282            CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
283        },
284    };
285
286    use super::{
287        super::r#trait::tests::{
288            PgDropLock, test_backend_cleans_database_with_tables,
289            test_backend_cleans_database_without_tables,
290            test_backend_creates_database_with_restricted_privileges,
291            test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
292            test_pool_drops_previous_databases,
293        },
294        TokioPostgresBackend,
295    };
296
297    async fn create_backend(with_table: bool) -> TokioPostgresBackend<TokioPostgresBb8> {
298        let mut config = Config::new();
299        config
300            .host("localhost")
301            .user("postgres")
302            .password("postgres");
303        TokioPostgresBackend::new(config, |_| Pool::builder(), |_| Pool::builder(), {
304            move |conn| {
305                if with_table {
306                    Box::pin(async move {
307                        conn.batch_execute(&CREATE_ENTITIES_STATEMENTS.join(";"))
308                            .await
309                            .unwrap();
310                        conn
311                    })
312                } else {
313                    Box::pin(async { conn })
314                }
315            }
316        })
317        .await
318        .unwrap()
319    }
320
321    #[test(flavor = "multi_thread", shared)]
322    async fn backend_drops_previous_databases() {
323        test_backend_drops_previous_databases(
324            create_backend(false).await,
325            create_backend(false).await.drop_previous_databases(true),
326            create_backend(false).await.drop_previous_databases(false),
327        )
328        .await;
329    }
330
331    #[test(flavor = "multi_thread", shared)]
332    async fn backend_creates_database_with_restricted_privileges() {
333        let backend = create_backend(true).await.drop_previous_databases(false);
334        test_backend_creates_database_with_restricted_privileges(backend).await;
335    }
336
337    #[test(flavor = "multi_thread", shared)]
338    async fn backend_creates_database_with_unrestricted_privileges() {
339        let backend = create_backend(true).await.drop_previous_databases(false);
340        test_backend_creates_database_with_unrestricted_privileges(backend).await;
341    }
342
343    #[test(flavor = "multi_thread", shared)]
344    async fn backend_cleans_database_with_tables() {
345        let backend = create_backend(true).await.drop_previous_databases(false);
346        test_backend_cleans_database_with_tables(backend).await;
347    }
348
349    #[test(flavor = "multi_thread", shared)]
350    async fn backend_cleans_database_without_tables() {
351        let backend = create_backend(false).await.drop_previous_databases(false);
352        test_backend_cleans_database_without_tables(backend).await;
353    }
354
355    #[test(flavor = "multi_thread", shared)]
356    async fn backend_drops_restricted_database() {
357        let backend = create_backend(true).await.drop_previous_databases(false);
358        test_backend_drops_database(backend, true).await;
359    }
360
361    #[test(flavor = "multi_thread", shared)]
362    async fn backend_drops_unrestricted_database() {
363        let backend = create_backend(true).await.drop_previous_databases(false);
364        test_backend_drops_database(backend, false).await;
365    }
366
367    #[test(flavor = "multi_thread", shared)]
368    async fn pool_drops_previous_databases() {
369        test_pool_drops_previous_databases(
370            create_backend(false).await,
371            create_backend(false).await.drop_previous_databases(true),
372            create_backend(false).await.drop_previous_databases(false),
373        )
374        .await;
375    }
376
377    #[test(flavor = "multi_thread", shared)]
378    async fn pool_provides_isolated_databases() {
379        const NUM_DBS: i64 = 3;
380
381        let backend = create_backend(true).await.drop_previous_databases(false);
382
383        async {
384            let db_pool = backend.create_database_pool().await.unwrap();
385            let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
386
387            // insert single row into each database
388            join_all(
389                conn_pools
390                    .iter()
391                    .enumerate()
392                    .map(|(i, conn_pool)| async move {
393                        let conn = &mut conn_pool.get().await.unwrap();
394                        conn.execute(
395                            "INSERT INTO book (title) VALUES ($1)",
396                            &[&format!("Title {i}").as_str()],
397                        )
398                        .await
399                        .unwrap();
400                    }),
401            )
402            .await;
403
404            // rows fetched must be as inserted
405            join_all(
406                conn_pools
407                    .iter()
408                    .enumerate()
409                    .map(|(i, conn_pool)| async move {
410                        let conn = &mut conn_pool.get().await.unwrap();
411                        assert_eq!(
412                            conn.query("SELECT title FROM book", &[])
413                                .await
414                                .unwrap()
415                                .iter()
416                                .map(|row| row.get::<_, String>(0))
417                                .collect::<Vec<_>>(),
418                            vec![format!("Title {i}")]
419                        );
420                    }),
421            )
422            .await;
423        }
424        .lock_read()
425        .await;
426    }
427
428    #[test(flavor = "multi_thread", shared)]
429    async fn pool_provides_restricted_databases() {
430        let backend = create_backend(true).await.drop_previous_databases(false);
431
432        async {
433            let db_pool = backend.create_database_pool().await.unwrap();
434
435            let conn_pool = db_pool.pull_immutable().await;
436            let conn = &mut conn_pool.get().await.unwrap();
437
438            // DDL statements must fail
439            for stmt in DDL_STATEMENTS {
440                assert!(conn.execute(stmt, &[]).await.is_err());
441            }
442
443            // DML statements must succeed
444            for stmt in DML_STATEMENTS {
445                assert!(conn.execute(stmt, &[]).await.is_ok());
446            }
447        }
448        .lock_read()
449        .await;
450    }
451
452    #[test(flavor = "multi_thread", shared)]
453    async fn pool_provides_unrestricted_databases() {
454        let backend = create_backend(true).await.drop_previous_databases(false);
455
456        async {
457            let db_pool = backend.create_database_pool().await.unwrap();
458
459            // DML statements must succeed
460            {
461                let conn_pool = db_pool.create_mutable().await.unwrap();
462                let conn = &mut conn_pool.get().await.unwrap();
463                for stmt in DML_STATEMENTS {
464                    assert!(conn.execute(stmt, &[]).await.is_ok());
465                }
466            }
467
468            // DDL statements must succeed
469            for stmt in DDL_STATEMENTS {
470                let conn_pool = db_pool.create_mutable().await.unwrap();
471                let conn = &mut conn_pool.get().await.unwrap();
472                assert!(conn.execute(stmt, &[]).await.is_ok());
473            }
474        }
475        .lock_read()
476        .await;
477    }
478
479    #[test(flavor = "multi_thread", shared)]
480    async fn pool_provides_clean_databases() {
481        const NUM_DBS: i64 = 3;
482
483        let backend = create_backend(true).await.drop_previous_databases(false);
484
485        async {
486            let db_pool = backend.create_database_pool().await.unwrap();
487
488            // fetch connection pools the first time
489            {
490                let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
491
492                // databases must be empty
493                join_all(conn_pools.iter().map(|conn_pool| async move {
494                    let conn = &mut conn_pool.get().await.unwrap();
495                    assert_eq!(
496                        conn.query_one("SELECT COUNT(*) FROM book", &[])
497                            .await
498                            .unwrap()
499                            .get::<_, i64>(0),
500                        0
501                    );
502                }))
503                .await;
504
505                // insert data into each database
506                join_all(conn_pools.iter().map(|conn_pool| async move {
507                    let conn = &mut conn_pool.get().await.unwrap();
508                    conn.execute("INSERT INTO book (title) VALUES ($1)", &[&"Title"])
509                        .await
510                        .unwrap();
511                }))
512                .await;
513            }
514
515            // fetch same connection pools a second time
516            {
517                let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
518
519                // databases must be empty
520                join_all(conn_pools.iter().map(|conn_pool| async move {
521                    let conn = &mut conn_pool.get().await.unwrap();
522                    assert_eq!(
523                        conn.query_one("SELECT COUNT(*) FROM book", &[])
524                            .await
525                            .unwrap()
526                            .get::<_, i64>(0),
527                        0
528                    );
529                }))
530                .await;
531            }
532        }
533        .lock_read()
534        .await;
535    }
536
537    #[test(flavor = "multi_thread", shared)]
538    async fn pool_drops_created_restricted_databases() {
539        let backend = create_backend(false).await;
540        test_pool_drops_created_restricted_databases(backend).await;
541    }
542
543    #[test(flavor = "multi_thread", shared)]
544    async fn pool_drops_created_unrestricted_database() {
545        let backend = create_backend(false).await;
546        test_pool_drops_created_unrestricted_database(backend).await;
547    }
548}