1use rama_core::{
2 combinators::Either,
3 error::{BoxError, ErrorContext, OpaqueError},
4 Context,
5};
6use rama_dns::{DnsOverwrite, DnsResolver, HickoryDns};
7use rama_net::address::{Authority, Domain, Host};
8use std::{
9 future::Future,
10 net::{IpAddr, SocketAddr},
11 ops::Deref,
12 sync::{
13 atomic::{AtomicBool, Ordering},
14 Arc,
15 },
16 time::Duration,
17};
18use tokio::{
19 net::TcpStream,
20 sync::{
21 mpsc::{channel, Sender},
22 Semaphore,
23 },
24};
25
26pub trait TcpStreamConnector: Clone + Send + Sync + 'static {
29 type Error;
31
32 fn connect(
34 &self,
35 addr: SocketAddr,
36 ) -> impl Future<Output = Result<TcpStream, Self::Error>> + Send + '_;
37}
38
39impl TcpStreamConnector for () {
40 type Error = std::io::Error;
41
42 fn connect(
43 &self,
44 addr: SocketAddr,
45 ) -> impl Future<Output = Result<TcpStream, Self::Error>> + Send + '_ {
46 TcpStream::connect(addr)
47 }
48}
49
50impl<T: TcpStreamConnector> TcpStreamConnector for Arc<T> {
51 type Error = T::Error;
52
53 fn connect(
54 &self,
55 addr: SocketAddr,
56 ) -> impl Future<Output = Result<TcpStream, Self::Error>> + Send + '_ {
57 (**self).connect(addr)
58 }
59}
60
61impl<ConnectFn, ConnectFnFut, ConnectFnErr> TcpStreamConnector for ConnectFn
62where
63 ConnectFn: FnOnce(SocketAddr) -> ConnectFnFut + Clone + Send + Sync + 'static,
64 ConnectFnFut: Future<Output = Result<TcpStream, ConnectFnErr>> + Send + 'static,
65 ConnectFnErr: Into<BoxError> + Send + 'static,
66{
67 type Error = ConnectFnErr;
68
69 fn connect(
70 &self,
71 addr: SocketAddr,
72 ) -> impl Future<Output = Result<TcpStream, Self::Error>> + Send + '_ {
73 (self.clone())(addr)
74 }
75}
76
77macro_rules! impl_stream_connector_either {
78 ($id:ident, $($param:ident),+ $(,)?) => {
79 impl<$($param),+> TcpStreamConnector for ::rama_core::combinators::$id<$($param),+>
80 where
81 $(
82 $param: TcpStreamConnector<Error: Into<BoxError>>,
83 )+
84 {
85 type Error = BoxError;
86
87 async fn connect(
88 &self,
89 addr: SocketAddr,
90 ) -> Result<TcpStream, Self::Error> {
91 match self {
92 $(
93 ::rama_core::combinators::$id::$param(s) => s.connect(addr).await.map_err(Into::into),
94 )+
95 }
96 }
97 }
98 };
99}
100
101::rama_core::combinators::impl_either!(impl_stream_connector_either);
102
103#[inline]
104pub async fn default_tcp_connect<State>(
110 ctx: &Context<State>,
111 authority: Authority,
112) -> Result<(TcpStream, SocketAddr), OpaqueError>
113where
114 State: Clone + Send + Sync + 'static,
115{
116 tcp_connect(ctx, authority, true, HickoryDns::default(), ()).await
117}
118
119pub async fn tcp_connect<State, Dns, Connector>(
121 ctx: &Context<State>,
122 authority: Authority,
123 allow_overwrites: bool,
124 dns: Dns,
125 connector: Connector,
126) -> Result<(TcpStream, SocketAddr), OpaqueError>
127where
128 State: Clone + Send + Sync + 'static,
129 Dns: DnsResolver<Error: Into<BoxError>> + Clone,
130 Connector: TcpStreamConnector<Error: Into<BoxError> + Send + 'static> + Clone,
131{
132 let (host, port) = authority.into_parts();
133 let domain = match host {
134 Host::Name(domain) => domain,
135 Host::Address(ip) => {
136 let addr = (ip, port).into();
138 let stream = connector
139 .connect(addr)
140 .await
141 .map_err(|err| OpaqueError::from_boxed(err.into()))
142 .context("establish tcp client connection")?;
143 return Ok((stream, addr));
144 }
145 };
146
147 if allow_overwrites {
148 if let Some(dns_overwrite) = ctx.get::<DnsOverwrite>() {
149 if let Ok(tuple) = tcp_connect_inner(
150 ctx,
151 domain.clone(),
152 port,
153 dns_overwrite.deref().clone(),
154 connector.clone(),
155 )
156 .await
157 {
158 return Ok(tuple);
159 }
160 }
161 }
162
163 tcp_connect_inner(ctx, domain, port, dns, connector).await
167}
168
169async fn tcp_connect_inner<State, Dns, Connector>(
170 ctx: &Context<State>,
171 domain: Domain,
172 port: u16,
173 dns: Dns,
174 connector: Connector,
175) -> Result<(TcpStream, SocketAddr), OpaqueError>
176where
177 State: Clone + Send + Sync + 'static,
178 Dns: DnsResolver<Error: Into<BoxError>> + Clone,
179 Connector: TcpStreamConnector<Error: Into<BoxError> + Send + 'static> + Clone,
180{
181 let (tx, mut rx) = channel(1);
182
183 let connected = Arc::new(AtomicBool::new(false));
184 let sem = Arc::new(Semaphore::new(3));
185
186 let ipv6_tx = tx.clone();
188 let ipv6_domain = domain.clone();
189 let ipv6_connected = connected.clone();
190 let ipv6_sem = sem.clone();
191 ctx.spawn(tcp_connect_inner_branch(
192 dns.clone(),
193 connector.clone(),
194 IpKind::Ipv6,
195 ipv6_domain,
196 port,
197 ipv6_tx,
198 ipv6_connected,
199 ipv6_sem,
200 ));
201
202 let ipv4_tx = tx;
204 let ipv4_domain = domain.clone();
205 let ipv4_connected = connected.clone();
206 let ipv4_sem = sem;
207 ctx.spawn(tcp_connect_inner_branch(
208 dns,
209 connector,
210 IpKind::Ipv4,
211 ipv4_domain,
212 port,
213 ipv4_tx,
214 ipv4_connected,
215 ipv4_sem,
216 ));
217
218 if let Some((stream, addr)) = rx.recv().await {
221 connected.store(true, Ordering::Release);
222 return Ok((stream, addr));
223 }
224
225 Err(OpaqueError::from_display(format!(
226 "failed to connect to any resolved IP address for {domain} (port {port})"
227 )))
228}
229
230#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
231enum IpKind {
232 Ipv4,
233 Ipv6,
234}
235
236#[allow(clippy::too_many_arguments)]
237async fn tcp_connect_inner_branch<Dns, Connector>(
238 dns: Dns,
239 connector: Connector,
240 ip_kind: IpKind,
241 domain: Domain,
242 port: u16,
243 tx: Sender<(TcpStream, SocketAddr)>,
244 connected: Arc<AtomicBool>,
245 sem: Arc<Semaphore>,
246) where
247 Dns: DnsResolver<Error: Into<BoxError>> + Clone,
248 Connector: TcpStreamConnector<Error: Into<BoxError> + Send + 'static> + Clone,
249{
250 let ip_it = match ip_kind {
251 IpKind::Ipv4 => match dns.ipv4_lookup(domain).await {
252 Ok(ips) => Either::A(ips.into_iter().map(IpAddr::V4)),
253 Err(err) => {
254 let err = OpaqueError::from_boxed(err.into());
255 tracing::trace!(err = %err, "[{ip_kind:?}] failed to resolve domain to IPv4 addresses");
256 return;
257 }
258 },
259 IpKind::Ipv6 => match dns.ipv6_lookup(domain).await {
260 Ok(ips) => Either::B(ips.into_iter().map(IpAddr::V6)),
261 Err(err) => {
262 let err = OpaqueError::from_boxed(err.into());
263 tracing::trace!(err = ?err, "[{ip_kind:?}] failed to resolve domain to IPv6 addresses");
264 return;
265 }
266 },
267 };
268
269 for (index, ip) in ip_it.enumerate() {
270 let addr = (ip, port).into();
271
272 let sem = sem.clone();
273 let tx = tx.clone();
274 let connected = connected.clone();
275
276 if index > 0 {
278 let delay = match ip_kind {
279 IpKind::Ipv4 => Duration::from_micros((21 * 2 * index) as u64),
280 IpKind::Ipv6 => Duration::from_micros((15 * 2 * index) as u64),
281 };
282 tokio::time::sleep(delay).await;
283 }
284
285 if connected.load(Ordering::Acquire) {
286 tracing::trace!("[{ip_kind:?}] #{index}: abort connect loop to {addr} (connection already established)");
287 return;
288 }
289
290 let connector = connector.clone();
291 tokio::spawn(async move {
292 let _permit = sem.acquire().await.unwrap();
293 if connected.load(Ordering::Acquire) {
294 tracing::trace!("[{ip_kind:?}] #{index}: abort spawned attempt to {addr} (connection already established)");
295 return;
296 }
297
298 tracing::trace!("[{ip_kind:?}] #{index}: tcp connect attempt to {addr}");
299
300 match connector.connect(addr).await {
301 Ok(stream) => {
302 tracing::trace!("[{ip_kind:?}] #{index}: tcp connection stablished to {addr}");
303 if let Err(err) = tx.send((stream, addr)).await {
304 tracing::trace!(err = %err, "[{ip_kind:?}] #{index}: failed to send resolved IP address");
305 }
306 }
307 Err(err) => {
308 let err = OpaqueError::from_boxed(err.into());
309 tracing::trace!(err = %err, "[{ip_kind:?}] #{index}: tcp connector failed to connect");
310 }
311 };
312 });
313 }
314}