1use std::{borrow::Cow, pin::Pin};
2
3use async_trait::async_trait;
4use futures::Future;
5use sea_orm::{
6 ActiveModelBehavior, ColumnTrait, ConnectOptions, ConnectionTrait, Database,
7 DatabaseConnection, DbErr, DeriveEntityModel, DerivePrimaryKey, DeriveRelation, EntityTrait,
8 EnumIter, FromQueryResult, PrimaryKeyTrait, QueryFilter, QuerySelect, TransactionError,
9 TransactionTrait,
10};
11use uuid::Uuid;
12
13use crate::{
14 common::{config::PrivilegedMySQLConfig, statement::mysql},
15 util::get_db_name,
16};
17
18use super::{
19 super::{
20 common::{
21 conn::sea_orm::PooledConnection,
22 error::sea_orm::{BuildError, ConnectionError, PoolError, QueryError},
23 },
24 error::Error as BackendError,
25 r#trait::Backend,
26 },
27 r#trait::{MySQLBackend, MySQLBackendWrapper},
28};
29
30type CreateEntities = dyn Fn(DatabaseConnection) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
31 + Send
32 + Sync
33 + 'static;
34
35pub struct SeaORMMySQLBackend {
37 privileged_config: PrivilegedMySQLConfig,
38 default_pool: DatabaseConnection,
39 create_restricted_pool: Box<dyn for<'tmp> Fn(&'tmp mut ConnectOptions) + Send + Sync + 'static>,
40 create_entities: Box<CreateEntities>,
41 drop_previous_databases_flag: bool,
42}
43
44impl SeaORMMySQLBackend {
45 pub async fn new(
82 privileged_config: PrivilegedMySQLConfig,
83 create_privileged_pool: impl for<'tmp> Fn(&'tmp mut ConnectOptions),
84 create_restricted_pool: impl for<'tmp> Fn(&'tmp mut ConnectOptions) + Send + Sync + 'static,
85 create_entities: impl Fn(DatabaseConnection) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
86 + Send
87 + Sync
88 + 'static,
89 ) -> Result<Self, DbErr> {
90 let mut opts = ConnectOptions::new(privileged_config.default_connection_url());
91 create_privileged_pool(&mut opts);
92 let default_pool = Database::connect(opts).await?;
93
94 Ok(Self {
95 privileged_config,
96 default_pool,
97 create_restricted_pool: Box::new(create_restricted_pool),
98 create_entities: Box::new(create_entities),
99 drop_previous_databases_flag: true,
100 })
101 }
102
103 #[must_use]
105 pub fn drop_previous_databases(self, value: bool) -> Self {
106 Self {
107 drop_previous_databases_flag: value,
108 ..self
109 }
110 }
111}
112
113#[async_trait]
114impl<'pool> MySQLBackend<'pool> for SeaORMMySQLBackend {
115 type Connection = DatabaseConnection;
116 type PooledConnection = PooledConnection;
117 type Pool = DatabaseConnection;
118
119 type BuildError = BuildError;
120 type PoolError = PoolError;
121 type ConnectionError = ConnectionError;
122 type QueryError = QueryError;
123
124 async fn get_connection(&'pool self) -> Result<PooledConnection, PoolError> {
125 Ok(self.default_pool.clone().into())
126 }
127
128 async fn execute_query(
129 &self,
130 query: &str,
131 conn: &mut DatabaseConnection,
132 ) -> Result<(), QueryError> {
133 conn.execute_unprepared(query).await?;
134 Ok(())
135 }
136
137 async fn batch_execute_query<'a>(
138 &self,
139 query: impl IntoIterator<Item = Cow<'a, str>> + Send,
140 conn: &mut DatabaseConnection,
141 ) -> Result<(), QueryError> {
142 let chunks = query.into_iter().collect::<Vec<_>>();
143 if chunks.is_empty() {
144 Ok(())
145 } else {
146 let query = chunks.join(";");
147 self.execute_query(query.as_str(), conn).await
148 }
149 }
150
151 fn get_host(&self) -> &str {
152 self.privileged_config.host.as_str()
153 }
154
155 async fn get_previous_database_names(
156 &self,
157 conn: &mut DatabaseConnection,
158 ) -> Result<Vec<String>, QueryError> {
159 #[derive(Clone, Debug, DeriveEntityModel)]
160 #[sea_orm(table_name = "schemata")]
161 pub struct Model {
162 #[sea_orm(primary_key)]
163 schema_name: String,
164 }
165
166 #[derive(Debug, EnumIter, DeriveRelation)]
167 pub enum Relation {}
168
169 impl ActiveModelBehavior for ActiveModel {}
170
171 conn.transaction(move |txn| {
172 Box::pin(async move {
173 txn.execute_unprepared(mysql::USE_DEFAULT_DATABASE).await?;
174
175 Entity::find()
176 .filter(Column::SchemaName.like("db_pool_%"))
177 .all(txn)
178 .await
179 })
180 })
181 .await
182 .map(|mut models| models.drain(..).map(|model| model.schema_name).collect())
183 .map_err(|err| match err {
184 TransactionError::Connection(err) | TransactionError::Transaction(err) => err.into(),
185 })
186 }
187
188 async fn create_entities(&self, db_name: &str) -> Result<(), ConnectionError> {
189 let database_url = self
190 .privileged_config
191 .privileged_database_connection_url(db_name);
192 let conn = Database::connect(database_url).await?;
193 (self.create_entities)(conn).await;
194 Ok(())
195 }
196
197 async fn create_connection_pool(&self, db_id: Uuid) -> Result<DatabaseConnection, BuildError> {
198 let db_name = get_db_name(db_id);
199 let db_name = db_name.as_str();
200 let database_url = self.privileged_config.restricted_database_connection_url(
201 db_name,
202 Some(db_name),
203 db_name,
204 );
205 let mut opts = ConnectOptions::new(database_url);
206 (self.create_restricted_pool)(&mut opts);
207 Database::connect(opts).await.map_err(Into::into)
208 }
209
210 async fn get_table_names(
212 &self,
213 db_name: &str,
214 conn: &mut DatabaseConnection,
215 ) -> Result<Vec<String>, QueryError> {
216 #[derive(Clone, Debug, DeriveEntityModel)]
217 #[sea_orm(table_name = "tables")]
218 pub struct Model {
219 #[sea_orm(primary_key)]
220 table_name: String,
221 table_schema: String,
222 }
223
224 #[derive(Debug, EnumIter, DeriveRelation)]
225 pub enum Relation {}
226
227 impl ActiveModelBehavior for ActiveModel {}
228
229 #[derive(FromQueryResult)]
230 struct QueryModel {
231 table_name: String,
232 }
233
234 conn.transaction(move |txn| {
235 let db_name = db_name.to_owned();
236 Box::pin(async move {
237 txn.execute_unprepared(mysql::USE_DEFAULT_DATABASE).await?;
238
239 Entity::find()
240 .select_only()
241 .column(Column::TableName)
242 .filter(Column::TableSchema.eq(db_name))
243 .into_model::<QueryModel>()
244 .all(txn)
245 .await
246 })
247 })
248 .await
249 .map(|mut models| models.drain(..).map(|model| model.table_name).collect())
250 .map_err(|err| match err {
251 TransactionError::Connection(err) | TransactionError::Transaction(err) => err.into(),
252 })
253 }
254
255 fn get_drop_previous_databases(&self) -> bool {
256 self.drop_previous_databases_flag
257 }
258}
259
260type BError = BackendError<BuildError, PoolError, ConnectionError, QueryError>;
261
262#[async_trait]
263impl Backend for SeaORMMySQLBackend {
264 type Pool = DatabaseConnection;
265
266 type BuildError = BuildError;
267 type PoolError = PoolError;
268 type ConnectionError = ConnectionError;
269 type QueryError = QueryError;
270
271 async fn init(&self) -> Result<(), BError> {
272 MySQLBackendWrapper::new(self).init().await
273 }
274
275 async fn create(
276 &self,
277 db_id: uuid::Uuid,
278 restrict_privileges: bool,
279 ) -> Result<DatabaseConnection, BError> {
280 MySQLBackendWrapper::new(self)
281 .create(db_id, restrict_privileges)
282 .await
283 }
284
285 async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError> {
286 MySQLBackendWrapper::new(self).clean(db_id).await
287 }
288
289 async fn drop(&self, db_id: uuid::Uuid, _is_restricted: bool) -> Result<(), BError> {
290 MySQLBackendWrapper::new(self).drop(db_id).await
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 #![allow(clippy::unwrap_used, clippy::needless_return)]
297
298 use futures::future::join_all;
299 use sea_orm::{
300 ActiveModelBehavior, ActiveModelTrait, ConnectionTrait, DeriveEntityModel,
301 DerivePrimaryKey, DeriveRelation, EntityTrait, EnumIter, FromQueryResult, PaginatorTrait,
302 PrimaryKeyTrait, QuerySelect, Set,
303 };
304 use tokio_shared_rt::test;
305
306 use crate::{
307 common::statement::mysql::tests::{
308 CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
309 },
310 r#async::{
311 backend::mysql::r#trait::tests::{
312 test_backend_creates_database_with_unrestricted_privileges,
313 test_pool_drops_created_restricted_databases,
314 test_pool_drops_created_unrestricted_database,
315 },
316 db_pool::DatabasePoolBuilder,
317 },
318 tests::get_privileged_mysql_config,
319 };
320
321 use super::{
322 super::r#trait::tests::{
323 test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables,
324 test_backend_creates_database_with_restricted_privileges, test_backend_drops_database,
325 test_backend_drops_previous_databases, test_pool_drops_previous_databases,
326 MySQLDropLock,
327 },
328 SeaORMMySQLBackend,
329 };
330
331 #[derive(Clone, Debug, DeriveEntityModel)]
332 #[sea_orm(table_name = "book")]
333 pub struct Model {
334 #[sea_orm(primary_key)]
335 id: i32,
336 title: String,
337 }
338
339 #[derive(Debug, EnumIter, DeriveRelation)]
340 pub enum Relation {}
341
342 impl ActiveModelBehavior for ActiveModel {}
343
344 async fn create_backend(with_table: bool) -> SeaORMMySQLBackend {
345 let config = get_privileged_mysql_config().clone();
346 SeaORMMySQLBackend::new(config, |_| {}, |_| {}, {
347 move |conn| {
348 if with_table {
349 Box::pin(async move {
350 conn.execute_unprepared(CREATE_ENTITIES_STATEMENTS.join(";").as_str())
351 .await
352 .unwrap();
353 })
354 } else {
355 Box::pin(async {})
356 }
357 }
358 })
359 .await
360 .unwrap()
361 }
362
363 #[test(flavor = "multi_thread", shared)]
364 async fn backend_drops_previous_databases() {
365 test_backend_drops_previous_databases(
366 create_backend(false).await,
367 create_backend(false).await.drop_previous_databases(true),
368 create_backend(false).await.drop_previous_databases(false),
369 )
370 .await;
371 }
372
373 #[test(flavor = "multi_thread", shared)]
374 async fn backend_creates_database_with_restricted_privileges() {
375 let backend = create_backend(true).await.drop_previous_databases(false);
376 test_backend_creates_database_with_restricted_privileges(backend).await;
377 }
378
379 #[test(flavor = "multi_thread", shared)]
380 async fn backend_creates_database_with_unrestricted_privileges() {
381 let backend = create_backend(true).await.drop_previous_databases(false);
382 test_backend_creates_database_with_unrestricted_privileges(backend).await;
383 }
384
385 #[test(flavor = "multi_thread", shared)]
386 async fn backend_cleans_database_with_tables() {
387 let backend = create_backend(true).await.drop_previous_databases(false);
388 test_backend_cleans_database_with_tables(backend).await;
389 }
390
391 #[test(flavor = "multi_thread", shared)]
392 async fn backend_cleans_database_without_tables() {
393 let backend = create_backend(false).await.drop_previous_databases(false);
394 test_backend_cleans_database_without_tables(backend).await;
395 }
396
397 #[test(flavor = "multi_thread", shared)]
398 async fn backend_drops_restricted_database() {
399 let backend = create_backend(true).await.drop_previous_databases(false);
400 test_backend_drops_database(backend, true).await;
401 }
402
403 #[test(flavor = "multi_thread", shared)]
404 async fn backend_drops_unrestricted_database() {
405 let backend = create_backend(true).await.drop_previous_databases(false);
406 test_backend_drops_database(backend, false).await;
407 }
408
409 #[test(flavor = "multi_thread", shared)]
410 async fn pool_drops_previous_databases() {
411 test_pool_drops_previous_databases(
412 create_backend(false).await,
413 create_backend(false).await.drop_previous_databases(true),
414 create_backend(false).await.drop_previous_databases(false),
415 )
416 .await;
417 }
418
419 #[test(flavor = "multi_thread", shared)]
420 async fn pool_provides_isolated_databases() {
421 #[derive(FromQueryResult, Eq, PartialEq, Debug)]
422 struct QueryModel {
423 title: String,
424 }
425
426 const NUM_DBS: i64 = 3;
427
428 let backend = create_backend(true).await.drop_previous_databases(false);
429
430 async {
431 let db_pool = backend.create_database_pool().await.unwrap();
432 let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
433
434 join_all(conns.iter().enumerate().map(|(i, conn)| async move {
436 let book = ActiveModel {
437 title: Set(format!("Title {i}")),
438 ..Default::default()
439 };
440 book.insert(&***conn).await.unwrap();
441 }))
442 .await;
443
444 join_all(conns.iter().enumerate().map(|(i, conn)| async move {
446 assert_eq!(
447 Entity::find()
448 .select_only()
449 .column(Column::Title)
450 .into_model::<QueryModel>()
451 .all(&***conn)
452 .await
453 .unwrap(),
454 vec![QueryModel {
455 title: format!("Title {i}")
456 }]
457 );
458 }))
459 .await;
460 }
461 .lock_read()
462 .await;
463 }
464
465 #[test(flavor = "multi_thread", shared)]
466 async fn pool_provides_restricted_databases() {
467 let backend = create_backend(true).await.drop_previous_databases(false);
468
469 async {
470 let db_pool = backend.create_database_pool().await.unwrap();
471 let conn = db_pool.pull_immutable().await;
472
473 for stmt in DDL_STATEMENTS {
475 assert!(conn.execute_unprepared(stmt).await.is_err());
476 }
477
478 for stmt in DML_STATEMENTS {
480 assert!(conn.execute_unprepared(stmt).await.is_ok());
481 }
482 }
483 .lock_read()
484 .await;
485 }
486
487 #[test(flavor = "multi_thread", shared)]
488 async fn pool_provides_unrestricted_databases() {
489 let backend = create_backend(true).await.drop_previous_databases(false);
490
491 async {
492 let db_pool = backend.create_database_pool().await.unwrap();
493
494 {
496 let conn = db_pool.create_mutable().await.unwrap();
497 for stmt in DML_STATEMENTS {
498 assert!(conn.execute_unprepared(stmt).await.is_ok());
499 }
500 }
501
502 for stmt in DDL_STATEMENTS {
504 let conn = db_pool.create_mutable().await.unwrap();
505 assert!(conn.execute_unprepared(stmt).await.is_ok());
506 }
507 }
508 .lock_read()
509 .await;
510 }
511
512 #[test(flavor = "multi_thread", shared)]
513 async fn pool_provides_clean_databases() {
514 const NUM_DBS: i64 = 3;
515
516 let backend = create_backend(true).await.drop_previous_databases(false);
517
518 async {
519 let db_pool = backend.create_database_pool().await.unwrap();
520
521 {
523 let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
524
525 join_all(conns.iter().map(|conn| async move {
527 assert_eq!(Entity::find().count(&***conn).await.unwrap(), 0);
528 }))
529 .await;
530
531 join_all(conns.iter().map(|conn| async move {
533 let book = ActiveModel {
534 title: Set("Title".to_owned()),
535 ..Default::default()
536 };
537 book.insert(&***conn).await.unwrap();
538 }))
539 .await;
540 }
541
542 {
544 let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
545
546 join_all(conns.iter().map(|conn| async move {
548 assert_eq!(Entity::find().count(&***conn).await.unwrap(), 0);
549 }))
550 .await;
551 }
552 }
553 .lock_read()
554 .await;
555 }
556
557 #[test(flavor = "multi_thread", shared)]
558 async fn pool_drops_created_restricted_databases() {
559 let backend = create_backend(false).await;
560 test_pool_drops_created_restricted_databases(backend).await;
561 }
562
563 #[test(flavor = "multi_thread", shared)]
564 async fn pool_drops_created_unrestricted_database() {
565 let backend = create_backend(false).await;
566 test_pool_drops_created_unrestricted_database(backend).await;
567 }
568}