db_pool/async/backend/postgres/
sqlx.rs

1use std::{borrow::Cow, collections::HashMap, pin::Pin};
2
3use async_trait::async_trait;
4use futures::Future;
5use parking_lot::Mutex;
6use sqlx::{
7    Connection, Executor, PgConnection, PgPool, Postgres, Row,
8    pool::PoolConnection,
9    postgres::{PgConnectOptions, PgPoolOptions},
10};
11use uuid::Uuid;
12
13use crate::{common::statement::postgres, util::get_db_name};
14
15use super::{
16    super::{
17        common::error::sqlx::{BuildError, ConnectionError, PoolError, QueryError},
18        error::Error as BackendError,
19        r#trait::Backend,
20    },
21    r#trait::{PostgresBackend, PostgresBackendWrapper},
22};
23
24type CreateEntities = dyn Fn(PgConnection) -> Pin<Box<dyn Future<Output = PgConnection> + Send + 'static>>
25    + Send
26    + Sync
27    + 'static;
28
29/// [`sqlx Postgres`](https://docs.rs/sqlx/0.8.6/sqlx/struct.Postgres.html) backend
30pub struct SqlxPostgresBackend {
31    privileged_opts: PgConnectOptions,
32    default_pool: PgPool,
33    db_conns: Mutex<HashMap<Uuid, PgConnection>>,
34    create_restricted_pool: Box<dyn Fn() -> PgPoolOptions + Send + Sync + 'static>,
35    create_entities: Box<CreateEntities>,
36    drop_previous_databases_flag: bool,
37}
38
39impl SqlxPostgresBackend {
40    /// Creates a new [`sqlx Postgres`](https://docs.rs/sqlx/0.8.6/sqlx/struct.Postgres.html) backend
41    /// # Example
42    /// ```
43    /// use db_pool::{r#async::SqlxPostgresBackend, PrivilegedPostgresConfig};
44    /// use dotenvy::dotenv;
45    /// use sqlx::{postgres::PgPoolOptions, Executor};
46    ///
47    /// async fn f() {
48    ///     dotenv().ok();
49    ///
50    ///     let config = PrivilegedPostgresConfig::from_env().unwrap();
51    ///
52    ///     let backend = SqlxPostgresBackend::new(
53    ///         config.into(),
54    ///         || PgPoolOptions::new().max_connections(10),
55    ///         || PgPoolOptions::new().max_connections(2),
56    ///         move |mut conn| {
57    ///             Box::pin(async move {
58    ///                 conn.execute("CREATE TABLE book(id SERIAL PRIMARY KEY, title TEXT NOT NULL)")
59    ///                     .await
60    ///                     .unwrap();
61    ///                 conn
62    ///             })
63    ///         },
64    ///     );
65    /// }
66    ///
67    /// tokio_test::block_on(f());
68    /// ```
69    pub fn new(
70        privileged_options: PgConnectOptions,
71        create_privileged_pool: impl Fn() -> PgPoolOptions,
72        create_restricted_pool: impl Fn() -> PgPoolOptions + Send + Sync + 'static,
73        create_entities: impl Fn(
74            PgConnection,
75        )
76            -> Pin<Box<dyn Future<Output = PgConnection> + Send + 'static>>
77        + Send
78        + Sync
79        + 'static,
80    ) -> Self {
81        let pool_opts = create_privileged_pool();
82        let default_pool = pool_opts.connect_lazy_with(privileged_options.clone());
83
84        Self {
85            privileged_opts: privileged_options,
86            default_pool,
87            db_conns: Mutex::new(HashMap::new()),
88            create_restricted_pool: Box::new(create_restricted_pool),
89            create_entities: Box::new(create_entities),
90            drop_previous_databases_flag: true,
91        }
92    }
93
94    /// Drop databases created in previous runs upon initialization
95    #[must_use]
96    pub fn drop_previous_databases(self, value: bool) -> Self {
97        Self {
98            drop_previous_databases_flag: value,
99            ..self
100        }
101    }
102}
103
104#[async_trait]
105impl<'pool> PostgresBackend<'pool> for SqlxPostgresBackend {
106    type Connection = PgConnection;
107    type PooledConnection = PoolConnection<Postgres>;
108    type Pool = PgPool;
109
110    type BuildError = BuildError;
111    type PoolError = PoolError;
112    type ConnectionError = ConnectionError;
113    type QueryError = QueryError;
114
115    async fn execute_query(&self, query: &str, conn: &mut PgConnection) -> Result<(), QueryError> {
116        conn.execute(query).await?;
117        Ok(())
118    }
119
120    async fn batch_execute_query<'a>(
121        &self,
122        query: impl IntoIterator<Item = Cow<'a, str>> + Send,
123        conn: &mut PgConnection,
124    ) -> Result<(), QueryError> {
125        let query = query.into_iter().collect::<Vec<_>>().join(";");
126        self.execute_query(query.as_str(), conn).await
127    }
128
129    async fn get_default_connection(&'pool self) -> Result<PoolConnection<Postgres>, PoolError> {
130        self.default_pool.acquire().await.map_err(Into::into)
131    }
132
133    async fn establish_privileged_database_connection(
134        &self,
135        db_id: Uuid,
136    ) -> Result<PgConnection, ConnectionError> {
137        let db_name = get_db_name(db_id);
138        let opts = self.privileged_opts.clone().database(db_name.as_str());
139        PgConnection::connect_with(&opts).await.map_err(Into::into)
140    }
141
142    async fn establish_restricted_database_connection(
143        &self,
144        db_id: Uuid,
145    ) -> Result<PgConnection, ConnectionError> {
146        let db_name = get_db_name(db_id);
147        let db_name = db_name.as_str();
148        let opts = self
149            .privileged_opts
150            .clone()
151            .username(db_name)
152            .password(db_name)
153            .database(db_name);
154        PgConnection::connect_with(&opts).await.map_err(Into::into)
155    }
156
157    fn put_database_connection(&self, db_id: Uuid, conn: PgConnection) {
158        self.db_conns.lock().insert(db_id, conn);
159    }
160
161    fn get_database_connection(&self, db_id: Uuid) -> PgConnection {
162        self.db_conns
163            .lock()
164            .remove(&db_id)
165            .unwrap_or_else(|| panic!("connection map must have a connection for {db_id}"))
166    }
167
168    async fn get_previous_database_names(
169        &self,
170        conn: &mut PgConnection,
171    ) -> Result<Vec<String>, QueryError> {
172        conn.fetch_all(postgres::GET_DATABASE_NAMES)
173            .await?
174            .iter()
175            .map(|row| row.try_get(0))
176            .collect::<Result<Vec<_>, _>>()
177            .map_err(Into::into)
178    }
179
180    async fn create_entities(&self, conn: PgConnection) -> Option<PgConnection> {
181        Some((self.create_entities)(conn).await)
182    }
183
184    async fn create_connection_pool(&self, db_id: Uuid) -> Result<PgPool, BuildError> {
185        let db_name = get_db_name(db_id);
186        let db_name = db_name.as_str();
187        let opts = self
188            .privileged_opts
189            .clone()
190            .database(db_name)
191            .username(db_name)
192            .password(db_name);
193        let pool = (self.create_restricted_pool)().connect_lazy_with(opts);
194        Ok(pool)
195    }
196
197    async fn get_table_names(&self, conn: &mut PgConnection) -> Result<Vec<String>, QueryError> {
198        conn.fetch_all(postgres::GET_TABLE_NAMES)
199            .await?
200            .iter()
201            .map(|row| row.try_get(0))
202            .collect::<Result<Vec<_>, _>>()
203            .map_err(Into::into)
204    }
205
206    fn get_drop_previous_databases(&self) -> bool {
207        self.drop_previous_databases_flag
208    }
209}
210
211type BError = BackendError<BuildError, PoolError, ConnectionError, QueryError>;
212
213#[async_trait]
214impl Backend for SqlxPostgresBackend {
215    type Pool = PgPool;
216
217    type BuildError = BuildError;
218    type PoolError = PoolError;
219    type ConnectionError = ConnectionError;
220    type QueryError = QueryError;
221
222    async fn init(&self) -> Result<(), BError> {
223        PostgresBackendWrapper::new(self).init().await
224    }
225
226    async fn create(&self, db_id: uuid::Uuid, restrict_privileges: bool) -> Result<PgPool, BError> {
227        PostgresBackendWrapper::new(self)
228            .create(db_id, restrict_privileges)
229            .await
230    }
231
232    async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError> {
233        PostgresBackendWrapper::new(self).clean(db_id).await
234    }
235
236    async fn drop(&self, db_id: uuid::Uuid, is_restricted: bool) -> Result<(), BError> {
237        PostgresBackendWrapper::new(self)
238            .drop(db_id, is_restricted)
239            .await
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    #![allow(clippy::unwrap_used, clippy::needless_return)]
246
247    use futures::{StreamExt, future::join_all};
248    use sqlx::{
249        Executor, FromRow, Row,
250        postgres::{PgConnectOptions, PgPoolOptions},
251        query, query_as,
252    };
253    use tokio_shared_rt::test;
254
255    use crate::{
256        r#async::{
257            backend::postgres::r#trait::tests::{
258                test_backend_creates_database_with_unrestricted_privileges,
259                test_backend_drops_database, test_pool_drops_created_unrestricted_database,
260            },
261            db_pool::DatabasePoolBuilder,
262        },
263        common::statement::postgres::tests::{
264            CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
265        },
266    };
267
268    use super::{
269        super::r#trait::tests::{
270            PgDropLock, test_backend_cleans_database_with_tables,
271            test_backend_cleans_database_without_tables,
272            test_backend_creates_database_with_restricted_privileges,
273            test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
274            test_pool_drops_previous_databases,
275        },
276        SqlxPostgresBackend,
277    };
278
279    fn create_backend(with_table: bool) -> SqlxPostgresBackend {
280        SqlxPostgresBackend::new(
281            PgConnectOptions::new()
282                .username("postgres")
283                .password("postgres"),
284            PgPoolOptions::new,
285            PgPoolOptions::new,
286            {
287                move |mut conn| {
288                    if with_table {
289                        Box::pin(async move {
290                            conn.execute_many(CREATE_ENTITIES_STATEMENTS.join(";").as_str())
291                                .collect::<Vec<_>>()
292                                .await
293                                .drain(..)
294                                .collect::<Result<Vec<_>, _>>()
295                                .unwrap();
296                            conn
297                        })
298                    } else {
299                        Box::pin(async { conn })
300                    }
301                }
302            },
303        )
304    }
305
306    #[test(flavor = "multi_thread", shared)]
307    async fn backend_drops_previous_databases() {
308        test_backend_drops_previous_databases(
309            create_backend(false),
310            create_backend(false).drop_previous_databases(true),
311            create_backend(false).drop_previous_databases(false),
312        )
313        .await;
314    }
315
316    #[test(flavor = "multi_thread", shared)]
317    async fn backend_creates_database_with_restricted_privileges() {
318        let backend = create_backend(true).drop_previous_databases(false);
319        test_backend_creates_database_with_restricted_privileges(backend).await;
320    }
321
322    #[test(flavor = "multi_thread", shared)]
323    async fn backend_creates_database_with_unrestricted_privileges() {
324        let backend = create_backend(true).drop_previous_databases(false);
325        test_backend_creates_database_with_unrestricted_privileges(backend).await;
326    }
327
328    #[test(flavor = "multi_thread", shared)]
329    async fn backend_cleans_database_with_tables() {
330        let backend = create_backend(true).drop_previous_databases(false);
331        test_backend_cleans_database_with_tables(backend).await;
332    }
333
334    #[test(flavor = "multi_thread", shared)]
335    async fn backend_cleans_database_without_tables() {
336        let backend = create_backend(false).drop_previous_databases(false);
337        test_backend_cleans_database_without_tables(backend).await;
338    }
339
340    #[test(flavor = "multi_thread", shared)]
341    async fn backend_drops_restricted_database() {
342        let backend = create_backend(true).drop_previous_databases(false);
343        test_backend_drops_database(backend, true).await;
344    }
345
346    #[test(flavor = "multi_thread", shared)]
347    async fn backend_drops_unrestricted_database() {
348        let backend = create_backend(true).drop_previous_databases(false);
349        test_backend_drops_database(backend, false).await;
350    }
351
352    #[test(flavor = "multi_thread", shared)]
353    async fn pool_drops_previous_databases() {
354        test_pool_drops_previous_databases(
355            create_backend(false),
356            create_backend(false).drop_previous_databases(true),
357            create_backend(false).drop_previous_databases(false),
358        )
359        .await;
360    }
361
362    #[test(flavor = "multi_thread", shared)]
363    async fn pool_provides_isolated_databases() {
364        #[derive(FromRow, Eq, PartialEq, Debug)]
365        struct Book {
366            title: String,
367        }
368
369        const NUM_DBS: i64 = 3;
370
371        let backend = create_backend(true).drop_previous_databases(false);
372
373        async {
374            let db_pool = backend.create_database_pool().await.unwrap();
375            let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
376
377            // insert single row into each database
378            join_all(
379                conn_pools
380                    .iter()
381                    .enumerate()
382                    .map(|(i, conn_pool)| async move {
383                        query("INSERT INTO book (title) VALUES ($1)")
384                            .bind(format!("Title {i}"))
385                            .execute(&***conn_pool)
386                            .await
387                            .unwrap();
388                    }),
389            )
390            .await;
391
392            // rows fetched must be as inserted
393            join_all(
394                conn_pools
395                    .iter()
396                    .enumerate()
397                    .map(|(i, conn_pool)| async move {
398                        assert_eq!(
399                            query_as::<_, Book>("SELECT title FROM book")
400                                .fetch_all(&***conn_pool)
401                                .await
402                                .unwrap(),
403                            vec![Book {
404                                title: format!("Title {i}")
405                            }]
406                        );
407                    }),
408            )
409            .await;
410        }
411        .lock_read()
412        .await;
413    }
414
415    #[test(flavor = "multi_thread", shared)]
416    async fn pool_provides_restricted_databases() {
417        let backend = create_backend(true).drop_previous_databases(false);
418
419        async {
420            let db_pool = backend.create_database_pool().await.unwrap();
421
422            let conn_pool = db_pool.pull_immutable().await;
423            let conn = &mut conn_pool.acquire().await.unwrap();
424
425            // DDL statements must fail
426            for stmt in DDL_STATEMENTS {
427                assert!(conn.execute(stmt).await.is_err());
428            }
429
430            // DML statements must succeed
431            for stmt in DML_STATEMENTS {
432                assert!(conn.execute(stmt).await.is_ok());
433            }
434        }
435        .lock_read()
436        .await;
437    }
438
439    #[test(flavor = "multi_thread", shared)]
440    async fn pool_provides_unrestricted_databases() {
441        let backend = create_backend(true).drop_previous_databases(false);
442
443        async {
444            let db_pool = backend.create_database_pool().await.unwrap();
445
446            // DML statements must succeed
447            {
448                let conn_pool = db_pool.create_mutable().await.unwrap();
449                let conn = &mut conn_pool.acquire().await.unwrap();
450                for stmt in DML_STATEMENTS {
451                    assert!(conn.execute(stmt).await.is_ok());
452                }
453            }
454
455            // DDL statements must succeed
456            for stmt in DDL_STATEMENTS {
457                let conn_pool = db_pool.create_mutable().await.unwrap();
458                let conn = &mut conn_pool.acquire().await.unwrap();
459                assert!(conn.execute(stmt).await.is_ok());
460            }
461        }
462        .lock_read()
463        .await;
464    }
465
466    #[test(flavor = "multi_thread", shared)]
467    async fn pool_provides_clean_databases() {
468        const NUM_DBS: i64 = 3;
469
470        let backend = create_backend(true).drop_previous_databases(false);
471
472        async {
473            let db_pool = backend.create_database_pool().await.unwrap();
474
475            // fetch connection pools the first time
476            {
477                let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
478
479                // databases must be empty
480                join_all(conn_pools.iter().map(|conn_pool| async move {
481                    assert_eq!(
482                        query("SELECT COUNT(*) FROM book")
483                            .fetch_one(&***conn_pool)
484                            .await
485                            .unwrap()
486                            .get::<i64, _>(0),
487                        0
488                    );
489                }))
490                .await;
491
492                // insert data into each database
493                join_all(conn_pools.iter().map(|conn_pool| async move {
494                    query("INSERT INTO book (title) VALUES ($1)")
495                        .bind("Title")
496                        .execute(&***conn_pool)
497                        .await
498                        .unwrap();
499                }))
500                .await;
501            }
502
503            // fetch same connection pools a second time
504            {
505                let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
506
507                // databases must be empty
508                join_all(conn_pools.iter().map(|conn_pool| async move {
509                    assert_eq!(
510                        query("SELECT COUNT(*) FROM book")
511                            .fetch_one(&***conn_pool)
512                            .await
513                            .unwrap()
514                            .get::<i64, _>(0),
515                        0
516                    );
517                }))
518                .await;
519            }
520        }
521        .lock_read()
522        .await;
523    }
524
525    #[test(flavor = "multi_thread", shared)]
526    async fn pool_drops_created_restricted_databases() {
527        let backend = create_backend(false);
528        test_pool_drops_created_restricted_databases(backend).await;
529    }
530
531    #[test(flavor = "multi_thread", shared)]
532    async fn pool_drops_created_unrestricted_database() {
533        let backend = create_backend(false);
534        test_pool_drops_created_unrestricted_database(backend).await;
535    }
536}