1use serde::{Deserialize, Serialize};
24
25use crate::{
26 error::{ExoError, Result},
27 types::{DeterministicMap, Did, Hash256, Timestamp},
28};
29
30#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
36pub struct InvariantViolation {
37 pub invariant_name: String,
39 pub description: String,
41 pub severity: ViolationSeverity,
43 pub context: DeterministicMap<String, String>,
45}
46
47#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
49pub enum ViolationSeverity {
50 Warning,
52 Error,
54 Critical,
56}
57
58impl core::fmt::Display for ViolationSeverity {
59 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
60 match self {
61 ViolationSeverity::Warning => write!(f, "WARNING"),
62 ViolationSeverity::Error => write!(f, "ERROR"),
63 ViolationSeverity::Critical => write!(f, "CRITICAL"),
64 }
65 }
66}
67
68#[derive(Clone, Debug)]
74pub struct InvariantContext {
75 pub actor_did: Did,
77 pub timestamp: Timestamp,
79 pub state_hash: Hash256,
81 pub properties: DeterministicMap<String, String>,
83}
84
85impl InvariantContext {
86 #[must_use]
88 pub fn new(actor_did: Did, timestamp: Timestamp, state_hash: Hash256) -> Self {
89 Self {
90 actor_did,
91 timestamp,
92 state_hash,
93 properties: DeterministicMap::new(),
94 }
95 }
96
97 pub fn set_property(&mut self, key: impl Into<String>, value: impl Into<String>) {
99 self.properties.insert(key.into(), value.into());
100 }
101
102 #[must_use]
104 pub fn get_property(&self, key: &str) -> Option<&String> {
105 self.properties.get(&key.to_owned())
106 }
107}
108
109pub trait Invariant: core::fmt::Debug {
115 fn name(&self) -> &str;
117
118 fn check(&self, context: &InvariantContext) -> core::result::Result<(), InvariantViolation>;
121}
122
123pub struct InvariantSet {
129 invariants: Vec<Box<dyn Invariant>>,
130}
131
132impl InvariantSet {
133 #[must_use]
135 pub fn new() -> Self {
136 Self {
137 invariants: Vec::new(),
138 }
139 }
140
141 pub fn add(&mut self, invariant: impl Invariant + 'static) {
143 self.invariants.push(Box::new(invariant));
144 }
145
146 #[must_use]
148 pub fn len(&self) -> usize {
149 self.invariants.len()
150 }
151
152 #[must_use]
154 pub fn is_empty(&self) -> bool {
155 self.invariants.is_empty()
156 }
157
158 pub fn check_all(&self, context: &InvariantContext) -> Result<()> {
163 for inv in &self.invariants {
164 if let Err(violation) = inv.check(context) {
165 return Err(ExoError::InvariantViolation {
166 description: format!(
167 "[{}] {}: {}",
168 violation.severity, violation.invariant_name, violation.description
169 ),
170 });
171 }
172 }
173 Ok(())
174 }
175
176 #[must_use]
178 pub fn check_all_collect(&self, context: &InvariantContext) -> Vec<InvariantViolation> {
179 let mut violations = Vec::new();
180 for inv in &self.invariants {
181 if let Err(v) = inv.check(context) {
182 violations.push(v);
183 }
184 }
185 violations
186 }
187}
188
189impl Default for InvariantSet {
190 fn default() -> Self {
191 Self::new()
192 }
193}
194
195impl core::fmt::Debug for InvariantSet {
196 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
197 f.debug_struct("InvariantSet")
198 .field("count", &self.invariants.len())
199 .finish()
200 }
201}
202
203pub fn check_all(invariants: &InvariantSet, context: &InvariantContext) -> Result<()> {
209 invariants.check_all(context)
210}
211
212#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::types::{Did, Hash256, Timestamp};
220
221 #[derive(Debug)]
225 struct AlwaysPass;
226
227 impl Invariant for AlwaysPass {
228 fn name(&self) -> &str {
229 "always_pass"
230 }
231
232 fn check(
233 &self,
234 _context: &InvariantContext,
235 ) -> core::result::Result<(), InvariantViolation> {
236 Ok(())
237 }
238 }
239
240 #[derive(Debug)]
242 struct AlwaysFail {
243 severity: ViolationSeverity,
244 }
245
246 impl Invariant for AlwaysFail {
247 fn name(&self) -> &str {
248 "always_fail"
249 }
250
251 fn check(
252 &self,
253 _context: &InvariantContext,
254 ) -> core::result::Result<(), InvariantViolation> {
255 Err(InvariantViolation {
256 invariant_name: self.name().to_string(),
257 description: "this always fails".to_string(),
258 severity: self.severity,
259 context: DeterministicMap::new(),
260 })
261 }
262 }
263
264 #[derive(Debug)]
266 struct RequireProperty {
267 key: String,
268 expected: String,
269 }
270
271 impl Invariant for RequireProperty {
272 fn name(&self) -> &str {
273 "require_property"
274 }
275
276 fn check(
277 &self,
278 context: &InvariantContext,
279 ) -> core::result::Result<(), InvariantViolation> {
280 match context.get_property(&self.key) {
281 Some(v) if v == &self.expected => Ok(()),
282 Some(v) => {
283 let mut ctx = DeterministicMap::new();
284 ctx.insert("expected".to_string(), self.expected.clone());
285 ctx.insert("actual".to_string(), v.clone());
286 Err(InvariantViolation {
287 invariant_name: self.name().to_string(),
288 description: format!("property '{}' mismatch", self.key),
289 severity: ViolationSeverity::Error,
290 context: ctx,
291 })
292 }
293 None => Err(InvariantViolation {
294 invariant_name: self.name().to_string(),
295 description: format!("property '{}' missing", self.key),
296 severity: ViolationSeverity::Error,
297 context: DeterministicMap::new(),
298 }),
299 }
300 }
301 }
302
303 fn test_context() -> InvariantContext {
304 InvariantContext::new(
305 Did::new("did:exo:tester").expect("valid"),
306 Timestamp::new(1000, 0),
307 Hash256::ZERO,
308 )
309 }
310
311 #[test]
314 fn violation_serde_roundtrip() {
315 let v = InvariantViolation {
316 invariant_name: "test".into(),
317 description: "something broke".into(),
318 severity: ViolationSeverity::Critical,
319 context: DeterministicMap::new(),
320 };
321 let json = serde_json::to_string(&v).expect("ser");
322 let v2: InvariantViolation = serde_json::from_str(&json).expect("de");
323 assert_eq!(v, v2);
324 }
325
326 #[test]
327 fn violation_severity_display() {
328 assert_eq!(ViolationSeverity::Warning.to_string(), "WARNING");
329 assert_eq!(ViolationSeverity::Error.to_string(), "ERROR");
330 assert_eq!(ViolationSeverity::Critical.to_string(), "CRITICAL");
331 }
332
333 #[test]
334 fn violation_severity_ord() {
335 assert!(ViolationSeverity::Warning < ViolationSeverity::Error);
336 assert!(ViolationSeverity::Error < ViolationSeverity::Critical);
337 }
338
339 #[test]
342 fn context_new() {
343 let ctx = test_context();
344 assert_eq!(ctx.actor_did.as_str(), "did:exo:tester");
345 assert_eq!(ctx.timestamp, Timestamp::new(1000, 0));
346 assert_eq!(ctx.state_hash, Hash256::ZERO);
347 assert!(ctx.properties.is_empty());
348 }
349
350 #[test]
351 fn context_set_get_property() {
352 let mut ctx = test_context();
353 ctx.set_property("role", "admin");
354 assert_eq!(ctx.get_property("role"), Some(&"admin".to_string()));
355 assert_eq!(ctx.get_property("missing"), None);
356 }
357
358 #[test]
361 fn always_pass_passes() {
362 let inv = AlwaysPass;
363 assert_eq!(inv.name(), "always_pass");
364 let ctx = test_context();
365 assert!(inv.check(&ctx).is_ok());
366 }
367
368 #[test]
369 fn always_fail_fails() {
370 let inv = AlwaysFail {
371 severity: ViolationSeverity::Error,
372 };
373 let ctx = test_context();
374 let err = inv.check(&ctx).unwrap_err();
375 assert_eq!(err.invariant_name, "always_fail");
376 assert_eq!(err.severity, ViolationSeverity::Error);
377 }
378
379 #[test]
380 fn require_property_pass() {
381 let inv = RequireProperty {
382 key: "mode".into(),
383 expected: "production".into(),
384 };
385 let mut ctx = test_context();
386 ctx.set_property("mode", "production");
387 assert!(inv.check(&ctx).is_ok());
388 }
389
390 #[test]
391 fn require_property_mismatch() {
392 let inv = RequireProperty {
393 key: "mode".into(),
394 expected: "production".into(),
395 };
396 let mut ctx = test_context();
397 ctx.set_property("mode", "debug");
398 let err = inv.check(&ctx).unwrap_err();
399 assert!(err.description.contains("mismatch"));
400 assert!(err.context.contains_key(&"expected".to_string()));
401 assert!(err.context.contains_key(&"actual".to_string()));
402 }
403
404 #[test]
405 fn require_property_missing() {
406 let inv = RequireProperty {
407 key: "mode".into(),
408 expected: "production".into(),
409 };
410 let ctx = test_context();
411 let err = inv.check(&ctx).unwrap_err();
412 assert!(err.description.contains("missing"));
413 }
414
415 #[test]
418 fn empty_set_passes() {
419 let set = InvariantSet::new();
420 assert!(set.is_empty());
421 assert_eq!(set.len(), 0);
422 let ctx = test_context();
423 assert!(set.check_all(&ctx).is_ok());
424 assert!(set.check_all_collect(&ctx).is_empty());
425 }
426
427 #[test]
428 fn set_all_pass() {
429 let mut set = InvariantSet::new();
430 set.add(AlwaysPass);
431 set.add(AlwaysPass);
432 assert_eq!(set.len(), 2);
433 assert!(!set.is_empty());
434 let ctx = test_context();
435 assert!(set.check_all(&ctx).is_ok());
436 assert!(set.check_all_collect(&ctx).is_empty());
437 }
438
439 #[test]
440 fn set_one_fails() {
441 let mut set = InvariantSet::new();
442 set.add(AlwaysPass);
443 set.add(AlwaysFail {
444 severity: ViolationSeverity::Critical,
445 });
446 set.add(AlwaysPass);
447 let ctx = test_context();
448 let err = set.check_all(&ctx).unwrap_err();
449 assert!(matches!(err, ExoError::InvariantViolation { .. }));
450 }
451
452 #[test]
453 fn set_collect_all_violations() {
454 let mut set = InvariantSet::new();
455 set.add(AlwaysFail {
456 severity: ViolationSeverity::Warning,
457 });
458 set.add(AlwaysPass);
459 set.add(AlwaysFail {
460 severity: ViolationSeverity::Critical,
461 });
462 let ctx = test_context();
463 let violations = set.check_all_collect(&ctx);
464 assert_eq!(violations.len(), 2);
465 assert_eq!(violations[0].severity, ViolationSeverity::Warning);
466 assert_eq!(violations[1].severity, ViolationSeverity::Critical);
467 }
468
469 #[test]
470 fn check_all_function() {
471 let mut set = InvariantSet::new();
472 set.add(AlwaysPass);
473 let ctx = test_context();
474 assert!(check_all(&set, &ctx).is_ok());
475
476 let mut failing = InvariantSet::new();
477 failing.add(AlwaysFail {
478 severity: ViolationSeverity::Error,
479 });
480 let err = check_all(&failing, &ctx).unwrap_err();
481 assert!(matches!(err, ExoError::InvariantViolation { .. }));
482 }
483
484 #[test]
485 fn set_default() {
486 let set = InvariantSet::default();
487 assert!(set.is_empty());
488 }
489
490 #[test]
491 fn set_debug() {
492 let mut set = InvariantSet::new();
493 set.add(AlwaysPass);
494 let dbg = format!("{set:?}");
495 assert!(dbg.contains("InvariantSet"));
496 assert!(dbg.contains("1"));
497 }
498
499 #[test]
500 fn set_with_property_check() {
501 let mut set = InvariantSet::new();
502 set.add(RequireProperty {
503 key: "consent".into(),
504 expected: "granted".into(),
505 });
506
507 let ctx = test_context();
509 assert!(set.check_all(&ctx).is_err());
510
511 let mut ctx2 = test_context();
513 ctx2.set_property("consent", "granted");
514 assert!(set.check_all(&ctx2).is_ok());
515 }
516
517 #[test]
518 fn violation_context_is_deterministic() {
519 let inv = RequireProperty {
520 key: "x".into(),
521 expected: "y".into(),
522 };
523 let mut ctx = test_context();
524 ctx.set_property("x", "wrong");
525 let v1 = inv.check(&ctx).unwrap_err();
526 let v2 = inv.check(&ctx).unwrap_err();
527 assert_eq!(v1, v2);
528 }
529}