use crate::{
Runtime,
sys::AsSysFd,
traits::{Executor, Reactor, RuntimeKit},
util::Task,
};
use async_compat::{Compat, CompatExt};
use cfg_if::cfg_if;
use futures_core::Stream;
use futures_io::{AsyncRead, AsyncWrite};
use std::{
future::Future,
io::{self, Read, Write},
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::{Duration, Instant},
};
use tokio::{
net::TcpStream,
runtime::{EnterGuard, Handle, Runtime as TokioRT},
time::Sleep,
};
use tokio_stream::{StreamExt, wrappers::IntervalStream};
use task::TTask;
pub type TokioRuntime = Runtime<Tokio>;
impl TokioRuntime {
pub fn tokio() -> io::Result<Self> {
Ok(Self::tokio_with_runtime(TokioRT::new()?))
}
pub fn tokio_current() -> Self {
Self::new(Tokio::current())
}
pub fn tokio_with_handle(handle: Handle) -> Self {
Self::new(Tokio::default().with_handle(handle))
}
pub fn tokio_with_runtime(runtime: TokioRT) -> Self {
Self::new(Tokio::default().with_runtime(runtime))
}
}
#[derive(Default, Clone, Debug)]
pub struct Tokio {
handle: Option<Handle>,
runtime: Option<Arc<TokioRT>>,
}
impl Tokio {
pub fn with_handle(mut self, handle: Handle) -> Self {
self.handle = Some(handle);
self
}
pub fn with_runtime(mut self, runtime: TokioRT) -> Self {
let handle = runtime.handle().clone();
self.runtime = Some(Arc::new(runtime));
self.with_handle(handle)
}
pub fn current() -> Self {
Self::default().with_handle(Handle::current())
}
fn handle(&self) -> Option<Handle> {
self.handle.clone().or_else(|| Handle::try_current().ok())
}
fn enter(&self) -> Option<EnterGuard<'_>> {
self.runtime
.as_ref()
.map(|r| r.handle())
.or(self.handle.as_ref())
.map(Handle::enter)
}
}
impl RuntimeKit for Tokio {}
impl Executor for Tokio {
type Task<T: Send + 'static> = TTask<T>;
fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
if let Some(runtime) = self.runtime.as_ref() {
runtime.block_on(f)
} else if let Some(handle) = self.handle() {
handle.block_on(f)
} else {
Handle::current().block_on(f)
}
}
fn spawn<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
&self,
f: F,
) -> Task<Self::Task<T>> {
TTask(Some(if let Some(handle) = self.handle() {
handle.spawn(f)
} else {
tokio::task::spawn(f)
}))
.into()
}
fn spawn_blocking<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(
&self,
f: F,
) -> Task<Self::Task<T>> {
TTask(Some(if let Some(handle) = self.handle() {
handle.spawn_blocking(f)
} else {
tokio::task::spawn_blocking(f)
}))
.into()
}
}
impl Reactor for Tokio {
type TcpStream = Compat<TcpStream>;
type Sleep = Sleep;
fn register<H: Read + Write + AsSysFd + Send + 'static>(
&self,
socket: H,
) -> io::Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static> {
let _enter = self.enter();
cfg_if! {
if #[cfg(unix)] {
Ok(unix::AsyncFdWrapper(
tokio::io::unix::AsyncFd::new(socket)?,
))
} else {
Err::<crate::util::DummyIO, _>(io::Error::other(
"Registering FD on tokio reactor is only supported on unix",
))
}
}
}
fn sleep(&self, dur: Duration) -> Self::Sleep {
let _enter = self.enter();
tokio::time::sleep(dur)
}
fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> + Send + 'static {
let _enter = self.enter();
IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std)
}
fn tcp_connect_addr(
&self,
addr: SocketAddr,
) -> impl Future<Output = io::Result<Self::TcpStream>> + Send + 'static {
let _enter = self.enter();
async move { Ok(TcpStream::connect(addr).await?.compat()) }
}
}
mod task {
use crate::util::TaskImpl;
use async_trait::async_trait;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
#[derive(Debug)]
pub struct TTask<T: Send + 'static>(pub(super) Option<tokio::task::JoinHandle<T>>);
#[async_trait]
impl<T: Send + 'static> TaskImpl for TTask<T> {
async fn cancel(&mut self) -> Option<T> {
let task = self.0.take()?;
task.abort();
task.await.ok()
}
}
impl<T: Send + 'static> Future for TTask<T> {
type Output = T;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let task = self.0.as_mut().expect("task has been canceled");
match Pin::new(task).poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(res) => Poll::Ready(res.expect("task has been canceled")),
}
}
}
}
#[cfg(unix)]
mod unix {
use super::*;
use futures_io::{AsyncRead, AsyncWrite};
use std::io::{IoSlice, IoSliceMut};
use tokio::io::unix::AsyncFd;
pub(super) struct AsyncFdWrapper<H: Read + Write + AsSysFd>(pub(super) AsyncFd<H>);
impl<H: Read + Write + AsSysFd> AsyncFdWrapper<H> {
fn read<F: FnOnce(&mut AsyncFd<H>) -> io::Result<usize>>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
f: F,
) -> Option<Poll<io::Result<usize>>> {
Some(match self.0.poll_read_ready_mut(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
Ok(res) => Poll::Ready(res),
Err(_) => return None,
},
})
}
fn write<R, F: FnOnce(&mut AsyncFd<H>) -> io::Result<R>>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
f: F,
) -> Option<Poll<io::Result<R>>> {
Some(match self.0.poll_write_ready_mut(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
Ok(res) => Poll::Ready(res),
Err(_) => return None,
},
})
}
}
impl<H: Read + Write + AsSysFd> Unpin for AsyncFdWrapper<H> {}
impl<H: Read + Write + AsSysFd> AsyncRead for AsyncFdWrapper<H> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
loop {
if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
return res;
}
}
}
fn poll_read_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [IoSliceMut<'_>],
) -> Poll<io::Result<usize>> {
loop {
if let Some(res) = self
.as_mut()
.read(cx, |socket| socket.get_mut().read_vectored(bufs))
{
return res;
}
}
}
}
impl<H: Read + Write + AsSysFd + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
loop {
if let Some(res) = self
.as_mut()
.write(cx, |socket| socket.get_mut().write(buf))
{
return res;
}
}
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
loop {
if let Some(res) = self
.as_mut()
.write(cx, |socket| socket.get_mut().write_vectored(bufs))
{
return res;
}
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
loop {
if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
return res;
}
}
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
self.poll_flush(cx)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn auto_traits() {
use crate::util::test::*;
let runtime = Runtime::tokio().unwrap();
assert_send(&runtime);
assert_sync(&runtime);
assert_clone(&runtime);
}
}