1use std::{borrow::Cow, collections::HashMap, pin::Pin};
2
3use async_trait::async_trait;
4use diesel::{ConnectionError, prelude::*, result::Error, sql_query, table};
5use diesel_async::{
6 AsyncConnection, AsyncPgConnection, RunQueryDsl, SimpleAsyncConnection,
7 pooled_connection::{AsyncDieselConnectionManager, ManagerConfig, SetupCallback},
8};
9use futures::{Future, future::FutureExt};
10use parking_lot::Mutex;
11use uuid::Uuid;
12
13use crate::{common::config::postgres::PrivilegedPostgresConfig, util::get_db_name};
14
15use super::{
16 super::{
17 common::pool::diesel::r#trait::DieselPoolAssociation, error::Error as BackendError,
18 r#trait::Backend,
19 },
20 r#trait::{PostgresBackend, PostgresBackendWrapper},
21};
22
23type CreateEntities = dyn Fn(
24 AsyncPgConnection,
25 ) -> Pin<Box<dyn Future<Output = Option<AsyncPgConnection>> + Send + 'static>>
26 + Send
27 + Sync
28 + 'static;
29
30pub struct DieselAsyncPostgresBackend<P: DieselPoolAssociation<AsyncPgConnection>> {
32 privileged_config: PrivilegedPostgresConfig,
33 default_pool: P::Pool,
34 db_conns: Mutex<HashMap<Uuid, AsyncPgConnection>>,
35 create_restricted_pool: Box<
36 dyn Fn(AsyncDieselConnectionManager<AsyncPgConnection>) -> P::Builder
37 + Send
38 + Sync
39 + 'static,
40 >,
41 create_connection: Box<dyn Fn() -> SetupCallback<AsyncPgConnection> + Send + Sync + 'static>,
42 create_entities: Box<CreateEntities>,
43 drop_previous_databases_flag: bool,
44}
45
46impl<P: DieselPoolAssociation<AsyncPgConnection>> DieselAsyncPostgresBackend<P> {
47 pub async fn new(
89 privileged_config: PrivilegedPostgresConfig,
90 create_privileged_pool: impl Fn(AsyncDieselConnectionManager<AsyncPgConnection>) -> P::Builder,
91 create_restricted_pool: impl Fn(AsyncDieselConnectionManager<AsyncPgConnection>) -> P::Builder
92 + Send
93 + Sync
94 + 'static,
95 custom_create_connection: Option<
96 Box<dyn Fn() -> SetupCallback<AsyncPgConnection> + Send + Sync + 'static>,
97 >,
98 create_entities: impl Fn(
99 AsyncPgConnection,
100 ) -> Pin<
101 Box<dyn Future<Output = Option<AsyncPgConnection>> + Send + 'static>,
102 > + Send
103 + Sync
104 + 'static,
105 ) -> Result<Self, P::BuildError> {
106 let create_connection = custom_create_connection.unwrap_or_else(|| {
107 Box::new(|| {
108 Box::new(|connection_url| AsyncPgConnection::establish(connection_url).boxed())
109 })
110 });
111
112 let manager = || {
113 let manager_config = {
114 let mut config = ManagerConfig::default();
115 config.custom_setup = Box::new(create_connection());
116 config
117 };
118 AsyncDieselConnectionManager::new_with_config(
119 privileged_config.default_connection_url(),
120 manager_config,
121 )
122 };
123
124 let builder = create_privileged_pool(manager());
127
128 let default_pool = P::build_pool(builder, manager()).await?;
129
130 Ok(Self {
131 privileged_config,
132 default_pool,
133 db_conns: Mutex::new(HashMap::new()),
134 create_restricted_pool: Box::new(create_restricted_pool),
135 create_connection,
136 create_entities: Box::new(create_entities),
137 drop_previous_databases_flag: true,
138 })
139 }
140
141 #[must_use]
143 pub fn drop_previous_databases(self, value: bool) -> Self {
144 Self {
145 drop_previous_databases_flag: value,
146 ..self
147 }
148 }
149}
150
151#[async_trait]
152impl<'pool, P: DieselPoolAssociation<AsyncPgConnection>> PostgresBackend<'pool>
153 for DieselAsyncPostgresBackend<P>
154{
155 type Connection = AsyncPgConnection;
156 type PooledConnection = P::PooledConnection<'pool>;
157 type Pool = P::Pool;
158
159 type BuildError = P::BuildError;
160 type PoolError = P::PoolError;
161 type ConnectionError = ConnectionError;
162 type QueryError = Error;
163
164 async fn execute_query(&self, query: &str, conn: &mut AsyncPgConnection) -> QueryResult<()> {
165 sql_query(query).execute(conn).await?;
166 Ok(())
167 }
168
169 async fn batch_execute_query<'a>(
170 &self,
171 query: impl IntoIterator<Item = Cow<'a, str>> + Send,
172 conn: &mut AsyncPgConnection,
173 ) -> QueryResult<()> {
174 let query = query.into_iter().collect::<Vec<_>>();
175 if query.is_empty() {
176 Ok(())
177 } else {
178 conn.batch_execute(query.join(";").as_str()).await
179 }
180 }
181
182 async fn get_default_connection(
183 &'pool self,
184 ) -> Result<P::PooledConnection<'pool>, P::PoolError> {
185 P::get_connection(&self.default_pool).await
186 }
187
188 async fn establish_privileged_database_connection(
189 &self,
190 db_id: Uuid,
191 ) -> ConnectionResult<AsyncPgConnection> {
192 let db_name = get_db_name(db_id);
193 let database_url = self
194 .privileged_config
195 .privileged_database_connection_url(db_name.as_str());
196 (self.create_connection)()(database_url.as_str()).await
197 }
198
199 async fn establish_restricted_database_connection(
200 &self,
201 db_id: Uuid,
202 ) -> ConnectionResult<AsyncPgConnection> {
203 let db_name = get_db_name(db_id);
204 let db_name = db_name.as_str();
205 let database_url = self.privileged_config.restricted_database_connection_url(
206 db_name,
207 Some(db_name),
208 db_name,
209 );
210 (self.create_connection)()(database_url.as_str()).await
211 }
212
213 fn put_database_connection(&self, db_id: Uuid, conn: AsyncPgConnection) {
214 self.db_conns.lock().insert(db_id, conn);
215 }
216
217 fn get_database_connection(&self, db_id: Uuid) -> AsyncPgConnection {
218 self.db_conns
219 .lock()
220 .remove(&db_id)
221 .unwrap_or_else(|| panic!("connection map must have a connection for {db_id}"))
222 }
223
224 async fn get_previous_database_names(
225 &self,
226 conn: &mut AsyncPgConnection,
227 ) -> QueryResult<Vec<String>> {
228 table! {
229 pg_database (oid) {
230 oid -> Int4,
231 datname -> Text
232 }
233 }
234
235 pg_database::table
236 .select(pg_database::datname)
237 .filter(pg_database::datname.like("db_pool_%"))
238 .load::<String>(conn)
239 .await
240 }
241
242 async fn create_entities(&self, conn: AsyncPgConnection) -> Option<AsyncPgConnection> {
243 (self.create_entities)(conn).await
244 }
245
246 async fn create_connection_pool(&self, db_id: Uuid) -> Result<P::Pool, P::BuildError> {
247 let db_name = get_db_name(db_id);
248 let db_name = db_name.as_str();
249 let database_url = self.privileged_config.restricted_database_connection_url(
250 db_name,
251 Some(db_name),
252 db_name,
253 );
254
255 let manager = {
256 || {
257 let manager_config = {
258 let mut config = ManagerConfig::default();
259 config.custom_setup = Box::new((self.create_connection)());
260 config
261 };
262 AsyncDieselConnectionManager::<AsyncPgConnection>::new_with_config(
263 database_url.clone(),
264 manager_config,
265 )
266 }
267 };
268
269 let builder = (self.create_restricted_pool)(manager());
272
273 P::build_pool(builder, manager()).await
274 }
275
276 async fn get_table_names(
277 &self,
278 privileged_conn: &mut AsyncPgConnection,
279 ) -> QueryResult<Vec<String>> {
280 table! {
281 pg_tables (tablename) {
282 #[sql_name = "schemaname"]
283 schema_name -> Text,
284 tablename -> Text
285 }
286 }
287
288 pg_tables::table
289 .filter(pg_tables::schema_name.ne_all(["pg_catalog", "information_schema"]))
290 .select(pg_tables::tablename)
291 .load(privileged_conn)
292 .await
293 }
294
295 fn get_drop_previous_databases(&self) -> bool {
296 self.drop_previous_databases_flag
297 }
298}
299
300type BError<BuildError, PoolError> = BackendError<BuildError, PoolError, ConnectionError, Error>;
301
302#[async_trait]
303impl<P: DieselPoolAssociation<AsyncPgConnection>> Backend for DieselAsyncPostgresBackend<P> {
304 type Pool = P::Pool;
305
306 type BuildError = P::BuildError;
307 type PoolError = P::PoolError;
308 type ConnectionError = ConnectionError;
309 type QueryError = Error;
310
311 async fn init(&self) -> Result<(), BError<P::BuildError, P::PoolError>> {
312 PostgresBackendWrapper::new(self).init().await
313 }
314
315 async fn create(
316 &self,
317 db_id: uuid::Uuid,
318 restrict_privileges: bool,
319 ) -> Result<P::Pool, BError<P::BuildError, P::PoolError>> {
320 PostgresBackendWrapper::new(self)
321 .create(db_id, restrict_privileges)
322 .await
323 }
324
325 async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError<P::BuildError, P::PoolError>> {
326 PostgresBackendWrapper::new(self).clean(db_id).await
327 }
328
329 async fn drop(
330 &self,
331 db_id: uuid::Uuid,
332 is_restricted: bool,
333 ) -> Result<(), BError<P::BuildError, P::PoolError>> {
334 PostgresBackendWrapper::new(self)
335 .drop(db_id, is_restricted)
336 .await
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 #![allow(clippy::unwrap_used, clippy::needless_return)]
343
344 use std::borrow::Cow;
345
346 use bb8::Pool;
347 use diesel::{Insertable, QueryDsl, insert_into, sql_query, table};
348 use diesel_async::{RunQueryDsl, SimpleAsyncConnection};
349 use dotenvy::dotenv;
350 use futures::future::join_all;
351 use tokio_shared_rt::test;
352
353 use crate::{
354 r#async::{
355 backend::{
356 common::pool::diesel::bb8::DieselBb8,
357 postgres::r#trait::tests::test_pool_drops_created_unrestricted_database,
358 },
359 db_pool::DatabasePoolBuilder,
360 },
361 common::{
362 config::PrivilegedPostgresConfig,
363 statement::postgres::tests::{
364 CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
365 },
366 },
367 };
368
369 use super::{
370 super::r#trait::tests::{
371 PgDropLock, test_backend_cleans_database_with_tables,
372 test_backend_cleans_database_without_tables,
373 test_backend_creates_database_with_restricted_privileges,
374 test_backend_creates_database_with_unrestricted_privileges,
375 test_backend_drops_database, test_backend_drops_previous_databases,
376 test_pool_drops_created_restricted_databases, test_pool_drops_previous_databases,
377 },
378 DieselAsyncPostgresBackend,
379 };
380
381 table! {
382 book (id) {
383 id -> Int4,
384 title -> Text
385 }
386 }
387
388 #[derive(Insertable)]
389 #[diesel(table_name = book)]
390 struct NewBook<'a> {
391 title: Cow<'a, str>,
392 }
393
394 async fn create_backend(with_table: bool) -> DieselAsyncPostgresBackend<DieselBb8> {
395 dotenv().ok();
396
397 let config = PrivilegedPostgresConfig::from_env().unwrap();
398
399 DieselAsyncPostgresBackend::new(config, |_| Pool::builder(), |_| Pool::builder(), None, {
400 move |mut conn| {
401 if with_table {
402 Box::pin(async move {
403 let query = CREATE_ENTITIES_STATEMENTS.join(";");
404 conn.batch_execute(query.as_str()).await.unwrap();
405 Some(conn)
406 })
407 } else {
408 Box::pin(async { Some(conn) })
409 }
410 }
411 })
412 .await
413 .unwrap()
414 }
415
416 #[test(flavor = "multi_thread", shared)]
417 async fn backend_drops_previous_databases() {
418 test_backend_drops_previous_databases(
419 create_backend(false).await,
420 create_backend(false).await.drop_previous_databases(true),
421 create_backend(false).await.drop_previous_databases(false),
422 )
423 .await;
424 }
425
426 #[test(flavor = "multi_thread", shared)]
427 async fn backend_creates_database_with_restricted_privileges() {
428 let backend = create_backend(true).await.drop_previous_databases(false);
429 test_backend_creates_database_with_restricted_privileges(backend).await;
430 }
431
432 #[test(flavor = "multi_thread", shared)]
433 async fn backend_creates_database_with_unrestricted_privileges() {
434 let backend = create_backend(true).await.drop_previous_databases(false);
435 test_backend_creates_database_with_unrestricted_privileges(backend).await;
436 }
437
438 #[test(flavor = "multi_thread", shared)]
439 async fn backend_cleans_database_with_tables() {
440 let backend = create_backend(true).await.drop_previous_databases(false);
441 test_backend_cleans_database_with_tables(backend).await;
442 }
443
444 #[test(flavor = "multi_thread", shared)]
445 async fn backend_cleans_database_without_tables() {
446 let backend = create_backend(false).await.drop_previous_databases(false);
447 test_backend_cleans_database_without_tables(backend).await;
448 }
449
450 #[test(flavor = "multi_thread", shared)]
451 async fn backend_drops_restricted_database() {
452 let backend = create_backend(true).await.drop_previous_databases(false);
453 test_backend_drops_database(backend, true).await;
454 }
455
456 #[test(flavor = "multi_thread", shared)]
457 async fn backend_drops_unrestricted_database() {
458 let backend = create_backend(true).await.drop_previous_databases(false);
459 test_backend_drops_database(backend, false).await;
460 }
461
462 #[test(flavor = "multi_thread", shared)]
463 async fn pool_drops_previous_databases() {
464 test_pool_drops_previous_databases(
465 create_backend(false).await,
466 create_backend(false).await.drop_previous_databases(true),
467 create_backend(false).await.drop_previous_databases(false),
468 )
469 .await;
470 }
471
472 #[test(flavor = "multi_thread", shared)]
473 async fn pool_provides_isolated_databases() {
474 const NUM_DBS: i64 = 3;
475
476 let backend = create_backend(true).await.drop_previous_databases(false);
477
478 async {
479 let db_pool = backend.create_database_pool().await.unwrap();
480 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
481
482 join_all(
484 conn_pools
485 .iter()
486 .enumerate()
487 .map(|(i, conn_pool)| async move {
488 let conn = &mut conn_pool.get().await.unwrap();
489 insert_into(book::table)
490 .values(NewBook {
491 title: format!("Title {i}").into(),
492 })
493 .execute(conn)
494 .await
495 .unwrap();
496 }),
497 )
498 .await;
499
500 join_all(
502 conn_pools
503 .iter()
504 .enumerate()
505 .map(|(i, conn_pool)| async move {
506 let conn = &mut conn_pool.get().await.unwrap();
507 assert_eq!(
508 book::table
509 .select(book::title)
510 .load::<String>(conn)
511 .await
512 .unwrap(),
513 vec![format!("Title {i}")]
514 );
515 }),
516 )
517 .await;
518 }
519 .lock_read()
520 .await;
521 }
522
523 #[test(flavor = "multi_thread", shared)]
524 async fn pool_provides_restricted_databases() {
525 let backend = create_backend(true).await.drop_previous_databases(false);
526
527 async {
528 let db_pool = backend.create_database_pool().await.unwrap();
529 let conn_pool = db_pool.pull_immutable().await;
530 let conn = &mut conn_pool.get().await.unwrap();
531
532 for stmt in DDL_STATEMENTS {
534 assert!(sql_query(stmt).execute(conn).await.is_err());
535 }
536
537 for stmt in DML_STATEMENTS {
539 assert!(sql_query(stmt).execute(conn).await.is_ok());
540 }
541 }
542 .lock_read()
543 .await;
544 }
545
546 #[test(flavor = "multi_thread", shared)]
547 async fn pool_provides_unrestricted_databases() {
548 let backend = create_backend(true).await.drop_previous_databases(false);
549
550 async {
551 let db_pool = backend.create_database_pool().await.unwrap();
552
553 {
555 let conn_pool = db_pool.create_mutable().await.unwrap();
556 let conn = &mut conn_pool.get().await.unwrap();
557 for stmt in DML_STATEMENTS {
558 assert!(sql_query(stmt).execute(conn).await.is_ok());
559 }
560 }
561
562 for stmt in DDL_STATEMENTS {
564 let conn_pool = db_pool.create_mutable().await.unwrap();
565 let conn = &mut conn_pool.get().await.unwrap();
566 assert!(sql_query(stmt).execute(conn).await.is_ok());
567 }
568 }
569 .lock_read()
570 .await;
571 }
572
573 #[test(flavor = "multi_thread", shared)]
574 async fn pool_provides_clean_databases() {
575 const NUM_DBS: i64 = 3;
576
577 let backend = create_backend(true).await.drop_previous_databases(false);
578
579 async {
580 let db_pool = backend.create_database_pool().await.unwrap();
581
582 {
584 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
585
586 join_all(conn_pools.iter().map(|conn_pool| async move {
588 let conn = &mut conn_pool.get().await.unwrap();
589 assert_eq!(
590 book::table.count().get_result::<i64>(conn).await.unwrap(),
591 0
592 );
593 }))
594 .await;
595
596 join_all(conn_pools.iter().map(|conn_pool| async move {
598 let conn = &mut conn_pool.get().await.unwrap();
599 insert_into(book::table)
600 .values(NewBook {
601 title: "Title".into(),
602 })
603 .execute(conn)
604 .await
605 .unwrap();
606 }))
607 .await;
608 }
609
610 {
612 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
613
614 join_all(conn_pools.iter().map(|conn_pool| async move {
616 let conn = &mut conn_pool.get().await.unwrap();
617 assert_eq!(
618 book::table.count().get_result::<i64>(conn).await.unwrap(),
619 0
620 );
621 }))
622 .await;
623 }
624 }
625 .lock_read()
626 .await;
627 }
628
629 #[test(flavor = "multi_thread", shared)]
630 async fn pool_drops_created_restricted_databases() {
631 let backend = create_backend(false).await;
632 test_pool_drops_created_restricted_databases(backend).await;
633 }
634
635 #[test(flavor = "multi_thread", shared)]
636 async fn pool_drops_created_unrestricted_database() {
637 let backend = create_backend(false).await;
638 test_pool_drops_created_unrestricted_database(backend).await;
639 }
640}