db_pool/async/backend/postgres/
tokio_postgres.rs1use std::{borrow::Cow, collections::HashMap, convert::Into, pin::Pin};
2
3use async_trait::async_trait;
4use deadpool_postgres::Manager;
5use futures::Future;
6use parking_lot::Mutex;
7use tokio_postgres::{Client, Config, NoTls};
8use uuid::Uuid;
9
10use crate::{common::statement::postgres, util::get_db_name};
11
12use super::{
13 super::{
14 common::{
15 error::tokio_postgres::{ConnectionError, QueryError},
16 pool::tokio_postgres::r#trait::TokioPostgresPoolAssociation,
17 },
18 error::Error as BackendError,
19 r#trait::Backend,
20 },
21 r#trait::{PostgresBackend, PostgresBackendWrapper},
22};
23
24type CreateEntities = dyn Fn(Client) -> Pin<Box<dyn Future<Output = Client> + Send + 'static>>
25 + Send
26 + Sync
27 + 'static;
28
29pub struct TokioPostgresBackend<P: TokioPostgresPoolAssociation> {
31 privileged_config: Config,
32 default_pool: P::Pool,
33 db_conns: Mutex<HashMap<Uuid, Client>>,
34 create_restricted_pool: Box<dyn Fn(Manager) -> P::Builder + Send + Sync + 'static>,
35 create_entities: Box<CreateEntities>,
36 drop_previous_databases_flag: bool,
37}
38
39impl<P: TokioPostgresPoolAssociation> TokioPostgresBackend<P> {
40 pub async fn new(
78 privileged_config: Config,
79 create_privileged_pool: impl Fn(Manager) -> P::Builder,
80 create_restricted_pool: impl Fn(Manager) -> P::Builder + Send + Sync + 'static,
81 create_entities: impl Fn(Client) -> Pin<Box<dyn Future<Output = Client> + Send + 'static>>
82 + Send
83 + Sync
84 + 'static,
85 ) -> Result<Self, P::BuildError> {
86 let manager = Manager::new(privileged_config.clone(), NoTls);
87 let builder = create_privileged_pool(manager);
88 let default_pool = P::build_pool(builder, privileged_config.clone()).await?;
89
90 Ok(Self {
91 privileged_config,
92 default_pool,
93 db_conns: Mutex::new(HashMap::new()),
94 create_entities: Box::new(create_entities),
95 create_restricted_pool: Box::new(create_restricted_pool),
96 drop_previous_databases_flag: true,
97 })
98 }
99
100 #[must_use]
102 pub fn drop_previous_databases(self, value: bool) -> Self {
103 Self {
104 drop_previous_databases_flag: value,
105 ..self
106 }
107 }
108}
109
110#[async_trait]
111impl<'pool, P: TokioPostgresPoolAssociation> PostgresBackend<'pool> for TokioPostgresBackend<P> {
112 type Connection = Client;
113 type PooledConnection = P::PooledConnection<'pool>;
114 type Pool = P::Pool;
115
116 type BuildError = P::BuildError;
117 type PoolError = P::PoolError;
118 type ConnectionError = ConnectionError;
119 type QueryError = QueryError;
120
121 async fn execute_query(&self, query: &str, conn: &mut Client) -> Result<(), QueryError> {
122 conn.execute(query, &[]).await?;
123 Ok(())
124 }
125
126 async fn batch_execute_query<'a>(
127 &self,
128 query: impl IntoIterator<Item = Cow<'a, str>> + Send,
129 conn: &mut Client,
130 ) -> Result<(), QueryError> {
131 let query = query.into_iter().collect::<Vec<_>>().join(";");
132 conn.batch_execute(query.as_str()).await?;
133 Ok(())
134 }
135
136 async fn get_default_connection(
137 &'pool self,
138 ) -> Result<P::PooledConnection<'pool>, P::PoolError> {
139 P::get_connection(&self.default_pool).await
140 }
141
142 async fn establish_privileged_database_connection(
143 &self,
144 db_id: Uuid,
145 ) -> Result<Client, ConnectionError> {
146 let mut config = self.privileged_config.clone();
147 let db_name = get_db_name(db_id);
148 config.dbname(db_name.as_str());
149 let (client, connection) = config.connect(NoTls).await?;
150 tokio::spawn(connection);
151 Ok(client)
152 }
153
154 async fn establish_restricted_database_connection(
155 &self,
156 db_id: Uuid,
157 ) -> Result<Client, ConnectionError> {
158 let mut config = self.privileged_config.clone();
159 let db_name = get_db_name(db_id);
160 let db_name = db_name.as_str();
161 config.user(db_name).password(db_name).dbname(db_name);
162 let (client, connection) = config.connect(NoTls).await?;
163 tokio::spawn(connection);
164 Ok(client)
165 }
166
167 fn put_database_connection(&self, db_id: Uuid, conn: Client) {
168 self.db_conns.lock().insert(db_id, conn);
169 }
170
171 fn get_database_connection(&self, db_id: Uuid) -> Client {
172 self.db_conns
173 .lock()
174 .remove(&db_id)
175 .unwrap_or_else(|| panic!("connection map must have a connection for {db_id}"))
176 }
177
178 async fn get_previous_database_names(
179 &self,
180 conn: &mut Client,
181 ) -> Result<Vec<String>, QueryError> {
182 conn.query(postgres::GET_DATABASE_NAMES, &[])
183 .await
184 .map(|rows| rows.iter().map(|row| row.get(0)).collect())
185 .map_err(Into::into)
186 }
187
188 async fn create_entities(&self, conn: Client) -> Option<Client> {
189 Some((self.create_entities)(conn).await)
190 }
191
192 async fn create_connection_pool(&self, db_id: Uuid) -> Result<P::Pool, P::BuildError> {
193 let db_name = get_db_name(db_id);
194 let db_name = db_name.as_str();
195 let mut config = self.privileged_config.clone();
196 config.dbname(db_name);
197 config.user(db_name);
198 config.password(db_name);
199 let manager = Manager::new(config.clone(), NoTls);
200 let builder = (self.create_restricted_pool)(manager);
201 P::build_pool(builder, config).await
202 }
203
204 async fn get_table_names(
205 &self,
206 privileged_conn: &mut Client,
207 ) -> Result<Vec<String>, QueryError> {
208 privileged_conn
209 .query(postgres::GET_TABLE_NAMES, &[])
210 .await
211 .map(|rows| rows.iter().map(|row| row.get(0)).collect())
212 .map_err(Into::into)
213 }
214
215 fn get_drop_previous_databases(&self) -> bool {
216 self.drop_previous_databases_flag
217 }
218}
219
220type BError<BuildError, PoolError> =
221 BackendError<BuildError, PoolError, ConnectionError, QueryError>;
222
223#[async_trait]
224impl<P: TokioPostgresPoolAssociation> Backend for TokioPostgresBackend<P> {
225 type Pool = P::Pool;
226
227 type BuildError = P::BuildError;
228 type PoolError = P::PoolError;
229 type ConnectionError = ConnectionError;
230 type QueryError = QueryError;
231
232 async fn init(&self) -> Result<(), BError<P::BuildError, P::PoolError>> {
233 PostgresBackendWrapper::new(self).init().await
234 }
235
236 async fn create(
237 &self,
238 db_id: uuid::Uuid,
239 restrict_privileges: bool,
240 ) -> Result<P::Pool, BError<P::BuildError, P::PoolError>> {
241 PostgresBackendWrapper::new(self)
242 .create(db_id, restrict_privileges)
243 .await
244 }
245
246 async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError<P::BuildError, P::PoolError>> {
247 PostgresBackendWrapper::new(self).clean(db_id).await
248 }
249
250 async fn drop(
251 &self,
252 db_id: uuid::Uuid,
253 is_restricted: bool,
254 ) -> Result<(), BError<P::BuildError, P::PoolError>> {
255 PostgresBackendWrapper::new(self)
256 .drop(db_id, is_restricted)
257 .await
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 #![allow(clippy::unwrap_used, clippy::needless_return)]
264
265 use bb8::Pool;
266 use futures::future::join_all;
267 use tokio_postgres::Config;
268 use tokio_shared_rt::test;
269
270 use crate::{
271 r#async::{
272 backend::{
273 common::pool::tokio_postgres::bb8::TokioPostgresBb8,
274 postgres::r#trait::tests::{
275 test_backend_creates_database_with_unrestricted_privileges,
276 test_backend_drops_database, test_pool_drops_created_unrestricted_database,
277 },
278 },
279 db_pool::DatabasePoolBuilder,
280 },
281 common::statement::postgres::tests::{
282 CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
283 },
284 };
285
286 use super::{
287 super::r#trait::tests::{
288 PgDropLock, test_backend_cleans_database_with_tables,
289 test_backend_cleans_database_without_tables,
290 test_backend_creates_database_with_restricted_privileges,
291 test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
292 test_pool_drops_previous_databases,
293 },
294 TokioPostgresBackend,
295 };
296
297 async fn create_backend(with_table: bool) -> TokioPostgresBackend<TokioPostgresBb8> {
298 let mut config = Config::new();
299 config
300 .host("localhost")
301 .user("postgres")
302 .password("postgres");
303 TokioPostgresBackend::new(config, |_| Pool::builder(), |_| Pool::builder(), {
304 move |conn| {
305 if with_table {
306 Box::pin(async move {
307 conn.batch_execute(&CREATE_ENTITIES_STATEMENTS.join(";"))
308 .await
309 .unwrap();
310 conn
311 })
312 } else {
313 Box::pin(async { conn })
314 }
315 }
316 })
317 .await
318 .unwrap()
319 }
320
321 #[test(flavor = "multi_thread", shared)]
322 async fn backend_drops_previous_databases() {
323 test_backend_drops_previous_databases(
324 create_backend(false).await,
325 create_backend(false).await.drop_previous_databases(true),
326 create_backend(false).await.drop_previous_databases(false),
327 )
328 .await;
329 }
330
331 #[test(flavor = "multi_thread", shared)]
332 async fn backend_creates_database_with_restricted_privileges() {
333 let backend = create_backend(true).await.drop_previous_databases(false);
334 test_backend_creates_database_with_restricted_privileges(backend).await;
335 }
336
337 #[test(flavor = "multi_thread", shared)]
338 async fn backend_creates_database_with_unrestricted_privileges() {
339 let backend = create_backend(true).await.drop_previous_databases(false);
340 test_backend_creates_database_with_unrestricted_privileges(backend).await;
341 }
342
343 #[test(flavor = "multi_thread", shared)]
344 async fn backend_cleans_database_with_tables() {
345 let backend = create_backend(true).await.drop_previous_databases(false);
346 test_backend_cleans_database_with_tables(backend).await;
347 }
348
349 #[test(flavor = "multi_thread", shared)]
350 async fn backend_cleans_database_without_tables() {
351 let backend = create_backend(false).await.drop_previous_databases(false);
352 test_backend_cleans_database_without_tables(backend).await;
353 }
354
355 #[test(flavor = "multi_thread", shared)]
356 async fn backend_drops_restricted_database() {
357 let backend = create_backend(true).await.drop_previous_databases(false);
358 test_backend_drops_database(backend, true).await;
359 }
360
361 #[test(flavor = "multi_thread", shared)]
362 async fn backend_drops_unrestricted_database() {
363 let backend = create_backend(true).await.drop_previous_databases(false);
364 test_backend_drops_database(backend, false).await;
365 }
366
367 #[test(flavor = "multi_thread", shared)]
368 async fn pool_drops_previous_databases() {
369 test_pool_drops_previous_databases(
370 create_backend(false).await,
371 create_backend(false).await.drop_previous_databases(true),
372 create_backend(false).await.drop_previous_databases(false),
373 )
374 .await;
375 }
376
377 #[test(flavor = "multi_thread", shared)]
378 async fn pool_provides_isolated_databases() {
379 const NUM_DBS: i64 = 3;
380
381 let backend = create_backend(true).await.drop_previous_databases(false);
382
383 async {
384 let db_pool = backend.create_database_pool().await.unwrap();
385 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
386
387 join_all(
389 conn_pools
390 .iter()
391 .enumerate()
392 .map(|(i, conn_pool)| async move {
393 let conn = &mut conn_pool.get().await.unwrap();
394 conn.execute(
395 "INSERT INTO book (title) VALUES ($1)",
396 &[&format!("Title {i}").as_str()],
397 )
398 .await
399 .unwrap();
400 }),
401 )
402 .await;
403
404 join_all(
406 conn_pools
407 .iter()
408 .enumerate()
409 .map(|(i, conn_pool)| async move {
410 let conn = &mut conn_pool.get().await.unwrap();
411 assert_eq!(
412 conn.query("SELECT title FROM book", &[])
413 .await
414 .unwrap()
415 .iter()
416 .map(|row| row.get::<_, String>(0))
417 .collect::<Vec<_>>(),
418 vec![format!("Title {i}")]
419 );
420 }),
421 )
422 .await;
423 }
424 .lock_read()
425 .await;
426 }
427
428 #[test(flavor = "multi_thread", shared)]
429 async fn pool_provides_restricted_databases() {
430 let backend = create_backend(true).await.drop_previous_databases(false);
431
432 async {
433 let db_pool = backend.create_database_pool().await.unwrap();
434
435 let conn_pool = db_pool.pull_immutable().await;
436 let conn = &mut conn_pool.get().await.unwrap();
437
438 for stmt in DDL_STATEMENTS {
440 assert!(conn.execute(stmt, &[]).await.is_err());
441 }
442
443 for stmt in DML_STATEMENTS {
445 assert!(conn.execute(stmt, &[]).await.is_ok());
446 }
447 }
448 .lock_read()
449 .await;
450 }
451
452 #[test(flavor = "multi_thread", shared)]
453 async fn pool_provides_unrestricted_databases() {
454 let backend = create_backend(true).await.drop_previous_databases(false);
455
456 async {
457 let db_pool = backend.create_database_pool().await.unwrap();
458
459 {
461 let conn_pool = db_pool.create_mutable().await.unwrap();
462 let conn = &mut conn_pool.get().await.unwrap();
463 for stmt in DML_STATEMENTS {
464 assert!(conn.execute(stmt, &[]).await.is_ok());
465 }
466 }
467
468 for stmt in DDL_STATEMENTS {
470 let conn_pool = db_pool.create_mutable().await.unwrap();
471 let conn = &mut conn_pool.get().await.unwrap();
472 assert!(conn.execute(stmt, &[]).await.is_ok());
473 }
474 }
475 .lock_read()
476 .await;
477 }
478
479 #[test(flavor = "multi_thread", shared)]
480 async fn pool_provides_clean_databases() {
481 const NUM_DBS: i64 = 3;
482
483 let backend = create_backend(true).await.drop_previous_databases(false);
484
485 async {
486 let db_pool = backend.create_database_pool().await.unwrap();
487
488 {
490 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
491
492 join_all(conn_pools.iter().map(|conn_pool| async move {
494 let conn = &mut conn_pool.get().await.unwrap();
495 assert_eq!(
496 conn.query_one("SELECT COUNT(*) FROM book", &[])
497 .await
498 .unwrap()
499 .get::<_, i64>(0),
500 0
501 );
502 }))
503 .await;
504
505 join_all(conn_pools.iter().map(|conn_pool| async move {
507 let conn = &mut conn_pool.get().await.unwrap();
508 conn.execute("INSERT INTO book (title) VALUES ($1)", &[&"Title"])
509 .await
510 .unwrap();
511 }))
512 .await;
513 }
514
515 {
517 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
518
519 join_all(conn_pools.iter().map(|conn_pool| async move {
521 let conn = &mut conn_pool.get().await.unwrap();
522 assert_eq!(
523 conn.query_one("SELECT COUNT(*) FROM book", &[])
524 .await
525 .unwrap()
526 .get::<_, i64>(0),
527 0
528 );
529 }))
530 .await;
531 }
532 }
533 .lock_read()
534 .await;
535 }
536
537 #[test(flavor = "multi_thread", shared)]
538 async fn pool_drops_created_restricted_databases() {
539 let backend = create_backend(false).await;
540 test_pool_drops_created_restricted_databases(backend).await;
541 }
542
543 #[test(flavor = "multi_thread", shared)]
544 async fn pool_drops_created_unrestricted_database() {
545 let backend = create_backend(false).await;
546 test_pool_drops_created_unrestricted_database(backend).await;
547 }
548}