use once_cell::sync::Lazy;
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use std::panic::{self, AssertUnwindSafe};
use std::sync::Mutex;
pub type FixtureFunc = Box<dyn Fn() + Send + Sync + 'static>;
static SETUP_FIXTURES: Lazy<Mutex<HashMap<&'static str, Vec<FixtureFunc>>>> = Lazy::new(|| Mutex::new(HashMap::new()));
static TEARDOWN_FIXTURES: Lazy<Mutex<HashMap<&'static str, Vec<FixtureFunc>>>> = Lazy::new(|| Mutex::new(HashMap::new()));
static BEFORE_ALL_FIXTURES: Lazy<Mutex<HashMap<&'static str, Vec<FixtureFunc>>>> = Lazy::new(|| Mutex::new(HashMap::new()));
static AFTER_ALL_FIXTURES: Lazy<Mutex<HashMap<&'static str, Vec<FixtureFunc>>>> = Lazy::new(|| Mutex::new(HashMap::new()));
static EXECUTED_MODULES: Lazy<Mutex<HashSet<&'static str>>> = Lazy::new(|| Mutex::new(HashSet::new()));
pub fn register_setup(module_path: &'static str, func: FixtureFunc) {
let mut fixtures = SETUP_FIXTURES.lock().unwrap();
fixtures.entry(module_path).or_default().push(func);
}
pub fn register_teardown(module_path: &'static str, func: FixtureFunc) {
let mut fixtures = TEARDOWN_FIXTURES.lock().unwrap();
fixtures.entry(module_path).or_default().push(func);
}
pub fn register_before_all(module_path: &'static str, func: FixtureFunc) {
let mut fixtures = BEFORE_ALL_FIXTURES.lock().unwrap();
fixtures.entry(module_path).or_default().push(func);
}
pub fn register_after_all(module_path: &'static str, func: FixtureFunc) {
let mut fixtures = AFTER_ALL_FIXTURES.lock().unwrap();
fixtures.entry(module_path).or_default().push(func);
}
thread_local! {
static IN_FIXTURE_TEST: RefCell<bool> = const { RefCell::new(false) };
}
pub fn run_test_with_fixtures<F>(module_path: &'static str, test_fn: AssertUnwindSafe<F>)
where
F: FnOnce(),
{
IN_FIXTURE_TEST.with(|flag| {
*flag.borrow_mut() = true;
});
run_before_all_if_needed(module_path);
if let Ok(fixtures) = SETUP_FIXTURES.lock() {
if let Some(setup_funcs) = fixtures.get(module_path) {
for setup_fn in setup_funcs {
setup_fn();
}
}
}
let result = panic::catch_unwind(test_fn);
if let Ok(fixtures) = TEARDOWN_FIXTURES.lock() {
if let Some(teardown_funcs) = fixtures.get(module_path) {
for teardown_fn in teardown_funcs {
teardown_fn();
}
}
}
IN_FIXTURE_TEST.with(|flag| {
*flag.borrow_mut() = false;
});
register_after_all_handler(module_path);
if let Err(err) = result {
panic::resume_unwind(err);
}
}
fn run_before_all_if_needed(module_path: &'static str) {
let mut executed = EXECUTED_MODULES.lock().unwrap();
if !executed.contains(module_path) {
executed.insert(module_path);
if let Ok(fixtures) = BEFORE_ALL_FIXTURES.lock() {
if let Some(before_all_funcs) = fixtures.get(module_path) {
for before_fn in before_all_funcs {
before_fn();
}
}
}
}
}
fn register_after_all_handler(module_path: &'static str) {
let mut executed = EXECUTED_MODULES.lock().unwrap();
executed.insert(module_path);
}
#[doc(hidden)]
pub fn run_after_all_fixtures() {
let executed = EXECUTED_MODULES.lock().unwrap();
if let Ok(fixtures) = AFTER_ALL_FIXTURES.lock() {
for module_path in executed.iter() {
if let Some(after_all_funcs) = fixtures.get(module_path) {
for after_fn in after_all_funcs {
after_fn();
}
}
}
}
}
pub fn is_in_fixture_test() -> bool {
return IN_FIXTURE_TEST.with(|flag| *flag.borrow());
}