1use std::{
2 borrow::Cow,
3 net::{IpAddr, Ipv4Addr, SocketAddr},
4 path::Path,
5 sync::Arc,
6};
7
8use derive_more::Debug;
9use rustls_pki_types::ServerName;
10
11use crate::TlsParameters;
12
13pub struct TargetName {
15 inner: MaybeResolvedTarget,
16}
17
18impl std::fmt::Debug for TargetName {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 write!(f, "{:?}", self.inner)
21 }
22}
23
24impl TargetName {
25 pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
27 #[cfg(unix)]
28 {
29 let path = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
30 Ok(Self {
31 inner: MaybeResolvedTarget::Resolved(path),
32 })
33 }
34 #[cfg(not(unix))]
35 {
36 Err(std::io::Error::new(
37 std::io::ErrorKind::Unsupported,
38 "Unix sockets are not supported on this platform",
39 ))
40 }
41 }
42
43 pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
45 #[cfg(any(target_os = "linux", target_os = "android"))]
46 {
47 use std::os::linux::net::SocketAddrExt;
48 let domain =
49 ResolvedTarget::from(std::os::unix::net::SocketAddr::from_abstract_name(domain)?);
50 Ok(Self {
51 inner: MaybeResolvedTarget::Resolved(domain),
52 })
53 }
54 #[cfg(not(any(target_os = "linux", target_os = "android")))]
55 {
56 Err(std::io::Error::new(
57 std::io::ErrorKind::Unsupported,
58 "Unix domain sockets are not supported on this platform",
59 ))
60 }
61 }
62
63 #[allow(private_bounds)]
65 pub fn new_tcp(host: impl TcpResolve) -> Self {
66 Self { inner: host.into() }
67 }
68
69 pub fn to_addrs_sync(&self) -> Result<Vec<ResolvedTarget>, std::io::Error> {
71 use std::net::ToSocketAddrs;
72 let mut result = Vec::new();
73 match &self.inner {
74 MaybeResolvedTarget::Resolved(addr) => {
75 return Ok(vec![addr.clone()]);
76 }
77 MaybeResolvedTarget::Unresolved(host, port, _interface) => {
78 let addrs = format!("{}:{}", host, port).to_socket_addrs()?;
79 result.extend(addrs.map(ResolvedTarget::SocketAddr));
80 }
81 }
82 Ok(result)
83 }
84}
85
86#[derive(Clone)]
87pub struct Target {
88 inner: TargetInner,
89}
90
91impl std::fmt::Debug for Target {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 match &self.inner {
94 TargetInner::NoTls(target) => write!(f, "{:?}", target),
95 TargetInner::Tls(target, _) => write!(f, "{:?} (TLS)", target),
96 TargetInner::StartTls(target, _) => write!(f, "{:?} (STARTTLS)", target),
97 }
98 }
99}
100
101#[allow(private_bounds)]
102impl Target {
103 pub fn new(name: TargetName) -> Self {
104 Self {
105 inner: TargetInner::NoTls(name.inner),
106 }
107 }
108
109 pub fn new_tls(name: TargetName, params: TlsParameters) -> Self {
110 Self {
111 inner: TargetInner::Tls(name.inner, params.into()),
112 }
113 }
114
115 pub fn new_starttls(name: TargetName, params: TlsParameters) -> Self {
116 Self {
117 inner: TargetInner::StartTls(name.inner, params.into()),
118 }
119 }
120
121 pub fn new_resolved(target: ResolvedTarget) -> Self {
122 Self {
123 inner: TargetInner::NoTls(target.into()),
124 }
125 }
126
127 pub fn new_resolved_tls(target: ResolvedTarget, params: TlsParameters) -> Self {
128 Self {
129 inner: TargetInner::Tls(target.into(), params.into()),
130 }
131 }
132
133 pub fn new_resolved_starttls(target: ResolvedTarget, params: TlsParameters) -> Self {
134 Self {
135 inner: TargetInner::StartTls(target.into(), params.into()),
136 }
137 }
138
139 pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
141 #[cfg(unix)]
142 {
143 let path = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
144 Ok(Self {
145 inner: TargetInner::NoTls(path.into()),
146 })
147 }
148 #[cfg(not(unix))]
149 {
150 Err(std::io::Error::new(
151 std::io::ErrorKind::Unsupported,
152 "Unix sockets are not supported on this platform",
153 ))
154 }
155 }
156
157 pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
159 #[cfg(any(target_os = "linux", target_os = "android"))]
160 {
161 use std::os::linux::net::SocketAddrExt;
162 let domain =
163 ResolvedTarget::from(std::os::unix::net::SocketAddr::from_abstract_name(domain)?);
164 Ok(Self {
165 inner: TargetInner::NoTls(domain.into()),
166 })
167 }
168 #[cfg(not(any(target_os = "linux", target_os = "android")))]
169 {
170 Err(std::io::Error::new(
171 std::io::ErrorKind::Unsupported,
172 "Unix domain sockets are not supported on this platform",
173 ))
174 }
175 }
176
177 pub fn new_tcp(host: impl TcpResolve) -> Self {
179 Self {
180 inner: TargetInner::NoTls(host.into()),
181 }
182 }
183
184 pub fn new_tcp_tls(host: impl TcpResolve, params: TlsParameters) -> Self {
186 Self {
187 inner: TargetInner::Tls(host.into(), params.into()),
188 }
189 }
190
191 pub fn new_tcp_starttls(host: impl TcpResolve, params: TlsParameters) -> Self {
193 Self {
194 inner: TargetInner::StartTls(host.into(), params.into()),
195 }
196 }
197
198 pub fn try_set_tls(&mut self, params: TlsParameters) -> Option<Option<Arc<TlsParameters>>> {
199 if self.maybe_resolved().path().is_some() {
201 return None;
202 }
203
204 let params = params.into();
205
206 let no_target = TargetInner::NoTls(MaybeResolvedTarget::Resolved(
208 ResolvedTarget::SocketAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)),
209 ));
210
211 match std::mem::replace(&mut self.inner, no_target) {
212 TargetInner::NoTls(target) => {
213 self.inner = TargetInner::Tls(target, params);
214 Some(None)
215 }
216 TargetInner::Tls(target, old_params) => {
217 self.inner = TargetInner::Tls(target, params);
218 Some(Some(old_params))
219 }
220 TargetInner::StartTls(target, old_params) => {
221 self.inner = TargetInner::StartTls(target, params);
222 Some(Some(old_params))
223 }
224 }
225 }
226
227 pub fn try_remove_tls(&mut self) -> Option<Arc<TlsParameters>> {
228 let no_target = TargetInner::NoTls(MaybeResolvedTarget::Resolved(
230 ResolvedTarget::SocketAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)),
231 ));
232
233 match std::mem::replace(&mut self.inner, no_target) {
234 TargetInner::NoTls(target) => {
235 self.inner = TargetInner::NoTls(target);
236 None
237 }
238 TargetInner::Tls(target, old_params) => {
239 self.inner = TargetInner::NoTls(target);
240 Some(old_params)
241 }
242 TargetInner::StartTls(target, old_params) => {
243 self.inner = TargetInner::NoTls(target);
244 Some(old_params)
245 }
246 }
247 }
248
249 pub fn is_tcp(&self) -> bool {
251 self.maybe_resolved().port().is_some()
252 }
253
254 pub fn port(&self) -> Option<u16> {
257 self.maybe_resolved().port()
258 }
259
260 pub fn try_set_port(&mut self, port: u16) -> Option<u16> {
263 self.maybe_resolved_mut().set_port(port)
264 }
265
266 pub fn path(&self) -> Option<&Path> {
269 self.maybe_resolved().path()
270 }
271
272 pub fn host(&self) -> Option<Cow<str>> {
277 self.maybe_resolved().host()
278 }
279
280 pub fn name(&self) -> Option<ServerName> {
284 self.maybe_resolved().name()
285 }
286
287 pub fn tcp(&self) -> Option<(Cow<str>, u16)> {
290 self.maybe_resolved().tcp()
291 }
292
293 pub(crate) fn maybe_resolved(&self) -> &MaybeResolvedTarget {
294 match &self.inner {
295 TargetInner::NoTls(target) => target,
296 TargetInner::Tls(target, _) => target,
297 TargetInner::StartTls(target, _) => target,
298 }
299 }
300
301 pub(crate) fn maybe_resolved_mut(&mut self) -> &mut MaybeResolvedTarget {
302 match &mut self.inner {
303 TargetInner::NoTls(target) => target,
304 TargetInner::Tls(target, _) => target,
305 TargetInner::StartTls(target, _) => target,
306 }
307 }
308
309 pub(crate) fn is_starttls(&self) -> bool {
310 matches!(self.inner, TargetInner::StartTls(_, _))
311 }
312
313 pub(crate) fn maybe_ssl(&self) -> Option<&TlsParameters> {
314 match &self.inner {
315 TargetInner::NoTls(_) => None,
316 TargetInner::Tls(_, params) => Some(params),
317 TargetInner::StartTls(_, params) => Some(params),
318 }
319 }
320}
321
322#[derive(Clone, derive_more::From)]
323pub(crate) enum MaybeResolvedTarget {
324 Resolved(ResolvedTarget),
325 Unresolved(Cow<'static, str>, u16, Option<Cow<'static, str>>),
326}
327
328impl std::fmt::Debug for MaybeResolvedTarget {
329 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330 match self {
331 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
332 if let SocketAddr::V6(addr) = addr {
333 if addr.scope_id() != 0 {
334 write!(f, "[{}%{}]:{}", addr.ip(), addr.scope_id(), addr.port())
335 } else {
336 write!(f, "[{}]:{}", addr.ip(), addr.port())
337 }
338 } else {
339 write!(f, "{}:{}", addr.ip(), addr.port())
340 }
341 }
342 #[cfg(unix)]
343 MaybeResolvedTarget::Resolved(ResolvedTarget::UnixSocketAddr(addr)) => {
344 if let Some(path) = addr.as_pathname() {
345 return write!(f, "{}", path.to_string_lossy());
346 } else {
347 #[cfg(any(target_os = "linux", target_os = "android"))]
348 {
349 use std::os::linux::net::SocketAddrExt;
350 if let Some(name) = addr.as_abstract_name() {
351 return write!(f, "@{}", String::from_utf8_lossy(name));
352 }
353 }
354 }
355 Ok(())
356 }
357 MaybeResolvedTarget::Unresolved(host, port, interface) => {
358 write!(f, "{}:{}", host, port)?;
359 if let Some(interface) = interface {
360 write!(f, "%{}", interface)?;
361 }
362 Ok(())
363 }
364 }
365 }
366}
367
368impl MaybeResolvedTarget {
369 fn name(&self) -> Option<ServerName> {
370 match self {
371 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
372 Some(ServerName::IpAddress(addr.ip().into()))
373 }
374 MaybeResolvedTarget::Unresolved(host, _, _) => {
375 Some(ServerName::DnsName(host.to_string().try_into().ok()?))
376 }
377 _ => None,
378 }
379 }
380
381 fn tcp(&self) -> Option<(Cow<str>, u16)> {
382 match self {
383 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
384 Some((Cow::Owned(addr.ip().to_string()), addr.port()))
385 }
386 MaybeResolvedTarget::Unresolved(host, port, _) => Some((Cow::Borrowed(host), *port)),
387 _ => None,
388 }
389 }
390
391 fn path(&self) -> Option<&Path> {
392 match self {
393 #[cfg(unix)]
394 MaybeResolvedTarget::Resolved(ResolvedTarget::UnixSocketAddr(addr)) => {
395 addr.as_pathname()
396 }
397 _ => None,
398 }
399 }
400
401 fn host(&self) -> Option<Cow<str>> {
402 match self {
403 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
404 Some(Cow::Owned(addr.ip().to_string()))
405 }
406 MaybeResolvedTarget::Unresolved(host, _, _) => Some(Cow::Borrowed(host)),
407 _ => None,
408 }
409 }
410
411 fn port(&self) -> Option<u16> {
412 match self {
413 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => Some(addr.port()),
414 MaybeResolvedTarget::Unresolved(_, port, _) => Some(*port),
415 _ => None,
416 }
417 }
418
419 fn set_port(&mut self, new_port: u16) -> Option<u16> {
420 match self {
421 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
422 let old_port = addr.port();
423 addr.set_port(new_port);
424 Some(old_port)
425 }
426 MaybeResolvedTarget::Unresolved(_, port, _) => {
427 let old_port = *port;
428 *port = new_port;
429 Some(old_port)
430 }
431 _ => None,
432 }
433 }
434}
435
436#[derive(Clone, Debug)]
438enum TargetInner {
439 NoTls(MaybeResolvedTarget),
440 Tls(MaybeResolvedTarget, Arc<TlsParameters>),
441 StartTls(MaybeResolvedTarget, Arc<TlsParameters>),
442}
443
444#[derive(Clone, Debug, derive_more::From)]
445pub enum ResolvedTarget {
447 SocketAddr(std::net::SocketAddr),
448 #[cfg(unix)]
449 UnixSocketAddr(std::os::unix::net::SocketAddr),
450}
451
452impl ResolvedTarget {
453 pub fn tcp(&self) -> Option<SocketAddr> {
454 match self {
455 ResolvedTarget::SocketAddr(addr) => Some(*addr),
456 _ => None,
457 }
458 }
459}
460
461pub trait LocalAddress {
462 fn local_address(&self) -> std::io::Result<ResolvedTarget>;
463}
464
465trait TcpResolve {
466 fn into(self) -> MaybeResolvedTarget;
467}
468
469impl<S: AsRef<str>> TcpResolve for (S, u16) {
470 fn into(self) -> MaybeResolvedTarget {
471 if let Ok(addr) = self.0.as_ref().parse::<IpAddr>() {
472 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(SocketAddr::new(addr, self.1)))
473 } else {
474 MaybeResolvedTarget::Unresolved(Cow::Owned(self.0.as_ref().to_owned()), self.1, None)
475 }
476 }
477}
478
479impl TcpResolve for SocketAddr {
480 fn into(self) -> MaybeResolvedTarget {
481 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(self))
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use std::net::SocketAddrV6;
488
489 use super::*;
490
491 #[test]
492 fn test_target() {
493 let target = Target::new_tcp(("localhost", 5432));
494 assert_eq!(
495 target.name(),
496 Some(ServerName::DnsName("localhost".try_into().unwrap()))
497 );
498 }
499
500 #[test]
501 fn test_target_name() {
502 let target = TargetName::new_tcp(("localhost", 5432));
503 assert_eq!(format!("{target:?}"), "localhost:5432");
504
505 let target = TargetName::new_tcp(("127.0.0.1", 5432));
506 assert_eq!(format!("{target:?}"), "127.0.0.1:5432");
507
508 let target = TargetName::new_tcp(("::1", 5432));
509 assert_eq!(format!("{target:?}"), "[::1]:5432");
510
511 let target = TargetName::new_tcp(SocketAddr::V6(SocketAddrV6::new(
512 "fe80::1ff:fe23:4567:890a".parse().unwrap(),
513 5432,
514 0,
515 2,
516 )));
517 assert_eq!(format!("{target:?}"), "[fe80::1ff:fe23:4567:890a%2]:5432");
518
519 let target = TargetName::new_unix_path("/tmp/test.sock").unwrap();
520 assert_eq!(format!("{target:?}"), "/tmp/test.sock");
521
522 #[cfg(any(target_os = "linux", target_os = "android"))]
523 {
524 let target = TargetName::new_unix_domain("test").unwrap();
525 assert_eq!(format!("{target:?}"), "@test");
526 }
527 }
528
529 #[test]
530 fn test_target_debug() {
531 let target = Target::new_tcp(("localhost", 5432));
532 assert_eq!(format!("{target:?}"), "localhost:5432");
533
534 let target = Target::new_tcp_tls(("localhost", 5432), TlsParameters::default());
535 assert_eq!(format!("{target:?}"), "localhost:5432 (TLS)");
536
537 let target = Target::new_tcp_starttls(("localhost", 5432), TlsParameters::default());
538 assert_eq!(format!("{target:?}"), "localhost:5432 (STARTTLS)");
539
540 let target = Target::new_tcp(("127.0.0.1", 5432));
541 assert_eq!(format!("{target:?}"), "127.0.0.1:5432");
542
543 let target = Target::new_tcp(("::1", 5432));
544 assert_eq!(format!("{target:?}"), "[::1]:5432");
545
546 #[cfg(unix)]
547 {
548 let target = Target::new_unix_path("/tmp/test.sock").unwrap();
549 assert_eq!(format!("{target:?}"), "/tmp/test.sock");
550 }
551
552 #[cfg(any(target_os = "linux", target_os = "android"))]
553 {
554 let target = Target::new_unix_domain("test").unwrap();
555 assert_eq!(format!("{target:?}"), "@test");
556 }
557 }
558}