1use crate::error::Kind;
2use crate::Error;
3#[cfg(any(
4 feature = "postgres",
5 feature = "tokio-postgres",
6 feature = "tiberius-config"
7))]
8use std::borrow::Cow;
9use std::convert::TryFrom;
10use std::str::FromStr;
11use url::Url;
12
13#[derive(Debug)]
16#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
17pub struct Config {
18 main: Main,
19}
20
21#[derive(Clone, Copy, PartialEq, Debug)]
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23pub enum ConfigDbType {
24 Mysql,
25 Postgres,
26 Sqlite,
27 Mssql,
28}
29
30impl Config {
31 pub fn new(db_type: ConfigDbType) -> Config {
33 Config {
34 main: Main::new(db_type),
35 }
36 }
37
38 pub fn from_env_var(name: &str) -> Result<Config, Error> {
40 let value = std::env::var(name).map_err(|_| {
41 Error::new(
42 Kind::ConfigError(format!("Couldn't find {name} environment variable")),
43 None,
44 )
45 })?;
46 Config::from_str(&value)
47 }
48
49 pub fn db_type(&self) -> ConfigDbType {
50 self.main.db_type
51 }
52
53 #[cfg(feature = "toml")]
55 pub fn from_file_location<T: AsRef<std::path::Path>>(location: T) -> Result<Config, Error> {
56 let file = std::fs::read_to_string(&location).map_err(|err| {
57 Error::new(
58 Kind::ConfigError(format!("could not open config file, {err}")),
59 None,
60 )
61 })?;
62
63 let config: Config = toml::from_str(&file).map_err(|err| {
64 Error::new(
65 Kind::ConfigError(format!("could not parse config file, {err}")),
66 None,
67 )
68 })?;
69
70 #[cfg(feature = "rusqlite")]
72 if config.main.db_type == ConfigDbType::Sqlite {
73 let mut config = config;
74 let mut config_db_path = config.main.db_path.ok_or_else(|| {
75 Error::new(
76 Kind::ConfigError("field path must be present for Sqlite database type".into()),
77 None,
78 )
79 })?;
80
81 if config_db_path.is_relative() {
82 let mut config_db_dir = location
83 .as_ref()
84 .parent()
85 .unwrap_or(&std::env::current_dir().unwrap())
87 .to_path_buf();
88
89 config_db_dir = std::fs::canonicalize(config_db_dir).unwrap();
90 config_db_path = config_db_dir.join(&config_db_path)
91 }
92
93 let config_db_path = config_db_path.canonicalize().map_err(|err| {
94 Error::new(
95 Kind::ConfigError(format!("invalid sqlite db path, {err}")),
96 None,
97 )
98 })?;
99 config.main.db_path = Some(config_db_path);
100
101 return Ok(config);
102 }
103
104 Ok(config)
105 }
106
107 #[cfg(feature = "tiberius-config")]
108 pub fn set_trust_cert(&mut self) {
109 self.main.trust_cert = true;
110 }
111}
112
113#[cfg(any(
114 feature = "mysql",
115 feature = "postgres",
116 feature = "tokio-postgres",
117 feature = "mysql_async",
118 feature = "tiberius-config"
119))]
120impl Config {
121 pub fn db_host(&self) -> Option<&str> {
122 self.main.db_host.as_deref()
123 }
124
125 pub fn db_port(&self) -> Option<&str> {
126 self.main.db_port.as_deref()
127 }
128
129 pub fn set_db_user(self, db_user: &str) -> Config {
130 Config {
131 main: Main {
132 db_user: Some(db_user.into()),
133 ..self.main
134 },
135 }
136 }
137
138 pub fn set_db_pass(self, db_pass: &str) -> Config {
139 Config {
140 main: Main {
141 db_pass: Some(db_pass.into()),
142 ..self.main
143 },
144 }
145 }
146
147 pub fn set_db_host(self, db_host: &str) -> Config {
148 Config {
149 main: Main {
150 db_host: Some(db_host.into()),
151 ..self.main
152 },
153 }
154 }
155
156 pub fn set_db_port(self, db_port: &str) -> Config {
157 Config {
158 main: Main {
159 db_port: Some(db_port.into()),
160 ..self.main
161 },
162 }
163 }
164
165 pub fn set_db_name(self, db_name: &str) -> Config {
166 Config {
167 main: Main {
168 db_name: Some(db_name.into()),
169 ..self.main
170 },
171 }
172 }
173}
174
175#[cfg(feature = "rusqlite")]
176impl Config {
177 pub(crate) fn db_path(&self) -> Option<&std::path::Path> {
178 self.main.db_path.as_deref()
179 }
180
181 pub fn set_db_path(self, db_path: &str) -> Config {
182 Config {
183 main: Main {
184 db_path: Some(db_path.into()),
185 ..self.main
186 },
187 }
188 }
189}
190
191#[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
192impl Config {
193 pub fn use_tls(&self) -> bool {
194 self.main.use_tls
195 }
196
197 pub fn set_use_tls(self, use_tls: bool) -> Config {
198 Config {
199 main: Main {
200 use_tls,
201 ..self.main
202 },
203 }
204 }
205}
206
207impl TryFrom<Url> for Config {
208 type Error = Error;
209
210 fn try_from(url: Url) -> Result<Config, Self::Error> {
211 let db_type = match url.scheme() {
212 "mysql" => ConfigDbType::Mysql,
213 "postgres" => ConfigDbType::Postgres,
214 "postgresql" => ConfigDbType::Postgres,
215 "sqlite" => ConfigDbType::Sqlite,
216 "mssql" => ConfigDbType::Mssql,
217 _ => {
218 return Err(Error::new(
219 Kind::ConfigError("Unsupported database".into()),
220 None,
221 ))
222 }
223 };
224
225 Ok(Self {
226 main: Main {
227 db_type,
228 #[cfg(feature = "rusqlite")]
229 db_path: Some(
230 url.as_str()[url.scheme().len()..]
231 .trim_start_matches(':')
232 .trim_start_matches("//")
233 .to_string()
234 .into(),
235 ),
236 #[cfg(any(
237 feature = "mysql",
238 feature = "postgres",
239 feature = "tokio-postgres",
240 feature = "mysql_async",
241 feature = "tiberius-config"
242 ))]
243 db_host: url.host_str().map(|r| r.to_string()),
244 #[cfg(any(
245 feature = "mysql",
246 feature = "postgres",
247 feature = "tokio-postgres",
248 feature = "mysql_async",
249 feature = "tiberius-config"
250 ))]
251 db_port: url.port().map(|r| r.to_string()),
252 #[cfg(any(
253 feature = "mysql",
254 feature = "postgres",
255 feature = "tokio-postgres",
256 feature = "mysql_async",
257 feature = "tiberius-config"
258 ))]
259 db_user: Some(url.username().to_string()),
260 #[cfg(any(
261 feature = "mysql",
262 feature = "postgres",
263 feature = "tokio-postgres",
264 feature = "mysql_async",
265 feature = "tiberius-config"
266 ))]
267 db_pass: url.password().map(|r| r.to_string()),
268 #[cfg(any(
269 feature = "mysql",
270 feature = "postgres",
271 feature = "tokio-postgres",
272 feature = "mysql_async",
273 feature = "tiberius-config"
274 ))]
275 db_name: Some(url.path().trim_start_matches('/').to_string()),
276 #[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
277 use_tls: match url
278 .query_pairs()
279 .collect::<std::collections::HashMap<Cow<'_, str>, Cow<'_, str>>>()
280 .get("sslmode")
281 {
282 Some(Cow::Borrowed("require")) => true,
283 Some(Cow::Borrowed("disable")) | None => false,
284 _ => {
285 return Err(Error::new(
286 Kind::ConfigError(
287 "Invalid sslmode value, please use disable/require".into(),
288 ),
289 None,
290 ))
291 }
292 },
293 #[cfg(feature = "tiberius-config")]
294 trust_cert: url
295 .query_pairs()
296 .collect::<std::collections::HashMap<Cow<'_, str>, Cow<'_, str>>>()
297 .get("trust_cert")
298 .unwrap_or(&Cow::Borrowed("false"))
299 .parse::<bool>()
300 .map_err(|_| {
301 Error::new(
302 Kind::ConfigError(
303 "Invalid trust_cert value, please use true/false".into(),
304 ),
305 None,
306 )
307 })?,
308 },
309 })
310 }
311}
312
313impl FromStr for Config {
314 type Err = Error;
315
316 fn from_str(url_str: &str) -> Result<Config, Self::Err> {
318 let url = Url::parse(url_str).map_err(|_| {
319 Error::new(
320 Kind::ConfigError(format!("Couldn't parse the string '{url_str}' as a URL")),
321 None,
322 )
323 })?;
324 Config::try_from(url)
325 }
326}
327
328#[derive(Debug)]
329#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
330struct Main {
331 db_type: ConfigDbType,
332 #[cfg(feature = "rusqlite")]
333 db_path: Option<std::path::PathBuf>,
334 #[cfg(any(
335 feature = "mysql",
336 feature = "postgres",
337 feature = "tokio-postgres",
338 feature = "mysql_async",
339 feature = "tiberius-config"
340 ))]
341 db_host: Option<String>,
342 #[cfg(any(
343 feature = "mysql",
344 feature = "postgres",
345 feature = "tokio-postgres",
346 feature = "mysql_async",
347 feature = "tiberius-config"
348 ))]
349 db_port: Option<String>,
350 #[cfg(any(
351 feature = "mysql",
352 feature = "postgres",
353 feature = "tokio-postgres",
354 feature = "mysql_async",
355 feature = "tiberius-config"
356 ))]
357 db_user: Option<String>,
358 #[cfg(any(
359 feature = "mysql",
360 feature = "postgres",
361 feature = "tokio-postgres",
362 feature = "mysql_async",
363 feature = "tiberius-config"
364 ))]
365 db_pass: Option<String>,
366 #[cfg(any(
367 feature = "mysql",
368 feature = "postgres",
369 feature = "tokio-postgres",
370 feature = "mysql_async",
371 feature = "tiberius-config"
372 ))]
373 db_name: Option<String>,
374 #[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
375 #[cfg_attr(feature = "serde", serde(default))]
376 use_tls: bool,
377 #[cfg(feature = "tiberius-config")]
378 #[cfg_attr(feature = "serde", serde(default))]
379 trust_cert: bool,
380}
381
382impl Main {
383 fn new(db_type: ConfigDbType) -> Self {
384 Main {
385 db_type,
386 #[cfg(feature = "rusqlite")]
387 db_path: None,
388 #[cfg(any(
389 feature = "mysql",
390 feature = "postgres",
391 feature = "tokio-postgres",
392 feature = "mysql_async",
393 feature = "tiberius-config"
394 ))]
395 db_host: None,
396 #[cfg(any(
397 feature = "mysql",
398 feature = "postgres",
399 feature = "tokio-postgres",
400 feature = "mysql_async",
401 feature = "tiberius-config"
402 ))]
403 db_port: None,
404 #[cfg(any(
405 feature = "mysql",
406 feature = "postgres",
407 feature = "tokio-postgres",
408 feature = "mysql_async",
409 feature = "tiberius-config"
410 ))]
411 db_user: None,
412 #[cfg(any(
413 feature = "mysql",
414 feature = "postgres",
415 feature = "tokio-postgres",
416 feature = "mysql_async",
417 feature = "tiberius-config"
418 ))]
419 db_pass: None,
420 #[cfg(any(
421 feature = "mysql",
422 feature = "postgres",
423 feature = "tokio-postgres",
424 feature = "mysql_async",
425 feature = "tiberius-config"
426 ))]
427 db_name: None,
428 #[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
429 use_tls: false,
430 #[cfg(feature = "tiberius-config")]
431 trust_cert: false,
432 }
433 }
434}
435
436#[cfg(any(
437 feature = "mysql",
438 feature = "postgres",
439 feature = "tokio-postgres",
440 feature = "mysql_async",
441))]
442pub(crate) fn build_db_url(name: &str, config: &Config) -> String {
443 let mut url: String = name.to_string() + "://";
444
445 if let Some(user) = &config.main.db_user {
446 url = url + user;
447 }
448 if let Some(pass) = &config.main.db_pass {
449 url = url + ":" + pass;
450 }
451 if let Some(host) = &config.main.db_host {
452 if config.main.db_user.is_some() {
453 url = url + "@" + host;
454 } else {
455 url = url + host;
456 }
457 }
458 if let Some(port) = &config.main.db_port {
459 url = url + ":" + port;
460 }
461 if let Some(name) = &config.main.db_name {
462 url = url + "/" + name;
463 }
464 url
465}
466
467#[cfg(feature = "tiberius-config")]
468impl TryFrom<&Config> for tiberius::Config {
469 type Error = Error;
470
471 fn try_from(config: &Config) -> Result<Self, Self::Error> {
472 let mut tconfig = tiberius::Config::new();
473 if let Some(host) = &config.main.db_host {
474 tconfig.host(host);
475 }
476
477 if let Some(port) = &config.main.db_port {
478 let port = port.parse().map_err(|_| {
479 Error::new(
480 Kind::ConfigError(format!("Couldn't parse value {port} as mssql port")),
481 None,
482 )
483 })?;
484 tconfig.port(port);
485 }
486
487 if let Some(db) = &config.main.db_name {
488 tconfig.database(db);
489 }
490
491 let user = config.main.db_user.as_deref().unwrap_or("");
492 let pass = config.main.db_pass.as_deref().unwrap_or("");
493
494 if config.main.trust_cert {
495 tconfig.trust_cert();
496 }
497 tconfig.authentication(tiberius::AuthMethod::sql_server(user, pass));
498
499 Ok(tconfig)
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::{Config, Kind};
506 use std::io::Write;
507 use std::str::FromStr;
508
509 #[cfg(any(
510 feature = "mysql",
511 feature = "postgres",
512 feature = "tokio-postgres",
513 feature = "mysql_async"
514 ))]
515 use super::build_db_url;
516
517 #[test]
518 #[cfg(feature = "toml")]
519 fn returns_config_error_from_invalid_config_location() {
520 let config = Config::from_file_location("invalid_path").unwrap_err();
521 match config.kind() {
522 Kind::ConfigError(msg) => assert!(msg.contains("could not open config file")),
523 _ => panic!("test failed"),
524 }
525 }
526
527 #[test]
528 #[cfg(feature = "toml")]
529 fn returns_config_error_from_invalid_toml_file() {
530 let config = "[<$%
531 db_type = \"Sqlite\" \n";
532
533 let mut config_file = tempfile::NamedTempFile::new_in(".").unwrap();
534 config_file.write_all(config.as_bytes()).unwrap();
535 let config = Config::from_file_location(config_file.path()).unwrap_err();
536 match config.kind() {
537 Kind::ConfigError(msg) => assert!(msg.contains("could not parse config file")),
538 _ => panic!("test failed"),
539 }
540 }
541
542 #[test]
543 #[cfg(all(feature = "toml", feature = "rusqlite"))]
544 fn returns_config_error_from_sqlite_with_missing_path() {
545 let config = "[main] \n
546 db_type = \"Sqlite\" \n";
547
548 let mut config_file = tempfile::NamedTempFile::new_in(".").unwrap();
549 config_file.write_all(config.as_bytes()).unwrap();
550 let config = Config::from_file_location(config_file.path()).unwrap_err();
551 match config.kind() {
552 Kind::ConfigError(msg) => {
553 assert_eq!("field path must be present for Sqlite database type", msg)
554 }
555 _ => panic!("test failed"),
556 }
557 }
558
559 #[test]
560 #[cfg(all(feature = "toml", feature = "rusqlite"))]
561 fn builds_sqlite_path_from_relative_path() {
562 let db_file = tempfile::NamedTempFile::new_in(".").unwrap();
563
564 let config = format!(
565 "[main] \n
566 db_type = \"Sqlite\" \n
567 db_path = \"{}\"",
568 db_file.path().file_name().unwrap().to_str().unwrap()
569 );
570
571 let mut config_file = tempfile::NamedTempFile::new_in(".").unwrap();
572 config_file.write_all(config.as_bytes()).unwrap();
573 let config = Config::from_file_location(config_file.path()).unwrap();
574
575 let parent = config_file.path().parent().unwrap();
576 assert!(parent.is_dir());
577 assert_eq!(
578 db_file.path().canonicalize().unwrap(),
579 config.main.db_path.unwrap()
580 );
581 }
582
583 #[test]
584 #[cfg(all(
585 feature = "toml",
586 any(
587 feature = "mysql",
588 feature = "postgres",
589 feature = "tokio-postgres",
590 feature = "mysql_async"
591 )
592 ))]
593 fn builds_db_url() {
594 let config = "[main] \n
595 db_type = \"Postgres\" \n
596 db_host = \"localhost\" \n
597 db_port = \"5432\" \n
598 db_user = \"root\" \n
599 db_pass = \"1234\" \n
600 db_name = \"refinery\"";
601
602 let config: Config = toml::from_str(config).unwrap();
603
604 assert_eq!(
605 "postgres://root:1234@localhost:5432/refinery",
606 build_db_url("postgres", &config)
607 );
608 }
609
610 #[test]
611 #[cfg(any(
612 feature = "mysql",
613 feature = "postgres",
614 feature = "tokio-postgres",
615 feature = "mysql_async"
616 ))]
617 fn builds_db_env_var() {
618 std::env::set_var(
619 "TEST_DATABASE_URL",
620 "postgres://root:1234@localhost:5432/refinery",
621 );
622 let config = Config::from_env_var("TEST_DATABASE_URL").unwrap();
623 assert_eq!(
624 "postgres://root:1234@localhost:5432/refinery",
625 build_db_url("postgres", &config)
626 );
627 }
628
629 #[test]
630 #[cfg(any(
631 feature = "mysql",
632 feature = "postgres",
633 feature = "tokio-postgres",
634 feature = "mysql_async"
635 ))]
636 fn builds_from_str() {
637 let config = Config::from_str("postgres://root:1234@localhost:5432/refinery").unwrap();
638 assert_eq!(
639 "postgres://root:1234@localhost:5432/refinery",
640 build_db_url("postgres", &config)
641 );
642 }
643
644 #[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
645 #[test]
646 fn builds_from_sslmode_str() {
647 use crate::config::ConfigDbType;
648
649 let config_disable =
650 Config::from_str("postgres://root:1234@localhost:5432/refinery?sslmode=disable")
651 .unwrap();
652 assert!(!config_disable.use_tls());
653
654 let config_require =
655 Config::from_str("postgres://root:1234@localhost:5432/refinery?sslmode=require")
656 .unwrap();
657 assert!(config_require.use_tls());
658
659 let manual_config_disable = Config::new(ConfigDbType::Postgres)
661 .set_db_user("root")
662 .set_db_pass("1234")
663 .set_db_host("localhost")
664 .set_db_port("5432")
665 .set_db_name("refinery")
666 .set_use_tls(false);
667 assert_eq!(config_disable.use_tls(), manual_config_disable.use_tls());
668
669 let manual_config_require = Config::new(ConfigDbType::Postgres)
670 .set_db_user("root")
671 .set_db_pass("1234")
672 .set_db_host("localhost")
673 .set_db_port("5432")
674 .set_db_name("refinery")
675 .set_use_tls(true);
676 assert_eq!(config_require.use_tls(), manual_config_require.use_tls());
677
678 let config =
679 Config::from_str("postgres://root:1234@localhost:5432/refinery?sslmode=invalidvalue");
680 assert!(config.is_err());
681 }
682
683 #[test]
684 fn builds_db_env_var_failure() {
685 std::env::set_var("TEST_DATABASE_URL_INVALID", "this_is_not_a_url");
686 let config = Config::from_env_var("TEST_DATABASE_URL_INVALID");
687 assert!(config.is_err());
688 }
689}