pg_client/
lib.rs

1#![doc = include_str!("../README.md")]
2
3/// Macro to generate `std::str::FromStr` plus helpers for string wrapped newtypes
4macro_rules! from_str_impl {
5    ($struct: ident, $min: expr, $max: expr) => {
6        impl std::str::FromStr for $struct {
7            type Err = &'static str;
8
9            fn from_str(value: &str) -> Result<Self, Self::Err> {
10                let min_length = Self::MIN_LENGTH;
11                let max_length = Self::MAX_LENGTH;
12                let actual = value.len();
13
14                if actual < min_length {
15                    Err(concat!(
16                        stringify!($struct),
17                        " byte min length: {min_length} violated, got: {actual}"
18                    ))
19                } else if actual > max_length {
20                    Err(concat!(
21                        stringify!($struct),
22                        " byte max length: {max_length} violated, got: {actual}"
23                    ))
24                } else if value.as_bytes().contains(&0) {
25                    Err(concat!(stringify!($struct), " contains NUL byte"))
26                } else {
27                    Ok(Self(value.to_string()))
28                }
29            }
30        }
31
32        impl AsRef<str> for $struct {
33            fn as_ref(&self) -> &str {
34                &self.0
35            }
36        }
37
38        impl $struct {
39            pub const MIN_LENGTH: usize = $min;
40            pub const MAX_LENGTH: usize = $max;
41
42            pub fn as_str(&self) -> &str {
43                &self.0
44            }
45        }
46    };
47}
48
49#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
50pub struct HostName(String);
51
52impl HostName {
53    pub fn as_str(&self) -> &str {
54        &self.0
55    }
56}
57
58impl std::str::FromStr for HostName {
59    type Err = &'static str;
60
61    fn from_str(value: &str) -> Result<Self, Self::Err> {
62        if hostname_validator::is_valid(value) {
63            Ok(Self(value.to_string()))
64        } else {
65            Err("invalid host name")
66        }
67    }
68}
69
70impl<'de> serde::Deserialize<'de> for HostName {
71    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
72    where
73        D: serde::Deserializer<'de>,
74    {
75        let s = String::deserialize(deserializer)?;
76        s.parse().map_err(serde::de::Error::custom)
77    }
78}
79
80#[derive(Clone, Debug, PartialEq, Eq)]
81pub enum Host {
82    HostName(HostName),
83    IpAddr(std::net::IpAddr),
84}
85
86impl serde::Serialize for Host {
87    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
88        serializer.serialize_str(&self.to_pg_env_value())
89    }
90}
91
92impl Host {
93    fn to_pg_env_value(&self) -> String {
94        match self {
95            Self::HostName(value) => value.0.clone(),
96            Self::IpAddr(value) => value.to_string(),
97        }
98    }
99}
100
101impl std::str::FromStr for Host {
102    type Err = &'static str;
103
104    fn from_str(value: &str) -> Result<Self, Self::Err> {
105        match std::net::IpAddr::from_str(value) {
106            Ok(addr) => Ok(Self::IpAddr(addr)),
107            Err(_) => match HostName::from_str(value) {
108                Ok(host_name) => Ok(Self::HostName(host_name)),
109                Err(_) => Err("Not a socket address or FQDN"),
110            },
111        }
112    }
113}
114
115#[macro_export]
116macro_rules! host {
117    ($string: literal) => {
118        <pg_client::Host as std::str::FromStr>::from_str($string).unwrap()
119    };
120}
121
122impl From<HostName> for Host {
123    fn from(value: HostName) -> Self {
124        Self::HostName(value)
125    }
126}
127
128impl From<std::net::IpAddr> for Host {
129    fn from(value: std::net::IpAddr) -> Self {
130        Self::IpAddr(value)
131    }
132}
133
134#[derive(Clone, Debug, PartialEq, Eq)]
135pub struct HostAddr(pub std::net::IpAddr);
136
137impl From<std::net::IpAddr> for HostAddr {
138    /// # Example
139    /// ```
140    /// use pg_client::HostAddr;
141    /// use std::net::IpAddr;
142    ///
143    /// let ip: IpAddr = "192.168.1.1".parse().unwrap();
144    /// let host_addr = HostAddr::from(ip);
145    /// assert_eq!(host_addr.0.to_string(), "192.168.1.1");
146    /// ```
147    fn from(value: std::net::IpAddr) -> Self {
148        Self(value)
149    }
150}
151
152impl std::fmt::Display for HostAddr {
153    /// # Example
154    /// ```
155    /// use pg_client::HostAddr;
156    ///
157    /// let host_addr: HostAddr = "10.0.0.1".parse().unwrap();
158    /// assert_eq!(host_addr.to_string(), "10.0.0.1");
159    /// ```
160    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        write!(formatter, "{}", self.0)
162    }
163}
164
165impl std::str::FromStr for HostAddr {
166    type Err = &'static str;
167
168    /// # Example
169    /// ```
170    /// use pg_client::HostAddr;
171    /// use std::str::FromStr;
172    ///
173    /// let host_addr = HostAddr::from_str("127.0.0.1").unwrap();
174    /// assert_eq!(host_addr.to_string(), "127.0.0.1");
175    ///
176    /// // Also works with the parse method
177    /// let host_addr: HostAddr = "::1".parse().unwrap();
178    /// assert_eq!(host_addr.to_string(), "::1");
179    ///
180    /// // Invalid IP addresses return an error
181    /// assert!(HostAddr::from_str("not-an-ip").is_err());
182    /// ```
183    fn from_str(value: &str) -> Result<Self, Self::Err> {
184        match std::net::IpAddr::from_str(value) {
185            Ok(addr) => Ok(Self(addr)),
186            Err(_) => Err("invalid IP address"),
187        }
188    }
189}
190
191#[derive(Clone, Debug, PartialEq, Eq)]
192pub enum Endpoint {
193    Network {
194        host: Host,
195        host_addr: Option<HostAddr>,
196        port: Option<Port>,
197    },
198    SocketPath(std::path::PathBuf),
199}
200
201impl serde::Serialize for Endpoint {
202    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
203        use serde::ser::SerializeStruct;
204        match self {
205            Self::Network {
206                host,
207                host_addr,
208                port,
209            } => {
210                let mut state = serializer.serialize_struct("Endpoint", 3)?;
211                state.serialize_field("host", host)?;
212                if let Some(addr) = host_addr {
213                    state.serialize_field("host_addr", &addr.to_string())?;
214                }
215                if let Some(port) = port {
216                    state.serialize_field("port", port)?;
217                }
218                state.end()
219            }
220            Self::SocketPath(path) => {
221                let mut state = serializer.serialize_struct("Endpoint", 1)?;
222                state.serialize_field(
223                    "socket_path",
224                    &path.to_str().expect("socket path contains invalid utf8"),
225                )?;
226                state.end()
227            }
228        }
229    }
230}
231
232#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize)]
233pub struct Port(pub u16);
234
235impl std::str::FromStr for Port {
236    type Err = &'static str;
237
238    fn from_str(value: &str) -> Result<Self, Self::Err> {
239        match <u16 as std::str::FromStr>::from_str(value) {
240            Ok(port) => Ok(Port(port)),
241            Err(_) => Err("invalid postgresql port string"),
242        }
243    }
244}
245
246impl Port {
247    fn to_pg_env_value(self) -> String {
248        self.0.to_string()
249    }
250}
251
252impl From<Port> for u16 {
253    fn from(port: Port) -> Self {
254        port.0
255    }
256}
257
258impl From<&Port> for u16 {
259    fn from(port: &Port) -> Self {
260        port.0
261    }
262}
263
264#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
265pub struct ApplicationName(String);
266
267from_str_impl!(ApplicationName, 1, 63);
268
269#[macro_export]
270macro_rules! application_name {
271    ($string: literal) => {
272        <pg_client::ApplicationName as std::str::FromStr>::from_str($string).unwrap()
273    };
274}
275
276impl ApplicationName {
277    fn to_pg_env_value(&self) -> String {
278        self.0.clone()
279    }
280}
281
282#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
283pub struct Database(String);
284
285from_str_impl!(Database, 1, 63);
286
287#[macro_export]
288macro_rules! database {
289    ($string: literal) => {
290        <pg_client::Database as std::str::FromStr>::from_str($string).unwrap()
291    };
292}
293
294impl Database {
295    fn to_pg_env_value(&self) -> String {
296        self.0.clone()
297    }
298}
299
300#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
301pub struct Username(String);
302
303from_str_impl!(Username, 1, 63);
304
305#[macro_export]
306macro_rules! username {
307    ($string: literal) => {
308        <pg_client::Username as std::str::FromStr>::from_str($string).unwrap()
309    };
310}
311
312impl Username {
313    fn to_pg_env_value(&self) -> String {
314        self.0.clone()
315    }
316}
317
318#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
319pub struct Password(String);
320
321from_str_impl!(Password, 0, 4096);
322
323impl Password {
324    fn to_pg_env_value(&self) -> String {
325        self.0.clone()
326    }
327}
328
329impl From<String> for Password {
330    fn from(value: String) -> Self {
331        Self(value)
332    }
333}
334
335#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
336#[serde(rename_all = "kebab-case")]
337pub enum SslMode {
338    Allow,
339    Disable,
340    Prefer,
341    Require,
342    VerifyCa,
343    VerifyFull,
344}
345
346impl SslMode {
347    pub fn as_str(&self) -> &'static str {
348        match self {
349            Self::Allow => "allow",
350            Self::Disable => "disable",
351            Self::Prefer => "prefer",
352            Self::Require => "require",
353            Self::VerifyCa => "verify-ca",
354            Self::VerifyFull => "verify-full",
355        }
356    }
357
358    fn to_sqlx_ssl_mode(&self) -> sqlx::postgres::PgSslMode {
359        use sqlx::postgres::PgSslMode;
360
361        match self {
362            Self::Allow => PgSslMode::Allow,
363            Self::Disable => PgSslMode::Disable,
364            Self::Prefer => PgSslMode::Prefer,
365            Self::Require => PgSslMode::Require,
366            Self::VerifyCa => PgSslMode::VerifyCa,
367            Self::VerifyFull => PgSslMode::VerifyFull,
368        }
369    }
370
371    fn to_pg_env_value(&self) -> String {
372        self.as_str().to_string()
373    }
374}
375
376#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
377#[serde(rename_all = "kebab-case")]
378pub enum SslRootCert {
379    File(std::path::PathBuf),
380    System,
381}
382
383impl SslRootCert {
384    fn to_pg_env_value(&self) -> String {
385        match self {
386            Self::File(path) => path.to_str().unwrap().to_string(),
387            Self::System => "system".to_string(),
388        }
389    }
390}
391
392impl From<std::path::PathBuf> for SslRootCert {
393    fn from(value: std::path::PathBuf) -> Self {
394        Self::File(value)
395    }
396}
397
398#[derive(Debug, Clone, PartialEq, Eq)]
399pub enum SqlxOptionsError {
400    EnvConflict { env_key: String, field_name: String },
401    UnsupportedFeature { env_key: String, field_name: String },
402}
403
404impl std::fmt::Display for SqlxOptionsError {
405    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
406        match self {
407            Self::EnvConflict {
408                env_key,
409                field_name,
410            } => write!(
411                f,
412                "`PgConnectOptions::new` has inferred a `{field_name}` from `{env_key}` environment variable, but `pg_client::Config` does not specify a `{field_name}` value. `PgConnectOptions` does not provide an API to construct an instance without inferring from the environment, does not provide an API to unset the field, we have to bail out at this point. Please remove the environment variable!"
413            ),
414            Self::UnsupportedFeature {
415                env_key,
416                field_name,
417            } => write!(
418                f,
419                "`PgConnectOptions::new` has inferred `{field_name}` from the `{env_key}` environment variable, but `pg_client::Config` does not support that feature at this point. As `PgConnectOptions` has no option to unset that field, or a constructor that allows us to bypass the inference: we have to bail out, please remove the environment variable!"
420            ),
421        }
422    }
423}
424
425impl std::error::Error for SqlxOptionsError {}
426
427#[derive(Debug, thiserror::Error)]
428pub enum SqlxConnectionError {
429    #[error("Failed to create SQLx connect options")]
430    Options(#[from] SqlxOptionsError),
431
432    #[error("Failed to connect to database")]
433    Connect(#[source] sqlx::Error),
434
435    #[error("Failed to close database connection")]
436    Close(#[source] sqlx::Error),
437}
438
439#[derive(Clone, Debug, PartialEq, Eq)]
440/// PG connection config with various presentation modes.
441///
442/// Supported:
443///
444/// 1. Env variables via `to_pg_env()`
445/// 2. JSON document via `serde`
446/// 3. sqlx connect options via `to_sqlx_connect_options()`
447/// 4. Individual field access
448pub struct Config {
449    pub application_name: Option<ApplicationName>,
450    pub database: Database,
451    pub endpoint: Endpoint,
452    pub password: Option<Password>,
453    pub ssl_mode: SslMode,
454    pub ssl_root_cert: Option<SslRootCert>,
455    pub username: Username,
456}
457
458impl serde::Serialize for Config {
459    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
460        use serde::ser::SerializeStruct;
461        let mut state = serializer.serialize_struct("Config", 8)?;
462
463        if let Some(application_name) = &self.application_name {
464            state.serialize_field("application_name", application_name)?;
465        }
466
467        state.serialize_field("database", &self.database)?;
468        state.serialize_field("endpoint", &self.endpoint)?;
469
470        if let Some(password) = &self.password {
471            state.serialize_field("password", password)?;
472        }
473
474        state.serialize_field("ssl_mode", &self.ssl_mode)?;
475
476        if let Some(ssl_root_cert) = &self.ssl_root_cert {
477            state.serialize_field("ssl_root_cert", ssl_root_cert)?;
478        }
479
480        state.serialize_field("username", &self.username)?;
481        state.serialize_field("url", &self.to_url())?;
482
483        state.end()
484    }
485}
486
487impl Config {
488    /// Convert to PG connection URL
489    ///
490    /// ```
491    /// # use pg_client::*;
492    /// # use std::str::FromStr;
493    /// # use url::Url;
494    ///
495    /// let config = Config {
496    ///     application_name: None,
497    ///     database: Database::from_str("some-database").unwrap(),
498    ///     endpoint: Endpoint::Network {
499    ///         host: Host::from_str("some-host").unwrap(),
500    ///         host_addr: None,
501    ///         port: Some(Port(5432)),
502    ///     },
503    ///     password: None,
504    ///     ssl_mode: SslMode::VerifyFull,
505    ///     ssl_root_cert: None,
506    ///     username: Username::from_str("some-username").unwrap(),
507    /// };
508    ///
509    /// let options = config.to_sqlx_connect_options();
510    ///
511    /// assert_eq!(
512    ///     Url::parse(
513    ///         "postgres://some-username@some-host:5432/some-database?sslmode=verify-full"
514    ///     ).unwrap(),
515    ///     config.to_url()
516    /// );
517    ///
518    /// assert_eq!(
519    ///     Url::parse(
520    ///         "postgres://some-username:some-password@some-host:5432/some-database?application_name=some-app&sslmode=verify-full&sslrootcert=%2Fsome.pem"
521    ///     ).unwrap(),
522    ///     Config {
523    ///         application_name: Some(ApplicationName::from_str("some-app").unwrap()),
524    ///         password: Some(Password::from_str("some-password").unwrap()),
525    ///         ssl_root_cert: Some(SslRootCert::File("/some.pem".into())),
526    ///         ..config.clone()
527    ///     }.to_url()
528    /// );
529    ///
530    /// assert_eq!(
531    ///     Url::parse(
532    ///         "postgres://some-username@some-host:5432/some-database?hostaddr=127.0.0.1&sslmode=verify-full"
533    ///     ).unwrap(),
534    ///     Config {
535    ///         endpoint: Endpoint::Network {
536    ///             host: Host::from_str("some-host").unwrap(),
537    ///             host_addr: Some("127.0.0.1".parse().unwrap()),
538    ///             port: Some(Port(5432)),
539    ///         },
540    ///         ..config.clone()
541    ///     }.to_url()
542    /// );
543    ///
544    /// // IPv4 example
545    /// let ipv4_config = Config {
546    ///     application_name: None,
547    ///     database: Database::from_str("mydb").unwrap(),
548    ///     endpoint: Endpoint::Network {
549    ///         host: Host::IpAddr(std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1))),
550    ///         host_addr: None,
551    ///         port: Some(Port(5432)),
552    ///     },
553    ///     password: None,
554    ///     ssl_mode: SslMode::Disable,
555    ///     ssl_root_cert: None,
556    ///     username: Username::from_str("user").unwrap(),
557    /// };
558    /// assert_eq!(
559    ///     ipv4_config.to_url().to_string(),
560    ///     "postgres://user@127.0.0.1:5432/mydb?sslmode=disable"
561    /// );
562    ///
563    /// // IPv6 example (automatically bracketed)
564    /// let ipv6_config = Config {
565    ///     application_name: None,
566    ///     database: Database::from_str("mydb").unwrap(),
567    ///     endpoint: Endpoint::Network {
568    ///         host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)),
569    ///         host_addr: None,
570    ///         port: Some(Port(5432)),
571    ///     },
572    ///     password: None,
573    ///     ssl_mode: SslMode::Disable,
574    ///     ssl_root_cert: None,
575    ///     username: Username::from_str("user").unwrap(),
576    /// };
577    /// assert_eq!(
578    ///     ipv6_config.to_url().to_string(),
579    ///     "postgres://user@[::1]:5432/mydb?sslmode=disable"
580    /// );
581    /// ```
582    pub fn to_url(&self) -> url::Url {
583        let mut url = url::Url::parse("postgres://").unwrap();
584
585        match &self.endpoint {
586            Endpoint::Network {
587                host,
588                host_addr,
589                port,
590            } => {
591                // Use set_ip_host for IP addresses to handle IPv6 bracketing automatically
592                match host {
593                    Host::IpAddr(ip_addr) => {
594                        url.set_ip_host(*ip_addr).unwrap();
595                    }
596                    Host::HostName(hostname) => {
597                        url.set_host(Some(hostname.as_str())).unwrap();
598                    }
599                }
600                url.set_username(self.username.to_pg_env_value().as_str())
601                    .unwrap();
602
603                if let Some(password) = &self.password {
604                    url.set_password(Some(password.as_str())).unwrap();
605                }
606
607                if let Some(port) = port {
608                    url.set_port(Some(port.0)).unwrap();
609                }
610
611                url.set_path(self.database.as_str());
612
613                // host_addr has no dedicated URL component
614                if let Some(addr) = host_addr {
615                    url.query_pairs_mut()
616                        .append_pair("hostaddr", &addr.to_string());
617                }
618            }
619            Endpoint::SocketPath(path) => {
620                // Socket paths require query parameters (no dedicated URL components without a network host)
621                url.query_pairs_mut()
622                    .append_pair(
623                        "host",
624                        path.to_str().expect("socket path contains invalid utf8"),
625                    )
626                    .append_pair("dbname", self.database.as_str())
627                    .append_pair("user", self.username.to_pg_env_value().as_str());
628
629                if let Some(password) = &self.password {
630                    url.query_pairs_mut()
631                        .append_pair("password", password.as_str());
632                }
633            }
634        }
635
636        {
637            let mut pairs = url.query_pairs_mut();
638
639            if let Some(application_name) = &self.application_name {
640                pairs.append_pair("application_name", application_name.as_str());
641            }
642
643            pairs.append_pair("sslmode", &self.ssl_mode.to_pg_env_value());
644
645            if let Some(ssl_root_cert) = &self.ssl_root_cert {
646                pairs.append_pair("sslrootcert", &ssl_root_cert.to_pg_env_value());
647            }
648        }
649
650        url
651    }
652
653    /// Convert to PG environment variable names
654    ///
655    /// ```
656    /// # use pg_client::*;
657    /// # use std::collections::BTreeMap;
658    ///
659    /// let config = Config {
660    ///     application_name: None,
661    ///     database: "some-database".parse().unwrap(),
662    ///     endpoint: Endpoint::Network {
663    ///         host: "some-host".parse().unwrap(),
664    ///         host_addr: None,
665    ///         port: Some(Port(5432)),
666    ///     },
667    ///     password: None,
668    ///     ssl_mode: SslMode::VerifyFull,
669    ///     ssl_root_cert: None,
670    ///     username: "some-username".parse().unwrap(),
671    /// };
672    ///
673    /// let expected = BTreeMap::from([
674    ///     ("PGDATABASE", "some-database".to_string()),
675    ///     ("PGHOST", "some-host".to_string()),
676    ///     ("PGPORT", "5432".to_string()),
677    ///     ("PGSSLMODE", "verify-full".to_string()),
678    ///     ("PGUSER", "some-username".to_string()),
679    /// ]);
680    ///
681    /// assert_eq!(expected, config.to_pg_env());
682    ///
683    /// let config_with_optionals = Config {
684    ///     application_name: Some("some-app".parse().unwrap()),
685    ///     endpoint: Endpoint::Network {
686    ///         host: "some-host".parse().unwrap(),
687    ///         host_addr: Some("127.0.0.1".parse().unwrap()),
688    ///         port: Some(Port(5432)),
689    ///     },
690    ///     password: Some("some-password".parse().unwrap()),
691    ///     ssl_root_cert: Some(SslRootCert::File("/some.pem".into())),
692    ///     ..config
693    /// };
694    ///
695    /// let expected = BTreeMap::from([
696    ///     ("PGAPPNAME", "some-app".to_string()),
697    ///     ("PGDATABASE", "some-database".to_string()),
698    ///     ("PGHOST", "some-host".to_string()),
699    ///     ("PGHOSTADDR", "127.0.0.1".to_string()),
700    ///     ("PGPASSWORD", "some-password".to_string()),
701    ///     ("PGPORT", "5432".to_string()),
702    ///     ("PGSSLMODE", "verify-full".to_string()),
703    ///     ("PGSSLROOTCERT", "/some.pem".to_string()),
704    ///     ("PGUSER", "some-username".to_string()),
705    /// ]);
706    ///
707    /// assert_eq!(expected, config_with_optionals.to_pg_env());
708    /// ```
709    pub fn to_pg_env(&self) -> std::collections::BTreeMap<&'static str, String> {
710        let mut map = std::collections::BTreeMap::new();
711
712        match &self.endpoint {
713            Endpoint::Network {
714                host,
715                host_addr,
716                port,
717            } => {
718                map.insert("PGHOST", host.to_pg_env_value());
719                if let Some(port) = port {
720                    map.insert("PGPORT", port.to_pg_env_value());
721                }
722                if let Some(addr) = host_addr {
723                    map.insert("PGHOSTADDR", addr.to_string());
724                }
725            }
726            Endpoint::SocketPath(path) => {
727                map.insert(
728                    "PGHOST",
729                    path.to_str()
730                        .expect("socket path contains invalid utf8")
731                        .to_string(),
732                );
733            }
734        }
735
736        map.insert("PGSSLMODE", self.ssl_mode.to_pg_env_value());
737        map.insert("PGUSER", self.username.to_pg_env_value());
738        map.insert("PGDATABASE", self.database.to_pg_env_value());
739
740        if let Some(application_name) = &self.application_name {
741            map.insert("PGAPPNAME", application_name.to_pg_env_value());
742        }
743
744        if let Some(password) = &self.password {
745            map.insert("PGPASSWORD", password.to_pg_env_value());
746        }
747
748        if let Some(ssl_root_cert) = &self.ssl_root_cert {
749            map.insert("PGSSLROOTCERT", ssl_root_cert.to_pg_env_value());
750        }
751
752        map
753    }
754
755    /// Convert to an sqlx pg connection config
756    ///
757    /// ```
758    /// # use pg_client::*;
759    /// # use std::str::FromStr;
760    ///
761    /// let config = Config {
762    ///     application_name: Some(ApplicationName::from_str("some-app").unwrap()),
763    ///     database: Database::from_str("some-database").unwrap(),
764    ///     endpoint: Endpoint::Network {
765    ///         host: Host::from_str("some-host").unwrap(),
766    ///         host_addr: None,
767    ///         port: Some(Port(5432)),
768    ///     },
769    ///     password: Some(Password::from_str("some-password").unwrap()),
770    ///     ssl_mode: SslMode::VerifyFull,
771    ///     ssl_root_cert: Some(SslRootCert::File("/some.pem".into())),
772    ///     username: Username::from_str("some-username").unwrap(),
773    /// };
774    ///
775    /// let options = config.to_sqlx_connect_options().unwrap();
776    ///
777    /// // `PgConnectOptions` does not have `PartialEq` and only partial getters
778    /// // so we can only assert a few fields.
779    /// assert_eq!(Some("some-app"), options.get_application_name());
780    /// assert_eq!("some-host", options.get_host());
781    /// assert_eq!(5432, options.get_port());
782    /// assert_eq!("some-username", options.get_username());
783    /// // No PartialEQ instance
784    /// assert_eq!(format!("{:#?}", sqlx::postgres::PgSslMode::VerifyFull), format!("{:#?}", options.get_ssl_mode()));
785    /// assert_eq!(Some("some-database"), options.get_database());
786    /// // Unsupported.
787    /// // assert_eq!("some-password", options.get_password());
788    /// // assert_eq!("/some.pem", options.get_ssl_root_cert());
789    /// ```
790    ///
791    /// # Errors
792    ///
793    /// Returns an error if fields inferred from the process environment variables
794    /// by `PgConnectOptions::new` contradict the settings in `Config`, and
795    /// there is no public API in `PgConnectOptions` to reset these values.
796    pub fn to_sqlx_connect_options(
797        &self,
798    ) -> Result<sqlx::postgres::PgConnectOptions, SqlxOptionsError> {
799        fn reject_env(env_key: &str, field_name: &str) -> Result<(), SqlxOptionsError> {
800            if std::env::var(env_key).is_ok() {
801                Err(SqlxOptionsError::EnvConflict {
802                    env_key: env_key.to_string(),
803                    field_name: field_name.to_string(),
804                })
805            } else {
806                Ok(())
807            }
808        }
809
810        fn unsupported_env(env_key: &str, field_name: &str) -> Result<(), SqlxOptionsError> {
811            if std::env::var(env_key).is_ok() {
812                Err(SqlxOptionsError::UnsupportedFeature {
813                    env_key: env_key.to_string(),
814                    field_name: field_name.to_string(),
815                })
816            } else {
817                Ok(())
818            }
819        }
820
821        // This is the "least powerful" API available to create a `PgConnectOptions`
822        // instance. Still it does ENV variable snooping and we below try hard to
823        // reset all of that snooped variables.
824        let mut options = sqlx::postgres::PgConnectOptions::new_without_pgpass();
825
826        unsupported_env("PGSSLKEY", "ssl_client_key")?;
827        unsupported_env("PGSSLCERT", "ssl_client_cert")?;
828        unsupported_env("PGOPTIONS", "options")?;
829
830        options = options.database(self.database.as_str());
831
832        match &self.endpoint {
833            Endpoint::Network {
834                host,
835                host_addr,
836                port,
837            } => {
838                options = options.host(&host.to_pg_env_value());
839                if let Some(port) = port {
840                    options = options.port(port.into());
841                } else {
842                    reject_env("PGPORT", "port")?;
843                }
844                if let Some(host_addr) = host_addr {
845                    options = options.host_addr(&host_addr.to_string())
846                } else {
847                    reject_env("PGHOSTADDR", "hostaddr")?;
848                }
849            }
850            Endpoint::SocketPath(path) => {
851                options = options.host(path.to_str().expect("socket path contains invalid utf8"));
852                reject_env("PGPORT", "port")?;
853                reject_env("PGHOSTADDR", "hostaddr")?;
854            }
855        }
856
857        options = options.ssl_mode(self.ssl_mode.to_sqlx_ssl_mode());
858        options = options.username(self.username.as_str());
859
860        if let Some(application_name) = &self.application_name {
861            options = options.application_name(application_name.as_str());
862        } else {
863            reject_env("PGAPPNAME", "application_name")?;
864        }
865
866        if let Some(password) = &self.password {
867            options = options.password(password.as_str());
868        } else {
869            reject_env("PGPASSWORD", "password")?;
870        }
871
872        if let Some(ssl_root_cert) = &self.ssl_root_cert {
873            options = options.ssl_root_cert(ssl_root_cert.to_pg_env_value());
874        } else {
875            reject_env("PGSSLROOTCERT", "ssl_root_cert")?;
876        }
877
878        Ok(options)
879    }
880
881    pub async fn with_sqlx_connection<T, F: AsyncFnMut(&mut sqlx::postgres::PgConnection) -> T>(
882        &self,
883        mut action: F,
884    ) -> Result<T, SqlxConnectionError> {
885        let config = self.to_sqlx_connect_options()?;
886
887        let mut connection = sqlx::ConnectOptions::connect(&config)
888            .await
889            .map_err(SqlxConnectionError::Connect)?;
890
891        let result = action(&mut connection).await;
892
893        sqlx::Connection::close(connection)
894            .await
895            .map_err(SqlxConnectionError::Close)?;
896
897        Ok(result)
898    }
899
900    pub fn endpoint(self, endpoint: Endpoint) -> Self {
901        Self { endpoint, ..self }
902    }
903}
904
905#[cfg(test)]
906mod test {
907    use super::*;
908    use pretty_assertions::assert_eq;
909    use std::str::FromStr;
910
911    fn assert_config(expected: serde_json::Value, config: &Config) {
912        assert_eq!(expected, serde_json::to_value(config).unwrap());
913    }
914
915    fn repeat(char: char, len: usize) -> String {
916        std::iter::repeat_n(char, len).collect()
917    }
918
919    #[test]
920    fn application_name_lt_min_length() {
921        let value = String::new();
922
923        let err = ApplicationName::from_str(&value).expect_err("expected min length failure");
924
925        assert_eq!(
926            err,
927            "ApplicationName byte min length: {min_length} violated, got: {actual}"
928        );
929    }
930
931    #[test]
932    fn application_name_eq_min_length() {
933        let value = repeat('a', 1);
934
935        let application_name =
936            ApplicationName::from_str(&value).expect("expected valid min length value");
937
938        assert_eq!(application_name, ApplicationName(value));
939    }
940
941    #[test]
942    fn application_name_gt_min_length() {
943        let value = repeat('a', 2);
944
945        let application_name =
946            ApplicationName::from_str(&value).expect("expected valid value greater than min");
947
948        assert_eq!(application_name, ApplicationName(value));
949    }
950
951    #[test]
952    fn application_name_lt_max_length() {
953        let value = repeat('a', 62);
954
955        let application_name =
956            ApplicationName::from_str(&value).expect("expected valid value less than max");
957
958        assert_eq!(application_name, ApplicationName(value));
959    }
960
961    #[test]
962    fn application_name_eq_max_length() {
963        let value = repeat('a', 63);
964
965        let application_name =
966            ApplicationName::from_str(&value).expect("expected valid value equal to max");
967
968        assert_eq!(application_name, ApplicationName(value));
969    }
970
971    #[test]
972    fn application_name_gt_max_length() {
973        let value = repeat('a', 64);
974
975        let err = ApplicationName::from_str(&value).expect_err("expected max length failure");
976
977        assert_eq!(
978            err,
979            "ApplicationName byte max length: {max_length} violated, got: {actual}"
980        );
981    }
982
983    #[test]
984    fn application_name_contains_nul() {
985        let value = String::from('\0');
986
987        let err = ApplicationName::from_str(&value).expect_err("expected NUL failure");
988
989        assert_eq!(err, "ApplicationName contains NUL byte");
990    }
991
992    #[test]
993    fn database_lt_min_length() {
994        let value = String::new();
995
996        let err = Database::from_str(&value).expect_err("expected min length failure");
997
998        assert_eq!(
999            err,
1000            "Database byte min length: {min_length} violated, got: {actual}"
1001        );
1002    }
1003
1004    #[test]
1005    fn database_eq_min_length() {
1006        let value = repeat('d', 1);
1007
1008        let database = Database::from_str(&value).expect("expected valid min length value");
1009
1010        assert_eq!(database, Database(value));
1011    }
1012
1013    #[test]
1014    fn database_gt_min_length() {
1015        let value = repeat('d', 2);
1016
1017        let database = Database::from_str(&value).expect("expected valid value greater than min");
1018
1019        assert_eq!(database, Database(value));
1020    }
1021
1022    #[test]
1023    fn database_lt_max_length() {
1024        let value = repeat('d', 62);
1025
1026        let database = Database::from_str(&value).expect("expected valid value less than max");
1027
1028        assert_eq!(database, Database(value));
1029    }
1030
1031    #[test]
1032    fn database_eq_max_length() {
1033        let value = repeat('d', 63);
1034
1035        let database = Database::from_str(&value).expect("expected valid value equal to max");
1036
1037        assert_eq!(database, Database(value));
1038    }
1039
1040    #[test]
1041    fn database_gt_max_length() {
1042        let value = repeat('d', 64);
1043
1044        let err = Database::from_str(&value).expect_err("expected max length failure");
1045
1046        assert_eq!(
1047            err,
1048            "Database byte max length: {max_length} violated, got: {actual}"
1049        );
1050    }
1051
1052    #[test]
1053    fn database_contains_nul() {
1054        let value = String::from('\0');
1055
1056        let err = Database::from_str(&value).expect_err("expected NUL failure");
1057
1058        assert_eq!(err, "Database contains NUL byte");
1059    }
1060
1061    #[test]
1062    fn username_lt_min_length() {
1063        let value = String::new();
1064
1065        let err = Username::from_str(&value).expect_err("expected min length failure");
1066
1067        assert_eq!(
1068            err,
1069            "Username byte min length: {min_length} violated, got: {actual}"
1070        );
1071    }
1072
1073    #[test]
1074    fn username_eq_min_length() {
1075        let value = repeat('u', 1);
1076
1077        let username = Username::from_str(&value).expect("expected valid min length value");
1078
1079        assert_eq!(username, Username(value));
1080    }
1081
1082    #[test]
1083    fn username_gt_min_length() {
1084        let value = repeat('u', 2);
1085
1086        let username = Username::from_str(&value).expect("expected valid value greater than min");
1087
1088        assert_eq!(username, Username(value));
1089    }
1090
1091    #[test]
1092    fn username_lt_max_length() {
1093        let value = repeat('u', 62);
1094
1095        let username = Username::from_str(&value).expect("expected valid value less than max");
1096
1097        assert_eq!(username, Username(value));
1098    }
1099
1100    #[test]
1101    fn username_eq_max_length() {
1102        let value = repeat('u', 63);
1103
1104        let username = Username::from_str(&value).expect("expected valid value equal to max");
1105
1106        assert_eq!(username, Username(value));
1107    }
1108
1109    #[test]
1110    fn username_gt_max_length() {
1111        let value = repeat('u', 64);
1112
1113        let err = Username::from_str(&value).expect_err("expected max length failure");
1114
1115        assert_eq!(
1116            err,
1117            "Username byte max length: {max_length} violated, got: {actual}"
1118        );
1119    }
1120
1121    #[test]
1122    fn username_contains_nul() {
1123        let value = String::from('\0');
1124
1125        let err = Username::from_str(&value).expect_err("expected NUL failure");
1126
1127        assert_eq!(err, "Username contains NUL byte");
1128    }
1129
1130    #[test]
1131    fn password_eq_min_length() {
1132        let value = String::new();
1133
1134        let password = Password::from_str(&value).expect("expected valid min length value");
1135
1136        assert_eq!(password, Password(value));
1137    }
1138
1139    #[test]
1140    fn password_gt_min_length() {
1141        let value = repeat('p', 1);
1142
1143        let password = Password::from_str(&value).expect("expected valid value greater than min");
1144
1145        assert_eq!(password, Password(value));
1146    }
1147
1148    #[test]
1149    fn password_lt_max_length() {
1150        let value = repeat('p', 4095);
1151
1152        let password = Password::from_str(&value).expect("expected valid value less than max");
1153
1154        assert_eq!(password, Password(value));
1155    }
1156
1157    #[test]
1158    fn password_eq_max_length() {
1159        let value = repeat('p', 4096);
1160
1161        let password = Password::from_str(&value).expect("expected valid value equal to max");
1162
1163        assert_eq!(password, Password(value));
1164    }
1165
1166    #[test]
1167    fn password_gt_max_length() {
1168        let value = repeat('p', 4097);
1169
1170        let err = Password::from_str(&value).expect_err("expected max length failure");
1171
1172        assert_eq!(
1173            err,
1174            "Password byte max length: {max_length} violated, got: {actual}"
1175        );
1176    }
1177
1178    #[test]
1179    fn password_contains_nul() {
1180        let value = String::from('\0');
1181
1182        let err = Password::from_str(&value).expect_err("expected NUL failure");
1183
1184        assert_eq!(err, "Password contains NUL byte");
1185    }
1186
1187    #[test]
1188    fn test_json() {
1189        let config = Config {
1190            application_name: None,
1191            database: Database::from_str("some-database").unwrap(),
1192            endpoint: Endpoint::Network {
1193                host: Host::from_str("some-host").unwrap(),
1194                host_addr: None,
1195                port: Some(Port(5432)),
1196            },
1197            password: None,
1198            ssl_mode: SslMode::VerifyFull,
1199            ssl_root_cert: None,
1200            username: Username::from_str("some-username").unwrap(),
1201        };
1202
1203        assert_config(
1204            serde_json::json!({
1205                "database": "some-database",
1206                "endpoint": {
1207                    "host": "some-host",
1208                    "port": 5432,
1209                },
1210                "ssl_mode": "verify-full",
1211                "url": "postgres://some-username@some-host:5432/some-database?sslmode=verify-full",
1212                "username": "some-username",
1213            }),
1214            &config,
1215        );
1216
1217        assert_config(
1218            serde_json::json!({
1219                "application_name": "some-app",
1220                "database": "some-database",
1221                "endpoint": {
1222                    "host": "some-host",
1223                    "port": 5432,
1224                },
1225                "password": "some-password",
1226                "ssl_mode": "verify-full",
1227                "ssl_root_cert": {
1228                    "file": "/some.pem"
1229                },
1230                "url": "postgres://some-username:some-password@some-host:5432/some-database?application_name=some-app&sslmode=verify-full&sslrootcert=%2Fsome.pem",
1231                "username": "some-username"
1232            }),
1233            &Config {
1234                application_name: Some(ApplicationName::from_str("some-app").unwrap()),
1235                password: Some(Password::from_str("some-password").unwrap()),
1236                ssl_root_cert: Some(SslRootCert::File("/some.pem".into())),
1237                ..config.clone()
1238            },
1239        );
1240
1241        assert_config(
1242            serde_json::json!({
1243                "database": "some-database",
1244                "endpoint": {
1245                    "host": "127.0.0.1",
1246                    "port": 5432,
1247                },
1248                "ssl_mode": "verify-full",
1249                "url": "postgres://some-username@127.0.0.1:5432/some-database?sslmode=verify-full",
1250                "username": "some-username"
1251            }),
1252            &Config {
1253                endpoint: Endpoint::Network {
1254                    host: Host::from_str("127.0.0.1").unwrap(),
1255                    host_addr: None,
1256                    port: Some(Port(5432)),
1257                },
1258                ..config.clone()
1259            },
1260        );
1261
1262        assert_config(
1263            serde_json::json!({
1264                "database": "some-database",
1265                "endpoint": {
1266                    "socket_path": "/some/socket",
1267                },
1268                "ssl_mode": "verify-full",
1269                "url": "postgres://?host=%2Fsome%2Fsocket&dbname=some-database&user=some-username&sslmode=verify-full",
1270                "username": "some-username"
1271            }),
1272            &Config {
1273                endpoint: Endpoint::SocketPath("/some/socket".into()),
1274                ..config.clone()
1275            },
1276        );
1277
1278        assert_config(
1279            serde_json::json!({
1280                "database": "some-database",
1281                "endpoint": {
1282                    "host": "some-host",
1283                    "port": 5432,
1284                },
1285                "ssl_mode": "verify-full",
1286                "ssl_root_cert": "system",
1287                "url": "postgres://some-username@some-host:5432/some-database?sslmode=verify-full&sslrootcert=system",
1288                "username": "some-username"
1289            }),
1290            &Config {
1291                ssl_root_cert: Some(SslRootCert::System),
1292                ..config.clone()
1293            },
1294        );
1295
1296        assert_config(
1297            serde_json::json!({
1298                "database": "some-database",
1299                "endpoint": {
1300                    "host": "some-host",
1301                    "host_addr": "192.168.1.100",
1302                    "port": 5432,
1303                },
1304                "ssl_mode": "verify-full",
1305                "url": "postgres://some-username@some-host:5432/some-database?hostaddr=192.168.1.100&sslmode=verify-full",
1306                "username": "some-username"
1307            }),
1308            &Config {
1309                endpoint: Endpoint::Network {
1310                    host: Host::from_str("some-host").unwrap(),
1311                    host_addr: Some("192.168.1.100".parse().unwrap()),
1312                    port: Some(Port(5432)),
1313                },
1314                ..config.clone()
1315            },
1316        );
1317
1318        // Test Network endpoint without port (should use default)
1319        assert_config(
1320            serde_json::json!({
1321                "database": "some-database",
1322                "endpoint": {
1323                    "host": "some-host",
1324                },
1325                "ssl_mode": "verify-full",
1326                "url": "postgres://some-username@some-host/some-database?sslmode=verify-full",
1327                "username": "some-username"
1328            }),
1329            &Config {
1330                endpoint: Endpoint::Network {
1331                    host: Host::from_str("some-host").unwrap(),
1332                    host_addr: None,
1333                    port: None,
1334                },
1335                ..config.clone()
1336            },
1337        );
1338
1339        // Test Network endpoint with host_addr but without port
1340        assert_config(
1341            serde_json::json!({
1342                "database": "some-database",
1343                "endpoint": {
1344                    "host": "some-host",
1345                    "host_addr": "10.0.0.1",
1346                },
1347                "ssl_mode": "verify-full",
1348                "url": "postgres://some-username@some-host/some-database?hostaddr=10.0.0.1&sslmode=verify-full",
1349                "username": "some-username"
1350            }),
1351            &Config {
1352                endpoint: Endpoint::Network {
1353                    host: Host::from_str("some-host").unwrap(),
1354                    host_addr: Some("10.0.0.1".parse().unwrap()),
1355                    port: None,
1356                },
1357                ..config.clone()
1358            },
1359        );
1360    }
1361
1362    #[test]
1363    fn test_ipv6_url_formation() {
1364        // Test IPv6 loopback address
1365        let config_ipv6_loopback = Config {
1366            application_name: None,
1367            database: Database::from_str("testdb").unwrap(),
1368            endpoint: Endpoint::Network {
1369                host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)),
1370                host_addr: None,
1371                port: Some(Port(5432)),
1372            },
1373            password: None,
1374            ssl_mode: SslMode::Disable,
1375            ssl_root_cert: None,
1376            username: Username::from_str("postgres").unwrap(),
1377        };
1378
1379        let url = config_ipv6_loopback.to_url();
1380        assert_eq!(
1381            url.to_string(),
1382            "postgres://postgres@[::1]:5432/testdb?sslmode=disable",
1383            "IPv6 loopback address should be bracketed in URL"
1384        );
1385
1386        // Test fe80 link-local IPv6 address
1387        let config_ipv6_fe80 = Config {
1388            application_name: None,
1389            database: Database::from_str("testdb").unwrap(),
1390            endpoint: Endpoint::Network {
1391                host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::new(
1392                    0xfe80, 0, 0, 0, 0, 0, 0, 1,
1393                ))),
1394                host_addr: None,
1395                port: Some(Port(5432)),
1396            },
1397            password: None,
1398            ssl_mode: SslMode::Disable,
1399            ssl_root_cert: None,
1400            username: Username::from_str("postgres").unwrap(),
1401        };
1402
1403        let url = config_ipv6_fe80.to_url();
1404        assert_eq!(
1405            url.to_string(),
1406            "postgres://postgres@[fe80::1]:5432/testdb?sslmode=disable",
1407            "IPv6 link-local address should be bracketed in URL"
1408        );
1409
1410        // Test full IPv6 address
1411        let config_ipv6_full = Config {
1412            application_name: None,
1413            database: Database::from_str("testdb").unwrap(),
1414            endpoint: Endpoint::Network {
1415                host: Host::IpAddr(std::net::IpAddr::V6(std::net::Ipv6Addr::new(
1416                    0x2001, 0x0db8, 0, 0, 0, 0, 0, 1,
1417                ))),
1418                host_addr: None,
1419                port: Some(Port(5432)),
1420            },
1421            password: None,
1422            ssl_mode: SslMode::Disable,
1423            ssl_root_cert: None,
1424            username: Username::from_str("postgres").unwrap(),
1425        };
1426
1427        let url = config_ipv6_full.to_url();
1428        assert_eq!(
1429            url.to_string(),
1430            "postgres://postgres@[2001:db8::1]:5432/testdb?sslmode=disable",
1431            "Full IPv6 address should be bracketed in URL"
1432        );
1433
1434        // Test IPv4 address (should NOT be bracketed)
1435        let config_ipv4 = Config {
1436            application_name: None,
1437            database: Database::from_str("testdb").unwrap(),
1438            endpoint: Endpoint::Network {
1439                host: Host::IpAddr(std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)),
1440                host_addr: None,
1441                port: Some(Port(5432)),
1442            },
1443            password: None,
1444            ssl_mode: SslMode::Disable,
1445            ssl_root_cert: None,
1446            username: Username::from_str("postgres").unwrap(),
1447        };
1448
1449        let url = config_ipv4.to_url();
1450        assert_eq!(
1451            url.to_string(),
1452            "postgres://postgres@127.0.0.1:5432/testdb?sslmode=disable",
1453            "IPv4 address should NOT be bracketed in URL"
1454        );
1455
1456        // Test hostname (should NOT be bracketed)
1457        let config_hostname = Config {
1458            application_name: None,
1459            database: Database::from_str("testdb").unwrap(),
1460            endpoint: Endpoint::Network {
1461                host: Host::from_str("localhost").unwrap(),
1462                host_addr: None,
1463                port: Some(Port(5432)),
1464            },
1465            password: None,
1466            ssl_mode: SslMode::Disable,
1467            ssl_root_cert: None,
1468            username: Username::from_str("postgres").unwrap(),
1469        };
1470
1471        let url = config_hostname.to_url();
1472        assert_eq!(
1473            url.to_string(),
1474            "postgres://postgres@localhost:5432/testdb?sslmode=disable",
1475            "Hostname should NOT be bracketed in URL"
1476        );
1477    }
1478}