1use std::error::Error as StdError;
2use std::fmt;
3use std::future::Future;
4use std::io;
5use std::marker::PhantomData;
6use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{self, ready, Poll};
10use std::time::Duration;
11
12use futures_util::future::Either;
13use http::uri::{Scheme, Uri};
14use pin_project_lite::pin_project;
15use socket2::TcpKeepalive;
16use tokio::net::{TcpSocket, TcpStream};
17use tokio::time::Sleep;
18use tracing::{debug, trace, warn};
19
20use super::dns::{self, resolve, GaiResolver, Resolve};
21use super::{Connected, Connection};
22use crate::rt::TokioIo;
23
24#[derive(Clone)]
33pub struct HttpConnector<R = GaiResolver> {
34 config: Arc<Config>,
35 resolver: R,
36}
37
38#[derive(Clone, Debug)]
62pub struct HttpInfo {
63 remote_addr: SocketAddr,
64 local_addr: SocketAddr,
65}
66
67#[derive(Clone)]
68struct Config {
69 connect_timeout: Option<Duration>,
70 enforce_http: bool,
71 happy_eyeballs_timeout: Option<Duration>,
72 tcp_keepalive_config: TcpKeepaliveConfig,
73 local_address_ipv4: Option<Ipv4Addr>,
74 local_address_ipv6: Option<Ipv6Addr>,
75 nodelay: bool,
76 reuse_address: bool,
77 send_buffer_size: Option<usize>,
78 recv_buffer_size: Option<usize>,
79 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
80 interface: Option<String>,
81 #[cfg(any(
82 target_os = "illumos",
83 target_os = "ios",
84 target_os = "macos",
85 target_os = "solaris",
86 target_os = "tvos",
87 target_os = "visionos",
88 target_os = "watchos",
89 ))]
90 interface: Option<std::ffi::CString>,
91 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
92 tcp_user_timeout: Option<Duration>,
93}
94
95#[derive(Default, Debug, Clone, Copy)]
96struct TcpKeepaliveConfig {
97 time: Option<Duration>,
98 interval: Option<Duration>,
99 retries: Option<u32>,
100}
101
102impl TcpKeepaliveConfig {
103 fn into_tcpkeepalive(self) -> Option<TcpKeepalive> {
105 let mut dirty = false;
106 let mut ka = TcpKeepalive::new();
107 if let Some(time) = self.time {
108 ka = ka.with_time(time);
109 dirty = true
110 }
111 if let Some(interval) = self.interval {
112 ka = Self::ka_with_interval(ka, interval, &mut dirty)
113 };
114 if let Some(retries) = self.retries {
115 ka = Self::ka_with_retries(ka, retries, &mut dirty)
116 };
117 if dirty {
118 Some(ka)
119 } else {
120 None
121 }
122 }
123
124 #[cfg(
125 any(
127 target_os = "android",
128 target_os = "dragonfly",
129 target_os = "freebsd",
130 target_os = "fuchsia",
131 target_os = "illumos",
132 target_os = "ios",
133 target_os = "visionos",
134 target_os = "linux",
135 target_os = "macos",
136 target_os = "netbsd",
137 target_os = "tvos",
138 target_os = "watchos",
139 target_os = "windows",
140 )
141 )]
142 fn ka_with_interval(ka: TcpKeepalive, interval: Duration, dirty: &mut bool) -> TcpKeepalive {
143 *dirty = true;
144 ka.with_interval(interval)
145 }
146
147 #[cfg(not(
148 any(
150 target_os = "android",
151 target_os = "dragonfly",
152 target_os = "freebsd",
153 target_os = "fuchsia",
154 target_os = "illumos",
155 target_os = "ios",
156 target_os = "visionos",
157 target_os = "linux",
158 target_os = "macos",
159 target_os = "netbsd",
160 target_os = "tvos",
161 target_os = "watchos",
162 target_os = "windows",
163 )
164 ))]
165 fn ka_with_interval(ka: TcpKeepalive, _: Duration, _: &mut bool) -> TcpKeepalive {
166 ka }
168
169 #[cfg(
170 any(
172 target_os = "android",
173 target_os = "dragonfly",
174 target_os = "freebsd",
175 target_os = "fuchsia",
176 target_os = "illumos",
177 target_os = "ios",
178 target_os = "visionos",
179 target_os = "linux",
180 target_os = "macos",
181 target_os = "netbsd",
182 target_os = "tvos",
183 target_os = "watchos",
184 )
185 )]
186 fn ka_with_retries(ka: TcpKeepalive, retries: u32, dirty: &mut bool) -> TcpKeepalive {
187 *dirty = true;
188 ka.with_retries(retries)
189 }
190
191 #[cfg(not(
192 any(
194 target_os = "android",
195 target_os = "dragonfly",
196 target_os = "freebsd",
197 target_os = "fuchsia",
198 target_os = "illumos",
199 target_os = "ios",
200 target_os = "visionos",
201 target_os = "linux",
202 target_os = "macos",
203 target_os = "netbsd",
204 target_os = "tvos",
205 target_os = "watchos",
206 )
207 ))]
208 fn ka_with_retries(ka: TcpKeepalive, _: u32, _: &mut bool) -> TcpKeepalive {
209 ka }
211}
212
213impl HttpConnector {
216 pub fn new() -> HttpConnector {
218 HttpConnector::new_with_resolver(GaiResolver::new())
219 }
220}
221
222impl<R> HttpConnector<R> {
223 pub fn new_with_resolver(resolver: R) -> HttpConnector<R> {
227 HttpConnector {
228 config: Arc::new(Config {
229 connect_timeout: None,
230 enforce_http: true,
231 happy_eyeballs_timeout: Some(Duration::from_millis(300)),
232 tcp_keepalive_config: TcpKeepaliveConfig::default(),
233 local_address_ipv4: None,
234 local_address_ipv6: None,
235 nodelay: false,
236 reuse_address: false,
237 send_buffer_size: None,
238 recv_buffer_size: None,
239 #[cfg(any(
240 target_os = "android",
241 target_os = "fuchsia",
242 target_os = "illumos",
243 target_os = "ios",
244 target_os = "linux",
245 target_os = "macos",
246 target_os = "solaris",
247 target_os = "tvos",
248 target_os = "visionos",
249 target_os = "watchos",
250 ))]
251 interface: None,
252 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
253 tcp_user_timeout: None,
254 }),
255 resolver,
256 }
257 }
258
259 #[inline]
263 pub fn enforce_http(&mut self, is_enforced: bool) {
264 self.config_mut().enforce_http = is_enforced;
265 }
266
267 #[inline]
274 pub fn set_keepalive(&mut self, time: Option<Duration>) {
275 self.config_mut().tcp_keepalive_config.time = time;
276 }
277
278 #[inline]
281 pub fn set_keepalive_interval(&mut self, interval: Option<Duration>) {
282 self.config_mut().tcp_keepalive_config.interval = interval;
283 }
284
285 #[inline]
287 pub fn set_keepalive_retries(&mut self, retries: Option<u32>) {
288 self.config_mut().tcp_keepalive_config.retries = retries;
289 }
290
291 #[inline]
295 pub fn set_nodelay(&mut self, nodelay: bool) {
296 self.config_mut().nodelay = nodelay;
297 }
298
299 #[inline]
301 pub fn set_send_buffer_size(&mut self, size: Option<usize>) {
302 self.config_mut().send_buffer_size = size;
303 }
304
305 #[inline]
307 pub fn set_recv_buffer_size(&mut self, size: Option<usize>) {
308 self.config_mut().recv_buffer_size = size;
309 }
310
311 #[inline]
317 pub fn set_local_address(&mut self, addr: Option<IpAddr>) {
318 let (v4, v6) = match addr {
319 Some(IpAddr::V4(a)) => (Some(a), None),
320 Some(IpAddr::V6(a)) => (None, Some(a)),
321 _ => (None, None),
322 };
323
324 let cfg = self.config_mut();
325
326 cfg.local_address_ipv4 = v4;
327 cfg.local_address_ipv6 = v6;
328 }
329
330 #[inline]
333 pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) {
334 let cfg = self.config_mut();
335
336 cfg.local_address_ipv4 = Some(addr_ipv4);
337 cfg.local_address_ipv6 = Some(addr_ipv6);
338 }
339
340 #[inline]
347 pub fn set_connect_timeout(&mut self, dur: Option<Duration>) {
348 self.config_mut().connect_timeout = dur;
349 }
350
351 #[inline]
364 pub fn set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>) {
365 self.config_mut().happy_eyeballs_timeout = dur;
366 }
367
368 #[inline]
372 pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self {
373 self.config_mut().reuse_address = reuse_address;
374 self
375 }
376
377 #[cfg(any(
402 target_os = "android",
403 target_os = "fuchsia",
404 target_os = "illumos",
405 target_os = "ios",
406 target_os = "linux",
407 target_os = "macos",
408 target_os = "solaris",
409 target_os = "tvos",
410 target_os = "visionos",
411 target_os = "watchos",
412 ))]
413 #[inline]
414 pub fn set_interface<S: Into<String>>(&mut self, interface: S) -> &mut Self {
415 let interface = interface.into();
416 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
417 {
418 self.config_mut().interface = Some(interface);
419 }
420 #[cfg(not(any(target_os = "android", target_os = "fuchsia", target_os = "linux")))]
421 {
422 let interface = std::ffi::CString::new(interface)
423 .expect("interface name should not have nulls in it");
424 self.config_mut().interface = Some(interface);
425 }
426 self
427 }
428
429 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
431 #[inline]
432 pub fn set_tcp_user_timeout(&mut self, time: Option<Duration>) {
433 self.config_mut().tcp_user_timeout = time;
434 }
435
436 fn config_mut(&mut self) -> &mut Config {
439 Arc::make_mut(&mut self.config)
443 }
444}
445
446static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http";
447static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing";
448static INVALID_MISSING_HOST: &str = "invalid URL, host is missing";
449
450impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> {
452 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
453 f.debug_struct("HttpConnector").finish()
454 }
455}
456
457impl<R> tower_service::Service<Uri> for HttpConnector<R>
458where
459 R: Resolve + Clone + Send + Sync + 'static,
460 R::Future: Send,
461{
462 type Response = TokioIo<TcpStream>;
463 type Error = ConnectError;
464 type Future = HttpConnecting<R>;
465
466 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
467 ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
468 Poll::Ready(Ok(()))
469 }
470
471 fn call(&mut self, dst: Uri) -> Self::Future {
472 let mut self_ = self.clone();
473 HttpConnecting {
474 fut: Box::pin(async move { self_.call_async(dst).await }),
475 _marker: PhantomData,
476 }
477 }
478}
479
480fn get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError> {
481 trace!(
482 "Http::connect; scheme={:?}, host={:?}, port={:?}",
483 dst.scheme(),
484 dst.host(),
485 dst.port(),
486 );
487
488 if config.enforce_http {
489 if dst.scheme() != Some(&Scheme::HTTP) {
490 return Err(ConnectError {
491 msg: INVALID_NOT_HTTP,
492 addr: None,
493 cause: None,
494 });
495 }
496 } else if dst.scheme().is_none() {
497 return Err(ConnectError {
498 msg: INVALID_MISSING_SCHEME,
499 addr: None,
500 cause: None,
501 });
502 }
503
504 let host = match dst.host() {
505 Some(s) => s,
506 None => {
507 return Err(ConnectError {
508 msg: INVALID_MISSING_HOST,
509 addr: None,
510 cause: None,
511 });
512 }
513 };
514 let port = match dst.port() {
515 Some(port) => port.as_u16(),
516 None => {
517 if dst.scheme() == Some(&Scheme::HTTPS) {
518 443
519 } else {
520 80
521 }
522 }
523 };
524
525 Ok((host, port))
526}
527
528impl<R> HttpConnector<R>
529where
530 R: Resolve,
531{
532 async fn call_async(&mut self, dst: Uri) -> Result<TokioIo<TcpStream>, ConnectError> {
533 let config = &self.config;
534
535 let (host, port) = get_host_port(config, &dst)?;
536 let host = host.trim_start_matches('[').trim_end_matches(']');
537
538 let addrs = if let Some(addrs) = dns::SocketAddrs::try_parse(host, port) {
541 addrs
542 } else {
543 let addrs = resolve(&mut self.resolver, dns::Name::new(host.into()))
544 .await
545 .map_err(ConnectError::dns)?;
546 let addrs = addrs
547 .map(|mut addr| {
548 set_port(&mut addr, port, dst.port().is_some());
549
550 addr
551 })
552 .collect();
553 dns::SocketAddrs::new(addrs)
554 };
555
556 let c = ConnectingTcp::new(addrs, config);
557
558 let sock = c.connect().await?;
559
560 if let Err(e) = sock.set_nodelay(config.nodelay) {
561 warn!("tcp set_nodelay error: {}", e);
562 }
563
564 Ok(TokioIo::new(sock))
565 }
566}
567
568impl Connection for TcpStream {
569 fn connected(&self) -> Connected {
570 let connected = Connected::new();
571 if let (Ok(remote_addr), Ok(local_addr)) = (self.peer_addr(), self.local_addr()) {
572 connected.extra(HttpInfo {
573 remote_addr,
574 local_addr,
575 })
576 } else {
577 connected
578 }
579 }
580}
581
582#[cfg(unix)]
583impl Connection for tokio::net::UnixStream {
584 fn connected(&self) -> Connected {
585 Connected::new()
586 }
587}
588
589#[cfg(windows)]
590impl Connection for tokio::net::windows::named_pipe::NamedPipeClient {
591 fn connected(&self) -> Connected {
592 Connected::new()
593 }
594}
595
596impl<T> Connection for TokioIo<T>
599where
600 T: Connection,
601{
602 fn connected(&self) -> Connected {
603 self.inner().connected()
604 }
605}
606
607impl HttpInfo {
608 pub fn remote_addr(&self) -> SocketAddr {
610 self.remote_addr
611 }
612
613 pub fn local_addr(&self) -> SocketAddr {
615 self.local_addr
616 }
617}
618
619pin_project! {
620 #[must_use = "futures do nothing unless polled"]
626 #[allow(missing_debug_implementations)]
627 pub struct HttpConnecting<R> {
628 #[pin]
629 fut: BoxConnecting,
630 _marker: PhantomData<R>,
631 }
632}
633
634type ConnectResult = Result<TokioIo<TcpStream>, ConnectError>;
635type BoxConnecting = Pin<Box<dyn Future<Output = ConnectResult> + Send>>;
636
637impl<R: Resolve> Future for HttpConnecting<R> {
638 type Output = ConnectResult;
639
640 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
641 self.project().fut.poll(cx)
642 }
643}
644
645pub struct ConnectError {
647 msg: &'static str,
648 addr: Option<SocketAddr>,
649 cause: Option<Box<dyn StdError + Send + Sync>>,
650}
651
652impl ConnectError {
653 fn new<E>(msg: &'static str, cause: E) -> ConnectError
654 where
655 E: Into<Box<dyn StdError + Send + Sync>>,
656 {
657 ConnectError {
658 msg,
659 addr: None,
660 cause: Some(cause.into()),
661 }
662 }
663
664 fn dns<E>(cause: E) -> ConnectError
665 where
666 E: Into<Box<dyn StdError + Send + Sync>>,
667 {
668 ConnectError::new("dns error", cause)
669 }
670
671 fn m<E>(msg: &'static str) -> impl FnOnce(E) -> ConnectError
672 where
673 E: Into<Box<dyn StdError + Send + Sync>>,
674 {
675 move |cause| ConnectError::new(msg, cause)
676 }
677}
678
679impl fmt::Debug for ConnectError {
680 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
681 let mut b = f.debug_tuple("ConnectError");
682 b.field(&self.msg);
683 if let Some(ref addr) = self.addr {
684 b.field(addr);
685 }
686 if let Some(ref cause) = self.cause {
687 b.field(cause);
688 }
689 b.finish()
690 }
691}
692
693impl fmt::Display for ConnectError {
694 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
695 f.write_str(self.msg)
696 }
697}
698
699impl StdError for ConnectError {
700 fn source(&self) -> Option<&(dyn StdError + 'static)> {
701 self.cause.as_ref().map(|e| &**e as _)
702 }
703}
704
705struct ConnectingTcp<'a> {
706 preferred: ConnectingTcpRemote,
707 fallback: Option<ConnectingTcpFallback>,
708 config: &'a Config,
709}
710
711impl<'a> ConnectingTcp<'a> {
712 fn new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self {
713 if let Some(fallback_timeout) = config.happy_eyeballs_timeout {
714 let (preferred_addrs, fallback_addrs) = remote_addrs
715 .split_by_preference(config.local_address_ipv4, config.local_address_ipv6);
716 if fallback_addrs.is_empty() {
717 return ConnectingTcp {
718 preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
719 fallback: None,
720 config,
721 };
722 }
723
724 ConnectingTcp {
725 preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
726 fallback: Some(ConnectingTcpFallback {
727 delay: tokio::time::sleep(fallback_timeout),
728 remote: ConnectingTcpRemote::new(fallback_addrs, config.connect_timeout),
729 }),
730 config,
731 }
732 } else {
733 ConnectingTcp {
734 preferred: ConnectingTcpRemote::new(remote_addrs, config.connect_timeout),
735 fallback: None,
736 config,
737 }
738 }
739 }
740}
741
742struct ConnectingTcpFallback {
743 delay: Sleep,
744 remote: ConnectingTcpRemote,
745}
746
747struct ConnectingTcpRemote {
748 addrs: dns::SocketAddrs,
749 connect_timeout: Option<Duration>,
750}
751
752impl ConnectingTcpRemote {
753 fn new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self {
754 let connect_timeout = connect_timeout.and_then(|t| t.checked_div(addrs.len() as u32));
755
756 Self {
757 addrs,
758 connect_timeout,
759 }
760 }
761}
762
763impl ConnectingTcpRemote {
764 async fn connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError> {
765 let mut err = None;
766 for addr in &mut self.addrs {
767 debug!("connecting to {}", addr);
768 match connect(&addr, config, self.connect_timeout)?.await {
769 Ok(tcp) => {
770 debug!("connected to {}", addr);
771 return Ok(tcp);
772 }
773 Err(mut e) => {
774 trace!("connect error for {}: {:?}", addr, e);
775 e.addr = Some(addr);
776 if err.is_none() {
778 err = Some(e);
779 }
780 }
781 }
782 }
783
784 match err {
785 Some(e) => Err(e),
786 None => Err(ConnectError::new(
787 "tcp connect error",
788 std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"),
789 )),
790 }
791 }
792}
793
794fn bind_local_address(
795 socket: &socket2::Socket,
796 dst_addr: &SocketAddr,
797 local_addr_ipv4: &Option<Ipv4Addr>,
798 local_addr_ipv6: &Option<Ipv6Addr>,
799) -> io::Result<()> {
800 match (*dst_addr, local_addr_ipv4, local_addr_ipv6) {
801 (SocketAddr::V4(_), Some(addr), _) => {
802 socket.bind(&SocketAddr::new((*addr).into(), 0).into())?;
803 }
804 (SocketAddr::V6(_), _, Some(addr)) => {
805 socket.bind(&SocketAddr::new((*addr).into(), 0).into())?;
806 }
807 _ => {
808 if cfg!(windows) {
809 let any: SocketAddr = match *dst_addr {
811 SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
812 SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
813 };
814 socket.bind(&any.into())?;
815 }
816 }
817 }
818
819 Ok(())
820}
821
822fn connect(
823 addr: &SocketAddr,
824 config: &Config,
825 connect_timeout: Option<Duration>,
826) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> {
827 use socket2::{Domain, Protocol, Socket, Type};
831
832 let domain = Domain::for_address(*addr);
833 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
834 .map_err(ConnectError::m("tcp open error"))?;
835
836 socket
839 .set_nonblocking(true)
840 .map_err(ConnectError::m("tcp set_nonblocking error"))?;
841
842 if let Some(tcp_keepalive) = &config.tcp_keepalive_config.into_tcpkeepalive() {
843 if let Err(e) = socket.set_tcp_keepalive(tcp_keepalive) {
844 warn!("tcp set_keepalive error: {}", e);
845 }
846 }
847
848 #[cfg(any(
850 target_os = "android",
851 target_os = "fuchsia",
852 target_os = "illumos",
853 target_os = "ios",
854 target_os = "linux",
855 target_os = "macos",
856 target_os = "solaris",
857 target_os = "tvos",
858 target_os = "visionos",
859 target_os = "watchos",
860 ))]
861 if let Some(interface) = &config.interface {
862 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
865 socket
866 .bind_device(Some(interface.as_bytes()))
867 .map_err(ConnectError::m("tcp bind interface error"))?;
868
869 #[cfg(any(
874 target_os = "illumos",
875 target_os = "ios",
876 target_os = "macos",
877 target_os = "solaris",
878 target_os = "tvos",
879 target_os = "visionos",
880 target_os = "watchos",
881 ))]
882 {
883 let idx = unsafe { libc::if_nametoindex(interface.as_ptr()) };
884 let idx = std::num::NonZeroU32::new(idx).ok_or_else(|| {
885 ConnectError::new(
887 "error converting interface name to index",
888 io::Error::last_os_error(),
889 )
890 })?;
891 match addr {
894 SocketAddr::V4(_) => socket.bind_device_by_index_v4(Some(idx)),
895 SocketAddr::V6(_) => socket.bind_device_by_index_v6(Some(idx)),
896 }
897 .map_err(ConnectError::m("tcp bind interface error"))?;
898 }
899 }
900
901 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
902 if let Some(tcp_user_timeout) = &config.tcp_user_timeout {
903 if let Err(e) = socket.set_tcp_user_timeout(Some(*tcp_user_timeout)) {
904 warn!("tcp set_tcp_user_timeout error: {}", e);
905 }
906 }
907
908 bind_local_address(
909 &socket,
910 addr,
911 &config.local_address_ipv4,
912 &config.local_address_ipv6,
913 )
914 .map_err(ConnectError::m("tcp bind local error"))?;
915
916 let socket = TcpSocket::from_std_stream(socket.into());
918
919 if config.reuse_address {
920 if let Err(e) = socket.set_reuseaddr(true) {
921 warn!("tcp set_reuse_address error: {}", e);
922 }
923 }
924
925 if let Some(size) = config.send_buffer_size {
926 if let Err(e) = socket.set_send_buffer_size(size.try_into().unwrap_or(u32::MAX)) {
927 warn!("tcp set_buffer_size error: {}", e);
928 }
929 }
930
931 if let Some(size) = config.recv_buffer_size {
932 if let Err(e) = socket.set_recv_buffer_size(size.try_into().unwrap_or(u32::MAX)) {
933 warn!("tcp set_recv_buffer_size error: {}", e);
934 }
935 }
936
937 let connect = socket.connect(*addr);
938 Ok(async move {
939 match connect_timeout {
940 Some(dur) => match tokio::time::timeout(dur, connect).await {
941 Ok(Ok(s)) => Ok(s),
942 Ok(Err(e)) => Err(e),
943 Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
944 },
945 None => connect.await,
946 }
947 .map_err(ConnectError::m("tcp connect error"))
948 })
949}
950
951impl ConnectingTcp<'_> {
952 async fn connect(mut self) -> Result<TcpStream, ConnectError> {
953 match self.fallback {
954 None => self.preferred.connect(self.config).await,
955 Some(mut fallback) => {
956 let preferred_fut = self.preferred.connect(self.config);
957 futures_util::pin_mut!(preferred_fut);
958
959 let fallback_fut = fallback.remote.connect(self.config);
960 futures_util::pin_mut!(fallback_fut);
961
962 let fallback_delay = fallback.delay;
963 futures_util::pin_mut!(fallback_delay);
964
965 let (result, future) =
966 match futures_util::future::select(preferred_fut, fallback_delay).await {
967 Either::Left((result, _fallback_delay)) => {
968 (result, Either::Right(fallback_fut))
969 }
970 Either::Right(((), preferred_fut)) => {
971 futures_util::future::select(preferred_fut, fallback_fut)
973 .await
974 .factor_first()
975 }
976 };
977
978 if result.is_err() {
979 future.await
982 } else {
983 result
984 }
985 }
986 }
987 }
988}
989
990fn set_port(addr: &mut SocketAddr, host_port: u16, explicit: bool) {
994 if explicit || addr.port() == 0 {
995 addr.set_port(host_port)
996 };
997}
998
999#[cfg(test)]
1000mod tests {
1001 use std::io;
1002 use std::net::SocketAddr;
1003
1004 use ::http::Uri;
1005
1006 use crate::client::legacy::connect::http::TcpKeepaliveConfig;
1007
1008 use super::super::sealed::{Connect, ConnectSvc};
1009 use super::{Config, ConnectError, HttpConnector};
1010
1011 use super::set_port;
1012
1013 async fn connect<C>(
1014 connector: C,
1015 dst: Uri,
1016 ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error>
1017 where
1018 C: Connect,
1019 {
1020 connector.connect(super::super::sealed::Internal, dst).await
1021 }
1022
1023 #[tokio::test]
1024 async fn test_errors_enforce_http() {
1025 let dst = "https://example.domain/foo/bar?baz".parse().unwrap();
1026 let connector = HttpConnector::new();
1027
1028 let err = connect(connector, dst).await.unwrap_err();
1029 assert_eq!(&*err.msg, super::INVALID_NOT_HTTP);
1030 }
1031
1032 #[cfg(any(target_os = "linux", target_os = "macos"))]
1033 fn get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>) {
1034 use std::net::{IpAddr, TcpListener};
1035
1036 let mut ip_v4 = None;
1037 let mut ip_v6 = None;
1038
1039 let ips = pnet_datalink::interfaces()
1040 .into_iter()
1041 .flat_map(|i| i.ips.into_iter().map(|n| n.ip()));
1042
1043 for ip in ips {
1044 match ip {
1045 IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip),
1046 IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip),
1047 _ => (),
1048 }
1049
1050 if ip_v4.is_some() && ip_v6.is_some() {
1051 break;
1052 }
1053 }
1054
1055 (ip_v4, ip_v6)
1056 }
1057
1058 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
1059 fn default_interface() -> Option<String> {
1060 pnet_datalink::interfaces()
1061 .iter()
1062 .find(|e| e.is_up() && !e.is_loopback() && !e.ips.is_empty())
1063 .map(|e| e.name.clone())
1064 }
1065
1066 #[tokio::test]
1067 async fn test_errors_missing_scheme() {
1068 let dst = "example.domain".parse().unwrap();
1069 let mut connector = HttpConnector::new();
1070 connector.enforce_http(false);
1071
1072 let err = connect(connector, dst).await.unwrap_err();
1073 assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME);
1074 }
1075
1076 #[cfg(any(target_os = "linux", target_os = "macos"))]
1078 #[cfg_attr(miri, ignore)]
1079 #[tokio::test]
1080 async fn local_address() {
1081 use std::net::{IpAddr, TcpListener};
1082
1083 let (bind_ip_v4, bind_ip_v6) = get_local_ips();
1084 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
1085 let port = server4.local_addr().unwrap().port();
1086 let server6 = TcpListener::bind(format!("[::1]:{port}")).unwrap();
1087
1088 let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move {
1089 let mut connector = HttpConnector::new();
1090
1091 match (bind_ip_v4, bind_ip_v6) {
1092 (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6),
1093 (Some(v4), None) => connector.set_local_address(Some(v4.into())),
1094 (None, Some(v6)) => connector.set_local_address(Some(v6.into())),
1095 _ => unreachable!(),
1096 }
1097
1098 connect(connector, dst.parse().unwrap()).await.unwrap();
1099
1100 let (_, client_addr) = server.accept().unwrap();
1101
1102 assert_eq!(client_addr.ip(), expected_ip);
1103 };
1104
1105 if let Some(ip) = bind_ip_v4 {
1106 assert_client_ip(format!("http://127.0.0.1:{port}"), server4, ip.into()).await;
1107 }
1108
1109 if let Some(ip) = bind_ip_v6 {
1110 assert_client_ip(format!("http://[::1]:{port}"), server6, ip.into()).await;
1111 }
1112 }
1113
1114 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
1116 #[tokio::test]
1117 #[ignore = "setting `SO_BINDTODEVICE` requires the `CAP_NET_RAW` capability (works when running as root)"]
1118 async fn interface() {
1119 use socket2::{Domain, Protocol, Socket, Type};
1120 use std::net::TcpListener;
1121
1122 let interface: Option<String> = default_interface();
1123
1124 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
1125 let port = server4.local_addr().unwrap().port();
1126
1127 let server6 = TcpListener::bind(format!("[::1]:{port}")).unwrap();
1128
1129 let assert_interface_name =
1130 |dst: String,
1131 server: TcpListener,
1132 bind_iface: Option<String>,
1133 expected_interface: Option<String>| async move {
1134 let mut connector = HttpConnector::new();
1135 if let Some(iface) = bind_iface {
1136 connector.set_interface(iface);
1137 }
1138
1139 connect(connector, dst.parse().unwrap()).await.unwrap();
1140 let domain = Domain::for_address(server.local_addr().unwrap());
1141 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP)).unwrap();
1142
1143 assert_eq!(
1144 socket.device().unwrap().as_deref(),
1145 expected_interface.as_deref().map(|val| val.as_bytes())
1146 );
1147 };
1148
1149 assert_interface_name(
1150 format!("http://127.0.0.1:{port}"),
1151 server4,
1152 interface.clone(),
1153 interface.clone(),
1154 )
1155 .await;
1156 assert_interface_name(
1157 format!("http://[::1]:{port}"),
1158 server6,
1159 interface.clone(),
1160 interface.clone(),
1161 )
1162 .await;
1163 }
1164
1165 #[test]
1166 #[ignore] #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
1168 fn client_happy_eyeballs() {
1169 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener};
1170 use std::time::{Duration, Instant};
1171
1172 use super::dns;
1173 use super::ConnectingTcp;
1174
1175 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
1176 let addr = server4.local_addr().unwrap();
1177 let _server6 = TcpListener::bind(format!("[::1]:{}", addr.port())).unwrap();
1178 let rt = tokio::runtime::Builder::new_current_thread()
1179 .enable_all()
1180 .build()
1181 .unwrap();
1182
1183 let local_timeout = Duration::default();
1184 let unreachable_v4_timeout = measure_connect(unreachable_ipv4_addr()).1;
1185 let unreachable_v6_timeout = measure_connect(unreachable_ipv6_addr()).1;
1186 let fallback_timeout = std::cmp::max(unreachable_v4_timeout, unreachable_v6_timeout)
1187 + Duration::from_millis(250);
1188
1189 let scenarios = &[
1190 (&[local_ipv4_addr()][..], 4, local_timeout, false),
1192 (&[local_ipv6_addr()][..], 6, local_timeout, false),
1193 (
1195 &[local_ipv4_addr(), local_ipv6_addr()][..],
1196 4,
1197 local_timeout,
1198 false,
1199 ),
1200 (
1201 &[local_ipv6_addr(), local_ipv4_addr()][..],
1202 6,
1203 local_timeout,
1204 false,
1205 ),
1206 (
1208 &[unreachable_ipv4_addr(), local_ipv4_addr()][..],
1209 4,
1210 unreachable_v4_timeout,
1211 false,
1212 ),
1213 (
1214 &[unreachable_ipv6_addr(), local_ipv6_addr()][..],
1215 6,
1216 unreachable_v6_timeout,
1217 false,
1218 ),
1219 (
1221 &[
1222 unreachable_ipv4_addr(),
1223 local_ipv4_addr(),
1224 local_ipv6_addr(),
1225 ][..],
1226 4,
1227 unreachable_v4_timeout,
1228 false,
1229 ),
1230 (
1231 &[
1232 unreachable_ipv6_addr(),
1233 local_ipv6_addr(),
1234 local_ipv4_addr(),
1235 ][..],
1236 6,
1237 unreachable_v6_timeout,
1238 true,
1239 ),
1240 (
1242 &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..],
1243 6,
1244 fallback_timeout,
1245 false,
1246 ),
1247 (
1248 &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..],
1249 4,
1250 fallback_timeout,
1251 true,
1252 ),
1253 (
1255 &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..],
1256 6,
1257 fallback_timeout + unreachable_v6_timeout,
1258 false,
1259 ),
1260 (
1261 &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..],
1262 4,
1263 fallback_timeout + unreachable_v4_timeout,
1264 true,
1265 ),
1266 ];
1267
1268 let ipv6_accessible = measure_connect(slow_ipv6_addr()).0;
1271
1272 for &(hosts, family, timeout, needs_ipv6_access) in scenarios {
1273 if needs_ipv6_access && !ipv6_accessible {
1274 continue;
1275 }
1276
1277 let (start, stream) = rt
1278 .block_on(async move {
1279 let addrs = hosts
1280 .iter()
1281 .map(|host| (*host, addr.port()).into())
1282 .collect();
1283 let cfg = Config {
1284 local_address_ipv4: None,
1285 local_address_ipv6: None,
1286 connect_timeout: None,
1287 tcp_keepalive_config: TcpKeepaliveConfig::default(),
1288 happy_eyeballs_timeout: Some(fallback_timeout),
1289 nodelay: false,
1290 reuse_address: false,
1291 enforce_http: false,
1292 send_buffer_size: None,
1293 recv_buffer_size: None,
1294 #[cfg(any(
1295 target_os = "android",
1296 target_os = "fuchsia",
1297 target_os = "linux"
1298 ))]
1299 interface: None,
1300 #[cfg(any(
1301 target_os = "illumos",
1302 target_os = "ios",
1303 target_os = "macos",
1304 target_os = "solaris",
1305 target_os = "tvos",
1306 target_os = "visionos",
1307 target_os = "watchos",
1308 ))]
1309 interface: None,
1310 #[cfg(any(
1311 target_os = "android",
1312 target_os = "fuchsia",
1313 target_os = "linux"
1314 ))]
1315 tcp_user_timeout: None,
1316 };
1317 let connecting_tcp = ConnectingTcp::new(dns::SocketAddrs::new(addrs), &cfg);
1318 let start = Instant::now();
1319 Ok::<_, ConnectError>((start, ConnectingTcp::connect(connecting_tcp).await?))
1320 })
1321 .unwrap();
1322 let res = if stream.peer_addr().unwrap().is_ipv4() {
1323 4
1324 } else {
1325 6
1326 };
1327 let duration = start.elapsed();
1328
1329 let min_duration = if timeout >= Duration::from_millis(150) {
1331 timeout - Duration::from_millis(150)
1332 } else {
1333 Duration::default()
1334 };
1335 let max_duration = timeout + Duration::from_millis(150);
1336
1337 assert_eq!(res, family);
1338 assert!(duration >= min_duration);
1339 assert!(duration <= max_duration);
1340 }
1341
1342 fn local_ipv4_addr() -> IpAddr {
1343 Ipv4Addr::new(127, 0, 0, 1).into()
1344 }
1345
1346 fn local_ipv6_addr() -> IpAddr {
1347 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()
1348 }
1349
1350 fn unreachable_ipv4_addr() -> IpAddr {
1351 Ipv4Addr::new(127, 0, 0, 2).into()
1352 }
1353
1354 fn unreachable_ipv6_addr() -> IpAddr {
1355 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2).into()
1356 }
1357
1358 fn slow_ipv4_addr() -> IpAddr {
1359 Ipv4Addr::new(198, 18, 0, 25).into()
1361 }
1362
1363 fn slow_ipv6_addr() -> IpAddr {
1364 Ipv6Addr::new(2001, 2, 0, 0, 0, 0, 0, 254).into()
1366 }
1367
1368 fn measure_connect(addr: IpAddr) -> (bool, Duration) {
1369 let start = Instant::now();
1370 let result =
1371 std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1));
1372
1373 let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut;
1374 let duration = start.elapsed();
1375 (reachable, duration)
1376 }
1377 }
1378
1379 use std::time::Duration;
1380
1381 #[test]
1382 fn no_tcp_keepalive_config() {
1383 assert!(TcpKeepaliveConfig::default().into_tcpkeepalive().is_none());
1384 }
1385
1386 #[test]
1387 fn tcp_keepalive_time_config() {
1388 let kac = TcpKeepaliveConfig {
1389 time: Some(Duration::from_secs(60)),
1390 ..Default::default()
1391 };
1392 if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1393 assert!(format!("{tcp_keepalive:?}").contains("time: Some(60s)"));
1394 } else {
1395 panic!("test failed");
1396 }
1397 }
1398
1399 #[cfg(not(any(target_os = "openbsd", target_os = "redox", target_os = "solaris")))]
1400 #[test]
1401 fn tcp_keepalive_interval_config() {
1402 let kac = TcpKeepaliveConfig {
1403 interval: Some(Duration::from_secs(1)),
1404 ..Default::default()
1405 };
1406 if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1407 assert!(format!("{tcp_keepalive:?}").contains("interval: Some(1s)"));
1408 } else {
1409 panic!("test failed");
1410 }
1411 }
1412
1413 #[cfg(not(any(
1414 target_os = "openbsd",
1415 target_os = "redox",
1416 target_os = "solaris",
1417 target_os = "windows"
1418 )))]
1419 #[test]
1420 fn tcp_keepalive_retries_config() {
1421 let kac = TcpKeepaliveConfig {
1422 retries: Some(3),
1423 ..Default::default()
1424 };
1425 if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1426 assert!(format!("{tcp_keepalive:?}").contains("retries: Some(3)"));
1427 } else {
1428 panic!("test failed");
1429 }
1430 }
1431
1432 #[test]
1433 fn test_set_port() {
1434 let mut addr = SocketAddr::from(([0, 0, 0, 0], 6881));
1436 set_port(&mut addr, 42, true);
1437 assert_eq!(addr.port(), 42);
1438
1439 let mut addr = SocketAddr::from(([0, 0, 0, 0], 6881));
1441 set_port(&mut addr, 443, false);
1442 assert_eq!(addr.port(), 6881);
1443
1444 let mut addr = SocketAddr::from(([0, 0, 0, 0], 0));
1446 set_port(&mut addr, 443, false);
1447 assert_eq!(addr.port(), 443);
1448 }
1449}