1use serde::{Deserialize, Serialize};
38use std::collections::HashMap;
39
40use crate::context::Context;
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
44pub enum InvariantClass {
45 Structural,
48
49 Semantic,
52
53 Acceptance,
56}
57
58#[derive(Debug, Clone, PartialEq)]
60pub enum InvariantResult {
61 Ok,
63 Violated(Violation),
65}
66
67impl InvariantResult {
68 #[must_use]
70 pub fn is_ok(&self) -> bool {
71 matches!(self, Self::Ok)
72 }
73
74 #[must_use]
76 pub fn is_violated(&self) -> bool {
77 matches!(self, Self::Violated(_))
78 }
79}
80
81#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
83pub struct Violation {
84 pub reason: String,
86 pub fact_ids: Vec<String>,
88}
89
90impl Violation {
91 #[must_use]
93 pub fn new(reason: impl Into<String>) -> Self {
94 Self {
95 reason: reason.into(),
96 fact_ids: Vec::new(),
97 }
98 }
99
100 #[must_use]
102 pub fn with_facts(reason: impl Into<String>, fact_ids: Vec<String>) -> Self {
103 Self {
104 reason: reason.into(),
105 fact_ids,
106 }
107 }
108}
109
110pub trait Invariant: Send + Sync {
114 fn name(&self) -> &str;
116
117 fn class(&self) -> InvariantClass;
119
120 fn check(&self, ctx: &Context) -> InvariantResult;
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
126pub struct InvariantId(pub(crate) u32);
127
128impl std::fmt::Display for InvariantId {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 write!(f, "Invariant({})", self.0)
131 }
132}
133
134#[derive(Default)]
136pub struct InvariantRegistry {
137 invariants: Vec<Box<dyn Invariant>>,
138 by_class: HashMap<InvariantClass, Vec<InvariantId>>,
139 next_id: u32,
140}
141
142impl InvariantRegistry {
143 #[must_use]
145 pub fn new() -> Self {
146 Self::default()
147 }
148
149 pub fn register(&mut self, invariant: impl Invariant + 'static) -> InvariantId {
151 let id = InvariantId(self.next_id);
152 self.next_id += 1;
153
154 let class = invariant.class();
155 self.by_class.entry(class).or_default().push(id);
156 self.invariants.push(Box::new(invariant));
157
158 id
159 }
160
161 #[must_use]
163 pub fn count(&self) -> usize {
164 self.invariants.len()
165 }
166
167 pub fn check_class(&self, class: InvariantClass, ctx: &Context) -> Result<(), InvariantError> {
175 let ids = self.by_class.get(&class).map_or(&[][..], Vec::as_slice);
176
177 for &id in ids {
178 let invariant = &self.invariants[id.0 as usize];
179 if let InvariantResult::Violated(violation) = invariant.check(ctx) {
180 return Err(InvariantError {
181 invariant_name: invariant.name().to_string(),
182 class,
183 violation,
184 });
185 }
186 }
187
188 Ok(())
189 }
190
191 pub fn check_structural(&self, ctx: &Context) -> Result<(), InvariantError> {
197 self.check_class(InvariantClass::Structural, ctx)
198 }
199
200 pub fn check_semantic(&self, ctx: &Context) -> Result<(), InvariantError> {
206 self.check_class(InvariantClass::Semantic, ctx)
207 }
208
209 pub fn check_acceptance(&self, ctx: &Context) -> Result<(), InvariantError> {
215 self.check_class(InvariantClass::Acceptance, ctx)
216 }
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct InvariantError {
222 pub invariant_name: String,
224 pub class: InvariantClass,
226 pub violation: Violation,
228}
229
230impl std::fmt::Display for InvariantError {
231 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
232 write!(
233 f,
234 "{:?} invariant '{}' violated: {}",
235 self.class, self.invariant_name, self.violation.reason
236 )
237 }
238}
239
240impl std::error::Error for InvariantError {}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245 use crate::context::{ContextKey, Fact};
246
247 struct RequireSeeds;
249
250 impl Invariant for RequireSeeds {
251 fn name(&self) -> &'static str {
252 "require_seeds"
253 }
254
255 fn class(&self) -> InvariantClass {
256 InvariantClass::Acceptance
257 }
258
259 fn check(&self, ctx: &Context) -> InvariantResult {
260 if ctx.has(ContextKey::Seeds) {
261 InvariantResult::Ok
262 } else {
263 InvariantResult::Violated(Violation::new("no seeds present"))
264 }
265 }
266 }
267
268 struct NoEmptyContent;
270
271 impl Invariant for NoEmptyContent {
272 fn name(&self) -> &'static str {
273 "no_empty_content"
274 }
275
276 fn class(&self) -> InvariantClass {
277 InvariantClass::Structural
278 }
279
280 fn check(&self, ctx: &Context) -> InvariantResult {
281 for key in &[
282 ContextKey::Seeds,
283 ContextKey::Hypotheses,
284 ContextKey::Strategies,
285 ContextKey::Competitors,
286 ContextKey::Evaluations,
287 ] {
288 for fact in ctx.get(*key) {
289 if fact.content.trim().is_empty() {
290 return InvariantResult::Violated(Violation::with_facts(
291 "empty content not allowed",
292 vec![fact.id.clone()],
293 ));
294 }
295 }
296 }
297 InvariantResult::Ok
298 }
299 }
300
301 #[test]
302 fn registry_registers_invariants() {
303 let mut registry = InvariantRegistry::new();
304 let id1 = registry.register(RequireSeeds);
305 let id2 = registry.register(NoEmptyContent);
306
307 assert_eq!(registry.count(), 2);
308 assert_ne!(id1, id2);
309 }
310
311 #[test]
312 fn acceptance_invariant_passes_with_seeds() {
313 let mut registry = InvariantRegistry::new();
314 registry.register(RequireSeeds);
315
316 let mut ctx = Context::new();
317 let _ = ctx.add_fact(Fact {
318 key: ContextKey::Seeds,
319 id: "s1".into(),
320 content: "value".into(),
321 });
322
323 assert!(registry.check_acceptance(&ctx).is_ok());
324 }
325
326 #[test]
327 fn acceptance_invariant_fails_without_seeds() {
328 let mut registry = InvariantRegistry::new();
329 registry.register(RequireSeeds);
330
331 let ctx = Context::new();
332 let result = registry.check_acceptance(&ctx);
333
334 assert!(result.is_err());
335 let err = result.unwrap_err();
336 assert_eq!(err.invariant_name, "require_seeds");
337 assert_eq!(err.class, InvariantClass::Acceptance);
338 }
339
340 #[test]
341 fn structural_invariant_catches_empty_content() {
342 let mut registry = InvariantRegistry::new();
343 registry.register(NoEmptyContent);
344
345 let mut ctx = Context::new();
346 let _ = ctx.add_fact(Fact {
347 key: ContextKey::Seeds,
348 id: "bad".into(),
349 content: " ".into(), });
351
352 let result = registry.check_structural(&ctx);
353 assert!(result.is_err());
354 assert!(
355 result
356 .unwrap_err()
357 .violation
358 .fact_ids
359 .contains(&"bad".into())
360 );
361 }
362
363 #[test]
364 fn different_classes_checked_independently() {
365 let mut registry = InvariantRegistry::new();
366 registry.register(RequireSeeds); registry.register(NoEmptyContent); let ctx = Context::new();
370
371 assert!(registry.check_structural(&ctx).is_ok());
373
374 assert!(registry.check_acceptance(&ctx).is_err());
376 }
377}