use crate::invariant::Invariant;
use crate::machine::{EventKind, Machine, StateMachine};
#[derive(Debug, Clone)]
pub struct TestParams {
pub cases: u32,
pub max_steps: usize,
pub seed: Option<u64>,
}
impl TestParams {
pub fn new() -> Self {
Self {
cases: 500,
max_steps: 100,
seed: None,
}
}
}
impl Default for TestParams {
fn default() -> Self {
Self::new()
}
}
pub mod probe {
use crate::invariant::{Invariant, Invariants};
use crate::machine::StateMachine;
use core::marker::PhantomData;
pub struct Probe<S>(pub PhantomData<S>);
pub trait ViaNone<S: StateMachine> {
fn collect(&self) -> Vec<Invariant<S, S::Event>> {
Vec::new()
}
}
impl<S: StateMachine> ViaNone<S> for Probe<S> {}
pub trait ViaImpl<S: StateMachine> {
fn collect(&self) -> Vec<Invariant<S, S::Event>>;
}
impl<S: StateMachine + Invariants> ViaImpl<S> for &Probe<S> {
fn collect(&self) -> Vec<Invariant<S, S::Event>> {
<S as Invariants>::invariants()
}
}
}
pub fn run<S: StateMachine>(params: TestParams, invariants: Vec<Invariant<S, S::Event>>) {
use proptest::prelude::*;
use proptest::test_runner::{Config, TestError, TestRng, TestRunner};
let events = S::Event::event_variants();
assert!(
!events.is_empty(),
"`{}` has no event variants to drive testing",
std::any::type_name::<S>()
);
let config = Config {
cases: params.cases,
..Config::default()
};
let mut runner = match params.seed {
Some(seed) => {
let mut bytes = [0u8; 32];
bytes[..8].copy_from_slice(&seed.to_le_bytes());
TestRunner::new_with_rng(
config,
TestRng::from_seed(proptest::test_runner::RngAlgorithm::ChaCha, &bytes),
)
}
None => TestRunner::new(config),
};
let strategy = proptest::collection::vec(0..events.len(), 0..=params.max_steps);
let result = runner.run(&strategy, |indices| {
let mut machine = Machine::<S>::new();
for idx in indices {
let event = events[idx].clone();
let before = machine.state().clone();
let after = machine.apply(event.clone()).ok();
for invariant in &invariants {
if !invariant.holds(&before, &event, &after) {
return Err(TestCaseError::fail(format!(
"invariant violated: {}\n before: {before:?}\n event: {event:?}\n after: {after:?}",
invariant.description(),
)));
}
}
}
Ok(())
});
if let Err(err) = result {
match err {
TestError::Fail(reason, case) => panic!(
"`{}` failed an invariant.\n{reason}\nminimal failing sequence: {case:?}",
std::any::type_name::<S>()
),
TestError::Abort(reason) => {
panic!("`{}` test aborted: {reason}", std::any::type_name::<S>())
}
}
}
}