strut_database/
config.rs

1use serde::de::{Error, IgnoredAny, MapAccess, Visitor};
2use serde::{Deserialize, Deserializer};
3use serde_value::Value;
4use std::collections::BTreeMap;
5use std::fmt::Formatter;
6use strut_factory::impl_deserialize_field;
7
8/// Represents the application-level configuration section that covers everything
9/// related to database connectivity, primarily the instance URL and credentials
10/// for all database servers that this application works with.
11///
12/// This config comes with a custom [`Deserialize`] implementation, to support more
13/// human-oriented textual configuration.
14#[derive(Debug, Default, Clone, PartialEq, Eq)]
15pub struct DatabaseConfig {
16    #[cfg(any(
17        feature = "default-mysql",
18        all(feature = "mysql", not(feature = "postgres"), not(feature = "sqlite")),
19    ))]
20    default_handle: crate::MySqlHandle,
21
22    #[cfg(any(
23        feature = "default-postgres",
24        all(feature = "postgres", not(feature = "mysql"), not(feature = "sqlite")),
25    ))]
26    default_handle: crate::PostgresHandle,
27
28    #[cfg(any(
29        feature = "default-sqlite",
30        all(feature = "sqlite", not(feature = "mysql"), not(feature = "postgres")),
31    ))]
32    default_handle: crate::SqliteHandle,
33
34    #[cfg(feature = "mysql")]
35    mysql_handles: crate::MySqlHandleCollection,
36
37    #[cfg(feature = "postgres")]
38    postgres_handles: crate::PostgresHandleCollection,
39
40    #[cfg(feature = "sqlite")]
41    sqlite_handles: crate::SqliteHandleCollection,
42}
43
44impl DatabaseConfig {
45    /// Returns the default [`MySqlHandle`](crate::MySqlHandle) for this
46    /// configuration.
47    #[cfg(any(
48        feature = "default-mysql",
49        all(feature = "mysql", not(feature = "postgres"), not(feature = "sqlite")),
50    ))]
51    pub fn default_handle(&self) -> &crate::MySqlHandle {
52        &self.default_handle
53    }
54
55    /// Returns the default [`PostgresHandle`](crate::PostgresHandle) for this
56    /// configuration.
57    #[cfg(any(
58        feature = "default-postgres",
59        all(feature = "postgres", not(feature = "mysql"), not(feature = "sqlite")),
60    ))]
61    pub fn default_handle(&self) -> &crate::PostgresHandle {
62        &self.default_handle
63    }
64
65    /// Returns the default [`SqliteHandle`](crate::SqliteHandle) for this
66    /// configuration.
67    #[cfg(any(
68        feature = "default-sqlite",
69        all(feature = "sqlite", not(feature = "mysql"), not(feature = "postgres")),
70    ))]
71    pub fn default_handle(&self) -> &crate::SqliteHandle {
72        &self.default_handle
73    }
74
75    /// Returns the named [`MySqlHandle`](crate::MySqlHandle)s for this
76    /// configuration.
77    #[cfg(feature = "mysql")]
78    pub fn mysql_handles(&self) -> &crate::MySqlHandleCollection {
79        &self.mysql_handles
80    }
81
82    /// Returns the named [`PostgresHandle`](crate::PostgresHandle)s for this
83    /// configuration.
84    #[cfg(feature = "postgres")]
85    pub fn postgres_handles(&self) -> &crate::PostgresHandleCollection {
86        &self.postgres_handles
87    }
88
89    /// Returns the named [`SqliteHandle`](crate::SqliteHandle)s for this
90    /// configuration.
91    #[cfg(feature = "sqlite")]
92    pub fn sqlite_handles(&self) -> &crate::SqliteHandleCollection {
93        &self.sqlite_handles
94    }
95}
96
97const _: () = {
98    impl<'de> Deserialize<'de> for DatabaseConfig {
99        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
100        where
101            D: Deserializer<'de>,
102        {
103            deserializer.deserialize_map(DatabaseConfigVisitor)
104        }
105    }
106
107    struct DatabaseConfigVisitor;
108
109    impl<'de> Visitor<'de> for DatabaseConfigVisitor {
110        type Value = DatabaseConfig;
111
112        fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
113            formatter.write_str(
114                "a map of application database configuration or a string URL for default database",
115            )
116        }
117
118        fn visit_str<E>(self, _value: &str) -> Result<Self::Value, E>
119        where
120            E: Error,
121        {
122            #[cfg(any(
123                feature = "default-mysql",
124                all(feature = "mysql", not(feature = "postgres"), not(feature = "sqlite")),
125            ))]
126            let default_handle = crate::repr::handle::mysql::visit_url::<E>(_value, None)?;
127
128            #[cfg(any(
129                feature = "default-postgres",
130                all(feature = "postgres", not(feature = "mysql"), not(feature = "sqlite")),
131            ))]
132            let default_handle = crate::repr::handle::postgres::visit_url::<E>(_value, None)?;
133
134            #[cfg(any(
135                feature = "default-sqlite",
136                all(feature = "sqlite", not(feature = "mysql"), not(feature = "postgres")),
137            ))]
138            let default_handle = crate::repr::handle::sqlite::visit_url::<E>(_value, None)?;
139
140            Ok(DatabaseConfig {
141                #[cfg(any(
142                    feature = "default-mysql",
143                    feature = "default-postgres",
144                    feature = "default-sqlite",
145                    all(feature = "mysql", not(feature = "postgres"), not(feature = "sqlite")),
146                    all(feature = "postgres", not(feature = "mysql"), not(feature = "sqlite")),
147                    all(feature = "sqlite", not(feature = "mysql"), not(feature = "postgres")),
148                ))]
149                default_handle,
150                ..DatabaseConfig::default()
151            })
152        }
153
154        fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
155        where
156            A: MapAccess<'de>,
157        {
158            #[cfg(any(
159                feature = "default-mysql",
160                feature = "default-postgres",
161                feature = "default-sqlite",
162                all(feature = "mysql", not(feature = "postgres"), not(feature = "sqlite")),
163                all(feature = "postgres", not(feature = "mysql"), not(feature = "sqlite")),
164                all(feature = "sqlite", not(feature = "mysql"), not(feature = "postgres")),
165            ))]
166            let mut default_handle = None;
167
168            #[cfg(feature = "mysql")]
169            let mut mysql_handles = None;
170
171            #[cfg(feature = "postgres")]
172            let mut postgres_handles = None;
173
174            #[cfg(feature = "sqlite")]
175            let mut sqlite_handles = None;
176
177            let mut discarded = BTreeMap::<Value, Value>::new();
178
179            while let Some(key) = map.next_key::<Value>()? {
180                let field = DatabaseConfigField::deserialize(key.clone()).map_err(Error::custom)?;
181
182                match field {
183                    #[cfg(any(
184                        feature = "default-mysql",
185                        feature = "default-postgres",
186                        feature = "default-sqlite",
187                        all(feature = "mysql", not(feature = "postgres"), not(feature = "sqlite")),
188                        all(feature = "postgres", not(feature = "mysql"), not(feature = "sqlite")),
189                        all(feature = "sqlite", not(feature = "mysql"), not(feature = "postgres")),
190                    ))]
191                    DatabaseConfigField::default_handle => {
192                        field.poll(&mut map, &mut default_handle)?
193                    }
194
195                    #[cfg(not(any(
196                        feature = "default-mysql",
197                        feature = "default-postgres",
198                        feature = "default-sqlite",
199                        all(feature = "mysql", not(feature = "postgres"), not(feature = "sqlite")),
200                        all(feature = "postgres", not(feature = "mysql"), not(feature = "sqlite")),
201                        all(feature = "sqlite", not(feature = "mysql"), not(feature = "postgres")),
202                    )))]
203                    DatabaseConfigField::default_handle => map.next_value::<IgnoredAny>()?,
204
205                    #[cfg(feature = "mysql")]
206                    DatabaseConfigField::mysql_handles => {
207                        field.poll(&mut map, &mut mysql_handles)?
208                    }
209                    #[cfg(not(feature = "mysql"))]
210                    DatabaseConfigField::mysql_handles => map.next_value()?,
211
212                    #[cfg(feature = "postgres")]
213                    DatabaseConfigField::postgres_handles => {
214                        field.poll(&mut map, &mut postgres_handles)?
215                    }
216                    #[cfg(not(feature = "postgres"))]
217                    DatabaseConfigField::postgres_handles => map.next_value()?,
218
219                    #[cfg(feature = "sqlite")]
220                    DatabaseConfigField::sqlite_handles => {
221                        field.poll(&mut map, &mut sqlite_handles)?
222                    }
223                    #[cfg(not(feature = "sqlite"))]
224                    DatabaseConfigField::sqlite_handles => map.next_value()?,
225
226                    DatabaseConfigField::__ignore => {
227                        discarded.insert(key, map.next_value()?);
228                        IgnoredAny
229                    }
230                };
231            }
232
233            #[cfg(any(
234                feature = "default-mysql",
235                all(feature = "mysql", not(feature = "postgres"), not(feature = "sqlite")),
236            ))]
237            if default_handle.is_none() {
238                default_handle = Some(
239                    crate::MySqlHandle::deserialize(Value::Map(discarded))
240                        .map_err(Error::custom)?,
241                );
242            }
243
244            #[cfg(any(
245                feature = "default-postgres",
246                all(feature = "postgres", not(feature = "mysql"), not(feature = "sqlite")),
247            ))]
248            if default_handle.is_none() {
249                default_handle = Some(
250                    crate::PostgresHandle::deserialize(Value::Map(discarded))
251                        .map_err(Error::custom)?,
252                );
253            }
254
255            #[cfg(any(
256                feature = "default-sqlite",
257                all(feature = "sqlite", not(feature = "mysql"), not(feature = "postgres")),
258            ))]
259            if default_handle.is_none() {
260                default_handle = Some(
261                    crate::SqliteHandle::deserialize(Value::Map(discarded))
262                        .map_err(Error::custom)?,
263                );
264            }
265
266            Ok(DatabaseConfig {
267                #[cfg(any(
268                    feature = "default-mysql",
269                    feature = "default-postgres",
270                    feature = "default-sqlite",
271                    all(feature = "mysql", not(feature = "postgres"), not(feature = "sqlite")),
272                    all(feature = "postgres", not(feature = "mysql"), not(feature = "sqlite")),
273                    all(feature = "sqlite", not(feature = "mysql"), not(feature = "postgres")),
274                ))]
275                default_handle: default_handle.unwrap_or_default(),
276
277                #[cfg(feature = "mysql")]
278                mysql_handles: mysql_handles.unwrap_or_default(),
279
280                #[cfg(feature = "postgres")]
281                postgres_handles: postgres_handles.unwrap_or_default(),
282
283                #[cfg(feature = "sqlite")]
284                sqlite_handles: sqlite_handles.unwrap_or_default(),
285            })
286        }
287    }
288
289    impl_deserialize_field!(
290        DatabaseConfigField,
291        strut_deserialize::Slug::eq_as_slugs,
292        default_handle | default,
293        mysql_handles | mysql,
294        postgres_handles | postgres | pg | postgre_sql | postgresql,
295        sqlite_handles | sqlite,
296    );
297};
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302    use pretty_assertions::assert_eq;
303
304    #[test]
305    fn empty_input() {
306        // Given
307        let input = r#"
308"#;
309        let expected_output = DatabaseConfig::default();
310
311        // When
312        let actual_output = serde_yml::from_str::<DatabaseConfig>(input).unwrap();
313
314        // Then
315        assert_eq!(expected_output, actual_output);
316    }
317
318    #[test]
319    #[cfg(any(
320        feature = "default-mysql",
321        all(feature = "mysql", not(feature = "postgres"), not(feature = "sqlite")),
322    ))]
323    fn default_mysql_url() {
324        // Given
325        let input = r#"
326url: mysql://alice:secret@example.com:9999/candy_shop
327"#;
328        let expected_output = DatabaseConfig {
329            default_handle: make_test_mysql(None),
330            ..DatabaseConfig::default()
331        };
332
333        // When
334        let actual_output = serde_yml::from_str::<DatabaseConfig>(input).unwrap();
335
336        // Then
337        assert_eq!(expected_output, actual_output);
338    }
339
340    #[test]
341    #[cfg(any(
342        feature = "default-mysql",
343        all(feature = "mysql", not(feature = "postgres"), not(feature = "sqlite")),
344    ))]
345    fn default_mysql_exploded() {
346        // Given
347        let input = r#"
348host: example.com
349port: 9999
350username: alice
351password: secret
352database: candy_shop
353"#;
354        let expected_output = DatabaseConfig {
355            default_handle: make_test_mysql(None),
356            ..DatabaseConfig::default()
357        };
358
359        // When
360        let actual_output = serde_yml::from_str::<DatabaseConfig>(input).unwrap();
361
362        // Then
363        assert_eq!(expected_output, actual_output);
364    }
365
366    #[test]
367    #[cfg(any(
368        feature = "default-postgres",
369        all(feature = "postgres", not(feature = "mysql"), not(feature = "sqlite")),
370    ))]
371    fn default_postgres_url() {
372        // Given
373        let input = r#"
374url: postgres://alice:secret@example.com:9999/candy_shop
375"#;
376        let expected_output = DatabaseConfig {
377            default_handle: make_test_postgres(None),
378            ..DatabaseConfig::default()
379        };
380
381        // When
382        let actual_output = serde_yml::from_str::<DatabaseConfig>(input).unwrap();
383
384        // Then
385        assert_eq!(expected_output, actual_output);
386    }
387
388    #[test]
389    #[cfg(any(
390        feature = "default-postgres",
391        all(feature = "postgres", not(feature = "mysql"), not(feature = "sqlite")),
392    ))]
393    fn default_postgres_exploded() {
394        // Given
395        let input = r#"
396host: example.com
397port: 9999
398username: alice
399password: secret
400database: candy_shop
401"#;
402        let expected_output = DatabaseConfig {
403            default_handle: make_test_postgres(None),
404            ..DatabaseConfig::default()
405        };
406
407        // When
408        let actual_output = serde_yml::from_str::<DatabaseConfig>(input).unwrap();
409
410        // Then
411        assert_eq!(expected_output, actual_output);
412    }
413
414    #[test]
415    #[cfg(any(
416        feature = "default-sqlite",
417        all(feature = "sqlite", not(feature = "mysql"), not(feature = "postgres")),
418    ))]
419    fn default_sqlite_url() {
420        // Given
421        let input = r#"
422url: sqlite://file.db
423"#;
424        let expected_output = DatabaseConfig {
425            default_handle: make_test_sqlite(None),
426            ..DatabaseConfig::default()
427        };
428
429        // When
430        let actual_output = serde_yml::from_str::<DatabaseConfig>(input).unwrap();
431
432        // Then
433        assert_eq!(expected_output, actual_output);
434    }
435
436    #[test]
437    #[cfg(any(
438        feature = "default-sqlite",
439        all(feature = "sqlite", not(feature = "mysql"), not(feature = "postgres")),
440    ))]
441    fn default_sqlite_exploded() {
442        // Given
443        let input = r#"
444filename: file.db
445"#;
446        let expected_output = DatabaseConfig {
447            default_handle: make_test_sqlite(None),
448            ..DatabaseConfig::default()
449        };
450
451        // When
452        let actual_output = serde_yml::from_str::<DatabaseConfig>(input).unwrap();
453
454        // Then
455        assert_eq!(expected_output, actual_output);
456    }
457
458    #[test]
459    #[cfg(feature = "mysql")]
460    fn mysql() {
461        // Given
462        let input = r#"
463mysql:
464    mysql_a: mysql://alice:secret@example.com:9999/candy_shop
465    mysql_b:
466        host: example.com
467        port: 9999
468        username: alice
469        password: secret
470        database: candy_shop
471"#;
472        let expected_output = DatabaseConfig {
473            mysql_handles: crate::MySqlHandleCollection::from([
474                ("mysql_a", make_test_mysql(Some("mysql_a"))),
475                ("mysql_b", make_test_mysql(Some("mysql_b"))),
476            ]),
477            ..DatabaseConfig::default()
478        };
479
480        // When
481        let actual_output = serde_yml::from_str::<DatabaseConfig>(input).unwrap();
482
483        // Then
484        assert_eq!(expected_output, actual_output);
485    }
486
487    #[test]
488    #[cfg(feature = "postgres")]
489    fn postgres() {
490        // Given
491        let input = r#"
492postgres:
493    postgres_a: postgres://alice:secret@example.com:9999/candy_shop
494    postgres_b:
495        host: example.com
496        port: 9999
497        username: alice
498        password: secret
499        database: candy_shop
500"#;
501        let expected_output = DatabaseConfig {
502            postgres_handles: crate::PostgresHandleCollection::from([
503                ("postgres_a", make_test_postgres(Some("postgres_a"))),
504                ("postgres_b", make_test_postgres(Some("postgres_b"))),
505            ]),
506            ..DatabaseConfig::default()
507        };
508
509        // When
510        let actual_output = serde_yml::from_str::<DatabaseConfig>(input).unwrap();
511
512        // Then
513        assert_eq!(expected_output, actual_output);
514    }
515
516    #[test]
517    #[cfg(feature = "sqlite")]
518    fn sqlite() {
519        // Given
520        let input = r#"
521sqlite:
522    sqlite_a: sqlite://file.db
523    sqlite_b:
524        filename: file.db
525"#;
526        let expected_output = DatabaseConfig {
527            sqlite_handles: crate::SqliteHandleCollection::from([
528                ("sqlite_a", make_test_sqlite(Some("sqlite_a"))),
529                ("sqlite_b", make_test_sqlite(Some("sqlite_b"))),
530            ]),
531            ..DatabaseConfig::default()
532        };
533
534        // When
535        let actual_output = serde_yml::from_str::<DatabaseConfig>(input).unwrap();
536
537        // Then
538        assert_eq!(expected_output, actual_output);
539    }
540
541    #[test]
542    #[cfg(all(feature = "mysql", feature = "postgres", feature = "sqlite"))]
543    fn all_types() {
544        // Given
545        let input = r#"
546mysql:
547    mysql_a: mysql://alice:secret@example.com:9999/candy_shop
548    mysql_b:
549        host: example.com
550        port: 9999
551        username: alice
552        password: secret
553        database: candy_shop
554
555postgres:
556    postgres_a: postgres://alice:secret@example.com:9999/candy_shop
557    postgres_b:
558        host: example.com
559        port: 9999
560        username: alice
561        password: secret
562        database: candy_shop
563
564sqlite:
565    sqlite_a: sqlite://file.db
566    sqlite_b:
567        filename: file.db
568"#;
569        let expected_output = DatabaseConfig {
570            mysql_handles: crate::MySqlHandleCollection::from([
571                ("mysql_a", make_test_mysql(Some("mysql_a"))),
572                ("mysql_b", make_test_mysql(Some("mysql_b"))),
573            ]),
574            postgres_handles: crate::PostgresHandleCollection::from([
575                ("postgres_a", make_test_postgres(Some("postgres_a"))),
576                ("postgres_b", make_test_postgres(Some("postgres_b"))),
577            ]),
578            sqlite_handles: crate::SqliteHandleCollection::from([
579                ("sqlite_a", make_test_sqlite(Some("sqlite_a"))),
580                ("sqlite_b", make_test_sqlite(Some("sqlite_b"))),
581            ]),
582            ..DatabaseConfig::default()
583        };
584
585        // When
586        let actual_output = serde_yml::from_str::<DatabaseConfig>(input).unwrap();
587
588        // Then
589        assert_eq!(expected_output, actual_output);
590    }
591
592    #[cfg(feature = "mysql")]
593    fn make_test_mysql(name: Option<&str>) -> crate::MySqlHandle {
594        let mut handle =
595            crate::MySqlHandle::default().recreate_with_connect_options(|connect_options| {
596                connect_options
597                    .username("alice")
598                    .password("secret")
599                    .host("example.com")
600                    .port(9999)
601                    .database("candy_shop")
602            });
603
604        if let Some(name) = name {
605            handle = handle.recreate_with_name(name);
606        }
607
608        handle
609    }
610
611    #[cfg(feature = "postgres")]
612    fn make_test_postgres(name: Option<&str>) -> crate::PostgresHandle {
613        let mut handle =
614            crate::PostgresHandle::default().recreate_with_connect_options(|connect_options| {
615                connect_options
616                    .username("alice")
617                    .password("secret")
618                    .host("example.com")
619                    .port(9999)
620                    .database("candy_shop")
621            });
622
623        if let Some(name) = name {
624            handle = handle.recreate_with_name(name);
625        }
626
627        handle
628    }
629
630    #[cfg(feature = "sqlite")]
631    fn make_test_sqlite(name: Option<&str>) -> crate::SqliteHandle {
632        let mut handle = crate::SqliteHandle::default()
633            .recreate_with_connect_options(|connect_options| connect_options.filename("file.db"));
634
635        if let Some(name) = name {
636            handle = handle.recreate_with_name(name);
637        }
638
639        handle
640    }
641}