use std::cell::RefCell;
use std::collections::HashSet;
use std::thread_local;
thread_local! {
static INVARIANT_LOG: RefCell<HashSet<String>> = RefCell::new(HashSet::new());
}
#[macro_export]
macro_rules! assert_invariant {
($condition:expr, $message:expr) => {
$crate::invariant_ppt::__assert_invariant_impl($condition, $message, None)
};
($condition:expr, $message:expr, $context:expr) => {
$crate::invariant_ppt::__assert_invariant_impl($condition, $message, Some($context))
};
}
#[doc(hidden)]
pub fn __assert_invariant_impl(condition: bool, message: &str, context: Option<&str>) {
INVARIANT_LOG.with(|log| {
log.borrow_mut().insert(message.to_string());
});
if !condition {
let ctx = context.unwrap_or("unknown");
panic!("INVARIANT VIOLATION [{}]: {}", ctx, message);
}
}
pub fn contract_test(test_name: &str, required_invariants: &[&str]) {
let log = INVARIANT_LOG.with(|log| log.borrow().clone());
let mut missing: Vec<&str> = Vec::new();
for invariant in required_invariants {
if !log.contains(*invariant) {
missing.push(invariant);
}
}
if !missing.is_empty() {
panic!(
"CONTRACT FAILURE [{}]: The following invariants were not checked:\n - {}",
test_name,
missing.join("\n - ")
);
}
}
pub fn clear_invariant_log() {
INVARIANT_LOG.with(|log| {
log.borrow_mut().clear();
});
}
pub fn get_logged_invariants() -> Vec<String> {
INVARIANT_LOG.with(|log| log.borrow().iter().cloned().collect())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_poisoned_lock_paths_are_handled() {
clear_invariant_log();
assert_invariant!(true, "poisoned invariant");
contract_test("poisoned", &["poisoned invariant"]);
let logged = get_logged_invariants();
assert!(logged.contains(&"poisoned invariant".to_string()));
}
#[test]
fn test_invariant_passes() {
clear_invariant_log();
assert_invariant!(true, "test invariant passes");
let logged = get_logged_invariants();
assert!(logged.contains(&"test invariant passes".to_string()));
}
#[test]
#[should_panic(expected = "INVARIANT VIOLATION")]
fn test_invariant_fails() {
assert_invariant!(false, "this should fail", "test");
}
#[test]
fn test_contract_passes() {
clear_invariant_log();
assert_invariant!(true, "contract required invariant");
contract_test("test contract", &["contract required invariant"]);
}
#[test]
#[should_panic(expected = "CONTRACT FAILURE")]
fn test_contract_fails_missing() {
clear_invariant_log();
contract_test("test missing", &["this invariant was never checked"]);
}
}