#[cfg(any(feature = "native-tls", feature = "rustls"))]
use crate::client::TlsConfig;
use crate::{Error, Result, client::Config};
use futures_util::{Future, FutureExt};
use log::{debug, info};
use socket2::TcpKeepalive;
#[cfg(feature = "tokio-runtime")]
use std::sync::Arc;
use std::{
pin::Pin,
task::{Context, Poll},
time::Duration,
};
#[cfg(feature = "tokio-runtime")]
pub(crate) type TcpStreamReader = tokio::io::ReadHalf<tokio::net::TcpStream>;
#[cfg(feature = "tokio-runtime")]
pub(crate) type TcpStreamWriter = tokio::io::WriteHalf<tokio::net::TcpStream>;
#[cfg(feature = "tokio-rustls")]
pub(crate) type TcpTlsStreamReader =
tokio::io::ReadHalf<tokio_rustls::client::TlsStream<tokio::net::TcpStream>>;
#[cfg(feature = "tokio-rustls")]
pub(crate) type TcpTlsStreamWriter =
tokio::io::WriteHalf<tokio_rustls::client::TlsStream<tokio::net::TcpStream>>;
#[cfg(feature = "tokio-native-tls")]
pub(crate) type TcpTlsStreamReader =
tokio::io::ReadHalf<tokio_native_tls::TlsStream<tokio::net::TcpStream>>;
#[cfg(feature = "tokio-native-tls")]
pub(crate) type TcpTlsStreamWriter =
tokio::io::WriteHalf<tokio_native_tls::TlsStream<tokio::net::TcpStream>>;
#[cfg(feature = "async-std-runtime")]
pub(crate) type TcpStreamReader =
tokio_util::compat::Compat<futures_util::io::ReadHalf<async_std::net::TcpStream>>;
#[cfg(feature = "async-std-runtime")]
pub(crate) type TcpStreamWriter =
tokio_util::compat::Compat<futures_util::io::WriteHalf<async_std::net::TcpStream>>;
#[cfg(feature = "async-std-native-tls")]
pub(crate) type TcpTlsStreamReader = tokio_util::compat::Compat<
futures_util::io::ReadHalf<async_native_tls::TlsStream<async_std::net::TcpStream>>,
>;
#[cfg(feature = "async-std-native-tls")]
pub(crate) type TcpTlsStreamWriter = tokio_util::compat::Compat<
futures_util::io::WriteHalf<async_native_tls::TlsStream<async_std::net::TcpStream>>,
>;
pub(crate) async fn tcp_connect(
host: &str,
port: u16,
config: &Config,
) -> Result<(TcpStreamReader, TcpStreamWriter)> {
debug!(
"Connecting to {host}:{port} with timeout {:?}...",
config.connect_timeout
);
let reader: TcpStreamReader;
let writer: TcpStreamWriter;
#[cfg(feature = "tokio-runtime")]
{
let stream = timeout(
config.connect_timeout,
tokio::net::TcpStream::connect((host, port)),
)
.await??;
if let Some(keep_alive) = config.keep_alive {
socket2::SockRef::from(&stream)
.set_tcp_keepalive(&TcpKeepalive::new().with_time(keep_alive))?;
}
if config.no_delay {
stream.set_nodelay(true)?;
}
(reader, writer) = tokio::io::split(stream);
}
#[cfg(feature = "async-std-runtime")]
{
use async_std::net::TcpStream;
use futures_util::AsyncReadExt;
use socket2::{Domain, Protocol, Socket, Type};
use std::net::{SocketAddr, ToSocketAddrs};
use tokio_util::compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt};
fn resolve_address(host: &str, port: u16) -> std::io::Result<SocketAddr> {
let mut addrs_iter = (host, port).to_socket_addrs()?;
addrs_iter
.next()
.ok_or_else(|| std::io::Error::other("No address found"))
}
let addr = resolve_address(host, port)?;
let socket = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?;
let keepalive = TcpKeepalive::new().with_time(Duration::from_secs(60));
socket.set_tcp_keepalive(&keepalive)?;
socket.connect(&addr.into())?;
let std_stream: std::net::TcpStream = socket.into();
let stream = TcpStream::from(std_stream);
if config.no_delay {
stream.set_nodelay(true)?;
}
let (r, w) = stream.split();
reader = r.compat();
writer = w.compat_write();
}
info!("Connected to {host}:{port}");
Ok((reader, writer))
}
#[cfg(any(feature = "native-tls", feature = "rustls"))]
pub(crate) async fn tcp_tls_connect(
host: &str,
port: u16,
tls_config: &TlsConfig,
connect_timeout: Duration,
) -> Result<(TcpTlsStreamReader, TcpTlsStreamWriter)> {
debug!("Connecting to {host}:{port} with timeout {connect_timeout:?}...");
let reader: TcpTlsStreamReader;
let writer: TcpTlsStreamWriter;
#[cfg(feature = "tokio-runtime")]
#[cfg(feature = "tokio-rustls")]
{
let stream = timeout(
connect_timeout,
tokio::net::TcpStream::connect((host, port)),
)
.await??;
let tls_connector = tokio_rustls::TlsConnector::from(tls_config.rustls_config.clone());
let server_name = host.to_owned().try_into()?;
let tls_stream = tls_connector.connect(server_name, stream).await?;
(reader, writer) = tokio::io::split(tls_stream);
}
#[cfg(feature = "tokio-runtime")]
#[cfg(feature = "tokio-native-tls")]
{
let builder = tls_config.into_tls_connector_builder();
let stream = timeout(
connect_timeout,
tokio::net::TcpStream::connect((host, port)),
)
.await??;
let tls_connector: native_tls::TlsConnector = builder.build()?;
let tls_connector = tokio_native_tls::TlsConnector::from(tls_connector);
let tls_stream = tls_connector.connect(host, stream).await?;
(reader, writer) = tokio::io::split(tls_stream);
}
#[cfg(feature = "async-std-runtime")]
#[cfg(feature = "async-std-native-tls")]
{
use futures_util::AsyncReadExt;
use tokio_util::compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt};
let stream = timeout(
connect_timeout,
async_std::net::TcpStream::connect((host, port)),
)
.await??;
let builder = tls_config.into_tls_connector_builder();
let tls_connector: async_native_tls::TlsConnector = builder.into();
let tls_stream = tls_connector.connect(host, stream).await?;
let (r, w) = tls_stream.split();
reader = r.compat();
writer = w.compat_write();
}
info!("Connected to {host}:{port}");
Ok((reader, writer))
}
pub enum JoinHandle<T> {
#[cfg(feature = "tokio-runtime")]
Tokio(tokio::task::JoinHandle<T>),
#[cfg(feature = "async-std-runtime")]
AsyncStd(async_std::task::JoinHandle<T>),
}
impl<T> Future for JoinHandle<T> {
type Output = Result<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.get_mut() {
#[cfg(feature = "tokio-runtime")]
JoinHandle::Tokio(join_handle) => match join_handle.poll_unpin(cx) {
Poll::Ready(Ok(result)) => Poll::Ready(Ok(result)),
Poll::Ready(Err(e)) => Poll::Ready(Err(Error::TokioJoin(Arc::new(e)))),
Poll::Pending => Poll::Pending,
},
#[cfg(feature = "async-std-runtime")]
JoinHandle::AsyncStd(join_handle) => match join_handle.poll_unpin(cx) {
Poll::Ready(result) => Poll::Ready(Ok(result)),
Poll::Pending => Poll::Pending,
},
}
}
}
pub(crate) fn spawn<F, T>(future: F) -> JoinHandle<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
#[cfg(feature = "tokio-runtime")]
return JoinHandle::Tokio(tokio::spawn(future));
#[cfg(feature = "async-std-runtime")]
return JoinHandle::AsyncStd(async_std::task::spawn(future));
}
#[allow(dead_code)]
pub(crate) async fn sleep(duration: Duration) {
#[cfg(feature = "tokio-runtime")]
tokio::time::sleep(duration).await;
#[cfg(feature = "async-std-runtime")]
async_std::task::sleep(duration).await;
}
#[allow(dead_code)]
pub(crate) async fn timeout<F: Future>(timeout: Duration, future: F) -> Result<F::Output> {
#[cfg(feature = "tokio-runtime")]
{
tokio::time::timeout(timeout, future)
.await
.map_err(|_| Error::Timeout)
}
#[cfg(feature = "async-std-runtime")]
{
if timeout == Duration::MAX {
Ok(future.await)
} else {
async_std::future::timeout(timeout, future)
.await
.map_err(|_| Error::Timeout)
}
}
}