db_pool/async/backend/postgres/
sqlx.rs1use std::{borrow::Cow, collections::HashMap, pin::Pin};
2
3use async_trait::async_trait;
4use futures::Future;
5use parking_lot::Mutex;
6use sqlx::{
7 Connection, Executor, PgConnection, PgPool, Postgres, Row,
8 pool::PoolConnection,
9 postgres::{PgConnectOptions, PgPoolOptions},
10};
11use uuid::Uuid;
12
13use crate::{common::statement::postgres, util::get_db_name};
14
15use super::{
16 super::{
17 common::error::sqlx::{BuildError, ConnectionError, PoolError, QueryError},
18 error::Error as BackendError,
19 r#trait::Backend,
20 },
21 r#trait::{PostgresBackend, PostgresBackendWrapper},
22};
23
24type CreateEntities = dyn Fn(PgConnection) -> Pin<Box<dyn Future<Output = PgConnection> + Send + 'static>>
25 + Send
26 + Sync
27 + 'static;
28
29pub struct SqlxPostgresBackend {
31 privileged_opts: PgConnectOptions,
32 default_pool: PgPool,
33 db_conns: Mutex<HashMap<Uuid, PgConnection>>,
34 create_restricted_pool: Box<dyn Fn() -> PgPoolOptions + Send + Sync + 'static>,
35 create_entities: Box<CreateEntities>,
36 drop_previous_databases_flag: bool,
37}
38
39impl SqlxPostgresBackend {
40 pub fn new(
70 privileged_options: PgConnectOptions,
71 create_privileged_pool: impl Fn() -> PgPoolOptions,
72 create_restricted_pool: impl Fn() -> PgPoolOptions + Send + Sync + 'static,
73 create_entities: impl Fn(
74 PgConnection,
75 )
76 -> Pin<Box<dyn Future<Output = PgConnection> + Send + 'static>>
77 + Send
78 + Sync
79 + 'static,
80 ) -> Self {
81 let pool_opts = create_privileged_pool();
82 let default_pool = pool_opts.connect_lazy_with(privileged_options.clone());
83
84 Self {
85 privileged_opts: privileged_options,
86 default_pool,
87 db_conns: Mutex::new(HashMap::new()),
88 create_restricted_pool: Box::new(create_restricted_pool),
89 create_entities: Box::new(create_entities),
90 drop_previous_databases_flag: true,
91 }
92 }
93
94 #[must_use]
96 pub fn drop_previous_databases(self, value: bool) -> Self {
97 Self {
98 drop_previous_databases_flag: value,
99 ..self
100 }
101 }
102}
103
104#[async_trait]
105impl<'pool> PostgresBackend<'pool> for SqlxPostgresBackend {
106 type Connection = PgConnection;
107 type PooledConnection = PoolConnection<Postgres>;
108 type Pool = PgPool;
109
110 type BuildError = BuildError;
111 type PoolError = PoolError;
112 type ConnectionError = ConnectionError;
113 type QueryError = QueryError;
114
115 async fn execute_query(&self, query: &str, conn: &mut PgConnection) -> Result<(), QueryError> {
116 conn.execute(query).await?;
117 Ok(())
118 }
119
120 async fn batch_execute_query<'a>(
121 &self,
122 query: impl IntoIterator<Item = Cow<'a, str>> + Send,
123 conn: &mut PgConnection,
124 ) -> Result<(), QueryError> {
125 let query = query.into_iter().collect::<Vec<_>>().join(";");
126 self.execute_query(query.as_str(), conn).await
127 }
128
129 async fn get_default_connection(&'pool self) -> Result<PoolConnection<Postgres>, PoolError> {
130 self.default_pool.acquire().await.map_err(Into::into)
131 }
132
133 async fn establish_privileged_database_connection(
134 &self,
135 db_id: Uuid,
136 ) -> Result<PgConnection, ConnectionError> {
137 let db_name = get_db_name(db_id);
138 let opts = self.privileged_opts.clone().database(db_name.as_str());
139 PgConnection::connect_with(&opts).await.map_err(Into::into)
140 }
141
142 async fn establish_restricted_database_connection(
143 &self,
144 db_id: Uuid,
145 ) -> Result<PgConnection, ConnectionError> {
146 let db_name = get_db_name(db_id);
147 let db_name = db_name.as_str();
148 let opts = self
149 .privileged_opts
150 .clone()
151 .username(db_name)
152 .password(db_name)
153 .database(db_name);
154 PgConnection::connect_with(&opts).await.map_err(Into::into)
155 }
156
157 fn put_database_connection(&self, db_id: Uuid, conn: PgConnection) {
158 self.db_conns.lock().insert(db_id, conn);
159 }
160
161 fn get_database_connection(&self, db_id: Uuid) -> PgConnection {
162 self.db_conns
163 .lock()
164 .remove(&db_id)
165 .unwrap_or_else(|| panic!("connection map must have a connection for {db_id}"))
166 }
167
168 async fn get_previous_database_names(
169 &self,
170 conn: &mut PgConnection,
171 ) -> Result<Vec<String>, QueryError> {
172 conn.fetch_all(postgres::GET_DATABASE_NAMES)
173 .await?
174 .iter()
175 .map(|row| row.try_get(0))
176 .collect::<Result<Vec<_>, _>>()
177 .map_err(Into::into)
178 }
179
180 async fn create_entities(&self, conn: PgConnection) -> Option<PgConnection> {
181 Some((self.create_entities)(conn).await)
182 }
183
184 async fn create_connection_pool(&self, db_id: Uuid) -> Result<PgPool, BuildError> {
185 let db_name = get_db_name(db_id);
186 let db_name = db_name.as_str();
187 let opts = self
188 .privileged_opts
189 .clone()
190 .database(db_name)
191 .username(db_name)
192 .password(db_name);
193 let pool = (self.create_restricted_pool)().connect_lazy_with(opts);
194 Ok(pool)
195 }
196
197 async fn get_table_names(&self, conn: &mut PgConnection) -> Result<Vec<String>, QueryError> {
198 conn.fetch_all(postgres::GET_TABLE_NAMES)
199 .await?
200 .iter()
201 .map(|row| row.try_get(0))
202 .collect::<Result<Vec<_>, _>>()
203 .map_err(Into::into)
204 }
205
206 fn get_drop_previous_databases(&self) -> bool {
207 self.drop_previous_databases_flag
208 }
209}
210
211type BError = BackendError<BuildError, PoolError, ConnectionError, QueryError>;
212
213#[async_trait]
214impl Backend for SqlxPostgresBackend {
215 type Pool = PgPool;
216
217 type BuildError = BuildError;
218 type PoolError = PoolError;
219 type ConnectionError = ConnectionError;
220 type QueryError = QueryError;
221
222 async fn init(&self) -> Result<(), BError> {
223 PostgresBackendWrapper::new(self).init().await
224 }
225
226 async fn create(&self, db_id: uuid::Uuid, restrict_privileges: bool) -> Result<PgPool, BError> {
227 PostgresBackendWrapper::new(self)
228 .create(db_id, restrict_privileges)
229 .await
230 }
231
232 async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError> {
233 PostgresBackendWrapper::new(self).clean(db_id).await
234 }
235
236 async fn drop(&self, db_id: uuid::Uuid, is_restricted: bool) -> Result<(), BError> {
237 PostgresBackendWrapper::new(self)
238 .drop(db_id, is_restricted)
239 .await
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 #![allow(clippy::unwrap_used, clippy::needless_return)]
246
247 use futures::{StreamExt, future::join_all};
248 use sqlx::{
249 Executor, FromRow, Row,
250 postgres::{PgConnectOptions, PgPoolOptions},
251 query, query_as,
252 };
253 use tokio_shared_rt::test;
254
255 use crate::{
256 r#async::{
257 backend::postgres::r#trait::tests::{
258 test_backend_creates_database_with_unrestricted_privileges,
259 test_backend_drops_database, test_pool_drops_created_unrestricted_database,
260 },
261 db_pool::DatabasePoolBuilder,
262 },
263 common::statement::postgres::tests::{
264 CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
265 },
266 };
267
268 use super::{
269 super::r#trait::tests::{
270 PgDropLock, test_backend_cleans_database_with_tables,
271 test_backend_cleans_database_without_tables,
272 test_backend_creates_database_with_restricted_privileges,
273 test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
274 test_pool_drops_previous_databases,
275 },
276 SqlxPostgresBackend,
277 };
278
279 fn create_backend(with_table: bool) -> SqlxPostgresBackend {
280 SqlxPostgresBackend::new(
281 PgConnectOptions::new()
282 .username("postgres")
283 .password("postgres"),
284 PgPoolOptions::new,
285 PgPoolOptions::new,
286 {
287 move |mut conn| {
288 if with_table {
289 Box::pin(async move {
290 conn.execute_many(CREATE_ENTITIES_STATEMENTS.join(";").as_str())
291 .collect::<Vec<_>>()
292 .await
293 .drain(..)
294 .collect::<Result<Vec<_>, _>>()
295 .unwrap();
296 conn
297 })
298 } else {
299 Box::pin(async { conn })
300 }
301 }
302 },
303 )
304 }
305
306 #[test(flavor = "multi_thread", shared)]
307 async fn backend_drops_previous_databases() {
308 test_backend_drops_previous_databases(
309 create_backend(false),
310 create_backend(false).drop_previous_databases(true),
311 create_backend(false).drop_previous_databases(false),
312 )
313 .await;
314 }
315
316 #[test(flavor = "multi_thread", shared)]
317 async fn backend_creates_database_with_restricted_privileges() {
318 let backend = create_backend(true).drop_previous_databases(false);
319 test_backend_creates_database_with_restricted_privileges(backend).await;
320 }
321
322 #[test(flavor = "multi_thread", shared)]
323 async fn backend_creates_database_with_unrestricted_privileges() {
324 let backend = create_backend(true).drop_previous_databases(false);
325 test_backend_creates_database_with_unrestricted_privileges(backend).await;
326 }
327
328 #[test(flavor = "multi_thread", shared)]
329 async fn backend_cleans_database_with_tables() {
330 let backend = create_backend(true).drop_previous_databases(false);
331 test_backend_cleans_database_with_tables(backend).await;
332 }
333
334 #[test(flavor = "multi_thread", shared)]
335 async fn backend_cleans_database_without_tables() {
336 let backend = create_backend(false).drop_previous_databases(false);
337 test_backend_cleans_database_without_tables(backend).await;
338 }
339
340 #[test(flavor = "multi_thread", shared)]
341 async fn backend_drops_restricted_database() {
342 let backend = create_backend(true).drop_previous_databases(false);
343 test_backend_drops_database(backend, true).await;
344 }
345
346 #[test(flavor = "multi_thread", shared)]
347 async fn backend_drops_unrestricted_database() {
348 let backend = create_backend(true).drop_previous_databases(false);
349 test_backend_drops_database(backend, false).await;
350 }
351
352 #[test(flavor = "multi_thread", shared)]
353 async fn pool_drops_previous_databases() {
354 test_pool_drops_previous_databases(
355 create_backend(false),
356 create_backend(false).drop_previous_databases(true),
357 create_backend(false).drop_previous_databases(false),
358 )
359 .await;
360 }
361
362 #[test(flavor = "multi_thread", shared)]
363 async fn pool_provides_isolated_databases() {
364 #[derive(FromRow, Eq, PartialEq, Debug)]
365 struct Book {
366 title: String,
367 }
368
369 const NUM_DBS: i64 = 3;
370
371 let backend = create_backend(true).drop_previous_databases(false);
372
373 async {
374 let db_pool = backend.create_database_pool().await.unwrap();
375 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
376
377 join_all(
379 conn_pools
380 .iter()
381 .enumerate()
382 .map(|(i, conn_pool)| async move {
383 query("INSERT INTO book (title) VALUES ($1)")
384 .bind(format!("Title {i}"))
385 .execute(&***conn_pool)
386 .await
387 .unwrap();
388 }),
389 )
390 .await;
391
392 join_all(
394 conn_pools
395 .iter()
396 .enumerate()
397 .map(|(i, conn_pool)| async move {
398 assert_eq!(
399 query_as::<_, Book>("SELECT title FROM book")
400 .fetch_all(&***conn_pool)
401 .await
402 .unwrap(),
403 vec![Book {
404 title: format!("Title {i}")
405 }]
406 );
407 }),
408 )
409 .await;
410 }
411 .lock_read()
412 .await;
413 }
414
415 #[test(flavor = "multi_thread", shared)]
416 async fn pool_provides_restricted_databases() {
417 let backend = create_backend(true).drop_previous_databases(false);
418
419 async {
420 let db_pool = backend.create_database_pool().await.unwrap();
421
422 let conn_pool = db_pool.pull_immutable().await;
423 let conn = &mut conn_pool.acquire().await.unwrap();
424
425 for stmt in DDL_STATEMENTS {
427 assert!(conn.execute(stmt).await.is_err());
428 }
429
430 for stmt in DML_STATEMENTS {
432 assert!(conn.execute(stmt).await.is_ok());
433 }
434 }
435 .lock_read()
436 .await;
437 }
438
439 #[test(flavor = "multi_thread", shared)]
440 async fn pool_provides_unrestricted_databases() {
441 let backend = create_backend(true).drop_previous_databases(false);
442
443 async {
444 let db_pool = backend.create_database_pool().await.unwrap();
445
446 {
448 let conn_pool = db_pool.create_mutable().await.unwrap();
449 let conn = &mut conn_pool.acquire().await.unwrap();
450 for stmt in DML_STATEMENTS {
451 assert!(conn.execute(stmt).await.is_ok());
452 }
453 }
454
455 for stmt in DDL_STATEMENTS {
457 let conn_pool = db_pool.create_mutable().await.unwrap();
458 let conn = &mut conn_pool.acquire().await.unwrap();
459 assert!(conn.execute(stmt).await.is_ok());
460 }
461 }
462 .lock_read()
463 .await;
464 }
465
466 #[test(flavor = "multi_thread", shared)]
467 async fn pool_provides_clean_databases() {
468 const NUM_DBS: i64 = 3;
469
470 let backend = create_backend(true).drop_previous_databases(false);
471
472 async {
473 let db_pool = backend.create_database_pool().await.unwrap();
474
475 {
477 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
478
479 join_all(conn_pools.iter().map(|conn_pool| async move {
481 assert_eq!(
482 query("SELECT COUNT(*) FROM book")
483 .fetch_one(&***conn_pool)
484 .await
485 .unwrap()
486 .get::<i64, _>(0),
487 0
488 );
489 }))
490 .await;
491
492 join_all(conn_pools.iter().map(|conn_pool| async move {
494 query("INSERT INTO book (title) VALUES ($1)")
495 .bind("Title")
496 .execute(&***conn_pool)
497 .await
498 .unwrap();
499 }))
500 .await;
501 }
502
503 {
505 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
506
507 join_all(conn_pools.iter().map(|conn_pool| async move {
509 assert_eq!(
510 query("SELECT COUNT(*) FROM book")
511 .fetch_one(&***conn_pool)
512 .await
513 .unwrap()
514 .get::<i64, _>(0),
515 0
516 );
517 }))
518 .await;
519 }
520 }
521 .lock_read()
522 .await;
523 }
524
525 #[test(flavor = "multi_thread", shared)]
526 async fn pool_drops_created_restricted_databases() {
527 let backend = create_backend(false);
528 test_pool_drops_created_restricted_databases(backend).await;
529 }
530
531 #[test(flavor = "multi_thread", shared)]
532 async fn pool_drops_created_unrestricted_database() {
533 let backend = create_backend(false);
534 test_pool_drops_created_unrestricted_database(backend).await;
535 }
536}