gel_stream/client/
connection.rs

1use std::marker::PhantomData;
2
3use crate::common::resolver::Resolver;
4use crate::common::tokio_stream::TokioStream;
5use crate::Target;
6use crate::{ConnectionError, ResolvedTarget, Ssl, StreamUpgrade, TlsDriver, UpgradableStream};
7
8type Connection<S, D> = UpgradableStream<S, D>;
9
10#[derive(derive_more::Debug, Clone)]
11enum ConnectorInner {
12    #[debug("{:?}", _0)]
13    Unresolved(Target, Resolver),
14    #[debug("{:?}", _0)]
15    Resolved(ResolvedTarget),
16}
17
18/// A connector can be used to connect multiple times to the same target.
19#[derive(derive_more::Debug, Clone)]
20#[allow(private_bounds)]
21pub struct Connector<D: TlsDriver = Ssl> {
22    target: ConnectorInner,
23    #[debug(skip)]
24    driver: PhantomData<D>,
25    ignore_missing_close_notify: bool,
26    #[cfg(feature = "keepalive")]
27    keepalive: Option<std::time::Duration>,
28}
29
30impl Connector<Ssl> {
31    /// Create a new connector with the given target and default resolver.
32    pub fn new(target: Target) -> Result<Self, std::io::Error> {
33        Self::new_explicit(target)
34    }
35
36    /// Create a new connector with the given resolved target.
37    pub fn new_resolved(target: ResolvedTarget) -> Self {
38        Self::new_explicit_resolved(target.into())
39    }
40
41    /// Create a new connector with the given target and resolver.
42    pub fn new_with_resolver(target: Target, resolver: Resolver) -> Self {
43        Self::new_explicit_with_resolver(target, resolver)
44    }
45}
46
47#[allow(private_bounds)]
48impl<D: TlsDriver> Connector<D> {
49    /// Create a new connector with the given TLS driver and default resolver.
50    pub fn new_explicit(target: Target) -> Result<Self, std::io::Error> {
51        Ok(Self {
52            target: ConnectorInner::Unresolved(target, Resolver::new()?),
53            driver: PhantomData,
54            ignore_missing_close_notify: false,
55            #[cfg(feature = "keepalive")]
56            keepalive: None,
57        })
58    }
59
60    /// Create a new connector with the given TLS driver and resolved target.
61    pub fn new_explicit_resolved(target: ResolvedTarget) -> Self {
62        Self {
63            target: ConnectorInner::Resolved(target),
64            driver: PhantomData,
65            ignore_missing_close_notify: false,
66            #[cfg(feature = "keepalive")]
67            keepalive: None,
68        }
69    }
70
71    /// Create a new connector with the given TLS driver and resolver.
72    pub fn new_explicit_with_resolver(target: Target, resolver: Resolver) -> Self {
73        Self {
74            target: ConnectorInner::Unresolved(target, resolver),
75            driver: PhantomData,
76            ignore_missing_close_notify: false,
77            #[cfg(feature = "keepalive")]
78            keepalive: None,
79        }
80    }
81
82    /// Set a keepalive for the connection. This is only supported for TCP
83    /// connections and will be ignored for unix sockets.
84    #[cfg(feature = "keepalive")]
85    pub fn set_keepalive(&mut self, keepalive: Option<std::time::Duration>) {
86        self.keepalive = keepalive;
87    }
88
89    /// For TLS connections, ignore a hard close where the socket was closed
90    /// before receiving CLOSE_NOTIFY.
91    ///
92    /// This may result in vulnerability to truncation attacks for protocols
93    /// that do not include an implicit length, but may also result in spurious
94    /// failures on Windows where sockets may be closed before the CLOSE_NOTIFY
95    /// is received.
96    pub fn ignore_missing_tls_close_notify(&mut self) {
97        self.ignore_missing_close_notify = true;
98    }
99
100    /// Connect to the target.
101    pub async fn connect(&self) -> Result<Connection<TokioStream, D>, ConnectionError> {
102        let target = match &self.target {
103            ConnectorInner::Unresolved(target, resolver) => {
104                resolver.resolve_remote(target.maybe_resolved()).await?
105            }
106            ConnectorInner::Resolved(target) => target.clone(),
107        };
108        let stream = target.connect().await?;
109
110        #[cfg(feature = "keepalive")]
111        if let Some(keepalive) = self.keepalive {
112            if target.is_tcp() {
113                stream.set_keepalive(Some(keepalive))?;
114            }
115        }
116
117        if let ConnectorInner::Unresolved(target, _) = &self.target {
118            if let Some(ssl) = target.maybe_ssl() {
119                let ssl = D::init_client(ssl, target.name())?;
120                let mut stm = UpgradableStream::new_client(stream, Some(ssl));
121                if self.ignore_missing_close_notify {
122                    stm.ignore_missing_close_notify();
123                }
124                if !target.is_starttls() {
125                    stm = stm.secure_upgrade().await?;
126                }
127                Ok(stm)
128            } else {
129                Ok(UpgradableStream::new_client(stream, None))
130            }
131        } else {
132            Ok(UpgradableStream::new_client(stream, None))
133        }
134    }
135}