gel_stream/common/
target.rs

1use std::{
2    borrow::Cow,
3    net::{IpAddr, Ipv4Addr, SocketAddr},
4    path::Path,
5    sync::Arc,
6};
7
8use derive_more::Debug;
9use rustls_pki_types::ServerName;
10
11use crate::TlsParameters;
12
13/// A target name describes the TCP or Unix socket that a client will connect to.
14pub struct TargetName {
15    inner: MaybeResolvedTarget,
16}
17
18impl std::fmt::Debug for TargetName {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        write!(f, "{:?}", self.inner)
21    }
22}
23
24impl TargetName {
25    /// Create a new target for a Unix socket.
26    pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
27        #[cfg(unix)]
28        {
29            let path = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
30            Ok(Self {
31                inner: MaybeResolvedTarget::Resolved(path),
32            })
33        }
34        #[cfg(not(unix))]
35        {
36            Err(std::io::Error::new(
37                std::io::ErrorKind::Unsupported,
38                "Unix sockets are not supported on this platform",
39            ))
40        }
41    }
42
43    /// Create a new target for a Unix socket.
44    pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
45        #[cfg(any(target_os = "linux", target_os = "android"))]
46        {
47            use std::os::linux::net::SocketAddrExt;
48            let domain =
49                ResolvedTarget::from(std::os::unix::net::SocketAddr::from_abstract_name(domain)?);
50            Ok(Self {
51                inner: MaybeResolvedTarget::Resolved(domain),
52            })
53        }
54        #[cfg(not(any(target_os = "linux", target_os = "android")))]
55        {
56            Err(std::io::Error::new(
57                std::io::ErrorKind::Unsupported,
58                "Unix domain sockets are not supported on this platform",
59            ))
60        }
61    }
62
63    /// Create a new target for a TCP socket.
64    #[allow(private_bounds)]
65    pub fn new_tcp(host: impl TcpResolve) -> Self {
66        Self { inner: host.into() }
67    }
68
69    /// Resolves the target addresses for a given host.
70    pub fn to_addrs_sync(&self) -> Result<Vec<ResolvedTarget>, std::io::Error> {
71        use std::net::ToSocketAddrs;
72        let mut result = Vec::new();
73        match &self.inner {
74            MaybeResolvedTarget::Resolved(addr) => {
75                return Ok(vec![addr.clone()]);
76            }
77            MaybeResolvedTarget::Unresolved(host, port, _interface) => {
78                let addrs = format!("{}:{}", host, port).to_socket_addrs()?;
79                result.extend(addrs.map(ResolvedTarget::SocketAddr));
80            }
81        }
82        Ok(result)
83    }
84}
85
86#[derive(Clone)]
87pub struct Target {
88    inner: TargetInner,
89}
90
91impl std::fmt::Debug for Target {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        match &self.inner {
94            TargetInner::NoTls(target) => write!(f, "{:?}", target),
95            TargetInner::Tls(target, _) => write!(f, "{:?} (TLS)", target),
96            TargetInner::StartTls(target, _) => write!(f, "{:?} (STARTTLS)", target),
97        }
98    }
99}
100
101#[allow(private_bounds)]
102impl Target {
103    pub fn new(name: TargetName) -> Self {
104        Self {
105            inner: TargetInner::NoTls(name.inner),
106        }
107    }
108
109    pub fn new_tls(name: TargetName, params: TlsParameters) -> Self {
110        Self {
111            inner: TargetInner::Tls(name.inner, params.into()),
112        }
113    }
114
115    pub fn new_starttls(name: TargetName, params: TlsParameters) -> Self {
116        Self {
117            inner: TargetInner::StartTls(name.inner, params.into()),
118        }
119    }
120
121    pub fn new_resolved(target: ResolvedTarget) -> Self {
122        Self {
123            inner: TargetInner::NoTls(target.into()),
124        }
125    }
126
127    pub fn new_resolved_tls(target: ResolvedTarget, params: TlsParameters) -> Self {
128        Self {
129            inner: TargetInner::Tls(target.into(), params.into()),
130        }
131    }
132
133    pub fn new_resolved_starttls(target: ResolvedTarget, params: TlsParameters) -> Self {
134        Self {
135            inner: TargetInner::StartTls(target.into(), params.into()),
136        }
137    }
138
139    /// Create a new target for a Unix socket.
140    pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
141        #[cfg(unix)]
142        {
143            let path = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
144            Ok(Self {
145                inner: TargetInner::NoTls(path.into()),
146            })
147        }
148        #[cfg(not(unix))]
149        {
150            Err(std::io::Error::new(
151                std::io::ErrorKind::Unsupported,
152                "Unix sockets are not supported on this platform",
153            ))
154        }
155    }
156
157    /// Create a new target for a Unix socket.
158    pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
159        #[cfg(any(target_os = "linux", target_os = "android"))]
160        {
161            use std::os::linux::net::SocketAddrExt;
162            let domain =
163                ResolvedTarget::from(std::os::unix::net::SocketAddr::from_abstract_name(domain)?);
164            Ok(Self {
165                inner: TargetInner::NoTls(domain.into()),
166            })
167        }
168        #[cfg(not(any(target_os = "linux", target_os = "android")))]
169        {
170            Err(std::io::Error::new(
171                std::io::ErrorKind::Unsupported,
172                "Unix domain sockets are not supported on this platform",
173            ))
174        }
175    }
176
177    /// Create a new target for a TCP socket.
178    pub fn new_tcp(host: impl TcpResolve) -> Self {
179        Self {
180            inner: TargetInner::NoTls(host.into()),
181        }
182    }
183
184    /// Create a new target for a TCP socket with TLS.
185    pub fn new_tcp_tls(host: impl TcpResolve, params: TlsParameters) -> Self {
186        Self {
187            inner: TargetInner::Tls(host.into(), params.into()),
188        }
189    }
190
191    /// Create a new target for a TCP socket with STARTTLS.
192    pub fn new_tcp_starttls(host: impl TcpResolve, params: TlsParameters) -> Self {
193        Self {
194            inner: TargetInner::StartTls(host.into(), params.into()),
195        }
196    }
197
198    pub fn try_set_tls(&mut self, params: TlsParameters) -> Option<Option<Arc<TlsParameters>>> {
199        // Don't set TLS parameters on Unix sockets.
200        if self.maybe_resolved().path().is_some() {
201            return None;
202        }
203
204        let params = params.into();
205
206        // Temporary
207        let no_target = TargetInner::NoTls(MaybeResolvedTarget::Resolved(
208            ResolvedTarget::SocketAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)),
209        ));
210
211        match std::mem::replace(&mut self.inner, no_target) {
212            TargetInner::NoTls(target) => {
213                self.inner = TargetInner::Tls(target, params);
214                Some(None)
215            }
216            TargetInner::Tls(target, old_params) => {
217                self.inner = TargetInner::Tls(target, params);
218                Some(Some(old_params))
219            }
220            TargetInner::StartTls(target, old_params) => {
221                self.inner = TargetInner::StartTls(target, params);
222                Some(Some(old_params))
223            }
224        }
225    }
226
227    pub fn try_remove_tls(&mut self) -> Option<Arc<TlsParameters>> {
228        // Temporary
229        let no_target = TargetInner::NoTls(MaybeResolvedTarget::Resolved(
230            ResolvedTarget::SocketAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)),
231        ));
232
233        match std::mem::replace(&mut self.inner, no_target) {
234            TargetInner::NoTls(target) => {
235                self.inner = TargetInner::NoTls(target);
236                None
237            }
238            TargetInner::Tls(target, old_params) => {
239                self.inner = TargetInner::NoTls(target);
240                Some(old_params)
241            }
242            TargetInner::StartTls(target, old_params) => {
243                self.inner = TargetInner::NoTls(target);
244                Some(old_params)
245            }
246        }
247    }
248
249    /// Check if the target is a TCP connection.
250    pub fn is_tcp(&self) -> bool {
251        self.maybe_resolved().port().is_some()
252    }
253
254    /// Get the port of the target. If the target type does not include a port,
255    /// this will return None.
256    pub fn port(&self) -> Option<u16> {
257        self.maybe_resolved().port()
258    }
259
260    /// Set the port of the target. If the target type does not include a port,
261    /// this will return None. Otherwise, it will return the old port.
262    pub fn try_set_port(&mut self, port: u16) -> Option<u16> {
263        self.maybe_resolved_mut().set_port(port)
264    }
265
266    /// Get the path of the target. If the target type does not include a path,
267    /// this will return None.
268    pub fn path(&self) -> Option<&Path> {
269        self.maybe_resolved().path()
270    }
271
272    /// Get the host of the target. For resolved IP addresses, this is the
273    /// string representation of the IP address. For unresolved hostnames, this
274    /// is the hostname. If the target type does not include a host, this will
275    /// return None.
276    pub fn host(&self) -> Option<Cow<str>> {
277        self.maybe_resolved().host()
278    }
279
280    /// Get the name of the target. For resolved IP addresses, this is the
281    /// string representation of the IP address. For unresolved hostnames, this
282    /// is the hostname.
283    pub fn name(&self) -> Option<ServerName> {
284        self.maybe_resolved().name()
285    }
286
287    /// Get the host and port of the target. If the target type does not include
288    /// a host or port, this will return None.
289    pub fn tcp(&self) -> Option<(Cow<str>, u16)> {
290        self.maybe_resolved().tcp()
291    }
292
293    pub(crate) fn maybe_resolved(&self) -> &MaybeResolvedTarget {
294        match &self.inner {
295            TargetInner::NoTls(target) => target,
296            TargetInner::Tls(target, _) => target,
297            TargetInner::StartTls(target, _) => target,
298        }
299    }
300
301    pub(crate) fn maybe_resolved_mut(&mut self) -> &mut MaybeResolvedTarget {
302        match &mut self.inner {
303            TargetInner::NoTls(target) => target,
304            TargetInner::Tls(target, _) => target,
305            TargetInner::StartTls(target, _) => target,
306        }
307    }
308
309    pub(crate) fn is_starttls(&self) -> bool {
310        matches!(self.inner, TargetInner::StartTls(_, _))
311    }
312
313    pub(crate) fn maybe_ssl(&self) -> Option<&TlsParameters> {
314        match &self.inner {
315            TargetInner::NoTls(_) => None,
316            TargetInner::Tls(_, params) => Some(params),
317            TargetInner::StartTls(_, params) => Some(params),
318        }
319    }
320}
321
322#[derive(Clone, derive_more::From)]
323pub(crate) enum MaybeResolvedTarget {
324    Resolved(ResolvedTarget),
325    Unresolved(Cow<'static, str>, u16, Option<Cow<'static, str>>),
326}
327
328impl std::fmt::Debug for MaybeResolvedTarget {
329    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330        match self {
331            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
332                if let SocketAddr::V6(addr) = addr {
333                    if addr.scope_id() != 0 {
334                        write!(f, "[{}%{}]:{}", addr.ip(), addr.scope_id(), addr.port())
335                    } else {
336                        write!(f, "[{}]:{}", addr.ip(), addr.port())
337                    }
338                } else {
339                    write!(f, "{}:{}", addr.ip(), addr.port())
340                }
341            }
342            #[cfg(unix)]
343            MaybeResolvedTarget::Resolved(ResolvedTarget::UnixSocketAddr(addr)) => {
344                if let Some(path) = addr.as_pathname() {
345                    return write!(f, "{}", path.to_string_lossy());
346                } else {
347                    #[cfg(any(target_os = "linux", target_os = "android"))]
348                    {
349                        use std::os::linux::net::SocketAddrExt;
350                        if let Some(name) = addr.as_abstract_name() {
351                            return write!(f, "@{}", String::from_utf8_lossy(name));
352                        }
353                    }
354                }
355                Ok(())
356            }
357            MaybeResolvedTarget::Unresolved(host, port, interface) => {
358                write!(f, "{}:{}", host, port)?;
359                if let Some(interface) = interface {
360                    write!(f, "%{}", interface)?;
361                }
362                Ok(())
363            }
364        }
365    }
366}
367
368impl MaybeResolvedTarget {
369    fn name(&self) -> Option<ServerName> {
370        match self {
371            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
372                Some(ServerName::IpAddress(addr.ip().into()))
373            }
374            MaybeResolvedTarget::Unresolved(host, _, _) => {
375                Some(ServerName::DnsName(host.to_string().try_into().ok()?))
376            }
377            _ => None,
378        }
379    }
380
381    fn tcp(&self) -> Option<(Cow<str>, u16)> {
382        match self {
383            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
384                Some((Cow::Owned(addr.ip().to_string()), addr.port()))
385            }
386            MaybeResolvedTarget::Unresolved(host, port, _) => Some((Cow::Borrowed(host), *port)),
387            _ => None,
388        }
389    }
390
391    fn path(&self) -> Option<&Path> {
392        match self {
393            #[cfg(unix)]
394            MaybeResolvedTarget::Resolved(ResolvedTarget::UnixSocketAddr(addr)) => {
395                addr.as_pathname()
396            }
397            _ => None,
398        }
399    }
400
401    fn host(&self) -> Option<Cow<str>> {
402        match self {
403            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
404                Some(Cow::Owned(addr.ip().to_string()))
405            }
406            MaybeResolvedTarget::Unresolved(host, _, _) => Some(Cow::Borrowed(host)),
407            _ => None,
408        }
409    }
410
411    fn port(&self) -> Option<u16> {
412        match self {
413            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => Some(addr.port()),
414            MaybeResolvedTarget::Unresolved(_, port, _) => Some(*port),
415            _ => None,
416        }
417    }
418
419    fn set_port(&mut self, new_port: u16) -> Option<u16> {
420        match self {
421            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
422                let old_port = addr.port();
423                addr.set_port(new_port);
424                Some(old_port)
425            }
426            MaybeResolvedTarget::Unresolved(_, port, _) => {
427                let old_port = *port;
428                *port = new_port;
429                Some(old_port)
430            }
431            _ => None,
432        }
433    }
434}
435
436/// The type of connection.
437#[derive(Clone, Debug)]
438enum TargetInner {
439    NoTls(MaybeResolvedTarget),
440    Tls(MaybeResolvedTarget, Arc<TlsParameters>),
441    StartTls(MaybeResolvedTarget, Arc<TlsParameters>),
442}
443
444#[derive(Clone, Debug, derive_more::From)]
445/// The resolved target of a connection attempt.
446pub enum ResolvedTarget {
447    SocketAddr(std::net::SocketAddr),
448    #[cfg(unix)]
449    UnixSocketAddr(std::os::unix::net::SocketAddr),
450}
451
452impl ResolvedTarget {
453    pub fn tcp(&self) -> Option<SocketAddr> {
454        match self {
455            ResolvedTarget::SocketAddr(addr) => Some(*addr),
456            _ => None,
457        }
458    }
459}
460
461pub trait LocalAddress {
462    fn local_address(&self) -> std::io::Result<ResolvedTarget>;
463}
464
465trait TcpResolve {
466    fn into(self) -> MaybeResolvedTarget;
467}
468
469impl<S: AsRef<str>> TcpResolve for (S, u16) {
470    fn into(self) -> MaybeResolvedTarget {
471        if let Ok(addr) = self.0.as_ref().parse::<IpAddr>() {
472            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(SocketAddr::new(addr, self.1)))
473        } else {
474            MaybeResolvedTarget::Unresolved(Cow::Owned(self.0.as_ref().to_owned()), self.1, None)
475        }
476    }
477}
478
479impl TcpResolve for SocketAddr {
480    fn into(self) -> MaybeResolvedTarget {
481        MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(self))
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use std::net::SocketAddrV6;
488
489    use super::*;
490
491    #[test]
492    fn test_target() {
493        let target = Target::new_tcp(("localhost", 5432));
494        assert_eq!(
495            target.name(),
496            Some(ServerName::DnsName("localhost".try_into().unwrap()))
497        );
498    }
499
500    #[test]
501    fn test_target_name() {
502        let target = TargetName::new_tcp(("localhost", 5432));
503        assert_eq!(format!("{target:?}"), "localhost:5432");
504
505        let target = TargetName::new_tcp(("127.0.0.1", 5432));
506        assert_eq!(format!("{target:?}"), "127.0.0.1:5432");
507
508        let target = TargetName::new_tcp(("::1", 5432));
509        assert_eq!(format!("{target:?}"), "[::1]:5432");
510
511        let target = TargetName::new_tcp(SocketAddr::V6(SocketAddrV6::new(
512            "fe80::1ff:fe23:4567:890a".parse().unwrap(),
513            5432,
514            0,
515            2,
516        )));
517        assert_eq!(format!("{target:?}"), "[fe80::1ff:fe23:4567:890a%2]:5432");
518
519        let target = TargetName::new_unix_path("/tmp/test.sock").unwrap();
520        assert_eq!(format!("{target:?}"), "/tmp/test.sock");
521
522        #[cfg(any(target_os = "linux", target_os = "android"))]
523        {
524            let target = TargetName::new_unix_domain("test").unwrap();
525            assert_eq!(format!("{target:?}"), "@test");
526        }
527    }
528
529    #[test]
530    fn test_target_debug() {
531        let target = Target::new_tcp(("localhost", 5432));
532        assert_eq!(format!("{target:?}"), "localhost:5432");
533
534        let target = Target::new_tcp_tls(("localhost", 5432), TlsParameters::default());
535        assert_eq!(format!("{target:?}"), "localhost:5432 (TLS)");
536
537        let target = Target::new_tcp_starttls(("localhost", 5432), TlsParameters::default());
538        assert_eq!(format!("{target:?}"), "localhost:5432 (STARTTLS)");
539
540        let target = Target::new_tcp(("127.0.0.1", 5432));
541        assert_eq!(format!("{target:?}"), "127.0.0.1:5432");
542
543        let target = Target::new_tcp(("::1", 5432));
544        assert_eq!(format!("{target:?}"), "[::1]:5432");
545
546        #[cfg(unix)]
547        {
548            let target = Target::new_unix_path("/tmp/test.sock").unwrap();
549            assert_eq!(format!("{target:?}"), "/tmp/test.sock");
550        }
551
552        #[cfg(any(target_os = "linux", target_os = "android"))]
553        {
554            let target = Target::new_unix_domain("test").unwrap();
555            assert_eq!(format!("{target:?}"), "@test");
556        }
557    }
558}