1#![doc = include_str!("../README.md")]
2
3pub mod config;
4pub mod identifier;
5pub mod pg_dump;
6
7pub use identifier::{Database, QualifiedTable, Role, User};
8pub use pg_dump::{PgSchemaDump, RestrictKey};
9
10#[cfg(feature = "sqlx")]
11pub mod sqlx;
12
13pub mod url;
14
15use config::{Endpoint, SslRootCert};
16
17#[derive(Clone, Debug, PartialEq, Eq)]
18pub struct Config {
27 pub endpoint: Endpoint,
28 pub session: config::Session,
29 pub ssl_mode: config::SslMode,
30 pub ssl_root_cert: Option<SslRootCert>,
31 #[cfg(feature = "sqlx")]
32 pub sqlx: crate::sqlx::Settings,
33}
34
35pub const PGAPPNAME: cmd_proc::EnvVariableName<'static> =
36 cmd_proc::EnvVariableName::from_static_or_panic("PGAPPNAME");
37pub const PGCHANNELBINDING: cmd_proc::EnvVariableName<'static> =
38 cmd_proc::EnvVariableName::from_static_or_panic("PGCHANNELBINDING");
39pub const PGDATABASE: cmd_proc::EnvVariableName<'static> =
40 cmd_proc::EnvVariableName::from_static_or_panic("PGDATABASE");
41pub const PGHOST: cmd_proc::EnvVariableName<'static> =
42 cmd_proc::EnvVariableName::from_static_or_panic("PGHOST");
43pub const PGHOSTADDR: cmd_proc::EnvVariableName<'static> =
44 cmd_proc::EnvVariableName::from_static_or_panic("PGHOSTADDR");
45pub const PGPASSWORD: cmd_proc::EnvVariableName<'static> =
46 cmd_proc::EnvVariableName::from_static_or_panic("PGPASSWORD");
47pub const PGPORT: cmd_proc::EnvVariableName<'static> =
48 cmd_proc::EnvVariableName::from_static_or_panic("PGPORT");
49pub const PGSSLMODE: cmd_proc::EnvVariableName<'static> =
50 cmd_proc::EnvVariableName::from_static_or_panic("PGSSLMODE");
51pub const PGSSLROOTCERT: cmd_proc::EnvVariableName<'static> =
52 cmd_proc::EnvVariableName::from_static_or_panic("PGSSLROOTCERT");
53pub const PGUSER: cmd_proc::EnvVariableName<'static> =
54 cmd_proc::EnvVariableName::from_static_or_panic("PGUSER");
55
56impl serde::Serialize for Config {
57 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
58 use serde::ser::SerializeStruct;
59 let mut state = serializer.serialize_struct("Config", 8)?;
60
61 if let Some(application_name) = &self.session.application_name {
62 state.serialize_field("application_name", application_name)?;
63 }
64
65 state.serialize_field("database", &self.session.database)?;
66 state.serialize_field("endpoint", &self.endpoint)?;
67
68 if let Some(password) = &self.session.password {
69 state.serialize_field("password", password)?;
70 }
71
72 state.serialize_field("ssl_mode", &self.ssl_mode)?;
73
74 if let Some(ssl_root_cert) = &self.ssl_root_cert {
75 state.serialize_field("ssl_root_cert", ssl_root_cert)?;
76 }
77
78 state.serialize_field("user", &self.session.user)?;
79 state.serialize_field("url", &self.to_url_string())?;
80
81 state.end()
82 }
83}
84
85impl Config {
86 #[must_use]
189 pub fn to_url(&self) -> ::fluent_uri::Uri<String> {
190 use ::fluent_uri::{
191 Uri,
192 build::Builder,
193 component::{Authority, Scheme},
194 pct_enc::{EStr, EString, encoder},
195 };
196
197 use config::Host;
198
199 const POSTGRES: &Scheme = Scheme::new_or_panic("postgres");
200
201 fn append_query_pair(query: &mut EString<encoder::Query>, key: &str, value: &str) {
202 if !query.is_empty() {
203 query.push('&');
204 }
205 query.encode_str::<encoder::Data>(key);
206 query.push('=');
207 query.encode_str::<encoder::Data>(value);
208 }
209
210 let mut query = EString::<encoder::Query>::new();
211
212 match &self.endpoint {
213 Endpoint::Network {
214 host,
215 channel_binding,
216 host_addr,
217 port,
218 } => {
219 let mut userinfo = EString::<encoder::Userinfo>::new();
220 userinfo.encode_str::<encoder::Data>(self.session.user.pg_env_value().as_str());
221 if let Some(password) = &self.session.password {
222 userinfo.push(':');
223 userinfo.encode_str::<encoder::Data>(password.as_str());
224 }
225
226 let mut path = EString::<encoder::Path>::new();
227 path.push('/');
228 path.encode_str::<encoder::Data>(self.session.database.as_str());
229
230 if let Some(addr) = host_addr {
231 append_query_pair(&mut query, "hostaddr", &addr.to_string());
232 }
233 if let Some(channel_binding) = channel_binding {
234 append_query_pair(&mut query, "channel_binding", channel_binding.as_str());
235 }
236 self.append_common_query_params(&mut query, append_query_pair);
237
238 let non_empty_query = if query.is_empty() {
239 None
240 } else {
241 Some(query.as_estr())
242 };
243
244 Uri::builder()
247 .scheme(POSTGRES)
248 .authority_with(|builder| {
249 let builder = builder.userinfo(&userinfo);
250 let builder = match host {
251 Host::IpAddr(addr) => builder.host(*addr),
252 Host::HostName(name) => {
253 let mut encoded = EString::<encoder::RegName>::new();
254 encoded.encode_str::<encoder::Data>(name.as_str());
255 builder.host(encoded.as_estr())
256 }
257 };
258 match port {
259 Some(port) => builder.port(u16::from(port)),
260 None => builder.advance(),
261 }
262 })
263 .path(&path)
264 .optional(Builder::query, non_empty_query)
265 .build()
266 .unwrap()
267 }
268 Endpoint::SocketPath(path) => {
269 append_query_pair(
270 &mut query,
271 "host",
272 path.to_str().expect("socket path contains invalid utf8"),
273 );
274 append_query_pair(&mut query, "dbname", self.session.database.as_str());
275 append_query_pair(
276 &mut query,
277 "user",
278 self.session.user.pg_env_value().as_str(),
279 );
280 if let Some(password) = &self.session.password {
281 append_query_pair(&mut query, "password", password.as_str());
282 }
283 self.append_common_query_params(&mut query, append_query_pair);
284
285 Uri::builder()
288 .scheme(POSTGRES)
289 .authority(Authority::EMPTY)
290 .path(EStr::EMPTY)
291 .query(&query)
292 .build()
293 .unwrap()
294 }
295 }
296 }
297
298 #[must_use]
300 pub fn to_url_string(&self) -> String {
301 self.to_url().into_string()
302 }
303
304 fn append_common_query_params(
305 &self,
306 query: &mut ::fluent_uri::pct_enc::EString<::fluent_uri::pct_enc::encoder::Query>,
307 append_query_pair: fn(
308 &mut ::fluent_uri::pct_enc::EString<::fluent_uri::pct_enc::encoder::Query>,
309 &str,
310 &str,
311 ),
312 ) {
313 if let Some(application_name) = &self.session.application_name {
314 append_query_pair(query, "application_name", application_name.as_str());
315 }
316 append_query_pair(query, "sslmode", &self.ssl_mode.pg_env_value());
317 if let Some(ssl_root_cert) = &self.ssl_root_cert {
318 append_query_pair(query, "sslrootcert", &ssl_root_cert.pg_env_value());
319 }
320 }
321
322 #[must_use]
388 pub fn to_pg_env(
389 &self,
390 ) -> std::collections::BTreeMap<cmd_proc::EnvVariableName<'static>, String> {
391 let mut map = std::collections::BTreeMap::new();
392
393 match &self.endpoint {
394 Endpoint::Network {
395 host,
396 channel_binding,
397 host_addr,
398 port,
399 } => {
400 map.insert(PGHOST.clone(), host.pg_env_value());
401 if let Some(port) = port {
402 map.insert(PGPORT.clone(), port.pg_env_value());
403 }
404 if let Some(channel_binding) = channel_binding {
405 map.insert(PGCHANNELBINDING.clone(), channel_binding.pg_env_value());
406 }
407 if let Some(addr) = host_addr {
408 map.insert(PGHOSTADDR.clone(), addr.to_string());
409 }
410 }
411 Endpoint::SocketPath(path) => {
412 map.insert(
413 PGHOST.clone(),
414 path.to_str()
415 .expect("socket path contains invalid utf8")
416 .to_string(),
417 );
418 }
419 }
420
421 map.insert(PGSSLMODE.clone(), self.ssl_mode.pg_env_value());
422 map.insert(PGUSER.clone(), self.session.user.pg_env_value());
423 map.insert(PGDATABASE.clone(), self.session.database.pg_env_value());
424
425 if let Some(application_name) = &self.session.application_name {
426 map.insert(PGAPPNAME.clone(), application_name.pg_env_value());
427 }
428
429 if let Some(password) = &self.session.password {
430 map.insert(PGPASSWORD.clone(), password.pg_env_value());
431 }
432
433 if let Some(ssl_root_cert) = &self.ssl_root_cert {
434 map.insert(PGSSLROOTCERT.clone(), ssl_root_cert.pg_env_value());
435 }
436
437 map
438 }
439
440 #[must_use]
441 pub fn endpoint(self, endpoint: Endpoint) -> Self {
442 Self { endpoint, ..self }
443 }
444
445 pub fn from_str_url(url: &str) -> Result<Self, crate::url::ParseError> {
452 crate::url::parse(url)
453 }
454}
455
456#[cfg(test)]
457mod test {
458 use super::*;
459 use config::*;
460 use pretty_assertions::assert_eq;
461 use std::str::FromStr;
462
463 const TEST_DATABASE: Database = Database::from_static_or_panic("some-database");
464 const TEST_USER: User = User::from_static_or_panic("some-user");
465
466 fn assert_config(expected: serde_json::Value, config: &Config) {
467 assert_eq!(expected, serde_json::to_value(config).unwrap());
468 }
469
470 #[test]
471 fn test_json() {
472 let config = Config {
473 endpoint: Endpoint::Network {
474 host: Host::from_str("some-host").unwrap(),
475 channel_binding: None,
476 host_addr: None,
477 port: Some(Port::new(5432)),
478 },
479 session: Session {
480 application_name: None,
481 database: TEST_DATABASE,
482 password: None,
483 user: TEST_USER,
484 },
485 ssl_mode: SslMode::VerifyFull,
486 ssl_root_cert: None,
487 #[cfg(feature = "sqlx")]
488 sqlx: Default::default(),
489 };
490
491 assert_config(
492 serde_json::json!({
493 "database": "some-database",
494 "endpoint": {
495 "host": "some-host",
496 "port": 5432,
497 },
498 "ssl_mode": "verify-full",
499 "url": "postgres://some-user@some-host:5432/some-database?sslmode=verify-full",
500 "user": "some-user",
501 }),
502 &config,
503 );
504
505 assert_config(
506 serde_json::json!({
507 "application_name": "some-app",
508 "database": "some-database",
509 "endpoint": {
510 "host": "some-host",
511 "port": 5432,
512 },
513 "password": "some-password",
514 "ssl_mode": "verify-full",
515 "ssl_root_cert": {
516 "file": "/some.pem"
517 },
518 "url": "postgres://some-user:some-password@some-host:5432/some-database?application_name=some-app&sslmode=verify-full&sslrootcert=%2Fsome.pem",
519 "user": "some-user"
520 }),
521 &Config {
522 session: Session {
523 application_name: Some(ApplicationName::from_str("some-app").unwrap()),
524 password: Some(Password::from_str("some-password").unwrap()),
525 ..config.session.clone()
526 },
527 ssl_root_cert: Some(SslRootCert::File("/some.pem".into())),
528 ..config.clone()
529 },
530 );
531
532 assert_config(
533 serde_json::json!({
534 "database": "some-database",
535 "endpoint": {
536 "host": "127.0.0.1",
537 "port": 5432,
538 },
539 "ssl_mode": "verify-full",
540 "url": "postgres://some-user@127.0.0.1:5432/some-database?sslmode=verify-full",
541 "user": "some-user"
542 }),
543 &Config {
544 endpoint: Endpoint::Network {
545 host: Host::from_str("127.0.0.1").unwrap(),
546 channel_binding: None,
547 host_addr: None,
548 port: Some(Port::new(5432)),
549 },
550 ..config.clone()
551 },
552 );
553
554 assert_config(
555 serde_json::json!({
556 "database": "some-database",
557 "endpoint": {
558 "socket_path": "/some/socket",
559 },
560 "ssl_mode": "verify-full",
561 "url": "postgres://?host=%2Fsome%2Fsocket&dbname=some-database&user=some-user&sslmode=verify-full",
562 "user": "some-user"
563 }),
564 &Config {
565 endpoint: Endpoint::SocketPath("/some/socket".into()),
566 ..config.clone()
567 },
568 );
569
570 assert_config(
571 serde_json::json!({
572 "database": "some-database",
573 "endpoint": {
574 "host": "some-host",
575 "port": 5432,
576 },
577 "ssl_mode": "verify-full",
578 "ssl_root_cert": "system",
579 "url": "postgres://some-user@some-host:5432/some-database?sslmode=verify-full&sslrootcert=system",
580 "user": "some-user"
581 }),
582 &Config {
583 ssl_root_cert: Some(SslRootCert::System),
584 ..config.clone()
585 },
586 );
587
588 assert_config(
589 serde_json::json!({
590 "database": "some-database",
591 "endpoint": {
592 "host": "some-host",
593 "host_addr": "192.168.1.100",
594 "port": 5432,
595 },
596 "ssl_mode": "verify-full",
597 "url": "postgres://some-user@some-host:5432/some-database?hostaddr=192.168.1.100&sslmode=verify-full",
598 "user": "some-user"
599 }),
600 &Config {
601 endpoint: Endpoint::Network {
602 host: Host::from_str("some-host").unwrap(),
603 channel_binding: None,
604 host_addr: Some("192.168.1.100".parse().unwrap()),
605 port: Some(Port::new(5432)),
606 },
607 ..config.clone()
608 },
609 );
610
611 assert_config(
613 serde_json::json!({
614 "database": "some-database",
615 "endpoint": {
616 "host": "some-host",
617 },
618 "ssl_mode": "verify-full",
619 "url": "postgres://some-user@some-host/some-database?sslmode=verify-full",
620 "user": "some-user"
621 }),
622 &Config {
623 endpoint: Endpoint::Network {
624 host: Host::from_str("some-host").unwrap(),
625 channel_binding: None,
626 host_addr: None,
627 port: None,
628 },
629 ..config.clone()
630 },
631 );
632
633 assert_config(
635 serde_json::json!({
636 "database": "some-database",
637 "endpoint": {
638 "host": "some-host",
639 "host_addr": "10.0.0.1",
640 },
641 "ssl_mode": "verify-full",
642 "url": "postgres://some-user@some-host/some-database?hostaddr=10.0.0.1&sslmode=verify-full",
643 "user": "some-user"
644 }),
645 &Config {
646 endpoint: Endpoint::Network {
647 host: Host::from_str("some-host").unwrap(),
648 channel_binding: None,
649 host_addr: Some("10.0.0.1".parse().unwrap()),
650 port: None,
651 },
652 ..config.clone()
653 },
654 );
655 }
656
657 #[test]
658 fn test_ipv6_url_formation() {
659 let config_ipv6_loopback = Config {
661 endpoint: Endpoint::Network {
662 host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)),
663 channel_binding: None,
664 host_addr: None,
665 port: Some(Port::new(5432)),
666 },
667 session: Session {
668 application_name: None,
669 database: TEST_DATABASE,
670 password: None,
671 user: User::POSTGRES,
672 },
673 ssl_mode: SslMode::Disable,
674 ssl_root_cert: None,
675 #[cfg(feature = "sqlx")]
676 sqlx: Default::default(),
677 };
678
679 assert_eq!(
680 config_ipv6_loopback.to_url_string(),
681 "postgres://postgres@[::1]:5432/some-database?sslmode=disable",
682 "IPv6 loopback address should be bracketed in URL"
683 );
684
685 let config_ipv6_fe80 = Config {
687 endpoint: Endpoint::Network {
688 host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::new(
689 0xfe80, 0, 0, 0, 0, 0, 0, 1,
690 ))),
691 channel_binding: None,
692 host_addr: None,
693 port: Some(Port::new(5432)),
694 },
695 session: Session {
696 application_name: None,
697 database: TEST_DATABASE,
698 password: None,
699 user: User::POSTGRES,
700 },
701 ssl_mode: SslMode::Disable,
702 ssl_root_cert: None,
703 #[cfg(feature = "sqlx")]
704 sqlx: Default::default(),
705 };
706
707 assert_eq!(
708 config_ipv6_fe80.to_url_string(),
709 "postgres://postgres@[fe80::1]:5432/some-database?sslmode=disable",
710 "IPv6 link-local address should be bracketed in URL"
711 );
712
713 let config_ipv6_full = Config {
715 endpoint: Endpoint::Network {
716 host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::new(
717 0x2001, 0x0db8, 0, 0, 0, 0, 0, 1,
718 ))),
719 channel_binding: None,
720 host_addr: None,
721 port: Some(Port::new(5432)),
722 },
723 session: Session {
724 application_name: None,
725 database: TEST_DATABASE,
726 password: None,
727 user: User::POSTGRES,
728 },
729 ssl_mode: SslMode::Disable,
730 ssl_root_cert: None,
731 #[cfg(feature = "sqlx")]
732 sqlx: Default::default(),
733 };
734
735 assert_eq!(
736 config_ipv6_full.to_url_string(),
737 "postgres://postgres@[2001:db8::1]:5432/some-database?sslmode=disable",
738 "Full IPv6 address should be bracketed in URL"
739 );
740
741 let config_ipv4 = Config {
743 endpoint: Endpoint::Network {
744 host: Host::IpAddr(std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)),
745 channel_binding: None,
746 host_addr: None,
747 port: Some(Port::new(5432)),
748 },
749 session: Session {
750 application_name: None,
751 database: TEST_DATABASE,
752 password: None,
753 user: User::POSTGRES,
754 },
755 ssl_mode: SslMode::Disable,
756 ssl_root_cert: None,
757 #[cfg(feature = "sqlx")]
758 sqlx: Default::default(),
759 };
760
761 assert_eq!(
762 config_ipv4.to_url_string(),
763 "postgres://postgres@127.0.0.1:5432/some-database?sslmode=disable",
764 "IPv4 address should NOT be bracketed in URL"
765 );
766
767 let config_hostname = Config {
769 endpoint: Endpoint::Network {
770 host: Host::from_str("localhost").unwrap(),
771 channel_binding: None,
772 host_addr: None,
773 port: Some(Port::new(5432)),
774 },
775 session: Session {
776 application_name: None,
777 database: TEST_DATABASE,
778 password: None,
779 user: User::POSTGRES,
780 },
781 ssl_mode: SslMode::Disable,
782 ssl_root_cert: None,
783 #[cfg(feature = "sqlx")]
784 sqlx: Default::default(),
785 };
786
787 assert_eq!(
788 config_hostname.to_url_string(),
789 "postgres://postgres@localhost:5432/some-database?sslmode=disable",
790 "Hostname should NOT be bracketed in URL"
791 );
792 }
793}