1use std::{
2 borrow::Cow,
3 net::{IpAddr, Ipv4Addr, SocketAddr},
4 path::Path,
5 sync::Arc,
6};
7
8use derive_more::Debug;
9use rustls_pki_types::ServerName;
10
11use crate::TlsParameters;
12
13pub struct TargetName {
15 inner: MaybeResolvedTarget,
16}
17
18impl std::fmt::Debug for TargetName {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 write!(f, "{:?}", self.inner)
21 }
22}
23
24impl TargetName {
25 #[cfg(unix)]
27 pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
28 let path = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
29 Ok(Self {
30 inner: MaybeResolvedTarget::Resolved(path),
31 })
32 }
33
34 #[cfg(any(target_os = "linux", target_os = "android"))]
36 pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
37 use std::os::linux::net::SocketAddrExt;
38 let domain =
39 ResolvedTarget::from(std::os::unix::net::SocketAddr::from_abstract_name(domain)?);
40 Ok(Self {
41 inner: MaybeResolvedTarget::Resolved(domain),
42 })
43 }
44
45 #[allow(private_bounds)]
47 pub fn new_tcp(host: impl TcpResolve) -> Self {
48 Self { inner: host.into() }
49 }
50
51 pub fn to_addrs_sync(&self) -> Result<Vec<ResolvedTarget>, std::io::Error> {
53 use std::net::ToSocketAddrs;
54 let mut result = Vec::new();
55 match &self.inner {
56 MaybeResolvedTarget::Resolved(addr) => {
57 return Ok(vec![addr.clone()]);
58 }
59 MaybeResolvedTarget::Unresolved(host, port, _interface) => {
60 let addrs = format!("{}:{}", host, port).to_socket_addrs()?;
61 result.extend(addrs.map(ResolvedTarget::SocketAddr));
62 }
63 }
64 Ok(result)
65 }
66}
67
68#[derive(Clone, Debug)]
69pub struct Target {
70 inner: TargetInner,
71}
72
73#[allow(private_bounds)]
74impl Target {
75 pub fn new(name: TargetName) -> Self {
76 Self {
77 inner: TargetInner::NoTls(name.inner),
78 }
79 }
80
81 pub fn new_tls(name: TargetName, params: TlsParameters) -> Self {
82 Self {
83 inner: TargetInner::Tls(name.inner, params.into()),
84 }
85 }
86
87 pub fn new_starttls(name: TargetName, params: TlsParameters) -> Self {
88 Self {
89 inner: TargetInner::StartTls(name.inner, params.into()),
90 }
91 }
92
93 pub fn new_resolved(target: ResolvedTarget) -> Self {
94 Self {
95 inner: TargetInner::NoTls(target.into()),
96 }
97 }
98
99 pub fn new_resolved_tls(target: ResolvedTarget, params: TlsParameters) -> Self {
100 Self {
101 inner: TargetInner::Tls(target.into(), params.into()),
102 }
103 }
104
105 pub fn new_resolved_starttls(target: ResolvedTarget, params: TlsParameters) -> Self {
106 Self {
107 inner: TargetInner::StartTls(target.into(), params.into()),
108 }
109 }
110
111 #[cfg(unix)]
113 pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
114 let path = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
115 Ok(Self {
116 inner: TargetInner::NoTls(path.into()),
117 })
118 }
119
120 #[cfg(any(target_os = "linux", target_os = "android"))]
122 pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
123 use std::os::linux::net::SocketAddrExt;
124 let domain =
125 ResolvedTarget::from(std::os::unix::net::SocketAddr::from_abstract_name(domain)?);
126 Ok(Self {
127 inner: TargetInner::NoTls(domain.into()),
128 })
129 }
130
131 pub fn new_tcp(host: impl TcpResolve) -> Self {
133 Self {
134 inner: TargetInner::NoTls(host.into()),
135 }
136 }
137
138 pub fn new_tcp_tls(host: impl TcpResolve, params: TlsParameters) -> Self {
140 Self {
141 inner: TargetInner::Tls(host.into(), params.into()),
142 }
143 }
144
145 pub fn new_tcp_starttls(host: impl TcpResolve, params: TlsParameters) -> Self {
147 Self {
148 inner: TargetInner::StartTls(host.into(), params.into()),
149 }
150 }
151
152 pub fn try_set_tls(&mut self, params: TlsParameters) -> Option<Option<Arc<TlsParameters>>> {
153 if self.maybe_resolved().path().is_some() {
155 return None;
156 }
157
158 let params = params.into();
159
160 let no_target = TargetInner::NoTls(MaybeResolvedTarget::Resolved(
162 ResolvedTarget::SocketAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)),
163 ));
164
165 match std::mem::replace(&mut self.inner, no_target) {
166 TargetInner::NoTls(target) => {
167 self.inner = TargetInner::Tls(target, params);
168 Some(None)
169 }
170 TargetInner::Tls(target, old_params) => {
171 self.inner = TargetInner::Tls(target, params);
172 Some(Some(old_params))
173 }
174 TargetInner::StartTls(target, old_params) => {
175 self.inner = TargetInner::StartTls(target, params);
176 Some(Some(old_params))
177 }
178 }
179 }
180
181 pub fn try_remove_tls(&mut self) -> Option<Arc<TlsParameters>> {
182 let no_target = TargetInner::NoTls(MaybeResolvedTarget::Resolved(
184 ResolvedTarget::SocketAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)),
185 ));
186
187 match std::mem::replace(&mut self.inner, no_target) {
188 TargetInner::NoTls(target) => {
189 self.inner = TargetInner::NoTls(target);
190 None
191 }
192 TargetInner::Tls(target, old_params) => {
193 self.inner = TargetInner::NoTls(target);
194 Some(old_params)
195 }
196 TargetInner::StartTls(target, old_params) => {
197 self.inner = TargetInner::NoTls(target);
198 Some(old_params)
199 }
200 }
201 }
202
203 pub fn is_tcp(&self) -> bool {
205 self.maybe_resolved().port().is_some()
206 }
207
208 pub fn port(&self) -> Option<u16> {
211 self.maybe_resolved().port()
212 }
213
214 pub fn try_set_port(&mut self, port: u16) -> Option<u16> {
217 self.maybe_resolved_mut().set_port(port)
218 }
219
220 pub fn path(&self) -> Option<&Path> {
223 self.maybe_resolved().path()
224 }
225
226 pub fn host(&self) -> Option<Cow<str>> {
231 self.maybe_resolved().host()
232 }
233
234 pub fn name(&self) -> Option<ServerName> {
238 self.maybe_resolved().name()
239 }
240
241 pub fn tcp(&self) -> Option<(Cow<str>, u16)> {
244 self.maybe_resolved().tcp()
245 }
246
247 pub(crate) fn maybe_resolved(&self) -> &MaybeResolvedTarget {
248 match &self.inner {
249 TargetInner::NoTls(target) => target,
250 TargetInner::Tls(target, _) => target,
251 TargetInner::StartTls(target, _) => target,
252 }
253 }
254
255 pub(crate) fn maybe_resolved_mut(&mut self) -> &mut MaybeResolvedTarget {
256 match &mut self.inner {
257 TargetInner::NoTls(target) => target,
258 TargetInner::Tls(target, _) => target,
259 TargetInner::StartTls(target, _) => target,
260 }
261 }
262
263 pub(crate) fn is_starttls(&self) -> bool {
264 matches!(self.inner, TargetInner::StartTls(_, _))
265 }
266
267 pub(crate) fn maybe_ssl(&self) -> Option<&TlsParameters> {
268 match &self.inner {
269 TargetInner::NoTls(_) => None,
270 TargetInner::Tls(_, params) => Some(params),
271 TargetInner::StartTls(_, params) => Some(params),
272 }
273 }
274}
275
276#[derive(Clone, derive_more::From)]
277pub(crate) enum MaybeResolvedTarget {
278 Resolved(ResolvedTarget),
279 Unresolved(Cow<'static, str>, u16, Option<Cow<'static, str>>),
280}
281
282impl std::fmt::Debug for MaybeResolvedTarget {
283 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284 match self {
285 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
286 if let SocketAddr::V6(addr) = addr {
287 if addr.scope_id() != 0 {
288 write!(f, "[{}%{}]:{}", addr.ip(), addr.scope_id(), addr.port())
289 } else {
290 write!(f, "[{}]:{}", addr.ip(), addr.port())
291 }
292 } else {
293 write!(f, "{}:{}", addr.ip(), addr.port())
294 }
295 }
296 #[cfg(unix)]
297 MaybeResolvedTarget::Resolved(ResolvedTarget::UnixSocketAddr(addr)) => {
298 if let Some(path) = addr.as_pathname() {
299 return write!(f, "{}", path.to_string_lossy());
300 } else {
301 #[cfg(any(target_os = "linux", target_os = "android"))]
302 {
303 use std::os::linux::net::SocketAddrExt;
304 if let Some(name) = addr.as_abstract_name() {
305 return write!(f, "@{}", String::from_utf8_lossy(name));
306 }
307 }
308 }
309 Ok(())
310 }
311 MaybeResolvedTarget::Unresolved(host, port, interface) => {
312 write!(f, "{}:{}", host, port)?;
313 if let Some(interface) = interface {
314 write!(f, "%{}", interface)?;
315 }
316 Ok(())
317 }
318 }
319 }
320}
321
322impl MaybeResolvedTarget {
323 fn name(&self) -> Option<ServerName> {
324 match self {
325 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
326 Some(ServerName::IpAddress(addr.ip().into()))
327 }
328 MaybeResolvedTarget::Unresolved(host, _, _) => {
329 Some(ServerName::DnsName(host.to_string().try_into().ok()?))
330 }
331 _ => None,
332 }
333 }
334
335 fn tcp(&self) -> Option<(Cow<str>, u16)> {
336 match self {
337 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
338 Some((Cow::Owned(addr.ip().to_string()), addr.port()))
339 }
340 MaybeResolvedTarget::Unresolved(host, port, _) => Some((Cow::Borrowed(host), *port)),
341 _ => None,
342 }
343 }
344
345 fn path(&self) -> Option<&Path> {
346 match self {
347 #[cfg(unix)]
348 MaybeResolvedTarget::Resolved(ResolvedTarget::UnixSocketAddr(addr)) => {
349 addr.as_pathname()
350 }
351 _ => None,
352 }
353 }
354
355 fn host(&self) -> Option<Cow<str>> {
356 match self {
357 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
358 Some(Cow::Owned(addr.ip().to_string()))
359 }
360 MaybeResolvedTarget::Unresolved(host, _, _) => Some(Cow::Borrowed(host)),
361 _ => None,
362 }
363 }
364
365 fn port(&self) -> Option<u16> {
366 match self {
367 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => Some(addr.port()),
368 MaybeResolvedTarget::Unresolved(_, port, _) => Some(*port),
369 _ => None,
370 }
371 }
372
373 fn set_port(&mut self, new_port: u16) -> Option<u16> {
374 match self {
375 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
376 let old_port = addr.port();
377 addr.set_port(new_port);
378 Some(old_port)
379 }
380 MaybeResolvedTarget::Unresolved(_, port, _) => {
381 let old_port = *port;
382 *port = new_port;
383 Some(old_port)
384 }
385 _ => None,
386 }
387 }
388}
389
390#[derive(Clone, Debug)]
392enum TargetInner {
393 NoTls(MaybeResolvedTarget),
394 Tls(MaybeResolvedTarget, Arc<TlsParameters>),
395 StartTls(MaybeResolvedTarget, Arc<TlsParameters>),
396}
397
398#[derive(Clone, Debug, derive_more::From)]
399pub enum ResolvedTarget {
401 SocketAddr(std::net::SocketAddr),
402 #[cfg(unix)]
403 UnixSocketAddr(std::os::unix::net::SocketAddr),
404}
405
406impl ResolvedTarget {
407 pub fn tcp(&self) -> Option<SocketAddr> {
408 match self {
409 ResolvedTarget::SocketAddr(addr) => Some(*addr),
410 _ => None,
411 }
412 }
413}
414
415pub trait LocalAddress {
416 fn local_address(&self) -> std::io::Result<ResolvedTarget>;
417}
418
419trait TcpResolve {
420 fn into(self) -> MaybeResolvedTarget;
421}
422
423impl<S: AsRef<str>> TcpResolve for (S, u16) {
424 fn into(self) -> MaybeResolvedTarget {
425 if let Ok(addr) = self.0.as_ref().parse::<IpAddr>() {
426 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(SocketAddr::new(addr, self.1)))
427 } else {
428 MaybeResolvedTarget::Unresolved(Cow::Owned(self.0.as_ref().to_owned()), self.1, None)
429 }
430 }
431}
432
433impl TcpResolve for SocketAddr {
434 fn into(self) -> MaybeResolvedTarget {
435 MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(self))
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use std::net::SocketAddrV6;
442
443 use super::*;
444
445 #[test]
446 fn test_target() {
447 let target = Target::new_tcp(("localhost", 5432));
448 assert_eq!(
449 target.name(),
450 Some(ServerName::DnsName("localhost".try_into().unwrap()))
451 );
452 }
453
454 #[test]
455 fn test_target_name() {
456 let target = TargetName::new_tcp(("localhost", 5432));
457 assert_eq!(format!("{target:?}"), "localhost:5432");
458
459 let target = TargetName::new_tcp(("127.0.0.1", 5432));
460 assert_eq!(format!("{target:?}"), "127.0.0.1:5432");
461
462 let target = TargetName::new_tcp(("::1", 5432));
463 assert_eq!(format!("{target:?}"), "[::1]:5432");
464
465 let target = TargetName::new_tcp(SocketAddr::V6(SocketAddrV6::new(
466 "fe80::1ff:fe23:4567:890a".parse().unwrap(),
467 5432,
468 0,
469 2,
470 )));
471 assert_eq!(format!("{target:?}"), "[fe80::1ff:fe23:4567:890a%2]:5432");
472
473 let target = TargetName::new_unix_path("/tmp/test.sock").unwrap();
474 assert_eq!(format!("{target:?}"), "/tmp/test.sock");
475
476 #[cfg(any(target_os = "linux", target_os = "android"))]
477 {
478 let target = TargetName::new_unix_domain("test").unwrap();
479 assert_eq!(format!("{target:?}"), "@test");
480 }
481 }
482}