1use cel_interpreter::{Context, Program, Value};
23use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25use thiserror::Error;
26use tracing::debug;
27
28#[derive(Debug, Error)]
30pub enum AssertionError {
31 #[error("Failed to compile CEL expression '{expr}': {message}")]
32 CompilationError { expr: String, message: String },
33
34 #[error("Failed to evaluate CEL expression '{expr}': {message}")]
35 EvaluationError { expr: String, message: String },
36
37 #[error("Invalid payload: {0}")]
38 InvalidPayload(String),
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct Assertion {
44 pub id: String,
46
47 pub expression: String,
50
51 pub message: String,
53
54 #[serde(default = "default_severity")]
56 pub severity: AssertionSeverity,
57
58 #[serde(default)]
60 pub tags: Vec<String>,
61}
62
63fn default_severity() -> AssertionSeverity {
64 AssertionSeverity::Error
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
69#[serde(rename_all = "lowercase")]
70pub enum AssertionSeverity {
71 #[default]
73 Error,
74 Warning,
76 Info,
78}
79
80impl Assertion {
81 pub fn new(id: impl Into<String>, expression: impl Into<String>, message: impl Into<String>) -> Self {
83 Self {
84 id: id.into(),
85 expression: expression.into(),
86 message: message.into(),
87 severity: AssertionSeverity::Error,
88 tags: Vec::new(),
89 }
90 }
91
92 pub fn with_severity(mut self, severity: AssertionSeverity) -> Self {
94 self.severity = severity;
95 self
96 }
97
98 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
100 self.tags = tags;
101 self
102 }
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct AssertionResult {
108 pub id: String,
110
111 pub passed: bool,
113
114 pub message: String,
116
117 pub severity: AssertionSeverity,
119
120 #[serde(skip_serializing_if = "Option::is_none")]
122 pub actual_value: Option<String>,
123
124 #[serde(skip_serializing_if = "Option::is_none")]
126 pub error: Option<String>,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize, Default)]
131pub struct AssertionSet {
132 #[serde(default)]
134 pub name: String,
135
136 #[serde(default)]
138 pub description: String,
139
140 pub assertions: Vec<Assertion>,
142
143 #[serde(default)]
145 pub fail_fast: bool,
146}
147
148impl AssertionSet {
149 pub fn new(assertions: Vec<Assertion>) -> Self {
151 Self {
152 name: String::new(),
153 description: String::new(),
154 assertions,
155 fail_fast: false,
156 }
157 }
158
159 pub fn with_name(mut self, name: impl Into<String>) -> Self {
161 self.name = name.into();
162 self
163 }
164
165 pub fn add(&mut self, assertion: Assertion) {
167 self.assertions.push(assertion);
168 }
169
170 pub fn evaluate(&self, payload: &serde_json::Value) -> Result<AssertionSetResult, AssertionError> {
172 let evaluator = AssertionEvaluator::new();
173 evaluator.evaluate_set(self, payload, None)
174 }
175
176 pub fn evaluate_with_context(
178 &self,
179 payload: &serde_json::Value,
180 context: &EvaluationContext,
181 ) -> Result<AssertionSetResult, AssertionError> {
182 let evaluator = AssertionEvaluator::new();
183 evaluator.evaluate_set(self, payload, Some(context))
184 }
185}
186
187#[derive(Debug, Clone, Default, Serialize, Deserialize)]
189pub struct EvaluationContext {
190 #[serde(default)]
192 pub metadata: HashMap<String, serde_json::Value>,
193
194 #[serde(default)]
196 pub stype: Option<String>,
197
198 #[serde(default)]
200 pub tool_name: Option<String>,
201
202 #[serde(default)]
204 pub arguments: Option<serde_json::Value>,
205
206 #[serde(default)]
208 pub response: Option<serde_json::Value>,
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct AssertionSetResult {
214 pub results: Vec<AssertionResult>,
216
217 pub passed_count: usize,
219
220 pub failed_count: usize,
222
223 pub error_count: usize,
225
226 pub warning_count: usize,
228
229 pub ic_score: f64,
231}
232
233impl AssertionSetResult {
234 pub fn passed(&self) -> bool {
236 self.error_count == 0
237 }
238
239 pub fn has_errors(&self) -> bool {
241 self.error_count > 0
242 }
243
244 pub fn failure_messages(&self) -> Vec<&str> {
246 self.results
247 .iter()
248 .filter(|r| !r.passed && r.severity == AssertionSeverity::Error)
249 .map(|r| r.message.as_str())
250 .collect()
251 }
252}
253
254pub struct AssertionEvaluator {
256 _marker: std::marker::PhantomData<()>,
259}
260
261impl Default for AssertionEvaluator {
262 fn default() -> Self {
263 Self::new()
264 }
265}
266
267impl AssertionEvaluator {
268 pub fn new() -> Self {
270 Self {
271 _marker: std::marker::PhantomData,
272 }
273 }
274
275 pub fn evaluate_set(
277 &self,
278 set: &AssertionSet,
279 payload: &serde_json::Value,
280 context: Option<&EvaluationContext>,
281 ) -> Result<AssertionSetResult, AssertionError> {
282 let mut results = Vec::with_capacity(set.assertions.len());
283 let mut passed_count = 0;
284 let mut failed_count = 0;
285 let mut error_count = 0;
286 let mut warning_count = 0;
287
288 for assertion in &set.assertions {
289 let result = self.evaluate_single(assertion, payload, context);
290
291 match &result {
292 Ok(r) => {
293 if r.passed {
294 passed_count += 1;
295 } else {
296 failed_count += 1;
297 match r.severity {
298 AssertionSeverity::Error => error_count += 1,
299 AssertionSeverity::Warning => warning_count += 1,
300 AssertionSeverity::Info => {}
301 }
302 }
303 results.push(r.clone());
304
305 if set.fail_fast && !r.passed && r.severity == AssertionSeverity::Error {
307 break;
308 }
309 }
310 Err(e) => {
311 failed_count += 1;
313 error_count += 1;
314 results.push(AssertionResult {
315 id: assertion.id.clone(),
316 passed: false,
317 message: assertion.message.clone(),
318 severity: assertion.severity,
319 actual_value: None,
320 error: Some(e.to_string()),
321 });
322 }
323 }
324 }
325
326 let total = set.assertions.len();
328 let ic_score = if total == 0 {
329 1.0
330 } else {
331 passed_count as f64 / total as f64
332 };
333
334 Ok(AssertionSetResult {
335 results,
336 passed_count,
337 failed_count,
338 error_count,
339 warning_count,
340 ic_score,
341 })
342 }
343
344 pub fn evaluate_single(
346 &self,
347 assertion: &Assertion,
348 payload: &serde_json::Value,
349 context: Option<&EvaluationContext>,
350 ) -> Result<AssertionResult, AssertionError> {
351 let program = Program::compile(&assertion.expression).map_err(|e| {
353 AssertionError::CompilationError {
354 expr: assertion.expression.clone(),
355 message: format!("{:?}", e),
356 }
357 })?;
358
359 let mut cel_context = Context::default();
361
362 let payload_value = json_to_cel(payload);
364 cel_context.add_variable("payload", payload_value).ok();
365
366 if let Some(ctx) = context {
368 if let Some(args) = &ctx.arguments {
369 cel_context.add_variable("args", json_to_cel(args)).ok();
370 }
371 if let Some(resp) = &ctx.response {
372 cel_context.add_variable("response", json_to_cel(resp)).ok();
373 }
374 if let Some(stype) = &ctx.stype {
375 cel_context.add_variable("stype", stype.clone()).ok();
376 }
377 if let Some(tool) = &ctx.tool_name {
378 cel_context.add_variable("tool", tool.clone()).ok();
379 }
380
381 let meta_value = json_to_cel(&serde_json::to_value(&ctx.metadata).unwrap_or_default());
383 cel_context.add_variable("metadata", meta_value).ok();
384 }
385
386 let result = program.execute(&cel_context).map_err(|e| {
388 AssertionError::EvaluationError {
389 expr: assertion.expression.clone(),
390 message: format!("{:?}", e),
391 }
392 })?;
393
394 let passed = match &result {
396 Value::Bool(b) => *b,
397 Value::Null => false,
398 Value::Int(i) => *i != 0,
399 Value::UInt(u) => *u != 0,
400 Value::Float(f) => *f != 0.0,
401 Value::String(s) => !s.is_empty(),
402 Value::List(l) => !l.is_empty(),
403 Value::Map(m) => !m.map.is_empty(),
404 _ => true, };
406
407 debug!(
408 assertion_id = %assertion.id,
409 passed = passed,
410 "Assertion evaluated"
411 );
412
413 Ok(AssertionResult {
414 id: assertion.id.clone(),
415 passed,
416 message: assertion.message.clone(),
417 severity: assertion.severity,
418 actual_value: Some(format!("{:?}", result)),
419 error: None,
420 })
421 }
422}
423
424fn json_to_cel(value: &serde_json::Value) -> Value {
426 match value {
427 serde_json::Value::Null => Value::Null,
428 serde_json::Value::Bool(b) => Value::Bool(*b),
429 serde_json::Value::Number(n) => {
430 if let Some(i) = n.as_i64() {
431 Value::Int(i)
432 } else if let Some(u) = n.as_u64() {
433 Value::UInt(u)
434 } else if let Some(f) = n.as_f64() {
435 Value::Float(f)
436 } else {
437 Value::Null
438 }
439 }
440 serde_json::Value::String(s) => Value::String(s.clone().into()),
441 serde_json::Value::Array(arr) => {
442 Value::List(arr.iter().map(json_to_cel).collect::<Vec<_>>().into())
443 }
444 serde_json::Value::Object(obj) => {
445 let map: HashMap<String, Value> = obj
446 .iter()
447 .map(|(k, v)| (k.clone(), json_to_cel(v)))
448 .collect();
449 let cel_map: HashMap<cel_interpreter::objects::Key, Value> = map
451 .into_iter()
452 .map(|(k, v)| (cel_interpreter::objects::Key::String(k.into()), v))
453 .collect();
454 Value::Map(cel_interpreter::objects::Map { map: cel_map.into() })
455 }
456 }
457}
458
459pub fn load_assertions_from_json(json: &str) -> Result<AssertionSet, serde_json::Error> {
461 serde_json::from_str(json)
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use serde_json::json;
468
469 #[test]
470 fn test_simple_assertion() {
471 let assertion = Assertion::new(
472 "amount_positive",
473 "payload.amount > 0",
474 "Amount must be positive",
475 );
476
477 let evaluator = AssertionEvaluator::new();
478
479 let payload = json!({"amount": 100});
481 let result = evaluator.evaluate_single(&assertion, &payload, None).unwrap();
482 assert!(result.passed);
483
484 let payload = json!({"amount": -50});
486 let result = evaluator.evaluate_single(&assertion, &payload, None).unwrap();
487 assert!(!result.passed);
488 }
489
490 #[test]
491 fn test_string_assertion() {
492 let assertion = Assertion::new(
493 "currency_valid",
494 "payload.currency in ['USD', 'EUR', 'GBP']",
495 "Invalid currency",
496 );
497
498 let evaluator = AssertionEvaluator::new();
499
500 let payload = json!({"currency": "USD"});
501 let result = evaluator.evaluate_single(&assertion, &payload, None).unwrap();
502 assert!(result.passed);
503
504 let payload = json!({"currency": "XYZ"});
505 let result = evaluator.evaluate_single(&assertion, &payload, None).unwrap();
506 assert!(!result.passed);
507 }
508
509 #[test]
510 fn test_assertion_set() {
511 let set = AssertionSet::new(vec![
512 Assertion::new("a1", "payload.x > 0", "X must be positive"),
513 Assertion::new("a2", "payload.y < 100", "Y must be less than 100"),
514 ]);
515
516 let payload = json!({"x": 10, "y": 50});
517 let result = set.evaluate(&payload).unwrap();
518 assert!(result.passed());
519 assert_eq!(result.ic_score, 1.0);
520
521 let payload = json!({"x": -5, "y": 50});
522 let result = set.evaluate(&payload).unwrap();
523 assert!(!result.passed());
524 assert_eq!(result.ic_score, 0.5);
525 }
526
527 #[test]
528 fn test_nested_payload() {
529 let assertion = Assertion::new(
530 "nested_check",
531 "payload.user.age >= 18",
532 "User must be 18+",
533 );
534
535 let evaluator = AssertionEvaluator::new();
536
537 let payload = json!({"user": {"name": "Alice", "age": 25}});
538 let result = evaluator.evaluate_single(&assertion, &payload, None).unwrap();
539 assert!(result.passed);
540 }
541
542 #[test]
543 fn test_array_operations() {
544 let assertion = Assertion::new(
545 "has_items",
546 "size(payload.items) > 0",
547 "Items cannot be empty",
548 );
549
550 let evaluator = AssertionEvaluator::new();
551
552 let payload = json!({"items": [1, 2, 3]});
553 let result = evaluator.evaluate_single(&assertion, &payload, None).unwrap();
554 assert!(result.passed);
555
556 let payload = json!({"items": []});
557 let result = evaluator.evaluate_single(&assertion, &payload, None).unwrap();
558 assert!(!result.passed);
559 }
560
561 #[test]
562 fn test_context_variables() {
563 let assertion = Assertion::new(
564 "tool_check",
565 "tool == 'calendar.create'",
566 "Only calendar.create allowed",
567 );
568
569 let evaluator = AssertionEvaluator::new();
570 let payload = json!({});
571
572 let ctx = EvaluationContext {
573 tool_name: Some("calendar.create".to_string()),
574 ..Default::default()
575 };
576
577 let result = evaluator.evaluate_single(&assertion, &payload, Some(&ctx)).unwrap();
578 assert!(result.passed);
579 }
580
581 #[test]
582 fn test_severity_levels() {
583 let set = AssertionSet::new(vec![
584 Assertion::new("error_check", "false", "Error level").with_severity(AssertionSeverity::Error),
585 Assertion::new("warn_check", "false", "Warning level").with_severity(AssertionSeverity::Warning),
586 Assertion::new("info_check", "false", "Info level").with_severity(AssertionSeverity::Info),
587 ]);
588
589 let result = set.evaluate(&json!({})).unwrap();
590 assert_eq!(result.error_count, 1);
591 assert_eq!(result.warning_count, 1);
592 assert_eq!(result.failed_count, 3);
593 assert!(!result.passed()); }
595
596 #[test]
597 fn test_load_from_json() {
598 let json = r#"{
599 "name": "finance_checks",
600 "description": "Financial payload validations",
601 "assertions": [
602 {
603 "id": "amount_check",
604 "expression": "payload.amount > 0",
605 "message": "Amount must be positive"
606 },
607 {
608 "id": "currency_check",
609 "expression": "payload.currency in ['USD', 'EUR']",
610 "message": "Invalid currency",
611 "severity": "warning"
612 }
613 ]
614 }"#;
615
616 let set = load_assertions_from_json(json).unwrap();
617 assert_eq!(set.name, "finance_checks");
618 assert_eq!(set.assertions.len(), 2);
619 assert_eq!(set.assertions[1].severity, AssertionSeverity::Warning);
620 }
621}