1use crate::identifier::{Database, Role, User};
2
3macro_rules! from_str_impl {
6 ($struct: ident, $err: ident, $min: expr, $max: expr) => {
7 #[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
8 pub enum $err {
9 #[error("{} byte min length: {min} violated, got: {actual}", stringify!($struct))]
10 TooShort { min: usize, actual: usize },
11 #[error("{} byte max length: {max} violated, got: {actual}", stringify!($struct))]
12 TooLong { max: usize, actual: usize },
13 #[error("{} contains NUL byte", stringify!($struct))]
14 ContainsNul,
15 }
16
17 impl std::str::FromStr for $struct {
18 type Err = $err;
19
20 fn from_str(value: &str) -> Result<Self, Self::Err> {
21 let actual = value.len();
22
23 if actual < Self::MIN_LENGTH {
24 Err($err::TooShort {
25 min: Self::MIN_LENGTH,
26 actual,
27 })
28 } else if actual > Self::MAX_LENGTH {
29 Err($err::TooLong {
30 max: Self::MAX_LENGTH,
31 actual,
32 })
33 } else if value.as_bytes().contains(&0) {
34 Err($err::ContainsNul)
35 } else {
36 Ok(Self(value.to_string()))
37 }
38 }
39 }
40
41 impl AsRef<str> for $struct {
42 fn as_ref(&self) -> &str {
43 &self.0
44 }
45 }
46
47 impl $struct {
48 pub const MIN_LENGTH: usize = $min;
49 pub const MAX_LENGTH: usize = $max;
50
51 pub fn as_str(&self) -> &str {
52 &self.0
53 }
54 }
55 };
56}
57
58#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
59pub struct HostName(String);
60
61impl HostName {
62 #[must_use]
63 pub fn as_str(&self) -> &str {
64 &self.0
65 }
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
69#[error("invalid host name")]
70pub struct HostNameParseError;
71
72impl std::str::FromStr for HostName {
73 type Err = HostNameParseError;
74
75 fn from_str(value: &str) -> Result<Self, Self::Err> {
76 if hostname_validator::is_valid(value) {
77 Ok(Self(value.to_string()))
78 } else {
79 Err(HostNameParseError)
80 }
81 }
82}
83
84impl<'de> serde::Deserialize<'de> for HostName {
85 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
86 where
87 D: serde::Deserializer<'de>,
88 {
89 let s = String::deserialize(deserializer)?;
90 s.parse().map_err(serde::de::Error::custom)
91 }
92}
93
94#[derive(Clone, Debug, PartialEq, Eq)]
95pub enum Host {
96 HostName(HostName),
97 IpAddr(std::net::IpAddr),
98}
99
100impl serde::Serialize for Host {
101 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
102 serializer.serialize_str(&self.pg_env_value())
103 }
104}
105
106impl Host {
107 pub(crate) fn pg_env_value(&self) -> String {
108 match self {
109 Self::HostName(value) => value.0.clone(),
110 Self::IpAddr(value) => value.to_string(),
111 }
112 }
113}
114
115#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
116#[error("Not a socket address or FQDN")]
117pub struct HostParseError;
118
119impl std::str::FromStr for Host {
120 type Err = HostParseError;
121
122 fn from_str(value: &str) -> Result<Self, Self::Err> {
123 match std::net::IpAddr::from_str(value) {
124 Ok(addr) => Ok(Self::IpAddr(addr)),
125 Err(_) => match HostName::from_str(value) {
126 Ok(host_name) => Ok(Self::HostName(host_name)),
127 Err(_) => Err(HostParseError),
128 },
129 }
130 }
131}
132
133impl From<HostName> for Host {
134 fn from(value: HostName) -> Self {
135 Self::HostName(value)
136 }
137}
138
139impl From<std::net::IpAddr> for Host {
140 fn from(value: std::net::IpAddr) -> Self {
141 Self::IpAddr(value)
142 }
143}
144
145#[derive(Clone, Debug, PartialEq, Eq)]
146pub struct HostAddr(std::net::IpAddr);
147
148impl HostAddr {
149 #[must_use]
150 pub const fn new(ip: std::net::IpAddr) -> Self {
151 Self(ip)
152 }
153}
154
155impl From<std::net::IpAddr> for HostAddr {
156 fn from(value: std::net::IpAddr) -> Self {
166 Self(value)
167 }
168}
169
170impl From<HostAddr> for std::net::IpAddr {
171 fn from(value: HostAddr) -> Self {
172 value.0
173 }
174}
175
176impl From<&HostAddr> for std::net::IpAddr {
177 fn from(value: &HostAddr) -> Self {
178 value.0
179 }
180}
181
182impl std::fmt::Display for HostAddr {
183 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191 write!(formatter, "{}", self.0)
192 }
193}
194
195#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
196#[error("invalid IP address")]
197pub struct HostAddrParseError;
198
199impl std::str::FromStr for HostAddr {
200 type Err = HostAddrParseError;
201
202 fn from_str(value: &str) -> Result<Self, Self::Err> {
218 match std::net::IpAddr::from_str(value) {
219 Ok(addr) => Ok(Self(addr)),
220 Err(_) => Err(HostAddrParseError),
221 }
222 }
223}
224
225#[derive(Clone, Debug, PartialEq, Eq)]
226pub enum Endpoint {
227 Network {
228 host: Host,
229 channel_binding: Option<ChannelBinding>,
230 host_addr: Option<HostAddr>,
231 port: Option<Port>,
232 },
233 SocketPath(std::path::PathBuf),
234}
235
236impl serde::Serialize for Endpoint {
237 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
238 use serde::ser::SerializeStruct;
239 match self {
240 Self::Network {
241 host,
242 channel_binding,
243 host_addr,
244 port,
245 } => {
246 let mut state = serializer.serialize_struct("Endpoint", 4)?;
247 state.serialize_field("host", host)?;
248 if let Some(channel_binding) = channel_binding {
249 state.serialize_field("channel_binding", channel_binding)?;
250 }
251 if let Some(addr) = host_addr {
252 state.serialize_field("host_addr", &addr.to_string())?;
253 }
254 if let Some(port) = port {
255 state.serialize_field("port", port)?;
256 }
257 state.end()
258 }
259 Self::SocketPath(path) => {
260 let mut state = serializer.serialize_struct("Endpoint", 1)?;
261 state.serialize_field(
262 "socket_path",
263 &path.to_str().expect("socket path contains invalid utf8"),
264 )?;
265 state.end()
266 }
267 }
268 }
269}
270
271#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize)]
272pub struct Port(u16);
273
274impl Port {
275 #[must_use]
276 pub const fn new(port: u16) -> Self {
277 Self(port)
278 }
279
280 pub(crate) fn pg_env_value(self) -> String {
281 self.0.to_string()
282 }
283}
284
285#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
286#[error("invalid postgresql port string")]
287pub struct PortParseError;
288
289impl std::str::FromStr for Port {
290 type Err = PortParseError;
291
292 fn from_str(value: &str) -> Result<Self, Self::Err> {
293 match <u16 as std::str::FromStr>::from_str(value) {
294 Ok(port) => Ok(Port(port)),
295 Err(_) => Err(PortParseError),
296 }
297 }
298}
299
300impl From<u16> for Port {
301 fn from(port: u16) -> Self {
302 Self(port)
303 }
304}
305
306impl From<Port> for u16 {
307 fn from(port: Port) -> Self {
308 port.0
309 }
310}
311
312impl From<&Port> for u16 {
313 fn from(port: &Port) -> Self {
314 port.0
315 }
316}
317
318#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
319pub struct ApplicationName(String);
320
321from_str_impl!(ApplicationName, ApplicationNameParseError, 1, 63);
322
323impl ApplicationName {
324 pub(crate) fn pg_env_value(&self) -> String {
325 self.0.clone()
326 }
327}
328
329impl Database {
330 pub(crate) fn pg_env_value(&self) -> String {
331 self.as_str().to_owned()
332 }
333}
334
335impl Role {
336 pub(crate) fn pg_env_value(&self) -> String {
337 self.as_str().to_owned()
338 }
339}
340
341#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
342pub struct Password(String);
343
344from_str_impl!(Password, PasswordParseError, 0, 4096);
345
346impl Password {
347 pub(crate) fn pg_env_value(&self) -> String {
348 self.0.clone()
349 }
350}
351
352#[derive(
353 Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, strum::IntoStaticStr, strum::EnumString,
354)]
355#[serde(rename_all = "kebab-case")]
356#[strum(serialize_all = "kebab-case")]
357pub enum SslMode {
358 Allow,
359 Disable,
360 Prefer,
361 Require,
362 VerifyCa,
363 VerifyFull,
364}
365
366impl SslMode {
367 #[must_use]
368 pub fn as_str(&self) -> &'static str {
369 self.into()
370 }
371
372 pub(crate) fn pg_env_value(&self) -> String {
373 self.as_str().to_string()
374 }
375}
376
377#[derive(
378 Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, strum::IntoStaticStr, strum::EnumString,
379)]
380#[serde(rename_all = "kebab-case")]
381#[strum(serialize_all = "kebab-case")]
382pub enum ChannelBinding {
383 Disable,
384 Prefer,
385 Require,
386}
387
388impl ChannelBinding {
389 #[must_use]
390 pub fn as_str(&self) -> &'static str {
391 self.into()
392 }
393
394 pub(crate) fn pg_env_value(&self) -> String {
395 self.as_str().to_string()
396 }
397}
398
399#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
400#[serde(rename_all = "kebab-case")]
401pub enum SslRootCert {
402 File(std::path::PathBuf),
403 System,
404}
405
406impl SslRootCert {
407 pub(crate) fn pg_env_value(&self) -> String {
408 match self {
409 Self::File(path) => path.to_str().unwrap().to_string(),
410 Self::System => "system".to_string(),
411 }
412 }
413}
414
415impl From<std::path::PathBuf> for SslRootCert {
416 fn from(value: std::path::PathBuf) -> Self {
417 Self::File(value)
418 }
419}
420
421#[derive(Clone, Debug, PartialEq, Eq)]
426pub struct Session {
427 pub application_name: Option<ApplicationName>,
428 pub database: Database,
429 pub password: Option<Password>,
430 pub user: User,
431}
432
433#[cfg(test)]
434mod test {
435 use super::*;
436 use pretty_assertions::assert_eq;
437 use std::str::FromStr;
438
439 fn repeat(char: char, len: usize) -> String {
440 std::iter::repeat_n(char, len).collect()
441 }
442
443 #[test]
444 fn application_name_lt_min_length() {
445 let value = String::new();
446
447 let err = ApplicationName::from_str(&value).expect_err("expected min length failure");
448
449 assert_eq!(
450 err,
451 ApplicationNameParseError::TooShort { min: 1, actual: 0 },
452 );
453 assert_eq!(
454 err.to_string(),
455 "ApplicationName byte min length: 1 violated, got: 0",
456 );
457 }
458
459 #[test]
460 fn application_name_eq_min_length() {
461 let value = repeat('a', 1);
462
463 let application_name =
464 ApplicationName::from_str(&value).expect("expected valid min length value");
465
466 assert_eq!(application_name, ApplicationName(value));
467 }
468
469 #[test]
470 fn application_name_gt_min_length() {
471 let value = repeat('a', 2);
472
473 let application_name =
474 ApplicationName::from_str(&value).expect("expected valid value greater than min");
475
476 assert_eq!(application_name, ApplicationName(value));
477 }
478
479 #[test]
480 fn application_name_lt_max_length() {
481 let value = repeat('a', 62);
482
483 let application_name =
484 ApplicationName::from_str(&value).expect("expected valid value less than max");
485
486 assert_eq!(application_name, ApplicationName(value));
487 }
488
489 #[test]
490 fn application_name_eq_max_length() {
491 let value = repeat('a', 63);
492
493 let application_name =
494 ApplicationName::from_str(&value).expect("expected valid value equal to max");
495
496 assert_eq!(application_name, ApplicationName(value));
497 }
498
499 #[test]
500 fn application_name_gt_max_length() {
501 let value = repeat('a', 64);
502
503 let err = ApplicationName::from_str(&value).expect_err("expected max length failure");
504
505 assert_eq!(
506 err,
507 ApplicationNameParseError::TooLong {
508 max: 63,
509 actual: 64,
510 },
511 );
512 assert_eq!(
513 err.to_string(),
514 "ApplicationName byte max length: 63 violated, got: 64",
515 );
516 }
517
518 #[test]
519 fn application_name_contains_nul() {
520 let value = String::from('\0');
521
522 let err = ApplicationName::from_str(&value).expect_err("expected NUL failure");
523
524 assert_eq!(err, ApplicationNameParseError::ContainsNul);
525 assert_eq!(err.to_string(), "ApplicationName contains NUL byte");
526 }
527
528 #[test]
529 fn password_eq_min_length() {
530 let value = String::new();
531
532 let password = Password::from_str(&value).expect("expected valid min length value");
533
534 assert_eq!(password, Password(value));
535 }
536
537 #[test]
538 fn password_gt_min_length() {
539 let value = repeat('p', 1);
540
541 let password = Password::from_str(&value).expect("expected valid value greater than min");
542
543 assert_eq!(password, Password(value));
544 }
545
546 #[test]
547 fn password_lt_max_length() {
548 let value = repeat('p', 4095);
549
550 let password = Password::from_str(&value).expect("expected valid value less than max");
551
552 assert_eq!(password, Password(value));
553 }
554
555 #[test]
556 fn password_eq_max_length() {
557 let value = repeat('p', 4096);
558
559 let password = Password::from_str(&value).expect("expected valid value equal to max");
560
561 assert_eq!(password, Password(value));
562 }
563
564 #[test]
565 fn password_gt_max_length() {
566 let value = repeat('p', 4097);
567
568 let err = Password::from_str(&value).expect_err("expected max length failure");
569
570 assert_eq!(
571 err,
572 PasswordParseError::TooLong {
573 max: 4096,
574 actual: 4097,
575 },
576 );
577 assert_eq!(
578 err.to_string(),
579 "Password byte max length: 4096 violated, got: 4097",
580 );
581 }
582
583 #[test]
584 fn password_contains_nul() {
585 let value = String::from('\0');
586
587 let err = Password::from_str(&value).expect_err("expected NUL failure");
588
589 assert_eq!(err, PasswordParseError::ContainsNul);
590 assert_eq!(err.to_string(), "Password contains NUL byte");
591 }
592}