1use rama_core::{
2 Context,
3 combinators::Either,
4 error::{BoxError, ErrorContext, OpaqueError},
5};
6use rama_dns::{DnsOverwrite, DnsResolver, HickoryDns};
7use rama_net::address::{Authority, Domain, Host};
8use rama_net::mode::{ConnectIpMode, DnsResolveIpMode};
9use std::{
10 net::{IpAddr, SocketAddr},
11 ops::Deref,
12 sync::{
13 Arc,
14 atomic::{AtomicBool, Ordering},
15 },
16 time::Duration,
17};
18use tokio::sync::{
19 Semaphore,
20 mpsc::{Sender, channel},
21};
22
23use crate::TcpStream;
24
25pub trait TcpStreamConnector: Clone + Send + Sync + 'static {
28 type Error;
30
31 fn connect(
33 &self,
34 addr: SocketAddr,
35 ) -> impl Future<Output = Result<TcpStream, Self::Error>> + Send + '_;
36}
37
38impl TcpStreamConnector for () {
39 type Error = std::io::Error;
40
41 fn connect(
42 &self,
43 addr: SocketAddr,
44 ) -> impl Future<Output = Result<TcpStream, Self::Error>> + Send + '_ {
45 TcpStream::connect(addr)
46 }
47}
48
49impl<T: TcpStreamConnector> TcpStreamConnector for Arc<T> {
50 type Error = T::Error;
51
52 fn connect(
53 &self,
54 addr: SocketAddr,
55 ) -> impl Future<Output = Result<TcpStream, Self::Error>> + Send + '_ {
56 (**self).connect(addr)
57 }
58}
59
60impl<ConnectFn, ConnectFnFut, ConnectFnErr> TcpStreamConnector for ConnectFn
61where
62 ConnectFn: FnOnce(SocketAddr) -> ConnectFnFut + Clone + Send + Sync + 'static,
63 ConnectFnFut: Future<Output = Result<TcpStream, ConnectFnErr>> + Send + 'static,
64 ConnectFnErr: Into<BoxError> + Send + 'static,
65{
66 type Error = ConnectFnErr;
67
68 fn connect(
69 &self,
70 addr: SocketAddr,
71 ) -> impl Future<Output = Result<TcpStream, Self::Error>> + Send + '_ {
72 (self.clone())(addr)
73 }
74}
75
76macro_rules! impl_stream_connector_either {
77 ($id:ident, $($param:ident),+ $(,)?) => {
78 impl<$($param),+> TcpStreamConnector for ::rama_core::combinators::$id<$($param),+>
79 where
80 $(
81 $param: TcpStreamConnector<Error: Into<BoxError>>,
82 )+
83 {
84 type Error = BoxError;
85
86 async fn connect(
87 &self,
88 addr: SocketAddr,
89 ) -> Result<TcpStream, Self::Error> {
90 match self {
91 $(
92 ::rama_core::combinators::$id::$param(s) => s.connect(addr).await.map_err(Into::into),
93 )+
94 }
95 }
96 }
97 };
98}
99
100::rama_core::combinators::impl_either!(impl_stream_connector_either);
101
102#[inline]
103pub async fn default_tcp_connect<State>(
109 ctx: &Context<State>,
110 authority: Authority,
111) -> Result<(TcpStream, SocketAddr), OpaqueError>
112where
113 State: Clone + Send + Sync + 'static,
114{
115 tcp_connect(ctx, authority, true, HickoryDns::default(), ()).await
116}
117
118pub async fn tcp_connect<State, Dns, Connector>(
120 ctx: &Context<State>,
121 authority: Authority,
122 allow_overwrites: bool,
123 dns: Dns,
124 connector: Connector,
125) -> Result<(TcpStream, SocketAddr), OpaqueError>
126where
127 State: Clone + Send + Sync + 'static,
128 Dns: DnsResolver<Error: Into<BoxError>> + Clone,
129 Connector: TcpStreamConnector<Error: Into<BoxError> + Send + 'static> + Clone,
130{
131 let ip_mode = ctx.get().copied().unwrap_or_default();
132 let dns_mode = ctx.get().copied().unwrap_or_default();
133
134 let (host, port) = authority.into_parts();
135 let domain = match host {
136 Host::Name(domain) => domain,
137 Host::Address(ip) => {
138 match (ip, ip_mode) {
140 (IpAddr::V4(_), ConnectIpMode::Ipv6) => {
141 return Err(OpaqueError::from_display("IPv4 address is not allowed"));
142 }
143 (IpAddr::V6(_), ConnectIpMode::Ipv4) => {
144 return Err(OpaqueError::from_display("IPv6 address is not allowed"));
145 }
146 _ => (),
147 }
148
149 let addr = (ip, port).into();
151 let stream = connector
152 .connect(addr)
153 .await
154 .map_err(|err| OpaqueError::from_boxed(err.into()))
155 .context("establish tcp client connection")?;
156 return Ok((stream, addr));
157 }
158 };
159
160 if allow_overwrites {
161 if let Some(dns_overwrite) = ctx.get::<DnsOverwrite>() {
162 if let Ok(tuple) = tcp_connect_inner(
163 ctx,
164 domain.clone(),
165 port,
166 dns_mode,
167 dns_overwrite.deref().clone(), connector.clone(),
169 ip_mode,
170 )
171 .await
172 {
173 return Ok(tuple);
174 }
175 }
176 }
177
178 tcp_connect_inner(ctx, domain, port, dns_mode, dns, connector, ip_mode).await
182}
183
184async fn tcp_connect_inner<State, Dns, Connector>(
185 ctx: &Context<State>,
186 domain: Domain,
187 port: u16,
188 dns_mode: DnsResolveIpMode,
189 dns: Dns,
190 connector: Connector,
191 connect_mode: ConnectIpMode,
192) -> Result<(TcpStream, SocketAddr), OpaqueError>
193where
194 State: Clone + Send + Sync + 'static,
195 Dns: DnsResolver<Error: Into<BoxError>> + Clone,
196 Connector: TcpStreamConnector<Error: Into<BoxError> + Send + 'static> + Clone,
197{
198 let (tx, mut rx) = channel(1);
199 let connected = Arc::new(AtomicBool::new(false));
200 let sem = Arc::new(Semaphore::new(3));
201
202 if dns_mode.ipv4_supported() {
203 ctx.spawn(tcp_connect_inner_branch(
204 dns_mode,
205 dns.clone(),
206 connect_mode,
207 connector.clone(),
208 IpKind::Ipv4,
209 domain.clone(),
210 port,
211 tx.clone(),
212 connected.clone(),
213 sem.clone(),
214 ));
215 }
216
217 if dns_mode.ipv6_supported() {
218 ctx.spawn(tcp_connect_inner_branch(
219 dns_mode,
220 dns.clone(),
221 connect_mode,
222 connector.clone(),
223 IpKind::Ipv6,
224 domain.clone(),
225 port,
226 tx.clone(),
227 connected.clone(),
228 sem.clone(),
229 ));
230 }
231
232 drop(tx);
233 if let Some((stream, addr)) = rx.recv().await {
234 connected.store(true, Ordering::Release);
235 return Ok((stream, addr));
236 }
237
238 Err(OpaqueError::from_display(format!(
239 "failed to connect to any resolved IP address for {domain} (port {port})"
240 )))
241}
242
243#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
244enum IpKind {
245 Ipv4,
246 Ipv6,
247}
248
249#[allow(clippy::too_many_arguments)]
250async fn tcp_connect_inner_branch<Dns, Connector>(
251 dns_mode: DnsResolveIpMode,
252 dns: Dns,
253 connect_mode: ConnectIpMode,
254 connector: Connector,
255 ip_kind: IpKind,
256 domain: Domain,
257 port: u16,
258 tx: Sender<(TcpStream, SocketAddr)>,
259 connected: Arc<AtomicBool>,
260 sem: Arc<Semaphore>,
261) where
262 Dns: DnsResolver<Error: Into<BoxError>> + Clone,
263 Connector: TcpStreamConnector<Error: Into<BoxError> + Send + 'static> + Clone,
264{
265 let ip_it = match ip_kind {
266 IpKind::Ipv4 => match dns.ipv4_lookup(domain).await {
267 Ok(ips) => Either::A(ips.into_iter().map(IpAddr::V4)),
268 Err(err) => {
269 let err = OpaqueError::from_boxed(err.into());
270 tracing::trace!(err = %err, "[{ip_kind:?}] failed to resolve domain to IPv4 addresses");
271 return;
272 }
273 },
274 IpKind::Ipv6 => match dns.ipv6_lookup(domain).await {
275 Ok(ips) => Either::B(ips.into_iter().map(IpAddr::V6)),
276 Err(err) => {
277 let err = OpaqueError::from_boxed(err.into());
278 tracing::trace!(err = ?err, "[{ip_kind:?}] failed to resolve domain to IPv6 addresses");
279 return;
280 }
281 },
282 };
283
284 let (ipv4_delay_scalar, ipv6_delay_scalar) = match dns_mode {
285 DnsResolveIpMode::DualPreferIpV4 | DnsResolveIpMode::SingleIpV4 => (15 * 2, 21 * 2),
286 _ => (21 * 2, 15 * 2),
287 };
288 for (index, ip) in ip_it.enumerate() {
289 let addr = (ip, port).into();
290
291 let sem = match (ip.is_ipv4(), connect_mode) {
292 (true, ConnectIpMode::Ipv6) => {
293 tracing::trace!(
294 "[{ip_kind:?}] #{index}: abort connect loop to {addr} (IPv4 address is not allowed)"
295 );
296 continue;
297 }
298 (false, ConnectIpMode::Ipv4) => {
299 tracing::trace!(
300 "[{ip_kind:?}] #{index}: abort connect loop to {addr} (IPv6 address is not allowed)"
301 );
302 continue;
303 }
304 _ => sem.clone(),
305 };
306
307 let tx = tx.clone();
308 let connected = connected.clone();
309
310 if index > 0 {
312 let delay = match ip_kind {
313 IpKind::Ipv4 => Duration::from_micros((ipv4_delay_scalar * index) as u64),
314 IpKind::Ipv6 => Duration::from_micros((ipv6_delay_scalar * index) as u64),
315 };
316 tokio::time::sleep(delay).await;
317 }
318
319 if connected.load(Ordering::Acquire) {
320 tracing::trace!(
321 "[{ip_kind:?}] #{index}: abort connect loop to {addr} (connection already established)"
322 );
323 return;
324 }
325
326 let connector = connector.clone();
327 tokio::spawn(async move {
328 let _permit = sem.acquire().await.unwrap();
329 if connected.load(Ordering::Acquire) {
330 tracing::trace!(
331 "[{ip_kind:?}] #{index}: abort spawned attempt to {addr} (connection already established)"
332 );
333 return;
334 }
335
336 tracing::trace!("[{ip_kind:?}] #{index}: tcp connect attempt to {addr}");
337
338 match connector.connect(addr).await {
339 Ok(stream) => {
340 tracing::trace!("[{ip_kind:?}] #{index}: tcp connection stablished to {addr}");
341 if let Err(err) = tx.send((stream, addr)).await {
342 tracing::trace!(err = %err, "[{ip_kind:?}] #{index}: failed to send resolved IP address");
343 }
344 }
345 Err(err) => {
346 let err = OpaqueError::from_boxed(err.into());
347 tracing::trace!(err = %err, "[{ip_kind:?}] #{index}: tcp connector failed to connect");
348 }
349 };
350 });
351 }
352}