1use serde::{Deserialize, Serialize};
35use std::collections::HashMap;
36
37use crate::context::{ContextState, FactId};
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
41pub enum InvariantClass {
42 Structural,
45
46 Semantic,
49
50 Acceptance,
53}
54
55#[derive(Debug, Clone, PartialEq)]
57pub enum InvariantResult {
58 Ok,
60 Violated(Violation),
62}
63
64impl InvariantResult {
65 #[must_use]
67 pub fn is_ok(&self) -> bool {
68 matches!(self, Self::Ok)
69 }
70
71 #[must_use]
73 pub fn is_violated(&self) -> bool {
74 matches!(self, Self::Violated(_))
75 }
76}
77
78#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
80pub struct Violation {
81 pub reason: String,
83 pub fact_ids: Vec<FactId>,
85}
86
87impl Violation {
88 #[must_use]
90 pub fn new(reason: impl Into<String>) -> Self {
91 Self {
92 reason: reason.into(),
93 fact_ids: Vec::new(),
94 }
95 }
96
97 #[must_use]
99 pub fn with_facts(reason: impl Into<String>, fact_ids: Vec<FactId>) -> Self {
100 Self {
101 reason: reason.into(),
102 fact_ids,
103 }
104 }
105}
106
107pub trait Invariant: Send + Sync {
111 fn name(&self) -> &str;
113
114 fn class(&self) -> InvariantClass;
116
117 fn check(&self, ctx: &dyn crate::Context) -> InvariantResult;
119}
120
121#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
123pub struct InvariantId(pub(crate) u32);
124
125impl std::fmt::Display for InvariantId {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 write!(f, "Invariant({})", self.0)
128 }
129}
130
131#[derive(Default)]
133pub struct InvariantRegistry {
134 invariants: Vec<Box<dyn Invariant>>,
135 by_class: HashMap<InvariantClass, Vec<InvariantId>>,
136 next_id: u32,
137}
138
139impl InvariantRegistry {
140 #[must_use]
142 pub fn new() -> Self {
143 Self::default()
144 }
145
146 pub fn register(&mut self, invariant: impl Invariant + 'static) -> InvariantId {
148 let id = InvariantId(self.next_id);
149 self.next_id += 1;
150
151 let class = invariant.class();
152 self.by_class.entry(class).or_default().push(id);
153 self.invariants.push(Box::new(invariant));
154
155 id
156 }
157
158 #[must_use]
160 pub fn count(&self) -> usize {
161 self.invariants.len()
162 }
163
164 pub fn check_class(
172 &self,
173 class: InvariantClass,
174 ctx: &ContextState,
175 ) -> Result<(), InvariantError> {
176 let ids = self.by_class.get(&class).map_or(&[][..], Vec::as_slice);
177
178 for &id in ids {
179 let invariant = &self.invariants[id.0 as usize];
180 if let InvariantResult::Violated(violation) = invariant.check(ctx) {
181 return Err(InvariantError {
182 invariant_name: invariant.name().to_string(),
183 class,
184 violation,
185 });
186 }
187 }
188
189 Ok(())
190 }
191
192 pub fn check_structural(&self, ctx: &ContextState) -> Result<(), InvariantError> {
198 self.check_class(InvariantClass::Structural, ctx)
199 }
200
201 pub fn check_semantic(&self, ctx: &ContextState) -> Result<(), InvariantError> {
207 self.check_class(InvariantClass::Semantic, ctx)
208 }
209
210 pub fn check_acceptance(&self, ctx: &ContextState) -> Result<(), InvariantError> {
216 self.check_class(InvariantClass::Acceptance, ctx)
217 }
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct InvariantError {
223 pub invariant_name: String,
225 pub class: InvariantClass,
227 pub violation: Violation,
229}
230
231impl std::fmt::Display for InvariantError {
232 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233 write!(
234 f,
235 "{:?} invariant '{}' violated: {}",
236 self.class, self.invariant_name, self.violation.reason
237 )
238 }
239}
240
241impl std::error::Error for InvariantError {}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use crate::context::ContextKey;
247
248 struct RequireSeeds;
250
251 impl Invariant for RequireSeeds {
252 fn name(&self) -> &'static str {
253 "require_seeds"
254 }
255
256 fn class(&self) -> InvariantClass {
257 InvariantClass::Acceptance
258 }
259
260 fn check(&self, ctx: &dyn crate::Context) -> InvariantResult {
261 if ctx.has(ContextKey::Seeds) {
262 InvariantResult::Ok
263 } else {
264 InvariantResult::Violated(Violation::new("no seeds present"))
265 }
266 }
267 }
268
269 struct NoEmptyContent;
271
272 impl Invariant for NoEmptyContent {
273 fn name(&self) -> &'static str {
274 "no_empty_content"
275 }
276
277 fn class(&self) -> InvariantClass {
278 InvariantClass::Structural
279 }
280
281 fn check(&self, ctx: &dyn crate::Context) -> InvariantResult {
282 for key in &[
283 ContextKey::Seeds,
284 ContextKey::Hypotheses,
285 ContextKey::Strategies,
286 ContextKey::Competitors,
287 ContextKey::Evaluations,
288 ] {
289 for fact in ctx.get(*key) {
290 if fact.text().is_some_and(|text| text.trim().is_empty()) {
291 return InvariantResult::Violated(Violation::with_facts(
292 "empty content not allowed",
293 vec![fact.id().clone()],
294 ));
295 }
296 }
297 }
298 InvariantResult::Ok
299 }
300 }
301
302 #[test]
303 fn registry_registers_invariants() {
304 let mut registry = InvariantRegistry::new();
305 let id1 = registry.register(RequireSeeds);
306 let id2 = registry.register(NoEmptyContent);
307
308 assert_eq!(registry.count(), 2);
309 assert_ne!(id1, id2);
310 }
311
312 #[test]
313 fn acceptance_invariant_passes_with_seeds() {
314 let mut registry = InvariantRegistry::new();
315 registry.register(RequireSeeds);
316
317 let mut ctx = ContextState::new();
318 let _ = ctx.add_fact(crate::context::new_fact(ContextKey::Seeds, "s1", "value"));
319
320 assert!(registry.check_acceptance(&ctx).is_ok());
321 }
322
323 #[test]
324 fn acceptance_invariant_fails_without_seeds() {
325 let mut registry = InvariantRegistry::new();
326 registry.register(RequireSeeds);
327
328 let ctx = ContextState::new();
329 let result = registry.check_acceptance(&ctx);
330
331 assert!(result.is_err());
332 let err = result.unwrap_err();
333 assert_eq!(err.invariant_name, "require_seeds");
334 assert_eq!(err.class, InvariantClass::Acceptance);
335 }
336
337 #[test]
338 fn structural_invariant_catches_empty_content() {
339 let mut registry = InvariantRegistry::new();
340 registry.register(NoEmptyContent);
341
342 let mut ctx = ContextState::new();
343 let _ = ctx.add_fact(crate::context::new_fact(ContextKey::Seeds, "bad", " ")); let result = registry.check_structural(&ctx);
346 assert!(result.is_err());
347 assert!(
348 result
349 .unwrap_err()
350 .violation
351 .fact_ids
352 .contains(&"bad".into())
353 );
354 }
355
356 #[test]
357 fn different_classes_checked_independently() {
358 let mut registry = InvariantRegistry::new();
359 registry.register(RequireSeeds); registry.register(NoEmptyContent); let ctx = ContextState::new();
363
364 assert!(registry.check_structural(&ctx).is_ok());
366
367 assert!(registry.check_acceptance(&ctx).is_err());
369 }
370}