1use crate::error::Kind;
2use crate::Error;
3use std::convert::TryFrom;
4use std::path::PathBuf;
5use std::str::FromStr;
6use url::Url;
7
8#[derive(Debug)]
11#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
12pub struct Config {
13 main: Main,
14}
15
16#[derive(Clone, Copy, PartialEq, Debug)]
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18pub enum ConfigDbType {
19 Mysql,
20 Postgres,
21 Sqlite,
22 Mssql,
23}
24
25impl Config {
26 pub fn new(db_type: ConfigDbType) -> Config {
28 Config {
29 main: Main {
30 db_type,
31 db_path: None,
32 db_host: None,
33 db_port: None,
34 db_user: None,
35 db_pass: None,
36 db_name: None,
37 #[cfg(feature = "tiberius-config")]
38 trust_cert: false,
39 },
40 }
41 }
42
43 pub fn from_env_var(name: &str) -> Result<Config, Error> {
45 let value = std::env::var(name).map_err(|_| {
46 Error::new(
47 Kind::ConfigError(format!("Couldn't find {} environment variable", name)),
48 None,
49 )
50 })?;
51 Config::from_str(&value)
52 }
53
54 #[cfg(feature = "toml")]
56 pub fn from_file_location<T: AsRef<std::path::Path>>(location: T) -> Result<Config, Error> {
57 let file = std::fs::read_to_string(&location).map_err(|err| {
58 Error::new(
59 Kind::ConfigError(format!("could not open config file, {}", err)),
60 None,
61 )
62 })?;
63
64 let mut config: Config = toml::from_str(&file).map_err(|err| {
65 Error::new(
66 Kind::ConfigError(format!("could not parse config file, {}", err)),
67 None,
68 )
69 })?;
70
71 if config.main.db_type == ConfigDbType::Sqlite {
73 let mut config_db_path = config.main.db_path.ok_or_else(|| {
74 Error::new(
75 Kind::ConfigError("field path must be present for Sqlite database type".into()),
76 None,
77 )
78 })?;
79
80 if config_db_path.is_relative() {
81 let mut config_db_dir = location
82 .as_ref()
83 .parent()
84 .unwrap_or(&std::env::current_dir().unwrap())
86 .to_path_buf();
87
88 config_db_dir = std::fs::canonicalize(config_db_dir).unwrap();
89 config_db_path = config_db_dir.join(&config_db_path)
90 }
91
92 let config_db_path = config_db_path.canonicalize().map_err(|err| {
93 Error::new(
94 Kind::ConfigError(format!("invalid sqlite db path, {}", err)),
95 None,
96 )
97 })?;
98
99 config.main.db_path = Some(config_db_path);
100 }
101
102 Ok(config)
103 }
104
105 cfg_if::cfg_if! {
106 if #[cfg(feature = "rusqlite")] {
107 pub(crate) fn db_path(&self) -> Option<&std::path::Path> {
108 self.main.db_path.as_deref()
109 }
110
111 pub fn set_db_path(self, db_path: &str) -> Config {
112 Config {
113 main: Main {
114 db_path: Some(db_path.into()),
115 ..self.main
116 },
117 }
118 }
119 }
120 }
121
122 cfg_if::cfg_if! {
123 if #[cfg(feature = "tiberius-config")] {
124 pub fn set_trust_cert(&mut self) {
125 self.main.trust_cert = true;
126 }
127 }
128 }
129
130 pub fn db_type(&self) -> ConfigDbType {
131 self.main.db_type
132 }
133
134 pub fn db_host(&self) -> Option<&str> {
135 self.main.db_host.as_deref()
136 }
137
138 pub fn db_port(&self) -> Option<&str> {
139 self.main.db_port.as_deref()
140 }
141
142 pub fn set_db_user(self, db_user: &str) -> Config {
143 Config {
144 main: Main {
145 db_user: Some(db_user.into()),
146 ..self.main
147 },
148 }
149 }
150
151 pub fn set_db_pass(self, db_pass: &str) -> Config {
152 Config {
153 main: Main {
154 db_pass: Some(db_pass.into()),
155 ..self.main
156 },
157 }
158 }
159
160 pub fn set_db_host(self, db_host: &str) -> Config {
161 Config {
162 main: Main {
163 db_host: Some(db_host.into()),
164 ..self.main
165 },
166 }
167 }
168
169 pub fn set_db_port(self, db_port: &str) -> Config {
170 Config {
171 main: Main {
172 db_port: Some(db_port.into()),
173 ..self.main
174 },
175 }
176 }
177
178 pub fn set_db_name(self, db_name: &str) -> Config {
179 Config {
180 main: Main {
181 db_name: Some(db_name.into()),
182 ..self.main
183 },
184 }
185 }
186}
187
188impl TryFrom<Url> for Config {
189 type Error = Error;
190
191 fn try_from(url: Url) -> Result<Config, Self::Error> {
192 let db_type = match url.scheme() {
193 "mysql" => ConfigDbType::Mysql,
194 "postgres" => ConfigDbType::Postgres,
195 "postgresql" => ConfigDbType::Postgres,
196 "sqlite" => ConfigDbType::Sqlite,
197 "mssql" => ConfigDbType::Mssql,
198 _ => {
199 return Err(Error::new(
200 Kind::ConfigError("Unsupported database".into()),
201 None,
202 ))
203 }
204 };
205
206 cfg_if::cfg_if! {
207 if #[cfg(feature = "tiberius-config")] {
208 use std::{borrow::Cow, collections::HashMap};
209 let query_params = url
210 .query_pairs()
211 .collect::<HashMap< Cow<'_, str>, Cow<'_, str>>>();
212
213 let trust_cert = query_params.
214 get("trust_cert")
215 .unwrap_or(&Cow::Borrowed("false"))
216 .parse::<bool>()
217 .map_err(|_| {
218 Error::new(
219 Kind::ConfigError("Invalid trust_cert value, please use true/false".into()),
220 None,
221 )
222 })?;
223 }
224 }
225
226 Ok(Self {
227 main: Main {
228 db_type,
229 db_path: Some(
230 url.as_str()[url.scheme().len()..]
231 .trim_start_matches(':')
232 .trim_start_matches("//")
233 .to_string()
234 .into(),
235 ),
236 db_host: url.host_str().map(|r| r.to_string()),
237 db_port: url.port().map(|r| r.to_string()),
238 db_user: Some(url.username().to_string()),
239 db_pass: url.password().map(|r| r.to_string()),
240 db_name: Some(url.path().trim_start_matches('/').to_string()),
241 #[cfg(feature = "tiberius-config")]
242 trust_cert,
243 },
244 })
245 }
246}
247
248impl FromStr for Config {
249 type Err = Error;
250
251 fn from_str(url_str: &str) -> Result<Config, Self::Err> {
253 let url = Url::parse(url_str).map_err(|_| {
254 Error::new(
255 Kind::ConfigError(format!("Couldn't parse the string '{}' as a URL", url_str)),
256 None,
257 )
258 })?;
259 Config::try_from(url)
260 }
261}
262
263#[derive(Debug)]
264#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
265struct Main {
266 db_type: ConfigDbType,
267 db_path: Option<PathBuf>,
268 db_host: Option<String>,
269 db_port: Option<String>,
270 db_user: Option<String>,
271 db_pass: Option<String>,
272 db_name: Option<String>,
273 #[cfg(feature = "tiberius-config")]
274 #[serde(default)]
275 trust_cert: bool,
276}
277
278#[cfg(any(
279 feature = "mysql",
280 feature = "postgres",
281 feature = "tokio-postgres",
282 feature = "mysql_async"
283))]
284pub(crate) fn build_db_url(name: &str, config: &Config) -> String {
285 let mut url: String = name.to_string() + "://";
286
287 if let Some(user) = &config.main.db_user {
288 url = url + user;
289 }
290 if let Some(pass) = &config.main.db_pass {
291 url = url + ":" + pass;
292 }
293 if let Some(host) = &config.main.db_host {
294 if config.main.db_user.is_some() {
295 url = url + "@" + host;
296 } else {
297 url = url + host;
298 }
299 }
300 if let Some(port) = &config.main.db_port {
301 url = url + ":" + port;
302 }
303 if let Some(name) = &config.main.db_name {
304 url = url + "/" + name;
305 }
306 url
307}
308
309cfg_if::cfg_if! {
310 if #[cfg(feature = "tiberius-config")] {
311 use tiberius::{AuthMethod, Config as TConfig};
312
313 impl TryFrom<&Config> for TConfig {
314 type Error=Error;
315
316 fn try_from(config: &Config) -> Result<Self, Self::Error> {
317 let mut tconfig = TConfig::new();
318 if let Some(host) = &config.main.db_host {
319 tconfig.host(host);
320 }
321
322 if let Some(port) = &config.main.db_port {
323 let port = port.parse().map_err(|_| Error::new(
324 Kind::ConfigError(format!("Couldn't parse value {} as mssql port", port)),
325 None,
326 ))?;
327 tconfig.port(port);
328 }
329
330 if let Some(db) = &config.main.db_name {
331 tconfig.database(db);
332 }
333
334 let user = config.main.db_user.as_deref().unwrap_or("");
335 let pass = config.main.db_pass.as_deref().unwrap_or("");
336
337 if config.main.trust_cert {
338 tconfig.trust_cert();
339 }
340 tconfig.authentication(AuthMethod::sql_server(user, pass));
341
342 Ok(tconfig)
343 }
344 }
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::{build_db_url, Config, Kind};
351 use std::io::Write;
352 use std::str::FromStr;
353
354 #[test]
355 fn returns_config_error_from_invalid_config_location() {
356 let config = Config::from_file_location("invalid_path").unwrap_err();
357 match config.kind() {
358 Kind::ConfigError(msg) => assert!(msg.contains("could not open config file")),
359 _ => panic!("test failed"),
360 }
361 }
362
363 #[test]
364 fn returns_config_error_from_invalid_toml_file() {
365 let config = "[<$%
366 db_type = \"Sqlite\" \n";
367
368 let mut config_file = tempfile::NamedTempFile::new_in(".").unwrap();
369 config_file.write_all(config.as_bytes()).unwrap();
370 let config = Config::from_file_location(config_file.path()).unwrap_err();
371 match config.kind() {
372 Kind::ConfigError(msg) => assert!(msg.contains("could not parse config file")),
373 _ => panic!("test failed"),
374 }
375 }
376
377 #[test]
378 fn returns_config_error_from_sqlite_with_missing_path() {
379 let config = "[main] \n
380 db_type = \"Sqlite\" \n";
381
382 let mut config_file = tempfile::NamedTempFile::new_in(".").unwrap();
383 config_file.write_all(config.as_bytes()).unwrap();
384 let config = Config::from_file_location(config_file.path()).unwrap_err();
385 match config.kind() {
386 Kind::ConfigError(msg) => {
387 assert_eq!("field path must be present for Sqlite database type", msg)
388 }
389 _ => panic!("test failed"),
390 }
391 }
392
393 #[test]
394 fn builds_sqlite_path_from_relative_path() {
395 let db_file = tempfile::NamedTempFile::new_in(".").unwrap();
396
397 let config = format!(
398 "[main] \n
399 db_type = \"Sqlite\" \n
400 db_path = \"{}\"",
401 db_file.path().file_name().unwrap().to_str().unwrap()
402 );
403
404 let mut config_file = tempfile::NamedTempFile::new_in(".").unwrap();
405 config_file.write_all(config.as_bytes()).unwrap();
406 let config = Config::from_file_location(config_file.path()).unwrap();
407
408 let parent = config_file.path().parent().unwrap();
409 assert!(parent.is_dir());
410 assert_eq!(
411 db_file.path().canonicalize().unwrap(),
412 config.main.db_path.unwrap()
413 );
414 }
415
416 #[test]
417 fn builds_db_url() {
418 let config = "[main] \n
419 db_type = \"Postgres\" \n
420 db_host = \"localhost\" \n
421 db_port = \"5432\" \n
422 db_user = \"root\" \n
423 db_pass = \"1234\" \n
424 db_name = \"refinery\"";
425
426 let config: Config = toml::from_str(config).unwrap();
427
428 assert_eq!(
429 "postgres://root:1234@localhost:5432/refinery",
430 build_db_url("postgres", &config)
431 );
432 }
433
434 #[test]
435 fn builds_db_env_var() {
436 std::env::set_var(
437 "DATABASE_URL",
438 "postgres://root:1234@localhost:5432/refinery",
439 );
440 let config = Config::from_env_var("DATABASE_URL").unwrap();
441 assert_eq!(
442 "postgres://root:1234@localhost:5432/refinery",
443 build_db_url("postgres", &config)
444 );
445 }
446
447 #[test]
448 fn builds_from_str() {
449 let config = Config::from_str("postgres://root:1234@localhost:5432/refinery").unwrap();
450 assert_eq!(
451 "postgres://root:1234@localhost:5432/refinery",
452 build_db_url("postgres", &config)
453 );
454 }
455
456 #[test]
457 fn builds_db_env_var_failure() {
458 std::env::set_var("DATABASE_URL", "this_is_not_a_url");
459 let config = Config::from_env_var("DATABASE_URL");
460 assert!(config.is_err());
461 }
462}