1use crate::identifier::{Database, Role, User};
2
3macro_rules! from_str_impl {
5 ($struct: ident, $min: expr, $max: expr) => {
6 impl std::str::FromStr for $struct {
7 type Err = String;
8
9 fn from_str(value: &str) -> Result<Self, Self::Err> {
10 let min_length = Self::MIN_LENGTH;
11 let max_length = Self::MAX_LENGTH;
12 let actual = value.len();
13
14 if actual < min_length {
15 Err(format!(
16 "{} byte min length: {min_length} violated, got: {actual}",
17 stringify!($struct)
18 ))
19 } else if actual > max_length {
20 Err(format!(
21 "{} byte max length: {max_length} violated, got: {actual}",
22 stringify!($struct)
23 ))
24 } else if value.as_bytes().contains(&0) {
25 Err(format!("{} contains NUL byte", stringify!($struct)))
26 } else {
27 Ok(Self(value.to_string()))
28 }
29 }
30 }
31
32 impl AsRef<str> for $struct {
33 fn as_ref(&self) -> &str {
34 &self.0
35 }
36 }
37
38 impl $struct {
39 pub const MIN_LENGTH: usize = $min;
40 pub const MAX_LENGTH: usize = $max;
41
42 pub fn as_str(&self) -> &str {
43 &self.0
44 }
45 }
46 };
47}
48
49#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
50pub struct HostName(String);
51
52impl HostName {
53 #[must_use]
54 pub fn as_str(&self) -> &str {
55 &self.0
56 }
57}
58
59impl std::str::FromStr for HostName {
60 type Err = &'static str;
61
62 fn from_str(value: &str) -> Result<Self, Self::Err> {
63 if hostname_validator::is_valid(value) {
64 Ok(Self(value.to_string()))
65 } else {
66 Err("invalid host name")
67 }
68 }
69}
70
71impl<'de> serde::Deserialize<'de> for HostName {
72 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
73 where
74 D: serde::Deserializer<'de>,
75 {
76 let s = String::deserialize(deserializer)?;
77 s.parse().map_err(serde::de::Error::custom)
78 }
79}
80
81#[derive(Clone, Debug, PartialEq, Eq)]
82pub enum Host {
83 HostName(HostName),
84 IpAddr(std::net::IpAddr),
85}
86
87impl serde::Serialize for Host {
88 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
89 serializer.serialize_str(&self.pg_env_value())
90 }
91}
92
93impl Host {
94 pub(crate) fn pg_env_value(&self) -> String {
95 match self {
96 Self::HostName(value) => value.0.clone(),
97 Self::IpAddr(value) => value.to_string(),
98 }
99 }
100}
101
102impl std::str::FromStr for Host {
103 type Err = &'static str;
104
105 fn from_str(value: &str) -> Result<Self, Self::Err> {
106 match std::net::IpAddr::from_str(value) {
107 Ok(addr) => Ok(Self::IpAddr(addr)),
108 Err(_) => match HostName::from_str(value) {
109 Ok(host_name) => Ok(Self::HostName(host_name)),
110 Err(_) => Err("Not a socket address or FQDN"),
111 },
112 }
113 }
114}
115
116impl From<HostName> for Host {
117 fn from(value: HostName) -> Self {
118 Self::HostName(value)
119 }
120}
121
122impl From<std::net::IpAddr> for Host {
123 fn from(value: std::net::IpAddr) -> Self {
124 Self::IpAddr(value)
125 }
126}
127
128#[derive(Clone, Debug, PartialEq, Eq)]
129pub struct HostAddr(std::net::IpAddr);
130
131impl HostAddr {
132 #[must_use]
133 pub const fn new(ip: std::net::IpAddr) -> Self {
134 Self(ip)
135 }
136}
137
138impl From<std::net::IpAddr> for HostAddr {
139 fn from(value: std::net::IpAddr) -> Self {
149 Self(value)
150 }
151}
152
153impl From<HostAddr> for std::net::IpAddr {
154 fn from(value: HostAddr) -> Self {
155 value.0
156 }
157}
158
159impl From<&HostAddr> for std::net::IpAddr {
160 fn from(value: &HostAddr) -> Self {
161 value.0
162 }
163}
164
165impl std::fmt::Display for HostAddr {
166 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 write!(formatter, "{}", self.0)
175 }
176}
177
178impl std::str::FromStr for HostAddr {
179 type Err = &'static str;
180
181 fn from_str(value: &str) -> Result<Self, Self::Err> {
197 match std::net::IpAddr::from_str(value) {
198 Ok(addr) => Ok(Self(addr)),
199 Err(_) => Err("invalid IP address"),
200 }
201 }
202}
203
204#[derive(Clone, Debug, PartialEq, Eq)]
205pub enum Endpoint {
206 Network {
207 host: Host,
208 channel_binding: Option<ChannelBinding>,
209 host_addr: Option<HostAddr>,
210 port: Option<Port>,
211 },
212 SocketPath(std::path::PathBuf),
213}
214
215impl serde::Serialize for Endpoint {
216 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
217 use serde::ser::SerializeStruct;
218 match self {
219 Self::Network {
220 host,
221 channel_binding,
222 host_addr,
223 port,
224 } => {
225 let mut state = serializer.serialize_struct("Endpoint", 4)?;
226 state.serialize_field("host", host)?;
227 if let Some(channel_binding) = channel_binding {
228 state.serialize_field("channel_binding", channel_binding)?;
229 }
230 if let Some(addr) = host_addr {
231 state.serialize_field("host_addr", &addr.to_string())?;
232 }
233 if let Some(port) = port {
234 state.serialize_field("port", port)?;
235 }
236 state.end()
237 }
238 Self::SocketPath(path) => {
239 let mut state = serializer.serialize_struct("Endpoint", 1)?;
240 state.serialize_field(
241 "socket_path",
242 &path.to_str().expect("socket path contains invalid utf8"),
243 )?;
244 state.end()
245 }
246 }
247 }
248}
249
250#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize)]
251pub struct Port(u16);
252
253impl Port {
254 #[must_use]
255 pub const fn new(port: u16) -> Self {
256 Self(port)
257 }
258
259 pub(crate) fn pg_env_value(self) -> String {
260 self.0.to_string()
261 }
262}
263
264impl std::str::FromStr for Port {
265 type Err = &'static str;
266
267 fn from_str(value: &str) -> Result<Self, Self::Err> {
268 match <u16 as std::str::FromStr>::from_str(value) {
269 Ok(port) => Ok(Port(port)),
270 Err(_) => Err("invalid postgresql port string"),
271 }
272 }
273}
274
275impl From<u16> for Port {
276 fn from(port: u16) -> Self {
277 Self(port)
278 }
279}
280
281impl From<Port> for u16 {
282 fn from(port: Port) -> Self {
283 port.0
284 }
285}
286
287impl From<&Port> for u16 {
288 fn from(port: &Port) -> Self {
289 port.0
290 }
291}
292
293#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
294pub struct ApplicationName(String);
295
296from_str_impl!(ApplicationName, 1, 63);
297
298impl ApplicationName {
299 pub(crate) fn pg_env_value(&self) -> String {
300 self.0.clone()
301 }
302}
303
304impl Database {
305 pub(crate) fn pg_env_value(&self) -> String {
306 self.as_str().to_owned()
307 }
308}
309
310impl Role {
311 pub(crate) fn pg_env_value(&self) -> String {
312 self.as_str().to_owned()
313 }
314}
315
316#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
317pub struct Password(String);
318
319from_str_impl!(Password, 0, 4096);
320
321impl Password {
322 pub(crate) fn pg_env_value(&self) -> String {
323 self.0.clone()
324 }
325}
326
327#[derive(
328 Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, strum::IntoStaticStr, strum::EnumString,
329)]
330#[serde(rename_all = "kebab-case")]
331#[strum(serialize_all = "kebab-case")]
332pub enum SslMode {
333 Allow,
334 Disable,
335 Prefer,
336 Require,
337 VerifyCa,
338 VerifyFull,
339}
340
341impl SslMode {
342 #[must_use]
343 pub fn as_str(&self) -> &'static str {
344 self.into()
345 }
346
347 pub(crate) fn pg_env_value(&self) -> String {
348 self.as_str().to_string()
349 }
350}
351
352#[derive(
353 Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, strum::IntoStaticStr, strum::EnumString,
354)]
355#[serde(rename_all = "kebab-case")]
356#[strum(serialize_all = "kebab-case")]
357pub enum ChannelBinding {
358 Disable,
359 Prefer,
360 Require,
361}
362
363impl ChannelBinding {
364 #[must_use]
365 pub fn as_str(&self) -> &'static str {
366 self.into()
367 }
368
369 pub(crate) fn pg_env_value(&self) -> String {
370 self.as_str().to_string()
371 }
372}
373
374#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)]
375#[serde(rename_all = "kebab-case")]
376pub enum SslRootCert {
377 File(std::path::PathBuf),
378 System,
379}
380
381impl SslRootCert {
382 pub(crate) fn pg_env_value(&self) -> String {
383 match self {
384 Self::File(path) => path.to_str().unwrap().to_string(),
385 Self::System => "system".to_string(),
386 }
387 }
388}
389
390impl From<std::path::PathBuf> for SslRootCert {
391 fn from(value: std::path::PathBuf) -> Self {
392 Self::File(value)
393 }
394}
395
396#[derive(Clone, Debug, PartialEq, Eq)]
401pub struct Session {
402 pub application_name: Option<ApplicationName>,
403 pub database: Database,
404 pub password: Option<Password>,
405 pub user: User,
406}
407
408#[cfg(test)]
409mod test {
410 use super::*;
411 use pretty_assertions::assert_eq;
412 use std::str::FromStr;
413
414 fn repeat(char: char, len: usize) -> String {
415 std::iter::repeat_n(char, len).collect()
416 }
417
418 #[test]
419 fn application_name_lt_min_length() {
420 let value = String::new();
421
422 let err = ApplicationName::from_str(&value).expect_err("expected min length failure");
423
424 assert_eq!(err, "ApplicationName byte min length: 1 violated, got: 0");
425 }
426
427 #[test]
428 fn application_name_eq_min_length() {
429 let value = repeat('a', 1);
430
431 let application_name =
432 ApplicationName::from_str(&value).expect("expected valid min length value");
433
434 assert_eq!(application_name, ApplicationName(value));
435 }
436
437 #[test]
438 fn application_name_gt_min_length() {
439 let value = repeat('a', 2);
440
441 let application_name =
442 ApplicationName::from_str(&value).expect("expected valid value greater than min");
443
444 assert_eq!(application_name, ApplicationName(value));
445 }
446
447 #[test]
448 fn application_name_lt_max_length() {
449 let value = repeat('a', 62);
450
451 let application_name =
452 ApplicationName::from_str(&value).expect("expected valid value less than max");
453
454 assert_eq!(application_name, ApplicationName(value));
455 }
456
457 #[test]
458 fn application_name_eq_max_length() {
459 let value = repeat('a', 63);
460
461 let application_name =
462 ApplicationName::from_str(&value).expect("expected valid value equal to max");
463
464 assert_eq!(application_name, ApplicationName(value));
465 }
466
467 #[test]
468 fn application_name_gt_max_length() {
469 let value = repeat('a', 64);
470
471 let err = ApplicationName::from_str(&value).expect_err("expected max length failure");
472
473 assert_eq!(err, "ApplicationName byte max length: 63 violated, got: 64");
474 }
475
476 #[test]
477 fn application_name_contains_nul() {
478 let value = String::from('\0');
479
480 let err = ApplicationName::from_str(&value).expect_err("expected NUL failure");
481
482 assert_eq!(err, "ApplicationName contains NUL byte");
483 }
484
485 #[test]
486 fn password_eq_min_length() {
487 let value = String::new();
488
489 let password = Password::from_str(&value).expect("expected valid min length value");
490
491 assert_eq!(password, Password(value));
492 }
493
494 #[test]
495 fn password_gt_min_length() {
496 let value = repeat('p', 1);
497
498 let password = Password::from_str(&value).expect("expected valid value greater than min");
499
500 assert_eq!(password, Password(value));
501 }
502
503 #[test]
504 fn password_lt_max_length() {
505 let value = repeat('p', 4095);
506
507 let password = Password::from_str(&value).expect("expected valid value less than max");
508
509 assert_eq!(password, Password(value));
510 }
511
512 #[test]
513 fn password_eq_max_length() {
514 let value = repeat('p', 4096);
515
516 let password = Password::from_str(&value).expect("expected valid value equal to max");
517
518 assert_eq!(password, Password(value));
519 }
520
521 #[test]
522 fn password_gt_max_length() {
523 let value = repeat('p', 4097);
524
525 let err = Password::from_str(&value).expect_err("expected max length failure");
526
527 assert_eq!(err, "Password byte max length: 4096 violated, got: 4097");
528 }
529
530 #[test]
531 fn password_contains_nul() {
532 let value = String::from('\0');
533
534 let err = Password::from_str(&value).expect_err("expected NUL failure");
535
536 assert_eq!(err, "Password contains NUL byte");
537 }
538}