1#![doc = include_str!("../README.md")]
2
3pub mod identifier;
4
5pub use identifier::{Database, Role, User};
6
7#[cfg(feature = "sqlx")]
8pub mod sqlx;
9
10pub mod url;
11
12macro_rules! from_str_impl {
14 ($struct: ident, $min: expr, $max: expr) => {
15 impl std::str::FromStr for $struct {
16 type Err = String;
17
18 fn from_str(value: &str) -> Result<Self, Self::Err> {
19 let min_length = Self::MIN_LENGTH;
20 let max_length = Self::MAX_LENGTH;
21 let actual = value.len();
22
23 if actual < min_length {
24 Err(format!(
25 "{} byte min length: {min_length} violated, got: {actual}",
26 stringify!($struct)
27 ))
28 } else if actual > max_length {
29 Err(format!(
30 "{} byte max length: {max_length} violated, got: {actual}",
31 stringify!($struct)
32 ))
33 } else if value.as_bytes().contains(&0) {
34 Err(format!("{} contains NUL byte", stringify!($struct)))
35 } else {
36 Ok(Self(value.to_string()))
37 }
38 }
39 }
40
41 impl AsRef<str> for $struct {
42 fn as_ref(&self) -> &str {
43 &self.0
44 }
45 }
46
47 impl $struct {
48 pub const MIN_LENGTH: usize = $min;
49 pub const MAX_LENGTH: usize = $max;
50
51 pub fn as_str(&self) -> &str {
52 &self.0
53 }
54 }
55 };
56}
57
58#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
59pub struct HostName(String);
60
61impl HostName {
62 #[must_use]
63 pub fn as_str(&self) -> &str {
64 &self.0
65 }
66}
67
68impl std::str::FromStr for HostName {
69 type Err = &'static str;
70
71 fn from_str(value: &str) -> Result<Self, Self::Err> {
72 if hostname_validator::is_valid(value) {
73 Ok(Self(value.to_string()))
74 } else {
75 Err("invalid host name")
76 }
77 }
78}
79
80impl<'de> serde::Deserialize<'de> for HostName {
81 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
82 where
83 D: serde::Deserializer<'de>,
84 {
85 let s = String::deserialize(deserializer)?;
86 s.parse().map_err(serde::de::Error::custom)
87 }
88}
89
90#[derive(Clone, Debug, PartialEq, Eq)]
91pub enum Host {
92 HostName(HostName),
93 IpAddr(std::net::IpAddr),
94}
95
96impl serde::Serialize for Host {
97 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
98 serializer.serialize_str(&self.pg_env_value())
99 }
100}
101
102impl Host {
103 pub(crate) fn pg_env_value(&self) -> String {
104 match self {
105 Self::HostName(value) => value.0.clone(),
106 Self::IpAddr(value) => value.to_string(),
107 }
108 }
109}
110
111impl std::str::FromStr for Host {
112 type Err = &'static str;
113
114 fn from_str(value: &str) -> Result<Self, Self::Err> {
115 match std::net::IpAddr::from_str(value) {
116 Ok(addr) => Ok(Self::IpAddr(addr)),
117 Err(_) => match HostName::from_str(value) {
118 Ok(host_name) => Ok(Self::HostName(host_name)),
119 Err(_) => Err("Not a socket address or FQDN"),
120 },
121 }
122 }
123}
124
125impl From<HostName> for Host {
126 fn from(value: HostName) -> Self {
127 Self::HostName(value)
128 }
129}
130
131impl From<std::net::IpAddr> for Host {
132 fn from(value: std::net::IpAddr) -> Self {
133 Self::IpAddr(value)
134 }
135}
136
137#[derive(Clone, Debug, PartialEq, Eq)]
138pub struct HostAddr(std::net::IpAddr);
139
140impl HostAddr {
141 #[must_use]
142 pub const fn new(ip: std::net::IpAddr) -> Self {
143 Self(ip)
144 }
145}
146
147impl From<std::net::IpAddr> for HostAddr {
148 fn from(value: std::net::IpAddr) -> Self {
158 Self(value)
159 }
160}
161
162impl From<HostAddr> for std::net::IpAddr {
163 fn from(value: HostAddr) -> Self {
164 value.0
165 }
166}
167
168impl From<&HostAddr> for std::net::IpAddr {
169 fn from(value: &HostAddr) -> Self {
170 value.0
171 }
172}
173
174impl std::fmt::Display for HostAddr {
175 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 write!(formatter, "{}", self.0)
184 }
185}
186
187impl std::str::FromStr for HostAddr {
188 type Err = &'static str;
189
190 fn from_str(value: &str) -> Result<Self, Self::Err> {
206 match std::net::IpAddr::from_str(value) {
207 Ok(addr) => Ok(Self(addr)),
208 Err(_) => Err("invalid IP address"),
209 }
210 }
211}
212
213#[derive(Clone, Debug, PartialEq, Eq)]
214pub enum Endpoint {
215 Network {
216 host: Host,
217 channel_binding: Option<ChannelBinding>,
218 host_addr: Option<HostAddr>,
219 port: Option<Port>,
220 },
221 SocketPath(std::path::PathBuf),
222}
223
224impl serde::Serialize for Endpoint {
225 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
226 use serde::ser::SerializeStruct;
227 match self {
228 Self::Network {
229 host,
230 channel_binding,
231 host_addr,
232 port,
233 } => {
234 let mut state = serializer.serialize_struct("Endpoint", 4)?;
235 state.serialize_field("host", host)?;
236 if let Some(channel_binding) = channel_binding {
237 state.serialize_field("channel_binding", channel_binding)?;
238 }
239 if let Some(addr) = host_addr {
240 state.serialize_field("host_addr", &addr.to_string())?;
241 }
242 if let Some(port) = port {
243 state.serialize_field("port", port)?;
244 }
245 state.end()
246 }
247 Self::SocketPath(path) => {
248 let mut state = serializer.serialize_struct("Endpoint", 1)?;
249 state.serialize_field(
250 "socket_path",
251 &path.to_str().expect("socket path contains invalid utf8"),
252 )?;
253 state.end()
254 }
255 }
256 }
257}
258
259#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize)]
260pub struct Port(u16);
261
262impl Port {
263 #[must_use]
264 pub const fn new(port: u16) -> Self {
265 Self(port)
266 }
267
268 fn pg_env_value(self) -> String {
269 self.0.to_string()
270 }
271}
272
273impl std::str::FromStr for Port {
274 type Err = &'static str;
275
276 fn from_str(value: &str) -> Result<Self, Self::Err> {
277 match <u16 as std::str::FromStr>::from_str(value) {
278 Ok(port) => Ok(Port(port)),
279 Err(_) => Err("invalid postgresql port string"),
280 }
281 }
282}
283
284impl From<u16> for Port {
285 fn from(port: u16) -> Self {
286 Self(port)
287 }
288}
289
290impl From<Port> for u16 {
291 fn from(port: Port) -> Self {
292 port.0
293 }
294}
295
296impl From<&Port> for u16 {
297 fn from(port: &Port) -> Self {
298 port.0
299 }
300}
301
302#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
303pub struct ApplicationName(String);
304
305from_str_impl!(ApplicationName, 1, 63);
306
307impl ApplicationName {
308 fn pg_env_value(&self) -> String {
309 self.0.clone()
310 }
311}
312
313impl Database {
314 fn pg_env_value(&self) -> String {
315 self.as_str().to_owned()
316 }
317}
318
319impl Role {
320 fn pg_env_value(&self) -> String {
321 self.as_str().to_owned()
322 }
323}
324
325#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
326pub struct Password(String);
327
328from_str_impl!(Password, 0, 4096);
329
330impl Password {
331 fn pg_env_value(&self) -> String {
332 self.0.clone()
333 }
334}
335
336#[derive(
337 Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, strum::IntoStaticStr, strum::EnumString,
338)]
339#[serde(rename_all = "kebab-case")]
340#[strum(serialize_all = "kebab-case")]
341pub enum SslMode {
342 Allow,
343 Disable,
344 Prefer,
345 Require,
346 VerifyCa,
347 VerifyFull,
348}
349
350impl SslMode {
351 #[must_use]
352 pub fn as_str(&self) -> &'static str {
353 self.into()
354 }
355
356 fn pg_env_value(&self) -> String {
357 self.as_str().to_string()
358 }
359}
360
361#[derive(
362 Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, strum::IntoStaticStr, strum::EnumString,
363)]
364#[serde(rename_all = "kebab-case")]
365#[strum(serialize_all = "kebab-case")]
366pub enum ChannelBinding {
367 Disable,
368 Prefer,
369 Require,
370}
371
372impl ChannelBinding {
373 #[must_use]
374 pub fn as_str(&self) -> &'static str {
375 self.into()
376 }
377
378 fn pg_env_value(&self) -> String {
379 self.as_str().to_string()
380 }
381}
382
383#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
384#[serde(rename_all = "kebab-case")]
385pub enum SslRootCert {
386 File(std::path::PathBuf),
387 System,
388}
389
390impl SslRootCert {
391 pub(crate) fn pg_env_value(&self) -> String {
392 match self {
393 Self::File(path) => path.to_str().unwrap().to_string(),
394 Self::System => "system".to_string(),
395 }
396 }
397}
398
399impl From<std::path::PathBuf> for SslRootCert {
400 fn from(value: std::path::PathBuf) -> Self {
401 Self::File(value)
402 }
403}
404
405#[derive(Clone, Debug, PartialEq, Eq)]
406pub struct Config {
415 pub application_name: Option<ApplicationName>,
416 pub database: Database,
417 pub endpoint: Endpoint,
418 pub password: Option<Password>,
419 pub ssl_mode: SslMode,
420 pub ssl_root_cert: Option<SslRootCert>,
421 pub user: User,
422}
423
424pub const PGAPPNAME: cmd_proc::EnvVariableName<'static> =
425 cmd_proc::EnvVariableName::from_static_or_panic("PGAPPNAME");
426pub const PGCHANNELBINDING: cmd_proc::EnvVariableName<'static> =
427 cmd_proc::EnvVariableName::from_static_or_panic("PGCHANNELBINDING");
428pub const PGDATABASE: cmd_proc::EnvVariableName<'static> =
429 cmd_proc::EnvVariableName::from_static_or_panic("PGDATABASE");
430pub const PGHOST: cmd_proc::EnvVariableName<'static> =
431 cmd_proc::EnvVariableName::from_static_or_panic("PGHOST");
432pub const PGHOSTADDR: cmd_proc::EnvVariableName<'static> =
433 cmd_proc::EnvVariableName::from_static_or_panic("PGHOSTADDR");
434pub const PGPASSWORD: cmd_proc::EnvVariableName<'static> =
435 cmd_proc::EnvVariableName::from_static_or_panic("PGPASSWORD");
436pub const PGPORT: cmd_proc::EnvVariableName<'static> =
437 cmd_proc::EnvVariableName::from_static_or_panic("PGPORT");
438pub const PGSSLMODE: cmd_proc::EnvVariableName<'static> =
439 cmd_proc::EnvVariableName::from_static_or_panic("PGSSLMODE");
440pub const PGSSLROOTCERT: cmd_proc::EnvVariableName<'static> =
441 cmd_proc::EnvVariableName::from_static_or_panic("PGSSLROOTCERT");
442pub const PGUSER: cmd_proc::EnvVariableName<'static> =
443 cmd_proc::EnvVariableName::from_static_or_panic("PGUSER");
444
445impl serde::Serialize for Config {
446 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
447 use serde::ser::SerializeStruct;
448 let mut state = serializer.serialize_struct("Config", 8)?;
449
450 if let Some(application_name) = &self.application_name {
451 state.serialize_field("application_name", application_name)?;
452 }
453
454 state.serialize_field("database", &self.database)?;
455 state.serialize_field("endpoint", &self.endpoint)?;
456
457 if let Some(password) = &self.password {
458 state.serialize_field("password", password)?;
459 }
460
461 state.serialize_field("ssl_mode", &self.ssl_mode)?;
462
463 if let Some(ssl_root_cert) = &self.ssl_root_cert {
464 state.serialize_field("ssl_root_cert", ssl_root_cert)?;
465 }
466
467 state.serialize_field("user", &self.user)?;
468 state.serialize_field("url", &self.to_url_string())?;
469
470 state.end()
471 }
472}
473
474impl Config {
475 #[must_use]
565 pub fn to_url(&self) -> ::fluent_uri::Uri<String> {
566 use ::fluent_uri::{
567 Uri,
568 build::Builder,
569 component::{Authority, Scheme},
570 pct_enc::{EStr, EString, encoder},
571 };
572
573 const POSTGRES: &Scheme = Scheme::new_or_panic("postgres");
574
575 fn append_query_pair(query: &mut EString<encoder::Query>, key: &str, value: &str) {
576 if !query.is_empty() {
577 query.push('&');
578 }
579 query.encode_str::<encoder::Data>(key);
580 query.push('=');
581 query.encode_str::<encoder::Data>(value);
582 }
583
584 let mut query = EString::<encoder::Query>::new();
585
586 match &self.endpoint {
587 Endpoint::Network {
588 host,
589 channel_binding,
590 host_addr,
591 port,
592 } => {
593 let mut userinfo = EString::<encoder::Userinfo>::new();
594 userinfo.encode_str::<encoder::Data>(self.user.pg_env_value().as_str());
595 if let Some(password) = &self.password {
596 userinfo.push(':');
597 userinfo.encode_str::<encoder::Data>(password.as_str());
598 }
599
600 let mut path = EString::<encoder::Path>::new();
601 path.push('/');
602 path.encode_str::<encoder::Data>(self.database.as_str());
603
604 if let Some(addr) = host_addr {
605 append_query_pair(&mut query, "hostaddr", &addr.to_string());
606 }
607 if let Some(channel_binding) = channel_binding {
608 append_query_pair(&mut query, "channel_binding", channel_binding.as_str());
609 }
610 self.append_common_query_params(&mut query, append_query_pair);
611
612 let non_empty_query = if query.is_empty() {
613 None
614 } else {
615 Some(query.as_estr())
616 };
617
618 Uri::builder()
621 .scheme(POSTGRES)
622 .authority_with(|builder| {
623 let builder = builder.userinfo(&userinfo);
624 let builder = match host {
625 Host::IpAddr(addr) => builder.host(*addr),
626 Host::HostName(name) => {
627 let mut encoded = EString::<encoder::RegName>::new();
628 encoded.encode_str::<encoder::Data>(name.as_str());
629 builder.host(encoded.as_estr())
630 }
631 };
632 match port {
633 Some(port) => builder.port(port.0),
634 None => builder.advance(),
635 }
636 })
637 .path(&path)
638 .optional(Builder::query, non_empty_query)
639 .build()
640 .unwrap()
641 }
642 Endpoint::SocketPath(path) => {
643 append_query_pair(
644 &mut query,
645 "host",
646 path.to_str().expect("socket path contains invalid utf8"),
647 );
648 append_query_pair(&mut query, "dbname", self.database.as_str());
649 append_query_pair(&mut query, "user", self.user.pg_env_value().as_str());
650 if let Some(password) = &self.password {
651 append_query_pair(&mut query, "password", password.as_str());
652 }
653 self.append_common_query_params(&mut query, append_query_pair);
654
655 Uri::builder()
658 .scheme(POSTGRES)
659 .authority(Authority::EMPTY)
660 .path(EStr::EMPTY)
661 .query(&query)
662 .build()
663 .unwrap()
664 }
665 }
666 }
667
668 #[must_use]
670 pub fn to_url_string(&self) -> String {
671 self.to_url().into_string()
672 }
673
674 fn append_common_query_params(
675 &self,
676 query: &mut ::fluent_uri::pct_enc::EString<::fluent_uri::pct_enc::encoder::Query>,
677 append_query_pair: fn(
678 &mut ::fluent_uri::pct_enc::EString<::fluent_uri::pct_enc::encoder::Query>,
679 &str,
680 &str,
681 ),
682 ) {
683 if let Some(application_name) = &self.application_name {
684 append_query_pair(query, "application_name", application_name.as_str());
685 }
686 append_query_pair(query, "sslmode", &self.ssl_mode.pg_env_value());
687 if let Some(ssl_root_cert) = &self.ssl_root_cert {
688 append_query_pair(query, "sslrootcert", &ssl_root_cert.pg_env_value());
689 }
690 }
691
692 #[must_use]
751 pub fn to_pg_env(
752 &self,
753 ) -> std::collections::BTreeMap<cmd_proc::EnvVariableName<'static>, String> {
754 let mut map = std::collections::BTreeMap::new();
755
756 match &self.endpoint {
757 Endpoint::Network {
758 host,
759 channel_binding,
760 host_addr,
761 port,
762 } => {
763 map.insert(PGHOST.clone(), host.pg_env_value());
764 if let Some(port) = port {
765 map.insert(PGPORT.clone(), port.pg_env_value());
766 }
767 if let Some(channel_binding) = channel_binding {
768 map.insert(PGCHANNELBINDING.clone(), channel_binding.pg_env_value());
769 }
770 if let Some(addr) = host_addr {
771 map.insert(PGHOSTADDR.clone(), addr.to_string());
772 }
773 }
774 Endpoint::SocketPath(path) => {
775 map.insert(
776 PGHOST.clone(),
777 path.to_str()
778 .expect("socket path contains invalid utf8")
779 .to_string(),
780 );
781 }
782 }
783
784 map.insert(PGSSLMODE.clone(), self.ssl_mode.pg_env_value());
785 map.insert(PGUSER.clone(), self.user.pg_env_value());
786 map.insert(PGDATABASE.clone(), self.database.pg_env_value());
787
788 if let Some(application_name) = &self.application_name {
789 map.insert(PGAPPNAME.clone(), application_name.pg_env_value());
790 }
791
792 if let Some(password) = &self.password {
793 map.insert(PGPASSWORD.clone(), password.pg_env_value());
794 }
795
796 if let Some(ssl_root_cert) = &self.ssl_root_cert {
797 map.insert(PGSSLROOTCERT.clone(), ssl_root_cert.pg_env_value());
798 }
799
800 map
801 }
802
803 #[must_use]
804 pub fn endpoint(self, endpoint: Endpoint) -> Self {
805 Self { endpoint, ..self }
806 }
807
808 pub fn from_str_url(url: &str) -> Result<Self, crate::url::ParseError> {
815 crate::url::parse(url)
816 }
817}
818
819#[cfg(test)]
820mod test {
821 use super::*;
822 use pretty_assertions::assert_eq;
823 use std::str::FromStr;
824
825 const TEST_DATABASE: Database = Database::from_static_or_panic("some-database");
826 const TEST_USER: User = User::from_static_or_panic("some-user");
827
828 fn assert_config(expected: serde_json::Value, config: &Config) {
829 assert_eq!(expected, serde_json::to_value(config).unwrap());
830 }
831
832 fn repeat(char: char, len: usize) -> String {
833 std::iter::repeat_n(char, len).collect()
834 }
835
836 #[test]
837 fn application_name_lt_min_length() {
838 let value = String::new();
839
840 let err = ApplicationName::from_str(&value).expect_err("expected min length failure");
841
842 assert_eq!(err, "ApplicationName byte min length: 1 violated, got: 0");
843 }
844
845 #[test]
846 fn application_name_eq_min_length() {
847 let value = repeat('a', 1);
848
849 let application_name =
850 ApplicationName::from_str(&value).expect("expected valid min length value");
851
852 assert_eq!(application_name, ApplicationName(value));
853 }
854
855 #[test]
856 fn application_name_gt_min_length() {
857 let value = repeat('a', 2);
858
859 let application_name =
860 ApplicationName::from_str(&value).expect("expected valid value greater than min");
861
862 assert_eq!(application_name, ApplicationName(value));
863 }
864
865 #[test]
866 fn application_name_lt_max_length() {
867 let value = repeat('a', 62);
868
869 let application_name =
870 ApplicationName::from_str(&value).expect("expected valid value less than max");
871
872 assert_eq!(application_name, ApplicationName(value));
873 }
874
875 #[test]
876 fn application_name_eq_max_length() {
877 let value = repeat('a', 63);
878
879 let application_name =
880 ApplicationName::from_str(&value).expect("expected valid value equal to max");
881
882 assert_eq!(application_name, ApplicationName(value));
883 }
884
885 #[test]
886 fn application_name_gt_max_length() {
887 let value = repeat('a', 64);
888
889 let err = ApplicationName::from_str(&value).expect_err("expected max length failure");
890
891 assert_eq!(err, "ApplicationName byte max length: 63 violated, got: 64");
892 }
893
894 #[test]
895 fn application_name_contains_nul() {
896 let value = String::from('\0');
897
898 let err = ApplicationName::from_str(&value).expect_err("expected NUL failure");
899
900 assert_eq!(err, "ApplicationName contains NUL byte");
901 }
902
903 #[test]
904 fn password_eq_min_length() {
905 let value = String::new();
906
907 let password = Password::from_str(&value).expect("expected valid min length value");
908
909 assert_eq!(password, Password(value));
910 }
911
912 #[test]
913 fn password_gt_min_length() {
914 let value = repeat('p', 1);
915
916 let password = Password::from_str(&value).expect("expected valid value greater than min");
917
918 assert_eq!(password, Password(value));
919 }
920
921 #[test]
922 fn password_lt_max_length() {
923 let value = repeat('p', 4095);
924
925 let password = Password::from_str(&value).expect("expected valid value less than max");
926
927 assert_eq!(password, Password(value));
928 }
929
930 #[test]
931 fn password_eq_max_length() {
932 let value = repeat('p', 4096);
933
934 let password = Password::from_str(&value).expect("expected valid value equal to max");
935
936 assert_eq!(password, Password(value));
937 }
938
939 #[test]
940 fn password_gt_max_length() {
941 let value = repeat('p', 4097);
942
943 let err = Password::from_str(&value).expect_err("expected max length failure");
944
945 assert_eq!(err, "Password byte max length: 4096 violated, got: 4097");
946 }
947
948 #[test]
949 fn password_contains_nul() {
950 let value = String::from('\0');
951
952 let err = Password::from_str(&value).expect_err("expected NUL failure");
953
954 assert_eq!(err, "Password contains NUL byte");
955 }
956
957 #[test]
958 fn test_json() {
959 let config = Config {
960 application_name: None,
961 database: TEST_DATABASE,
962 endpoint: Endpoint::Network {
963 host: Host::from_str("some-host").unwrap(),
964 channel_binding: None,
965 host_addr: None,
966 port: Some(Port::new(5432)),
967 },
968 password: None,
969 ssl_mode: SslMode::VerifyFull,
970 ssl_root_cert: None,
971 user: TEST_USER,
972 };
973
974 assert_config(
975 serde_json::json!({
976 "database": "some-database",
977 "endpoint": {
978 "host": "some-host",
979 "port": 5432,
980 },
981 "ssl_mode": "verify-full",
982 "url": "postgres://some-user@some-host:5432/some-database?sslmode=verify-full",
983 "user": "some-user",
984 }),
985 &config,
986 );
987
988 assert_config(
989 serde_json::json!({
990 "application_name": "some-app",
991 "database": "some-database",
992 "endpoint": {
993 "host": "some-host",
994 "port": 5432,
995 },
996 "password": "some-password",
997 "ssl_mode": "verify-full",
998 "ssl_root_cert": {
999 "file": "/some.pem"
1000 },
1001 "url": "postgres://some-user:some-password@some-host:5432/some-database?application_name=some-app&sslmode=verify-full&sslrootcert=%2Fsome.pem",
1002 "user": "some-user"
1003 }),
1004 &Config {
1005 application_name: Some(ApplicationName::from_str("some-app").unwrap()),
1006 password: Some(Password::from_str("some-password").unwrap()),
1007 ssl_root_cert: Some(SslRootCert::File("/some.pem".into())),
1008 ..config.clone()
1009 },
1010 );
1011
1012 assert_config(
1013 serde_json::json!({
1014 "database": "some-database",
1015 "endpoint": {
1016 "host": "127.0.0.1",
1017 "port": 5432,
1018 },
1019 "ssl_mode": "verify-full",
1020 "url": "postgres://some-user@127.0.0.1:5432/some-database?sslmode=verify-full",
1021 "user": "some-user"
1022 }),
1023 &Config {
1024 endpoint: Endpoint::Network {
1025 host: Host::from_str("127.0.0.1").unwrap(),
1026 channel_binding: None,
1027 host_addr: None,
1028 port: Some(Port::new(5432)),
1029 },
1030 ..config.clone()
1031 },
1032 );
1033
1034 assert_config(
1035 serde_json::json!({
1036 "database": "some-database",
1037 "endpoint": {
1038 "socket_path": "/some/socket",
1039 },
1040 "ssl_mode": "verify-full",
1041 "url": "postgres://?host=%2Fsome%2Fsocket&dbname=some-database&user=some-user&sslmode=verify-full",
1042 "user": "some-user"
1043 }),
1044 &Config {
1045 endpoint: Endpoint::SocketPath("/some/socket".into()),
1046 ..config.clone()
1047 },
1048 );
1049
1050 assert_config(
1051 serde_json::json!({
1052 "database": "some-database",
1053 "endpoint": {
1054 "host": "some-host",
1055 "port": 5432,
1056 },
1057 "ssl_mode": "verify-full",
1058 "ssl_root_cert": "system",
1059 "url": "postgres://some-user@some-host:5432/some-database?sslmode=verify-full&sslrootcert=system",
1060 "user": "some-user"
1061 }),
1062 &Config {
1063 ssl_root_cert: Some(SslRootCert::System),
1064 ..config.clone()
1065 },
1066 );
1067
1068 assert_config(
1069 serde_json::json!({
1070 "database": "some-database",
1071 "endpoint": {
1072 "host": "some-host",
1073 "host_addr": "192.168.1.100",
1074 "port": 5432,
1075 },
1076 "ssl_mode": "verify-full",
1077 "url": "postgres://some-user@some-host:5432/some-database?hostaddr=192.168.1.100&sslmode=verify-full",
1078 "user": "some-user"
1079 }),
1080 &Config {
1081 endpoint: Endpoint::Network {
1082 host: Host::from_str("some-host").unwrap(),
1083 channel_binding: None,
1084 host_addr: Some("192.168.1.100".parse().unwrap()),
1085 port: Some(Port::new(5432)),
1086 },
1087 ..config.clone()
1088 },
1089 );
1090
1091 assert_config(
1093 serde_json::json!({
1094 "database": "some-database",
1095 "endpoint": {
1096 "host": "some-host",
1097 },
1098 "ssl_mode": "verify-full",
1099 "url": "postgres://some-user@some-host/some-database?sslmode=verify-full",
1100 "user": "some-user"
1101 }),
1102 &Config {
1103 endpoint: Endpoint::Network {
1104 host: Host::from_str("some-host").unwrap(),
1105 channel_binding: None,
1106 host_addr: None,
1107 port: None,
1108 },
1109 ..config.clone()
1110 },
1111 );
1112
1113 assert_config(
1115 serde_json::json!({
1116 "database": "some-database",
1117 "endpoint": {
1118 "host": "some-host",
1119 "host_addr": "10.0.0.1",
1120 },
1121 "ssl_mode": "verify-full",
1122 "url": "postgres://some-user@some-host/some-database?hostaddr=10.0.0.1&sslmode=verify-full",
1123 "user": "some-user"
1124 }),
1125 &Config {
1126 endpoint: Endpoint::Network {
1127 host: Host::from_str("some-host").unwrap(),
1128 channel_binding: None,
1129 host_addr: Some("10.0.0.1".parse().unwrap()),
1130 port: None,
1131 },
1132 ..config.clone()
1133 },
1134 );
1135 }
1136
1137 #[test]
1138 fn test_ipv6_url_formation() {
1139 let config_ipv6_loopback = Config {
1141 application_name: None,
1142 database: TEST_DATABASE,
1143 endpoint: Endpoint::Network {
1144 host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)),
1145 channel_binding: None,
1146 host_addr: None,
1147 port: Some(Port::new(5432)),
1148 },
1149 password: None,
1150 ssl_mode: SslMode::Disable,
1151 ssl_root_cert: None,
1152 user: User::POSTGRES,
1153 };
1154
1155 assert_eq!(
1156 config_ipv6_loopback.to_url_string(),
1157 "postgres://postgres@[::1]:5432/some-database?sslmode=disable",
1158 "IPv6 loopback address should be bracketed in URL"
1159 );
1160
1161 let config_ipv6_fe80 = Config {
1163 application_name: None,
1164 database: TEST_DATABASE,
1165 endpoint: Endpoint::Network {
1166 host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::new(
1167 0xfe80, 0, 0, 0, 0, 0, 0, 1,
1168 ))),
1169 channel_binding: None,
1170 host_addr: None,
1171 port: Some(Port::new(5432)),
1172 },
1173 password: None,
1174 ssl_mode: SslMode::Disable,
1175 ssl_root_cert: None,
1176 user: User::POSTGRES,
1177 };
1178
1179 assert_eq!(
1180 config_ipv6_fe80.to_url_string(),
1181 "postgres://postgres@[fe80::1]:5432/some-database?sslmode=disable",
1182 "IPv6 link-local address should be bracketed in URL"
1183 );
1184
1185 let config_ipv6_full = Config {
1187 application_name: None,
1188 database: TEST_DATABASE,
1189 endpoint: Endpoint::Network {
1190 host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::new(
1191 0x2001, 0x0db8, 0, 0, 0, 0, 0, 1,
1192 ))),
1193 channel_binding: None,
1194 host_addr: None,
1195 port: Some(Port::new(5432)),
1196 },
1197 password: None,
1198 ssl_mode: SslMode::Disable,
1199 ssl_root_cert: None,
1200 user: User::POSTGRES,
1201 };
1202
1203 assert_eq!(
1204 config_ipv6_full.to_url_string(),
1205 "postgres://postgres@[2001:db8::1]:5432/some-database?sslmode=disable",
1206 "Full IPv6 address should be bracketed in URL"
1207 );
1208
1209 let config_ipv4 = Config {
1211 application_name: None,
1212 database: TEST_DATABASE,
1213 endpoint: Endpoint::Network {
1214 host: Host::IpAddr(std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)),
1215 channel_binding: None,
1216 host_addr: None,
1217 port: Some(Port::new(5432)),
1218 },
1219 password: None,
1220 ssl_mode: SslMode::Disable,
1221 ssl_root_cert: None,
1222 user: User::POSTGRES,
1223 };
1224
1225 assert_eq!(
1226 config_ipv4.to_url_string(),
1227 "postgres://postgres@127.0.0.1:5432/some-database?sslmode=disable",
1228 "IPv4 address should NOT be bracketed in URL"
1229 );
1230
1231 let config_hostname = Config {
1233 application_name: None,
1234 database: TEST_DATABASE,
1235 endpoint: Endpoint::Network {
1236 host: Host::from_str("localhost").unwrap(),
1237 channel_binding: None,
1238 host_addr: None,
1239 port: Some(Port::new(5432)),
1240 },
1241 password: None,
1242 ssl_mode: SslMode::Disable,
1243 ssl_root_cert: None,
1244 user: User::POSTGRES,
1245 };
1246
1247 assert_eq!(
1248 config_hostname.to_url_string(),
1249 "postgres://postgres@localhost:5432/some-database?sslmode=disable",
1250 "Hostname should NOT be bracketed in URL"
1251 );
1252 }
1253}