gel_stream/common/
target.rs

1use std::{
2    borrow::Cow,
3    net::{IpAddr, Ipv4Addr, SocketAddr},
4    path::Path,
5    sync::Arc,
6};
7
8use derive_more::Debug;
9use rustls_pki_types::ServerName;
10
11use crate::TlsParameters;
12
13/// A target name describes the TCP or Unix socket that a client will connect to.
14pub struct TargetName {
15    inner: MaybeResolvedTarget,
16}
17
18impl std::fmt::Debug for TargetName {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        write!(f, "{:?}", self.inner)
21    }
22}
23
24impl TargetName {
25    /// Create a new target for a Unix socket.
26    pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
27        #[cfg(unix)]
28        {
29            let path = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
30            Ok(Self {
31                inner: MaybeResolvedTarget::Resolved(path),
32            })
33        }
34        #[cfg(not(unix))]
35        {
36            Err(std::io::Error::new(
37                std::io::ErrorKind::Unsupported,
38                "Unix sockets are not supported on this platform",
39            ))
40        }
41    }
42
43    /// Create a new target for a Unix socket.
44    pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
45        #[cfg(any(target_os = "linux", target_os = "android"))]
46        {
47            use std::os::linux::net::SocketAddrExt;
48            let domain =
49                ResolvedTarget::from(std::os::unix::net::SocketAddr::from_abstract_name(domain)?);
50            Ok(Self {
51                inner: MaybeResolvedTarget::Resolved(domain),
52            })
53        }
54        #[cfg(not(any(target_os = "linux", target_os = "android")))]
55        {
56            Err(std::io::Error::new(
57                std::io::ErrorKind::Unsupported,
58                "Unix domain sockets are not supported on this platform",
59            ))
60        }
61    }
62
63    /// Create a new target for a TCP socket.
64    #[allow(private_bounds)]
65    pub fn new_tcp(host: impl TcpResolve) -> Self {
66        Self { inner: host.into() }
67    }
68
69    /// Resolves the target addresses for a given host.
70    pub fn to_addrs_sync(&self) -> Result<Vec<ResolvedTarget>, std::io::Error> {
71        use std::net::ToSocketAddrs;
72        let mut result = Vec::new();
73        match &self.inner {
74            MaybeResolvedTarget::Resolved(addr) => {
75                return Ok(vec![addr.clone()]);
76            }
77            MaybeResolvedTarget::Unresolved(host, port, _interface) => {
78                let addrs = format!("{}:{}", host, port).to_socket_addrs()?;
79                result.extend(addrs.map(ResolvedTarget::SocketAddr));
80            }
81        }
82        Ok(result)
83    }
84
85    pub(crate) fn maybe_resolved(&self) -> &MaybeResolvedTarget {
86        &self.inner
87    }
88
89    pub(crate) fn maybe_resolved_mut(&mut self) -> &mut MaybeResolvedTarget {
90        &mut self.inner
91    }
92
93    /// Check if the target is a TCP connection.
94    pub fn is_tcp(&self) -> bool {
95        self.maybe_resolved().port().is_some()
96    }
97
98    /// Get the port of the target. If the target type does not include a port,
99    /// this will return None.
100    pub fn port(&self) -> Option<u16> {
101        self.maybe_resolved().port()
102    }
103
104    /// Set the port of the target. If the target type does not include a port,
105    /// this will return None. Otherwise, it will return the old port.
106    pub fn try_set_port(&mut self, port: u16) -> Option<u16> {
107        self.maybe_resolved_mut().set_port(port)
108    }
109
110    /// Get the path of the target. If the target type does not include a path,
111    /// this will return None.
112    pub fn path(&self) -> Option<&Path> {
113        self.maybe_resolved().path()
114    }
115
116    /// Get the host of the target. For resolved IP addresses, this is the
117    /// string representation of the IP address. For unresolved hostnames, this
118    /// is the hostname. If the target type does not include a host, this will
119    /// return None.
120    pub fn host(&self) -> Option<Cow<str>> {
121        self.maybe_resolved().host()
122    }
123
124    /// Get the name of the target. For resolved IP addresses, this is the
125    /// string representation of the IP address. For unresolved hostnames, this
126    /// is the hostname.
127    pub fn name(&self) -> Option<ServerName> {
128        self.maybe_resolved().name()
129    }
130
131    /// Get the host and port of the target. If the target type does not include
132    /// a host or port, this will return None.
133    pub fn tcp(&self) -> Option<(Cow<str>, u16)> {
134        self.maybe_resolved().tcp()
135    }
136}
137
138/// A target describes the TCP or Unix socket that a client will connect to,
139/// along with any optional TLS parameters.
140#[derive(Clone)]
141pub struct Target {
142    inner: TargetInner,
143}
144
145impl std::fmt::Debug for Target {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        match &self.inner {
148            TargetInner::NoTls(target) => write!(f, "{:?}", target),
149            TargetInner::Tls(target, _) => write!(f, "{:?} (TLS)", target),
150            TargetInner::StartTls(target, _) => write!(f, "{:?} (STARTTLS)", target),
151        }
152    }
153}
154
155#[allow(private_bounds)]
156impl Target {
157    pub fn new(name: TargetName) -> Self {
158        Self {
159            inner: TargetInner::NoTls(name.inner),
160        }
161    }
162
163    pub fn new_tls(name: TargetName, params: TlsParameters) -> Self {
164        Self {
165            inner: TargetInner::Tls(name.inner, params.into()),
166        }
167    }
168
169    pub fn new_starttls(name: TargetName, params: TlsParameters) -> Self {
170        Self {
171            inner: TargetInner::StartTls(name.inner, params.into()),
172        }
173    }
174
175    pub fn new_resolved(target: ResolvedTarget) -> Self {
176        Self {
177            inner: TargetInner::NoTls(target.into()),
178        }
179    }
180
181    pub fn new_resolved_tls(target: ResolvedTarget, params: TlsParameters) -> Self {
182        Self {
183            inner: TargetInner::Tls(target.into(), params.into()),
184        }
185    }
186
187    pub fn new_resolved_starttls(target: ResolvedTarget, params: TlsParameters) -> Self {
188        Self {
189            inner: TargetInner::StartTls(target.into(), params.into()),
190        }
191    }
192
193    /// Create a new target for a Unix socket.
194    pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
195        #[cfg(unix)]
196        {
197            let path = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
198            Ok(Self {
199                inner: TargetInner::NoTls(path.into()),
200            })
201        }
202        #[cfg(not(unix))]
203        {
204            Err(std::io::Error::new(
205                std::io::ErrorKind::Unsupported,
206                "Unix sockets are not supported on this platform",
207            ))
208        }
209    }
210
211    /// Create a new target for a Unix socket.
212    pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
213        #[cfg(any(target_os = "linux", target_os = "android"))]
214        {
215            use std::os::linux::net::SocketAddrExt;
216            let domain =
217                ResolvedTarget::from(std::os::unix::net::SocketAddr::from_abstract_name(domain)?);
218            Ok(Self {
219                inner: TargetInner::NoTls(domain.into()),
220            })
221        }
222        #[cfg(not(any(target_os = "linux", target_os = "android")))]
223        {
224            Err(std::io::Error::new(
225                std::io::ErrorKind::Unsupported,
226                "Unix domain sockets are not supported on this platform",
227            ))
228        }
229    }
230
231    /// Create a new target for a TCP socket.
232    pub fn new_tcp(host: impl TcpResolve) -> Self {
233        Self {
234            inner: TargetInner::NoTls(host.into()),
235        }
236    }
237
238    /// Create a new target for a TCP socket with TLS.
239    pub fn new_tcp_tls(host: impl TcpResolve, params: TlsParameters) -> Self {
240        Self {
241            inner: TargetInner::Tls(host.into(), params.into()),
242        }
243    }
244
245    /// Create a new target for a TCP socket with STARTTLS.
246    pub fn new_tcp_starttls(host: impl TcpResolve, params: TlsParameters) -> Self {
247        Self {
248            inner: TargetInner::StartTls(host.into(), params.into()),
249        }
250    }
251
252    pub fn try_set_tls(&mut self, params: TlsParameters) -> Option<Option<Arc<TlsParameters>>> {
253        // Don't set TLS parameters on Unix sockets.
254        if self.maybe_resolved().path().is_some() {
255            return None;
256        }
257
258        let params = params.into();
259
260        // Temporary
261        let no_target = TargetInner::NoTls(MaybeResolvedTarget::Resolved(
262            ResolvedTarget::SocketAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)),
263        ));
264
265        match std::mem::replace(&mut self.inner, no_target) {
266            TargetInner::NoTls(target) => {
267                self.inner = TargetInner::Tls(target, params);
268                Some(None)
269            }
270            TargetInner::Tls(target, old_params) => {
271                self.inner = TargetInner::Tls(target, params);
272                Some(Some(old_params))
273            }
274            TargetInner::StartTls(target, old_params) => {
275                self.inner = TargetInner::StartTls(target, params);
276                Some(Some(old_params))
277            }
278        }
279    }
280
281    pub fn try_remove_tls(&mut self) -> Option<Arc<TlsParameters>> {
282        // Temporary
283        let no_target = TargetInner::NoTls(MaybeResolvedTarget::Resolved(
284            ResolvedTarget::SocketAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)),
285        ));
286
287        match std::mem::replace(&mut self.inner, no_target) {
288            TargetInner::NoTls(target) => {
289                self.inner = TargetInner::NoTls(target);
290                None
291            }
292            TargetInner::Tls(target, old_params) => {
293                self.inner = TargetInner::NoTls(target);
294                Some(old_params)
295            }
296            TargetInner::StartTls(target, old_params) => {
297                self.inner = TargetInner::NoTls(target);
298                Some(old_params)
299            }
300        }
301    }
302
303    /// Check if the target is a TCP connection.
304    pub fn is_tcp(&self) -> bool {
305        self.maybe_resolved().port().is_some()
306    }
307
308    /// Get the port of the target. If the target type does not include a port,
309    /// this will return None.
310    pub fn port(&self) -> Option<u16> {
311        self.maybe_resolved().port()
312    }
313
314    /// Set the port of the target. If the target type does not include a port,
315    /// this will return None. Otherwise, it will return the old port.
316    pub fn try_set_port(&mut self, port: u16) -> Option<u16> {
317        self.maybe_resolved_mut().set_port(port)
318    }
319
320    /// Get the path of the target. If the target type does not include a path,
321    /// this will return None.
322    pub fn path(&self) -> Option<&Path> {
323        self.maybe_resolved().path()
324    }
325
326    /// Get the host of the target. For resolved IP addresses, this is the
327    /// string representation of the IP address. For unresolved hostnames, this
328    /// is the hostname. If the target type does not include a host, this will
329    /// return None.
330    pub fn host(&self) -> Option<Cow<str>> {
331        self.maybe_resolved().host()
332    }
333
334    /// Get the name of the target. For resolved IP addresses, this is the
335    /// string representation of the IP address. For unresolved hostnames, this
336    /// is the hostname.
337    pub fn name(&self) -> Option<ServerName> {
338        self.maybe_resolved().name()
339    }
340
341    /// Get the host and port of the target. If the target type does not include
342    /// a host or port, this will return None.
343    pub fn tcp(&self) -> Option<(Cow<str>, u16)> {
344        self.maybe_resolved().tcp()
345    }
346
347    pub(crate) fn maybe_resolved(&self) -> &MaybeResolvedTarget {
348        match &self.inner {
349            TargetInner::NoTls(target) => target,
350            TargetInner::Tls(target, _) => target,
351            TargetInner::StartTls(target, _) => target,
352        }
353    }
354
355    pub(crate) fn maybe_resolved_mut(&mut self) -> &mut MaybeResolvedTarget {
356        match &mut self.inner {
357            TargetInner::NoTls(target) => target,
358            TargetInner::Tls(target, _) => target,
359            TargetInner::StartTls(target, _) => target,
360        }
361    }
362
363    pub(crate) fn is_starttls(&self) -> bool {
364        matches!(self.inner, TargetInner::StartTls(_, _))
365    }
366
367    pub(crate) fn maybe_ssl(&self) -> Option<&TlsParameters> {
368        match &self.inner {
369            TargetInner::NoTls(_) => None,
370            TargetInner::Tls(_, params) => Some(params),
371            TargetInner::StartTls(_, params) => Some(params),
372        }
373    }
374}
375
376#[derive(Clone, derive_more::From)]
377pub(crate) enum MaybeResolvedTarget {
378    Resolved(ResolvedTarget),
379    Unresolved(Cow<'static, str>, u16, Option<Cow<'static, str>>),
380}
381
382impl std::fmt::Debug for MaybeResolvedTarget {
383    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384        match self {
385            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
386                if let SocketAddr::V6(addr) = addr {
387                    if addr.scope_id() != 0 {
388                        write!(f, "[{}%{}]:{}", addr.ip(), addr.scope_id(), addr.port())
389                    } else {
390                        write!(f, "[{}]:{}", addr.ip(), addr.port())
391                    }
392                } else {
393                    write!(f, "{}:{}", addr.ip(), addr.port())
394                }
395            }
396            #[cfg(unix)]
397            MaybeResolvedTarget::Resolved(ResolvedTarget::UnixSocketAddr(addr)) => {
398                if let Some(path) = addr.as_pathname() {
399                    return write!(f, "{}", path.to_string_lossy());
400                } else {
401                    #[cfg(any(target_os = "linux", target_os = "android"))]
402                    {
403                        use std::os::linux::net::SocketAddrExt;
404                        if let Some(name) = addr.as_abstract_name() {
405                            return write!(f, "@{}", String::from_utf8_lossy(name));
406                        }
407                    }
408                }
409                Ok(())
410            }
411            MaybeResolvedTarget::Unresolved(host, port, interface) => {
412                write!(f, "{}:{}", host, port)?;
413                if let Some(interface) = interface {
414                    write!(f, "%{}", interface)?;
415                }
416                Ok(())
417            }
418        }
419    }
420}
421
422impl MaybeResolvedTarget {
423    fn name(&self) -> Option<ServerName> {
424        match self {
425            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
426                Some(ServerName::IpAddress(addr.ip().into()))
427            }
428            MaybeResolvedTarget::Unresolved(host, _, _) => {
429                Some(ServerName::DnsName(host.to_string().try_into().ok()?))
430            }
431            _ => None,
432        }
433    }
434
435    fn tcp(&self) -> Option<(Cow<str>, u16)> {
436        match self {
437            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
438                Some((Cow::Owned(addr.ip().to_string()), addr.port()))
439            }
440            MaybeResolvedTarget::Unresolved(host, port, _) => Some((Cow::Borrowed(host), *port)),
441            _ => None,
442        }
443    }
444
445    fn path(&self) -> Option<&Path> {
446        match self {
447            #[cfg(unix)]
448            MaybeResolvedTarget::Resolved(ResolvedTarget::UnixSocketAddr(addr)) => {
449                addr.as_pathname()
450            }
451            _ => None,
452        }
453    }
454
455    fn host(&self) -> Option<Cow<str>> {
456        match self {
457            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
458                Some(Cow::Owned(addr.ip().to_string()))
459            }
460            MaybeResolvedTarget::Unresolved(host, _, _) => Some(Cow::Borrowed(host)),
461            _ => None,
462        }
463    }
464
465    fn port(&self) -> Option<u16> {
466        match self {
467            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => Some(addr.port()),
468            MaybeResolvedTarget::Unresolved(_, port, _) => Some(*port),
469            _ => None,
470        }
471    }
472
473    fn set_port(&mut self, new_port: u16) -> Option<u16> {
474        match self {
475            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
476                let old_port = addr.port();
477                addr.set_port(new_port);
478                Some(old_port)
479            }
480            MaybeResolvedTarget::Unresolved(_, port, _) => {
481                let old_port = *port;
482                *port = new_port;
483                Some(old_port)
484            }
485            _ => None,
486        }
487    }
488}
489
490/// The type of connection.
491#[derive(Clone, Debug)]
492enum TargetInner {
493    NoTls(MaybeResolvedTarget),
494    Tls(MaybeResolvedTarget, Arc<TlsParameters>),
495    StartTls(MaybeResolvedTarget, Arc<TlsParameters>),
496}
497
498#[derive(Clone, Debug, derive_more::From)]
499/// The resolved target of a connection attempt.
500pub enum ResolvedTarget {
501    SocketAddr(std::net::SocketAddr),
502    #[cfg(unix)]
503    UnixSocketAddr(std::os::unix::net::SocketAddr),
504}
505
506impl ResolvedTarget {
507    pub fn tcp(&self) -> Option<SocketAddr> {
508        match self {
509            ResolvedTarget::SocketAddr(addr) => Some(*addr),
510            _ => None,
511        }
512    }
513}
514
515/// A trait for types that have a local address.
516pub trait LocalAddress {
517    fn local_address(&self) -> std::io::Result<ResolvedTarget>;
518}
519
520trait TcpResolve {
521    fn into(self) -> MaybeResolvedTarget;
522}
523
524impl<S: AsRef<str>> TcpResolve for (S, u16) {
525    fn into(self) -> MaybeResolvedTarget {
526        if let Ok(addr) = self.0.as_ref().parse::<IpAddr>() {
527            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(SocketAddr::new(addr, self.1)))
528        } else {
529            MaybeResolvedTarget::Unresolved(Cow::Owned(self.0.as_ref().to_owned()), self.1, None)
530        }
531    }
532}
533
534impl TcpResolve for SocketAddr {
535    fn into(self) -> MaybeResolvedTarget {
536        MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(self))
537    }
538}
539
540#[cfg(test)]
541mod tests {
542    use std::net::SocketAddrV6;
543
544    use super::*;
545
546    #[test]
547    fn test_target() {
548        let target = Target::new_tcp(("localhost", 5432));
549        assert_eq!(
550            target.name(),
551            Some(ServerName::DnsName("localhost".try_into().unwrap()))
552        );
553    }
554
555    #[test]
556    fn test_target_name() {
557        let target = TargetName::new_tcp(("localhost", 5432));
558        assert_eq!(format!("{target:?}"), "localhost:5432");
559
560        let target = TargetName::new_tcp(("127.0.0.1", 5432));
561        assert_eq!(format!("{target:?}"), "127.0.0.1:5432");
562
563        let target = TargetName::new_tcp(("::1", 5432));
564        assert_eq!(format!("{target:?}"), "[::1]:5432");
565
566        let target = TargetName::new_tcp(SocketAddr::V6(SocketAddrV6::new(
567            "fe80::1ff:fe23:4567:890a".parse().unwrap(),
568            5432,
569            0,
570            2,
571        )));
572        assert_eq!(format!("{target:?}"), "[fe80::1ff:fe23:4567:890a%2]:5432");
573
574        #[cfg(unix)]
575        {
576            let target = TargetName::new_unix_path("/tmp/test.sock").unwrap();
577            assert_eq!(format!("{target:?}"), "/tmp/test.sock");
578        }
579
580        #[cfg(any(target_os = "linux", target_os = "android"))]
581        {
582            let target = TargetName::new_unix_domain("test").unwrap();
583            assert_eq!(format!("{target:?}"), "@test");
584        }
585    }
586
587    #[test]
588    fn test_target_debug() {
589        let target = Target::new_tcp(("localhost", 5432));
590        assert_eq!(format!("{target:?}"), "localhost:5432");
591
592        let target = Target::new_tcp_tls(("localhost", 5432), TlsParameters::default());
593        assert_eq!(format!("{target:?}"), "localhost:5432 (TLS)");
594
595        let target = Target::new_tcp_starttls(("localhost", 5432), TlsParameters::default());
596        assert_eq!(format!("{target:?}"), "localhost:5432 (STARTTLS)");
597
598        let target = Target::new_tcp(("127.0.0.1", 5432));
599        assert_eq!(format!("{target:?}"), "127.0.0.1:5432");
600
601        let target = Target::new_tcp(("::1", 5432));
602        assert_eq!(format!("{target:?}"), "[::1]:5432");
603
604        #[cfg(unix)]
605        {
606            let target = Target::new_unix_path("/tmp/test.sock").unwrap();
607            assert_eq!(format!("{target:?}"), "/tmp/test.sock");
608        }
609
610        #[cfg(any(target_os = "linux", target_os = "android"))]
611        {
612            let target = Target::new_unix_domain("test").unwrap();
613            assert_eq!(format!("{target:?}"), "@test");
614        }
615    }
616}