db_pool/async/backend/postgres/
tokio_postgres.rs1use std::{borrow::Cow, collections::HashMap, convert::Into, pin::Pin};
2
3use async_trait::async_trait;
4use futures::Future;
5use parking_lot::Mutex;
6use tokio_postgres::{Client, Config, NoTls};
7use uuid::Uuid;
8
9use crate::{common::statement::postgres, util::get_db_name};
10
11use super::{
12 super::{
13 common::{
14 error::tokio_postgres::{ConnectionError, QueryError},
15 pool::tokio_postgres::r#trait::TokioPostgresPoolAssociation,
16 },
17 error::Error as BackendError,
18 r#trait::Backend,
19 },
20 r#trait::{PostgresBackend, PostgresBackendWrapper},
21};
22
23type CreateEntities = dyn Fn(Client) -> Pin<Box<dyn Future<Output = Client> + Send + 'static>>
24 + Send
25 + Sync
26 + 'static;
27
28pub struct TokioPostgresBackend<P: TokioPostgresPoolAssociation> {
30 privileged_config: Config,
31 default_pool: P::Pool,
32 db_conns: Mutex<HashMap<Uuid, Client>>,
33 create_restricted_pool: Box<dyn Fn() -> P::Builder + Send + Sync + 'static>,
34 create_entities: Box<CreateEntities>,
35 drop_previous_databases_flag: bool,
36}
37
38impl<P: TokioPostgresPoolAssociation> TokioPostgresBackend<P> {
39 pub async fn new(
77 privileged_config: Config,
78 create_privileged_pool: impl Fn() -> P::Builder,
79 create_restricted_pool: impl Fn() -> P::Builder + Send + Sync + 'static,
80 create_entities: impl Fn(Client) -> Pin<Box<dyn Future<Output = Client> + Send + 'static>>
81 + Send
82 + Sync
83 + 'static,
84 ) -> Result<Self, P::BuildError> {
85 let builder = create_privileged_pool();
86 let default_pool = P::build_pool(builder, privileged_config.clone()).await?;
87
88 Ok(Self {
89 privileged_config,
90 default_pool,
91 db_conns: Mutex::new(HashMap::new()),
92 create_entities: Box::new(create_entities),
93 create_restricted_pool: Box::new(create_restricted_pool),
94 drop_previous_databases_flag: true,
95 })
96 }
97
98 #[must_use]
100 pub fn drop_previous_databases(self, value: bool) -> Self {
101 Self {
102 drop_previous_databases_flag: value,
103 ..self
104 }
105 }
106}
107
108#[async_trait]
109impl<'pool, P: TokioPostgresPoolAssociation> PostgresBackend<'pool> for TokioPostgresBackend<P> {
110 type Connection = Client;
111 type PooledConnection = P::PooledConnection<'pool>;
112 type Pool = P::Pool;
113
114 type BuildError = P::BuildError;
115 type PoolError = P::PoolError;
116 type ConnectionError = ConnectionError;
117 type QueryError = QueryError;
118
119 async fn execute_query(&self, query: &str, conn: &mut Client) -> Result<(), QueryError> {
120 conn.execute(query, &[]).await?;
121 Ok(())
122 }
123
124 async fn batch_execute_query<'a>(
125 &self,
126 query: impl IntoIterator<Item = Cow<'a, str>> + Send,
127 conn: &mut Client,
128 ) -> Result<(), QueryError> {
129 let query = query.into_iter().collect::<Vec<_>>().join(";");
130 conn.batch_execute(query.as_str()).await?;
131 Ok(())
132 }
133
134 async fn get_default_connection(
135 &'pool self,
136 ) -> Result<P::PooledConnection<'pool>, P::PoolError> {
137 P::get_connection(&self.default_pool).await
138 }
139
140 async fn establish_privileged_database_connection(
141 &self,
142 db_id: Uuid,
143 ) -> Result<Client, ConnectionError> {
144 let mut config = self.privileged_config.clone();
145 let db_name = get_db_name(db_id);
146 config.dbname(db_name.as_str());
147 let (client, connection) = config.connect(NoTls).await?;
148 tokio::spawn(connection);
149 Ok(client)
150 }
151
152 async fn establish_restricted_database_connection(
153 &self,
154 db_id: Uuid,
155 ) -> Result<Client, ConnectionError> {
156 let mut config = self.privileged_config.clone();
157 let db_name = get_db_name(db_id);
158 let db_name = db_name.as_str();
159 config.user(db_name).password(db_name).dbname(db_name);
160 let (client, connection) = config.connect(NoTls).await?;
161 tokio::spawn(connection);
162 Ok(client)
163 }
164
165 fn put_database_connection(&self, db_id: Uuid, conn: Client) {
166 self.db_conns.lock().insert(db_id, conn);
167 }
168
169 fn get_database_connection(&self, db_id: Uuid) -> Client {
170 self.db_conns
171 .lock()
172 .remove(&db_id)
173 .unwrap_or_else(|| panic!("connection map must have a connection for {db_id}"))
174 }
175
176 async fn get_previous_database_names(
177 &self,
178 conn: &mut Client,
179 ) -> Result<Vec<String>, QueryError> {
180 conn.query(postgres::GET_DATABASE_NAMES, &[])
181 .await
182 .map(|rows| rows.iter().map(|row| row.get(0)).collect())
183 .map_err(Into::into)
184 }
185
186 async fn create_entities(&self, conn: Client) -> Client {
187 (self.create_entities)(conn).await
188 }
189
190 async fn create_connection_pool(&self, db_id: Uuid) -> Result<P::Pool, P::BuildError> {
191 let db_name = get_db_name(db_id);
192 let db_name = db_name.as_str();
193 let mut config = self.privileged_config.clone();
194 config.dbname(db_name);
195 config.user(db_name);
196 config.password(db_name);
197 let builder = (self.create_restricted_pool)();
198 P::build_pool(builder, config).await
199 }
200
201 async fn get_table_names(
202 &self,
203 privileged_conn: &mut Client,
204 ) -> Result<Vec<String>, QueryError> {
205 privileged_conn
206 .query(postgres::GET_TABLE_NAMES, &[])
207 .await
208 .map(|rows| rows.iter().map(|row| row.get(0)).collect())
209 .map_err(Into::into)
210 }
211
212 fn get_drop_previous_databases(&self) -> bool {
213 self.drop_previous_databases_flag
214 }
215}
216
217type BError<BuildError, PoolError> =
218 BackendError<BuildError, PoolError, ConnectionError, QueryError>;
219
220#[async_trait]
221impl<P: TokioPostgresPoolAssociation> Backend for TokioPostgresBackend<P> {
222 type Pool = P::Pool;
223
224 type BuildError = P::BuildError;
225 type PoolError = P::PoolError;
226 type ConnectionError = ConnectionError;
227 type QueryError = QueryError;
228
229 async fn init(&self) -> Result<(), BError<P::BuildError, P::PoolError>> {
230 PostgresBackendWrapper::new(self).init().await
231 }
232
233 async fn create(
234 &self,
235 db_id: uuid::Uuid,
236 restrict_privileges: bool,
237 ) -> Result<P::Pool, BError<P::BuildError, P::PoolError>> {
238 PostgresBackendWrapper::new(self)
239 .create(db_id, restrict_privileges)
240 .await
241 }
242
243 async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError<P::BuildError, P::PoolError>> {
244 PostgresBackendWrapper::new(self).clean(db_id).await
245 }
246
247 async fn drop(
248 &self,
249 db_id: uuid::Uuid,
250 is_restricted: bool,
251 ) -> Result<(), BError<P::BuildError, P::PoolError>> {
252 PostgresBackendWrapper::new(self)
253 .drop(db_id, is_restricted)
254 .await
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 #![allow(clippy::unwrap_used, clippy::needless_return)]
261
262 use bb8::Pool;
263 use futures::future::join_all;
264 use tokio_postgres::Config;
265 use tokio_shared_rt::test;
266
267 use crate::{
268 common::statement::postgres::tests::{
269 CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
270 },
271 r#async::{
272 backend::{
273 common::pool::tokio_postgres::bb8::TokioPostgresBb8,
274 postgres::r#trait::tests::{
275 test_backend_creates_database_with_unrestricted_privileges,
276 test_backend_drops_database, test_pool_drops_created_unrestricted_database,
277 },
278 },
279 db_pool::DatabasePoolBuilder,
280 },
281 };
282
283 use super::{
284 super::r#trait::tests::{
285 test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables,
286 test_backend_creates_database_with_restricted_privileges,
287 test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
288 test_pool_drops_previous_databases, PgDropLock,
289 },
290 TokioPostgresBackend,
291 };
292
293 async fn create_backend(with_table: bool) -> TokioPostgresBackend<TokioPostgresBb8> {
294 let mut config = Config::new();
295 config
296 .host("localhost")
297 .user("postgres")
298 .password("postgres");
299 TokioPostgresBackend::new(config, Pool::builder, Pool::builder, {
300 move |conn| {
301 if with_table {
302 Box::pin(async move {
303 conn.batch_execute(&CREATE_ENTITIES_STATEMENTS.join(";"))
304 .await
305 .unwrap();
306 conn
307 })
308 } else {
309 Box::pin(async { conn })
310 }
311 }
312 })
313 .await
314 .unwrap()
315 }
316
317 #[test(flavor = "multi_thread", shared)]
318 async fn backend_drops_previous_databases() {
319 test_backend_drops_previous_databases(
320 create_backend(false).await,
321 create_backend(false).await.drop_previous_databases(true),
322 create_backend(false).await.drop_previous_databases(false),
323 )
324 .await;
325 }
326
327 #[test(flavor = "multi_thread", shared)]
328 async fn backend_creates_database_with_restricted_privileges() {
329 let backend = create_backend(true).await.drop_previous_databases(false);
330 test_backend_creates_database_with_restricted_privileges(backend).await;
331 }
332
333 #[test(flavor = "multi_thread", shared)]
334 async fn backend_creates_database_with_unrestricted_privileges() {
335 let backend = create_backend(true).await.drop_previous_databases(false);
336 test_backend_creates_database_with_unrestricted_privileges(backend).await;
337 }
338
339 #[test(flavor = "multi_thread", shared)]
340 async fn backend_cleans_database_with_tables() {
341 let backend = create_backend(true).await.drop_previous_databases(false);
342 test_backend_cleans_database_with_tables(backend).await;
343 }
344
345 #[test(flavor = "multi_thread", shared)]
346 async fn backend_cleans_database_without_tables() {
347 let backend = create_backend(false).await.drop_previous_databases(false);
348 test_backend_cleans_database_without_tables(backend).await;
349 }
350
351 #[test(flavor = "multi_thread", shared)]
352 async fn backend_drops_restricted_database() {
353 let backend = create_backend(true).await.drop_previous_databases(false);
354 test_backend_drops_database(backend, true).await;
355 }
356
357 #[test(flavor = "multi_thread", shared)]
358 async fn backend_drops_unrestricted_database() {
359 let backend = create_backend(true).await.drop_previous_databases(false);
360 test_backend_drops_database(backend, false).await;
361 }
362
363 #[test(flavor = "multi_thread", shared)]
364 async fn pool_drops_previous_databases() {
365 test_pool_drops_previous_databases(
366 create_backend(false).await,
367 create_backend(false).await.drop_previous_databases(true),
368 create_backend(false).await.drop_previous_databases(false),
369 )
370 .await;
371 }
372
373 #[test(flavor = "multi_thread", shared)]
374 async fn pool_provides_isolated_databases() {
375 const NUM_DBS: i64 = 3;
376
377 let backend = create_backend(true).await.drop_previous_databases(false);
378
379 async {
380 let db_pool = backend.create_database_pool().await.unwrap();
381 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
382
383 join_all(
385 conn_pools
386 .iter()
387 .enumerate()
388 .map(|(i, conn_pool)| async move {
389 let conn = &mut conn_pool.get().await.unwrap();
390 conn.execute(
391 "INSERT INTO book (title) VALUES ($1)",
392 &[&format!("Title {i}").as_str()],
393 )
394 .await
395 .unwrap();
396 }),
397 )
398 .await;
399
400 join_all(
402 conn_pools
403 .iter()
404 .enumerate()
405 .map(|(i, conn_pool)| async move {
406 let conn = &mut conn_pool.get().await.unwrap();
407 assert_eq!(
408 conn.query("SELECT title FROM book", &[])
409 .await
410 .unwrap()
411 .iter()
412 .map(|row| row.get::<_, String>(0))
413 .collect::<Vec<_>>(),
414 vec![format!("Title {i}")]
415 );
416 }),
417 )
418 .await;
419 }
420 .lock_read()
421 .await;
422 }
423
424 #[test(flavor = "multi_thread", shared)]
425 async fn pool_provides_restricted_databases() {
426 let backend = create_backend(true).await.drop_previous_databases(false);
427
428 async {
429 let db_pool = backend.create_database_pool().await.unwrap();
430
431 let conn_pool = db_pool.pull_immutable().await;
432 let conn = &mut conn_pool.get().await.unwrap();
433
434 for stmt in DDL_STATEMENTS {
436 assert!(conn.execute(stmt, &[]).await.is_err());
437 }
438
439 for stmt in DML_STATEMENTS {
441 assert!(conn.execute(stmt, &[]).await.is_ok());
442 }
443 }
444 .lock_read()
445 .await;
446 }
447
448 #[test(flavor = "multi_thread", shared)]
449 async fn pool_provides_unrestricted_databases() {
450 let backend = create_backend(true).await.drop_previous_databases(false);
451
452 async {
453 let db_pool = backend.create_database_pool().await.unwrap();
454
455 {
457 let conn_pool = db_pool.create_mutable().await.unwrap();
458 let conn = &mut conn_pool.get().await.unwrap();
459 for stmt in DML_STATEMENTS {
460 assert!(conn.execute(stmt, &[]).await.is_ok());
461 }
462 }
463
464 for stmt in DDL_STATEMENTS {
466 let conn_pool = db_pool.create_mutable().await.unwrap();
467 let conn = &mut conn_pool.get().await.unwrap();
468 assert!(conn.execute(stmt, &[]).await.is_ok());
469 }
470 }
471 .lock_read()
472 .await;
473 }
474
475 #[test(flavor = "multi_thread", shared)]
476 async fn pool_provides_clean_databases() {
477 const NUM_DBS: i64 = 3;
478
479 let backend = create_backend(true).await.drop_previous_databases(false);
480
481 async {
482 let db_pool = backend.create_database_pool().await.unwrap();
483
484 {
486 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
487
488 join_all(conn_pools.iter().map(|conn_pool| async move {
490 let conn = &mut conn_pool.get().await.unwrap();
491 assert_eq!(
492 conn.query_one("SELECT COUNT(*) FROM book", &[])
493 .await
494 .unwrap()
495 .get::<_, i64>(0),
496 0
497 );
498 }))
499 .await;
500
501 join_all(conn_pools.iter().map(|conn_pool| async move {
503 let conn = &mut conn_pool.get().await.unwrap();
504 conn.execute("INSERT INTO book (title) VALUES ($1)", &[&"Title"])
505 .await
506 .unwrap();
507 }))
508 .await;
509 }
510
511 {
513 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
514
515 join_all(conn_pools.iter().map(|conn_pool| async move {
517 let conn = &mut conn_pool.get().await.unwrap();
518 assert_eq!(
519 conn.query_one("SELECT COUNT(*) FROM book", &[])
520 .await
521 .unwrap()
522 .get::<_, i64>(0),
523 0
524 );
525 }))
526 .await;
527 }
528 }
529 .lock_read()
530 .await;
531 }
532
533 #[test(flavor = "multi_thread", shared)]
534 async fn pool_drops_created_restricted_databases() {
535 let backend = create_backend(false).await;
536 test_pool_drops_created_restricted_databases(backend).await;
537 }
538
539 #[test(flavor = "multi_thread", shared)]
540 async fn pool_drops_created_unrestricted_database() {
541 let backend = create_backend(false).await;
542 test_pool_drops_created_unrestricted_database(backend).await;
543 }
544}