db_pool/async/backend/mysql/
sqlx.rs1use std::{borrow::Cow, pin::Pin};
2
3use async_trait::async_trait;
4use futures::Future;
5use sqlx::{
6 Connection, Executor, MySql, MySqlConnection, MySqlPool, Row,
7 mysql::{MySqlConnectOptions, MySqlPoolOptions},
8 pool::PoolConnection,
9};
10use uuid::Uuid;
11
12use crate::{common::statement::mysql, util::get_db_name};
13
14use super::{
15 super::{
16 common::error::sqlx::{BuildError, ConnectionError, PoolError, QueryError},
17 error::Error as BackendError,
18 r#trait::Backend,
19 },
20 r#trait::{MySQLBackend, MySQLBackendWrapper},
21};
22
23type CreateEntities = dyn Fn(MySqlConnection) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
24 + Send
25 + Sync
26 + 'static;
27
28pub struct SqlxMySQLBackend {
30 privileged_opts: MySqlConnectOptions,
31 default_pool: MySqlPool,
32 create_restricted_pool: Box<dyn Fn() -> MySqlPoolOptions + Send + Sync + 'static>,
33 create_entities: Box<CreateEntities>,
34 drop_previous_databases_flag: bool,
35}
36
37impl SqlxMySQLBackend {
38 pub fn new(
67 privileged_options: MySqlConnectOptions,
68 create_privileged_pool: impl Fn() -> MySqlPoolOptions,
69 create_restricted_pool: impl Fn() -> MySqlPoolOptions + Send + Sync + 'static,
70 create_entities: impl Fn(MySqlConnection) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
71 + Send
72 + Sync
73 + 'static,
74 ) -> Self {
75 let pool_opts = create_privileged_pool();
76 let default_pool = pool_opts.connect_lazy_with(privileged_options.clone());
77
78 Self {
79 privileged_opts: privileged_options,
80 default_pool,
81 create_restricted_pool: Box::new(create_restricted_pool),
82 create_entities: Box::new(create_entities),
83 drop_previous_databases_flag: true,
84 }
85 }
86
87 #[must_use]
89 pub fn drop_previous_databases(self, value: bool) -> Self {
90 Self {
91 drop_previous_databases_flag: value,
92 ..self
93 }
94 }
95}
96
97#[async_trait]
98impl<'pool> MySQLBackend<'pool> for SqlxMySQLBackend {
99 type Connection = MySqlConnection;
100 type PooledConnection = PoolConnection<MySql>;
101 type Pool = MySqlPool;
102
103 type BuildError = BuildError;
104 type PoolError = PoolError;
105 type ConnectionError = ConnectionError;
106 type QueryError = QueryError;
107
108 async fn get_connection(&'pool self) -> Result<PoolConnection<MySql>, PoolError> {
109 self.default_pool.acquire().await.map_err(Into::into)
110 }
111
112 async fn execute_query(
113 &self,
114 query: &str,
115 conn: &mut MySqlConnection,
116 ) -> Result<(), QueryError> {
117 conn.execute(query).await?;
118 Ok(())
119 }
120
121 async fn batch_execute_query<'a>(
122 &self,
123 query: impl IntoIterator<Item = Cow<'a, str>> + Send,
124 conn: &mut MySqlConnection,
125 ) -> Result<(), QueryError> {
126 let chunks = query.into_iter().collect::<Vec<_>>();
127 if chunks.is_empty() {
128 Ok(())
129 } else {
130 let query = chunks.join(";");
131 self.execute_query(query.as_str(), conn).await
132 }
133 }
134
135 fn get_host(&self) -> &str {
136 self.privileged_opts.get_host()
137 }
138
139 async fn get_previous_database_names(
140 &self,
141 conn: &mut MySqlConnection,
142 ) -> Result<Vec<String>, QueryError> {
143 conn.fetch_all(mysql::GET_DATABASE_NAMES)
144 .await?
145 .iter()
146 .map(|row| row.try_get(0))
147 .collect::<Result<Vec<_>, _>>()
148 .map_err(Into::into)
149 }
150
151 async fn create_entities(&self, db_name: &str) -> Result<(), ConnectionError> {
152 let opts = self.privileged_opts.clone().database(db_name);
153 let conn = MySqlConnection::connect_with(&opts).await?;
154 (self.create_entities)(conn).await;
155 Ok(())
156 }
157
158 async fn create_connection_pool(&self, db_id: Uuid) -> Result<MySqlPool, BuildError> {
159 let db_name = get_db_name(db_id);
160 let db_name = db_name.as_str();
161 let opts = self
162 .privileged_opts
163 .clone()
164 .database(db_name)
165 .username(db_name)
166 .password(db_name);
167 let pool = (self.create_restricted_pool)().connect_lazy_with(opts);
168 Ok(pool)
169 }
170
171 async fn get_table_names(
172 &self,
173 db_name: &str,
174 conn: &mut MySqlConnection,
175 ) -> Result<Vec<String>, QueryError> {
176 conn.fetch_all(mysql::get_table_names(db_name).as_str())
177 .await?
178 .iter()
179 .map(|row| row.try_get(0))
180 .collect::<Result<Vec<_>, _>>()
181 .map_err(Into::into)
182 }
183
184 fn get_drop_previous_databases(&self) -> bool {
185 self.drop_previous_databases_flag
186 }
187}
188
189type BError = BackendError<BuildError, PoolError, ConnectionError, QueryError>;
190
191#[async_trait]
192impl Backend for SqlxMySQLBackend {
193 type Pool = MySqlPool;
194
195 type BuildError = BuildError;
196 type PoolError = PoolError;
197 type ConnectionError = ConnectionError;
198 type QueryError = QueryError;
199
200 async fn init(&self) -> Result<(), BError> {
201 MySQLBackendWrapper::new(self).init().await
202 }
203
204 async fn create(
205 &self,
206 db_id: uuid::Uuid,
207 restrict_privileges: bool,
208 ) -> Result<MySqlPool, BError> {
209 MySQLBackendWrapper::new(self)
210 .create(db_id, restrict_privileges)
211 .await
212 }
213
214 async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError> {
215 MySQLBackendWrapper::new(self).clean(db_id).await
216 }
217
218 async fn drop(&self, db_id: uuid::Uuid, _is_restricted: bool) -> Result<(), BError> {
219 MySQLBackendWrapper::new(self).drop(db_id).await
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 #![allow(clippy::unwrap_used, clippy::needless_return)]
226
227 use futures::{StreamExt, future::join_all};
228 use sqlx::{
229 Executor, FromRow, Row,
230 mysql::{MySqlConnectOptions, MySqlPoolOptions},
231 query, query_as,
232 };
233 use tokio_shared_rt::test;
234
235 use crate::{
236 r#async::{
237 backend::mysql::r#trait::tests::test_backend_creates_database_with_unrestricted_privileges,
238 db_pool::DatabasePoolBuilder,
239 },
240 common::statement::mysql::tests::{
241 CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
242 },
243 tests::get_privileged_mysql_config,
244 };
245
246 use super::{
247 super::r#trait::tests::{
248 MySQLDropLock, test_backend_cleans_database_with_tables,
249 test_backend_cleans_database_without_tables,
250 test_backend_creates_database_with_restricted_privileges, test_backend_drops_database,
251 test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
252 test_pool_drops_created_unrestricted_database, test_pool_drops_previous_databases,
253 },
254 SqlxMySQLBackend,
255 };
256
257 fn create_backend(with_table: bool) -> SqlxMySQLBackend {
258 let config = get_privileged_mysql_config();
259 let opts = MySqlConnectOptions::new().username(config.username.as_str());
260 let opts = if let Some(password) = &config.password {
261 opts.password(password)
262 } else {
263 opts
264 };
265 SqlxMySQLBackend::new(opts, MySqlPoolOptions::new, MySqlPoolOptions::new, {
266 move |mut conn| {
267 if with_table {
268 Box::pin(async move {
269 conn.execute_many(CREATE_ENTITIES_STATEMENTS.join(";").as_str())
270 .collect::<Vec<_>>()
271 .await
272 .drain(..)
273 .collect::<Result<Vec<_>, _>>()
274 .unwrap();
275 })
276 } else {
277 Box::pin(async {})
278 }
279 }
280 })
281 }
282
283 #[test(flavor = "multi_thread", shared)]
284 async fn backend_drops_previous_databases() {
285 test_backend_drops_previous_databases(
286 create_backend(false),
287 create_backend(false).drop_previous_databases(true),
288 create_backend(false).drop_previous_databases(false),
289 )
290 .await;
291 }
292
293 #[test(flavor = "multi_thread", shared)]
294 async fn backend_creates_database_with_restricted_privileges() {
295 let backend = create_backend(true).drop_previous_databases(false);
296 test_backend_creates_database_with_restricted_privileges(backend).await;
297 }
298
299 #[test(flavor = "multi_thread", shared)]
300 async fn backend_creates_database_with_unrestricted_privileges() {
301 let backend = create_backend(true).drop_previous_databases(false);
302 test_backend_creates_database_with_unrestricted_privileges(backend).await;
303 }
304
305 #[test(flavor = "multi_thread", shared)]
306 async fn backend_cleans_database_with_tables() {
307 let backend = create_backend(true).drop_previous_databases(false);
308 test_backend_cleans_database_with_tables(backend).await;
309 }
310
311 #[test(flavor = "multi_thread", shared)]
312 async fn backend_cleans_database_without_tables() {
313 let backend = create_backend(false).drop_previous_databases(false);
314 test_backend_cleans_database_without_tables(backend).await;
315 }
316
317 #[test(flavor = "multi_thread", shared)]
318 async fn backend_drops_restricted_database() {
319 let backend = create_backend(true).drop_previous_databases(false);
320 test_backend_drops_database(backend, true).await;
321 }
322
323 #[test(flavor = "multi_thread", shared)]
324 async fn backend_drops_unrestricted_database() {
325 let backend = create_backend(true).drop_previous_databases(false);
326 test_backend_drops_database(backend, false).await;
327 }
328
329 #[test(flavor = "multi_thread", shared)]
330 async fn pool_drops_previous_databases() {
331 test_pool_drops_previous_databases(
332 create_backend(false),
333 create_backend(false).drop_previous_databases(true),
334 create_backend(false).drop_previous_databases(false),
335 )
336 .await;
337 }
338
339 #[test(flavor = "multi_thread", shared)]
340 async fn pool_provides_isolated_databases() {
341 #[derive(FromRow, Eq, PartialEq, Debug)]
342 struct Book {
343 title: String,
344 }
345
346 const NUM_DBS: i64 = 3;
347
348 let backend = create_backend(true).drop_previous_databases(false);
349
350 async {
351 let db_pool = backend.create_database_pool().await.unwrap();
352 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
353
354 join_all(
356 conn_pools
357 .iter()
358 .enumerate()
359 .map(|(i, conn_pool)| async move {
360 query("INSERT INTO book (title) VALUES (?)")
361 .bind(format!("Title {i}"))
362 .execute(&***conn_pool)
363 .await
364 .unwrap();
365 }),
366 )
367 .await;
368
369 join_all(
371 conn_pools
372 .iter()
373 .enumerate()
374 .map(|(i, conn_pool)| async move {
375 assert_eq!(
376 query_as::<_, Book>("SELECT title FROM book")
377 .fetch_all(&***conn_pool)
378 .await
379 .unwrap(),
380 vec![Book {
381 title: format!("Title {i}")
382 }]
383 );
384 }),
385 )
386 .await;
387 }
388 .lock_read()
389 .await;
390 }
391
392 #[test(flavor = "multi_thread", shared)]
393 async fn pool_provides_restricted_databases() {
394 let backend = create_backend(true).drop_previous_databases(false);
395
396 async {
397 let db_pool = backend.create_database_pool().await.unwrap();
398
399 let conn_pool = db_pool.pull_immutable().await;
400 let conn = &mut conn_pool.acquire().await.unwrap();
401
402 for stmt in DDL_STATEMENTS {
404 assert!(conn.execute(stmt).await.is_err());
405 }
406
407 for stmt in DML_STATEMENTS {
409 assert!(conn.execute(stmt).await.is_ok());
410 }
411 }
412 .lock_read()
413 .await;
414 }
415
416 #[test(flavor = "multi_thread", shared)]
417 async fn pool_provides_unrestricted_databases() {
418 let backend = create_backend(true).drop_previous_databases(false);
419
420 async {
421 let db_pool = backend.create_database_pool().await.unwrap();
422
423 {
425 let conn_pool = db_pool.create_mutable().await.unwrap();
426 let conn = &mut conn_pool.acquire().await.unwrap();
427 for stmt in DML_STATEMENTS {
428 assert!(conn.execute(stmt).await.is_ok());
429 }
430 }
431
432 for stmt in DDL_STATEMENTS {
434 let conn_pool = db_pool.create_mutable().await.unwrap();
435 let conn = &mut conn_pool.acquire().await.unwrap();
436 assert!(conn.execute(stmt).await.is_ok());
437 }
438 }
439 .lock_read()
440 .await;
441 }
442
443 #[test(flavor = "multi_thread", shared)]
444 async fn pool_provides_clean_databases() {
445 const NUM_DBS: i64 = 3;
446
447 let backend = create_backend(true).drop_previous_databases(false);
448
449 async {
450 let db_pool = backend.create_database_pool().await.unwrap();
451
452 {
454 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
455
456 join_all(conn_pools.iter().map(|conn_pool| async move {
458 assert_eq!(
459 query("SELECT COUNT(*) FROM book")
460 .fetch_one(&***conn_pool)
461 .await
462 .unwrap()
463 .get::<i64, _>(0),
464 0
465 );
466 }))
467 .await;
468
469 join_all(conn_pools.iter().map(|conn_pool| async move {
471 query("INSERT INTO book (title) VALUES (?)")
472 .bind("Title")
473 .execute(&***conn_pool)
474 .await
475 .unwrap();
476 }))
477 .await;
478 }
479
480 {
482 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
483
484 join_all(conn_pools.iter().map(|conn_pool| async move {
486 assert_eq!(
487 query("SELECT COUNT(*) FROM book")
488 .fetch_one(&***conn_pool)
489 .await
490 .unwrap()
491 .get::<i64, _>(0),
492 0
493 );
494 }))
495 .await;
496 }
497 }
498 .lock_read()
499 .await;
500 }
501
502 #[test(flavor = "multi_thread", shared)]
503 async fn pool_drops_created_restricted_databases() {
504 let backend = create_backend(false);
505 test_pool_drops_created_restricted_databases(backend).await;
506 }
507
508 #[test(flavor = "multi_thread", shared)]
509 async fn pool_drops_created_unrestricted_databases() {
510 let backend = create_backend(false);
511 test_pool_drops_created_unrestricted_database(backend).await;
512 }
513}