1use std::{
2 borrow::Cow,
3 net::{IpAddr, Ipv4Addr, SocketAddr},
4 path::Path,
5 sync::Arc,
6};
7
8use derive_more::Debug;
9use rustls_pki_types::ServerName;
10
11use crate::TlsParameters;
12
13pub struct TargetName {
15 inner: MaybeResolvedTarget,
16}
17
18impl std::fmt::Debug for TargetName {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 write!(f, "{:?}", self.inner)
21 }
22}
23
24impl TargetName {
25 pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
27 #[cfg(unix)]
28 {
29 let path = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
30 Ok(Self {
31 inner: MaybeResolvedTarget::Resolved(path),
32 })
33 }
34 #[cfg(not(unix))]
35 {
36 Err(std::io::Error::new(
37 std::io::ErrorKind::Unsupported,
38 "Unix sockets are not supported on this platform",
39 ))
40 }
41 }
42
43 pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
45 #[cfg(any(target_os = "linux", target_os = "android"))]
46 {
47 use std::os::linux::net::SocketAddrExt;
48 let domain =
49 ResolvedTarget::from(std::os::unix::net::SocketAddr::from_abstract_name(domain)?);
50 Ok(Self {
51 inner: MaybeResolvedTarget::Resolved(domain),
52 })
53 }
54 #[cfg(not(any(target_os = "linux", target_os = "android")))]
55 {
56 Err(std::io::Error::new(
57 std::io::ErrorKind::Unsupported,
58 "Unix domain sockets are not supported on this platform",
59 ))
60 }
61 }
62
63 #[allow(private_bounds)]
65 pub fn new_tcp(host: impl TcpResolve) -> Self {
66 Self { inner: host.into() }
67 }
68
69 pub fn to_addrs_sync(&self) -> Result<Vec<ResolvedTarget>, std::io::Error> {
71 use std::net::ToSocketAddrs;
72 let mut result = Vec::new();
73 match &self.inner {
74 MaybeResolvedTarget::Resolved(addr) => {
75 return Ok(vec![addr.clone()]);
76 }
77 MaybeResolvedTarget::Unresolved(host, port, _interface) => {
78 let addrs = format!("{}:{}", host, port).to_socket_addrs()?;
79 result.extend(addrs.map(ResolvedTarget::SocketAddr));
80 }
81 }
82 Ok(result)
83 }
84
85 pub(crate) fn maybe_resolved(&self) -> &MaybeResolvedTarget {
86 &self.inner
87 }
88
89 pub(crate) fn maybe_resolved_mut(&mut self) -> &mut MaybeResolvedTarget {
90 &mut self.inner
91 }
92
93 pub fn is_tcp(&self) -> bool {
95 self.maybe_resolved().port().is_some()
96 }
97
98 pub fn port(&self) -> Option<u16> {
101 self.maybe_resolved().port()
102 }
103
104 pub fn try_set_port(&mut self, port: u16) -> Option<u16> {
107 self.maybe_resolved_mut().set_port(port)
108 }
109
110 pub fn path(&self) -> Option<&Path> {
113 self.maybe_resolved().path()
114 }
115
116 pub fn host(&self) -> Option<Cow<str>> {
121 self.maybe_resolved().host()
122 }
123
124 pub fn name(&self) -> Option<ServerName> {
128 self.maybe_resolved().name()
129 }
130
131 pub fn tcp(&self) -> Option<(Cow<str>, u16)> {
134 self.maybe_resolved().tcp()
135 }
136}
137
138#[derive(Clone)]
139pub struct Target {
140 inner: TargetInner,
141}
142
143impl std::fmt::Debug for Target {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 match &self.inner {
146 TargetInner::NoTls(target) => write!(f, "{:?}", target),
147 TargetInner::Tls(target, _) => write!(f, "{:?} (TLS)", target),
148 TargetInner::StartTls(target, _) => write!(f, "{:?} (STARTTLS)", target),
149 }
150 }
151}
152
153#[allow(private_bounds)]
154impl Target {
155 pub fn new(name: TargetName) -> Self {
156 Self {
157 inner: TargetInner::NoTls(name.inner),
158 }
159 }
160
161 pub fn new_tls(name: TargetName, params: TlsParameters) -> Self {
162 Self {
163 inner: TargetInner::Tls(name.inner, params.into()),
164 }
165 }
166
167 pub fn new_starttls(name: TargetName, params: TlsParameters) -> Self {
168 Self {
169 inner: TargetInner::StartTls(name.inner, params.into()),
170 }
171 }
172
173 pub fn new_resolved(target: ResolvedTarget) -> Self {
174 Self {
175 inner: TargetInner::NoTls(target.into()),
176 }
177 }
178
179 pub fn new_resolved_tls(target: ResolvedTarget, params: TlsParameters) -> Self {
180 Self {
181 inner: TargetInner::Tls(target.into(), params.into()),
182 }
183 }
184
185 pub fn new_resolved_starttls(target: ResolvedTarget, params: TlsParameters) -> Self {
186 Self {
187 inner: TargetInner::StartTls(target.into(), params.into()),
188 }
189 }
190
191 pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
193 #[cfg(unix)]
194 {
195 let path = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
196 Ok(Self {
197 inner: TargetInner::NoTls(path.into()),
198 })
199 }
200 #[cfg(not(unix))]
201 {
202 Err(std::io::Error::new(
203 std::io::ErrorKind::Unsupported,
204 "Unix sockets are not supported on this platform",
205 ))
206 }
207 }
208
209 pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
211 #[cfg(any(target_os = "linux", target_os = "android"))]
212 {
213 use std::os::linux::net::SocketAddrExt;
214 let domain =
215 ResolvedTarget::from(std::os::unix::net::SocketAddr::from_abstract_name(domain)?);
216 Ok(Self {
217 inner: TargetInner::NoTls(domain.into()),
218 })
219 }
220 #[cfg(not(any(target_os = "linux", target_os = "android")))]
221 {
222 Err(std::io::Error::new(
223 std::io::ErrorKind::Unsupported,
224 "Unix domain sockets are not supported on this platform",
225 ))
226 }
227 }
228
229 pub fn new_tcp(host: impl TcpResolve) -> Self {
231 Self {
232 inner: TargetInner::NoTls(host.into()),
233 }
234 }
235
236 pub fn new_tcp_tls(host: impl TcpResolve, params: TlsParameters) -> Self {
238 Self {
239 inner: TargetInner::Tls(host.into(), params.into()),
240 }
241 }
242
243 pub fn new_tcp_starttls(host: impl TcpResolve, params: TlsParameters) -> Self {
245 Self {
246 inner: TargetInner::StartTls(host.into(), params.into()),
247 }
248 }
249
250 pub fn try_set_tls(&mut self, params: TlsParameters) -> Option<Option<Arc<TlsParameters>>> {
251 if self.maybe_resolved().path().is_some() {
253 return None;
254 }
255
256 let params = params.into();
257
258 let no_target = TargetInner::NoTls(MaybeResolvedTarget::Resolved(
260 ResolvedTarget::SocketAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)),
261 ));
262
263 match std::mem::replace(&mut self.inner, no_target) {
264 TargetInner::NoTls(target) => {
265 self.inner = TargetInner::Tls(target, params);
266 Some(None)
267 }
268 TargetInner::Tls(target, old_params) => {
269 self.inner = TargetInner::Tls(target, params);
270 Some(Some(old_params))
271 }
272 TargetInner::StartTls(target, old_params) => {
273 self.inner = TargetInner::StartTls(target, params);
274 Some(Some(old_params))
275 }
276 }
277 }
278
279 pub fn try_remove_tls(&mut self) -> Option<Arc<TlsParameters>> {
280 let no_target = TargetInner::NoTls(MaybeResolvedTarget::Resolved(
282 ResolvedTarget::SocketAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)),
283 ));
284
285 match std::mem::replace(&mut self.inner, no_target) {
286 TargetInner::NoTls(target) => {
287 self.inner = TargetInner::NoTls(target);
288 None
289 }
290 TargetInner::Tls(target, old_params) => {
291 self.inner = TargetInner::NoTls(target);
292 Some(old_params)
293 }
294 TargetInner::StartTls(target, old_params) => {
295 self.inner = TargetInner::NoTls(target);
296 Some(old_params)
297 }
298 }
299 }
300
301 pub fn is_tcp(&self) -> bool {
303 self.maybe_resolved().port().is_some()
304 }
305
306 pub fn port(&self) -> Option<u16> {
309 self.maybe_resolved().port()
310 }
311
312 pub fn try_set_port(&mut self, port: u16) -> Option<u16> {
315 self.maybe_resolved_mut().set_port(port)
316 }
317
318 pub fn path(&self) -> Option<&Path> {
321 self.maybe_resolved().path()
322 }
323
324 pub fn host(&self) -> Option<Cow<str>> {
329 self.maybe_resolved().host()
330 }
331
332 pub fn name(&self) -> Option<ServerName> {
336 self.maybe_resolved().name()
337 }
338
339 pub fn tcp(&self) -> Option<(Cow<str>, u16)> {
342 self.maybe_resolved().tcp()
343 }
344
345 pub(crate) fn maybe_resolved(&self) -> &MaybeResolvedTarget {
346 match &self.inner {
347 TargetInner::NoTls(target) => target,
348 TargetInner::Tls(target, _) => target,
349 TargetInner::StartTls(target, _) => target,
350 }
351 }
352
353 pub(crate) fn maybe_resolved_mut(&mut self) -> &mut MaybeResolvedTarget {
354 match &mut self.inner {
355 TargetInner::NoTls(target) => target,
356 TargetInner::Tls(target, _) => target,
357 TargetInner::StartTls(target, _) => target,
358 }
359 }
360
361 pub(crate) fn is_starttls(&self) -> bool {
362 matches!(self.inner, TargetInner::StartTls(_, _))
363 }
364
365 pub(crate) fn maybe_ssl(&self) -> Option<&TlsParameters> {
366 match &self.inner {
367 TargetInner::NoTls(_) => None,
368 TargetInner::Tls(_, params) => Some(params),
369 TargetInner::StartTls(_, params) => Some(params),
370 }
371 }
372}
373
374#[derive(Clone, derive_more::From)]
375pub(crate) enum MaybeResolvedTarget {
376 Resolved(ResolvedTarget),
377 Unresolved(Cow<'static, str>, u16, Option<Cow<'static, str>>),
378}
379
380impl std::fmt::Debug for MaybeResolvedTarget {
381 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
382 match self {
383 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
384 if let SocketAddr::V6(addr) = addr {
385 if addr.scope_id() != 0 {
386 write!(f, "[{}%{}]:{}", addr.ip(), addr.scope_id(), addr.port())
387 } else {
388 write!(f, "[{}]:{}", addr.ip(), addr.port())
389 }
390 } else {
391 write!(f, "{}:{}", addr.ip(), addr.port())
392 }
393 }
394 #[cfg(unix)]
395 MaybeResolvedTarget::Resolved(ResolvedTarget::UnixSocketAddr(addr)) => {
396 if let Some(path) = addr.as_pathname() {
397 return write!(f, "{}", path.to_string_lossy());
398 } else {
399 #[cfg(any(target_os = "linux", target_os = "android"))]
400 {
401 use std::os::linux::net::SocketAddrExt;
402 if let Some(name) = addr.as_abstract_name() {
403 return write!(f, "@{}", String::from_utf8_lossy(name));
404 }
405 }
406 }
407 Ok(())
408 }
409 MaybeResolvedTarget::Unresolved(host, port, interface) => {
410 write!(f, "{}:{}", host, port)?;
411 if let Some(interface) = interface {
412 write!(f, "%{}", interface)?;
413 }
414 Ok(())
415 }
416 }
417 }
418}
419
420impl MaybeResolvedTarget {
421 fn name(&self) -> Option<ServerName> {
422 match self {
423 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
424 Some(ServerName::IpAddress(addr.ip().into()))
425 }
426 MaybeResolvedTarget::Unresolved(host, _, _) => {
427 Some(ServerName::DnsName(host.to_string().try_into().ok()?))
428 }
429 _ => None,
430 }
431 }
432
433 fn tcp(&self) -> Option<(Cow<str>, u16)> {
434 match self {
435 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
436 Some((Cow::Owned(addr.ip().to_string()), addr.port()))
437 }
438 MaybeResolvedTarget::Unresolved(host, port, _) => Some((Cow::Borrowed(host), *port)),
439 _ => None,
440 }
441 }
442
443 fn path(&self) -> Option<&Path> {
444 match self {
445 #[cfg(unix)]
446 MaybeResolvedTarget::Resolved(ResolvedTarget::UnixSocketAddr(addr)) => {
447 addr.as_pathname()
448 }
449 _ => None,
450 }
451 }
452
453 fn host(&self) -> Option<Cow<str>> {
454 match self {
455 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
456 Some(Cow::Owned(addr.ip().to_string()))
457 }
458 MaybeResolvedTarget::Unresolved(host, _, _) => Some(Cow::Borrowed(host)),
459 _ => None,
460 }
461 }
462
463 fn port(&self) -> Option<u16> {
464 match self {
465 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => Some(addr.port()),
466 MaybeResolvedTarget::Unresolved(_, port, _) => Some(*port),
467 _ => None,
468 }
469 }
470
471 fn set_port(&mut self, new_port: u16) -> Option<u16> {
472 match self {
473 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
474 let old_port = addr.port();
475 addr.set_port(new_port);
476 Some(old_port)
477 }
478 MaybeResolvedTarget::Unresolved(_, port, _) => {
479 let old_port = *port;
480 *port = new_port;
481 Some(old_port)
482 }
483 _ => None,
484 }
485 }
486}
487
488#[derive(Clone, Debug)]
490enum TargetInner {
491 NoTls(MaybeResolvedTarget),
492 Tls(MaybeResolvedTarget, Arc<TlsParameters>),
493 StartTls(MaybeResolvedTarget, Arc<TlsParameters>),
494}
495
496#[derive(Clone, Debug, derive_more::From)]
497pub enum ResolvedTarget {
499 SocketAddr(std::net::SocketAddr),
500 #[cfg(unix)]
501 UnixSocketAddr(std::os::unix::net::SocketAddr),
502}
503
504impl ResolvedTarget {
505 pub fn tcp(&self) -> Option<SocketAddr> {
506 match self {
507 ResolvedTarget::SocketAddr(addr) => Some(*addr),
508 _ => None,
509 }
510 }
511}
512
513pub trait LocalAddress {
514 fn local_address(&self) -> std::io::Result<ResolvedTarget>;
515}
516
517trait TcpResolve {
518 fn into(self) -> MaybeResolvedTarget;
519}
520
521impl<S: AsRef<str>> TcpResolve for (S, u16) {
522 fn into(self) -> MaybeResolvedTarget {
523 if let Ok(addr) = self.0.as_ref().parse::<IpAddr>() {
524 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(SocketAddr::new(addr, self.1)))
525 } else {
526 MaybeResolvedTarget::Unresolved(Cow::Owned(self.0.as_ref().to_owned()), self.1, None)
527 }
528 }
529}
530
531impl TcpResolve for SocketAddr {
532 fn into(self) -> MaybeResolvedTarget {
533 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(self))
534 }
535}
536
537#[cfg(test)]
538mod tests {
539 use std::net::SocketAddrV6;
540
541 use super::*;
542
543 #[test]
544 fn test_target() {
545 let target = Target::new_tcp(("localhost", 5432));
546 assert_eq!(
547 target.name(),
548 Some(ServerName::DnsName("localhost".try_into().unwrap()))
549 );
550 }
551
552 #[test]
553 fn test_target_name() {
554 let target = TargetName::new_tcp(("localhost", 5432));
555 assert_eq!(format!("{target:?}"), "localhost:5432");
556
557 let target = TargetName::new_tcp(("127.0.0.1", 5432));
558 assert_eq!(format!("{target:?}"), "127.0.0.1:5432");
559
560 let target = TargetName::new_tcp(("::1", 5432));
561 assert_eq!(format!("{target:?}"), "[::1]:5432");
562
563 let target = TargetName::new_tcp(SocketAddr::V6(SocketAddrV6::new(
564 "fe80::1ff:fe23:4567:890a".parse().unwrap(),
565 5432,
566 0,
567 2,
568 )));
569 assert_eq!(format!("{target:?}"), "[fe80::1ff:fe23:4567:890a%2]:5432");
570
571 let target = TargetName::new_unix_path("/tmp/test.sock").unwrap();
572 assert_eq!(format!("{target:?}"), "/tmp/test.sock");
573
574 #[cfg(any(target_os = "linux", target_os = "android"))]
575 {
576 let target = TargetName::new_unix_domain("test").unwrap();
577 assert_eq!(format!("{target:?}"), "@test");
578 }
579 }
580
581 #[test]
582 fn test_target_debug() {
583 let target = Target::new_tcp(("localhost", 5432));
584 assert_eq!(format!("{target:?}"), "localhost:5432");
585
586 let target = Target::new_tcp_tls(("localhost", 5432), TlsParameters::default());
587 assert_eq!(format!("{target:?}"), "localhost:5432 (TLS)");
588
589 let target = Target::new_tcp_starttls(("localhost", 5432), TlsParameters::default());
590 assert_eq!(format!("{target:?}"), "localhost:5432 (STARTTLS)");
591
592 let target = Target::new_tcp(("127.0.0.1", 5432));
593 assert_eq!(format!("{target:?}"), "127.0.0.1:5432");
594
595 let target = Target::new_tcp(("::1", 5432));
596 assert_eq!(format!("{target:?}"), "[::1]:5432");
597
598 #[cfg(unix)]
599 {
600 let target = Target::new_unix_path("/tmp/test.sock").unwrap();
601 assert_eq!(format!("{target:?}"), "/tmp/test.sock");
602 }
603
604 #[cfg(any(target_os = "linux", target_os = "android"))]
605 {
606 let target = Target::new_unix_domain("test").unwrap();
607 assert_eq!(format!("{target:?}"), "@test");
608 }
609 }
610}