#[cfg(feature = "tokio-rt")]
pub mod tokio;
#[cfg(feature = "compio-rt")]
pub mod compio;
use std::{
error::Error as StdError,
fmt,
future::Future,
io,
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
pin::{Pin, pin},
time::Duration,
};
use futures_util::future::Either;
use socket2::TcpKeepalive;
use crate::{
conn::{Connection, net::SocketBindOptions},
dns,
error::BoxError,
};
type BoxConnecting<T, E> = Pin<Box<dyn Future<Output = Result<T, E>> + Send>>;
pub trait TcpConnector: Clone + Send + Sync + 'static {
type TcpStream: From<socket2::Socket> + Send + Sync + 'static;
type Connection: ::tokio::io::AsyncRead
+ ::tokio::io::AsyncWrite
+ Connection
+ Send
+ Unpin
+ 'static;
type Error: Into<Box<dyn StdError + Send + Sync>>;
type Future: Future<Output = Result<Self::Connection, Self::Error>> + Send + 'static;
type Sleep: Future<Output = ()> + Send + 'static;
fn connect(&self, socket: Self::TcpStream, addr: SocketAddr) -> Self::Future;
fn sleep(&self, duration: Duration) -> Self::Sleep;
}
pub(crate) struct ConnectingTcp<S: TcpConnector> {
preferred: ConnectingTcpRemote<S>,
fallback: Option<ConnectingTcpFallback<S>>,
}
struct ConnectingTcpFallback<S: TcpConnector> {
delay: S::Sleep,
remote: ConnectingTcpRemote<S>,
}
struct ConnectingTcpRemote<S: TcpConnector> {
addrs: dns::SocketAddrs,
connect_timeout: Option<Duration>,
connector: S,
}
impl<S: TcpConnector> ConnectingTcp<S>
where
S::TcpStream: From<socket2::Socket>,
{
pub(crate) fn new(remote_addrs: dns::SocketAddrs, config: &TcpOptions, connector: S) -> Self {
if let Some(fallback_timeout) = config.happy_eyeballs_timeout {
let (preferred_addrs, fallback_addrs) = remote_addrs.split_by_preference(
config.socket_bind.ipv4_address,
config.socket_bind.ipv6_address,
);
if fallback_addrs.is_empty() {
return ConnectingTcp {
preferred: ConnectingTcpRemote::new(
preferred_addrs,
config.connect_timeout,
connector,
),
fallback: None,
};
}
ConnectingTcp {
preferred: ConnectingTcpRemote::new(
preferred_addrs,
config.connect_timeout,
connector.clone(),
),
fallback: Some(ConnectingTcpFallback {
delay: connector.sleep(fallback_timeout),
remote: ConnectingTcpRemote::new(
fallback_addrs,
config.connect_timeout,
connector,
),
}),
}
} else {
ConnectingTcp {
preferred: ConnectingTcpRemote::new(
remote_addrs,
config.connect_timeout,
connector,
),
fallback: None,
}
}
}
}
impl<S: TcpConnector> ConnectingTcpRemote<S>
where
S::TcpStream: From<socket2::Socket>,
{
fn new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>, connector: S) -> Self {
let connect_timeout = connect_timeout.and_then(|t| t.checked_div(addrs.len() as u32));
Self {
addrs,
connect_timeout,
connector,
}
}
async fn connect(&mut self, config: &TcpOptions) -> Result<S::Connection, ConnectError> {
let mut err = None;
for addr in &mut self.addrs {
debug!("connecting to {}", addr);
match connect(&addr, config, self.connect_timeout, &self.connector) {
Ok(fut) => match fut.await {
Ok(tcp) => {
debug!("connected to {}", addr);
return Ok(tcp);
}
Err(mut e) => {
trace!("connect error for {}: {:?}", addr, e);
e.addr = Some(addr);
if err.is_none() {
err = Some(e);
}
}
},
Err(mut e) => {
trace!("connect error for {}: {:?}", addr, e);
e.addr = Some(addr);
if err.is_none() {
err = Some(e);
}
}
}
}
match err {
Some(e) => Err(e),
None => Err(ConnectError::new(
"tcp connect error",
std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"),
)),
}
}
}
fn bind_local_address(
socket: &socket2::Socket,
dst_addr: &SocketAddr,
local_addr_ipv4: &Option<Ipv4Addr>,
local_addr_ipv6: &Option<Ipv6Addr>,
) -> io::Result<()> {
match (*dst_addr, local_addr_ipv4, local_addr_ipv6) {
(SocketAddr::V4(_), Some(addr), _) => {
socket.bind(&SocketAddr::new((*addr).into(), 0).into())?;
}
(SocketAddr::V6(_), _, Some(addr)) => {
socket.bind(&SocketAddr::new((*addr).into(), 0).into())?;
}
_ => {
if cfg!(windows) {
let any: SocketAddr = match *dst_addr {
SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
};
socket.bind(&any.into())?;
}
}
}
Ok(())
}
fn connect<S: TcpConnector>(
addr: &SocketAddr,
config: &TcpOptions,
connect_timeout: Option<Duration>,
connector: &S,
) -> Result<impl Future<Output = Result<S::Connection, ConnectError>>, ConnectError>
where
S::TcpStream: From<socket2::Socket>,
{
use socket2::{Domain, Protocol, Socket, Type};
let domain = Domain::for_address(*addr);
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
.map_err(ConnectError::m("tcp open error"))?;
socket
.set_nonblocking(true)
.map_err(ConnectError::m("tcp set_nonblocking error"))?;
if let Some(tcp_keepalive) = &config.tcp_keepalive.into_tcpkeepalive() {
if let Err(_e) = socket.set_tcp_keepalive(tcp_keepalive) {
warn!("tcp set_keepalive error: {_e}");
}
}
#[cfg(any(
target_os = "android",
target_os = "fuchsia",
target_os = "illumos",
target_os = "ios",
target_os = "linux",
target_os = "macos",
target_os = "solaris",
target_os = "tvos",
target_os = "visionos",
target_os = "watchos",
))]
if let Some(interface) = &config.socket_bind.interface {
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
socket
.bind_device(Some(interface.as_bytes()))
.map_err(ConnectError::m("tcp bind interface error"))?;
#[cfg(any(
target_os = "illumos",
target_os = "ios",
target_os = "macos",
target_os = "solaris",
target_os = "tvos",
target_os = "visionos",
target_os = "watchos",
))]
if let Ok(interface) = std::ffi::CString::new(interface.as_bytes()) {
#[allow(unsafe_code)]
let idx = unsafe { libc::if_nametoindex(interface.as_ptr()) };
let idx = std::num::NonZeroU32::new(idx).ok_or_else(|| {
ConnectError::new(
"error converting interface name to index",
io::Error::last_os_error(),
)
})?;
match addr {
SocketAddr::V4(_) => socket.bind_device_by_index_v4(Some(idx)),
SocketAddr::V6(_) => socket.bind_device_by_index_v6(Some(idx)),
}
.map_err(ConnectError::m("tcp bind interface error"))?;
}
}
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
if let Some(tcp_user_timeout) = &config.tcp_user_timeout {
if let Err(_e) = socket.set_tcp_user_timeout(Some(*tcp_user_timeout)) {
warn!("tcp set_tcp_user_timeout error: {_e}");
}
}
bind_local_address(
&socket,
addr,
&config.socket_bind.ipv4_address,
&config.socket_bind.ipv6_address,
)
.map_err(ConnectError::m("tcp bind local error"))?;
if config.reuse_address {
if let Err(_e) = socket.set_reuse_address(true) {
warn!("tcp set_reuse_address error: {_e}");
}
}
if let Some(size) = config.send_buffer_size {
if let Err(_e) = socket.set_send_buffer_size(size) {
warn!("tcp set_buffer_size error: {_e}");
}
}
if let Some(size) = config.recv_buffer_size {
if let Err(_e) = socket.set_recv_buffer_size(size) {
warn!("tcp set_recv_buffer_size error: {_e}");
}
}
if let Err(_e) = socket.set_tcp_nodelay(config.nodelay) {
warn!("tcp set_tcp_nodelay error: {_e}");
}
let connect = connector.connect(socket.into(), *addr);
let sleep = connect_timeout.map(|dur| connector.sleep(dur));
Ok(async move {
match sleep {
Some(sleep) => match futures_util::future::select(pin!(sleep), pin!(connect)).await {
Either::Left(((), _)) => {
Err(io::Error::new(io::ErrorKind::TimedOut, "connect timeout").into())
}
Either::Right((Ok(s), _)) => Ok(s),
Either::Right((Err(e), _)) => Err(e.into()),
},
None => connect.await.map_err(Into::into),
}
.map_err(ConnectError::m("tcp connect error"))
})
}
impl<S: TcpConnector> ConnectingTcp<S>
where
S::TcpStream: From<socket2::Socket>,
{
pub(crate) async fn connect(
mut self,
config: &TcpOptions,
) -> Result<S::Connection, ConnectError> {
match self.fallback {
None => self.preferred.connect(config).await,
Some(mut fallback) => {
let preferred_fut = pin!(self.preferred.connect(config));
let fallback_fut = pin!(fallback.remote.connect(config));
let fallback_delay = pin!(fallback.delay);
let (result, future) =
match futures_util::future::select(preferred_fut, fallback_delay).await {
Either::Left((result, _fallback_delay)) => {
(result, Either::Right(fallback_fut))
}
Either::Right(((), preferred_fut)) => {
futures_util::future::select(preferred_fut, fallback_fut)
.await
.factor_first()
}
};
if result.is_err() {
future.await
} else {
result
}
}
}
}
}
pub struct ConnectError {
pub(crate) msg: &'static str,
pub(crate) addr: Option<SocketAddr>,
pub(crate) cause: Option<BoxError>,
}
impl ConnectError {
pub(crate) fn new<E>(msg: &'static str, cause: E) -> ConnectError
where
E: Into<BoxError>,
{
ConnectError {
msg,
addr: None,
cause: Some(cause.into()),
}
}
pub(crate) fn dns<E>(cause: E) -> ConnectError
where
E: Into<BoxError>,
{
ConnectError::new("dns error", cause)
}
pub(crate) fn m<E>(msg: &'static str) -> impl FnOnce(E) -> ConnectError
where
E: Into<BoxError>,
{
move |cause| ConnectError::new(msg, cause)
}
}
impl fmt::Debug for ConnectError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut b = f.debug_tuple("ConnectError");
b.field(&self.msg);
if let Some(ref addr) = self.addr {
b.field(addr);
}
if let Some(ref cause) = self.cause {
b.field(cause);
}
b.finish()
}
}
impl fmt::Display for ConnectError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.msg)
}
}
impl StdError for ConnectError {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
self.cause.as_ref().map(|e| &**e as _)
}
}
#[derive(Clone)]
pub(crate) struct TcpOptions {
pub enforce_http: bool,
pub connect_timeout: Option<Duration>,
pub happy_eyeballs_timeout: Option<Duration>,
pub nodelay: bool,
pub reuse_address: bool,
pub send_buffer_size: Option<usize>,
pub recv_buffer_size: Option<usize>,
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
pub tcp_user_timeout: Option<Duration>,
pub tcp_keepalive: TcpKeepaliveOptions,
pub socket_bind: SocketBindOptions,
}
#[derive(Default, Debug, Clone, Copy)]
pub(crate) struct TcpKeepaliveOptions {
pub time: Option<Duration>,
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "illumos",
target_os = "ios",
target_os = "visionos",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "tvos",
target_os = "watchos",
target_os = "windows",
target_os = "cygwin",
))]
pub interval: Option<Duration>,
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "illumos",
target_os = "ios",
target_os = "visionos",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "tvos",
target_os = "watchos",
target_os = "cygwin",
target_os = "windows",
))]
pub retries: Option<u32>,
}
impl TcpKeepaliveOptions {
pub(crate) fn into_tcpkeepalive(self) -> Option<TcpKeepalive> {
let mut dirty = false;
let mut ka = TcpKeepalive::new();
if let Some(time) = self.time {
ka = ka.with_time(time);
dirty = true
}
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "illumos",
target_os = "ios",
target_os = "visionos",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "tvos",
target_os = "watchos",
target_os = "windows",
target_os = "cygwin",
))]
{
if let Some(interval) = self.interval {
dirty = true;
ka = ka.with_interval(interval)
};
}
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "illumos",
target_os = "ios",
target_os = "visionos",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "tvos",
target_os = "watchos",
target_os = "cygwin",
target_os = "windows",
))]
if let Some(retries) = self.retries {
dirty = true;
ka = ka.with_retries(retries)
};
if dirty { Some(ka) } else { None }
}
}