db_pool/sync/backend/postgres/
diesel.rs1use std::{borrow::Cow, collections::HashMap};
2
3use diesel::{
4 QueryResult, RunQueryDsl, connection::SimpleConnection, pg::PgConnection, prelude::*,
5 r2d2::ConnectionManager, result::Error, sql_query,
6};
7use parking_lot::Mutex;
8use r2d2::{Builder, Pool, PooledConnection};
9use uuid::Uuid;
10
11use crate::{common::config::postgres::PrivilegedPostgresConfig, util::get_db_name};
12
13use super::{
14 super::{error::Error as BackendError, r#trait::Backend},
15 r#trait::{PostgresBackend, PostgresBackendWrapper},
16};
17
18type Manager = ConnectionManager<PgConnection>;
19
20pub struct DieselPostgresBackend {
22 privileged_config: PrivilegedPostgresConfig,
23 default_pool: Pool<Manager>,
24 db_conns: Mutex<HashMap<Uuid, PgConnection>>,
25 create_restricted_pool: Box<dyn Fn() -> Builder<Manager> + Send + Sync + 'static>,
26 create_entities: Box<dyn Fn(&mut PgConnection) + Send + Sync + 'static>,
27 drop_previous_databases_flag: bool,
28}
29
30impl DieselPostgresBackend {
31 pub fn new(
56 privileged_config: PrivilegedPostgresConfig,
57 create_privileged_pool: impl Fn() -> Builder<Manager>,
58 create_restricted_pool: impl Fn() -> Builder<Manager> + Send + Sync + 'static,
59 create_entities: impl Fn(&mut PgConnection) + Send + Sync + 'static,
60 ) -> Result<Self, r2d2::Error> {
61 let manager = Manager::new(privileged_config.default_connection_url());
62 let default_pool = (create_privileged_pool()).build(manager)?;
63
64 Ok(Self {
65 privileged_config,
66 default_pool,
67 db_conns: Mutex::new(HashMap::new()),
68 create_entities: Box::new(create_entities),
69 create_restricted_pool: Box::new(create_restricted_pool),
70 drop_previous_databases_flag: true,
71 })
72 }
73
74 #[must_use]
76 pub fn drop_previous_databases(self, value: bool) -> Self {
77 Self {
78 drop_previous_databases_flag: value,
79 ..self
80 }
81 }
82}
83
84impl PostgresBackend for DieselPostgresBackend {
85 type ConnectionManager = Manager;
86 type ConnectionError = ConnectionError;
87 type QueryError = Error;
88
89 fn execute_query(&self, query: &str, conn: &mut PgConnection) -> QueryResult<()> {
90 sql_query(query).execute(conn)?;
91 Ok(())
92 }
93
94 fn batch_execute_query<'a>(
95 &self,
96 query: impl IntoIterator<Item = Cow<'a, str>>,
97 conn: &mut PgConnection,
98 ) -> QueryResult<()> {
99 let query = query.into_iter().collect::<Vec<_>>();
100 if query.is_empty() {
101 Ok(())
102 } else {
103 conn.batch_execute(query.join(";").as_str())
104 }
105 }
106
107 fn get_default_connection(&self) -> Result<PooledConnection<Manager>, r2d2::Error> {
108 self.default_pool.get()
109 }
110
111 fn establish_privileged_database_connection(
112 &self,
113 db_id: Uuid,
114 ) -> ConnectionResult<PgConnection> {
115 let db_name = get_db_name(db_id);
116 let database_url = self
117 .privileged_config
118 .privileged_database_connection_url(db_name.as_str());
119 PgConnection::establish(database_url.as_str())
120 }
121
122 fn establish_restricted_database_connection(
123 &self,
124 db_id: Uuid,
125 ) -> ConnectionResult<PgConnection> {
126 let db_name = get_db_name(db_id);
127 let db_name = db_name.as_str();
128 let database_url = self.privileged_config.restricted_database_connection_url(
129 db_name,
130 Some(db_name),
131 db_name,
132 );
133 PgConnection::establish(database_url.as_str())
134 }
135
136 fn put_database_connection(&self, db_id: Uuid, conn: PgConnection) {
137 self.db_conns.lock().insert(db_id, conn);
138 }
139
140 fn get_database_connection(&self, db_id: Uuid) -> PgConnection {
141 self.db_conns
142 .lock()
143 .remove(&db_id)
144 .unwrap_or_else(|| panic!("connection map must have a connection for {db_id}"))
145 }
146
147 fn get_previous_database_names(&self, conn: &mut PgConnection) -> QueryResult<Vec<String>> {
148 table! {
149 pg_database (oid) {
150 oid -> Int4,
151 datname -> Text
152 }
153 }
154
155 pg_database::table
156 .select(pg_database::datname)
157 .filter(pg_database::datname.like("db_pool_%"))
158 .load::<String>(conn)
159 }
160
161 fn create_entities(&self, conn: &mut PgConnection) {
162 (self.create_entities)(conn);
163 }
164
165 fn create_connection_pool(
166 &self,
167 db_id: Uuid,
168 ) -> Result<Pool<Self::ConnectionManager>, r2d2::Error> {
169 let db_name = get_db_name(db_id);
170 let db_name = db_name.as_str();
171 let database_url = self.privileged_config.restricted_database_connection_url(
172 db_name,
173 Some(db_name),
174 db_name,
175 );
176 let manager = ConnectionManager::<PgConnection>::new(database_url.as_str());
177 (self.create_restricted_pool)().build(manager)
178 }
179
180 fn get_table_names(&self, conn: &mut PgConnection) -> QueryResult<Vec<String>> {
181 table! {
182 pg_tables (tablename) {
183 #[sql_name = "schemaname"]
184 schema_name -> Text,
185 tablename -> Text
186 }
187 }
188
189 pg_tables::table
190 .filter(pg_tables::schema_name.ne_all(["pg_catalog", "information_schema"]))
191 .select(pg_tables::tablename)
192 .load(conn)
193 }
194
195 fn get_drop_previous_databases(&self) -> bool {
196 self.drop_previous_databases_flag
197 }
198}
199
200impl Backend for DieselPostgresBackend {
201 type ConnectionManager = Manager;
202 type ConnectionError = ConnectionError;
203 type QueryError = Error;
204
205 fn init(&self) -> Result<(), BackendError<ConnectionError, Error>> {
206 PostgresBackendWrapper::new(self).init()
207 }
208
209 fn create(
210 &self,
211 db_id: Uuid,
212 restrict_privileges: bool,
213 ) -> Result<Pool<Manager>, BackendError<ConnectionError, Error>> {
214 PostgresBackendWrapper::new(self).create(db_id, restrict_privileges)
215 }
216
217 fn clean(&self, db_id: Uuid) -> Result<(), BackendError<ConnectionError, Error>> {
218 PostgresBackendWrapper::new(self).clean(db_id)
219 }
220
221 fn drop(
222 &self,
223 db_id: Uuid,
224 is_restricted: bool,
225 ) -> Result<(), BackendError<ConnectionError, Error>> {
226 PostgresBackendWrapper::new(self).drop(db_id, is_restricted)
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 #![allow(unused_variables, clippy::unwrap_used, clippy::needless_return)]
233
234 use std::borrow::Cow;
235
236 use diesel::{
237 Insertable, QueryDsl, RunQueryDsl, connection::SimpleConnection, insert_into, sql_query,
238 table,
239 };
240 use dotenvy::dotenv;
241 use r2d2::Pool;
242
243 use crate::{
244 common::{
245 config::PrivilegedPostgresConfig,
246 statement::postgres::tests::{
247 CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
248 },
249 },
250 sync::{
251 backend::postgres::r#trait::tests::test_backend_creates_database_with_unrestricted_privileges,
252 db_pool::DatabasePoolBuilder,
253 },
254 };
255
256 use super::{
257 super::r#trait::tests::{
258 lock_read, test_backend_cleans_database_with_tables,
259 test_backend_cleans_database_without_tables,
260 test_backend_creates_database_with_restricted_privileges, test_backend_drops_database,
261 test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
262 test_pool_drops_created_unrestricted_database, test_pool_drops_previous_databases,
263 },
264 DieselPostgresBackend,
265 };
266
267 table! {
268 book (id) {
269 id -> Int4,
270 title -> Text
271 }
272 }
273
274 #[derive(Insertable)]
275 #[diesel(table_name = book)]
276 struct NewBook<'a> {
277 title: Cow<'a, str>,
278 }
279
280 fn create_backend(with_table: bool) -> DieselPostgresBackend {
281 dotenv().ok();
282
283 let config = PrivilegedPostgresConfig::from_env().unwrap();
284
285 DieselPostgresBackend::new(config, Pool::builder, Pool::builder, {
286 move |conn| {
287 if with_table {
288 let query = CREATE_ENTITIES_STATEMENTS.join(";");
289 conn.batch_execute(query.as_str()).unwrap();
290 }
291 }
292 })
293 .unwrap()
294 }
295
296 #[test]
297 fn backend_drops_previous_databases() {
298 test_backend_drops_previous_databases(
299 create_backend(false),
300 create_backend(false).drop_previous_databases(true),
301 create_backend(false).drop_previous_databases(false),
302 );
303 }
304
305 #[test]
306 fn backend_creates_database_with_restricted_privileges() {
307 let backend = create_backend(true).drop_previous_databases(false);
308 test_backend_creates_database_with_restricted_privileges(&backend);
309 }
310
311 #[test]
312 fn backend_creates_database_with_unrestricted_privileges() {
313 let backend = create_backend(true).drop_previous_databases(false);
314 test_backend_creates_database_with_unrestricted_privileges(&backend);
315 }
316
317 #[test]
318 fn backend_cleans_database_with_tables() {
319 let backend = create_backend(true).drop_previous_databases(false);
320 test_backend_cleans_database_with_tables(&backend);
321 }
322
323 #[test]
324 fn backend_cleans_database_without_tables() {
325 let backend = create_backend(false).drop_previous_databases(false);
326 test_backend_cleans_database_without_tables(&backend);
327 }
328
329 #[test]
330 fn backend_drops_restricted_database() {
331 let backend = create_backend(true).drop_previous_databases(false);
332 test_backend_drops_database(&backend, true);
333 }
334
335 #[test]
336 fn backend_drops_unrestricted_database() {
337 let backend = create_backend(true).drop_previous_databases(false);
338 test_backend_drops_database(&backend, false);
339 }
340
341 #[test]
342 fn pool_drops_previous_databases() {
343 test_pool_drops_previous_databases(
344 create_backend(false),
345 create_backend(false).drop_previous_databases(true),
346 create_backend(false).drop_previous_databases(false),
347 );
348 }
349
350 #[test]
351 fn pool_provides_isolated_databases() {
352 const NUM_DBS: i64 = 3;
353
354 let backend = create_backend(true).drop_previous_databases(false);
355
356 let guard = lock_read();
357
358 let db_pool = backend.create_database_pool().unwrap();
359 let conn_pools = (0..NUM_DBS)
360 .map(|_| db_pool.pull_immutable())
361 .collect::<Vec<_>>();
362
363 conn_pools.iter().enumerate().for_each(|(i, conn_pool)| {
365 let conn = &mut conn_pool.get().unwrap();
366 insert_into(book::table)
367 .values(NewBook {
368 title: format!("Title {i}").into(),
369 })
370 .execute(conn)
371 .unwrap();
372 });
373
374 conn_pools.iter().enumerate().for_each(|(i, conn_pool)| {
376 let conn = &mut conn_pool.get().unwrap();
377 assert_eq!(
378 book::table
379 .select(book::title)
380 .load::<String>(conn)
381 .unwrap(),
382 vec![format!("Title {i}")]
383 );
384 });
385 }
386
387 #[test]
388 fn pool_provides_restricted_databases() {
389 let backend = create_backend(true).drop_previous_databases(false);
390
391 let guard = lock_read();
392
393 let db_pool = backend.create_database_pool().unwrap();
394 let conn_pool = db_pool.pull_immutable();
395 let conn = &mut conn_pool.get().unwrap();
396
397 for stmt in DDL_STATEMENTS {
399 assert!(sql_query(stmt).execute(conn).is_err());
400 }
401
402 for stmt in DML_STATEMENTS {
404 assert!(sql_query(stmt).execute(conn).is_ok());
405 }
406 }
407
408 #[test]
409 fn pool_provides_unrestricted_databases() {
410 let backend = create_backend(true).drop_previous_databases(false);
411
412 let guard = lock_read();
413
414 let db_pool = backend.create_database_pool().unwrap();
415
416 {
418 let conn_pool = db_pool.create_mutable().unwrap();
419 let conn = &mut conn_pool.get().unwrap();
420 for stmt in DML_STATEMENTS {
421 assert!(sql_query(stmt).execute(conn).is_ok());
422 }
423 }
424
425 for stmt in DDL_STATEMENTS {
427 let conn_pool = db_pool.create_mutable().unwrap();
428 let conn = &mut conn_pool.get().unwrap();
429 assert!(sql_query(stmt).execute(conn).is_ok());
430 }
431 }
432
433 #[test]
434 fn pool_provides_clean_databases() {
435 const NUM_DBS: i64 = 3;
436
437 let backend = create_backend(true).drop_previous_databases(false);
438
439 let guard = lock_read();
440
441 let db_pool = backend.create_database_pool().unwrap();
442
443 {
445 let conn_pools = (0..NUM_DBS)
446 .map(|_| db_pool.pull_immutable())
447 .collect::<Vec<_>>();
448
449 for conn_pool in &conn_pools {
451 let conn = &mut conn_pool.get().unwrap();
452 assert_eq!(book::table.count().get_result::<i64>(conn).unwrap(), 0);
453 }
454
455 for conn_pool in &conn_pools {
457 let conn = &mut conn_pool.get().unwrap();
458 insert_into(book::table)
459 .values(NewBook {
460 title: "Title".into(),
461 })
462 .execute(conn)
463 .unwrap();
464 }
465 }
466
467 {
469 let conn_pools = (0..NUM_DBS)
470 .map(|_| db_pool.pull_immutable())
471 .collect::<Vec<_>>();
472
473 for conn_pool in &conn_pools {
475 let conn = &mut conn_pool.get().unwrap();
476 assert_eq!(book::table.count().get_result::<i64>(conn).unwrap(), 0);
477 }
478 }
479 }
480
481 #[test]
482 fn pool_drops_created_restricted_databases() {
483 let backend = create_backend(false);
484 test_pool_drops_created_restricted_databases(backend);
485 }
486
487 #[test]
488 fn pool_drops_created_unrestricted_database() {
489 let backend = create_backend(false);
490 test_pool_drops_created_unrestricted_database(backend);
491 }
492}