Skip to main content

pg_client/
lib.rs

1#![doc = include_str!("../README.md")]
2
3pub mod identifier;
4pub mod pg_dump;
5
6pub use identifier::{Database, QualifiedTable, Role, User};
7pub use pg_dump::{PgSchemaDump, RestrictKey};
8
9#[cfg(feature = "sqlx")]
10pub mod sqlx;
11
12pub mod url;
13
14/// Macro to generate `std::str::FromStr` plus helpers for string wrapped newtypes
15macro_rules! from_str_impl {
16    ($struct: ident, $min: expr, $max: expr) => {
17        impl std::str::FromStr for $struct {
18            type Err = String;
19
20            fn from_str(value: &str) -> Result<Self, Self::Err> {
21                let min_length = Self::MIN_LENGTH;
22                let max_length = Self::MAX_LENGTH;
23                let actual = value.len();
24
25                if actual < min_length {
26                    Err(format!(
27                        "{} byte min length: {min_length} violated, got: {actual}",
28                        stringify!($struct)
29                    ))
30                } else if actual > max_length {
31                    Err(format!(
32                        "{} byte max length: {max_length} violated, got: {actual}",
33                        stringify!($struct)
34                    ))
35                } else if value.as_bytes().contains(&0) {
36                    Err(format!("{} contains NUL byte", stringify!($struct)))
37                } else {
38                    Ok(Self(value.to_string()))
39                }
40            }
41        }
42
43        impl AsRef<str> for $struct {
44            fn as_ref(&self) -> &str {
45                &self.0
46            }
47        }
48
49        impl $struct {
50            pub const MIN_LENGTH: usize = $min;
51            pub const MAX_LENGTH: usize = $max;
52
53            pub fn as_str(&self) -> &str {
54                &self.0
55            }
56        }
57    };
58}
59
60#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
61pub struct HostName(String);
62
63impl HostName {
64    #[must_use]
65    pub fn as_str(&self) -> &str {
66        &self.0
67    }
68}
69
70impl std::str::FromStr for HostName {
71    type Err = &'static str;
72
73    fn from_str(value: &str) -> Result<Self, Self::Err> {
74        if hostname_validator::is_valid(value) {
75            Ok(Self(value.to_string()))
76        } else {
77            Err("invalid host name")
78        }
79    }
80}
81
82impl<'de> serde::Deserialize<'de> for HostName {
83    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
84    where
85        D: serde::Deserializer<'de>,
86    {
87        let s = String::deserialize(deserializer)?;
88        s.parse().map_err(serde::de::Error::custom)
89    }
90}
91
92#[derive(Clone, Debug, PartialEq, Eq)]
93pub enum Host {
94    HostName(HostName),
95    IpAddr(std::net::IpAddr),
96}
97
98impl serde::Serialize for Host {
99    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
100        serializer.serialize_str(&self.pg_env_value())
101    }
102}
103
104impl Host {
105    pub(crate) fn pg_env_value(&self) -> String {
106        match self {
107            Self::HostName(value) => value.0.clone(),
108            Self::IpAddr(value) => value.to_string(),
109        }
110    }
111}
112
113impl std::str::FromStr for Host {
114    type Err = &'static str;
115
116    fn from_str(value: &str) -> Result<Self, Self::Err> {
117        match std::net::IpAddr::from_str(value) {
118            Ok(addr) => Ok(Self::IpAddr(addr)),
119            Err(_) => match HostName::from_str(value) {
120                Ok(host_name) => Ok(Self::HostName(host_name)),
121                Err(_) => Err("Not a socket address or FQDN"),
122            },
123        }
124    }
125}
126
127impl From<HostName> for Host {
128    fn from(value: HostName) -> Self {
129        Self::HostName(value)
130    }
131}
132
133impl From<std::net::IpAddr> for Host {
134    fn from(value: std::net::IpAddr) -> Self {
135        Self::IpAddr(value)
136    }
137}
138
139#[derive(Clone, Debug, PartialEq, Eq)]
140pub struct HostAddr(std::net::IpAddr);
141
142impl HostAddr {
143    #[must_use]
144    pub const fn new(ip: std::net::IpAddr) -> Self {
145        Self(ip)
146    }
147}
148
149impl From<std::net::IpAddr> for HostAddr {
150    /// # Example
151    /// ```
152    /// use pg_client::HostAddr;
153    /// use std::net::IpAddr;
154    ///
155    /// let ip: IpAddr = "192.168.1.1".parse().unwrap();
156    /// let host_addr = HostAddr::from(ip);
157    /// assert_eq!(IpAddr::from(host_addr).to_string(), "192.168.1.1");
158    /// ```
159    fn from(value: std::net::IpAddr) -> Self {
160        Self(value)
161    }
162}
163
164impl From<HostAddr> for std::net::IpAddr {
165    fn from(value: HostAddr) -> Self {
166        value.0
167    }
168}
169
170impl From<&HostAddr> for std::net::IpAddr {
171    fn from(value: &HostAddr) -> Self {
172        value.0
173    }
174}
175
176impl std::fmt::Display for HostAddr {
177    /// # Example
178    /// ```
179    /// use pg_client::HostAddr;
180    ///
181    /// let host_addr: HostAddr = "10.0.0.1".parse().unwrap();
182    /// assert_eq!(host_addr.to_string(), "10.0.0.1");
183    /// ```
184    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        write!(formatter, "{}", self.0)
186    }
187}
188
189impl std::str::FromStr for HostAddr {
190    type Err = &'static str;
191
192    /// # Example
193    /// ```
194    /// use pg_client::HostAddr;
195    /// use std::str::FromStr;
196    ///
197    /// let host_addr = HostAddr::from_str("127.0.0.1").unwrap();
198    /// assert_eq!(host_addr.to_string(), "127.0.0.1");
199    ///
200    /// // Also works with the parse method
201    /// let host_addr: HostAddr = "::1".parse().unwrap();
202    /// assert_eq!(host_addr.to_string(), "::1");
203    ///
204    /// // Invalid IP addresses return an error
205    /// assert!(HostAddr::from_str("not-an-ip").is_err());
206    /// ```
207    fn from_str(value: &str) -> Result<Self, Self::Err> {
208        match std::net::IpAddr::from_str(value) {
209            Ok(addr) => Ok(Self(addr)),
210            Err(_) => Err("invalid IP address"),
211        }
212    }
213}
214
215#[derive(Clone, Debug, PartialEq, Eq)]
216pub enum Endpoint {
217    Network {
218        host: Host,
219        channel_binding: Option<ChannelBinding>,
220        host_addr: Option<HostAddr>,
221        port: Option<Port>,
222    },
223    SocketPath(std::path::PathBuf),
224}
225
226impl serde::Serialize for Endpoint {
227    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
228        use serde::ser::SerializeStruct;
229        match self {
230            Self::Network {
231                host,
232                channel_binding,
233                host_addr,
234                port,
235            } => {
236                let mut state = serializer.serialize_struct("Endpoint", 4)?;
237                state.serialize_field("host", host)?;
238                if let Some(channel_binding) = channel_binding {
239                    state.serialize_field("channel_binding", channel_binding)?;
240                }
241                if let Some(addr) = host_addr {
242                    state.serialize_field("host_addr", &addr.to_string())?;
243                }
244                if let Some(port) = port {
245                    state.serialize_field("port", port)?;
246                }
247                state.end()
248            }
249            Self::SocketPath(path) => {
250                let mut state = serializer.serialize_struct("Endpoint", 1)?;
251                state.serialize_field(
252                    "socket_path",
253                    &path.to_str().expect("socket path contains invalid utf8"),
254                )?;
255                state.end()
256            }
257        }
258    }
259}
260
261#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize)]
262pub struct Port(u16);
263
264impl Port {
265    #[must_use]
266    pub const fn new(port: u16) -> Self {
267        Self(port)
268    }
269
270    fn pg_env_value(self) -> String {
271        self.0.to_string()
272    }
273}
274
275impl std::str::FromStr for Port {
276    type Err = &'static str;
277
278    fn from_str(value: &str) -> Result<Self, Self::Err> {
279        match <u16 as std::str::FromStr>::from_str(value) {
280            Ok(port) => Ok(Port(port)),
281            Err(_) => Err("invalid postgresql port string"),
282        }
283    }
284}
285
286impl From<u16> for Port {
287    fn from(port: u16) -> Self {
288        Self(port)
289    }
290}
291
292impl From<Port> for u16 {
293    fn from(port: Port) -> Self {
294        port.0
295    }
296}
297
298impl From<&Port> for u16 {
299    fn from(port: &Port) -> Self {
300        port.0
301    }
302}
303
304#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
305pub struct ApplicationName(String);
306
307from_str_impl!(ApplicationName, 1, 63);
308
309impl ApplicationName {
310    fn pg_env_value(&self) -> String {
311        self.0.clone()
312    }
313}
314
315impl Database {
316    fn pg_env_value(&self) -> String {
317        self.as_str().to_owned()
318    }
319}
320
321impl Role {
322    fn pg_env_value(&self) -> String {
323        self.as_str().to_owned()
324    }
325}
326
327#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
328pub struct Password(String);
329
330from_str_impl!(Password, 0, 4096);
331
332impl Password {
333    fn pg_env_value(&self) -> String {
334        self.0.clone()
335    }
336}
337
338#[derive(
339    Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, strum::IntoStaticStr, strum::EnumString,
340)]
341#[serde(rename_all = "kebab-case")]
342#[strum(serialize_all = "kebab-case")]
343pub enum SslMode {
344    Allow,
345    Disable,
346    Prefer,
347    Require,
348    VerifyCa,
349    VerifyFull,
350}
351
352impl SslMode {
353    #[must_use]
354    pub fn as_str(&self) -> &'static str {
355        self.into()
356    }
357
358    fn pg_env_value(&self) -> String {
359        self.as_str().to_string()
360    }
361}
362
363#[derive(
364    Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, strum::IntoStaticStr, strum::EnumString,
365)]
366#[serde(rename_all = "kebab-case")]
367#[strum(serialize_all = "kebab-case")]
368pub enum ChannelBinding {
369    Disable,
370    Prefer,
371    Require,
372}
373
374impl ChannelBinding {
375    #[must_use]
376    pub fn as_str(&self) -> &'static str {
377        self.into()
378    }
379
380    fn pg_env_value(&self) -> String {
381        self.as_str().to_string()
382    }
383}
384
385#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
386#[serde(rename_all = "kebab-case")]
387pub enum SslRootCert {
388    File(std::path::PathBuf),
389    System,
390}
391
392impl SslRootCert {
393    pub(crate) fn pg_env_value(&self) -> String {
394        match self {
395            Self::File(path) => path.to_str().unwrap().to_string(),
396            Self::System => "system".to_string(),
397        }
398    }
399}
400
401impl From<std::path::PathBuf> for SslRootCert {
402    fn from(value: std::path::PathBuf) -> Self {
403        Self::File(value)
404    }
405}
406
407#[derive(Clone, Debug, PartialEq, Eq)]
408/// PG connection config with various presentation modes.
409///
410/// Supported:
411///
412/// 1. Env variables via `to_pg_env()`
413/// 2. JSON document via `serde`
414/// 3. sqlx connect options via `to_sqlx_connect_options()`
415/// 4. Individual field access
416pub struct Config {
417    pub application_name: Option<ApplicationName>,
418    pub database: Database,
419    pub endpoint: Endpoint,
420    pub password: Option<Password>,
421    pub ssl_mode: SslMode,
422    pub ssl_root_cert: Option<SslRootCert>,
423    pub user: User,
424}
425
426pub const PGAPPNAME: cmd_proc::EnvVariableName<'static> =
427    cmd_proc::EnvVariableName::from_static_or_panic("PGAPPNAME");
428pub const PGCHANNELBINDING: cmd_proc::EnvVariableName<'static> =
429    cmd_proc::EnvVariableName::from_static_or_panic("PGCHANNELBINDING");
430pub const PGDATABASE: cmd_proc::EnvVariableName<'static> =
431    cmd_proc::EnvVariableName::from_static_or_panic("PGDATABASE");
432pub const PGHOST: cmd_proc::EnvVariableName<'static> =
433    cmd_proc::EnvVariableName::from_static_or_panic("PGHOST");
434pub const PGHOSTADDR: cmd_proc::EnvVariableName<'static> =
435    cmd_proc::EnvVariableName::from_static_or_panic("PGHOSTADDR");
436pub const PGPASSWORD: cmd_proc::EnvVariableName<'static> =
437    cmd_proc::EnvVariableName::from_static_or_panic("PGPASSWORD");
438pub const PGPORT: cmd_proc::EnvVariableName<'static> =
439    cmd_proc::EnvVariableName::from_static_or_panic("PGPORT");
440pub const PGSSLMODE: cmd_proc::EnvVariableName<'static> =
441    cmd_proc::EnvVariableName::from_static_or_panic("PGSSLMODE");
442pub const PGSSLROOTCERT: cmd_proc::EnvVariableName<'static> =
443    cmd_proc::EnvVariableName::from_static_or_panic("PGSSLROOTCERT");
444pub const PGUSER: cmd_proc::EnvVariableName<'static> =
445    cmd_proc::EnvVariableName::from_static_or_panic("PGUSER");
446
447impl serde::Serialize for Config {
448    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
449        use serde::ser::SerializeStruct;
450        let mut state = serializer.serialize_struct("Config", 8)?;
451
452        if let Some(application_name) = &self.application_name {
453            state.serialize_field("application_name", application_name)?;
454        }
455
456        state.serialize_field("database", &self.database)?;
457        state.serialize_field("endpoint", &self.endpoint)?;
458
459        if let Some(password) = &self.password {
460            state.serialize_field("password", password)?;
461        }
462
463        state.serialize_field("ssl_mode", &self.ssl_mode)?;
464
465        if let Some(ssl_root_cert) = &self.ssl_root_cert {
466            state.serialize_field("ssl_root_cert", ssl_root_cert)?;
467        }
468
469        state.serialize_field("user", &self.user)?;
470        state.serialize_field("url", &self.to_url_string())?;
471
472        state.end()
473    }
474}
475
476impl Config {
477    /// Convert to PG connection URL
478    ///
479    /// ```
480    /// # use pg_client::*;
481    /// # use std::str::FromStr;
482    ///
483    /// let config = Config {
484    ///     application_name: None,
485    ///     database: Database::from_static_or_panic("some-database"),
486    ///     endpoint: Endpoint::Network {
487    ///         host: Host::from_str("some-host").unwrap(),
488    ///         channel_binding: None,
489    ///         host_addr: None,
490    ///         port: Some(Port::new(5432)),
491    ///     },
492    ///     password: None,
493    ///     ssl_mode: SslMode::VerifyFull,
494    ///     ssl_root_cert: None,
495    ///     user: User::from_static_or_panic("some-user"),
496    /// };
497    ///
498    /// assert_eq!(
499    ///     config.to_url_string(),
500    ///     "postgres://some-user@some-host:5432/some-database?sslmode=verify-full"
501    /// );
502    ///
503    /// assert_eq!(
504    ///     Config {
505    ///         application_name: Some(ApplicationName::from_str("some-app").unwrap()),
506    ///         password: Some(Password::from_str("some-password").unwrap()),
507    ///         ssl_root_cert: Some(SslRootCert::File("/some.pem".into())),
508    ///         ..config.clone()
509    ///     }.to_url_string(),
510    ///     "postgres://some-user:some-password@some-host:5432/some-database?application_name=some-app&sslmode=verify-full&sslrootcert=%2Fsome.pem"
511    /// );
512    ///
513    /// assert_eq!(
514    ///     Config {
515    ///         endpoint: Endpoint::Network {
516    ///             host: Host::from_str("some-host").unwrap(),
517    ///             channel_binding: None,
518    ///             host_addr: Some("127.0.0.1".parse().unwrap()),
519    ///             port: Some(Port::new(5432)),
520    ///         },
521    ///         ..config.clone()
522    ///     }.to_url_string(),
523    ///     "postgres://some-user@some-host:5432/some-database?hostaddr=127.0.0.1&sslmode=verify-full"
524    /// );
525    ///
526    /// // IPv4 example
527    /// let ipv4_config = Config {
528    ///     application_name: None,
529    ///     database: Database::from_static_or_panic("mydb"),
530    ///     endpoint: Endpoint::Network {
531    ///         host: Host::IpAddr(std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1))),
532    ///         channel_binding: None,
533    ///         host_addr: None,
534    ///         port: Some(Port::new(5432)),
535    ///     },
536    ///     password: None,
537    ///     ssl_mode: SslMode::Disable,
538    ///     ssl_root_cert: None,
539    ///     user: User::from_static_or_panic("user"),
540    /// };
541    /// assert_eq!(
542    ///     ipv4_config.to_url_string(),
543    ///     "postgres://user@127.0.0.1:5432/mydb?sslmode=disable"
544    /// );
545    ///
546    /// // IPv6 example (automatically bracketed)
547    /// let ipv6_config = Config {
548    ///     application_name: None,
549    ///     database: Database::from_static_or_panic("mydb"),
550    ///     endpoint: Endpoint::Network {
551    ///         host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)),
552    ///         channel_binding: None,
553    ///         host_addr: None,
554    ///         port: Some(Port::new(5432)),
555    ///     },
556    ///     password: None,
557    ///     ssl_mode: SslMode::Disable,
558    ///     ssl_root_cert: None,
559    ///     user: User::from_static_or_panic("user"),
560    /// };
561    /// assert_eq!(
562    ///     ipv6_config.to_url_string(),
563    ///     "postgres://user@[::1]:5432/mydb?sslmode=disable"
564    /// );
565    /// ```
566    #[must_use]
567    pub fn to_url(&self) -> ::fluent_uri::Uri<String> {
568        use ::fluent_uri::{
569            Uri,
570            build::Builder,
571            component::{Authority, Scheme},
572            pct_enc::{EStr, EString, encoder},
573        };
574
575        const POSTGRES: &Scheme = Scheme::new_or_panic("postgres");
576
577        fn append_query_pair(query: &mut EString<encoder::Query>, key: &str, value: &str) {
578            if !query.is_empty() {
579                query.push('&');
580            }
581            query.encode_str::<encoder::Data>(key);
582            query.push('=');
583            query.encode_str::<encoder::Data>(value);
584        }
585
586        let mut query = EString::<encoder::Query>::new();
587
588        match &self.endpoint {
589            Endpoint::Network {
590                host,
591                channel_binding,
592                host_addr,
593                port,
594            } => {
595                let mut userinfo = EString::<encoder::Userinfo>::new();
596                userinfo.encode_str::<encoder::Data>(self.user.pg_env_value().as_str());
597                if let Some(password) = &self.password {
598                    userinfo.push(':');
599                    userinfo.encode_str::<encoder::Data>(password.as_str());
600                }
601
602                let mut path = EString::<encoder::Path>::new();
603                path.push('/');
604                path.encode_str::<encoder::Data>(self.database.as_str());
605
606                if let Some(addr) = host_addr {
607                    append_query_pair(&mut query, "hostaddr", &addr.to_string());
608                }
609                if let Some(channel_binding) = channel_binding {
610                    append_query_pair(&mut query, "channel_binding", channel_binding.as_str());
611                }
612                self.append_common_query_params(&mut query, append_query_pair);
613
614                let non_empty_query = if query.is_empty() {
615                    None
616                } else {
617                    Some(query.as_estr())
618                };
619
620                // build() only fails on RFC 3986 structural violations:
621                // scheme and authority are always present, path starts with '/'.
622                Uri::builder()
623                    .scheme(POSTGRES)
624                    .authority_with(|builder| {
625                        let builder = builder.userinfo(&userinfo);
626                        let builder = match host {
627                            Host::IpAddr(addr) => builder.host(*addr),
628                            Host::HostName(name) => {
629                                let mut encoded = EString::<encoder::RegName>::new();
630                                encoded.encode_str::<encoder::Data>(name.as_str());
631                                builder.host(encoded.as_estr())
632                            }
633                        };
634                        match port {
635                            Some(port) => builder.port(port.0),
636                            None => builder.advance(),
637                        }
638                    })
639                    .path(&path)
640                    .optional(Builder::query, non_empty_query)
641                    .build()
642                    .unwrap()
643            }
644            Endpoint::SocketPath(path) => {
645                append_query_pair(
646                    &mut query,
647                    "host",
648                    path.to_str().expect("socket path contains invalid utf8"),
649                );
650                append_query_pair(&mut query, "dbname", self.database.as_str());
651                append_query_pair(&mut query, "user", self.user.pg_env_value().as_str());
652                if let Some(password) = &self.password {
653                    append_query_pair(&mut query, "password", password.as_str());
654                }
655                self.append_common_query_params(&mut query, append_query_pair);
656
657                // build() only fails on RFC 3986 structural violations:
658                // scheme and authority are always present, path is empty.
659                Uri::builder()
660                    .scheme(POSTGRES)
661                    .authority(Authority::EMPTY)
662                    .path(EStr::EMPTY)
663                    .query(&query)
664                    .build()
665                    .unwrap()
666            }
667        }
668    }
669
670    /// Convert to PG connection URL string
671    #[must_use]
672    pub fn to_url_string(&self) -> String {
673        self.to_url().into_string()
674    }
675
676    fn append_common_query_params(
677        &self,
678        query: &mut ::fluent_uri::pct_enc::EString<::fluent_uri::pct_enc::encoder::Query>,
679        append_query_pair: fn(
680            &mut ::fluent_uri::pct_enc::EString<::fluent_uri::pct_enc::encoder::Query>,
681            &str,
682            &str,
683        ),
684    ) {
685        if let Some(application_name) = &self.application_name {
686            append_query_pair(query, "application_name", application_name.as_str());
687        }
688        append_query_pair(query, "sslmode", &self.ssl_mode.pg_env_value());
689        if let Some(ssl_root_cert) = &self.ssl_root_cert {
690            append_query_pair(query, "sslrootcert", &ssl_root_cert.pg_env_value());
691        }
692    }
693
694    /// Convert to PG environment variable names
695    ///
696    /// ```
697    /// # use pg_client::*;
698    /// # use std::collections::BTreeMap;
699    ///
700    /// let config = Config {
701    ///     application_name: None,
702    ///     database: "some-database".parse().unwrap(),
703    ///     endpoint: Endpoint::Network {
704    ///         host: "some-host".parse().unwrap(),
705    ///         channel_binding: None,
706    ///         host_addr: None,
707    ///         port: Some(Port::new(5432)),
708    ///     },
709    ///     password: None,
710    ///     ssl_mode: SslMode::VerifyFull,
711    ///     ssl_root_cert: None,
712    ///     user: "some-user".parse().unwrap(),
713    /// };
714    ///
715    /// let expected = BTreeMap::from([
716    ///     (PGDATABASE, "some-database".to_string()),
717    ///     (PGHOST, "some-host".to_string()),
718    ///     (PGPORT, "5432".to_string()),
719    ///     (PGSSLMODE, "verify-full".to_string()),
720    ///     (PGUSER, "some-user".to_string()),
721    /// ]);
722    ///
723    /// assert_eq!(expected, config.to_pg_env());
724    ///
725    /// let config_with_optionals = Config {
726    ///     application_name: Some("some-app".parse().unwrap()),
727    ///     endpoint: Endpoint::Network {
728    ///         host: "some-host".parse().unwrap(),
729    ///         channel_binding: None,
730    ///         host_addr: Some("127.0.0.1".parse().unwrap()),
731    ///         port: Some(Port::new(5432)),
732    ///     },
733    ///     password: Some("some-password".parse().unwrap()),
734    ///     ssl_root_cert: Some(SslRootCert::File("/some.pem".into())),
735    ///     ..config
736    /// };
737    ///
738    /// let expected = BTreeMap::from([
739    ///     (PGAPPNAME, "some-app".to_string()),
740    ///     (PGDATABASE, "some-database".to_string()),
741    ///     (PGHOST, "some-host".to_string()),
742    ///     (PGHOSTADDR, "127.0.0.1".to_string()),
743    ///     (PGPASSWORD, "some-password".to_string()),
744    ///     (PGPORT, "5432".to_string()),
745    ///     (PGSSLMODE, "verify-full".to_string()),
746    ///     (PGSSLROOTCERT, "/some.pem".to_string()),
747    ///     (PGUSER, "some-user".to_string()),
748    /// ]);
749    ///
750    /// assert_eq!(expected, config_with_optionals.to_pg_env());
751    /// ```
752    #[must_use]
753    pub fn to_pg_env(
754        &self,
755    ) -> std::collections::BTreeMap<cmd_proc::EnvVariableName<'static>, String> {
756        let mut map = std::collections::BTreeMap::new();
757
758        match &self.endpoint {
759            Endpoint::Network {
760                host,
761                channel_binding,
762                host_addr,
763                port,
764            } => {
765                map.insert(PGHOST.clone(), host.pg_env_value());
766                if let Some(port) = port {
767                    map.insert(PGPORT.clone(), port.pg_env_value());
768                }
769                if let Some(channel_binding) = channel_binding {
770                    map.insert(PGCHANNELBINDING.clone(), channel_binding.pg_env_value());
771                }
772                if let Some(addr) = host_addr {
773                    map.insert(PGHOSTADDR.clone(), addr.to_string());
774                }
775            }
776            Endpoint::SocketPath(path) => {
777                map.insert(
778                    PGHOST.clone(),
779                    path.to_str()
780                        .expect("socket path contains invalid utf8")
781                        .to_string(),
782                );
783            }
784        }
785
786        map.insert(PGSSLMODE.clone(), self.ssl_mode.pg_env_value());
787        map.insert(PGUSER.clone(), self.user.pg_env_value());
788        map.insert(PGDATABASE.clone(), self.database.pg_env_value());
789
790        if let Some(application_name) = &self.application_name {
791            map.insert(PGAPPNAME.clone(), application_name.pg_env_value());
792        }
793
794        if let Some(password) = &self.password {
795            map.insert(PGPASSWORD.clone(), password.pg_env_value());
796        }
797
798        if let Some(ssl_root_cert) = &self.ssl_root_cert {
799            map.insert(PGSSLROOTCERT.clone(), ssl_root_cert.pg_env_value());
800        }
801
802        map
803    }
804
805    #[must_use]
806    pub fn endpoint(self, endpoint: Endpoint) -> Self {
807        Self { endpoint, ..self }
808    }
809
810    /// Parse a PostgreSQL connection URL string into a Config.
811    ///
812    /// When the URL does not specify `sslmode`, it defaults to `verify-full`
813    /// to ensure secure connections by default.
814    ///
815    /// See [`url::parse`] for full documentation.
816    pub fn from_str_url(url: &str) -> Result<Self, crate::url::ParseError> {
817        crate::url::parse(url)
818    }
819}
820
821#[cfg(test)]
822mod test {
823    use super::*;
824    use pretty_assertions::assert_eq;
825    use std::str::FromStr;
826
827    const TEST_DATABASE: Database = Database::from_static_or_panic("some-database");
828    const TEST_USER: User = User::from_static_or_panic("some-user");
829
830    fn assert_config(expected: serde_json::Value, config: &Config) {
831        assert_eq!(expected, serde_json::to_value(config).unwrap());
832    }
833
834    fn repeat(char: char, len: usize) -> String {
835        std::iter::repeat_n(char, len).collect()
836    }
837
838    #[test]
839    fn application_name_lt_min_length() {
840        let value = String::new();
841
842        let err = ApplicationName::from_str(&value).expect_err("expected min length failure");
843
844        assert_eq!(err, "ApplicationName byte min length: 1 violated, got: 0");
845    }
846
847    #[test]
848    fn application_name_eq_min_length() {
849        let value = repeat('a', 1);
850
851        let application_name =
852            ApplicationName::from_str(&value).expect("expected valid min length value");
853
854        assert_eq!(application_name, ApplicationName(value));
855    }
856
857    #[test]
858    fn application_name_gt_min_length() {
859        let value = repeat('a', 2);
860
861        let application_name =
862            ApplicationName::from_str(&value).expect("expected valid value greater than min");
863
864        assert_eq!(application_name, ApplicationName(value));
865    }
866
867    #[test]
868    fn application_name_lt_max_length() {
869        let value = repeat('a', 62);
870
871        let application_name =
872            ApplicationName::from_str(&value).expect("expected valid value less than max");
873
874        assert_eq!(application_name, ApplicationName(value));
875    }
876
877    #[test]
878    fn application_name_eq_max_length() {
879        let value = repeat('a', 63);
880
881        let application_name =
882            ApplicationName::from_str(&value).expect("expected valid value equal to max");
883
884        assert_eq!(application_name, ApplicationName(value));
885    }
886
887    #[test]
888    fn application_name_gt_max_length() {
889        let value = repeat('a', 64);
890
891        let err = ApplicationName::from_str(&value).expect_err("expected max length failure");
892
893        assert_eq!(err, "ApplicationName byte max length: 63 violated, got: 64");
894    }
895
896    #[test]
897    fn application_name_contains_nul() {
898        let value = String::from('\0');
899
900        let err = ApplicationName::from_str(&value).expect_err("expected NUL failure");
901
902        assert_eq!(err, "ApplicationName contains NUL byte");
903    }
904
905    #[test]
906    fn password_eq_min_length() {
907        let value = String::new();
908
909        let password = Password::from_str(&value).expect("expected valid min length value");
910
911        assert_eq!(password, Password(value));
912    }
913
914    #[test]
915    fn password_gt_min_length() {
916        let value = repeat('p', 1);
917
918        let password = Password::from_str(&value).expect("expected valid value greater than min");
919
920        assert_eq!(password, Password(value));
921    }
922
923    #[test]
924    fn password_lt_max_length() {
925        let value = repeat('p', 4095);
926
927        let password = Password::from_str(&value).expect("expected valid value less than max");
928
929        assert_eq!(password, Password(value));
930    }
931
932    #[test]
933    fn password_eq_max_length() {
934        let value = repeat('p', 4096);
935
936        let password = Password::from_str(&value).expect("expected valid value equal to max");
937
938        assert_eq!(password, Password(value));
939    }
940
941    #[test]
942    fn password_gt_max_length() {
943        let value = repeat('p', 4097);
944
945        let err = Password::from_str(&value).expect_err("expected max length failure");
946
947        assert_eq!(err, "Password byte max length: 4096 violated, got: 4097");
948    }
949
950    #[test]
951    fn password_contains_nul() {
952        let value = String::from('\0');
953
954        let err = Password::from_str(&value).expect_err("expected NUL failure");
955
956        assert_eq!(err, "Password contains NUL byte");
957    }
958
959    #[test]
960    fn test_json() {
961        let config = Config {
962            application_name: None,
963            database: TEST_DATABASE,
964            endpoint: Endpoint::Network {
965                host: Host::from_str("some-host").unwrap(),
966                channel_binding: None,
967                host_addr: None,
968                port: Some(Port::new(5432)),
969            },
970            password: None,
971            ssl_mode: SslMode::VerifyFull,
972            ssl_root_cert: None,
973            user: TEST_USER,
974        };
975
976        assert_config(
977            serde_json::json!({
978                "database": "some-database",
979                "endpoint": {
980                    "host": "some-host",
981                    "port": 5432,
982                },
983                "ssl_mode": "verify-full",
984                "url": "postgres://some-user@some-host:5432/some-database?sslmode=verify-full",
985                "user": "some-user",
986            }),
987            &config,
988        );
989
990        assert_config(
991            serde_json::json!({
992                "application_name": "some-app",
993                "database": "some-database",
994                "endpoint": {
995                    "host": "some-host",
996                    "port": 5432,
997                },
998                "password": "some-password",
999                "ssl_mode": "verify-full",
1000                "ssl_root_cert": {
1001                    "file": "/some.pem"
1002                },
1003                "url": "postgres://some-user:some-password@some-host:5432/some-database?application_name=some-app&sslmode=verify-full&sslrootcert=%2Fsome.pem",
1004                "user": "some-user"
1005            }),
1006            &Config {
1007                application_name: Some(ApplicationName::from_str("some-app").unwrap()),
1008                password: Some(Password::from_str("some-password").unwrap()),
1009                ssl_root_cert: Some(SslRootCert::File("/some.pem".into())),
1010                ..config.clone()
1011            },
1012        );
1013
1014        assert_config(
1015            serde_json::json!({
1016                "database": "some-database",
1017                "endpoint": {
1018                    "host": "127.0.0.1",
1019                    "port": 5432,
1020                },
1021                "ssl_mode": "verify-full",
1022                "url": "postgres://some-user@127.0.0.1:5432/some-database?sslmode=verify-full",
1023                "user": "some-user"
1024            }),
1025            &Config {
1026                endpoint: Endpoint::Network {
1027                    host: Host::from_str("127.0.0.1").unwrap(),
1028                    channel_binding: None,
1029                    host_addr: None,
1030                    port: Some(Port::new(5432)),
1031                },
1032                ..config.clone()
1033            },
1034        );
1035
1036        assert_config(
1037            serde_json::json!({
1038                "database": "some-database",
1039                "endpoint": {
1040                    "socket_path": "/some/socket",
1041                },
1042                "ssl_mode": "verify-full",
1043                "url": "postgres://?host=%2Fsome%2Fsocket&dbname=some-database&user=some-user&sslmode=verify-full",
1044                "user": "some-user"
1045            }),
1046            &Config {
1047                endpoint: Endpoint::SocketPath("/some/socket".into()),
1048                ..config.clone()
1049            },
1050        );
1051
1052        assert_config(
1053            serde_json::json!({
1054                "database": "some-database",
1055                "endpoint": {
1056                    "host": "some-host",
1057                    "port": 5432,
1058                },
1059                "ssl_mode": "verify-full",
1060                "ssl_root_cert": "system",
1061                "url": "postgres://some-user@some-host:5432/some-database?sslmode=verify-full&sslrootcert=system",
1062                "user": "some-user"
1063            }),
1064            &Config {
1065                ssl_root_cert: Some(SslRootCert::System),
1066                ..config.clone()
1067            },
1068        );
1069
1070        assert_config(
1071            serde_json::json!({
1072                "database": "some-database",
1073                "endpoint": {
1074                    "host": "some-host",
1075                    "host_addr": "192.168.1.100",
1076                    "port": 5432,
1077                },
1078                "ssl_mode": "verify-full",
1079                "url": "postgres://some-user@some-host:5432/some-database?hostaddr=192.168.1.100&sslmode=verify-full",
1080                "user": "some-user"
1081            }),
1082            &Config {
1083                endpoint: Endpoint::Network {
1084                    host: Host::from_str("some-host").unwrap(),
1085                    channel_binding: None,
1086                    host_addr: Some("192.168.1.100".parse().unwrap()),
1087                    port: Some(Port::new(5432)),
1088                },
1089                ..config.clone()
1090            },
1091        );
1092
1093        // Test Network endpoint without port (should use default)
1094        assert_config(
1095            serde_json::json!({
1096                "database": "some-database",
1097                "endpoint": {
1098                    "host": "some-host",
1099                },
1100                "ssl_mode": "verify-full",
1101                "url": "postgres://some-user@some-host/some-database?sslmode=verify-full",
1102                "user": "some-user"
1103            }),
1104            &Config {
1105                endpoint: Endpoint::Network {
1106                    host: Host::from_str("some-host").unwrap(),
1107                    channel_binding: None,
1108                    host_addr: None,
1109                    port: None,
1110                },
1111                ..config.clone()
1112            },
1113        );
1114
1115        // Test Network endpoint with host_addr but without port
1116        assert_config(
1117            serde_json::json!({
1118                "database": "some-database",
1119                "endpoint": {
1120                    "host": "some-host",
1121                    "host_addr": "10.0.0.1",
1122                },
1123                "ssl_mode": "verify-full",
1124                "url": "postgres://some-user@some-host/some-database?hostaddr=10.0.0.1&sslmode=verify-full",
1125                "user": "some-user"
1126            }),
1127            &Config {
1128                endpoint: Endpoint::Network {
1129                    host: Host::from_str("some-host").unwrap(),
1130                    channel_binding: None,
1131                    host_addr: Some("10.0.0.1".parse().unwrap()),
1132                    port: None,
1133                },
1134                ..config.clone()
1135            },
1136        );
1137    }
1138
1139    #[test]
1140    fn test_ipv6_url_formation() {
1141        // Test IPv6 loopback address
1142        let config_ipv6_loopback = Config {
1143            application_name: None,
1144            database: TEST_DATABASE,
1145            endpoint: Endpoint::Network {
1146                host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)),
1147                channel_binding: None,
1148                host_addr: None,
1149                port: Some(Port::new(5432)),
1150            },
1151            password: None,
1152            ssl_mode: SslMode::Disable,
1153            ssl_root_cert: None,
1154            user: User::POSTGRES,
1155        };
1156
1157        assert_eq!(
1158            config_ipv6_loopback.to_url_string(),
1159            "postgres://postgres@[::1]:5432/some-database?sslmode=disable",
1160            "IPv6 loopback address should be bracketed in URL"
1161        );
1162
1163        // Test fe80 link-local IPv6 address
1164        let config_ipv6_fe80 = Config {
1165            application_name: None,
1166            database: TEST_DATABASE,
1167            endpoint: Endpoint::Network {
1168                host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::new(
1169                    0xfe80, 0, 0, 0, 0, 0, 0, 1,
1170                ))),
1171                channel_binding: None,
1172                host_addr: None,
1173                port: Some(Port::new(5432)),
1174            },
1175            password: None,
1176            ssl_mode: SslMode::Disable,
1177            ssl_root_cert: None,
1178            user: User::POSTGRES,
1179        };
1180
1181        assert_eq!(
1182            config_ipv6_fe80.to_url_string(),
1183            "postgres://postgres@[fe80::1]:5432/some-database?sslmode=disable",
1184            "IPv6 link-local address should be bracketed in URL"
1185        );
1186
1187        // Test full IPv6 address
1188        let config_ipv6_full = Config {
1189            application_name: None,
1190            database: TEST_DATABASE,
1191            endpoint: Endpoint::Network {
1192                host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::new(
1193                    0x2001, 0x0db8, 0, 0, 0, 0, 0, 1,
1194                ))),
1195                channel_binding: None,
1196                host_addr: None,
1197                port: Some(Port::new(5432)),
1198            },
1199            password: None,
1200            ssl_mode: SslMode::Disable,
1201            ssl_root_cert: None,
1202            user: User::POSTGRES,
1203        };
1204
1205        assert_eq!(
1206            config_ipv6_full.to_url_string(),
1207            "postgres://postgres@[2001:db8::1]:5432/some-database?sslmode=disable",
1208            "Full IPv6 address should be bracketed in URL"
1209        );
1210
1211        // Test IPv4 address (should NOT be bracketed)
1212        let config_ipv4 = Config {
1213            application_name: None,
1214            database: TEST_DATABASE,
1215            endpoint: Endpoint::Network {
1216                host: Host::IpAddr(std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)),
1217                channel_binding: None,
1218                host_addr: None,
1219                port: Some(Port::new(5432)),
1220            },
1221            password: None,
1222            ssl_mode: SslMode::Disable,
1223            ssl_root_cert: None,
1224            user: User::POSTGRES,
1225        };
1226
1227        assert_eq!(
1228            config_ipv4.to_url_string(),
1229            "postgres://postgres@127.0.0.1:5432/some-database?sslmode=disable",
1230            "IPv4 address should NOT be bracketed in URL"
1231        );
1232
1233        // Test hostname (should NOT be bracketed)
1234        let config_hostname = Config {
1235            application_name: None,
1236            database: TEST_DATABASE,
1237            endpoint: Endpoint::Network {
1238                host: Host::from_str("localhost").unwrap(),
1239                channel_binding: None,
1240                host_addr: None,
1241                port: Some(Port::new(5432)),
1242            },
1243            password: None,
1244            ssl_mode: SslMode::Disable,
1245            ssl_root_cert: None,
1246            user: User::POSTGRES,
1247        };
1248
1249        assert_eq!(
1250            config_hostname.to_url_string(),
1251            "postgres://postgres@localhost:5432/some-database?sslmode=disable",
1252            "Hostname should NOT be bracketed in URL"
1253        );
1254    }
1255}