1use std::{borrow::Cow, collections::HashMap, pin::Pin};
2
3use async_trait::async_trait;
4use futures::Future;
5use parking_lot::Mutex;
6use sea_orm::{
7 ActiveModelBehavior, ColumnTrait, ConnectOptions, ConnectionTrait, Database,
8 DatabaseConnection, DbErr, DeriveEntityModel, DerivePrimaryKey, DeriveRelation, EntityTrait,
9 EnumIter, FromQueryResult, PrimaryKeyTrait, QueryFilter, QuerySelect,
10};
11use uuid::Uuid;
12
13use crate::{common::config::PrivilegedPostgresConfig, util::get_db_name};
14
15use super::{
16 super::{
17 common::{
18 conn::sea_orm::PooledConnection,
19 error::sea_orm::{BuildError, ConnectionError, PoolError, QueryError},
20 },
21 error::Error as BackendError,
22 r#trait::Backend,
23 },
24 r#trait::{PostgresBackend, PostgresBackendWrapper},
25};
26
27type CreateEntities = dyn Fn(DatabaseConnection) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
28 + Send
29 + Sync
30 + 'static;
31
32pub struct SeaORMPostgresBackend {
34 privileged_config: PrivilegedPostgresConfig,
35 default_pool: DatabaseConnection,
36 db_conns: Mutex<HashMap<Uuid, DatabaseConnection>>,
37 create_restricted_pool: Box<dyn for<'tmp> Fn(&'tmp mut ConnectOptions) + Send + Sync + 'static>,
38 create_entities: Box<CreateEntities>,
39 drop_previous_databases_flag: bool,
40}
41
42impl SeaORMPostgresBackend {
43 pub async fn new(
83 privileged_config: PrivilegedPostgresConfig,
84 create_privileged_pool: impl for<'tmp> Fn(&'tmp mut ConnectOptions),
85 create_restricted_pool: impl for<'tmp> Fn(&'tmp mut ConnectOptions) + Send + Sync + 'static,
86 create_entities: impl Fn(DatabaseConnection) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
87 + Send
88 + Sync
89 + 'static,
90 ) -> Result<Self, DbErr> {
91 let mut opts = ConnectOptions::new(privileged_config.default_connection_url());
92 create_privileged_pool(&mut opts);
93 let default_pool = Database::connect(opts).await?;
94
95 Ok(Self {
96 privileged_config,
97 default_pool,
98 db_conns: Mutex::new(HashMap::new()),
99 create_restricted_pool: Box::new(create_restricted_pool),
100 create_entities: Box::new(create_entities),
101 drop_previous_databases_flag: true,
102 })
103 }
104
105 #[must_use]
107 pub fn drop_previous_databases(self, value: bool) -> Self {
108 Self {
109 drop_previous_databases_flag: value,
110 ..self
111 }
112 }
113}
114
115#[async_trait]
116impl<'pool> PostgresBackend<'pool> for SeaORMPostgresBackend {
117 type Connection = DatabaseConnection;
118 type PooledConnection = PooledConnection;
119 type Pool = DatabaseConnection;
120
121 type BuildError = BuildError;
122 type PoolError = PoolError;
123 type ConnectionError = ConnectionError;
124 type QueryError = QueryError;
125
126 async fn execute_query(
127 &self,
128 query: &str,
129 conn: &mut DatabaseConnection,
130 ) -> Result<(), QueryError> {
131 conn.execute_unprepared(query).await?;
132 Ok(())
133 }
134
135 async fn batch_execute_query<'a>(
136 &self,
137 query: impl IntoIterator<Item = Cow<'a, str>> + Send,
138 conn: &mut DatabaseConnection,
139 ) -> Result<(), QueryError> {
140 let query = query.into_iter().collect::<Vec<_>>().join(";");
141 self.execute_query(query.as_str(), conn).await
142 }
143
144 async fn get_default_connection(&'pool self) -> Result<PooledConnection, PoolError> {
145 Ok(self.default_pool.clone().into())
146 }
147
148 async fn establish_privileged_database_connection(
149 &self,
150 db_id: Uuid,
151 ) -> Result<DatabaseConnection, ConnectionError> {
152 let db_name = get_db_name(db_id);
153 let database_url = self
154 .privileged_config
155 .privileged_database_connection_url(db_name.as_str());
156 let opts = ConnectOptions::new(database_url);
157 Database::connect(opts).await.map_err(Into::into)
158 }
159
160 async fn establish_restricted_database_connection(
161 &self,
162 db_id: Uuid,
163 ) -> Result<DatabaseConnection, ConnectionError> {
164 let db_name = get_db_name(db_id);
165 let db_name = db_name.as_str();
166 let database_url = self.privileged_config.restricted_database_connection_url(
167 db_name,
168 Some(db_name),
169 db_name,
170 );
171 let opts = ConnectOptions::new(database_url);
172 Database::connect(opts).await.map_err(Into::into)
173 }
174
175 fn put_database_connection(&self, db_id: Uuid, conn: DatabaseConnection) {
176 self.db_conns.lock().insert(db_id, conn);
177 }
178
179 fn get_database_connection(&self, db_id: Uuid) -> DatabaseConnection {
180 self.db_conns
181 .lock()
182 .remove(&db_id)
183 .unwrap_or_else(|| panic!("connection map must have a connection for {db_id}"))
184 }
185
186 async fn get_previous_database_names(
187 &self,
188 conn: &mut DatabaseConnection,
189 ) -> Result<Vec<String>, QueryError> {
190 #[derive(Clone, Debug, DeriveEntityModel)]
191 #[sea_orm(table_name = "pg_database")]
192 pub struct Model {
193 #[sea_orm(primary_key)]
194 oid: i32,
195 datname: String,
196 }
197
198 #[derive(Debug, EnumIter, DeriveRelation)]
199 pub enum Relation {}
200
201 impl ActiveModelBehavior for ActiveModel {}
202
203 #[derive(FromQueryResult)]
204 struct QueryModel {
205 datname: String,
206 }
207
208 Entity::find()
209 .select_only()
210 .column(Column::Datname)
211 .filter(Column::Datname.like("db_pool_%"))
212 .into_model::<QueryModel>()
213 .all(conn)
214 .await
215 .map(|mut models| models.drain(..).map(|model| model.datname).collect())
216 .map_err(Into::into)
217 }
218
219 async fn create_entities(&self, conn: DatabaseConnection) -> DatabaseConnection {
220 (self.create_entities)(conn.clone()).await;
221 conn
222 }
223
224 async fn create_connection_pool(&self, db_id: Uuid) -> Result<DatabaseConnection, BuildError> {
225 let db_name = get_db_name(db_id);
226 let db_name = db_name.as_str();
227 let database_url = self.privileged_config.restricted_database_connection_url(
228 db_name,
229 Some(db_name),
230 db_name,
231 );
232 let mut opts = ConnectOptions::new(database_url);
233 (self.create_restricted_pool)(&mut opts);
234 Database::connect(opts).await.map_err(Into::into)
235 }
236
237 async fn get_table_names(
238 &self,
239 conn: &mut DatabaseConnection,
240 ) -> Result<Vec<String>, QueryError> {
241 #[derive(Clone, Debug, DeriveEntityModel)]
242 #[sea_orm(table_name = "pg_tables")]
243 pub struct Model {
244 schemaname: String,
245 #[sea_orm(primary_key)]
246 tablename: String,
247 }
248
249 #[derive(Debug, EnumIter, DeriveRelation)]
250 pub enum Relation {}
251
252 impl ActiveModelBehavior for ActiveModel {}
253
254 #[derive(FromQueryResult)]
255 struct QueryModel {
256 tablename: String,
257 }
258
259 Entity::find()
260 .select_only()
261 .column(Column::Tablename)
262 .filter(Column::Schemaname.is_not_in(["pg_catalog", "information_schema"]))
263 .into_model::<QueryModel>()
264 .all(conn)
265 .await
266 .map(|mut models| models.drain(..).map(|model| model.tablename).collect())
267 .map_err(Into::into)
268 }
269
270 fn get_drop_previous_databases(&self) -> bool {
271 self.drop_previous_databases_flag
272 }
273}
274
275type BError = BackendError<BuildError, PoolError, ConnectionError, QueryError>;
276
277#[async_trait]
278impl Backend for SeaORMPostgresBackend {
279 type Pool = DatabaseConnection;
280
281 type BuildError = BuildError;
282 type PoolError = PoolError;
283 type ConnectionError = ConnectionError;
284 type QueryError = QueryError;
285
286 async fn init(&self) -> Result<(), BError> {
287 PostgresBackendWrapper::new(self).init().await
288 }
289
290 async fn create(
291 &self,
292 db_id: uuid::Uuid,
293 restrict_privileges: bool,
294 ) -> Result<DatabaseConnection, BError> {
295 PostgresBackendWrapper::new(self)
296 .create(db_id, restrict_privileges)
297 .await
298 }
299
300 async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError> {
301 PostgresBackendWrapper::new(self).clean(db_id).await
302 }
303
304 async fn drop(&self, db_id: uuid::Uuid, is_restricted: bool) -> Result<(), BError> {
305 PostgresBackendWrapper::new(self)
306 .drop(db_id, is_restricted)
307 .await
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 #![allow(clippy::unwrap_used, clippy::needless_return)]
314
315 use dotenvy::dotenv;
316 use futures::future::join_all;
317 use sea_orm::{
318 ActiveModelBehavior, ActiveModelTrait, ConnectionTrait, DeriveEntityModel,
319 DerivePrimaryKey, DeriveRelation, EntityTrait, EnumIter, FromQueryResult, PaginatorTrait,
320 PrimaryKeyTrait, QuerySelect, Set,
321 };
322 use tokio_shared_rt::test;
323
324 use crate::{
325 common::{
326 config::PrivilegedPostgresConfig,
327 statement::postgres::tests::{
328 CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
329 },
330 },
331 r#async::{
332 backend::postgres::r#trait::tests::{
333 test_backend_drops_database, test_pool_drops_created_unrestricted_database,
334 },
335 db_pool::DatabasePoolBuilder,
336 },
337 };
338
339 use super::{
340 super::r#trait::tests::{
341 test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables,
342 test_backend_creates_database_with_restricted_privileges,
343 test_backend_creates_database_with_unrestricted_privileges,
344 test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
345 test_pool_drops_previous_databases, PgDropLock,
346 },
347 SeaORMPostgresBackend,
348 };
349
350 #[derive(Clone, Debug, DeriveEntityModel)]
351 #[sea_orm(table_name = "book")]
352 pub struct Model {
353 #[sea_orm(primary_key)]
354 id: i32,
355 title: String,
356 }
357
358 #[derive(Debug, EnumIter, DeriveRelation)]
359 pub enum Relation {}
360
361 impl ActiveModelBehavior for ActiveModel {}
362
363 async fn create_backend(with_table: bool) -> SeaORMPostgresBackend {
364 dotenv().ok();
365
366 let config = PrivilegedPostgresConfig::from_env().unwrap();
367
368 SeaORMPostgresBackend::new(config, |_| {}, |_| {}, {
369 move |conn| {
370 if with_table {
371 Box::pin(async move {
372 conn.execute_unprepared(CREATE_ENTITIES_STATEMENTS.join(";").as_str())
373 .await
374 .unwrap();
375 })
376 } else {
377 Box::pin(async {})
378 }
379 }
380 })
381 .await
382 .unwrap()
383 }
384
385 #[test(flavor = "multi_thread", shared)]
386 async fn backend_drops_previous_databases() {
387 test_backend_drops_previous_databases(
388 create_backend(false).await,
389 create_backend(false).await.drop_previous_databases(true),
390 create_backend(false).await.drop_previous_databases(false),
391 )
392 .await;
393 }
394
395 #[test(flavor = "multi_thread", shared)]
396 async fn backend_creates_database_with_restricted_privileges() {
397 let backend = create_backend(true).await.drop_previous_databases(false);
398 test_backend_creates_database_with_restricted_privileges(backend).await;
399 }
400
401 #[test(flavor = "multi_thread", shared)]
402 async fn backend_creates_database_with_unrestricted_privileges() {
403 let backend = create_backend(true).await.drop_previous_databases(false);
404 test_backend_creates_database_with_unrestricted_privileges(backend).await;
405 }
406
407 #[test(flavor = "multi_thread", shared)]
408 async fn backend_cleans_database_with_tables() {
409 let backend = create_backend(true).await.drop_previous_databases(false);
410 test_backend_cleans_database_with_tables(backend).await;
411 }
412
413 #[test(flavor = "multi_thread", shared)]
414 async fn backend_cleans_database_without_tables() {
415 let backend = create_backend(false).await.drop_previous_databases(false);
416 test_backend_cleans_database_without_tables(backend).await;
417 }
418
419 #[test(flavor = "multi_thread", shared)]
420 async fn backend_drops_restricted_database() {
421 let backend = create_backend(true).await.drop_previous_databases(false);
422 test_backend_drops_database(backend, true).await;
423 }
424
425 #[test(flavor = "multi_thread", shared)]
426 async fn backend_drops_unrestricted_database() {
427 let backend = create_backend(true).await.drop_previous_databases(false);
428 test_backend_drops_database(backend, false).await;
429 }
430
431 #[test(flavor = "multi_thread", shared)]
432 async fn pool_drops_previous_databases() {
433 test_pool_drops_previous_databases(
434 create_backend(false).await,
435 create_backend(false).await.drop_previous_databases(true),
436 create_backend(false).await.drop_previous_databases(false),
437 )
438 .await;
439 }
440
441 #[test(flavor = "multi_thread", shared)]
442 async fn pool_provides_isolated_databases() {
443 #[derive(FromQueryResult, Eq, PartialEq, Debug)]
444 struct QueryModel {
445 title: String,
446 }
447
448 const NUM_DBS: i64 = 3;
449
450 let backend = create_backend(true).await.drop_previous_databases(false);
451
452 async {
453 let db_pool = backend.create_database_pool().await.unwrap();
454 let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
455
456 join_all(conns.iter().enumerate().map(|(i, conn)| async move {
458 let book = ActiveModel {
459 title: Set(format!("Title {i}")),
460 ..Default::default()
461 };
462 book.insert(&***conn).await.unwrap();
463 }))
464 .await;
465
466 join_all(conns.iter().enumerate().map(|(i, conn)| async move {
468 assert_eq!(
469 Entity::find()
470 .select_only()
471 .column(Column::Title)
472 .into_model::<QueryModel>()
473 .all(&***conn)
474 .await
475 .unwrap(),
476 vec![QueryModel {
477 title: format!("Title {i}")
478 }]
479 );
480 }))
481 .await;
482 }
483 .lock_read()
484 .await;
485 }
486
487 #[test(flavor = "multi_thread", shared)]
488 async fn pool_provides_restricted_databases() {
489 let backend = create_backend(true).await.drop_previous_databases(false);
490
491 async {
492 let db_pool = backend.create_database_pool().await.unwrap();
493 let conn = db_pool.pull_immutable().await;
494
495 for stmt in DDL_STATEMENTS {
497 assert!(conn.execute_unprepared(stmt).await.is_err());
498 }
499
500 for stmt in DML_STATEMENTS {
502 assert!(conn.execute_unprepared(stmt).await.is_ok());
503 }
504 }
505 .lock_read()
506 .await;
507 }
508
509 #[test(flavor = "multi_thread", shared)]
510 async fn pool_provides_unrestricted_databases() {
511 let backend = create_backend(true).await.drop_previous_databases(false);
512
513 async {
514 let db_pool = backend.create_database_pool().await.unwrap();
515
516 {
518 let conn = db_pool.create_mutable().await.unwrap();
519 for stmt in DML_STATEMENTS {
520 assert!(conn.execute_unprepared(stmt).await.is_ok());
521 }
522 }
523
524 for stmt in DDL_STATEMENTS {
526 let conn = db_pool.create_mutable().await.unwrap();
527 assert!(conn.execute_unprepared(stmt).await.is_ok());
528 }
529 }
530 .lock_read()
531 .await;
532 }
533
534 #[test(flavor = "multi_thread", shared)]
535 async fn pool_provides_clean_databases() {
536 const NUM_DBS: i64 = 3;
537
538 let backend = create_backend(true).await.drop_previous_databases(false);
539
540 async {
541 let db_pool = backend.create_database_pool().await.unwrap();
542
543 {
545 let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
546
547 join_all(conns.iter().map(|conn| async move {
549 assert_eq!(Entity::find().count(&***conn).await.unwrap(), 0);
550 }))
551 .await;
552
553 join_all(conns.iter().map(|conn| async move {
555 let book = ActiveModel {
556 title: Set("Title".to_owned()),
557 ..Default::default()
558 };
559 book.insert(&***conn).await.unwrap();
560 }))
561 .await;
562 }
563
564 {
566 let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
567
568 join_all(conns.iter().map(|conn| async move {
570 assert_eq!(Entity::find().count(&***conn).await.unwrap(), 0);
571 }))
572 .await;
573 }
574 }
575 .lock_read()
576 .await;
577 }
578
579 #[test(flavor = "multi_thread", shared)]
580 async fn pool_drops_created_restricted_databases() {
581 let backend = create_backend(false).await;
582 test_pool_drops_created_restricted_databases(backend).await;
583 }
584
585 #[test(flavor = "multi_thread", shared)]
586 async fn pool_drops_created_unrestricted_database() {
587 let backend = create_backend(false).await;
588 test_pool_drops_created_unrestricted_database(backend).await;
589 }
590}