gel_stream/common/
target.rs

1use std::{
2    borrow::Cow,
3    net::{IpAddr, Ipv4Addr, SocketAddr},
4    path::Path,
5    sync::Arc,
6};
7
8use derive_more::Debug;
9use rustls_pki_types::ServerName;
10
11use crate::TlsParameters;
12
13/// A target name describes the TCP or Unix socket that a client will connect to.
14pub struct TargetName {
15    inner: MaybeResolvedTarget,
16}
17
18impl std::fmt::Debug for TargetName {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        write!(f, "{:?}", self.inner)
21    }
22}
23
24impl TargetName {
25    /// Create a new target for a Unix socket.
26    #[cfg(unix)]
27    pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
28        let path = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
29        Ok(Self {
30            inner: MaybeResolvedTarget::Resolved(path),
31        })
32    }
33
34    /// Create a new target for a Unix socket.
35    #[cfg(any(target_os = "linux", target_os = "android"))]
36    pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
37        use std::os::linux::net::SocketAddrExt;
38        let domain =
39            ResolvedTarget::from(std::os::unix::net::SocketAddr::from_abstract_name(domain)?);
40        Ok(Self {
41            inner: MaybeResolvedTarget::Resolved(domain),
42        })
43    }
44
45    /// Create a new target for a TCP socket.
46    #[allow(private_bounds)]
47    pub fn new_tcp(host: impl TcpResolve) -> Self {
48        Self { inner: host.into() }
49    }
50
51    /// Resolves the target addresses for a given host.
52    pub fn to_addrs_sync(&self) -> Result<Vec<ResolvedTarget>, std::io::Error> {
53        use std::net::ToSocketAddrs;
54        let mut result = Vec::new();
55        match &self.inner {
56            MaybeResolvedTarget::Resolved(addr) => {
57                return Ok(vec![addr.clone()]);
58            }
59            MaybeResolvedTarget::Unresolved(host, port, _interface) => {
60                let addrs = format!("{}:{}", host, port).to_socket_addrs()?;
61                result.extend(addrs.map(ResolvedTarget::SocketAddr));
62            }
63        }
64        Ok(result)
65    }
66}
67
68#[derive(Clone, Debug)]
69pub struct Target {
70    inner: TargetInner,
71}
72
73#[allow(private_bounds)]
74impl Target {
75    pub fn new(name: TargetName) -> Self {
76        Self {
77            inner: TargetInner::NoTls(name.inner),
78        }
79    }
80
81    pub fn new_tls(name: TargetName, params: TlsParameters) -> Self {
82        Self {
83            inner: TargetInner::Tls(name.inner, params.into()),
84        }
85    }
86
87    pub fn new_starttls(name: TargetName, params: TlsParameters) -> Self {
88        Self {
89            inner: TargetInner::StartTls(name.inner, params.into()),
90        }
91    }
92
93    pub fn new_resolved(target: ResolvedTarget) -> Self {
94        Self {
95            inner: TargetInner::NoTls(target.into()),
96        }
97    }
98
99    pub fn new_resolved_tls(target: ResolvedTarget, params: TlsParameters) -> Self {
100        Self {
101            inner: TargetInner::Tls(target.into(), params.into()),
102        }
103    }
104
105    pub fn new_resolved_starttls(target: ResolvedTarget, params: TlsParameters) -> Self {
106        Self {
107            inner: TargetInner::StartTls(target.into(), params.into()),
108        }
109    }
110
111    /// Create a new target for a Unix socket.
112    #[cfg(unix)]
113    pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
114        let path = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
115        Ok(Self {
116            inner: TargetInner::NoTls(path.into()),
117        })
118    }
119
120    /// Create a new target for a Unix socket.
121    #[cfg(any(target_os = "linux", target_os = "android"))]
122    pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
123        use std::os::linux::net::SocketAddrExt;
124        let domain =
125            ResolvedTarget::from(std::os::unix::net::SocketAddr::from_abstract_name(domain)?);
126        Ok(Self {
127            inner: TargetInner::NoTls(domain.into()),
128        })
129    }
130
131    /// Create a new target for a TCP socket.
132    pub fn new_tcp(host: impl TcpResolve) -> Self {
133        Self {
134            inner: TargetInner::NoTls(host.into()),
135        }
136    }
137
138    /// Create a new target for a TCP socket with TLS.
139    pub fn new_tcp_tls(host: impl TcpResolve, params: TlsParameters) -> Self {
140        Self {
141            inner: TargetInner::Tls(host.into(), params.into()),
142        }
143    }
144
145    /// Create a new target for a TCP socket with STARTTLS.
146    pub fn new_tcp_starttls(host: impl TcpResolve, params: TlsParameters) -> Self {
147        Self {
148            inner: TargetInner::StartTls(host.into(), params.into()),
149        }
150    }
151
152    pub fn try_set_tls(&mut self, params: TlsParameters) -> Option<Option<Arc<TlsParameters>>> {
153        // Don't set TLS parameters on Unix sockets.
154        if self.maybe_resolved().path().is_some() {
155            return None;
156        }
157
158        let params = params.into();
159
160        // Temporary
161        let no_target = TargetInner::NoTls(MaybeResolvedTarget::Resolved(
162            ResolvedTarget::SocketAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)),
163        ));
164
165        match std::mem::replace(&mut self.inner, no_target) {
166            TargetInner::NoTls(target) => {
167                self.inner = TargetInner::Tls(target, params);
168                Some(None)
169            }
170            TargetInner::Tls(target, old_params) => {
171                self.inner = TargetInner::Tls(target, params);
172                Some(Some(old_params))
173            }
174            TargetInner::StartTls(target, old_params) => {
175                self.inner = TargetInner::StartTls(target, params);
176                Some(Some(old_params))
177            }
178        }
179    }
180
181    pub fn try_remove_tls(&mut self) -> Option<Arc<TlsParameters>> {
182        // Temporary
183        let no_target = TargetInner::NoTls(MaybeResolvedTarget::Resolved(
184            ResolvedTarget::SocketAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)),
185        ));
186
187        match std::mem::replace(&mut self.inner, no_target) {
188            TargetInner::NoTls(target) => {
189                self.inner = TargetInner::NoTls(target);
190                None
191            }
192            TargetInner::Tls(target, old_params) => {
193                self.inner = TargetInner::NoTls(target);
194                Some(old_params)
195            }
196            TargetInner::StartTls(target, old_params) => {
197                self.inner = TargetInner::NoTls(target);
198                Some(old_params)
199            }
200        }
201    }
202
203    /// Check if the target is a TCP connection.
204    pub fn is_tcp(&self) -> bool {
205        self.maybe_resolved().port().is_some()
206    }
207
208    /// Get the port of the target. If the target type does not include a port,
209    /// this will return None.
210    pub fn port(&self) -> Option<u16> {
211        self.maybe_resolved().port()
212    }
213
214    /// Set the port of the target. If the target type does not include a port,
215    /// this will return None. Otherwise, it will return the old port.
216    pub fn try_set_port(&mut self, port: u16) -> Option<u16> {
217        self.maybe_resolved_mut().set_port(port)
218    }
219
220    /// Get the path of the target. If the target type does not include a path,
221    /// this will return None.
222    pub fn path(&self) -> Option<&Path> {
223        self.maybe_resolved().path()
224    }
225
226    /// Get the host of the target. For resolved IP addresses, this is the
227    /// string representation of the IP address. For unresolved hostnames, this
228    /// is the hostname. If the target type does not include a host, this will
229    /// return None.
230    pub fn host(&self) -> Option<Cow<str>> {
231        self.maybe_resolved().host()
232    }
233
234    /// Get the name of the target. For resolved IP addresses, this is the
235    /// string representation of the IP address. For unresolved hostnames, this
236    /// is the hostname.
237    pub fn name(&self) -> Option<ServerName> {
238        self.maybe_resolved().name()
239    }
240
241    /// Get the host and port of the target. If the target type does not include
242    /// a host or port, this will return None.
243    pub fn tcp(&self) -> Option<(Cow<str>, u16)> {
244        self.maybe_resolved().tcp()
245    }
246
247    pub(crate) fn maybe_resolved(&self) -> &MaybeResolvedTarget {
248        match &self.inner {
249            TargetInner::NoTls(target) => target,
250            TargetInner::Tls(target, _) => target,
251            TargetInner::StartTls(target, _) => target,
252        }
253    }
254
255    pub(crate) fn maybe_resolved_mut(&mut self) -> &mut MaybeResolvedTarget {
256        match &mut self.inner {
257            TargetInner::NoTls(target) => target,
258            TargetInner::Tls(target, _) => target,
259            TargetInner::StartTls(target, _) => target,
260        }
261    }
262
263    pub(crate) fn is_starttls(&self) -> bool {
264        matches!(self.inner, TargetInner::StartTls(_, _))
265    }
266
267    pub(crate) fn maybe_ssl(&self) -> Option<&TlsParameters> {
268        match &self.inner {
269            TargetInner::NoTls(_) => None,
270            TargetInner::Tls(_, params) => Some(params),
271            TargetInner::StartTls(_, params) => Some(params),
272        }
273    }
274}
275
276#[derive(Clone, derive_more::From)]
277pub(crate) enum MaybeResolvedTarget {
278    Resolved(ResolvedTarget),
279    Unresolved(Cow<'static, str>, u16, Option<Cow<'static, str>>),
280}
281
282impl std::fmt::Debug for MaybeResolvedTarget {
283    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284        match self {
285            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
286                if let SocketAddr::V6(addr) = addr {
287                    if addr.scope_id() != 0 {
288                        write!(f, "[{}%{}]:{}", addr.ip(), addr.scope_id(), addr.port())
289                    } else {
290                        write!(f, "[{}]:{}", addr.ip(), addr.port())
291                    }
292                } else {
293                    write!(f, "{}:{}", addr.ip(), addr.port())
294                }
295            }
296            MaybeResolvedTarget::Resolved(ResolvedTarget::UnixSocketAddr(addr)) => {
297                if let Some(path) = addr.as_pathname() {
298                    return write!(f, "{}", path.to_string_lossy());
299                } else {
300                    #[cfg(any(target_os = "linux", target_os = "android"))]
301                    {
302                        use std::os::linux::net::SocketAddrExt;
303                        if let Some(name) = addr.as_abstract_name() {
304                            return write!(f, "@{}", String::from_utf8_lossy(name));
305                        }
306                    }
307                }
308                Ok(())
309            }
310            MaybeResolvedTarget::Unresolved(host, port, interface) => {
311                write!(f, "{}:{}", host, port)?;
312                if let Some(interface) = interface {
313                    write!(f, "%{}", interface)?;
314                }
315                Ok(())
316            }
317        }
318    }
319}
320
321impl MaybeResolvedTarget {
322    fn name(&self) -> Option<ServerName> {
323        match self {
324            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
325                Some(ServerName::IpAddress(addr.ip().into()))
326            }
327            MaybeResolvedTarget::Unresolved(host, _, _) => {
328                Some(ServerName::DnsName(host.to_string().try_into().ok()?))
329            }
330            _ => None,
331        }
332    }
333
334    fn tcp(&self) -> Option<(Cow<str>, u16)> {
335        match self {
336            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
337                Some((Cow::Owned(addr.ip().to_string()), addr.port()))
338            }
339            MaybeResolvedTarget::Unresolved(host, port, _) => Some((Cow::Borrowed(host), *port)),
340            _ => None,
341        }
342    }
343
344    fn path(&self) -> Option<&Path> {
345        match self {
346            MaybeResolvedTarget::Resolved(ResolvedTarget::UnixSocketAddr(addr)) => {
347                addr.as_pathname()
348            }
349            _ => None,
350        }
351    }
352
353    fn host(&self) -> Option<Cow<str>> {
354        match self {
355            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
356                Some(Cow::Owned(addr.ip().to_string()))
357            }
358            MaybeResolvedTarget::Unresolved(host, _, _) => Some(Cow::Borrowed(host)),
359            _ => None,
360        }
361    }
362
363    fn port(&self) -> Option<u16> {
364        match self {
365            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => Some(addr.port()),
366            MaybeResolvedTarget::Unresolved(_, port, _) => Some(*port),
367            _ => None,
368        }
369    }
370
371    fn set_port(&mut self, new_port: u16) -> Option<u16> {
372        match self {
373            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(addr)) => {
374                let old_port = addr.port();
375                addr.set_port(new_port);
376                Some(old_port)
377            }
378            MaybeResolvedTarget::Unresolved(_, port, _) => {
379                let old_port = *port;
380                *port = new_port;
381                Some(old_port)
382            }
383            _ => None,
384        }
385    }
386}
387
388/// The type of connection.
389#[derive(Clone, Debug)]
390enum TargetInner {
391    NoTls(MaybeResolvedTarget),
392    Tls(MaybeResolvedTarget, Arc<TlsParameters>),
393    StartTls(MaybeResolvedTarget, Arc<TlsParameters>),
394}
395
396#[derive(Clone, Debug, derive_more::From)]
397/// The resolved target of a connection attempt.
398pub enum ResolvedTarget {
399    SocketAddr(std::net::SocketAddr),
400    #[cfg(unix)]
401    UnixSocketAddr(std::os::unix::net::SocketAddr),
402}
403
404impl ResolvedTarget {
405    pub fn tcp(&self) -> Option<SocketAddr> {
406        match self {
407            ResolvedTarget::SocketAddr(addr) => Some(*addr),
408            _ => None,
409        }
410    }
411}
412
413pub trait LocalAddress {
414    fn local_address(&self) -> std::io::Result<ResolvedTarget>;
415}
416
417trait TcpResolve {
418    fn into(self) -> MaybeResolvedTarget;
419}
420
421impl<S: AsRef<str>> TcpResolve for (S, u16) {
422    fn into(self) -> MaybeResolvedTarget {
423        if let Ok(addr) = self.0.as_ref().parse::<IpAddr>() {
424            MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(SocketAddr::new(addr, self.1)))
425        } else {
426            MaybeResolvedTarget::Unresolved(Cow::Owned(self.0.as_ref().to_owned()), self.1, None)
427        }
428    }
429}
430
431impl TcpResolve for SocketAddr {
432    fn into(self) -> MaybeResolvedTarget {
433        MaybeResolvedTarget::Resolved(ResolvedTarget::SocketAddr(self))
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    use std::net::SocketAddrV6;
440
441    use super::*;
442
443    #[test]
444    fn test_target() {
445        let target = Target::new_tcp(("localhost", 5432));
446        assert_eq!(
447            target.name(),
448            Some(ServerName::DnsName("localhost".try_into().unwrap()))
449        );
450    }
451
452    #[test]
453    fn test_target_name() {
454        let target = TargetName::new_tcp(("localhost", 5432));
455        assert_eq!(format!("{target:?}"), "localhost:5432");
456
457        let target = TargetName::new_tcp(("127.0.0.1", 5432));
458        assert_eq!(format!("{target:?}"), "127.0.0.1:5432");
459
460        let target = TargetName::new_tcp(("::1", 5432));
461        assert_eq!(format!("{target:?}"), "[::1]:5432");
462
463        let target = TargetName::new_tcp(SocketAddr::V6(SocketAddrV6::new(
464            "fe80::1ff:fe23:4567:890a".parse().unwrap(),
465            5432,
466            0,
467            2,
468        )));
469        assert_eq!(format!("{target:?}"), "[fe80::1ff:fe23:4567:890a%2]:5432");
470
471        let target = TargetName::new_unix_path("/tmp/test.sock").unwrap();
472        assert_eq!(format!("{target:?}"), "/tmp/test.sock");
473
474        #[cfg(any(target_os = "linux", target_os = "android"))]
475        {
476            let target = TargetName::new_unix_domain("test").unwrap();
477            assert_eq!(format!("{target:?}"), "@test");
478        }
479    }
480}