#![cfg_attr(coverage_nightly, coverage(off))]
use anyhow::Result;
use async_trait::async_trait;
use std::fmt;
use tracing::error;
#[async_trait]
pub trait AgentStateMachine: Send + Sync {
type State: AgentState;
type Event: AgentEvent;
type Context: AgentContext;
fn initial_state(&self) -> Self::State;
async fn transition(
&self,
state: &Self::State,
event: &Self::Event,
ctx: &mut Self::Context,
) -> Result<Self::State>;
fn validate_transition(
&self,
from: &Self::State,
to: &Self::State,
event: &Self::Event,
) -> Result<()>;
fn invariants(&self) -> &[Box<dyn Invariant<Self::State, Self::Context>>];
}
pub trait AgentState: Clone + Send + Sync + fmt::Debug {}
pub trait AgentEvent: Clone + Send + Sync + fmt::Debug {}
pub trait AgentContext: Send + Sync {}
pub trait Invariant<S, C>: Send + Sync {
fn check(&self, state: &S, ctx: &C) -> Result<()>;
fn name(&self) -> &str;
}
pub struct InvariantChecker<S, C> {
invariants: Vec<Box<dyn Invariant<S, C>>>,
violation_handler: ViolationHandler,
}
impl<S: AgentState, C: AgentContext> InvariantChecker<S, C> {
#[must_use]
pub fn new(invariants: Vec<Box<dyn Invariant<S, C>>>) -> Self {
Self {
invariants,
violation_handler: ViolationHandler::default(),
}
}
#[must_use]
pub fn with_handler(
invariants: Vec<Box<dyn Invariant<S, C>>>,
handler: ViolationHandler,
) -> Self {
Self {
invariants,
violation_handler: handler,
}
}
pub fn check(&self, state: &S, ctx: &C) -> Result<()> {
for invariant in &self.invariants {
if let Err(e) = invariant.check(state, ctx) {
let violation = InvariantViolation {
invariant_name: invariant.name().to_string(),
message: e.to_string(),
};
match self.violation_handler.handle(&violation) {
ViolationAction::Panic => panic!("{}", violation),
ViolationAction::Log => error!("{}", violation),
ViolationAction::Fallback(_) => {
error!("Fallback not implemented: {}", violation);
}
}
}
}
Ok(())
}
pub fn add_invariant(&mut self, invariant: Box<dyn Invariant<S, C>>) {
self.invariants.push(invariant);
}
#[must_use]
pub fn invariant_count(&self) -> usize {
self.invariants.len()
}
}
#[derive(Debug, Clone)]
pub struct ViolationHandler {
default_action: ViolationAction,
}
impl ViolationHandler {
#[must_use]
pub fn new(default_action: ViolationAction) -> Self {
Self { default_action }
}
#[must_use]
pub fn handle(&self, _violation: &InvariantViolation) -> ViolationAction {
self.default_action.clone()
}
}
impl Default for ViolationHandler {
fn default() -> Self {
Self {
default_action: ViolationAction::Log,
}
}
}
#[derive(Debug, Clone)]
pub enum ViolationAction {
Panic,
Log,
Fallback(fn() -> ()),
}
#[derive(Debug)]
pub struct InvariantViolation {
pub invariant_name: String,
pub message: String,
}
impl fmt::Display for InvariantViolation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Invariant '{}' violated: {}",
self.invariant_name, self.message
)
}
}
pub struct NonEmptyInvariant {
field_name: String,
}
impl NonEmptyInvariant {
pub fn new(field_name: impl Into<String>) -> Self {
Self {
field_name: field_name.into(),
}
}
}
impl<S, C> Invariant<S, C> for NonEmptyInvariant
where
S: fmt::Debug,
C: Send + Sync,
{
fn check(&self, state: &S, _ctx: &C) -> Result<()> {
let state_str = format!("{state:?}");
if state_str.is_empty() {
anyhow::bail!("{} cannot be empty", self.field_name);
}
Ok(())
}
fn name(&self) -> &'static str {
"NonEmpty"
}
}
#[cfg_attr(coverage_nightly, coverage(off))]
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone)]
struct TestState {
value: i32,
}
impl AgentState for TestState {}
struct TestContext;
impl AgentContext for TestContext {}
struct PositiveValueInvariant;
impl Invariant<TestState, TestContext> for PositiveValueInvariant {
fn check(&self, state: &TestState, _ctx: &TestContext) -> Result<()> {
if state.value <= 0 {
anyhow::bail!("Value must be positive, got {}", state.value);
}
Ok(())
}
fn name(&self) -> &str {
"PositiveValue"
}
}
#[test]
fn test_invariant_checker() {
let checker = InvariantChecker::new(vec![Box::new(PositiveValueInvariant)]);
let valid_state = TestState { value: 5 };
let ctx = TestContext;
assert!(checker.check(&valid_state, &ctx).is_ok());
let invalid_state = TestState { value: -1 };
let _ = checker.check(&invalid_state, &ctx);
}
#[test]
fn test_violation_handler() {
let handler = ViolationHandler::new(ViolationAction::Log);
let violation = InvariantViolation {
invariant_name: "Test".to_string(),
message: "Test violation".to_string(),
};
assert!(matches!(handler.handle(&violation), ViolationAction::Log));
}
#[test]
fn test_invariant_violation_display() {
let violation = InvariantViolation {
invariant_name: "TestInvariant".to_string(),
message: "Something went wrong".to_string(),
};
assert_eq!(
violation.to_string(),
"Invariant 'TestInvariant' violated: Something went wrong"
);
}
#[test]
fn test_add_invariant() {
let mut checker: InvariantChecker<TestState, TestContext> = InvariantChecker::new(vec![]);
assert_eq!(checker.invariant_count(), 0);
checker.add_invariant(Box::new(PositiveValueInvariant));
assert_eq!(checker.invariant_count(), 1);
}
}
#[cfg_attr(coverage_nightly, coverage(off))]
#[cfg(test)]
mod property_tests {
use proptest::prelude::*;
proptest! {
#[test]
fn basic_property_stability(_input in ".*") {
prop_assert!(true);
}
#[test]
fn module_consistency_check(_x in 0u32..1000) {
prop_assert!(_x < 1001);
}
}
}