use core::future::Future;
use core::marker::Send;
use core::net::SocketAddr;
use core::pin::Pin;
use core::time::Duration;
#[cfg(feature = "__quic")]
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use std::{
future::poll_fn,
io,
task::{Context, Poll},
};
use async_trait::async_trait;
use futures_io::{AsyncRead, AsyncWrite};
#[cfg(any(test, feature = "tokio"))]
use tokio::runtime::Runtime;
#[cfg(any(test, feature = "tokio"))]
use tokio::task::JoinHandle;
#[cfg(any(test, feature = "tokio"))]
pub fn spawn_bg<F: Future<Output = R> + Send + 'static, R: Send + 'static>(
runtime: &Runtime,
background: F,
) -> JoinHandle<R> {
runtime.spawn(background)
}
#[cfg(feature = "tokio")]
#[doc(hidden)]
pub mod iocompat {
use core::pin::Pin;
use core::task::{Context, Poll};
use std::io::{self, IoSlice};
use futures_io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead as TokioAsyncRead, AsyncWrite as TokioAsyncWrite, ReadBuf};
pub struct AsyncIoTokioAsStd<T: TokioAsyncRead + TokioAsyncWrite>(pub T);
impl<T: TokioAsyncRead + TokioAsyncWrite + Unpin> Unpin for AsyncIoTokioAsStd<T> {}
impl<R: TokioAsyncRead + TokioAsyncWrite + Unpin> AsyncRead for AsyncIoTokioAsStd<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut buf = ReadBuf::new(buf);
let polled = Pin::new(&mut self.0).poll_read(cx, &mut buf);
polled.map_ok(|_| buf.filled().len())
}
}
impl<W: TokioAsyncRead + TokioAsyncWrite + Unpin> AsyncWrite for AsyncIoTokioAsStd<W> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}
pub struct AsyncIoStdAsTokio<T: AsyncRead + AsyncWrite>(pub T);
impl<T: AsyncRead + AsyncWrite + Unpin> Unpin for AsyncIoStdAsTokio<T> {}
impl<R: AsyncRead + AsyncWrite + Unpin> TokioAsyncRead for AsyncIoStdAsTokio<R> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().0)
.poll_read(cx, buf.initialized_mut())
.map_ok(|len| buf.advance(len))
}
}
impl<W: AsyncRead + AsyncWrite + Unpin> TokioAsyncWrite for AsyncIoStdAsTokio<W> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut self.get_mut().0).poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Pin::new(&mut self.get_mut().0).poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
Pin::new(&mut self.get_mut().0).poll_close(cx)
}
}
}
#[cfg(feature = "tokio")]
#[allow(unreachable_pub)]
mod tokio_runtime {
use std::sync::Arc;
use std::sync::Mutex;
#[cfg(feature = "__quic")]
use quinn::Runtime;
use tokio::net::{TcpSocket, TcpStream, UdpSocket as TokioUdpSocket};
use tokio::task::JoinSet;
use tokio::time::timeout;
use tracing::debug;
use super::iocompat::AsyncIoTokioAsStd;
use super::*;
use crate::xfer::CONNECT_TIMEOUT;
#[derive(Clone, Default)]
pub struct TokioHandle {
join_set: Arc<Mutex<JoinSet<()>>>,
}
impl Spawn for TokioHandle {
fn spawn_bg(&mut self, future: impl Future<Output = ()> + Send + 'static) {
let mut join_set = self.join_set.lock().unwrap();
join_set.spawn(future);
reap_tasks(&mut join_set);
}
}
#[derive(Clone, Default)]
pub struct TokioRuntimeProvider(TokioHandle);
impl TokioRuntimeProvider {
pub fn new() -> Self {
Self::default()
}
}
impl RuntimeProvider for TokioRuntimeProvider {
type Handle = TokioHandle;
type Timer = TokioTime;
type Udp = TokioUdpSocket;
type Tcp = AsyncIoTokioAsStd<TcpStream>;
fn create_handle(&self) -> Self::Handle {
self.0.clone()
}
fn connect_tcp(
&self,
server_addr: SocketAddr,
bind_addr: Option<SocketAddr>,
wait_for: Option<Duration>,
) -> Pin<Box<dyn Send + Future<Output = Result<Self::Tcp, io::Error>>>> {
Box::pin(async move {
let socket = match server_addr {
SocketAddr::V4(_) => TcpSocket::new_v4(),
SocketAddr::V6(_) => TcpSocket::new_v6(),
}?;
if let Some(bind_addr) = bind_addr {
socket.bind(bind_addr)?;
}
socket.set_nodelay(true)?;
let future = socket.connect(server_addr);
let wait_for = wait_for.unwrap_or(CONNECT_TIMEOUT);
match timeout(wait_for, future).await {
Ok(Ok(socket)) => Ok(AsyncIoTokioAsStd(socket)),
Ok(Err(e)) => Err(e),
Err(_) => {
debug!(%server_addr, "TCP connect timeout");
Err(io::Error::new(
io::ErrorKind::TimedOut,
"TCP connect timed out",
))
}
}
})
}
fn bind_udp(
&self,
local_addr: SocketAddr,
_server_addr: SocketAddr,
) -> Pin<Box<dyn Send + Future<Output = Result<Self::Udp, io::Error>>>> {
Box::pin(async move { tokio::net::UdpSocket::bind(local_addr).await })
}
#[cfg(feature = "__quic")]
fn quic_binder(&self) -> Option<&dyn QuicSocketBinder> {
Some(&TokioQuicSocketBinder)
}
}
fn reap_tasks(join_set: &mut JoinSet<()>) {
while join_set.try_join_next().is_some() {}
}
#[cfg(feature = "__quic")]
struct TokioQuicSocketBinder;
#[cfg(feature = "__quic")]
impl QuicSocketBinder for TokioQuicSocketBinder {
fn bind_quic(
&self,
local_addr: SocketAddr,
_server_addr: SocketAddr,
) -> Result<Arc<dyn quinn::AsyncUdpSocket>, io::Error> {
let socket = std::net::UdpSocket::bind(local_addr)?;
quinn::TokioRuntime.wrap_udp_socket(socket)
}
}
}
#[cfg(feature = "tokio")]
pub use tokio_runtime::{TokioHandle, TokioRuntimeProvider};
pub trait RuntimeProvider: Clone + Send + Sync + Unpin + 'static {
type Handle: Clone + Send + Spawn + Sync + Unpin;
type Timer: Time;
type Udp: DnsUdpSocket;
type Tcp: DnsTcpStream;
fn create_handle(&self) -> Self::Handle;
fn connect_tcp(
&self,
server_addr: SocketAddr,
bind_addr: Option<SocketAddr>,
timeout: Option<Duration>,
) -> Pin<Box<dyn Send + Future<Output = Result<Self::Tcp, io::Error>>>>;
fn bind_udp(
&self,
local_addr: SocketAddr,
server_addr: SocketAddr,
) -> Pin<Box<dyn Send + Future<Output = Result<Self::Udp, io::Error>>>>;
fn quic_binder(&self) -> Option<&dyn QuicSocketBinder> {
None
}
}
#[async_trait]
pub trait DnsUdpSocket
where
Self: Send + Sync + Sized + Unpin,
{
type Time: Time;
fn poll_recv_from(
&self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<(usize, SocketAddr)>>;
async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
poll_fn(|cx| self.poll_recv_from(cx, buf)).await
}
fn poll_send_to(
&self,
cx: &mut Context<'_>,
buf: &[u8],
target: SocketAddr,
) -> Poll<io::Result<usize>>;
async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
poll_fn(|cx| self.poll_send_to(cx, buf, target)).await
}
}
#[cfg(not(feature = "__quic"))]
pub trait QuicSocketBinder {}
#[cfg(feature = "__quic")]
pub trait QuicSocketBinder {
fn bind_quic(
&self,
_local_addr: SocketAddr,
_server_addr: SocketAddr,
) -> Result<Arc<dyn quinn::AsyncUdpSocket>, io::Error>;
}
pub trait DnsTcpStream: AsyncRead + AsyncWrite + Unpin + Send + Sync + Sized + 'static {
type Time: Time;
}
pub trait Spawn {
fn spawn_bg(&mut self, future: impl Future<Output = ()> + Send + 'static);
}
#[async_trait]
pub trait Time: Send + Sync + Unpin {
async fn delay_for(duration: Duration);
async fn timeout<F: 'static + Future + Send>(
duration: Duration,
future: F,
) -> Result<F::Output, io::Error>;
fn current_time() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
}
}
#[cfg(any(test, feature = "tokio"))]
#[derive(Clone, Copy, Debug)]
pub struct TokioTime;
#[cfg(any(test, feature = "tokio"))]
#[async_trait]
impl Time for TokioTime {
async fn delay_for(duration: Duration) {
tokio::time::sleep(duration).await
}
async fn timeout<F: 'static + Future + Send>(
duration: Duration,
future: F,
) -> Result<F::Output, io::Error> {
tokio::time::timeout(duration, future)
.await
.map_err(move |_| io::Error::new(io::ErrorKind::TimedOut, "future timed out"))
}
}