mod util;
use crate::guard::{SlotIdx, StateNodeIdx, state_node_iter};
use crate::memory::MemoryBacking;
use crate::polling::{PollableSlotQuery, SlotStatus};
use crate::tests::util::Racer;
use crate::{EngineFields, SharedStorageHeader, SlotMemory, SlotPoller, StackSlots};
use cache_padded::CachePadded;
use futures_buffered::FuturesUnordered;
use futures_util::StreamExt;
use manual_future::ManualFuture;
use slotpoller_test_util::{DropVerify, FutureVerify};
use std::cell::Cell;
use std::collections::HashSet;
use std::future::poll_fn;
use std::hint::black_box;
use std::pin::{Pin, pin};
use std::rc::Rc;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use std::thread;
impl<F, M> SlotPoller<FutureVerify<F>, M>
where
F: Future,
M: SlotMemory<FutureVerify<F>>,
{
fn assert_integrity(&mut self) {
struct OrderedColl {
label: &'static str,
items: Vec<SlotIdx>,
seen_before: HashSet<SlotIdx>,
}
impl OrderedColl {
fn new(label: &'static str) -> Self {
Self {
label,
items: Vec::new(),
seen_before: HashSet::new(),
}
}
fn add_idx(&mut self, slot_idx: SlotIdx) {
assert!(
self.seen_before.insert(slot_idx),
"Circular queue/stack detected"
);
self.items.push(slot_idx);
}
fn assert_contains(&self, slot_idx: SlotIdx) {
assert!(
self.seen_before.contains(&slot_idx),
"Expected {:?} to be inside {}",
slot_idx,
self.label
);
}
fn assert_not_contains(&self, slot_idx: SlotIdx) {
assert!(
!self.seen_before.contains(&slot_idx),
"Expected {:?} to NOT be inside {}",
slot_idx,
self.label
);
}
}
let panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let EngineFields {
activity,
mut slots,
shared_storage,
} = self.memory.fields();
let mut empty_stack = OrderedColl::new("empty_stack");
{
let mut current = activity.empty_head;
while current.is_set() {
let current_idx = current.into_slot_idx();
empty_stack.add_idx(current_idx);
let slot = current_idx.get_slot(slots.as_mut());
current = slot.empty_link;
}
assert_eq!(
empty_stack.items.len(),
slots.len() - (activity.slots_active as usize)
);
}
let mut pollable_queue = OrderedColl::new("pollable_queue");
{
let head = activity.poll_queue_head;
let tail = shared_storage
.header
.poll_queue_tail
.load(Ordering::Relaxed);
assert_ne!(head, StateNodeIdx::UNSET);
assert_ne!(tail, StateNodeIdx::UNSET);
let mut found_stub = false;
let mut current = head;
let mut last = current;
while current != StateNodeIdx::UNSET {
let current_node = unsafe {
current.get_state_node(shared_storage)
};
if current == StateNodeIdx::STUB {
assert!(!found_stub, "Cannot find stub more than once");
found_stub = true;
} else {
pollable_queue.add_idx(current_node.slot_idx);
}
last = current;
current = current_node.poll_queue_link.load(Ordering::Relaxed);
}
assert_eq!(tail, last, "Tail should be last of the chain");
assert_eq!(slots.len() + 1, shared_storage.header.nodes_len as usize);
}
for (idx, state_node) in state_node_iter(shared_storage) {
assert_eq!(idx, state_node.slot_idx);
let slot = idx.get_slot(slots.as_mut());
let status = state_node.status.load(Ordering::Relaxed);
let has_future = match status {
SlotStatus::Uninit => {
empty_stack.assert_contains(idx);
pollable_queue.assert_not_contains(idx);
false
}
SlotStatus::UninitButEnqueued => {
empty_stack.assert_contains(idx);
pollable_queue.assert_contains(idx);
false
}
SlotStatus::Waiting => {
empty_stack.assert_not_contains(idx);
pollable_queue.assert_not_contains(idx);
true
}
SlotStatus::Woken | SlotStatus::Init => {
empty_stack.assert_not_contains(idx);
pollable_queue.assert_contains(idx);
true
}
};
if has_future {
let future = unsafe {
slot.future.assume_init_ref()
};
black_box(future);
}
}
}));
if let Err(panic) = panic {
println!("Before panicking, self is {:?}", self);
std::panic::resume_unwind(panic)
}
}
}
#[test]
fn completed_future() {
let mut drop_verify = DropVerify::default();
{
let cell = Cell::new(1u32);
let mut make_fut = || {
let fut = async {
let val = cell.get();
let update = val + 1;
cell.set(update);
val
};
let fut = FutureVerify::with_drop_detect(fut, &mut drop_verify);
fut
};
let slots = pin!(StackSlots::<5, _>::new());
let mut poller = SlotPoller::new(slots);
for _ in 0..5 {
poller.try_push(make_fut()).unwrap();
}
let sum = async {
let mut sum = 0;
poller
.drain(|completion| {
println!("Received completion {:?}", completion);
sum += completion;
})
.await;
sum
};
assert_eq!(15, pollster::block_on(sum));
poller.assert_integrity();
let sum = async {
for _ in 0..5 {
let (res, vacancy) = poller.next_vacancy().await;
if res.is_some() {
panic!("Poller should have free space");
}
vacancy.insert(make_fut());
}
let mut sum = 0;
for _ in 0..5 {
let completion = poller.next_completion().unwrap().await;
sum += completion;
}
sum
};
cell.set(2);
assert_eq!(20, pollster::block_on(sum));
poller.assert_integrity();
for _ in 0..5 {
poller.try_push(make_fut()).unwrap();
}
let sum = async {
let mut sum = 0;
for _ in 0..5 {
let (res, vacancy) = poller.next_vacancy().await;
if let Some(res) = res {
sum += res;
} else {
panic!("Poller should be saturated");
}
vacancy.insert(make_fut());
}
sum
};
cell.set(0);
assert_eq!(10, pollster::block_on(sum));
poller.assert_integrity();
}
drop_verify.verify();
}
#[test]
fn race_to_wake_future() {
let mut drop_verify = DropVerify::default();
pollster::block_on(async {
let slots = pin!(StackSlots::<1, _>::new());
let mut poller = SlotPoller::new(slots);
let waker_store = Arc::new(Mutex::new(None::<Waker>));
let waker_avail = Arc::new(AtomicBool::new(false));
let (backing, completable) = ManualFuture::new();
let future = {
let waker_store = waker_store.clone();
let waker_avail = waker_avail.clone();
let mut backing = Box::pin(backing);
FutureVerify::with_drop_detect(
poll_fn(move |cx| {
let waker = cx.waker();
{
{
let mut waker_store = waker_store.lock().unwrap();
if let Some(prev_waker) = waker_store.as_mut() {
prev_waker.clone_from(waker);
} else {
waker_store.replace(waker.clone());
}
}
waker_avail.store(true, Ordering::Release);
}
let mut racer = Racer::default();
for _ in 0..5 {
racer.add_task(|| waker.wake_by_ref());
}
racer.execute();
backing
.as_mut()
.poll(&mut Context::from_waker(&Waker::noop()))
}),
&mut drop_verify,
)
};
poller.try_push(future).unwrap();
let completing_thread = thread::spawn(move || {
thread::sleep(std::time::Duration::from_millis(200));
println!("Finished sleeping for 200 millis");
let waker_store = waker_store.lock().unwrap();
while !waker_avail.load(Ordering::Acquire) {
thread::yield_now();
}
let waker = waker_store.as_ref().unwrap();
let mut racer = Racer::default();
for _ in 0..20 {
racer.add_task(|| waker.wake_by_ref());
}
pollster::block_on(completable.complete(3));
racer.execute();
});
let (res, _vacancy) = poller.next_vacancy().await;
assert_eq!(Some(3), res);
completing_thread.join().unwrap();
poller.assert_integrity();
});
drop_verify.verify();
}
#[test]
fn race_to_enqueue_as_pollable() {
let mut drop_verify = DropVerify::default();
pollster::block_on(async {
const COUNT: usize = 10;
let slots = pin!(StackSlots::<COUNT, _>::new());
let mut poller = SlotPoller::new(slots);
let mut completers = Vec::new();
for _ in 0..COUNT {
let (backing, completable) = ManualFuture::new();
let mut backing = Box::pin(backing);
let future = FutureVerify::with_drop_detect(
poll_fn(move |cx| {
let waker = cx.waker();
let mut racer = Racer::default();
for _ in 0..5 {
racer.add_task(|| waker.wake_by_ref());
}
racer.execute();
backing.as_mut().poll(cx)
}),
&mut drop_verify,
);
poller.try_push(future).unwrap();
completers.push(completable);
}
let counter = Arc::new(AtomicU32::new(0));
let completing_thread_master = thread::spawn(move || {
thread::sleep(std::time::Duration::from_millis(200));
let mut racer = Racer::default();
for completer in completers {
let counter = counter.clone();
racer.add_task(move || {
let update = counter.fetch_add(1, Ordering::AcqRel);
pollster::block_on(completer.complete(update));
});
}
racer.execute();
});
let mut sum = 0;
for _ in 0..COUNT {
let completion = poller.next_completion().unwrap().await;
sum += completion;
}
assert_eq!(45, sum);
completing_thread_master.join().unwrap();
poller.assert_integrity();
});
drop_verify.verify();
}
enum EitherFut<F1, F2> {
One(F1),
Two(F2),
}
impl<F1, F2, O> Future for EitherFut<F1, F2>
where
F1: Future<Output = O>,
F2: Future<Output = O>,
{
type Output = O;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
unsafe {
match self.get_unchecked_mut() {
EitherFut::One(f) => Pin::new_unchecked(f).poll(cx),
EitherFut::Two(f) => Pin::new_unchecked(f).poll(cx),
}
}
}
}
#[test]
fn change_waker_across_polls() {
let (start_fut, start_completer) = ManualFuture::new();
let (connector_fut, connector_completer) = ManualFuture::new();
let test_started = Rc::new(Cell::new(false));
let test_pass = Rc::new(Cell::new(false));
let task1 = async move {
connector_fut.await;
};
let task2 = {
let test_started = test_started.clone();
let test_pass = test_pass.clone();
async move {
start_fut.await;
test_started.set(true);
let stack_slots = pin!(StackSlots::<2, _>::new());
let mut poller = SlotPoller::new(stack_slots);
poller.try_push(task1).unwrap();
let drain = poller.drain(|()| {});
let mut drain = pin!(drain);
if let Poll::Ready(()) = drain
.as_mut()
.poll(&mut Context::from_waker(&Waker::noop()))
{
panic!("broken test setup");
}
drain.await;
test_pass.set(true);
}
};
let task3 = {
let test_started = test_started.clone();
let test_pass = test_pass.clone();
async move {
start_completer.complete(()).await;
poll_fn(|cx| {
if test_started.get() {
Poll::Ready(())
} else {
cx.waker().wake_by_ref();
Poll::Pending
}
})
.await;
connector_completer.complete(()).await;
let yield_count = &mut 0;
poll_fn(|cx| {
if *yield_count < 10 {
*yield_count += 1;
cx.waker().wake_by_ref();
Poll::Pending
} else {
Poll::Ready(())
}
})
.await;
assert!(test_pass.get().clone());
}
};
pollster::block_on(async move {
let mut all = FuturesUnordered::new();
all.push(EitherFut::One(task2));
all.push(EitherFut::Two(task3));
while let Some(_) = all.next().await {
}
});
}
#[test]
fn optimal_layout() {
assert_eq!(
size_of::<CachePadded<usize>>(),
size_of::<SharedStorageHeader>(),
"Shared storage header should fit into single cache line"
);
assert_eq!(
2 * size_of::<usize>(),
size_of::<PollableSlotQuery<'_>>(),
"PollableSlotQuery should use NPO optimization"
)
}