Skip to main content

pg_client/
lib.rs

1#![doc = include_str!("../README.md")]
2
3pub mod identifier;
4
5pub use identifier::{Database, Role, User};
6
7#[cfg(feature = "sqlx")]
8pub mod sqlx;
9
10pub mod url;
11
12/// Macro to generate `std::str::FromStr` plus helpers for string wrapped newtypes
13macro_rules! from_str_impl {
14    ($struct: ident, $min: expr, $max: expr) => {
15        impl std::str::FromStr for $struct {
16            type Err = String;
17
18            fn from_str(value: &str) -> Result<Self, Self::Err> {
19                let min_length = Self::MIN_LENGTH;
20                let max_length = Self::MAX_LENGTH;
21                let actual = value.len();
22
23                if actual < min_length {
24                    Err(format!(
25                        "{} byte min length: {min_length} violated, got: {actual}",
26                        stringify!($struct)
27                    ))
28                } else if actual > max_length {
29                    Err(format!(
30                        "{} byte max length: {max_length} violated, got: {actual}",
31                        stringify!($struct)
32                    ))
33                } else if value.as_bytes().contains(&0) {
34                    Err(format!("{} contains NUL byte", stringify!($struct)))
35                } else {
36                    Ok(Self(value.to_string()))
37                }
38            }
39        }
40
41        impl AsRef<str> for $struct {
42            fn as_ref(&self) -> &str {
43                &self.0
44            }
45        }
46
47        impl $struct {
48            pub const MIN_LENGTH: usize = $min;
49            pub const MAX_LENGTH: usize = $max;
50
51            pub fn as_str(&self) -> &str {
52                &self.0
53            }
54        }
55    };
56}
57
58#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
59pub struct HostName(String);
60
61impl HostName {
62    #[must_use]
63    pub fn as_str(&self) -> &str {
64        &self.0
65    }
66}
67
68impl std::str::FromStr for HostName {
69    type Err = &'static str;
70
71    fn from_str(value: &str) -> Result<Self, Self::Err> {
72        if hostname_validator::is_valid(value) {
73            Ok(Self(value.to_string()))
74        } else {
75            Err("invalid host name")
76        }
77    }
78}
79
80impl<'de> serde::Deserialize<'de> for HostName {
81    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
82    where
83        D: serde::Deserializer<'de>,
84    {
85        let s = String::deserialize(deserializer)?;
86        s.parse().map_err(serde::de::Error::custom)
87    }
88}
89
90#[derive(Clone, Debug, PartialEq, Eq)]
91pub enum Host {
92    HostName(HostName),
93    IpAddr(std::net::IpAddr),
94}
95
96impl serde::Serialize for Host {
97    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
98        serializer.serialize_str(&self.pg_env_value())
99    }
100}
101
102impl Host {
103    pub(crate) fn pg_env_value(&self) -> String {
104        match self {
105            Self::HostName(value) => value.0.clone(),
106            Self::IpAddr(value) => value.to_string(),
107        }
108    }
109}
110
111impl std::str::FromStr for Host {
112    type Err = &'static str;
113
114    fn from_str(value: &str) -> Result<Self, Self::Err> {
115        match std::net::IpAddr::from_str(value) {
116            Ok(addr) => Ok(Self::IpAddr(addr)),
117            Err(_) => match HostName::from_str(value) {
118                Ok(host_name) => Ok(Self::HostName(host_name)),
119                Err(_) => Err("Not a socket address or FQDN"),
120            },
121        }
122    }
123}
124
125impl From<HostName> for Host {
126    fn from(value: HostName) -> Self {
127        Self::HostName(value)
128    }
129}
130
131impl From<std::net::IpAddr> for Host {
132    fn from(value: std::net::IpAddr) -> Self {
133        Self::IpAddr(value)
134    }
135}
136
137#[derive(Clone, Debug, PartialEq, Eq)]
138pub struct HostAddr(std::net::IpAddr);
139
140impl HostAddr {
141    #[must_use]
142    pub const fn new(ip: std::net::IpAddr) -> Self {
143        Self(ip)
144    }
145}
146
147impl From<std::net::IpAddr> for HostAddr {
148    /// # Example
149    /// ```
150    /// use pg_client::HostAddr;
151    /// use std::net::IpAddr;
152    ///
153    /// let ip: IpAddr = "192.168.1.1".parse().unwrap();
154    /// let host_addr = HostAddr::from(ip);
155    /// assert_eq!(IpAddr::from(host_addr).to_string(), "192.168.1.1");
156    /// ```
157    fn from(value: std::net::IpAddr) -> Self {
158        Self(value)
159    }
160}
161
162impl From<HostAddr> for std::net::IpAddr {
163    fn from(value: HostAddr) -> Self {
164        value.0
165    }
166}
167
168impl From<&HostAddr> for std::net::IpAddr {
169    fn from(value: &HostAddr) -> Self {
170        value.0
171    }
172}
173
174impl std::fmt::Display for HostAddr {
175    /// # Example
176    /// ```
177    /// use pg_client::HostAddr;
178    ///
179    /// let host_addr: HostAddr = "10.0.0.1".parse().unwrap();
180    /// assert_eq!(host_addr.to_string(), "10.0.0.1");
181    /// ```
182    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        write!(formatter, "{}", self.0)
184    }
185}
186
187impl std::str::FromStr for HostAddr {
188    type Err = &'static str;
189
190    /// # Example
191    /// ```
192    /// use pg_client::HostAddr;
193    /// use std::str::FromStr;
194    ///
195    /// let host_addr = HostAddr::from_str("127.0.0.1").unwrap();
196    /// assert_eq!(host_addr.to_string(), "127.0.0.1");
197    ///
198    /// // Also works with the parse method
199    /// let host_addr: HostAddr = "::1".parse().unwrap();
200    /// assert_eq!(host_addr.to_string(), "::1");
201    ///
202    /// // Invalid IP addresses return an error
203    /// assert!(HostAddr::from_str("not-an-ip").is_err());
204    /// ```
205    fn from_str(value: &str) -> Result<Self, Self::Err> {
206        match std::net::IpAddr::from_str(value) {
207            Ok(addr) => Ok(Self(addr)),
208            Err(_) => Err("invalid IP address"),
209        }
210    }
211}
212
213#[derive(Clone, Debug, PartialEq, Eq)]
214pub enum Endpoint {
215    Network {
216        host: Host,
217        channel_binding: Option<ChannelBinding>,
218        host_addr: Option<HostAddr>,
219        port: Option<Port>,
220    },
221    SocketPath(std::path::PathBuf),
222}
223
224impl serde::Serialize for Endpoint {
225    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
226        use serde::ser::SerializeStruct;
227        match self {
228            Self::Network {
229                host,
230                channel_binding,
231                host_addr,
232                port,
233            } => {
234                let mut state = serializer.serialize_struct("Endpoint", 4)?;
235                state.serialize_field("host", host)?;
236                if let Some(channel_binding) = channel_binding {
237                    state.serialize_field("channel_binding", channel_binding)?;
238                }
239                if let Some(addr) = host_addr {
240                    state.serialize_field("host_addr", &addr.to_string())?;
241                }
242                if let Some(port) = port {
243                    state.serialize_field("port", port)?;
244                }
245                state.end()
246            }
247            Self::SocketPath(path) => {
248                let mut state = serializer.serialize_struct("Endpoint", 1)?;
249                state.serialize_field(
250                    "socket_path",
251                    &path.to_str().expect("socket path contains invalid utf8"),
252                )?;
253                state.end()
254            }
255        }
256    }
257}
258
259#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize)]
260pub struct Port(u16);
261
262impl Port {
263    #[must_use]
264    pub const fn new(port: u16) -> Self {
265        Self(port)
266    }
267
268    fn pg_env_value(self) -> String {
269        self.0.to_string()
270    }
271}
272
273impl std::str::FromStr for Port {
274    type Err = &'static str;
275
276    fn from_str(value: &str) -> Result<Self, Self::Err> {
277        match <u16 as std::str::FromStr>::from_str(value) {
278            Ok(port) => Ok(Port(port)),
279            Err(_) => Err("invalid postgresql port string"),
280        }
281    }
282}
283
284impl From<u16> for Port {
285    fn from(port: u16) -> Self {
286        Self(port)
287    }
288}
289
290impl From<Port> for u16 {
291    fn from(port: Port) -> Self {
292        port.0
293    }
294}
295
296impl From<&Port> for u16 {
297    fn from(port: &Port) -> Self {
298        port.0
299    }
300}
301
302#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
303pub struct ApplicationName(String);
304
305from_str_impl!(ApplicationName, 1, 63);
306
307impl ApplicationName {
308    fn pg_env_value(&self) -> String {
309        self.0.clone()
310    }
311}
312
313impl Database {
314    fn pg_env_value(&self) -> String {
315        self.as_str().to_owned()
316    }
317}
318
319impl Role {
320    fn pg_env_value(&self) -> String {
321        self.as_str().to_owned()
322    }
323}
324
325#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
326pub struct Password(String);
327
328from_str_impl!(Password, 0, 4096);
329
330impl Password {
331    fn pg_env_value(&self) -> String {
332        self.0.clone()
333    }
334}
335
336#[derive(
337    Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, strum::IntoStaticStr, strum::EnumString,
338)]
339#[serde(rename_all = "kebab-case")]
340#[strum(serialize_all = "kebab-case")]
341pub enum SslMode {
342    Allow,
343    Disable,
344    Prefer,
345    Require,
346    VerifyCa,
347    VerifyFull,
348}
349
350impl SslMode {
351    #[must_use]
352    pub fn as_str(&self) -> &'static str {
353        self.into()
354    }
355
356    fn pg_env_value(&self) -> String {
357        self.as_str().to_string()
358    }
359}
360
361#[derive(
362    Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, strum::IntoStaticStr, strum::EnumString,
363)]
364#[serde(rename_all = "kebab-case")]
365#[strum(serialize_all = "kebab-case")]
366pub enum ChannelBinding {
367    Disable,
368    Prefer,
369    Require,
370}
371
372impl ChannelBinding {
373    #[must_use]
374    pub fn as_str(&self) -> &'static str {
375        self.into()
376    }
377
378    fn pg_env_value(&self) -> String {
379        self.as_str().to_string()
380    }
381}
382
383#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
384#[serde(rename_all = "kebab-case")]
385pub enum SslRootCert {
386    File(std::path::PathBuf),
387    System,
388}
389
390impl SslRootCert {
391    pub(crate) fn pg_env_value(&self) -> String {
392        match self {
393            Self::File(path) => path.to_str().unwrap().to_string(),
394            Self::System => "system".to_string(),
395        }
396    }
397}
398
399impl From<std::path::PathBuf> for SslRootCert {
400    fn from(value: std::path::PathBuf) -> Self {
401        Self::File(value)
402    }
403}
404
405#[derive(Clone, Debug, PartialEq, Eq)]
406/// PG connection config with various presentation modes.
407///
408/// Supported:
409///
410/// 1. Env variables via `to_pg_env()`
411/// 2. JSON document via `serde`
412/// 3. sqlx connect options via `to_sqlx_connect_options()`
413/// 4. Individual field access
414pub struct Config {
415    pub application_name: Option<ApplicationName>,
416    pub database: Database,
417    pub endpoint: Endpoint,
418    pub password: Option<Password>,
419    pub ssl_mode: SslMode,
420    pub ssl_root_cert: Option<SslRootCert>,
421    pub user: User,
422}
423
424pub const PGAPPNAME: cmd_proc::EnvVariableName<'static> =
425    cmd_proc::EnvVariableName::from_static_or_panic("PGAPPNAME");
426pub const PGCHANNELBINDING: cmd_proc::EnvVariableName<'static> =
427    cmd_proc::EnvVariableName::from_static_or_panic("PGCHANNELBINDING");
428pub const PGDATABASE: cmd_proc::EnvVariableName<'static> =
429    cmd_proc::EnvVariableName::from_static_or_panic("PGDATABASE");
430pub const PGHOST: cmd_proc::EnvVariableName<'static> =
431    cmd_proc::EnvVariableName::from_static_or_panic("PGHOST");
432pub const PGHOSTADDR: cmd_proc::EnvVariableName<'static> =
433    cmd_proc::EnvVariableName::from_static_or_panic("PGHOSTADDR");
434pub const PGPASSWORD: cmd_proc::EnvVariableName<'static> =
435    cmd_proc::EnvVariableName::from_static_or_panic("PGPASSWORD");
436pub const PGPORT: cmd_proc::EnvVariableName<'static> =
437    cmd_proc::EnvVariableName::from_static_or_panic("PGPORT");
438pub const PGSSLMODE: cmd_proc::EnvVariableName<'static> =
439    cmd_proc::EnvVariableName::from_static_or_panic("PGSSLMODE");
440pub const PGSSLROOTCERT: cmd_proc::EnvVariableName<'static> =
441    cmd_proc::EnvVariableName::from_static_or_panic("PGSSLROOTCERT");
442pub const PGUSER: cmd_proc::EnvVariableName<'static> =
443    cmd_proc::EnvVariableName::from_static_or_panic("PGUSER");
444
445impl serde::Serialize for Config {
446    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
447        use serde::ser::SerializeStruct;
448        let mut state = serializer.serialize_struct("Config", 8)?;
449
450        if let Some(application_name) = &self.application_name {
451            state.serialize_field("application_name", application_name)?;
452        }
453
454        state.serialize_field("database", &self.database)?;
455        state.serialize_field("endpoint", &self.endpoint)?;
456
457        if let Some(password) = &self.password {
458            state.serialize_field("password", password)?;
459        }
460
461        state.serialize_field("ssl_mode", &self.ssl_mode)?;
462
463        if let Some(ssl_root_cert) = &self.ssl_root_cert {
464            state.serialize_field("ssl_root_cert", ssl_root_cert)?;
465        }
466
467        state.serialize_field("user", &self.user)?;
468        state.serialize_field("url", &self.to_url_string())?;
469
470        state.end()
471    }
472}
473
474impl Config {
475    /// Convert to PG connection URL
476    ///
477    /// ```
478    /// # use pg_client::*;
479    /// # use std::str::FromStr;
480    ///
481    /// let config = Config {
482    ///     application_name: None,
483    ///     database: Database::from_static_or_panic("some-database"),
484    ///     endpoint: Endpoint::Network {
485    ///         host: Host::from_str("some-host").unwrap(),
486    ///         channel_binding: None,
487    ///         host_addr: None,
488    ///         port: Some(Port::new(5432)),
489    ///     },
490    ///     password: None,
491    ///     ssl_mode: SslMode::VerifyFull,
492    ///     ssl_root_cert: None,
493    ///     user: User::from_static_or_panic("some-user"),
494    /// };
495    ///
496    /// assert_eq!(
497    ///     config.to_url_string(),
498    ///     "postgres://some-user@some-host:5432/some-database?sslmode=verify-full"
499    /// );
500    ///
501    /// assert_eq!(
502    ///     Config {
503    ///         application_name: Some(ApplicationName::from_str("some-app").unwrap()),
504    ///         password: Some(Password::from_str("some-password").unwrap()),
505    ///         ssl_root_cert: Some(SslRootCert::File("/some.pem".into())),
506    ///         ..config.clone()
507    ///     }.to_url_string(),
508    ///     "postgres://some-user:some-password@some-host:5432/some-database?application_name=some-app&sslmode=verify-full&sslrootcert=%2Fsome.pem"
509    /// );
510    ///
511    /// assert_eq!(
512    ///     Config {
513    ///         endpoint: Endpoint::Network {
514    ///             host: Host::from_str("some-host").unwrap(),
515    ///             channel_binding: None,
516    ///             host_addr: Some("127.0.0.1".parse().unwrap()),
517    ///             port: Some(Port::new(5432)),
518    ///         },
519    ///         ..config.clone()
520    ///     }.to_url_string(),
521    ///     "postgres://some-user@some-host:5432/some-database?hostaddr=127.0.0.1&sslmode=verify-full"
522    /// );
523    ///
524    /// // IPv4 example
525    /// let ipv4_config = Config {
526    ///     application_name: None,
527    ///     database: Database::from_static_or_panic("mydb"),
528    ///     endpoint: Endpoint::Network {
529    ///         host: Host::IpAddr(std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1))),
530    ///         channel_binding: None,
531    ///         host_addr: None,
532    ///         port: Some(Port::new(5432)),
533    ///     },
534    ///     password: None,
535    ///     ssl_mode: SslMode::Disable,
536    ///     ssl_root_cert: None,
537    ///     user: User::from_static_or_panic("user"),
538    /// };
539    /// assert_eq!(
540    ///     ipv4_config.to_url_string(),
541    ///     "postgres://user@127.0.0.1:5432/mydb?sslmode=disable"
542    /// );
543    ///
544    /// // IPv6 example (automatically bracketed)
545    /// let ipv6_config = Config {
546    ///     application_name: None,
547    ///     database: Database::from_static_or_panic("mydb"),
548    ///     endpoint: Endpoint::Network {
549    ///         host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)),
550    ///         channel_binding: None,
551    ///         host_addr: None,
552    ///         port: Some(Port::new(5432)),
553    ///     },
554    ///     password: None,
555    ///     ssl_mode: SslMode::Disable,
556    ///     ssl_root_cert: None,
557    ///     user: User::from_static_or_panic("user"),
558    /// };
559    /// assert_eq!(
560    ///     ipv6_config.to_url_string(),
561    ///     "postgres://user@[::1]:5432/mydb?sslmode=disable"
562    /// );
563    /// ```
564    #[must_use]
565    pub fn to_url(&self) -> ::fluent_uri::Uri<String> {
566        use ::fluent_uri::{
567            Uri,
568            build::Builder,
569            component::{Authority, Scheme},
570            pct_enc::{EStr, EString, encoder},
571        };
572
573        const POSTGRES: &Scheme = Scheme::new_or_panic("postgres");
574
575        fn append_query_pair(query: &mut EString<encoder::Query>, key: &str, value: &str) {
576            if !query.is_empty() {
577                query.push('&');
578            }
579            query.encode_str::<encoder::Data>(key);
580            query.push('=');
581            query.encode_str::<encoder::Data>(value);
582        }
583
584        let mut query = EString::<encoder::Query>::new();
585
586        match &self.endpoint {
587            Endpoint::Network {
588                host,
589                channel_binding,
590                host_addr,
591                port,
592            } => {
593                let mut userinfo = EString::<encoder::Userinfo>::new();
594                userinfo.encode_str::<encoder::Data>(self.user.pg_env_value().as_str());
595                if let Some(password) = &self.password {
596                    userinfo.push(':');
597                    userinfo.encode_str::<encoder::Data>(password.as_str());
598                }
599
600                let mut path = EString::<encoder::Path>::new();
601                path.push('/');
602                path.encode_str::<encoder::Data>(self.database.as_str());
603
604                if let Some(addr) = host_addr {
605                    append_query_pair(&mut query, "hostaddr", &addr.to_string());
606                }
607                if let Some(channel_binding) = channel_binding {
608                    append_query_pair(&mut query, "channel_binding", channel_binding.as_str());
609                }
610                self.append_common_query_params(&mut query, append_query_pair);
611
612                let non_empty_query = if query.is_empty() {
613                    None
614                } else {
615                    Some(query.as_estr())
616                };
617
618                // build() only fails on RFC 3986 structural violations:
619                // scheme and authority are always present, path starts with '/'.
620                Uri::builder()
621                    .scheme(POSTGRES)
622                    .authority_with(|builder| {
623                        let builder = builder.userinfo(&userinfo);
624                        let builder = match host {
625                            Host::IpAddr(addr) => builder.host(*addr),
626                            Host::HostName(name) => {
627                                let mut encoded = EString::<encoder::RegName>::new();
628                                encoded.encode_str::<encoder::Data>(name.as_str());
629                                builder.host(encoded.as_estr())
630                            }
631                        };
632                        match port {
633                            Some(port) => builder.port(port.0),
634                            None => builder.advance(),
635                        }
636                    })
637                    .path(&path)
638                    .optional(Builder::query, non_empty_query)
639                    .build()
640                    .unwrap()
641            }
642            Endpoint::SocketPath(path) => {
643                append_query_pair(
644                    &mut query,
645                    "host",
646                    path.to_str().expect("socket path contains invalid utf8"),
647                );
648                append_query_pair(&mut query, "dbname", self.database.as_str());
649                append_query_pair(&mut query, "user", self.user.pg_env_value().as_str());
650                if let Some(password) = &self.password {
651                    append_query_pair(&mut query, "password", password.as_str());
652                }
653                self.append_common_query_params(&mut query, append_query_pair);
654
655                // build() only fails on RFC 3986 structural violations:
656                // scheme and authority are always present, path is empty.
657                Uri::builder()
658                    .scheme(POSTGRES)
659                    .authority(Authority::EMPTY)
660                    .path(EStr::EMPTY)
661                    .query(&query)
662                    .build()
663                    .unwrap()
664            }
665        }
666    }
667
668    /// Convert to PG connection URL string
669    #[must_use]
670    pub fn to_url_string(&self) -> String {
671        self.to_url().into_string()
672    }
673
674    fn append_common_query_params(
675        &self,
676        query: &mut ::fluent_uri::pct_enc::EString<::fluent_uri::pct_enc::encoder::Query>,
677        append_query_pair: fn(
678            &mut ::fluent_uri::pct_enc::EString<::fluent_uri::pct_enc::encoder::Query>,
679            &str,
680            &str,
681        ),
682    ) {
683        if let Some(application_name) = &self.application_name {
684            append_query_pair(query, "application_name", application_name.as_str());
685        }
686        append_query_pair(query, "sslmode", &self.ssl_mode.pg_env_value());
687        if let Some(ssl_root_cert) = &self.ssl_root_cert {
688            append_query_pair(query, "sslrootcert", &ssl_root_cert.pg_env_value());
689        }
690    }
691
692    /// Convert to PG environment variable names
693    ///
694    /// ```
695    /// # use pg_client::*;
696    /// # use std::collections::BTreeMap;
697    ///
698    /// let config = Config {
699    ///     application_name: None,
700    ///     database: "some-database".parse().unwrap(),
701    ///     endpoint: Endpoint::Network {
702    ///         host: "some-host".parse().unwrap(),
703    ///         channel_binding: None,
704    ///         host_addr: None,
705    ///         port: Some(Port::new(5432)),
706    ///     },
707    ///     password: None,
708    ///     ssl_mode: SslMode::VerifyFull,
709    ///     ssl_root_cert: None,
710    ///     user: "some-user".parse().unwrap(),
711    /// };
712    ///
713    /// let expected = BTreeMap::from([
714    ///     (PGDATABASE, "some-database".to_string()),
715    ///     (PGHOST, "some-host".to_string()),
716    ///     (PGPORT, "5432".to_string()),
717    ///     (PGSSLMODE, "verify-full".to_string()),
718    ///     (PGUSER, "some-user".to_string()),
719    /// ]);
720    ///
721    /// assert_eq!(expected, config.to_pg_env());
722    ///
723    /// let config_with_optionals = Config {
724    ///     application_name: Some("some-app".parse().unwrap()),
725    ///     endpoint: Endpoint::Network {
726    ///         host: "some-host".parse().unwrap(),
727    ///         channel_binding: None,
728    ///         host_addr: Some("127.0.0.1".parse().unwrap()),
729    ///         port: Some(Port::new(5432)),
730    ///     },
731    ///     password: Some("some-password".parse().unwrap()),
732    ///     ssl_root_cert: Some(SslRootCert::File("/some.pem".into())),
733    ///     ..config
734    /// };
735    ///
736    /// let expected = BTreeMap::from([
737    ///     (PGAPPNAME, "some-app".to_string()),
738    ///     (PGDATABASE, "some-database".to_string()),
739    ///     (PGHOST, "some-host".to_string()),
740    ///     (PGHOSTADDR, "127.0.0.1".to_string()),
741    ///     (PGPASSWORD, "some-password".to_string()),
742    ///     (PGPORT, "5432".to_string()),
743    ///     (PGSSLMODE, "verify-full".to_string()),
744    ///     (PGSSLROOTCERT, "/some.pem".to_string()),
745    ///     (PGUSER, "some-user".to_string()),
746    /// ]);
747    ///
748    /// assert_eq!(expected, config_with_optionals.to_pg_env());
749    /// ```
750    #[must_use]
751    pub fn to_pg_env(
752        &self,
753    ) -> std::collections::BTreeMap<cmd_proc::EnvVariableName<'static>, String> {
754        let mut map = std::collections::BTreeMap::new();
755
756        match &self.endpoint {
757            Endpoint::Network {
758                host,
759                channel_binding,
760                host_addr,
761                port,
762            } => {
763                map.insert(PGHOST.clone(), host.pg_env_value());
764                if let Some(port) = port {
765                    map.insert(PGPORT.clone(), port.pg_env_value());
766                }
767                if let Some(channel_binding) = channel_binding {
768                    map.insert(PGCHANNELBINDING.clone(), channel_binding.pg_env_value());
769                }
770                if let Some(addr) = host_addr {
771                    map.insert(PGHOSTADDR.clone(), addr.to_string());
772                }
773            }
774            Endpoint::SocketPath(path) => {
775                map.insert(
776                    PGHOST.clone(),
777                    path.to_str()
778                        .expect("socket path contains invalid utf8")
779                        .to_string(),
780                );
781            }
782        }
783
784        map.insert(PGSSLMODE.clone(), self.ssl_mode.pg_env_value());
785        map.insert(PGUSER.clone(), self.user.pg_env_value());
786        map.insert(PGDATABASE.clone(), self.database.pg_env_value());
787
788        if let Some(application_name) = &self.application_name {
789            map.insert(PGAPPNAME.clone(), application_name.pg_env_value());
790        }
791
792        if let Some(password) = &self.password {
793            map.insert(PGPASSWORD.clone(), password.pg_env_value());
794        }
795
796        if let Some(ssl_root_cert) = &self.ssl_root_cert {
797            map.insert(PGSSLROOTCERT.clone(), ssl_root_cert.pg_env_value());
798        }
799
800        map
801    }
802
803    #[must_use]
804    pub fn endpoint(self, endpoint: Endpoint) -> Self {
805        Self { endpoint, ..self }
806    }
807
808    /// Parse a PostgreSQL connection URL string into a Config.
809    ///
810    /// When the URL does not specify `sslmode`, it defaults to `verify-full`
811    /// to ensure secure connections by default.
812    ///
813    /// See [`url::parse`] for full documentation.
814    pub fn from_str_url(url: &str) -> Result<Self, crate::url::ParseError> {
815        crate::url::parse(url)
816    }
817}
818
819#[cfg(test)]
820mod test {
821    use super::*;
822    use pretty_assertions::assert_eq;
823    use std::str::FromStr;
824
825    const TEST_DATABASE: Database = Database::from_static_or_panic("some-database");
826    const TEST_USER: User = User::from_static_or_panic("some-user");
827
828    fn assert_config(expected: serde_json::Value, config: &Config) {
829        assert_eq!(expected, serde_json::to_value(config).unwrap());
830    }
831
832    fn repeat(char: char, len: usize) -> String {
833        std::iter::repeat_n(char, len).collect()
834    }
835
836    #[test]
837    fn application_name_lt_min_length() {
838        let value = String::new();
839
840        let err = ApplicationName::from_str(&value).expect_err("expected min length failure");
841
842        assert_eq!(err, "ApplicationName byte min length: 1 violated, got: 0");
843    }
844
845    #[test]
846    fn application_name_eq_min_length() {
847        let value = repeat('a', 1);
848
849        let application_name =
850            ApplicationName::from_str(&value).expect("expected valid min length value");
851
852        assert_eq!(application_name, ApplicationName(value));
853    }
854
855    #[test]
856    fn application_name_gt_min_length() {
857        let value = repeat('a', 2);
858
859        let application_name =
860            ApplicationName::from_str(&value).expect("expected valid value greater than min");
861
862        assert_eq!(application_name, ApplicationName(value));
863    }
864
865    #[test]
866    fn application_name_lt_max_length() {
867        let value = repeat('a', 62);
868
869        let application_name =
870            ApplicationName::from_str(&value).expect("expected valid value less than max");
871
872        assert_eq!(application_name, ApplicationName(value));
873    }
874
875    #[test]
876    fn application_name_eq_max_length() {
877        let value = repeat('a', 63);
878
879        let application_name =
880            ApplicationName::from_str(&value).expect("expected valid value equal to max");
881
882        assert_eq!(application_name, ApplicationName(value));
883    }
884
885    #[test]
886    fn application_name_gt_max_length() {
887        let value = repeat('a', 64);
888
889        let err = ApplicationName::from_str(&value).expect_err("expected max length failure");
890
891        assert_eq!(err, "ApplicationName byte max length: 63 violated, got: 64");
892    }
893
894    #[test]
895    fn application_name_contains_nul() {
896        let value = String::from('\0');
897
898        let err = ApplicationName::from_str(&value).expect_err("expected NUL failure");
899
900        assert_eq!(err, "ApplicationName contains NUL byte");
901    }
902
903    #[test]
904    fn password_eq_min_length() {
905        let value = String::new();
906
907        let password = Password::from_str(&value).expect("expected valid min length value");
908
909        assert_eq!(password, Password(value));
910    }
911
912    #[test]
913    fn password_gt_min_length() {
914        let value = repeat('p', 1);
915
916        let password = Password::from_str(&value).expect("expected valid value greater than min");
917
918        assert_eq!(password, Password(value));
919    }
920
921    #[test]
922    fn password_lt_max_length() {
923        let value = repeat('p', 4095);
924
925        let password = Password::from_str(&value).expect("expected valid value less than max");
926
927        assert_eq!(password, Password(value));
928    }
929
930    #[test]
931    fn password_eq_max_length() {
932        let value = repeat('p', 4096);
933
934        let password = Password::from_str(&value).expect("expected valid value equal to max");
935
936        assert_eq!(password, Password(value));
937    }
938
939    #[test]
940    fn password_gt_max_length() {
941        let value = repeat('p', 4097);
942
943        let err = Password::from_str(&value).expect_err("expected max length failure");
944
945        assert_eq!(err, "Password byte max length: 4096 violated, got: 4097");
946    }
947
948    #[test]
949    fn password_contains_nul() {
950        let value = String::from('\0');
951
952        let err = Password::from_str(&value).expect_err("expected NUL failure");
953
954        assert_eq!(err, "Password contains NUL byte");
955    }
956
957    #[test]
958    fn test_json() {
959        let config = Config {
960            application_name: None,
961            database: TEST_DATABASE,
962            endpoint: Endpoint::Network {
963                host: Host::from_str("some-host").unwrap(),
964                channel_binding: None,
965                host_addr: None,
966                port: Some(Port::new(5432)),
967            },
968            password: None,
969            ssl_mode: SslMode::VerifyFull,
970            ssl_root_cert: None,
971            user: TEST_USER,
972        };
973
974        assert_config(
975            serde_json::json!({
976                "database": "some-database",
977                "endpoint": {
978                    "host": "some-host",
979                    "port": 5432,
980                },
981                "ssl_mode": "verify-full",
982                "url": "postgres://some-user@some-host:5432/some-database?sslmode=verify-full",
983                "user": "some-user",
984            }),
985            &config,
986        );
987
988        assert_config(
989            serde_json::json!({
990                "application_name": "some-app",
991                "database": "some-database",
992                "endpoint": {
993                    "host": "some-host",
994                    "port": 5432,
995                },
996                "password": "some-password",
997                "ssl_mode": "verify-full",
998                "ssl_root_cert": {
999                    "file": "/some.pem"
1000                },
1001                "url": "postgres://some-user:some-password@some-host:5432/some-database?application_name=some-app&sslmode=verify-full&sslrootcert=%2Fsome.pem",
1002                "user": "some-user"
1003            }),
1004            &Config {
1005                application_name: Some(ApplicationName::from_str("some-app").unwrap()),
1006                password: Some(Password::from_str("some-password").unwrap()),
1007                ssl_root_cert: Some(SslRootCert::File("/some.pem".into())),
1008                ..config.clone()
1009            },
1010        );
1011
1012        assert_config(
1013            serde_json::json!({
1014                "database": "some-database",
1015                "endpoint": {
1016                    "host": "127.0.0.1",
1017                    "port": 5432,
1018                },
1019                "ssl_mode": "verify-full",
1020                "url": "postgres://some-user@127.0.0.1:5432/some-database?sslmode=verify-full",
1021                "user": "some-user"
1022            }),
1023            &Config {
1024                endpoint: Endpoint::Network {
1025                    host: Host::from_str("127.0.0.1").unwrap(),
1026                    channel_binding: None,
1027                    host_addr: None,
1028                    port: Some(Port::new(5432)),
1029                },
1030                ..config.clone()
1031            },
1032        );
1033
1034        assert_config(
1035            serde_json::json!({
1036                "database": "some-database",
1037                "endpoint": {
1038                    "socket_path": "/some/socket",
1039                },
1040                "ssl_mode": "verify-full",
1041                "url": "postgres://?host=%2Fsome%2Fsocket&dbname=some-database&user=some-user&sslmode=verify-full",
1042                "user": "some-user"
1043            }),
1044            &Config {
1045                endpoint: Endpoint::SocketPath("/some/socket".into()),
1046                ..config.clone()
1047            },
1048        );
1049
1050        assert_config(
1051            serde_json::json!({
1052                "database": "some-database",
1053                "endpoint": {
1054                    "host": "some-host",
1055                    "port": 5432,
1056                },
1057                "ssl_mode": "verify-full",
1058                "ssl_root_cert": "system",
1059                "url": "postgres://some-user@some-host:5432/some-database?sslmode=verify-full&sslrootcert=system",
1060                "user": "some-user"
1061            }),
1062            &Config {
1063                ssl_root_cert: Some(SslRootCert::System),
1064                ..config.clone()
1065            },
1066        );
1067
1068        assert_config(
1069            serde_json::json!({
1070                "database": "some-database",
1071                "endpoint": {
1072                    "host": "some-host",
1073                    "host_addr": "192.168.1.100",
1074                    "port": 5432,
1075                },
1076                "ssl_mode": "verify-full",
1077                "url": "postgres://some-user@some-host:5432/some-database?hostaddr=192.168.1.100&sslmode=verify-full",
1078                "user": "some-user"
1079            }),
1080            &Config {
1081                endpoint: Endpoint::Network {
1082                    host: Host::from_str("some-host").unwrap(),
1083                    channel_binding: None,
1084                    host_addr: Some("192.168.1.100".parse().unwrap()),
1085                    port: Some(Port::new(5432)),
1086                },
1087                ..config.clone()
1088            },
1089        );
1090
1091        // Test Network endpoint without port (should use default)
1092        assert_config(
1093            serde_json::json!({
1094                "database": "some-database",
1095                "endpoint": {
1096                    "host": "some-host",
1097                },
1098                "ssl_mode": "verify-full",
1099                "url": "postgres://some-user@some-host/some-database?sslmode=verify-full",
1100                "user": "some-user"
1101            }),
1102            &Config {
1103                endpoint: Endpoint::Network {
1104                    host: Host::from_str("some-host").unwrap(),
1105                    channel_binding: None,
1106                    host_addr: None,
1107                    port: None,
1108                },
1109                ..config.clone()
1110            },
1111        );
1112
1113        // Test Network endpoint with host_addr but without port
1114        assert_config(
1115            serde_json::json!({
1116                "database": "some-database",
1117                "endpoint": {
1118                    "host": "some-host",
1119                    "host_addr": "10.0.0.1",
1120                },
1121                "ssl_mode": "verify-full",
1122                "url": "postgres://some-user@some-host/some-database?hostaddr=10.0.0.1&sslmode=verify-full",
1123                "user": "some-user"
1124            }),
1125            &Config {
1126                endpoint: Endpoint::Network {
1127                    host: Host::from_str("some-host").unwrap(),
1128                    channel_binding: None,
1129                    host_addr: Some("10.0.0.1".parse().unwrap()),
1130                    port: None,
1131                },
1132                ..config.clone()
1133            },
1134        );
1135    }
1136
1137    #[test]
1138    fn test_ipv6_url_formation() {
1139        // Test IPv6 loopback address
1140        let config_ipv6_loopback = Config {
1141            application_name: None,
1142            database: TEST_DATABASE,
1143            endpoint: Endpoint::Network {
1144                host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)),
1145                channel_binding: None,
1146                host_addr: None,
1147                port: Some(Port::new(5432)),
1148            },
1149            password: None,
1150            ssl_mode: SslMode::Disable,
1151            ssl_root_cert: None,
1152            user: User::POSTGRES,
1153        };
1154
1155        assert_eq!(
1156            config_ipv6_loopback.to_url_string(),
1157            "postgres://postgres@[::1]:5432/some-database?sslmode=disable",
1158            "IPv6 loopback address should be bracketed in URL"
1159        );
1160
1161        // Test fe80 link-local IPv6 address
1162        let config_ipv6_fe80 = Config {
1163            application_name: None,
1164            database: TEST_DATABASE,
1165            endpoint: Endpoint::Network {
1166                host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::new(
1167                    0xfe80, 0, 0, 0, 0, 0, 0, 1,
1168                ))),
1169                channel_binding: None,
1170                host_addr: None,
1171                port: Some(Port::new(5432)),
1172            },
1173            password: None,
1174            ssl_mode: SslMode::Disable,
1175            ssl_root_cert: None,
1176            user: User::POSTGRES,
1177        };
1178
1179        assert_eq!(
1180            config_ipv6_fe80.to_url_string(),
1181            "postgres://postgres@[fe80::1]:5432/some-database?sslmode=disable",
1182            "IPv6 link-local address should be bracketed in URL"
1183        );
1184
1185        // Test full IPv6 address
1186        let config_ipv6_full = Config {
1187            application_name: None,
1188            database: TEST_DATABASE,
1189            endpoint: Endpoint::Network {
1190                host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::new(
1191                    0x2001, 0x0db8, 0, 0, 0, 0, 0, 1,
1192                ))),
1193                channel_binding: None,
1194                host_addr: None,
1195                port: Some(Port::new(5432)),
1196            },
1197            password: None,
1198            ssl_mode: SslMode::Disable,
1199            ssl_root_cert: None,
1200            user: User::POSTGRES,
1201        };
1202
1203        assert_eq!(
1204            config_ipv6_full.to_url_string(),
1205            "postgres://postgres@[2001:db8::1]:5432/some-database?sslmode=disable",
1206            "Full IPv6 address should be bracketed in URL"
1207        );
1208
1209        // Test IPv4 address (should NOT be bracketed)
1210        let config_ipv4 = Config {
1211            application_name: None,
1212            database: TEST_DATABASE,
1213            endpoint: Endpoint::Network {
1214                host: Host::IpAddr(std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)),
1215                channel_binding: None,
1216                host_addr: None,
1217                port: Some(Port::new(5432)),
1218            },
1219            password: None,
1220            ssl_mode: SslMode::Disable,
1221            ssl_root_cert: None,
1222            user: User::POSTGRES,
1223        };
1224
1225        assert_eq!(
1226            config_ipv4.to_url_string(),
1227            "postgres://postgres@127.0.0.1:5432/some-database?sslmode=disable",
1228            "IPv4 address should NOT be bracketed in URL"
1229        );
1230
1231        // Test hostname (should NOT be bracketed)
1232        let config_hostname = Config {
1233            application_name: None,
1234            database: TEST_DATABASE,
1235            endpoint: Endpoint::Network {
1236                host: Host::from_str("localhost").unwrap(),
1237                channel_binding: None,
1238                host_addr: None,
1239                port: Some(Port::new(5432)),
1240            },
1241            password: None,
1242            ssl_mode: SslMode::Disable,
1243            ssl_root_cert: None,
1244            user: User::POSTGRES,
1245        };
1246
1247        assert_eq!(
1248            config_hostname.to_url_string(),
1249            "postgres://postgres@localhost:5432/some-database?sslmode=disable",
1250            "Hostname should NOT be bracketed in URL"
1251        );
1252    }
1253}