1use std::error::Error as StdError;
2use std::fmt;
3use std::future::Future;
4use std::io;
5use std::marker::PhantomData;
6use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{self, Poll};
10use std::time::Duration;
11
12use futures_core::ready;
13use futures_util::future::Either;
14use http::uri::{Scheme, Uri};
15use pin_project_lite::pin_project;
16use socket2::TcpKeepalive;
17use tokio::net::{TcpSocket, TcpStream};
18use tokio::time::Sleep;
19use tracing::{debug, trace, warn};
20
21use super::dns::{self, resolve, GaiResolver, Resolve};
22use super::{Connected, Connection};
23use crate::rt::TokioIo;
24
25#[derive(Clone)]
34pub struct HttpConnector<R = GaiResolver> {
35 config: Arc<Config>,
36 resolver: R,
37}
38
39#[derive(Clone, Debug)]
63pub struct HttpInfo {
64 remote_addr: SocketAddr,
65 local_addr: SocketAddr,
66}
67
68#[derive(Clone)]
69struct Config {
70 connect_timeout: Option<Duration>,
71 enforce_http: bool,
72 happy_eyeballs_timeout: Option<Duration>,
73 tcp_keepalive_config: TcpKeepaliveConfig,
74 local_address_ipv4: Option<Ipv4Addr>,
75 local_address_ipv6: Option<Ipv6Addr>,
76 nodelay: bool,
77 reuse_address: bool,
78 send_buffer_size: Option<usize>,
79 recv_buffer_size: Option<usize>,
80 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
81 interface: Option<String>,
82 #[cfg(any(
83 target_os = "illumos",
84 target_os = "ios",
85 target_os = "macos",
86 target_os = "solaris",
87 target_os = "tvos",
88 target_os = "visionos",
89 target_os = "watchos",
90 ))]
91 interface: Option<std::ffi::CString>,
92 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
93 tcp_user_timeout: Option<Duration>,
94}
95
96#[derive(Default, Debug, Clone, Copy)]
97struct TcpKeepaliveConfig {
98 time: Option<Duration>,
99 interval: Option<Duration>,
100 retries: Option<u32>,
101}
102
103impl TcpKeepaliveConfig {
104 fn into_tcpkeepalive(self) -> Option<TcpKeepalive> {
106 let mut dirty = false;
107 let mut ka = TcpKeepalive::new();
108 if let Some(time) = self.time {
109 ka = ka.with_time(time);
110 dirty = true
111 }
112 if let Some(interval) = self.interval {
113 ka = Self::ka_with_interval(ka, interval, &mut dirty)
114 };
115 if let Some(retries) = self.retries {
116 ka = Self::ka_with_retries(ka, retries, &mut dirty)
117 };
118 if dirty {
119 Some(ka)
120 } else {
121 None
122 }
123 }
124
125 #[cfg(
126 any(
128 target_os = "android",
129 target_os = "dragonfly",
130 target_os = "freebsd",
131 target_os = "fuchsia",
132 target_os = "illumos",
133 target_os = "ios",
134 target_os = "visionos",
135 target_os = "linux",
136 target_os = "macos",
137 target_os = "netbsd",
138 target_os = "tvos",
139 target_os = "watchos",
140 target_os = "windows",
141 )
142 )]
143 fn ka_with_interval(ka: TcpKeepalive, interval: Duration, dirty: &mut bool) -> TcpKeepalive {
144 *dirty = true;
145 ka.with_interval(interval)
146 }
147
148 #[cfg(not(
149 any(
151 target_os = "android",
152 target_os = "dragonfly",
153 target_os = "freebsd",
154 target_os = "fuchsia",
155 target_os = "illumos",
156 target_os = "ios",
157 target_os = "visionos",
158 target_os = "linux",
159 target_os = "macos",
160 target_os = "netbsd",
161 target_os = "tvos",
162 target_os = "watchos",
163 target_os = "windows",
164 )
165 ))]
166 fn ka_with_interval(ka: TcpKeepalive, _: Duration, _: &mut bool) -> TcpKeepalive {
167 ka }
169
170 #[cfg(
171 any(
173 target_os = "android",
174 target_os = "dragonfly",
175 target_os = "freebsd",
176 target_os = "fuchsia",
177 target_os = "illumos",
178 target_os = "ios",
179 target_os = "visionos",
180 target_os = "linux",
181 target_os = "macos",
182 target_os = "netbsd",
183 target_os = "tvos",
184 target_os = "watchos",
185 )
186 )]
187 fn ka_with_retries(ka: TcpKeepalive, retries: u32, dirty: &mut bool) -> TcpKeepalive {
188 *dirty = true;
189 ka.with_retries(retries)
190 }
191
192 #[cfg(not(
193 any(
195 target_os = "android",
196 target_os = "dragonfly",
197 target_os = "freebsd",
198 target_os = "fuchsia",
199 target_os = "illumos",
200 target_os = "ios",
201 target_os = "visionos",
202 target_os = "linux",
203 target_os = "macos",
204 target_os = "netbsd",
205 target_os = "tvos",
206 target_os = "watchos",
207 )
208 ))]
209 fn ka_with_retries(ka: TcpKeepalive, _: u32, _: &mut bool) -> TcpKeepalive {
210 ka }
212}
213
214impl HttpConnector {
217 pub fn new() -> HttpConnector {
219 HttpConnector::new_with_resolver(GaiResolver::new())
220 }
221}
222
223impl<R> HttpConnector<R> {
224 pub fn new_with_resolver(resolver: R) -> HttpConnector<R> {
228 HttpConnector {
229 config: Arc::new(Config {
230 connect_timeout: None,
231 enforce_http: true,
232 happy_eyeballs_timeout: Some(Duration::from_millis(300)),
233 tcp_keepalive_config: TcpKeepaliveConfig::default(),
234 local_address_ipv4: None,
235 local_address_ipv6: None,
236 nodelay: false,
237 reuse_address: false,
238 send_buffer_size: None,
239 recv_buffer_size: None,
240 #[cfg(any(
241 target_os = "android",
242 target_os = "fuchsia",
243 target_os = "illumos",
244 target_os = "ios",
245 target_os = "linux",
246 target_os = "macos",
247 target_os = "solaris",
248 target_os = "tvos",
249 target_os = "visionos",
250 target_os = "watchos",
251 ))]
252 interface: None,
253 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
254 tcp_user_timeout: None,
255 }),
256 resolver,
257 }
258 }
259
260 #[inline]
264 pub fn enforce_http(&mut self, is_enforced: bool) {
265 self.config_mut().enforce_http = is_enforced;
266 }
267
268 #[inline]
275 pub fn set_keepalive(&mut self, time: Option<Duration>) {
276 self.config_mut().tcp_keepalive_config.time = time;
277 }
278
279 #[inline]
282 pub fn set_keepalive_interval(&mut self, interval: Option<Duration>) {
283 self.config_mut().tcp_keepalive_config.interval = interval;
284 }
285
286 #[inline]
288 pub fn set_keepalive_retries(&mut self, retries: Option<u32>) {
289 self.config_mut().tcp_keepalive_config.retries = retries;
290 }
291
292 #[inline]
296 pub fn set_nodelay(&mut self, nodelay: bool) {
297 self.config_mut().nodelay = nodelay;
298 }
299
300 #[inline]
302 pub fn set_send_buffer_size(&mut self, size: Option<usize>) {
303 self.config_mut().send_buffer_size = size;
304 }
305
306 #[inline]
308 pub fn set_recv_buffer_size(&mut self, size: Option<usize>) {
309 self.config_mut().recv_buffer_size = size;
310 }
311
312 #[inline]
318 pub fn set_local_address(&mut self, addr: Option<IpAddr>) {
319 let (v4, v6) = match addr {
320 Some(IpAddr::V4(a)) => (Some(a), None),
321 Some(IpAddr::V6(a)) => (None, Some(a)),
322 _ => (None, None),
323 };
324
325 let cfg = self.config_mut();
326
327 cfg.local_address_ipv4 = v4;
328 cfg.local_address_ipv6 = v6;
329 }
330
331 #[inline]
334 pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) {
335 let cfg = self.config_mut();
336
337 cfg.local_address_ipv4 = Some(addr_ipv4);
338 cfg.local_address_ipv6 = Some(addr_ipv6);
339 }
340
341 #[inline]
348 pub fn set_connect_timeout(&mut self, dur: Option<Duration>) {
349 self.config_mut().connect_timeout = dur;
350 }
351
352 #[inline]
365 pub fn set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>) {
366 self.config_mut().happy_eyeballs_timeout = dur;
367 }
368
369 #[inline]
373 pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self {
374 self.config_mut().reuse_address = reuse_address;
375 self
376 }
377
378 #[cfg(any(
403 target_os = "android",
404 target_os = "fuchsia",
405 target_os = "illumos",
406 target_os = "ios",
407 target_os = "linux",
408 target_os = "macos",
409 target_os = "solaris",
410 target_os = "tvos",
411 target_os = "visionos",
412 target_os = "watchos",
413 ))]
414 #[inline]
415 pub fn set_interface<S: Into<String>>(&mut self, interface: S) -> &mut Self {
416 let interface = interface.into();
417 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
418 {
419 self.config_mut().interface = Some(interface);
420 }
421 #[cfg(not(any(target_os = "android", target_os = "fuchsia", target_os = "linux")))]
422 {
423 let interface = std::ffi::CString::new(interface)
424 .expect("interface name should not have nulls in it");
425 self.config_mut().interface = Some(interface);
426 }
427 self
428 }
429
430 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
432 #[inline]
433 pub fn set_tcp_user_timeout(&mut self, time: Option<Duration>) {
434 self.config_mut().tcp_user_timeout = time;
435 }
436
437 fn config_mut(&mut self) -> &mut Config {
440 Arc::make_mut(&mut self.config)
444 }
445}
446
447static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http";
448static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing";
449static INVALID_MISSING_HOST: &str = "invalid URL, host is missing";
450
451impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> {
453 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
454 f.debug_struct("HttpConnector").finish()
455 }
456}
457
458impl<R> tower_service::Service<Uri> for HttpConnector<R>
459where
460 R: Resolve + Clone + Send + Sync + 'static,
461 R::Future: Send,
462{
463 type Response = TokioIo<TcpStream>;
464 type Error = ConnectError;
465 type Future = HttpConnecting<R>;
466
467 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
468 ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
469 Poll::Ready(Ok(()))
470 }
471
472 fn call(&mut self, dst: Uri) -> Self::Future {
473 let mut self_ = self.clone();
474 HttpConnecting {
475 fut: Box::pin(async move { self_.call_async(dst).await }),
476 _marker: PhantomData,
477 }
478 }
479}
480
481fn get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError> {
482 trace!(
483 "Http::connect; scheme={:?}, host={:?}, port={:?}",
484 dst.scheme(),
485 dst.host(),
486 dst.port(),
487 );
488
489 if config.enforce_http {
490 if dst.scheme() != Some(&Scheme::HTTP) {
491 return Err(ConnectError {
492 msg: INVALID_NOT_HTTP,
493 addr: None,
494 cause: None,
495 });
496 }
497 } else if dst.scheme().is_none() {
498 return Err(ConnectError {
499 msg: INVALID_MISSING_SCHEME,
500 addr: None,
501 cause: None,
502 });
503 }
504
505 let host = match dst.host() {
506 Some(s) => s,
507 None => {
508 return Err(ConnectError {
509 msg: INVALID_MISSING_HOST,
510 addr: None,
511 cause: None,
512 });
513 }
514 };
515 let port = match dst.port() {
516 Some(port) => port.as_u16(),
517 None => {
518 if dst.scheme() == Some(&Scheme::HTTPS) {
519 443
520 } else {
521 80
522 }
523 }
524 };
525
526 Ok((host, port))
527}
528
529impl<R> HttpConnector<R>
530where
531 R: Resolve,
532{
533 async fn call_async(&mut self, dst: Uri) -> Result<TokioIo<TcpStream>, ConnectError> {
534 let config = &self.config;
535
536 let (host, port) = get_host_port(config, &dst)?;
537 let host = host.trim_start_matches('[').trim_end_matches(']');
538
539 let addrs = if let Some(addrs) = dns::SocketAddrs::try_parse(host, port) {
542 addrs
543 } else {
544 let addrs = resolve(&mut self.resolver, dns::Name::new(host.into()))
545 .await
546 .map_err(ConnectError::dns)?;
547 let addrs = addrs
548 .map(|mut addr| {
549 set_port(&mut addr, port, dst.port().is_some());
550
551 addr
552 })
553 .collect();
554 dns::SocketAddrs::new(addrs)
555 };
556
557 let c = ConnectingTcp::new(addrs, config);
558
559 let sock = c.connect().await?;
560
561 if let Err(e) = sock.set_nodelay(config.nodelay) {
562 warn!("tcp set_nodelay error: {}", e);
563 }
564
565 Ok(TokioIo::new(sock))
566 }
567}
568
569impl Connection for TcpStream {
570 fn connected(&self) -> Connected {
571 let connected = Connected::new();
572 if let (Ok(remote_addr), Ok(local_addr)) = (self.peer_addr(), self.local_addr()) {
573 connected.extra(HttpInfo {
574 remote_addr,
575 local_addr,
576 })
577 } else {
578 connected
579 }
580 }
581}
582
583#[cfg(unix)]
584impl Connection for tokio::net::UnixStream {
585 fn connected(&self) -> Connected {
586 Connected::new()
587 }
588}
589
590#[cfg(windows)]
591impl Connection for tokio::net::windows::named_pipe::NamedPipeClient {
592 fn connected(&self) -> Connected {
593 Connected::new()
594 }
595}
596
597impl<T> Connection for TokioIo<T>
600where
601 T: Connection,
602{
603 fn connected(&self) -> Connected {
604 self.inner().connected()
605 }
606}
607
608impl HttpInfo {
609 pub fn remote_addr(&self) -> SocketAddr {
611 self.remote_addr
612 }
613
614 pub fn local_addr(&self) -> SocketAddr {
616 self.local_addr
617 }
618}
619
620pin_project! {
621 #[must_use = "futures do nothing unless polled"]
627 #[allow(missing_debug_implementations)]
628 pub struct HttpConnecting<R> {
629 #[pin]
630 fut: BoxConnecting,
631 _marker: PhantomData<R>,
632 }
633}
634
635type ConnectResult = Result<TokioIo<TcpStream>, ConnectError>;
636type BoxConnecting = Pin<Box<dyn Future<Output = ConnectResult> + Send>>;
637
638impl<R: Resolve> Future for HttpConnecting<R> {
639 type Output = ConnectResult;
640
641 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
642 self.project().fut.poll(cx)
643 }
644}
645
646pub struct ConnectError {
648 msg: &'static str,
649 addr: Option<SocketAddr>,
650 cause: Option<Box<dyn StdError + Send + Sync>>,
651}
652
653impl ConnectError {
654 fn new<E>(msg: &'static str, cause: E) -> ConnectError
655 where
656 E: Into<Box<dyn StdError + Send + Sync>>,
657 {
658 ConnectError {
659 msg,
660 addr: None,
661 cause: Some(cause.into()),
662 }
663 }
664
665 fn dns<E>(cause: E) -> ConnectError
666 where
667 E: Into<Box<dyn StdError + Send + Sync>>,
668 {
669 ConnectError::new("dns error", cause)
670 }
671
672 fn m<E>(msg: &'static str) -> impl FnOnce(E) -> ConnectError
673 where
674 E: Into<Box<dyn StdError + Send + Sync>>,
675 {
676 move |cause| ConnectError::new(msg, cause)
677 }
678}
679
680impl fmt::Debug for ConnectError {
681 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
682 let mut b = f.debug_tuple("ConnectError");
683 b.field(&self.msg);
684 if let Some(ref addr) = self.addr {
685 b.field(addr);
686 }
687 if let Some(ref cause) = self.cause {
688 b.field(cause);
689 }
690 b.finish()
691 }
692}
693
694impl fmt::Display for ConnectError {
695 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
696 f.write_str(self.msg)
697 }
698}
699
700impl StdError for ConnectError {
701 fn source(&self) -> Option<&(dyn StdError + 'static)> {
702 self.cause.as_ref().map(|e| &**e as _)
703 }
704}
705
706struct ConnectingTcp<'a> {
707 preferred: ConnectingTcpRemote,
708 fallback: Option<ConnectingTcpFallback>,
709 config: &'a Config,
710}
711
712impl<'a> ConnectingTcp<'a> {
713 fn new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self {
714 if let Some(fallback_timeout) = config.happy_eyeballs_timeout {
715 let (preferred_addrs, fallback_addrs) = remote_addrs
716 .split_by_preference(config.local_address_ipv4, config.local_address_ipv6);
717 if fallback_addrs.is_empty() {
718 return ConnectingTcp {
719 preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
720 fallback: None,
721 config,
722 };
723 }
724
725 ConnectingTcp {
726 preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
727 fallback: Some(ConnectingTcpFallback {
728 delay: tokio::time::sleep(fallback_timeout),
729 remote: ConnectingTcpRemote::new(fallback_addrs, config.connect_timeout),
730 }),
731 config,
732 }
733 } else {
734 ConnectingTcp {
735 preferred: ConnectingTcpRemote::new(remote_addrs, config.connect_timeout),
736 fallback: None,
737 config,
738 }
739 }
740 }
741}
742
743struct ConnectingTcpFallback {
744 delay: Sleep,
745 remote: ConnectingTcpRemote,
746}
747
748struct ConnectingTcpRemote {
749 addrs: dns::SocketAddrs,
750 connect_timeout: Option<Duration>,
751}
752
753impl ConnectingTcpRemote {
754 fn new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self {
755 let connect_timeout = connect_timeout.and_then(|t| t.checked_div(addrs.len() as u32));
756
757 Self {
758 addrs,
759 connect_timeout,
760 }
761 }
762}
763
764impl ConnectingTcpRemote {
765 async fn connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError> {
766 let mut err = None;
767 for addr in &mut self.addrs {
768 debug!("connecting to {}", addr);
769 match connect(&addr, config, self.connect_timeout)?.await {
770 Ok(tcp) => {
771 debug!("connected to {}", addr);
772 return Ok(tcp);
773 }
774 Err(mut e) => {
775 trace!("connect error for {}: {:?}", addr, e);
776 e.addr = Some(addr);
777 if err.is_none() {
779 err = Some(e);
780 }
781 }
782 }
783 }
784
785 match err {
786 Some(e) => Err(e),
787 None => Err(ConnectError::new(
788 "tcp connect error",
789 std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"),
790 )),
791 }
792 }
793}
794
795fn bind_local_address(
796 socket: &socket2::Socket,
797 dst_addr: &SocketAddr,
798 local_addr_ipv4: &Option<Ipv4Addr>,
799 local_addr_ipv6: &Option<Ipv6Addr>,
800) -> io::Result<()> {
801 match (*dst_addr, local_addr_ipv4, local_addr_ipv6) {
802 (SocketAddr::V4(_), Some(addr), _) => {
803 socket.bind(&SocketAddr::new((*addr).into(), 0).into())?;
804 }
805 (SocketAddr::V6(_), _, Some(addr)) => {
806 socket.bind(&SocketAddr::new((*addr).into(), 0).into())?;
807 }
808 _ => {
809 if cfg!(windows) {
810 let any: SocketAddr = match *dst_addr {
812 SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
813 SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
814 };
815 socket.bind(&any.into())?;
816 }
817 }
818 }
819
820 Ok(())
821}
822
823fn connect(
824 addr: &SocketAddr,
825 config: &Config,
826 connect_timeout: Option<Duration>,
827) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> {
828 use socket2::{Domain, Protocol, Socket, Type};
832
833 let domain = Domain::for_address(*addr);
834 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
835 .map_err(ConnectError::m("tcp open error"))?;
836
837 socket
840 .set_nonblocking(true)
841 .map_err(ConnectError::m("tcp set_nonblocking error"))?;
842
843 if let Some(tcp_keepalive) = &config.tcp_keepalive_config.into_tcpkeepalive() {
844 if let Err(e) = socket.set_tcp_keepalive(tcp_keepalive) {
845 warn!("tcp set_keepalive error: {}", e);
846 }
847 }
848
849 #[cfg(any(
851 target_os = "android",
852 target_os = "fuchsia",
853 target_os = "illumos",
854 target_os = "ios",
855 target_os = "linux",
856 target_os = "macos",
857 target_os = "solaris",
858 target_os = "tvos",
859 target_os = "visionos",
860 target_os = "watchos",
861 ))]
862 if let Some(interface) = &config.interface {
863 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
866 socket
867 .bind_device(Some(interface.as_bytes()))
868 .map_err(ConnectError::m("tcp bind interface error"))?;
869
870 #[cfg(any(
875 target_os = "illumos",
876 target_os = "ios",
877 target_os = "macos",
878 target_os = "solaris",
879 target_os = "tvos",
880 target_os = "visionos",
881 target_os = "watchos",
882 ))]
883 {
884 let idx = unsafe { libc::if_nametoindex(interface.as_ptr()) };
885 let idx = std::num::NonZeroU32::new(idx).ok_or_else(|| {
886 ConnectError::new(
888 "error converting interface name to index",
889 io::Error::last_os_error(),
890 )
891 })?;
892 match addr {
895 SocketAddr::V4(_) => socket.bind_device_by_index_v4(Some(idx)),
896 SocketAddr::V6(_) => socket.bind_device_by_index_v6(Some(idx)),
897 }
898 .map_err(ConnectError::m("tcp bind interface error"))?;
899 }
900 }
901
902 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
903 if let Some(tcp_user_timeout) = &config.tcp_user_timeout {
904 if let Err(e) = socket.set_tcp_user_timeout(Some(*tcp_user_timeout)) {
905 warn!("tcp set_tcp_user_timeout error: {}", e);
906 }
907 }
908
909 bind_local_address(
910 &socket,
911 addr,
912 &config.local_address_ipv4,
913 &config.local_address_ipv6,
914 )
915 .map_err(ConnectError::m("tcp bind local error"))?;
916
917 let socket = TcpSocket::from_std_stream(socket.into());
919
920 if config.reuse_address {
921 if let Err(e) = socket.set_reuseaddr(true) {
922 warn!("tcp set_reuse_address error: {}", e);
923 }
924 }
925
926 if let Some(size) = config.send_buffer_size {
927 if let Err(e) = socket.set_send_buffer_size(size.try_into().unwrap_or(u32::MAX)) {
928 warn!("tcp set_buffer_size error: {}", e);
929 }
930 }
931
932 if let Some(size) = config.recv_buffer_size {
933 if let Err(e) = socket.set_recv_buffer_size(size.try_into().unwrap_or(u32::MAX)) {
934 warn!("tcp set_recv_buffer_size error: {}", e);
935 }
936 }
937
938 let connect = socket.connect(*addr);
939 Ok(async move {
940 match connect_timeout {
941 Some(dur) => match tokio::time::timeout(dur, connect).await {
942 Ok(Ok(s)) => Ok(s),
943 Ok(Err(e)) => Err(e),
944 Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
945 },
946 None => connect.await,
947 }
948 .map_err(ConnectError::m("tcp connect error"))
949 })
950}
951
952impl ConnectingTcp<'_> {
953 async fn connect(mut self) -> Result<TcpStream, ConnectError> {
954 match self.fallback {
955 None => self.preferred.connect(self.config).await,
956 Some(mut fallback) => {
957 let preferred_fut = self.preferred.connect(self.config);
958 futures_util::pin_mut!(preferred_fut);
959
960 let fallback_fut = fallback.remote.connect(self.config);
961 futures_util::pin_mut!(fallback_fut);
962
963 let fallback_delay = fallback.delay;
964 futures_util::pin_mut!(fallback_delay);
965
966 let (result, future) =
967 match futures_util::future::select(preferred_fut, fallback_delay).await {
968 Either::Left((result, _fallback_delay)) => {
969 (result, Either::Right(fallback_fut))
970 }
971 Either::Right(((), preferred_fut)) => {
972 futures_util::future::select(preferred_fut, fallback_fut)
974 .await
975 .factor_first()
976 }
977 };
978
979 if result.is_err() {
980 future.await
983 } else {
984 result
985 }
986 }
987 }
988 }
989}
990
991fn set_port(addr: &mut SocketAddr, host_port: u16, explicit: bool) {
995 if explicit || addr.port() == 0 {
996 addr.set_port(host_port)
997 };
998}
999
1000#[cfg(test)]
1001mod tests {
1002 use std::io;
1003 use std::net::SocketAddr;
1004
1005 use ::http::Uri;
1006
1007 use crate::client::legacy::connect::http::TcpKeepaliveConfig;
1008
1009 use super::super::sealed::{Connect, ConnectSvc};
1010 use super::{Config, ConnectError, HttpConnector};
1011
1012 use super::set_port;
1013
1014 async fn connect<C>(
1015 connector: C,
1016 dst: Uri,
1017 ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error>
1018 where
1019 C: Connect,
1020 {
1021 connector.connect(super::super::sealed::Internal, dst).await
1022 }
1023
1024 #[tokio::test]
1025 async fn test_errors_enforce_http() {
1026 let dst = "https://example.domain/foo/bar?baz".parse().unwrap();
1027 let connector = HttpConnector::new();
1028
1029 let err = connect(connector, dst).await.unwrap_err();
1030 assert_eq!(&*err.msg, super::INVALID_NOT_HTTP);
1031 }
1032
1033 #[cfg(any(target_os = "linux", target_os = "macos"))]
1034 fn get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>) {
1035 use std::net::{IpAddr, TcpListener};
1036
1037 let mut ip_v4 = None;
1038 let mut ip_v6 = None;
1039
1040 let ips = pnet_datalink::interfaces()
1041 .into_iter()
1042 .flat_map(|i| i.ips.into_iter().map(|n| n.ip()));
1043
1044 for ip in ips {
1045 match ip {
1046 IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip),
1047 IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip),
1048 _ => (),
1049 }
1050
1051 if ip_v4.is_some() && ip_v6.is_some() {
1052 break;
1053 }
1054 }
1055
1056 (ip_v4, ip_v6)
1057 }
1058
1059 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
1060 fn default_interface() -> Option<String> {
1061 pnet_datalink::interfaces()
1062 .iter()
1063 .find(|e| e.is_up() && !e.is_loopback() && !e.ips.is_empty())
1064 .map(|e| e.name.clone())
1065 }
1066
1067 #[tokio::test]
1068 async fn test_errors_missing_scheme() {
1069 let dst = "example.domain".parse().unwrap();
1070 let mut connector = HttpConnector::new();
1071 connector.enforce_http(false);
1072
1073 let err = connect(connector, dst).await.unwrap_err();
1074 assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME);
1075 }
1076
1077 #[cfg(any(target_os = "linux", target_os = "macos"))]
1079 #[cfg_attr(miri, ignore)]
1080 #[tokio::test]
1081 async fn local_address() {
1082 use std::net::{IpAddr, TcpListener};
1083
1084 let (bind_ip_v4, bind_ip_v6) = get_local_ips();
1085 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
1086 let port = server4.local_addr().unwrap().port();
1087 let server6 = TcpListener::bind(format!("[::1]:{port}")).unwrap();
1088
1089 let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move {
1090 let mut connector = HttpConnector::new();
1091
1092 match (bind_ip_v4, bind_ip_v6) {
1093 (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6),
1094 (Some(v4), None) => connector.set_local_address(Some(v4.into())),
1095 (None, Some(v6)) => connector.set_local_address(Some(v6.into())),
1096 _ => unreachable!(),
1097 }
1098
1099 connect(connector, dst.parse().unwrap()).await.unwrap();
1100
1101 let (_, client_addr) = server.accept().unwrap();
1102
1103 assert_eq!(client_addr.ip(), expected_ip);
1104 };
1105
1106 if let Some(ip) = bind_ip_v4 {
1107 assert_client_ip(format!("http://127.0.0.1:{port}"), server4, ip.into()).await;
1108 }
1109
1110 if let Some(ip) = bind_ip_v6 {
1111 assert_client_ip(format!("http://[::1]:{port}"), server6, ip.into()).await;
1112 }
1113 }
1114
1115 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
1117 #[tokio::test]
1118 #[ignore = "setting `SO_BINDTODEVICE` requires the `CAP_NET_RAW` capability (works when running as root)"]
1119 async fn interface() {
1120 use socket2::{Domain, Protocol, Socket, Type};
1121 use std::net::TcpListener;
1122
1123 let interface: Option<String> = default_interface();
1124
1125 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
1126 let port = server4.local_addr().unwrap().port();
1127
1128 let server6 = TcpListener::bind(format!("[::1]:{port}")).unwrap();
1129
1130 let assert_interface_name =
1131 |dst: String,
1132 server: TcpListener,
1133 bind_iface: Option<String>,
1134 expected_interface: Option<String>| async move {
1135 let mut connector = HttpConnector::new();
1136 if let Some(iface) = bind_iface {
1137 connector.set_interface(iface);
1138 }
1139
1140 connect(connector, dst.parse().unwrap()).await.unwrap();
1141 let domain = Domain::for_address(server.local_addr().unwrap());
1142 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP)).unwrap();
1143
1144 assert_eq!(
1145 socket.device().unwrap().as_deref(),
1146 expected_interface.as_deref().map(|val| val.as_bytes())
1147 );
1148 };
1149
1150 assert_interface_name(
1151 format!("http://127.0.0.1:{port}"),
1152 server4,
1153 interface.clone(),
1154 interface.clone(),
1155 )
1156 .await;
1157 assert_interface_name(
1158 format!("http://[::1]:{port}"),
1159 server6,
1160 interface.clone(),
1161 interface.clone(),
1162 )
1163 .await;
1164 }
1165
1166 #[test]
1167 #[ignore] #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
1169 fn client_happy_eyeballs() {
1170 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener};
1171 use std::time::{Duration, Instant};
1172
1173 use super::dns;
1174 use super::ConnectingTcp;
1175
1176 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
1177 let addr = server4.local_addr().unwrap();
1178 let _server6 = TcpListener::bind(format!("[::1]:{}", addr.port())).unwrap();
1179 let rt = tokio::runtime::Builder::new_current_thread()
1180 .enable_all()
1181 .build()
1182 .unwrap();
1183
1184 let local_timeout = Duration::default();
1185 let unreachable_v4_timeout = measure_connect(unreachable_ipv4_addr()).1;
1186 let unreachable_v6_timeout = measure_connect(unreachable_ipv6_addr()).1;
1187 let fallback_timeout = std::cmp::max(unreachable_v4_timeout, unreachable_v6_timeout)
1188 + Duration::from_millis(250);
1189
1190 let scenarios = &[
1191 (&[local_ipv4_addr()][..], 4, local_timeout, false),
1193 (&[local_ipv6_addr()][..], 6, local_timeout, false),
1194 (
1196 &[local_ipv4_addr(), local_ipv6_addr()][..],
1197 4,
1198 local_timeout,
1199 false,
1200 ),
1201 (
1202 &[local_ipv6_addr(), local_ipv4_addr()][..],
1203 6,
1204 local_timeout,
1205 false,
1206 ),
1207 (
1209 &[unreachable_ipv4_addr(), local_ipv4_addr()][..],
1210 4,
1211 unreachable_v4_timeout,
1212 false,
1213 ),
1214 (
1215 &[unreachable_ipv6_addr(), local_ipv6_addr()][..],
1216 6,
1217 unreachable_v6_timeout,
1218 false,
1219 ),
1220 (
1222 &[
1223 unreachable_ipv4_addr(),
1224 local_ipv4_addr(),
1225 local_ipv6_addr(),
1226 ][..],
1227 4,
1228 unreachable_v4_timeout,
1229 false,
1230 ),
1231 (
1232 &[
1233 unreachable_ipv6_addr(),
1234 local_ipv6_addr(),
1235 local_ipv4_addr(),
1236 ][..],
1237 6,
1238 unreachable_v6_timeout,
1239 true,
1240 ),
1241 (
1243 &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..],
1244 6,
1245 fallback_timeout,
1246 false,
1247 ),
1248 (
1249 &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..],
1250 4,
1251 fallback_timeout,
1252 true,
1253 ),
1254 (
1256 &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..],
1257 6,
1258 fallback_timeout + unreachable_v6_timeout,
1259 false,
1260 ),
1261 (
1262 &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..],
1263 4,
1264 fallback_timeout + unreachable_v4_timeout,
1265 true,
1266 ),
1267 ];
1268
1269 let ipv6_accessible = measure_connect(slow_ipv6_addr()).0;
1272
1273 for &(hosts, family, timeout, needs_ipv6_access) in scenarios {
1274 if needs_ipv6_access && !ipv6_accessible {
1275 continue;
1276 }
1277
1278 let (start, stream) = rt
1279 .block_on(async move {
1280 let addrs = hosts
1281 .iter()
1282 .map(|host| (*host, addr.port()).into())
1283 .collect();
1284 let cfg = Config {
1285 local_address_ipv4: None,
1286 local_address_ipv6: None,
1287 connect_timeout: None,
1288 tcp_keepalive_config: TcpKeepaliveConfig::default(),
1289 happy_eyeballs_timeout: Some(fallback_timeout),
1290 nodelay: false,
1291 reuse_address: false,
1292 enforce_http: false,
1293 send_buffer_size: None,
1294 recv_buffer_size: None,
1295 #[cfg(any(
1296 target_os = "android",
1297 target_os = "fuchsia",
1298 target_os = "linux"
1299 ))]
1300 interface: None,
1301 #[cfg(any(
1302 target_os = "illumos",
1303 target_os = "ios",
1304 target_os = "macos",
1305 target_os = "solaris",
1306 target_os = "tvos",
1307 target_os = "visionos",
1308 target_os = "watchos",
1309 ))]
1310 interface: None,
1311 #[cfg(any(
1312 target_os = "android",
1313 target_os = "fuchsia",
1314 target_os = "linux"
1315 ))]
1316 tcp_user_timeout: None,
1317 };
1318 let connecting_tcp = ConnectingTcp::new(dns::SocketAddrs::new(addrs), &cfg);
1319 let start = Instant::now();
1320 Ok::<_, ConnectError>((start, ConnectingTcp::connect(connecting_tcp).await?))
1321 })
1322 .unwrap();
1323 let res = if stream.peer_addr().unwrap().is_ipv4() {
1324 4
1325 } else {
1326 6
1327 };
1328 let duration = start.elapsed();
1329
1330 let min_duration = if timeout >= Duration::from_millis(150) {
1332 timeout - Duration::from_millis(150)
1333 } else {
1334 Duration::default()
1335 };
1336 let max_duration = timeout + Duration::from_millis(150);
1337
1338 assert_eq!(res, family);
1339 assert!(duration >= min_duration);
1340 assert!(duration <= max_duration);
1341 }
1342
1343 fn local_ipv4_addr() -> IpAddr {
1344 Ipv4Addr::new(127, 0, 0, 1).into()
1345 }
1346
1347 fn local_ipv6_addr() -> IpAddr {
1348 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()
1349 }
1350
1351 fn unreachable_ipv4_addr() -> IpAddr {
1352 Ipv4Addr::new(127, 0, 0, 2).into()
1353 }
1354
1355 fn unreachable_ipv6_addr() -> IpAddr {
1356 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2).into()
1357 }
1358
1359 fn slow_ipv4_addr() -> IpAddr {
1360 Ipv4Addr::new(198, 18, 0, 25).into()
1362 }
1363
1364 fn slow_ipv6_addr() -> IpAddr {
1365 Ipv6Addr::new(2001, 2, 0, 0, 0, 0, 0, 254).into()
1367 }
1368
1369 fn measure_connect(addr: IpAddr) -> (bool, Duration) {
1370 let start = Instant::now();
1371 let result =
1372 std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1));
1373
1374 let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut;
1375 let duration = start.elapsed();
1376 (reachable, duration)
1377 }
1378 }
1379
1380 use std::time::Duration;
1381
1382 #[test]
1383 fn no_tcp_keepalive_config() {
1384 assert!(TcpKeepaliveConfig::default().into_tcpkeepalive().is_none());
1385 }
1386
1387 #[test]
1388 fn tcp_keepalive_time_config() {
1389 let kac = TcpKeepaliveConfig {
1390 time: Some(Duration::from_secs(60)),
1391 ..Default::default()
1392 };
1393 if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1394 assert!(format!("{tcp_keepalive:?}").contains("time: Some(60s)"));
1395 } else {
1396 panic!("test failed");
1397 }
1398 }
1399
1400 #[cfg(not(any(target_os = "openbsd", target_os = "redox", target_os = "solaris")))]
1401 #[test]
1402 fn tcp_keepalive_interval_config() {
1403 let kac = TcpKeepaliveConfig {
1404 interval: Some(Duration::from_secs(1)),
1405 ..Default::default()
1406 };
1407 if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1408 assert!(format!("{tcp_keepalive:?}").contains("interval: Some(1s)"));
1409 } else {
1410 panic!("test failed");
1411 }
1412 }
1413
1414 #[cfg(not(any(
1415 target_os = "openbsd",
1416 target_os = "redox",
1417 target_os = "solaris",
1418 target_os = "windows"
1419 )))]
1420 #[test]
1421 fn tcp_keepalive_retries_config() {
1422 let kac = TcpKeepaliveConfig {
1423 retries: Some(3),
1424 ..Default::default()
1425 };
1426 if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1427 assert!(format!("{tcp_keepalive:?}").contains("retries: Some(3)"));
1428 } else {
1429 panic!("test failed");
1430 }
1431 }
1432
1433 #[test]
1434 fn test_set_port() {
1435 let mut addr = SocketAddr::from(([0, 0, 0, 0], 6881));
1437 set_port(&mut addr, 42, true);
1438 assert_eq!(addr.port(), 42);
1439
1440 let mut addr = SocketAddr::from(([0, 0, 0, 0], 6881));
1442 set_port(&mut addr, 443, false);
1443 assert_eq!(addr.port(), 6881);
1444
1445 let mut addr = SocketAddr::from(([0, 0, 0, 0], 0));
1447 set_port(&mut addr, 443, false);
1448 assert_eq!(addr.port(), 443);
1449 }
1450}