gel_stream/common/
target.rs

1use std::{
2    borrow::Cow,
3    net::{IpAddr, Ipv4Addr, SocketAddr},
4    path::Path,
5    sync::Arc,
6};
7
8use derive_more::Debug;
9use rustls_pki_types::ServerName;
10
11use crate::TlsParameters;
12
13/// A target name describes the TCP or Unix socket that a client will connect to.
14pub struct TargetName {
15    inner: MaybeResolvedTarget,
16}
17
18impl std::fmt::Debug for TargetName {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        write!(f, "{:?}", self.inner)
21    }
22}
23
24impl TargetName {
25    /// Create a new target for a Unix socket.
26    pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
27        #[cfg(unix)]
28        {
29            let path = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
30            Ok(Self {
31                inner: MaybeResolvedTarget::Resolved(path),
32            })
33        }
34        #[cfg(not(unix))]
35        {
36            Err(std::io::Error::new(
37                std::io::ErrorKind::Unsupported,
38                "Unix sockets are not supported on this platform",
39            ))
40        }
41    }
42
43    /// Create a new target for a Unix socket.
44    pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
45        #[cfg(any(target_os = "linux", target_os = "android"))]
46        {
47            use std::os::linux::net::SocketAddrExt;
48            let domain =
49                ResolvedTarget::from(std::os::unix::net::SocketAddr::from_abstract_name(domain)?);
50            Ok(Self {
51                inner: MaybeResolvedTarget::Resolved(domain),
52            })
53        }
54        #[cfg(not(any(target_os = "linux", target_os = "android")))]
55        {
56            Err(std::io::Error::new(
57                std::io::ErrorKind::Unsupported,
58                "Unix domain sockets are not supported on this platform",
59            ))
60        }
61    }
62
63    /// Create a new target for a TCP socket.
64    #[allow(private_bounds)]
65    pub fn new_tcp(host: impl TcpResolve) -> Self {
66        Self { inner: host.into() }
67    }
68
69    /// Resolves the target addresses for a given host.
70    pub fn to_addrs_sync(&self) -> Result<Vec<ResolvedTarget>, std::io::Error> {
71        use std::net::ToSocketAddrs;
72        let mut result = Vec::new();
73        match &self.inner {
74            MaybeResolvedTarget::Resolved(addr) => {
75                return Ok(vec![addr.clone()]);
76            }
77            MaybeResolvedTarget::Unresolved(host, port, _interface) => {
78                let addrs = format!("{}:{}", host, port).to_socket_addrs()?;
79                result.extend(addrs.map(ResolvedTarget::SocketAddr));
80            }
81        }
82        Ok(result)
83    }
84}
85
86#[derive(Clone, Debug)]
87pub struct Target {
88    inner: TargetInner,
89}
90
91#[allow(private_bounds)]
92impl Target {
93    pub fn new(name: TargetName) -> Self {
94        Self {
95            inner: TargetInner::NoTls(name.inner),
96        }
97    }
98
99    pub fn new_tls(name: TargetName, params: TlsParameters) -> Self {
100        Self {
101            inner: TargetInner::Tls(name.inner, params.into()),
102        }
103    }
104
105    pub fn new_starttls(name: TargetName, params: TlsParameters) -> Self {
106        Self {
107            inner: TargetInner::StartTls(name.inner, params.into()),
108        }
109    }
110
111    pub fn new_resolved(target: ResolvedTarget) -> Self {
112        Self {
113            inner: TargetInner::NoTls(target.into()),
114        }
115    }
116
117    pub fn new_resolved_tls(target: ResolvedTarget, params: TlsParameters) -> Self {
118        Self {
119            inner: TargetInner::Tls(target.into(), params.into()),
120        }
121    }
122
123    pub fn new_resolved_starttls(target: ResolvedTarget, params: TlsParameters) -> Self {
124        Self {
125            inner: TargetInner::StartTls(target.into(), params.into()),
126        }
127    }
128
129    /// Create a new target for a Unix socket.
130    pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
131        #[cfg(unix)]
132        {
133            let path = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
134            Ok(Self {
135                inner: TargetInner::NoTls(path.into()),
136            })
137        }
138        #[cfg(not(unix))]
139        {
140            Err(std::io::Error::new(
141                std::io::ErrorKind::Unsupported,
142                "Unix sockets are not supported on this platform",
143            ))
144        }
145    }
146
147    /// Create a new target for a Unix socket.
148    pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
149        #[cfg(any(target_os = "linux", target_os = "android"))]
150        {
151            use std::os::linux::net::SocketAddrExt;
152            let domain =
153                ResolvedTarget::from(std::os::unix::net::SocketAddr::from_abstract_name(domain)?);
154            Ok(Self {
155                inner: TargetInner::NoTls(domain.into()),
156            })
157        }
158        #[cfg(not(any(target_os = "linux", target_os = "android")))]
159        {
160            Err(std::io::Error::new(
161                std::io::ErrorKind::Unsupported,
162                "Unix domain sockets are not supported on this platform",
163            ))
164        }
165    }
166
167    /// Create a new target for a TCP socket.
168    pub fn new_tcp(host: impl TcpResolve) -> Self {
169        Self {
170            inner: TargetInner::NoTls(host.into()),
171        }
172    }
173
174    /// Create a new target for a TCP socket with TLS.
175    pub fn new_tcp_tls(host: impl TcpResolve, params: TlsParameters) -> Self {
176        Self {
177            inner: TargetInner::Tls(host.into(), params.into()),
178        }
179    }
180
181    /// Create a new target for a TCP socket with STARTTLS.
182    pub fn new_tcp_starttls(host: impl TcpResolve, params: TlsParameters) -> Self {
183        Self {
184            inner: TargetInner::StartTls(host.into(), params.into()),
185        }
186    }
187
188    pub fn try_set_tls(&mut self, params: TlsParameters) -> Option<Option<Arc<TlsParameters>>> {
189        // Don't set TLS parameters on Unix sockets.
190        if self.maybe_resolved().path().is_some() {
191            return None;
192        }
193
194        let params = params.into();
195
196        // Temporary
197        let no_target = TargetInner::NoTls(MaybeResolvedTarget::Resolved(
198            ResolvedTarget::SocketAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)),
199        ));
200
201        match std::mem::replace(&mut self.inner, no_target) {
202            TargetInner::NoTls(target) => {
203                self.inner = TargetInner::Tls(target, params);
204                Some(None)
205            }
206            TargetInner::Tls(target, old_params) => {
207                self.inner = TargetInner::Tls(target, params);
208                Some(Some(old_params))
209            }
210            TargetInner::StartTls(target, old_params) => {
211                self.inner = TargetInner::StartTls(target, params);
212                Some(Some(old_params))
213            }
214        }
215    }
216
217    pub fn try_remove_tls(&mut self) -> Option<Arc<TlsParameters>> {
218        // Temporary
219        let no_target = TargetInner::NoTls(MaybeResolvedTarget::Resolved(
220            ResolvedTarget::SocketAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)),
221        ));
222
223        match std::mem::replace(&mut self.inner, no_target) {
224            TargetInner::NoTls(target) => {
225                self.inner = TargetInner::NoTls(target);
226                None
227            }
228            TargetInner::Tls(target, old_params) => {
229                self.inner = TargetInner::NoTls(target);
230                Some(old_params)
231            }
232            TargetInner::StartTls(target, old_params) => {
233                self.inner = TargetInner::NoTls(target);
234                Some(old_params)
235            }
236        }
237    }
238
239    /// Check if the target is a TCP connection.
240    pub fn is_tcp(&self) -> bool {
241        self.maybe_resolved().port().is_some()
242    }
243
244    /// Get the port of the target. If the target type does not include a port,
245    /// this will return None.
246    pub fn port(&self) -> Option<u16> {
247        self.maybe_resolved().port()
248    }
249
250    /// Set the port of the target. If the target type does not include a port,
251    /// this will return None. Otherwise, it will return the old port.
252    pub fn try_set_port(&mut self, port: u16) -> Option<u16> {
253        self.maybe_resolved_mut().set_port(port)
254    }
255
256    /// Get the path of the target. If the target type does not include a path,
257    /// this will return None.
258    pub fn path(&self) -> Option<&Path> {
259        self.maybe_resolved().path()
260    }
261
262    /// Get the host of the target. For resolved IP addresses, this is the
263    /// string representation of the IP address. For unresolved hostnames, this
264    /// is the hostname. If the target type does not include a host, this will
265    /// return None.
266    pub fn host(&self) -> Option<Cow<str>> {
267        self.maybe_resolved().host()
268    }
269
270    /// Get the name of the target. For resolved IP addresses, this is the
271    /// string representation of the IP address. For unresolved hostnames, this
272    /// is the hostname.
273    pub fn name(&self) -> Option<ServerName> {
274        self.maybe_resolved().name()
275    }
276
277    /// Get the host and port of the target. If the target type does not include
278    /// a host or port, this will return None.
279    pub fn tcp(&self) -> Option<(Cow<str>, u16)> {
280        self.maybe_resolved().tcp()
281    }
282
283    pub(crate) fn maybe_resolved(&self) -> &MaybeResolvedTarget {
284        match &self.inner {
285            TargetInner::NoTls(target) => target,
286            TargetInner::Tls(target, _) => target,
287            TargetInner::StartTls(target, _) => target,
288        }
289    }
290
291    pub(crate) fn maybe_resolved_mut(&mut self) -> &mut MaybeResolvedTarget {
292        match &mut self.inner {
293            TargetInner::NoTls(target) => target,
294            TargetInner::Tls(target, _) => target,
295            TargetInner::StartTls(target, _) => target,
296        }
297    }
298
299    pub(crate) fn is_starttls(&self) -> bool {
300        matches!(self.inner, TargetInner::StartTls(_, _))
301    }
302
303    pub(crate) fn maybe_ssl(&self) -> Option<&TlsParameters> {
304        match &self.inner {
305            TargetInner::NoTls(_) => None,
306            TargetInner::Tls(_, params) => Some(params),
307            TargetInner::StartTls(_, params) => Some(params),
308        }
309    }
310}
311
312#[derive(Clone, derive_more::From)]
313pub(crate) enum MaybeResolvedTarget {
314    Resolved(ResolvedTarget),
315    Unresolved(Cow<'static, str>, u16, Option<Cow<'static, str>>),
316}
317
318impl std::fmt::Debug for MaybeResolvedTarget {
319    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320        match self {
321            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
322                if let SocketAddr::V6(addr) = addr {
323                    if addr.scope_id() != 0 {
324                        write!(f, "[{}%{}]:{}", addr.ip(), addr.scope_id(), addr.port())
325                    } else {
326                        write!(f, "[{}]:{}", addr.ip(), addr.port())
327                    }
328                } else {
329                    write!(f, "{}:{}", addr.ip(), addr.port())
330                }
331            }
332            #[cfg(unix)]
333            MaybeResolvedTarget::Resolved(ResolvedTarget::UnixSocketAddr(addr)) => {
334                if let Some(path) = addr.as_pathname() {
335                    return write!(f, "{}", path.to_string_lossy());
336                } else {
337                    #[cfg(any(target_os = "linux", target_os = "android"))]
338                    {
339                        use std::os::linux::net::SocketAddrExt;
340                        if let Some(name) = addr.as_abstract_name() {
341                            return write!(f, "@{}", String::from_utf8_lossy(name));
342                        }
343                    }
344                }
345                Ok(())
346            }
347            MaybeResolvedTarget::Unresolved(host, port, interface) => {
348                write!(f, "{}:{}", host, port)?;
349                if let Some(interface) = interface {
350                    write!(f, "%{}", interface)?;
351                }
352                Ok(())
353            }
354        }
355    }
356}
357
358impl MaybeResolvedTarget {
359    fn name(&self) -> Option<ServerName> {
360        match self {
361            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
362                Some(ServerName::IpAddress(addr.ip().into()))
363            }
364            MaybeResolvedTarget::Unresolved(host, _, _) => {
365                Some(ServerName::DnsName(host.to_string().try_into().ok()?))
366            }
367            _ => None,
368        }
369    }
370
371    fn tcp(&self) -> Option<(Cow<str>, u16)> {
372        match self {
373            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
374                Some((Cow::Owned(addr.ip().to_string()), addr.port()))
375            }
376            MaybeResolvedTarget::Unresolved(host, port, _) => Some((Cow::Borrowed(host), *port)),
377            _ => None,
378        }
379    }
380
381    fn path(&self) -> Option<&Path> {
382        match self {
383            #[cfg(unix)]
384            MaybeResolvedTarget::Resolved(ResolvedTarget::UnixSocketAddr(addr)) => {
385                addr.as_pathname()
386            }
387            _ => None,
388        }
389    }
390
391    fn host(&self) -> Option<Cow<str>> {
392        match self {
393            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
394                Some(Cow::Owned(addr.ip().to_string()))
395            }
396            MaybeResolvedTarget::Unresolved(host, _, _) => Some(Cow::Borrowed(host)),
397            _ => None,
398        }
399    }
400
401    fn port(&self) -> Option<u16> {
402        match self {
403            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => Some(addr.port()),
404            MaybeResolvedTarget::Unresolved(_, port, _) => Some(*port),
405            _ => None,
406        }
407    }
408
409    fn set_port(&mut self, new_port: u16) -> Option<u16> {
410        match self {
411            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
412                let old_port = addr.port();
413                addr.set_port(new_port);
414                Some(old_port)
415            }
416            MaybeResolvedTarget::Unresolved(_, port, _) => {
417                let old_port = *port;
418                *port = new_port;
419                Some(old_port)
420            }
421            _ => None,
422        }
423    }
424}
425
426/// The type of connection.
427#[derive(Clone, Debug)]
428enum TargetInner {
429    NoTls(MaybeResolvedTarget),
430    Tls(MaybeResolvedTarget, Arc<TlsParameters>),
431    StartTls(MaybeResolvedTarget, Arc<TlsParameters>),
432}
433
434#[derive(Clone, Debug, derive_more::From)]
435/// The resolved target of a connection attempt.
436pub enum ResolvedTarget {
437    SocketAddr(std::net::SocketAddr),
438    #[cfg(unix)]
439    UnixSocketAddr(std::os::unix::net::SocketAddr),
440}
441
442impl ResolvedTarget {
443    pub fn tcp(&self) -> Option<SocketAddr> {
444        match self {
445            ResolvedTarget::SocketAddr(addr) => Some(*addr),
446            _ => None,
447        }
448    }
449}
450
451pub trait LocalAddress {
452    fn local_address(&self) -> std::io::Result<ResolvedTarget>;
453}
454
455trait TcpResolve {
456    fn into(self) -> MaybeResolvedTarget;
457}
458
459impl<S: AsRef<str>> TcpResolve for (S, u16) {
460    fn into(self) -> MaybeResolvedTarget {
461        if let Ok(addr) = self.0.as_ref().parse::<IpAddr>() {
462            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(SocketAddr::new(addr, self.1)))
463        } else {
464            MaybeResolvedTarget::Unresolved(Cow::Owned(self.0.as_ref().to_owned()), self.1, None)
465        }
466    }
467}
468
469impl TcpResolve for SocketAddr {
470    fn into(self) -> MaybeResolvedTarget {
471        MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(self))
472    }
473}
474
475#[cfg(test)]
476mod tests {
477    use std::net::SocketAddrV6;
478
479    use super::*;
480
481    #[test]
482    fn test_target() {
483        let target = Target::new_tcp(("localhost", 5432));
484        assert_eq!(
485            target.name(),
486            Some(ServerName::DnsName("localhost".try_into().unwrap()))
487        );
488    }
489
490    #[test]
491    fn test_target_name() {
492        let target = TargetName::new_tcp(("localhost", 5432));
493        assert_eq!(format!("{target:?}"), "localhost:5432");
494
495        let target = TargetName::new_tcp(("127.0.0.1", 5432));
496        assert_eq!(format!("{target:?}"), "127.0.0.1:5432");
497
498        let target = TargetName::new_tcp(("::1", 5432));
499        assert_eq!(format!("{target:?}"), "[::1]:5432");
500
501        let target = TargetName::new_tcp(SocketAddr::V6(SocketAddrV6::new(
502            "fe80::1ff:fe23:4567:890a".parse().unwrap(),
503            5432,
504            0,
505            2,
506        )));
507        assert_eq!(format!("{target:?}"), "[fe80::1ff:fe23:4567:890a%2]:5432");
508
509        let target = TargetName::new_unix_path("/tmp/test.sock").unwrap();
510        assert_eq!(format!("{target:?}"), "/tmp/test.sock");
511
512        #[cfg(any(target_os = "linux", target_os = "android"))]
513        {
514            let target = TargetName::new_unix_domain("test").unwrap();
515            assert_eq!(format!("{target:?}"), "@test");
516        }
517    }
518}