use std::{
panic::AssertUnwindSafe,
sync::atomic::{AtomicU32, Ordering},
time::Duration,
};
use inc_complete::{
Db, DbHandle, Storage, define_intermediate, intermediate, storage::HashMapStorage,
};
#[derive(Default, Storage)]
struct MyStorage {
check: HashMapStorage<Foo>,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
struct Foo(bool);
static ASSERTS_PASSED: AtomicU32 = AtomicU32::new(0);
#[intermediate(id = 0)]
fn foo(ctx: &Foo, db: &DbHandle<MyStorage>) -> u32 {
println!("Foo({}) on thread {:?}", ctx.0, std::thread::current().id());
std::thread::sleep(Duration::from_millis(250));
match std::panic::catch_unwind(AssertUnwindSafe(|| Foo(!ctx.0).get(db))) {
Ok(_) => panic!("Ran cycle on Foo({}) without error!", ctx.0),
Err(message) => {
if let Some(message) = message.downcast_ref::<String>() {
assert!(message.contains("Foo(false)"));
assert!(message.contains("Foo(true)"));
println!(
"Caught cycle panic on thread {:?}",
std::thread::current().id()
);
ASSERTS_PASSED.fetch_add(1, Ordering::Relaxed);
} else {
println!("No cycle panic on thread {:?}", std::thread::current().id());
}
panic!()
}
}
}
#[test]
fn cycle_between_two_threads() {
let db = Db::<MyStorage>::new();
std::thread::scope(|scope| {
let f = scope.spawn(|| db.get(Foo(false)));
let t = scope.spawn(|| db.get(Foo(true)));
t.join().ok();
f.join().ok();
});
assert_eq!(ASSERTS_PASSED.load(Ordering::SeqCst), 2);
}