1use std::error::Error as StdError;
2use std::fmt;
3use std::future::Future;
4use std::io;
5use std::marker::PhantomData;
6use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{self, Poll};
10use std::time::Duration;
11
12use futures_util::future::Either;
13use http::uri::{Scheme, Uri};
14use pin_project_lite::pin_project;
15use socket2::TcpKeepalive;
16use tokio::net::{TcpSocket, TcpStream};
17use tokio::time::Sleep;
18use tracing::{debug, trace, warn};
19
20use super::dns::{self, resolve, GaiResolver, Resolve};
21use super::{Connected, Connection};
22use crate::rt::TokioIo;
23
24#[derive(Clone)]
33pub struct HttpConnector<R = GaiResolver> {
34 config: Arc<Config>,
35 resolver: R,
36}
37
38#[derive(Clone, Debug)]
62pub struct HttpInfo {
63 remote_addr: SocketAddr,
64 local_addr: SocketAddr,
65}
66
67#[derive(Clone)]
68struct Config {
69 connect_timeout: Option<Duration>,
70 enforce_http: bool,
71 happy_eyeballs_timeout: Option<Duration>,
72 tcp_keepalive_config: TcpKeepaliveConfig,
73 local_address_ipv4: Option<Ipv4Addr>,
74 local_address_ipv6: Option<Ipv6Addr>,
75 nodelay: bool,
76 reuse_address: bool,
77 send_buffer_size: Option<usize>,
78 recv_buffer_size: Option<usize>,
79 interface: Option<String>,
80}
81
82#[derive(Default, Debug, Clone, Copy)]
83struct TcpKeepaliveConfig {
84 time: Option<Duration>,
85 interval: Option<Duration>,
86 retries: Option<u32>,
87}
88
89impl TcpKeepaliveConfig {
90 fn into_tcpkeepalive(self) -> Option<TcpKeepalive> {
92 let mut dirty = false;
93 let mut ka = TcpKeepalive::new();
94 if let Some(time) = self.time {
95 ka = ka.with_time(time);
96 dirty = true
97 }
98 if let Some(interval) = self.interval {
99 ka = Self::ka_with_interval(ka, interval, &mut dirty)
100 };
101 if let Some(retries) = self.retries {
102 ka = Self::ka_with_retries(ka, retries, &mut dirty)
103 };
104 if dirty {
105 Some(ka)
106 } else {
107 None
108 }
109 }
110
111 #[cfg(not(any(
112 target_os = "aix",
113 target_os = "openbsd",
114 target_os = "redox",
115 target_os = "solaris"
116 )))]
117 fn ka_with_interval(ka: TcpKeepalive, interval: Duration, dirty: &mut bool) -> TcpKeepalive {
118 *dirty = true;
119 ka.with_interval(interval)
120 }
121
122 #[cfg(any(
123 target_os = "aix",
124 target_os = "openbsd",
125 target_os = "redox",
126 target_os = "solaris"
127 ))]
128 fn ka_with_interval(ka: TcpKeepalive, _: Duration, _: &mut bool) -> TcpKeepalive {
129 ka }
131
132 #[cfg(not(any(
133 target_os = "aix",
134 target_os = "openbsd",
135 target_os = "redox",
136 target_os = "solaris",
137 target_os = "windows"
138 )))]
139 fn ka_with_retries(ka: TcpKeepalive, retries: u32, dirty: &mut bool) -> TcpKeepalive {
140 *dirty = true;
141 ka.with_retries(retries)
142 }
143
144 #[cfg(any(
145 target_os = "aix",
146 target_os = "openbsd",
147 target_os = "redox",
148 target_os = "solaris",
149 target_os = "windows"
150 ))]
151 fn ka_with_retries(ka: TcpKeepalive, _: u32, _: &mut bool) -> TcpKeepalive {
152 ka }
154}
155
156impl HttpConnector {
159 pub fn new() -> HttpConnector {
161 HttpConnector::new_with_resolver(GaiResolver::new())
162 }
163}
164
165impl<R> HttpConnector<R> {
166 pub fn new_with_resolver(resolver: R) -> HttpConnector<R> {
170 HttpConnector {
171 config: Arc::new(Config {
172 connect_timeout: None,
173 enforce_http: true,
174 happy_eyeballs_timeout: Some(Duration::from_millis(300)),
175 tcp_keepalive_config: TcpKeepaliveConfig::default(),
176 local_address_ipv4: None,
177 local_address_ipv6: None,
178 nodelay: false,
179 reuse_address: false,
180 send_buffer_size: None,
181 recv_buffer_size: None,
182 interface: None,
183 }),
184 resolver,
185 }
186 }
187
188 #[inline]
192 pub fn enforce_http(&mut self, is_enforced: bool) {
193 self.config_mut().enforce_http = is_enforced;
194 }
195
196 #[inline]
203 pub fn set_keepalive(&mut self, time: Option<Duration>) {
204 self.config_mut().tcp_keepalive_config.time = time;
205 }
206
207 #[inline]
210 pub fn set_keepalive_interval(&mut self, interval: Option<Duration>) {
211 self.config_mut().tcp_keepalive_config.interval = interval;
212 }
213
214 #[inline]
216 pub fn set_keepalive_retries(&mut self, retries: Option<u32>) {
217 self.config_mut().tcp_keepalive_config.retries = retries;
218 }
219
220 #[inline]
224 pub fn set_nodelay(&mut self, nodelay: bool) {
225 self.config_mut().nodelay = nodelay;
226 }
227
228 #[inline]
230 pub fn set_send_buffer_size(&mut self, size: Option<usize>) {
231 self.config_mut().send_buffer_size = size;
232 }
233
234 #[inline]
236 pub fn set_recv_buffer_size(&mut self, size: Option<usize>) {
237 self.config_mut().recv_buffer_size = size;
238 }
239
240 #[inline]
246 pub fn set_local_address(&mut self, addr: Option<IpAddr>) {
247 let (v4, v6) = match addr {
248 Some(IpAddr::V4(a)) => (Some(a), None),
249 Some(IpAddr::V6(a)) => (None, Some(a)),
250 _ => (None, None),
251 };
252
253 let cfg = self.config_mut();
254
255 cfg.local_address_ipv4 = v4;
256 cfg.local_address_ipv6 = v6;
257 }
258
259 #[inline]
262 pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) {
263 let cfg = self.config_mut();
264
265 cfg.local_address_ipv4 = Some(addr_ipv4);
266 cfg.local_address_ipv6 = Some(addr_ipv6);
267 }
268
269 #[inline]
276 pub fn set_connect_timeout(&mut self, dur: Option<Duration>) {
277 self.config_mut().connect_timeout = dur;
278 }
279
280 #[inline]
293 pub fn set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>) {
294 self.config_mut().happy_eyeballs_timeout = dur;
295 }
296
297 #[inline]
301 pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self {
302 self.config_mut().reuse_address = reuse_address;
303 self
304 }
305
306 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
319 #[inline]
320 pub fn set_interface<S: Into<String>>(&mut self, interface: S) -> &mut Self {
321 self.config_mut().interface = Some(interface.into());
322 self
323 }
324
325 fn config_mut(&mut self) -> &mut Config {
328 Arc::make_mut(&mut self.config)
332 }
333}
334
335static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http";
336static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing";
337static INVALID_MISSING_HOST: &str = "invalid URL, host is missing";
338
339impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> {
341 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
342 f.debug_struct("HttpConnector").finish()
343 }
344}
345
346impl<R> tower_service::Service<Uri> for HttpConnector<R>
347where
348 R: Resolve + Clone + Send + Sync + 'static,
349 R::Future: Send,
350{
351 type Response = TokioIo<TcpStream>;
352 type Error = ConnectError;
353 type Future = HttpConnecting<R>;
354
355 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
356 futures_util::ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
357 Poll::Ready(Ok(()))
358 }
359
360 fn call(&mut self, dst: Uri) -> Self::Future {
361 let mut self_ = self.clone();
362 HttpConnecting {
363 fut: Box::pin(async move { self_.call_async(dst).await }),
364 _marker: PhantomData,
365 }
366 }
367}
368
369fn get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError> {
370 trace!(
371 "Http::connect; scheme={:?}, host={:?}, port={:?}",
372 dst.scheme(),
373 dst.host(),
374 dst.port(),
375 );
376
377 if config.enforce_http {
378 if dst.scheme() != Some(&Scheme::HTTP) {
379 return Err(ConnectError {
380 msg: INVALID_NOT_HTTP.into(),
381 cause: None,
382 });
383 }
384 } else if dst.scheme().is_none() {
385 return Err(ConnectError {
386 msg: INVALID_MISSING_SCHEME.into(),
387 cause: None,
388 });
389 }
390
391 let host = match dst.host() {
392 Some(s) => s,
393 None => {
394 return Err(ConnectError {
395 msg: INVALID_MISSING_HOST.into(),
396 cause: None,
397 })
398 }
399 };
400 let port = match dst.port() {
401 Some(port) => port.as_u16(),
402 None => {
403 if dst.scheme() == Some(&Scheme::HTTPS) {
404 443
405 } else {
406 80
407 }
408 }
409 };
410
411 Ok((host, port))
412}
413
414impl<R> HttpConnector<R>
415where
416 R: Resolve,
417{
418 async fn call_async(&mut self, dst: Uri) -> Result<TokioIo<TcpStream>, ConnectError> {
419 let config = &self.config;
420
421 let (host, port) = get_host_port(config, &dst)?;
422 let host = host.trim_start_matches('[').trim_end_matches(']');
423
424 let addrs = if let Some(addrs) = dns::SocketAddrs::try_parse(host, port) {
427 addrs
428 } else {
429 let addrs = resolve(&mut self.resolver, dns::Name::new(host.into()))
430 .await
431 .map_err(ConnectError::dns)?;
432 let addrs = addrs
433 .map(|mut addr| {
434 addr.set_port(port);
435 addr
436 })
437 .collect();
438 dns::SocketAddrs::new(addrs)
439 };
440
441 let c = ConnectingTcp::new(addrs, config);
442
443 let sock = c.connect().await?;
444
445 if let Err(e) = sock.set_nodelay(config.nodelay) {
446 warn!("tcp set_nodelay error: {}", e);
447 }
448
449 Ok(TokioIo::new(sock))
450 }
451}
452
453impl Connection for TcpStream {
454 fn connected(&self) -> Connected {
455 let connected = Connected::new();
456 if let (Ok(remote_addr), Ok(local_addr)) = (self.peer_addr(), self.local_addr()) {
457 connected.extra(HttpInfo {
458 remote_addr,
459 local_addr,
460 })
461 } else {
462 connected
463 }
464 }
465}
466
467impl<T> Connection for TokioIo<T>
470where
471 T: Connection,
472{
473 fn connected(&self) -> Connected {
474 self.inner().connected()
475 }
476}
477
478impl HttpInfo {
479 pub fn remote_addr(&self) -> SocketAddr {
481 self.remote_addr
482 }
483
484 pub fn local_addr(&self) -> SocketAddr {
486 self.local_addr
487 }
488}
489
490pin_project! {
491 #[must_use = "futures do nothing unless polled"]
497 #[allow(missing_debug_implementations)]
498 pub struct HttpConnecting<R> {
499 #[pin]
500 fut: BoxConnecting,
501 _marker: PhantomData<R>,
502 }
503}
504
505type ConnectResult = Result<TokioIo<TcpStream>, ConnectError>;
506type BoxConnecting = Pin<Box<dyn Future<Output = ConnectResult> + Send>>;
507
508impl<R: Resolve> Future for HttpConnecting<R> {
509 type Output = ConnectResult;
510
511 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
512 self.project().fut.poll(cx)
513 }
514}
515
516pub struct ConnectError {
518 msg: Box<str>,
519 cause: Option<Box<dyn StdError + Send + Sync>>,
520}
521
522impl ConnectError {
523 fn new<S, E>(msg: S, cause: E) -> ConnectError
524 where
525 S: Into<Box<str>>,
526 E: Into<Box<dyn StdError + Send + Sync>>,
527 {
528 ConnectError {
529 msg: msg.into(),
530 cause: Some(cause.into()),
531 }
532 }
533
534 fn dns<E>(cause: E) -> ConnectError
535 where
536 E: Into<Box<dyn StdError + Send + Sync>>,
537 {
538 ConnectError::new("dns error", cause)
539 }
540
541 fn m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError
542 where
543 S: Into<Box<str>>,
544 E: Into<Box<dyn StdError + Send + Sync>>,
545 {
546 move |cause| ConnectError::new(msg, cause)
547 }
548}
549
550impl fmt::Debug for ConnectError {
551 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
552 if let Some(ref cause) = self.cause {
553 f.debug_tuple("ConnectError")
554 .field(&self.msg)
555 .field(cause)
556 .finish()
557 } else {
558 self.msg.fmt(f)
559 }
560 }
561}
562
563impl fmt::Display for ConnectError {
564 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
565 f.write_str(&self.msg)?;
566
567 if let Some(ref cause) = self.cause {
568 write!(f, ": {}", cause)?;
569 }
570
571 Ok(())
572 }
573}
574
575impl StdError for ConnectError {
576 fn source(&self) -> Option<&(dyn StdError + 'static)> {
577 self.cause.as_ref().map(|e| &**e as _)
578 }
579}
580
581struct ConnectingTcp<'a> {
582 preferred: ConnectingTcpRemote,
583 fallback: Option<ConnectingTcpFallback>,
584 config: &'a Config,
585}
586
587impl<'a> ConnectingTcp<'a> {
588 fn new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self {
589 if let Some(fallback_timeout) = config.happy_eyeballs_timeout {
590 let (preferred_addrs, fallback_addrs) = remote_addrs
591 .split_by_preference(config.local_address_ipv4, config.local_address_ipv6);
592 if fallback_addrs.is_empty() {
593 return ConnectingTcp {
594 preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
595 fallback: None,
596 config,
597 };
598 }
599
600 ConnectingTcp {
601 preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
602 fallback: Some(ConnectingTcpFallback {
603 delay: tokio::time::sleep(fallback_timeout),
604 remote: ConnectingTcpRemote::new(fallback_addrs, config.connect_timeout),
605 }),
606 config,
607 }
608 } else {
609 ConnectingTcp {
610 preferred: ConnectingTcpRemote::new(remote_addrs, config.connect_timeout),
611 fallback: None,
612 config,
613 }
614 }
615 }
616}
617
618struct ConnectingTcpFallback {
619 delay: Sleep,
620 remote: ConnectingTcpRemote,
621}
622
623struct ConnectingTcpRemote {
624 addrs: dns::SocketAddrs,
625 connect_timeout: Option<Duration>,
626}
627
628impl ConnectingTcpRemote {
629 fn new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self {
630 let connect_timeout = connect_timeout.and_then(|t| t.checked_div(addrs.len() as u32));
631
632 Self {
633 addrs,
634 connect_timeout,
635 }
636 }
637}
638
639impl ConnectingTcpRemote {
640 async fn connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError> {
641 let mut err = None;
642 for addr in &mut self.addrs {
643 debug!("connecting to {}", addr);
644 match connect(&addr, config, self.connect_timeout)?.await {
645 Ok(tcp) => {
646 debug!("connected to {}", addr);
647 return Ok(tcp);
648 }
649 Err(e) => {
650 trace!("connect error for {}: {:?}", addr, e);
651 err = Some(e);
652 }
653 }
654 }
655
656 match err {
657 Some(e) => Err(e),
658 None => Err(ConnectError::new(
659 "tcp connect error",
660 std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"),
661 )),
662 }
663 }
664}
665
666fn bind_local_address(
667 socket: &socket2::Socket,
668 dst_addr: &SocketAddr,
669 local_addr_ipv4: &Option<Ipv4Addr>,
670 local_addr_ipv6: &Option<Ipv6Addr>,
671) -> io::Result<()> {
672 match (*dst_addr, local_addr_ipv4, local_addr_ipv6) {
673 (SocketAddr::V4(_), Some(addr), _) => {
674 socket.bind(&SocketAddr::new((*addr).into(), 0).into())?;
675 }
676 (SocketAddr::V6(_), _, Some(addr)) => {
677 socket.bind(&SocketAddr::new((*addr).into(), 0).into())?;
678 }
679 _ => {
680 if cfg!(windows) {
681 let any: SocketAddr = match *dst_addr {
683 SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
684 SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
685 };
686 socket.bind(&any.into())?;
687 }
688 }
689 }
690
691 Ok(())
692}
693
694fn connect(
695 addr: &SocketAddr,
696 config: &Config,
697 connect_timeout: Option<Duration>,
698) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> {
699 use socket2::{Domain, Protocol, Socket, Type};
703
704 let domain = Domain::for_address(*addr);
705 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
706 .map_err(ConnectError::m("tcp open error"))?;
707
708 socket
711 .set_nonblocking(true)
712 .map_err(ConnectError::m("tcp set_nonblocking error"))?;
713
714 if let Some(tcp_keepalive) = &config.tcp_keepalive_config.into_tcpkeepalive() {
715 if let Err(e) = socket.set_tcp_keepalive(tcp_keepalive) {
716 warn!("tcp set_keepalive error: {}", e);
717 }
718 }
719
720 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
721 if let Some(interface) = &config.interface {
723 socket
724 .bind_device(Some(interface.as_bytes()))
725 .map_err(ConnectError::m("tcp bind interface error"))?;
726 }
727
728 bind_local_address(
729 &socket,
730 addr,
731 &config.local_address_ipv4,
732 &config.local_address_ipv6,
733 )
734 .map_err(ConnectError::m("tcp bind local error"))?;
735
736 #[cfg(unix)]
737 let socket = unsafe {
738 use std::os::unix::io::{FromRawFd, IntoRawFd};
743 TcpSocket::from_raw_fd(socket.into_raw_fd())
744 };
745 #[cfg(windows)]
746 let socket = unsafe {
747 use std::os::windows::io::{FromRawSocket, IntoRawSocket};
752 TcpSocket::from_raw_socket(socket.into_raw_socket())
753 };
754
755 if config.reuse_address {
756 if let Err(e) = socket.set_reuseaddr(true) {
757 warn!("tcp set_reuse_address error: {}", e);
758 }
759 }
760
761 if let Some(size) = config.send_buffer_size {
762 if let Err(e) = socket.set_send_buffer_size(size.try_into().unwrap_or(u32::MAX)) {
763 warn!("tcp set_buffer_size error: {}", e);
764 }
765 }
766
767 if let Some(size) = config.recv_buffer_size {
768 if let Err(e) = socket.set_recv_buffer_size(size.try_into().unwrap_or(u32::MAX)) {
769 warn!("tcp set_recv_buffer_size error: {}", e);
770 }
771 }
772
773 let connect = socket.connect(*addr);
774 Ok(async move {
775 match connect_timeout {
776 Some(dur) => match tokio::time::timeout(dur, connect).await {
777 Ok(Ok(s)) => Ok(s),
778 Ok(Err(e)) => Err(e),
779 Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
780 },
781 None => connect.await,
782 }
783 .map_err(ConnectError::m("tcp connect error"))
784 })
785}
786
787impl ConnectingTcp<'_> {
788 async fn connect(mut self) -> Result<TcpStream, ConnectError> {
789 match self.fallback {
790 None => self.preferred.connect(self.config).await,
791 Some(mut fallback) => {
792 let preferred_fut = self.preferred.connect(self.config);
793 futures_util::pin_mut!(preferred_fut);
794
795 let fallback_fut = fallback.remote.connect(self.config);
796 futures_util::pin_mut!(fallback_fut);
797
798 let fallback_delay = fallback.delay;
799 futures_util::pin_mut!(fallback_delay);
800
801 let (result, future) =
802 match futures_util::future::select(preferred_fut, fallback_delay).await {
803 Either::Left((result, _fallback_delay)) => {
804 (result, Either::Right(fallback_fut))
805 }
806 Either::Right(((), preferred_fut)) => {
807 futures_util::future::select(preferred_fut, fallback_fut)
809 .await
810 .factor_first()
811 }
812 };
813
814 if result.is_err() {
815 future.await
818 } else {
819 result
820 }
821 }
822 }
823 }
824}
825
826#[cfg(test)]
827mod tests {
828 use std::io;
829
830 use ::http::Uri;
831
832 use crate::client::legacy::connect::http::TcpKeepaliveConfig;
833
834 use super::super::sealed::{Connect, ConnectSvc};
835 use super::{Config, ConnectError, HttpConnector};
836
837 async fn connect<C>(
838 connector: C,
839 dst: Uri,
840 ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error>
841 where
842 C: Connect,
843 {
844 connector.connect(super::super::sealed::Internal, dst).await
845 }
846
847 #[tokio::test]
848 #[cfg_attr(miri, ignore)]
849 async fn test_errors_enforce_http() {
850 let dst = "https://example.domain/foo/bar?baz".parse().unwrap();
851 let connector = HttpConnector::new();
852
853 let err = connect(connector, dst).await.unwrap_err();
854 assert_eq!(&*err.msg, super::INVALID_NOT_HTTP);
855 }
856
857 #[cfg(any(target_os = "linux", target_os = "macos"))]
858 fn get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>) {
859 use std::net::{IpAddr, TcpListener};
860
861 let mut ip_v4 = None;
862 let mut ip_v6 = None;
863
864 let ips = pnet_datalink::interfaces()
865 .into_iter()
866 .flat_map(|i| i.ips.into_iter().map(|n| n.ip()));
867
868 for ip in ips {
869 match ip {
870 IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip),
871 IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip),
872 _ => (),
873 }
874
875 if ip_v4.is_some() && ip_v6.is_some() {
876 break;
877 }
878 }
879
880 (ip_v4, ip_v6)
881 }
882
883 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
884 fn default_interface() -> Option<String> {
885 pnet_datalink::interfaces()
886 .iter()
887 .find(|e| e.is_up() && !e.is_loopback() && !e.ips.is_empty())
888 .map(|e| e.name.clone())
889 }
890
891 #[tokio::test]
892 #[cfg_attr(miri, ignore)]
893 async fn test_errors_missing_scheme() {
894 let dst = "example.domain".parse().unwrap();
895 let mut connector = HttpConnector::new();
896 connector.enforce_http(false);
897
898 let err = connect(connector, dst).await.unwrap_err();
899 assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME);
900 }
901
902 #[cfg(any(target_os = "linux", target_os = "macos"))]
904 #[cfg_attr(miri, ignore)]
905 #[tokio::test]
906 async fn local_address() {
907 use std::net::{IpAddr, TcpListener};
908
909 let (bind_ip_v4, bind_ip_v6) = get_local_ips();
910 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
911 let port = server4.local_addr().unwrap().port();
912 let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap();
913
914 let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move {
915 let mut connector = HttpConnector::new();
916
917 match (bind_ip_v4, bind_ip_v6) {
918 (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6),
919 (Some(v4), None) => connector.set_local_address(Some(v4.into())),
920 (None, Some(v6)) => connector.set_local_address(Some(v6.into())),
921 _ => unreachable!(),
922 }
923
924 connect(connector, dst.parse().unwrap()).await.unwrap();
925
926 let (_, client_addr) = server.accept().unwrap();
927
928 assert_eq!(client_addr.ip(), expected_ip);
929 };
930
931 if let Some(ip) = bind_ip_v4 {
932 assert_client_ip(format!("http://127.0.0.1:{}", port), server4, ip.into()).await;
933 }
934
935 if let Some(ip) = bind_ip_v6 {
936 assert_client_ip(format!("http://[::1]:{}", port), server6, ip.into()).await;
937 }
938 }
939
940 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
942 #[tokio::test]
943 #[ignore = "setting `SO_BINDTODEVICE` requires the `CAP_NET_RAW` capability (works when running as root)"]
944 async fn interface() {
945 use socket2::{Domain, Protocol, Socket, Type};
946 use std::net::TcpListener;
947
948 let interface: Option<String> = default_interface();
949
950 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
951 let port = server4.local_addr().unwrap().port();
952
953 let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap();
954
955 let assert_interface_name =
956 |dst: String,
957 server: TcpListener,
958 bind_iface: Option<String>,
959 expected_interface: Option<String>| async move {
960 let mut connector = HttpConnector::new();
961 if let Some(iface) = bind_iface {
962 connector.set_interface(iface);
963 }
964
965 connect(connector, dst.parse().unwrap()).await.unwrap();
966 let domain = Domain::for_address(server.local_addr().unwrap());
967 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP)).unwrap();
968
969 assert_eq!(
970 socket.device().unwrap().as_deref(),
971 expected_interface.as_deref().map(|val| val.as_bytes())
972 );
973 };
974
975 assert_interface_name(
976 format!("http://127.0.0.1:{}", port),
977 server4,
978 interface.clone(),
979 interface.clone(),
980 )
981 .await;
982 assert_interface_name(
983 format!("http://[::1]:{}", port),
984 server6,
985 interface.clone(),
986 interface.clone(),
987 )
988 .await;
989 }
990
991 #[test]
992 #[ignore] #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
994 fn client_happy_eyeballs() {
995 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener};
996 use std::time::{Duration, Instant};
997
998 use super::dns;
999 use super::ConnectingTcp;
1000
1001 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
1002 let addr = server4.local_addr().unwrap();
1003 let _server6 = TcpListener::bind(&format!("[::1]:{}", addr.port())).unwrap();
1004 let rt = tokio::runtime::Builder::new_current_thread()
1005 .enable_all()
1006 .build()
1007 .unwrap();
1008
1009 let local_timeout = Duration::default();
1010 let unreachable_v4_timeout = measure_connect(unreachable_ipv4_addr()).1;
1011 let unreachable_v6_timeout = measure_connect(unreachable_ipv6_addr()).1;
1012 let fallback_timeout = std::cmp::max(unreachable_v4_timeout, unreachable_v6_timeout)
1013 + Duration::from_millis(250);
1014
1015 let scenarios = &[
1016 (&[local_ipv4_addr()][..], 4, local_timeout, false),
1018 (&[local_ipv6_addr()][..], 6, local_timeout, false),
1019 (
1021 &[local_ipv4_addr(), local_ipv6_addr()][..],
1022 4,
1023 local_timeout,
1024 false,
1025 ),
1026 (
1027 &[local_ipv6_addr(), local_ipv4_addr()][..],
1028 6,
1029 local_timeout,
1030 false,
1031 ),
1032 (
1034 &[unreachable_ipv4_addr(), local_ipv4_addr()][..],
1035 4,
1036 unreachable_v4_timeout,
1037 false,
1038 ),
1039 (
1040 &[unreachable_ipv6_addr(), local_ipv6_addr()][..],
1041 6,
1042 unreachable_v6_timeout,
1043 false,
1044 ),
1045 (
1047 &[
1048 unreachable_ipv4_addr(),
1049 local_ipv4_addr(),
1050 local_ipv6_addr(),
1051 ][..],
1052 4,
1053 unreachable_v4_timeout,
1054 false,
1055 ),
1056 (
1057 &[
1058 unreachable_ipv6_addr(),
1059 local_ipv6_addr(),
1060 local_ipv4_addr(),
1061 ][..],
1062 6,
1063 unreachable_v6_timeout,
1064 true,
1065 ),
1066 (
1068 &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..],
1069 6,
1070 fallback_timeout,
1071 false,
1072 ),
1073 (
1074 &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..],
1075 4,
1076 fallback_timeout,
1077 true,
1078 ),
1079 (
1081 &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..],
1082 6,
1083 fallback_timeout + unreachable_v6_timeout,
1084 false,
1085 ),
1086 (
1087 &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..],
1088 4,
1089 fallback_timeout + unreachable_v4_timeout,
1090 true,
1091 ),
1092 ];
1093
1094 let ipv6_accessible = measure_connect(slow_ipv6_addr()).0;
1097
1098 for &(hosts, family, timeout, needs_ipv6_access) in scenarios {
1099 if needs_ipv6_access && !ipv6_accessible {
1100 continue;
1101 }
1102
1103 let (start, stream) = rt
1104 .block_on(async move {
1105 let addrs = hosts
1106 .iter()
1107 .map(|host| (host.clone(), addr.port()).into())
1108 .collect();
1109 let cfg = Config {
1110 local_address_ipv4: None,
1111 local_address_ipv6: None,
1112 connect_timeout: None,
1113 tcp_keepalive_config: TcpKeepaliveConfig::default(),
1114 happy_eyeballs_timeout: Some(fallback_timeout),
1115 nodelay: false,
1116 reuse_address: false,
1117 enforce_http: false,
1118 send_buffer_size: None,
1119 recv_buffer_size: None,
1120 interface: None,
1121 };
1122 let connecting_tcp = ConnectingTcp::new(dns::SocketAddrs::new(addrs), &cfg);
1123 let start = Instant::now();
1124 Ok::<_, ConnectError>((start, ConnectingTcp::connect(connecting_tcp).await?))
1125 })
1126 .unwrap();
1127 let res = if stream.peer_addr().unwrap().is_ipv4() {
1128 4
1129 } else {
1130 6
1131 };
1132 let duration = start.elapsed();
1133
1134 let min_duration = if timeout >= Duration::from_millis(150) {
1136 timeout - Duration::from_millis(150)
1137 } else {
1138 Duration::default()
1139 };
1140 let max_duration = timeout + Duration::from_millis(150);
1141
1142 assert_eq!(res, family);
1143 assert!(duration >= min_duration);
1144 assert!(duration <= max_duration);
1145 }
1146
1147 fn local_ipv4_addr() -> IpAddr {
1148 Ipv4Addr::new(127, 0, 0, 1).into()
1149 }
1150
1151 fn local_ipv6_addr() -> IpAddr {
1152 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()
1153 }
1154
1155 fn unreachable_ipv4_addr() -> IpAddr {
1156 Ipv4Addr::new(127, 0, 0, 2).into()
1157 }
1158
1159 fn unreachable_ipv6_addr() -> IpAddr {
1160 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2).into()
1161 }
1162
1163 fn slow_ipv4_addr() -> IpAddr {
1164 Ipv4Addr::new(198, 18, 0, 25).into()
1166 }
1167
1168 fn slow_ipv6_addr() -> IpAddr {
1169 Ipv6Addr::new(2001, 2, 0, 0, 0, 0, 0, 254).into()
1171 }
1172
1173 fn measure_connect(addr: IpAddr) -> (bool, Duration) {
1174 let start = Instant::now();
1175 let result =
1176 std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1));
1177
1178 let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut;
1179 let duration = start.elapsed();
1180 (reachable, duration)
1181 }
1182 }
1183
1184 use std::time::Duration;
1185
1186 #[test]
1187 fn no_tcp_keepalive_config() {
1188 assert!(TcpKeepaliveConfig::default().into_tcpkeepalive().is_none());
1189 }
1190
1191 #[test]
1192 fn tcp_keepalive_time_config() {
1193 let mut kac = TcpKeepaliveConfig::default();
1194 kac.time = Some(Duration::from_secs(60));
1195 if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1196 assert!(format!("{tcp_keepalive:?}").contains("time: Some(60s)"));
1197 } else {
1198 panic!("test failed");
1199 }
1200 }
1201
1202 #[cfg(not(any(target_os = "openbsd", target_os = "redox", target_os = "solaris")))]
1203 #[test]
1204 fn tcp_keepalive_interval_config() {
1205 let mut kac = TcpKeepaliveConfig::default();
1206 kac.interval = Some(Duration::from_secs(1));
1207 if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1208 assert!(format!("{tcp_keepalive:?}").contains("interval: Some(1s)"));
1209 } else {
1210 panic!("test failed");
1211 }
1212 }
1213
1214 #[cfg(not(any(
1215 target_os = "openbsd",
1216 target_os = "redox",
1217 target_os = "solaris",
1218 target_os = "windows"
1219 )))]
1220 #[test]
1221 fn tcp_keepalive_retries_config() {
1222 let mut kac = TcpKeepaliveConfig::default();
1223 kac.retries = Some(3);
1224 if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1225 assert!(format!("{tcp_keepalive:?}").contains("retries: Some(3)"));
1226 } else {
1227 panic!("test failed");
1228 }
1229 }
1230}