1use cel_interpreter::objects::Map;
2use cel_interpreter::{Context, Program, Value};
3use std::collections::HashMap;
4use std::path::Path;
5use std::sync::Arc;
6
7use crate::config::{GuardrailAction, GuardrailDirection, GuardrailOnError, GuardrailRule};
8use crate::message::InboundMessage;
9
10pub fn load_rules_from_dir(dir: &Path) -> Vec<GuardrailRule> {
11 if !dir.exists() {
12 return vec![];
13 }
14 let mut entries: Vec<_> = match std::fs::read_dir(dir) {
15 Ok(rd) => rd.filter_map(|e| e.ok()).collect(),
16 Err(e) => {
17 tracing::error!(error = %e, "Failed to read guardrails dir");
18 return vec![];
19 }
20 };
21 entries.sort_by_key(|e| e.file_name());
22 let mut rules = vec![];
23 for entry in entries {
24 let path = entry.path();
25 if path.extension().and_then(|e| e.to_str()) != Some("json") {
26 continue;
27 }
28 let content = match std::fs::read_to_string(&path) {
29 Ok(c) => c,
30 Err(e) => {
31 tracing::error!(path = %path.display(), error = %e, "Failed to read rule file");
32 continue;
33 }
34 };
35 let rule: GuardrailRule = match serde_json::from_str(&content) {
36 Ok(r) => r,
37 Err(e) => {
38 tracing::error!(path = %path.display(), error = %e, "Failed to parse rule file");
39 continue;
40 }
41 };
42 if !rule.enabled {
43 tracing::debug!(name = %rule.name, "Skipping disabled guardrail rule");
44 continue;
45 }
46 rules.push(rule);
47 }
48 rules
49}
50
51pub fn json_to_cel_value(value: serde_json::Value) -> Value {
65 match value {
66 serde_json::Value::Null => Value::Null,
67 serde_json::Value::Bool(b) => Value::Bool(b),
68 serde_json::Value::String(s) => Value::String(Arc::new(s)),
69 serde_json::Value::Number(n) => {
70 if let Some(i) = n.as_i64() {
71 Value::Int(i)
72 } else if let Some(f) = n.as_f64() {
73 Value::Float(f)
74 } else {
75 Value::String(Arc::new(n.to_string()))
77 }
78 }
79 serde_json::Value::Array(arr) => {
80 let cel_list: Vec<Value> = arr.into_iter().map(json_to_cel_value).collect();
81 Value::List(Arc::new(cel_list))
82 }
83 serde_json::Value::Object(obj) => {
84 let mut map: HashMap<cel_interpreter::objects::Key, Value> = HashMap::new();
85 for (k, v) in obj {
86 map.insert(
87 cel_interpreter::objects::Key::String(Arc::new(k)),
88 json_to_cel_value(v),
89 );
90 }
91 Value::Map(Map::from(map))
92 }
93 }
94}
95
96#[derive(Debug, Clone, PartialEq, Eq)]
97pub enum GuardrailVerdict {
98 Allow,
99 Block {
100 rule_name: String,
101 reject_message: String,
102 },
103}
104
105pub struct CompiledRule {
106 pub name: String,
107 pub program: Arc<Program>,
108 pub action: GuardrailAction,
109 pub direction: GuardrailDirection,
110 pub on_error: GuardrailOnError,
111 pub reject_message: Option<String>,
112}
113
114pub struct GuardrailEngine {
115 rules: Vec<CompiledRule>,
116}
117
118impl GuardrailEngine {
119 pub fn from_rules(rules: Vec<GuardrailRule>) -> Self {
120 let mut compiled = Vec::new();
121 for rule in rules {
122 match Program::compile(&rule.expression) {
123 Ok(program) => {
124 compiled.push(CompiledRule {
125 name: rule.name,
126 program: Arc::new(program),
127 action: rule.action,
128 direction: rule.direction,
129 on_error: rule.on_error,
130 reject_message: rule.reject_message,
131 });
132 }
133 Err(e) => {
134 tracing::error!(
135 rule_name = %rule.name,
136 expression = %rule.expression,
137 error = %e,
138 "Failed to compile guardrail CEL expression, skipping rule"
139 );
140 }
141 }
142 }
143 GuardrailEngine { rules: compiled }
144 }
145
146 pub fn evaluate_inbound(&self, message: &InboundMessage) -> GuardrailVerdict {
147 let mut json_val = match serde_json::to_value(message) {
148 Ok(v) => v,
149 Err(e) => {
150 tracing::error!(error = %e, "Failed to serialize InboundMessage for guardrail evaluation");
151 return GuardrailVerdict::Allow;
152 }
153 };
154 if let Some(obj) = json_val.as_object_mut() {
157 obj.entry("attachments")
158 .or_insert(serde_json::Value::Array(vec![]));
159 }
160 let cel_val = json_to_cel_value(json_val);
161
162 let mut ctx = Context::default();
163 ctx.add_variable_from_value("message", cel_val);
164
165 for rule in &self.rules {
166 if rule.direction == GuardrailDirection::Outbound {
167 continue;
168 }
169
170 match rule.program.execute(&ctx) {
171 Ok(Value::Bool(true)) => match rule.action {
172 GuardrailAction::Block => {
173 let reject_msg = rule
174 .reject_message
175 .clone()
176 .unwrap_or_else(|| rule.name.clone());
177 return GuardrailVerdict::Block {
178 rule_name: rule.name.clone(),
179 reject_message: reject_msg,
180 };
181 }
182 GuardrailAction::Log => {
183 tracing::warn!(
184 rule_name = %rule.name,
185 "Guardrail rule matched (log only)"
186 );
187 }
188 },
189 Ok(_) => {}
190 Err(e) => {
191 tracing::error!(
192 rule_name = %rule.name,
193 error = %e,
194 "Guardrail rule evaluation error"
195 );
196 match rule.on_error {
197 GuardrailOnError::Block => {
198 let reject_msg = rule
199 .reject_message
200 .clone()
201 .unwrap_or_else(|| rule.name.clone());
202 return GuardrailVerdict::Block {
203 rule_name: rule.name.clone(),
204 reject_message: reject_msg,
205 };
206 }
207 GuardrailOnError::Allow => {}
208 }
209 }
210 }
211 }
212
213 GuardrailVerdict::Allow
214 }
215
216 #[allow(dead_code)]
217 pub fn is_empty(&self) -> bool {
218 self.rules.is_empty()
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use cel_interpreter::objects::Key;
226 use serde_json::json;
227 use std::fs;
228 use tempfile::TempDir;
229
230 fn write_rule(dir: &TempDir, filename: &str, content: &str) {
231 fs::write(dir.path().join(filename), content).unwrap();
232 }
233
234 fn minimal_rule_json(name: &str) -> String {
235 format!(
236 r#"{{"name":"{name}","expression":"true","enabled":true}}"#,
237 name = name
238 )
239 }
240
241 fn disabled_rule_json(name: &str) -> String {
242 format!(
243 r#"{{"name":"{name}","expression":"true","enabled":false}}"#,
244 name = name
245 )
246 }
247
248 #[test]
249 fn test_load_rules_three_valid_files_in_filename_order() {
250 let dir = TempDir::new().unwrap();
251 write_rule(&dir, "03_c.json", &minimal_rule_json("rule_c"));
252 write_rule(&dir, "01_a.json", &minimal_rule_json("rule_a"));
253 write_rule(&dir, "02_b.json", &minimal_rule_json("rule_b"));
254
255 let rules = load_rules_from_dir(dir.path());
256
257 assert_eq!(rules.len(), 3);
258 assert_eq!(rules[0].name, "rule_a");
259 assert_eq!(rules[1].name, "rule_b");
260 assert_eq!(rules[2].name, "rule_c");
261 }
262
263 #[test]
264 fn test_load_rules_sorted_lexicographically() {
265 let dir = TempDir::new().unwrap();
266 write_rule(&dir, "02_b.json", &minimal_rule_json("rule_b"));
267 write_rule(&dir, "01_a.json", &minimal_rule_json("rule_a"));
268
269 let rules = load_rules_from_dir(dir.path());
270
271 assert_eq!(rules.len(), 2);
272 assert_eq!(rules[0].name, "rule_a");
273 assert_eq!(rules[1].name, "rule_b");
274 }
275
276 #[test]
277 fn test_load_rules_skips_malformed_json() {
278 let dir = TempDir::new().unwrap();
279 write_rule(&dir, "01_valid.json", &minimal_rule_json("valid_rule"));
280 write_rule(&dir, "02_bad.json", "this is not json {{{");
281
282 let rules = load_rules_from_dir(dir.path());
283
284 assert_eq!(rules.len(), 1);
285 assert_eq!(rules[0].name, "valid_rule");
286 }
287
288 #[test]
289 fn test_load_rules_skips_disabled_rules() {
290 let dir = TempDir::new().unwrap();
291 write_rule(&dir, "01_enabled.json", &minimal_rule_json("enabled_rule"));
292 write_rule(
293 &dir,
294 "02_disabled.json",
295 &disabled_rule_json("disabled_rule"),
296 );
297
298 let rules = load_rules_from_dir(dir.path());
299
300 assert_eq!(rules.len(), 1);
301 assert_eq!(rules[0].name, "enabled_rule");
302 }
303
304 #[test]
305 fn test_load_rules_nonexistent_dir_returns_empty() {
306 let rules = load_rules_from_dir(std::path::Path::new(
307 "/nonexistent/path/that/does/not/exist",
308 ));
309 assert!(rules.is_empty());
310 }
311
312 #[test]
313 fn test_load_rules_empty_dir_returns_empty() {
314 let dir = TempDir::new().unwrap();
315 let rules = load_rules_from_dir(dir.path());
316 assert!(rules.is_empty());
317 }
318
319 #[test]
320 fn test_load_rules_ignores_non_json_files() {
321 let dir = TempDir::new().unwrap();
322 write_rule(&dir, "rule.txt", &minimal_rule_json("txt_rule"));
323 write_rule(
324 &dir,
325 "rule.disabled",
326 &minimal_rule_json("disabled_ext_rule"),
327 );
328 write_rule(&dir, "rule.json", &minimal_rule_json("json_rule"));
329
330 let rules = load_rules_from_dir(dir.path());
331
332 assert_eq!(rules.len(), 1);
333 assert_eq!(rules[0].name, "json_rule");
334 }
335
336 fn get_map_value(cel_val: &Value, key: &str) -> Option<Value> {
337 if let Value::Map(m) = cel_val {
338 m.map.get(&Key::String(Arc::new(key.to_string()))).cloned()
339 } else {
340 None
341 }
342 }
343
344 #[test]
345 fn test_null() {
346 let result = json_to_cel_value(serde_json::Value::Null);
347 assert!(matches!(result, Value::Null));
348 }
349
350 #[test]
351 fn test_bool_true() {
352 let result = json_to_cel_value(json!(true));
353 assert!(matches!(result, Value::Bool(true)));
354 }
355
356 #[test]
357 fn test_bool_false() {
358 let result = json_to_cel_value(json!(false));
359 assert!(matches!(result, Value::Bool(false)));
360 }
361
362 #[test]
363 fn test_string() {
364 let result = json_to_cel_value(json!("hello"));
365 assert!(matches!(result, Value::String(s) if s.as_ref() == "hello"));
366 }
367
368 #[test]
369 fn test_empty_string() {
370 let result = json_to_cel_value(json!(""));
371 assert!(matches!(result, Value::String(s) if s.as_ref() == ""));
372 }
373
374 #[test]
375 fn test_positive_integer() {
376 let result = json_to_cel_value(json!(42));
377 assert!(matches!(result, Value::Int(42)));
378 }
379
380 #[test]
381 fn test_negative_integer() {
382 let result = json_to_cel_value(json!(-7));
383 assert!(matches!(result, Value::Int(-7)));
384 }
385
386 #[test]
387 fn test_zero() {
388 let result = json_to_cel_value(json!(0));
389 assert!(matches!(result, Value::Int(0)));
390 }
391
392 #[test]
393 fn test_float() {
394 let result = json_to_cel_value(json!(1.5));
395 if let Value::Float(f) = result {
396 assert!((f - 1.5_f64).abs() < 1e-10);
397 } else {
398 panic!("Expected Float, got {:?}", result);
399 }
400 }
401
402 #[test]
403 fn test_empty_array() {
404 let result = json_to_cel_value(json!([]));
405 if let Value::List(list) = result {
406 assert!(list.is_empty());
407 } else {
408 panic!("Expected List");
409 }
410 }
411
412 #[test]
413 fn test_array_of_strings() {
414 let result = json_to_cel_value(json!(["a", "b", "c"]));
415 if let Value::List(list) = result {
416 assert_eq!(list.len(), 3);
417 assert!(matches!(&list[0], Value::String(s) if s.as_ref() == "a"));
418 assert!(matches!(&list[1], Value::String(s) if s.as_ref() == "b"));
419 assert!(matches!(&list[2], Value::String(s) if s.as_ref() == "c"));
420 } else {
421 panic!("Expected List");
422 }
423 }
424
425 #[test]
426 fn test_mixed_array() {
427 let result = json_to_cel_value(json!([1, "two", null, true]));
428 if let Value::List(list) = result {
429 assert_eq!(list.len(), 4);
430 assert!(matches!(&list[0], Value::Int(1)));
431 assert!(matches!(&list[1], Value::String(s) if s.as_ref() == "two"));
432 assert!(matches!(&list[2], Value::Null));
433 assert!(matches!(&list[3], Value::Bool(true)));
434 } else {
435 panic!("Expected List");
436 }
437 }
438
439 #[test]
440 fn test_empty_object() {
441 let result = json_to_cel_value(json!({}));
442 assert!(matches!(result, Value::Map(_)));
443 if let Value::Map(m) = result {
444 assert!(m.map.is_empty());
445 }
446 }
447
448 #[test]
449 fn test_simple_object() {
450 let result = json_to_cel_value(json!({"name": "alice", "age": 30}));
451 let name = get_map_value(&result, "name");
452 let age = get_map_value(&result, "age");
453 assert!(matches!(name, Some(Value::String(s)) if s.as_ref() == "alice"));
454 assert!(matches!(age, Some(Value::Int(30))));
455 }
456
457 #[test]
458 fn test_nested_object() {
459 let result = json_to_cel_value(json!({"a": {"b": null}}));
460 let a = get_map_value(&result, "a").expect("key 'a' missing");
461 let b = get_map_value(&a, "b").expect("key 'b' missing");
462 assert!(matches!(b, Value::Null));
463 }
464
465 #[test]
466 fn test_option_field_as_null() {
467 let result = json_to_cel_value(json!({"username": null}));
469 let username = get_map_value(&result, "username");
470 assert!(
471 matches!(username, Some(Value::Null)),
472 "Option<T> None must map to CEL Null, got {:?}",
473 username
474 );
475 }
476
477 #[test]
478 fn test_complex_nested() {
479 let result = json_to_cel_value(json!({
480 "source": {
481 "protocol": "telegram",
482 "from": {
483 "id": "123",
484 "username": null
485 }
486 },
487 "text": "hello",
488 "attachments": []
489 }));
490 let source = get_map_value(&result, "source").expect("source missing");
491 let protocol = get_map_value(&source, "protocol").expect("protocol missing");
492 assert!(matches!(protocol, Value::String(s) if s.as_ref() == "telegram"));
493
494 let from = get_map_value(&source, "from").expect("from missing");
495 let username = get_map_value(&from, "username").expect("username key missing");
496 assert!(matches!(username, Value::Null));
497
498 let attachments = get_map_value(&result, "attachments").expect("attachments missing");
499 assert!(matches!(attachments, Value::List(l) if l.is_empty()));
500 }
501
502 use crate::config::{GuardrailAction, GuardrailDirection, GuardrailOnError, GuardrailType};
503 use crate::message::{InboundMessage, MessageSource, UserInfo};
504 use chrono::Utc;
505
506 fn make_rule(
507 name: &str,
508 expression: &str,
509 action: GuardrailAction,
510 direction: GuardrailDirection,
511 on_error: GuardrailOnError,
512 reject_message: Option<&str>,
513 ) -> GuardrailRule {
514 GuardrailRule {
515 name: name.to_string(),
516 r#type: GuardrailType::Cel,
517 expression: expression.to_string(),
518 action,
519 direction,
520 on_error,
521 reject_message: reject_message.map(|s| s.to_string()),
522 enabled: true,
523 }
524 }
525
526 fn test_message(text: &str) -> InboundMessage {
527 InboundMessage {
528 route: json!({"channel": "test"}),
529 credential_id: "test_cred".to_string(),
530 source: MessageSource {
531 protocol: "test".to_string(),
532 chat_id: "chat_1".to_string(),
533 message_id: "msg_1".to_string(),
534 reply_to_message_id: None,
535 from: UserInfo {
536 id: "user_1".to_string(),
537 username: Some("testuser".to_string()),
538 display_name: None,
539 },
540 },
541 text: text.to_string(),
542 attachments: vec![],
543 timestamp: Utc::now(),
544 extra_data: None,
545 }
546 }
547
548 #[test]
549 fn test_engine_true_block_returns_block() {
550 let rules = vec![make_rule(
551 "block_all",
552 "true",
553 GuardrailAction::Block,
554 GuardrailDirection::Inbound,
555 GuardrailOnError::Allow,
556 None,
557 )];
558 let engine = GuardrailEngine::from_rules(rules);
559 let msg = test_message("hello");
560 let verdict = engine.evaluate_inbound(&msg);
561 assert_eq!(
562 verdict,
563 GuardrailVerdict::Block {
564 rule_name: "block_all".to_string(),
565 reject_message: "block_all".to_string(),
566 }
567 );
568 }
569
570 #[test]
571 fn test_engine_false_block_returns_allow() {
572 let rules = vec![make_rule(
573 "never_fire",
574 "false",
575 GuardrailAction::Block,
576 GuardrailDirection::Inbound,
577 GuardrailOnError::Allow,
578 None,
579 )];
580 let engine = GuardrailEngine::from_rules(rules);
581 let msg = test_message("hello");
582 assert_eq!(engine.evaluate_inbound(&msg), GuardrailVerdict::Allow);
583 }
584
585 #[test]
586 fn test_engine_true_log_returns_allow() {
587 let rules = vec![make_rule(
588 "log_all",
589 "true",
590 GuardrailAction::Log,
591 GuardrailDirection::Inbound,
592 GuardrailOnError::Allow,
593 None,
594 )];
595 let engine = GuardrailEngine::from_rules(rules);
596 let msg = test_message("hello");
597 assert_eq!(engine.evaluate_inbound(&msg), GuardrailVerdict::Allow);
598 }
599
600 #[test]
601 fn test_engine_short_circuits_on_first_block() {
602 let rules = vec![
603 make_rule(
604 "first",
605 "true",
606 GuardrailAction::Block,
607 GuardrailDirection::Inbound,
608 GuardrailOnError::Allow,
609 None,
610 ),
611 make_rule(
612 "second",
613 "true",
614 GuardrailAction::Block,
615 GuardrailDirection::Inbound,
616 GuardrailOnError::Allow,
617 None,
618 ),
619 ];
620 let engine = GuardrailEngine::from_rules(rules);
621 let msg = test_message("hello");
622 let verdict = engine.evaluate_inbound(&msg);
623 assert_eq!(
624 verdict,
625 GuardrailVerdict::Block {
626 rule_name: "first".to_string(),
627 reject_message: "first".to_string(),
628 }
629 );
630 }
631
632 #[test]
633 fn test_engine_invalid_expression_skipped() {
634 let rules = vec![
635 make_rule(
636 "bad_rule",
637 "this is not valid CEL !!!",
638 GuardrailAction::Block,
639 GuardrailDirection::Inbound,
640 GuardrailOnError::Allow,
641 None,
642 ),
643 make_rule(
644 "good_rule",
645 "true",
646 GuardrailAction::Block,
647 GuardrailDirection::Inbound,
648 GuardrailOnError::Allow,
649 None,
650 ),
651 ];
652 let engine = GuardrailEngine::from_rules(rules);
653 assert_eq!(engine.rules.len(), 1);
654 let msg = test_message("hello");
655 let verdict = engine.evaluate_inbound(&msg);
656 assert_eq!(
657 verdict,
658 GuardrailVerdict::Block {
659 rule_name: "good_rule".to_string(),
660 reject_message: "good_rule".to_string(),
661 }
662 );
663 }
664
665 #[test]
666 fn test_engine_on_error_allow() {
667 let rules = vec![make_rule(
668 "error_rule",
669 "message.nonexistent_field == true",
670 GuardrailAction::Block,
671 GuardrailDirection::Inbound,
672 GuardrailOnError::Allow,
673 None,
674 )];
675 let engine = GuardrailEngine::from_rules(rules);
676 let msg = test_message("hello");
677 assert_eq!(engine.evaluate_inbound(&msg), GuardrailVerdict::Allow);
678 }
679
680 #[test]
681 fn test_engine_on_error_block() {
682 let rules = vec![make_rule(
683 "error_rule",
684 "message.nonexistent_field == true",
685 GuardrailAction::Block,
686 GuardrailDirection::Inbound,
687 GuardrailOnError::Block,
688 None,
689 )];
690 let engine = GuardrailEngine::from_rules(rules);
691 let msg = test_message("hello");
692 assert_eq!(
693 engine.evaluate_inbound(&msg),
694 GuardrailVerdict::Block {
695 rule_name: "error_rule".to_string(),
696 reject_message: "error_rule".to_string(),
697 }
698 );
699 }
700
701 #[test]
702 fn test_engine_message_text_matches_password_block() {
703 let rules = vec![make_rule(
704 "no_passwords",
705 r#"message.text.matches("password")"#,
706 GuardrailAction::Block,
707 GuardrailDirection::Inbound,
708 GuardrailOnError::Allow,
709 Some("Message contains sensitive content"),
710 )];
711 let engine = GuardrailEngine::from_rules(rules);
712 let msg = test_message("my password is secret");
713 let verdict = engine.evaluate_inbound(&msg);
714 assert_eq!(
715 verdict,
716 GuardrailVerdict::Block {
717 rule_name: "no_passwords".to_string(),
718 reject_message: "Message contains sensitive content".to_string(),
719 }
720 );
721 }
722
723 #[test]
724 fn test_engine_message_text_matches_password_allow() {
725 let rules = vec![make_rule(
726 "no_passwords",
727 r#"message.text.matches("password")"#,
728 GuardrailAction::Block,
729 GuardrailDirection::Inbound,
730 GuardrailOnError::Allow,
731 None,
732 )];
733 let engine = GuardrailEngine::from_rules(rules);
734 let msg = test_message("hello world");
735 assert_eq!(engine.evaluate_inbound(&msg), GuardrailVerdict::Allow);
736 }
737
738 #[test]
739 fn test_engine_is_empty() {
740 let empty = GuardrailEngine::from_rules(vec![]);
741 assert!(empty.is_empty());
742
743 let non_empty = GuardrailEngine::from_rules(vec![make_rule(
744 "rule",
745 "true",
746 GuardrailAction::Block,
747 GuardrailDirection::Inbound,
748 GuardrailOnError::Allow,
749 None,
750 )]);
751 assert!(!non_empty.is_empty());
752 }
753
754 #[test]
755 fn test_engine_outbound_rule_skipped_for_inbound() {
756 let rules = vec![make_rule(
757 "outbound_only",
758 "true",
759 GuardrailAction::Block,
760 GuardrailDirection::Outbound,
761 GuardrailOnError::Allow,
762 None,
763 )];
764 let engine = GuardrailEngine::from_rules(rules);
765 let msg = test_message("hello");
766 assert_eq!(engine.evaluate_inbound(&msg), GuardrailVerdict::Allow);
767 }
768
769 #[test]
770 fn test_engine_both_direction_applies_to_inbound() {
771 let rules = vec![make_rule(
772 "both_dir",
773 "true",
774 GuardrailAction::Block,
775 GuardrailDirection::Both,
776 GuardrailOnError::Allow,
777 None,
778 )];
779 let engine = GuardrailEngine::from_rules(rules);
780 let msg = test_message("hello");
781 assert_eq!(
782 engine.evaluate_inbound(&msg),
783 GuardrailVerdict::Block {
784 rule_name: "both_dir".to_string(),
785 reject_message: "both_dir".to_string(),
786 }
787 );
788 }
789}