gel_stream/client/
connection.rs

1use std::marker::PhantomData;
2use std::net::SocketAddr;
3
4use crate::common::tokio_stream::{Resolver, TokioStream};
5use crate::{ConnectionError, Ssl, StreamUpgrade, TlsDriver, UpgradableStream};
6use crate::{MaybeResolvedTarget, ResolvedTarget, Target};
7
8type Connection<S, D> = UpgradableStream<S, D>;
9
10/// A connector can be used to connect multiple times to the same target.
11#[allow(private_bounds)]
12pub struct Connector<D: TlsDriver = Ssl> {
13    target: Target,
14    resolver: Resolver,
15    driver: PhantomData<D>,
16    ignore_missing_close_notify: bool,
17    #[cfg(feature = "keepalive")]
18    keepalive: Option<std::time::Duration>,
19}
20
21impl Connector<Ssl> {
22    pub fn new(target: Target) -> Result<Self, std::io::Error> {
23        Self::new_explicit(target)
24    }
25}
26
27#[allow(private_bounds)]
28impl<D: TlsDriver> Connector<D> {
29    pub fn new_explicit(target: Target) -> Result<Self, std::io::Error> {
30        Ok(Self {
31            target,
32            resolver: Resolver::new()?,
33            driver: PhantomData,
34            ignore_missing_close_notify: false,
35            #[cfg(feature = "keepalive")]
36            keepalive: None,
37        })
38    }
39
40    /// Set a keepalive for the connection. This is only supported for TCP
41    /// connections and will be ignored for unix sockets.
42    #[cfg(feature = "keepalive")]
43    pub fn set_keepalive(&mut self, keepalive: Option<std::time::Duration>) {
44        self.keepalive = keepalive;
45    }
46
47    /// For TLS connections, ignore a hard close where the socket was closed
48    /// before receiving CLOSE_NOTIFY.
49    ///
50    /// This may result in vulnerability to truncation attacks for protocols
51    /// that do not include an implicit length, but may also result in spurious
52    /// failures on Windows where sockets may be closed before the CLOSE_NOTIFY
53    /// is received.
54    pub fn ignore_missing_tls_close_notify(&mut self) {
55        self.ignore_missing_close_notify = true;
56    }
57
58    pub async fn connect(&self) -> Result<Connection<TokioStream, D>, ConnectionError> {
59        let stream = match self.target.maybe_resolved() {
60            MaybeResolvedTarget::Resolved(target) => target.connect().await?,
61            MaybeResolvedTarget::Unresolved(host, port, _) => {
62                let ip = self
63                    .resolver
64                    .resolve_remote(host.clone().into_owned())
65                    .await?;
66                ResolvedTarget::SocketAddr(SocketAddr::new(ip, *port))
67                    .connect()
68                    .await?
69            }
70        };
71
72        #[cfg(feature = "keepalive")]
73        if let Some(keepalive) = self.keepalive {
74            if self.target.is_tcp() {
75                stream.set_keepalive(Some(keepalive))?;
76            }
77        }
78
79        if let Some(ssl) = self.target.maybe_ssl() {
80            let ssl = D::init_client(ssl, self.target.name())?;
81            let mut stm = UpgradableStream::new_client(stream, Some(ssl));
82            if self.ignore_missing_close_notify {
83                stm.ignore_missing_close_notify();
84            }
85            if !self.target.is_starttls() {
86                stm.secure_upgrade().await?;
87            }
88            Ok(stm)
89        } else {
90            Ok(UpgradableStream::new_client(stream, None))
91        }
92    }
93}