db_pool/sync/backend/postgres/
postgres.rs

1use std::{borrow::Cow, collections::HashMap, ops::Deref};
2
3use parking_lot::Mutex;
4use r2d2::{Builder, Pool, PooledConnection};
5use r2d2_postgres::{
6    postgres::{Client, Config, Error, NoTls},
7    PostgresConnectionManager,
8};
9use uuid::Uuid;
10
11use crate::{common::statement::postgres, util::get_db_name};
12
13use super::{
14    super::{error::Error as BackendError, r#trait::Backend},
15    r#trait::{PostgresBackend as PostgresBackendTrait, PostgresBackendWrapper},
16};
17
18type Manager = PostgresConnectionManager<NoTls>;
19
20/// Postgres backend
21pub struct PostgresBackend {
22    config: Config,
23    default_pool: Pool<Manager>,
24    db_conns: Mutex<HashMap<Uuid, Client>>,
25    create_restricted_pool: Box<dyn Fn() -> Builder<Manager> + Send + Sync + 'static>,
26    create_entities: Box<dyn Fn(&mut Client) + Send + Sync + 'static>,
27    drop_previous_databases_flag: bool,
28}
29
30impl PostgresBackend {
31    /// Creates a new Postgres backend
32    /// # Example
33    /// ```
34    /// use db_pool::{sync::PostgresBackend, PrivilegedPostgresConfig};
35    /// use r2d2::Pool;
36    /// use dotenvy::dotenv;
37    ///
38    /// dotenv().ok();
39    ///
40    /// let config = PrivilegedPostgresConfig::from_env().unwrap();
41    ///
42    /// let backend = PostgresBackend::new(
43    ///     config.into(),
44    ///     || Pool::builder().max_size(10),
45    ///     || Pool::builder().max_size(2),
46    ///     move |conn| {
47    ///         conn.query(
48    ///             "CREATE TABLE book(id SERIAL PRIMARY KEY, title TEXT NOT NULL)",
49    ///             &[],
50    ///         )
51    ///         .unwrap();
52    ///     },
53    /// )
54    /// .unwrap();
55    /// ```
56    pub fn new(
57        config: Config,
58        create_privileged_pool: impl Fn() -> Builder<Manager>,
59        create_restricted_pool: impl Fn() -> Builder<Manager> + Send + Sync + 'static,
60        create_entities: impl Fn(&mut Client) + Send + Sync + 'static,
61    ) -> Result<Self, r2d2::Error> {
62        let manager = Manager::new(config.clone(), NoTls);
63        let default_pool = (create_privileged_pool()).build(manager)?;
64
65        Ok(Self {
66            config,
67            default_pool,
68            db_conns: Mutex::new(HashMap::new()),
69            create_restricted_pool: Box::new(create_restricted_pool),
70            create_entities: Box::new(create_entities),
71            drop_previous_databases_flag: true,
72        })
73    }
74
75    /// Drop databases created in previous runs upon initialization
76    #[must_use]
77    pub fn drop_previous_databases(self, value: bool) -> Self {
78        Self {
79            drop_previous_databases_flag: value,
80            ..self
81        }
82    }
83}
84
85impl PostgresBackendTrait for PostgresBackend {
86    type ConnectionManager = Manager;
87    type ConnectionError = ConnectionError;
88    type QueryError = QueryError;
89
90    fn execute_query(&self, query: &str, conn: &mut Client) -> Result<(), QueryError> {
91        conn.execute(query, &[])?;
92        Ok(())
93    }
94
95    fn batch_execute_query<'a>(
96        &self,
97        query: impl IntoIterator<Item = Cow<'a, str>>,
98        conn: &mut Client,
99    ) -> Result<(), QueryError> {
100        let query = query.into_iter().collect::<Vec<_>>().join(";");
101        conn.batch_execute(query.as_str())?;
102        Ok(())
103    }
104
105    fn get_default_connection(&self) -> Result<PooledConnection<Manager>, r2d2::Error> {
106        self.default_pool.get()
107    }
108
109    fn establish_privileged_database_connection(
110        &self,
111        db_id: Uuid,
112    ) -> Result<Client, ConnectionError> {
113        let mut config = self.config.clone();
114        let db_name = get_db_name(db_id);
115        config.dbname(db_name.as_str());
116        config.connect(NoTls).map_err(Into::into)
117    }
118
119    fn establish_restricted_database_connection(
120        &self,
121        db_id: Uuid,
122    ) -> Result<Client, ConnectionError> {
123        let mut config = self.config.clone();
124        let db_name = get_db_name(db_id);
125        let db_name = db_name.as_str();
126        config.user(db_name).password(db_name).dbname(db_name);
127        config.connect(NoTls).map_err(Into::into)
128    }
129
130    fn put_database_connection(&self, db_id: Uuid, conn: Client) {
131        self.db_conns.lock().insert(db_id, conn);
132    }
133
134    fn get_database_connection(&self, db_id: Uuid) -> Client {
135        self.db_conns
136            .lock()
137            .remove(&db_id)
138            .unwrap_or_else(|| panic!("connection map must have a connection for {db_id}"))
139    }
140
141    fn get_previous_database_names(&self, conn: &mut Client) -> Result<Vec<String>, QueryError> {
142        conn.query(postgres::GET_DATABASE_NAMES, &[])
143            .map(|rows| rows.iter().map(|row| row.get(0)).collect())
144            .map_err(Into::into)
145    }
146
147    fn create_entities(&self, conn: &mut Client) {
148        (self.create_entities)(conn);
149    }
150
151    fn create_connection_pool(&self, db_id: Uuid) -> Result<Pool<Manager>, r2d2::Error> {
152        let mut config = self.config.clone();
153        let db_name = get_db_name(db_id);
154        let db_name = db_name.as_str();
155        config.dbname(db_name);
156        config.user(db_name);
157        config.password(db_name);
158        let manager = PostgresConnectionManager::new(config, NoTls);
159        (self.create_restricted_pool)().build(manager)
160    }
161
162    fn get_table_names(&self, conn: &mut Client) -> Result<Vec<String>, QueryError> {
163        conn.query(postgres::GET_TABLE_NAMES, &[])
164            .map(|rows| rows.iter().map(|row| row.get(0)).collect())
165            .map_err(Into::into)
166    }
167
168    fn get_drop_previous_databases(&self) -> bool {
169        self.drop_previous_databases_flag
170    }
171}
172
173#[derive(Debug)]
174pub struct ConnectionError(Error);
175
176impl Deref for ConnectionError {
177    type Target = Error;
178
179    fn deref(&self) -> &Self::Target {
180        &self.0
181    }
182}
183
184impl From<Error> for ConnectionError {
185    fn from(value: Error) -> Self {
186        Self(value)
187    }
188}
189
190#[derive(Debug)]
191pub struct QueryError(Error);
192
193impl Deref for QueryError {
194    type Target = Error;
195
196    fn deref(&self) -> &Self::Target {
197        &self.0
198    }
199}
200
201impl From<Error> for QueryError {
202    fn from(value: Error) -> Self {
203        Self(value)
204    }
205}
206
207impl From<ConnectionError> for BackendError<ConnectionError, QueryError> {
208    fn from(value: ConnectionError) -> Self {
209        Self::Connection(value)
210    }
211}
212
213impl From<QueryError> for BackendError<ConnectionError, QueryError> {
214    fn from(value: QueryError) -> Self {
215        Self::Query(value)
216    }
217}
218
219impl Backend for PostgresBackend {
220    type ConnectionManager = Manager;
221    type ConnectionError = ConnectionError;
222    type QueryError = QueryError;
223
224    fn init(&self) -> Result<(), BackendError<ConnectionError, QueryError>> {
225        PostgresBackendWrapper::new(self).init()
226    }
227
228    fn create(
229        &self,
230        db_id: Uuid,
231        restrict_privileges: bool,
232    ) -> Result<Pool<Manager>, BackendError<ConnectionError, QueryError>> {
233        PostgresBackendWrapper::new(self).create(db_id, restrict_privileges)
234    }
235
236    fn clean(&self, db_id: Uuid) -> Result<(), BackendError<ConnectionError, QueryError>> {
237        PostgresBackendWrapper::new(self).clean(db_id)
238    }
239
240    fn drop(
241        &self,
242        db_id: Uuid,
243        is_restricted: bool,
244    ) -> Result<(), BackendError<ConnectionError, QueryError>> {
245        PostgresBackendWrapper::new(self).drop(db_id, is_restricted)
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    #![allow(unused_variables, clippy::unwrap_used)]
252
253    use dotenvy::dotenv;
254    use r2d2::Pool;
255
256    use crate::{
257        common::statement::postgres::tests::{
258            CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
259        },
260        sync::{
261            backend::postgres::r#trait::tests::{
262                test_backend_creates_database_with_unrestricted_privileges,
263                test_pool_drops_created_unrestricted_database,
264            },
265            db_pool::DatabasePoolBuilder,
266        },
267        PrivilegedPostgresConfig,
268    };
269
270    use super::{
271        super::r#trait::tests::{
272            lock_read, test_backend_cleans_database_with_tables,
273            test_backend_cleans_database_without_tables,
274            test_backend_creates_database_with_restricted_privileges, test_backend_drops_database,
275            test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
276            test_pool_drops_previous_databases,
277        },
278        PostgresBackend,
279    };
280
281    fn create_backend(with_table: bool) -> PostgresBackend {
282        dotenv().ok();
283
284        let config = PrivilegedPostgresConfig::from_env().unwrap();
285
286        PostgresBackend::new(config.into(), Pool::builder, Pool::builder, {
287            move |conn| {
288                if with_table {
289                    conn.batch_execute(&CREATE_ENTITIES_STATEMENTS.join(";"))
290                        .unwrap();
291                }
292            }
293        })
294        .unwrap()
295    }
296
297    #[test]
298    fn backend_drops_previous_databases() {
299        test_backend_drops_previous_databases(
300            create_backend(false),
301            create_backend(false).drop_previous_databases(true),
302            create_backend(false).drop_previous_databases(false),
303        );
304    }
305
306    #[test]
307    fn backend_creates_database_with_restricted_privileges() {
308        let backend = create_backend(true).drop_previous_databases(false);
309        test_backend_creates_database_with_restricted_privileges(&backend);
310    }
311
312    #[test]
313    fn backend_creates_database_with_unrestricted_privileges() {
314        let backend = create_backend(true).drop_previous_databases(false);
315        test_backend_creates_database_with_unrestricted_privileges(&backend);
316    }
317
318    #[test]
319    fn backend_cleans_database_with_tables() {
320        let backend = create_backend(true).drop_previous_databases(false);
321        test_backend_cleans_database_with_tables(&backend);
322    }
323
324    #[test]
325    fn backend_cleans_database_without_tables() {
326        let backend = create_backend(false).drop_previous_databases(false);
327        test_backend_cleans_database_without_tables(&backend);
328    }
329
330    #[test]
331    fn backend_drops_restricted_database() {
332        let backend = create_backend(true).drop_previous_databases(false);
333        test_backend_drops_database(&backend, true);
334    }
335
336    #[test]
337    fn backend_drops_unrestricted_database() {
338        let backend = create_backend(true).drop_previous_databases(false);
339        test_backend_drops_database(&backend, false);
340    }
341
342    #[test]
343    fn pool_drops_previous_databases() {
344        test_pool_drops_previous_databases(
345            create_backend(false),
346            create_backend(false).drop_previous_databases(true),
347            create_backend(false).drop_previous_databases(false),
348        );
349    }
350
351    #[test]
352    fn pool_provides_isolated_databases() {
353        const NUM_DBS: i64 = 3;
354
355        let backend = create_backend(true).drop_previous_databases(false);
356
357        let guard = lock_read();
358
359        let db_pool = backend.create_database_pool().unwrap();
360        let conn_pools = (0..NUM_DBS)
361            .map(|_| db_pool.pull_immutable())
362            .collect::<Vec<_>>();
363
364        // insert single row into each database
365        conn_pools.iter().enumerate().for_each(|(i, conn_pool)| {
366            let conn = &mut conn_pool.get().unwrap();
367            conn.execute(
368                "INSERT INTO book (title) VALUES ($1)",
369                &[&format!("Title {i}").as_str()],
370            )
371            .unwrap();
372        });
373
374        // rows fetched must be as inserted
375        conn_pools.iter().enumerate().for_each(|(i, conn_pool)| {
376            let conn = &mut conn_pool.get().unwrap();
377            assert_eq!(
378                conn.query("SELECT title FROM book", &[])
379                    .unwrap()
380                    .iter()
381                    .map(|row| row.get::<_, String>(0))
382                    .collect::<Vec<_>>(),
383                vec![format!("Title {i}")]
384            );
385        });
386    }
387
388    #[test]
389    fn pool_provides_restricted_databases() {
390        let backend = create_backend(true).drop_previous_databases(false);
391
392        let guard = lock_read();
393
394        let db_pool = backend.create_database_pool().unwrap();
395
396        let conn_pool = db_pool.pull_immutable();
397        let conn = &mut conn_pool.get().unwrap();
398
399        // DDL statements must fail
400        for stmt in DDL_STATEMENTS {
401            assert!(conn.execute(stmt, &[]).is_err());
402        }
403
404        // DML statements must succeed
405        for stmt in DML_STATEMENTS {
406            assert!(conn.execute(stmt, &[]).is_ok());
407        }
408    }
409
410    #[test]
411    fn pool_provides_unrestricted_databases() {
412        let backend = create_backend(true).drop_previous_databases(false);
413
414        let guard = lock_read();
415
416        let db_pool = backend.create_database_pool().unwrap();
417
418        // DML statements must succeed
419        {
420            let conn_pool = db_pool.create_mutable().unwrap();
421            let conn = &mut conn_pool.get().unwrap();
422            for stmt in DML_STATEMENTS {
423                assert!(conn.execute(stmt, &[]).is_ok());
424            }
425        }
426
427        // DDL statements must succeed
428        for stmt in DDL_STATEMENTS {
429            let conn_pool = db_pool.create_mutable().unwrap();
430            let conn = &mut conn_pool.get().unwrap();
431            assert!(conn.execute(stmt, &[]).is_ok());
432        }
433    }
434
435    #[test]
436    fn pool_provides_clean_databases() {
437        const NUM_DBS: i64 = 3;
438
439        let backend = create_backend(true).drop_previous_databases(false);
440
441        let guard = lock_read();
442
443        let db_pool = backend.create_database_pool().unwrap();
444
445        // fetch connection pools the first time
446        {
447            let conn_pools = (0..NUM_DBS)
448                .map(|_| db_pool.pull_immutable())
449                .collect::<Vec<_>>();
450
451            // databases must be empty
452            for conn_pool in &conn_pools {
453                let conn = &mut conn_pool.get().unwrap();
454                assert_eq!(
455                    conn.query_one("SELECT COUNT(*) FROM book", &[])
456                        .unwrap()
457                        .get::<_, i64>(0),
458                    0
459                );
460            }
461
462            // insert data into each database
463            for conn_pool in &conn_pools {
464                let conn = &mut conn_pool.get().unwrap();
465                conn.execute("INSERT INTO book (title) VALUES ($1)", &[&"Title"])
466                    .unwrap();
467            }
468        }
469
470        // fetch same connection pools a second time
471        {
472            let conn_pools = (0..NUM_DBS)
473                .map(|_| db_pool.pull_immutable())
474                .collect::<Vec<_>>();
475
476            // databases must be empty
477            for conn_pool in &conn_pools {
478                let conn = &mut conn_pool.get().unwrap();
479                assert_eq!(
480                    conn.query_one("SELECT COUNT(*) FROM book", &[])
481                        .unwrap()
482                        .get::<_, i64>(0),
483                    0
484                );
485            }
486        }
487    }
488
489    #[test]
490    fn pool_drops_created_restricted_databases() {
491        let backend = create_backend(false);
492        test_pool_drops_created_restricted_databases(backend);
493    }
494
495    #[test]
496    fn pool_drops_created_unrestricted_database() {
497        let backend = create_backend(false);
498        test_pool_drops_created_unrestricted_database(backend);
499    }
500}