db_pool/async/backend/mysql/
diesel.rs1use std::{borrow::Cow, pin::Pin};
2
3use async_trait::async_trait;
4use diesel::{prelude::*, result::Error, sql_query, table};
5use diesel_async::{
6 pooled_connection::{AsyncDieselConnectionManager, ManagerConfig, SetupCallback},
7 AsyncConnection, AsyncMysqlConnection, RunQueryDsl, SimpleAsyncConnection,
8};
9use futures::{future::FutureExt, Future};
10use uuid::Uuid;
11
12use crate::{
13 common::{config::mysql::PrivilegedMySQLConfig, statement::mysql},
14 util::get_db_name,
15};
16
17use super::{
18 super::{
19 common::pool::diesel::r#trait::DieselPoolAssociation, error::Error as BackendError,
20 r#trait::Backend,
21 },
22 r#trait::{MySQLBackend, MySQLBackendWrapper},
23};
24
25type CreateEntities = dyn Fn(AsyncMysqlConnection) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
26 + Send
27 + Sync
28 + 'static;
29
30pub struct DieselAsyncMySQLBackend<P: DieselPoolAssociation<AsyncMysqlConnection>> {
32 privileged_config: PrivilegedMySQLConfig,
33 default_pool: P::Pool,
34 create_restricted_pool: Box<dyn Fn() -> P::Builder + Send + Sync + 'static>,
35 create_connection: Box<dyn Fn() -> SetupCallback<AsyncMysqlConnection> + Send + Sync + 'static>,
36 create_entities: Box<CreateEntities>,
37 drop_previous_databases_flag: bool,
38}
39
40impl<P: DieselPoolAssociation<AsyncMysqlConnection>> DieselAsyncMySQLBackend<P> {
41 pub async fn new(
79 privileged_config: PrivilegedMySQLConfig,
80 create_privileged_pool: impl Fn() -> P::Builder,
81 create_restricted_pool: impl Fn() -> P::Builder + Send + Sync + 'static,
82 custom_create_connection: Option<
83 Box<dyn Fn() -> SetupCallback<AsyncMysqlConnection> + Send + Sync + 'static>,
84 >,
85 create_entities: impl Fn(AsyncMysqlConnection) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
86 + Send
87 + Sync
88 + 'static,
89 ) -> Result<Self, P::BuildError> {
90 let create_connection = custom_create_connection.unwrap_or_else(|| {
91 Box::new(|| {
92 Box::new(|connection_url| AsyncMysqlConnection::establish(connection_url).boxed())
93 })
94 });
95
96 let manager_config = {
97 let mut config = ManagerConfig::default();
98 config.custom_setup = Box::new(create_connection());
99 config
100 };
101 let manager = AsyncDieselConnectionManager::new_with_config(
102 privileged_config.default_connection_url(),
103 manager_config,
104 );
105 let builder = create_privileged_pool();
106 let default_pool = P::build_pool(builder, manager).await?;
107
108 Ok(Self {
109 privileged_config,
110 default_pool,
111 create_restricted_pool: Box::new(create_restricted_pool),
112 create_connection: Box::new(create_connection),
113 create_entities: Box::new(create_entities),
114 drop_previous_databases_flag: true,
115 })
116 }
117
118 #[must_use]
120 pub fn drop_previous_databases(self, value: bool) -> Self {
121 Self {
122 drop_previous_databases_flag: value,
123 ..self
124 }
125 }
126}
127
128#[async_trait]
129impl<'pool, P: DieselPoolAssociation<AsyncMysqlConnection>> MySQLBackend<'pool>
130 for DieselAsyncMySQLBackend<P>
131{
132 type Connection = AsyncMysqlConnection;
133 type PooledConnection = P::PooledConnection<'pool>;
134 type Pool = P::Pool;
135
136 type BuildError = P::BuildError;
137 type PoolError = P::PoolError;
138 type ConnectionError = ConnectionError;
139 type QueryError = Error;
140
141 async fn get_connection(&'pool self) -> Result<P::PooledConnection<'pool>, P::PoolError> {
142 P::get_connection(&self.default_pool).await
143 }
144
145 async fn execute_query(&self, query: &str, conn: &mut AsyncMysqlConnection) -> QueryResult<()> {
146 sql_query(query).execute(conn).await?;
147 Ok(())
148 }
149
150 async fn batch_execute_query<'a>(
151 &self,
152 query: impl IntoIterator<Item = Cow<'a, str>> + Send,
153 conn: &mut AsyncMysqlConnection,
154 ) -> QueryResult<()> {
155 let query = query.into_iter().collect::<Vec<_>>();
156 if query.is_empty() {
157 Ok(())
158 } else {
159 conn.batch_execute(query.join(";").as_str()).await
160 }
161 }
162
163 fn get_host(&self) -> &str {
164 self.privileged_config.host.as_str()
165 }
166
167 async fn get_previous_database_names(
168 &self,
169 conn: &mut AsyncMysqlConnection,
170 ) -> QueryResult<Vec<String>> {
171 table! {
172 schemata (schema_name) {
173 schema_name -> Text
174 }
175 }
176
177 schemata::table
178 .select(schemata::schema_name)
179 .filter(schemata::schema_name.like("db_pool_%"))
180 .load::<String>(conn)
181 .await
182 }
183
184 async fn create_entities(&self, db_name: &str) -> Result<(), ConnectionError> {
185 let database_url = self
186 .privileged_config
187 .privileged_database_connection_url(db_name);
188 let conn = (self.create_connection)()(database_url.as_str()).await?;
189 (self.create_entities)(conn).await;
190 Ok(())
191 }
192
193 async fn create_connection_pool(&self, db_id: Uuid) -> Result<P::Pool, P::BuildError> {
194 let db_name = get_db_name(db_id);
195 let db_name = db_name.as_str();
196 let database_url = self.privileged_config.restricted_database_connection_url(
197 db_name,
198 Some(db_name),
199 db_name,
200 );
201 let manager_config = {
202 let mut config = ManagerConfig::default();
203 config.custom_setup = (self.create_connection)();
204 config
205 };
206 let manager = AsyncDieselConnectionManager::<AsyncMysqlConnection>::new_with_config(
207 database_url.as_str(),
208 manager_config,
209 );
210 let builder = (self.create_restricted_pool)();
211 P::build_pool(builder, manager).await
212 }
213
214 async fn get_table_names(
215 &self,
216 db_name: &str,
217 conn: &mut AsyncMysqlConnection,
218 ) -> QueryResult<Vec<String>> {
219 table! {
220 tables (table_name) {
221 table_name -> Text,
222 table_schema -> Text
223 }
224 }
225
226 sql_query(mysql::USE_DEFAULT_DATABASE).execute(conn).await?;
227
228 tables::table
229 .filter(tables::table_schema.eq(db_name))
230 .select(tables::table_name)
231 .load::<String>(conn)
232 .await
233 }
234
235 fn get_drop_previous_databases(&self) -> bool {
236 self.drop_previous_databases_flag
237 }
238}
239
240type BError<BuildError, PoolError> = BackendError<BuildError, PoolError, ConnectionError, Error>;
241
242#[async_trait]
243impl<P: DieselPoolAssociation<AsyncMysqlConnection>> Backend for DieselAsyncMySQLBackend<P> {
244 type Pool = P::Pool;
245
246 type BuildError = P::BuildError;
247 type PoolError = P::PoolError;
248 type ConnectionError = ConnectionError;
249 type QueryError = Error;
250
251 async fn init(&self) -> Result<(), BError<P::BuildError, P::PoolError>> {
252 MySQLBackendWrapper::new(self).init().await
253 }
254
255 async fn create(
256 &self,
257 db_id: uuid::Uuid,
258 restrict_privileges: bool,
259 ) -> Result<P::Pool, BError<P::BuildError, P::PoolError>> {
260 MySQLBackendWrapper::new(self)
261 .create(db_id, restrict_privileges)
262 .await
263 }
264
265 async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError<P::BuildError, P::PoolError>> {
266 MySQLBackendWrapper::new(self).clean(db_id).await
267 }
268
269 async fn drop(
270 &self,
271 db_id: uuid::Uuid,
272 _is_restricted: bool,
273 ) -> Result<(), BError<P::BuildError, P::PoolError>> {
274 MySQLBackendWrapper::new(self).drop(db_id).await
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 #![allow(clippy::unwrap_used, clippy::needless_return)]
281
282 use std::borrow::Cow;
283
284 use bb8::Pool;
285 use diesel::{insert_into, sql_query, table, Insertable, QueryDsl};
286 use diesel_async::{RunQueryDsl, SimpleAsyncConnection};
287 use futures::future::join_all;
288 use tokio_shared_rt::test;
289
290 use crate::{
291 common::statement::mysql::tests::{
292 CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
293 },
294 r#async::{
295 backend::{
296 common::pool::diesel::bb8::DieselBb8,
297 mysql::r#trait::tests::{
298 test_backend_creates_database_with_unrestricted_privileges,
299 test_pool_drops_created_unrestricted_database,
300 },
301 },
302 db_pool::DatabasePoolBuilder,
303 },
304 tests::get_privileged_mysql_config,
305 };
306
307 use super::{
308 super::r#trait::tests::{
309 test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables,
310 test_backend_creates_database_with_restricted_privileges, test_backend_drops_database,
311 test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
312 test_pool_drops_previous_databases, MySQLDropLock,
313 },
314 DieselAsyncMySQLBackend,
315 };
316
317 table! {
318 book (id) {
319 id -> Int4,
320 title -> Text
321 }
322 }
323
324 #[derive(Insertable)]
325 #[diesel(table_name = book)]
326 struct NewBook<'a> {
327 title: Cow<'a, str>,
328 }
329
330 async fn create_backend(with_table: bool) -> DieselAsyncMySQLBackend<DieselBb8> {
331 let config = get_privileged_mysql_config().clone();
332 DieselAsyncMySQLBackend::new(config, Pool::builder, Pool::builder, None, {
333 move |mut conn| {
334 if with_table {
335 Box::pin(async move {
336 let query = CREATE_ENTITIES_STATEMENTS.join(";");
337 conn.batch_execute(query.as_str()).await.unwrap();
338 })
339 } else {
340 Box::pin(async {})
341 }
342 }
343 })
344 .await
345 .unwrap()
346 }
347
348 #[test(flavor = "multi_thread", shared)]
349 async fn backend_drops_previous_databases() {
350 test_backend_drops_previous_databases(
351 create_backend(false).await,
352 create_backend(false).await.drop_previous_databases(true),
353 create_backend(false).await.drop_previous_databases(false),
354 )
355 .await;
356 }
357
358 #[test(flavor = "multi_thread", shared)]
359 async fn backend_creates_database_with_restricted_privileges() {
360 let backend = create_backend(true).await.drop_previous_databases(false);
361 test_backend_creates_database_with_restricted_privileges(backend).await;
362 }
363
364 #[test(flavor = "multi_thread", shared)]
365 async fn backend_creates_database_with_unrestricted_privileges() {
366 let backend = create_backend(true).await.drop_previous_databases(false);
367 test_backend_creates_database_with_unrestricted_privileges(backend).await;
368 }
369
370 #[test(flavor = "multi_thread", shared)]
371 async fn backend_cleans_database_with_tables() {
372 let backend = create_backend(true).await.drop_previous_databases(false);
373 test_backend_cleans_database_with_tables(backend).await;
374 }
375
376 #[test(flavor = "multi_thread", shared)]
377 async fn backend_cleans_database_without_tables() {
378 let backend = create_backend(false).await.drop_previous_databases(false);
379 test_backend_cleans_database_without_tables(backend).await;
380 }
381
382 #[test(flavor = "multi_thread", shared)]
383 async fn backend_drops_restricted_database() {
384 let backend = create_backend(true).await.drop_previous_databases(false);
385 test_backend_drops_database(backend, true).await;
386 }
387
388 #[test(flavor = "multi_thread", shared)]
389 async fn backend_drops_unrestricted_database() {
390 let backend = create_backend(true).await.drop_previous_databases(false);
391 test_backend_drops_database(backend, false).await;
392 }
393
394 #[test(flavor = "multi_thread", shared)]
395 async fn pool_drops_previous_databases() {
396 test_pool_drops_previous_databases(
397 create_backend(false).await,
398 create_backend(false).await.drop_previous_databases(true),
399 create_backend(false).await.drop_previous_databases(false),
400 )
401 .await;
402 }
403
404 #[test(flavor = "multi_thread", shared)]
405 async fn pool_provides_isolated_databases() {
406 const NUM_DBS: i64 = 3;
407
408 let backend = create_backend(true).await.drop_previous_databases(false);
409
410 async {
411 let db_pool = backend.create_database_pool().await.unwrap();
412 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
413
414 join_all(
416 conn_pools
417 .iter()
418 .enumerate()
419 .map(|(i, conn_pool)| async move {
420 let conn = &mut conn_pool.get().await.unwrap();
421 insert_into(book::table)
422 .values(NewBook {
423 title: format!("Title {i}").into(),
424 })
425 .execute(conn)
426 .await
427 .unwrap();
428 }),
429 )
430 .await;
431
432 join_all(
434 conn_pools
435 .iter()
436 .enumerate()
437 .map(|(i, conn_pool)| async move {
438 let conn = &mut conn_pool.get().await.unwrap();
439 assert_eq!(
440 book::table
441 .select(book::title)
442 .load::<String>(conn)
443 .await
444 .unwrap(),
445 vec![format!("Title {i}")]
446 );
447 }),
448 )
449 .await;
450 }
451 .lock_read()
452 .await;
453 }
454
455 #[test(flavor = "multi_thread", shared)]
456 async fn pool_provides_restricted_databases() {
457 let backend = create_backend(true).await.drop_previous_databases(false);
458
459 async {
460 let db_pool = backend.create_database_pool().await.unwrap();
461 let conn_pool = db_pool.pull_immutable().await;
462 let conn = &mut conn_pool.get().await.unwrap();
463
464 for stmt in DDL_STATEMENTS {
466 assert!(sql_query(stmt).execute(conn).await.is_err());
467 }
468
469 for stmt in DML_STATEMENTS {
471 assert!(sql_query(stmt).execute(conn).await.is_ok());
472 }
473 }
474 .lock_read()
475 .await;
476 }
477
478 #[test(flavor = "multi_thread", shared)]
479 async fn pool_provides_unrestricted_databases() {
480 let backend = create_backend(true).await.drop_previous_databases(false);
481
482 async {
483 let db_pool = backend.create_database_pool().await.unwrap();
484
485 {
487 let conn_pool = db_pool.create_mutable().await.unwrap();
488 let conn = &mut conn_pool.get().await.unwrap();
489 for stmt in DML_STATEMENTS {
490 assert!(sql_query(stmt).execute(conn).await.is_ok());
491 }
492 }
493
494 for stmt in DDL_STATEMENTS {
496 let conn_pool = db_pool.create_mutable().await.unwrap();
497 let conn = &mut conn_pool.get().await.unwrap();
498 assert!(sql_query(stmt).execute(conn).await.is_ok());
499 }
500 }
501 .lock_read()
502 .await;
503 }
504
505 #[test(flavor = "multi_thread", shared)]
506 async fn pool_provides_clean_databases() {
507 const NUM_DBS: i64 = 3;
508
509 let backend = create_backend(true).await.drop_previous_databases(false);
510
511 async {
512 let db_pool = backend.create_database_pool().await.unwrap();
513
514 {
516 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
517
518 join_all(conn_pools.iter().map(|conn_pool| async move {
520 let conn = &mut conn_pool.get().await.unwrap();
521 assert_eq!(
522 book::table.count().get_result::<i64>(conn).await.unwrap(),
523 0
524 );
525 }))
526 .await;
527
528 join_all(conn_pools.iter().map(|conn_pool| async move {
530 let conn = &mut conn_pool.get().await.unwrap();
531 insert_into(book::table)
532 .values(NewBook {
533 title: "Title".into(),
534 })
535 .execute(conn)
536 .await
537 .unwrap();
538 }))
539 .await;
540 }
541
542 {
544 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
545
546 join_all(conn_pools.iter().map(|conn_pool| async move {
548 let conn = &mut conn_pool.get().await.unwrap();
549 assert_eq!(
550 book::table.count().get_result::<i64>(conn).await.unwrap(),
551 0
552 );
553 }))
554 .await;
555 }
556 }
557 .lock_read()
558 .await;
559 }
560
561 #[test(flavor = "multi_thread", shared)]
562 async fn pool_drops_created_restricted_databases() {
563 let backend = create_backend(false).await;
564 test_pool_drops_created_restricted_databases(backend).await;
565 }
566
567 #[test(flavor = "multi_thread", shared)]
568 async fn pool_drops_created_unrestricted_database() {
569 let backend = create_backend(false).await;
570 test_pool_drops_created_unrestricted_database(backend).await;
571 }
572}