refinery_core/
config.rs

1use crate::error::Kind;
2use crate::Error;
3use std::convert::TryFrom;
4use std::path::PathBuf;
5use std::str::FromStr;
6use url::Url;
7
8// refinery config file used by migrate_from_config if migration from a Config struct is preferred instead of using the macros
9// Config can either be instanced with [`Config::new`] or retrieved from a config file with [`Config::from_file_location`]
10#[derive(Debug)]
11#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
12pub struct Config {
13    main: Main,
14}
15
16#[derive(Clone, Copy, PartialEq, Debug)]
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18pub enum ConfigDbType {
19    Mysql,
20    Postgres,
21    Sqlite,
22    Mssql,
23}
24
25impl Config {
26    /// create a new config instance
27    pub fn new(db_type: ConfigDbType) -> Config {
28        Config {
29            main: Main {
30                db_type,
31                db_path: None,
32                db_host: None,
33                db_port: None,
34                db_user: None,
35                db_pass: None,
36                db_name: None,
37                #[cfg(feature = "tiberius-config")]
38                trust_cert: false,
39            },
40        }
41    }
42
43    /// create a new Config instance from an environment variable that contains a URL
44    pub fn from_env_var(name: &str) -> Result<Config, Error> {
45        let value = std::env::var(name).map_err(|_| {
46            Error::new(
47                Kind::ConfigError(format!("Couldn't find {} environment variable", name)),
48                None,
49            )
50        })?;
51        Config::from_str(&value)
52    }
53
54    /// create a new Config instance from a config file located on the file system
55    #[cfg(feature = "toml")]
56    pub fn from_file_location<T: AsRef<std::path::Path>>(location: T) -> Result<Config, Error> {
57        let file = std::fs::read_to_string(&location).map_err(|err| {
58            Error::new(
59                Kind::ConfigError(format!("could not open config file, {}", err)),
60                None,
61            )
62        })?;
63
64        let mut config: Config = toml::from_str(&file).map_err(|err| {
65            Error::new(
66                Kind::ConfigError(format!("could not parse config file, {}", err)),
67                None,
68            )
69        })?;
70
71        //replace relative path with canonical path in case of Sqlite db
72        if config.main.db_type == ConfigDbType::Sqlite {
73            let mut config_db_path = config.main.db_path.ok_or_else(|| {
74                Error::new(
75                    Kind::ConfigError("field path must be present for Sqlite database type".into()),
76                    None,
77                )
78            })?;
79
80            if config_db_path.is_relative() {
81                let mut config_db_dir = location
82                    .as_ref()
83                    .parent()
84                    //safe to call unwrap in the below cases as the current dir exists and if config was opened there are permissions on the current dir
85                    .unwrap_or(&std::env::current_dir().unwrap())
86                    .to_path_buf();
87
88                config_db_dir = std::fs::canonicalize(config_db_dir).unwrap();
89                config_db_path = config_db_dir.join(&config_db_path)
90            }
91
92            let config_db_path = config_db_path.canonicalize().map_err(|err| {
93                Error::new(
94                    Kind::ConfigError(format!("invalid sqlite db path, {}", err)),
95                    None,
96                )
97            })?;
98
99            config.main.db_path = Some(config_db_path);
100        }
101
102        Ok(config)
103    }
104
105    cfg_if::cfg_if! {
106        if #[cfg(feature = "rusqlite")] {
107            pub(crate) fn db_path(&self) -> Option<&std::path::Path> {
108                self.main.db_path.as_deref()
109            }
110
111            pub fn set_db_path(self, db_path: &str) -> Config {
112                Config {
113                    main: Main {
114                        db_path: Some(db_path.into()),
115                        ..self.main
116                    },
117                }
118            }
119        }
120    }
121
122    cfg_if::cfg_if! {
123        if #[cfg(feature = "tiberius-config")] {
124            pub fn set_trust_cert(&mut self) {
125                self.main.trust_cert = true;
126            }
127        }
128    }
129
130    pub fn db_type(&self) -> ConfigDbType {
131        self.main.db_type
132    }
133
134    pub fn db_host(&self) -> Option<&str> {
135        self.main.db_host.as_deref()
136    }
137
138    pub fn db_port(&self) -> Option<&str> {
139        self.main.db_port.as_deref()
140    }
141
142    pub fn set_db_user(self, db_user: &str) -> Config {
143        Config {
144            main: Main {
145                db_user: Some(db_user.into()),
146                ..self.main
147            },
148        }
149    }
150
151    pub fn set_db_pass(self, db_pass: &str) -> Config {
152        Config {
153            main: Main {
154                db_pass: Some(db_pass.into()),
155                ..self.main
156            },
157        }
158    }
159
160    pub fn set_db_host(self, db_host: &str) -> Config {
161        Config {
162            main: Main {
163                db_host: Some(db_host.into()),
164                ..self.main
165            },
166        }
167    }
168
169    pub fn set_db_port(self, db_port: &str) -> Config {
170        Config {
171            main: Main {
172                db_port: Some(db_port.into()),
173                ..self.main
174            },
175        }
176    }
177
178    pub fn set_db_name(self, db_name: &str) -> Config {
179        Config {
180            main: Main {
181                db_name: Some(db_name.into()),
182                ..self.main
183            },
184        }
185    }
186}
187
188impl TryFrom<Url> for Config {
189    type Error = Error;
190
191    fn try_from(url: Url) -> Result<Config, Self::Error> {
192        let db_type = match url.scheme() {
193            "mysql" => ConfigDbType::Mysql,
194            "postgres" => ConfigDbType::Postgres,
195            "postgresql" => ConfigDbType::Postgres,
196            "sqlite" => ConfigDbType::Sqlite,
197            "mssql" => ConfigDbType::Mssql,
198            _ => {
199                return Err(Error::new(
200                    Kind::ConfigError("Unsupported database".into()),
201                    None,
202                ))
203            }
204        };
205
206        cfg_if::cfg_if! {
207            if #[cfg(feature = "tiberius-config")] {
208                use std::{borrow::Cow, collections::HashMap};
209                let query_params = url
210                    .query_pairs()
211                    .collect::<HashMap< Cow<'_, str>,  Cow<'_, str>>>();
212
213                let trust_cert = query_params.
214                    get("trust_cert")
215                    .unwrap_or(&Cow::Borrowed("false"))
216                    .parse::<bool>()
217                    .map_err(|_| {
218                        Error::new(
219                            Kind::ConfigError("Invalid trust_cert value, please use true/false".into()),
220                            None,
221                        )
222                    })?;
223            }
224        }
225
226        Ok(Self {
227            main: Main {
228                db_type,
229                db_path: Some(
230                    url.as_str()[url.scheme().len()..]
231                        .trim_start_matches(':')
232                        .trim_start_matches("//")
233                        .to_string()
234                        .into(),
235                ),
236                db_host: url.host_str().map(|r| r.to_string()),
237                db_port: url.port().map(|r| r.to_string()),
238                db_user: Some(url.username().to_string()),
239                db_pass: url.password().map(|r| r.to_string()),
240                db_name: Some(url.path().trim_start_matches('/').to_string()),
241                #[cfg(feature = "tiberius-config")]
242                trust_cert,
243            },
244        })
245    }
246}
247
248impl FromStr for Config {
249    type Err = Error;
250
251    /// create a new Config instance from a string that contains a URL
252    fn from_str(url_str: &str) -> Result<Config, Self::Err> {
253        let url = Url::parse(url_str).map_err(|_| {
254            Error::new(
255                Kind::ConfigError(format!("Couldn't parse the string '{}' as a URL", url_str)),
256                None,
257            )
258        })?;
259        Config::try_from(url)
260    }
261}
262
263#[derive(Debug)]
264#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
265struct Main {
266    db_type: ConfigDbType,
267    db_path: Option<PathBuf>,
268    db_host: Option<String>,
269    db_port: Option<String>,
270    db_user: Option<String>,
271    db_pass: Option<String>,
272    db_name: Option<String>,
273    #[cfg(feature = "tiberius-config")]
274    #[serde(default)]
275    trust_cert: bool,
276}
277
278#[cfg(any(
279    feature = "mysql",
280    feature = "postgres",
281    feature = "tokio-postgres",
282    feature = "mysql_async"
283))]
284pub(crate) fn build_db_url(name: &str, config: &Config) -> String {
285    let mut url: String = name.to_string() + "://";
286
287    if let Some(user) = &config.main.db_user {
288        url = url + user;
289    }
290    if let Some(pass) = &config.main.db_pass {
291        url = url + ":" + pass;
292    }
293    if let Some(host) = &config.main.db_host {
294        if config.main.db_user.is_some() {
295            url = url + "@" + host;
296        } else {
297            url = url + host;
298        }
299    }
300    if let Some(port) = &config.main.db_port {
301        url = url + ":" + port;
302    }
303    if let Some(name) = &config.main.db_name {
304        url = url + "/" + name;
305    }
306    url
307}
308
309cfg_if::cfg_if! {
310    if #[cfg(feature = "tiberius-config")] {
311        use tiberius::{AuthMethod, Config as TConfig};
312
313        impl TryFrom<&Config> for TConfig {
314            type Error=Error;
315
316            fn try_from(config: &Config) -> Result<Self, Self::Error> {
317                let mut tconfig = TConfig::new();
318                if let Some(host) = &config.main.db_host {
319                    tconfig.host(host);
320                }
321
322                if let Some(port) = &config.main.db_port {
323                    let port = port.parse().map_err(|_| Error::new(
324                            Kind::ConfigError(format!("Couldn't parse value {} as mssql port", port)),
325                            None,
326                    ))?;
327                    tconfig.port(port);
328                }
329
330                if let Some(db) = &config.main.db_name {
331                    tconfig.database(db);
332                }
333
334                let user = config.main.db_user.as_deref().unwrap_or("");
335                let pass = config.main.db_pass.as_deref().unwrap_or("");
336
337                if config.main.trust_cert {
338                    tconfig.trust_cert();
339                }
340                tconfig.authentication(AuthMethod::sql_server(user, pass));
341
342                Ok(tconfig)
343            }
344        }
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::{build_db_url, Config, Kind};
351    use std::io::Write;
352    use std::str::FromStr;
353
354    #[test]
355    fn returns_config_error_from_invalid_config_location() {
356        let config = Config::from_file_location("invalid_path").unwrap_err();
357        match config.kind() {
358            Kind::ConfigError(msg) => assert!(msg.contains("could not open config file")),
359            _ => panic!("test failed"),
360        }
361    }
362
363    #[test]
364    fn returns_config_error_from_invalid_toml_file() {
365        let config = "[<$%
366                     db_type = \"Sqlite\" \n";
367
368        let mut config_file = tempfile::NamedTempFile::new_in(".").unwrap();
369        config_file.write_all(config.as_bytes()).unwrap();
370        let config = Config::from_file_location(config_file.path()).unwrap_err();
371        match config.kind() {
372            Kind::ConfigError(msg) => assert!(msg.contains("could not parse config file")),
373            _ => panic!("test failed"),
374        }
375    }
376
377    #[test]
378    fn returns_config_error_from_sqlite_with_missing_path() {
379        let config = "[main] \n
380                     db_type = \"Sqlite\" \n";
381
382        let mut config_file = tempfile::NamedTempFile::new_in(".").unwrap();
383        config_file.write_all(config.as_bytes()).unwrap();
384        let config = Config::from_file_location(config_file.path()).unwrap_err();
385        match config.kind() {
386            Kind::ConfigError(msg) => {
387                assert_eq!("field path must be present for Sqlite database type", msg)
388            }
389            _ => panic!("test failed"),
390        }
391    }
392
393    #[test]
394    fn builds_sqlite_path_from_relative_path() {
395        let db_file = tempfile::NamedTempFile::new_in(".").unwrap();
396
397        let config = format!(
398            "[main] \n
399                       db_type = \"Sqlite\" \n
400                       db_path = \"{}\"",
401            db_file.path().file_name().unwrap().to_str().unwrap()
402        );
403
404        let mut config_file = tempfile::NamedTempFile::new_in(".").unwrap();
405        config_file.write_all(config.as_bytes()).unwrap();
406        let config = Config::from_file_location(config_file.path()).unwrap();
407
408        let parent = config_file.path().parent().unwrap();
409        assert!(parent.is_dir());
410        assert_eq!(
411            db_file.path().canonicalize().unwrap(),
412            config.main.db_path.unwrap()
413        );
414    }
415
416    #[test]
417    fn builds_db_url() {
418        let config = "[main] \n
419                     db_type = \"Postgres\" \n
420                     db_host = \"localhost\" \n
421                     db_port = \"5432\" \n
422                     db_user = \"root\" \n
423                     db_pass = \"1234\" \n
424                     db_name = \"refinery\"";
425
426        let config: Config = toml::from_str(config).unwrap();
427
428        assert_eq!(
429            "postgres://root:1234@localhost:5432/refinery",
430            build_db_url("postgres", &config)
431        );
432    }
433
434    #[test]
435    fn builds_db_env_var() {
436        std::env::set_var(
437            "DATABASE_URL",
438            "postgres://root:1234@localhost:5432/refinery",
439        );
440        let config = Config::from_env_var("DATABASE_URL").unwrap();
441        assert_eq!(
442            "postgres://root:1234@localhost:5432/refinery",
443            build_db_url("postgres", &config)
444        );
445    }
446
447    #[test]
448    fn builds_from_str() {
449        let config = Config::from_str("postgres://root:1234@localhost:5432/refinery").unwrap();
450        assert_eq!(
451            "postgres://root:1234@localhost:5432/refinery",
452            build_db_url("postgres", &config)
453        );
454    }
455
456    #[test]
457    fn builds_db_env_var_failure() {
458        std::env::set_var("DATABASE_URL", "this_is_not_a_url");
459        let config = Config::from_env_var("DATABASE_URL");
460        assert!(config.is_err());
461    }
462}