Skip to main content

pg_client/
config.rs

1use crate::identifier::{Database, Role, User};
2
3/// Macro to generate `std::str::FromStr` plus helpers for string wrapped newtypes,
4/// along with a typed parse-error enum.
5macro_rules! from_str_impl {
6    ($struct: ident, $err: ident, $min: expr, $max: expr) => {
7        #[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
8        pub enum $err {
9            #[error("{} byte min length: {min} violated, got: {actual}", stringify!($struct))]
10            TooShort { min: usize, actual: usize },
11            #[error("{} byte max length: {max} violated, got: {actual}", stringify!($struct))]
12            TooLong { max: usize, actual: usize },
13            #[error("{} contains NUL byte", stringify!($struct))]
14            ContainsNul,
15        }
16
17        impl std::str::FromStr for $struct {
18            type Err = $err;
19
20            fn from_str(value: &str) -> Result<Self, Self::Err> {
21                let actual = value.len();
22
23                if actual < Self::MIN_LENGTH {
24                    Err($err::TooShort {
25                        min: Self::MIN_LENGTH,
26                        actual,
27                    })
28                } else if actual > Self::MAX_LENGTH {
29                    Err($err::TooLong {
30                        max: Self::MAX_LENGTH,
31                        actual,
32                    })
33                } else if value.as_bytes().contains(&0) {
34                    Err($err::ContainsNul)
35                } else {
36                    Ok(Self(value.to_string()))
37                }
38            }
39        }
40
41        impl AsRef<str> for $struct {
42            fn as_ref(&self) -> &str {
43                &self.0
44            }
45        }
46
47        impl $struct {
48            pub const MIN_LENGTH: usize = $min;
49            pub const MAX_LENGTH: usize = $max;
50
51            pub fn as_str(&self) -> &str {
52                &self.0
53            }
54        }
55    };
56}
57
58#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
59pub struct HostName(String);
60
61impl HostName {
62    #[must_use]
63    pub fn as_str(&self) -> &str {
64        &self.0
65    }
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
69#[error("invalid host name")]
70pub struct HostNameParseError;
71
72impl std::str::FromStr for HostName {
73    type Err = HostNameParseError;
74
75    fn from_str(value: &str) -> Result<Self, Self::Err> {
76        if hostname_validator::is_valid(value) {
77            Ok(Self(value.to_string()))
78        } else {
79            Err(HostNameParseError)
80        }
81    }
82}
83
84impl<'de> serde::Deserialize<'de> for HostName {
85    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
86    where
87        D: serde::Deserializer<'de>,
88    {
89        let s = String::deserialize(deserializer)?;
90        s.parse().map_err(serde::de::Error::custom)
91    }
92}
93
94#[derive(Clone, Debug, PartialEq, Eq)]
95pub enum Host {
96    HostName(HostName),
97    IpAddr(std::net::IpAddr),
98}
99
100impl serde::Serialize for Host {
101    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
102        serializer.serialize_str(&self.pg_env_value())
103    }
104}
105
106impl Host {
107    pub(crate) fn pg_env_value(&self) -> String {
108        match self {
109            Self::HostName(value) => value.0.clone(),
110            Self::IpAddr(value) => value.to_string(),
111        }
112    }
113}
114
115#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
116#[error("Not a socket address or FQDN")]
117pub struct HostParseError;
118
119impl std::str::FromStr for Host {
120    type Err = HostParseError;
121
122    fn from_str(value: &str) -> Result<Self, Self::Err> {
123        match std::net::IpAddr::from_str(value) {
124            Ok(addr) => Ok(Self::IpAddr(addr)),
125            Err(_) => match HostName::from_str(value) {
126                Ok(host_name) => Ok(Self::HostName(host_name)),
127                Err(_) => Err(HostParseError),
128            },
129        }
130    }
131}
132
133impl From<HostName> for Host {
134    fn from(value: HostName) -> Self {
135        Self::HostName(value)
136    }
137}
138
139impl From<std::net::IpAddr> for Host {
140    fn from(value: std::net::IpAddr) -> Self {
141        Self::IpAddr(value)
142    }
143}
144
145#[derive(Clone, Debug, PartialEq, Eq)]
146pub struct HostAddr(std::net::IpAddr);
147
148impl HostAddr {
149    #[must_use]
150    pub const fn new(ip: std::net::IpAddr) -> Self {
151        Self(ip)
152    }
153}
154
155impl From<std::net::IpAddr> for HostAddr {
156    /// # Example
157    /// ```
158    /// use pg_client::config::HostAddr;
159    /// use std::net::IpAddr;
160    ///
161    /// let ip: IpAddr = "192.168.1.1".parse().unwrap();
162    /// let host_addr = HostAddr::from(ip);
163    /// assert_eq!(IpAddr::from(host_addr).to_string(), "192.168.1.1");
164    /// ```
165    fn from(value: std::net::IpAddr) -> Self {
166        Self(value)
167    }
168}
169
170impl From<HostAddr> for std::net::IpAddr {
171    fn from(value: HostAddr) -> Self {
172        value.0
173    }
174}
175
176impl From<&HostAddr> for std::net::IpAddr {
177    fn from(value: &HostAddr) -> Self {
178        value.0
179    }
180}
181
182impl std::fmt::Display for HostAddr {
183    /// # Example
184    /// ```
185    /// use pg_client::config::HostAddr;
186    ///
187    /// let host_addr: HostAddr = "10.0.0.1".parse().unwrap();
188    /// assert_eq!(host_addr.to_string(), "10.0.0.1");
189    /// ```
190    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191        write!(formatter, "{}", self.0)
192    }
193}
194
195#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
196#[error("invalid IP address")]
197pub struct HostAddrParseError;
198
199impl std::str::FromStr for HostAddr {
200    type Err = HostAddrParseError;
201
202    /// # Example
203    /// ```
204    /// use pg_client::config::HostAddr;
205    /// use std::str::FromStr;
206    ///
207    /// let host_addr = HostAddr::from_str("127.0.0.1").unwrap();
208    /// assert_eq!(host_addr.to_string(), "127.0.0.1");
209    ///
210    /// // Also works with the parse method
211    /// let host_addr: HostAddr = "::1".parse().unwrap();
212    /// assert_eq!(host_addr.to_string(), "::1");
213    ///
214    /// // Invalid IP addresses return an error
215    /// assert!(HostAddr::from_str("not-an-ip").is_err());
216    /// ```
217    fn from_str(value: &str) -> Result<Self, Self::Err> {
218        match std::net::IpAddr::from_str(value) {
219            Ok(addr) => Ok(Self(addr)),
220            Err(_) => Err(HostAddrParseError),
221        }
222    }
223}
224
225#[derive(Clone, Debug, PartialEq, Eq)]
226pub enum Endpoint {
227    Network {
228        host: Host,
229        channel_binding: Option<ChannelBinding>,
230        host_addr: Option<HostAddr>,
231        port: Option<Port>,
232    },
233    SocketPath(std::path::PathBuf),
234}
235
236impl serde::Serialize for Endpoint {
237    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
238        use serde::ser::SerializeStruct;
239        match self {
240            Self::Network {
241                host,
242                channel_binding,
243                host_addr,
244                port,
245            } => {
246                let mut state = serializer.serialize_struct("Endpoint", 4)?;
247                state.serialize_field("host", host)?;
248                if let Some(channel_binding) = channel_binding {
249                    state.serialize_field("channel_binding", channel_binding)?;
250                }
251                if let Some(addr) = host_addr {
252                    state.serialize_field("host_addr", &addr.to_string())?;
253                }
254                if let Some(port) = port {
255                    state.serialize_field("port", port)?;
256                }
257                state.end()
258            }
259            Self::SocketPath(path) => {
260                let mut state = serializer.serialize_struct("Endpoint", 1)?;
261                state.serialize_field(
262                    "socket_path",
263                    &path.to_str().expect("socket path contains invalid utf8"),
264                )?;
265                state.end()
266            }
267        }
268    }
269}
270
271#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize)]
272pub struct Port(u16);
273
274impl Port {
275    #[must_use]
276    pub const fn new(port: u16) -> Self {
277        Self(port)
278    }
279
280    pub(crate) fn pg_env_value(self) -> String {
281        self.0.to_string()
282    }
283}
284
285#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
286#[error("invalid postgresql port string")]
287pub struct PortParseError;
288
289impl std::str::FromStr for Port {
290    type Err = PortParseError;
291
292    fn from_str(value: &str) -> Result<Self, Self::Err> {
293        match <u16 as std::str::FromStr>::from_str(value) {
294            Ok(port) => Ok(Port(port)),
295            Err(_) => Err(PortParseError),
296        }
297    }
298}
299
300impl From<u16> for Port {
301    fn from(port: u16) -> Self {
302        Self(port)
303    }
304}
305
306impl From<Port> for u16 {
307    fn from(port: Port) -> Self {
308        port.0
309    }
310}
311
312impl From<&Port> for u16 {
313    fn from(port: &Port) -> Self {
314        port.0
315    }
316}
317
318#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
319pub struct ApplicationName(String);
320
321from_str_impl!(ApplicationName, ApplicationNameParseError, 1, 63);
322
323impl ApplicationName {
324    pub(crate) fn pg_env_value(&self) -> String {
325        self.0.clone()
326    }
327}
328
329impl Database {
330    pub(crate) fn pg_env_value(&self) -> String {
331        self.as_str().to_owned()
332    }
333}
334
335impl Role {
336    pub(crate) fn pg_env_value(&self) -> String {
337        self.as_str().to_owned()
338    }
339}
340
341#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
342pub struct Password(String);
343
344from_str_impl!(Password, PasswordParseError, 0, 4096);
345
346impl Password {
347    pub(crate) fn pg_env_value(&self) -> String {
348        self.0.clone()
349    }
350}
351
352#[derive(
353    Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, strum::IntoStaticStr, strum::EnumString,
354)]
355#[serde(rename_all = "kebab-case")]
356#[strum(serialize_all = "kebab-case")]
357pub enum SslMode {
358    Allow,
359    Disable,
360    Prefer,
361    Require,
362    VerifyCa,
363    VerifyFull,
364}
365
366impl SslMode {
367    #[must_use]
368    pub fn as_str(&self) -> &'static str {
369        self.into()
370    }
371
372    pub(crate) fn pg_env_value(&self) -> String {
373        self.as_str().to_string()
374    }
375}
376
377#[derive(
378    Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, strum::IntoStaticStr, strum::EnumString,
379)]
380#[serde(rename_all = "kebab-case")]
381#[strum(serialize_all = "kebab-case")]
382pub enum ChannelBinding {
383    Disable,
384    Prefer,
385    Require,
386}
387
388impl ChannelBinding {
389    #[must_use]
390    pub fn as_str(&self) -> &'static str {
391        self.into()
392    }
393
394    pub(crate) fn pg_env_value(&self) -> String {
395        self.as_str().to_string()
396    }
397}
398
399#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
400#[serde(rename_all = "kebab-case")]
401pub enum SslRootCert {
402    File(std::path::PathBuf),
403    System,
404}
405
406impl SslRootCert {
407    pub(crate) fn pg_env_value(&self) -> String {
408        match self {
409            Self::File(path) => path.to_str().unwrap().to_string(),
410            Self::System => "system".to_string(),
411        }
412    }
413}
414
415impl From<std::path::PathBuf> for SslRootCert {
416    fn from(value: std::path::PathBuf) -> Self {
417        Self::File(value)
418    }
419}
420
421/// Session parameters sent during PostgreSQL connection setup.
422///
423/// These are independent of how the connection is established (TCP, Unix socket, etc.)
424/// and represent what the client identifies as during the startup message.
425#[derive(Clone, Debug, PartialEq, Eq)]
426pub struct Session {
427    pub application_name: Option<ApplicationName>,
428    pub database: Database,
429    pub password: Option<Password>,
430    pub user: User,
431}
432
433#[cfg(test)]
434mod test {
435    use super::*;
436    use pretty_assertions::assert_eq;
437    use std::str::FromStr;
438
439    fn repeat(char: char, len: usize) -> String {
440        std::iter::repeat_n(char, len).collect()
441    }
442
443    #[test]
444    fn application_name_lt_min_length() {
445        let value = String::new();
446
447        let err = ApplicationName::from_str(&value).expect_err("expected min length failure");
448
449        assert_eq!(
450            err,
451            ApplicationNameParseError::TooShort { min: 1, actual: 0 },
452        );
453        assert_eq!(
454            err.to_string(),
455            "ApplicationName byte min length: 1 violated, got: 0",
456        );
457    }
458
459    #[test]
460    fn application_name_eq_min_length() {
461        let value = repeat('a', 1);
462
463        let application_name =
464            ApplicationName::from_str(&value).expect("expected valid min length value");
465
466        assert_eq!(application_name, ApplicationName(value));
467    }
468
469    #[test]
470    fn application_name_gt_min_length() {
471        let value = repeat('a', 2);
472
473        let application_name =
474            ApplicationName::from_str(&value).expect("expected valid value greater than min");
475
476        assert_eq!(application_name, ApplicationName(value));
477    }
478
479    #[test]
480    fn application_name_lt_max_length() {
481        let value = repeat('a', 62);
482
483        let application_name =
484            ApplicationName::from_str(&value).expect("expected valid value less than max");
485
486        assert_eq!(application_name, ApplicationName(value));
487    }
488
489    #[test]
490    fn application_name_eq_max_length() {
491        let value = repeat('a', 63);
492
493        let application_name =
494            ApplicationName::from_str(&value).expect("expected valid value equal to max");
495
496        assert_eq!(application_name, ApplicationName(value));
497    }
498
499    #[test]
500    fn application_name_gt_max_length() {
501        let value = repeat('a', 64);
502
503        let err = ApplicationName::from_str(&value).expect_err("expected max length failure");
504
505        assert_eq!(
506            err,
507            ApplicationNameParseError::TooLong {
508                max: 63,
509                actual: 64,
510            },
511        );
512        assert_eq!(
513            err.to_string(),
514            "ApplicationName byte max length: 63 violated, got: 64",
515        );
516    }
517
518    #[test]
519    fn application_name_contains_nul() {
520        let value = String::from('\0');
521
522        let err = ApplicationName::from_str(&value).expect_err("expected NUL failure");
523
524        assert_eq!(err, ApplicationNameParseError::ContainsNul);
525        assert_eq!(err.to_string(), "ApplicationName contains NUL byte");
526    }
527
528    #[test]
529    fn password_eq_min_length() {
530        let value = String::new();
531
532        let password = Password::from_str(&value).expect("expected valid min length value");
533
534        assert_eq!(password, Password(value));
535    }
536
537    #[test]
538    fn password_gt_min_length() {
539        let value = repeat('p', 1);
540
541        let password = Password::from_str(&value).expect("expected valid value greater than min");
542
543        assert_eq!(password, Password(value));
544    }
545
546    #[test]
547    fn password_lt_max_length() {
548        let value = repeat('p', 4095);
549
550        let password = Password::from_str(&value).expect("expected valid value less than max");
551
552        assert_eq!(password, Password(value));
553    }
554
555    #[test]
556    fn password_eq_max_length() {
557        let value = repeat('p', 4096);
558
559        let password = Password::from_str(&value).expect("expected valid value equal to max");
560
561        assert_eq!(password, Password(value));
562    }
563
564    #[test]
565    fn password_gt_max_length() {
566        let value = repeat('p', 4097);
567
568        let err = Password::from_str(&value).expect_err("expected max length failure");
569
570        assert_eq!(
571            err,
572            PasswordParseError::TooLong {
573                max: 4096,
574                actual: 4097,
575            },
576        );
577        assert_eq!(
578            err.to_string(),
579            "Password byte max length: 4096 violated, got: 4097",
580        );
581    }
582
583    #[test]
584    fn password_contains_nul() {
585        let value = String::from('\0');
586
587        let err = Password::from_str(&value).expect_err("expected NUL failure");
588
589        assert_eq!(err, PasswordParseError::ContainsNul);
590        assert_eq!(err.to_string(), "Password contains NUL byte");
591    }
592}