gel_stream/common/
target.rs

1use std::{
2    borrow::Cow,
3    net::{IpAddr, Ipv4Addr, SocketAddr},
4    path::Path,
5    sync::Arc,
6};
7
8use derive_more::Debug;
9use rustls_pki_types::ServerName;
10
11use crate::TlsParameters;
12
13/// A target name describes the TCP or Unix socket that a client will connect to.
14pub struct TargetName {
15    inner: MaybeResolvedTarget,
16}
17
18impl std::fmt::Debug for TargetName {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        write!(f, "{:?}", self.inner)
21    }
22}
23
24impl TargetName {
25    /// Create a new target for a Unix socket.
26    pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
27        #[cfg(unix)]
28        {
29            let path = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
30            Ok(Self {
31                inner: MaybeResolvedTarget::Resolved(path),
32            })
33        }
34        #[cfg(not(unix))]
35        {
36            Err(std::io::Error::new(
37                std::io::ErrorKind::Unsupported,
38                "Unix sockets are not supported on this platform",
39            ))
40        }
41    }
42
43    /// Create a new target for a Unix socket.
44    pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
45        #[cfg(any(target_os = "linux", target_os = "android"))]
46        {
47            use std::os::linux::net::SocketAddrExt;
48            let domain =
49                ResolvedTarget::from(std::os::unix::net::SocketAddr::from_abstract_name(domain)?);
50            Ok(Self {
51                inner: MaybeResolvedTarget::Resolved(domain),
52            })
53        }
54        #[cfg(not(any(target_os = "linux", target_os = "android")))]
55        {
56            Err(std::io::Error::new(
57                std::io::ErrorKind::Unsupported,
58                "Unix domain sockets are not supported on this platform",
59            ))
60        }
61    }
62
63    /// Create a new target for a TCP socket.
64    #[allow(private_bounds)]
65    pub fn new_tcp(host: impl TcpResolve) -> Self {
66        Self { inner: host.into() }
67    }
68
69    /// Resolves the target addresses for a given host.
70    pub fn to_addrs_sync(&self) -> Result<Vec<ResolvedTarget>, std::io::Error> {
71        use std::net::ToSocketAddrs;
72        let mut result = Vec::new();
73        match &self.inner {
74            MaybeResolvedTarget::Resolved(addr) => {
75                return Ok(vec![addr.clone()]);
76            }
77            MaybeResolvedTarget::Unresolved(host, port, _interface) => {
78                let addrs = format!("{}:{}", host, port).to_socket_addrs()?;
79                result.extend(addrs.map(ResolvedTarget::SocketAddr));
80            }
81        }
82        Ok(result)
83    }
84
85    pub(crate) fn maybe_resolved(&self) -> &MaybeResolvedTarget {
86        &self.inner
87    }
88
89    pub(crate) fn maybe_resolved_mut(&mut self) -> &mut MaybeResolvedTarget {
90        &mut self.inner
91    }
92
93    /// Check if the target is a TCP connection.
94    pub fn is_tcp(&self) -> bool {
95        self.maybe_resolved().port().is_some()
96    }
97
98    /// Get the port of the target. If the target type does not include a port,
99    /// this will return None.
100    pub fn port(&self) -> Option<u16> {
101        self.maybe_resolved().port()
102    }
103
104    /// Set the port of the target. If the target type does not include a port,
105    /// this will return None. Otherwise, it will return the old port.
106    pub fn try_set_port(&mut self, port: u16) -> Option<u16> {
107        self.maybe_resolved_mut().set_port(port)
108    }
109
110    /// Get the path of the target. If the target type does not include a path,
111    /// this will return None.
112    pub fn path(&self) -> Option<&Path> {
113        self.maybe_resolved().path()
114    }
115
116    /// Get the host of the target. For resolved IP addresses, this is the
117    /// string representation of the IP address. For unresolved hostnames, this
118    /// is the hostname. If the target type does not include a host, this will
119    /// return None.
120    pub fn host(&self) -> Option<Cow<str>> {
121        self.maybe_resolved().host()
122    }
123
124    /// Get the name of the target. For resolved IP addresses, this is the
125    /// string representation of the IP address. For unresolved hostnames, this
126    /// is the hostname.
127    pub fn name(&self) -> Option<ServerName> {
128        self.maybe_resolved().name()
129    }
130
131    /// Get the host and port of the target. If the target type does not include
132    /// a host or port, this will return None.
133    pub fn tcp(&self) -> Option<(Cow<str>, u16)> {
134        self.maybe_resolved().tcp()
135    }
136}
137
138#[derive(Clone)]
139pub struct Target {
140    inner: TargetInner,
141}
142
143impl std::fmt::Debug for Target {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        match &self.inner {
146            TargetInner::NoTls(target) => write!(f, "{:?}", target),
147            TargetInner::Tls(target, _) => write!(f, "{:?} (TLS)", target),
148            TargetInner::StartTls(target, _) => write!(f, "{:?} (STARTTLS)", target),
149        }
150    }
151}
152
153#[allow(private_bounds)]
154impl Target {
155    pub fn new(name: TargetName) -> Self {
156        Self {
157            inner: TargetInner::NoTls(name.inner),
158        }
159    }
160
161    pub fn new_tls(name: TargetName, params: TlsParameters) -> Self {
162        Self {
163            inner: TargetInner::Tls(name.inner, params.into()),
164        }
165    }
166
167    pub fn new_starttls(name: TargetName, params: TlsParameters) -> Self {
168        Self {
169            inner: TargetInner::StartTls(name.inner, params.into()),
170        }
171    }
172
173    pub fn new_resolved(target: ResolvedTarget) -> Self {
174        Self {
175            inner: TargetInner::NoTls(target.into()),
176        }
177    }
178
179    pub fn new_resolved_tls(target: ResolvedTarget, params: TlsParameters) -> Self {
180        Self {
181            inner: TargetInner::Tls(target.into(), params.into()),
182        }
183    }
184
185    pub fn new_resolved_starttls(target: ResolvedTarget, params: TlsParameters) -> Self {
186        Self {
187            inner: TargetInner::StartTls(target.into(), params.into()),
188        }
189    }
190
191    /// Create a new target for a Unix socket.
192    pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
193        #[cfg(unix)]
194        {
195            let path = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
196            Ok(Self {
197                inner: TargetInner::NoTls(path.into()),
198            })
199        }
200        #[cfg(not(unix))]
201        {
202            Err(std::io::Error::new(
203                std::io::ErrorKind::Unsupported,
204                "Unix sockets are not supported on this platform",
205            ))
206        }
207    }
208
209    /// Create a new target for a Unix socket.
210    pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
211        #[cfg(any(target_os = "linux", target_os = "android"))]
212        {
213            use std::os::linux::net::SocketAddrExt;
214            let domain =
215                ResolvedTarget::from(std::os::unix::net::SocketAddr::from_abstract_name(domain)?);
216            Ok(Self {
217                inner: TargetInner::NoTls(domain.into()),
218            })
219        }
220        #[cfg(not(any(target_os = "linux", target_os = "android")))]
221        {
222            Err(std::io::Error::new(
223                std::io::ErrorKind::Unsupported,
224                "Unix domain sockets are not supported on this platform",
225            ))
226        }
227    }
228
229    /// Create a new target for a TCP socket.
230    pub fn new_tcp(host: impl TcpResolve) -> Self {
231        Self {
232            inner: TargetInner::NoTls(host.into()),
233        }
234    }
235
236    /// Create a new target for a TCP socket with TLS.
237    pub fn new_tcp_tls(host: impl TcpResolve, params: TlsParameters) -> Self {
238        Self {
239            inner: TargetInner::Tls(host.into(), params.into()),
240        }
241    }
242
243    /// Create a new target for a TCP socket with STARTTLS.
244    pub fn new_tcp_starttls(host: impl TcpResolve, params: TlsParameters) -> Self {
245        Self {
246            inner: TargetInner::StartTls(host.into(), params.into()),
247        }
248    }
249
250    pub fn try_set_tls(&mut self, params: TlsParameters) -> Option<Option<Arc<TlsParameters>>> {
251        // Don't set TLS parameters on Unix sockets.
252        if self.maybe_resolved().path().is_some() {
253            return None;
254        }
255
256        let params = params.into();
257
258        // Temporary
259        let no_target = TargetInner::NoTls(MaybeResolvedTarget::Resolved(
260            ResolvedTarget::SocketAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)),
261        ));
262
263        match std::mem::replace(&mut self.inner, no_target) {
264            TargetInner::NoTls(target) => {
265                self.inner = TargetInner::Tls(target, params);
266                Some(None)
267            }
268            TargetInner::Tls(target, old_params) => {
269                self.inner = TargetInner::Tls(target, params);
270                Some(Some(old_params))
271            }
272            TargetInner::StartTls(target, old_params) => {
273                self.inner = TargetInner::StartTls(target, params);
274                Some(Some(old_params))
275            }
276        }
277    }
278
279    pub fn try_remove_tls(&mut self) -> Option<Arc<TlsParameters>> {
280        // Temporary
281        let no_target = TargetInner::NoTls(MaybeResolvedTarget::Resolved(
282            ResolvedTarget::SocketAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)),
283        ));
284
285        match std::mem::replace(&mut self.inner, no_target) {
286            TargetInner::NoTls(target) => {
287                self.inner = TargetInner::NoTls(target);
288                None
289            }
290            TargetInner::Tls(target, old_params) => {
291                self.inner = TargetInner::NoTls(target);
292                Some(old_params)
293            }
294            TargetInner::StartTls(target, old_params) => {
295                self.inner = TargetInner::NoTls(target);
296                Some(old_params)
297            }
298        }
299    }
300
301    /// Check if the target is a TCP connection.
302    pub fn is_tcp(&self) -> bool {
303        self.maybe_resolved().port().is_some()
304    }
305
306    /// Get the port of the target. If the target type does not include a port,
307    /// this will return None.
308    pub fn port(&self) -> Option<u16> {
309        self.maybe_resolved().port()
310    }
311
312    /// Set the port of the target. If the target type does not include a port,
313    /// this will return None. Otherwise, it will return the old port.
314    pub fn try_set_port(&mut self, port: u16) -> Option<u16> {
315        self.maybe_resolved_mut().set_port(port)
316    }
317
318    /// Get the path of the target. If the target type does not include a path,
319    /// this will return None.
320    pub fn path(&self) -> Option<&Path> {
321        self.maybe_resolved().path()
322    }
323
324    /// Get the host of the target. For resolved IP addresses, this is the
325    /// string representation of the IP address. For unresolved hostnames, this
326    /// is the hostname. If the target type does not include a host, this will
327    /// return None.
328    pub fn host(&self) -> Option<Cow<str>> {
329        self.maybe_resolved().host()
330    }
331
332    /// Get the name of the target. For resolved IP addresses, this is the
333    /// string representation of the IP address. For unresolved hostnames, this
334    /// is the hostname.
335    pub fn name(&self) -> Option<ServerName> {
336        self.maybe_resolved().name()
337    }
338
339    /// Get the host and port of the target. If the target type does not include
340    /// a host or port, this will return None.
341    pub fn tcp(&self) -> Option<(Cow<str>, u16)> {
342        self.maybe_resolved().tcp()
343    }
344
345    pub(crate) fn maybe_resolved(&self) -> &MaybeResolvedTarget {
346        match &self.inner {
347            TargetInner::NoTls(target) => target,
348            TargetInner::Tls(target, _) => target,
349            TargetInner::StartTls(target, _) => target,
350        }
351    }
352
353    pub(crate) fn maybe_resolved_mut(&mut self) -> &mut MaybeResolvedTarget {
354        match &mut self.inner {
355            TargetInner::NoTls(target) => target,
356            TargetInner::Tls(target, _) => target,
357            TargetInner::StartTls(target, _) => target,
358        }
359    }
360
361    pub(crate) fn is_starttls(&self) -> bool {
362        matches!(self.inner, TargetInner::StartTls(_, _))
363    }
364
365    pub(crate) fn maybe_ssl(&self) -> Option<&TlsParameters> {
366        match &self.inner {
367            TargetInner::NoTls(_) => None,
368            TargetInner::Tls(_, params) => Some(params),
369            TargetInner::StartTls(_, params) => Some(params),
370        }
371    }
372}
373
374#[derive(Clone, derive_more::From)]
375pub(crate) enum MaybeResolvedTarget {
376    Resolved(ResolvedTarget),
377    Unresolved(Cow<'static, str>, u16, Option<Cow<'static, str>>),
378}
379
380impl std::fmt::Debug for MaybeResolvedTarget {
381    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
382        match self {
383            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
384                if let SocketAddr::V6(addr) = addr {
385                    if addr.scope_id() != 0 {
386                        write!(f, "[{}%{}]:{}", addr.ip(), addr.scope_id(), addr.port())
387                    } else {
388                        write!(f, "[{}]:{}", addr.ip(), addr.port())
389                    }
390                } else {
391                    write!(f, "{}:{}", addr.ip(), addr.port())
392                }
393            }
394            #[cfg(unix)]
395            MaybeResolvedTarget::Resolved(ResolvedTarget::UnixSocketAddr(addr)) => {
396                if let Some(path) = addr.as_pathname() {
397                    return write!(f, "{}", path.to_string_lossy());
398                } else {
399                    #[cfg(any(target_os = "linux", target_os = "android"))]
400                    {
401                        use std::os::linux::net::SocketAddrExt;
402                        if let Some(name) = addr.as_abstract_name() {
403                            return write!(f, "@{}", String::from_utf8_lossy(name));
404                        }
405                    }
406                }
407                Ok(())
408            }
409            MaybeResolvedTarget::Unresolved(host, port, interface) => {
410                write!(f, "{}:{}", host, port)?;
411                if let Some(interface) = interface {
412                    write!(f, "%{}", interface)?;
413                }
414                Ok(())
415            }
416        }
417    }
418}
419
420impl MaybeResolvedTarget {
421    fn name(&self) -> Option<ServerName> {
422        match self {
423            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
424                Some(ServerName::IpAddress(addr.ip().into()))
425            }
426            MaybeResolvedTarget::Unresolved(host, _, _) => {
427                Some(ServerName::DnsName(host.to_string().try_into().ok()?))
428            }
429            _ => None,
430        }
431    }
432
433    fn tcp(&self) -> Option<(Cow<str>, u16)> {
434        match self {
435            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
436                Some((Cow::Owned(addr.ip().to_string()), addr.port()))
437            }
438            MaybeResolvedTarget::Unresolved(host, port, _) => Some((Cow::Borrowed(host), *port)),
439            _ => None,
440        }
441    }
442
443    fn path(&self) -> Option<&Path> {
444        match self {
445            #[cfg(unix)]
446            MaybeResolvedTarget::Resolved(ResolvedTarget::UnixSocketAddr(addr)) => {
447                addr.as_pathname()
448            }
449            _ => None,
450        }
451    }
452
453    fn host(&self) -> Option<Cow<str>> {
454        match self {
455            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
456                Some(Cow::Owned(addr.ip().to_string()))
457            }
458            MaybeResolvedTarget::Unresolved(host, _, _) => Some(Cow::Borrowed(host)),
459            _ => None,
460        }
461    }
462
463    fn port(&self) -> Option<u16> {
464        match self {
465            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => Some(addr.port()),
466            MaybeResolvedTarget::Unresolved(_, port, _) => Some(*port),
467            _ => None,
468        }
469    }
470
471    fn set_port(&mut self, new_port: u16) -> Option<u16> {
472        match self {
473            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
474                let old_port = addr.port();
475                addr.set_port(new_port);
476                Some(old_port)
477            }
478            MaybeResolvedTarget::Unresolved(_, port, _) => {
479                let old_port = *port;
480                *port = new_port;
481                Some(old_port)
482            }
483            _ => None,
484        }
485    }
486}
487
488/// The type of connection.
489#[derive(Clone, Debug)]
490enum TargetInner {
491    NoTls(MaybeResolvedTarget),
492    Tls(MaybeResolvedTarget, Arc<TlsParameters>),
493    StartTls(MaybeResolvedTarget, Arc<TlsParameters>),
494}
495
496#[derive(Clone, Debug, derive_more::From)]
497/// The resolved target of a connection attempt.
498pub enum ResolvedTarget {
499    SocketAddr(std::net::SocketAddr),
500    #[cfg(unix)]
501    UnixSocketAddr(std::os::unix::net::SocketAddr),
502}
503
504impl ResolvedTarget {
505    pub fn tcp(&self) -> Option<SocketAddr> {
506        match self {
507            ResolvedTarget::SocketAddr(addr) => Some(*addr),
508            _ => None,
509        }
510    }
511}
512
513pub trait LocalAddress {
514    fn local_address(&self) -> std::io::Result<ResolvedTarget>;
515}
516
517trait TcpResolve {
518    fn into(self) -> MaybeResolvedTarget;
519}
520
521impl<S: AsRef<str>> TcpResolve for (S, u16) {
522    fn into(self) -> MaybeResolvedTarget {
523        if let Ok(addr) = self.0.as_ref().parse::<IpAddr>() {
524            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(SocketAddr::new(addr, self.1)))
525        } else {
526            MaybeResolvedTarget::Unresolved(Cow::Owned(self.0.as_ref().to_owned()), self.1, None)
527        }
528    }
529}
530
531impl TcpResolve for SocketAddr {
532    fn into(self) -> MaybeResolvedTarget {
533        MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(self))
534    }
535}
536
537#[cfg(test)]
538mod tests {
539    use std::net::SocketAddrV6;
540
541    use super::*;
542
543    #[test]
544    fn test_target() {
545        let target = Target::new_tcp(("localhost", 5432));
546        assert_eq!(
547            target.name(),
548            Some(ServerName::DnsName("localhost".try_into().unwrap()))
549        );
550    }
551
552    #[test]
553    fn test_target_name() {
554        let target = TargetName::new_tcp(("localhost", 5432));
555        assert_eq!(format!("{target:?}"), "localhost:5432");
556
557        let target = TargetName::new_tcp(("127.0.0.1", 5432));
558        assert_eq!(format!("{target:?}"), "127.0.0.1:5432");
559
560        let target = TargetName::new_tcp(("::1", 5432));
561        assert_eq!(format!("{target:?}"), "[::1]:5432");
562
563        let target = TargetName::new_tcp(SocketAddr::V6(SocketAddrV6::new(
564            "fe80::1ff:fe23:4567:890a".parse().unwrap(),
565            5432,
566            0,
567            2,
568        )));
569        assert_eq!(format!("{target:?}"), "[fe80::1ff:fe23:4567:890a%2]:5432");
570
571        #[cfg(unix)]
572        {
573            let target = TargetName::new_unix_path("/tmp/test.sock").unwrap();
574            assert_eq!(format!("{target:?}"), "/tmp/test.sock");
575        }
576
577        #[cfg(any(target_os = "linux", target_os = "android"))]
578        {
579            let target = TargetName::new_unix_domain("test").unwrap();
580            assert_eq!(format!("{target:?}"), "@test");
581        }
582    }
583
584    #[test]
585    fn test_target_debug() {
586        let target = Target::new_tcp(("localhost", 5432));
587        assert_eq!(format!("{target:?}"), "localhost:5432");
588
589        let target = Target::new_tcp_tls(("localhost", 5432), TlsParameters::default());
590        assert_eq!(format!("{target:?}"), "localhost:5432 (TLS)");
591
592        let target = Target::new_tcp_starttls(("localhost", 5432), TlsParameters::default());
593        assert_eq!(format!("{target:?}"), "localhost:5432 (STARTTLS)");
594
595        let target = Target::new_tcp(("127.0.0.1", 5432));
596        assert_eq!(format!("{target:?}"), "127.0.0.1:5432");
597
598        let target = Target::new_tcp(("::1", 5432));
599        assert_eq!(format!("{target:?}"), "[::1]:5432");
600
601        #[cfg(unix)]
602        {
603            let target = Target::new_unix_path("/tmp/test.sock").unwrap();
604            assert_eq!(format!("{target:?}"), "/tmp/test.sock");
605        }
606
607        #[cfg(any(target_os = "linux", target_os = "android"))]
608        {
609            let target = Target::new_unix_domain("test").unwrap();
610            assert_eq!(format!("{target:?}"), "@test");
611        }
612    }
613}