1use std::{
2 borrow::Cow,
3 net::{IpAddr, Ipv4Addr, SocketAddr},
4 path::Path,
5 sync::Arc,
6};
7
8use derive_more::Debug;
9use rustls_pki_types::ServerName;
10
11use crate::TlsParameters;
12
13pub struct TargetName {
15 inner: MaybeResolvedTarget,
16}
17
18impl std::fmt::Debug for TargetName {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 write!(f, "{:?}", self.inner)
21 }
22}
23
24impl TargetName {
25 pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
27 #[cfg(unix)]
28 {
29 let path = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
30 Ok(Self {
31 inner: MaybeResolvedTarget::Resolved(path),
32 })
33 }
34 #[cfg(not(unix))]
35 {
36 Err(std::io::Error::new(
37 std::io::ErrorKind::Unsupported,
38 "Unix sockets are not supported on this platform",
39 ))
40 }
41 }
42
43 pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
45 #[cfg(any(target_os = "linux", target_os = "android"))]
46 {
47 use std::os::linux::net::SocketAddrExt;
48 let domain =
49 ResolvedTarget::from(std::os::unix::net::SocketAddr::from_abstract_name(domain)?);
50 Ok(Self {
51 inner: MaybeResolvedTarget::Resolved(domain),
52 })
53 }
54 #[cfg(not(any(target_os = "linux", target_os = "android")))]
55 {
56 Err(std::io::Error::new(
57 std::io::ErrorKind::Unsupported,
58 "Unix domain sockets are not supported on this platform",
59 ))
60 }
61 }
62
63 #[allow(private_bounds)]
65 pub fn new_tcp(host: impl TcpResolve) -> Self {
66 Self { inner: host.into() }
67 }
68
69 pub fn to_addrs_sync(&self) -> Result<Vec<ResolvedTarget>, std::io::Error> {
71 use std::net::ToSocketAddrs;
72 let mut result = Vec::new();
73 match &self.inner {
74 MaybeResolvedTarget::Resolved(addr) => {
75 return Ok(vec![addr.clone()]);
76 }
77 MaybeResolvedTarget::Unresolved(host, port, _interface) => {
78 let addrs = format!("{}:{}", host, port).to_socket_addrs()?;
79 result.extend(addrs.map(ResolvedTarget::SocketAddr));
80 }
81 }
82 Ok(result)
83 }
84
85 pub(crate) fn maybe_resolved(&self) -> &MaybeResolvedTarget {
86 &self.inner
87 }
88
89 pub(crate) fn maybe_resolved_mut(&mut self) -> &mut MaybeResolvedTarget {
90 &mut self.inner
91 }
92
93 pub fn is_tcp(&self) -> bool {
95 self.maybe_resolved().port().is_some()
96 }
97
98 pub fn port(&self) -> Option<u16> {
101 self.maybe_resolved().port()
102 }
103
104 pub fn try_set_port(&mut self, port: u16) -> Option<u16> {
107 self.maybe_resolved_mut().set_port(port)
108 }
109
110 pub fn path(&self) -> Option<&Path> {
113 self.maybe_resolved().path()
114 }
115
116 pub fn host(&self) -> Option<Cow<str>> {
121 self.maybe_resolved().host()
122 }
123
124 pub fn name(&self) -> Option<ServerName> {
128 self.maybe_resolved().name()
129 }
130
131 pub fn tcp(&self) -> Option<(Cow<str>, u16)> {
134 self.maybe_resolved().tcp()
135 }
136}
137
138#[derive(Clone)]
141pub struct Target {
142 inner: TargetInner,
143}
144
145impl std::fmt::Debug for Target {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 match &self.inner {
148 TargetInner::NoTls(target) => write!(f, "{:?}", target),
149 TargetInner::Tls(target, _) => write!(f, "{:?} (TLS)", target),
150 TargetInner::StartTls(target, _) => write!(f, "{:?} (STARTTLS)", target),
151 }
152 }
153}
154
155#[allow(private_bounds)]
156impl Target {
157 pub fn new(name: TargetName) -> Self {
158 Self {
159 inner: TargetInner::NoTls(name.inner),
160 }
161 }
162
163 pub fn new_tls(name: TargetName, params: TlsParameters) -> Self {
164 Self {
165 inner: TargetInner::Tls(name.inner, params.into()),
166 }
167 }
168
169 pub fn new_starttls(name: TargetName, params: TlsParameters) -> Self {
170 Self {
171 inner: TargetInner::StartTls(name.inner, params.into()),
172 }
173 }
174
175 pub fn new_resolved(target: ResolvedTarget) -> Self {
176 Self {
177 inner: TargetInner::NoTls(target.into()),
178 }
179 }
180
181 pub fn new_resolved_tls(target: ResolvedTarget, params: TlsParameters) -> Self {
182 Self {
183 inner: TargetInner::Tls(target.into(), params.into()),
184 }
185 }
186
187 pub fn new_resolved_starttls(target: ResolvedTarget, params: TlsParameters) -> Self {
188 Self {
189 inner: TargetInner::StartTls(target.into(), params.into()),
190 }
191 }
192
193 pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
195 #[cfg(unix)]
196 {
197 let path = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
198 Ok(Self {
199 inner: TargetInner::NoTls(path.into()),
200 })
201 }
202 #[cfg(not(unix))]
203 {
204 Err(std::io::Error::new(
205 std::io::ErrorKind::Unsupported,
206 "Unix sockets are not supported on this platform",
207 ))
208 }
209 }
210
211 pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
213 #[cfg(any(target_os = "linux", target_os = "android"))]
214 {
215 use std::os::linux::net::SocketAddrExt;
216 let domain =
217 ResolvedTarget::from(std::os::unix::net::SocketAddr::from_abstract_name(domain)?);
218 Ok(Self {
219 inner: TargetInner::NoTls(domain.into()),
220 })
221 }
222 #[cfg(not(any(target_os = "linux", target_os = "android")))]
223 {
224 Err(std::io::Error::new(
225 std::io::ErrorKind::Unsupported,
226 "Unix domain sockets are not supported on this platform",
227 ))
228 }
229 }
230
231 pub fn new_tcp(host: impl TcpResolve) -> Self {
233 Self {
234 inner: TargetInner::NoTls(host.into()),
235 }
236 }
237
238 pub fn new_tcp_tls(host: impl TcpResolve, params: TlsParameters) -> Self {
240 Self {
241 inner: TargetInner::Tls(host.into(), params.into()),
242 }
243 }
244
245 pub fn new_tcp_starttls(host: impl TcpResolve, params: TlsParameters) -> Self {
247 Self {
248 inner: TargetInner::StartTls(host.into(), params.into()),
249 }
250 }
251
252 pub fn try_set_tls(&mut self, params: TlsParameters) -> Option<Option<Arc<TlsParameters>>> {
253 if self.maybe_resolved().path().is_some() {
255 return None;
256 }
257
258 let params = params.into();
259
260 let no_target = TargetInner::NoTls(MaybeResolvedTarget::Resolved(
262 ResolvedTarget::SocketAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)),
263 ));
264
265 match std::mem::replace(&mut self.inner, no_target) {
266 TargetInner::NoTls(target) => {
267 self.inner = TargetInner::Tls(target, params);
268 Some(None)
269 }
270 TargetInner::Tls(target, old_params) => {
271 self.inner = TargetInner::Tls(target, params);
272 Some(Some(old_params))
273 }
274 TargetInner::StartTls(target, old_params) => {
275 self.inner = TargetInner::StartTls(target, params);
276 Some(Some(old_params))
277 }
278 }
279 }
280
281 pub fn try_remove_tls(&mut self) -> Option<Arc<TlsParameters>> {
282 let no_target = TargetInner::NoTls(MaybeResolvedTarget::Resolved(
284 ResolvedTarget::SocketAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)),
285 ));
286
287 match std::mem::replace(&mut self.inner, no_target) {
288 TargetInner::NoTls(target) => {
289 self.inner = TargetInner::NoTls(target);
290 None
291 }
292 TargetInner::Tls(target, old_params) => {
293 self.inner = TargetInner::NoTls(target);
294 Some(old_params)
295 }
296 TargetInner::StartTls(target, old_params) => {
297 self.inner = TargetInner::NoTls(target);
298 Some(old_params)
299 }
300 }
301 }
302
303 pub fn is_tcp(&self) -> bool {
305 self.maybe_resolved().port().is_some()
306 }
307
308 pub fn port(&self) -> Option<u16> {
311 self.maybe_resolved().port()
312 }
313
314 pub fn try_set_port(&mut self, port: u16) -> Option<u16> {
317 self.maybe_resolved_mut().set_port(port)
318 }
319
320 pub fn path(&self) -> Option<&Path> {
323 self.maybe_resolved().path()
324 }
325
326 pub fn host(&self) -> Option<Cow<str>> {
331 self.maybe_resolved().host()
332 }
333
334 pub fn name(&self) -> Option<ServerName> {
338 self.maybe_resolved().name()
339 }
340
341 pub fn tcp(&self) -> Option<(Cow<str>, u16)> {
344 self.maybe_resolved().tcp()
345 }
346
347 pub(crate) fn maybe_resolved(&self) -> &MaybeResolvedTarget {
348 match &self.inner {
349 TargetInner::NoTls(target) => target,
350 TargetInner::Tls(target, _) => target,
351 TargetInner::StartTls(target, _) => target,
352 }
353 }
354
355 pub(crate) fn maybe_resolved_mut(&mut self) -> &mut MaybeResolvedTarget {
356 match &mut self.inner {
357 TargetInner::NoTls(target) => target,
358 TargetInner::Tls(target, _) => target,
359 TargetInner::StartTls(target, _) => target,
360 }
361 }
362
363 pub(crate) fn is_starttls(&self) -> bool {
364 matches!(self.inner, TargetInner::StartTls(_, _))
365 }
366
367 pub(crate) fn maybe_ssl(&self) -> Option<&TlsParameters> {
368 match &self.inner {
369 TargetInner::NoTls(_) => None,
370 TargetInner::Tls(_, params) => Some(params),
371 TargetInner::StartTls(_, params) => Some(params),
372 }
373 }
374}
375
376#[derive(Clone, derive_more::From)]
377pub(crate) enum MaybeResolvedTarget {
378 Resolved(ResolvedTarget),
379 Unresolved(Cow<'static, str>, u16, Option<Cow<'static, str>>),
380}
381
382impl std::fmt::Debug for MaybeResolvedTarget {
383 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384 match self {
385 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
386 if let SocketAddr::V6(addr) = addr {
387 if addr.scope_id() != 0 {
388 write!(f, "[{}%{}]:{}", addr.ip(), addr.scope_id(), addr.port())
389 } else {
390 write!(f, "[{}]:{}", addr.ip(), addr.port())
391 }
392 } else {
393 write!(f, "{}:{}", addr.ip(), addr.port())
394 }
395 }
396 #[cfg(unix)]
397 MaybeResolvedTarget::Resolved(ResolvedTarget::UnixSocketAddr(addr)) => {
398 if let Some(path) = addr.as_pathname() {
399 return write!(f, "{}", path.to_string_lossy());
400 } else {
401 #[cfg(any(target_os = "linux", target_os = "android"))]
402 {
403 use std::os::linux::net::SocketAddrExt;
404 if let Some(name) = addr.as_abstract_name() {
405 return write!(f, "@{}", String::from_utf8_lossy(name));
406 }
407 }
408 }
409 Ok(())
410 }
411 MaybeResolvedTarget::Unresolved(host, port, interface) => {
412 write!(f, "{}:{}", host, port)?;
413 if let Some(interface) = interface {
414 write!(f, "%{}", interface)?;
415 }
416 Ok(())
417 }
418 }
419 }
420}
421
422impl MaybeResolvedTarget {
423 fn name(&self) -> Option<ServerName> {
424 match self {
425 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
426 Some(ServerName::IpAddress(addr.ip().into()))
427 }
428 MaybeResolvedTarget::Unresolved(host, _, _) => {
429 Some(ServerName::DnsName(host.to_string().try_into().ok()?))
430 }
431 _ => None,
432 }
433 }
434
435 fn tcp(&self) -> Option<(Cow<str>, u16)> {
436 match self {
437 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
438 Some((Cow::Owned(addr.ip().to_string()), addr.port()))
439 }
440 MaybeResolvedTarget::Unresolved(host, port, _) => Some((Cow::Borrowed(host), *port)),
441 _ => None,
442 }
443 }
444
445 fn path(&self) -> Option<&Path> {
446 match self {
447 #[cfg(unix)]
448 MaybeResolvedTarget::Resolved(ResolvedTarget::UnixSocketAddr(addr)) => {
449 addr.as_pathname()
450 }
451 _ => None,
452 }
453 }
454
455 fn host(&self) -> Option<Cow<str>> {
456 match self {
457 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
458 Some(Cow::Owned(addr.ip().to_string()))
459 }
460 MaybeResolvedTarget::Unresolved(host, _, _) => Some(Cow::Borrowed(host)),
461 _ => None,
462 }
463 }
464
465 fn port(&self) -> Option<u16> {
466 match self {
467 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => Some(addr.port()),
468 MaybeResolvedTarget::Unresolved(_, port, _) => Some(*port),
469 _ => None,
470 }
471 }
472
473 fn set_port(&mut self, new_port: u16) -> Option<u16> {
474 match self {
475 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
476 let old_port = addr.port();
477 addr.set_port(new_port);
478 Some(old_port)
479 }
480 MaybeResolvedTarget::Unresolved(_, port, _) => {
481 let old_port = *port;
482 *port = new_port;
483 Some(old_port)
484 }
485 _ => None,
486 }
487 }
488}
489
490#[derive(Clone, Debug)]
492enum TargetInner {
493 NoTls(MaybeResolvedTarget),
494 Tls(MaybeResolvedTarget, Arc<TlsParameters>),
495 StartTls(MaybeResolvedTarget, Arc<TlsParameters>),
496}
497
498#[derive(Clone, Debug, derive_more::From)]
499pub enum ResolvedTarget {
501 SocketAddr(std::net::SocketAddr),
502 #[cfg(unix)]
503 UnixSocketAddr(std::os::unix::net::SocketAddr),
504}
505
506impl ResolvedTarget {
507 pub fn tcp(&self) -> Option<SocketAddr> {
508 match self {
509 ResolvedTarget::SocketAddr(addr) => Some(*addr),
510 _ => None,
511 }
512 }
513}
514
515pub trait LocalAddress {
517 fn local_address(&self) -> std::io::Result<ResolvedTarget>;
518}
519
520trait TcpResolve {
521 fn into(self) -> MaybeResolvedTarget;
522}
523
524impl<S: AsRef<str>> TcpResolve for (S, u16) {
525 fn into(self) -> MaybeResolvedTarget {
526 if let Ok(addr) = self.0.as_ref().parse::<IpAddr>() {
527 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(SocketAddr::new(addr, self.1)))
528 } else {
529 MaybeResolvedTarget::Unresolved(Cow::Owned(self.0.as_ref().to_owned()), self.1, None)
530 }
531 }
532}
533
534impl TcpResolve for SocketAddr {
535 fn into(self) -> MaybeResolvedTarget {
536 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(self))
537 }
538}
539
540#[cfg(test)]
541mod tests {
542 use std::net::SocketAddrV6;
543
544 use super::*;
545
546 #[test]
547 fn test_target() {
548 let target = Target::new_tcp(("localhost", 5432));
549 assert_eq!(
550 target.name(),
551 Some(ServerName::DnsName("localhost".try_into().unwrap()))
552 );
553 }
554
555 #[test]
556 fn test_target_name() {
557 let target = TargetName::new_tcp(("localhost", 5432));
558 assert_eq!(format!("{target:?}"), "localhost:5432");
559
560 let target = TargetName::new_tcp(("127.0.0.1", 5432));
561 assert_eq!(format!("{target:?}"), "127.0.0.1:5432");
562
563 let target = TargetName::new_tcp(("::1", 5432));
564 assert_eq!(format!("{target:?}"), "[::1]:5432");
565
566 let target = TargetName::new_tcp(SocketAddr::V6(SocketAddrV6::new(
567 "fe80::1ff:fe23:4567:890a".parse().unwrap(),
568 5432,
569 0,
570 2,
571 )));
572 assert_eq!(format!("{target:?}"), "[fe80::1ff:fe23:4567:890a%2]:5432");
573
574 #[cfg(unix)]
575 {
576 let target = TargetName::new_unix_path("/tmp/test.sock").unwrap();
577 assert_eq!(format!("{target:?}"), "/tmp/test.sock");
578 }
579
580 #[cfg(any(target_os = "linux", target_os = "android"))]
581 {
582 let target = TargetName::new_unix_domain("test").unwrap();
583 assert_eq!(format!("{target:?}"), "@test");
584 }
585 }
586
587 #[test]
588 fn test_target_debug() {
589 let target = Target::new_tcp(("localhost", 5432));
590 assert_eq!(format!("{target:?}"), "localhost:5432");
591
592 let target = Target::new_tcp_tls(("localhost", 5432), TlsParameters::default());
593 assert_eq!(format!("{target:?}"), "localhost:5432 (TLS)");
594
595 let target = Target::new_tcp_starttls(("localhost", 5432), TlsParameters::default());
596 assert_eq!(format!("{target:?}"), "localhost:5432 (STARTTLS)");
597
598 let target = Target::new_tcp(("127.0.0.1", 5432));
599 assert_eq!(format!("{target:?}"), "127.0.0.1:5432");
600
601 let target = Target::new_tcp(("::1", 5432));
602 assert_eq!(format!("{target:?}"), "[::1]:5432");
603
604 #[cfg(unix)]
605 {
606 let target = Target::new_unix_path("/tmp/test.sock").unwrap();
607 assert_eq!(format!("{target:?}"), "/tmp/test.sock");
608 }
609
610 #[cfg(any(target_os = "linux", target_os = "android"))]
611 {
612 let target = Target::new_unix_domain("test").unwrap();
613 assert_eq!(format!("{target:?}"), "@test");
614 }
615 }
616}