use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use nexus_net::{ParserSink, WireStream};
#[cfg(feature = "tokio-rt")]
pub struct AsyncReadAdapter<S> {
inner: S,
}
#[cfg(feature = "tokio-rt")]
impl<S> AsyncReadAdapter<S> {
pub fn new(inner: S) -> Self {
Self { inner }
}
pub fn get_ref(&self) -> &S {
&self.inner
}
pub fn get_mut(&mut self) -> &mut S {
&mut self.inner
}
pub fn into_inner(self) -> S {
self.inner
}
}
#[cfg(feature = "tokio-rt")]
impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> WireStream for AsyncReadAdapter<S> {
fn poll_fill_into<P: ParserSink>(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
sink: &mut P,
max: usize,
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
let spare = sink.spare();
if max == 0 || spare.is_empty() {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"poll_fill_into called with no buffer space \
(max == 0 or sink.spare() is empty)",
)));
}
let cap = spare.len().min(max);
let mut tmp_buf = tokio::io::ReadBuf::new(&mut spare[..cap]);
match Pin::new(&mut this.inner).poll_read(cx, &mut tmp_buf) {
Poll::Ready(Ok(())) => {
let n = tmp_buf.filled().len();
if n > 0 {
sink.filled(n);
}
Poll::Ready(Ok(n))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
Pin::new(&mut this.inner).poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
Pin::new(&mut this.inner).poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
Pin::new(&mut this.inner).poll_shutdown(cx)
}
}
#[cfg(feature = "nexus")]
pub struct NexusAsyncReadAdapter<S> {
inner: S,
}
#[cfg(feature = "nexus")]
impl<S> NexusAsyncReadAdapter<S> {
pub fn new(inner: S) -> Self {
Self { inner }
}
pub fn get_ref(&self) -> &S {
&self.inner
}
pub fn get_mut(&mut self) -> &mut S {
&mut self.inner
}
pub fn into_inner(self) -> S {
self.inner
}
}
#[cfg(feature = "nexus")]
impl<S: nexus_async_rt::AsyncRead + nexus_async_rt::AsyncWrite + Unpin> WireStream
for NexusAsyncReadAdapter<S>
{
fn poll_fill_into<P: ParserSink>(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
sink: &mut P,
max: usize,
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
let spare = sink.spare();
if max == 0 || spare.is_empty() {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"poll_fill_into called with no buffer space \
(max == 0 or sink.spare() is empty)",
)));
}
let cap = spare.len().min(max);
match Pin::new(&mut this.inner).poll_read(cx, &mut spare[..cap]) {
Poll::Ready(Ok(n)) => {
if n > 0 {
sink.filled(n);
}
Poll::Ready(Ok(n))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
Pin::new(&mut this.inner).poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
Pin::new(&mut this.inner).poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
Pin::new(&mut this.inner).poll_shutdown(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::future::poll_fn;
struct StubSink {
buf: Vec<u8>,
committed: usize,
}
impl StubSink {
fn with_capacity(cap: usize) -> Self {
Self {
buf: vec![0u8; cap],
committed: 0,
}
}
}
impl ParserSink for StubSink {
fn spare(&mut self) -> &mut [u8] {
&mut self.buf[self.committed..]
}
fn filled(&mut self, n: usize) {
self.committed += n;
}
}
struct UnpolledStream;
#[cfg(feature = "tokio-rt")]
impl tokio::io::AsyncRead for UnpolledStream {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
panic!("UnpolledStream::poll_read should not be reached")
}
}
#[cfg(feature = "tokio-rt")]
impl tokio::io::AsyncWrite for UnpolledStream {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<io::Result<usize>> {
panic!("unreached")
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
panic!("unreached")
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
panic!("unreached")
}
}
#[cfg(feature = "tokio-rt")]
#[tokio::test]
async fn tokio_adapter_empty_spare_returns_invalid_input() {
let mut adapter = AsyncReadAdapter::new(UnpolledStream);
let mut sink = StubSink::with_capacity(0);
let err = poll_fn(|cx| Pin::new(&mut adapter).poll_fill_into(cx, &mut sink, 8192))
.await
.expect_err("must error on empty sink");
assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
}
#[cfg(feature = "tokio-rt")]
#[tokio::test]
async fn tokio_adapter_max_zero_returns_invalid_input() {
let mut adapter = AsyncReadAdapter::new(UnpolledStream);
let mut sink = StubSink::with_capacity(64);
let err = poll_fn(|cx| Pin::new(&mut adapter).poll_fill_into(cx, &mut sink, 0))
.await
.expect_err("must error on max == 0");
assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
}
#[cfg(feature = "nexus")]
impl nexus_async_rt::AsyncRead for UnpolledStream {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut [u8],
) -> Poll<io::Result<usize>> {
panic!("UnpolledStream::poll_read should not be reached")
}
}
#[cfg(feature = "nexus")]
impl nexus_async_rt::AsyncWrite for UnpolledStream {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<io::Result<usize>> {
panic!("unreached")
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
panic!("unreached")
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
panic!("unreached")
}
}
#[cfg(feature = "nexus")]
fn block_on<F: std::future::Future>(f: F) -> F::Output {
use std::task::{RawWaker, RawWakerVTable, Waker};
fn noop(_: *const ()) {}
fn noop_clone(p: *const ()) -> RawWaker {
RawWaker::new(p, &VTABLE)
}
const VTABLE: RawWakerVTable = RawWakerVTable::new(noop_clone, noop, noop, noop);
let waker = unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) };
let mut cx = Context::from_waker(&waker);
let mut f = std::pin::pin!(f);
match f.as_mut().poll(&mut cx) {
Poll::Ready(v) => v,
Poll::Pending => panic!("precondition error must be synchronous"),
}
}
#[cfg(feature = "nexus")]
#[test]
fn nexus_adapter_empty_spare_returns_invalid_input() {
let mut adapter = NexusAsyncReadAdapter::new(UnpolledStream);
let mut sink = StubSink::with_capacity(0);
let err = block_on(poll_fn(|cx| {
Pin::new(&mut adapter).poll_fill_into(cx, &mut sink, 8192)
}))
.expect_err("must error on empty sink");
assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
}
#[cfg(feature = "nexus")]
#[test]
fn nexus_adapter_max_zero_returns_invalid_input() {
let mut adapter = NexusAsyncReadAdapter::new(UnpolledStream);
let mut sink = StubSink::with_capacity(64);
let err = block_on(poll_fn(|cx| {
Pin::new(&mut adapter).poll_fill_into(cx, &mut sink, 0)
}))
.expect_err("must error on max == 0");
assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
}
}