use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::context::Context;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum InvariantClass {
Structural,
Semantic,
Acceptance,
}
#[derive(Debug, Clone, PartialEq)]
pub enum InvariantResult {
Ok,
Violated(Violation),
}
impl InvariantResult {
#[must_use]
pub fn is_ok(&self) -> bool {
matches!(self, Self::Ok)
}
#[must_use]
pub fn is_violated(&self) -> bool {
matches!(self, Self::Violated(_))
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Violation {
pub reason: String,
pub fact_ids: Vec<String>,
}
impl Violation {
#[must_use]
pub fn new(reason: impl Into<String>) -> Self {
Self {
reason: reason.into(),
fact_ids: Vec::new(),
}
}
#[must_use]
pub fn with_facts(reason: impl Into<String>, fact_ids: Vec<String>) -> Self {
Self {
reason: reason.into(),
fact_ids,
}
}
}
pub trait Invariant: Send + Sync {
fn name(&self) -> &str;
fn class(&self) -> InvariantClass;
fn check(&self, ctx: &dyn crate::ContextView) -> InvariantResult;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct InvariantId(pub(crate) u32);
impl std::fmt::Display for InvariantId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Invariant({})", self.0)
}
}
#[derive(Default)]
pub struct InvariantRegistry {
invariants: Vec<Box<dyn Invariant>>,
by_class: HashMap<InvariantClass, Vec<InvariantId>>,
next_id: u32,
}
impl InvariantRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, invariant: impl Invariant + 'static) -> InvariantId {
let id = InvariantId(self.next_id);
self.next_id += 1;
let class = invariant.class();
self.by_class.entry(class).or_default().push(id);
self.invariants.push(Box::new(invariant));
id
}
#[must_use]
pub fn count(&self) -> usize {
self.invariants.len()
}
pub fn check_class(&self, class: InvariantClass, ctx: &Context) -> Result<(), InvariantError> {
let ids = self.by_class.get(&class).map_or(&[][..], Vec::as_slice);
for &id in ids {
let invariant = &self.invariants[id.0 as usize];
if let InvariantResult::Violated(violation) = invariant.check(ctx) {
return Err(InvariantError {
invariant_name: invariant.name().to_string(),
class,
violation,
});
}
}
Ok(())
}
pub fn check_structural(&self, ctx: &Context) -> Result<(), InvariantError> {
self.check_class(InvariantClass::Structural, ctx)
}
pub fn check_semantic(&self, ctx: &Context) -> Result<(), InvariantError> {
self.check_class(InvariantClass::Semantic, ctx)
}
pub fn check_acceptance(&self, ctx: &Context) -> Result<(), InvariantError> {
self.check_class(InvariantClass::Acceptance, ctx)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InvariantError {
pub invariant_name: String,
pub class: InvariantClass,
pub violation: Violation,
}
impl std::fmt::Display for InvariantError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{:?} invariant '{}' violated: {}",
self.class, self.invariant_name, self.violation.reason
)
}
}
impl std::error::Error for InvariantError {}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::{ContextKey, Fact};
struct RequireSeeds;
impl Invariant for RequireSeeds {
fn name(&self) -> &'static str {
"require_seeds"
}
fn class(&self) -> InvariantClass {
InvariantClass::Acceptance
}
fn check(&self, ctx: &dyn crate::ContextView) -> InvariantResult {
if ctx.has(ContextKey::Seeds) {
InvariantResult::Ok
} else {
InvariantResult::Violated(Violation::new("no seeds present"))
}
}
}
struct NoEmptyContent;
impl Invariant for NoEmptyContent {
fn name(&self) -> &'static str {
"no_empty_content"
}
fn class(&self) -> InvariantClass {
InvariantClass::Structural
}
fn check(&self, ctx: &dyn crate::ContextView) -> InvariantResult {
for key in &[
ContextKey::Seeds,
ContextKey::Hypotheses,
ContextKey::Strategies,
ContextKey::Competitors,
ContextKey::Evaluations,
] {
for fact in ctx.get(*key) {
if fact.content.trim().is_empty() {
return InvariantResult::Violated(Violation::with_facts(
"empty content not allowed",
vec![fact.id.clone()],
));
}
}
}
InvariantResult::Ok
}
}
#[test]
fn registry_registers_invariants() {
let mut registry = InvariantRegistry::new();
let id1 = registry.register(RequireSeeds);
let id2 = registry.register(NoEmptyContent);
assert_eq!(registry.count(), 2);
assert_ne!(id1, id2);
}
#[test]
fn acceptance_invariant_passes_with_seeds() {
let mut registry = InvariantRegistry::new();
registry.register(RequireSeeds);
let mut ctx = Context::new();
let _ = ctx.add_fact(Fact {
key: ContextKey::Seeds,
id: "s1".into(),
content: "value".into(),
});
assert!(registry.check_acceptance(&ctx).is_ok());
}
#[test]
fn acceptance_invariant_fails_without_seeds() {
let mut registry = InvariantRegistry::new();
registry.register(RequireSeeds);
let ctx = Context::new();
let result = registry.check_acceptance(&ctx);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.invariant_name, "require_seeds");
assert_eq!(err.class, InvariantClass::Acceptance);
}
#[test]
fn structural_invariant_catches_empty_content() {
let mut registry = InvariantRegistry::new();
registry.register(NoEmptyContent);
let mut ctx = Context::new();
let _ = ctx.add_fact(Fact {
key: ContextKey::Seeds,
id: "bad".into(),
content: " ".into(), });
let result = registry.check_structural(&ctx);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.violation
.fact_ids
.contains(&"bad".into())
);
}
#[test]
fn different_classes_checked_independently() {
let mut registry = InvariantRegistry::new();
registry.register(RequireSeeds); registry.register(NoEmptyContent);
let ctx = Context::new();
assert!(registry.check_structural(&ctx).is_ok());
assert!(registry.check_acceptance(&ctx).is_err());
}
}