#![cfg_attr(coverage_nightly, coverage(off))]
use anyhow::Result;
use async_trait::async_trait;
use std::fmt;
use tracing::error;
#[async_trait]
pub trait AgentStateMachine: Send + Sync {
type State: AgentState;
type Event: AgentEvent;
type Context: AgentContext;
fn initial_state(&self) -> Self::State;
async fn transition(
&self,
state: &Self::State,
event: &Self::Event,
ctx: &mut Self::Context,
) -> Result<Self::State>;
fn validate_transition(
&self,
from: &Self::State,
to: &Self::State,
event: &Self::Event,
) -> Result<()>;
fn invariants(&self) -> &[Box<dyn Invariant<Self::State, Self::Context>>];
}
pub trait AgentState: Clone + Send + Sync + fmt::Debug {}
pub trait AgentEvent: Clone + Send + Sync + fmt::Debug {}
pub trait AgentContext: Send + Sync {}
pub trait Invariant<S, C>: Send + Sync {
fn check(&self, state: &S, ctx: &C) -> Result<()>;
fn name(&self) -> &str;
}
pub struct InvariantChecker<S, C> {
invariants: Vec<Box<dyn Invariant<S, C>>>,
violation_handler: ViolationHandler,
}
impl<S: AgentState, C: AgentContext> InvariantChecker<S, C> {
#[must_use]
pub fn new(invariants: Vec<Box<dyn Invariant<S, C>>>) -> Self {
Self {
invariants,
violation_handler: ViolationHandler::default(),
}
}
#[must_use]
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub fn with_handler(
invariants: Vec<Box<dyn Invariant<S, C>>>,
handler: ViolationHandler,
) -> Self {
Self {
invariants,
violation_handler: handler,
}
}
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub fn check(&self, state: &S, ctx: &C) -> Result<()> {
for invariant in &self.invariants {
if let Err(e) = invariant.check(state, ctx) {
let violation = InvariantViolation {
invariant_name: invariant.name().to_string(),
message: e.to_string(),
};
match self.violation_handler.handle(&violation) {
ViolationAction::Panic => panic!("{}", violation),
ViolationAction::Log => error!("{}", violation),
ViolationAction::Fallback(_) => {
error!("Fallback not implemented: {}", violation);
}
}
}
}
Ok(())
}
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub fn add_invariant(&mut self, invariant: Box<dyn Invariant<S, C>>) {
self.invariants.push(invariant);
}
#[must_use]
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub fn invariant_count(&self) -> usize {
self.invariants.len()
}
}
#[derive(Debug, Clone)]
pub struct ViolationHandler {
default_action: ViolationAction,
}
impl ViolationHandler {
#[must_use]
pub fn new(default_action: ViolationAction) -> Self {
Self { default_action }
}
#[must_use]
#[provable_contracts_macros::contract("pmat-core.yaml", equation = "check_compliance")]
pub fn handle(&self, _violation: &InvariantViolation) -> ViolationAction {
self.default_action.clone()
}
}
impl Default for ViolationHandler {
fn default() -> Self {
Self {
default_action: ViolationAction::Log,
}
}
}
#[derive(Debug, Clone)]
pub enum ViolationAction {
Panic,
Log,
Fallback(fn() -> ()),
}
#[derive(Debug)]
pub struct InvariantViolation {
pub invariant_name: String,
pub message: String,
}
impl fmt::Display for InvariantViolation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Invariant '{}' violated: {}",
self.invariant_name, self.message
)
}
}
pub struct NonEmptyInvariant {
field_name: String,
}
impl NonEmptyInvariant {
pub fn new(field_name: impl Into<String>) -> Self {
Self {
field_name: field_name.into(),
}
}
}
impl<S, C> Invariant<S, C> for NonEmptyInvariant
where
S: fmt::Debug,
C: Send + Sync,
{
fn check(&self, state: &S, _ctx: &C) -> Result<()> {
let state_str = format!("{state:?}");
if state_str.is_empty() {
anyhow::bail!("{} cannot be empty", self.field_name);
}
Ok(())
}
fn name(&self) -> &'static str {
"NonEmpty"
}
}
#[cfg_attr(coverage_nightly, coverage(off))]
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone)]
struct TestState {
value: i32,
}
impl AgentState for TestState {}
struct TestContext;
impl AgentContext for TestContext {}
struct PositiveValueInvariant;
impl Invariant<TestState, TestContext> for PositiveValueInvariant {
fn check(&self, state: &TestState, _ctx: &TestContext) -> Result<()> {
if state.value <= 0 {
anyhow::bail!("Value must be positive, got {}", state.value);
}
Ok(())
}
fn name(&self) -> &str {
"PositiveValue"
}
}
#[test]
fn test_invariant_checker() {
let checker = InvariantChecker::new(vec![Box::new(PositiveValueInvariant)]);
let valid_state = TestState { value: 5 };
let ctx = TestContext;
assert!(checker.check(&valid_state, &ctx).is_ok());
let invalid_state = TestState { value: -1 };
let _ = checker.check(&invalid_state, &ctx);
}
#[test]
fn test_violation_handler() {
let handler = ViolationHandler::new(ViolationAction::Log);
let violation = InvariantViolation {
invariant_name: "Test".to_string(),
message: "Test violation".to_string(),
};
assert!(matches!(handler.handle(&violation), ViolationAction::Log));
}
#[test]
fn test_invariant_violation_display() {
let violation = InvariantViolation {
invariant_name: "TestInvariant".to_string(),
message: "Something went wrong".to_string(),
};
assert_eq!(
violation.to_string(),
"Invariant 'TestInvariant' violated: Something went wrong"
);
}
#[test]
fn test_add_invariant() {
let mut checker: InvariantChecker<TestState, TestContext> = InvariantChecker::new(vec![]);
assert_eq!(checker.invariant_count(), 0);
checker.add_invariant(Box::new(PositiveValueInvariant));
assert_eq!(checker.invariant_count(), 1);
}
struct AlwaysFailInvariant;
impl Invariant<TestState, TestContext> for AlwaysFailInvariant {
fn check(&self, _state: &TestState, _ctx: &TestContext) -> Result<()> {
anyhow::bail!("always fails");
}
fn name(&self) -> &str {
"AlwaysFail"
}
}
fn fallback_noop() {}
#[test]
fn test_violation_handler_default_is_log() {
let h = ViolationHandler::default();
let v = InvariantViolation {
invariant_name: "x".into(),
message: "m".into(),
};
assert!(matches!(h.handle(&v), ViolationAction::Log));
}
#[test]
fn test_violation_handler_returns_panic_action() {
let h = ViolationHandler::new(ViolationAction::Panic);
let v = InvariantViolation {
invariant_name: "x".into(),
message: "m".into(),
};
assert!(matches!(h.handle(&v), ViolationAction::Panic));
}
#[test]
fn test_violation_handler_returns_fallback_action() {
let h = ViolationHandler::new(ViolationAction::Fallback(fallback_noop));
let v = InvariantViolation {
invariant_name: "x".into(),
message: "m".into(),
};
assert!(matches!(h.handle(&v), ViolationAction::Fallback(_)));
}
#[test]
fn test_invariant_checker_with_handler_uses_custom_handler() {
let checker: InvariantChecker<TestState, TestContext> = InvariantChecker::with_handler(
vec![Box::new(AlwaysFailInvariant)],
ViolationHandler::new(ViolationAction::Fallback(fallback_noop)),
);
let ctx = TestContext;
let state = TestState { value: 1 };
assert!(checker.check(&state, &ctx).is_ok());
}
#[test]
fn test_invariant_checker_panic_path_on_violation() {
let checker: InvariantChecker<TestState, TestContext> = InvariantChecker::with_handler(
vec![Box::new(AlwaysFailInvariant)],
ViolationHandler::new(ViolationAction::Panic),
);
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _ = checker.check(&TestState { value: 0 }, &TestContext);
}));
let err = result.expect_err("expected panic from Panic handler");
let panic_msg = err
.downcast_ref::<String>()
.cloned()
.or_else(|| err.downcast_ref::<&'static str>().map(|s| s.to_string()))
.unwrap_or_default();
assert!(
panic_msg.contains("AlwaysFail") && panic_msg.contains("violated"),
"panic msg was: {panic_msg}"
);
}
#[test]
fn test_invariant_checker_all_passing_returns_ok() {
let checker: InvariantChecker<TestState, TestContext> = InvariantChecker::new(vec![
Box::new(PositiveValueInvariant),
Box::new(PositiveValueInvariant),
]);
assert_eq!(checker.invariant_count(), 2);
let ctx = TestContext;
checker.check(&TestState { value: 7 }, &ctx).expect("ok");
}
#[test]
fn test_invariant_checker_continues_past_first_failure() {
let checker: InvariantChecker<TestState, TestContext> = InvariantChecker::new(vec![
Box::new(AlwaysFailInvariant),
Box::new(PositiveValueInvariant),
]);
let result = checker.check(&TestState { value: 7 }, &TestContext);
assert!(result.is_ok());
}
#[test]
fn test_non_empty_invariant_name_is_non_empty() {
let inv = NonEmptyInvariant::new("some_field");
let name: &str = <NonEmptyInvariant as Invariant<TestState, TestContext>>::name(&inv);
assert_eq!(name, "NonEmpty");
}
#[test]
fn test_non_empty_invariant_check_passes_on_debuggable_state() {
let inv = NonEmptyInvariant::new("field");
let ctx = TestContext;
let state = TestState { value: 3 };
Invariant::<TestState, TestContext>::check(&inv, &state, &ctx).expect("Ok");
}
#[test]
fn test_non_empty_invariant_preserves_field_name() {
let _from_str = NonEmptyInvariant::new("a");
let _from_string = NonEmptyInvariant::new(String::from("b"));
}
#[test]
fn test_invariant_violation_fields_public() {
let v = InvariantViolation {
invariant_name: "Inv".into(),
message: "msg".into(),
};
assert_eq!(v.invariant_name, "Inv");
assert_eq!(v.message, "msg");
}
#[test]
fn test_violation_action_clone_variants() {
let a = ViolationAction::Panic;
let b = ViolationAction::Log;
let c = ViolationAction::Fallback(fallback_noop);
let _a2 = a.clone();
let _b2 = b.clone();
let _c2 = c.clone();
for v in [
ViolationAction::Panic,
ViolationAction::Log,
ViolationAction::Fallback(fallback_noop),
] {
let _ = format!("{v:?}");
}
}
}
#[cfg_attr(coverage_nightly, coverage(off))]
#[cfg(test)]
mod property_tests {
use proptest::prelude::*;
proptest! {
#[test]
fn basic_property_stability(_input in ".*") {
prop_assert!(true);
}
#[test]
fn module_consistency_check(_x in 0u32..1000) {
prop_assert!(_x < 1001);
}
}
}