Skip to main content

exo_core/
invariants.rs

1// Copyright 2026 Exochain Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at:
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// SPDX-License-Identifier: Apache-2.0
16
17//! Invariant checking primitives for EXOCHAIN.
18//!
19//! Every operation in the system must pass a set of invariants before it
20//! is committed.  This module provides the trait, context, and set
21//! abstractions to express and enforce those invariants.
22
23use serde::{Deserialize, Serialize};
24
25use crate::{
26    error::{ExoError, Result},
27    types::{DeterministicMap, Did, Hash256, Timestamp},
28};
29
30// ---------------------------------------------------------------------------
31// InvariantViolation
32// ---------------------------------------------------------------------------
33
34/// A detailed report of a single invariant violation.
35#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
36pub struct InvariantViolation {
37    /// Human-readable name of the violated invariant.
38    pub invariant_name: String,
39    /// Description of what went wrong.
40    pub description: String,
41    /// Severity level.
42    pub severity: ViolationSeverity,
43    /// Optional context key-value pairs for diagnostics.
44    pub context: DeterministicMap<String, String>,
45}
46
47/// Severity of an invariant violation.
48#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
49pub enum ViolationSeverity {
50    /// Advisory — logged but does not block.
51    Warning,
52    /// Blocks the current operation.
53    Error,
54    /// Critical system integrity issue.
55    Critical,
56}
57
58impl core::fmt::Display for ViolationSeverity {
59    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
60        match self {
61            ViolationSeverity::Warning => write!(f, "WARNING"),
62            ViolationSeverity::Error => write!(f, "ERROR"),
63            ViolationSeverity::Critical => write!(f, "CRITICAL"),
64        }
65    }
66}
67
68// ---------------------------------------------------------------------------
69// InvariantContext
70// ---------------------------------------------------------------------------
71
72/// Snapshot of current state available to invariant checks.
73#[derive(Clone, Debug)]
74pub struct InvariantContext {
75    /// The actor performing the current operation.
76    pub actor_did: Did,
77    /// The current HLC timestamp.
78    pub timestamp: Timestamp,
79    /// The hash of the current state being validated.
80    pub state_hash: Hash256,
81    /// Arbitrary string properties for flexible invariant checking.
82    pub properties: DeterministicMap<String, String>,
83}
84
85impl InvariantContext {
86    /// Create a new context.
87    #[must_use]
88    pub fn new(actor_did: Did, timestamp: Timestamp, state_hash: Hash256) -> Self {
89        Self {
90            actor_did,
91            timestamp,
92            state_hash,
93            properties: DeterministicMap::new(),
94        }
95    }
96
97    /// Add a property to the context.
98    pub fn set_property(&mut self, key: impl Into<String>, value: impl Into<String>) {
99        self.properties.insert(key.into(), value.into());
100    }
101
102    /// Retrieve a property.
103    #[must_use]
104    pub fn get_property(&self, key: &str) -> Option<&String> {
105        self.properties.get(&key.to_owned())
106    }
107}
108
109// ---------------------------------------------------------------------------
110// Invariant trait
111// ---------------------------------------------------------------------------
112
113/// A single invariant that can be checked against a context.
114pub trait Invariant: core::fmt::Debug {
115    /// The name of this invariant for reporting.
116    fn name(&self) -> &str;
117
118    /// Check the invariant.  Return `Ok(())` if it holds, or an
119    /// `InvariantViolation` describing the failure.
120    fn check(&self, context: &InvariantContext) -> core::result::Result<(), InvariantViolation>;
121}
122
123// ---------------------------------------------------------------------------
124// InvariantSet
125// ---------------------------------------------------------------------------
126
127/// A collection of invariants that must all pass.
128pub struct InvariantSet {
129    invariants: Vec<Box<dyn Invariant>>,
130}
131
132impl InvariantSet {
133    /// Create an empty set.
134    #[must_use]
135    pub fn new() -> Self {
136        Self {
137            invariants: Vec::new(),
138        }
139    }
140
141    /// Add an invariant to the set.
142    pub fn add(&mut self, invariant: impl Invariant + 'static) {
143        self.invariants.push(Box::new(invariant));
144    }
145
146    /// Number of invariants in the set.
147    #[must_use]
148    pub fn len(&self) -> usize {
149        self.invariants.len()
150    }
151
152    /// Is the set empty?
153    #[must_use]
154    pub fn is_empty(&self) -> bool {
155        self.invariants.is_empty()
156    }
157
158    /// Check all invariants, collecting every violation.
159    ///
160    /// Returns `Ok(())` if all invariants pass, or `Err` with the first
161    /// blocking violation found.
162    pub fn check_all(&self, context: &InvariantContext) -> Result<()> {
163        for inv in &self.invariants {
164            if let Err(violation) = inv.check(context) {
165                return Err(ExoError::InvariantViolation {
166                    description: format!(
167                        "[{}] {}: {}",
168                        violation.severity, violation.invariant_name, violation.description
169                    ),
170                });
171            }
172        }
173        Ok(())
174    }
175
176    /// Check all invariants and return all violations (does not short-circuit).
177    #[must_use]
178    pub fn check_all_collect(&self, context: &InvariantContext) -> Vec<InvariantViolation> {
179        let mut violations = Vec::new();
180        for inv in &self.invariants {
181            if let Err(v) = inv.check(context) {
182                violations.push(v);
183            }
184        }
185        violations
186    }
187}
188
189impl Default for InvariantSet {
190    fn default() -> Self {
191        Self::new()
192    }
193}
194
195impl core::fmt::Debug for InvariantSet {
196    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
197        f.debug_struct("InvariantSet")
198            .field("count", &self.invariants.len())
199            .finish()
200    }
201}
202
203/// Convenience function: check all invariants in a set against a context.
204///
205/// # Errors
206///
207/// Returns `ExoError::InvariantViolation` on the first failure.
208pub fn check_all(invariants: &InvariantSet, context: &InvariantContext) -> Result<()> {
209    invariants.check_all(context)
210}
211
212// ===========================================================================
213// Tests
214// ===========================================================================
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::types::{Did, Hash256, Timestamp};
220
221    // -- Test invariant implementations ------------------------------------
222
223    /// An invariant that always passes.
224    #[derive(Debug)]
225    struct AlwaysPass;
226
227    impl Invariant for AlwaysPass {
228        fn name(&self) -> &str {
229            "always_pass"
230        }
231
232        fn check(
233            &self,
234            _context: &InvariantContext,
235        ) -> core::result::Result<(), InvariantViolation> {
236            Ok(())
237        }
238    }
239
240    /// An invariant that always fails.
241    #[derive(Debug)]
242    struct AlwaysFail {
243        severity: ViolationSeverity,
244    }
245
246    impl Invariant for AlwaysFail {
247        fn name(&self) -> &str {
248            "always_fail"
249        }
250
251        fn check(
252            &self,
253            _context: &InvariantContext,
254        ) -> core::result::Result<(), InvariantViolation> {
255            Err(InvariantViolation {
256                invariant_name: self.name().to_string(),
257                description: "this always fails".to_string(),
258                severity: self.severity,
259                context: DeterministicMap::new(),
260            })
261        }
262    }
263
264    /// An invariant that checks a property value.
265    #[derive(Debug)]
266    struct RequireProperty {
267        key: String,
268        expected: String,
269    }
270
271    impl Invariant for RequireProperty {
272        fn name(&self) -> &str {
273            "require_property"
274        }
275
276        fn check(
277            &self,
278            context: &InvariantContext,
279        ) -> core::result::Result<(), InvariantViolation> {
280            match context.get_property(&self.key) {
281                Some(v) if v == &self.expected => Ok(()),
282                Some(v) => {
283                    let mut ctx = DeterministicMap::new();
284                    ctx.insert("expected".to_string(), self.expected.clone());
285                    ctx.insert("actual".to_string(), v.clone());
286                    Err(InvariantViolation {
287                        invariant_name: self.name().to_string(),
288                        description: format!("property '{}' mismatch", self.key),
289                        severity: ViolationSeverity::Error,
290                        context: ctx,
291                    })
292                }
293                None => Err(InvariantViolation {
294                    invariant_name: self.name().to_string(),
295                    description: format!("property '{}' missing", self.key),
296                    severity: ViolationSeverity::Error,
297                    context: DeterministicMap::new(),
298                }),
299            }
300        }
301    }
302
303    fn test_context() -> InvariantContext {
304        InvariantContext::new(
305            Did::new("did:exo:tester").expect("valid"),
306            Timestamp::new(1000, 0),
307            Hash256::ZERO,
308        )
309    }
310
311    // -- InvariantViolation ------------------------------------------------
312
313    #[test]
314    fn violation_serde_roundtrip() {
315        let v = InvariantViolation {
316            invariant_name: "test".into(),
317            description: "something broke".into(),
318            severity: ViolationSeverity::Critical,
319            context: DeterministicMap::new(),
320        };
321        let json = serde_json::to_string(&v).expect("ser");
322        let v2: InvariantViolation = serde_json::from_str(&json).expect("de");
323        assert_eq!(v, v2);
324    }
325
326    #[test]
327    fn violation_severity_display() {
328        assert_eq!(ViolationSeverity::Warning.to_string(), "WARNING");
329        assert_eq!(ViolationSeverity::Error.to_string(), "ERROR");
330        assert_eq!(ViolationSeverity::Critical.to_string(), "CRITICAL");
331    }
332
333    #[test]
334    fn violation_severity_ord() {
335        assert!(ViolationSeverity::Warning < ViolationSeverity::Error);
336        assert!(ViolationSeverity::Error < ViolationSeverity::Critical);
337    }
338
339    // -- InvariantContext ---------------------------------------------------
340
341    #[test]
342    fn context_new() {
343        let ctx = test_context();
344        assert_eq!(ctx.actor_did.as_str(), "did:exo:tester");
345        assert_eq!(ctx.timestamp, Timestamp::new(1000, 0));
346        assert_eq!(ctx.state_hash, Hash256::ZERO);
347        assert!(ctx.properties.is_empty());
348    }
349
350    #[test]
351    fn context_set_get_property() {
352        let mut ctx = test_context();
353        ctx.set_property("role", "admin");
354        assert_eq!(ctx.get_property("role"), Some(&"admin".to_string()));
355        assert_eq!(ctx.get_property("missing"), None);
356    }
357
358    // -- Invariant implementations -----------------------------------------
359
360    #[test]
361    fn always_pass_passes() {
362        let inv = AlwaysPass;
363        assert_eq!(inv.name(), "always_pass");
364        let ctx = test_context();
365        assert!(inv.check(&ctx).is_ok());
366    }
367
368    #[test]
369    fn always_fail_fails() {
370        let inv = AlwaysFail {
371            severity: ViolationSeverity::Error,
372        };
373        let ctx = test_context();
374        let err = inv.check(&ctx).unwrap_err();
375        assert_eq!(err.invariant_name, "always_fail");
376        assert_eq!(err.severity, ViolationSeverity::Error);
377    }
378
379    #[test]
380    fn require_property_pass() {
381        let inv = RequireProperty {
382            key: "mode".into(),
383            expected: "production".into(),
384        };
385        let mut ctx = test_context();
386        ctx.set_property("mode", "production");
387        assert!(inv.check(&ctx).is_ok());
388    }
389
390    #[test]
391    fn require_property_mismatch() {
392        let inv = RequireProperty {
393            key: "mode".into(),
394            expected: "production".into(),
395        };
396        let mut ctx = test_context();
397        ctx.set_property("mode", "debug");
398        let err = inv.check(&ctx).unwrap_err();
399        assert!(err.description.contains("mismatch"));
400        assert!(err.context.contains_key(&"expected".to_string()));
401        assert!(err.context.contains_key(&"actual".to_string()));
402    }
403
404    #[test]
405    fn require_property_missing() {
406        let inv = RequireProperty {
407            key: "mode".into(),
408            expected: "production".into(),
409        };
410        let ctx = test_context();
411        let err = inv.check(&ctx).unwrap_err();
412        assert!(err.description.contains("missing"));
413    }
414
415    // -- InvariantSet ------------------------------------------------------
416
417    #[test]
418    fn empty_set_passes() {
419        let set = InvariantSet::new();
420        assert!(set.is_empty());
421        assert_eq!(set.len(), 0);
422        let ctx = test_context();
423        assert!(set.check_all(&ctx).is_ok());
424        assert!(set.check_all_collect(&ctx).is_empty());
425    }
426
427    #[test]
428    fn set_all_pass() {
429        let mut set = InvariantSet::new();
430        set.add(AlwaysPass);
431        set.add(AlwaysPass);
432        assert_eq!(set.len(), 2);
433        assert!(!set.is_empty());
434        let ctx = test_context();
435        assert!(set.check_all(&ctx).is_ok());
436        assert!(set.check_all_collect(&ctx).is_empty());
437    }
438
439    #[test]
440    fn set_one_fails() {
441        let mut set = InvariantSet::new();
442        set.add(AlwaysPass);
443        set.add(AlwaysFail {
444            severity: ViolationSeverity::Critical,
445        });
446        set.add(AlwaysPass);
447        let ctx = test_context();
448        let err = set.check_all(&ctx).unwrap_err();
449        assert!(matches!(err, ExoError::InvariantViolation { .. }));
450    }
451
452    #[test]
453    fn set_collect_all_violations() {
454        let mut set = InvariantSet::new();
455        set.add(AlwaysFail {
456            severity: ViolationSeverity::Warning,
457        });
458        set.add(AlwaysPass);
459        set.add(AlwaysFail {
460            severity: ViolationSeverity::Critical,
461        });
462        let ctx = test_context();
463        let violations = set.check_all_collect(&ctx);
464        assert_eq!(violations.len(), 2);
465        assert_eq!(violations[0].severity, ViolationSeverity::Warning);
466        assert_eq!(violations[1].severity, ViolationSeverity::Critical);
467    }
468
469    #[test]
470    fn check_all_function() {
471        let mut set = InvariantSet::new();
472        set.add(AlwaysPass);
473        let ctx = test_context();
474        assert!(check_all(&set, &ctx).is_ok());
475
476        let mut failing = InvariantSet::new();
477        failing.add(AlwaysFail {
478            severity: ViolationSeverity::Error,
479        });
480        let err = check_all(&failing, &ctx).unwrap_err();
481        assert!(matches!(err, ExoError::InvariantViolation { .. }));
482    }
483
484    #[test]
485    fn set_default() {
486        let set = InvariantSet::default();
487        assert!(set.is_empty());
488    }
489
490    #[test]
491    fn set_debug() {
492        let mut set = InvariantSet::new();
493        set.add(AlwaysPass);
494        let dbg = format!("{set:?}");
495        assert!(dbg.contains("InvariantSet"));
496        assert!(dbg.contains("1"));
497    }
498
499    #[test]
500    fn set_with_property_check() {
501        let mut set = InvariantSet::new();
502        set.add(RequireProperty {
503            key: "consent".into(),
504            expected: "granted".into(),
505        });
506
507        // Fails without property
508        let ctx = test_context();
509        assert!(set.check_all(&ctx).is_err());
510
511        // Passes with correct property
512        let mut ctx2 = test_context();
513        ctx2.set_property("consent", "granted");
514        assert!(set.check_all(&ctx2).is_ok());
515    }
516
517    #[test]
518    fn violation_context_is_deterministic() {
519        let inv = RequireProperty {
520            key: "x".into(),
521            expected: "y".into(),
522        };
523        let mut ctx = test_context();
524        ctx.set_property("x", "wrong");
525        let v1 = inv.check(&ctx).unwrap_err();
526        let v2 = inv.check(&ctx).unwrap_err();
527        assert_eq!(v1, v2);
528    }
529}