use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use hyper::rt::{self, Read, Write};
use pin_project_lite::pin_project;
use super::Runtime;
pub struct TokioRuntime;
impl Runtime for TokioRuntime {
type TcpStream = TokioIo<tokio::net::TcpStream>;
type Sleep = TokioSleep;
async fn connect(addr: SocketAddr) -> io::Result<Self::TcpStream> {
let stream = tokio::net::TcpStream::connect(addr).await?;
stream.set_nodelay(true)?;
Ok(TokioIo::new(stream))
}
async fn resolve_all(host: &str, port: u16) -> io::Result<Vec<SocketAddr>> {
let addr = format!("{host}:{port}");
let addrs: Vec<SocketAddr> = tokio::net::lookup_host(addr).await?.collect();
if addrs.is_empty() {
return Err(io::Error::new(
io::ErrorKind::AddrNotAvailable,
"no addresses found",
));
}
Ok(addrs)
}
fn sleep(duration: Duration) -> Self::Sleep {
TokioSleep {
inner: tokio::time::sleep(duration),
}
}
fn spawn<F>(future: F)
where
F: Future<Output = ()> + Send + 'static,
{
tokio::spawn(future);
}
fn set_tcp_keepalive(
stream: &Self::TcpStream,
time: Duration,
interval: Option<Duration>,
retries: Option<u32>,
) -> io::Result<()> {
use socket2::SockRef;
let sock_ref = SockRef::from(stream.inner());
let mut keepalive = socket2::TcpKeepalive::new().with_time(time);
if let Some(interval) = interval {
keepalive = keepalive.with_interval(interval);
}
#[cfg(any(
target_os = "linux",
target_os = "macos",
target_os = "ios",
target_os = "freebsd",
target_os = "netbsd",
))]
if let Some(retries) = retries {
keepalive = keepalive.with_retries(retries);
}
#[cfg(not(any(
target_os = "linux",
target_os = "macos",
target_os = "ios",
target_os = "freebsd",
target_os = "netbsd",
)))]
let _ = retries;
sock_ref.set_tcp_keepalive(&keepalive)
}
#[cfg(target_os = "linux")]
fn set_tcp_fast_open(stream: &Self::TcpStream) -> io::Result<()> {
use socket2::SockRef;
use std::os::unix::io::AsRawFd;
unsafe extern "C" {
fn setsockopt(
sockfd: std::ffi::c_int,
level: std::ffi::c_int,
optname: std::ffi::c_int,
optval: *const std::ffi::c_void,
optlen: u32,
) -> std::ffi::c_int;
}
let sock_ref = SockRef::from(stream.inner());
let fd = sock_ref.as_raw_fd();
const IPPROTO_TCP: std::ffi::c_int = 6;
const TCP_FASTOPEN_CONNECT: std::ffi::c_int = 30;
let optval: std::ffi::c_int = 1;
unsafe {
let ret = setsockopt(
fd,
IPPROTO_TCP,
TCP_FASTOPEN_CONNECT,
&optval as *const std::ffi::c_int as *const std::ffi::c_void,
std::mem::size_of::<std::ffi::c_int>() as u32,
);
if ret != 0 {
return Err(io::Error::last_os_error());
}
}
Ok(())
}
#[cfg(target_os = "linux")]
fn bind_device(stream: &Self::TcpStream, interface: &str) -> io::Result<()> {
use socket2::SockRef;
let sock_ref = SockRef::from(stream.inner());
sock_ref.bind_device(Some(interface.as_bytes()))
}
fn from_std_tcp(stream: std::net::TcpStream) -> io::Result<Self::TcpStream> {
stream.set_nonblocking(true)?;
stream.set_nodelay(true)?;
let tokio_stream = tokio::net::TcpStream::from_std(stream)?;
Ok(TokioIo::new(tokio_stream))
}
async fn connect_bound(
addr: SocketAddr,
local: std::net::IpAddr,
) -> io::Result<Self::TcpStream> {
let socket = if addr.is_ipv4() {
tokio::net::TcpSocket::new_v4()?
} else {
tokio::net::TcpSocket::new_v6()?
};
socket.bind(std::net::SocketAddr::new(local, 0))?;
let stream = socket.connect(addr).await?;
stream.set_nodelay(true)?;
Ok(TokioIo::new(stream))
}
#[cfg(unix)]
type UnixStream = TokioIo<tokio::net::UnixStream>;
#[cfg(unix)]
async fn connect_unix(path: &std::path::Path) -> io::Result<Self::UnixStream> {
let stream = tokio::net::UnixStream::connect(path).await?;
Ok(TokioIo::new(stream))
}
}
pin_project! {
pub struct TokioSleep {
#[pin]
inner: tokio::time::Sleep,
}
}
impl Future for TokioSleep {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().inner.poll(cx)
}
}
pin_project! {
pub struct TokioIo<T> {
#[pin]
inner: T,
}
}
impl<T> TokioIo<T> {
pub fn new(inner: T) -> Self {
Self { inner }
}
pub fn inner(&self) -> &T {
&self.inner
}
}
impl<T> Read for TokioIo<T>
where
T: tokio::io::AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
mut buf: rt::ReadBufCursor<'_>,
) -> Poll<io::Result<()>> {
let n = unsafe {
let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) {
Poll::Ready(Ok(())) => tbuf.filled().len(),
other => return other,
}
};
unsafe {
buf.advance(n);
}
Poll::Ready(Ok(()))
}
}
impl<T> Write for TokioIo<T>
where
T: tokio::io::AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
}
fn is_write_vectored(&self) -> bool {
tokio::io::AsyncWrite::is_write_vectored(&self.inner)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::Runtime;
#[tokio::test]
async fn resolve_all_localhost() {
let addrs = TokioRuntime::resolve_all("localhost", 80).await.unwrap();
assert!(!addrs.is_empty());
}
#[tokio::test]
async fn connect_and_set_keepalive_with_interval_retries() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let stream = TokioRuntime::connect(addr).await.unwrap();
let result = TokioRuntime::set_tcp_keepalive(
&stream,
Duration::from_secs(60),
Some(Duration::from_secs(10)),
Some(3),
);
assert!(result.is_ok());
}
#[tokio::test]
async fn from_std_tcp_succeeds() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let std_stream = std::net::TcpStream::connect(addr).unwrap();
let tokio_stream = TokioRuntime::from_std_tcp(std_stream).unwrap();
assert!(tokio_stream.inner().peer_addr().is_ok());
}
#[tokio::test]
async fn is_write_vectored_returns_true() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let stream = TokioRuntime::connect(addr).await.unwrap();
assert!(Write::is_write_vectored(&stream));
}
#[tokio::test]
async fn write_vectored_delivers_data() {
use std::future::poll_fn;
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let mut client = TokioRuntime::connect(addr).await.unwrap();
let (mut server, _) = listener.accept().await.unwrap();
let bufs = [
io::IoSlice::new(b"hello"),
io::IoSlice::new(b" "),
io::IoSlice::new(b"world"),
];
let n = poll_fn(|cx| Pin::new(&mut client).poll_write_vectored(cx, &bufs))
.await
.unwrap();
assert_eq!(n, 11);
let mut buf = vec![0u8; 11];
use tokio::io::AsyncReadExt;
server.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello world");
}
#[cfg(unix)]
#[tokio::test]
async fn connect_unix_succeeds() {
let dir = std::env::temp_dir().join("aioduct_rt_unix_test");
let _ = std::fs::create_dir_all(&dir);
let sock_path = dir.join("rt_test.sock");
let _ = std::fs::remove_file(&sock_path);
let _listener = tokio::net::UnixListener::bind(&sock_path).unwrap();
let stream = TokioRuntime::connect_unix(&sock_path).await.unwrap();
drop(stream);
let _ = std::fs::remove_file(&sock_path);
let _ = std::fs::remove_dir(&dir);
}
}