db_pool/async/backend/postgres/
sqlx.rs

1use std::{borrow::Cow, collections::HashMap, pin::Pin};
2
3use async_trait::async_trait;
4use futures::Future;
5use parking_lot::Mutex;
6use sqlx::{
7    pool::PoolConnection,
8    postgres::{PgConnectOptions, PgPoolOptions},
9    Connection, Executor, PgConnection, PgPool, Postgres, Row,
10};
11use uuid::Uuid;
12
13use crate::{common::statement::postgres, util::get_db_name};
14
15use super::{
16    super::{
17        common::error::sqlx::{BuildError, ConnectionError, PoolError, QueryError},
18        error::Error as BackendError,
19        r#trait::Backend,
20    },
21    r#trait::{PostgresBackend, PostgresBackendWrapper},
22};
23
24type CreateEntities = dyn Fn(PgConnection) -> Pin<Box<dyn Future<Output = PgConnection> + Send + 'static>>
25    + Send
26    + Sync
27    + 'static;
28
29/// [`sqlx Postgres`](https://docs.rs/sqlx/0.8.2/sqlx/struct.Postgres.html) backend
30pub struct SqlxPostgresBackend {
31    privileged_opts: PgConnectOptions,
32    default_pool: PgPool,
33    db_conns: Mutex<HashMap<Uuid, PgConnection>>,
34    create_restricted_pool: Box<dyn Fn() -> PgPoolOptions + Send + Sync + 'static>,
35    create_entities: Box<CreateEntities>,
36    drop_previous_databases_flag: bool,
37}
38
39impl SqlxPostgresBackend {
40    /// Creates a new [`sqlx Postgres`](https://docs.rs/sqlx/0.8.2/sqlx/struct.Postgres.html) backend
41    /// # Example
42    /// ```
43    /// use db_pool::{r#async::SqlxPostgresBackend, PrivilegedPostgresConfig};
44    /// use dotenvy::dotenv;
45    /// use sqlx::{postgres::PgPoolOptions, Executor};
46    ///
47    /// async fn f() {
48    ///     dotenv().ok();
49    ///
50    ///     let config = PrivilegedPostgresConfig::from_env().unwrap();
51    ///
52    ///     let backend = SqlxPostgresBackend::new(
53    ///         config.into(),
54    ///         || PgPoolOptions::new().max_connections(10),
55    ///         || PgPoolOptions::new().max_connections(2),
56    ///         move |mut conn| {
57    ///             Box::pin(async move {
58    ///                 conn.execute("CREATE TABLE book(id SERIAL PRIMARY KEY, title TEXT NOT NULL)")
59    ///                     .await
60    ///                     .unwrap();
61    ///                 conn
62    ///             })
63    ///         },
64    ///     );
65    /// }
66    ///
67    /// tokio_test::block_on(f());
68    /// ```
69    pub fn new(
70        privileged_options: PgConnectOptions,
71        create_privileged_pool: impl Fn() -> PgPoolOptions,
72        create_restricted_pool: impl Fn() -> PgPoolOptions + Send + Sync + 'static,
73        create_entities: impl Fn(PgConnection) -> Pin<Box<dyn Future<Output = PgConnection> + Send + 'static>>
74            + Send
75            + Sync
76            + 'static,
77    ) -> Self {
78        let pool_opts = create_privileged_pool();
79        let default_pool = pool_opts.connect_lazy_with(privileged_options.clone());
80
81        Self {
82            privileged_opts: privileged_options,
83            default_pool,
84            db_conns: Mutex::new(HashMap::new()),
85            create_restricted_pool: Box::new(create_restricted_pool),
86            create_entities: Box::new(create_entities),
87            drop_previous_databases_flag: true,
88        }
89    }
90
91    /// Drop databases created in previous runs upon initialization
92    #[must_use]
93    pub fn drop_previous_databases(self, value: bool) -> Self {
94        Self {
95            drop_previous_databases_flag: value,
96            ..self
97        }
98    }
99}
100
101#[async_trait]
102impl<'pool> PostgresBackend<'pool> for SqlxPostgresBackend {
103    type Connection = PgConnection;
104    type PooledConnection = PoolConnection<Postgres>;
105    type Pool = PgPool;
106
107    type BuildError = BuildError;
108    type PoolError = PoolError;
109    type ConnectionError = ConnectionError;
110    type QueryError = QueryError;
111
112    async fn execute_query(&self, query: &str, conn: &mut PgConnection) -> Result<(), QueryError> {
113        conn.execute(query).await?;
114        Ok(())
115    }
116
117    async fn batch_execute_query<'a>(
118        &self,
119        query: impl IntoIterator<Item = Cow<'a, str>> + Send,
120        conn: &mut PgConnection,
121    ) -> Result<(), QueryError> {
122        let query = query.into_iter().collect::<Vec<_>>().join(";");
123        self.execute_query(query.as_str(), conn).await
124    }
125
126    async fn get_default_connection(&'pool self) -> Result<PoolConnection<Postgres>, PoolError> {
127        self.default_pool.acquire().await.map_err(Into::into)
128    }
129
130    async fn establish_privileged_database_connection(
131        &self,
132        db_id: Uuid,
133    ) -> Result<PgConnection, ConnectionError> {
134        let db_name = get_db_name(db_id);
135        let opts = self.privileged_opts.clone().database(db_name.as_str());
136        PgConnection::connect_with(&opts).await.map_err(Into::into)
137    }
138
139    async fn establish_restricted_database_connection(
140        &self,
141        db_id: Uuid,
142    ) -> Result<PgConnection, ConnectionError> {
143        let db_name = get_db_name(db_id);
144        let db_name = db_name.as_str();
145        let opts = self
146            .privileged_opts
147            .clone()
148            .username(db_name)
149            .password(db_name)
150            .database(db_name);
151        PgConnection::connect_with(&opts).await.map_err(Into::into)
152    }
153
154    fn put_database_connection(&self, db_id: Uuid, conn: PgConnection) {
155        self.db_conns.lock().insert(db_id, conn);
156    }
157
158    fn get_database_connection(&self, db_id: Uuid) -> PgConnection {
159        self.db_conns
160            .lock()
161            .remove(&db_id)
162            .unwrap_or_else(|| panic!("connection map must have a connection for {db_id}"))
163    }
164
165    async fn get_previous_database_names(
166        &self,
167        conn: &mut PgConnection,
168    ) -> Result<Vec<String>, QueryError> {
169        conn.fetch_all(postgres::GET_DATABASE_NAMES)
170            .await?
171            .iter()
172            .map(|row| row.try_get(0))
173            .collect::<Result<Vec<_>, _>>()
174            .map_err(Into::into)
175    }
176
177    async fn create_entities(&self, conn: PgConnection) -> PgConnection {
178        (self.create_entities)(conn).await
179    }
180
181    async fn create_connection_pool(&self, db_id: Uuid) -> Result<PgPool, BuildError> {
182        let db_name = get_db_name(db_id);
183        let db_name = db_name.as_str();
184        let opts = self
185            .privileged_opts
186            .clone()
187            .database(db_name)
188            .username(db_name)
189            .password(db_name);
190        let pool = (self.create_restricted_pool)().connect_lazy_with(opts);
191        Ok(pool)
192    }
193
194    async fn get_table_names(&self, conn: &mut PgConnection) -> Result<Vec<String>, QueryError> {
195        conn.fetch_all(postgres::GET_TABLE_NAMES)
196            .await?
197            .iter()
198            .map(|row| row.try_get(0))
199            .collect::<Result<Vec<_>, _>>()
200            .map_err(Into::into)
201    }
202
203    fn get_drop_previous_databases(&self) -> bool {
204        self.drop_previous_databases_flag
205    }
206}
207
208type BError = BackendError<BuildError, PoolError, ConnectionError, QueryError>;
209
210#[async_trait]
211impl Backend for SqlxPostgresBackend {
212    type Pool = PgPool;
213
214    type BuildError = BuildError;
215    type PoolError = PoolError;
216    type ConnectionError = ConnectionError;
217    type QueryError = QueryError;
218
219    async fn init(&self) -> Result<(), BError> {
220        PostgresBackendWrapper::new(self).init().await
221    }
222
223    async fn create(&self, db_id: uuid::Uuid, restrict_privileges: bool) -> Result<PgPool, BError> {
224        PostgresBackendWrapper::new(self)
225            .create(db_id, restrict_privileges)
226            .await
227    }
228
229    async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError> {
230        PostgresBackendWrapper::new(self).clean(db_id).await
231    }
232
233    async fn drop(&self, db_id: uuid::Uuid, is_restricted: bool) -> Result<(), BError> {
234        PostgresBackendWrapper::new(self)
235            .drop(db_id, is_restricted)
236            .await
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    #![allow(clippy::unwrap_used, clippy::needless_return)]
243
244    use futures::{future::join_all, StreamExt};
245    use sqlx::{
246        postgres::{PgConnectOptions, PgPoolOptions},
247        query, query_as, Executor, FromRow, Row,
248    };
249    use tokio_shared_rt::test;
250
251    use crate::{
252        common::statement::postgres::tests::{
253            CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
254        },
255        r#async::{
256            backend::postgres::r#trait::tests::{
257                test_backend_creates_database_with_unrestricted_privileges,
258                test_backend_drops_database, test_pool_drops_created_unrestricted_database,
259            },
260            db_pool::DatabasePoolBuilder,
261        },
262    };
263
264    use super::{
265        super::r#trait::tests::{
266            test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables,
267            test_backend_creates_database_with_restricted_privileges,
268            test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
269            test_pool_drops_previous_databases, PgDropLock,
270        },
271        SqlxPostgresBackend,
272    };
273
274    fn create_backend(with_table: bool) -> SqlxPostgresBackend {
275        SqlxPostgresBackend::new(
276            PgConnectOptions::new()
277                .username("postgres")
278                .password("postgres"),
279            PgPoolOptions::new,
280            PgPoolOptions::new,
281            {
282                move |mut conn| {
283                    if with_table {
284                        Box::pin(async move {
285                            conn.execute_many(CREATE_ENTITIES_STATEMENTS.join(";").as_str())
286                                .collect::<Vec<_>>()
287                                .await
288                                .drain(..)
289                                .collect::<Result<Vec<_>, _>>()
290                                .unwrap();
291                            conn
292                        })
293                    } else {
294                        Box::pin(async { conn })
295                    }
296                }
297            },
298        )
299    }
300
301    #[test(flavor = "multi_thread", shared)]
302    async fn backend_drops_previous_databases() {
303        test_backend_drops_previous_databases(
304            create_backend(false),
305            create_backend(false).drop_previous_databases(true),
306            create_backend(false).drop_previous_databases(false),
307        )
308        .await;
309    }
310
311    #[test(flavor = "multi_thread", shared)]
312    async fn backend_creates_database_with_restricted_privileges() {
313        let backend = create_backend(true).drop_previous_databases(false);
314        test_backend_creates_database_with_restricted_privileges(backend).await;
315    }
316
317    #[test(flavor = "multi_thread", shared)]
318    async fn backend_creates_database_with_unrestricted_privileges() {
319        let backend = create_backend(true).drop_previous_databases(false);
320        test_backend_creates_database_with_unrestricted_privileges(backend).await;
321    }
322
323    #[test(flavor = "multi_thread", shared)]
324    async fn backend_cleans_database_with_tables() {
325        let backend = create_backend(true).drop_previous_databases(false);
326        test_backend_cleans_database_with_tables(backend).await;
327    }
328
329    #[test(flavor = "multi_thread", shared)]
330    async fn backend_cleans_database_without_tables() {
331        let backend = create_backend(false).drop_previous_databases(false);
332        test_backend_cleans_database_without_tables(backend).await;
333    }
334
335    #[test(flavor = "multi_thread", shared)]
336    async fn backend_drops_restricted_database() {
337        let backend = create_backend(true).drop_previous_databases(false);
338        test_backend_drops_database(backend, true).await;
339    }
340
341    #[test(flavor = "multi_thread", shared)]
342    async fn backend_drops_unrestricted_database() {
343        let backend = create_backend(true).drop_previous_databases(false);
344        test_backend_drops_database(backend, false).await;
345    }
346
347    #[test(flavor = "multi_thread", shared)]
348    async fn pool_drops_previous_databases() {
349        test_pool_drops_previous_databases(
350            create_backend(false),
351            create_backend(false).drop_previous_databases(true),
352            create_backend(false).drop_previous_databases(false),
353        )
354        .await;
355    }
356
357    #[test(flavor = "multi_thread", shared)]
358    async fn pool_provides_isolated_databases() {
359        #[derive(FromRow, Eq, PartialEq, Debug)]
360        struct Book {
361            title: String,
362        }
363
364        const NUM_DBS: i64 = 3;
365
366        let backend = create_backend(true).drop_previous_databases(false);
367
368        async {
369            let db_pool = backend.create_database_pool().await.unwrap();
370            let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
371
372            // insert single row into each database
373            join_all(
374                conn_pools
375                    .iter()
376                    .enumerate()
377                    .map(|(i, conn_pool)| async move {
378                        query("INSERT INTO book (title) VALUES ($1)")
379                            .bind(format!("Title {i}"))
380                            .execute(&***conn_pool)
381                            .await
382                            .unwrap();
383                    }),
384            )
385            .await;
386
387            // rows fetched must be as inserted
388            join_all(
389                conn_pools
390                    .iter()
391                    .enumerate()
392                    .map(|(i, conn_pool)| async move {
393                        assert_eq!(
394                            query_as::<_, Book>("SELECT title FROM book")
395                                .fetch_all(&***conn_pool)
396                                .await
397                                .unwrap(),
398                            vec![Book {
399                                title: format!("Title {i}")
400                            }]
401                        );
402                    }),
403            )
404            .await;
405        }
406        .lock_read()
407        .await;
408    }
409
410    #[test(flavor = "multi_thread", shared)]
411    async fn pool_provides_restricted_databases() {
412        let backend = create_backend(true).drop_previous_databases(false);
413
414        async {
415            let db_pool = backend.create_database_pool().await.unwrap();
416
417            let conn_pool = db_pool.pull_immutable().await;
418            let conn = &mut conn_pool.acquire().await.unwrap();
419
420            // DDL statements must fail
421            for stmt in DDL_STATEMENTS {
422                assert!(conn.execute(stmt).await.is_err());
423            }
424
425            // DML statements must succeed
426            for stmt in DML_STATEMENTS {
427                assert!(conn.execute(stmt).await.is_ok());
428            }
429        }
430        .lock_read()
431        .await;
432    }
433
434    #[test(flavor = "multi_thread", shared)]
435    async fn pool_provides_unrestricted_databases() {
436        let backend = create_backend(true).drop_previous_databases(false);
437
438        async {
439            let db_pool = backend.create_database_pool().await.unwrap();
440
441            // DML statements must succeed
442            {
443                let conn_pool = db_pool.create_mutable().await.unwrap();
444                let conn = &mut conn_pool.acquire().await.unwrap();
445                for stmt in DML_STATEMENTS {
446                    assert!(conn.execute(stmt).await.is_ok());
447                }
448            }
449
450            // DDL statements must succeed
451            for stmt in DDL_STATEMENTS {
452                let conn_pool = db_pool.create_mutable().await.unwrap();
453                let conn = &mut conn_pool.acquire().await.unwrap();
454                assert!(conn.execute(stmt).await.is_ok());
455            }
456        }
457        .lock_read()
458        .await;
459    }
460
461    #[test(flavor = "multi_thread", shared)]
462    async fn pool_provides_clean_databases() {
463        const NUM_DBS: i64 = 3;
464
465        let backend = create_backend(true).drop_previous_databases(false);
466
467        async {
468            let db_pool = backend.create_database_pool().await.unwrap();
469
470            // fetch connection pools the first time
471            {
472                let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
473
474                // databases must be empty
475                join_all(conn_pools.iter().map(|conn_pool| async move {
476                    assert_eq!(
477                        query("SELECT COUNT(*) FROM book")
478                            .fetch_one(&***conn_pool)
479                            .await
480                            .unwrap()
481                            .get::<i64, _>(0),
482                        0
483                    );
484                }))
485                .await;
486
487                // insert data into each database
488                join_all(conn_pools.iter().map(|conn_pool| async move {
489                    query("INSERT INTO book (title) VALUES ($1)")
490                        .bind("Title")
491                        .execute(&***conn_pool)
492                        .await
493                        .unwrap();
494                }))
495                .await;
496            }
497
498            // fetch same connection pools a second time
499            {
500                let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
501
502                // databases must be empty
503                join_all(conn_pools.iter().map(|conn_pool| async move {
504                    assert_eq!(
505                        query("SELECT COUNT(*) FROM book")
506                            .fetch_one(&***conn_pool)
507                            .await
508                            .unwrap()
509                            .get::<i64, _>(0),
510                        0
511                    );
512                }))
513                .await;
514            }
515        }
516        .lock_read()
517        .await;
518    }
519
520    #[test(flavor = "multi_thread", shared)]
521    async fn pool_drops_created_restricted_databases() {
522        let backend = create_backend(false);
523        test_pool_drops_created_restricted_databases(backend).await;
524    }
525
526    #[test(flavor = "multi_thread", shared)]
527    async fn pool_drops_created_unrestricted_database() {
528        let backend = create_backend(false);
529        test_pool_drops_created_unrestricted_database(backend).await;
530    }
531}