rama_tcp/client/
connect.rs

1use rama_core::{
2    combinators::Either,
3    error::{BoxError, ErrorContext, OpaqueError},
4    Context,
5};
6use rama_dns::{DnsOverwrite, DnsResolver, HickoryDns};
7use rama_net::address::{Authority, Domain, Host};
8use std::{
9    future::Future,
10    net::{IpAddr, SocketAddr},
11    ops::Deref,
12    sync::{
13        atomic::{AtomicBool, Ordering},
14        Arc,
15    },
16    time::Duration,
17};
18use tokio::{
19    net::TcpStream,
20    sync::{
21        mpsc::{channel, Sender},
22        Semaphore,
23    },
24};
25
26/// Trait used internally by [`tcp_connect`] and the `TcpConnector`
27/// to actually establish the [`TcpStream`.]
28pub trait TcpStreamConnector: Clone + Send + Sync + 'static {
29    /// Type of error that can occurr when establishing the connection failed.
30    type Error;
31
32    /// Connect to the target via the given [`SocketAddr`]ess to establish a [`TcpStream`].
33    fn connect(
34        &self,
35        addr: SocketAddr,
36    ) -> impl Future<Output = Result<TcpStream, Self::Error>> + Send + '_;
37}
38
39impl TcpStreamConnector for () {
40    type Error = std::io::Error;
41
42    fn connect(
43        &self,
44        addr: SocketAddr,
45    ) -> impl Future<Output = Result<TcpStream, Self::Error>> + Send + '_ {
46        TcpStream::connect(addr)
47    }
48}
49
50impl<T: TcpStreamConnector> TcpStreamConnector for Arc<T> {
51    type Error = T::Error;
52
53    fn connect(
54        &self,
55        addr: SocketAddr,
56    ) -> impl Future<Output = Result<TcpStream, Self::Error>> + Send + '_ {
57        (**self).connect(addr)
58    }
59}
60
61impl<ConnectFn, ConnectFnFut, ConnectFnErr> TcpStreamConnector for ConnectFn
62where
63    ConnectFn: FnOnce(SocketAddr) -> ConnectFnFut + Clone + Send + Sync + 'static,
64    ConnectFnFut: Future<Output = Result<TcpStream, ConnectFnErr>> + Send + 'static,
65    ConnectFnErr: Into<BoxError> + Send + 'static,
66{
67    type Error = ConnectFnErr;
68
69    fn connect(
70        &self,
71        addr: SocketAddr,
72    ) -> impl Future<Output = Result<TcpStream, Self::Error>> + Send + '_ {
73        (self.clone())(addr)
74    }
75}
76
77macro_rules! impl_stream_connector_either {
78    ($id:ident, $($param:ident),+ $(,)?) => {
79        impl<$($param),+> TcpStreamConnector for ::rama_core::combinators::$id<$($param),+>
80        where
81            $(
82                $param: TcpStreamConnector<Error: Into<BoxError>>,
83            )+
84        {
85            type Error = BoxError;
86
87            async fn connect(
88                &self,
89                addr: SocketAddr,
90            ) -> Result<TcpStream, Self::Error> {
91                match self {
92                    $(
93                        ::rama_core::combinators::$id::$param(s) => s.connect(addr).await.map_err(Into::into),
94                    )+
95                }
96            }
97        }
98    };
99}
100
101::rama_core::combinators::impl_either!(impl_stream_connector_either);
102
103#[inline]
104/// Establish a [`TcpStream`] connection for the given [`Authority`],
105/// using the default settings and no custom state.
106///
107/// Use [`tcp_connect`] in case you want to customise any of these settings,
108/// or use a [`rama_net::client::ConnectorService`] for even more advanced possibilities.
109pub async fn default_tcp_connect<State>(
110    ctx: &Context<State>,
111    authority: Authority,
112) -> Result<(TcpStream, SocketAddr), OpaqueError>
113where
114    State: Clone + Send + Sync + 'static,
115{
116    tcp_connect(ctx, authority, true, HickoryDns::default(), ()).await
117}
118
119/// Establish a [`TcpStream`] connection for the given [`Authority`].
120pub async fn tcp_connect<State, Dns, Connector>(
121    ctx: &Context<State>,
122    authority: Authority,
123    allow_overwrites: bool,
124    dns: Dns,
125    connector: Connector,
126) -> Result<(TcpStream, SocketAddr), OpaqueError>
127where
128    State: Clone + Send + Sync + 'static,
129    Dns: DnsResolver<Error: Into<BoxError>> + Clone,
130    Connector: TcpStreamConnector<Error: Into<BoxError> + Send + 'static> + Clone,
131{
132    let (host, port) = authority.into_parts();
133    let domain = match host {
134        Host::Name(domain) => domain,
135        Host::Address(ip) => {
136            // if the authority is already defined as an IP address, we can directly connect to it
137            let addr = (ip, port).into();
138            let stream = connector
139                .connect(addr)
140                .await
141                .map_err(|err| OpaqueError::from_boxed(err.into()))
142                .context("establish tcp client connection")?;
143            return Ok((stream, addr));
144        }
145    };
146
147    if allow_overwrites {
148        if let Some(dns_overwrite) = ctx.get::<DnsOverwrite>() {
149            if let Ok(tuple) = tcp_connect_inner(
150                ctx,
151                domain.clone(),
152                port,
153                dns_overwrite.deref().clone(),
154                connector.clone(),
155            )
156            .await
157            {
158                return Ok(tuple);
159            }
160        }
161    }
162
163    //... otherwise we'll try to establish a connection,
164    // with dual-stack parallel connections...
165
166    tcp_connect_inner(ctx, domain, port, dns, connector).await
167}
168
169async fn tcp_connect_inner<State, Dns, Connector>(
170    ctx: &Context<State>,
171    domain: Domain,
172    port: u16,
173    dns: Dns,
174    connector: Connector,
175) -> Result<(TcpStream, SocketAddr), OpaqueError>
176where
177    State: Clone + Send + Sync + 'static,
178    Dns: DnsResolver<Error: Into<BoxError>> + Clone,
179    Connector: TcpStreamConnector<Error: Into<BoxError> + Send + 'static> + Clone,
180{
181    let (tx, mut rx) = channel(1);
182
183    let connected = Arc::new(AtomicBool::new(false));
184    let sem = Arc::new(Semaphore::new(3));
185
186    // IPv6
187    let ipv6_tx = tx.clone();
188    let ipv6_domain = domain.clone();
189    let ipv6_connected = connected.clone();
190    let ipv6_sem = sem.clone();
191    ctx.spawn(tcp_connect_inner_branch(
192        dns.clone(),
193        connector.clone(),
194        IpKind::Ipv6,
195        ipv6_domain,
196        port,
197        ipv6_tx,
198        ipv6_connected,
199        ipv6_sem,
200    ));
201
202    // IPv4
203    let ipv4_tx = tx;
204    let ipv4_domain = domain.clone();
205    let ipv4_connected = connected.clone();
206    let ipv4_sem = sem;
207    ctx.spawn(tcp_connect_inner_branch(
208        dns,
209        connector,
210        IpKind::Ipv4,
211        ipv4_domain,
212        port,
213        ipv4_tx,
214        ipv4_connected,
215        ipv4_sem,
216    ));
217
218    // wait for the first connection to succeed,
219    // ignore the rest of the connections (sorry, but not sorry)
220    if let Some((stream, addr)) = rx.recv().await {
221        connected.store(true, Ordering::Release);
222        return Ok((stream, addr));
223    }
224
225    Err(OpaqueError::from_display(format!(
226        "failed to connect to any resolved IP address for {domain} (port {port})"
227    )))
228}
229
230#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
231enum IpKind {
232    Ipv4,
233    Ipv6,
234}
235
236#[allow(clippy::too_many_arguments)]
237async fn tcp_connect_inner_branch<Dns, Connector>(
238    dns: Dns,
239    connector: Connector,
240    ip_kind: IpKind,
241    domain: Domain,
242    port: u16,
243    tx: Sender<(TcpStream, SocketAddr)>,
244    connected: Arc<AtomicBool>,
245    sem: Arc<Semaphore>,
246) where
247    Dns: DnsResolver<Error: Into<BoxError>> + Clone,
248    Connector: TcpStreamConnector<Error: Into<BoxError> + Send + 'static> + Clone,
249{
250    let ip_it = match ip_kind {
251        IpKind::Ipv4 => match dns.ipv4_lookup(domain).await {
252            Ok(ips) => Either::A(ips.into_iter().map(IpAddr::V4)),
253            Err(err) => {
254                let err = OpaqueError::from_boxed(err.into());
255                tracing::trace!(err = %err, "[{ip_kind:?}] failed to resolve domain to IPv4 addresses");
256                return;
257            }
258        },
259        IpKind::Ipv6 => match dns.ipv6_lookup(domain).await {
260            Ok(ips) => Either::B(ips.into_iter().map(IpAddr::V6)),
261            Err(err) => {
262                let err = OpaqueError::from_boxed(err.into());
263                tracing::trace!(err = ?err, "[{ip_kind:?}] failed to resolve domain to IPv6 addresses");
264                return;
265            }
266        },
267    };
268
269    for (index, ip) in ip_it.enumerate() {
270        let addr = (ip, port).into();
271
272        let sem = sem.clone();
273        let tx = tx.clone();
274        let connected = connected.clone();
275
276        // back off retries exponentially
277        if index > 0 {
278            let delay = match ip_kind {
279                IpKind::Ipv4 => Duration::from_micros((21 * 2 * index) as u64),
280                IpKind::Ipv6 => Duration::from_micros((15 * 2 * index) as u64),
281            };
282            tokio::time::sleep(delay).await;
283        }
284
285        if connected.load(Ordering::Acquire) {
286            tracing::trace!("[{ip_kind:?}] #{index}: abort connect loop to {addr} (connection already established)");
287            return;
288        }
289
290        let connector = connector.clone();
291        tokio::spawn(async move {
292            let _permit = sem.acquire().await.unwrap();
293            if connected.load(Ordering::Acquire) {
294                tracing::trace!("[{ip_kind:?}] #{index}: abort spawned attempt to {addr} (connection already established)");
295                return;
296            }
297
298            tracing::trace!("[{ip_kind:?}] #{index}: tcp connect attempt to {addr}");
299
300            match connector.connect(addr).await {
301                Ok(stream) => {
302                    tracing::trace!("[{ip_kind:?}] #{index}: tcp connection stablished to {addr}");
303                    if let Err(err) = tx.send((stream, addr)).await {
304                        tracing::trace!(err = %err, "[{ip_kind:?}] #{index}: failed to send resolved IP address");
305                    }
306                }
307                Err(err) => {
308                    let err = OpaqueError::from_boxed(err.into());
309                    tracing::trace!(err = %err, "[{ip_kind:?}] #{index}: tcp connector failed to connect");
310                }
311            };
312        });
313    }
314}