db_pool/sync/backend/postgres/
postgres.rs1use std::{borrow::Cow, collections::HashMap, ops::Deref};
2
3use parking_lot::Mutex;
4use r2d2::{Builder, Pool, PooledConnection};
5use r2d2_postgres::{
6 postgres::{Client, Config, Error, NoTls},
7 PostgresConnectionManager,
8};
9use uuid::Uuid;
10
11use crate::{common::statement::postgres, util::get_db_name};
12
13use super::{
14 super::{error::Error as BackendError, r#trait::Backend},
15 r#trait::{PostgresBackend as PostgresBackendTrait, PostgresBackendWrapper},
16};
17
18type Manager = PostgresConnectionManager<NoTls>;
19
20pub struct PostgresBackend {
22 config: Config,
23 default_pool: Pool<Manager>,
24 db_conns: Mutex<HashMap<Uuid, Client>>,
25 create_restricted_pool: Box<dyn Fn() -> Builder<Manager> + Send + Sync + 'static>,
26 create_entities: Box<dyn Fn(&mut Client) + Send + Sync + 'static>,
27 drop_previous_databases_flag: bool,
28}
29
30impl PostgresBackend {
31 pub fn new(
57 config: Config,
58 create_privileged_pool: impl Fn() -> Builder<Manager>,
59 create_restricted_pool: impl Fn() -> Builder<Manager> + Send + Sync + 'static,
60 create_entities: impl Fn(&mut Client) + Send + Sync + 'static,
61 ) -> Result<Self, r2d2::Error> {
62 let manager = Manager::new(config.clone(), NoTls);
63 let default_pool = (create_privileged_pool()).build(manager)?;
64
65 Ok(Self {
66 config,
67 default_pool,
68 db_conns: Mutex::new(HashMap::new()),
69 create_restricted_pool: Box::new(create_restricted_pool),
70 create_entities: Box::new(create_entities),
71 drop_previous_databases_flag: true,
72 })
73 }
74
75 #[must_use]
77 pub fn drop_previous_databases(self, value: bool) -> Self {
78 Self {
79 drop_previous_databases_flag: value,
80 ..self
81 }
82 }
83}
84
85impl PostgresBackendTrait for PostgresBackend {
86 type ConnectionManager = Manager;
87 type ConnectionError = ConnectionError;
88 type QueryError = QueryError;
89
90 fn execute_query(&self, query: &str, conn: &mut Client) -> Result<(), QueryError> {
91 conn.execute(query, &[])?;
92 Ok(())
93 }
94
95 fn batch_execute_query<'a>(
96 &self,
97 query: impl IntoIterator<Item = Cow<'a, str>>,
98 conn: &mut Client,
99 ) -> Result<(), QueryError> {
100 let query = query.into_iter().collect::<Vec<_>>().join(";");
101 conn.batch_execute(query.as_str())?;
102 Ok(())
103 }
104
105 fn get_default_connection(&self) -> Result<PooledConnection<Manager>, r2d2::Error> {
106 self.default_pool.get()
107 }
108
109 fn establish_privileged_database_connection(
110 &self,
111 db_id: Uuid,
112 ) -> Result<Client, ConnectionError> {
113 let mut config = self.config.clone();
114 let db_name = get_db_name(db_id);
115 config.dbname(db_name.as_str());
116 config.connect(NoTls).map_err(Into::into)
117 }
118
119 fn establish_restricted_database_connection(
120 &self,
121 db_id: Uuid,
122 ) -> Result<Client, ConnectionError> {
123 let mut config = self.config.clone();
124 let db_name = get_db_name(db_id);
125 let db_name = db_name.as_str();
126 config.user(db_name).password(db_name).dbname(db_name);
127 config.connect(NoTls).map_err(Into::into)
128 }
129
130 fn put_database_connection(&self, db_id: Uuid, conn: Client) {
131 self.db_conns.lock().insert(db_id, conn);
132 }
133
134 fn get_database_connection(&self, db_id: Uuid) -> Client {
135 self.db_conns
136 .lock()
137 .remove(&db_id)
138 .unwrap_or_else(|| panic!("connection map must have a connection for {db_id}"))
139 }
140
141 fn get_previous_database_names(&self, conn: &mut Client) -> Result<Vec<String>, QueryError> {
142 conn.query(postgres::GET_DATABASE_NAMES, &[])
143 .map(|rows| rows.iter().map(|row| row.get(0)).collect())
144 .map_err(Into::into)
145 }
146
147 fn create_entities(&self, conn: &mut Client) {
148 (self.create_entities)(conn);
149 }
150
151 fn create_connection_pool(&self, db_id: Uuid) -> Result<Pool<Manager>, r2d2::Error> {
152 let mut config = self.config.clone();
153 let db_name = get_db_name(db_id);
154 let db_name = db_name.as_str();
155 config.dbname(db_name);
156 config.user(db_name);
157 config.password(db_name);
158 let manager = PostgresConnectionManager::new(config, NoTls);
159 (self.create_restricted_pool)().build(manager)
160 }
161
162 fn get_table_names(&self, conn: &mut Client) -> Result<Vec<String>, QueryError> {
163 conn.query(postgres::GET_TABLE_NAMES, &[])
164 .map(|rows| rows.iter().map(|row| row.get(0)).collect())
165 .map_err(Into::into)
166 }
167
168 fn get_drop_previous_databases(&self) -> bool {
169 self.drop_previous_databases_flag
170 }
171}
172
173#[derive(Debug)]
174pub struct ConnectionError(Error);
175
176impl Deref for ConnectionError {
177 type Target = Error;
178
179 fn deref(&self) -> &Self::Target {
180 &self.0
181 }
182}
183
184impl From<Error> for ConnectionError {
185 fn from(value: Error) -> Self {
186 Self(value)
187 }
188}
189
190#[derive(Debug)]
191pub struct QueryError(Error);
192
193impl Deref for QueryError {
194 type Target = Error;
195
196 fn deref(&self) -> &Self::Target {
197 &self.0
198 }
199}
200
201impl From<Error> for QueryError {
202 fn from(value: Error) -> Self {
203 Self(value)
204 }
205}
206
207impl From<ConnectionError> for BackendError<ConnectionError, QueryError> {
208 fn from(value: ConnectionError) -> Self {
209 Self::Connection(value)
210 }
211}
212
213impl From<QueryError> for BackendError<ConnectionError, QueryError> {
214 fn from(value: QueryError) -> Self {
215 Self::Query(value)
216 }
217}
218
219impl Backend for PostgresBackend {
220 type ConnectionManager = Manager;
221 type ConnectionError = ConnectionError;
222 type QueryError = QueryError;
223
224 fn init(&self) -> Result<(), BackendError<ConnectionError, QueryError>> {
225 PostgresBackendWrapper::new(self).init()
226 }
227
228 fn create(
229 &self,
230 db_id: Uuid,
231 restrict_privileges: bool,
232 ) -> Result<Pool<Manager>, BackendError<ConnectionError, QueryError>> {
233 PostgresBackendWrapper::new(self).create(db_id, restrict_privileges)
234 }
235
236 fn clean(&self, db_id: Uuid) -> Result<(), BackendError<ConnectionError, QueryError>> {
237 PostgresBackendWrapper::new(self).clean(db_id)
238 }
239
240 fn drop(
241 &self,
242 db_id: Uuid,
243 is_restricted: bool,
244 ) -> Result<(), BackendError<ConnectionError, QueryError>> {
245 PostgresBackendWrapper::new(self).drop(db_id, is_restricted)
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 #![allow(unused_variables, clippy::unwrap_used)]
252
253 use dotenvy::dotenv;
254 use r2d2::Pool;
255
256 use crate::{
257 common::statement::postgres::tests::{
258 CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
259 },
260 sync::{
261 backend::postgres::r#trait::tests::{
262 test_backend_creates_database_with_unrestricted_privileges,
263 test_pool_drops_created_unrestricted_database,
264 },
265 db_pool::DatabasePoolBuilder,
266 },
267 PrivilegedPostgresConfig,
268 };
269
270 use super::{
271 super::r#trait::tests::{
272 lock_read, test_backend_cleans_database_with_tables,
273 test_backend_cleans_database_without_tables,
274 test_backend_creates_database_with_restricted_privileges, test_backend_drops_database,
275 test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
276 test_pool_drops_previous_databases,
277 },
278 PostgresBackend,
279 };
280
281 fn create_backend(with_table: bool) -> PostgresBackend {
282 dotenv().ok();
283
284 let config = PrivilegedPostgresConfig::from_env().unwrap();
285
286 PostgresBackend::new(config.into(), Pool::builder, Pool::builder, {
287 move |conn| {
288 if with_table {
289 conn.batch_execute(&CREATE_ENTITIES_STATEMENTS.join(";"))
290 .unwrap();
291 }
292 }
293 })
294 .unwrap()
295 }
296
297 #[test]
298 fn backend_drops_previous_databases() {
299 test_backend_drops_previous_databases(
300 create_backend(false),
301 create_backend(false).drop_previous_databases(true),
302 create_backend(false).drop_previous_databases(false),
303 );
304 }
305
306 #[test]
307 fn backend_creates_database_with_restricted_privileges() {
308 let backend = create_backend(true).drop_previous_databases(false);
309 test_backend_creates_database_with_restricted_privileges(&backend);
310 }
311
312 #[test]
313 fn backend_creates_database_with_unrestricted_privileges() {
314 let backend = create_backend(true).drop_previous_databases(false);
315 test_backend_creates_database_with_unrestricted_privileges(&backend);
316 }
317
318 #[test]
319 fn backend_cleans_database_with_tables() {
320 let backend = create_backend(true).drop_previous_databases(false);
321 test_backend_cleans_database_with_tables(&backend);
322 }
323
324 #[test]
325 fn backend_cleans_database_without_tables() {
326 let backend = create_backend(false).drop_previous_databases(false);
327 test_backend_cleans_database_without_tables(&backend);
328 }
329
330 #[test]
331 fn backend_drops_restricted_database() {
332 let backend = create_backend(true).drop_previous_databases(false);
333 test_backend_drops_database(&backend, true);
334 }
335
336 #[test]
337 fn backend_drops_unrestricted_database() {
338 let backend = create_backend(true).drop_previous_databases(false);
339 test_backend_drops_database(&backend, false);
340 }
341
342 #[test]
343 fn pool_drops_previous_databases() {
344 test_pool_drops_previous_databases(
345 create_backend(false),
346 create_backend(false).drop_previous_databases(true),
347 create_backend(false).drop_previous_databases(false),
348 );
349 }
350
351 #[test]
352 fn pool_provides_isolated_databases() {
353 const NUM_DBS: i64 = 3;
354
355 let backend = create_backend(true).drop_previous_databases(false);
356
357 let guard = lock_read();
358
359 let db_pool = backend.create_database_pool().unwrap();
360 let conn_pools = (0..NUM_DBS)
361 .map(|_| db_pool.pull_immutable())
362 .collect::<Vec<_>>();
363
364 conn_pools.iter().enumerate().for_each(|(i, conn_pool)| {
366 let conn = &mut conn_pool.get().unwrap();
367 conn.execute(
368 "INSERT INTO book (title) VALUES ($1)",
369 &[&format!("Title {i}").as_str()],
370 )
371 .unwrap();
372 });
373
374 conn_pools.iter().enumerate().for_each(|(i, conn_pool)| {
376 let conn = &mut conn_pool.get().unwrap();
377 assert_eq!(
378 conn.query("SELECT title FROM book", &[])
379 .unwrap()
380 .iter()
381 .map(|row| row.get::<_, String>(0))
382 .collect::<Vec<_>>(),
383 vec![format!("Title {i}")]
384 );
385 });
386 }
387
388 #[test]
389 fn pool_provides_restricted_databases() {
390 let backend = create_backend(true).drop_previous_databases(false);
391
392 let guard = lock_read();
393
394 let db_pool = backend.create_database_pool().unwrap();
395
396 let conn_pool = db_pool.pull_immutable();
397 let conn = &mut conn_pool.get().unwrap();
398
399 for stmt in DDL_STATEMENTS {
401 assert!(conn.execute(stmt, &[]).is_err());
402 }
403
404 for stmt in DML_STATEMENTS {
406 assert!(conn.execute(stmt, &[]).is_ok());
407 }
408 }
409
410 #[test]
411 fn pool_provides_unrestricted_databases() {
412 let backend = create_backend(true).drop_previous_databases(false);
413
414 let guard = lock_read();
415
416 let db_pool = backend.create_database_pool().unwrap();
417
418 {
420 let conn_pool = db_pool.create_mutable().unwrap();
421 let conn = &mut conn_pool.get().unwrap();
422 for stmt in DML_STATEMENTS {
423 assert!(conn.execute(stmt, &[]).is_ok());
424 }
425 }
426
427 for stmt in DDL_STATEMENTS {
429 let conn_pool = db_pool.create_mutable().unwrap();
430 let conn = &mut conn_pool.get().unwrap();
431 assert!(conn.execute(stmt, &[]).is_ok());
432 }
433 }
434
435 #[test]
436 fn pool_provides_clean_databases() {
437 const NUM_DBS: i64 = 3;
438
439 let backend = create_backend(true).drop_previous_databases(false);
440
441 let guard = lock_read();
442
443 let db_pool = backend.create_database_pool().unwrap();
444
445 {
447 let conn_pools = (0..NUM_DBS)
448 .map(|_| db_pool.pull_immutable())
449 .collect::<Vec<_>>();
450
451 for conn_pool in &conn_pools {
453 let conn = &mut conn_pool.get().unwrap();
454 assert_eq!(
455 conn.query_one("SELECT COUNT(*) FROM book", &[])
456 .unwrap()
457 .get::<_, i64>(0),
458 0
459 );
460 }
461
462 for conn_pool in &conn_pools {
464 let conn = &mut conn_pool.get().unwrap();
465 conn.execute("INSERT INTO book (title) VALUES ($1)", &[&"Title"])
466 .unwrap();
467 }
468 }
469
470 {
472 let conn_pools = (0..NUM_DBS)
473 .map(|_| db_pool.pull_immutable())
474 .collect::<Vec<_>>();
475
476 for conn_pool in &conn_pools {
478 let conn = &mut conn_pool.get().unwrap();
479 assert_eq!(
480 conn.query_one("SELECT COUNT(*) FROM book", &[])
481 .unwrap()
482 .get::<_, i64>(0),
483 0
484 );
485 }
486 }
487 }
488
489 #[test]
490 fn pool_drops_created_restricted_databases() {
491 let backend = create_backend(false);
492 test_pool_drops_created_restricted_databases(backend);
493 }
494
495 #[test]
496 fn pool_drops_created_unrestricted_database() {
497 let backend = create_backend(false);
498 test_pool_drops_created_unrestricted_database(backend);
499 }
500}