1use std::{borrow::Cow, collections::HashMap, pin::Pin};
2
3use async_trait::async_trait;
4use diesel::{prelude::*, result::Error, sql_query, table, ConnectionError};
5use diesel_async::{
6 pooled_connection::{AsyncDieselConnectionManager, ManagerConfig, SetupCallback},
7 AsyncConnection, AsyncPgConnection, RunQueryDsl, SimpleAsyncConnection,
8};
9use futures::{future::FutureExt, Future};
10use parking_lot::Mutex;
11use uuid::Uuid;
12
13use crate::{common::config::postgres::PrivilegedPostgresConfig, util::get_db_name};
14
15use super::{
16 super::{
17 common::pool::diesel::r#trait::DieselPoolAssociation, error::Error as BackendError,
18 r#trait::Backend,
19 },
20 r#trait::{PostgresBackend, PostgresBackendWrapper},
21};
22
23type CreateEntities = dyn Fn(AsyncPgConnection) -> Pin<Box<dyn Future<Output = AsyncPgConnection> + Send + 'static>>
24 + Send
25 + Sync
26 + 'static;
27
28pub struct DieselAsyncPostgresBackend<P: DieselPoolAssociation<AsyncPgConnection>> {
30 privileged_config: PrivilegedPostgresConfig,
31 default_pool: P::Pool,
32 db_conns: Mutex<HashMap<Uuid, AsyncPgConnection>>,
33 create_restricted_pool: Box<dyn Fn() -> P::Builder + Send + Sync + 'static>,
34 create_connection: Box<dyn Fn() -> SetupCallback<AsyncPgConnection> + Send + Sync + 'static>,
35 create_entities: Box<CreateEntities>,
36 drop_previous_databases_flag: bool,
37}
38
39impl<P: DieselPoolAssociation<AsyncPgConnection>> DieselAsyncPostgresBackend<P> {
40 pub async fn new(
79 privileged_config: PrivilegedPostgresConfig,
80 create_privileged_pool: impl Fn() -> P::Builder,
81 create_restricted_pool: impl Fn() -> P::Builder + Send + Sync + 'static,
82 custom_create_connection: Option<
83 Box<dyn Fn() -> SetupCallback<AsyncPgConnection> + Send + Sync + 'static>,
84 >,
85 create_entities: impl Fn(
86 AsyncPgConnection,
87 ) -> Pin<Box<dyn Future<Output = AsyncPgConnection> + Send + 'static>>
88 + Send
89 + Sync
90 + 'static,
91 ) -> Result<Self, P::BuildError> {
92 let create_connection = custom_create_connection.unwrap_or_else(|| {
93 Box::new(|| {
94 Box::new(|connection_url| AsyncPgConnection::establish(connection_url).boxed())
95 })
96 });
97
98 let manager_config = {
99 let mut config = ManagerConfig::default();
100 config.custom_setup = Box::new(create_connection());
101 config
102 };
103 let manager = AsyncDieselConnectionManager::new_with_config(
104 privileged_config.default_connection_url(),
105 manager_config,
106 );
107 let builder = create_privileged_pool();
108 let default_pool = P::build_pool(builder, manager).await?;
109
110 Ok(Self {
111 privileged_config,
112 default_pool,
113 db_conns: Mutex::new(HashMap::new()),
114 create_restricted_pool: Box::new(create_restricted_pool),
115 create_connection,
116 create_entities: Box::new(create_entities),
117 drop_previous_databases_flag: true,
118 })
119 }
120
121 #[must_use]
123 pub fn drop_previous_databases(self, value: bool) -> Self {
124 Self {
125 drop_previous_databases_flag: value,
126 ..self
127 }
128 }
129}
130
131#[async_trait]
132impl<'pool, P: DieselPoolAssociation<AsyncPgConnection>> PostgresBackend<'pool>
133 for DieselAsyncPostgresBackend<P>
134{
135 type Connection = AsyncPgConnection;
136 type PooledConnection = P::PooledConnection<'pool>;
137 type Pool = P::Pool;
138
139 type BuildError = P::BuildError;
140 type PoolError = P::PoolError;
141 type ConnectionError = ConnectionError;
142 type QueryError = Error;
143
144 async fn execute_query(&self, query: &str, conn: &mut AsyncPgConnection) -> QueryResult<()> {
145 sql_query(query).execute(conn).await?;
146 Ok(())
147 }
148
149 async fn batch_execute_query<'a>(
150 &self,
151 query: impl IntoIterator<Item = Cow<'a, str>> + Send,
152 conn: &mut AsyncPgConnection,
153 ) -> QueryResult<()> {
154 let query = query.into_iter().collect::<Vec<_>>();
155 if query.is_empty() {
156 Ok(())
157 } else {
158 conn.batch_execute(query.join(";").as_str()).await
159 }
160 }
161
162 async fn get_default_connection(
163 &'pool self,
164 ) -> Result<P::PooledConnection<'pool>, P::PoolError> {
165 P::get_connection(&self.default_pool).await
166 }
167
168 async fn establish_privileged_database_connection(
169 &self,
170 db_id: Uuid,
171 ) -> ConnectionResult<AsyncPgConnection> {
172 let db_name = get_db_name(db_id);
173 let database_url = self
174 .privileged_config
175 .privileged_database_connection_url(db_name.as_str());
176 (self.create_connection)()(database_url.as_str()).await
177 }
178
179 async fn establish_restricted_database_connection(
180 &self,
181 db_id: Uuid,
182 ) -> ConnectionResult<AsyncPgConnection> {
183 let db_name = get_db_name(db_id);
184 let db_name = db_name.as_str();
185 let database_url = self.privileged_config.restricted_database_connection_url(
186 db_name,
187 Some(db_name),
188 db_name,
189 );
190 (self.create_connection)()(database_url.as_str()).await
191 }
192
193 fn put_database_connection(&self, db_id: Uuid, conn: AsyncPgConnection) {
194 self.db_conns.lock().insert(db_id, conn);
195 }
196
197 fn get_database_connection(&self, db_id: Uuid) -> AsyncPgConnection {
198 self.db_conns
199 .lock()
200 .remove(&db_id)
201 .unwrap_or_else(|| panic!("connection map must have a connection for {db_id}"))
202 }
203
204 async fn get_previous_database_names(
205 &self,
206 conn: &mut AsyncPgConnection,
207 ) -> QueryResult<Vec<String>> {
208 table! {
209 pg_database (oid) {
210 oid -> Int4,
211 datname -> Text
212 }
213 }
214
215 pg_database::table
216 .select(pg_database::datname)
217 .filter(pg_database::datname.like("db_pool_%"))
218 .load::<String>(conn)
219 .await
220 }
221
222 async fn create_entities(&self, conn: AsyncPgConnection) -> AsyncPgConnection {
223 (self.create_entities)(conn).await
224 }
225
226 async fn create_connection_pool(&self, db_id: Uuid) -> Result<P::Pool, P::BuildError> {
227 let db_name = get_db_name(db_id);
228 let db_name = db_name.as_str();
229 let database_url = self.privileged_config.restricted_database_connection_url(
230 db_name,
231 Some(db_name),
232 db_name,
233 );
234 let manager_config = {
235 let mut config = ManagerConfig::default();
236 config.custom_setup = Box::new((self.create_connection)());
237 config
238 };
239 let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new_with_config(
240 database_url.as_str(),
241 manager_config,
242 );
243 let builder = (self.create_restricted_pool)();
244 P::build_pool(builder, manager).await
245 }
246
247 async fn get_table_names(
248 &self,
249 privileged_conn: &mut AsyncPgConnection,
250 ) -> QueryResult<Vec<String>> {
251 table! {
252 pg_tables (tablename) {
253 #[sql_name = "schemaname"]
254 schema_name -> Text,
255 tablename -> Text
256 }
257 }
258
259 pg_tables::table
260 .filter(pg_tables::schema_name.ne_all(["pg_catalog", "information_schema"]))
261 .select(pg_tables::tablename)
262 .load(privileged_conn)
263 .await
264 }
265
266 fn get_drop_previous_databases(&self) -> bool {
267 self.drop_previous_databases_flag
268 }
269}
270
271type BError<BuildError, PoolError> = BackendError<BuildError, PoolError, ConnectionError, Error>;
272
273#[async_trait]
274impl<P: DieselPoolAssociation<AsyncPgConnection>> Backend for DieselAsyncPostgresBackend<P> {
275 type Pool = P::Pool;
276
277 type BuildError = P::BuildError;
278 type PoolError = P::PoolError;
279 type ConnectionError = ConnectionError;
280 type QueryError = Error;
281
282 async fn init(&self) -> Result<(), BError<P::BuildError, P::PoolError>> {
283 PostgresBackendWrapper::new(self).init().await
284 }
285
286 async fn create(
287 &self,
288 db_id: uuid::Uuid,
289 restrict_privileges: bool,
290 ) -> Result<P::Pool, BError<P::BuildError, P::PoolError>> {
291 PostgresBackendWrapper::new(self)
292 .create(db_id, restrict_privileges)
293 .await
294 }
295
296 async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError<P::BuildError, P::PoolError>> {
297 PostgresBackendWrapper::new(self).clean(db_id).await
298 }
299
300 async fn drop(
301 &self,
302 db_id: uuid::Uuid,
303 is_restricted: bool,
304 ) -> Result<(), BError<P::BuildError, P::PoolError>> {
305 PostgresBackendWrapper::new(self)
306 .drop(db_id, is_restricted)
307 .await
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 #![allow(clippy::unwrap_used, clippy::needless_return)]
314
315 use std::borrow::Cow;
316
317 use bb8::Pool;
318 use diesel::{insert_into, sql_query, table, Insertable, QueryDsl};
319 use diesel_async::{RunQueryDsl, SimpleAsyncConnection};
320 use dotenvy::dotenv;
321 use futures::future::join_all;
322 use tokio_shared_rt::test;
323
324 use crate::{
325 common::{
326 config::PrivilegedPostgresConfig,
327 statement::postgres::tests::{
328 CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
329 },
330 },
331 r#async::{
332 backend::{
333 common::pool::diesel::bb8::DieselBb8,
334 postgres::r#trait::tests::test_pool_drops_created_unrestricted_database,
335 },
336 db_pool::DatabasePoolBuilder,
337 },
338 };
339
340 use super::{
341 super::r#trait::tests::{
342 test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables,
343 test_backend_creates_database_with_restricted_privileges,
344 test_backend_creates_database_with_unrestricted_privileges,
345 test_backend_drops_database, test_backend_drops_previous_databases,
346 test_pool_drops_created_restricted_databases, test_pool_drops_previous_databases,
347 PgDropLock,
348 },
349 DieselAsyncPostgresBackend,
350 };
351
352 table! {
353 book (id) {
354 id -> Int4,
355 title -> Text
356 }
357 }
358
359 #[derive(Insertable)]
360 #[diesel(table_name = book)]
361 struct NewBook<'a> {
362 title: Cow<'a, str>,
363 }
364
365 async fn create_backend(with_table: bool) -> DieselAsyncPostgresBackend<DieselBb8> {
366 dotenv().ok();
367
368 let config = PrivilegedPostgresConfig::from_env().unwrap();
369
370 DieselAsyncPostgresBackend::new(config, Pool::builder, Pool::builder, None, {
371 move |mut conn| {
372 if with_table {
373 Box::pin(async move {
374 let query = CREATE_ENTITIES_STATEMENTS.join(";");
375 conn.batch_execute(query.as_str()).await.unwrap();
376 conn
377 })
378 } else {
379 Box::pin(async { conn })
380 }
381 }
382 })
383 .await
384 .unwrap()
385 }
386
387 #[test(flavor = "multi_thread", shared)]
388 async fn backend_drops_previous_databases() {
389 test_backend_drops_previous_databases(
390 create_backend(false).await,
391 create_backend(false).await.drop_previous_databases(true),
392 create_backend(false).await.drop_previous_databases(false),
393 )
394 .await;
395 }
396
397 #[test(flavor = "multi_thread", shared)]
398 async fn backend_creates_database_with_restricted_privileges() {
399 let backend = create_backend(true).await.drop_previous_databases(false);
400 test_backend_creates_database_with_restricted_privileges(backend).await;
401 }
402
403 #[test(flavor = "multi_thread", shared)]
404 async fn backend_creates_database_with_unrestricted_privileges() {
405 let backend = create_backend(true).await.drop_previous_databases(false);
406 test_backend_creates_database_with_unrestricted_privileges(backend).await;
407 }
408
409 #[test(flavor = "multi_thread", shared)]
410 async fn backend_cleans_database_with_tables() {
411 let backend = create_backend(true).await.drop_previous_databases(false);
412 test_backend_cleans_database_with_tables(backend).await;
413 }
414
415 #[test(flavor = "multi_thread", shared)]
416 async fn backend_cleans_database_without_tables() {
417 let backend = create_backend(false).await.drop_previous_databases(false);
418 test_backend_cleans_database_without_tables(backend).await;
419 }
420
421 #[test(flavor = "multi_thread", shared)]
422 async fn backend_drops_restricted_database() {
423 let backend = create_backend(true).await.drop_previous_databases(false);
424 test_backend_drops_database(backend, true).await;
425 }
426
427 #[test(flavor = "multi_thread", shared)]
428 async fn backend_drops_unrestricted_database() {
429 let backend = create_backend(true).await.drop_previous_databases(false);
430 test_backend_drops_database(backend, false).await;
431 }
432
433 #[test(flavor = "multi_thread", shared)]
434 async fn pool_drops_previous_databases() {
435 test_pool_drops_previous_databases(
436 create_backend(false).await,
437 create_backend(false).await.drop_previous_databases(true),
438 create_backend(false).await.drop_previous_databases(false),
439 )
440 .await;
441 }
442
443 #[test(flavor = "multi_thread", shared)]
444 async fn pool_provides_isolated_databases() {
445 const NUM_DBS: i64 = 3;
446
447 let backend = create_backend(true).await.drop_previous_databases(false);
448
449 async {
450 let db_pool = backend.create_database_pool().await.unwrap();
451 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
452
453 join_all(
455 conn_pools
456 .iter()
457 .enumerate()
458 .map(|(i, conn_pool)| async move {
459 let conn = &mut conn_pool.get().await.unwrap();
460 insert_into(book::table)
461 .values(NewBook {
462 title: format!("Title {i}").into(),
463 })
464 .execute(conn)
465 .await
466 .unwrap();
467 }),
468 )
469 .await;
470
471 join_all(
473 conn_pools
474 .iter()
475 .enumerate()
476 .map(|(i, conn_pool)| async move {
477 let conn = &mut conn_pool.get().await.unwrap();
478 assert_eq!(
479 book::table
480 .select(book::title)
481 .load::<String>(conn)
482 .await
483 .unwrap(),
484 vec![format!("Title {i}")]
485 );
486 }),
487 )
488 .await;
489 }
490 .lock_read()
491 .await;
492 }
493
494 #[test(flavor = "multi_thread", shared)]
495 async fn pool_provides_restricted_databases() {
496 let backend = create_backend(true).await.drop_previous_databases(false);
497
498 async {
499 let db_pool = backend.create_database_pool().await.unwrap();
500 let conn_pool = db_pool.pull_immutable().await;
501 let conn = &mut conn_pool.get().await.unwrap();
502
503 for stmt in DDL_STATEMENTS {
505 assert!(sql_query(stmt).execute(conn).await.is_err());
506 }
507
508 for stmt in DML_STATEMENTS {
510 assert!(sql_query(stmt).execute(conn).await.is_ok());
511 }
512 }
513 .lock_read()
514 .await;
515 }
516
517 #[test(flavor = "multi_thread", shared)]
518 async fn pool_provides_unrestricted_databases() {
519 let backend = create_backend(true).await.drop_previous_databases(false);
520
521 async {
522 let db_pool = backend.create_database_pool().await.unwrap();
523
524 {
526 let conn_pool = db_pool.create_mutable().await.unwrap();
527 let conn = &mut conn_pool.get().await.unwrap();
528 for stmt in DML_STATEMENTS {
529 assert!(sql_query(stmt).execute(conn).await.is_ok());
530 }
531 }
532
533 for stmt in DDL_STATEMENTS {
535 let conn_pool = db_pool.create_mutable().await.unwrap();
536 let conn = &mut conn_pool.get().await.unwrap();
537 assert!(sql_query(stmt).execute(conn).await.is_ok());
538 }
539 }
540 .lock_read()
541 .await;
542 }
543
544 #[test(flavor = "multi_thread", shared)]
545 async fn pool_provides_clean_databases() {
546 const NUM_DBS: i64 = 3;
547
548 let backend = create_backend(true).await.drop_previous_databases(false);
549
550 async {
551 let db_pool = backend.create_database_pool().await.unwrap();
552
553 {
555 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
556
557 join_all(conn_pools.iter().map(|conn_pool| async move {
559 let conn = &mut conn_pool.get().await.unwrap();
560 assert_eq!(
561 book::table.count().get_result::<i64>(conn).await.unwrap(),
562 0
563 );
564 }))
565 .await;
566
567 join_all(conn_pools.iter().map(|conn_pool| async move {
569 let conn = &mut conn_pool.get().await.unwrap();
570 insert_into(book::table)
571 .values(NewBook {
572 title: "Title".into(),
573 })
574 .execute(conn)
575 .await
576 .unwrap();
577 }))
578 .await;
579 }
580
581 {
583 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
584
585 join_all(conn_pools.iter().map(|conn_pool| async move {
587 let conn = &mut conn_pool.get().await.unwrap();
588 assert_eq!(
589 book::table.count().get_result::<i64>(conn).await.unwrap(),
590 0
591 );
592 }))
593 .await;
594 }
595 }
596 .lock_read()
597 .await;
598 }
599
600 #[test(flavor = "multi_thread", shared)]
601 async fn pool_drops_created_restricted_databases() {
602 let backend = create_backend(false).await;
603 test_pool_drops_created_restricted_databases(backend).await;
604 }
605
606 #[test(flavor = "multi_thread", shared)]
607 async fn pool_drops_created_unrestricted_database() {
608 let backend = create_backend(false).await;
609 test_pool_drops_created_unrestricted_database(backend).await;
610 }
611}