1#![doc = include_str!("../README.md")]
2
3macro_rules! from_str_impl {
5 ($struct: ident, $min: expr, $max: expr) => {
6 impl std::str::FromStr for $struct {
7 type Err = &'static str;
8
9 fn from_str(value: &str) -> Result<Self, Self::Err> {
10 let min_length = Self::MIN_LENGTH;
11 let max_length = Self::MAX_LENGTH;
12 let actual = value.len();
13
14 if actual < min_length {
15 Err(concat!(
16 stringify!($struct),
17 " byte min length: {min_length} violated, got: {actual}"
18 ))
19 } else if actual > max_length {
20 Err(concat!(
21 stringify!($struct),
22 " byte max length: {max_length} violated, got: {actual}"
23 ))
24 } else if value.as_bytes().contains(&0) {
25 Err(concat!(stringify!($struct), " contains NUL byte"))
26 } else {
27 Ok(Self(value.to_string()))
28 }
29 }
30 }
31
32 impl AsRef<str> for $struct {
33 fn as_ref(&self) -> &str {
34 &self.0
35 }
36 }
37
38 impl $struct {
39 pub const MIN_LENGTH: usize = $min;
40 pub const MAX_LENGTH: usize = $max;
41
42 pub fn as_str(&self) -> &str {
43 &self.0
44 }
45 }
46 };
47}
48
49#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
50pub struct HostName(String);
51
52impl HostName {
53 pub fn as_str(&self) -> &str {
54 &self.0
55 }
56}
57
58impl std::str::FromStr for HostName {
59 type Err = &'static str;
60
61 fn from_str(value: &str) -> Result<Self, Self::Err> {
62 if hostname_validator::is_valid(value) {
63 Ok(Self(value.to_string()))
64 } else {
65 Err("invalid host name")
66 }
67 }
68}
69
70impl<'de> serde::Deserialize<'de> for HostName {
71 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
72 where
73 D: serde::Deserializer<'de>,
74 {
75 let s = String::deserialize(deserializer)?;
76 s.parse().map_err(serde::de::Error::custom)
77 }
78}
79
80#[derive(Clone, Debug, PartialEq, Eq)]
81pub enum Host {
82 HostName(HostName),
83 IpAddr(std::net::IpAddr),
84}
85
86impl serde::Serialize for Host {
87 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
88 serializer.serialize_str(&self.to_pg_env_value())
89 }
90}
91
92impl Host {
93 fn to_pg_env_value(&self) -> String {
94 match self {
95 Self::HostName(value) => value.0.clone(),
96 Self::IpAddr(value) => value.to_string(),
97 }
98 }
99}
100
101impl std::str::FromStr for Host {
102 type Err = &'static str;
103
104 fn from_str(value: &str) -> Result<Self, Self::Err> {
105 match std::net::IpAddr::from_str(value) {
106 Ok(addr) => Ok(Self::IpAddr(addr)),
107 Err(_) => match HostName::from_str(value) {
108 Ok(host_name) => Ok(Self::HostName(host_name)),
109 Err(_) => Err("Not a socket address or FQDN"),
110 },
111 }
112 }
113}
114
115#[macro_export]
116macro_rules! host {
117 ($string: literal) => {
118 <pg_client::Host as std::str::FromStr>::from_str($string).unwrap()
119 };
120}
121
122impl From<HostName> for Host {
123 fn from(value: HostName) -> Self {
124 Self::HostName(value)
125 }
126}
127
128impl From<std::net::IpAddr> for Host {
129 fn from(value: std::net::IpAddr) -> Self {
130 Self::IpAddr(value)
131 }
132}
133
134#[derive(Clone, Debug, PartialEq, Eq)]
135pub struct HostAddr(pub std::net::IpAddr);
136
137impl From<std::net::IpAddr> for HostAddr {
138 fn from(value: std::net::IpAddr) -> Self {
148 Self(value)
149 }
150}
151
152impl std::fmt::Display for HostAddr {
153 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 write!(formatter, "{}", self.0)
162 }
163}
164
165impl std::str::FromStr for HostAddr {
166 type Err = &'static str;
167
168 fn from_str(value: &str) -> Result<Self, Self::Err> {
184 match std::net::IpAddr::from_str(value) {
185 Ok(addr) => Ok(Self(addr)),
186 Err(_) => Err("invalid IP address"),
187 }
188 }
189}
190
191#[derive(Clone, Debug, PartialEq, Eq)]
192pub enum Endpoint {
193 Network {
194 host: Host,
195 host_addr: Option<HostAddr>,
196 port: Option<Port>,
197 },
198 SocketPath(std::path::PathBuf),
199}
200
201impl serde::Serialize for Endpoint {
202 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
203 use serde::ser::SerializeStruct;
204 match self {
205 Self::Network {
206 host,
207 host_addr,
208 port,
209 } => {
210 let mut state = serializer.serialize_struct("Endpoint", 3)?;
211 state.serialize_field("host", host)?;
212 if let Some(addr) = host_addr {
213 state.serialize_field("host_addr", &addr.to_string())?;
214 }
215 if let Some(port) = port {
216 state.serialize_field("port", port)?;
217 }
218 state.end()
219 }
220 Self::SocketPath(path) => {
221 let mut state = serializer.serialize_struct("Endpoint", 1)?;
222 state.serialize_field(
223 "socket_path",
224 &path.to_str().expect("socket path contains invalid utf8"),
225 )?;
226 state.end()
227 }
228 }
229 }
230}
231
232#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize)]
233pub struct Port(pub u16);
234
235impl std::str::FromStr for Port {
236 type Err = &'static str;
237
238 fn from_str(value: &str) -> Result<Self, Self::Err> {
239 match <u16 as std::str::FromStr>::from_str(value) {
240 Ok(port) => Ok(Port(port)),
241 Err(_) => Err("invalid postgresql port string"),
242 }
243 }
244}
245
246impl Port {
247 fn to_pg_env_value(self) -> String {
248 self.0.to_string()
249 }
250}
251
252impl From<Port> for u16 {
253 fn from(port: Port) -> Self {
254 port.0
255 }
256}
257
258impl From<&Port> for u16 {
259 fn from(port: &Port) -> Self {
260 port.0
261 }
262}
263
264#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
265pub struct ApplicationName(String);
266
267from_str_impl!(ApplicationName, 1, 63);
268
269#[macro_export]
270macro_rules! application_name {
271 ($string: literal) => {
272 <pg_client::ApplicationName as std::str::FromStr>::from_str($string).unwrap()
273 };
274}
275
276impl ApplicationName {
277 fn to_pg_env_value(&self) -> String {
278 self.0.clone()
279 }
280}
281
282#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
283pub struct Database(String);
284
285from_str_impl!(Database, 1, 63);
286
287#[macro_export]
288macro_rules! database {
289 ($string: literal) => {
290 <pg_client::Database as std::str::FromStr>::from_str($string).unwrap()
291 };
292}
293
294impl Database {
295 fn to_pg_env_value(&self) -> String {
296 self.0.clone()
297 }
298}
299
300#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
301pub struct Username(String);
302
303from_str_impl!(Username, 1, 63);
304
305#[macro_export]
306macro_rules! username {
307 ($string: literal) => {
308 <pg_client::Username as std::str::FromStr>::from_str($string).unwrap()
309 };
310}
311
312impl Username {
313 fn to_pg_env_value(&self) -> String {
314 self.0.clone()
315 }
316}
317
318#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
319pub struct Password(String);
320
321from_str_impl!(Password, 0, 4096);
322
323impl Password {
324 fn to_pg_env_value(&self) -> String {
325 self.0.clone()
326 }
327}
328
329impl From<String> for Password {
330 fn from(value: String) -> Self {
331 Self(value)
332 }
333}
334
335#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
336#[serde(rename_all = "kebab-case")]
337pub enum SslMode {
338 Allow,
339 Disable,
340 Prefer,
341 Require,
342 VerifyCa,
343 VerifyFull,
344}
345
346impl SslMode {
347 pub fn as_str(&self) -> &'static str {
348 match self {
349 Self::Allow => "allow",
350 Self::Disable => "disable",
351 Self::Prefer => "prefer",
352 Self::Require => "require",
353 Self::VerifyCa => "verify-ca",
354 Self::VerifyFull => "verify-full",
355 }
356 }
357
358 fn to_sqlx_ssl_mode(&self) -> sqlx::postgres::PgSslMode {
359 use sqlx::postgres::PgSslMode;
360
361 match self {
362 Self::Allow => PgSslMode::Allow,
363 Self::Disable => PgSslMode::Disable,
364 Self::Prefer => PgSslMode::Prefer,
365 Self::Require => PgSslMode::Require,
366 Self::VerifyCa => PgSslMode::VerifyCa,
367 Self::VerifyFull => PgSslMode::VerifyFull,
368 }
369 }
370
371 fn to_pg_env_value(&self) -> String {
372 self.as_str().to_string()
373 }
374}
375
376#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
377#[serde(rename_all = "kebab-case")]
378pub enum SslRootCert {
379 File(std::path::PathBuf),
380 System,
381}
382
383impl SslRootCert {
384 fn to_pg_env_value(&self) -> String {
385 match self {
386 Self::File(path) => path.to_str().unwrap().to_string(),
387 Self::System => "system".to_string(),
388 }
389 }
390}
391
392impl From<std::path::PathBuf> for SslRootCert {
393 fn from(value: std::path::PathBuf) -> Self {
394 Self::File(value)
395 }
396}
397
398#[derive(Debug, Clone, PartialEq, Eq)]
399pub enum SqlxOptionsError {
400 EnvConflict { env_key: String, field_name: String },
401 UnsupportedFeature { env_key: String, field_name: String },
402}
403
404impl std::fmt::Display for SqlxOptionsError {
405 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
406 match self {
407 Self::EnvConflict {
408 env_key,
409 field_name,
410 } => write!(
411 f,
412 "`PgConnectOptions::new` has inferred a `{field_name}` from `{env_key}` environment variable, but `pg_client::Config` does not specify a `{field_name}` value. `PgConnectOptions` does not provide an API to construct an instance without inferring from the environment, does not provide an API to unset the field, we have to bail out at this point. Please remove the environment variable!"
413 ),
414 Self::UnsupportedFeature {
415 env_key,
416 field_name,
417 } => write!(
418 f,
419 "`PgConnectOptions::new` has inferred `{field_name}` from the `{env_key}` environment variable, but `pg_client::Config` does not support that feature at this point. As `PgConnectOptions` has no option to unset that field, or a constructor that allows us to bypass the inference: we have to bail out, please remove the environment variable!"
420 ),
421 }
422 }
423}
424
425impl std::error::Error for SqlxOptionsError {}
426
427#[derive(Debug, thiserror::Error)]
428pub enum SqlxConnectionError {
429 #[error("Failed to create SQLx connect options")]
430 Options(#[from] SqlxOptionsError),
431
432 #[error("Failed to connect to database")]
433 Connect(#[source] sqlx::Error),
434
435 #[error("Failed to close database connection")]
436 Close(#[source] sqlx::Error),
437}
438
439#[derive(Clone, Debug, PartialEq, Eq)]
440pub struct Config {
449 pub application_name: Option<ApplicationName>,
450 pub database: Database,
451 pub endpoint: Endpoint,
452 pub password: Option<Password>,
453 pub ssl_mode: SslMode,
454 pub ssl_root_cert: Option<SslRootCert>,
455 pub username: Username,
456}
457
458impl serde::Serialize for Config {
459 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
460 use serde::ser::SerializeStruct;
461 let mut state = serializer.serialize_struct("Config", 8)?;
462
463 if let Some(application_name) = &self.application_name {
464 state.serialize_field("application_name", application_name)?;
465 }
466
467 state.serialize_field("database", &self.database)?;
468 state.serialize_field("endpoint", &self.endpoint)?;
469
470 if let Some(password) = &self.password {
471 state.serialize_field("password", password)?;
472 }
473
474 state.serialize_field("ssl_mode", &self.ssl_mode)?;
475
476 if let Some(ssl_root_cert) = &self.ssl_root_cert {
477 state.serialize_field("ssl_root_cert", ssl_root_cert)?;
478 }
479
480 state.serialize_field("username", &self.username)?;
481 state.serialize_field("url", &self.to_url())?;
482
483 state.end()
484 }
485}
486
487impl Config {
488 pub fn to_url(&self) -> url::Url {
583 let mut url = url::Url::parse("postgres://").unwrap();
584
585 match &self.endpoint {
586 Endpoint::Network {
587 host,
588 host_addr,
589 port,
590 } => {
591 match host {
593 Host::IpAddr(ip_addr) => {
594 url.set_ip_host(*ip_addr).unwrap();
595 }
596 Host::HostName(hostname) => {
597 url.set_host(Some(hostname.as_str())).unwrap();
598 }
599 }
600 url.set_username(self.username.to_pg_env_value().as_str())
601 .unwrap();
602
603 if let Some(password) = &self.password {
604 url.set_password(Some(password.as_str())).unwrap();
605 }
606
607 if let Some(port) = port {
608 url.set_port(Some(port.0)).unwrap();
609 }
610
611 url.set_path(self.database.as_str());
612
613 if let Some(addr) = host_addr {
615 url.query_pairs_mut()
616 .append_pair("hostaddr", &addr.to_string());
617 }
618 }
619 Endpoint::SocketPath(path) => {
620 url.query_pairs_mut()
622 .append_pair(
623 "host",
624 path.to_str().expect("socket path contains invalid utf8"),
625 )
626 .append_pair("dbname", self.database.as_str())
627 .append_pair("user", self.username.to_pg_env_value().as_str());
628
629 if let Some(password) = &self.password {
630 url.query_pairs_mut()
631 .append_pair("password", password.as_str());
632 }
633 }
634 }
635
636 {
637 let mut pairs = url.query_pairs_mut();
638
639 if let Some(application_name) = &self.application_name {
640 pairs.append_pair("application_name", application_name.as_str());
641 }
642
643 pairs.append_pair("sslmode", &self.ssl_mode.to_pg_env_value());
644
645 if let Some(ssl_root_cert) = &self.ssl_root_cert {
646 pairs.append_pair("sslrootcert", &ssl_root_cert.to_pg_env_value());
647 }
648 }
649
650 url
651 }
652
653 pub fn to_pg_env(&self) -> std::collections::BTreeMap<&'static str, String> {
710 let mut map = std::collections::BTreeMap::new();
711
712 match &self.endpoint {
713 Endpoint::Network {
714 host,
715 host_addr,
716 port,
717 } => {
718 map.insert("PGHOST", host.to_pg_env_value());
719 if let Some(port) = port {
720 map.insert("PGPORT", port.to_pg_env_value());
721 }
722 if let Some(addr) = host_addr {
723 map.insert("PGHOSTADDR", addr.to_string());
724 }
725 }
726 Endpoint::SocketPath(path) => {
727 map.insert(
728 "PGHOST",
729 path.to_str()
730 .expect("socket path contains invalid utf8")
731 .to_string(),
732 );
733 }
734 }
735
736 map.insert("PGSSLMODE", self.ssl_mode.to_pg_env_value());
737 map.insert("PGUSER", self.username.to_pg_env_value());
738 map.insert("PGDATABASE", self.database.to_pg_env_value());
739
740 if let Some(application_name) = &self.application_name {
741 map.insert("PGAPPNAME", application_name.to_pg_env_value());
742 }
743
744 if let Some(password) = &self.password {
745 map.insert("PGPASSWORD", password.to_pg_env_value());
746 }
747
748 if let Some(ssl_root_cert) = &self.ssl_root_cert {
749 map.insert("PGSSLROOTCERT", ssl_root_cert.to_pg_env_value());
750 }
751
752 map
753 }
754
755 pub fn to_sqlx_connect_options(
797 &self,
798 ) -> Result<sqlx::postgres::PgConnectOptions, SqlxOptionsError> {
799 fn reject_env(env_key: &str, field_name: &str) -> Result<(), SqlxOptionsError> {
800 if std::env::var(env_key).is_ok() {
801 Err(SqlxOptionsError::EnvConflict {
802 env_key: env_key.to_string(),
803 field_name: field_name.to_string(),
804 })
805 } else {
806 Ok(())
807 }
808 }
809
810 fn unsupported_env(env_key: &str, field_name: &str) -> Result<(), SqlxOptionsError> {
811 if std::env::var(env_key).is_ok() {
812 Err(SqlxOptionsError::UnsupportedFeature {
813 env_key: env_key.to_string(),
814 field_name: field_name.to_string(),
815 })
816 } else {
817 Ok(())
818 }
819 }
820
821 let mut options = sqlx::postgres::PgConnectOptions::new_without_pgpass();
825
826 unsupported_env("PGSSLKEY", "ssl_client_key")?;
827 unsupported_env("PGSSLCERT", "ssl_client_cert")?;
828 unsupported_env("PGOPTIONS", "options")?;
829
830 options = options.database(self.database.as_str());
831
832 match &self.endpoint {
833 Endpoint::Network {
834 host,
835 host_addr,
836 port,
837 } => {
838 options = options.host(&host.to_pg_env_value());
839 if let Some(port) = port {
840 options = options.port(port.into());
841 } else {
842 reject_env("PGPORT", "port")?;
843 }
844 if let Some(host_addr) = host_addr {
845 options = options.host_addr(&host_addr.to_string())
846 } else {
847 reject_env("PGHOSTADDR", "hostaddr")?;
848 }
849 }
850 Endpoint::SocketPath(path) => {
851 options = options.host(path.to_str().expect("socket path contains invalid utf8"));
852 reject_env("PGPORT", "port")?;
853 reject_env("PGHOSTADDR", "hostaddr")?;
854 }
855 }
856
857 options = options.ssl_mode(self.ssl_mode.to_sqlx_ssl_mode());
858 options = options.username(self.username.as_str());
859
860 if let Some(application_name) = &self.application_name {
861 options = options.application_name(application_name.as_str());
862 } else {
863 reject_env("PGAPPNAME", "application_name")?;
864 }
865
866 if let Some(password) = &self.password {
867 options = options.password(password.as_str());
868 } else {
869 reject_env("PGPASSWORD", "password")?;
870 }
871
872 if let Some(ssl_root_cert) = &self.ssl_root_cert {
873 options = options.ssl_root_cert(ssl_root_cert.to_pg_env_value());
874 } else {
875 reject_env("PGSSLROOTCERT", "ssl_root_cert")?;
876 }
877
878 Ok(options)
879 }
880
881 pub async fn with_sqlx_connection<T, F: AsyncFnMut(&mut sqlx::postgres::PgConnection) -> T>(
882 &self,
883 mut action: F,
884 ) -> Result<T, SqlxConnectionError> {
885 let config = self.to_sqlx_connect_options()?;
886
887 let mut connection = sqlx::ConnectOptions::connect(&config)
888 .await
889 .map_err(SqlxConnectionError::Connect)?;
890
891 let result = action(&mut connection).await;
892
893 sqlx::Connection::close(connection)
894 .await
895 .map_err(SqlxConnectionError::Close)?;
896
897 Ok(result)
898 }
899
900 pub fn endpoint(self, endpoint: Endpoint) -> Self {
901 Self { endpoint, ..self }
902 }
903}
904
905#[cfg(test)]
906mod test {
907 use super::*;
908 use pretty_assertions::assert_eq;
909 use std::str::FromStr;
910
911 fn assert_config(expected: serde_json::Value, config: &Config) {
912 assert_eq!(expected, serde_json::to_value(config).unwrap());
913 }
914
915 fn repeat(char: char, len: usize) -> String {
916 std::iter::repeat_n(char, len).collect()
917 }
918
919 #[test]
920 fn application_name_lt_min_length() {
921 let value = String::new();
922
923 let err = ApplicationName::from_str(&value).expect_err("expected min length failure");
924
925 assert_eq!(
926 err,
927 "ApplicationName byte min length: {min_length} violated, got: {actual}"
928 );
929 }
930
931 #[test]
932 fn application_name_eq_min_length() {
933 let value = repeat('a', 1);
934
935 let application_name =
936 ApplicationName::from_str(&value).expect("expected valid min length value");
937
938 assert_eq!(application_name, ApplicationName(value));
939 }
940
941 #[test]
942 fn application_name_gt_min_length() {
943 let value = repeat('a', 2);
944
945 let application_name =
946 ApplicationName::from_str(&value).expect("expected valid value greater than min");
947
948 assert_eq!(application_name, ApplicationName(value));
949 }
950
951 #[test]
952 fn application_name_lt_max_length() {
953 let value = repeat('a', 62);
954
955 let application_name =
956 ApplicationName::from_str(&value).expect("expected valid value less than max");
957
958 assert_eq!(application_name, ApplicationName(value));
959 }
960
961 #[test]
962 fn application_name_eq_max_length() {
963 let value = repeat('a', 63);
964
965 let application_name =
966 ApplicationName::from_str(&value).expect("expected valid value equal to max");
967
968 assert_eq!(application_name, ApplicationName(value));
969 }
970
971 #[test]
972 fn application_name_gt_max_length() {
973 let value = repeat('a', 64);
974
975 let err = ApplicationName::from_str(&value).expect_err("expected max length failure");
976
977 assert_eq!(
978 err,
979 "ApplicationName byte max length: {max_length} violated, got: {actual}"
980 );
981 }
982
983 #[test]
984 fn application_name_contains_nul() {
985 let value = String::from('\0');
986
987 let err = ApplicationName::from_str(&value).expect_err("expected NUL failure");
988
989 assert_eq!(err, "ApplicationName contains NUL byte");
990 }
991
992 #[test]
993 fn database_lt_min_length() {
994 let value = String::new();
995
996 let err = Database::from_str(&value).expect_err("expected min length failure");
997
998 assert_eq!(
999 err,
1000 "Database byte min length: {min_length} violated, got: {actual}"
1001 );
1002 }
1003
1004 #[test]
1005 fn database_eq_min_length() {
1006 let value = repeat('d', 1);
1007
1008 let database = Database::from_str(&value).expect("expected valid min length value");
1009
1010 assert_eq!(database, Database(value));
1011 }
1012
1013 #[test]
1014 fn database_gt_min_length() {
1015 let value = repeat('d', 2);
1016
1017 let database = Database::from_str(&value).expect("expected valid value greater than min");
1018
1019 assert_eq!(database, Database(value));
1020 }
1021
1022 #[test]
1023 fn database_lt_max_length() {
1024 let value = repeat('d', 62);
1025
1026 let database = Database::from_str(&value).expect("expected valid value less than max");
1027
1028 assert_eq!(database, Database(value));
1029 }
1030
1031 #[test]
1032 fn database_eq_max_length() {
1033 let value = repeat('d', 63);
1034
1035 let database = Database::from_str(&value).expect("expected valid value equal to max");
1036
1037 assert_eq!(database, Database(value));
1038 }
1039
1040 #[test]
1041 fn database_gt_max_length() {
1042 let value = repeat('d', 64);
1043
1044 let err = Database::from_str(&value).expect_err("expected max length failure");
1045
1046 assert_eq!(
1047 err,
1048 "Database byte max length: {max_length} violated, got: {actual}"
1049 );
1050 }
1051
1052 #[test]
1053 fn database_contains_nul() {
1054 let value = String::from('\0');
1055
1056 let err = Database::from_str(&value).expect_err("expected NUL failure");
1057
1058 assert_eq!(err, "Database contains NUL byte");
1059 }
1060
1061 #[test]
1062 fn username_lt_min_length() {
1063 let value = String::new();
1064
1065 let err = Username::from_str(&value).expect_err("expected min length failure");
1066
1067 assert_eq!(
1068 err,
1069 "Username byte min length: {min_length} violated, got: {actual}"
1070 );
1071 }
1072
1073 #[test]
1074 fn username_eq_min_length() {
1075 let value = repeat('u', 1);
1076
1077 let username = Username::from_str(&value).expect("expected valid min length value");
1078
1079 assert_eq!(username, Username(value));
1080 }
1081
1082 #[test]
1083 fn username_gt_min_length() {
1084 let value = repeat('u', 2);
1085
1086 let username = Username::from_str(&value).expect("expected valid value greater than min");
1087
1088 assert_eq!(username, Username(value));
1089 }
1090
1091 #[test]
1092 fn username_lt_max_length() {
1093 let value = repeat('u', 62);
1094
1095 let username = Username::from_str(&value).expect("expected valid value less than max");
1096
1097 assert_eq!(username, Username(value));
1098 }
1099
1100 #[test]
1101 fn username_eq_max_length() {
1102 let value = repeat('u', 63);
1103
1104 let username = Username::from_str(&value).expect("expected valid value equal to max");
1105
1106 assert_eq!(username, Username(value));
1107 }
1108
1109 #[test]
1110 fn username_gt_max_length() {
1111 let value = repeat('u', 64);
1112
1113 let err = Username::from_str(&value).expect_err("expected max length failure");
1114
1115 assert_eq!(
1116 err,
1117 "Username byte max length: {max_length} violated, got: {actual}"
1118 );
1119 }
1120
1121 #[test]
1122 fn username_contains_nul() {
1123 let value = String::from('\0');
1124
1125 let err = Username::from_str(&value).expect_err("expected NUL failure");
1126
1127 assert_eq!(err, "Username contains NUL byte");
1128 }
1129
1130 #[test]
1131 fn password_eq_min_length() {
1132 let value = String::new();
1133
1134 let password = Password::from_str(&value).expect("expected valid min length value");
1135
1136 assert_eq!(password, Password(value));
1137 }
1138
1139 #[test]
1140 fn password_gt_min_length() {
1141 let value = repeat('p', 1);
1142
1143 let password = Password::from_str(&value).expect("expected valid value greater than min");
1144
1145 assert_eq!(password, Password(value));
1146 }
1147
1148 #[test]
1149 fn password_lt_max_length() {
1150 let value = repeat('p', 4095);
1151
1152 let password = Password::from_str(&value).expect("expected valid value less than max");
1153
1154 assert_eq!(password, Password(value));
1155 }
1156
1157 #[test]
1158 fn password_eq_max_length() {
1159 let value = repeat('p', 4096);
1160
1161 let password = Password::from_str(&value).expect("expected valid value equal to max");
1162
1163 assert_eq!(password, Password(value));
1164 }
1165
1166 #[test]
1167 fn password_gt_max_length() {
1168 let value = repeat('p', 4097);
1169
1170 let err = Password::from_str(&value).expect_err("expected max length failure");
1171
1172 assert_eq!(
1173 err,
1174 "Password byte max length: {max_length} violated, got: {actual}"
1175 );
1176 }
1177
1178 #[test]
1179 fn password_contains_nul() {
1180 let value = String::from('\0');
1181
1182 let err = Password::from_str(&value).expect_err("expected NUL failure");
1183
1184 assert_eq!(err, "Password contains NUL byte");
1185 }
1186
1187 #[test]
1188 fn test_json() {
1189 let config = Config {
1190 application_name: None,
1191 database: Database::from_str("some-database").unwrap(),
1192 endpoint: Endpoint::Network {
1193 host: Host::from_str("some-host").unwrap(),
1194 host_addr: None,
1195 port: Some(Port(5432)),
1196 },
1197 password: None,
1198 ssl_mode: SslMode::VerifyFull,
1199 ssl_root_cert: None,
1200 username: Username::from_str("some-username").unwrap(),
1201 };
1202
1203 assert_config(
1204 serde_json::json!({
1205 "database": "some-database",
1206 "endpoint": {
1207 "host": "some-host",
1208 "port": 5432,
1209 },
1210 "ssl_mode": "verify-full",
1211 "url": "postgres://some-username@some-host:5432/some-database?sslmode=verify-full",
1212 "username": "some-username",
1213 }),
1214 &config,
1215 );
1216
1217 assert_config(
1218 serde_json::json!({
1219 "application_name": "some-app",
1220 "database": "some-database",
1221 "endpoint": {
1222 "host": "some-host",
1223 "port": 5432,
1224 },
1225 "password": "some-password",
1226 "ssl_mode": "verify-full",
1227 "ssl_root_cert": {
1228 "file": "/some.pem"
1229 },
1230 "url": "postgres://some-username:some-password@some-host:5432/some-database?application_name=some-app&sslmode=verify-full&sslrootcert=%2Fsome.pem",
1231 "username": "some-username"
1232 }),
1233 &Config {
1234 application_name: Some(ApplicationName::from_str("some-app").unwrap()),
1235 password: Some(Password::from_str("some-password").unwrap()),
1236 ssl_root_cert: Some(SslRootCert::File("/some.pem".into())),
1237 ..config.clone()
1238 },
1239 );
1240
1241 assert_config(
1242 serde_json::json!({
1243 "database": "some-database",
1244 "endpoint": {
1245 "host": "127.0.0.1",
1246 "port": 5432,
1247 },
1248 "ssl_mode": "verify-full",
1249 "url": "postgres://some-username@127.0.0.1:5432/some-database?sslmode=verify-full",
1250 "username": "some-username"
1251 }),
1252 &Config {
1253 endpoint: Endpoint::Network {
1254 host: Host::from_str("127.0.0.1").unwrap(),
1255 host_addr: None,
1256 port: Some(Port(5432)),
1257 },
1258 ..config.clone()
1259 },
1260 );
1261
1262 assert_config(
1263 serde_json::json!({
1264 "database": "some-database",
1265 "endpoint": {
1266 "socket_path": "/some/socket",
1267 },
1268 "ssl_mode": "verify-full",
1269 "url": "postgres://?host=%2Fsome%2Fsocket&dbname=some-database&user=some-username&sslmode=verify-full",
1270 "username": "some-username"
1271 }),
1272 &Config {
1273 endpoint: Endpoint::SocketPath("/some/socket".into()),
1274 ..config.clone()
1275 },
1276 );
1277
1278 assert_config(
1279 serde_json::json!({
1280 "database": "some-database",
1281 "endpoint": {
1282 "host": "some-host",
1283 "port": 5432,
1284 },
1285 "ssl_mode": "verify-full",
1286 "ssl_root_cert": "system",
1287 "url": "postgres://some-username@some-host:5432/some-database?sslmode=verify-full&sslrootcert=system",
1288 "username": "some-username"
1289 }),
1290 &Config {
1291 ssl_root_cert: Some(SslRootCert::System),
1292 ..config.clone()
1293 },
1294 );
1295
1296 assert_config(
1297 serde_json::json!({
1298 "database": "some-database",
1299 "endpoint": {
1300 "host": "some-host",
1301 "host_addr": "192.168.1.100",
1302 "port": 5432,
1303 },
1304 "ssl_mode": "verify-full",
1305 "url": "postgres://some-username@some-host:5432/some-database?hostaddr=192.168.1.100&sslmode=verify-full",
1306 "username": "some-username"
1307 }),
1308 &Config {
1309 endpoint: Endpoint::Network {
1310 host: Host::from_str("some-host").unwrap(),
1311 host_addr: Some("192.168.1.100".parse().unwrap()),
1312 port: Some(Port(5432)),
1313 },
1314 ..config.clone()
1315 },
1316 );
1317
1318 assert_config(
1320 serde_json::json!({
1321 "database": "some-database",
1322 "endpoint": {
1323 "host": "some-host",
1324 },
1325 "ssl_mode": "verify-full",
1326 "url": "postgres://some-username@some-host/some-database?sslmode=verify-full",
1327 "username": "some-username"
1328 }),
1329 &Config {
1330 endpoint: Endpoint::Network {
1331 host: Host::from_str("some-host").unwrap(),
1332 host_addr: None,
1333 port: None,
1334 },
1335 ..config.clone()
1336 },
1337 );
1338
1339 assert_config(
1341 serde_json::json!({
1342 "database": "some-database",
1343 "endpoint": {
1344 "host": "some-host",
1345 "host_addr": "10.0.0.1",
1346 },
1347 "ssl_mode": "verify-full",
1348 "url": "postgres://some-username@some-host/some-database?hostaddr=10.0.0.1&sslmode=verify-full",
1349 "username": "some-username"
1350 }),
1351 &Config {
1352 endpoint: Endpoint::Network {
1353 host: Host::from_str("some-host").unwrap(),
1354 host_addr: Some("10.0.0.1".parse().unwrap()),
1355 port: None,
1356 },
1357 ..config.clone()
1358 },
1359 );
1360 }
1361
1362 #[test]
1363 fn test_ipv6_url_formation() {
1364 let config_ipv6_loopback = Config {
1366 application_name: None,
1367 database: Database::from_str("testdb").unwrap(),
1368 endpoint: Endpoint::Network {
1369 host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)),
1370 host_addr: None,
1371 port: Some(Port(5432)),
1372 },
1373 password: None,
1374 ssl_mode: SslMode::Disable,
1375 ssl_root_cert: None,
1376 username: Username::from_str("postgres").unwrap(),
1377 };
1378
1379 let url = config_ipv6_loopback.to_url();
1380 assert_eq!(
1381 url.to_string(),
1382 "postgres://postgres@[::1]:5432/testdb?sslmode=disable",
1383 "IPv6 loopback address should be bracketed in URL"
1384 );
1385
1386 let config_ipv6_fe80 = Config {
1388 application_name: None,
1389 database: Database::from_str("testdb").unwrap(),
1390 endpoint: Endpoint::Network {
1391 host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::new(
1392 0xfe80, 0, 0, 0, 0, 0, 0, 1,
1393 ))),
1394 host_addr: None,
1395 port: Some(Port(5432)),
1396 },
1397 password: None,
1398 ssl_mode: SslMode::Disable,
1399 ssl_root_cert: None,
1400 username: Username::from_str("postgres").unwrap(),
1401 };
1402
1403 let url = config_ipv6_fe80.to_url();
1404 assert_eq!(
1405 url.to_string(),
1406 "postgres://postgres@[fe80::1]:5432/testdb?sslmode=disable",
1407 "IPv6 link-local address should be bracketed in URL"
1408 );
1409
1410 let config_ipv6_full = Config {
1412 application_name: None,
1413 database: Database::from_str("testdb").unwrap(),
1414 endpoint: Endpoint::Network {
1415 host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::new(
1416 0x2001, 0x0db8, 0, 0, 0, 0, 0, 1,
1417 ))),
1418 host_addr: None,
1419 port: Some(Port(5432)),
1420 },
1421 password: None,
1422 ssl_mode: SslMode::Disable,
1423 ssl_root_cert: None,
1424 username: Username::from_str("postgres").unwrap(),
1425 };
1426
1427 let url = config_ipv6_full.to_url();
1428 assert_eq!(
1429 url.to_string(),
1430 "postgres://postgres@[2001:db8::1]:5432/testdb?sslmode=disable",
1431 "Full IPv6 address should be bracketed in URL"
1432 );
1433
1434 let config_ipv4 = Config {
1436 application_name: None,
1437 database: Database::from_str("testdb").unwrap(),
1438 endpoint: Endpoint::Network {
1439 host: Host::IpAddr(std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)),
1440 host_addr: None,
1441 port: Some(Port(5432)),
1442 },
1443 password: None,
1444 ssl_mode: SslMode::Disable,
1445 ssl_root_cert: None,
1446 username: Username::from_str("postgres").unwrap(),
1447 };
1448
1449 let url = config_ipv4.to_url();
1450 assert_eq!(
1451 url.to_string(),
1452 "postgres://postgres@127.0.0.1:5432/testdb?sslmode=disable",
1453 "IPv4 address should NOT be bracketed in URL"
1454 );
1455
1456 let config_hostname = Config {
1458 application_name: None,
1459 database: Database::from_str("testdb").unwrap(),
1460 endpoint: Endpoint::Network {
1461 host: Host::from_str("localhost").unwrap(),
1462 host_addr: None,
1463 port: Some(Port(5432)),
1464 },
1465 password: None,
1466 ssl_mode: SslMode::Disable,
1467 ssl_root_cert: None,
1468 username: Username::from_str("postgres").unwrap(),
1469 };
1470
1471 let url = config_hostname.to_url();
1472 assert_eq!(
1473 url.to_string(),
1474 "postgres://postgres@localhost:5432/testdb?sslmode=disable",
1475 "Hostname should NOT be bracketed in URL"
1476 );
1477 }
1478}