use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use async_net::TcpStream;
use futures_lite::future;
use socket2::{Domain, Protocol, Socket, Type};
use crate::error::{Error, ErrorKind, Result};
pub(crate) async fn connect_tcp_candidates(
addrs: impl IntoIterator<Item = SocketAddr>,
local_addr: Option<SocketAddr>,
timeout: Option<Duration>,
) -> Result<TcpStream> {
let mut last_error = None;
for addr in addrs {
if let Some(local_addr) = local_addr {
if local_addr.is_ipv4() != addr.is_ipv4() {
continue;
}
}
match connect_tcp_addr(addr, local_addr, timeout).await {
Ok(stream) => return Ok(stream),
Err(err) => last_error = Some(err),
}
}
Err(last_error.unwrap_or_else(|| Error::new(ErrorKind::Transport, "connect failed")))
}
pub(crate) async fn connect_happy_eyeballs(
primary: Vec<SocketAddr>,
fallback: Vec<SocketAddr>,
local_addr: Option<SocketAddr>,
timeout: Option<Duration>,
) -> Result<TcpStream> {
if primary.is_empty() {
return connect_tcp_candidates(fallback, local_addr, timeout).await;
}
if fallback.is_empty() {
return connect_tcp_candidates(primary, local_addr, timeout).await;
}
let primary_fut = connect_tcp_candidates(primary, local_addr, timeout);
let fallback_fut = async move {
async_io::Timer::after(Duration::from_millis(250)).await;
connect_tcp_candidates(fallback, local_addr, timeout).await
};
race_first_ok(primary_fut, fallback_fut).await
}
async fn race_first_ok<T, E, F1, F2>(f1: F1, f2: F2) -> std::result::Result<T, E>
where
F1: future::Future<Output = std::result::Result<T, E>>,
F2: future::Future<Output = std::result::Result<T, E>>,
{
RaceOk {
f1: Some(f1),
f2: Some(f2),
f1_err: None,
f2_err: None,
_t: std::marker::PhantomData,
}
.await
}
struct RaceOk<T, E, F1, F2> {
f1: Option<F1>,
f2: Option<F2>,
f1_err: Option<E>,
f2_err: Option<E>,
_t: std::marker::PhantomData<T>,
}
impl<T, E, F1, F2> std::future::Future for RaceOk<T, E, F1, F2>
where
F1: future::Future<Output = std::result::Result<T, E>>,
F2: future::Future<Output = std::result::Result<T, E>>,
{
type Output = std::result::Result<T, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
if let Some(f1) = this.f1.as_mut() {
if let Poll::Ready(r) = unsafe { Pin::new_unchecked(f1) }.poll(cx) {
this.f1 = None;
match r {
Ok(v) => return Poll::Ready(Ok(v)),
Err(e) => this.f1_err = Some(e),
}
}
}
if let Some(f2) = this.f2.as_mut() {
if let Poll::Ready(r) = unsafe { Pin::new_unchecked(f2) }.poll(cx) {
this.f2 = None;
match r {
Ok(v) => return Poll::Ready(Ok(v)),
Err(e) => this.f2_err = Some(e),
}
}
}
if this.f1.is_none() && this.f2.is_none() {
let err = this
.f2_err
.take()
.or_else(|| this.f1_err.take())
.expect("RaceOk completed with no stored error");
return Poll::Ready(Err(err));
}
Poll::Pending
}
}
pub(crate) async fn connect_tcp_addr(
addr: SocketAddr,
local_addr: Option<SocketAddr>,
timeout: Option<Duration>,
) -> Result<TcpStream> {
if let Some(local) = local_addr {
let domain = if addr.is_ipv4() {
Domain::IPV4
} else {
Domain::IPV6
};
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP)).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to create tcp socket", err)
})?;
socket.set_nodelay(true).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to set TCP_NODELAY", err)
})?;
socket.bind(&local.into()).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to bind local tcp socket", err)
})?;
socket.set_nonblocking(true).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to set non-blocking", err)
})?;
let _ = socket.connect(&addr.into());
let std_stream = std::net::TcpStream::from(socket);
let stream = async_io::Async::new(std_stream).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to adopt tcp socket", err)
})?;
let connect_fut = async {
stream.writable().await.map_err(|err| {
Error::with_source(ErrorKind::Transport, "tcp connect failed", err)
})?;
let so_error = stream.get_ref().take_error().map_err(|err| {
Error::with_source(ErrorKind::Transport, "tcp connect failed", err)
})?;
if let Some(err) = so_error {
return Err(Error::with_source(
ErrorKind::Transport,
"tcp connect failed",
err,
));
}
let inner = stream.into_inner().map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to finalize tcp stream", err)
})?;
async_net::TcpStream::try_from(inner).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to adopt tcp stream", err)
})
};
return apply_timeout(connect_fut, timeout).await;
}
let connect_fut = async move {
async_net::TcpStream::connect(addr)
.await
.map_err(|err| Error::with_source(ErrorKind::Transport, "tcp connect failed", err))
};
let stream = apply_timeout(connect_fut, timeout).await?;
stream.set_nodelay(true).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to set TCP_NODELAY", err)
})?;
Ok(stream)
}
async fn apply_timeout<T>(
fut: impl future::Future<Output = Result<T>>,
timeout: Option<Duration>,
) -> Result<T> {
match timeout {
None => fut.await,
Some(dur) => {
futures_lite::pin!(fut);
future::or(fut, async {
async_io::Timer::after(dur).await;
Err(Error::new(ErrorKind::Transport, "tcp connect timed out"))
})
.await
}
}
}