use tokio_executor::enter;
use pin_convert::AsPinMut;
use std::future::Future;
use std::mem;
use std::pin::Pin;
use std::sync::{Arc, Condvar, Mutex};
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
pub fn mock<F, R>(f: F) -> R
where
F: Fn(&mut Context<'_>) -> R,
{
let mut task = MockTask::new();
task.enter(|cx| f(cx))
}
#[derive(Debug, Clone)]
pub struct MockTask {
waker: Arc<ThreadWaker>,
}
#[derive(Debug)]
pub struct Spawn<T> {
task: MockTask,
future: Pin<Box<T>>,
}
pub fn spawn<T>(task: T) -> Spawn<T> {
Spawn {
task: MockTask::new(),
future: Box::pin(task),
}
}
#[derive(Debug)]
struct ThreadWaker {
state: Mutex<usize>,
condvar: Condvar,
}
const IDLE: usize = 0;
const WAKE: usize = 1;
const SLEEP: usize = 2;
impl<T: Future> Spawn<T> {
pub fn poll(&mut self) -> Poll<T::Output> {
let fut = self.future.as_mut();
self.task.enter(|cx| fut.poll(cx))
}
pub fn is_woken(&self) -> bool {
self.task.is_woken()
}
pub fn waker_ref_count(&self) -> usize {
self.task.waker_ref_count()
}
}
impl MockTask {
pub fn new() -> Self {
MockTask {
waker: Arc::new(ThreadWaker::new()),
}
}
pub fn poll<T, F>(&mut self, mut fut: T) -> Poll<F::Output>
where
T: AsPinMut<F>,
F: Future,
{
self.enter(|cx| fut.as_pin_mut().poll(cx))
}
pub fn enter<F, R>(&mut self, f: F) -> R
where
F: FnOnce(&mut Context<'_>) -> R,
{
let _enter = enter().unwrap();
self.waker.clear();
let waker = self.waker();
let mut cx = Context::from_waker(&waker);
f(&mut cx)
}
pub fn is_woken(&self) -> bool {
self.waker.is_woken()
}
pub fn waker_ref_count(&self) -> usize {
Arc::strong_count(&self.waker)
}
fn waker(&self) -> Waker {
unsafe {
let raw = to_raw(self.waker.clone());
Waker::from_raw(raw)
}
}
}
impl Default for MockTask {
fn default() -> Self {
Self::new()
}
}
impl ThreadWaker {
fn new() -> Self {
ThreadWaker {
state: Mutex::new(IDLE),
condvar: Condvar::new(),
}
}
fn clear(&self) {
*self.state.lock().unwrap() = IDLE;
}
fn is_woken(&self) -> bool {
match *self.state.lock().unwrap() {
IDLE => false,
WAKE => true,
_ => unreachable!(),
}
}
fn wake(&self) {
let mut state = self.state.lock().unwrap();
let prev = *state;
if prev == WAKE {
return;
}
*state = WAKE;
if prev == IDLE {
return;
}
assert_eq!(prev, SLEEP);
self.condvar.notify_one();
}
}
static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop);
unsafe fn to_raw(waker: Arc<ThreadWaker>) -> RawWaker {
RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE)
}
unsafe fn from_raw(raw: *const ()) -> Arc<ThreadWaker> {
Arc::from_raw(raw as *const ThreadWaker)
}
unsafe fn clone(raw: *const ()) -> RawWaker {
let waker = from_raw(raw);
mem::forget(waker.clone());
to_raw(waker)
}
unsafe fn wake(raw: *const ()) {
let waker = from_raw(raw);
waker.wake();
}
unsafe fn wake_by_ref(raw: *const ()) {
let waker = from_raw(raw);
waker.wake();
mem::forget(waker);
}
unsafe fn drop(raw: *const ()) {
let _ = from_raw(raw);
}