db_pool/async/backend/postgres/
sqlx.rs1use std::{borrow::Cow, collections::HashMap, pin::Pin};
2
3use async_trait::async_trait;
4use futures::Future;
5use parking_lot::Mutex;
6use sqlx::{
7 pool::PoolConnection,
8 postgres::{PgConnectOptions, PgPoolOptions},
9 Connection, Executor, PgConnection, PgPool, Postgres, Row,
10};
11use uuid::Uuid;
12
13use crate::{common::statement::postgres, util::get_db_name};
14
15use super::{
16 super::{
17 common::error::sqlx::{BuildError, ConnectionError, PoolError, QueryError},
18 error::Error as BackendError,
19 r#trait::Backend,
20 },
21 r#trait::{PostgresBackend, PostgresBackendWrapper},
22};
23
24type CreateEntities = dyn Fn(PgConnection) -> Pin<Box<dyn Future<Output = PgConnection> + Send + 'static>>
25 + Send
26 + Sync
27 + 'static;
28
29pub struct SqlxPostgresBackend {
31 privileged_opts: PgConnectOptions,
32 default_pool: PgPool,
33 db_conns: Mutex<HashMap<Uuid, PgConnection>>,
34 create_restricted_pool: Box<dyn Fn() -> PgPoolOptions + Send + Sync + 'static>,
35 create_entities: Box<CreateEntities>,
36 drop_previous_databases_flag: bool,
37}
38
39impl SqlxPostgresBackend {
40 pub fn new(
70 privileged_options: PgConnectOptions,
71 create_privileged_pool: impl Fn() -> PgPoolOptions,
72 create_restricted_pool: impl Fn() -> PgPoolOptions + Send + Sync + 'static,
73 create_entities: impl Fn(PgConnection) -> Pin<Box<dyn Future<Output = PgConnection> + Send + 'static>>
74 + Send
75 + Sync
76 + 'static,
77 ) -> Self {
78 let pool_opts = create_privileged_pool();
79 let default_pool = pool_opts.connect_lazy_with(privileged_options.clone());
80
81 Self {
82 privileged_opts: privileged_options,
83 default_pool,
84 db_conns: Mutex::new(HashMap::new()),
85 create_restricted_pool: Box::new(create_restricted_pool),
86 create_entities: Box::new(create_entities),
87 drop_previous_databases_flag: true,
88 }
89 }
90
91 #[must_use]
93 pub fn drop_previous_databases(self, value: bool) -> Self {
94 Self {
95 drop_previous_databases_flag: value,
96 ..self
97 }
98 }
99}
100
101#[async_trait]
102impl<'pool> PostgresBackend<'pool> for SqlxPostgresBackend {
103 type Connection = PgConnection;
104 type PooledConnection = PoolConnection<Postgres>;
105 type Pool = PgPool;
106
107 type BuildError = BuildError;
108 type PoolError = PoolError;
109 type ConnectionError = ConnectionError;
110 type QueryError = QueryError;
111
112 async fn execute_query(&self, query: &str, conn: &mut PgConnection) -> Result<(), QueryError> {
113 conn.execute(query).await?;
114 Ok(())
115 }
116
117 async fn batch_execute_query<'a>(
118 &self,
119 query: impl IntoIterator<Item = Cow<'a, str>> + Send,
120 conn: &mut PgConnection,
121 ) -> Result<(), QueryError> {
122 let query = query.into_iter().collect::<Vec<_>>().join(";");
123 self.execute_query(query.as_str(), conn).await
124 }
125
126 async fn get_default_connection(&'pool self) -> Result<PoolConnection<Postgres>, PoolError> {
127 self.default_pool.acquire().await.map_err(Into::into)
128 }
129
130 async fn establish_privileged_database_connection(
131 &self,
132 db_id: Uuid,
133 ) -> Result<PgConnection, ConnectionError> {
134 let db_name = get_db_name(db_id);
135 let opts = self.privileged_opts.clone().database(db_name.as_str());
136 PgConnection::connect_with(&opts).await.map_err(Into::into)
137 }
138
139 async fn establish_restricted_database_connection(
140 &self,
141 db_id: Uuid,
142 ) -> Result<PgConnection, ConnectionError> {
143 let db_name = get_db_name(db_id);
144 let db_name = db_name.as_str();
145 let opts = self
146 .privileged_opts
147 .clone()
148 .username(db_name)
149 .password(db_name)
150 .database(db_name);
151 PgConnection::connect_with(&opts).await.map_err(Into::into)
152 }
153
154 fn put_database_connection(&self, db_id: Uuid, conn: PgConnection) {
155 self.db_conns.lock().insert(db_id, conn);
156 }
157
158 fn get_database_connection(&self, db_id: Uuid) -> PgConnection {
159 self.db_conns
160 .lock()
161 .remove(&db_id)
162 .unwrap_or_else(|| panic!("connection map must have a connection for {db_id}"))
163 }
164
165 async fn get_previous_database_names(
166 &self,
167 conn: &mut PgConnection,
168 ) -> Result<Vec<String>, QueryError> {
169 conn.fetch_all(postgres::GET_DATABASE_NAMES)
170 .await?
171 .iter()
172 .map(|row| row.try_get(0))
173 .collect::<Result<Vec<_>, _>>()
174 .map_err(Into::into)
175 }
176
177 async fn create_entities(&self, conn: PgConnection) -> PgConnection {
178 (self.create_entities)(conn).await
179 }
180
181 async fn create_connection_pool(&self, db_id: Uuid) -> Result<PgPool, BuildError> {
182 let db_name = get_db_name(db_id);
183 let db_name = db_name.as_str();
184 let opts = self
185 .privileged_opts
186 .clone()
187 .database(db_name)
188 .username(db_name)
189 .password(db_name);
190 let pool = (self.create_restricted_pool)().connect_lazy_with(opts);
191 Ok(pool)
192 }
193
194 async fn get_table_names(&self, conn: &mut PgConnection) -> Result<Vec<String>, QueryError> {
195 conn.fetch_all(postgres::GET_TABLE_NAMES)
196 .await?
197 .iter()
198 .map(|row| row.try_get(0))
199 .collect::<Result<Vec<_>, _>>()
200 .map_err(Into::into)
201 }
202
203 fn get_drop_previous_databases(&self) -> bool {
204 self.drop_previous_databases_flag
205 }
206}
207
208type BError = BackendError<BuildError, PoolError, ConnectionError, QueryError>;
209
210#[async_trait]
211impl Backend for SqlxPostgresBackend {
212 type Pool = PgPool;
213
214 type BuildError = BuildError;
215 type PoolError = PoolError;
216 type ConnectionError = ConnectionError;
217 type QueryError = QueryError;
218
219 async fn init(&self) -> Result<(), BError> {
220 PostgresBackendWrapper::new(self).init().await
221 }
222
223 async fn create(&self, db_id: uuid::Uuid, restrict_privileges: bool) -> Result<PgPool, BError> {
224 PostgresBackendWrapper::new(self)
225 .create(db_id, restrict_privileges)
226 .await
227 }
228
229 async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError> {
230 PostgresBackendWrapper::new(self).clean(db_id).await
231 }
232
233 async fn drop(&self, db_id: uuid::Uuid, is_restricted: bool) -> Result<(), BError> {
234 PostgresBackendWrapper::new(self)
235 .drop(db_id, is_restricted)
236 .await
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 #![allow(clippy::unwrap_used, clippy::needless_return)]
243
244 use futures::{future::join_all, StreamExt};
245 use sqlx::{
246 postgres::{PgConnectOptions, PgPoolOptions},
247 query, query_as, Executor, FromRow, Row,
248 };
249 use tokio_shared_rt::test;
250
251 use crate::{
252 common::statement::postgres::tests::{
253 CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
254 },
255 r#async::{
256 backend::postgres::r#trait::tests::{
257 test_backend_creates_database_with_unrestricted_privileges,
258 test_backend_drops_database, test_pool_drops_created_unrestricted_database,
259 },
260 db_pool::DatabasePoolBuilder,
261 },
262 };
263
264 use super::{
265 super::r#trait::tests::{
266 test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables,
267 test_backend_creates_database_with_restricted_privileges,
268 test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
269 test_pool_drops_previous_databases, PgDropLock,
270 },
271 SqlxPostgresBackend,
272 };
273
274 fn create_backend(with_table: bool) -> SqlxPostgresBackend {
275 SqlxPostgresBackend::new(
276 PgConnectOptions::new()
277 .username("postgres")
278 .password("postgres"),
279 PgPoolOptions::new,
280 PgPoolOptions::new,
281 {
282 move |mut conn| {
283 if with_table {
284 Box::pin(async move {
285 conn.execute_many(CREATE_ENTITIES_STATEMENTS.join(";").as_str())
286 .collect::<Vec<_>>()
287 .await
288 .drain(..)
289 .collect::<Result<Vec<_>, _>>()
290 .unwrap();
291 conn
292 })
293 } else {
294 Box::pin(async { conn })
295 }
296 }
297 },
298 )
299 }
300
301 #[test(flavor = "multi_thread", shared)]
302 async fn backend_drops_previous_databases() {
303 test_backend_drops_previous_databases(
304 create_backend(false),
305 create_backend(false).drop_previous_databases(true),
306 create_backend(false).drop_previous_databases(false),
307 )
308 .await;
309 }
310
311 #[test(flavor = "multi_thread", shared)]
312 async fn backend_creates_database_with_restricted_privileges() {
313 let backend = create_backend(true).drop_previous_databases(false);
314 test_backend_creates_database_with_restricted_privileges(backend).await;
315 }
316
317 #[test(flavor = "multi_thread", shared)]
318 async fn backend_creates_database_with_unrestricted_privileges() {
319 let backend = create_backend(true).drop_previous_databases(false);
320 test_backend_creates_database_with_unrestricted_privileges(backend).await;
321 }
322
323 #[test(flavor = "multi_thread", shared)]
324 async fn backend_cleans_database_with_tables() {
325 let backend = create_backend(true).drop_previous_databases(false);
326 test_backend_cleans_database_with_tables(backend).await;
327 }
328
329 #[test(flavor = "multi_thread", shared)]
330 async fn backend_cleans_database_without_tables() {
331 let backend = create_backend(false).drop_previous_databases(false);
332 test_backend_cleans_database_without_tables(backend).await;
333 }
334
335 #[test(flavor = "multi_thread", shared)]
336 async fn backend_drops_restricted_database() {
337 let backend = create_backend(true).drop_previous_databases(false);
338 test_backend_drops_database(backend, true).await;
339 }
340
341 #[test(flavor = "multi_thread", shared)]
342 async fn backend_drops_unrestricted_database() {
343 let backend = create_backend(true).drop_previous_databases(false);
344 test_backend_drops_database(backend, false).await;
345 }
346
347 #[test(flavor = "multi_thread", shared)]
348 async fn pool_drops_previous_databases() {
349 test_pool_drops_previous_databases(
350 create_backend(false),
351 create_backend(false).drop_previous_databases(true),
352 create_backend(false).drop_previous_databases(false),
353 )
354 .await;
355 }
356
357 #[test(flavor = "multi_thread", shared)]
358 async fn pool_provides_isolated_databases() {
359 #[derive(FromRow, Eq, PartialEq, Debug)]
360 struct Book {
361 title: String,
362 }
363
364 const NUM_DBS: i64 = 3;
365
366 let backend = create_backend(true).drop_previous_databases(false);
367
368 async {
369 let db_pool = backend.create_database_pool().await.unwrap();
370 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
371
372 join_all(
374 conn_pools
375 .iter()
376 .enumerate()
377 .map(|(i, conn_pool)| async move {
378 query("INSERT INTO book (title) VALUES ($1)")
379 .bind(format!("Title {i}"))
380 .execute(&***conn_pool)
381 .await
382 .unwrap();
383 }),
384 )
385 .await;
386
387 join_all(
389 conn_pools
390 .iter()
391 .enumerate()
392 .map(|(i, conn_pool)| async move {
393 assert_eq!(
394 query_as::<_, Book>("SELECT title FROM book")
395 .fetch_all(&***conn_pool)
396 .await
397 .unwrap(),
398 vec![Book {
399 title: format!("Title {i}")
400 }]
401 );
402 }),
403 )
404 .await;
405 }
406 .lock_read()
407 .await;
408 }
409
410 #[test(flavor = "multi_thread", shared)]
411 async fn pool_provides_restricted_databases() {
412 let backend = create_backend(true).drop_previous_databases(false);
413
414 async {
415 let db_pool = backend.create_database_pool().await.unwrap();
416
417 let conn_pool = db_pool.pull_immutable().await;
418 let conn = &mut conn_pool.acquire().await.unwrap();
419
420 for stmt in DDL_STATEMENTS {
422 assert!(conn.execute(stmt).await.is_err());
423 }
424
425 for stmt in DML_STATEMENTS {
427 assert!(conn.execute(stmt).await.is_ok());
428 }
429 }
430 .lock_read()
431 .await;
432 }
433
434 #[test(flavor = "multi_thread", shared)]
435 async fn pool_provides_unrestricted_databases() {
436 let backend = create_backend(true).drop_previous_databases(false);
437
438 async {
439 let db_pool = backend.create_database_pool().await.unwrap();
440
441 {
443 let conn_pool = db_pool.create_mutable().await.unwrap();
444 let conn = &mut conn_pool.acquire().await.unwrap();
445 for stmt in DML_STATEMENTS {
446 assert!(conn.execute(stmt).await.is_ok());
447 }
448 }
449
450 for stmt in DDL_STATEMENTS {
452 let conn_pool = db_pool.create_mutable().await.unwrap();
453 let conn = &mut conn_pool.acquire().await.unwrap();
454 assert!(conn.execute(stmt).await.is_ok());
455 }
456 }
457 .lock_read()
458 .await;
459 }
460
461 #[test(flavor = "multi_thread", shared)]
462 async fn pool_provides_clean_databases() {
463 const NUM_DBS: i64 = 3;
464
465 let backend = create_backend(true).drop_previous_databases(false);
466
467 async {
468 let db_pool = backend.create_database_pool().await.unwrap();
469
470 {
472 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
473
474 join_all(conn_pools.iter().map(|conn_pool| async move {
476 assert_eq!(
477 query("SELECT COUNT(*) FROM book")
478 .fetch_one(&***conn_pool)
479 .await
480 .unwrap()
481 .get::<i64, _>(0),
482 0
483 );
484 }))
485 .await;
486
487 join_all(conn_pools.iter().map(|conn_pool| async move {
489 query("INSERT INTO book (title) VALUES ($1)")
490 .bind("Title")
491 .execute(&***conn_pool)
492 .await
493 .unwrap();
494 }))
495 .await;
496 }
497
498 {
500 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
501
502 join_all(conn_pools.iter().map(|conn_pool| async move {
504 assert_eq!(
505 query("SELECT COUNT(*) FROM book")
506 .fetch_one(&***conn_pool)
507 .await
508 .unwrap()
509 .get::<i64, _>(0),
510 0
511 );
512 }))
513 .await;
514 }
515 }
516 .lock_read()
517 .await;
518 }
519
520 #[test(flavor = "multi_thread", shared)]
521 async fn pool_drops_created_restricted_databases() {
522 let backend = create_backend(false);
523 test_pool_drops_created_restricted_databases(backend).await;
524 }
525
526 #[test(flavor = "multi_thread", shared)]
527 async fn pool_drops_created_unrestricted_database() {
528 let backend = create_backend(false);
529 test_pool_drops_created_unrestricted_database(backend).await;
530 }
531}