db_pool/async/backend/mysql/
sqlx.rs1use std::{borrow::Cow, pin::Pin};
2
3use async_trait::async_trait;
4use futures::Future;
5use sqlx::{
6 mysql::{MySqlConnectOptions, MySqlPoolOptions},
7 pool::PoolConnection,
8 Connection, Executor, MySql, MySqlConnection, MySqlPool, Row,
9};
10use uuid::Uuid;
11
12use crate::{common::statement::mysql, util::get_db_name};
13
14use super::{
15 super::{
16 common::error::sqlx::{BuildError, ConnectionError, PoolError, QueryError},
17 error::Error as BackendError,
18 r#trait::Backend,
19 },
20 r#trait::{MySQLBackend, MySQLBackendWrapper},
21};
22
23type CreateEntities = dyn Fn(MySqlConnection) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
24 + Send
25 + Sync
26 + 'static;
27
28pub struct SqlxMySQLBackend {
30 privileged_opts: MySqlConnectOptions,
31 default_pool: MySqlPool,
32 create_restricted_pool: Box<dyn Fn() -> MySqlPoolOptions + Send + Sync + 'static>,
33 create_entities: Box<CreateEntities>,
34 drop_previous_databases_flag: bool,
35}
36
37impl SqlxMySQLBackend {
38 pub fn new(
67 privileged_options: MySqlConnectOptions,
68 create_privileged_pool: impl Fn() -> MySqlPoolOptions,
69 create_restricted_pool: impl Fn() -> MySqlPoolOptions + Send + Sync + 'static,
70 create_entities: impl Fn(MySqlConnection) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
71 + Send
72 + Sync
73 + 'static,
74 ) -> Self {
75 let pool_opts = create_privileged_pool();
76 let default_pool = pool_opts.connect_lazy_with(privileged_options.clone());
77
78 Self {
79 privileged_opts: privileged_options,
80 default_pool,
81 create_restricted_pool: Box::new(create_restricted_pool),
82 create_entities: Box::new(create_entities),
83 drop_previous_databases_flag: true,
84 }
85 }
86
87 #[must_use]
89 pub fn drop_previous_databases(self, value: bool) -> Self {
90 Self {
91 drop_previous_databases_flag: value,
92 ..self
93 }
94 }
95}
96
97#[async_trait]
98impl<'pool> MySQLBackend<'pool> for SqlxMySQLBackend {
99 type Connection = MySqlConnection;
100 type PooledConnection = PoolConnection<MySql>;
101 type Pool = MySqlPool;
102
103 type BuildError = BuildError;
104 type PoolError = PoolError;
105 type ConnectionError = ConnectionError;
106 type QueryError = QueryError;
107
108 async fn get_connection(&'pool self) -> Result<PoolConnection<MySql>, PoolError> {
109 self.default_pool.acquire().await.map_err(Into::into)
110 }
111
112 async fn execute_query(
113 &self,
114 query: &str,
115 conn: &mut MySqlConnection,
116 ) -> Result<(), QueryError> {
117 conn.execute(query).await?;
118 Ok(())
119 }
120
121 async fn batch_execute_query<'a>(
122 &self,
123 query: impl IntoIterator<Item = Cow<'a, str>> + Send,
124 conn: &mut MySqlConnection,
125 ) -> Result<(), QueryError> {
126 let chunks = query.into_iter().collect::<Vec<_>>();
127 if chunks.is_empty() {
128 Ok(())
129 } else {
130 let query = chunks.join(";");
131 self.execute_query(query.as_str(), conn).await
132 }
133 }
134
135 fn get_host(&self) -> &str {
136 self.privileged_opts.get_host()
137 }
138
139 async fn get_previous_database_names(
140 &self,
141 conn: &mut MySqlConnection,
142 ) -> Result<Vec<String>, QueryError> {
143 conn.fetch_all(mysql::GET_DATABASE_NAMES)
144 .await?
145 .iter()
146 .map(|row| row.try_get(0))
147 .collect::<Result<Vec<_>, _>>()
148 .map_err(Into::into)
149 }
150
151 async fn create_entities(&self, db_name: &str) -> Result<(), ConnectionError> {
152 let opts = self.privileged_opts.clone().database(db_name);
153 let conn = MySqlConnection::connect_with(&opts).await?;
154 (self.create_entities)(conn).await;
155 Ok(())
156 }
157
158 async fn create_connection_pool(&self, db_id: Uuid) -> Result<MySqlPool, BuildError> {
159 let db_name = get_db_name(db_id);
160 let db_name = db_name.as_str();
161 let opts = self
162 .privileged_opts
163 .clone()
164 .database(db_name)
165 .username(db_name)
166 .password(db_name);
167 let pool = (self.create_restricted_pool)().connect_lazy_with(opts);
168 Ok(pool)
169 }
170
171 async fn get_table_names(
172 &self,
173 db_name: &str,
174 conn: &mut MySqlConnection,
175 ) -> Result<Vec<String>, QueryError> {
176 conn.fetch_all(mysql::get_table_names(db_name).as_str())
177 .await?
178 .iter()
179 .map(|row| row.try_get(0))
180 .collect::<Result<Vec<_>, _>>()
181 .map_err(Into::into)
182 }
183
184 fn get_drop_previous_databases(&self) -> bool {
185 self.drop_previous_databases_flag
186 }
187}
188
189type BError = BackendError<BuildError, PoolError, ConnectionError, QueryError>;
190
191#[async_trait]
192impl Backend for SqlxMySQLBackend {
193 type Pool = MySqlPool;
194
195 type BuildError = BuildError;
196 type PoolError = PoolError;
197 type ConnectionError = ConnectionError;
198 type QueryError = QueryError;
199
200 async fn init(&self) -> Result<(), BError> {
201 MySQLBackendWrapper::new(self).init().await
202 }
203
204 async fn create(
205 &self,
206 db_id: uuid::Uuid,
207 restrict_privileges: bool,
208 ) -> Result<MySqlPool, BError> {
209 MySQLBackendWrapper::new(self)
210 .create(db_id, restrict_privileges)
211 .await
212 }
213
214 async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError> {
215 MySQLBackendWrapper::new(self).clean(db_id).await
216 }
217
218 async fn drop(&self, db_id: uuid::Uuid, _is_restricted: bool) -> Result<(), BError> {
219 MySQLBackendWrapper::new(self).drop(db_id).await
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 #![allow(clippy::unwrap_used, clippy::needless_return)]
226
227 use futures::{future::join_all, StreamExt};
228 use sqlx::{
229 mysql::{MySqlConnectOptions, MySqlPoolOptions},
230 query, query_as, Executor, FromRow, Row,
231 };
232 use tokio_shared_rt::test;
233
234 use crate::{
235 common::statement::mysql::tests::{
236 CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
237 },
238 r#async::{
239 backend::mysql::r#trait::tests::test_backend_creates_database_with_unrestricted_privileges,
240 db_pool::DatabasePoolBuilder,
241 },
242 tests::get_privileged_mysql_config,
243 };
244
245 use super::{
246 super::r#trait::tests::{
247 test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables,
248 test_backend_creates_database_with_restricted_privileges, test_backend_drops_database,
249 test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
250 test_pool_drops_created_unrestricted_database, test_pool_drops_previous_databases,
251 MySQLDropLock,
252 },
253 SqlxMySQLBackend,
254 };
255
256 fn create_backend(with_table: bool) -> SqlxMySQLBackend {
257 let config = get_privileged_mysql_config();
258 let opts = MySqlConnectOptions::new().username(config.username.as_str());
259 let opts = if let Some(password) = &config.password {
260 opts.password(password)
261 } else {
262 opts
263 };
264 SqlxMySQLBackend::new(opts, MySqlPoolOptions::new, MySqlPoolOptions::new, {
265 move |mut conn| {
266 if with_table {
267 Box::pin(async move {
268 conn.execute_many(CREATE_ENTITIES_STATEMENTS.join(";").as_str())
269 .collect::<Vec<_>>()
270 .await
271 .drain(..)
272 .collect::<Result<Vec<_>, _>>()
273 .unwrap();
274 })
275 } else {
276 Box::pin(async {})
277 }
278 }
279 })
280 }
281
282 #[test(flavor = "multi_thread", shared)]
283 async fn backend_drops_previous_databases() {
284 test_backend_drops_previous_databases(
285 create_backend(false),
286 create_backend(false).drop_previous_databases(true),
287 create_backend(false).drop_previous_databases(false),
288 )
289 .await;
290 }
291
292 #[test(flavor = "multi_thread", shared)]
293 async fn backend_creates_database_with_restricted_privileges() {
294 let backend = create_backend(true).drop_previous_databases(false);
295 test_backend_creates_database_with_restricted_privileges(backend).await;
296 }
297
298 #[test(flavor = "multi_thread", shared)]
299 async fn backend_creates_database_with_unrestricted_privileges() {
300 let backend = create_backend(true).drop_previous_databases(false);
301 test_backend_creates_database_with_unrestricted_privileges(backend).await;
302 }
303
304 #[test(flavor = "multi_thread", shared)]
305 async fn backend_cleans_database_with_tables() {
306 let backend = create_backend(true).drop_previous_databases(false);
307 test_backend_cleans_database_with_tables(backend).await;
308 }
309
310 #[test(flavor = "multi_thread", shared)]
311 async fn backend_cleans_database_without_tables() {
312 let backend = create_backend(false).drop_previous_databases(false);
313 test_backend_cleans_database_without_tables(backend).await;
314 }
315
316 #[test(flavor = "multi_thread", shared)]
317 async fn backend_drops_restricted_database() {
318 let backend = create_backend(true).drop_previous_databases(false);
319 test_backend_drops_database(backend, true).await;
320 }
321
322 #[test(flavor = "multi_thread", shared)]
323 async fn backend_drops_unrestricted_database() {
324 let backend = create_backend(true).drop_previous_databases(false);
325 test_backend_drops_database(backend, false).await;
326 }
327
328 #[test(flavor = "multi_thread", shared)]
329 async fn pool_drops_previous_databases() {
330 test_pool_drops_previous_databases(
331 create_backend(false),
332 create_backend(false).drop_previous_databases(true),
333 create_backend(false).drop_previous_databases(false),
334 )
335 .await;
336 }
337
338 #[test(flavor = "multi_thread", shared)]
339 async fn pool_provides_isolated_databases() {
340 #[derive(FromRow, Eq, PartialEq, Debug)]
341 struct Book {
342 title: String,
343 }
344
345 const NUM_DBS: i64 = 3;
346
347 let backend = create_backend(true).drop_previous_databases(false);
348
349 async {
350 let db_pool = backend.create_database_pool().await.unwrap();
351 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
352
353 join_all(
355 conn_pools
356 .iter()
357 .enumerate()
358 .map(|(i, conn_pool)| async move {
359 query("INSERT INTO book (title) VALUES (?)")
360 .bind(format!("Title {i}"))
361 .execute(&***conn_pool)
362 .await
363 .unwrap();
364 }),
365 )
366 .await;
367
368 join_all(
370 conn_pools
371 .iter()
372 .enumerate()
373 .map(|(i, conn_pool)| async move {
374 assert_eq!(
375 query_as::<_, Book>("SELECT title FROM book")
376 .fetch_all(&***conn_pool)
377 .await
378 .unwrap(),
379 vec![Book {
380 title: format!("Title {i}")
381 }]
382 );
383 }),
384 )
385 .await;
386 }
387 .lock_read()
388 .await;
389 }
390
391 #[test(flavor = "multi_thread", shared)]
392 async fn pool_provides_restricted_databases() {
393 let backend = create_backend(true).drop_previous_databases(false);
394
395 async {
396 let db_pool = backend.create_database_pool().await.unwrap();
397
398 let conn_pool = db_pool.pull_immutable().await;
399 let conn = &mut conn_pool.acquire().await.unwrap();
400
401 for stmt in DDL_STATEMENTS {
403 assert!(conn.execute(stmt).await.is_err());
404 }
405
406 for stmt in DML_STATEMENTS {
408 assert!(conn.execute(stmt).await.is_ok());
409 }
410 }
411 .lock_read()
412 .await;
413 }
414
415 #[test(flavor = "multi_thread", shared)]
416 async fn pool_provides_unrestricted_databases() {
417 let backend = create_backend(true).drop_previous_databases(false);
418
419 async {
420 let db_pool = backend.create_database_pool().await.unwrap();
421
422 {
424 let conn_pool = db_pool.create_mutable().await.unwrap();
425 let conn = &mut conn_pool.acquire().await.unwrap();
426 for stmt in DML_STATEMENTS {
427 assert!(conn.execute(stmt).await.is_ok());
428 }
429 }
430
431 for stmt in DDL_STATEMENTS {
433 let conn_pool = db_pool.create_mutable().await.unwrap();
434 let conn = &mut conn_pool.acquire().await.unwrap();
435 assert!(conn.execute(stmt).await.is_ok());
436 }
437 }
438 .lock_read()
439 .await;
440 }
441
442 #[test(flavor = "multi_thread", shared)]
443 async fn pool_provides_clean_databases() {
444 const NUM_DBS: i64 = 3;
445
446 let backend = create_backend(true).drop_previous_databases(false);
447
448 async {
449 let db_pool = backend.create_database_pool().await.unwrap();
450
451 {
453 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
454
455 join_all(conn_pools.iter().map(|conn_pool| async move {
457 assert_eq!(
458 query("SELECT COUNT(*) FROM book")
459 .fetch_one(&***conn_pool)
460 .await
461 .unwrap()
462 .get::<i64, _>(0),
463 0
464 );
465 }))
466 .await;
467
468 join_all(conn_pools.iter().map(|conn_pool| async move {
470 query("INSERT INTO book (title) VALUES (?)")
471 .bind("Title")
472 .execute(&***conn_pool)
473 .await
474 .unwrap();
475 }))
476 .await;
477 }
478
479 {
481 let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await;
482
483 join_all(conn_pools.iter().map(|conn_pool| async move {
485 assert_eq!(
486 query("SELECT COUNT(*) FROM book")
487 .fetch_one(&***conn_pool)
488 .await
489 .unwrap()
490 .get::<i64, _>(0),
491 0
492 );
493 }))
494 .await;
495 }
496 }
497 .lock_read()
498 .await;
499 }
500
501 #[test(flavor = "multi_thread", shared)]
502 async fn pool_drops_created_restricted_databases() {
503 let backend = create_backend(false);
504 test_pool_drops_created_restricted_databases(backend).await;
505 }
506
507 #[test(flavor = "multi_thread", shared)]
508 async fn pool_drops_created_unrestricted_databases() {
509 let backend = create_backend(false);
510 test_pool_drops_created_unrestricted_database(backend).await;
511 }
512}