use std::collections::HashSet;
use std::sync::{OnceLock, RwLock};
fn invariant_log() -> &'static RwLock<HashSet<String>> {
static INVARIANT_LOG: OnceLock<RwLock<HashSet<String>>> = OnceLock::new();
INVARIANT_LOG.get_or_init(|| RwLock::new(HashSet::new()))
}
#[macro_export]
macro_rules! assert_invariant {
($condition:expr, $message:expr) => {
$crate::invariant_ppt::__assert_invariant_impl($condition, $message, None)
};
($condition:expr, $message:expr, $context:expr) => {
$crate::invariant_ppt::__assert_invariant_impl($condition, $message, Some($context))
};
}
#[doc(hidden)]
pub fn __assert_invariant_impl(condition: bool, message: &str, context: Option<&str>) {
if let Ok(mut log) = invariant_log().write() {
log.insert(message.to_string());
}
if !condition {
let ctx = context.unwrap_or("unknown");
panic!("INVARIANT VIOLATION [{}]: {}", ctx, message);
}
}
pub fn contract_test(test_name: &str, required_invariants: &[&str]) {
let log = match invariant_log().read() {
Ok(l) => l,
Err(poisoned) => poisoned.into_inner(),
};
let mut missing: Vec<&str> = Vec::new();
for invariant in required_invariants {
if !log.contains(*invariant) {
missing.push(invariant);
}
}
if !missing.is_empty() {
panic!(
"CONTRACT FAILURE [{}]: The following invariants were not checked:\n - {}",
test_name,
missing.join("\n - ")
);
}
}
pub fn clear_invariant_log() {
if let Ok(mut log) = invariant_log().write() {
log.clear();
}
}
pub fn get_logged_invariants() -> Vec<String> {
match invariant_log().read() {
Ok(log) => log.iter().cloned().collect(),
Err(poisoned) => poisoned.into_inner().iter().cloned().collect(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_poisoned_lock_paths_are_handled() {
clear_invariant_log();
let _ = std::panic::catch_unwind(|| {
let mut log = invariant_log().write().unwrap();
log.insert("poisoned invariant".to_string());
panic!("poison the lock");
});
contract_test("poisoned", &["poisoned invariant"]);
let logged = get_logged_invariants();
assert!(logged.contains(&"poisoned invariant".to_string()));
}
#[test]
fn test_invariant_passes() {
clear_invariant_log();
assert_invariant!(true, "test invariant passes");
let logged = get_logged_invariants();
assert!(logged.contains(&"test invariant passes".to_string()));
}
#[test]
#[should_panic(expected = "INVARIANT VIOLATION")]
fn test_invariant_fails() {
assert_invariant!(false, "this should fail", "test");
}
#[test]
fn test_contract_passes() {
clear_invariant_log();
assert_invariant!(true, "contract required invariant");
contract_test("test contract", &["contract required invariant"]);
}
#[test]
#[should_panic(expected = "CONTRACT FAILURE")]
fn test_contract_fails_missing() {
clear_invariant_log();
contract_test("test missing", &["this invariant was never checked"]);
}
}