hyper_util_wasm/client/legacy/connect/
http.rs

1use std::error::Error as StdError;
2use std::fmt;
3use std::future::Future;
4use std::io;
5use std::marker::PhantomData;
6use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{self, Poll};
10use std::time::Duration;
11
12use futures_util::future::Either;
13use http::uri::{Scheme, Uri};
14use pin_project_lite::pin_project;
15use socket2::TcpKeepalive;
16use tokio::net::{TcpSocket, TcpStream};
17use tokio::time::Sleep;
18use tracing::{debug, trace, warn};
19
20use super::dns::{self, resolve, GaiResolver, Resolve};
21use super::{Connected, Connection};
22use crate::rt::TokioIo;
23
24/// A connector for the `http` scheme.
25///
26/// Performs DNS resolution in a thread pool, and then connects over TCP.
27///
28/// # Note
29///
30/// Sets the [`HttpInfo`](HttpInfo) value on responses, which includes
31/// transport information such as the remote socket address used.
32#[derive(Clone)]
33pub struct HttpConnector<R = GaiResolver> {
34    config: Arc<Config>,
35    resolver: R,
36}
37
38/// Extra information about the transport when an HttpConnector is used.
39///
40/// # Example
41///
42/// ```
43/// # fn doc(res: http::Response<()>) {
44/// use hyper_util::client::legacy::connect::HttpInfo;
45///
46/// // res = http::Response
47/// res
48///     .extensions()
49///     .get::<HttpInfo>()
50///     .map(|info| {
51///         println!("remote addr = {}", info.remote_addr());
52///     });
53/// # }
54/// ```
55///
56/// # Note
57///
58/// If a different connector is used besides [`HttpConnector`](HttpConnector),
59/// this value will not exist in the extensions. Consult that specific
60/// connector to see what "extra" information it might provide to responses.
61#[derive(Clone, Debug)]
62pub struct HttpInfo {
63    remote_addr: SocketAddr,
64    local_addr: SocketAddr,
65}
66
67#[derive(Clone)]
68struct Config {
69    connect_timeout: Option<Duration>,
70    enforce_http: bool,
71    happy_eyeballs_timeout: Option<Duration>,
72    tcp_keepalive_config: TcpKeepaliveConfig,
73    local_address_ipv4: Option<Ipv4Addr>,
74    local_address_ipv6: Option<Ipv6Addr>,
75    nodelay: bool,
76    reuse_address: bool,
77    send_buffer_size: Option<usize>,
78    recv_buffer_size: Option<usize>,
79    interface: Option<String>,
80}
81
82#[derive(Default, Debug, Clone, Copy)]
83struct TcpKeepaliveConfig {
84    time: Option<Duration>,
85    interval: Option<Duration>,
86    retries: Option<u32>,
87}
88
89impl TcpKeepaliveConfig {
90    /// Converts into a `socket2::TcpKeealive` if there is any keep alive configuration.
91    fn into_tcpkeepalive(self) -> Option<TcpKeepalive> {
92        let mut dirty = false;
93        let mut ka = TcpKeepalive::new();
94        if let Some(time) = self.time {
95            ka = ka.with_time(time);
96            dirty = true
97        }
98        if let Some(interval) = self.interval {
99            ka = Self::ka_with_interval(ka, interval, &mut dirty)
100        };
101        if let Some(retries) = self.retries {
102            ka = Self::ka_with_retries(ka, retries, &mut dirty)
103        };
104        if dirty {
105            Some(ka)
106        } else {
107            None
108        }
109    }
110
111    #[cfg(not(any(
112        target_os = "aix",
113        target_os = "openbsd",
114        target_os = "redox",
115        target_os = "solaris"
116    )))]
117    fn ka_with_interval(ka: TcpKeepalive, interval: Duration, dirty: &mut bool) -> TcpKeepalive {
118        *dirty = true;
119        ka.with_interval(interval)
120    }
121
122    #[cfg(any(
123        target_os = "aix",
124        target_os = "openbsd",
125        target_os = "redox",
126        target_os = "solaris"
127    ))]
128    fn ka_with_interval(ka: TcpKeepalive, _: Duration, _: &mut bool) -> TcpKeepalive {
129        ka // no-op as keepalive interval is not supported on this platform
130    }
131
132    #[cfg(not(any(
133        target_os = "aix",
134        target_os = "openbsd",
135        target_os = "redox",
136        target_os = "solaris",
137        target_os = "windows"
138    )))]
139    fn ka_with_retries(ka: TcpKeepalive, retries: u32, dirty: &mut bool) -> TcpKeepalive {
140        *dirty = true;
141        ka.with_retries(retries)
142    }
143
144    #[cfg(any(
145        target_os = "aix",
146        target_os = "openbsd",
147        target_os = "redox",
148        target_os = "solaris",
149        target_os = "windows"
150    ))]
151    fn ka_with_retries(ka: TcpKeepalive, _: u32, _: &mut bool) -> TcpKeepalive {
152        ka // no-op as keepalive retries is not supported on this platform
153    }
154}
155
156// ===== impl HttpConnector =====
157
158impl HttpConnector {
159    /// Construct a new HttpConnector.
160    pub fn new() -> HttpConnector {
161        HttpConnector::new_with_resolver(GaiResolver::new())
162    }
163}
164
165impl<R> HttpConnector<R> {
166    /// Construct a new HttpConnector.
167    ///
168    /// Takes a [`Resolver`](crate::client::connect::dns#resolvers-are-services) to handle DNS lookups.
169    pub fn new_with_resolver(resolver: R) -> HttpConnector<R> {
170        HttpConnector {
171            config: Arc::new(Config {
172                connect_timeout: None,
173                enforce_http: true,
174                happy_eyeballs_timeout: Some(Duration::from_millis(300)),
175                tcp_keepalive_config: TcpKeepaliveConfig::default(),
176                local_address_ipv4: None,
177                local_address_ipv6: None,
178                nodelay: false,
179                reuse_address: false,
180                send_buffer_size: None,
181                recv_buffer_size: None,
182                interface: None,
183            }),
184            resolver,
185        }
186    }
187
188    /// Option to enforce all `Uri`s have the `http` scheme.
189    ///
190    /// Enabled by default.
191    #[inline]
192    pub fn enforce_http(&mut self, is_enforced: bool) {
193        self.config_mut().enforce_http = is_enforced;
194    }
195
196    /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration
197    /// to remain idle before sending TCP keepalive probes.
198    ///
199    /// If `None`, keepalive is disabled.
200    ///
201    /// Default is `None`.
202    #[inline]
203    pub fn set_keepalive(&mut self, time: Option<Duration>) {
204        self.config_mut().tcp_keepalive_config.time = time;
205    }
206
207    /// Set the duration between two successive TCP keepalive retransmissions,
208    /// if acknowledgement to the previous keepalive transmission is not received.
209    #[inline]
210    pub fn set_keepalive_interval(&mut self, interval: Option<Duration>) {
211        self.config_mut().tcp_keepalive_config.interval = interval;
212    }
213
214    /// Set the number of retransmissions to be carried out before declaring that remote end is not available.
215    #[inline]
216    pub fn set_keepalive_retries(&mut self, retries: Option<u32>) {
217        self.config_mut().tcp_keepalive_config.retries = retries;
218    }
219
220    /// Set that all sockets have `SO_NODELAY` set to the supplied value `nodelay`.
221    ///
222    /// Default is `false`.
223    #[inline]
224    pub fn set_nodelay(&mut self, nodelay: bool) {
225        self.config_mut().nodelay = nodelay;
226    }
227
228    /// Sets the value of the SO_SNDBUF option on the socket.
229    #[inline]
230    pub fn set_send_buffer_size(&mut self, size: Option<usize>) {
231        self.config_mut().send_buffer_size = size;
232    }
233
234    /// Sets the value of the SO_RCVBUF option on the socket.
235    #[inline]
236    pub fn set_recv_buffer_size(&mut self, size: Option<usize>) {
237        self.config_mut().recv_buffer_size = size;
238    }
239
240    /// Set that all sockets are bound to the configured address before connection.
241    ///
242    /// If `None`, the sockets will not be bound.
243    ///
244    /// Default is `None`.
245    #[inline]
246    pub fn set_local_address(&mut self, addr: Option<IpAddr>) {
247        let (v4, v6) = match addr {
248            Some(IpAddr::V4(a)) => (Some(a), None),
249            Some(IpAddr::V6(a)) => (None, Some(a)),
250            _ => (None, None),
251        };
252
253        let cfg = self.config_mut();
254
255        cfg.local_address_ipv4 = v4;
256        cfg.local_address_ipv6 = v6;
257    }
258
259    /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's
260    /// preferences) before connection.
261    #[inline]
262    pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) {
263        let cfg = self.config_mut();
264
265        cfg.local_address_ipv4 = Some(addr_ipv4);
266        cfg.local_address_ipv6 = Some(addr_ipv6);
267    }
268
269    /// Set the connect timeout.
270    ///
271    /// If a domain resolves to multiple IP addresses, the timeout will be
272    /// evenly divided across them.
273    ///
274    /// Default is `None`.
275    #[inline]
276    pub fn set_connect_timeout(&mut self, dur: Option<Duration>) {
277        self.config_mut().connect_timeout = dur;
278    }
279
280    /// Set timeout for [RFC 6555 (Happy Eyeballs)][RFC 6555] algorithm.
281    ///
282    /// If hostname resolves to both IPv4 and IPv6 addresses and connection
283    /// cannot be established using preferred address family before timeout
284    /// elapses, then connector will in parallel attempt connection using other
285    /// address family.
286    ///
287    /// If `None`, parallel connection attempts are disabled.
288    ///
289    /// Default is 300 milliseconds.
290    ///
291    /// [RFC 6555]: https://tools.ietf.org/html/rfc6555
292    #[inline]
293    pub fn set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>) {
294        self.config_mut().happy_eyeballs_timeout = dur;
295    }
296
297    /// Set that all socket have `SO_REUSEADDR` set to the supplied value `reuse_address`.
298    ///
299    /// Default is `false`.
300    #[inline]
301    pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self {
302        self.config_mut().reuse_address = reuse_address;
303        self
304    }
305
306    /// Sets the value for the `SO_BINDTODEVICE` option on this socket.
307    ///
308    /// If a socket is bound to an interface, only packets received from that particular
309    /// interface are processed by the socket. Note that this only works for some socket
310    /// types, particularly AF_INET sockets.
311    ///
312    /// On Linux it can be used to specify a [VRF], but the binary needs
313    /// to either have `CAP_NET_RAW` or to be run as root.
314    ///
315    /// This function is only available on Android、Fuchsia and Linux.
316    ///
317    /// [VRF]: https://www.kernel.org/doc/Documentation/networking/vrf.txt
318    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
319    #[inline]
320    pub fn set_interface<S: Into<String>>(&mut self, interface: S) -> &mut Self {
321        self.config_mut().interface = Some(interface.into());
322        self
323    }
324
325    // private
326
327    fn config_mut(&mut self) -> &mut Config {
328        // If the are HttpConnector clones, this will clone the inner
329        // config. So mutating the config won't ever affect previous
330        // clones.
331        Arc::make_mut(&mut self.config)
332    }
333}
334
335static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http";
336static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing";
337static INVALID_MISSING_HOST: &str = "invalid URL, host is missing";
338
339// R: Debug required for now to allow adding it to debug output later...
340impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> {
341    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
342        f.debug_struct("HttpConnector").finish()
343    }
344}
345
346impl<R> tower_service::Service<Uri> for HttpConnector<R>
347where
348    R: Resolve + Clone + Send + Sync + 'static,
349    R::Future: Send,
350{
351    type Response = TokioIo<TcpStream>;
352    type Error = ConnectError;
353    type Future = HttpConnecting<R>;
354
355    fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
356        futures_util::ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
357        Poll::Ready(Ok(()))
358    }
359
360    fn call(&mut self, dst: Uri) -> Self::Future {
361        let mut self_ = self.clone();
362        HttpConnecting {
363            fut: Box::pin(async move { self_.call_async(dst).await }),
364            _marker: PhantomData,
365        }
366    }
367}
368
369fn get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError> {
370    trace!(
371        "Http::connect; scheme={:?}, host={:?}, port={:?}",
372        dst.scheme(),
373        dst.host(),
374        dst.port(),
375    );
376
377    if config.enforce_http {
378        if dst.scheme() != Some(&Scheme::HTTP) {
379            return Err(ConnectError {
380                msg: INVALID_NOT_HTTP.into(),
381                cause: None,
382            });
383        }
384    } else if dst.scheme().is_none() {
385        return Err(ConnectError {
386            msg: INVALID_MISSING_SCHEME.into(),
387            cause: None,
388        });
389    }
390
391    let host = match dst.host() {
392        Some(s) => s,
393        None => {
394            return Err(ConnectError {
395                msg: INVALID_MISSING_HOST.into(),
396                cause: None,
397            })
398        }
399    };
400    let port = match dst.port() {
401        Some(port) => port.as_u16(),
402        None => {
403            if dst.scheme() == Some(&Scheme::HTTPS) {
404                443
405            } else {
406                80
407            }
408        }
409    };
410
411    Ok((host, port))
412}
413
414impl<R> HttpConnector<R>
415where
416    R: Resolve,
417{
418    async fn call_async(&mut self, dst: Uri) -> Result<TokioIo<TcpStream>, ConnectError> {
419        let config = &self.config;
420
421        let (host, port) = get_host_port(config, &dst)?;
422        let host = host.trim_start_matches('[').trim_end_matches(']');
423
424        // If the host is already an IP addr (v4 or v6),
425        // skip resolving the dns and start connecting right away.
426        let addrs = if let Some(addrs) = dns::SocketAddrs::try_parse(host, port) {
427            addrs
428        } else {
429            let addrs = resolve(&mut self.resolver, dns::Name::new(host.into()))
430                .await
431                .map_err(ConnectError::dns)?;
432            let addrs = addrs
433                .map(|mut addr| {
434                    addr.set_port(port);
435                    addr
436                })
437                .collect();
438            dns::SocketAddrs::new(addrs)
439        };
440
441        let c = ConnectingTcp::new(addrs, config);
442
443        let sock = c.connect().await?;
444
445        if let Err(e) = sock.set_nodelay(config.nodelay) {
446            warn!("tcp set_nodelay error: {}", e);
447        }
448
449        Ok(TokioIo::new(sock))
450    }
451}
452
453impl Connection for TcpStream {
454    fn connected(&self) -> Connected {
455        let connected = Connected::new();
456        if let (Ok(remote_addr), Ok(local_addr)) = (self.peer_addr(), self.local_addr()) {
457            connected.extra(HttpInfo {
458                remote_addr,
459                local_addr,
460            })
461        } else {
462            connected
463        }
464    }
465}
466
467// Implement `Connection` for generic `TokioIo<T>` so that external crates can
468// implement their own `HttpConnector` with `TokioIo<CustomTcpStream>`.
469impl<T> Connection for TokioIo<T>
470where
471    T: Connection,
472{
473    fn connected(&self) -> Connected {
474        self.inner().connected()
475    }
476}
477
478impl HttpInfo {
479    /// Get the remote address of the transport used.
480    pub fn remote_addr(&self) -> SocketAddr {
481        self.remote_addr
482    }
483
484    /// Get the local address of the transport used.
485    pub fn local_addr(&self) -> SocketAddr {
486        self.local_addr
487    }
488}
489
490pin_project! {
491    // Not publicly exported (so missing_docs doesn't trigger).
492    //
493    // We return this `Future` instead of the `Pin<Box<dyn Future>>` directly
494    // so that users don't rely on it fitting in a `Pin<Box<dyn Future>>` slot
495    // (and thus we can change the type in the future).
496    #[must_use = "futures do nothing unless polled"]
497    #[allow(missing_debug_implementations)]
498    pub struct HttpConnecting<R> {
499        #[pin]
500        fut: BoxConnecting,
501        _marker: PhantomData<R>,
502    }
503}
504
505type ConnectResult = Result<TokioIo<TcpStream>, ConnectError>;
506type BoxConnecting = Pin<Box<dyn Future<Output = ConnectResult> + Send>>;
507
508impl<R: Resolve> Future for HttpConnecting<R> {
509    type Output = ConnectResult;
510
511    fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
512        self.project().fut.poll(cx)
513    }
514}
515
516// Not publicly exported (so missing_docs doesn't trigger).
517pub struct ConnectError {
518    msg: Box<str>,
519    cause: Option<Box<dyn StdError + Send + Sync>>,
520}
521
522impl ConnectError {
523    fn new<S, E>(msg: S, cause: E) -> ConnectError
524    where
525        S: Into<Box<str>>,
526        E: Into<Box<dyn StdError + Send + Sync>>,
527    {
528        ConnectError {
529            msg: msg.into(),
530            cause: Some(cause.into()),
531        }
532    }
533
534    fn dns<E>(cause: E) -> ConnectError
535    where
536        E: Into<Box<dyn StdError + Send + Sync>>,
537    {
538        ConnectError::new("dns error", cause)
539    }
540
541    fn m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError
542    where
543        S: Into<Box<str>>,
544        E: Into<Box<dyn StdError + Send + Sync>>,
545    {
546        move |cause| ConnectError::new(msg, cause)
547    }
548}
549
550impl fmt::Debug for ConnectError {
551    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
552        if let Some(ref cause) = self.cause {
553            f.debug_tuple("ConnectError")
554                .field(&self.msg)
555                .field(cause)
556                .finish()
557        } else {
558            self.msg.fmt(f)
559        }
560    }
561}
562
563impl fmt::Display for ConnectError {
564    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
565        f.write_str(&self.msg)?;
566
567        if let Some(ref cause) = self.cause {
568            write!(f, ": {}", cause)?;
569        }
570
571        Ok(())
572    }
573}
574
575impl StdError for ConnectError {
576    fn source(&self) -> Option<&(dyn StdError + 'static)> {
577        self.cause.as_ref().map(|e| &**e as _)
578    }
579}
580
581struct ConnectingTcp<'a> {
582    preferred: ConnectingTcpRemote,
583    fallback: Option<ConnectingTcpFallback>,
584    config: &'a Config,
585}
586
587impl<'a> ConnectingTcp<'a> {
588    fn new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self {
589        if let Some(fallback_timeout) = config.happy_eyeballs_timeout {
590            let (preferred_addrs, fallback_addrs) = remote_addrs
591                .split_by_preference(config.local_address_ipv4, config.local_address_ipv6);
592            if fallback_addrs.is_empty() {
593                return ConnectingTcp {
594                    preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
595                    fallback: None,
596                    config,
597                };
598            }
599
600            ConnectingTcp {
601                preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
602                fallback: Some(ConnectingTcpFallback {
603                    delay: tokio::time::sleep(fallback_timeout),
604                    remote: ConnectingTcpRemote::new(fallback_addrs, config.connect_timeout),
605                }),
606                config,
607            }
608        } else {
609            ConnectingTcp {
610                preferred: ConnectingTcpRemote::new(remote_addrs, config.connect_timeout),
611                fallback: None,
612                config,
613            }
614        }
615    }
616}
617
618struct ConnectingTcpFallback {
619    delay: Sleep,
620    remote: ConnectingTcpRemote,
621}
622
623struct ConnectingTcpRemote {
624    addrs: dns::SocketAddrs,
625    connect_timeout: Option<Duration>,
626}
627
628impl ConnectingTcpRemote {
629    fn new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self {
630        let connect_timeout = connect_timeout.and_then(|t| t.checked_div(addrs.len() as u32));
631
632        Self {
633            addrs,
634            connect_timeout,
635        }
636    }
637}
638
639impl ConnectingTcpRemote {
640    async fn connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError> {
641        let mut err = None;
642        for addr in &mut self.addrs {
643            debug!("connecting to {}", addr);
644            match connect(&addr, config, self.connect_timeout)?.await {
645                Ok(tcp) => {
646                    debug!("connected to {}", addr);
647                    return Ok(tcp);
648                }
649                Err(e) => {
650                    trace!("connect error for {}: {:?}", addr, e);
651                    err = Some(e);
652                }
653            }
654        }
655
656        match err {
657            Some(e) => Err(e),
658            None => Err(ConnectError::new(
659                "tcp connect error",
660                std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"),
661            )),
662        }
663    }
664}
665
666fn bind_local_address(
667    socket: &socket2::Socket,
668    dst_addr: &SocketAddr,
669    local_addr_ipv4: &Option<Ipv4Addr>,
670    local_addr_ipv6: &Option<Ipv6Addr>,
671) -> io::Result<()> {
672    match (*dst_addr, local_addr_ipv4, local_addr_ipv6) {
673        (SocketAddr::V4(_), Some(addr), _) => {
674            socket.bind(&SocketAddr::new((*addr).into(), 0).into())?;
675        }
676        (SocketAddr::V6(_), _, Some(addr)) => {
677            socket.bind(&SocketAddr::new((*addr).into(), 0).into())?;
678        }
679        _ => {
680            if cfg!(windows) {
681                // Windows requires a socket be bound before calling connect
682                let any: SocketAddr = match *dst_addr {
683                    SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
684                    SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
685                };
686                socket.bind(&any.into())?;
687            }
688        }
689    }
690
691    Ok(())
692}
693
694fn connect(
695    addr: &SocketAddr,
696    config: &Config,
697    connect_timeout: Option<Duration>,
698) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> {
699    // TODO(eliza): if Tokio's `TcpSocket` gains support for setting the
700    // keepalive timeout, it would be nice to use that instead of socket2,
701    // and avoid the unsafe `into_raw_fd`/`from_raw_fd` dance...
702    use socket2::{Domain, Protocol, Socket, Type};
703
704    let domain = Domain::for_address(*addr);
705    let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
706        .map_err(ConnectError::m("tcp open error"))?;
707
708    // When constructing a Tokio `TcpSocket` from a raw fd/socket, the user is
709    // responsible for ensuring O_NONBLOCK is set.
710    socket
711        .set_nonblocking(true)
712        .map_err(ConnectError::m("tcp set_nonblocking error"))?;
713
714    if let Some(tcp_keepalive) = &config.tcp_keepalive_config.into_tcpkeepalive() {
715        if let Err(e) = socket.set_tcp_keepalive(tcp_keepalive) {
716            warn!("tcp set_keepalive error: {}", e);
717        }
718    }
719
720    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
721    // That this only works for some socket types, particularly AF_INET sockets.
722    if let Some(interface) = &config.interface {
723        socket
724            .bind_device(Some(interface.as_bytes()))
725            .map_err(ConnectError::m("tcp bind interface error"))?;
726    }
727
728    bind_local_address(
729        &socket,
730        addr,
731        &config.local_address_ipv4,
732        &config.local_address_ipv6,
733    )
734    .map_err(ConnectError::m("tcp bind local error"))?;
735
736    #[cfg(unix)]
737    let socket = unsafe {
738        // Safety: `from_raw_fd` is only safe to call if ownership of the raw
739        // file descriptor is transferred. Since we call `into_raw_fd` on the
740        // socket2 socket, it gives up ownership of the fd and will not close
741        // it, so this is safe.
742        use std::os::unix::io::{FromRawFd, IntoRawFd};
743        TcpSocket::from_raw_fd(socket.into_raw_fd())
744    };
745    #[cfg(windows)]
746    let socket = unsafe {
747        // Safety: `from_raw_socket` is only safe to call if ownership of the raw
748        // Windows SOCKET is transferred. Since we call `into_raw_socket` on the
749        // socket2 socket, it gives up ownership of the SOCKET and will not close
750        // it, so this is safe.
751        use std::os::windows::io::{FromRawSocket, IntoRawSocket};
752        TcpSocket::from_raw_socket(socket.into_raw_socket())
753    };
754
755    if config.reuse_address {
756        if let Err(e) = socket.set_reuseaddr(true) {
757            warn!("tcp set_reuse_address error: {}", e);
758        }
759    }
760
761    if let Some(size) = config.send_buffer_size {
762        if let Err(e) = socket.set_send_buffer_size(size.try_into().unwrap_or(u32::MAX)) {
763            warn!("tcp set_buffer_size error: {}", e);
764        }
765    }
766
767    if let Some(size) = config.recv_buffer_size {
768        if let Err(e) = socket.set_recv_buffer_size(size.try_into().unwrap_or(u32::MAX)) {
769            warn!("tcp set_recv_buffer_size error: {}", e);
770        }
771    }
772
773    let connect = socket.connect(*addr);
774    Ok(async move {
775        match connect_timeout {
776            Some(dur) => match tokio::time::timeout(dur, connect).await {
777                Ok(Ok(s)) => Ok(s),
778                Ok(Err(e)) => Err(e),
779                Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
780            },
781            None => connect.await,
782        }
783        .map_err(ConnectError::m("tcp connect error"))
784    })
785}
786
787impl ConnectingTcp<'_> {
788    async fn connect(mut self) -> Result<TcpStream, ConnectError> {
789        match self.fallback {
790            None => self.preferred.connect(self.config).await,
791            Some(mut fallback) => {
792                let preferred_fut = self.preferred.connect(self.config);
793                futures_util::pin_mut!(preferred_fut);
794
795                let fallback_fut = fallback.remote.connect(self.config);
796                futures_util::pin_mut!(fallback_fut);
797
798                let fallback_delay = fallback.delay;
799                futures_util::pin_mut!(fallback_delay);
800
801                let (result, future) =
802                    match futures_util::future::select(preferred_fut, fallback_delay).await {
803                        Either::Left((result, _fallback_delay)) => {
804                            (result, Either::Right(fallback_fut))
805                        }
806                        Either::Right(((), preferred_fut)) => {
807                            // Delay is done, start polling both the preferred and the fallback
808                            futures_util::future::select(preferred_fut, fallback_fut)
809                                .await
810                                .factor_first()
811                        }
812                    };
813
814                if result.is_err() {
815                    // Fallback to the remaining future (could be preferred or fallback)
816                    // if we get an error
817                    future.await
818                } else {
819                    result
820                }
821            }
822        }
823    }
824}
825
826#[cfg(test)]
827mod tests {
828    use std::io;
829
830    use ::http::Uri;
831
832    use crate::client::legacy::connect::http::TcpKeepaliveConfig;
833
834    use super::super::sealed::{Connect, ConnectSvc};
835    use super::{Config, ConnectError, HttpConnector};
836
837    async fn connect<C>(
838        connector: C,
839        dst: Uri,
840    ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error>
841    where
842        C: Connect,
843    {
844        connector.connect(super::super::sealed::Internal, dst).await
845    }
846
847    #[tokio::test]
848    #[cfg_attr(miri, ignore)]
849    async fn test_errors_enforce_http() {
850        let dst = "https://example.domain/foo/bar?baz".parse().unwrap();
851        let connector = HttpConnector::new();
852
853        let err = connect(connector, dst).await.unwrap_err();
854        assert_eq!(&*err.msg, super::INVALID_NOT_HTTP);
855    }
856
857    #[cfg(any(target_os = "linux", target_os = "macos"))]
858    fn get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>) {
859        use std::net::{IpAddr, TcpListener};
860
861        let mut ip_v4 = None;
862        let mut ip_v6 = None;
863
864        let ips = pnet_datalink::interfaces()
865            .into_iter()
866            .flat_map(|i| i.ips.into_iter().map(|n| n.ip()));
867
868        for ip in ips {
869            match ip {
870                IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip),
871                IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip),
872                _ => (),
873            }
874
875            if ip_v4.is_some() && ip_v6.is_some() {
876                break;
877            }
878        }
879
880        (ip_v4, ip_v6)
881    }
882
883    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
884    fn default_interface() -> Option<String> {
885        pnet_datalink::interfaces()
886            .iter()
887            .find(|e| e.is_up() && !e.is_loopback() && !e.ips.is_empty())
888            .map(|e| e.name.clone())
889    }
890
891    #[tokio::test]
892    #[cfg_attr(miri, ignore)]
893    async fn test_errors_missing_scheme() {
894        let dst = "example.domain".parse().unwrap();
895        let mut connector = HttpConnector::new();
896        connector.enforce_http(false);
897
898        let err = connect(connector, dst).await.unwrap_err();
899        assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME);
900    }
901
902    // NOTE: pnet crate that we use in this test doesn't compile on Windows
903    #[cfg(any(target_os = "linux", target_os = "macos"))]
904    #[cfg_attr(miri, ignore)]
905    #[tokio::test]
906    async fn local_address() {
907        use std::net::{IpAddr, TcpListener};
908
909        let (bind_ip_v4, bind_ip_v6) = get_local_ips();
910        let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
911        let port = server4.local_addr().unwrap().port();
912        let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap();
913
914        let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move {
915            let mut connector = HttpConnector::new();
916
917            match (bind_ip_v4, bind_ip_v6) {
918                (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6),
919                (Some(v4), None) => connector.set_local_address(Some(v4.into())),
920                (None, Some(v6)) => connector.set_local_address(Some(v6.into())),
921                _ => unreachable!(),
922            }
923
924            connect(connector, dst.parse().unwrap()).await.unwrap();
925
926            let (_, client_addr) = server.accept().unwrap();
927
928            assert_eq!(client_addr.ip(), expected_ip);
929        };
930
931        if let Some(ip) = bind_ip_v4 {
932            assert_client_ip(format!("http://127.0.0.1:{}", port), server4, ip.into()).await;
933        }
934
935        if let Some(ip) = bind_ip_v6 {
936            assert_client_ip(format!("http://[::1]:{}", port), server6, ip.into()).await;
937        }
938    }
939
940    // NOTE: pnet crate that we use in this test doesn't compile on Windows
941    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
942    #[tokio::test]
943    #[ignore = "setting `SO_BINDTODEVICE` requires the `CAP_NET_RAW` capability (works when running as root)"]
944    async fn interface() {
945        use socket2::{Domain, Protocol, Socket, Type};
946        use std::net::TcpListener;
947
948        let interface: Option<String> = default_interface();
949
950        let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
951        let port = server4.local_addr().unwrap().port();
952
953        let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap();
954
955        let assert_interface_name =
956            |dst: String,
957             server: TcpListener,
958             bind_iface: Option<String>,
959             expected_interface: Option<String>| async move {
960                let mut connector = HttpConnector::new();
961                if let Some(iface) = bind_iface {
962                    connector.set_interface(iface);
963                }
964
965                connect(connector, dst.parse().unwrap()).await.unwrap();
966                let domain = Domain::for_address(server.local_addr().unwrap());
967                let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP)).unwrap();
968
969                assert_eq!(
970                    socket.device().unwrap().as_deref(),
971                    expected_interface.as_deref().map(|val| val.as_bytes())
972                );
973            };
974
975        assert_interface_name(
976            format!("http://127.0.0.1:{}", port),
977            server4,
978            interface.clone(),
979            interface.clone(),
980        )
981        .await;
982        assert_interface_name(
983            format!("http://[::1]:{}", port),
984            server6,
985            interface.clone(),
986            interface.clone(),
987        )
988        .await;
989    }
990
991    #[test]
992    #[ignore] // TODO
993    #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
994    fn client_happy_eyeballs() {
995        use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener};
996        use std::time::{Duration, Instant};
997
998        use super::dns;
999        use super::ConnectingTcp;
1000
1001        let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
1002        let addr = server4.local_addr().unwrap();
1003        let _server6 = TcpListener::bind(&format!("[::1]:{}", addr.port())).unwrap();
1004        let rt = tokio::runtime::Builder::new_current_thread()
1005            .enable_all()
1006            .build()
1007            .unwrap();
1008
1009        let local_timeout = Duration::default();
1010        let unreachable_v4_timeout = measure_connect(unreachable_ipv4_addr()).1;
1011        let unreachable_v6_timeout = measure_connect(unreachable_ipv6_addr()).1;
1012        let fallback_timeout = std::cmp::max(unreachable_v4_timeout, unreachable_v6_timeout)
1013            + Duration::from_millis(250);
1014
1015        let scenarios = &[
1016            // Fast primary, without fallback.
1017            (&[local_ipv4_addr()][..], 4, local_timeout, false),
1018            (&[local_ipv6_addr()][..], 6, local_timeout, false),
1019            // Fast primary, with (unused) fallback.
1020            (
1021                &[local_ipv4_addr(), local_ipv6_addr()][..],
1022                4,
1023                local_timeout,
1024                false,
1025            ),
1026            (
1027                &[local_ipv6_addr(), local_ipv4_addr()][..],
1028                6,
1029                local_timeout,
1030                false,
1031            ),
1032            // Unreachable + fast primary, without fallback.
1033            (
1034                &[unreachable_ipv4_addr(), local_ipv4_addr()][..],
1035                4,
1036                unreachable_v4_timeout,
1037                false,
1038            ),
1039            (
1040                &[unreachable_ipv6_addr(), local_ipv6_addr()][..],
1041                6,
1042                unreachable_v6_timeout,
1043                false,
1044            ),
1045            // Unreachable + fast primary, with (unused) fallback.
1046            (
1047                &[
1048                    unreachable_ipv4_addr(),
1049                    local_ipv4_addr(),
1050                    local_ipv6_addr(),
1051                ][..],
1052                4,
1053                unreachable_v4_timeout,
1054                false,
1055            ),
1056            (
1057                &[
1058                    unreachable_ipv6_addr(),
1059                    local_ipv6_addr(),
1060                    local_ipv4_addr(),
1061                ][..],
1062                6,
1063                unreachable_v6_timeout,
1064                true,
1065            ),
1066            // Slow primary, with (used) fallback.
1067            (
1068                &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..],
1069                6,
1070                fallback_timeout,
1071                false,
1072            ),
1073            (
1074                &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..],
1075                4,
1076                fallback_timeout,
1077                true,
1078            ),
1079            // Slow primary, with (used) unreachable + fast fallback.
1080            (
1081                &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..],
1082                6,
1083                fallback_timeout + unreachable_v6_timeout,
1084                false,
1085            ),
1086            (
1087                &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..],
1088                4,
1089                fallback_timeout + unreachable_v4_timeout,
1090                true,
1091            ),
1092        ];
1093
1094        // Scenarios for IPv6 -> IPv4 fallback require that host can access IPv6 network.
1095        // Otherwise, connection to "slow" IPv6 address will error-out immediately.
1096        let ipv6_accessible = measure_connect(slow_ipv6_addr()).0;
1097
1098        for &(hosts, family, timeout, needs_ipv6_access) in scenarios {
1099            if needs_ipv6_access && !ipv6_accessible {
1100                continue;
1101            }
1102
1103            let (start, stream) = rt
1104                .block_on(async move {
1105                    let addrs = hosts
1106                        .iter()
1107                        .map(|host| (host.clone(), addr.port()).into())
1108                        .collect();
1109                    let cfg = Config {
1110                        local_address_ipv4: None,
1111                        local_address_ipv6: None,
1112                        connect_timeout: None,
1113                        tcp_keepalive_config: TcpKeepaliveConfig::default(),
1114                        happy_eyeballs_timeout: Some(fallback_timeout),
1115                        nodelay: false,
1116                        reuse_address: false,
1117                        enforce_http: false,
1118                        send_buffer_size: None,
1119                        recv_buffer_size: None,
1120                        interface: None,
1121                    };
1122                    let connecting_tcp = ConnectingTcp::new(dns::SocketAddrs::new(addrs), &cfg);
1123                    let start = Instant::now();
1124                    Ok::<_, ConnectError>((start, ConnectingTcp::connect(connecting_tcp).await?))
1125                })
1126                .unwrap();
1127            let res = if stream.peer_addr().unwrap().is_ipv4() {
1128                4
1129            } else {
1130                6
1131            };
1132            let duration = start.elapsed();
1133
1134            // Allow actual duration to be +/- 150ms off.
1135            let min_duration = if timeout >= Duration::from_millis(150) {
1136                timeout - Duration::from_millis(150)
1137            } else {
1138                Duration::default()
1139            };
1140            let max_duration = timeout + Duration::from_millis(150);
1141
1142            assert_eq!(res, family);
1143            assert!(duration >= min_duration);
1144            assert!(duration <= max_duration);
1145        }
1146
1147        fn local_ipv4_addr() -> IpAddr {
1148            Ipv4Addr::new(127, 0, 0, 1).into()
1149        }
1150
1151        fn local_ipv6_addr() -> IpAddr {
1152            Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()
1153        }
1154
1155        fn unreachable_ipv4_addr() -> IpAddr {
1156            Ipv4Addr::new(127, 0, 0, 2).into()
1157        }
1158
1159        fn unreachable_ipv6_addr() -> IpAddr {
1160            Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2).into()
1161        }
1162
1163        fn slow_ipv4_addr() -> IpAddr {
1164            // RFC 6890 reserved IPv4 address.
1165            Ipv4Addr::new(198, 18, 0, 25).into()
1166        }
1167
1168        fn slow_ipv6_addr() -> IpAddr {
1169            // RFC 6890 reserved IPv6 address.
1170            Ipv6Addr::new(2001, 2, 0, 0, 0, 0, 0, 254).into()
1171        }
1172
1173        fn measure_connect(addr: IpAddr) -> (bool, Duration) {
1174            let start = Instant::now();
1175            let result =
1176                std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1));
1177
1178            let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut;
1179            let duration = start.elapsed();
1180            (reachable, duration)
1181        }
1182    }
1183
1184    use std::time::Duration;
1185
1186    #[test]
1187    fn no_tcp_keepalive_config() {
1188        assert!(TcpKeepaliveConfig::default().into_tcpkeepalive().is_none());
1189    }
1190
1191    #[test]
1192    fn tcp_keepalive_time_config() {
1193        let mut kac = TcpKeepaliveConfig::default();
1194        kac.time = Some(Duration::from_secs(60));
1195        if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1196            assert!(format!("{tcp_keepalive:?}").contains("time: Some(60s)"));
1197        } else {
1198            panic!("test failed");
1199        }
1200    }
1201
1202    #[cfg(not(any(target_os = "openbsd", target_os = "redox", target_os = "solaris")))]
1203    #[test]
1204    fn tcp_keepalive_interval_config() {
1205        let mut kac = TcpKeepaliveConfig::default();
1206        kac.interval = Some(Duration::from_secs(1));
1207        if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1208            assert!(format!("{tcp_keepalive:?}").contains("interval: Some(1s)"));
1209        } else {
1210            panic!("test failed");
1211        }
1212    }
1213
1214    #[cfg(not(any(
1215        target_os = "openbsd",
1216        target_os = "redox",
1217        target_os = "solaris",
1218        target_os = "windows"
1219    )))]
1220    #[test]
1221    fn tcp_keepalive_retries_config() {
1222        let mut kac = TcpKeepaliveConfig::default();
1223        kac.retries = Some(3);
1224        if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1225            assert!(format!("{tcp_keepalive:?}").contains("retries: Some(3)"));
1226        } else {
1227            panic!("test failed");
1228        }
1229    }
1230}