db_pool/async/backend/postgres/
tokio_postgres.rs

1use std::{borrow::Cow, collections::HashMap, convert::Into, pin::Pin};
2
3use async_trait::async_trait;
4use futures::Future;
5use parking_lot::Mutex;
6use tokio_postgres::{Client, Config, NoTls};
7use uuid::Uuid;
8
9use crate::{common::statement::postgres, util::get_db_name};
10
11use super::{
12    super::{
13        common::{
14            error::tokio_postgres::{ConnectionError, QueryError},
15            pool::tokio_postgres::r#trait::TokioPostgresPoolAssociation,
16        },
17        error::Error as BackendError,
18        r#trait::Backend,
19    },
20    r#trait::{PostgresBackend, PostgresBackendWrapper},
21};
22
23type CreateEntities = dyn Fn(Client) -> Pin<Box<dyn Future<Output = Client> + Send + 'static>>
24    + Send
25    + Sync
26    + 'static;
27
28/// [`tokio-postgres`](https://docs.rs/tokio-postgres/0.7.10/tokio_postgres/) backend
29pub struct TokioPostgresBackend<P: TokioPostgresPoolAssociation> {
30    privileged_config: Config,
31    default_pool: P::Pool,
32    db_conns: Mutex<HashMap<Uuid, Client>>,
33    create_restricted_pool: Box<dyn Fn() -> P::Builder + Send + Sync + 'static>,
34    create_entities: Box<CreateEntities>,
35    drop_previous_databases_flag: bool,
36}
37
38impl<P: TokioPostgresPoolAssociation> TokioPostgresBackend<P> {
39    /// Creates a new [`tokio-postgres`](https://docs.rs/tokio-postgres/0.7.10/tokio_postgres/) backend
40    /// # Example
41    /// ```
42    /// use bb8::Pool;
43    /// use db_pool::{
44    ///     r#async::{TokioPostgresBackend, TokioPostgresBb8},
45    ///     PrivilegedPostgresConfig,
46    /// };
47    /// use dotenvy::dotenv;
48    ///
49    /// async fn f() {
50    ///     dotenv().ok();
51    ///
52    ///     let config = PrivilegedPostgresConfig::from_env().unwrap();
53    ///     
54    ///     let backend = TokioPostgresBackend::<TokioPostgresBb8>::new(
55    ///         config.into(),
56    ///         || Pool::builder().max_size(10),
57    ///         || Pool::builder().max_size(2),
58    ///         move |conn| {
59    ///             Box::pin(async move {
60    ///                 conn.execute(
61    ///                     "CREATE TABLE book(id SERIAL PRIMARY KEY, title TEXT NOT NULL)",
62    ///                     &[],
63    ///                 )
64    ///                 .await
65    ///                 .unwrap();
66    ///                 conn
67    ///             })
68    ///         },
69    ///     )
70    ///     .await
71    ///     .unwrap();
72    /// }
73    ///
74    /// tokio_test::block_on(f());
75    /// ```
76    pub async fn new(
77        privileged_config: Config,
78        create_privileged_pool: impl Fn() -> P::Builder,
79        create_restricted_pool: impl Fn() -> P::Builder + Send + Sync + 'static,
80        create_entities: impl Fn(Client) -> Pin<Box<dyn Future<Output = Client> + Send + 'static>>
81            + Send
82            + Sync
83            + 'static,
84    ) -> Result<Self, P::BuildError> {
85        let builder = create_privileged_pool();
86        let default_pool = P::build_pool(builder, privileged_config.clone()).await?;
87
88        Ok(Self {
89            privileged_config,
90            default_pool,
91            db_conns: Mutex::new(HashMap::new()),
92            create_entities: Box::new(create_entities),
93            create_restricted_pool: Box::new(create_restricted_pool),
94            drop_previous_databases_flag: true,
95        })
96    }
97
98    /// Drop databases created in previous runs upon initialization
99    #[must_use]
100    pub fn drop_previous_databases(self, value: bool) -> Self {
101        Self {
102            drop_previous_databases_flag: value,
103            ..self
104        }
105    }
106}
107
108#[async_trait]
109impl<'pool, P: TokioPostgresPoolAssociation> PostgresBackend<'pool> for TokioPostgresBackend<P> {
110    type Connection = Client;
111    type PooledConnection = P::PooledConnection<'pool>;
112    type Pool = P::Pool;
113
114    type BuildError = P::BuildError;
115    type PoolError = P::PoolError;
116    type ConnectionError = ConnectionError;
117    type QueryError = QueryError;
118
119    async fn execute_query(&self, query: &str, conn: &mut Client) -> Result<(), QueryError> {
120        conn.execute(query, &[]).await?;
121        Ok(())
122    }
123
124    async fn batch_execute_query<'a>(
125        &self,
126        query: impl IntoIterator<Item = Cow<'a, str>> + Send,
127        conn: &mut Client,
128    ) -> Result<(), QueryError> {
129        let query = query.into_iter().collect::<Vec<_>>().join(";");
130        conn.batch_execute(query.as_str()).await?;
131        Ok(())
132    }
133
134    async fn get_default_connection(
135        &'pool self,
136    ) -> Result<P::PooledConnection<'pool>, P::PoolError> {
137        P::get_connection(&self.default_pool).await
138    }
139
140    async fn establish_privileged_database_connection(
141        &self,
142        db_id: Uuid,
143    ) -> Result<Client, ConnectionError> {
144        let mut config = self.privileged_config.clone();
145        let db_name = get_db_name(db_id);
146        config.dbname(db_name.as_str());
147        let (client, connection) = config.connect(NoTls).await?;
148        tokio::spawn(connection);
149        Ok(client)
150    }
151
152    async fn establish_restricted_database_connection(
153        &self,
154        db_id: Uuid,
155    ) -> Result<Client, ConnectionError> {
156        let mut config = self.privileged_config.clone();
157        let db_name = get_db_name(db_id);
158        let db_name = db_name.as_str();
159        config.user(db_name).password(db_name).dbname(db_name);
160        let (client, connection) = config.connect(NoTls).await?;
161        tokio::spawn(connection);
162        Ok(client)
163    }
164
165    fn put_database_connection(&self, db_id: Uuid, conn: Client) {
166        self.db_conns.lock().insert(db_id, conn);
167    }
168
169    fn get_database_connection(&self, db_id: Uuid) -> Client {
170        self.db_conns
171            .lock()
172            .remove(&db_id)
173            .unwrap_or_else(|| panic!("connection map must have a connection for {db_id}"))
174    }
175
176    async fn get_previous_database_names(
177        &self,
178        conn: &mut Client,
179    ) -> Result<Vec<String>, QueryError> {
180        conn.query(postgres::GET_DATABASE_NAMES, &[])
181            .await
182            .map(|rows| rows.iter().map(|row| row.get(0)).collect())
183            .map_err(Into::into)
184    }
185
186    async fn create_entities(&self, conn: Client) -> Client {
187        (self.create_entities)(conn).await
188    }
189
190    async fn create_connection_pool(&self, db_id: Uuid) -> Result<P::Pool, P::BuildError> {
191        let db_name = get_db_name(db_id);
192        let db_name = db_name.as_str();
193        let mut config = self.privileged_config.clone();
194        config.dbname(db_name);
195        config.user(db_name);
196        config.password(db_name);
197        let builder = (self.create_restricted_pool)();
198        P::build_pool(builder, config).await
199    }
200
201    async fn get_table_names(
202        &self,
203        privileged_conn: &mut Client,
204    ) -> Result<Vec<String>, QueryError> {
205        privileged_conn
206            .query(postgres::GET_TABLE_NAMES, &[])
207            .await
208            .map(|rows| rows.iter().map(|row| row.get(0)).collect())
209            .map_err(Into::into)
210    }
211
212    fn get_drop_previous_databases(&self) -> bool {
213        self.drop_previous_databases_flag
214    }
215}
216
217type BError<BuildError, PoolError> =
218    BackendError<BuildError, PoolError, ConnectionError, QueryError>;
219
220#[async_trait]
221impl<P: TokioPostgresPoolAssociation> Backend for TokioPostgresBackend<P> {
222    type Pool = P::Pool;
223
224    type BuildError = P::BuildError;
225    type PoolError = P::PoolError;
226    type ConnectionError = ConnectionError;
227    type QueryError = QueryError;
228
229    async fn init(&self) -> Result<(), BError<P::BuildError, P::PoolError>> {
230        PostgresBackendWrapper::new(self).init().await
231    }
232
233    async fn create(
234        &self,
235        db_id: uuid::Uuid,
236        restrict_privileges: bool,
237    ) -> Result<P::Pool, BError<P::BuildError, P::PoolError>> {
238        PostgresBackendWrapper::new(self)
239            .create(db_id, restrict_privileges)
240            .await
241    }
242
243    async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError<P::BuildError, P::PoolError>> {
244        PostgresBackendWrapper::new(self).clean(db_id).await
245    }
246
247    async fn drop(
248        &self,
249        db_id: uuid::Uuid,
250        is_restricted: bool,
251    ) -> Result<(), BError<P::BuildError, P::PoolError>> {
252        PostgresBackendWrapper::new(self)
253            .drop(db_id, is_restricted)
254            .await
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    #![allow(clippy::unwrap_used, clippy::needless_return)]
261
262    use bb8::Pool;
263    use futures::future::join_all;
264    use tokio_postgres::Config;
265    use tokio_shared_rt::test;
266
267    use crate::{
268        common::statement::postgres::tests::{
269            CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
270        },
271        r#async::{
272            backend::{
273                common::pool::tokio_postgres::bb8::TokioPostgresBb8,
274                postgres::r#trait::tests::{
275                    test_backend_creates_database_with_unrestricted_privileges,
276                    test_backend_drops_database, test_pool_drops_created_unrestricted_database,
277                },
278            },
279            db_pool::DatabasePoolBuilder,
280        },
281    };
282
283    use super::{
284        super::r#trait::tests::{
285            test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables,
286            test_backend_creates_database_with_restricted_privileges,
287            test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
288            test_pool_drops_previous_databases, PgDropLock,
289        },
290        TokioPostgresBackend,
291    };
292
293    async fn create_backend(with_table: bool) -> TokioPostgresBackend<TokioPostgresBb8> {
294        let mut config = Config::new();
295        config
296            .host("localhost")
297            .user("postgres")
298            .password("postgres");
299        TokioPostgresBackend::new(config, Pool::builder, Pool::builder, {
300            move |conn| {
301                if with_table {
302                    Box::pin(async move {
303                        conn.batch_execute(&CREATE_ENTITIES_STATEMENTS.join(";"))
304                            .await
305                            .unwrap();
306                        conn
307                    })
308                } else {
309                    Box::pin(async { conn })
310                }
311            }
312        })
313        .await
314        .unwrap()
315    }
316
317    #[test(flavor = "multi_thread", shared)]
318    async fn backend_drops_previous_databases() {
319        test_backend_drops_previous_databases(
320            create_backend(false).await,
321            create_backend(false).await.drop_previous_databases(true),
322            create_backend(false).await.drop_previous_databases(false),
323        )
324        .await;
325    }
326
327    #[test(flavor = "multi_thread", shared)]
328    async fn backend_creates_database_with_restricted_privileges() {
329        let backend = create_backend(true).await.drop_previous_databases(false);
330        test_backend_creates_database_with_restricted_privileges(backend).await;
331    }
332
333    #[test(flavor = "multi_thread", shared)]
334    async fn backend_creates_database_with_unrestricted_privileges() {
335        let backend = create_backend(true).await.drop_previous_databases(false);
336        test_backend_creates_database_with_unrestricted_privileges(backend).await;
337    }
338
339    #[test(flavor = "multi_thread", shared)]
340    async fn backend_cleans_database_with_tables() {
341        let backend = create_backend(true).await.drop_previous_databases(false);
342        test_backend_cleans_database_with_tables(backend).await;
343    }
344
345    #[test(flavor = "multi_thread", shared)]
346    async fn backend_cleans_database_without_tables() {
347        let backend = create_backend(false).await.drop_previous_databases(false);
348        test_backend_cleans_database_without_tables(backend).await;
349    }
350
351    #[test(flavor = "multi_thread", shared)]
352    async fn backend_drops_restricted_database() {
353        let backend = create_backend(true).await.drop_previous_databases(false);
354        test_backend_drops_database(backend, true).await;
355    }
356
357    #[test(flavor = "multi_thread", shared)]
358    async fn backend_drops_unrestricted_database() {
359        let backend = create_backend(true).await.drop_previous_databases(false);
360        test_backend_drops_database(backend, false).await;
361    }
362
363    #[test(flavor = "multi_thread", shared)]
364    async fn pool_drops_previous_databases() {
365        test_pool_drops_previous_databases(
366            create_backend(false).await,
367            create_backend(false).await.drop_previous_databases(true),
368            create_backend(false).await.drop_previous_databases(false),
369        )
370        .await;
371    }
372
373    #[test(flavor = "multi_thread", shared)]
374    async fn pool_provides_isolated_databases() {
375        const NUM_DBS: i64 = 3;
376
377        let backend = create_backend(true).await.drop_previous_databases(false);
378
379        async {
380            let db_pool = backend.create_database_pool().await.unwrap();
381            let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
382
383            // insert single row into each database
384            join_all(
385                conn_pools
386                    .iter()
387                    .enumerate()
388                    .map(|(i, conn_pool)| async move {
389                        let conn = &mut conn_pool.get().await.unwrap();
390                        conn.execute(
391                            "INSERT INTO book (title) VALUES ($1)",
392                            &[&format!("Title {i}").as_str()],
393                        )
394                        .await
395                        .unwrap();
396                    }),
397            )
398            .await;
399
400            // rows fetched must be as inserted
401            join_all(
402                conn_pools
403                    .iter()
404                    .enumerate()
405                    .map(|(i, conn_pool)| async move {
406                        let conn = &mut conn_pool.get().await.unwrap();
407                        assert_eq!(
408                            conn.query("SELECT title FROM book", &[])
409                                .await
410                                .unwrap()
411                                .iter()
412                                .map(|row| row.get::<_, String>(0))
413                                .collect::<Vec<_>>(),
414                            vec![format!("Title {i}")]
415                        );
416                    }),
417            )
418            .await;
419        }
420        .lock_read()
421        .await;
422    }
423
424    #[test(flavor = "multi_thread", shared)]
425    async fn pool_provides_restricted_databases() {
426        let backend = create_backend(true).await.drop_previous_databases(false);
427
428        async {
429            let db_pool = backend.create_database_pool().await.unwrap();
430
431            let conn_pool = db_pool.pull_immutable().await;
432            let conn = &mut conn_pool.get().await.unwrap();
433
434            // DDL statements must fail
435            for stmt in DDL_STATEMENTS {
436                assert!(conn.execute(stmt, &[]).await.is_err());
437            }
438
439            // DML statements must succeed
440            for stmt in DML_STATEMENTS {
441                assert!(conn.execute(stmt, &[]).await.is_ok());
442            }
443        }
444        .lock_read()
445        .await;
446    }
447
448    #[test(flavor = "multi_thread", shared)]
449    async fn pool_provides_unrestricted_databases() {
450        let backend = create_backend(true).await.drop_previous_databases(false);
451
452        async {
453            let db_pool = backend.create_database_pool().await.unwrap();
454
455            // DML statements must succeed
456            {
457                let conn_pool = db_pool.create_mutable().await.unwrap();
458                let conn = &mut conn_pool.get().await.unwrap();
459                for stmt in DML_STATEMENTS {
460                    assert!(conn.execute(stmt, &[]).await.is_ok());
461                }
462            }
463
464            // DDL statements must succeed
465            for stmt in DDL_STATEMENTS {
466                let conn_pool = db_pool.create_mutable().await.unwrap();
467                let conn = &mut conn_pool.get().await.unwrap();
468                assert!(conn.execute(stmt, &[]).await.is_ok());
469            }
470        }
471        .lock_read()
472        .await;
473    }
474
475    #[test(flavor = "multi_thread", shared)]
476    async fn pool_provides_clean_databases() {
477        const NUM_DBS: i64 = 3;
478
479        let backend = create_backend(true).await.drop_previous_databases(false);
480
481        async {
482            let db_pool = backend.create_database_pool().await.unwrap();
483
484            // fetch connection pools the first time
485            {
486                let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
487
488                // databases must be empty
489                join_all(conn_pools.iter().map(|conn_pool| async move {
490                    let conn = &mut conn_pool.get().await.unwrap();
491                    assert_eq!(
492                        conn.query_one("SELECT COUNT(*) FROM book", &[])
493                            .await
494                            .unwrap()
495                            .get::<_, i64>(0),
496                        0
497                    );
498                }))
499                .await;
500
501                // insert data into each database
502                join_all(conn_pools.iter().map(|conn_pool| async move {
503                    let conn = &mut conn_pool.get().await.unwrap();
504                    conn.execute("INSERT INTO book (title) VALUES ($1)", &[&"Title"])
505                        .await
506                        .unwrap();
507                }))
508                .await;
509            }
510
511            // fetch same connection pools a second time
512            {
513                let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
514
515                // databases must be empty
516                join_all(conn_pools.iter().map(|conn_pool| async move {
517                    let conn = &mut conn_pool.get().await.unwrap();
518                    assert_eq!(
519                        conn.query_one("SELECT COUNT(*) FROM book", &[])
520                            .await
521                            .unwrap()
522                            .get::<_, i64>(0),
523                        0
524                    );
525                }))
526                .await;
527            }
528        }
529        .lock_read()
530        .await;
531    }
532
533    #[test(flavor = "multi_thread", shared)]
534    async fn pool_drops_created_restricted_databases() {
535        let backend = create_backend(false).await;
536        test_pool_drops_created_restricted_databases(backend).await;
537    }
538
539    #[test(flavor = "multi_thread", shared)]
540    async fn pool_drops_created_unrestricted_database() {
541        let backend = create_backend(false).await;
542        test_pool_drops_created_unrestricted_database(backend).await;
543    }
544}