use crate::setup::{CancellationFlag, Knobs, ParDatabase, ParDatabaseImpl, WithValue};
use salsa::{Cancelled, ParallelDatabase};
macro_rules! assert_cancelled {
($thread:expr) => {
match $thread.join() {
Ok(value) => panic!("expected cancellation, got {:?}", value),
Err(payload) => match payload.downcast::<Cancelled>() {
Ok(_) => {}
Err(payload) => ::std::panic::resume_unwind(payload),
},
}
};
}
#[test]
fn in_par_get_set_cancellation_immediate() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 100);
db.set_input('b', 10);
db.set_input('c', 1);
db.set_input('d', 0);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_cancellation
.with_value(CancellationFlag::Panic, || db.sum("abc"))
})
}
});
db.wait_for(1);
db.set_input('d', 1000);
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || db.sum("abc")
});
assert_eq!(db.sum("d"), 1000);
assert_cancelled!(thread1);
assert_eq!(thread2.join().unwrap(), 111);
}
#[test]
fn in_par_get_set_cancellation_transitive() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 100);
db.set_input('b', 10);
db.set_input('c', 1);
db.set_input('d', 0);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_cancellation
.with_value(CancellationFlag::Panic, || db.sum2("abc"))
})
}
});
db.wait_for(1);
db.set_input('d', 1000);
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || db.sum2("abc")
});
assert_eq!(db.sum2("d"), 1000);
assert_cancelled!(thread1);
assert_eq!(thread2.join().unwrap(), 111);
}
#[test]
fn no_back_dating_in_cancellation() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 1);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_cancellation
.with_value(CancellationFlag::Panic, || db.sum3("a"))
})
}
});
db.wait_for(1);
db.set_input('b', 2);
assert_eq!(db.sum3("a"), 1);
assert_cancelled!(thread1);
db.set_input('a', 3);
db.set_input('a', 4);
assert_eq!(db.sum3("ab"), 6);
}