use serde::{Deserialize, Serialize};
use std::collections::HashMap;
type PostConditionFn<T, U> = Box<dyn Fn(&T, &U) -> bool>;
type PropertyFn = Box<dyn Fn(&[u8]) -> bool>;
type TransitionFn<S> = Box<dyn Fn(&S, &str) -> Option<S>>;
pub struct Invariant<T: ?Sized> {
name: String,
predicate: Box<dyn Fn(&T) -> bool>,
}
impl<T: ?Sized> Invariant<T> {
pub fn new<F>(name: &str, predicate: F) -> Self
where
F: Fn(&T) -> bool + 'static,
{
Self {
name: name.to_string(),
predicate: Box::new(predicate),
}
}
pub fn check(&self, state: &T) -> bool {
(self.predicate)(state)
}
pub fn name(&self) -> &str {
&self.name
}
}
pub struct PreCondition<T: ?Sized> {
name: String,
predicate: Box<dyn Fn(&T) -> bool>,
}
impl<T: ?Sized> PreCondition<T> {
pub fn new<F>(name: &str, predicate: F) -> Self
where
F: Fn(&T) -> bool + 'static,
{
Self {
name: name.to_string(),
predicate: Box::new(predicate),
}
}
pub fn check(&self, input: &T) -> bool {
(self.predicate)(input)
}
pub fn name(&self) -> &str {
&self.name
}
}
pub struct PostCondition<T: ?Sized, U: ?Sized> {
name: String,
predicate: PostConditionFn<T, U>,
}
impl<T: ?Sized, U: ?Sized> PostCondition<T, U> {
pub fn new<F>(name: &str, predicate: F) -> Self
where
F: Fn(&T, &U) -> bool + 'static,
{
Self {
name: name.to_string(),
predicate: Box::new(predicate),
}
}
pub fn check(&self, input: &T, output: &U) -> bool {
(self.predicate)(input, output)
}
pub fn name(&self) -> &str {
&self.name
}
}
pub struct PropertyChecker {
properties: HashMap<String, PropertyFn>,
num_cases: usize,
}
impl Default for PropertyChecker {
fn default() -> Self {
Self::new()
}
}
impl PropertyChecker {
pub fn new() -> Self {
Self {
properties: HashMap::new(),
num_cases: 100,
}
}
pub fn with_num_cases(mut self, num: usize) -> Self {
self.num_cases = num;
self
}
pub fn add_property<F>(&mut self, name: &str, property: F)
where
F: Fn(&[u8]) -> bool + 'static,
{
self.properties.insert(name.to_string(), Box::new(property));
}
pub fn check_all(&self) -> PropertyCheckResult {
use rand::Rng as _;
let mut rng = rand::rng();
let mut results = HashMap::new();
for (name, property) in &self.properties {
let mut passed = 0;
let mut failed = 0;
for _ in 0..self.num_cases {
let mut data = vec![0u8; 32];
rng.fill_bytes(&mut data);
if property(&data) {
passed += 1;
} else {
failed += 1;
}
}
results.insert(
name.clone(),
PropertyResult {
passed,
failed,
total: self.num_cases,
},
);
}
PropertyCheckResult { results }
}
pub fn check_property(&self, name: &str) -> Option<PropertyResult> {
use rand::Rng as _;
let property = self.properties.get(name)?;
let mut rng = rand::rng();
let mut passed = 0;
let mut failed = 0;
for _ in 0..self.num_cases {
let mut data = vec![0u8; 32];
rng.fill_bytes(&mut data);
if property(&data) {
passed += 1;
} else {
failed += 1;
}
}
Some(PropertyResult {
passed,
failed,
total: self.num_cases,
})
}
}
#[derive(Debug, Clone)]
pub struct PropertyCheckResult {
pub results: HashMap<String, PropertyResult>,
}
impl PropertyCheckResult {
pub fn all_passed(&self) -> bool {
self.results.values().all(|r| r.failed == 0)
}
pub fn failed_properties(&self) -> Vec<String> {
self.results
.iter()
.filter(|(_, r)| r.failed > 0)
.map(|(name, _)| name.clone())
.collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PropertyResult {
pub passed: usize,
pub failed: usize,
pub total: usize,
}
impl PropertyResult {
pub fn success_rate(&self) -> f64 {
if self.total == 0 {
return 0.0;
}
self.passed as f64 / self.total as f64
}
pub fn all_passed(&self) -> bool {
self.failed == 0
}
}
pub struct StateMachine<S> {
current_state: S,
transitions: Vec<TransitionFn<S>>,
invariants: Vec<Invariant<S>>,
}
impl<S: Clone> StateMachine<S> {
pub fn new(initial_state: S) -> Self {
Self {
current_state: initial_state,
transitions: Vec::new(),
invariants: Vec::new(),
}
}
pub fn add_transition<F>(&mut self, transition: F)
where
F: Fn(&S, &str) -> Option<S> + 'static,
{
self.transitions.push(Box::new(transition));
}
pub fn add_invariant(&mut self, invariant: Invariant<S>) {
self.invariants.push(invariant);
}
pub fn check_invariants(&self) -> Vec<String> {
self.invariants
.iter()
.filter(|inv| !inv.check(&self.current_state))
.map(|inv| inv.name().to_string())
.collect()
}
pub fn transition(&mut self, event: &str) -> Result<(), String> {
for trans in &self.transitions {
if let Some(new_state) = trans(&self.current_state, event) {
let old_state = self.current_state.clone();
self.current_state = new_state;
let violations = self.check_invariants();
if !violations.is_empty() {
self.current_state = old_state;
return Err(format!("Invariant violations: {:?}", violations));
}
return Ok(());
}
}
Err(format!("No valid transition for event: {}", event))
}
pub fn current_state(&self) -> &S {
&self.current_state
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VerificationCondition {
pub name: String,
pub description: String,
pub formula: String,
}
impl VerificationCondition {
pub fn new(name: &str, description: &str, formula: &str) -> Self {
Self {
name: name.to_string(),
description: description.to_string(),
formula: formula.to_string(),
}
}
}
pub fn check_invariant<T: ?Sized>(name: &str, state: &T, predicate: impl Fn(&T) -> bool) {
if !predicate(state) {
panic!("Invariant '{}' violated", name);
}
}
pub fn check_precondition<T: ?Sized>(name: &str, input: &T, predicate: impl Fn(&T) -> bool) {
if !predicate(input) {
panic!("Pre-condition '{}' violated", name);
}
}
pub fn check_postcondition<T: ?Sized, U: ?Sized>(
name: &str,
input: &T,
output: &U,
predicate: impl Fn(&T, &U) -> bool,
) {
if !predicate(input, output) {
panic!("Post-condition '{}' violated", name);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_invariant_creation() {
let inv = Invariant::new("test", |x: &i32| *x > 0);
assert_eq!(inv.name(), "test");
assert!(inv.check(&5));
assert!(!inv.check(&-5));
}
#[test]
fn test_invariant_key_length() {
let inv = Invariant::new("key_length_32", |key: &[u8]| key.len() == 32);
assert!(inv.check(&[0u8; 32]));
assert!(!inv.check(&[0u8; 16]));
}
#[test]
fn test_precondition() {
let pre = PreCondition::new("non_empty", |data: &[u8]| !data.is_empty());
assert!(pre.check(&[1, 2, 3]));
assert!(!pre.check(&[]));
}
#[test]
fn test_postcondition() {
let post = PostCondition::new("output_not_empty", |_input: &[u8], output: &[u8]| {
!output.is_empty()
});
assert!(post.check(&[1, 2], &[3, 4]));
assert!(!post.check(&[1, 2], &[]));
}
#[test]
fn test_property_checker() {
let mut checker = PropertyChecker::new().with_num_cases(10);
checker.add_property("always_true", |_| true);
checker.add_property("always_false", |_| false);
let results = checker.check_all();
assert!(results.results["always_true"].all_passed());
assert!(!results.results["always_false"].all_passed());
assert!(!results.all_passed());
}
#[test]
fn test_property_result_success_rate() {
let result = PropertyResult {
passed: 75,
failed: 25,
total: 100,
};
assert_eq!(result.success_rate(), 0.75);
}
#[test]
fn test_property_checker_single_property() {
let mut checker = PropertyChecker::new().with_num_cases(20);
checker.add_property("test_prop", |_| true);
let result = checker.check_property("test_prop").unwrap();
assert_eq!(result.passed, 20);
assert_eq!(result.failed, 0);
assert!(result.all_passed());
}
#[test]
fn test_state_machine_basic() {
#[derive(Clone, PartialEq, Debug)]
enum State {
Init,
Ready,
Running,
}
let mut sm = StateMachine::new(State::Init);
sm.add_transition(|state, event| match (state, event) {
(State::Init, "start") => Some(State::Ready),
(State::Ready, "run") => Some(State::Running),
_ => None,
});
assert!(sm.transition("start").is_ok());
assert_eq!(*sm.current_state(), State::Ready);
assert!(sm.transition("run").is_ok());
assert_eq!(*sm.current_state(), State::Running);
assert!(sm.transition("start").is_err());
}
#[test]
fn test_state_machine_with_invariant() {
let mut sm = StateMachine::new(0i32);
sm.add_invariant(Invariant::new("non_negative", |s: &i32| *s >= 0));
sm.add_transition(|state, event| {
if event == "increment" {
Some(state + 1)
} else {
None
}
});
assert!(sm.transition("increment").is_ok());
assert_eq!(*sm.current_state(), 1);
assert!(sm.check_invariants().is_empty());
}
#[test]
fn test_state_machine_invariant_violation() {
let mut sm = StateMachine::new(5i32);
sm.add_invariant(Invariant::new("max_10", |s: &i32| *s <= 10));
sm.add_transition(|state, event| {
if event == "add_10" {
Some(state + 10)
} else {
None
}
});
assert!(sm.transition("add_10").is_err());
assert_eq!(*sm.current_state(), 5);
}
#[test]
fn test_check_invariant_helper() {
let state = vec![1, 2, 3];
check_invariant("non_empty", &state, |s| !s.is_empty());
}
#[test]
#[should_panic(expected = "Invariant 'empty' violated")]
fn test_check_invariant_helper_panic() {
let state = vec![1, 2, 3];
check_invariant("empty", &state, |s| s.is_empty());
}
#[test]
fn test_verification_condition() {
let vc = VerificationCondition::new(
"encryption_correctness",
"Decryption of encrypted data returns original",
"forall m, k: decrypt(encrypt(m, k), k) = m",
);
assert_eq!(vc.name, "encryption_correctness");
assert!(vc.formula.contains("decrypt"));
}
#[test]
fn test_failed_properties() {
let mut checker = PropertyChecker::new().with_num_cases(10);
checker.add_property("pass1", |_| true);
checker.add_property("fail1", |_| false);
checker.add_property("pass2", |_| true);
let results = checker.check_all();
let failed = results.failed_properties();
assert_eq!(failed.len(), 1);
assert!(failed.contains(&"fail1".to_string()));
}
}