tokio-test 0.4.2

Testing utilities for Tokio- and futures-based code
Documentation
//! Futures task based helpers

#![allow(clippy::mutex_atomic)]

use std::future::Future;
use std::mem;
use std::ops;
use std::pin::Pin;
use std::sync::{Arc, Condvar, Mutex};
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};

use tokio_stream::Stream;

/// TODO: dox
pub fn spawn<T>(task: T) -> Spawn<T> {
    Spawn {
        task: MockTask::new(),
        future: Box::pin(task),
    }
}

/// Future spawned on a mock task
#[derive(Debug)]
pub struct Spawn<T> {
    task: MockTask,
    future: Pin<Box<T>>,
}

/// Mock task
///
/// A mock task is able to intercept and track wake notifications.
#[derive(Debug, Clone)]
struct MockTask {
    waker: Arc<ThreadWaker>,
}

#[derive(Debug)]
struct ThreadWaker {
    state: Mutex<usize>,
    condvar: Condvar,
}

const IDLE: usize = 0;
const WAKE: usize = 1;
const SLEEP: usize = 2;

impl<T> Spawn<T> {
    /// Consumes `self` returning the inner value
    pub fn into_inner(self) -> T
    where
        T: Unpin,
    {
        *Pin::into_inner(self.future)
    }

    /// Returns `true` if the inner future has received a wake notification
    /// since the last call to `enter`.
    pub fn is_woken(&self) -> bool {
        self.task.is_woken()
    }

    /// Returns the number of references to the task waker
    ///
    /// The task itself holds a reference. The return value will never be zero.
    pub fn waker_ref_count(&self) -> usize {
        self.task.waker_ref_count()
    }

    /// Enter the task context
    pub fn enter<F, R>(&mut self, f: F) -> R
    where
        F: FnOnce(&mut Context<'_>, Pin<&mut T>) -> R,
    {
        let fut = self.future.as_mut();
        self.task.enter(|cx| f(cx, fut))
    }
}

impl<T: Unpin> ops::Deref for Spawn<T> {
    type Target = T;

    fn deref(&self) -> &T {
        &self.future
    }
}

impl<T: Unpin> ops::DerefMut for Spawn<T> {
    fn deref_mut(&mut self) -> &mut T {
        &mut self.future
    }
}

impl<T: Future> Spawn<T> {
    /// Polls a future
    pub fn poll(&mut self) -> Poll<T::Output> {
        let fut = self.future.as_mut();
        self.task.enter(|cx| fut.poll(cx))
    }
}

impl<T: Stream> Spawn<T> {
    /// Polls a stream
    pub fn poll_next(&mut self) -> Poll<Option<T::Item>> {
        let stream = self.future.as_mut();
        self.task.enter(|cx| stream.poll_next(cx))
    }
}

impl<T: Future> Future for Spawn<T> {
    type Output = T::Output;

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        self.future.as_mut().poll(cx)
    }
}

impl<T: Stream> Stream for Spawn<T> {
    type Item = T::Item;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        self.future.as_mut().poll_next(cx)
    }
}

impl MockTask {
    /// Creates new mock task
    fn new() -> Self {
        MockTask {
            waker: Arc::new(ThreadWaker::new()),
        }
    }

    /// Runs a closure from the context of the task.
    ///
    /// Any wake notifications resulting from the execution of the closure are
    /// tracked.
    fn enter<F, R>(&mut self, f: F) -> R
    where
        F: FnOnce(&mut Context<'_>) -> R,
    {
        self.waker.clear();
        let waker = self.waker();
        let mut cx = Context::from_waker(&waker);

        f(&mut cx)
    }

    /// Returns `true` if the inner future has received a wake notification
    /// since the last call to `enter`.
    fn is_woken(&self) -> bool {
        self.waker.is_woken()
    }

    /// Returns the number of references to the task waker
    ///
    /// The task itself holds a reference. The return value will never be zero.
    fn waker_ref_count(&self) -> usize {
        Arc::strong_count(&self.waker)
    }

    fn waker(&self) -> Waker {
        unsafe {
            let raw = to_raw(self.waker.clone());
            Waker::from_raw(raw)
        }
    }
}

impl Default for MockTask {
    fn default() -> Self {
        Self::new()
    }
}

impl ThreadWaker {
    fn new() -> Self {
        ThreadWaker {
            state: Mutex::new(IDLE),
            condvar: Condvar::new(),
        }
    }

    /// Clears any previously received wakes, avoiding potential spurrious
    /// wake notifications. This should only be called immediately before running the
    /// task.
    fn clear(&self) {
        *self.state.lock().unwrap() = IDLE;
    }

    fn is_woken(&self) -> bool {
        match *self.state.lock().unwrap() {
            IDLE => false,
            WAKE => true,
            _ => unreachable!(),
        }
    }

    fn wake(&self) {
        // First, try transitioning from IDLE -> NOTIFY, this does not require a lock.
        let mut state = self.state.lock().unwrap();
        let prev = *state;

        if prev == WAKE {
            return;
        }

        *state = WAKE;

        if prev == IDLE {
            return;
        }

        // The other half is sleeping, so we wake it up.
        assert_eq!(prev, SLEEP);
        self.condvar.notify_one();
    }
}

static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker);

unsafe fn to_raw(waker: Arc<ThreadWaker>) -> RawWaker {
    RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE)
}

unsafe fn from_raw(raw: *const ()) -> Arc<ThreadWaker> {
    Arc::from_raw(raw as *const ThreadWaker)
}

unsafe fn clone(raw: *const ()) -> RawWaker {
    let waker = from_raw(raw);

    // Increment the ref count
    mem::forget(waker.clone());

    to_raw(waker)
}

unsafe fn wake(raw: *const ()) {
    let waker = from_raw(raw);
    waker.wake();
}

unsafe fn wake_by_ref(raw: *const ()) {
    let waker = from_raw(raw);
    waker.wake();

    // We don't actually own a reference to the unparker
    mem::forget(waker);
}

unsafe fn drop_waker(raw: *const ()) {
    let _ = from_raw(raw);
}