use loom::thread;
use std::fmt::Debug;
use std::panic::{self, UnwindSafe};
use std::rc::Rc;
use crate::checker::*;
use crate::execution::*;
use crate::recorder::{self, *};
use crate::spec::*;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Scenario<Op> {
pub init_part: Vec<Op>,
pub parallel_part: Vec<Vec<Op>>,
pub post_part: Vec<Op>,
}
pub fn check_scenario_with_loom<Conc>(
scenario: Scenario<ConcOp<Conc>>,
) -> Result<(), Execution<ConcOp<Conc>, ConcRet<Conc>>>
where
Conc: ConcurrentSpec + Send + Sync + 'static,
Conc::Seq: Send + Sync + 'static,
ConcOp<Conc>: Send + Sync + Clone + Debug + UnwindSafe + 'static,
ConcRet<Conc>: PartialEq + Clone + Debug + Send,
{
let old_hook = panic::take_hook();
panic::set_hook(Box::new(|_| {}));
let result = panic::catch_unwind(|| {
loom::model(move || {
let execution = execute_scenario_with_loom::<Conc>(scenario.clone());
if !LinearizabilityChecker::<Conc::Seq>::check(&execution) {
panic::panic_any(execution);
}
});
});
panic::set_hook(old_hook);
result.map_err(|payload| {
*payload
.downcast::<Execution<ConcOp<Conc>, ConcRet<Conc>>>()
.unwrap_or_else(|_| panic!("loom::model panicked with unknown payload"))
})
}
pub fn execute_scenario_with_loom<Conc>(
scenario: Scenario<ConcOp<Conc>>,
) -> Execution<ConcOp<Conc>, ConcRet<Conc>>
where
Conc: ConcurrentSpec + Send + Sync + 'static,
ConcOp<Conc>: Send + Sync + Clone + 'static,
ConcRet<Conc>: PartialEq,
{
let conc = Rc::new(Conc::default());
let mut recorder = recorder::record_init_part_with_capacity(scenario.init_part.len());
for op in scenario.init_part {
recorder.record(op.clone(), || conc.exec(op));
}
let total_parallel_ops = scenario.parallel_part.iter().map(Vec::len).sum();
let recorder = Rc::new(recorder.record_parallel_part_with_capacity(total_parallel_ops));
let handles: Vec<_> = scenario
.parallel_part
.into_iter()
.map(|thread_ops| {
let conc = conc.clone();
let recorder = recorder.clone();
thread::spawn(move || {
let mut recorder = recorder.record_thread_with_capacity(thread_ops.len());
for op in thread_ops {
recorder.record(op.clone(), || conc.exec(op));
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
let mut recorder = recorder.record_post_part_with_capacity(scenario.post_part.len());
for op in scenario.post_part {
recorder.record(op.clone(), || conc.exec(op));
}
recorder.finish() }