use super::state_handle::StateHandle;
pub trait Invariant: 'static {
fn name(&self) -> &str;
fn check(&self, state: &StateHandle, sim_time_ms: u64);
fn reset(&mut self) {}
}
type InvariantCheckFn = Box<dyn Fn(&StateHandle, u64)>;
struct FnInvariant {
name: String,
check_fn: InvariantCheckFn,
}
impl Invariant for FnInvariant {
fn name(&self) -> &str {
&self.name
}
fn check(&self, state: &StateHandle, sim_time_ms: u64) {
(self.check_fn)(state, sim_time_ms);
}
}
pub fn invariant_fn(
name: impl Into<String>,
f: impl Fn(&StateHandle, u64) + 'static,
) -> Box<dyn Invariant> {
Box::new(FnInvariant {
name: name.into(),
check_fn: Box::new(f),
})
}
#[cfg(test)]
mod tests {
use super::*;
struct TestInvariant;
impl Invariant for TestInvariant {
fn name(&self) -> &str {
"test"
}
fn check(&self, state: &StateHandle, _sim_time_ms: u64) {
if let Some(val) = state.get::<i64>("value") {
assert!(val >= 0, "value went negative: {}", val);
}
}
}
#[test]
fn test_trait_impl() {
let inv = TestInvariant;
let state = StateHandle::new();
state.publish("value", 42i64);
inv.check(&state, 0);
assert_eq!(inv.name(), "test");
}
#[test]
fn test_invariant_fn() {
let inv = invariant_fn("check_positive", |state, _t| {
if let Some(val) = state.get::<i64>("val") {
assert!(val >= 0, "negative: {}", val);
}
});
let state = StateHandle::new();
state.publish("val", 10i64);
inv.check(&state, 100);
assert_eq!(inv.name(), "check_positive");
}
#[test]
#[should_panic(expected = "negative")]
fn test_invariant_violation_panics() {
let inv = invariant_fn("must_be_positive", |state, _t| {
let val: i64 = state.get("val").unwrap_or(0);
assert!(val >= 0, "negative: {}", val);
});
let state = StateHandle::new();
state.publish("val", -1i64);
inv.check(&state, 0);
}
}