use std::cell::Cell;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::rc::Rc;
use std::time::Duration;
use async_trait::async_trait;
use proptest::prelude::*;
use proptest::strategy::{NewTree, ValueTree};
use proptest::test_runner::TestRunner;
use rand::distributions::{Distribution, Uniform};
pub struct ProptestStatefulConfig {
pub min_ops: usize,
pub max_ops: usize,
pub test_case_timeout: Duration,
pub proptest_config: ProptestConfig,
}
#[async_trait(?Send)]
pub trait ModelState: Clone + Debug + Default {
type Operation: Clone + Debug;
type RunContext;
type OperationStrategy: Strategy<Value = Self::Operation>;
fn op_generators(&self) -> Vec<Self::OperationStrategy>;
fn preconditions_met(&self, op: &Self::Operation) -> bool;
fn next_state(&mut self, op: &Self::Operation);
async fn init_test_run(&self) -> Self::RunContext;
async fn run_op(&self, op: &Self::Operation, ctxt: &mut Self::RunContext);
async fn check_postconditions(&self, ctxt: &mut Self::RunContext);
async fn clean_up_test_run(&self, ctxt: &mut Self::RunContext);
}
struct TestTree<T, O>
where
O: Clone + Debug,
T: ModelState<Operation = O>,
{
ops: Vec<O>,
currently_included: Vec<bool>,
last_shrank_idx: usize,
next_step_idx: Rc<Cell<usize>>,
have_trimmed_after_failure: bool,
state_type: PhantomData<T>,
}
impl<T, O> ValueTree for TestTree<T, O>
where
O: Clone + Debug,
T: ModelState<Operation = O>,
{
type Value = Vec<O>;
fn current(&self) -> Self::Value {
self.ops
.iter()
.zip(self.currently_included.iter())
.filter_map(|(v, include)| include.then_some(v))
.cloned()
.collect()
}
fn simplify(&mut self) -> bool {
if !self.have_trimmed_after_failure {
self.have_trimmed_after_failure = true;
let next_step_idx = self.next_step_idx.get();
let steps_after_failure = next_step_idx + 1..self.ops.len();
for i in steps_after_failure {
self.currently_included[i] = false;
}
self.last_shrank_idx = next_step_idx;
}
self.try_removing_op()
}
fn complicate(&mut self) -> bool {
self.currently_included[self.last_shrank_idx] = true;
self.try_removing_op()
}
}
impl<T, O> TestTree<T, O>
where
O: Clone + Debug,
T: ModelState<Operation = O>,
{
fn try_removing_op(&mut self) -> bool {
while self.last_shrank_idx > 0 {
self.last_shrank_idx -= 1;
self.currently_included[self.last_shrank_idx] = false;
let candidate_ops = self.current();
if preconditions_hold::<T, O>(&candidate_ops) {
return true;
}
self.currently_included[self.last_shrank_idx] = true;
}
false
}
}
fn preconditions_hold<T, O>(ops: &[O]) -> bool
where
T: ModelState<Operation = O>,
{
let mut candidate_state = T::default();
for op in ops {
if !candidate_state.preconditions_met(op) {
return false;
} else {
candidate_state.next_state(op);
}
}
true
}
#[derive(Clone, Debug)]
struct TestState<T, O>
where
T: ModelState<Operation = O>,
{
model_state: T,
next_step_idx: Rc<Cell<usize>>,
min_ops: usize,
max_ops: usize,
}
impl<T, O> Strategy for TestState<T, O>
where
T: ModelState<Operation = O>,
O: Clone + Debug,
{
type Value = Vec<O>;
type Tree = TestTree<T, O>;
fn new_tree(&self, runner: &mut TestRunner) -> NewTree<Self> {
let mut symbolic_state = self.model_state.clone();
let size = Uniform::new_inclusive(self.min_ops, self.max_ops).sample(runner.rng());
let ops = (0..size)
.map(|_| {
let mut possible_ops = symbolic_state.op_generators();
let op_gen =
possible_ops.swap_remove(runner.rng().gen_range(0..possible_ops.len()));
let next_op = op_gen.new_tree(runner).unwrap().current();
symbolic_state.next_state(&next_op);
next_op
})
.collect();
let currently_included = vec![true; size];
Ok(TestTree {
ops,
currently_included,
last_shrank_idx: size, next_step_idx: self.next_step_idx.clone(),
have_trimmed_after_failure: false,
state_type: PhantomData,
})
}
}
async fn run<T, O>(steps: Vec<O>, next_step_idx: Rc<Cell<usize>>)
where
T: ModelState<Operation = O>,
O: Debug,
{
let mut runtime_state = T::default();
let mut ctxt = runtime_state.init_test_run().await;
for (idx, op) in steps.iter().enumerate() {
println!("Running op {idx}: {op:?}");
runtime_state.run_op(op, &mut ctxt).await;
runtime_state.next_state(op);
runtime_state.check_postconditions(&mut ctxt).await;
next_step_idx.set(next_step_idx.get() + 1);
}
runtime_state.clean_up_test_run(&mut ctxt).await;
}
pub fn test<T>(mut stateful_config: ProptestStatefulConfig)
where
T: ModelState,
{
stateful_config.proptest_config.max_shrink_iters = stateful_config.max_ops as u32 * 2;
let next_step_idx = Rc::new(Cell::new(0));
let test_state = TestState {
model_state: T::default(),
next_step_idx: next_step_idx.clone(),
min_ops: stateful_config.min_ops,
max_ops: stateful_config.max_ops,
};
let mut runner = TestRunner::new(stateful_config.proptest_config);
let result = runner.run(&test_state, |steps| {
prop_assume!(preconditions_hold::<T, _>(&steps));
next_step_idx.set(0);
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
tokio::time::timeout(
stateful_config.test_case_timeout,
run::<T, _>(steps, next_step_idx.clone()),
)
.await
})
.unwrap();
Ok(())
});
match result {
Ok(_) => (),
Err(e) => panic!("{}\n{}", e, runner),
}
}