converge_core/
invariant.rs

1// Copyright 2024-2025 Aprio One AB, Sweden
2// Author: Kenneth Pernyer, kenneth@aprio.one
3// SPDX-License-Identifier: LicenseRef-Proprietary
4// All rights reserved. This source code is proprietary and confidential.
5// Unauthorized copying, modification, or distribution is strictly prohibited.
6
7//! Invariant system for Converge.
8//!
9//! Invariants are runtime constraints that the engine enforces.
10//! They are compiled from Gherkin specs into Rust predicates.
11//!
12//! # Invariant Classes
13//!
14//! - **Structural**: Checked on every merge. Violation = immediate failure.
15//! - **Semantic**: Checked per cycle. Violation = blocks convergence.
16//! - **Acceptance**: Checked at convergence. Violation = rejects results.
17//!
18//! # Example
19//!
20//! ```
21//! use converge_core::invariant::{Invariant, InvariantClass, InvariantResult};
22//! use converge_core::Context;
23//!
24//! struct NoEmptyFacts;
25//!
26//! impl Invariant for NoEmptyFacts {
27//!     fn name(&self) -> &str { "no_empty_facts" }
28//!     fn class(&self) -> InvariantClass { InvariantClass::Structural }
29//!
30//!     fn check(&self, ctx: &Context) -> InvariantResult {
31//!         // Check logic here
32//!         InvariantResult::Ok
33//!     }
34//! }
35//! ```
36
37use serde::{Deserialize, Serialize};
38use std::collections::HashMap;
39
40use crate::context::Context;
41
42/// The class of an invariant determines when it's checked and how violations are handled.
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
44pub enum InvariantClass {
45    /// Checked on every merge. Violation causes immediate failure.
46    /// Examples: schema validity, type correctness, forbidden combinations.
47    Structural,
48
49    /// Checked at the end of each cycle. Violation blocks convergence.
50    /// Examples: "no strategy violates brand safety".
51    Semantic,
52
53    /// Checked when convergence is claimed. Violation rejects results.
54    /// Examples: "at least two viable strategies exist".
55    Acceptance,
56}
57
58/// The result of checking an invariant.
59#[derive(Debug, Clone, PartialEq)]
60pub enum InvariantResult {
61    /// Invariant holds.
62    Ok,
63    /// Invariant is violated.
64    Violated(Violation),
65}
66
67impl InvariantResult {
68    /// Returns true if the invariant holds.
69    #[must_use]
70    pub fn is_ok(&self) -> bool {
71        matches!(self, Self::Ok)
72    }
73
74    /// Returns true if the invariant is violated.
75    #[must_use]
76    pub fn is_violated(&self) -> bool {
77        matches!(self, Self::Violated(_))
78    }
79}
80
81/// Details of an invariant violation.
82#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
83pub struct Violation {
84    /// Human-readable description of what went wrong.
85    pub reason: String,
86    /// Optional: which facts contributed to the violation.
87    pub fact_ids: Vec<String>,
88}
89
90impl Violation {
91    /// Creates a new violation with just a reason.
92    #[must_use]
93    pub fn new(reason: impl Into<String>) -> Self {
94        Self {
95            reason: reason.into(),
96            fact_ids: Vec::new(),
97        }
98    }
99
100    /// Creates a violation with associated fact IDs.
101    #[must_use]
102    pub fn with_facts(reason: impl Into<String>, fact_ids: Vec<String>) -> Self {
103        Self {
104            reason: reason.into(),
105            fact_ids,
106        }
107    }
108}
109
110/// A runtime invariant that the engine enforces.
111///
112/// Invariants are the "law" that Gherkin specs compile to.
113pub trait Invariant: Send + Sync {
114    /// Human-readable name for tracing and error messages.
115    fn name(&self) -> &str;
116
117    /// The class determines when this invariant is checked.
118    fn class(&self) -> InvariantClass;
119
120    /// Check the invariant against the current context.
121    fn check(&self, ctx: &Context) -> InvariantResult;
122}
123
124/// Unique identifier for a registered invariant.
125#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
126pub struct InvariantId(pub(crate) u32);
127
128impl std::fmt::Display for InvariantId {
129    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130        write!(f, "Invariant({})", self.0)
131    }
132}
133
134/// Registry of invariants, organized by class for efficient checking.
135#[derive(Default)]
136pub struct InvariantRegistry {
137    invariants: Vec<Box<dyn Invariant>>,
138    by_class: HashMap<InvariantClass, Vec<InvariantId>>,
139    next_id: u32,
140}
141
142impl InvariantRegistry {
143    /// Creates an empty registry.
144    #[must_use]
145    pub fn new() -> Self {
146        Self::default()
147    }
148
149    /// Registers an invariant and returns its ID.
150    pub fn register(&mut self, invariant: impl Invariant + 'static) -> InvariantId {
151        let id = InvariantId(self.next_id);
152        self.next_id += 1;
153
154        let class = invariant.class();
155        self.by_class.entry(class).or_default().push(id);
156        self.invariants.push(Box::new(invariant));
157
158        id
159    }
160
161    /// Returns the number of registered invariants.
162    #[must_use]
163    pub fn count(&self) -> usize {
164        self.invariants.len()
165    }
166
167    /// Checks all invariants of a given class.
168    ///
169    /// Returns the first violation found, or Ok if all pass.
170    ///
171    /// # Errors
172    ///
173    /// Returns `InvariantError` if any invariant of the given class is violated.
174    pub fn check_class(&self, class: InvariantClass, ctx: &Context) -> Result<(), InvariantError> {
175        let ids = self.by_class.get(&class).map_or(&[][..], Vec::as_slice);
176
177        for &id in ids {
178            let invariant = &self.invariants[id.0 as usize];
179            if let InvariantResult::Violated(violation) = invariant.check(ctx) {
180                return Err(InvariantError {
181                    invariant_name: invariant.name().to_string(),
182                    class,
183                    violation,
184                });
185            }
186        }
187
188        Ok(())
189    }
190
191    /// Checks all structural invariants.
192    ///
193    /// # Errors
194    ///
195    /// Returns `InvariantError` if any structural invariant is violated.
196    pub fn check_structural(&self, ctx: &Context) -> Result<(), InvariantError> {
197        self.check_class(InvariantClass::Structural, ctx)
198    }
199
200    /// Checks all semantic invariants.
201    ///
202    /// # Errors
203    ///
204    /// Returns `InvariantError` if any semantic invariant is violated.
205    pub fn check_semantic(&self, ctx: &Context) -> Result<(), InvariantError> {
206        self.check_class(InvariantClass::Semantic, ctx)
207    }
208
209    /// Checks all acceptance invariants.
210    ///
211    /// # Errors
212    ///
213    /// Returns `InvariantError` if any acceptance invariant is violated.
214    pub fn check_acceptance(&self, ctx: &Context) -> Result<(), InvariantError> {
215        self.check_class(InvariantClass::Acceptance, ctx)
216    }
217}
218
219/// Error returned when an invariant is violated.
220#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct InvariantError {
222    /// Name of the invariant that was violated.
223    pub invariant_name: String,
224    /// Class of the invariant.
225    pub class: InvariantClass,
226    /// Details of the violation.
227    pub violation: Violation,
228}
229
230impl std::fmt::Display for InvariantError {
231    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
232        write!(
233            f,
234            "{:?} invariant '{}' violated: {}",
235            self.class, self.invariant_name, self.violation.reason
236        )
237    }
238}
239
240impl std::error::Error for InvariantError {}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245    use crate::context::{ContextKey, Fact};
246
247    /// Invariant that requires at least one seed.
248    struct RequireSeeds;
249
250    impl Invariant for RequireSeeds {
251        fn name(&self) -> &'static str {
252            "require_seeds"
253        }
254
255        fn class(&self) -> InvariantClass {
256            InvariantClass::Acceptance
257        }
258
259        fn check(&self, ctx: &Context) -> InvariantResult {
260            if ctx.has(ContextKey::Seeds) {
261                InvariantResult::Ok
262            } else {
263                InvariantResult::Violated(Violation::new("no seeds present"))
264            }
265        }
266    }
267
268    /// Invariant that forbids empty fact content.
269    struct NoEmptyContent;
270
271    impl Invariant for NoEmptyContent {
272        fn name(&self) -> &'static str {
273            "no_empty_content"
274        }
275
276        fn class(&self) -> InvariantClass {
277            InvariantClass::Structural
278        }
279
280        fn check(&self, ctx: &Context) -> InvariantResult {
281            for key in &[
282                ContextKey::Seeds,
283                ContextKey::Hypotheses,
284                ContextKey::Strategies,
285                ContextKey::Competitors,
286                ContextKey::Evaluations,
287            ] {
288                for fact in ctx.get(*key) {
289                    if fact.content.trim().is_empty() {
290                        return InvariantResult::Violated(Violation::with_facts(
291                            "empty content not allowed",
292                            vec![fact.id.clone()],
293                        ));
294                    }
295                }
296            }
297            InvariantResult::Ok
298        }
299    }
300
301    #[test]
302    fn registry_registers_invariants() {
303        let mut registry = InvariantRegistry::new();
304        let id1 = registry.register(RequireSeeds);
305        let id2 = registry.register(NoEmptyContent);
306
307        assert_eq!(registry.count(), 2);
308        assert_ne!(id1, id2);
309    }
310
311    #[test]
312    fn acceptance_invariant_passes_with_seeds() {
313        let mut registry = InvariantRegistry::new();
314        registry.register(RequireSeeds);
315
316        let mut ctx = Context::new();
317        let _ = ctx.add_fact(Fact {
318            key: ContextKey::Seeds,
319            id: "s1".into(),
320            content: "value".into(),
321        });
322
323        assert!(registry.check_acceptance(&ctx).is_ok());
324    }
325
326    #[test]
327    fn acceptance_invariant_fails_without_seeds() {
328        let mut registry = InvariantRegistry::new();
329        registry.register(RequireSeeds);
330
331        let ctx = Context::new();
332        let result = registry.check_acceptance(&ctx);
333
334        assert!(result.is_err());
335        let err = result.unwrap_err();
336        assert_eq!(err.invariant_name, "require_seeds");
337        assert_eq!(err.class, InvariantClass::Acceptance);
338    }
339
340    #[test]
341    fn structural_invariant_catches_empty_content() {
342        let mut registry = InvariantRegistry::new();
343        registry.register(NoEmptyContent);
344
345        let mut ctx = Context::new();
346        let _ = ctx.add_fact(Fact {
347            key: ContextKey::Seeds,
348            id: "bad".into(),
349            content: "   ".into(), // Empty after trim
350        });
351
352        let result = registry.check_structural(&ctx);
353        assert!(result.is_err());
354        assert!(
355            result
356                .unwrap_err()
357                .violation
358                .fact_ids
359                .contains(&"bad".into())
360        );
361    }
362
363    #[test]
364    fn different_classes_checked_independently() {
365        let mut registry = InvariantRegistry::new();
366        registry.register(RequireSeeds); // Acceptance
367        registry.register(NoEmptyContent); // Structural
368
369        let ctx = Context::new();
370
371        // Structural passes (no facts to check)
372        assert!(registry.check_structural(&ctx).is_ok());
373
374        // Acceptance fails (no seeds)
375        assert!(registry.check_acceptance(&ctx).is_err());
376    }
377}