#![cfg(not(loom))]
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::mpsc;
use tokio::time::{self, Duration, Instant, Sleep};
use tokio_stream::wrappers::UnboundedReceiverStream;
use futures_core::Stream;
use std::collections::VecDeque;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{self, ready, Poll, Waker};
use std::{cmp, io};
#[derive(Debug)]
pub struct Mock {
inner: Inner,
}
#[derive(Debug)]
pub struct Handle {
tx: mpsc::UnboundedSender<Action>,
}
#[derive(Debug, Clone, Default)]
pub struct Builder {
actions: VecDeque<Action>,
name: String,
}
#[derive(Debug, Clone)]
enum Action {
Read(Vec<u8>),
Write(Vec<u8>),
Wait(Duration),
ReadError(Option<Arc<io::Error>>),
WriteError(Option<Arc<io::Error>>),
}
struct Inner {
actions: VecDeque<Action>,
waiting: Option<Instant>,
sleep: Option<Pin<Box<Sleep>>>,
read_wait: Option<Waker>,
rx: UnboundedReceiverStream<Action>,
name: String,
}
impl Builder {
pub fn new() -> Self {
Self::default()
}
pub fn read(&mut self, buf: &[u8]) -> &mut Self {
self.actions.push_back(Action::Read(buf.into()));
self
}
pub fn read_error(&mut self, error: io::Error) -> &mut Self {
let error = Some(error.into());
self.actions.push_back(Action::ReadError(error));
self
}
pub fn write(&mut self, buf: &[u8]) -> &mut Self {
self.actions.push_back(Action::Write(buf.into()));
self
}
pub fn write_error(&mut self, error: io::Error) -> &mut Self {
let error = Some(error.into());
self.actions.push_back(Action::WriteError(error));
self
}
pub fn wait(&mut self, duration: Duration) -> &mut Self {
let duration = cmp::max(duration, Duration::from_millis(1));
self.actions.push_back(Action::Wait(duration));
self
}
pub fn name(&mut self, name: impl Into<String>) -> &mut Self {
self.name = name.into();
self
}
pub fn build(&mut self) -> Mock {
let (mock, _) = self.build_with_handle();
mock
}
pub fn build_with_handle(&mut self) -> (Mock, Handle) {
let (inner, handle) = Inner::new(self.actions.clone(), self.name.clone());
let mock = Mock { inner };
(mock, handle)
}
}
impl Handle {
pub fn read(&mut self, buf: &[u8]) -> &mut Self {
self.tx.send(Action::Read(buf.into())).unwrap();
self
}
pub fn read_error(&mut self, error: io::Error) -> &mut Self {
let error = Some(error.into());
self.tx.send(Action::ReadError(error)).unwrap();
self
}
pub fn write(&mut self, buf: &[u8]) -> &mut Self {
self.tx.send(Action::Write(buf.into())).unwrap();
self
}
pub fn write_error(&mut self, error: io::Error) -> &mut Self {
let error = Some(error.into());
self.tx.send(Action::WriteError(error)).unwrap();
self
}
}
impl Inner {
fn new(actions: VecDeque<Action>, name: String) -> (Inner, Handle) {
let (tx, rx) = mpsc::unbounded_channel();
let rx = UnboundedReceiverStream::new(rx);
let inner = Inner {
actions,
sleep: None,
read_wait: None,
rx,
waiting: None,
name,
};
let handle = Handle { tx };
(inner, handle)
}
fn poll_action(&mut self, cx: &mut task::Context<'_>) -> Poll<Option<Action>> {
Pin::new(&mut self.rx).poll_next(cx)
}
fn read(&mut self, dst: &mut ReadBuf<'_>) -> io::Result<()> {
match self.action() {
Some(&mut Action::Read(ref mut data)) => {
let n = cmp::min(dst.remaining(), data.len());
dst.put_slice(&data[..n]);
data.drain(..n);
Ok(())
}
Some(&mut Action::ReadError(ref mut err)) => {
let err = err.take().expect("Should have been removed from actions.");
let err = Arc::try_unwrap(err).expect("There are no other references.");
Err(err)
}
Some(_) => {
Err(io::ErrorKind::WouldBlock.into())
}
None => Ok(()),
}
}
fn write(&mut self, mut src: &[u8]) -> io::Result<usize> {
let mut ret = 0;
if self.actions.is_empty() {
return Err(io::ErrorKind::BrokenPipe.into());
}
if let Some(&mut Action::Wait(..)) = self.action() {
return Err(io::ErrorKind::WouldBlock.into());
}
if let Some(&mut Action::WriteError(ref mut err)) = self.action() {
let err = err.take().expect("Should have been removed from actions.");
let err = Arc::try_unwrap(err).expect("There are no other references.");
return Err(err);
}
for i in 0..self.actions.len() {
match self.actions[i] {
Action::Write(ref mut expect) => {
let n = cmp::min(src.len(), expect.len());
assert_eq!(&src[..n], &expect[..n], "name={} i={}", self.name, i);
expect.drain(..n);
src = &src[n..];
ret += n;
if src.is_empty() {
return Ok(ret);
}
}
Action::Wait(..) | Action::WriteError(..) => {
break;
}
_ => {}
}
}
Ok(ret)
}
fn remaining_wait(&mut self) -> Option<Duration> {
match self.action() {
Some(&mut Action::Wait(dur)) => Some(dur),
_ => None,
}
}
fn action(&mut self) -> Option<&mut Action> {
loop {
if self.actions.is_empty() {
return None;
}
match self.actions[0] {
Action::Read(ref mut data) => {
if !data.is_empty() {
break;
}
}
Action::Write(ref mut data) => {
if !data.is_empty() {
break;
}
}
Action::Wait(ref mut dur) => {
if let Some(until) = self.waiting {
let now = Instant::now();
if now < until {
break;
} else {
self.waiting = None;
}
} else {
self.waiting = Some(Instant::now() + *dur);
break;
}
}
Action::ReadError(ref mut error) | Action::WriteError(ref mut error) => {
if error.is_some() {
break;
}
}
}
let _action = self.actions.pop_front();
}
self.actions.front_mut()
}
}
impl Mock {
fn maybe_wakeup_reader(&mut self) {
match self.inner.action() {
Some(&mut Action::Read(_)) | Some(&mut Action::ReadError(_)) | None => {
if let Some(waker) = self.inner.read_wait.take() {
waker.wake();
}
}
_ => {}
}
}
}
impl AsyncRead for Mock {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
loop {
if let Some(ref mut sleep) = self.inner.sleep {
ready!(Pin::new(sleep).poll(cx));
}
self.inner.sleep = None;
let filled = buf.filled().len();
match self.inner.read(buf) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if let Some(rem) = self.inner.remaining_wait() {
let until = Instant::now() + rem;
self.inner.sleep = Some(Box::pin(time::sleep_until(until)));
} else {
self.inner.read_wait = Some(cx.waker().clone());
return Poll::Pending;
}
}
Ok(()) => {
if buf.filled().len() == filled {
match ready!(self.inner.poll_action(cx)) {
Some(action) => {
self.inner.actions.push_back(action);
continue;
}
None => {
return Poll::Ready(Ok(()));
}
}
} else {
return Poll::Ready(Ok(()));
}
}
Err(e) => return Poll::Ready(Err(e)),
}
}
}
}
impl AsyncWrite for Mock {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
loop {
if let Some(ref mut sleep) = self.inner.sleep {
ready!(Pin::new(sleep).poll(cx));
}
self.inner.sleep = None;
if self.inner.actions.is_empty() {
match self.inner.poll_action(cx) {
Poll::Pending => {
}
Poll::Ready(Some(action)) => {
self.inner.actions.push_back(action);
}
Poll::Ready(None) => {
panic!("unexpected write {}", self.pmsg());
}
}
}
match self.inner.write(buf) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if let Some(rem) = self.inner.remaining_wait() {
let until = Instant::now() + rem;
self.inner.sleep = Some(Box::pin(time::sleep_until(until)));
} else {
panic!("unexpected WouldBlock {}", self.pmsg());
}
}
Ok(0) => {
if !self.inner.actions.is_empty() {
return Poll::Pending;
}
match ready!(self.inner.poll_action(cx)) {
Some(action) => {
self.inner.actions.push_back(action);
continue;
}
None => {
panic!("unexpected write {}", self.pmsg());
}
}
}
ret => {
self.maybe_wakeup_reader();
return Poll::Ready(ret);
}
}
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl Drop for Mock {
fn drop(&mut self) {
if std::thread::panicking() {
return;
}
self.inner.actions.iter().for_each(|a| match a {
Action::Read(data) => assert!(
data.is_empty(),
"There is still data left to read. {}",
self.pmsg()
),
Action::Write(data) => assert!(
data.is_empty(),
"There is still data left to write. {}",
self.pmsg()
),
_ => (),
});
}
}
impl fmt::Debug for Inner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.name.is_empty() {
write!(f, "Inner {{...}}")
} else {
write!(f, "Inner {{name={}, ...}}", self.name)
}
}
}
struct PanicMsgSnippet<'a>(&'a Inner);
impl<'a> fmt::Display for PanicMsgSnippet<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.0.name.is_empty() {
write!(f, "({} actions remain)", self.0.actions.len())
} else {
write!(
f,
"(name {}, {} actions remain)",
self.0.name,
self.0.actions.len()
)
}
}
}
impl Mock {
fn pmsg(&self) -> PanicMsgSnippet<'_> {
PanicMsgSnippet(&self.inner)
}
}