#![cfg(not(loom))]
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::mpsc;
use tokio::time::{self, Duration, Instant, Sleep};
use tokio_stream::wrappers::UnboundedReceiverStream;
use futures_core::{ready, Stream};
use std::collections::VecDeque;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{self, Poll, Waker};
use std::{cmp, io};
#[derive(Debug)]
pub struct Mock {
inner: Inner,
}
#[derive(Debug)]
pub struct Handle {
tx: mpsc::UnboundedSender<Action>,
}
#[derive(Debug, Clone, Default)]
pub struct Builder {
actions: VecDeque<Action>,
}
#[derive(Debug, Clone)]
enum Action {
Read(Vec<u8>),
Write(Vec<u8>),
Wait(Duration),
ReadError(Option<Arc<io::Error>>),
WriteError(Option<Arc<io::Error>>),
}
struct Inner {
actions: VecDeque<Action>,
waiting: Option<Instant>,
sleep: Option<Pin<Box<Sleep>>>,
read_wait: Option<Waker>,
rx: UnboundedReceiverStream<Action>,
}
impl Builder {
pub fn new() -> Self {
Self::default()
}
pub fn read(&mut self, buf: &[u8]) -> &mut Self {
self.actions.push_back(Action::Read(buf.into()));
self
}
pub fn read_error(&mut self, error: io::Error) -> &mut Self {
let error = Some(error.into());
self.actions.push_back(Action::ReadError(error));
self
}
pub fn write(&mut self, buf: &[u8]) -> &mut Self {
self.actions.push_back(Action::Write(buf.into()));
self
}
pub fn write_error(&mut self, error: io::Error) -> &mut Self {
let error = Some(error.into());
self.actions.push_back(Action::WriteError(error));
self
}
pub fn wait(&mut self, duration: Duration) -> &mut Self {
let duration = cmp::max(duration, Duration::from_millis(1));
self.actions.push_back(Action::Wait(duration));
self
}
pub fn build(&mut self) -> Mock {
let (mock, _) = self.build_with_handle();
mock
}
pub fn build_with_handle(&mut self) -> (Mock, Handle) {
let (inner, handle) = Inner::new(self.actions.clone());
let mock = Mock { inner };
(mock, handle)
}
}
impl Handle {
pub fn read(&mut self, buf: &[u8]) -> &mut Self {
self.tx.send(Action::Read(buf.into())).unwrap();
self
}
pub fn read_error(&mut self, error: io::Error) -> &mut Self {
let error = Some(error.into());
self.tx.send(Action::ReadError(error)).unwrap();
self
}
pub fn write(&mut self, buf: &[u8]) -> &mut Self {
self.tx.send(Action::Write(buf.into())).unwrap();
self
}
pub fn write_error(&mut self, error: io::Error) -> &mut Self {
let error = Some(error.into());
self.tx.send(Action::WriteError(error)).unwrap();
self
}
}
impl Inner {
fn new(actions: VecDeque<Action>) -> (Inner, Handle) {
let (tx, rx) = mpsc::unbounded_channel();
let rx = UnboundedReceiverStream::new(rx);
let inner = Inner {
actions,
sleep: None,
read_wait: None,
rx,
waiting: None,
};
let handle = Handle { tx };
(inner, handle)
}
fn poll_action(&mut self, cx: &mut task::Context<'_>) -> Poll<Option<Action>> {
Pin::new(&mut self.rx).poll_next(cx)
}
fn read(&mut self, dst: &mut ReadBuf<'_>) -> io::Result<()> {
match self.action() {
Some(&mut Action::Read(ref mut data)) => {
let n = cmp::min(dst.remaining(), data.len());
dst.put_slice(&data[..n]);
data.drain(..n);
Ok(())
}
Some(&mut Action::ReadError(ref mut err)) => {
let err = err.take().expect("Should have been removed from actions.");
let err = Arc::try_unwrap(err).expect("There are no other references.");
Err(err)
}
Some(_) => {
Err(io::ErrorKind::WouldBlock.into())
}
None => Ok(()),
}
}
fn write(&mut self, mut src: &[u8]) -> io::Result<usize> {
let mut ret = 0;
if self.actions.is_empty() {
return Err(io::ErrorKind::BrokenPipe.into());
}
if let Some(&mut Action::Wait(..)) = self.action() {
return Err(io::ErrorKind::WouldBlock.into());
}
if let Some(&mut Action::WriteError(ref mut err)) = self.action() {
let err = err.take().expect("Should have been removed from actions.");
let err = Arc::try_unwrap(err).expect("There are no other references.");
return Err(err);
}
for i in 0..self.actions.len() {
match self.actions[i] {
Action::Write(ref mut expect) => {
let n = cmp::min(src.len(), expect.len());
assert_eq!(&src[..n], &expect[..n]);
expect.drain(..n);
src = &src[n..];
ret += n;
if src.is_empty() {
return Ok(ret);
}
}
Action::Wait(..) | Action::WriteError(..) => {
break;
}
_ => {}
}
}
Ok(ret)
}
fn remaining_wait(&mut self) -> Option<Duration> {
match self.action() {
Some(&mut Action::Wait(dur)) => Some(dur),
_ => None,
}
}
fn action(&mut self) -> Option<&mut Action> {
loop {
if self.actions.is_empty() {
return None;
}
match self.actions[0] {
Action::Read(ref mut data) => {
if !data.is_empty() {
break;
}
}
Action::Write(ref mut data) => {
if !data.is_empty() {
break;
}
}
Action::Wait(ref mut dur) => {
if let Some(until) = self.waiting {
let now = Instant::now();
if now < until {
break;
}
} else {
self.waiting = Some(Instant::now() + *dur);
break;
}
}
Action::ReadError(ref mut error) | Action::WriteError(ref mut error) => {
if error.is_some() {
break;
}
}
}
let _action = self.actions.pop_front();
}
self.actions.front_mut()
}
}
impl Mock {
fn maybe_wakeup_reader(&mut self) {
match self.inner.action() {
Some(&mut Action::Read(_)) | Some(&mut Action::ReadError(_)) | None => {
if let Some(waker) = self.inner.read_wait.take() {
waker.wake();
}
}
_ => {}
}
}
}
impl AsyncRead for Mock {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
loop {
if let Some(ref mut sleep) = self.inner.sleep {
ready!(Pin::new(sleep).poll(cx));
}
self.inner.sleep = None;
let filled = buf.filled().len();
match self.inner.read(buf) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if let Some(rem) = self.inner.remaining_wait() {
let until = Instant::now() + rem;
self.inner.sleep = Some(Box::pin(time::sleep_until(until)));
} else {
self.inner.read_wait = Some(cx.waker().clone());
return Poll::Pending;
}
}
Ok(()) => {
if buf.filled().len() == filled {
match ready!(self.inner.poll_action(cx)) {
Some(action) => {
self.inner.actions.push_back(action);
continue;
}
None => {
return Poll::Ready(Ok(()));
}
}
} else {
return Poll::Ready(Ok(()));
}
}
Err(e) => return Poll::Ready(Err(e)),
}
}
}
}
impl AsyncWrite for Mock {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
loop {
if let Some(ref mut sleep) = self.inner.sleep {
ready!(Pin::new(sleep).poll(cx));
}
self.inner.sleep = None;
match self.inner.write(buf) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if let Some(rem) = self.inner.remaining_wait() {
let until = Instant::now() + rem;
self.inner.sleep = Some(Box::pin(time::sleep_until(until)));
} else {
panic!("unexpected WouldBlock");
}
}
Ok(0) => {
if !self.inner.actions.is_empty() {
return Poll::Pending;
}
match ready!(self.inner.poll_action(cx)) {
Some(action) => {
self.inner.actions.push_back(action);
continue;
}
None => {
panic!("unexpected write");
}
}
}
ret => {
self.maybe_wakeup_reader();
return Poll::Ready(ret);
}
}
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl Drop for Mock {
fn drop(&mut self) {
if std::thread::panicking() {
return;
}
self.inner.actions.iter().for_each(|a| match a {
Action::Read(data) => assert!(data.is_empty(), "There is still data left to read."),
Action::Write(data) => assert!(data.is_empty(), "There is still data left to write."),
_ => (),
})
}
}
impl fmt::Debug for Inner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Inner {{...}}")
}
}