mpl_core/
assertions.rs

1//! CEL-based Assertion System
2//!
3//! Provides assertion definitions and evaluation using the Common Expression Language (CEL).
4//! Assertions are used to compute Instruction Compliance (IC) metrics.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use mpl_core::assertions::{Assertion, AssertionSet, AssertionEvaluator};
10//! use serde_json::json;
11//!
12//! let assertions = AssertionSet::new(vec![
13//!     Assertion::new("amount_positive", "payload.amount > 0", "Amount must be positive"),
14//!     Assertion::new("currency_valid", "payload.currency in ['USD', 'EUR', 'GBP']", "Invalid currency"),
15//! ]);
16//!
17//! let payload = json!({"amount": 100, "currency": "USD"});
18//! let result = assertions.evaluate(&payload)?;
19//! assert!(result.passed());
20//! ```
21
22use cel_interpreter::{Context, Program, Value};
23use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25use thiserror::Error;
26use tracing::debug;
27
28/// Assertion evaluation error
29#[derive(Debug, Error)]
30pub enum AssertionError {
31    #[error("Failed to compile CEL expression '{expr}': {message}")]
32    CompilationError { expr: String, message: String },
33
34    #[error("Failed to evaluate CEL expression '{expr}': {message}")]
35    EvaluationError { expr: String, message: String },
36
37    #[error("Invalid payload: {0}")]
38    InvalidPayload(String),
39}
40
41/// A single assertion with a CEL expression
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct Assertion {
44    /// Unique identifier for this assertion
45    pub id: String,
46
47    /// CEL expression to evaluate
48    /// Available variables: payload, metadata, context
49    pub expression: String,
50
51    /// Human-readable message when assertion fails
52    pub message: String,
53
54    /// Severity: error (blocks), warning (logs), info (metrics only)
55    #[serde(default = "default_severity")]
56    pub severity: AssertionSeverity,
57
58    /// Optional tags for categorization
59    #[serde(default)]
60    pub tags: Vec<String>,
61}
62
63fn default_severity() -> AssertionSeverity {
64    AssertionSeverity::Error
65}
66
67/// Assertion severity levels
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
69#[serde(rename_all = "lowercase")]
70pub enum AssertionSeverity {
71    /// Blocks the request if assertion fails
72    #[default]
73    Error,
74    /// Logs a warning but allows the request
75    Warning,
76    /// Only affects metrics, no blocking or logging
77    Info,
78}
79
80impl Assertion {
81    /// Create a new assertion
82    pub fn new(id: impl Into<String>, expression: impl Into<String>, message: impl Into<String>) -> Self {
83        Self {
84            id: id.into(),
85            expression: expression.into(),
86            message: message.into(),
87            severity: AssertionSeverity::Error,
88            tags: Vec::new(),
89        }
90    }
91
92    /// Set severity
93    pub fn with_severity(mut self, severity: AssertionSeverity) -> Self {
94        self.severity = severity;
95        self
96    }
97
98    /// Add tags
99    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
100        self.tags = tags;
101        self
102    }
103}
104
105/// Result of evaluating a single assertion
106#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct AssertionResult {
108    /// Assertion ID
109    pub id: String,
110
111    /// Whether the assertion passed
112    pub passed: bool,
113
114    /// The assertion message (shown on failure)
115    pub message: String,
116
117    /// Severity of this assertion
118    pub severity: AssertionSeverity,
119
120    /// Actual value returned by the expression (for debugging)
121    #[serde(skip_serializing_if = "Option::is_none")]
122    pub actual_value: Option<String>,
123
124    /// Error message if evaluation failed
125    #[serde(skip_serializing_if = "Option::is_none")]
126    pub error: Option<String>,
127}
128
129/// A set of assertions to evaluate together
130#[derive(Debug, Clone, Serialize, Deserialize, Default)]
131pub struct AssertionSet {
132    /// Name of this assertion set
133    #[serde(default)]
134    pub name: String,
135
136    /// Description
137    #[serde(default)]
138    pub description: String,
139
140    /// The assertions in this set
141    pub assertions: Vec<Assertion>,
142
143    /// Whether to stop on first error-severity failure
144    #[serde(default)]
145    pub fail_fast: bool,
146}
147
148impl AssertionSet {
149    /// Create a new assertion set
150    pub fn new(assertions: Vec<Assertion>) -> Self {
151        Self {
152            name: String::new(),
153            description: String::new(),
154            assertions,
155            fail_fast: false,
156        }
157    }
158
159    /// Create with name
160    pub fn with_name(mut self, name: impl Into<String>) -> Self {
161        self.name = name.into();
162        self
163    }
164
165    /// Add an assertion
166    pub fn add(&mut self, assertion: Assertion) {
167        self.assertions.push(assertion);
168    }
169
170    /// Evaluate all assertions against a payload
171    pub fn evaluate(&self, payload: &serde_json::Value) -> Result<AssertionSetResult, AssertionError> {
172        let evaluator = AssertionEvaluator::new();
173        evaluator.evaluate_set(self, payload, None)
174    }
175
176    /// Evaluate with additional context
177    pub fn evaluate_with_context(
178        &self,
179        payload: &serde_json::Value,
180        context: &EvaluationContext,
181    ) -> Result<AssertionSetResult, AssertionError> {
182        let evaluator = AssertionEvaluator::new();
183        evaluator.evaluate_set(self, payload, Some(context))
184    }
185}
186
187/// Additional context for assertion evaluation
188#[derive(Debug, Clone, Default, Serialize, Deserialize)]
189pub struct EvaluationContext {
190    /// Metadata from the request
191    #[serde(default)]
192    pub metadata: HashMap<String, serde_json::Value>,
193
194    /// SType being evaluated
195    #[serde(default)]
196    pub stype: Option<String>,
197
198    /// Tool name (if tool call)
199    #[serde(default)]
200    pub tool_name: Option<String>,
201
202    /// Request arguments (for IC)
203    #[serde(default)]
204    pub arguments: Option<serde_json::Value>,
205
206    /// Response data (for TOC)
207    #[serde(default)]
208    pub response: Option<serde_json::Value>,
209}
210
211/// Result of evaluating an assertion set
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct AssertionSetResult {
214    /// Individual assertion results
215    pub results: Vec<AssertionResult>,
216
217    /// Number of assertions that passed
218    pub passed_count: usize,
219
220    /// Number of assertions that failed
221    pub failed_count: usize,
222
223    /// Number of error-severity failures
224    pub error_count: usize,
225
226    /// Number of warning-severity failures
227    pub warning_count: usize,
228
229    /// Computed IC score (0.0 - 1.0)
230    pub ic_score: f64,
231}
232
233impl AssertionSetResult {
234    /// Check if all assertions passed
235    pub fn passed(&self) -> bool {
236        self.error_count == 0
237    }
238
239    /// Check if any error-severity assertions failed
240    pub fn has_errors(&self) -> bool {
241        self.error_count > 0
242    }
243
244    /// Get failed assertion messages
245    pub fn failure_messages(&self) -> Vec<&str> {
246        self.results
247            .iter()
248            .filter(|r| !r.passed && r.severity == AssertionSeverity::Error)
249            .map(|r| r.message.as_str())
250            .collect()
251    }
252}
253
254/// CEL-based assertion evaluator
255pub struct AssertionEvaluator {
256    // Note: Program doesn't implement Clone, so we compile on each evaluation
257    // This is fast enough for typical use cases
258    _marker: std::marker::PhantomData<()>,
259}
260
261impl Default for AssertionEvaluator {
262    fn default() -> Self {
263        Self::new()
264    }
265}
266
267impl AssertionEvaluator {
268    /// Create a new evaluator
269    pub fn new() -> Self {
270        Self {
271            _marker: std::marker::PhantomData,
272        }
273    }
274
275    /// Evaluate an assertion set
276    pub fn evaluate_set(
277        &self,
278        set: &AssertionSet,
279        payload: &serde_json::Value,
280        context: Option<&EvaluationContext>,
281    ) -> Result<AssertionSetResult, AssertionError> {
282        let mut results = Vec::with_capacity(set.assertions.len());
283        let mut passed_count = 0;
284        let mut failed_count = 0;
285        let mut error_count = 0;
286        let mut warning_count = 0;
287
288        for assertion in &set.assertions {
289            let result = self.evaluate_single(assertion, payload, context);
290
291            match &result {
292                Ok(r) => {
293                    if r.passed {
294                        passed_count += 1;
295                    } else {
296                        failed_count += 1;
297                        match r.severity {
298                            AssertionSeverity::Error => error_count += 1,
299                            AssertionSeverity::Warning => warning_count += 1,
300                            AssertionSeverity::Info => {}
301                        }
302                    }
303                    results.push(r.clone());
304
305                    // Fail fast on error-severity failures
306                    if set.fail_fast && !r.passed && r.severity == AssertionSeverity::Error {
307                        break;
308                    }
309                }
310                Err(e) => {
311                    // Evaluation error counts as failure
312                    failed_count += 1;
313                    error_count += 1;
314                    results.push(AssertionResult {
315                        id: assertion.id.clone(),
316                        passed: false,
317                        message: assertion.message.clone(),
318                        severity: assertion.severity,
319                        actual_value: None,
320                        error: Some(e.to_string()),
321                    });
322                }
323            }
324        }
325
326        // Compute IC score
327        let total = set.assertions.len();
328        let ic_score = if total == 0 {
329            1.0
330        } else {
331            passed_count as f64 / total as f64
332        };
333
334        Ok(AssertionSetResult {
335            results,
336            passed_count,
337            failed_count,
338            error_count,
339            warning_count,
340            ic_score,
341        })
342    }
343
344    /// Evaluate a single assertion
345    pub fn evaluate_single(
346        &self,
347        assertion: &Assertion,
348        payload: &serde_json::Value,
349        context: Option<&EvaluationContext>,
350    ) -> Result<AssertionResult, AssertionError> {
351        // Compile the CEL expression
352        let program = Program::compile(&assertion.expression).map_err(|e| {
353            AssertionError::CompilationError {
354                expr: assertion.expression.clone(),
355                message: format!("{:?}", e),
356            }
357        })?;
358
359        // Build CEL context
360        let mut cel_context = Context::default();
361
362        // Add payload as a variable
363        let payload_value = json_to_cel(payload);
364        cel_context.add_variable("payload", payload_value).ok();
365
366        // Add context variables if provided
367        if let Some(ctx) = context {
368            if let Some(args) = &ctx.arguments {
369                cel_context.add_variable("args", json_to_cel(args)).ok();
370            }
371            if let Some(resp) = &ctx.response {
372                cel_context.add_variable("response", json_to_cel(resp)).ok();
373            }
374            if let Some(stype) = &ctx.stype {
375                cel_context.add_variable("stype", stype.clone()).ok();
376            }
377            if let Some(tool) = &ctx.tool_name {
378                cel_context.add_variable("tool", tool.clone()).ok();
379            }
380
381            // Add metadata
382            let meta_value = json_to_cel(&serde_json::to_value(&ctx.metadata).unwrap_or_default());
383            cel_context.add_variable("metadata", meta_value).ok();
384        }
385
386        // Execute the program
387        let result = program.execute(&cel_context).map_err(|e| {
388            AssertionError::EvaluationError {
389                expr: assertion.expression.clone(),
390                message: format!("{:?}", e),
391            }
392        })?;
393
394        // Check if result is truthy
395        let passed = match &result {
396            Value::Bool(b) => *b,
397            Value::Null => false,
398            Value::Int(i) => *i != 0,
399            Value::UInt(u) => *u != 0,
400            Value::Float(f) => *f != 0.0,
401            Value::String(s) => !s.is_empty(),
402            Value::List(l) => !l.is_empty(),
403            Value::Map(m) => !m.map.is_empty(),
404            _ => true, // Other types are considered truthy
405        };
406
407        debug!(
408            assertion_id = %assertion.id,
409            passed = passed,
410            "Assertion evaluated"
411        );
412
413        Ok(AssertionResult {
414            id: assertion.id.clone(),
415            passed,
416            message: assertion.message.clone(),
417            severity: assertion.severity,
418            actual_value: Some(format!("{:?}", result)),
419            error: None,
420        })
421    }
422}
423
424/// Convert JSON value to CEL value
425fn json_to_cel(value: &serde_json::Value) -> Value {
426    match value {
427        serde_json::Value::Null => Value::Null,
428        serde_json::Value::Bool(b) => Value::Bool(*b),
429        serde_json::Value::Number(n) => {
430            if let Some(i) = n.as_i64() {
431                Value::Int(i)
432            } else if let Some(u) = n.as_u64() {
433                Value::UInt(u)
434            } else if let Some(f) = n.as_f64() {
435                Value::Float(f)
436            } else {
437                Value::Null
438            }
439        }
440        serde_json::Value::String(s) => Value::String(s.clone().into()),
441        serde_json::Value::Array(arr) => {
442            Value::List(arr.iter().map(json_to_cel).collect::<Vec<_>>().into())
443        }
444        serde_json::Value::Object(obj) => {
445            let map: HashMap<String, Value> = obj
446                .iter()
447                .map(|(k, v)| (k.clone(), json_to_cel(v)))
448                .collect();
449            // CEL Map requires Key type, use String keys
450            let cel_map: HashMap<cel_interpreter::objects::Key, Value> = map
451                .into_iter()
452                .map(|(k, v)| (cel_interpreter::objects::Key::String(k.into()), v))
453                .collect();
454            Value::Map(cel_interpreter::objects::Map { map: cel_map.into() })
455        }
456    }
457}
458
459/// Load assertions from a JSON file
460pub fn load_assertions_from_json(json: &str) -> Result<AssertionSet, serde_json::Error> {
461    serde_json::from_str(json)
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467    use serde_json::json;
468
469    #[test]
470    fn test_simple_assertion() {
471        let assertion = Assertion::new(
472            "amount_positive",
473            "payload.amount > 0",
474            "Amount must be positive",
475        );
476
477        let evaluator = AssertionEvaluator::new();
478
479        // Passing case
480        let payload = json!({"amount": 100});
481        let result = evaluator.evaluate_single(&assertion, &payload, None).unwrap();
482        assert!(result.passed);
483
484        // Failing case
485        let payload = json!({"amount": -50});
486        let result = evaluator.evaluate_single(&assertion, &payload, None).unwrap();
487        assert!(!result.passed);
488    }
489
490    #[test]
491    fn test_string_assertion() {
492        let assertion = Assertion::new(
493            "currency_valid",
494            "payload.currency in ['USD', 'EUR', 'GBP']",
495            "Invalid currency",
496        );
497
498        let evaluator = AssertionEvaluator::new();
499
500        let payload = json!({"currency": "USD"});
501        let result = evaluator.evaluate_single(&assertion, &payload, None).unwrap();
502        assert!(result.passed);
503
504        let payload = json!({"currency": "XYZ"});
505        let result = evaluator.evaluate_single(&assertion, &payload, None).unwrap();
506        assert!(!result.passed);
507    }
508
509    #[test]
510    fn test_assertion_set() {
511        let set = AssertionSet::new(vec![
512            Assertion::new("a1", "payload.x > 0", "X must be positive"),
513            Assertion::new("a2", "payload.y < 100", "Y must be less than 100"),
514        ]);
515
516        let payload = json!({"x": 10, "y": 50});
517        let result = set.evaluate(&payload).unwrap();
518        assert!(result.passed());
519        assert_eq!(result.ic_score, 1.0);
520
521        let payload = json!({"x": -5, "y": 50});
522        let result = set.evaluate(&payload).unwrap();
523        assert!(!result.passed());
524        assert_eq!(result.ic_score, 0.5);
525    }
526
527    #[test]
528    fn test_nested_payload() {
529        let assertion = Assertion::new(
530            "nested_check",
531            "payload.user.age >= 18",
532            "User must be 18+",
533        );
534
535        let evaluator = AssertionEvaluator::new();
536
537        let payload = json!({"user": {"name": "Alice", "age": 25}});
538        let result = evaluator.evaluate_single(&assertion, &payload, None).unwrap();
539        assert!(result.passed);
540    }
541
542    #[test]
543    fn test_array_operations() {
544        let assertion = Assertion::new(
545            "has_items",
546            "size(payload.items) > 0",
547            "Items cannot be empty",
548        );
549
550        let evaluator = AssertionEvaluator::new();
551
552        let payload = json!({"items": [1, 2, 3]});
553        let result = evaluator.evaluate_single(&assertion, &payload, None).unwrap();
554        assert!(result.passed);
555
556        let payload = json!({"items": []});
557        let result = evaluator.evaluate_single(&assertion, &payload, None).unwrap();
558        assert!(!result.passed);
559    }
560
561    #[test]
562    fn test_context_variables() {
563        let assertion = Assertion::new(
564            "tool_check",
565            "tool == 'calendar.create'",
566            "Only calendar.create allowed",
567        );
568
569        let evaluator = AssertionEvaluator::new();
570        let payload = json!({});
571
572        let ctx = EvaluationContext {
573            tool_name: Some("calendar.create".to_string()),
574            ..Default::default()
575        };
576
577        let result = evaluator.evaluate_single(&assertion, &payload, Some(&ctx)).unwrap();
578        assert!(result.passed);
579    }
580
581    #[test]
582    fn test_severity_levels() {
583        let set = AssertionSet::new(vec![
584            Assertion::new("error_check", "false", "Error level").with_severity(AssertionSeverity::Error),
585            Assertion::new("warn_check", "false", "Warning level").with_severity(AssertionSeverity::Warning),
586            Assertion::new("info_check", "false", "Info level").with_severity(AssertionSeverity::Info),
587        ]);
588
589        let result = set.evaluate(&json!({})).unwrap();
590        assert_eq!(result.error_count, 1);
591        assert_eq!(result.warning_count, 1);
592        assert_eq!(result.failed_count, 3);
593        assert!(!result.passed()); // Has errors
594    }
595
596    #[test]
597    fn test_load_from_json() {
598        let json = r#"{
599            "name": "finance_checks",
600            "description": "Financial payload validations",
601            "assertions": [
602                {
603                    "id": "amount_check",
604                    "expression": "payload.amount > 0",
605                    "message": "Amount must be positive"
606                },
607                {
608                    "id": "currency_check",
609                    "expression": "payload.currency in ['USD', 'EUR']",
610                    "message": "Invalid currency",
611                    "severity": "warning"
612                }
613            ]
614        }"#;
615
616        let set = load_assertions_from_json(json).unwrap();
617        assert_eq!(set.name, "finance_checks");
618        assert_eq!(set.assertions.len(), 2);
619        assert_eq!(set.assertions[1].severity, AssertionSeverity::Warning);
620    }
621}