1#![doc = include_str!("../README.md")]
2
3pub mod config;
4pub mod identifier;
5pub mod parameter;
6pub mod pg_dump;
7
8pub use identifier::{Database, QualifiedTable, Role, User};
9pub use pg_dump::{PgSchemaDump, RestrictKey};
10
11#[cfg(feature = "sqlx")]
12pub mod sqlx;
13
14pub mod url;
15
16use config::{Endpoint, SslRootCert};
17
18#[derive(Clone, Debug, PartialEq, Eq)]
19pub struct Config {
28 pub endpoint: Endpoint,
29 pub session: config::Session,
30 pub ssl_mode: config::SslMode,
31 pub ssl_root_cert: Option<SslRootCert>,
32 #[cfg(feature = "sqlx")]
33 pub sqlx: crate::sqlx::Settings,
34}
35
36pub const PGAPPNAME: cmd_proc::EnvVariableName =
37 cmd_proc::EnvVariableName::from_static_or_panic("PGAPPNAME");
38pub const PGCHANNELBINDING: cmd_proc::EnvVariableName =
39 cmd_proc::EnvVariableName::from_static_or_panic("PGCHANNELBINDING");
40pub const PGDATABASE: cmd_proc::EnvVariableName =
41 cmd_proc::EnvVariableName::from_static_or_panic("PGDATABASE");
42pub const PGHOST: cmd_proc::EnvVariableName =
43 cmd_proc::EnvVariableName::from_static_or_panic("PGHOST");
44pub const PGHOSTADDR: cmd_proc::EnvVariableName =
45 cmd_proc::EnvVariableName::from_static_or_panic("PGHOSTADDR");
46pub const PGPASSWORD: cmd_proc::EnvVariableName =
47 cmd_proc::EnvVariableName::from_static_or_panic("PGPASSWORD");
48pub const PGPORT: cmd_proc::EnvVariableName =
49 cmd_proc::EnvVariableName::from_static_or_panic("PGPORT");
50pub const PGSSLMODE: cmd_proc::EnvVariableName =
51 cmd_proc::EnvVariableName::from_static_or_panic("PGSSLMODE");
52pub const PGSSLROOTCERT: cmd_proc::EnvVariableName =
53 cmd_proc::EnvVariableName::from_static_or_panic("PGSSLROOTCERT");
54pub const PGUSER: cmd_proc::EnvVariableName =
55 cmd_proc::EnvVariableName::from_static_or_panic("PGUSER");
56
57impl serde::Serialize for Config {
58 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
59 use serde::ser::SerializeStruct;
60 let mut state = serializer.serialize_struct("Config", 8)?;
61
62 if let Some(application_name) = &self.session.application_name {
63 state.serialize_field("application_name", application_name)?;
64 }
65
66 state.serialize_field("database", &self.session.database)?;
67 state.serialize_field("endpoint", &self.endpoint)?;
68
69 if let Some(password) = &self.session.password {
70 state.serialize_field("password", password)?;
71 }
72
73 state.serialize_field("ssl_mode", &self.ssl_mode)?;
74
75 if let Some(ssl_root_cert) = &self.ssl_root_cert {
76 state.serialize_field("ssl_root_cert", ssl_root_cert)?;
77 }
78
79 state.serialize_field("user", &self.session.user)?;
80 state.serialize_field("url", &self.to_url_string())?;
81
82 state.end()
83 }
84}
85
86impl Config {
87 #[must_use]
190 pub fn to_url(&self) -> ::fluent_uri::Uri<String> {
191 use ::fluent_uri::{
192 Uri,
193 build::Builder,
194 component::{Authority, Scheme},
195 pct_enc::{EStr, EString, encoder},
196 };
197
198 use config::Host;
199
200 const POSTGRES: &Scheme = Scheme::new_or_panic("postgres");
201
202 fn append_query_pair(query: &mut EString<encoder::Query>, key: &str, value: &str) {
203 if !query.is_empty() {
204 query.push('&');
205 }
206 query.encode_str::<encoder::Data>(key);
207 query.push('=');
208 query.encode_str::<encoder::Data>(value);
209 }
210
211 let mut query = EString::<encoder::Query>::new();
212
213 match &self.endpoint {
214 Endpoint::Network {
215 host,
216 channel_binding,
217 host_addr,
218 port,
219 } => {
220 let mut userinfo = EString::<encoder::Userinfo>::new();
221 userinfo.encode_str::<encoder::Data>(self.session.user.pg_env_value().as_str());
222 if let Some(password) = &self.session.password {
223 userinfo.push(':');
224 userinfo.encode_str::<encoder::Data>(password.as_str());
225 }
226
227 let mut path = EString::<encoder::Path>::new();
228 path.push('/');
229 path.encode_str::<encoder::Data>(self.session.database.as_str());
230
231 if let Some(addr) = host_addr {
232 append_query_pair(&mut query, "hostaddr", &addr.to_string());
233 }
234 if let Some(channel_binding) = channel_binding {
235 append_query_pair(&mut query, "channel_binding", channel_binding.as_str());
236 }
237 self.append_common_query_params(&mut query, append_query_pair);
238
239 let non_empty_query = if query.is_empty() {
240 None
241 } else {
242 Some(query.as_estr())
243 };
244
245 Uri::builder()
248 .scheme(POSTGRES)
249 .authority_with(|builder| {
250 let builder = builder.userinfo(&userinfo);
251 let builder = match host {
252 Host::IpAddr(addr) => builder.host(*addr),
253 Host::HostName(name) => {
254 let mut encoded = EString::<encoder::RegName>::new();
255 encoded.encode_str::<encoder::Data>(name.as_str());
256 builder.host(encoded.as_estr())
257 }
258 };
259 match port {
260 Some(port) => builder.port(u16::from(port)),
261 None => builder.advance(),
262 }
263 })
264 .path(&path)
265 .optional(Builder::query, non_empty_query)
266 .build()
267 .unwrap()
268 }
269 Endpoint::SocketPath(path) => {
270 append_query_pair(
271 &mut query,
272 "host",
273 path.to_str().expect("socket path contains invalid utf8"),
274 );
275 append_query_pair(&mut query, "dbname", self.session.database.as_str());
276 append_query_pair(
277 &mut query,
278 "user",
279 self.session.user.pg_env_value().as_str(),
280 );
281 if let Some(password) = &self.session.password {
282 append_query_pair(&mut query, "password", password.as_str());
283 }
284 self.append_common_query_params(&mut query, append_query_pair);
285
286 Uri::builder()
289 .scheme(POSTGRES)
290 .authority(Authority::EMPTY)
291 .path(EStr::EMPTY)
292 .query(&query)
293 .build()
294 .unwrap()
295 }
296 }
297 }
298
299 #[must_use]
301 pub fn to_url_string(&self) -> String {
302 self.to_url().into_string()
303 }
304
305 fn append_common_query_params(
306 &self,
307 query: &mut ::fluent_uri::pct_enc::EString<::fluent_uri::pct_enc::encoder::Query>,
308 append_query_pair: fn(
309 &mut ::fluent_uri::pct_enc::EString<::fluent_uri::pct_enc::encoder::Query>,
310 &str,
311 &str,
312 ),
313 ) {
314 if let Some(application_name) = &self.session.application_name {
315 append_query_pair(query, "application_name", application_name.as_str());
316 }
317 append_query_pair(query, "sslmode", &self.ssl_mode.pg_env_value());
318 if let Some(ssl_root_cert) = &self.ssl_root_cert {
319 append_query_pair(query, "sslrootcert", &ssl_root_cert.pg_env_value());
320 }
321 }
322
323 pub fn pg_env(
389 &self,
390 ) -> Result<
391 std::collections::BTreeMap<cmd_proc::EnvVariableName, cmd_proc::EnvVariableValue>,
392 cmd_proc::EnvVariableValueError,
393 > {
394 let mut map = std::collections::BTreeMap::new();
395
396 match &self.endpoint {
397 Endpoint::Network {
398 host,
399 channel_binding,
400 host_addr,
401 port,
402 } => {
403 map.insert(
404 PGHOST.clone(),
405 cmd_proc::EnvVariableValue::try_from(host.pg_env_value())?,
406 );
407 if let Some(port) = port {
408 map.insert(
409 PGPORT.clone(),
410 cmd_proc::EnvVariableValue::try_from(port.pg_env_value())?,
411 );
412 }
413 if let Some(channel_binding) = channel_binding {
414 map.insert(
415 PGCHANNELBINDING.clone(),
416 cmd_proc::EnvVariableValue::try_from(channel_binding.pg_env_value())?,
417 );
418 }
419 if let Some(addr) = host_addr {
420 map.insert(
421 PGHOSTADDR.clone(),
422 cmd_proc::EnvVariableValue::try_from(addr.to_string())?,
423 );
424 }
425 }
426 Endpoint::SocketPath(path) => {
427 map.insert(
428 PGHOST.clone(),
429 cmd_proc::EnvVariableValue::try_from(
430 path.to_str()
431 .expect("socket path contains invalid utf8")
432 .to_string(),
433 )?,
434 );
435 }
436 }
437
438 map.insert(
439 PGSSLMODE.clone(),
440 cmd_proc::EnvVariableValue::try_from(self.ssl_mode.pg_env_value())?,
441 );
442 map.insert(
443 PGUSER.clone(),
444 cmd_proc::EnvVariableValue::try_from(self.session.user.pg_env_value())?,
445 );
446 map.insert(
447 PGDATABASE.clone(),
448 cmd_proc::EnvVariableValue::try_from(self.session.database.pg_env_value())?,
449 );
450
451 if let Some(application_name) = &self.session.application_name {
452 map.insert(
453 PGAPPNAME.clone(),
454 cmd_proc::EnvVariableValue::try_from(application_name.pg_env_value())?,
455 );
456 }
457
458 if let Some(password) = &self.session.password {
459 map.insert(
460 PGPASSWORD.clone(),
461 cmd_proc::EnvVariableValue::try_from(password.pg_env_value())?,
462 );
463 }
464
465 if let Some(ssl_root_cert) = &self.ssl_root_cert {
466 map.insert(
467 PGSSLROOTCERT.clone(),
468 cmd_proc::EnvVariableValue::try_from(ssl_root_cert.pg_env_value())?,
469 );
470 }
471
472 Ok(map)
473 }
474
475 #[must_use]
476 pub fn endpoint(self, endpoint: Endpoint) -> Self {
477 Self { endpoint, ..self }
478 }
479
480 pub fn from_str_url(url: &str) -> Result<Self, crate::url::ParseError> {
487 crate::url::parse(url)
488 }
489}
490
491#[cfg(test)]
492mod test {
493 use super::*;
494 use config::*;
495 use pretty_assertions::assert_eq;
496 use std::str::FromStr;
497
498 const TEST_DATABASE: Database = Database::from_static_or_panic("some-database");
499 const TEST_USER: User = User::from_static_or_panic("some-user");
500
501 fn assert_config(expected: serde_json::Value, config: &Config) {
502 assert_eq!(expected, serde_json::to_value(config).unwrap());
503 }
504
505 #[test]
506 fn test_json() {
507 let config = Config {
508 endpoint: Endpoint::Network {
509 host: Host::from_str("some-host").unwrap(),
510 channel_binding: None,
511 host_addr: None,
512 port: Some(Port::new(5432)),
513 },
514 session: Session {
515 application_name: None,
516 database: TEST_DATABASE,
517 password: None,
518 user: TEST_USER,
519 },
520 ssl_mode: SslMode::VerifyFull,
521 ssl_root_cert: None,
522 #[cfg(feature = "sqlx")]
523 sqlx: Default::default(),
524 };
525
526 assert_config(
527 serde_json::json!({
528 "database": "some-database",
529 "endpoint": {
530 "host": "some-host",
531 "port": 5432,
532 },
533 "ssl_mode": "verify-full",
534 "url": "postgres://some-user@some-host:5432/some-database?sslmode=verify-full",
535 "user": "some-user",
536 }),
537 &config,
538 );
539
540 assert_config(
541 serde_json::json!({
542 "application_name": "some-app",
543 "database": "some-database",
544 "endpoint": {
545 "host": "some-host",
546 "port": 5432,
547 },
548 "password": "some-password",
549 "ssl_mode": "verify-full",
550 "ssl_root_cert": {
551 "file": "/some.pem"
552 },
553 "url": "postgres://some-user:some-password@some-host:5432/some-database?application_name=some-app&sslmode=verify-full&sslrootcert=%2Fsome.pem",
554 "user": "some-user"
555 }),
556 &Config {
557 session: Session {
558 application_name: Some(ApplicationName::from_str("some-app").unwrap()),
559 password: Some(Password::from_str("some-password").unwrap()),
560 ..config.session.clone()
561 },
562 ssl_root_cert: Some(SslRootCert::File("/some.pem".into())),
563 ..config.clone()
564 },
565 );
566
567 assert_config(
568 serde_json::json!({
569 "database": "some-database",
570 "endpoint": {
571 "host": "127.0.0.1",
572 "port": 5432,
573 },
574 "ssl_mode": "verify-full",
575 "url": "postgres://some-user@127.0.0.1:5432/some-database?sslmode=verify-full",
576 "user": "some-user"
577 }),
578 &Config {
579 endpoint: Endpoint::Network {
580 host: Host::from_str("127.0.0.1").unwrap(),
581 channel_binding: None,
582 host_addr: None,
583 port: Some(Port::new(5432)),
584 },
585 ..config.clone()
586 },
587 );
588
589 assert_config(
590 serde_json::json!({
591 "database": "some-database",
592 "endpoint": {
593 "socket_path": "/some/socket",
594 },
595 "ssl_mode": "verify-full",
596 "url": "postgres://?host=%2Fsome%2Fsocket&dbname=some-database&user=some-user&sslmode=verify-full",
597 "user": "some-user"
598 }),
599 &Config {
600 endpoint: Endpoint::SocketPath("/some/socket".into()),
601 ..config.clone()
602 },
603 );
604
605 assert_config(
606 serde_json::json!({
607 "database": "some-database",
608 "endpoint": {
609 "host": "some-host",
610 "port": 5432,
611 },
612 "ssl_mode": "verify-full",
613 "ssl_root_cert": "system",
614 "url": "postgres://some-user@some-host:5432/some-database?sslmode=verify-full&sslrootcert=system",
615 "user": "some-user"
616 }),
617 &Config {
618 ssl_root_cert: Some(SslRootCert::System),
619 ..config.clone()
620 },
621 );
622
623 assert_config(
624 serde_json::json!({
625 "database": "some-database",
626 "endpoint": {
627 "host": "some-host",
628 "host_addr": "192.168.1.100",
629 "port": 5432,
630 },
631 "ssl_mode": "verify-full",
632 "url": "postgres://some-user@some-host:5432/some-database?hostaddr=192.168.1.100&sslmode=verify-full",
633 "user": "some-user"
634 }),
635 &Config {
636 endpoint: Endpoint::Network {
637 host: Host::from_str("some-host").unwrap(),
638 channel_binding: None,
639 host_addr: Some("192.168.1.100".parse().unwrap()),
640 port: Some(Port::new(5432)),
641 },
642 ..config.clone()
643 },
644 );
645
646 assert_config(
648 serde_json::json!({
649 "database": "some-database",
650 "endpoint": {
651 "host": "some-host",
652 },
653 "ssl_mode": "verify-full",
654 "url": "postgres://some-user@some-host/some-database?sslmode=verify-full",
655 "user": "some-user"
656 }),
657 &Config {
658 endpoint: Endpoint::Network {
659 host: Host::from_str("some-host").unwrap(),
660 channel_binding: None,
661 host_addr: None,
662 port: None,
663 },
664 ..config.clone()
665 },
666 );
667
668 assert_config(
670 serde_json::json!({
671 "database": "some-database",
672 "endpoint": {
673 "host": "some-host",
674 "host_addr": "10.0.0.1",
675 },
676 "ssl_mode": "verify-full",
677 "url": "postgres://some-user@some-host/some-database?hostaddr=10.0.0.1&sslmode=verify-full",
678 "user": "some-user"
679 }),
680 &Config {
681 endpoint: Endpoint::Network {
682 host: Host::from_str("some-host").unwrap(),
683 channel_binding: None,
684 host_addr: Some("10.0.0.1".parse().unwrap()),
685 port: None,
686 },
687 ..config.clone()
688 },
689 );
690 }
691
692 #[test]
693 fn test_ipv6_url_formation() {
694 let config_ipv6_loopback = Config {
696 endpoint: Endpoint::Network {
697 host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)),
698 channel_binding: None,
699 host_addr: None,
700 port: Some(Port::new(5432)),
701 },
702 session: Session {
703 application_name: None,
704 database: TEST_DATABASE,
705 password: None,
706 user: User::POSTGRES,
707 },
708 ssl_mode: SslMode::Disable,
709 ssl_root_cert: None,
710 #[cfg(feature = "sqlx")]
711 sqlx: Default::default(),
712 };
713
714 assert_eq!(
715 config_ipv6_loopback.to_url_string(),
716 "postgres://postgres@[::1]:5432/some-database?sslmode=disable",
717 "IPv6 loopback address should be bracketed in URL"
718 );
719
720 let config_ipv6_fe80 = Config {
722 endpoint: Endpoint::Network {
723 host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::new(
724 0xfe80, 0, 0, 0, 0, 0, 0, 1,
725 ))),
726 channel_binding: None,
727 host_addr: None,
728 port: Some(Port::new(5432)),
729 },
730 session: Session {
731 application_name: None,
732 database: TEST_DATABASE,
733 password: None,
734 user: User::POSTGRES,
735 },
736 ssl_mode: SslMode::Disable,
737 ssl_root_cert: None,
738 #[cfg(feature = "sqlx")]
739 sqlx: Default::default(),
740 };
741
742 assert_eq!(
743 config_ipv6_fe80.to_url_string(),
744 "postgres://postgres@[fe80::1]:5432/some-database?sslmode=disable",
745 "IPv6 link-local address should be bracketed in URL"
746 );
747
748 let config_ipv6_full = Config {
750 endpoint: Endpoint::Network {
751 host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::new(
752 0x2001, 0x0db8, 0, 0, 0, 0, 0, 1,
753 ))),
754 channel_binding: None,
755 host_addr: None,
756 port: Some(Port::new(5432)),
757 },
758 session: Session {
759 application_name: None,
760 database: TEST_DATABASE,
761 password: None,
762 user: User::POSTGRES,
763 },
764 ssl_mode: SslMode::Disable,
765 ssl_root_cert: None,
766 #[cfg(feature = "sqlx")]
767 sqlx: Default::default(),
768 };
769
770 assert_eq!(
771 config_ipv6_full.to_url_string(),
772 "postgres://postgres@[2001:db8::1]:5432/some-database?sslmode=disable",
773 "Full IPv6 address should be bracketed in URL"
774 );
775
776 let config_ipv4 = Config {
778 endpoint: Endpoint::Network {
779 host: Host::IpAddr(std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)),
780 channel_binding: None,
781 host_addr: None,
782 port: Some(Port::new(5432)),
783 },
784 session: Session {
785 application_name: None,
786 database: TEST_DATABASE,
787 password: None,
788 user: User::POSTGRES,
789 },
790 ssl_mode: SslMode::Disable,
791 ssl_root_cert: None,
792 #[cfg(feature = "sqlx")]
793 sqlx: Default::default(),
794 };
795
796 assert_eq!(
797 config_ipv4.to_url_string(),
798 "postgres://postgres@127.0.0.1:5432/some-database?sslmode=disable",
799 "IPv4 address should NOT be bracketed in URL"
800 );
801
802 let config_hostname = Config {
804 endpoint: Endpoint::Network {
805 host: Host::from_str("localhost").unwrap(),
806 channel_binding: None,
807 host_addr: None,
808 port: Some(Port::new(5432)),
809 },
810 session: Session {
811 application_name: None,
812 database: TEST_DATABASE,
813 password: None,
814 user: User::POSTGRES,
815 },
816 ssl_mode: SslMode::Disable,
817 ssl_root_cert: None,
818 #[cfg(feature = "sqlx")]
819 sqlx: Default::default(),
820 };
821
822 assert_eq!(
823 config_hostname.to_url_string(),
824 "postgres://postgres@localhost:5432/some-database?sslmode=disable",
825 "Hostname should NOT be bracketed in URL"
826 );
827 }
828}