db_pool/sync/backend/mysql/
mysql.rs

1use std::borrow::Cow;
2
3use r2d2::{Builder, Pool, PooledConnection};
4use r2d2_mysql::{
5    mysql::{prelude::*, Conn, Error, Opts, OptsBuilder},
6    MySqlConnectionManager,
7};
8use uuid::Uuid;
9
10use crate::{common::statement::mysql, util::get_db_name};
11
12use super::{
13    super::{error::Error as BackendError, r#trait::Backend},
14    r#trait::{MySQLBackend as MySQLBackendTrait, MySQLBackendWrapper},
15};
16
17type Manager = MySqlConnectionManager;
18
19/// MySQL backend
20pub struct MySQLBackend {
21    opts: Opts,
22    default_pool: Pool<Manager>,
23    create_restricted_pool: Box<dyn Fn() -> Builder<Manager> + Send + Sync + 'static>,
24    create_entities: Box<dyn Fn(&mut Conn) + Send + Sync + 'static>,
25    drop_previous_databases_flag: bool,
26}
27
28impl MySQLBackend {
29    /// Creates a new MySQL backend
30    /// # Example
31    /// ```
32    /// use db_pool::{sync::MySQLBackend, PrivilegedMySQLConfig};
33    /// use dotenvy::dotenv;
34    /// use r2d2::Pool;
35    /// use r2d2_mysql::mysql::{prelude::Queryable, OptsBuilder};
36    ///
37    /// dotenv().ok();
38    ///
39    /// let config = PrivilegedMySQLConfig::from_env().unwrap();
40    ///
41    /// let backend = MySQLBackend::new(
42    ///     config.into(),
43    ///     || Pool::builder().max_size(10),
44    ///     || Pool::builder().max_size(2),
45    ///     move |conn| {
46    ///         conn.query_drop(
47    ///             "CREATE TABLE book(id INTEGER PRIMARY KEY AUTO_INCREMENT, title TEXT NOT NULL)",
48    ///         )
49    ///         .unwrap();
50    ///     },
51    /// )
52    /// .unwrap();
53    /// ```
54    pub fn new(
55        opts: Opts,
56        create_privileged_pool: impl Fn() -> Builder<Manager>,
57        create_restricted_pool: impl Fn() -> Builder<Manager> + Send + Sync + 'static,
58        create_entities: impl Fn(&mut Conn) + Send + Sync + 'static,
59    ) -> Result<Self, r2d2::Error> {
60        let manager = Manager::new(OptsBuilder::from_opts(opts.clone()));
61        let default_pool = (create_privileged_pool()).build(manager)?;
62
63        Ok(Self {
64            opts,
65            default_pool,
66            create_entities: Box::new(create_entities),
67            create_restricted_pool: Box::new(create_restricted_pool),
68            drop_previous_databases_flag: true,
69        })
70    }
71
72    /// Drop databases created in previous runs upon initialization
73    #[must_use]
74    pub fn drop_previous_databases(self, value: bool) -> Self {
75        Self {
76            drop_previous_databases_flag: value,
77            ..self
78        }
79    }
80}
81
82impl MySQLBackendTrait for MySQLBackend {
83    type ConnectionManager = Manager;
84    type ConnectionError = Error;
85    type QueryError = Error;
86
87    fn get_connection(&self) -> Result<PooledConnection<Manager>, r2d2::Error> {
88        self.default_pool.get()
89    }
90
91    fn execute(&self, query: &str, conn: &mut Conn) -> Result<(), Error> {
92        conn.query_drop(query)
93    }
94
95    fn batch_execute<'a>(
96        &self,
97        query: impl IntoIterator<Item = Cow<'a, str>>,
98        conn: &mut Conn,
99    ) -> Result<(), Error> {
100        let chunks = query.into_iter().collect::<Vec<_>>();
101        if chunks.is_empty() {
102            Ok(())
103        } else {
104            let query = chunks.join(";");
105            self.execute(query.as_str(), conn)
106        }
107    }
108
109    fn get_host(&self) -> Cow<str> {
110        self.opts.get_ip_or_hostname()
111    }
112
113    fn get_previous_database_names(
114        &self,
115        conn: &mut <Self::ConnectionManager as r2d2::ManageConnection>::Connection,
116    ) -> Result<Vec<String>, Error> {
117        conn.query(mysql::GET_DATABASE_NAMES)
118    }
119
120    fn create_entities(&self, conn: &mut Conn) {
121        (self.create_entities)(conn);
122    }
123
124    fn create_connection_pool(&self, db_id: Uuid) -> Result<Pool<Manager>, r2d2::Error> {
125        let db_name = get_db_name(db_id);
126        let db_name = db_name.as_str();
127        let opts = OptsBuilder::from_opts(self.opts.clone())
128            .db_name(Some(db_name))
129            .user(Some(db_name))
130            .pass(Some(db_name));
131        let manager = MySqlConnectionManager::new(opts);
132        (self.create_restricted_pool)().build(manager)
133    }
134
135    fn get_table_names(&self, db_name: &str, conn: &mut Conn) -> Result<Vec<String>, Error> {
136        conn.query(mysql::get_table_names(db_name))
137    }
138
139    fn get_drop_previous_databases(&self) -> bool {
140        self.drop_previous_databases_flag
141    }
142}
143
144impl From<Error> for BackendError<Error, Error> {
145    fn from(value: Error) -> Self {
146        Self::Query(value)
147    }
148}
149
150impl Backend for MySQLBackend {
151    type ConnectionManager = Manager;
152    type ConnectionError = Error;
153    type QueryError = Error;
154
155    fn init(&self) -> Result<(), BackendError<Error, Error>> {
156        MySQLBackendWrapper::new(self).init()
157    }
158
159    fn create(
160        &self,
161        db_id: Uuid,
162        restrict_privileges: bool,
163    ) -> Result<Pool<Manager>, BackendError<Error, Error>> {
164        MySQLBackendWrapper::new(self).create(db_id, restrict_privileges)
165    }
166
167    fn clean(&self, db_id: Uuid) -> Result<(), BackendError<Error, Error>> {
168        MySQLBackendWrapper::new(self).clean(db_id)
169    }
170
171    fn drop(&self, db_id: Uuid, _is_restricted: bool) -> Result<(), BackendError<Error, Error>> {
172        MySQLBackendWrapper::new(self).drop(db_id)
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    #![allow(unused_variables, clippy::unwrap_used)]
179
180    use r2d2::Pool;
181    use r2d2_mysql::mysql::{params, prelude::Queryable};
182
183    use crate::{
184        common::statement::mysql::tests::{
185            CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
186        },
187        sync::{
188            backend::mysql::r#trait::tests::{
189                test_backend_creates_database_with_unrestricted_privileges,
190                test_pool_drops_created_unrestricted_database,
191            },
192            DatabasePoolBuilderTrait,
193        },
194        tests::get_privileged_mysql_config,
195    };
196
197    use super::{
198        super::r#trait::tests::{
199            lock_read, test_backend_cleans_database_with_tables,
200            test_backend_cleans_database_without_tables,
201            test_backend_creates_database_with_restricted_privileges, test_backend_drops_database,
202            test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
203            test_pool_drops_previous_databases,
204        },
205        MySQLBackend,
206    };
207
208    fn create_backend(with_table: bool) -> MySQLBackend {
209        let config = get_privileged_mysql_config().clone();
210        MySQLBackend::new(config.into(), Pool::builder, Pool::builder, {
211            move |conn| {
212                if with_table {
213                    conn.query_drop(CREATE_ENTITIES_STATEMENTS.join(";"))
214                        .unwrap();
215                }
216            }
217        })
218        .unwrap()
219    }
220
221    #[test]
222    fn backend_drops_previous_databases() {
223        test_backend_drops_previous_databases(
224            create_backend(false),
225            create_backend(false).drop_previous_databases(true),
226            create_backend(false).drop_previous_databases(false),
227        );
228    }
229
230    #[test]
231    fn backend_creates_database_with_restricted_privileges() {
232        let backend = create_backend(true).drop_previous_databases(false);
233        test_backend_creates_database_with_restricted_privileges(&backend);
234    }
235
236    #[test]
237    fn backend_creates_database_with_unrestricted_privileges() {
238        let backend = create_backend(true).drop_previous_databases(false);
239        test_backend_creates_database_with_unrestricted_privileges(&backend);
240    }
241
242    #[test]
243    fn backend_cleans_database_with_tables() {
244        let backend = create_backend(true).drop_previous_databases(false);
245        test_backend_cleans_database_with_tables(&backend);
246    }
247
248    #[test]
249    fn backend_cleans_database_without_tables() {
250        let backend = create_backend(false).drop_previous_databases(false);
251        test_backend_cleans_database_without_tables(&backend);
252    }
253
254    #[test]
255    fn backend_drops_restricted_database() {
256        let backend = create_backend(true).drop_previous_databases(false);
257        test_backend_drops_database(&backend, true);
258    }
259
260    #[test]
261    fn backend_drops_unrestricted_database() {
262        let backend = create_backend(true).drop_previous_databases(false);
263        test_backend_drops_database(&backend, false);
264    }
265
266    #[test]
267    fn pool_drops_previous_databases() {
268        test_pool_drops_previous_databases(
269            create_backend(false),
270            create_backend(false).drop_previous_databases(true),
271            create_backend(false).drop_previous_databases(false),
272        );
273    }
274
275    #[test]
276    fn pool_provides_isolated_databases() {
277        const NUM_DBS: i64 = 3;
278
279        let backend = create_backend(true).drop_previous_databases(false);
280
281        let guard = lock_read();
282
283        let db_pool = backend.create_database_pool().unwrap();
284        let conn_pools = (0..NUM_DBS)
285            .map(|_| db_pool.pull_immutable())
286            .collect::<Vec<_>>();
287
288        // insert single row into each database
289        conn_pools.iter().enumerate().for_each(|(i, conn_pool)| {
290            let conn = &mut conn_pool.get().unwrap();
291            conn.exec_drop(
292                "INSERT INTO book (title) VALUES (:title)",
293                params! {
294                    "title" => format!("Title {i}"),
295                },
296            )
297            .unwrap();
298        });
299
300        // rows fetched must be as inserted
301        conn_pools.iter().enumerate().for_each(|(i, conn_pool)| {
302            let conn = &mut conn_pool.get().unwrap();
303            assert_eq!(
304                conn.query::<String, _>("SELECT title FROM book").unwrap(),
305                vec![format!("Title {i}")]
306            );
307        });
308    }
309
310    #[test]
311    fn pool_provides_restricted_databases() {
312        let backend = create_backend(true).drop_previous_databases(false);
313
314        let guard = lock_read();
315
316        let db_pool = backend.create_database_pool().unwrap();
317        let conn_pool = db_pool.pull_immutable();
318        let conn = &mut conn_pool.get().unwrap();
319
320        // DDL statements must fail
321        for stmt in DDL_STATEMENTS {
322            assert!(conn.query_drop(stmt).is_err());
323        }
324
325        // DML statements must succeed
326        for stmt in DML_STATEMENTS {
327            assert!(conn.query_drop(stmt).is_ok());
328        }
329    }
330
331    #[test]
332    fn pool_provides_unrestricted_databases() {
333        let backend = create_backend(true).drop_previous_databases(false);
334
335        let guard = lock_read();
336
337        let db_pool = backend.create_database_pool().unwrap();
338
339        // DML statements must succeed
340        {
341            let conn_pool = db_pool.create_mutable().unwrap();
342            let conn = &mut conn_pool.get().unwrap();
343            for stmt in DML_STATEMENTS {
344                assert!(conn.query_drop(stmt).is_ok());
345            }
346        }
347
348        // DDL statements must succeed
349        for stmt in DDL_STATEMENTS {
350            let conn_pool = db_pool.create_mutable().unwrap();
351            let conn = &mut conn_pool.get().unwrap();
352            assert!(conn.query_drop(stmt).is_ok());
353        }
354    }
355
356    #[test]
357    fn pool_provides_clean_databases() {
358        const NUM_DBS: i64 = 3;
359
360        let backend = create_backend(true).drop_previous_databases(false);
361
362        let guard = lock_read();
363
364        let db_pool = backend.create_database_pool().unwrap();
365
366        // fetch connection pools the first time
367        {
368            let conn_pools = (0..NUM_DBS)
369                .map(|_| db_pool.pull_immutable())
370                .collect::<Vec<_>>();
371
372            // databases must be empty
373            for conn_pool in &conn_pools {
374                let conn = &mut conn_pool.get().unwrap();
375                assert_eq!(
376                    conn.query_first::<i64, _>("SELECT COUNT(*) FROM book")
377                        .unwrap()
378                        .unwrap(),
379                    0
380                );
381            }
382
383            // insert data into each database
384            for conn_pool in &conn_pools {
385                let conn = &mut conn_pool.get().unwrap();
386                conn.exec_drop(
387                    "INSERT INTO book (title) VALUES (:title)",
388                    params! {
389                        "title" => "Title",
390                    },
391                )
392                .unwrap();
393            }
394        }
395
396        // fetch same connection pools a second time
397        {
398            let conn_pools = (0..NUM_DBS)
399                .map(|_| db_pool.pull_immutable())
400                .collect::<Vec<_>>();
401
402            // databases must be empty
403            for conn_pool in &conn_pools {
404                let conn = &mut conn_pool.get().unwrap();
405                assert_eq!(
406                    conn.query_first::<i64, _>("SELECT COUNT(*) FROM book")
407                        .unwrap()
408                        .unwrap(),
409                    0
410                );
411            }
412        }
413    }
414
415    #[test]
416    fn pool_drops_created_restricted_databases() {
417        let backend = create_backend(false);
418        test_pool_drops_created_restricted_databases(backend);
419    }
420
421    #[test]
422    fn pool_drops_created_unrestricted_database() {
423        let backend = create_backend(false);
424        test_pool_drops_created_unrestricted_database(backend);
425    }
426}