rama_tcp/client/
connect.rs

1use rama_core::{
2    Context,
3    combinators::Either,
4    error::{BoxError, ErrorContext, OpaqueError},
5};
6use rama_dns::{DnsOverwrite, DnsResolver, HickoryDns};
7use rama_net::address::{Authority, Domain, Host};
8use rama_net::mode::{ConnectIpMode, DnsResolveIpMode};
9use std::{
10    net::{IpAddr, SocketAddr},
11    ops::Deref,
12    sync::{
13        Arc,
14        atomic::{AtomicBool, Ordering},
15    },
16    time::Duration,
17};
18use tokio::sync::{
19    Semaphore,
20    mpsc::{Sender, channel},
21};
22
23use crate::TcpStream;
24
25/// Trait used internally by [`tcp_connect`] and the `TcpConnector`
26/// to actually establish the [`TcpStream`.]
27pub trait TcpStreamConnector: Clone + Send + Sync + 'static {
28    /// Type of error that can occurr when establishing the connection failed.
29    type Error;
30
31    /// Connect to the target via the given [`SocketAddr`]ess to establish a [`TcpStream`].
32    fn connect(
33        &self,
34        addr: SocketAddr,
35    ) -> impl Future<Output = Result<TcpStream, Self::Error>> + Send + '_;
36}
37
38impl TcpStreamConnector for () {
39    type Error = std::io::Error;
40
41    fn connect(
42        &self,
43        addr: SocketAddr,
44    ) -> impl Future<Output = Result<TcpStream, Self::Error>> + Send + '_ {
45        TcpStream::connect(addr)
46    }
47}
48
49impl<T: TcpStreamConnector> TcpStreamConnector for Arc<T> {
50    type Error = T::Error;
51
52    fn connect(
53        &self,
54        addr: SocketAddr,
55    ) -> impl Future<Output = Result<TcpStream, Self::Error>> + Send + '_ {
56        (**self).connect(addr)
57    }
58}
59
60impl<ConnectFn, ConnectFnFut, ConnectFnErr> TcpStreamConnector for ConnectFn
61where
62    ConnectFn: FnOnce(SocketAddr) -> ConnectFnFut + Clone + Send + Sync + 'static,
63    ConnectFnFut: Future<Output = Result<TcpStream, ConnectFnErr>> + Send + 'static,
64    ConnectFnErr: Into<BoxError> + Send + 'static,
65{
66    type Error = ConnectFnErr;
67
68    fn connect(
69        &self,
70        addr: SocketAddr,
71    ) -> impl Future<Output = Result<TcpStream, Self::Error>> + Send + '_ {
72        (self.clone())(addr)
73    }
74}
75
76macro_rules! impl_stream_connector_either {
77    ($id:ident, $($param:ident),+ $(,)?) => {
78        impl<$($param),+> TcpStreamConnector for ::rama_core::combinators::$id<$($param),+>
79        where
80            $(
81                $param: TcpStreamConnector<Error: Into<BoxError>>,
82            )+
83        {
84            type Error = BoxError;
85
86            async fn connect(
87                &self,
88                addr: SocketAddr,
89            ) -> Result<TcpStream, Self::Error> {
90                match self {
91                    $(
92                        ::rama_core::combinators::$id::$param(s) => s.connect(addr).await.map_err(Into::into),
93                    )+
94                }
95            }
96        }
97    };
98}
99
100::rama_core::combinators::impl_either!(impl_stream_connector_either);
101
102#[inline]
103/// Establish a [`TcpStream`] connection for the given [`Authority`],
104/// using the default settings and no custom state.
105///
106/// Use [`tcp_connect`] in case you want to customise any of these settings,
107/// or use a [`rama_net::client::ConnectorService`] for even more advanced possibilities.
108pub async fn default_tcp_connect<State>(
109    ctx: &Context<State>,
110    authority: Authority,
111) -> Result<(TcpStream, SocketAddr), OpaqueError>
112where
113    State: Clone + Send + Sync + 'static,
114{
115    tcp_connect(ctx, authority, true, HickoryDns::default(), ()).await
116}
117
118/// Establish a [`TcpStream`] connection for the given [`Authority`].
119pub async fn tcp_connect<State, Dns, Connector>(
120    ctx: &Context<State>,
121    authority: Authority,
122    allow_overwrites: bool,
123    dns: Dns,
124    connector: Connector,
125) -> Result<(TcpStream, SocketAddr), OpaqueError>
126where
127    State: Clone + Send + Sync + 'static,
128    Dns: DnsResolver<Error: Into<BoxError>> + Clone,
129    Connector: TcpStreamConnector<Error: Into<BoxError> + Send + 'static> + Clone,
130{
131    let ip_mode = ctx.get().copied().unwrap_or_default();
132    let dns_mode = ctx.get().copied().unwrap_or_default();
133
134    let (host, port) = authority.into_parts();
135    let domain = match host {
136        Host::Name(domain) => domain,
137        Host::Address(ip) => {
138            //check if IP Version is allowed
139            match (ip, ip_mode) {
140                (IpAddr::V4(_), ConnectIpMode::Ipv6) => {
141                    return Err(OpaqueError::from_display("IPv4 address is not allowed"));
142                }
143                (IpAddr::V6(_), ConnectIpMode::Ipv4) => {
144                    return Err(OpaqueError::from_display("IPv6 address is not allowed"));
145                }
146                _ => (),
147            }
148
149            // if the authority is already defined as an IP address, we can directly connect to it
150            let addr = (ip, port).into();
151            let stream = connector
152                .connect(addr)
153                .await
154                .map_err(|err| OpaqueError::from_boxed(err.into()))
155                .context("establish tcp client connection")?;
156            return Ok((stream, addr));
157        }
158    };
159
160    if allow_overwrites {
161        if let Some(dns_overwrite) = ctx.get::<DnsOverwrite>() {
162            if let Ok(tuple) = tcp_connect_inner(
163                ctx,
164                domain.clone(),
165                port,
166                dns_mode,
167                dns_overwrite.deref().clone(), // Convert DnsOverwrite to a DnsResolver
168                connector.clone(),
169                ip_mode,
170            )
171            .await
172            {
173                return Ok(tuple);
174            }
175        }
176    }
177
178    //... otherwise we'll try to establish a connection,
179    // with dual-stack parallel connections...
180
181    tcp_connect_inner(ctx, domain, port, dns_mode, dns, connector, ip_mode).await
182}
183
184async fn tcp_connect_inner<State, Dns, Connector>(
185    ctx: &Context<State>,
186    domain: Domain,
187    port: u16,
188    dns_mode: DnsResolveIpMode,
189    dns: Dns,
190    connector: Connector,
191    connect_mode: ConnectIpMode,
192) -> Result<(TcpStream, SocketAddr), OpaqueError>
193where
194    State: Clone + Send + Sync + 'static,
195    Dns: DnsResolver<Error: Into<BoxError>> + Clone,
196    Connector: TcpStreamConnector<Error: Into<BoxError> + Send + 'static> + Clone,
197{
198    let (tx, mut rx) = channel(1);
199    let connected = Arc::new(AtomicBool::new(false));
200    let sem = Arc::new(Semaphore::new(3));
201
202    if dns_mode.ipv4_supported() {
203        ctx.spawn(tcp_connect_inner_branch(
204            dns_mode,
205            dns.clone(),
206            connect_mode,
207            connector.clone(),
208            IpKind::Ipv4,
209            domain.clone(),
210            port,
211            tx.clone(),
212            connected.clone(),
213            sem.clone(),
214        ));
215    }
216
217    if dns_mode.ipv6_supported() {
218        ctx.spawn(tcp_connect_inner_branch(
219            dns_mode,
220            dns.clone(),
221            connect_mode,
222            connector.clone(),
223            IpKind::Ipv6,
224            domain.clone(),
225            port,
226            tx.clone(),
227            connected.clone(),
228            sem.clone(),
229        ));
230    }
231
232    drop(tx);
233    if let Some((stream, addr)) = rx.recv().await {
234        connected.store(true, Ordering::Release);
235        return Ok((stream, addr));
236    }
237
238    Err(OpaqueError::from_display(format!(
239        "failed to connect to any resolved IP address for {domain} (port {port})"
240    )))
241}
242
243#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
244enum IpKind {
245    Ipv4,
246    Ipv6,
247}
248
249#[allow(clippy::too_many_arguments)]
250async fn tcp_connect_inner_branch<Dns, Connector>(
251    dns_mode: DnsResolveIpMode,
252    dns: Dns,
253    connect_mode: ConnectIpMode,
254    connector: Connector,
255    ip_kind: IpKind,
256    domain: Domain,
257    port: u16,
258    tx: Sender<(TcpStream, SocketAddr)>,
259    connected: Arc<AtomicBool>,
260    sem: Arc<Semaphore>,
261) where
262    Dns: DnsResolver<Error: Into<BoxError>> + Clone,
263    Connector: TcpStreamConnector<Error: Into<BoxError> + Send + 'static> + Clone,
264{
265    let ip_it = match ip_kind {
266        IpKind::Ipv4 => match dns.ipv4_lookup(domain).await {
267            Ok(ips) => Either::A(ips.into_iter().map(IpAddr::V4)),
268            Err(err) => {
269                let err = OpaqueError::from_boxed(err.into());
270                tracing::trace!(err = %err, "[{ip_kind:?}] failed to resolve domain to IPv4 addresses");
271                return;
272            }
273        },
274        IpKind::Ipv6 => match dns.ipv6_lookup(domain).await {
275            Ok(ips) => Either::B(ips.into_iter().map(IpAddr::V6)),
276            Err(err) => {
277                let err = OpaqueError::from_boxed(err.into());
278                tracing::trace!(err = ?err, "[{ip_kind:?}] failed to resolve domain to IPv6 addresses");
279                return;
280            }
281        },
282    };
283
284    let (ipv4_delay_scalar, ipv6_delay_scalar) = match dns_mode {
285        DnsResolveIpMode::DualPreferIpV4 | DnsResolveIpMode::SingleIpV4 => (15 * 2, 21 * 2),
286        _ => (21 * 2, 15 * 2),
287    };
288    for (index, ip) in ip_it.enumerate() {
289        let addr = (ip, port).into();
290
291        let sem = match (ip.is_ipv4(), connect_mode) {
292            (true, ConnectIpMode::Ipv6) => {
293                tracing::trace!(
294                    "[{ip_kind:?}] #{index}: abort connect loop to {addr} (IPv4 address is not allowed)"
295                );
296                continue;
297            }
298            (false, ConnectIpMode::Ipv4) => {
299                tracing::trace!(
300                    "[{ip_kind:?}] #{index}: abort connect loop to {addr} (IPv6 address is not allowed)"
301                );
302                continue;
303            }
304            _ => sem.clone(),
305        };
306
307        let tx = tx.clone();
308        let connected = connected.clone();
309
310        // back off retries exponentially
311        if index > 0 {
312            let delay = match ip_kind {
313                IpKind::Ipv4 => Duration::from_micros((ipv4_delay_scalar * index) as u64),
314                IpKind::Ipv6 => Duration::from_micros((ipv6_delay_scalar * index) as u64),
315            };
316            tokio::time::sleep(delay).await;
317        }
318
319        if connected.load(Ordering::Acquire) {
320            tracing::trace!(
321                "[{ip_kind:?}] #{index}: abort connect loop to {addr} (connection already established)"
322            );
323            return;
324        }
325
326        let connector = connector.clone();
327        tokio::spawn(async move {
328            let _permit = sem.acquire().await.unwrap();
329            if connected.load(Ordering::Acquire) {
330                tracing::trace!(
331                    "[{ip_kind:?}] #{index}: abort spawned attempt to {addr} (connection already established)"
332                );
333                return;
334            }
335
336            tracing::trace!("[{ip_kind:?}] #{index}: tcp connect attempt to {addr}");
337
338            match connector.connect(addr).await {
339                Ok(stream) => {
340                    tracing::trace!("[{ip_kind:?}] #{index}: tcp connection stablished to {addr}");
341                    if let Err(err) = tx.send((stream, addr)).await {
342                        tracing::trace!(err = %err, "[{ip_kind:?}] #{index}: failed to send resolved IP address");
343                    }
344                }
345                Err(err) => {
346                    let err = OpaqueError::from_boxed(err.into());
347                    tracing::trace!(err = %err, "[{ip_kind:?}] #{index}: tcp connector failed to connect");
348                }
349            };
350        });
351    }
352}