1#![doc = include_str!("../README.md")]
2
3pub mod identifier;
4pub mod pg_dump;
5
6pub use identifier::{Database, QualifiedTable, Role, User};
7pub use pg_dump::{PgSchemaDump, RestrictKey};
8
9#[cfg(feature = "sqlx")]
10pub mod sqlx;
11
12pub mod url;
13
14macro_rules! from_str_impl {
16 ($struct: ident, $min: expr, $max: expr) => {
17 impl std::str::FromStr for $struct {
18 type Err = String;
19
20 fn from_str(value: &str) -> Result<Self, Self::Err> {
21 let min_length = Self::MIN_LENGTH;
22 let max_length = Self::MAX_LENGTH;
23 let actual = value.len();
24
25 if actual < min_length {
26 Err(format!(
27 "{} byte min length: {min_length} violated, got: {actual}",
28 stringify!($struct)
29 ))
30 } else if actual > max_length {
31 Err(format!(
32 "{} byte max length: {max_length} violated, got: {actual}",
33 stringify!($struct)
34 ))
35 } else if value.as_bytes().contains(&0) {
36 Err(format!("{} contains NUL byte", stringify!($struct)))
37 } else {
38 Ok(Self(value.to_string()))
39 }
40 }
41 }
42
43 impl AsRef<str> for $struct {
44 fn as_ref(&self) -> &str {
45 &self.0
46 }
47 }
48
49 impl $struct {
50 pub const MIN_LENGTH: usize = $min;
51 pub const MAX_LENGTH: usize = $max;
52
53 pub fn as_str(&self) -> &str {
54 &self.0
55 }
56 }
57 };
58}
59
60#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
61pub struct HostName(String);
62
63impl HostName {
64 #[must_use]
65 pub fn as_str(&self) -> &str {
66 &self.0
67 }
68}
69
70impl std::str::FromStr for HostName {
71 type Err = &'static str;
72
73 fn from_str(value: &str) -> Result<Self, Self::Err> {
74 if hostname_validator::is_valid(value) {
75 Ok(Self(value.to_string()))
76 } else {
77 Err("invalid host name")
78 }
79 }
80}
81
82impl<'de> serde::Deserialize<'de> for HostName {
83 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
84 where
85 D: serde::Deserializer<'de>,
86 {
87 let s = String::deserialize(deserializer)?;
88 s.parse().map_err(serde::de::Error::custom)
89 }
90}
91
92#[derive(Clone, Debug, PartialEq, Eq)]
93pub enum Host {
94 HostName(HostName),
95 IpAddr(std::net::IpAddr),
96}
97
98impl serde::Serialize for Host {
99 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
100 serializer.serialize_str(&self.pg_env_value())
101 }
102}
103
104impl Host {
105 pub(crate) fn pg_env_value(&self) -> String {
106 match self {
107 Self::HostName(value) => value.0.clone(),
108 Self::IpAddr(value) => value.to_string(),
109 }
110 }
111}
112
113impl std::str::FromStr for Host {
114 type Err = &'static str;
115
116 fn from_str(value: &str) -> Result<Self, Self::Err> {
117 match std::net::IpAddr::from_str(value) {
118 Ok(addr) => Ok(Self::IpAddr(addr)),
119 Err(_) => match HostName::from_str(value) {
120 Ok(host_name) => Ok(Self::HostName(host_name)),
121 Err(_) => Err("Not a socket address or FQDN"),
122 },
123 }
124 }
125}
126
127impl From<HostName> for Host {
128 fn from(value: HostName) -> Self {
129 Self::HostName(value)
130 }
131}
132
133impl From<std::net::IpAddr> for Host {
134 fn from(value: std::net::IpAddr) -> Self {
135 Self::IpAddr(value)
136 }
137}
138
139#[derive(Clone, Debug, PartialEq, Eq)]
140pub struct HostAddr(std::net::IpAddr);
141
142impl HostAddr {
143 #[must_use]
144 pub const fn new(ip: std::net::IpAddr) -> Self {
145 Self(ip)
146 }
147}
148
149impl From<std::net::IpAddr> for HostAddr {
150 fn from(value: std::net::IpAddr) -> Self {
160 Self(value)
161 }
162}
163
164impl From<HostAddr> for std::net::IpAddr {
165 fn from(value: HostAddr) -> Self {
166 value.0
167 }
168}
169
170impl From<&HostAddr> for std::net::IpAddr {
171 fn from(value: &HostAddr) -> Self {
172 value.0
173 }
174}
175
176impl std::fmt::Display for HostAddr {
177 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 write!(formatter, "{}", self.0)
186 }
187}
188
189impl std::str::FromStr for HostAddr {
190 type Err = &'static str;
191
192 fn from_str(value: &str) -> Result<Self, Self::Err> {
208 match std::net::IpAddr::from_str(value) {
209 Ok(addr) => Ok(Self(addr)),
210 Err(_) => Err("invalid IP address"),
211 }
212 }
213}
214
215#[derive(Clone, Debug, PartialEq, Eq)]
216pub enum Endpoint {
217 Network {
218 host: Host,
219 channel_binding: Option<ChannelBinding>,
220 host_addr: Option<HostAddr>,
221 port: Option<Port>,
222 },
223 SocketPath(std::path::PathBuf),
224}
225
226impl serde::Serialize for Endpoint {
227 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
228 use serde::ser::SerializeStruct;
229 match self {
230 Self::Network {
231 host,
232 channel_binding,
233 host_addr,
234 port,
235 } => {
236 let mut state = serializer.serialize_struct("Endpoint", 4)?;
237 state.serialize_field("host", host)?;
238 if let Some(channel_binding) = channel_binding {
239 state.serialize_field("channel_binding", channel_binding)?;
240 }
241 if let Some(addr) = host_addr {
242 state.serialize_field("host_addr", &addr.to_string())?;
243 }
244 if let Some(port) = port {
245 state.serialize_field("port", port)?;
246 }
247 state.end()
248 }
249 Self::SocketPath(path) => {
250 let mut state = serializer.serialize_struct("Endpoint", 1)?;
251 state.serialize_field(
252 "socket_path",
253 &path.to_str().expect("socket path contains invalid utf8"),
254 )?;
255 state.end()
256 }
257 }
258 }
259}
260
261#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize)]
262pub struct Port(u16);
263
264impl Port {
265 #[must_use]
266 pub const fn new(port: u16) -> Self {
267 Self(port)
268 }
269
270 fn pg_env_value(self) -> String {
271 self.0.to_string()
272 }
273}
274
275impl std::str::FromStr for Port {
276 type Err = &'static str;
277
278 fn from_str(value: &str) -> Result<Self, Self::Err> {
279 match <u16 as std::str::FromStr>::from_str(value) {
280 Ok(port) => Ok(Port(port)),
281 Err(_) => Err("invalid postgresql port string"),
282 }
283 }
284}
285
286impl From<u16> for Port {
287 fn from(port: u16) -> Self {
288 Self(port)
289 }
290}
291
292impl From<Port> for u16 {
293 fn from(port: Port) -> Self {
294 port.0
295 }
296}
297
298impl From<&Port> for u16 {
299 fn from(port: &Port) -> Self {
300 port.0
301 }
302}
303
304#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
305pub struct ApplicationName(String);
306
307from_str_impl!(ApplicationName, 1, 63);
308
309impl ApplicationName {
310 fn pg_env_value(&self) -> String {
311 self.0.clone()
312 }
313}
314
315impl Database {
316 fn pg_env_value(&self) -> String {
317 self.as_str().to_owned()
318 }
319}
320
321impl Role {
322 fn pg_env_value(&self) -> String {
323 self.as_str().to_owned()
324 }
325}
326
327#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
328pub struct Password(String);
329
330from_str_impl!(Password, 0, 4096);
331
332impl Password {
333 fn pg_env_value(&self) -> String {
334 self.0.clone()
335 }
336}
337
338#[derive(
339 Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, strum::IntoStaticStr, strum::EnumString,
340)]
341#[serde(rename_all = "kebab-case")]
342#[strum(serialize_all = "kebab-case")]
343pub enum SslMode {
344 Allow,
345 Disable,
346 Prefer,
347 Require,
348 VerifyCa,
349 VerifyFull,
350}
351
352impl SslMode {
353 #[must_use]
354 pub fn as_str(&self) -> &'static str {
355 self.into()
356 }
357
358 fn pg_env_value(&self) -> String {
359 self.as_str().to_string()
360 }
361}
362
363#[derive(
364 Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, strum::IntoStaticStr, strum::EnumString,
365)]
366#[serde(rename_all = "kebab-case")]
367#[strum(serialize_all = "kebab-case")]
368pub enum ChannelBinding {
369 Disable,
370 Prefer,
371 Require,
372}
373
374impl ChannelBinding {
375 #[must_use]
376 pub fn as_str(&self) -> &'static str {
377 self.into()
378 }
379
380 fn pg_env_value(&self) -> String {
381 self.as_str().to_string()
382 }
383}
384
385#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
386#[serde(rename_all = "kebab-case")]
387pub enum SslRootCert {
388 File(std::path::PathBuf),
389 System,
390}
391
392impl SslRootCert {
393 pub(crate) fn pg_env_value(&self) -> String {
394 match self {
395 Self::File(path) => path.to_str().unwrap().to_string(),
396 Self::System => "system".to_string(),
397 }
398 }
399}
400
401impl From<std::path::PathBuf> for SslRootCert {
402 fn from(value: std::path::PathBuf) -> Self {
403 Self::File(value)
404 }
405}
406
407#[derive(Clone, Debug, PartialEq, Eq)]
408pub struct Config {
417 pub application_name: Option<ApplicationName>,
418 pub database: Database,
419 pub endpoint: Endpoint,
420 pub password: Option<Password>,
421 pub ssl_mode: SslMode,
422 pub ssl_root_cert: Option<SslRootCert>,
423 pub user: User,
424}
425
426pub const PGAPPNAME: cmd_proc::EnvVariableName<'static> =
427 cmd_proc::EnvVariableName::from_static_or_panic("PGAPPNAME");
428pub const PGCHANNELBINDING: cmd_proc::EnvVariableName<'static> =
429 cmd_proc::EnvVariableName::from_static_or_panic("PGCHANNELBINDING");
430pub const PGDATABASE: cmd_proc::EnvVariableName<'static> =
431 cmd_proc::EnvVariableName::from_static_or_panic("PGDATABASE");
432pub const PGHOST: cmd_proc::EnvVariableName<'static> =
433 cmd_proc::EnvVariableName::from_static_or_panic("PGHOST");
434pub const PGHOSTADDR: cmd_proc::EnvVariableName<'static> =
435 cmd_proc::EnvVariableName::from_static_or_panic("PGHOSTADDR");
436pub const PGPASSWORD: cmd_proc::EnvVariableName<'static> =
437 cmd_proc::EnvVariableName::from_static_or_panic("PGPASSWORD");
438pub const PGPORT: cmd_proc::EnvVariableName<'static> =
439 cmd_proc::EnvVariableName::from_static_or_panic("PGPORT");
440pub const PGSSLMODE: cmd_proc::EnvVariableName<'static> =
441 cmd_proc::EnvVariableName::from_static_or_panic("PGSSLMODE");
442pub const PGSSLROOTCERT: cmd_proc::EnvVariableName<'static> =
443 cmd_proc::EnvVariableName::from_static_or_panic("PGSSLROOTCERT");
444pub const PGUSER: cmd_proc::EnvVariableName<'static> =
445 cmd_proc::EnvVariableName::from_static_or_panic("PGUSER");
446
447impl serde::Serialize for Config {
448 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
449 use serde::ser::SerializeStruct;
450 let mut state = serializer.serialize_struct("Config", 8)?;
451
452 if let Some(application_name) = &self.application_name {
453 state.serialize_field("application_name", application_name)?;
454 }
455
456 state.serialize_field("database", &self.database)?;
457 state.serialize_field("endpoint", &self.endpoint)?;
458
459 if let Some(password) = &self.password {
460 state.serialize_field("password", password)?;
461 }
462
463 state.serialize_field("ssl_mode", &self.ssl_mode)?;
464
465 if let Some(ssl_root_cert) = &self.ssl_root_cert {
466 state.serialize_field("ssl_root_cert", ssl_root_cert)?;
467 }
468
469 state.serialize_field("user", &self.user)?;
470 state.serialize_field("url", &self.to_url_string())?;
471
472 state.end()
473 }
474}
475
476impl Config {
477 #[must_use]
567 pub fn to_url(&self) -> ::fluent_uri::Uri<String> {
568 use ::fluent_uri::{
569 Uri,
570 build::Builder,
571 component::{Authority, Scheme},
572 pct_enc::{EStr, EString, encoder},
573 };
574
575 const POSTGRES: &Scheme = Scheme::new_or_panic("postgres");
576
577 fn append_query_pair(query: &mut EString<encoder::Query>, key: &str, value: &str) {
578 if !query.is_empty() {
579 query.push('&');
580 }
581 query.encode_str::<encoder::Data>(key);
582 query.push('=');
583 query.encode_str::<encoder::Data>(value);
584 }
585
586 let mut query = EString::<encoder::Query>::new();
587
588 match &self.endpoint {
589 Endpoint::Network {
590 host,
591 channel_binding,
592 host_addr,
593 port,
594 } => {
595 let mut userinfo = EString::<encoder::Userinfo>::new();
596 userinfo.encode_str::<encoder::Data>(self.user.pg_env_value().as_str());
597 if let Some(password) = &self.password {
598 userinfo.push(':');
599 userinfo.encode_str::<encoder::Data>(password.as_str());
600 }
601
602 let mut path = EString::<encoder::Path>::new();
603 path.push('/');
604 path.encode_str::<encoder::Data>(self.database.as_str());
605
606 if let Some(addr) = host_addr {
607 append_query_pair(&mut query, "hostaddr", &addr.to_string());
608 }
609 if let Some(channel_binding) = channel_binding {
610 append_query_pair(&mut query, "channel_binding", channel_binding.as_str());
611 }
612 self.append_common_query_params(&mut query, append_query_pair);
613
614 let non_empty_query = if query.is_empty() {
615 None
616 } else {
617 Some(query.as_estr())
618 };
619
620 Uri::builder()
623 .scheme(POSTGRES)
624 .authority_with(|builder| {
625 let builder = builder.userinfo(&userinfo);
626 let builder = match host {
627 Host::IpAddr(addr) => builder.host(*addr),
628 Host::HostName(name) => {
629 let mut encoded = EString::<encoder::RegName>::new();
630 encoded.encode_str::<encoder::Data>(name.as_str());
631 builder.host(encoded.as_estr())
632 }
633 };
634 match port {
635 Some(port) => builder.port(port.0),
636 None => builder.advance(),
637 }
638 })
639 .path(&path)
640 .optional(Builder::query, non_empty_query)
641 .build()
642 .unwrap()
643 }
644 Endpoint::SocketPath(path) => {
645 append_query_pair(
646 &mut query,
647 "host",
648 path.to_str().expect("socket path contains invalid utf8"),
649 );
650 append_query_pair(&mut query, "dbname", self.database.as_str());
651 append_query_pair(&mut query, "user", self.user.pg_env_value().as_str());
652 if let Some(password) = &self.password {
653 append_query_pair(&mut query, "password", password.as_str());
654 }
655 self.append_common_query_params(&mut query, append_query_pair);
656
657 Uri::builder()
660 .scheme(POSTGRES)
661 .authority(Authority::EMPTY)
662 .path(EStr::EMPTY)
663 .query(&query)
664 .build()
665 .unwrap()
666 }
667 }
668 }
669
670 #[must_use]
672 pub fn to_url_string(&self) -> String {
673 self.to_url().into_string()
674 }
675
676 fn append_common_query_params(
677 &self,
678 query: &mut ::fluent_uri::pct_enc::EString<::fluent_uri::pct_enc::encoder::Query>,
679 append_query_pair: fn(
680 &mut ::fluent_uri::pct_enc::EString<::fluent_uri::pct_enc::encoder::Query>,
681 &str,
682 &str,
683 ),
684 ) {
685 if let Some(application_name) = &self.application_name {
686 append_query_pair(query, "application_name", application_name.as_str());
687 }
688 append_query_pair(query, "sslmode", &self.ssl_mode.pg_env_value());
689 if let Some(ssl_root_cert) = &self.ssl_root_cert {
690 append_query_pair(query, "sslrootcert", &ssl_root_cert.pg_env_value());
691 }
692 }
693
694 #[must_use]
753 pub fn to_pg_env(
754 &self,
755 ) -> std::collections::BTreeMap<cmd_proc::EnvVariableName<'static>, String> {
756 let mut map = std::collections::BTreeMap::new();
757
758 match &self.endpoint {
759 Endpoint::Network {
760 host,
761 channel_binding,
762 host_addr,
763 port,
764 } => {
765 map.insert(PGHOST.clone(), host.pg_env_value());
766 if let Some(port) = port {
767 map.insert(PGPORT.clone(), port.pg_env_value());
768 }
769 if let Some(channel_binding) = channel_binding {
770 map.insert(PGCHANNELBINDING.clone(), channel_binding.pg_env_value());
771 }
772 if let Some(addr) = host_addr {
773 map.insert(PGHOSTADDR.clone(), addr.to_string());
774 }
775 }
776 Endpoint::SocketPath(path) => {
777 map.insert(
778 PGHOST.clone(),
779 path.to_str()
780 .expect("socket path contains invalid utf8")
781 .to_string(),
782 );
783 }
784 }
785
786 map.insert(PGSSLMODE.clone(), self.ssl_mode.pg_env_value());
787 map.insert(PGUSER.clone(), self.user.pg_env_value());
788 map.insert(PGDATABASE.clone(), self.database.pg_env_value());
789
790 if let Some(application_name) = &self.application_name {
791 map.insert(PGAPPNAME.clone(), application_name.pg_env_value());
792 }
793
794 if let Some(password) = &self.password {
795 map.insert(PGPASSWORD.clone(), password.pg_env_value());
796 }
797
798 if let Some(ssl_root_cert) = &self.ssl_root_cert {
799 map.insert(PGSSLROOTCERT.clone(), ssl_root_cert.pg_env_value());
800 }
801
802 map
803 }
804
805 #[must_use]
806 pub fn endpoint(self, endpoint: Endpoint) -> Self {
807 Self { endpoint, ..self }
808 }
809
810 pub fn from_str_url(url: &str) -> Result<Self, crate::url::ParseError> {
817 crate::url::parse(url)
818 }
819}
820
821#[cfg(test)]
822mod test {
823 use super::*;
824 use pretty_assertions::assert_eq;
825 use std::str::FromStr;
826
827 const TEST_DATABASE: Database = Database::from_static_or_panic("some-database");
828 const TEST_USER: User = User::from_static_or_panic("some-user");
829
830 fn assert_config(expected: serde_json::Value, config: &Config) {
831 assert_eq!(expected, serde_json::to_value(config).unwrap());
832 }
833
834 fn repeat(char: char, len: usize) -> String {
835 std::iter::repeat_n(char, len).collect()
836 }
837
838 #[test]
839 fn application_name_lt_min_length() {
840 let value = String::new();
841
842 let err = ApplicationName::from_str(&value).expect_err("expected min length failure");
843
844 assert_eq!(err, "ApplicationName byte min length: 1 violated, got: 0");
845 }
846
847 #[test]
848 fn application_name_eq_min_length() {
849 let value = repeat('a', 1);
850
851 let application_name =
852 ApplicationName::from_str(&value).expect("expected valid min length value");
853
854 assert_eq!(application_name, ApplicationName(value));
855 }
856
857 #[test]
858 fn application_name_gt_min_length() {
859 let value = repeat('a', 2);
860
861 let application_name =
862 ApplicationName::from_str(&value).expect("expected valid value greater than min");
863
864 assert_eq!(application_name, ApplicationName(value));
865 }
866
867 #[test]
868 fn application_name_lt_max_length() {
869 let value = repeat('a', 62);
870
871 let application_name =
872 ApplicationName::from_str(&value).expect("expected valid value less than max");
873
874 assert_eq!(application_name, ApplicationName(value));
875 }
876
877 #[test]
878 fn application_name_eq_max_length() {
879 let value = repeat('a', 63);
880
881 let application_name =
882 ApplicationName::from_str(&value).expect("expected valid value equal to max");
883
884 assert_eq!(application_name, ApplicationName(value));
885 }
886
887 #[test]
888 fn application_name_gt_max_length() {
889 let value = repeat('a', 64);
890
891 let err = ApplicationName::from_str(&value).expect_err("expected max length failure");
892
893 assert_eq!(err, "ApplicationName byte max length: 63 violated, got: 64");
894 }
895
896 #[test]
897 fn application_name_contains_nul() {
898 let value = String::from('\0');
899
900 let err = ApplicationName::from_str(&value).expect_err("expected NUL failure");
901
902 assert_eq!(err, "ApplicationName contains NUL byte");
903 }
904
905 #[test]
906 fn password_eq_min_length() {
907 let value = String::new();
908
909 let password = Password::from_str(&value).expect("expected valid min length value");
910
911 assert_eq!(password, Password(value));
912 }
913
914 #[test]
915 fn password_gt_min_length() {
916 let value = repeat('p', 1);
917
918 let password = Password::from_str(&value).expect("expected valid value greater than min");
919
920 assert_eq!(password, Password(value));
921 }
922
923 #[test]
924 fn password_lt_max_length() {
925 let value = repeat('p', 4095);
926
927 let password = Password::from_str(&value).expect("expected valid value less than max");
928
929 assert_eq!(password, Password(value));
930 }
931
932 #[test]
933 fn password_eq_max_length() {
934 let value = repeat('p', 4096);
935
936 let password = Password::from_str(&value).expect("expected valid value equal to max");
937
938 assert_eq!(password, Password(value));
939 }
940
941 #[test]
942 fn password_gt_max_length() {
943 let value = repeat('p', 4097);
944
945 let err = Password::from_str(&value).expect_err("expected max length failure");
946
947 assert_eq!(err, "Password byte max length: 4096 violated, got: 4097");
948 }
949
950 #[test]
951 fn password_contains_nul() {
952 let value = String::from('\0');
953
954 let err = Password::from_str(&value).expect_err("expected NUL failure");
955
956 assert_eq!(err, "Password contains NUL byte");
957 }
958
959 #[test]
960 fn test_json() {
961 let config = Config {
962 application_name: None,
963 database: TEST_DATABASE,
964 endpoint: Endpoint::Network {
965 host: Host::from_str("some-host").unwrap(),
966 channel_binding: None,
967 host_addr: None,
968 port: Some(Port::new(5432)),
969 },
970 password: None,
971 ssl_mode: SslMode::VerifyFull,
972 ssl_root_cert: None,
973 user: TEST_USER,
974 };
975
976 assert_config(
977 serde_json::json!({
978 "database": "some-database",
979 "endpoint": {
980 "host": "some-host",
981 "port": 5432,
982 },
983 "ssl_mode": "verify-full",
984 "url": "postgres://some-user@some-host:5432/some-database?sslmode=verify-full",
985 "user": "some-user",
986 }),
987 &config,
988 );
989
990 assert_config(
991 serde_json::json!({
992 "application_name": "some-app",
993 "database": "some-database",
994 "endpoint": {
995 "host": "some-host",
996 "port": 5432,
997 },
998 "password": "some-password",
999 "ssl_mode": "verify-full",
1000 "ssl_root_cert": {
1001 "file": "/some.pem"
1002 },
1003 "url": "postgres://some-user:some-password@some-host:5432/some-database?application_name=some-app&sslmode=verify-full&sslrootcert=%2Fsome.pem",
1004 "user": "some-user"
1005 }),
1006 &Config {
1007 application_name: Some(ApplicationName::from_str("some-app").unwrap()),
1008 password: Some(Password::from_str("some-password").unwrap()),
1009 ssl_root_cert: Some(SslRootCert::File("/some.pem".into())),
1010 ..config.clone()
1011 },
1012 );
1013
1014 assert_config(
1015 serde_json::json!({
1016 "database": "some-database",
1017 "endpoint": {
1018 "host": "127.0.0.1",
1019 "port": 5432,
1020 },
1021 "ssl_mode": "verify-full",
1022 "url": "postgres://some-user@127.0.0.1:5432/some-database?sslmode=verify-full",
1023 "user": "some-user"
1024 }),
1025 &Config {
1026 endpoint: Endpoint::Network {
1027 host: Host::from_str("127.0.0.1").unwrap(),
1028 channel_binding: None,
1029 host_addr: None,
1030 port: Some(Port::new(5432)),
1031 },
1032 ..config.clone()
1033 },
1034 );
1035
1036 assert_config(
1037 serde_json::json!({
1038 "database": "some-database",
1039 "endpoint": {
1040 "socket_path": "/some/socket",
1041 },
1042 "ssl_mode": "verify-full",
1043 "url": "postgres://?host=%2Fsome%2Fsocket&dbname=some-database&user=some-user&sslmode=verify-full",
1044 "user": "some-user"
1045 }),
1046 &Config {
1047 endpoint: Endpoint::SocketPath("/some/socket".into()),
1048 ..config.clone()
1049 },
1050 );
1051
1052 assert_config(
1053 serde_json::json!({
1054 "database": "some-database",
1055 "endpoint": {
1056 "host": "some-host",
1057 "port": 5432,
1058 },
1059 "ssl_mode": "verify-full",
1060 "ssl_root_cert": "system",
1061 "url": "postgres://some-user@some-host:5432/some-database?sslmode=verify-full&sslrootcert=system",
1062 "user": "some-user"
1063 }),
1064 &Config {
1065 ssl_root_cert: Some(SslRootCert::System),
1066 ..config.clone()
1067 },
1068 );
1069
1070 assert_config(
1071 serde_json::json!({
1072 "database": "some-database",
1073 "endpoint": {
1074 "host": "some-host",
1075 "host_addr": "192.168.1.100",
1076 "port": 5432,
1077 },
1078 "ssl_mode": "verify-full",
1079 "url": "postgres://some-user@some-host:5432/some-database?hostaddr=192.168.1.100&sslmode=verify-full",
1080 "user": "some-user"
1081 }),
1082 &Config {
1083 endpoint: Endpoint::Network {
1084 host: Host::from_str("some-host").unwrap(),
1085 channel_binding: None,
1086 host_addr: Some("192.168.1.100".parse().unwrap()),
1087 port: Some(Port::new(5432)),
1088 },
1089 ..config.clone()
1090 },
1091 );
1092
1093 assert_config(
1095 serde_json::json!({
1096 "database": "some-database",
1097 "endpoint": {
1098 "host": "some-host",
1099 },
1100 "ssl_mode": "verify-full",
1101 "url": "postgres://some-user@some-host/some-database?sslmode=verify-full",
1102 "user": "some-user"
1103 }),
1104 &Config {
1105 endpoint: Endpoint::Network {
1106 host: Host::from_str("some-host").unwrap(),
1107 channel_binding: None,
1108 host_addr: None,
1109 port: None,
1110 },
1111 ..config.clone()
1112 },
1113 );
1114
1115 assert_config(
1117 serde_json::json!({
1118 "database": "some-database",
1119 "endpoint": {
1120 "host": "some-host",
1121 "host_addr": "10.0.0.1",
1122 },
1123 "ssl_mode": "verify-full",
1124 "url": "postgres://some-user@some-host/some-database?hostaddr=10.0.0.1&sslmode=verify-full",
1125 "user": "some-user"
1126 }),
1127 &Config {
1128 endpoint: Endpoint::Network {
1129 host: Host::from_str("some-host").unwrap(),
1130 channel_binding: None,
1131 host_addr: Some("10.0.0.1".parse().unwrap()),
1132 port: None,
1133 },
1134 ..config.clone()
1135 },
1136 );
1137 }
1138
1139 #[test]
1140 fn test_ipv6_url_formation() {
1141 let config_ipv6_loopback = Config {
1143 application_name: None,
1144 database: TEST_DATABASE,
1145 endpoint: Endpoint::Network {
1146 host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)),
1147 channel_binding: None,
1148 host_addr: None,
1149 port: Some(Port::new(5432)),
1150 },
1151 password: None,
1152 ssl_mode: SslMode::Disable,
1153 ssl_root_cert: None,
1154 user: User::POSTGRES,
1155 };
1156
1157 assert_eq!(
1158 config_ipv6_loopback.to_url_string(),
1159 "postgres://postgres@[::1]:5432/some-database?sslmode=disable",
1160 "IPv6 loopback address should be bracketed in URL"
1161 );
1162
1163 let config_ipv6_fe80 = Config {
1165 application_name: None,
1166 database: TEST_DATABASE,
1167 endpoint: Endpoint::Network {
1168 host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::new(
1169 0xfe80, 0, 0, 0, 0, 0, 0, 1,
1170 ))),
1171 channel_binding: None,
1172 host_addr: None,
1173 port: Some(Port::new(5432)),
1174 },
1175 password: None,
1176 ssl_mode: SslMode::Disable,
1177 ssl_root_cert: None,
1178 user: User::POSTGRES,
1179 };
1180
1181 assert_eq!(
1182 config_ipv6_fe80.to_url_string(),
1183 "postgres://postgres@[fe80::1]:5432/some-database?sslmode=disable",
1184 "IPv6 link-local address should be bracketed in URL"
1185 );
1186
1187 let config_ipv6_full = Config {
1189 application_name: None,
1190 database: TEST_DATABASE,
1191 endpoint: Endpoint::Network {
1192 host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::new(
1193 0x2001, 0x0db8, 0, 0, 0, 0, 0, 1,
1194 ))),
1195 channel_binding: None,
1196 host_addr: None,
1197 port: Some(Port::new(5432)),
1198 },
1199 password: None,
1200 ssl_mode: SslMode::Disable,
1201 ssl_root_cert: None,
1202 user: User::POSTGRES,
1203 };
1204
1205 assert_eq!(
1206 config_ipv6_full.to_url_string(),
1207 "postgres://postgres@[2001:db8::1]:5432/some-database?sslmode=disable",
1208 "Full IPv6 address should be bracketed in URL"
1209 );
1210
1211 let config_ipv4 = Config {
1213 application_name: None,
1214 database: TEST_DATABASE,
1215 endpoint: Endpoint::Network {
1216 host: Host::IpAddr(std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)),
1217 channel_binding: None,
1218 host_addr: None,
1219 port: Some(Port::new(5432)),
1220 },
1221 password: None,
1222 ssl_mode: SslMode::Disable,
1223 ssl_root_cert: None,
1224 user: User::POSTGRES,
1225 };
1226
1227 assert_eq!(
1228 config_ipv4.to_url_string(),
1229 "postgres://postgres@127.0.0.1:5432/some-database?sslmode=disable",
1230 "IPv4 address should NOT be bracketed in URL"
1231 );
1232
1233 let config_hostname = Config {
1235 application_name: None,
1236 database: TEST_DATABASE,
1237 endpoint: Endpoint::Network {
1238 host: Host::from_str("localhost").unwrap(),
1239 channel_binding: None,
1240 host_addr: None,
1241 port: Some(Port::new(5432)),
1242 },
1243 password: None,
1244 ssl_mode: SslMode::Disable,
1245 ssl_root_cert: None,
1246 user: User::POSTGRES,
1247 };
1248
1249 assert_eq!(
1250 config_hostname.to_url_string(),
1251 "postgres://postgres@localhost:5432/some-database?sslmode=disable",
1252 "Hostname should NOT be bracketed in URL"
1253 );
1254 }
1255}