use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use futures_channel::mpsc;
use futures_core::Stream;
use crate::runtime::{ConnectorLocal, ConnectorSend, RuntimeLocal, RuntimePoll};
pub(crate) const HAPPY_EYEBALLS_DELAY: Duration = Duration::from_millis(250);
pub(crate) async fn connect_happy_eyeballs<R: RuntimePoll, C: ConnectorSend>(
connector: &C,
addrs: &[SocketAddr],
local_address: Option<std::net::IpAddr>,
) -> io::Result<(C::Stream, SocketAddr)> {
if addrs.is_empty() {
return Err(io::Error::new(
io::ErrorKind::AddrNotAvailable,
"no addresses to connect to",
));
}
if addrs.len() == 1 {
let stream = tcp_connect_send::<C>(connector, addrs[0], local_address).await?;
return Ok((stream, addrs[0]));
}
let interleaved = interleave_addrs(addrs);
race_connect::<R, C>(connector, &interleaved, local_address).await
}
pub(crate) fn interleave_addrs(addrs: &[SocketAddr]) -> Vec<SocketAddr> {
let (v6, v4): (Vec<&SocketAddr>, Vec<&SocketAddr>) = addrs.iter().partition(|a| a.is_ipv6());
let mut result = Vec::with_capacity(addrs.len());
let mut i6 = v6.into_iter();
let mut i4 = v4.into_iter();
loop {
let a = i6.next();
let b = i4.next();
if a.is_none() && b.is_none() {
break;
}
if let Some(addr) = a {
result.push(*addr);
}
if let Some(addr) = b {
result.push(*addr);
}
}
result
}
async fn tcp_connect_send<C: ConnectorSend>(
connector: &C,
addr: SocketAddr,
local_address: Option<std::net::IpAddr>,
) -> io::Result<C::Stream> {
if let Some(local) = local_address {
connector.connect_bound(addr, local).await
} else {
connector.connect(addr).await
}
}
macro_rules! impl_race_connect {
(
race_fn: $race_fn:ident,
spawn_fn: $spawn_fn:ident,
connect_fn: $connect_fn:ident,
wait_result: $WaitResult:ident,
wait_future: $WaitFuture:ident,
connector_trait: $C:ident $(: $extra_bound:ident)*,
runtime_trait: $R:ident,
spawn_method: $spawn_method:ident,
future_extra_bound: $fut_bound:tt,
) => {
async fn $race_fn<R: $R, C: $C $(+ $extra_bound)*>(
connector: &C,
addrs: &[SocketAddr],
local_address: Option<std::net::IpAddr>,
) -> io::Result<(C::Stream, SocketAddr)> {
let mut last_err =
io::Error::new(io::ErrorKind::AddrNotAvailable, "no addresses");
let (tx, mut rx) =
mpsc::unbounded::<Result<(C::Stream, SocketAddr), io::Error>>();
let mut next_idx = 0;
let mut in_flight = 0usize;
$spawn_fn::<R, C>(connector, addrs[next_idx], local_address, &tx);
next_idx += 1;
in_flight += 1;
let mut delay: Pin<Box<dyn std::future::Future<Output = ()> + $fut_bound>> =
if next_idx < addrs.len() {
Box::pin(R::sleep(HAPPY_EYEBALLS_DELAY))
} else {
Box::pin(std::future::pending())
};
loop {
match ($WaitResult::<C>::wait(&mut rx, &mut delay)).await {
$WaitResult::Message(Ok((stream, addr))) => return Ok((stream, addr)),
$WaitResult::Message(Err(e)) => {
last_err = e;
in_flight -= 1;
if in_flight == 0 {
if next_idx >= addrs.len() {
return Err(last_err);
}
$spawn_fn::<R, C>(
connector,
addrs[next_idx],
local_address,
&tx,
);
next_idx += 1;
in_flight += 1;
delay = if next_idx < addrs.len() {
Box::pin(R::sleep(HAPPY_EYEBALLS_DELAY))
} else {
Box::pin(std::future::pending())
};
}
}
$WaitResult::Delay => {
if next_idx < addrs.len() {
$spawn_fn::<R, C>(
connector,
addrs[next_idx],
local_address,
&tx,
);
next_idx += 1;
in_flight += 1;
delay = if next_idx < addrs.len() {
Box::pin(R::sleep(HAPPY_EYEBALLS_DELAY))
} else {
Box::pin(std::future::pending())
};
}
}
$WaitResult::ChannelClosed => {
return Err(last_err);
}
}
}
}
fn $spawn_fn<R: $R, C: $C $(+ $extra_bound)*>(
connector: &C,
addr: SocketAddr,
local_address: Option<std::net::IpAddr>,
tx: &mpsc::UnboundedSender<Result<(C::Stream, SocketAddr), io::Error>>,
) {
let connector = connector.clone();
let tx = tx.clone();
R::$spawn_method(async move {
let result = $connect_fn::<C>(&connector, addr, local_address).await;
let _ = tx.unbounded_send(result.map(|stream| (stream, addr)));
});
}
enum $WaitResult<C: $C $(+ $extra_bound)*> {
Message(Result<(C::Stream, SocketAddr), io::Error>),
Delay,
ChannelClosed,
}
impl<C: $C $(+ $extra_bound)*> $WaitResult<C> {
fn wait<'a>(
rx: &'a mut mpsc::UnboundedReceiver<
Result<(C::Stream, SocketAddr), io::Error>,
>,
delay: &'a mut Pin<
Box<dyn std::future::Future<Output = ()> + $fut_bound>,
>,
) -> $WaitFuture<'a, C> {
$WaitFuture {
rx,
delay,
_marker: std::marker::PhantomData,
}
}
}
struct $WaitFuture<'a, C: $C $(+ $extra_bound)*> {
rx: &'a mut mpsc::UnboundedReceiver<
Result<(C::Stream, SocketAddr), io::Error>,
>,
delay: &'a mut Pin<
Box<dyn std::future::Future<Output = ()> + $fut_bound>,
>,
_marker: std::marker::PhantomData<C>,
}
impl<C: $C $(+ $extra_bound)*> std::future::Future for $WaitFuture<'_, C> {
type Output = $WaitResult<C>;
fn poll(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
if let Poll::Ready(msg) = Pin::new(&mut *this.rx).poll_next(cx) {
return Poll::Ready(match msg {
Some(result) => $WaitResult::Message(result),
None => $WaitResult::ChannelClosed,
});
}
if let Poll::Ready(()) = this.delay.as_mut().poll(cx) {
return Poll::Ready($WaitResult::Delay);
}
Poll::Pending
}
}
};
}
impl_race_connect! {
race_fn: race_connect,
spawn_fn: spawn_attempt,
connect_fn: tcp_connect_send,
wait_result: WaitResult,
wait_future: WaitFuture,
connector_trait: ConnectorSend,
runtime_trait: RuntimePoll,
spawn_method: spawn_send,
future_extra_bound: Send,
}
async fn tcp_connect_local<C: ConnectorLocal>(
connector: &C,
addr: SocketAddr,
local_address: Option<std::net::IpAddr>,
) -> io::Result<C::Stream> {
if let Some(local) = local_address {
connector.connect_bound(addr, local).await
} else {
connector.connect(addr).await
}
}
pub(crate) async fn connect_happy_eyeballs_local<R: RuntimeLocal, C: ConnectorLocal + Clone>(
connector: &C,
addrs: &[SocketAddr],
local_address: Option<std::net::IpAddr>,
) -> io::Result<(C::Stream, SocketAddr)> {
if addrs.is_empty() {
return Err(io::Error::new(
io::ErrorKind::AddrNotAvailable,
"no addresses to connect to",
));
}
if addrs.len() == 1 {
let stream = tcp_connect_local::<C>(connector, addrs[0], local_address).await?;
return Ok((stream, addrs[0]));
}
let interleaved = interleave_addrs(addrs);
race_connect_local::<R, C>(connector, &interleaved, local_address).await
}
impl_race_connect! {
race_fn: race_connect_local,
spawn_fn: spawn_attempt_local,
connect_fn: tcp_connect_local,
wait_result: WaitResultLocal,
wait_future: WaitFutureLocal,
connector_trait: ConnectorLocal: Clone,
runtime_trait: RuntimeLocal,
spawn_method: spawn_local,
future_extra_bound: 'static,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn interleave_v6_first() {
let addrs = vec![
"127.0.0.1:80".parse().unwrap(),
"[::1]:80".parse().unwrap(),
"10.0.0.1:80".parse().unwrap(),
"[::2]:80".parse().unwrap(),
];
let result = interleave_addrs(&addrs);
assert!(result[0].is_ipv6());
assert!(result[1].is_ipv4());
assert!(result[2].is_ipv6());
assert!(result[3].is_ipv4());
}
#[test]
fn interleave_only_v4() {
let addrs = vec![
"1.1.1.1:443".parse().unwrap(),
"8.8.8.8:443".parse().unwrap(),
];
let result = interleave_addrs(&addrs);
assert_eq!(result.len(), 2);
assert!(result.iter().all(|a| a.is_ipv4()));
}
#[test]
fn interleave_empty() {
let result = interleave_addrs(&[]);
assert!(result.is_empty());
}
#[test]
fn interleave_only_v6() {
let addrs = vec!["[::1]:443".parse().unwrap(), "[::2]:443".parse().unwrap()];
let result = interleave_addrs(&addrs);
assert_eq!(result.len(), 2);
assert!(result.iter().all(|a| a.is_ipv6()));
}
#[test]
fn interleave_single_v4() {
let addrs = vec!["1.2.3.4:80".parse().unwrap()];
let result = interleave_addrs(&addrs);
assert_eq!(result.len(), 1);
assert!(result[0].is_ipv4());
}
#[test]
fn interleave_single_v6() {
let addrs = vec!["[::1]:80".parse().unwrap()];
let result = interleave_addrs(&addrs);
assert_eq!(result.len(), 1);
assert!(result[0].is_ipv6());
}
#[test]
fn interleave_uneven_more_v6() {
let addrs = vec![
"[::1]:80".parse().unwrap(),
"[::2]:80".parse().unwrap(),
"[::3]:80".parse().unwrap(),
"1.1.1.1:80".parse().unwrap(),
];
let result = interleave_addrs(&addrs);
assert_eq!(result.len(), 4);
assert!(result[0].is_ipv6()); assert!(result[1].is_ipv4()); assert!(result[2].is_ipv6()); assert!(result[3].is_ipv6()); }
#[test]
fn interleave_uneven_more_v4() {
let addrs = vec![
"1.1.1.1:80".parse().unwrap(),
"2.2.2.2:80".parse().unwrap(),
"3.3.3.3:80".parse().unwrap(),
"[::1]:80".parse().unwrap(),
];
let result = interleave_addrs(&addrs);
assert_eq!(result.len(), 4);
assert!(result[0].is_ipv6()); assert!(result[1].is_ipv4()); assert!(result[2].is_ipv4()); assert!(result[3].is_ipv4()); }
#[test]
fn interleave_preserves_order_within_family() {
let addrs = vec![
"1.0.0.1:80".parse().unwrap(),
"[2001:db8::1]:80".parse().unwrap(),
"8.8.8.8:80".parse().unwrap(),
"[2001:db8::2]:80".parse().unwrap(),
];
let result = interleave_addrs(&addrs);
let v6: Vec<_> = result.iter().filter(|a| a.is_ipv6()).collect();
let v4: Vec<_> = result.iter().filter(|a| a.is_ipv4()).collect();
assert_eq!(v6[0].to_string(), "[2001:db8::1]:80");
assert_eq!(v6[1].to_string(), "[2001:db8::2]:80");
assert_eq!(v4[0].to_string(), "1.0.0.1:80");
assert_eq!(v4[1].to_string(), "8.8.8.8:80");
}
#[cfg(feature = "tokio")]
#[tokio::test]
async fn connect_empty_addrs_errors() {
use crate::runtime::tokio_rt::{TcpConnector, TokioRuntime};
let connector = TcpConnector;
let result =
connect_happy_eyeballs::<TokioRuntime, TcpConnector>(&connector, &[], None).await;
let err = result.err().expect("should be an error");
assert_eq!(err.kind(), io::ErrorKind::AddrNotAvailable);
}
#[cfg(feature = "tokio")]
#[tokio::test]
async fn connect_single_addr_succeeds() {
use crate::runtime::tokio_rt::{TcpConnector, TokioRuntime};
let connector = TcpConnector;
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (stream, connected_addr) =
connect_happy_eyeballs::<TokioRuntime, TcpConnector>(&connector, &[addr], None)
.await
.unwrap();
assert_eq!(connected_addr, addr);
drop(stream);
}
#[cfg(feature = "tokio")]
#[tokio::test]
async fn connect_multi_addrs_first_succeeds() {
use crate::runtime::tokio_rt::{TcpConnector, TokioRuntime};
let connector = TcpConnector;
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let good_addr = listener.local_addr().unwrap();
let bad_addr: SocketAddr = "127.0.0.1:1".parse().unwrap();
let (stream, connected_addr) = connect_happy_eyeballs::<TokioRuntime, TcpConnector>(
&connector,
&[good_addr, bad_addr],
None,
)
.await
.unwrap();
assert_eq!(connected_addr, good_addr);
drop(stream);
}
#[cfg(feature = "tokio")]
#[tokio::test]
async fn connect_multi_addrs_second_succeeds() {
use crate::runtime::tokio_rt::{TcpConnector, TokioRuntime};
let connector = TcpConnector;
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let good_addr = listener.local_addr().unwrap();
let bad_addr: SocketAddr = "127.0.0.1:1".parse().unwrap();
let (stream, connected_addr) = connect_happy_eyeballs::<TokioRuntime, TcpConnector>(
&connector,
&[bad_addr, good_addr],
None,
)
.await
.unwrap();
assert_eq!(connected_addr, good_addr);
drop(stream);
}
#[cfg(feature = "tokio")]
#[tokio::test]
async fn connect_all_fail() {
use crate::runtime::tokio_rt::{TcpConnector, TokioRuntime};
let connector = TcpConnector;
let bad1: SocketAddr = "127.0.0.1:1".parse().unwrap();
let bad2: SocketAddr = "127.0.0.1:2".parse().unwrap();
let result =
connect_happy_eyeballs::<TokioRuntime, TcpConnector>(&connector, &[bad1, bad2], None)
.await;
assert!(result.is_err());
}
#[cfg(feature = "compio")]
#[test]
fn local_connect_empty_addrs_errors() {
use crate::runtime::compio_rt::{CompioRuntime, TcpConnector};
compio_runtime::Runtime::new().unwrap().block_on(async {
let connector = TcpConnector;
let result =
connect_happy_eyeballs_local::<CompioRuntime, TcpConnector>(&connector, &[], None)
.await;
let err = result.err().expect("should be an error");
assert_eq!(err.kind(), io::ErrorKind::AddrNotAvailable);
});
}
#[cfg(feature = "compio")]
#[test]
fn local_connect_single_addr_succeeds() {
use crate::runtime::compio_rt::{CompioRuntime, TcpConnector};
compio_runtime::Runtime::new().unwrap().block_on(async {
let connector = TcpConnector;
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let (_stream, connected_addr) = connect_happy_eyeballs_local::<
CompioRuntime,
TcpConnector,
>(&connector, &[addr], None)
.await
.unwrap();
assert_eq!(connected_addr, addr);
});
}
#[cfg(feature = "compio")]
#[test]
fn local_connect_multi_addrs_first_succeeds() {
use crate::runtime::compio_rt::{CompioRuntime, TcpConnector};
compio_runtime::Runtime::new().unwrap().block_on(async {
let connector = TcpConnector;
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let good_addr = listener.local_addr().unwrap();
let bad_addr: SocketAddr = "127.0.0.1:1".parse().unwrap();
let (_stream, connected_addr) = connect_happy_eyeballs_local::<
CompioRuntime,
TcpConnector,
>(&connector, &[good_addr, bad_addr], None)
.await
.unwrap();
assert_eq!(connected_addr, good_addr);
});
}
#[cfg(feature = "compio")]
#[test]
fn local_connect_multi_addrs_second_succeeds() {
use crate::runtime::compio_rt::{CompioRuntime, TcpConnector};
compio_runtime::Runtime::new().unwrap().block_on(async {
let connector = TcpConnector;
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let good_addr = listener.local_addr().unwrap();
let bad_addr: SocketAddr = "127.0.0.1:1".parse().unwrap();
let (_stream, connected_addr) = connect_happy_eyeballs_local::<
CompioRuntime,
TcpConnector,
>(&connector, &[bad_addr, good_addr], None)
.await
.unwrap();
assert_eq!(connected_addr, good_addr);
});
}
#[cfg(feature = "compio")]
#[test]
fn local_connect_all_fail() {
use crate::runtime::compio_rt::{CompioRuntime, TcpConnector};
compio_runtime::Runtime::new().unwrap().block_on(async {
let connector = TcpConnector;
let bad1: SocketAddr = "127.0.0.1:1".parse().unwrap();
let bad2: SocketAddr = "127.0.0.1:2".parse().unwrap();
let result = connect_happy_eyeballs_local::<CompioRuntime, TcpConnector>(
&connector,
&[bad1, bad2],
None,
)
.await;
assert!(result.is_err());
});
}
#[cfg(feature = "tokio")]
#[tokio::test]
async fn connect_single_addr_with_local_address() {
use crate::runtime::tokio_rt::{TcpConnector, TokioRuntime};
let connector = TcpConnector;
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let local_ip: std::net::IpAddr = "127.0.0.1".parse().unwrap();
let result = connect_happy_eyeballs::<TokioRuntime, TcpConnector>(
&connector,
&[addr],
Some(local_ip),
)
.await;
assert!(result.is_ok());
let (_, connected_addr) = result.unwrap();
assert_eq!(connected_addr, addr);
}
#[cfg(feature = "tokio")]
#[tokio::test]
async fn connect_multi_with_local_address() {
use crate::runtime::tokio_rt::{TcpConnector, TokioRuntime};
let connector = TcpConnector;
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let good_addr = listener.local_addr().unwrap();
let bad_addr: SocketAddr = "127.0.0.1:1".parse().unwrap();
let local_ip: std::net::IpAddr = "127.0.0.1".parse().unwrap();
let result = connect_happy_eyeballs::<TokioRuntime, TcpConnector>(
&connector,
&[bad_addr, good_addr],
Some(local_ip),
)
.await;
assert!(result.is_ok());
}
#[cfg(feature = "compio")]
#[test]
fn local_connect_single_with_local_address() {
use crate::runtime::compio_rt::{CompioRuntime, TcpConnector};
compio_runtime::Runtime::new().unwrap().block_on(async {
let connector = TcpConnector;
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let local_ip: std::net::IpAddr = "127.0.0.1".parse().unwrap();
let result = connect_happy_eyeballs_local::<CompioRuntime, TcpConnector>(
&connector,
&[addr],
Some(local_ip),
)
.await;
assert!(result.is_ok());
});
}
#[cfg(feature = "compio")]
#[test]
fn local_connect_multi_with_local_address() {
use crate::runtime::compio_rt::{CompioRuntime, TcpConnector};
compio_runtime::Runtime::new().unwrap().block_on(async {
let connector = TcpConnector;
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let good_addr = listener.local_addr().unwrap();
let bad_addr: SocketAddr = "127.0.0.1:1".parse().unwrap();
let local_ip: std::net::IpAddr = "127.0.0.1".parse().unwrap();
let result = connect_happy_eyeballs_local::<CompioRuntime, TcpConnector>(
&connector,
&[bad_addr, good_addr],
Some(local_ip),
)
.await;
assert!(result.is_ok());
});
}
#[cfg(feature = "smol")]
#[test]
fn smol_connect_empty_addrs_errors() {
use crate::runtime::smol_rt::{SmolRuntime, TcpConnector};
smol::block_on(async {
let connector = TcpConnector;
let result =
connect_happy_eyeballs::<SmolRuntime, TcpConnector>(&connector, &[], None).await;
let err = result.err().expect("should be an error");
assert_eq!(err.kind(), io::ErrorKind::AddrNotAvailable);
});
}
#[cfg(feature = "smol")]
#[test]
fn smol_connect_single_addr_succeeds() {
use crate::runtime::smol_rt::{SmolRuntime, TcpConnector};
smol::block_on(async {
let connector = TcpConnector;
let listener = smol::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (_stream, connected_addr) =
connect_happy_eyeballs::<SmolRuntime, TcpConnector>(&connector, &[addr], None)
.await
.unwrap();
assert_eq!(connected_addr, addr);
});
}
#[cfg(feature = "smol")]
#[test]
fn smol_connect_multi_addrs_second_succeeds() {
use crate::runtime::smol_rt::{SmolRuntime, TcpConnector};
smol::block_on(async {
let connector = TcpConnector;
let listener = smol::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let good_addr = listener.local_addr().unwrap();
let bad_addr: SocketAddr = "127.0.0.1:1".parse().unwrap();
let (_stream, connected_addr) = connect_happy_eyeballs::<SmolRuntime, TcpConnector>(
&connector,
&[bad_addr, good_addr],
None,
)
.await
.unwrap();
assert_eq!(connected_addr, good_addr);
});
}
#[cfg(feature = "smol")]
#[test]
fn smol_connect_all_fail() {
use crate::runtime::smol_rt::{SmolRuntime, TcpConnector};
smol::block_on(async {
let connector = TcpConnector;
let bad1: SocketAddr = "127.0.0.1:1".parse().unwrap();
let bad2: SocketAddr = "127.0.0.1:2".parse().unwrap();
let result = connect_happy_eyeballs::<SmolRuntime, TcpConnector>(
&connector,
&[bad1, bad2],
None,
)
.await;
assert!(result.is_err());
});
}
#[cfg(feature = "smol")]
#[test]
fn smol_connect_single_with_local_address() {
use crate::runtime::smol_rt::{SmolRuntime, TcpConnector};
smol::block_on(async {
let connector = TcpConnector;
let listener = smol::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let local_ip: std::net::IpAddr = "127.0.0.1".parse().unwrap();
let result = connect_happy_eyeballs::<SmolRuntime, TcpConnector>(
&connector,
&[addr],
Some(local_ip),
)
.await;
assert!(result.is_ok());
});
}
#[cfg(feature = "tokio")]
#[tokio::test]
async fn connect_deadline_reached_then_second_succeeds() {
use crate::runtime::tokio_rt::{TcpConnector, TokioRuntime};
let connector = TcpConnector;
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let good_addr = listener.local_addr().unwrap();
let hanging_addr: SocketAddr = "192.0.2.1:80".parse().unwrap();
let (stream, connected_addr) = connect_happy_eyeballs::<TokioRuntime, TcpConnector>(
&connector,
&[hanging_addr, good_addr],
None,
)
.await
.unwrap();
assert_eq!(connected_addr, good_addr);
drop(stream);
}
}