1use std::{borrow::Cow, collections::HashMap, pin::Pin};
2
3use async_trait::async_trait;
4use futures::Future;
5use parking_lot::Mutex;
6use sea_orm::{
7 ActiveModelBehavior, ColumnTrait, ConnectOptions, ConnectionTrait, Database,
8 DatabaseConnection, DbErr, DeriveEntityModel, DerivePrimaryKey, DeriveRelation, EntityTrait,
9 EnumIter, FromQueryResult, PrimaryKeyTrait, QueryFilter, QuerySelect,
10};
11use uuid::Uuid;
12
13use crate::{common::config::PrivilegedPostgresConfig, util::get_db_name};
14
15use super::{
16 super::{
17 common::{
18 conn::sea_orm::PooledConnection,
19 error::sea_orm::{BuildError, ConnectionError, PoolError, QueryError},
20 },
21 error::Error as BackendError,
22 r#trait::Backend,
23 },
24 r#trait::{PostgresBackend, PostgresBackendWrapper},
25};
26
27type CreateEntities = dyn Fn(DatabaseConnection) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
28 + Send
29 + Sync
30 + 'static;
31
32pub struct SeaORMPostgresBackend {
34 privileged_config: PrivilegedPostgresConfig,
35 default_pool: DatabaseConnection,
36 db_conns: Mutex<HashMap<Uuid, DatabaseConnection>>,
37 create_restricted_pool: Box<dyn for<'tmp> Fn(&'tmp mut ConnectOptions) + Send + Sync + 'static>,
38 create_entities: Box<CreateEntities>,
39 drop_previous_databases_flag: bool,
40}
41
42impl SeaORMPostgresBackend {
43 pub async fn new(
83 privileged_config: PrivilegedPostgresConfig,
84 create_privileged_pool: impl for<'tmp> Fn(&'tmp mut ConnectOptions),
85 create_restricted_pool: impl for<'tmp> Fn(&'tmp mut ConnectOptions) + Send + Sync + 'static,
86 create_entities: impl Fn(
87 DatabaseConnection,
88 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
89 + Send
90 + Sync
91 + 'static,
92 ) -> Result<Self, DbErr> {
93 let mut opts = ConnectOptions::new(privileged_config.default_connection_url());
94 create_privileged_pool(&mut opts);
95 let default_pool = Database::connect(opts).await?;
96
97 Ok(Self {
98 privileged_config,
99 default_pool,
100 db_conns: Mutex::new(HashMap::new()),
101 create_restricted_pool: Box::new(create_restricted_pool),
102 create_entities: Box::new(create_entities),
103 drop_previous_databases_flag: true,
104 })
105 }
106
107 #[must_use]
109 pub fn drop_previous_databases(self, value: bool) -> Self {
110 Self {
111 drop_previous_databases_flag: value,
112 ..self
113 }
114 }
115}
116
117#[async_trait]
118impl<'pool> PostgresBackend<'pool> for SeaORMPostgresBackend {
119 type Connection = DatabaseConnection;
120 type PooledConnection = PooledConnection;
121 type Pool = DatabaseConnection;
122
123 type BuildError = BuildError;
124 type PoolError = PoolError;
125 type ConnectionError = ConnectionError;
126 type QueryError = QueryError;
127
128 async fn execute_query(
129 &self,
130 query: &str,
131 conn: &mut DatabaseConnection,
132 ) -> Result<(), QueryError> {
133 conn.execute_unprepared(query).await?;
134 Ok(())
135 }
136
137 async fn batch_execute_query<'a>(
138 &self,
139 query: impl IntoIterator<Item = Cow<'a, str>> + Send,
140 conn: &mut DatabaseConnection,
141 ) -> Result<(), QueryError> {
142 let query = query.into_iter().collect::<Vec<_>>().join(";");
143 self.execute_query(query.as_str(), conn).await
144 }
145
146 async fn get_default_connection(&'pool self) -> Result<PooledConnection, PoolError> {
147 Ok(self.default_pool.clone().into())
148 }
149
150 async fn establish_privileged_database_connection(
151 &self,
152 db_id: Uuid,
153 ) -> Result<DatabaseConnection, ConnectionError> {
154 let db_name = get_db_name(db_id);
155 let database_url = self
156 .privileged_config
157 .privileged_database_connection_url(db_name.as_str());
158 let opts = ConnectOptions::new(database_url);
159 Database::connect(opts).await.map_err(Into::into)
160 }
161
162 async fn establish_restricted_database_connection(
163 &self,
164 db_id: Uuid,
165 ) -> Result<DatabaseConnection, ConnectionError> {
166 let db_name = get_db_name(db_id);
167 let db_name = db_name.as_str();
168 let database_url = self.privileged_config.restricted_database_connection_url(
169 db_name,
170 Some(db_name),
171 db_name,
172 );
173 let opts = ConnectOptions::new(database_url);
174 Database::connect(opts).await.map_err(Into::into)
175 }
176
177 fn put_database_connection(&self, db_id: Uuid, conn: DatabaseConnection) {
178 self.db_conns.lock().insert(db_id, conn);
179 }
180
181 fn get_database_connection(&self, db_id: Uuid) -> DatabaseConnection {
182 self.db_conns
183 .lock()
184 .remove(&db_id)
185 .unwrap_or_else(|| panic!("connection map must have a connection for {db_id}"))
186 }
187
188 async fn get_previous_database_names(
189 &self,
190 conn: &mut DatabaseConnection,
191 ) -> Result<Vec<String>, QueryError> {
192 #[derive(Clone, Debug, DeriveEntityModel)]
193 #[sea_orm(table_name = "pg_database")]
194 pub struct Model {
195 #[sea_orm(primary_key)]
196 oid: i32,
197 datname: String,
198 }
199
200 #[derive(Debug, EnumIter, DeriveRelation)]
201 pub enum Relation {}
202
203 impl ActiveModelBehavior for ActiveModel {}
204
205 #[derive(FromQueryResult)]
206 struct QueryModel {
207 datname: String,
208 }
209
210 Entity::find()
211 .select_only()
212 .column(Column::Datname)
213 .filter(Column::Datname.like("db_pool_%"))
214 .into_model::<QueryModel>()
215 .all(conn)
216 .await
217 .map(|mut models| models.drain(..).map(|model| model.datname).collect())
218 .map_err(Into::into)
219 }
220
221 async fn create_entities(&self, conn: DatabaseConnection) -> Option<DatabaseConnection> {
222 (self.create_entities)(conn.clone()).await;
223 Some(conn)
224 }
225
226 async fn create_connection_pool(&self, db_id: Uuid) -> Result<DatabaseConnection, BuildError> {
227 let db_name = get_db_name(db_id);
228 let db_name = db_name.as_str();
229 let database_url = self.privileged_config.restricted_database_connection_url(
230 db_name,
231 Some(db_name),
232 db_name,
233 );
234 let mut opts = ConnectOptions::new(database_url);
235 (self.create_restricted_pool)(&mut opts);
236 Database::connect(opts).await.map_err(Into::into)
237 }
238
239 async fn get_table_names(
240 &self,
241 conn: &mut DatabaseConnection,
242 ) -> Result<Vec<String>, QueryError> {
243 #[derive(Clone, Debug, DeriveEntityModel)]
244 #[sea_orm(table_name = "pg_tables")]
245 pub struct Model {
246 schemaname: String,
247 #[sea_orm(primary_key)]
248 tablename: String,
249 }
250
251 #[derive(Debug, EnumIter, DeriveRelation)]
252 pub enum Relation {}
253
254 impl ActiveModelBehavior for ActiveModel {}
255
256 #[derive(FromQueryResult)]
257 struct QueryModel {
258 tablename: String,
259 }
260
261 Entity::find()
262 .select_only()
263 .column(Column::Tablename)
264 .filter(Column::Schemaname.is_not_in(["pg_catalog", "information_schema"]))
265 .into_model::<QueryModel>()
266 .all(conn)
267 .await
268 .map(|mut models| models.drain(..).map(|model| model.tablename).collect())
269 .map_err(Into::into)
270 }
271
272 fn get_drop_previous_databases(&self) -> bool {
273 self.drop_previous_databases_flag
274 }
275}
276
277type BError = BackendError<BuildError, PoolError, ConnectionError, QueryError>;
278
279#[async_trait]
280impl Backend for SeaORMPostgresBackend {
281 type Pool = DatabaseConnection;
282
283 type BuildError = BuildError;
284 type PoolError = PoolError;
285 type ConnectionError = ConnectionError;
286 type QueryError = QueryError;
287
288 async fn init(&self) -> Result<(), BError> {
289 PostgresBackendWrapper::new(self).init().await
290 }
291
292 async fn create(
293 &self,
294 db_id: uuid::Uuid,
295 restrict_privileges: bool,
296 ) -> Result<DatabaseConnection, BError> {
297 PostgresBackendWrapper::new(self)
298 .create(db_id, restrict_privileges)
299 .await
300 }
301
302 async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError> {
303 PostgresBackendWrapper::new(self).clean(db_id).await
304 }
305
306 async fn drop(&self, db_id: uuid::Uuid, is_restricted: bool) -> Result<(), BError> {
307 PostgresBackendWrapper::new(self)
308 .drop(db_id, is_restricted)
309 .await
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 #![allow(clippy::unwrap_used, clippy::needless_return)]
316
317 use dotenvy::dotenv;
318 use futures::future::join_all;
319 use sea_orm::{
320 ActiveModelBehavior, ActiveModelTrait, ConnectionTrait, DeriveEntityModel,
321 DerivePrimaryKey, DeriveRelation, EntityTrait, EnumIter, FromQueryResult, PaginatorTrait,
322 PrimaryKeyTrait, QuerySelect, Set,
323 };
324 use tokio_shared_rt::test;
325
326 use crate::{
327 r#async::{
328 backend::postgres::r#trait::tests::{
329 test_backend_drops_database, test_pool_drops_created_unrestricted_database,
330 },
331 db_pool::DatabasePoolBuilder,
332 },
333 common::{
334 config::PrivilegedPostgresConfig,
335 statement::postgres::tests::{
336 CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
337 },
338 },
339 };
340
341 use super::{
342 super::r#trait::tests::{
343 PgDropLock, test_backend_cleans_database_with_tables,
344 test_backend_cleans_database_without_tables,
345 test_backend_creates_database_with_restricted_privileges,
346 test_backend_creates_database_with_unrestricted_privileges,
347 test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
348 test_pool_drops_previous_databases,
349 },
350 SeaORMPostgresBackend,
351 };
352
353 #[derive(Clone, Debug, DeriveEntityModel)]
354 #[sea_orm(table_name = "book")]
355 pub struct Model {
356 #[sea_orm(primary_key)]
357 id: i32,
358 title: String,
359 }
360
361 #[derive(Debug, EnumIter, DeriveRelation)]
362 pub enum Relation {}
363
364 impl ActiveModelBehavior for ActiveModel {}
365
366 async fn create_backend(with_table: bool) -> SeaORMPostgresBackend {
367 dotenv().ok();
368
369 let config = PrivilegedPostgresConfig::from_env().unwrap();
370
371 SeaORMPostgresBackend::new(config, |_| {}, |_| {}, {
372 move |conn| {
373 if with_table {
374 Box::pin(async move {
375 conn.execute_unprepared(CREATE_ENTITIES_STATEMENTS.join(";").as_str())
376 .await
377 .unwrap();
378 })
379 } else {
380 Box::pin(async {})
381 }
382 }
383 })
384 .await
385 .unwrap()
386 }
387
388 #[test(flavor = "multi_thread", shared)]
389 async fn backend_drops_previous_databases() {
390 test_backend_drops_previous_databases(
391 create_backend(false).await,
392 create_backend(false).await.drop_previous_databases(true),
393 create_backend(false).await.drop_previous_databases(false),
394 )
395 .await;
396 }
397
398 #[test(flavor = "multi_thread", shared)]
399 async fn backend_creates_database_with_restricted_privileges() {
400 let backend = create_backend(true).await.drop_previous_databases(false);
401 test_backend_creates_database_with_restricted_privileges(backend).await;
402 }
403
404 #[test(flavor = "multi_thread", shared)]
405 async fn backend_creates_database_with_unrestricted_privileges() {
406 let backend = create_backend(true).await.drop_previous_databases(false);
407 test_backend_creates_database_with_unrestricted_privileges(backend).await;
408 }
409
410 #[test(flavor = "multi_thread", shared)]
411 async fn backend_cleans_database_with_tables() {
412 let backend = create_backend(true).await.drop_previous_databases(false);
413 test_backend_cleans_database_with_tables(backend).await;
414 }
415
416 #[test(flavor = "multi_thread", shared)]
417 async fn backend_cleans_database_without_tables() {
418 let backend = create_backend(false).await.drop_previous_databases(false);
419 test_backend_cleans_database_without_tables(backend).await;
420 }
421
422 #[test(flavor = "multi_thread", shared)]
423 async fn backend_drops_restricted_database() {
424 let backend = create_backend(true).await.drop_previous_databases(false);
425 test_backend_drops_database(backend, true).await;
426 }
427
428 #[test(flavor = "multi_thread", shared)]
429 async fn backend_drops_unrestricted_database() {
430 let backend = create_backend(true).await.drop_previous_databases(false);
431 test_backend_drops_database(backend, false).await;
432 }
433
434 #[test(flavor = "multi_thread", shared)]
435 async fn pool_drops_previous_databases() {
436 test_pool_drops_previous_databases(
437 create_backend(false).await,
438 create_backend(false).await.drop_previous_databases(true),
439 create_backend(false).await.drop_previous_databases(false),
440 )
441 .await;
442 }
443
444 #[test(flavor = "multi_thread", shared)]
445 async fn pool_provides_isolated_databases() {
446 #[derive(FromQueryResult, Eq, PartialEq, Debug)]
447 struct QueryModel {
448 title: String,
449 }
450
451 const NUM_DBS: i64 = 3;
452
453 let backend = create_backend(true).await.drop_previous_databases(false);
454
455 async {
456 let db_pool = backend.create_database_pool().await.unwrap();
457 let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
458
459 join_all(conns.iter().enumerate().map(|(i, conn)| async move {
461 let book = ActiveModel {
462 title: Set(format!("Title {i}")),
463 ..Default::default()
464 };
465 book.insert(&***conn).await.unwrap();
466 }))
467 .await;
468
469 join_all(conns.iter().enumerate().map(|(i, conn)| async move {
471 assert_eq!(
472 Entity::find()
473 .select_only()
474 .column(Column::Title)
475 .into_model::<QueryModel>()
476 .all(&***conn)
477 .await
478 .unwrap(),
479 vec![QueryModel {
480 title: format!("Title {i}")
481 }]
482 );
483 }))
484 .await;
485 }
486 .lock_read()
487 .await;
488 }
489
490 #[test(flavor = "multi_thread", shared)]
491 async fn pool_provides_restricted_databases() {
492 let backend = create_backend(true).await.drop_previous_databases(false);
493
494 async {
495 let db_pool = backend.create_database_pool().await.unwrap();
496 let conn = db_pool.pull_immutable().await;
497
498 for stmt in DDL_STATEMENTS {
500 assert!(conn.execute_unprepared(stmt).await.is_err());
501 }
502
503 for stmt in DML_STATEMENTS {
505 assert!(conn.execute_unprepared(stmt).await.is_ok());
506 }
507 }
508 .lock_read()
509 .await;
510 }
511
512 #[test(flavor = "multi_thread", shared)]
513 async fn pool_provides_unrestricted_databases() {
514 let backend = create_backend(true).await.drop_previous_databases(false);
515
516 async {
517 let db_pool = backend.create_database_pool().await.unwrap();
518
519 {
521 let conn = db_pool.create_mutable().await.unwrap();
522 for stmt in DML_STATEMENTS {
523 assert!(conn.execute_unprepared(stmt).await.is_ok());
524 }
525 }
526
527 for stmt in DDL_STATEMENTS {
529 let conn = db_pool.create_mutable().await.unwrap();
530 assert!(conn.execute_unprepared(stmt).await.is_ok());
531 }
532 }
533 .lock_read()
534 .await;
535 }
536
537 #[test(flavor = "multi_thread", shared)]
538 async fn pool_provides_clean_databases() {
539 const NUM_DBS: i64 = 3;
540
541 let backend = create_backend(true).await.drop_previous_databases(false);
542
543 async {
544 let db_pool = backend.create_database_pool().await.unwrap();
545
546 {
548 let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
549
550 join_all(conns.iter().map(|conn| async move {
552 assert_eq!(Entity::find().count(&***conn).await.unwrap(), 0);
553 }))
554 .await;
555
556 join_all(conns.iter().map(|conn| async move {
558 let book = ActiveModel {
559 title: Set("Title".to_owned()),
560 ..Default::default()
561 };
562 book.insert(&***conn).await.unwrap();
563 }))
564 .await;
565 }
566
567 {
569 let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
570
571 join_all(conns.iter().map(|conn| async move {
573 assert_eq!(Entity::find().count(&***conn).await.unwrap(), 0);
574 }))
575 .await;
576 }
577 }
578 .lock_read()
579 .await;
580 }
581
582 #[test(flavor = "multi_thread", shared)]
583 async fn pool_drops_created_restricted_databases() {
584 let backend = create_backend(false).await;
585 test_pool_drops_created_restricted_databases(backend).await;
586 }
587
588 #[test(flavor = "multi_thread", shared)]
589 async fn pool_drops_created_unrestricted_database() {
590 let backend = create_backend(false).await;
591 test_pool_drops_created_unrestricted_database(backend).await;
592 }
593}