db_pool/sync/backend/mysql/
diesel.rs

1use std::borrow::Cow;
2
3use diesel::{
4    connection::SimpleConnection,
5    mysql::MysqlConnection,
6    prelude::*,
7    r2d2::ConnectionManager,
8    result::{ConnectionError, Error, QueryResult},
9    sql_query,
10};
11use r2d2::{Builder, Pool, PooledConnection};
12use uuid::Uuid;
13
14use crate::{
15    common::{config::mysql::PrivilegedMySQLConfig, statement::mysql},
16    util::get_db_name,
17};
18
19use super::{
20    super::{error::Error as BackendError, r#trait::Backend},
21    r#trait::{MySQLBackend, MySQLBackendWrapper},
22};
23
24type Manager = ConnectionManager<MysqlConnection>;
25
26/// [`Diesel MySQL`](https://docs.rs/diesel/2.2.4/diesel/mysql/struct.MysqlConnection.html) backend
27pub struct DieselMySQLBackend {
28    privileged_config: PrivilegedMySQLConfig,
29    default_pool: Pool<Manager>,
30    create_restricted_pool: Box<dyn Fn() -> Builder<Manager> + Send + Sync + 'static>,
31    create_entities: Box<dyn Fn(&mut MysqlConnection) + Send + Sync + 'static>,
32    drop_previous_databases_flag: bool,
33}
34
35impl DieselMySQLBackend {
36    /// Creates a new [`Diesel MySQL`](https://docs.rs/diesel/2.2.4/diesel/mysql/struct.MysqlConnection.html) backend
37    /// # Example
38    /// ```
39    /// use db_pool::{sync::DieselMySQLBackend, PrivilegedMySQLConfig};
40    /// use diesel::{sql_query, RunQueryDsl};
41    /// use dotenvy::dotenv;
42    /// use r2d2::Pool;
43    ///
44    /// dotenv().ok();
45    ///
46    /// let config = PrivilegedMySQLConfig::from_env().unwrap();
47    ///
48    /// let backend = DieselMySQLBackend::new(
49    ///     config,
50    ///     || Pool::builder().max_size(10),
51    ///     || Pool::builder().max_size(2),
52    ///     move |conn| {
53    ///         sql_query(
54    ///             "CREATE TABLE book(id INTEGER PRIMARY KEY AUTO_INCREMENT, title TEXT NOT NULL)",
55    ///         )
56    ///         .execute(conn)
57    ///         .unwrap();
58    ///     },
59    /// )
60    /// .unwrap();
61    /// ```
62    pub fn new(
63        privileged_config: PrivilegedMySQLConfig,
64        create_privileged_pool: impl Fn() -> Builder<Manager>,
65        create_restricted_pool: impl Fn() -> Builder<Manager> + Send + Sync + 'static,
66        create_entities: impl Fn(&mut MysqlConnection) + Send + Sync + 'static,
67    ) -> Result<Self, r2d2::Error> {
68        let manager = Manager::new(privileged_config.default_connection_url());
69        let default_pool = (create_privileged_pool()).build(manager)?;
70
71        Ok(Self {
72            privileged_config,
73            default_pool,
74            create_entities: Box::new(create_entities),
75            create_restricted_pool: Box::new(create_restricted_pool),
76            drop_previous_databases_flag: true,
77        })
78    }
79
80    /// Drop databases created in previous runs upon initialization
81    #[must_use]
82    pub fn drop_previous_databases(self, value: bool) -> Self {
83        Self {
84            drop_previous_databases_flag: value,
85            ..self
86        }
87    }
88}
89
90impl MySQLBackend for DieselMySQLBackend {
91    type ConnectionManager = Manager;
92    type ConnectionError = ConnectionError;
93    type QueryError = Error;
94
95    fn get_connection(&self) -> Result<PooledConnection<Manager>, r2d2::Error> {
96        self.default_pool.get()
97    }
98
99    fn execute(&self, query: &str, conn: &mut MysqlConnection) -> QueryResult<()> {
100        sql_query(query).execute(conn)?;
101        Ok(())
102    }
103
104    fn batch_execute<'a>(
105        &self,
106        query: impl IntoIterator<Item = Cow<'a, str>>,
107        conn: &mut MysqlConnection,
108    ) -> QueryResult<()> {
109        let query = query.into_iter().collect::<Vec<_>>();
110        if query.is_empty() {
111            Ok(())
112        } else {
113            conn.batch_execute(query.join(";").as_str())
114        }
115    }
116
117    fn get_host(&self) -> Cow<str> {
118        self.privileged_config.host.as_str().into()
119    }
120
121    fn get_previous_database_names(
122        &self,
123        conn: &mut <Self::ConnectionManager as r2d2::ManageConnection>::Connection,
124    ) -> QueryResult<Vec<String>> {
125        table! {
126            schemata (schema_name) {
127                schema_name -> Text
128            }
129        }
130
131        schemata::table
132            .select(schemata::schema_name)
133            .filter(schemata::schema_name.like("db_pool_%"))
134            .load::<String>(conn)
135    }
136
137    fn create_entities(&self, conn: &mut MysqlConnection) {
138        (self.create_entities)(conn);
139    }
140
141    fn create_connection_pool(
142        &self,
143        db_id: Uuid,
144    ) -> Result<Pool<Self::ConnectionManager>, r2d2::Error> {
145        let db_name = get_db_name(db_id);
146        let db_name = db_name.as_str();
147        let database_url = self.privileged_config.restricted_database_connection_url(
148            db_name,
149            Some(db_name),
150            db_name,
151        );
152        let manager = ConnectionManager::<MysqlConnection>::new(database_url.as_str());
153        (self.create_restricted_pool)().build(manager)
154    }
155
156    fn get_table_names(
157        &self,
158        db_name: &str,
159        conn: &mut MysqlConnection,
160    ) -> QueryResult<Vec<String>> {
161        table! {
162            tables (table_name) {
163                table_name -> Text,
164                table_schema -> Text
165            }
166        }
167
168        sql_query(mysql::USE_DEFAULT_DATABASE).execute(conn)?;
169
170        tables::table
171            .filter(tables::table_schema.eq(db_name))
172            .select(tables::table_name)
173            .load::<String>(conn)
174    }
175
176    fn get_drop_previous_databases(&self) -> bool {
177        self.drop_previous_databases_flag
178    }
179}
180
181impl Backend for DieselMySQLBackend {
182    type ConnectionManager = Manager;
183    type ConnectionError = ConnectionError;
184    type QueryError = Error;
185
186    fn init(&self) -> Result<(), BackendError<ConnectionError, Error>> {
187        MySQLBackendWrapper::new(self).init()
188    }
189
190    fn create(
191        &self,
192        db_id: Uuid,
193        restrict_privileges: bool,
194    ) -> Result<Pool<Manager>, BackendError<ConnectionError, Error>> {
195        MySQLBackendWrapper::new(self).create(db_id, restrict_privileges)
196    }
197
198    fn clean(&self, db_id: Uuid) -> Result<(), BackendError<ConnectionError, Error>> {
199        MySQLBackendWrapper::new(self).clean(db_id)
200    }
201
202    fn drop(
203        &self,
204        db_id: Uuid,
205        _is_restricted: bool,
206    ) -> Result<(), BackendError<ConnectionError, Error>> {
207        MySQLBackendWrapper::new(self).drop(db_id)
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    #![allow(unused_variables, clippy::unwrap_used)]
214
215    use std::borrow::Cow;
216
217    use diesel::{
218        connection::SimpleConnection, insert_into, sql_query, table, Insertable, QueryDsl,
219        RunQueryDsl,
220    };
221    use r2d2::Pool;
222
223    use crate::{
224        common::statement::mysql::tests::{
225            CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS,
226        },
227        sync::{
228            backend::mysql::r#trait::tests::{
229                test_backend_creates_database_with_unrestricted_privileges,
230                test_pool_drops_created_unrestricted_database,
231            },
232            db_pool::DatabasePoolBuilder,
233        },
234        tests::get_privileged_mysql_config,
235    };
236
237    use super::{
238        super::r#trait::tests::{
239            lock_read, test_backend_cleans_database_with_tables,
240            test_backend_cleans_database_without_tables,
241            test_backend_creates_database_with_restricted_privileges, test_backend_drops_database,
242            test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases,
243            test_pool_drops_previous_databases,
244        },
245        DieselMySQLBackend,
246    };
247
248    table! {
249        book (id) {
250            id -> Int4,
251            title -> Text
252        }
253    }
254
255    #[derive(Insertable)]
256    #[diesel(table_name = book)]
257    struct NewBook<'a> {
258        title: Cow<'a, str>,
259    }
260
261    fn create_backend(with_table: bool) -> DieselMySQLBackend {
262        let config = get_privileged_mysql_config().clone();
263        DieselMySQLBackend::new(config, Pool::builder, Pool::builder, {
264            move |conn| {
265                if with_table {
266                    let query = CREATE_ENTITIES_STATEMENTS.join(";");
267                    conn.batch_execute(query.as_str()).unwrap();
268                }
269            }
270        })
271        .unwrap()
272    }
273
274    #[test]
275    fn backend_drops_previous_databases() {
276        test_backend_drops_previous_databases(
277            create_backend(false),
278            create_backend(false).drop_previous_databases(true),
279            create_backend(false).drop_previous_databases(false),
280        );
281    }
282
283    #[test]
284    fn backend_creates_database_with_restricted_privileges() {
285        let backend = create_backend(true).drop_previous_databases(false);
286        test_backend_creates_database_with_restricted_privileges(&backend);
287    }
288
289    #[test]
290    fn backend_creates_database_with_unrestricted_privileges() {
291        let backend = create_backend(true).drop_previous_databases(false);
292        test_backend_creates_database_with_unrestricted_privileges(&backend);
293    }
294
295    #[test]
296    fn backend_cleans_database_with_tables() {
297        let backend = create_backend(true).drop_previous_databases(false);
298        test_backend_cleans_database_with_tables(&backend);
299    }
300
301    #[test]
302    fn backend_cleans_database_without_tables() {
303        let backend = create_backend(false).drop_previous_databases(false);
304        test_backend_cleans_database_without_tables(&backend);
305    }
306
307    #[test]
308    fn backend_drops_restricted_database() {
309        let backend = create_backend(true).drop_previous_databases(false);
310        test_backend_drops_database(&backend, true);
311    }
312
313    #[test]
314    fn backend_drops_unrestricted_database() {
315        let backend = create_backend(true).drop_previous_databases(false);
316        test_backend_drops_database(&backend, false);
317    }
318
319    #[test]
320    fn pool_drops_previous_databases() {
321        test_pool_drops_previous_databases(
322            create_backend(false),
323            create_backend(false).drop_previous_databases(true),
324            create_backend(false).drop_previous_databases(false),
325        );
326    }
327
328    #[test]
329    fn pool_provides_isolated_databases() {
330        const NUM_DBS: i64 = 3;
331
332        let backend = create_backend(true).drop_previous_databases(false);
333
334        let guard = lock_read();
335
336        let db_pool = backend.create_database_pool().unwrap();
337        let conn_pools = (0..NUM_DBS)
338            .map(|_| db_pool.pull_immutable())
339            .collect::<Vec<_>>();
340
341        // insert single row into each database
342        conn_pools.iter().enumerate().for_each(|(i, conn_pool)| {
343            let conn = &mut conn_pool.get().unwrap();
344            insert_into(book::table)
345                .values(NewBook {
346                    title: format!("Title {i}").into(),
347                })
348                .execute(conn)
349                .unwrap();
350        });
351
352        // rows fetched must be as inserted
353        conn_pools.iter().enumerate().for_each(|(i, conn_pool)| {
354            let conn = &mut conn_pool.get().unwrap();
355            assert_eq!(
356                book::table
357                    .select(book::title)
358                    .load::<String>(conn)
359                    .unwrap(),
360                vec![format!("Title {i}")]
361            );
362        });
363    }
364
365    #[test]
366    fn pool_provides_restricted_databases() {
367        let backend = create_backend(true).drop_previous_databases(false);
368
369        let guard = lock_read();
370
371        let db_pool = backend.create_database_pool().unwrap();
372        let conn_pool = db_pool.pull_immutable();
373        let conn = &mut conn_pool.get().unwrap();
374
375        // DDL statements must fail
376        for stmt in DDL_STATEMENTS {
377            assert!(sql_query(stmt).execute(conn).is_err());
378        }
379
380        // DML statements must succeed
381        for stmt in DML_STATEMENTS {
382            assert!(sql_query(stmt).execute(conn).is_ok());
383        }
384    }
385
386    #[test]
387    fn pool_provides_unrestricted_databases() {
388        let backend = create_backend(true).drop_previous_databases(false);
389
390        let guard = lock_read();
391
392        let db_pool = backend.create_database_pool().unwrap();
393
394        // DML statements must succeed
395        {
396            let conn_pool = db_pool.create_mutable().unwrap();
397            let conn = &mut conn_pool.get().unwrap();
398            for stmt in DML_STATEMENTS {
399                assert!(sql_query(stmt).execute(conn).is_ok());
400            }
401        }
402
403        // DDL statements must succeed
404        for stmt in DDL_STATEMENTS {
405            let conn_pool = db_pool.create_mutable().unwrap();
406            let conn = &mut conn_pool.get().unwrap();
407            assert!(sql_query(stmt).execute(conn).is_ok());
408        }
409    }
410
411    #[test]
412    fn pool_provides_clean_databases() {
413        const NUM_DBS: i64 = 3;
414
415        let backend = create_backend(true).drop_previous_databases(false);
416
417        let guard = lock_read();
418
419        let db_pool = backend.create_database_pool().unwrap();
420
421        // fetch connection pools the first time
422        {
423            let conn_pools = (0..NUM_DBS)
424                .map(|_| db_pool.pull_immutable())
425                .collect::<Vec<_>>();
426
427            // databases must be empty
428            for conn_pool in &conn_pools {
429                let conn = &mut conn_pool.get().unwrap();
430                assert_eq!(book::table.count().get_result::<i64>(conn).unwrap(), 0);
431            }
432
433            // insert data into each database
434            for conn_pool in &conn_pools {
435                let conn = &mut conn_pool.get().unwrap();
436                insert_into(book::table)
437                    .values(NewBook {
438                        title: "Title".into(),
439                    })
440                    .execute(conn)
441                    .unwrap();
442            }
443        }
444
445        // fetch same connection pools a second time
446        {
447            let conn_pools = (0..NUM_DBS)
448                .map(|_| db_pool.pull_immutable())
449                .collect::<Vec<_>>();
450
451            // databases must be empty
452            for conn_pool in &conn_pools {
453                let conn = &mut conn_pool.get().unwrap();
454                assert_eq!(book::table.count().get_result::<i64>(conn).unwrap(), 0);
455            }
456        }
457    }
458
459    #[test]
460    fn pool_drops_created_restricted_databases() {
461        let backend = create_backend(false);
462        test_pool_drops_created_restricted_databases(backend);
463    }
464
465    #[test]
466    fn pool_drops_created_unrestricted_databases() {
467        let backend = create_backend(false);
468        test_pool_drops_created_unrestricted_database(backend);
469    }
470}