Skip to main content

pg_client/
config.rs

1use crate::identifier::{Database, Role, User};
2
3/// Macro to generate `std::str::FromStr` plus helpers for string wrapped newtypes
4macro_rules! from_str_impl {
5    ($struct: ident, $min: expr, $max: expr) => {
6        impl std::str::FromStr for $struct {
7            type Err = String;
8
9            fn from_str(value: &str) -> Result<Self, Self::Err> {
10                let min_length = Self::MIN_LENGTH;
11                let max_length = Self::MAX_LENGTH;
12                let actual = value.len();
13
14                if actual < min_length {
15                    Err(format!(
16                        "{} byte min length: {min_length} violated, got: {actual}",
17                        stringify!($struct)
18                    ))
19                } else if actual > max_length {
20                    Err(format!(
21                        "{} byte max length: {max_length} violated, got: {actual}",
22                        stringify!($struct)
23                    ))
24                } else if value.as_bytes().contains(&0) {
25                    Err(format!("{} contains NUL byte", stringify!($struct)))
26                } else {
27                    Ok(Self(value.to_string()))
28                }
29            }
30        }
31
32        impl AsRef<str> for $struct {
33            fn as_ref(&self) -> &str {
34                &self.0
35            }
36        }
37
38        impl $struct {
39            pub const MIN_LENGTH: usize = $min;
40            pub const MAX_LENGTH: usize = $max;
41
42            pub fn as_str(&self) -> &str {
43                &self.0
44            }
45        }
46    };
47}
48
49#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
50pub struct HostName(String);
51
52impl HostName {
53    #[must_use]
54    pub fn as_str(&self) -> &str {
55        &self.0
56    }
57}
58
59impl std::str::FromStr for HostName {
60    type Err = &'static str;
61
62    fn from_str(value: &str) -> Result<Self, Self::Err> {
63        if hostname_validator::is_valid(value) {
64            Ok(Self(value.to_string()))
65        } else {
66            Err("invalid host name")
67        }
68    }
69}
70
71impl<'de> serde::Deserialize<'de> for HostName {
72    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
73    where
74        D: serde::Deserializer<'de>,
75    {
76        let s = String::deserialize(deserializer)?;
77        s.parse().map_err(serde::de::Error::custom)
78    }
79}
80
81#[derive(Clone, Debug, PartialEq, Eq)]
82pub enum Host {
83    HostName(HostName),
84    IpAddr(std::net::IpAddr),
85}
86
87impl serde::Serialize for Host {
88    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
89        serializer.serialize_str(&self.pg_env_value())
90    }
91}
92
93impl Host {
94    pub(crate) fn pg_env_value(&self) -> String {
95        match self {
96            Self::HostName(value) => value.0.clone(),
97            Self::IpAddr(value) => value.to_string(),
98        }
99    }
100}
101
102impl std::str::FromStr for Host {
103    type Err = &'static str;
104
105    fn from_str(value: &str) -> Result<Self, Self::Err> {
106        match std::net::IpAddr::from_str(value) {
107            Ok(addr) => Ok(Self::IpAddr(addr)),
108            Err(_) => match HostName::from_str(value) {
109                Ok(host_name) => Ok(Self::HostName(host_name)),
110                Err(_) => Err("Not a socket address or FQDN"),
111            },
112        }
113    }
114}
115
116impl From<HostName> for Host {
117    fn from(value: HostName) -> Self {
118        Self::HostName(value)
119    }
120}
121
122impl From<std::net::IpAddr> for Host {
123    fn from(value: std::net::IpAddr) -> Self {
124        Self::IpAddr(value)
125    }
126}
127
128#[derive(Clone, Debug, PartialEq, Eq)]
129pub struct HostAddr(std::net::IpAddr);
130
131impl HostAddr {
132    #[must_use]
133    pub const fn new(ip: std::net::IpAddr) -> Self {
134        Self(ip)
135    }
136}
137
138impl From<std::net::IpAddr> for HostAddr {
139    /// # Example
140    /// ```
141    /// use pg_client::config::HostAddr;
142    /// use std::net::IpAddr;
143    ///
144    /// let ip: IpAddr = "192.168.1.1".parse().unwrap();
145    /// let host_addr = HostAddr::from(ip);
146    /// assert_eq!(IpAddr::from(host_addr).to_string(), "192.168.1.1");
147    /// ```
148    fn from(value: std::net::IpAddr) -> Self {
149        Self(value)
150    }
151}
152
153impl From<HostAddr> for std::net::IpAddr {
154    fn from(value: HostAddr) -> Self {
155        value.0
156    }
157}
158
159impl From<&HostAddr> for std::net::IpAddr {
160    fn from(value: &HostAddr) -> Self {
161        value.0
162    }
163}
164
165impl std::fmt::Display for HostAddr {
166    /// # Example
167    /// ```
168    /// use pg_client::config::HostAddr;
169    ///
170    /// let host_addr: HostAddr = "10.0.0.1".parse().unwrap();
171    /// assert_eq!(host_addr.to_string(), "10.0.0.1");
172    /// ```
173    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174        write!(formatter, "{}", self.0)
175    }
176}
177
178impl std::str::FromStr for HostAddr {
179    type Err = &'static str;
180
181    /// # Example
182    /// ```
183    /// use pg_client::config::HostAddr;
184    /// use std::str::FromStr;
185    ///
186    /// let host_addr = HostAddr::from_str("127.0.0.1").unwrap();
187    /// assert_eq!(host_addr.to_string(), "127.0.0.1");
188    ///
189    /// // Also works with the parse method
190    /// let host_addr: HostAddr = "::1".parse().unwrap();
191    /// assert_eq!(host_addr.to_string(), "::1");
192    ///
193    /// // Invalid IP addresses return an error
194    /// assert!(HostAddr::from_str("not-an-ip").is_err());
195    /// ```
196    fn from_str(value: &str) -> Result<Self, Self::Err> {
197        match std::net::IpAddr::from_str(value) {
198            Ok(addr) => Ok(Self(addr)),
199            Err(_) => Err("invalid IP address"),
200        }
201    }
202}
203
204#[derive(Clone, Debug, PartialEq, Eq)]
205pub enum Endpoint {
206    Network {
207        host: Host,
208        channel_binding: Option<ChannelBinding>,
209        host_addr: Option<HostAddr>,
210        port: Option<Port>,
211    },
212    SocketPath(std::path::PathBuf),
213}
214
215impl serde::Serialize for Endpoint {
216    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
217        use serde::ser::SerializeStruct;
218        match self {
219            Self::Network {
220                host,
221                channel_binding,
222                host_addr,
223                port,
224            } => {
225                let mut state = serializer.serialize_struct("Endpoint", 4)?;
226                state.serialize_field("host", host)?;
227                if let Some(channel_binding) = channel_binding {
228                    state.serialize_field("channel_binding", channel_binding)?;
229                }
230                if let Some(addr) = host_addr {
231                    state.serialize_field("host_addr", &addr.to_string())?;
232                }
233                if let Some(port) = port {
234                    state.serialize_field("port", port)?;
235                }
236                state.end()
237            }
238            Self::SocketPath(path) => {
239                let mut state = serializer.serialize_struct("Endpoint", 1)?;
240                state.serialize_field(
241                    "socket_path",
242                    &path.to_str().expect("socket path contains invalid utf8"),
243                )?;
244                state.end()
245            }
246        }
247    }
248}
249
250#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize)]
251pub struct Port(u16);
252
253impl Port {
254    #[must_use]
255    pub const fn new(port: u16) -> Self {
256        Self(port)
257    }
258
259    pub(crate) fn pg_env_value(self) -> String {
260        self.0.to_string()
261    }
262}
263
264impl std::str::FromStr for Port {
265    type Err = &'static str;
266
267    fn from_str(value: &str) -> Result<Self, Self::Err> {
268        match <u16 as std::str::FromStr>::from_str(value) {
269            Ok(port) => Ok(Port(port)),
270            Err(_) => Err("invalid postgresql port string"),
271        }
272    }
273}
274
275impl From<u16> for Port {
276    fn from(port: u16) -> Self {
277        Self(port)
278    }
279}
280
281impl From<Port> for u16 {
282    fn from(port: Port) -> Self {
283        port.0
284    }
285}
286
287impl From<&Port> for u16 {
288    fn from(port: &Port) -> Self {
289        port.0
290    }
291}
292
293#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
294pub struct ApplicationName(String);
295
296from_str_impl!(ApplicationName, 1, 63);
297
298impl ApplicationName {
299    pub(crate) fn pg_env_value(&self) -> String {
300        self.0.clone()
301    }
302}
303
304impl Database {
305    pub(crate) fn pg_env_value(&self) -> String {
306        self.as_str().to_owned()
307    }
308}
309
310impl Role {
311    pub(crate) fn pg_env_value(&self) -> String {
312        self.as_str().to_owned()
313    }
314}
315
316#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
317pub struct Password(String);
318
319from_str_impl!(Password, 0, 4096);
320
321impl Password {
322    pub(crate) fn pg_env_value(&self) -> String {
323        self.0.clone()
324    }
325}
326
327#[derive(
328    Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, strum::IntoStaticStr, strum::EnumString,
329)]
330#[serde(rename_all = "kebab-case")]
331#[strum(serialize_all = "kebab-case")]
332pub enum SslMode {
333    Allow,
334    Disable,
335    Prefer,
336    Require,
337    VerifyCa,
338    VerifyFull,
339}
340
341impl SslMode {
342    #[must_use]
343    pub fn as_str(&self) -> &'static str {
344        self.into()
345    }
346
347    pub(crate) fn pg_env_value(&self) -> String {
348        self.as_str().to_string()
349    }
350}
351
352#[derive(
353    Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, strum::IntoStaticStr, strum::EnumString,
354)]
355#[serde(rename_all = "kebab-case")]
356#[strum(serialize_all = "kebab-case")]
357pub enum ChannelBinding {
358    Disable,
359    Prefer,
360    Require,
361}
362
363impl ChannelBinding {
364    #[must_use]
365    pub fn as_str(&self) -> &'static str {
366        self.into()
367    }
368
369    pub(crate) fn pg_env_value(&self) -> String {
370        self.as_str().to_string()
371    }
372}
373
374#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
375#[serde(rename_all = "kebab-case")]
376pub enum SslRootCert {
377    File(std::path::PathBuf),
378    System,
379}
380
381impl SslRootCert {
382    pub(crate) fn pg_env_value(&self) -> String {
383        match self {
384            Self::File(path) => path.to_str().unwrap().to_string(),
385            Self::System => "system".to_string(),
386        }
387    }
388}
389
390impl From<std::path::PathBuf> for SslRootCert {
391    fn from(value: std::path::PathBuf) -> Self {
392        Self::File(value)
393    }
394}
395
396/// Session parameters sent during PostgreSQL connection setup.
397///
398/// These are independent of how the connection is established (TCP, Unix socket, etc.)
399/// and represent what the client identifies as during the startup message.
400#[derive(Clone, Debug, PartialEq, Eq)]
401pub struct Session {
402    pub application_name: Option<ApplicationName>,
403    pub database: Database,
404    pub password: Option<Password>,
405    pub user: User,
406}
407
408#[cfg(test)]
409mod test {
410    use super::*;
411    use pretty_assertions::assert_eq;
412    use std::str::FromStr;
413
414    fn repeat(char: char, len: usize) -> String {
415        std::iter::repeat_n(char, len).collect()
416    }
417
418    #[test]
419    fn application_name_lt_min_length() {
420        let value = String::new();
421
422        let err = ApplicationName::from_str(&value).expect_err("expected min length failure");
423
424        assert_eq!(err, "ApplicationName byte min length: 1 violated, got: 0");
425    }
426
427    #[test]
428    fn application_name_eq_min_length() {
429        let value = repeat('a', 1);
430
431        let application_name =
432            ApplicationName::from_str(&value).expect("expected valid min length value");
433
434        assert_eq!(application_name, ApplicationName(value));
435    }
436
437    #[test]
438    fn application_name_gt_min_length() {
439        let value = repeat('a', 2);
440
441        let application_name =
442            ApplicationName::from_str(&value).expect("expected valid value greater than min");
443
444        assert_eq!(application_name, ApplicationName(value));
445    }
446
447    #[test]
448    fn application_name_lt_max_length() {
449        let value = repeat('a', 62);
450
451        let application_name =
452            ApplicationName::from_str(&value).expect("expected valid value less than max");
453
454        assert_eq!(application_name, ApplicationName(value));
455    }
456
457    #[test]
458    fn application_name_eq_max_length() {
459        let value = repeat('a', 63);
460
461        let application_name =
462            ApplicationName::from_str(&value).expect("expected valid value equal to max");
463
464        assert_eq!(application_name, ApplicationName(value));
465    }
466
467    #[test]
468    fn application_name_gt_max_length() {
469        let value = repeat('a', 64);
470
471        let err = ApplicationName::from_str(&value).expect_err("expected max length failure");
472
473        assert_eq!(err, "ApplicationName byte max length: 63 violated, got: 64");
474    }
475
476    #[test]
477    fn application_name_contains_nul() {
478        let value = String::from('\0');
479
480        let err = ApplicationName::from_str(&value).expect_err("expected NUL failure");
481
482        assert_eq!(err, "ApplicationName contains NUL byte");
483    }
484
485    #[test]
486    fn password_eq_min_length() {
487        let value = String::new();
488
489        let password = Password::from_str(&value).expect("expected valid min length value");
490
491        assert_eq!(password, Password(value));
492    }
493
494    #[test]
495    fn password_gt_min_length() {
496        let value = repeat('p', 1);
497
498        let password = Password::from_str(&value).expect("expected valid value greater than min");
499
500        assert_eq!(password, Password(value));
501    }
502
503    #[test]
504    fn password_lt_max_length() {
505        let value = repeat('p', 4095);
506
507        let password = Password::from_str(&value).expect("expected valid value less than max");
508
509        assert_eq!(password, Password(value));
510    }
511
512    #[test]
513    fn password_eq_max_length() {
514        let value = repeat('p', 4096);
515
516        let password = Password::from_str(&value).expect("expected valid value equal to max");
517
518        assert_eq!(password, Password(value));
519    }
520
521    #[test]
522    fn password_gt_max_length() {
523        let value = repeat('p', 4097);
524
525        let err = Password::from_str(&value).expect_err("expected max length failure");
526
527        assert_eq!(err, "Password byte max length: 4096 violated, got: 4097");
528    }
529
530    #[test]
531    fn password_contains_nul() {
532        let value = String::from('\0');
533
534        let err = Password::from_str(&value).expect_err("expected NUL failure");
535
536        assert_eq!(err, "Password contains NUL byte");
537    }
538}