use crate::Runtime;
use futures::FutureExt;
use std::{
collections::HashMap,
future::Future,
hash::Hash,
panic::{resume_unwind, AssertUnwindSafe},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use tokio::{
sync::watch::{self, Sender},
task::JoinHandle,
};
pub struct TestRuntimeManager<L> {
map: HashMap<L, TestRuntime>,
}
impl<L> TestRuntimeManager<L> {
pub fn new() -> TestRuntimeManager<L> {
TestRuntimeManager {
map: HashMap::new(),
}
}
}
impl<L> TestRuntimeManager<L>
where
L: Eq + Hash,
{
pub fn with_label(&mut self, label: L) -> TestRuntime {
self.map.entry(label).or_default().clone()
}
pub async fn wait_for_completed_tasks(&mut self, label: L, target_count: usize) {
let labeled_runtime = self.with_label(label);
let mut receiver = labeled_runtime.inner.sender.subscribe();
while *receiver.borrow_and_update() < target_count {
receiver.changed().await.expect(
"The channel sender should not be dropped before waits have finished, this likely \
indicates an issue with a test.",
);
}
}
}
impl<L> Default for TestRuntimeManager<L> {
fn default() -> TestRuntimeManager<L> {
TestRuntimeManager::new()
}
}
impl<L> Drop for TestRuntimeManager<L> {
fn drop(&mut self) {
for labeled_runtime in self.map.values() {
if labeled_runtime.inner.any_panic.load(Ordering::Acquire) {
panic!("An async task panicked");
}
}
}
}
#[derive(Clone)]
pub struct TestRuntime {
inner: Arc<Inner>,
}
struct Inner {
any_panic: AtomicBool,
sender: Sender<usize>,
}
impl TestRuntime {
fn new() -> TestRuntime {
let (channel, _) = watch::channel(0);
TestRuntime {
inner: Arc::new(Inner {
any_panic: AtomicBool::new(false),
sender: channel,
}),
}
}
}
impl Default for TestRuntime {
fn default() -> TestRuntime {
TestRuntime::new()
}
}
impl Runtime for TestRuntime {
fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let labeled_runtime = self.clone();
tokio::task::spawn(async move {
let res = AssertUnwindSafe(future).catch_unwind().await;
labeled_runtime
.inner
.sender
.send_modify(|counter| *counter += 1);
match res {
Ok(output) => output,
Err(e) => {
labeled_runtime
.inner
.any_panic
.fetch_or(true, Ordering::Release);
resume_unwind(e);
}
}
})
}
}
#[cfg(test)]
mod tests {
use super::{Runtime, TestRuntimeManager};
use std::sync::Arc;
use tokio::{sync::Barrier, task::JoinHandle};
#[tokio::test]
async fn mock_runtime() {
#[derive(PartialEq, Eq, Hash)]
enum Label {
A,
B,
}
#[must_use]
struct TaskHandleNewtype(JoinHandle<()>);
let mut runtime = TestRuntimeManager::<Label>::new();
let runtime_a = runtime.with_label(Label::A);
let runtime_b = runtime.with_label(Label::B);
let barrier = Arc::new(Barrier::new(2));
let handle_a_1 = runtime_a.spawn(std::future::ready(()));
let handle_b_1 = runtime_b.spawn(std::future::ready(()));
runtime.wait_for_completed_tasks(Label::A, 1).await;
runtime.wait_for_completed_tasks(Label::B, 1).await;
assert_eq!(*runtime_a.inner.sender.borrow(), 1);
assert_eq!(*runtime_b.inner.sender.borrow(), 1);
let handle_a_2 = runtime_a.spawn({
let handle_a_3 = runtime_a.spawn(std::future::ready(()));
let barrier = Arc::clone(&barrier);
async move {
barrier.wait().await;
TaskHandleNewtype(handle_a_3)
}
});
assert_eq!(*runtime_a.inner.sender.borrow(), 1);
barrier.wait().await;
runtime.wait_for_completed_tasks(Label::A, 2).await;
runtime.wait_for_completed_tasks(Label::A, 3).await;
runtime.wait_for_completed_tasks(Label::A, 2).await;
handle_a_1.await.unwrap();
let handle_a_3 = handle_a_2.await.unwrap().0;
handle_a_3.await.unwrap();
handle_b_1.await.unwrap();
assert_eq!(*runtime_a.inner.sender.borrow(), 3);
assert_eq!(*runtime_b.inner.sender.borrow(), 1);
}
#[tokio::test]
#[should_panic]
async fn noisy_task_panic() {
let mut runtime = TestRuntimeManager::<()>::new();
let handle = runtime.with_label(()).spawn(async {
panic!("panic inside separate task");
});
let _ = handle.await;
drop(runtime);
}
}