use core::future::Future;
use core::pin::Pin;
use std::boxed::Box;
use std::io;
use std::net::SocketAddr;
use std::task::{Context, Poll};
use tokio::io::ReadBuf;
use tokio::net::{TcpStream, UdpSocket};
const RETRY_RANDOM_PORT: usize = 10;
pub trait AsyncConnect {
type Connection;
type Fut: Future<Output = Result<Self::Connection, io::Error>>
+ Send
+ Sync;
fn connect(&self) -> Self::Fut;
}
#[derive(Clone, Copy, Debug)]
pub struct TcpConnect {
addr: SocketAddr,
}
impl TcpConnect {
pub fn new(addr: SocketAddr) -> Self {
Self { addr }
}
}
impl AsyncConnect for TcpConnect {
type Connection = TcpStream;
type Fut = Pin<
Box<
dyn Future<Output = Result<Self::Connection, std::io::Error>>
+ Send
+ Sync,
>,
>;
fn connect(&self) -> Self::Fut {
Box::pin(TcpStream::connect(self.addr))
}
}
#[cfg(feature = "tokio-rustls")]
#[derive(Clone, Debug)]
pub struct TlsConnect {
client_config: std::sync::Arc<tokio_rustls::rustls::ClientConfig>,
server_name: tokio_rustls::rustls::pki_types::ServerName<'static>,
addr: SocketAddr,
}
#[cfg(feature = "tokio-rustls")]
impl TlsConnect {
pub fn new<Conf>(
client_config: Conf,
server_name: tokio_rustls::rustls::pki_types::ServerName<'static>,
addr: SocketAddr,
) -> Self
where
Conf: Into<std::sync::Arc<tokio_rustls::rustls::ClientConfig>>,
{
Self {
client_config: client_config.into(),
server_name,
addr,
}
}
}
#[cfg(feature = "tokio-rustls")]
impl AsyncConnect for TlsConnect {
type Connection = tokio_rustls::client::TlsStream<TcpStream>;
type Fut = Pin<
Box<
dyn Future<Output = Result<Self::Connection, std::io::Error>>
+ Send
+ Sync,
>,
>;
fn connect(&self) -> Self::Fut {
let tls_connection =
tokio_rustls::TlsConnector::from(self.client_config.clone());
let server_name = self.server_name.clone();
let addr = self.addr;
Box::pin(async move {
let box_connection = Box::new(tls_connection);
let tcp = TcpStream::connect(addr).await?;
box_connection.connect(server_name, tcp).await
})
}
}
#[derive(Clone, Copy, Debug)]
pub struct UdpConnect {
addr: SocketAddr,
}
impl UdpConnect {
pub fn new(addr: SocketAddr) -> Self {
Self { addr }
}
async fn bind_and_connect(self) -> Result<UdpSocket, io::Error> {
let mut i = 0;
let sock = loop {
let local: SocketAddr = if self.addr.is_ipv4() {
([0u8; 4], 0).into()
} else {
([0u16; 8], 0).into()
};
match UdpSocket::bind(&local).await {
Ok(sock) => break sock,
Err(err) => {
if i == RETRY_RANDOM_PORT {
return Err(err);
} else {
i += 1
}
}
}
};
sock.connect(self.addr).await?;
Ok(sock)
}
}
impl AsyncConnect for UdpConnect {
type Connection = UdpSocket;
type Fut = Pin<
Box<
dyn Future<Output = Result<Self::Connection, std::io::Error>>
+ Send
+ Sync,
>,
>;
fn connect(&self) -> Self::Fut {
Box::pin(self.bind_and_connect())
}
}
pub trait AsyncDgramRecv {
fn poll_recv(
&self,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<(), io::Error>>;
}
impl AsyncDgramRecv for UdpSocket {
fn poll_recv(
&self,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<(), io::Error>> {
UdpSocket::poll_recv(self, cx, buf)
}
}
pub trait AsyncDgramRecvEx: AsyncDgramRecv {
fn recv<'a>(&'a mut self, buf: &'a mut [u8]) -> DgramRecv<'a, Self>
where
Self: Unpin,
{
DgramRecv {
receiver: self,
buf,
}
}
}
impl<R: AsyncDgramRecv> AsyncDgramRecvEx for R {}
pub struct DgramRecv<'a, R: ?Sized> {
receiver: &'a R,
buf: &'a mut [u8],
}
impl<R: AsyncDgramRecv + Unpin> Future for DgramRecv<'_, R> {
type Output = io::Result<usize>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<usize>> {
let receiver = self.receiver;
let mut buf = ReadBuf::new(self.buf);
match Pin::new(receiver).poll_recv(cx, &mut buf) {
Poll::Pending => return Poll::Pending,
Poll::Ready(res) => {
if let Err(err) = res {
return Poll::Ready(Err(err));
}
}
}
Poll::Ready(Ok(buf.filled().len()))
}
}
pub trait AsyncDgramSend {
fn poll_send(
&self,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>>;
}
impl AsyncDgramSend for UdpSocket {
fn poll_send(
&self,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
UdpSocket::poll_send(self, cx, buf)
}
}
pub trait AsyncDgramSendEx: AsyncDgramSend {
fn send<'a>(&'a self, buf: &'a [u8]) -> DgramSend<'a, Self>
where
Self: Unpin,
{
DgramSend { sender: self, buf }
}
}
impl<S: AsyncDgramSend> AsyncDgramSendEx for S {}
pub struct DgramSend<'a, S: ?Sized> {
sender: &'a S,
buf: &'a [u8],
}
impl<S: AsyncDgramSend + Unpin> Future for DgramSend<'_, S> {
type Output = io::Result<usize>;
fn poll(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<usize>> {
Pin::new(self.sender).poll_send(cx, self.buf)
}
}