testtools 0.1.3

Helpers for eliminating boilerplate code in tests
Documentation
//! Synchronization primitives for thread/tasks.

use std::{
  future::Future,
  hash::Hash,
  pin::Pin,
  sync::{
    atomic::{AtomicUsize, Ordering},
    Arc
  },
  task::{Context, Poll, Waker},
  thread,
  time::Duration
};

use parking_lot::{Condvar, Mutex};

use hashbrown::{HashMap, HashSet};

struct Inner<T> {
  checkpoints: HashSet<T>,
  wakers: HashMap<usize, Waker>
}

impl<T> Inner<T> {
  fn new() -> Self {
    Inner {
      checkpoints: HashSet::new(),
      wakers: HashMap::new()
    }
  }
}

struct Shared<T> {
  inner: Mutex<Inner<T>>,
  delay: Option<Duration>,
  signal: Condvar,
  idgen: AtomicUsize
}

impl<T> Shared<T> {
  fn new(inner: Inner<T>, delay: Option<Duration>) -> Self {
    Self {
      inner: Mutex::new(inner),
      delay,
      signal: Condvar::new(),
      idgen: AtomicUsize::new(0)
    }
  }

  fn wake_all(&self, inner: &mut Inner<T>) {
    self.signal.notify_all();
    for (_, waker) in inner.wakers.drain() {
      waker.wake();
    }
  }
}

/// The `Checkpoint` is used as a rudimentary synchronization mechanism.
///
/// It is intended to be used primarily for threaded tests where timing can
/// cause non-deterministic call order.
#[derive(Clone)]
pub struct Checkpoint<T>(Arc<Shared<T>>);

impl<T> Checkpoint<T>
where
  T: PartialEq + Eq + Hash
{
  pub fn new() -> Self {
    Self::default()
  }

  pub fn with_delay(dur: Duration) -> Self {
    let inner = Inner::new();
    let shared = Shared::new(inner, Some(dur));
    Self(Arc::new(shared))
  }

  /// Mark a checkpoint as having been reached.
  pub fn reached(&self, checkpoint: T) -> &Self {
    let mut inner = self.0.inner.lock();
    inner.checkpoints.insert(checkpoint);
    self.0.wake_all(&mut inner);

    drop(inner);

    if let Some(dur) = self.0.delay {
      thread::sleep(dur);
    }

    self
  }

  /// Wait for checkpoint(s) to be reached.
  ///
  /// All of the supplied checkpoints must have been reached before this
  /// function returns.
  pub fn waitfor<I>(&self, checkpoints: I) -> &Self
  where
    I: IntoIterator<Item = T>
  {
    let checkpoints = HashSet::from_iter(checkpoints);
    let mut inner = self.0.inner.lock();
    while !checkpoints.is_subset(&inner.checkpoints) {
      self.0.signal.wait(&mut inner);
    }
    drop(inner);

    if let Some(dur) = self.0.delay {
      thread::sleep(dur);
    }

    self
  }

  /// Return a `Future` that will resolve once the specified checkpoint(s) have
  /// been reached.
  ///
  /// All of the supplied checkpoints must have been reached before this
  /// function returns.
  pub fn async_waitfor<I>(&self, checkpoints: I) -> WaitForCheckpoints<T>
  where
    I: IntoIterator<Item = T>
  {
    WaitForCheckpoints {
      sh: Arc::clone(&self.0),
      checkpoints: HashSet::from_iter(checkpoints),
      id: None
    }
  }
}

impl<T> Default for Checkpoint<T> {
  fn default() -> Self {
    let inner = Inner::new();
    let shared = Shared::new(inner, None);
    Self(Arc::new(shared))
  }
}

/// Future used to wait for checkpoints to be reached in async tasks.
pub struct WaitForCheckpoints<T>
where
  T: PartialEq + Eq + Hash
{
  sh: Arc<Shared<T>>,
  checkpoints: HashSet<T>,
  id: Option<usize>
}

impl<T> Future for WaitForCheckpoints<T>
where
  T: PartialEq + Eq + Hash + Unpin
{
  type Output = ();
  fn poll(
    mut self: Pin<&mut Self>,
    ctx: &mut Context<'_>
  ) -> Poll<Self::Output> {
    let mut inner = self.sh.inner.lock();
    if self.checkpoints.is_subset(&inner.checkpoints) {
      Poll::Ready(())
    } else {
      let id = loop {
        let id = self.sh.idgen.fetch_add(1, Ordering::SeqCst);
        if !inner.wakers.contains_key(&id) {
          break id;
        }
      };
      inner.wakers.insert(id, ctx.waker().clone());
      drop(inner);

      self.id = Some(id);
      Poll::Pending
    }
  }
}

impl<T> Drop for WaitForCheckpoints<T>
where
  T: Eq + Hash
{
  fn drop(&mut self) {
    if let Some(id) = self.id {
      let mut inner = self.sh.inner.lock();
      inner.wakers.remove(&id);
    }
  }
}

// vim: set ft=rust et sw=2 ts=2 sts=2 cinoptions=2 tw=79 :