gel_stream/common/
target.rs

1use std::{
2    borrow::Cow,
3    net::{IpAddr, Ipv4Addr, SocketAddr},
4    path::Path,
5    sync::Arc,
6};
7
8use derive_more::Debug;
9use rustls_pki_types::ServerName;
10
11use crate::TlsParameters;
12
13/// A target name describes the TCP or Unix socket that a client will connect to.
14pub struct TargetName {
15    inner: MaybeResolvedTarget,
16}
17
18impl std::fmt::Debug for TargetName {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        write!(f, "{:?}", self.inner)
21    }
22}
23
24impl TargetName {
25    /// Create a new target for a Unix socket.
26    #[cfg(unix)]
27    pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
28        let path = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
29        Ok(Self {
30            inner: MaybeResolvedTarget::Resolved(path),
31        })
32    }
33
34    /// Create a new target for a Unix socket.
35    #[cfg(any(target_os = "linux", target_os = "android"))]
36    pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
37        use std::os::linux::net::SocketAddrExt;
38        let domain =
39            ResolvedTarget::from(std::os::unix::net::SocketAddr::from_abstract_name(domain)?);
40        Ok(Self {
41            inner: MaybeResolvedTarget::Resolved(domain),
42        })
43    }
44
45    /// Create a new target for a TCP socket.
46    #[allow(private_bounds)]
47    pub fn new_tcp(host: impl TcpResolve) -> Self {
48        Self { inner: host.into() }
49    }
50
51    /// Resolves the target addresses for a given host.
52    pub fn to_addrs_sync(&self) -> Result<Vec<ResolvedTarget>, std::io::Error> {
53        use std::net::ToSocketAddrs;
54        let mut result = Vec::new();
55        match &self.inner {
56            MaybeResolvedTarget::Resolved(addr) => {
57                return Ok(vec![addr.clone()]);
58            }
59            MaybeResolvedTarget::Unresolved(host, port, _interface) => {
60                let addrs = format!("{}:{}", host, port).to_socket_addrs()?;
61                result.extend(addrs.map(ResolvedTarget::SocketAddr));
62            }
63        }
64        Ok(result)
65    }
66}
67
68#[derive(Clone, Debug)]
69pub struct Target {
70    inner: TargetInner,
71}
72
73#[allow(private_bounds)]
74impl Target {
75    pub fn new(name: TargetName) -> Self {
76        Self {
77            inner: TargetInner::NoTls(name.inner),
78        }
79    }
80
81    pub fn new_tls(name: TargetName, params: TlsParameters) -> Self {
82        Self {
83            inner: TargetInner::Tls(name.inner, params.into()),
84        }
85    }
86
87    pub fn new_starttls(name: TargetName, params: TlsParameters) -> Self {
88        Self {
89            inner: TargetInner::StartTls(name.inner, params.into()),
90        }
91    }
92
93    pub fn new_resolved(target: ResolvedTarget) -> Self {
94        Self {
95            inner: TargetInner::NoTls(target.into()),
96        }
97    }
98
99    pub fn new_resolved_tls(target: ResolvedTarget, params: TlsParameters) -> Self {
100        Self {
101            inner: TargetInner::Tls(target.into(), params.into()),
102        }
103    }
104
105    pub fn new_resolved_starttls(target: ResolvedTarget, params: TlsParameters) -> Self {
106        Self {
107            inner: TargetInner::StartTls(target.into(), params.into()),
108        }
109    }
110
111    /// Create a new target for a Unix socket.
112    #[cfg(unix)]
113    pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
114        let path = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
115        Ok(Self {
116            inner: TargetInner::NoTls(path.into()),
117        })
118    }
119
120    /// Create a new target for a Unix socket.
121    #[cfg(any(target_os = "linux", target_os = "android"))]
122    pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
123        use std::os::linux::net::SocketAddrExt;
124        let domain =
125            ResolvedTarget::from(std::os::unix::net::SocketAddr::from_abstract_name(domain)?);
126        Ok(Self {
127            inner: TargetInner::NoTls(domain.into()),
128        })
129    }
130
131    /// Create a new target for a TCP socket.
132    pub fn new_tcp(host: impl TcpResolve) -> Self {
133        Self {
134            inner: TargetInner::NoTls(host.into()),
135        }
136    }
137
138    /// Create a new target for a TCP socket with TLS.
139    pub fn new_tcp_tls(host: impl TcpResolve, params: TlsParameters) -> Self {
140        Self {
141            inner: TargetInner::Tls(host.into(), params.into()),
142        }
143    }
144
145    /// Create a new target for a TCP socket with STARTTLS.
146    pub fn new_tcp_starttls(host: impl TcpResolve, params: TlsParameters) -> Self {
147        Self {
148            inner: TargetInner::StartTls(host.into(), params.into()),
149        }
150    }
151
152    pub fn try_set_tls(&mut self, params: TlsParameters) -> Option<Option<Arc<TlsParameters>>> {
153        // Don't set TLS parameters on Unix sockets.
154        if self.maybe_resolved().path().is_some() {
155            return None;
156        }
157
158        let params = params.into();
159
160        // Temporary
161        let no_target = TargetInner::NoTls(MaybeResolvedTarget::Resolved(
162            ResolvedTarget::SocketAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)),
163        ));
164
165        match std::mem::replace(&mut self.inner, no_target) {
166            TargetInner::NoTls(target) => {
167                self.inner = TargetInner::Tls(target, params);
168                Some(None)
169            }
170            TargetInner::Tls(target, old_params) => {
171                self.inner = TargetInner::Tls(target, params);
172                Some(Some(old_params))
173            }
174            TargetInner::StartTls(target, old_params) => {
175                self.inner = TargetInner::StartTls(target, params);
176                Some(Some(old_params))
177            }
178        }
179    }
180
181    pub fn try_remove_tls(&mut self) -> Option<Arc<TlsParameters>> {
182        // Temporary
183        let no_target = TargetInner::NoTls(MaybeResolvedTarget::Resolved(
184            ResolvedTarget::SocketAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)),
185        ));
186
187        match std::mem::replace(&mut self.inner, no_target) {
188            TargetInner::NoTls(target) => {
189                self.inner = TargetInner::NoTls(target);
190                None
191            }
192            TargetInner::Tls(target, old_params) => {
193                self.inner = TargetInner::NoTls(target);
194                Some(old_params)
195            }
196            TargetInner::StartTls(target, old_params) => {
197                self.inner = TargetInner::NoTls(target);
198                Some(old_params)
199            }
200        }
201    }
202
203    /// Check if the target is a TCP connection.
204    pub fn is_tcp(&self) -> bool {
205        self.maybe_resolved().port().is_some()
206    }
207
208    /// Get the port of the target. If the target type does not include a port,
209    /// this will return None.
210    pub fn port(&self) -> Option<u16> {
211        self.maybe_resolved().port()
212    }
213
214    /// Set the port of the target. If the target type does not include a port,
215    /// this will return None. Otherwise, it will return the old port.
216    pub fn try_set_port(&mut self, port: u16) -> Option<u16> {
217        self.maybe_resolved_mut().set_port(port)
218    }
219
220    /// Get the path of the target. If the target type does not include a path,
221    /// this will return None.
222    pub fn path(&self) -> Option<&Path> {
223        self.maybe_resolved().path()
224    }
225
226    /// Get the host of the target. For resolved IP addresses, this is the
227    /// string representation of the IP address. For unresolved hostnames, this
228    /// is the hostname. If the target type does not include a host, this will
229    /// return None.
230    pub fn host(&self) -> Option<Cow<str>> {
231        self.maybe_resolved().host()
232    }
233
234    /// Get the name of the target. For resolved IP addresses, this is the
235    /// string representation of the IP address. For unresolved hostnames, this
236    /// is the hostname.
237    pub fn name(&self) -> Option<ServerName> {
238        self.maybe_resolved().name()
239    }
240
241    /// Get the host and port of the target. If the target type does not include
242    /// a host or port, this will return None.
243    pub fn tcp(&self) -> Option<(Cow<str>, u16)> {
244        self.maybe_resolved().tcp()
245    }
246
247    pub(crate) fn maybe_resolved(&self) -> &MaybeResolvedTarget {
248        match &self.inner {
249            TargetInner::NoTls(target) => target,
250            TargetInner::Tls(target, _) => target,
251            TargetInner::StartTls(target, _) => target,
252        }
253    }
254
255    pub(crate) fn maybe_resolved_mut(&mut self) -> &mut MaybeResolvedTarget {
256        match &mut self.inner {
257            TargetInner::NoTls(target) => target,
258            TargetInner::Tls(target, _) => target,
259            TargetInner::StartTls(target, _) => target,
260        }
261    }
262
263    pub(crate) fn is_starttls(&self) -> bool {
264        matches!(self.inner, TargetInner::StartTls(_, _))
265    }
266
267    pub(crate) fn maybe_ssl(&self) -> Option<&TlsParameters> {
268        match &self.inner {
269            TargetInner::NoTls(_) => None,
270            TargetInner::Tls(_, params) => Some(params),
271            TargetInner::StartTls(_, params) => Some(params),
272        }
273    }
274}
275
276#[derive(Clone, derive_more::From)]
277pub(crate) enum MaybeResolvedTarget {
278    Resolved(ResolvedTarget),
279    Unresolved(Cow<'static, str>, u16, Option<Cow<'static, str>>),
280}
281
282impl std::fmt::Debug for MaybeResolvedTarget {
283    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284        match self {
285            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
286                if let SocketAddr::V6(addr) = addr {
287                    if addr.scope_id() != 0 {
288                        write!(f, "[{}%{}]:{}", addr.ip(), addr.scope_id(), addr.port())
289                    } else {
290                        write!(f, "[{}]:{}", addr.ip(), addr.port())
291                    }
292                } else {
293                    write!(f, "{}:{}", addr.ip(), addr.port())
294                }
295            }
296            #[cfg(unix)]
297            MaybeResolvedTarget::Resolved(ResolvedTarget::UnixSocketAddr(addr)) => {
298                if let Some(path) = addr.as_pathname() {
299                    return write!(f, "{}", path.to_string_lossy());
300                } else {
301                    #[cfg(any(target_os = "linux", target_os = "android"))]
302                    {
303                        use std::os::linux::net::SocketAddrExt;
304                        if let Some(name) = addr.as_abstract_name() {
305                            return write!(f, "@{}", String::from_utf8_lossy(name));
306                        }
307                    }
308                }
309                Ok(())
310            }
311            MaybeResolvedTarget::Unresolved(host, port, interface) => {
312                write!(f, "{}:{}", host, port)?;
313                if let Some(interface) = interface {
314                    write!(f, "%{}", interface)?;
315                }
316                Ok(())
317            }
318        }
319    }
320}
321
322impl MaybeResolvedTarget {
323    fn name(&self) -> Option<ServerName> {
324        match self {
325            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
326                Some(ServerName::IpAddress(addr.ip().into()))
327            }
328            MaybeResolvedTarget::Unresolved(host, _, _) => {
329                Some(ServerName::DnsName(host.to_string().try_into().ok()?))
330            }
331            _ => None,
332        }
333    }
334
335    fn tcp(&self) -> Option<(Cow<str>, u16)> {
336        match self {
337            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
338                Some((Cow::Owned(addr.ip().to_string()), addr.port()))
339            }
340            MaybeResolvedTarget::Unresolved(host, port, _) => Some((Cow::Borrowed(host), *port)),
341            _ => None,
342        }
343    }
344
345    fn path(&self) -> Option<&Path> {
346        match self {
347            #[cfg(unix)]
348            MaybeResolvedTarget::Resolved(ResolvedTarget::UnixSocketAddr(addr)) => {
349                addr.as_pathname()
350            }
351            _ => None,
352        }
353    }
354
355    fn host(&self) -> Option<Cow<str>> {
356        match self {
357            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
358                Some(Cow::Owned(addr.ip().to_string()))
359            }
360            MaybeResolvedTarget::Unresolved(host, _, _) => Some(Cow::Borrowed(host)),
361            _ => None,
362        }
363    }
364
365    fn port(&self) -> Option<u16> {
366        match self {
367            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => Some(addr.port()),
368            MaybeResolvedTarget::Unresolved(_, port, _) => Some(*port),
369            _ => None,
370        }
371    }
372
373    fn set_port(&mut self, new_port: u16) -> Option<u16> {
374        match self {
375            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
376                let old_port = addr.port();
377                addr.set_port(new_port);
378                Some(old_port)
379            }
380            MaybeResolvedTarget::Unresolved(_, port, _) => {
381                let old_port = *port;
382                *port = new_port;
383                Some(old_port)
384            }
385            _ => None,
386        }
387    }
388}
389
390/// The type of connection.
391#[derive(Clone, Debug)]
392enum TargetInner {
393    NoTls(MaybeResolvedTarget),
394    Tls(MaybeResolvedTarget, Arc<TlsParameters>),
395    StartTls(MaybeResolvedTarget, Arc<TlsParameters>),
396}
397
398#[derive(Clone, Debug, derive_more::From)]
399/// The resolved target of a connection attempt.
400pub enum ResolvedTarget {
401    SocketAddr(std::net::SocketAddr),
402    #[cfg(unix)]
403    UnixSocketAddr(std::os::unix::net::SocketAddr),
404}
405
406impl ResolvedTarget {
407    pub fn tcp(&self) -> Option<SocketAddr> {
408        match self {
409            ResolvedTarget::SocketAddr(addr) => Some(*addr),
410            _ => None,
411        }
412    }
413}
414
415pub trait LocalAddress {
416    fn local_address(&self) -> std::io::Result<ResolvedTarget>;
417}
418
419trait TcpResolve {
420    fn into(self) -> MaybeResolvedTarget;
421}
422
423impl<S: AsRef<str>> TcpResolve for (S, u16) {
424    fn into(self) -> MaybeResolvedTarget {
425        if let Ok(addr) = self.0.as_ref().parse::<IpAddr>() {
426            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(SocketAddr::new(addr, self.1)))
427        } else {
428            MaybeResolvedTarget::Unresolved(Cow::Owned(self.0.as_ref().to_owned()), self.1, None)
429        }
430    }
431}
432
433impl TcpResolve for SocketAddr {
434    fn into(self) -> MaybeResolvedTarget {
435        MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(self))
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use std::net::SocketAddrV6;
442
443    use super::*;
444
445    #[test]
446    fn test_target() {
447        let target = Target::new_tcp(("localhost", 5432));
448        assert_eq!(
449            target.name(),
450            Some(ServerName::DnsName("localhost".try_into().unwrap()))
451        );
452    }
453
454    #[test]
455    fn test_target_name() {
456        let target = TargetName::new_tcp(("localhost", 5432));
457        assert_eq!(format!("{target:?}"), "localhost:5432");
458
459        let target = TargetName::new_tcp(("127.0.0.1", 5432));
460        assert_eq!(format!("{target:?}"), "127.0.0.1:5432");
461
462        let target = TargetName::new_tcp(("::1", 5432));
463        assert_eq!(format!("{target:?}"), "[::1]:5432");
464
465        let target = TargetName::new_tcp(SocketAddr::V6(SocketAddrV6::new(
466            "fe80::1ff:fe23:4567:890a".parse().unwrap(),
467            5432,
468            0,
469            2,
470        )));
471        assert_eq!(format!("{target:?}"), "[fe80::1ff:fe23:4567:890a%2]:5432");
472
473        let target = TargetName::new_unix_path("/tmp/test.sock").unwrap();
474        assert_eq!(format!("{target:?}"), "/tmp/test.sock");
475
476        #[cfg(any(target_os = "linux", target_os = "android"))]
477        {
478            let target = TargetName::new_unix_domain("test").unwrap();
479            assert_eq!(format!("{target:?}"), "@test");
480        }
481    }
482}