refinery_core/
config.rs

1use crate::error::Kind;
2use crate::Error;
3#[cfg(any(
4    feature = "postgres",
5    feature = "tokio-postgres",
6    feature = "tiberius-config"
7))]
8use std::borrow::Cow;
9use std::convert::TryFrom;
10use std::str::FromStr;
11use url::Url;
12
13// refinery config file used by migrate_from_config if migration from a Config struct is preferred instead of using the macros
14// Config can either be instanced with [`Config::new`] or retrieved from a config file with [`Config::from_file_location`]
15#[derive(Debug)]
16#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
17pub struct Config {
18    main: Main,
19}
20
21#[derive(Clone, Copy, PartialEq, Debug)]
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23pub enum ConfigDbType {
24    Mysql,
25    Postgres,
26    Sqlite,
27    Mssql,
28}
29
30impl Config {
31    /// create a new config instance
32    pub fn new(db_type: ConfigDbType) -> Config {
33        Config {
34            main: Main::new(db_type),
35        }
36    }
37
38    /// create a new Config instance from an environment variable that contains a URL
39    pub fn from_env_var(name: &str) -> Result<Config, Error> {
40        let value = std::env::var(name).map_err(|_| {
41            Error::new(
42                Kind::ConfigError(format!("Couldn't find {name} environment variable")),
43                None,
44            )
45        })?;
46        Config::from_str(&value)
47    }
48
49    pub fn db_type(&self) -> ConfigDbType {
50        self.main.db_type
51    }
52
53    /// create a new Config instance from a config file located on the file system
54    #[cfg(feature = "toml")]
55    pub fn from_file_location<T: AsRef<std::path::Path>>(location: T) -> Result<Config, Error> {
56        let file = std::fs::read_to_string(&location).map_err(|err| {
57            Error::new(
58                Kind::ConfigError(format!("could not open config file, {err}")),
59                None,
60            )
61        })?;
62
63        let config: Config = toml::from_str(&file).map_err(|err| {
64            Error::new(
65                Kind::ConfigError(format!("could not parse config file, {err}")),
66                None,
67            )
68        })?;
69
70        //replace relative path with canonical path in case of Sqlite db
71        #[cfg(feature = "rusqlite")]
72        if config.main.db_type == ConfigDbType::Sqlite {
73            let mut config = config;
74            let mut config_db_path = config.main.db_path.ok_or_else(|| {
75                Error::new(
76                    Kind::ConfigError("field path must be present for Sqlite database type".into()),
77                    None,
78                )
79            })?;
80
81            if config_db_path.is_relative() {
82                let mut config_db_dir = location
83                    .as_ref()
84                    .parent()
85                    //safe to call unwrap in the below cases as the current dir exists and if config was opened there are permissions on the current dir
86                    .unwrap_or(&std::env::current_dir().unwrap())
87                    .to_path_buf();
88
89                config_db_dir = std::fs::canonicalize(config_db_dir).unwrap();
90                config_db_path = config_db_dir.join(&config_db_path)
91            }
92
93            let config_db_path = config_db_path.canonicalize().map_err(|err| {
94                Error::new(
95                    Kind::ConfigError(format!("invalid sqlite db path, {err}")),
96                    None,
97                )
98            })?;
99            config.main.db_path = Some(config_db_path);
100
101            return Ok(config);
102        }
103
104        Ok(config)
105    }
106
107    #[cfg(feature = "tiberius-config")]
108    pub fn set_trust_cert(&mut self) {
109        self.main.trust_cert = true;
110    }
111}
112
113#[cfg(any(
114    feature = "mysql",
115    feature = "postgres",
116    feature = "tokio-postgres",
117    feature = "mysql_async",
118    feature = "tiberius-config"
119))]
120impl Config {
121    pub fn db_host(&self) -> Option<&str> {
122        self.main.db_host.as_deref()
123    }
124
125    pub fn db_port(&self) -> Option<&str> {
126        self.main.db_port.as_deref()
127    }
128
129    pub fn set_db_user(self, db_user: &str) -> Config {
130        Config {
131            main: Main {
132                db_user: Some(db_user.into()),
133                ..self.main
134            },
135        }
136    }
137
138    pub fn set_db_pass(self, db_pass: &str) -> Config {
139        Config {
140            main: Main {
141                db_pass: Some(db_pass.into()),
142                ..self.main
143            },
144        }
145    }
146
147    pub fn set_db_host(self, db_host: &str) -> Config {
148        Config {
149            main: Main {
150                db_host: Some(db_host.into()),
151                ..self.main
152            },
153        }
154    }
155
156    pub fn set_db_port(self, db_port: &str) -> Config {
157        Config {
158            main: Main {
159                db_port: Some(db_port.into()),
160                ..self.main
161            },
162        }
163    }
164
165    pub fn set_db_name(self, db_name: &str) -> Config {
166        Config {
167            main: Main {
168                db_name: Some(db_name.into()),
169                ..self.main
170            },
171        }
172    }
173}
174
175#[cfg(feature = "rusqlite")]
176impl Config {
177    pub(crate) fn db_path(&self) -> Option<&std::path::Path> {
178        self.main.db_path.as_deref()
179    }
180
181    pub fn set_db_path(self, db_path: &str) -> Config {
182        Config {
183            main: Main {
184                db_path: Some(db_path.into()),
185                ..self.main
186            },
187        }
188    }
189}
190
191#[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
192impl Config {
193    pub fn use_tls(&self) -> bool {
194        self.main.use_tls
195    }
196
197    pub fn set_use_tls(self, use_tls: bool) -> Config {
198        Config {
199            main: Main {
200                use_tls,
201                ..self.main
202            },
203        }
204    }
205}
206
207impl TryFrom<Url> for Config {
208    type Error = Error;
209
210    fn try_from(url: Url) -> Result<Config, Self::Error> {
211        let db_type = match url.scheme() {
212            "mysql" => ConfigDbType::Mysql,
213            "postgres" => ConfigDbType::Postgres,
214            "postgresql" => ConfigDbType::Postgres,
215            "sqlite" => ConfigDbType::Sqlite,
216            "mssql" => ConfigDbType::Mssql,
217            _ => {
218                return Err(Error::new(
219                    Kind::ConfigError("Unsupported database".into()),
220                    None,
221                ))
222            }
223        };
224
225        Ok(Self {
226            main: Main {
227                db_type,
228                #[cfg(feature = "rusqlite")]
229                db_path: Some(
230                    url.as_str()[url.scheme().len()..]
231                        .trim_start_matches(':')
232                        .trim_start_matches("//")
233                        .to_string()
234                        .into(),
235                ),
236                #[cfg(any(
237                    feature = "mysql",
238                    feature = "postgres",
239                    feature = "tokio-postgres",
240                    feature = "mysql_async",
241                    feature = "tiberius-config"
242                ))]
243                db_host: url.host_str().map(|r| r.to_string()),
244                #[cfg(any(
245                    feature = "mysql",
246                    feature = "postgres",
247                    feature = "tokio-postgres",
248                    feature = "mysql_async",
249                    feature = "tiberius-config"
250                ))]
251                db_port: url.port().map(|r| r.to_string()),
252                #[cfg(any(
253                    feature = "mysql",
254                    feature = "postgres",
255                    feature = "tokio-postgres",
256                    feature = "mysql_async",
257                    feature = "tiberius-config"
258                ))]
259                db_user: Some(url.username().to_string()),
260                #[cfg(any(
261                    feature = "mysql",
262                    feature = "postgres",
263                    feature = "tokio-postgres",
264                    feature = "mysql_async",
265                    feature = "tiberius-config"
266                ))]
267                db_pass: url.password().map(|r| r.to_string()),
268                #[cfg(any(
269                    feature = "mysql",
270                    feature = "postgres",
271                    feature = "tokio-postgres",
272                    feature = "mysql_async",
273                    feature = "tiberius-config"
274                ))]
275                db_name: Some(url.path().trim_start_matches('/').to_string()),
276                #[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
277                use_tls: match url
278                    .query_pairs()
279                    .collect::<std::collections::HashMap<Cow<'_, str>, Cow<'_, str>>>()
280                    .get("sslmode")
281                {
282                    Some(Cow::Borrowed("require")) => true,
283                    Some(Cow::Borrowed("disable")) | None => false,
284                    _ => {
285                        return Err(Error::new(
286                            Kind::ConfigError(
287                                "Invalid sslmode value, please use disable/require".into(),
288                            ),
289                            None,
290                        ))
291                    }
292                },
293                #[cfg(feature = "tiberius-config")]
294                trust_cert: url
295                    .query_pairs()
296                    .collect::<std::collections::HashMap<Cow<'_, str>, Cow<'_, str>>>()
297                    .get("trust_cert")
298                    .unwrap_or(&Cow::Borrowed("false"))
299                    .parse::<bool>()
300                    .map_err(|_| {
301                        Error::new(
302                            Kind::ConfigError(
303                                "Invalid trust_cert value, please use true/false".into(),
304                            ),
305                            None,
306                        )
307                    })?,
308            },
309        })
310    }
311}
312
313impl FromStr for Config {
314    type Err = Error;
315
316    /// create a new Config instance from a string that contains a URL
317    fn from_str(url_str: &str) -> Result<Config, Self::Err> {
318        let url = Url::parse(url_str).map_err(|_| {
319            Error::new(
320                Kind::ConfigError(format!("Couldn't parse the string '{url_str}' as a URL")),
321                None,
322            )
323        })?;
324        Config::try_from(url)
325    }
326}
327
328#[derive(Debug)]
329#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
330struct Main {
331    db_type: ConfigDbType,
332    #[cfg(feature = "rusqlite")]
333    db_path: Option<std::path::PathBuf>,
334    #[cfg(any(
335        feature = "mysql",
336        feature = "postgres",
337        feature = "tokio-postgres",
338        feature = "mysql_async",
339        feature = "tiberius-config"
340    ))]
341    db_host: Option<String>,
342    #[cfg(any(
343        feature = "mysql",
344        feature = "postgres",
345        feature = "tokio-postgres",
346        feature = "mysql_async",
347        feature = "tiberius-config"
348    ))]
349    db_port: Option<String>,
350    #[cfg(any(
351        feature = "mysql",
352        feature = "postgres",
353        feature = "tokio-postgres",
354        feature = "mysql_async",
355        feature = "tiberius-config"
356    ))]
357    db_user: Option<String>,
358    #[cfg(any(
359        feature = "mysql",
360        feature = "postgres",
361        feature = "tokio-postgres",
362        feature = "mysql_async",
363        feature = "tiberius-config"
364    ))]
365    db_pass: Option<String>,
366    #[cfg(any(
367        feature = "mysql",
368        feature = "postgres",
369        feature = "tokio-postgres",
370        feature = "mysql_async",
371        feature = "tiberius-config"
372    ))]
373    db_name: Option<String>,
374    #[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
375    #[cfg_attr(feature = "serde", serde(default))]
376    use_tls: bool,
377    #[cfg(feature = "tiberius-config")]
378    #[cfg_attr(feature = "serde", serde(default))]
379    trust_cert: bool,
380}
381
382impl Main {
383    fn new(db_type: ConfigDbType) -> Self {
384        Main {
385            db_type,
386            #[cfg(feature = "rusqlite")]
387            db_path: None,
388            #[cfg(any(
389                feature = "mysql",
390                feature = "postgres",
391                feature = "tokio-postgres",
392                feature = "mysql_async",
393                feature = "tiberius-config"
394            ))]
395            db_host: None,
396            #[cfg(any(
397                feature = "mysql",
398                feature = "postgres",
399                feature = "tokio-postgres",
400                feature = "mysql_async",
401                feature = "tiberius-config"
402            ))]
403            db_port: None,
404            #[cfg(any(
405                feature = "mysql",
406                feature = "postgres",
407                feature = "tokio-postgres",
408                feature = "mysql_async",
409                feature = "tiberius-config"
410            ))]
411            db_user: None,
412            #[cfg(any(
413                feature = "mysql",
414                feature = "postgres",
415                feature = "tokio-postgres",
416                feature = "mysql_async",
417                feature = "tiberius-config"
418            ))]
419            db_pass: None,
420            #[cfg(any(
421                feature = "mysql",
422                feature = "postgres",
423                feature = "tokio-postgres",
424                feature = "mysql_async",
425                feature = "tiberius-config"
426            ))]
427            db_name: None,
428            #[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
429            use_tls: false,
430            #[cfg(feature = "tiberius-config")]
431            trust_cert: false,
432        }
433    }
434}
435
436#[cfg(any(
437    feature = "mysql",
438    feature = "postgres",
439    feature = "tokio-postgres",
440    feature = "mysql_async",
441))]
442pub(crate) fn build_db_url(name: &str, config: &Config) -> String {
443    let mut url: String = name.to_string() + "://";
444
445    if let Some(user) = &config.main.db_user {
446        url = url + user;
447    }
448    if let Some(pass) = &config.main.db_pass {
449        url = url + ":" + pass;
450    }
451    if let Some(host) = &config.main.db_host {
452        if config.main.db_user.is_some() {
453            url = url + "@" + host;
454        } else {
455            url = url + host;
456        }
457    }
458    if let Some(port) = &config.main.db_port {
459        url = url + ":" + port;
460    }
461    if let Some(name) = &config.main.db_name {
462        url = url + "/" + name;
463    }
464    url
465}
466
467#[cfg(feature = "tiberius-config")]
468impl TryFrom<&Config> for tiberius::Config {
469    type Error = Error;
470
471    fn try_from(config: &Config) -> Result<Self, Self::Error> {
472        let mut tconfig = tiberius::Config::new();
473        if let Some(host) = &config.main.db_host {
474            tconfig.host(host);
475        }
476
477        if let Some(port) = &config.main.db_port {
478            let port = port.parse().map_err(|_| {
479                Error::new(
480                    Kind::ConfigError(format!("Couldn't parse value {port} as mssql port")),
481                    None,
482                )
483            })?;
484            tconfig.port(port);
485        }
486
487        if let Some(db) = &config.main.db_name {
488            tconfig.database(db);
489        }
490
491        let user = config.main.db_user.as_deref().unwrap_or("");
492        let pass = config.main.db_pass.as_deref().unwrap_or("");
493
494        if config.main.trust_cert {
495            tconfig.trust_cert();
496        }
497        tconfig.authentication(tiberius::AuthMethod::sql_server(user, pass));
498
499        Ok(tconfig)
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use super::{Config, Kind};
506    use std::io::Write;
507    use std::str::FromStr;
508
509    #[cfg(any(
510        feature = "mysql",
511        feature = "postgres",
512        feature = "tokio-postgres",
513        feature = "mysql_async"
514    ))]
515    use super::build_db_url;
516
517    #[test]
518    #[cfg(feature = "toml")]
519    fn returns_config_error_from_invalid_config_location() {
520        let config = Config::from_file_location("invalid_path").unwrap_err();
521        match config.kind() {
522            Kind::ConfigError(msg) => assert!(msg.contains("could not open config file")),
523            _ => panic!("test failed"),
524        }
525    }
526
527    #[test]
528    #[cfg(feature = "toml")]
529    fn returns_config_error_from_invalid_toml_file() {
530        let config = "[<$%
531                     db_type = \"Sqlite\" \n";
532
533        let mut config_file = tempfile::NamedTempFile::new_in(".").unwrap();
534        config_file.write_all(config.as_bytes()).unwrap();
535        let config = Config::from_file_location(config_file.path()).unwrap_err();
536        match config.kind() {
537            Kind::ConfigError(msg) => assert!(msg.contains("could not parse config file")),
538            _ => panic!("test failed"),
539        }
540    }
541
542    #[test]
543    #[cfg(all(feature = "toml", feature = "rusqlite"))]
544    fn returns_config_error_from_sqlite_with_missing_path() {
545        let config = "[main] \n
546                     db_type = \"Sqlite\" \n";
547
548        let mut config_file = tempfile::NamedTempFile::new_in(".").unwrap();
549        config_file.write_all(config.as_bytes()).unwrap();
550        let config = Config::from_file_location(config_file.path()).unwrap_err();
551        match config.kind() {
552            Kind::ConfigError(msg) => {
553                assert_eq!("field path must be present for Sqlite database type", msg)
554            }
555            _ => panic!("test failed"),
556        }
557    }
558
559    #[test]
560    #[cfg(all(feature = "toml", feature = "rusqlite"))]
561    fn builds_sqlite_path_from_relative_path() {
562        let db_file = tempfile::NamedTempFile::new_in(".").unwrap();
563
564        let config = format!(
565            "[main] \n
566                       db_type = \"Sqlite\" \n
567                       db_path = \"{}\"",
568            db_file.path().file_name().unwrap().to_str().unwrap()
569        );
570
571        let mut config_file = tempfile::NamedTempFile::new_in(".").unwrap();
572        config_file.write_all(config.as_bytes()).unwrap();
573        let config = Config::from_file_location(config_file.path()).unwrap();
574
575        let parent = config_file.path().parent().unwrap();
576        assert!(parent.is_dir());
577        assert_eq!(
578            db_file.path().canonicalize().unwrap(),
579            config.main.db_path.unwrap()
580        );
581    }
582
583    #[test]
584    #[cfg(all(
585        feature = "toml",
586        any(
587            feature = "mysql",
588            feature = "postgres",
589            feature = "tokio-postgres",
590            feature = "mysql_async"
591        )
592    ))]
593    fn builds_db_url() {
594        let config = "[main] \n
595                     db_type = \"Postgres\" \n
596                     db_host = \"localhost\" \n
597                     db_port = \"5432\" \n
598                     db_user = \"root\" \n
599                     db_pass = \"1234\" \n
600                     db_name = \"refinery\"";
601
602        let config: Config = toml::from_str(config).unwrap();
603
604        assert_eq!(
605            "postgres://root:1234@localhost:5432/refinery",
606            build_db_url("postgres", &config)
607        );
608    }
609
610    #[test]
611    #[cfg(any(
612        feature = "mysql",
613        feature = "postgres",
614        feature = "tokio-postgres",
615        feature = "mysql_async"
616    ))]
617    fn builds_db_env_var() {
618        std::env::set_var(
619            "TEST_DATABASE_URL",
620            "postgres://root:1234@localhost:5432/refinery",
621        );
622        let config = Config::from_env_var("TEST_DATABASE_URL").unwrap();
623        assert_eq!(
624            "postgres://root:1234@localhost:5432/refinery",
625            build_db_url("postgres", &config)
626        );
627    }
628
629    #[test]
630    #[cfg(any(
631        feature = "mysql",
632        feature = "postgres",
633        feature = "tokio-postgres",
634        feature = "mysql_async"
635    ))]
636    fn builds_from_str() {
637        let config = Config::from_str("postgres://root:1234@localhost:5432/refinery").unwrap();
638        assert_eq!(
639            "postgres://root:1234@localhost:5432/refinery",
640            build_db_url("postgres", &config)
641        );
642    }
643
644    #[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
645    #[test]
646    fn builds_from_sslmode_str() {
647        use crate::config::ConfigDbType;
648
649        let config_disable =
650            Config::from_str("postgres://root:1234@localhost:5432/refinery?sslmode=disable")
651                .unwrap();
652        assert!(!config_disable.use_tls());
653
654        let config_require =
655            Config::from_str("postgres://root:1234@localhost:5432/refinery?sslmode=require")
656                .unwrap();
657        assert!(config_require.use_tls());
658
659        // Verify that manually created config matches parsed URL config
660        let manual_config_disable = Config::new(ConfigDbType::Postgres)
661            .set_db_user("root")
662            .set_db_pass("1234")
663            .set_db_host("localhost")
664            .set_db_port("5432")
665            .set_db_name("refinery")
666            .set_use_tls(false);
667        assert_eq!(config_disable.use_tls(), manual_config_disable.use_tls());
668
669        let manual_config_require = Config::new(ConfigDbType::Postgres)
670            .set_db_user("root")
671            .set_db_pass("1234")
672            .set_db_host("localhost")
673            .set_db_port("5432")
674            .set_db_name("refinery")
675            .set_use_tls(true);
676        assert_eq!(config_require.use_tls(), manual_config_require.use_tls());
677
678        let config =
679            Config::from_str("postgres://root:1234@localhost:5432/refinery?sslmode=invalidvalue");
680        assert!(config.is_err());
681    }
682
683    #[test]
684    fn builds_db_env_var_failure() {
685        std::env::set_var("TEST_DATABASE_URL_INVALID", "this_is_not_a_url");
686        let config = Config::from_env_var("TEST_DATABASE_URL_INVALID");
687        assert!(config.is_err());
688    }
689}