db_pool/sync/backend/postgres/
diesel.rs

1use std::{borrow::Cow, collections::HashMap};
2
3use diesel::{
4    QueryResult, RunQueryDsl, connection::SimpleConnection, pg::PgConnection, prelude::*,
5    r2d2::ConnectionManager, result::Error, sql_query,
6};
7use parking_lot::Mutex;
8use r2d2::{Builder, Pool, PooledConnection};
9use uuid::Uuid;
10
11use crate::{common::config::postgres::PrivilegedPostgresConfig, util::get_db_name};
12
13use super::{
14    super::{error::Error as BackendError, r#trait::Backend},
15    r#trait::{PostgresBackend, PostgresBackendWrapper},
16};
17
18type Manager = ConnectionManager<PgConnection>;
19
20/// [`Diesel Postgres`](https://docs.rs/diesel/2.2.11/diesel/pg/struct.PgConnection.html) backend
21pub struct DieselPostgresBackend {
22    privileged_config: PrivilegedPostgresConfig,
23    default_pool: Pool<Manager>,
24    db_conns: Mutex<HashMap<Uuid, PgConnection>>,
25    create_restricted_pool: Box<dyn Fn() -> Builder<Manager> + Send + Sync + 'static>,
26    create_entities: Box<dyn Fn(&mut PgConnection) + Send + Sync + 'static>,
27    drop_previous_databases_flag: bool,
28}
29
30impl DieselPostgresBackend {
31    /// Creates a new [`Diesel Postgres`](https://docs.rs/diesel/2.2.11/diesel/pg/struct.PgConnection.html) backend
32    /// # Example
33    /// ```
34    /// use db_pool::{sync::DieselPostgresBackend, PrivilegedPostgresConfig};
35    /// use diesel::{sql_query, RunQueryDsl};
36    /// use dotenvy::dotenv;
37    /// use r2d2::Pool;
38    ///
39    /// dotenv().ok();
40    ///
41    /// let config = PrivilegedPostgresConfig::from_env().unwrap();
42    ///
43    /// let backend = DieselPostgresBackend::new(
44    ///     config,
45    ///     || Pool::builder().max_size(10),
46    ///     || Pool::builder().max_size(2),
47    ///     move |conn| {
48    ///         sql_query("CREATE TABLE book(id SERIAL PRIMARY KEY, title TEXT NOT NULL)")
49    ///             .execute(conn)
50    ///             .unwrap();
51    ///     },
52    /// )
53    /// .unwrap();
54    /// ```
55    pub fn new(
56        privileged_config: PrivilegedPostgresConfig,
57        create_privileged_pool: impl Fn() -> Builder<Manager>,
58        create_restricted_pool: impl Fn() -> Builder<Manager> + Send + Sync + 'static,
59        create_entities: impl Fn(&mut PgConnection) + Send + Sync + 'static,
60    ) -> Result<Self, r2d2::Error> {
61        let manager = Manager::new(privileged_config.default_connection_url());
62        let default_pool = (create_privileged_pool()).build(manager)?;
63
64        Ok(Self {
65            privileged_config,
66            default_pool,
67            db_conns: Mutex::new(HashMap::new()),
68            create_entities: Box::new(create_entities),
69            create_restricted_pool: Box::new(create_restricted_pool),
70            drop_previous_databases_flag: true,
71        })
72    }
73
74    /// Drop databases created in previous runs upon initialization
75    #[must_use]
76    pub fn drop_previous_databases(self, value: bool) -> Self {
77        Self {
78            drop_previous_databases_flag: value,
79            ..self
80        }
81    }
82}
83
84impl PostgresBackend for DieselPostgresBackend {
85    type ConnectionManager = Manager;
86    type ConnectionError = ConnectionError;
87    type QueryError = Error;
88
89    fn execute_query(&self, query: &str, conn: &mut PgConnection) -> QueryResult<()> {
90        sql_query(query).execute(conn)?;
91        Ok(())
92    }
93
94    fn batch_execute_query<'a>(
95        &self,
96        query: impl IntoIterator<Item = Cow<'a, str>>,
97        conn: &mut PgConnection,
98    ) -> QueryResult<()> {
99        let query = query.into_iter().collect::<Vec<_>>();
100        if query.is_empty() {
101            Ok(())
102        } else {
103            conn.batch_execute(query.join(";").as_str())
104        }
105    }
106
107    fn get_default_connection(&self) -> Result<PooledConnection<Manager>, r2d2::Error> {
108        self.default_pool.get()
109    }
110
111    fn establish_privileged_database_connection(
112        &self,
113        db_id: Uuid,
114    ) -> ConnectionResult<PgConnection> {
115        let db_name = get_db_name(db_id);
116        let database_url = self
117            .privileged_config
118            .privileged_database_connection_url(db_name.as_str());
119        PgConnection::establish(database_url.as_str())
120    }
121
122    fn establish_restricted_database_connection(
123        &self,
124        db_id: Uuid,
125    ) -> ConnectionResult<PgConnection> {
126        let db_name = get_db_name(db_id);
127        let db_name = db_name.as_str();
128        let database_url = self.privileged_config.restricted_database_connection_url(
129            db_name,
130            Some(db_name),
131            db_name,
132        );
133        PgConnection::establish(database_url.as_str())
134    }
135
136    fn put_database_connection(&self, db_id: Uuid, conn: PgConnection) {
137        self.db_conns.lock().insert(db_id, conn);
138    }
139
140    fn get_database_connection(&self, db_id: Uuid) -> PgConnection {
141        self.db_conns
142            .lock()
143            .remove(&db_id)
144            .unwrap_or_else(|| panic!("connection map must have a connection for {db_id}"))
145    }
146
147    fn get_previous_database_names(&self, conn: &mut PgConnection) -> QueryResult<Vec<String>> {
148        table! {
149            pg_database (oid) {
150                oid -> Int4,
151                datname -> Text
152            }
153        }
154
155        pg_database::table
156            .select(pg_database::datname)
157            .filter(pg_database::datname.like("db_pool_%"))
158            .load::<String>(conn)
159    }
160
161    fn create_entities(&self, conn: &mut PgConnection) {
162        (self.create_entities)(conn);
163    }
164
165    fn create_connection_pool(
166        &self,
167        db_id: Uuid,
168    ) -> Result<Pool<Self::ConnectionManager>, r2d2::Error> {
169        let db_name = get_db_name(db_id);
170        let db_name = db_name.as_str();
171        let database_url = self.privileged_config.restricted_database_connection_url(
172            db_name,
173            Some(db_name),
174            db_name,
175        );
176        let manager = ConnectionManager::<PgConnection>::new(database_url.as_str());
177        (self.create_restricted_pool)().build(manager)
178    }
179
180    fn get_table_names(&self, conn: &mut PgConnection) -> QueryResult<Vec<String>> {
181        table! {
182            pg_tables (tablename) {
183                #[sql_name = "schemaname"]
184                schema_name -> Text,
185                tablename -> Text
186            }
187        }
188
189        pg_tables::table
190            .filter(pg_tables::schema_name.ne_all(["pg_catalog", "information_schema"]))
191            .select(pg_tables::tablename)
192            .load(conn)
193    }
194
195    fn get_drop_previous_databases(&self) -> bool {
196        self.drop_previous_databases_flag
197    }
198}
199
200impl Backend for DieselPostgresBackend {
201    type ConnectionManager = Manager;
202    type ConnectionError = ConnectionError;
203    type QueryError = Error;
204
205    fn init(&self) -> Result<(), BackendError<ConnectionError, Error>> {
206        PostgresBackendWrapper::new(self).init()
207    }
208
209    fn create(
210        &self,
211        db_id: Uuid,
212        restrict_privileges: bool,
213    ) -> Result<Pool<Manager>, BackendError<ConnectionError, Error>> {
214        PostgresBackendWrapper::new(self).create(db_id, restrict_privileges)
215    }
216
217    fn clean(&self, db_id: Uuid) -> Result<(), BackendError<ConnectionError, Error>> {
218        PostgresBackendWrapper::new(self).clean(db_id)
219    }
220
221    fn drop(
222        &self,
223        db_id: Uuid,
224        is_restricted: bool,
225    ) -> Result<(), BackendError<ConnectionError, Error>> {
226        PostgresBackendWrapper::new(self).drop(db_id, is_restricted)
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    #![allow(unused_variables, clippy::unwrap_used, clippy::needless_return)]
233
234    use std::borrow::Cow;
235
236    use diesel::{
237        Insertable, QueryDsl, RunQueryDsl, connection::SimpleConnection, insert_into, sql_query,
238        table,
239    };
240    use dotenvy::dotenv;
241    use r2d2::Pool;
242
243    use crate::{
244        common::{
245            config::PrivilegedPostgresConfig,
246            statement::postgres::tests::{
247                CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
248            },
249        },
250        sync::{
251            backend::postgres::r#trait::tests::test_backend_creates_database_with_unrestricted_privileges,
252            db_pool::DatabasePoolBuilder,
253        },
254    };
255
256    use super::{
257        super::r#trait::tests::{
258            lock_read, test_backend_cleans_database_with_tables,
259            test_backend_cleans_database_without_tables,
260            test_backend_creates_database_with_restricted_privileges, test_backend_drops_database,
261            test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
262            test_pool_drops_created_unrestricted_database, test_pool_drops_previous_databases,
263        },
264        DieselPostgresBackend,
265    };
266
267    table! {
268        book (id) {
269            id -> Int4,
270            title -> Text
271        }
272    }
273
274    #[derive(Insertable)]
275    #[diesel(table_name = book)]
276    struct NewBook<'a> {
277        title: Cow<'a, str>,
278    }
279
280    fn create_backend(with_table: bool) -> DieselPostgresBackend {
281        dotenv().ok();
282
283        let config = PrivilegedPostgresConfig::from_env().unwrap();
284
285        DieselPostgresBackend::new(config, Pool::builder, Pool::builder, {
286            move |conn| {
287                if with_table {
288                    let query = CREATE_ENTITIES_STATEMENTS.join(";");
289                    conn.batch_execute(query.as_str()).unwrap();
290                }
291            }
292        })
293        .unwrap()
294    }
295
296    #[test]
297    fn backend_drops_previous_databases() {
298        test_backend_drops_previous_databases(
299            create_backend(false),
300            create_backend(false).drop_previous_databases(true),
301            create_backend(false).drop_previous_databases(false),
302        );
303    }
304
305    #[test]
306    fn backend_creates_database_with_restricted_privileges() {
307        let backend = create_backend(true).drop_previous_databases(false);
308        test_backend_creates_database_with_restricted_privileges(&backend);
309    }
310
311    #[test]
312    fn backend_creates_database_with_unrestricted_privileges() {
313        let backend = create_backend(true).drop_previous_databases(false);
314        test_backend_creates_database_with_unrestricted_privileges(&backend);
315    }
316
317    #[test]
318    fn backend_cleans_database_with_tables() {
319        let backend = create_backend(true).drop_previous_databases(false);
320        test_backend_cleans_database_with_tables(&backend);
321    }
322
323    #[test]
324    fn backend_cleans_database_without_tables() {
325        let backend = create_backend(false).drop_previous_databases(false);
326        test_backend_cleans_database_without_tables(&backend);
327    }
328
329    #[test]
330    fn backend_drops_restricted_database() {
331        let backend = create_backend(true).drop_previous_databases(false);
332        test_backend_drops_database(&backend, true);
333    }
334
335    #[test]
336    fn backend_drops_unrestricted_database() {
337        let backend = create_backend(true).drop_previous_databases(false);
338        test_backend_drops_database(&backend, false);
339    }
340
341    #[test]
342    fn pool_drops_previous_databases() {
343        test_pool_drops_previous_databases(
344            create_backend(false),
345            create_backend(false).drop_previous_databases(true),
346            create_backend(false).drop_previous_databases(false),
347        );
348    }
349
350    #[test]
351    fn pool_provides_isolated_databases() {
352        const NUM_DBS: i64 = 3;
353
354        let backend = create_backend(true).drop_previous_databases(false);
355
356        let guard = lock_read();
357
358        let db_pool = backend.create_database_pool().unwrap();
359        let conn_pools = (0..NUM_DBS)
360            .map(|_| db_pool.pull_immutable())
361            .collect::<Vec<_>>();
362
363        // insert single row into each database
364        conn_pools.iter().enumerate().for_each(|(i, conn_pool)| {
365            let conn = &mut conn_pool.get().unwrap();
366            insert_into(book::table)
367                .values(NewBook {
368                    title: format!("Title {i}").into(),
369                })
370                .execute(conn)
371                .unwrap();
372        });
373
374        // rows fetched must be as inserted
375        conn_pools.iter().enumerate().for_each(|(i, conn_pool)| {
376            let conn = &mut conn_pool.get().unwrap();
377            assert_eq!(
378                book::table
379                    .select(book::title)
380                    .load::<String>(conn)
381                    .unwrap(),
382                vec![format!("Title {i}")]
383            );
384        });
385    }
386
387    #[test]
388    fn pool_provides_restricted_databases() {
389        let backend = create_backend(true).drop_previous_databases(false);
390
391        let guard = lock_read();
392
393        let db_pool = backend.create_database_pool().unwrap();
394        let conn_pool = db_pool.pull_immutable();
395        let conn = &mut conn_pool.get().unwrap();
396
397        // DDL statements must fail
398        for stmt in DDL_STATEMENTS {
399            assert!(sql_query(stmt).execute(conn).is_err());
400        }
401
402        // DML statements must succeed
403        for stmt in DML_STATEMENTS {
404            assert!(sql_query(stmt).execute(conn).is_ok());
405        }
406    }
407
408    #[test]
409    fn pool_provides_unrestricted_databases() {
410        let backend = create_backend(true).drop_previous_databases(false);
411
412        let guard = lock_read();
413
414        let db_pool = backend.create_database_pool().unwrap();
415
416        // DML statements must succeed
417        {
418            let conn_pool = db_pool.create_mutable().unwrap();
419            let conn = &mut conn_pool.get().unwrap();
420            for stmt in DML_STATEMENTS {
421                assert!(sql_query(stmt).execute(conn).is_ok());
422            }
423        }
424
425        // DDL statements must succeed
426        for stmt in DDL_STATEMENTS {
427            let conn_pool = db_pool.create_mutable().unwrap();
428            let conn = &mut conn_pool.get().unwrap();
429            assert!(sql_query(stmt).execute(conn).is_ok());
430        }
431    }
432
433    #[test]
434    fn pool_provides_clean_databases() {
435        const NUM_DBS: i64 = 3;
436
437        let backend = create_backend(true).drop_previous_databases(false);
438
439        let guard = lock_read();
440
441        let db_pool = backend.create_database_pool().unwrap();
442
443        // fetch connection pools the first time
444        {
445            let conn_pools = (0..NUM_DBS)
446                .map(|_| db_pool.pull_immutable())
447                .collect::<Vec<_>>();
448
449            // databases must be empty
450            for conn_pool in &conn_pools {
451                let conn = &mut conn_pool.get().unwrap();
452                assert_eq!(book::table.count().get_result::<i64>(conn).unwrap(), 0);
453            }
454
455            // insert data into each database
456            for conn_pool in &conn_pools {
457                let conn = &mut conn_pool.get().unwrap();
458                insert_into(book::table)
459                    .values(NewBook {
460                        title: "Title".into(),
461                    })
462                    .execute(conn)
463                    .unwrap();
464            }
465        }
466
467        // fetch same connection pools a second time
468        {
469            let conn_pools = (0..NUM_DBS)
470                .map(|_| db_pool.pull_immutable())
471                .collect::<Vec<_>>();
472
473            // databases must be empty
474            for conn_pool in &conn_pools {
475                let conn = &mut conn_pool.get().unwrap();
476                assert_eq!(book::table.count().get_result::<i64>(conn).unwrap(), 0);
477            }
478        }
479    }
480
481    #[test]
482    fn pool_drops_created_restricted_databases() {
483        let backend = create_backend(false);
484        test_pool_drops_created_restricted_databases(backend);
485    }
486
487    #[test]
488    fn pool_drops_created_unrestricted_database() {
489        let backend = create_backend(false);
490        test_pool_drops_created_unrestricted_database(backend);
491    }
492}