use async_backtrace::frame;
use async_backtrace::framed;
use futures::FutureExt;
use futures::future::BoxFuture;
use futures::stream;
use futures::stream::StreamExt;
use std::collections::HashMap;
use std::collections::HashSet;
use std::future;
use std::panic;
use std::sync::Arc;
use std::sync::RwLock;
use std::time::Duration;
use tokio::time;
use tracing::Instrument;
use tracing::Span;
use tracing::error;
use tracing::error_span;
use tracing::info;
type TestFuture = BoxFuture<'static, ()>;
type TestFn<F> = Box<dyn Fn(F) -> TestFuture>;
#[derive(Debug, Default)]
pub(crate) struct RunReport {
failed_tests: Vec<String>,
}
impl RunReport {
pub(crate) fn failed_tests_summary(&self) -> Option<String> {
format_failed_tests(&self.failed_tests)
}
pub(crate) fn is_success(&self) -> bool {
self.failed_tests.is_empty()
}
}
impl From<Statistics> for RunReport {
fn from(stats: Statistics) -> Self {
Self {
failed_tests: stats.failed_tests,
}
}
}
pub(crate) fn format_failed_tests(failed_tests: &[String]) -> Option<String> {
if failed_tests.is_empty() {
return None;
}
let failed_tests = failed_tests
.iter()
.map(|failed_test| format!("- {failed_test}"))
.collect::<Vec<_>>()
.join("\n");
Some(format!("Failed tests:\n{failed_tests}"))
}
#[derive(Debug)]
pub(crate) struct Statistics {
total: usize,
launched: usize,
ok: usize,
failed: usize,
failed_tests: Vec<String>,
}
impl Statistics {
fn new(total: usize) -> Self {
Self {
total,
launched: 0,
ok: 0,
failed: 0,
failed_tests: vec![],
}
}
fn append(&mut self, other: &Self) {
self.total += other.total;
self.launched += other.launched;
self.ok += other.ok;
self.failed += other.failed;
self.failed_tests.extend(other.failed_tests.iter().cloned());
}
fn record_failure(&mut self, failed_test: impl Into<String>) {
self.failed += 1;
self.failed_tests.push(failed_test.into());
}
}
pub struct TestCase<F>
where
F: Clone + Send + Sync + 'static,
{
init: Option<(Duration, TestFn<F>)>,
tests: Vec<(String, Duration, TestFn<F>)>,
cleanup: Option<(Duration, TestFn<F>)>,
}
impl<F> TestCase<F>
where
F: Clone + Send + Sync + 'static,
{
pub fn empty() -> Self {
Self {
init: None,
tests: vec![],
cleanup: None,
}
}
pub fn tests(&self) -> &Vec<(String, Duration, TestFn<F>)> {
&self.tests
}
pub fn with_init<TF, R>(mut self, timeout: Duration, test_fn: TF) -> Self
where
TF: Fn(F) -> R + 'static,
R: Future<Output = ()> + Send + 'static,
{
self.init = Some((timeout, wrap_test_fn(test_fn)));
self
}
pub fn with_test<TF, R>(mut self, name: impl ToString, timeout: Duration, test_fn: TF) -> Self
where
TF: Fn(F) -> R + 'static,
R: Future<Output = ()> + Send + 'static,
{
self.tests
.push((name.to_string(), timeout, wrap_test_fn(test_fn)));
self
}
pub fn with_cleanup<TF, R>(mut self, timeout: Duration, test_fn: TF) -> Self
where
TF: Fn(F) -> R + 'static,
R: Future<Output = ()> + Send + 'static,
{
self.cleanup = Some((timeout, wrap_test_fn(test_fn)));
self
}
#[framed]
async fn run(
&self,
fixture: F,
test_case_name: &str,
test_cases: &HashSet<String>,
backtrace: Backtrace,
) -> Statistics {
let total = if test_cases.is_empty() {
self.tests.len()
} else {
test_cases.len()
} + (self.init.is_some() as usize)
+ (self.cleanup.is_some() as usize);
let mut stats = Statistics::new(total);
if let Some((timeout, init)) = &self.init {
stats.launched += 1;
if !run_single(
error_span!("init"),
*timeout,
init(fixture.clone()),
backtrace.clone(),
)
.await
{
stats.record_failure(format!("{test_case_name}::init"));
return stats;
}
stats.ok += 1;
}
stream::iter(self.tests.iter())
.filter(|(name, _, _)| {
future::ready(test_cases.is_empty() || test_cases.contains(name))
})
.then(|(name, timeout, test)| {
let actors = fixture.clone();
let backtrace = backtrace.clone();
async move {
let ok =
run_single(error_span!("test", name), *timeout, test(actors), backtrace)
.await;
(name, ok)
}
})
.for_each(|(name, ok)| {
stats.launched += 1;
if ok {
stats.ok += 1;
} else {
stats.record_failure(format!("{test_case_name}::{name}"));
}
future::ready(())
})
.await;
if let Some((timeout, cleanup)) = &self.cleanup {
stats.launched += 1;
if !run_single(
error_span!("cleanup"),
*timeout,
cleanup(fixture.clone()),
backtrace.clone(),
)
.await
{
stats.record_failure(format!("{test_case_name}::cleanup"));
} else {
stats.ok += 1;
}
}
stats
}
}
fn wrap_test_fn<TF, R, F>(test_fn: TF) -> TestFn<F>
where
TF: Fn(F) -> R + 'static,
R: Future<Output = ()> + Send + 'static,
{
Box::new(move |fixture: F| {
let future = test_fn(fixture);
future.boxed()
})
}
#[framed]
async fn run_single(
span: Span,
timeout: Duration,
future: TestFuture,
backtrace: Backtrace,
) -> bool {
let task = tokio::spawn(frame!(
async move {
time::timeout(timeout, future)
.await
.expect("test timed out");
}
.instrument(span.clone())
));
if let Err(err) = task.await {
let backtrace = backtrace.get();
error!(parent: &span, "test failed: {err}\n{backtrace}");
false
} else {
info!(parent: &span, "test ok");
true
}
}
#[framed]
pub(crate) async fn run<F>(
fixture: F,
test_cases: Vec<(String, TestCase<F>)>,
filter_map: Arc<HashMap<String, HashSet<String>>>,
) -> RunReport
where
F: Clone + Send + Sync + 'static,
{
let backtrace = setup_panic_hook();
let stats = stream::iter(test_cases.into_iter())
.filter(|(file_name, _)| {
let process = filter_map.is_empty() || filter_map.contains_key(file_name);
async move { process }
})
.then(|(name, test_case)| {
let fixture = fixture.clone();
let filter = filter_map.clone();
let file_name = name.clone();
let backtrace = backtrace.clone();
async move {
let stats = test_case
.run(
fixture,
&file_name,
filter.get(&file_name).unwrap_or(&HashSet::new()),
backtrace,
)
.instrument(error_span!("test-case", name))
.await;
if stats.failed > 0 {
error!("test case failed: {stats:?}");
} else {
info!("test case ok: {stats:?}");
}
stats
}
})
.fold(Statistics::new(0), |mut acc, stats| async move {
acc.append(&stats);
acc
})
.await;
clear_panic_hook();
if stats.failed > 0 {
error!("test run failed: {stats:?}");
return stats.into();
}
info!("test run ok: {stats:?}");
stats.into()
}
#[derive(Clone, Debug)]
struct Backtrace(Arc<RwLock<String>>);
impl Backtrace {
fn new() -> Self {
Self(Arc::new(RwLock::new(String::new())))
}
fn get(&self) -> String {
self.0.read().unwrap().clone()
}
}
fn setup_panic_hook() -> Backtrace {
let backtrace = Backtrace::new();
let old_hook = panic::take_hook();
panic::set_hook(Box::new({
let backtrace = backtrace.clone();
move |info| {
*backtrace.0.write().unwrap() = async_backtrace::taskdump_tree(true);
old_hook(info);
}
}));
backtrace
}
fn clear_panic_hook() {
_ = panic::take_hook();
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
#[test]
fn formats_failed_tests_as_bulleted_list() {
let failed_tests = vec!["crud::boom".to_string(), "crud::cleanup".to_string()];
assert_eq!(
format_failed_tests(&failed_tests).as_deref(),
Some("Failed tests:\n- crud::boom\n- crud::cleanup")
);
}
#[tokio::test]
async fn collects_failed_test_names() {
let timeout = Duration::from_secs(1);
let test_case = TestCase::empty()
.with_test("ok", timeout, |_actors| async {})
.with_test("boom", timeout, |_actors| async {
panic!("boom");
})
.with_cleanup(timeout, |_actors| async {
panic!("cleanup");
});
let stats = test_case
.run((), "crud", &HashSet::new(), Backtrace::new())
.await;
assert_eq!(stats.failed, 2);
assert_eq!(
stats.failed_tests,
&["crud::boom".to_string(), "crud::cleanup".to_string()]
);
}
}