1use serde::{Deserialize, Serialize};
35use std::collections::HashMap;
36
37use crate::context::Context;
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
41pub enum InvariantClass {
42 Structural,
45
46 Semantic,
49
50 Acceptance,
53}
54
55#[derive(Debug, Clone, PartialEq)]
57pub enum InvariantResult {
58 Ok,
60 Violated(Violation),
62}
63
64impl InvariantResult {
65 #[must_use]
67 pub fn is_ok(&self) -> bool {
68 matches!(self, Self::Ok)
69 }
70
71 #[must_use]
73 pub fn is_violated(&self) -> bool {
74 matches!(self, Self::Violated(_))
75 }
76}
77
78#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
80pub struct Violation {
81 pub reason: String,
83 pub fact_ids: Vec<String>,
85}
86
87impl Violation {
88 #[must_use]
90 pub fn new(reason: impl Into<String>) -> Self {
91 Self {
92 reason: reason.into(),
93 fact_ids: Vec::new(),
94 }
95 }
96
97 #[must_use]
99 pub fn with_facts(reason: impl Into<String>, fact_ids: Vec<String>) -> Self {
100 Self {
101 reason: reason.into(),
102 fact_ids,
103 }
104 }
105}
106
107pub trait Invariant: Send + Sync {
111 fn name(&self) -> &str;
113
114 fn class(&self) -> InvariantClass;
116
117 fn check(&self, ctx: &dyn crate::ContextView) -> InvariantResult;
119}
120
121#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
123pub struct InvariantId(pub(crate) u32);
124
125impl std::fmt::Display for InvariantId {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 write!(f, "Invariant({})", self.0)
128 }
129}
130
131#[derive(Default)]
133pub struct InvariantRegistry {
134 invariants: Vec<Box<dyn Invariant>>,
135 by_class: HashMap<InvariantClass, Vec<InvariantId>>,
136 next_id: u32,
137}
138
139impl InvariantRegistry {
140 #[must_use]
142 pub fn new() -> Self {
143 Self::default()
144 }
145
146 pub fn register(&mut self, invariant: impl Invariant + 'static) -> InvariantId {
148 let id = InvariantId(self.next_id);
149 self.next_id += 1;
150
151 let class = invariant.class();
152 self.by_class.entry(class).or_default().push(id);
153 self.invariants.push(Box::new(invariant));
154
155 id
156 }
157
158 #[must_use]
160 pub fn count(&self) -> usize {
161 self.invariants.len()
162 }
163
164 pub fn check_class(&self, class: InvariantClass, ctx: &Context) -> Result<(), InvariantError> {
172 let ids = self.by_class.get(&class).map_or(&[][..], Vec::as_slice);
173
174 for &id in ids {
175 let invariant = &self.invariants[id.0 as usize];
176 if let InvariantResult::Violated(violation) = invariant.check(ctx) {
177 return Err(InvariantError {
178 invariant_name: invariant.name().to_string(),
179 class,
180 violation,
181 });
182 }
183 }
184
185 Ok(())
186 }
187
188 pub fn check_structural(&self, ctx: &Context) -> Result<(), InvariantError> {
194 self.check_class(InvariantClass::Structural, ctx)
195 }
196
197 pub fn check_semantic(&self, ctx: &Context) -> Result<(), InvariantError> {
203 self.check_class(InvariantClass::Semantic, ctx)
204 }
205
206 pub fn check_acceptance(&self, ctx: &Context) -> Result<(), InvariantError> {
212 self.check_class(InvariantClass::Acceptance, ctx)
213 }
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct InvariantError {
219 pub invariant_name: String,
221 pub class: InvariantClass,
223 pub violation: Violation,
225}
226
227impl std::fmt::Display for InvariantError {
228 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229 write!(
230 f,
231 "{:?} invariant '{}' violated: {}",
232 self.class, self.invariant_name, self.violation.reason
233 )
234 }
235}
236
237impl std::error::Error for InvariantError {}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use crate::context::ContextKey;
243
244 struct RequireSeeds;
246
247 impl Invariant for RequireSeeds {
248 fn name(&self) -> &'static str {
249 "require_seeds"
250 }
251
252 fn class(&self) -> InvariantClass {
253 InvariantClass::Acceptance
254 }
255
256 fn check(&self, ctx: &dyn crate::ContextView) -> InvariantResult {
257 if ctx.has(ContextKey::Seeds) {
258 InvariantResult::Ok
259 } else {
260 InvariantResult::Violated(Violation::new("no seeds present"))
261 }
262 }
263 }
264
265 struct NoEmptyContent;
267
268 impl Invariant for NoEmptyContent {
269 fn name(&self) -> &'static str {
270 "no_empty_content"
271 }
272
273 fn class(&self) -> InvariantClass {
274 InvariantClass::Structural
275 }
276
277 fn check(&self, ctx: &dyn crate::ContextView) -> InvariantResult {
278 for key in &[
279 ContextKey::Seeds,
280 ContextKey::Hypotheses,
281 ContextKey::Strategies,
282 ContextKey::Competitors,
283 ContextKey::Evaluations,
284 ] {
285 for fact in ctx.get(*key) {
286 if fact.content.trim().is_empty() {
287 return InvariantResult::Violated(Violation::with_facts(
288 "empty content not allowed",
289 vec![fact.id.clone()],
290 ));
291 }
292 }
293 }
294 InvariantResult::Ok
295 }
296 }
297
298 #[test]
299 fn registry_registers_invariants() {
300 let mut registry = InvariantRegistry::new();
301 let id1 = registry.register(RequireSeeds);
302 let id2 = registry.register(NoEmptyContent);
303
304 assert_eq!(registry.count(), 2);
305 assert_ne!(id1, id2);
306 }
307
308 #[test]
309 fn acceptance_invariant_passes_with_seeds() {
310 let mut registry = InvariantRegistry::new();
311 registry.register(RequireSeeds);
312
313 let mut ctx = Context::new();
314 let _ = ctx.add_fact(crate::context::new_fact(ContextKey::Seeds, "s1", "value"));
315
316 assert!(registry.check_acceptance(&ctx).is_ok());
317 }
318
319 #[test]
320 fn acceptance_invariant_fails_without_seeds() {
321 let mut registry = InvariantRegistry::new();
322 registry.register(RequireSeeds);
323
324 let ctx = Context::new();
325 let result = registry.check_acceptance(&ctx);
326
327 assert!(result.is_err());
328 let err = result.unwrap_err();
329 assert_eq!(err.invariant_name, "require_seeds");
330 assert_eq!(err.class, InvariantClass::Acceptance);
331 }
332
333 #[test]
334 fn structural_invariant_catches_empty_content() {
335 let mut registry = InvariantRegistry::new();
336 registry.register(NoEmptyContent);
337
338 let mut ctx = Context::new();
339 let _ = ctx.add_fact(crate::context::new_fact(ContextKey::Seeds, "bad", " ")); let result = registry.check_structural(&ctx);
342 assert!(result.is_err());
343 assert!(
344 result
345 .unwrap_err()
346 .violation
347 .fact_ids
348 .contains(&"bad".into())
349 );
350 }
351
352 #[test]
353 fn different_classes_checked_independently() {
354 let mut registry = InvariantRegistry::new();
355 registry.register(RequireSeeds); registry.register(NoEmptyContent); let ctx = Context::new();
359
360 assert!(registry.check_structural(&ctx).is_ok());
362
363 assert!(registry.check_acceptance(&ctx).is_err());
365 }
366}