1use std::{
2 borrow::Cow,
3 net::{IpAddr, Ipv4Addr, SocketAddr},
4 path::Path,
5 sync::Arc,
6};
7
8use derive_more::Debug;
9use rustls_pki_types::ServerName;
10
11use crate::TlsParameters;
12
13pub struct TargetName {
15 inner: MaybeResolvedTarget,
16}
17
18impl std::fmt::Debug for TargetName {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 write!(f, "{:?}", self.inner)
21 }
22}
23
24impl TargetName {
25 pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
27 #[cfg(unix)]
28 {
29 let path = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
30 Ok(Self {
31 inner: MaybeResolvedTarget::Resolved(path),
32 })
33 }
34 #[cfg(not(unix))]
35 {
36 Err(std::io::Error::new(
37 std::io::ErrorKind::Unsupported,
38 "Unix sockets are not supported on this platform",
39 ))
40 }
41 }
42
43 pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
45 #[cfg(any(target_os = "linux", target_os = "android"))]
46 {
47 use std::os::linux::net::SocketAddrExt;
48 let domain =
49 ResolvedTarget::from(std::os::unix::net::SocketAddr::from_abstract_name(domain)?);
50 Ok(Self {
51 inner: MaybeResolvedTarget::Resolved(domain),
52 })
53 }
54 #[cfg(not(any(target_os = "linux", target_os = "android")))]
55 {
56 Err(std::io::Error::new(
57 std::io::ErrorKind::Unsupported,
58 "Unix domain sockets are not supported on this platform",
59 ))
60 }
61 }
62
63 #[allow(private_bounds)]
65 pub fn new_tcp(host: impl TcpResolve) -> Self {
66 Self { inner: host.into() }
67 }
68
69 pub fn to_addrs_sync(&self) -> Result<Vec<ResolvedTarget>, std::io::Error> {
71 use std::net::ToSocketAddrs;
72 let mut result = Vec::new();
73 match &self.inner {
74 MaybeResolvedTarget::Resolved(addr) => {
75 return Ok(vec![addr.clone()]);
76 }
77 MaybeResolvedTarget::Unresolved(host, port, _interface) => {
78 let addrs = format!("{}:{}", host, port).to_socket_addrs()?;
79 result.extend(addrs.map(ResolvedTarget::SocketAddr));
80 }
81 }
82 Ok(result)
83 }
84}
85
86#[derive(Clone, Debug)]
87pub struct Target {
88 inner: TargetInner,
89}
90
91#[allow(private_bounds)]
92impl Target {
93 pub fn new(name: TargetName) -> Self {
94 Self {
95 inner: TargetInner::NoTls(name.inner),
96 }
97 }
98
99 pub fn new_tls(name: TargetName, params: TlsParameters) -> Self {
100 Self {
101 inner: TargetInner::Tls(name.inner, params.into()),
102 }
103 }
104
105 pub fn new_starttls(name: TargetName, params: TlsParameters) -> Self {
106 Self {
107 inner: TargetInner::StartTls(name.inner, params.into()),
108 }
109 }
110
111 pub fn new_resolved(target: ResolvedTarget) -> Self {
112 Self {
113 inner: TargetInner::NoTls(target.into()),
114 }
115 }
116
117 pub fn new_resolved_tls(target: ResolvedTarget, params: TlsParameters) -> Self {
118 Self {
119 inner: TargetInner::Tls(target.into(), params.into()),
120 }
121 }
122
123 pub fn new_resolved_starttls(target: ResolvedTarget, params: TlsParameters) -> Self {
124 Self {
125 inner: TargetInner::StartTls(target.into(), params.into()),
126 }
127 }
128
129 pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
131 #[cfg(unix)]
132 {
133 let path = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
134 Ok(Self {
135 inner: TargetInner::NoTls(path.into()),
136 })
137 }
138 #[cfg(not(unix))]
139 {
140 Err(std::io::Error::new(
141 std::io::ErrorKind::Unsupported,
142 "Unix sockets are not supported on this platform",
143 ))
144 }
145 }
146
147 pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
149 #[cfg(any(target_os = "linux", target_os = "android"))]
150 {
151 use std::os::linux::net::SocketAddrExt;
152 let domain =
153 ResolvedTarget::from(std::os::unix::net::SocketAddr::from_abstract_name(domain)?);
154 Ok(Self {
155 inner: TargetInner::NoTls(domain.into()),
156 })
157 }
158 #[cfg(not(any(target_os = "linux", target_os = "android")))]
159 {
160 Err(std::io::Error::new(
161 std::io::ErrorKind::Unsupported,
162 "Unix domain sockets are not supported on this platform",
163 ))
164 }
165 }
166
167 pub fn new_tcp(host: impl TcpResolve) -> Self {
169 Self {
170 inner: TargetInner::NoTls(host.into()),
171 }
172 }
173
174 pub fn new_tcp_tls(host: impl TcpResolve, params: TlsParameters) -> Self {
176 Self {
177 inner: TargetInner::Tls(host.into(), params.into()),
178 }
179 }
180
181 pub fn new_tcp_starttls(host: impl TcpResolve, params: TlsParameters) -> Self {
183 Self {
184 inner: TargetInner::StartTls(host.into(), params.into()),
185 }
186 }
187
188 pub fn try_set_tls(&mut self, params: TlsParameters) -> Option<Option<Arc<TlsParameters>>> {
189 if self.maybe_resolved().path().is_some() {
191 return None;
192 }
193
194 let params = params.into();
195
196 let no_target = TargetInner::NoTls(MaybeResolvedTarget::Resolved(
198 ResolvedTarget::SocketAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)),
199 ));
200
201 match std::mem::replace(&mut self.inner, no_target) {
202 TargetInner::NoTls(target) => {
203 self.inner = TargetInner::Tls(target, params);
204 Some(None)
205 }
206 TargetInner::Tls(target, old_params) => {
207 self.inner = TargetInner::Tls(target, params);
208 Some(Some(old_params))
209 }
210 TargetInner::StartTls(target, old_params) => {
211 self.inner = TargetInner::StartTls(target, params);
212 Some(Some(old_params))
213 }
214 }
215 }
216
217 pub fn try_remove_tls(&mut self) -> Option<Arc<TlsParameters>> {
218 let no_target = TargetInner::NoTls(MaybeResolvedTarget::Resolved(
220 ResolvedTarget::SocketAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)),
221 ));
222
223 match std::mem::replace(&mut self.inner, no_target) {
224 TargetInner::NoTls(target) => {
225 self.inner = TargetInner::NoTls(target);
226 None
227 }
228 TargetInner::Tls(target, old_params) => {
229 self.inner = TargetInner::NoTls(target);
230 Some(old_params)
231 }
232 TargetInner::StartTls(target, old_params) => {
233 self.inner = TargetInner::NoTls(target);
234 Some(old_params)
235 }
236 }
237 }
238
239 pub fn is_tcp(&self) -> bool {
241 self.maybe_resolved().port().is_some()
242 }
243
244 pub fn port(&self) -> Option<u16> {
247 self.maybe_resolved().port()
248 }
249
250 pub fn try_set_port(&mut self, port: u16) -> Option<u16> {
253 self.maybe_resolved_mut().set_port(port)
254 }
255
256 pub fn path(&self) -> Option<&Path> {
259 self.maybe_resolved().path()
260 }
261
262 pub fn host(&self) -> Option<Cow<str>> {
267 self.maybe_resolved().host()
268 }
269
270 pub fn name(&self) -> Option<ServerName> {
274 self.maybe_resolved().name()
275 }
276
277 pub fn tcp(&self) -> Option<(Cow<str>, u16)> {
280 self.maybe_resolved().tcp()
281 }
282
283 pub(crate) fn maybe_resolved(&self) -> &MaybeResolvedTarget {
284 match &self.inner {
285 TargetInner::NoTls(target) => target,
286 TargetInner::Tls(target, _) => target,
287 TargetInner::StartTls(target, _) => target,
288 }
289 }
290
291 pub(crate) fn maybe_resolved_mut(&mut self) -> &mut MaybeResolvedTarget {
292 match &mut self.inner {
293 TargetInner::NoTls(target) => target,
294 TargetInner::Tls(target, _) => target,
295 TargetInner::StartTls(target, _) => target,
296 }
297 }
298
299 pub(crate) fn is_starttls(&self) -> bool {
300 matches!(self.inner, TargetInner::StartTls(_, _))
301 }
302
303 pub(crate) fn maybe_ssl(&self) -> Option<&TlsParameters> {
304 match &self.inner {
305 TargetInner::NoTls(_) => None,
306 TargetInner::Tls(_, params) => Some(params),
307 TargetInner::StartTls(_, params) => Some(params),
308 }
309 }
310}
311
312#[derive(Clone, derive_more::From)]
313pub(crate) enum MaybeResolvedTarget {
314 Resolved(ResolvedTarget),
315 Unresolved(Cow<'static, str>, u16, Option<Cow<'static, str>>),
316}
317
318impl std::fmt::Debug for MaybeResolvedTarget {
319 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320 match self {
321 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
322 if let SocketAddr::V6(addr) = addr {
323 if addr.scope_id() != 0 {
324 write!(f, "[{}%{}]:{}", addr.ip(), addr.scope_id(), addr.port())
325 } else {
326 write!(f, "[{}]:{}", addr.ip(), addr.port())
327 }
328 } else {
329 write!(f, "{}:{}", addr.ip(), addr.port())
330 }
331 }
332 #[cfg(unix)]
333 MaybeResolvedTarget::Resolved(ResolvedTarget::UnixSocketAddr(addr)) => {
334 if let Some(path) = addr.as_pathname() {
335 return write!(f, "{}", path.to_string_lossy());
336 } else {
337 #[cfg(any(target_os = "linux", target_os = "android"))]
338 {
339 use std::os::linux::net::SocketAddrExt;
340 if let Some(name) = addr.as_abstract_name() {
341 return write!(f, "@{}", String::from_utf8_lossy(name));
342 }
343 }
344 }
345 Ok(())
346 }
347 MaybeResolvedTarget::Unresolved(host, port, interface) => {
348 write!(f, "{}:{}", host, port)?;
349 if let Some(interface) = interface {
350 write!(f, "%{}", interface)?;
351 }
352 Ok(())
353 }
354 }
355 }
356}
357
358impl MaybeResolvedTarget {
359 fn name(&self) -> Option<ServerName> {
360 match self {
361 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
362 Some(ServerName::IpAddress(addr.ip().into()))
363 }
364 MaybeResolvedTarget::Unresolved(host, _, _) => {
365 Some(ServerName::DnsName(host.to_string().try_into().ok()?))
366 }
367 _ => None,
368 }
369 }
370
371 fn tcp(&self) -> Option<(Cow<str>, u16)> {
372 match self {
373 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
374 Some((Cow::Owned(addr.ip().to_string()), addr.port()))
375 }
376 MaybeResolvedTarget::Unresolved(host, port, _) => Some((Cow::Borrowed(host), *port)),
377 _ => None,
378 }
379 }
380
381 fn path(&self) -> Option<&Path> {
382 match self {
383 #[cfg(unix)]
384 MaybeResolvedTarget::Resolved(ResolvedTarget::UnixSocketAddr(addr)) => {
385 addr.as_pathname()
386 }
387 _ => None,
388 }
389 }
390
391 fn host(&self) -> Option<Cow<str>> {
392 match self {
393 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
394 Some(Cow::Owned(addr.ip().to_string()))
395 }
396 MaybeResolvedTarget::Unresolved(host, _, _) => Some(Cow::Borrowed(host)),
397 _ => None,
398 }
399 }
400
401 fn port(&self) -> Option<u16> {
402 match self {
403 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => Some(addr.port()),
404 MaybeResolvedTarget::Unresolved(_, port, _) => Some(*port),
405 _ => None,
406 }
407 }
408
409 fn set_port(&mut self, new_port: u16) -> Option<u16> {
410 match self {
411 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
412 let old_port = addr.port();
413 addr.set_port(new_port);
414 Some(old_port)
415 }
416 MaybeResolvedTarget::Unresolved(_, port, _) => {
417 let old_port = *port;
418 *port = new_port;
419 Some(old_port)
420 }
421 _ => None,
422 }
423 }
424}
425
426#[derive(Clone, Debug)]
428enum TargetInner {
429 NoTls(MaybeResolvedTarget),
430 Tls(MaybeResolvedTarget, Arc<TlsParameters>),
431 StartTls(MaybeResolvedTarget, Arc<TlsParameters>),
432}
433
434#[derive(Clone, Debug, derive_more::From)]
435pub enum ResolvedTarget {
437 SocketAddr(std::net::SocketAddr),
438 #[cfg(unix)]
439 UnixSocketAddr(std::os::unix::net::SocketAddr),
440}
441
442impl ResolvedTarget {
443 pub fn tcp(&self) -> Option<SocketAddr> {
444 match self {
445 ResolvedTarget::SocketAddr(addr) => Some(*addr),
446 _ => None,
447 }
448 }
449}
450
451pub trait LocalAddress {
452 fn local_address(&self) -> std::io::Result<ResolvedTarget>;
453}
454
455trait TcpResolve {
456 fn into(self) -> MaybeResolvedTarget;
457}
458
459impl<S: AsRef<str>> TcpResolve for (S, u16) {
460 fn into(self) -> MaybeResolvedTarget {
461 if let Ok(addr) = self.0.as_ref().parse::<IpAddr>() {
462 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(SocketAddr::new(addr, self.1)))
463 } else {
464 MaybeResolvedTarget::Unresolved(Cow::Owned(self.0.as_ref().to_owned()), self.1, None)
465 }
466 }
467}
468
469impl TcpResolve for SocketAddr {
470 fn into(self) -> MaybeResolvedTarget {
471 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(self))
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use std::net::SocketAddrV6;
478
479 use super::*;
480
481 #[test]
482 fn test_target() {
483 let target = Target::new_tcp(("localhost", 5432));
484 assert_eq!(
485 target.name(),
486 Some(ServerName::DnsName("localhost".try_into().unwrap()))
487 );
488 }
489
490 #[test]
491 fn test_target_name() {
492 let target = TargetName::new_tcp(("localhost", 5432));
493 assert_eq!(format!("{target:?}"), "localhost:5432");
494
495 let target = TargetName::new_tcp(("127.0.0.1", 5432));
496 assert_eq!(format!("{target:?}"), "127.0.0.1:5432");
497
498 let target = TargetName::new_tcp(("::1", 5432));
499 assert_eq!(format!("{target:?}"), "[::1]:5432");
500
501 let target = TargetName::new_tcp(SocketAddr::V6(SocketAddrV6::new(
502 "fe80::1ff:fe23:4567:890a".parse().unwrap(),
503 5432,
504 0,
505 2,
506 )));
507 assert_eq!(format!("{target:?}"), "[fe80::1ff:fe23:4567:890a%2]:5432");
508
509 let target = TargetName::new_unix_path("/tmp/test.sock").unwrap();
510 assert_eq!(format!("{target:?}"), "/tmp/test.sock");
511
512 #[cfg(any(target_os = "linux", target_os = "android"))]
513 {
514 let target = TargetName::new_unix_domain("test").unwrap();
515 assert_eq!(format!("{target:?}"), "@test");
516 }
517 }
518}