use std::io::{Read, Seek, SeekFrom, Write};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use std::{io, mem};
use tokio::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
#[cfg(test)]
mod tests;
#[derive(Debug)]
pub struct AsyncIoCompat<T> {
inner: T,
last_seek: io::Result<u64>,
last_seek_position: SeekFrom,
wake_delay: Duration,
}
impl<T> AsyncIoCompat<T> {
pub const fn new(inner: T) -> Self {
Self {
inner,
last_seek: Ok(0),
last_seek_position: SeekFrom::Start(0),
wake_delay: Duration::ZERO,
}
}
pub const fn new_with_delay(inner: T, delay: Duration) -> Self {
Self {
inner,
last_seek: Ok(0),
last_seek_position: SeekFrom::Start(0),
wake_delay: delay,
}
}
#[allow(clippy::missing_const_for_fn)] pub fn into_inner(self) -> T {
self.inner
}
fn schedule_wake(&self, ctx: &Context<'_>) {
if self.wake_delay.is_zero() {
ctx.waker().wake_by_ref();
} else {
let waker = ctx.waker().clone();
let delay = self.wake_delay;
tokio::spawn(async move {
tokio::time::sleep(delay).await;
waker.wake();
});
}
}
fn no_blocking<F, O>(&mut self, ctx: &Context<'_>, f: F) -> Poll<io::Result<O>>
where
F: for<'a> FnOnce(&'a mut Self) -> io::Result<O>,
{
match f(self) {
Ok(t) => Poll::Ready(Ok(t)),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
self.schedule_wake(ctx);
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
}
impl<T: Read + Unpin> AsyncRead for AsyncIoCompat<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.no_blocking(cx, |this| {
this.inner.read(buf.initialize_unfilled()).map(|filled| {
buf.advance(filled);
})
})
}
}
impl<T: Write + Unpin> AsyncWrite for AsyncIoCompat<T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.no_blocking(cx, |this| this.inner.write(buf))
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.no_blocking(cx, |this| this.inner.flush())
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl<T: Seek + Unpin> AsyncSeek for AsyncIoCompat<T> {
fn start_seek(mut self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
self.last_seek_position = position;
self.last_seek = self.inner.seek(position);
Ok(())
}
fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
match self.last_seek {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
let position = self.last_seek_position;
let res = self.inner.seek(position);
match res {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
self.last_seek = res;
self.schedule_wake(cx);
Poll::Pending
}
_ => {
self.last_seek = Ok(0);
Poll::Ready(res)
}
}
}
_ => Poll::Ready(mem::replace(&mut self.last_seek, Ok(0))),
}
}
}
pub trait CompatHelperTrait {
fn tokio_io(self) -> AsyncIoCompat<Self>
where
Self: Sized;
fn tokio_io_mut(&mut self) -> AsyncIoCompat<&mut Self>;
}
impl<T> CompatHelperTrait for T {
fn tokio_io(self) -> AsyncIoCompat<Self>
where
Self: Sized,
{
AsyncIoCompat::new(self)
}
fn tokio_io_mut(&mut self) -> AsyncIoCompat<&mut Self> {
AsyncIoCompat::new(self)
}
}