use std::{
net::SocketAddr,
ops::DerefMut,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::TcpStream,
};
use crate::{
error::{Error, ErrorKind, Result},
options::{ServerAddress, Socks5Proxy},
runtime,
};
use super::{
tls::{tls_connect, TlsStream},
TlsConfig,
};
pub(crate) const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
#[cfg(not(target_os = "wasi"))]
const KEEPALIVE_TIME: Duration = Duration::from_secs(120);
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
pub(crate) enum AsyncStream {
Null,
Tcp(TcpStream),
Tls(TlsStream<TcpStream>),
#[cfg(unix)]
Unix(tokio::net::UnixStream),
#[cfg(feature = "socks5-proxy")]
Socks5(fast_socks5::client::Socks5Stream<TcpStream>),
#[cfg(feature = "socks5-proxy")]
Socks5Tls(TlsStream<fast_socks5::client::Socks5Stream<TcpStream>>),
}
#[cfg(feature = "socks5-proxy")]
impl Socks5Proxy {
async fn connect(
&self,
host: String,
port: Option<u16>,
) -> Result<fast_socks5::client::Socks5Stream<TcpStream>> {
use crate::options::DEFAULT_PORT;
use fast_socks5::{
client::{Config, Socks5Stream},
SocksError,
};
let proxy_address = format!("{}:{}", self.host, self.port.unwrap_or(1080));
let port = port.unwrap_or(DEFAULT_PORT);
let stream = if let Some((username, password)) = self.authentication.as_ref() {
Socks5Stream::connect_with_password(
proxy_address,
host,
port,
username.clone(),
password.clone(),
Config::default(),
)
.await
} else {
Socks5Stream::connect(proxy_address, host, port, Config::default()).await
}
.map_err(|error| {
if let SocksError::Io(io_error) = error {
ErrorKind::Io(std::sync::Arc::new(io_error))
} else {
ErrorKind::ProxyConnect {
message: error.to_string(),
}
}
})?;
Ok(stream)
}
}
impl AsyncStream {
pub(crate) async fn connect(
address: ServerAddress,
tls_cfg: Option<&TlsConfig>,
#[allow(unused)] proxy: Option<&Socks5Proxy>,
) -> Result<Self> {
match &address {
#[allow(unused)] ServerAddress::Tcp { host, port } => {
#[cfg(feature = "socks5-proxy")]
if let Some(proxy) = proxy {
let inner = proxy.connect(host.clone(), *port).await?;
return match tls_cfg {
Some(cfg) => {
Ok(AsyncStream::Socks5Tls(tls_connect(host, inner, cfg).await?))
}
None => Ok(AsyncStream::Socks5(inner)),
};
}
let resolved: Vec<_> = runtime::resolve_address(&address).await?.collect();
if resolved.is_empty() {
return Err(ErrorKind::DnsResolve {
message: format!("No DNS results for domain {address}"),
}
.into());
}
let tcp_stream = tcp_connect(resolved)
.await
.map_err(Error::with_backpressure_labels)?;
match tls_cfg {
Some(cfg) => {
let tls_stream = tls_connect(host, tcp_stream, cfg)
.await
.map_err(Error::with_backpressure_labels)?;
Ok(AsyncStream::Tls(tls_stream))
}
None => Ok(AsyncStream::Tcp(tcp_stream)),
}
}
#[cfg(unix)]
ServerAddress::Unix { path } => Ok(AsyncStream::Unix(
tokio::net::UnixStream::connect(path.as_path()).await?,
)),
}
}
}
async fn tcp_try_connect(address: &SocketAddr) -> Result<TcpStream> {
let stream = TcpStream::connect(address).await?;
stream.set_nodelay(true)?;
#[cfg(not(target_os = "wasi"))]
{
let sock_ref = socket2::SockRef::from(&stream);
let conf = socket2::TcpKeepalive::new().with_time(KEEPALIVE_TIME);
sock_ref.set_tcp_keepalive(&conf)?;
}
Ok(stream)
}
pub(crate) async fn tcp_connect(resolved: Vec<SocketAddr>) -> Result<TcpStream> {
let (addrs_v6, addrs_v4): (Vec<_>, Vec<_>) = resolved
.into_iter()
.partition(|a| matches!(a, SocketAddr::V6(_)));
let socket_addrs = interleave(addrs_v6, addrs_v4);
fn handle_join(
result: std::result::Result<Result<TcpStream>, tokio::task::JoinError>,
) -> Result<TcpStream> {
match result {
Ok(r) => r,
Err(e) => Err(Error::internal(format!("TCP connect task failure: {e}"))),
}
}
static CONNECTION_ATTEMPT_DELAY: Duration = Duration::from_millis(250);
let mut attempts = tokio::task::JoinSet::new();
let mut connect_error = None;
'spawn: for a in socket_addrs {
attempts.spawn(async move { tcp_try_connect(&a).await });
let sleep = tokio::time::sleep(CONNECTION_ATTEMPT_DELAY);
tokio::pin!(sleep); while !attempts.is_empty() {
tokio::select! {
biased;
connect_res = attempts.join_next() => {
match connect_res.map(handle_join) {
None => return Err(Error::internal("empty TCP connect task set")),
Some(Ok(cnx)) => return Ok(cnx),
Some(Err(e)) => {
connect_error.get_or_insert(e);
},
}
}
_ = &mut sleep => continue 'spawn
}
}
}
while let Some(result) = attempts.join_next().await {
match handle_join(result) {
Ok(cnx) => return Ok(cnx),
Err(e) => {
connect_error.get_or_insert(e);
}
}
}
Err(connect_error.unwrap_or_else(|| {
ErrorKind::Internal {
message: "connecting to all DNS results failed but no error reported".to_string(),
}
.into()
}))
}
fn interleave<T>(left: Vec<T>, right: Vec<T>) -> Vec<T> {
let mut out = Vec::with_capacity(left.len() + right.len());
let (mut left, mut right) = (left.into_iter(), right.into_iter());
while let Some(a) = left.next() {
out.push(a);
std::mem::swap(&mut left, &mut right);
}
out.extend(right);
out
}
impl AsyncRead for AsyncStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.deref_mut() {
Self::Null => Poll::Ready(Ok(())),
Self::Tcp(ref mut inner) => AsyncRead::poll_read(Pin::new(inner), cx, buf),
Self::Tls(ref mut inner) => AsyncRead::poll_read(Pin::new(inner), cx, buf),
#[cfg(unix)]
Self::Unix(ref mut inner) => AsyncRead::poll_read(Pin::new(inner), cx, buf),
#[cfg(feature = "socks5-proxy")]
Self::Socks5(ref mut inner) => AsyncRead::poll_read(Pin::new(inner), cx, buf),
#[cfg(feature = "socks5-proxy")]
Self::Socks5Tls(ref mut inner) => AsyncRead::poll_read(Pin::new(inner), cx, buf),
}
}
}
impl AsyncWrite for AsyncStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
match self.deref_mut() {
Self::Null => Poll::Ready(Ok(0)),
Self::Tcp(ref mut inner) => AsyncWrite::poll_write(Pin::new(inner), cx, buf),
Self::Tls(ref mut inner) => Pin::new(inner).poll_write(cx, buf),
#[cfg(unix)]
Self::Unix(ref mut inner) => AsyncWrite::poll_write(Pin::new(inner), cx, buf),
#[cfg(feature = "socks5-proxy")]
Self::Socks5(ref mut inner) => AsyncWrite::poll_write(Pin::new(inner), cx, buf),
#[cfg(feature = "socks5-proxy")]
Self::Socks5Tls(ref mut inner) => AsyncWrite::poll_write(Pin::new(inner), cx, buf),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.deref_mut() {
Self::Null => Poll::Ready(Ok(())),
Self::Tcp(ref mut inner) => AsyncWrite::poll_flush(Pin::new(inner), cx),
Self::Tls(ref mut inner) => Pin::new(inner).poll_flush(cx),
#[cfg(unix)]
Self::Unix(ref mut inner) => AsyncWrite::poll_flush(Pin::new(inner), cx),
#[cfg(feature = "socks5-proxy")]
Self::Socks5(ref mut inner) => AsyncWrite::poll_flush(Pin::new(inner), cx),
#[cfg(feature = "socks5-proxy")]
Self::Socks5Tls(ref mut inner) => AsyncWrite::poll_flush(Pin::new(inner), cx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.deref_mut() {
Self::Null => Poll::Ready(Ok(())),
Self::Tcp(ref mut inner) => Pin::new(inner).poll_shutdown(cx),
Self::Tls(ref mut inner) => Pin::new(inner).poll_shutdown(cx),
#[cfg(unix)]
Self::Unix(ref mut inner) => Pin::new(inner).poll_shutdown(cx),
#[cfg(feature = "socks5-proxy")]
Self::Socks5(ref mut inner) => Pin::new(inner).poll_shutdown(cx),
#[cfg(feature = "socks5-proxy")]
Self::Socks5Tls(ref mut inner) => Pin::new(inner).poll_shutdown(cx),
}
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[futures_io::IoSlice<'_>],
) -> Poll<std::result::Result<usize, std::io::Error>> {
match self.get_mut() {
Self::Null => Poll::Ready(Ok(0)),
Self::Tcp(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs),
Self::Tls(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs),
#[cfg(unix)]
Self::Unix(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs),
#[cfg(feature = "socks5-proxy")]
Self::Socks5(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs),
#[cfg(feature = "socks5-proxy")]
Self::Socks5Tls(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs),
}
}
fn is_write_vectored(&self) -> bool {
match self {
Self::Null => false,
Self::Tcp(ref inner) => inner.is_write_vectored(),
Self::Tls(ref inner) => inner.is_write_vectored(),
#[cfg(unix)]
Self::Unix(ref inner) => inner.is_write_vectored(),
#[cfg(feature = "socks5-proxy")]
Self::Socks5(ref inner) => inner.is_write_vectored(),
#[cfg(feature = "socks5-proxy")]
Self::Socks5Tls(ref inner) => inner.is_write_vectored(),
}
}
}