1use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use async_trait::async_trait;
12use pingora_timeout::timeout;
13use sentinel_agent_protocol::{
14 GuardrailDetection, GuardrailInspectEvent, GuardrailInspectionType, GuardrailResponse,
15};
16use sentinel_config::{
17 GuardrailAction, GuardrailFailureMode, PiiDetectionConfig, PromptInjectionConfig,
18};
19use tracing::{debug, trace, warn};
20
21use crate::agents::AgentManager;
22
23#[derive(Debug)]
25pub enum PromptInjectionResult {
26 Clean,
28 Blocked {
30 status: u16,
31 message: String,
32 detections: Vec<GuardrailDetection>,
33 },
34 Detected { detections: Vec<GuardrailDetection> },
36 Warning { detections: Vec<GuardrailDetection> },
38 Error { message: String },
40}
41
42#[derive(Debug)]
44pub enum PiiCheckResult {
45 Clean,
47 Detected {
49 detections: Vec<GuardrailDetection>,
50 redacted_content: Option<String>,
51 },
52 Error { message: String },
54}
55
56#[async_trait]
60pub trait GuardrailAgentCaller: Send + Sync {
61 async fn call_guardrail_agent(
63 &self,
64 agent_name: &str,
65 event: GuardrailInspectEvent,
66 ) -> Result<GuardrailResponse, String>;
67}
68
69pub struct AgentManagerCaller {
71 #[allow(dead_code)]
72 agent_manager: Arc<AgentManager>,
73}
74
75impl AgentManagerCaller {
76 pub fn new(agent_manager: Arc<AgentManager>) -> Self {
78 Self { agent_manager }
79 }
80}
81
82#[async_trait]
83impl GuardrailAgentCaller for AgentManagerCaller {
84 async fn call_guardrail_agent(
85 &self,
86 agent_name: &str,
87 event: GuardrailInspectEvent,
88 ) -> Result<GuardrailResponse, String> {
89 trace!(
96 agent = agent_name,
97 inspection_type = ?event.inspection_type,
98 "Calling guardrail agent"
99 );
100
101 Err(format!(
104 "Agent '{}' not configured for guardrail inspection",
105 agent_name
106 ))
107 }
108}
109
110pub struct GuardrailProcessor {
115 agent_caller: Arc<dyn GuardrailAgentCaller>,
116}
117
118impl GuardrailProcessor {
119 pub fn new(agent_manager: Arc<AgentManager>) -> Self {
121 Self {
122 agent_caller: Arc::new(AgentManagerCaller::new(agent_manager)),
123 }
124 }
125
126 pub fn with_caller(agent_caller: Arc<dyn GuardrailAgentCaller>) -> Self {
130 Self { agent_caller }
131 }
132
133 pub async fn check_prompt_injection(
142 &self,
143 config: &PromptInjectionConfig,
144 content: &str,
145 model: Option<&str>,
146 route_id: Option<&str>,
147 correlation_id: &str,
148 ) -> PromptInjectionResult {
149 if !config.enabled {
150 return PromptInjectionResult::Clean;
151 }
152
153 trace!(
154 correlation_id = correlation_id,
155 agent = %config.agent,
156 content_len = content.len(),
157 "Checking content for prompt injection"
158 );
159
160 let event = GuardrailInspectEvent {
161 correlation_id: correlation_id.to_string(),
162 inspection_type: GuardrailInspectionType::PromptInjection,
163 content: content.to_string(),
164 model: model.map(String::from),
165 categories: vec![],
166 route_id: route_id.map(String::from),
167 metadata: HashMap::new(),
168 };
169
170 let start = Instant::now();
171 let timeout_duration = Duration::from_millis(config.timeout_ms);
172
173 match timeout(
175 timeout_duration,
176 self.agent_caller.call_guardrail_agent(&config.agent, event),
177 )
178 .await
179 {
180 Ok(Ok(response)) => {
181 let duration = start.elapsed();
182 debug!(
183 correlation_id = correlation_id,
184 agent = %config.agent,
185 detected = response.detected,
186 confidence = response.confidence,
187 detection_count = response.detections.len(),
188 duration_ms = duration.as_millis(),
189 "Prompt injection check completed"
190 );
191
192 if response.detected {
193 match config.action {
194 GuardrailAction::Block => PromptInjectionResult::Blocked {
195 status: config.block_status,
196 message: config.block_message.clone().unwrap_or_else(|| {
197 "Request blocked: potential prompt injection detected".to_string()
198 }),
199 detections: response.detections,
200 },
201 GuardrailAction::Log => PromptInjectionResult::Detected {
202 detections: response.detections,
203 },
204 GuardrailAction::Warn => PromptInjectionResult::Warning {
205 detections: response.detections,
206 },
207 }
208 } else {
209 PromptInjectionResult::Clean
210 }
211 }
212 Ok(Err(e)) => {
213 warn!(
214 correlation_id = correlation_id,
215 agent = %config.agent,
216 error = %e,
217 failure_mode = ?config.failure_mode,
218 "Prompt injection agent call failed"
219 );
220
221 match config.failure_mode {
222 GuardrailFailureMode::Open => PromptInjectionResult::Clean,
223 GuardrailFailureMode::Closed => PromptInjectionResult::Blocked {
224 status: 503,
225 message: "Guardrail check unavailable".to_string(),
226 detections: vec![],
227 },
228 }
229 }
230 Err(_) => {
231 warn!(
232 correlation_id = correlation_id,
233 agent = %config.agent,
234 timeout_ms = config.timeout_ms,
235 failure_mode = ?config.failure_mode,
236 "Prompt injection agent call timed out"
237 );
238
239 match config.failure_mode {
240 GuardrailFailureMode::Open => PromptInjectionResult::Clean,
241 GuardrailFailureMode::Closed => PromptInjectionResult::Blocked {
242 status: 504,
243 message: "Guardrail check timed out".to_string(),
244 detections: vec![],
245 },
246 }
247 }
248 }
249 }
250
251 pub async fn check_pii(
259 &self,
260 config: &PiiDetectionConfig,
261 content: &str,
262 route_id: Option<&str>,
263 correlation_id: &str,
264 ) -> PiiCheckResult {
265 if !config.enabled {
266 return PiiCheckResult::Clean;
267 }
268
269 trace!(
270 correlation_id = correlation_id,
271 agent = %config.agent,
272 content_len = content.len(),
273 categories = ?config.categories,
274 "Checking response for PII"
275 );
276
277 let event = GuardrailInspectEvent {
278 correlation_id: correlation_id.to_string(),
279 inspection_type: GuardrailInspectionType::PiiDetection,
280 content: content.to_string(),
281 model: None,
282 categories: config.categories.clone(),
283 route_id: route_id.map(String::from),
284 metadata: HashMap::new(),
285 };
286
287 let start = Instant::now();
288 let timeout_duration = Duration::from_millis(config.timeout_ms);
289
290 match timeout(
291 timeout_duration,
292 self.agent_caller.call_guardrail_agent(&config.agent, event),
293 )
294 .await
295 {
296 Ok(Ok(response)) => {
297 let duration = start.elapsed();
298 debug!(
299 correlation_id = correlation_id,
300 agent = %config.agent,
301 detected = response.detected,
302 detection_count = response.detections.len(),
303 duration_ms = duration.as_millis(),
304 "PII check completed"
305 );
306
307 if response.detected {
308 PiiCheckResult::Detected {
309 detections: response.detections,
310 redacted_content: response.redacted_content,
311 }
312 } else {
313 PiiCheckResult::Clean
314 }
315 }
316 Ok(Err(e)) => {
317 warn!(
318 correlation_id = correlation_id,
319 agent = %config.agent,
320 error = %e,
321 "PII detection agent call failed"
322 );
323
324 PiiCheckResult::Error {
325 message: e.to_string(),
326 }
327 }
328 Err(_) => {
329 warn!(
330 correlation_id = correlation_id,
331 agent = %config.agent,
332 timeout_ms = config.timeout_ms,
333 "PII detection agent call timed out"
334 );
335
336 PiiCheckResult::Error {
337 message: "Agent timeout".to_string(),
338 }
339 }
340 }
341 }
342}
343
344pub fn extract_inference_content(body: &[u8]) -> Option<String> {
349 let json: serde_json::Value = serde_json::from_slice(body).ok()?;
350
351 if let Some(messages) = json.get("messages").and_then(|m| m.as_array()) {
353 let content: Vec<String> = messages
354 .iter()
355 .filter_map(|msg| msg.get("content").and_then(|c| c.as_str()))
356 .map(String::from)
357 .collect();
358 if !content.is_empty() {
359 return Some(content.join("\n"));
360 }
361 }
362
363 if let Some(prompt) = json.get("prompt").and_then(|p| p.as_str()) {
365 return Some(prompt.to_string());
366 }
367
368 for field in &["input", "text", "query", "question"] {
370 if let Some(value) = json.get(*field).and_then(|v| v.as_str()) {
371 return Some(value.to_string());
372 }
373 }
374
375 None
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use sentinel_agent_protocol::{DetectionSeverity, TextSpan};
382 use std::sync::atomic::{AtomicUsize, Ordering};
383 use tokio::sync::Mutex;
384
385 struct MockAgentCaller {
389 response: Mutex<Option<Result<GuardrailResponse, String>>>,
390 call_count: AtomicUsize,
391 }
392
393 impl MockAgentCaller {
394 fn new() -> Self {
395 Self {
396 response: Mutex::new(None),
397 call_count: AtomicUsize::new(0),
398 }
399 }
400
401 fn with_response(response: Result<GuardrailResponse, String>) -> Self {
402 Self {
403 response: Mutex::new(Some(response)),
404 call_count: AtomicUsize::new(0),
405 }
406 }
407
408 fn call_count(&self) -> usize {
409 self.call_count.load(Ordering::SeqCst)
410 }
411 }
412
413 #[async_trait]
414 impl GuardrailAgentCaller for MockAgentCaller {
415 async fn call_guardrail_agent(
416 &self,
417 _agent_name: &str,
418 _event: GuardrailInspectEvent,
419 ) -> Result<GuardrailResponse, String> {
420 self.call_count.fetch_add(1, Ordering::SeqCst);
421
422 let guard = self.response.lock().await;
423 match &*guard {
424 Some(response) => response.clone(),
425 None => Err("No mock response configured".to_string()),
426 }
427 }
428 }
429
430 fn create_prompt_injection_config(
433 action: GuardrailAction,
434 failure_mode: GuardrailFailureMode,
435 ) -> PromptInjectionConfig {
436 PromptInjectionConfig {
437 enabled: true,
438 agent: "test-agent".to_string(),
439 action,
440 block_status: 400,
441 block_message: Some("Blocked: injection detected".to_string()),
442 timeout_ms: 5000,
443 failure_mode,
444 }
445 }
446
447 fn create_pii_config() -> PiiDetectionConfig {
448 PiiDetectionConfig {
449 enabled: true,
450 agent: "pii-scanner".to_string(),
451 action: sentinel_config::PiiAction::Log,
452 categories: vec!["ssn".to_string(), "email".to_string()],
453 timeout_ms: 5000,
454 failure_mode: GuardrailFailureMode::Open,
455 }
456 }
457
458 fn create_detection(category: &str, description: &str) -> GuardrailDetection {
459 GuardrailDetection {
460 category: category.to_string(),
461 description: description.to_string(),
462 severity: DetectionSeverity::High,
463 confidence: Some(0.95),
464 span: Some(TextSpan { start: 0, end: 10 }),
465 }
466 }
467
468 fn create_guardrail_response(
469 detected: bool,
470 detections: Vec<GuardrailDetection>,
471 ) -> GuardrailResponse {
472 GuardrailResponse {
473 detected,
474 confidence: if detected { 0.95 } else { 0.0 },
475 detections,
476 redacted_content: None,
477 }
478 }
479
480 #[test]
483 fn test_extract_openai_content() {
484 let body = br#"{"messages": [{"role": "user", "content": "Hello world"}]}"#;
485 let content = extract_inference_content(body);
486 assert_eq!(content, Some("Hello world".to_string()));
487 }
488
489 #[test]
490 fn test_extract_openai_multi_message() {
491 let body = br#"{
492 "messages": [
493 {"role": "system", "content": "You are helpful"},
494 {"role": "user", "content": "Hello"}
495 ]
496 }"#;
497 let content = extract_inference_content(body);
498 assert_eq!(content, Some("You are helpful\nHello".to_string()));
499 }
500
501 #[test]
502 fn test_extract_anthropic_content() {
503 let body = br#"{"prompt": "Human: Hello\n\nAssistant:"}"#;
504 let content = extract_inference_content(body);
505 assert_eq!(content, Some("Human: Hello\n\nAssistant:".to_string()));
506 }
507
508 #[test]
509 fn test_extract_generic_input() {
510 let body = br#"{"input": "Test query"}"#;
511 let content = extract_inference_content(body);
512 assert_eq!(content, Some("Test query".to_string()));
513 }
514
515 #[test]
516 fn test_extract_generic_text() {
517 let body = br#"{"text": "Some text content"}"#;
518 let content = extract_inference_content(body);
519 assert_eq!(content, Some("Some text content".to_string()));
520 }
521
522 #[test]
523 fn test_extract_generic_query() {
524 let body = br#"{"query": "What is the weather?"}"#;
525 let content = extract_inference_content(body);
526 assert_eq!(content, Some("What is the weather?".to_string()));
527 }
528
529 #[test]
530 fn test_extract_generic_question() {
531 let body = br#"{"question": "How does this work?"}"#;
532 let content = extract_inference_content(body);
533 assert_eq!(content, Some("How does this work?".to_string()));
534 }
535
536 #[test]
537 fn test_extract_invalid_json() {
538 let body = b"not json";
539 let content = extract_inference_content(body);
540 assert_eq!(content, None);
541 }
542
543 #[test]
544 fn test_extract_empty_messages() {
545 let body = br#"{"messages": []}"#;
546 let content = extract_inference_content(body);
547 assert_eq!(content, None);
548 }
549
550 #[test]
551 fn test_extract_messages_without_content() {
552 let body = br#"{"messages": [{"role": "user"}]}"#;
553 let content = extract_inference_content(body);
554 assert_eq!(content, None);
555 }
556
557 #[test]
558 fn test_extract_empty_object() {
559 let body = br#"{}"#;
560 let content = extract_inference_content(body);
561 assert_eq!(content, None);
562 }
563
564 #[test]
565 fn test_extract_nested_content() {
566 let body = br#"{
568 "messages": [
569 {"role": "system"},
570 {"role": "user", "content": "Valid content"},
571 {"role": "assistant"}
572 ]
573 }"#;
574 let content = extract_inference_content(body);
575 assert_eq!(content, Some("Valid content".to_string()));
576 }
577
578 #[tokio::test]
581 async fn test_prompt_injection_disabled() {
582 let mock = Arc::new(MockAgentCaller::new());
583 let processor = GuardrailProcessor::with_caller(mock.clone());
584
585 let mut config =
586 create_prompt_injection_config(GuardrailAction::Block, GuardrailFailureMode::Open);
587 config.enabled = false;
588
589 let result = processor
590 .check_prompt_injection(&config, "test content", None, None, "corr-123")
591 .await;
592
593 assert!(matches!(result, PromptInjectionResult::Clean));
594 assert_eq!(mock.call_count(), 0); }
596
597 #[tokio::test]
598 async fn test_prompt_injection_clean() {
599 let response = create_guardrail_response(false, vec![]);
600 let mock = Arc::new(MockAgentCaller::with_response(Ok(response)));
601 let processor = GuardrailProcessor::with_caller(mock.clone());
602
603 let config =
604 create_prompt_injection_config(GuardrailAction::Block, GuardrailFailureMode::Open);
605
606 let result = processor
607 .check_prompt_injection(
608 &config,
609 "normal content",
610 Some("gpt-4"),
611 Some("route-1"),
612 "corr-123",
613 )
614 .await;
615
616 assert!(matches!(result, PromptInjectionResult::Clean));
617 assert_eq!(mock.call_count(), 1);
618 }
619
620 #[tokio::test]
621 async fn test_prompt_injection_detected_block_action() {
622 let detection = create_detection("injection", "Attempt to override instructions");
623 let response = create_guardrail_response(true, vec![detection]);
624 let mock = Arc::new(MockAgentCaller::with_response(Ok(response)));
625 let processor = GuardrailProcessor::with_caller(mock);
626
627 let config =
628 create_prompt_injection_config(GuardrailAction::Block, GuardrailFailureMode::Open);
629
630 let result = processor
631 .check_prompt_injection(
632 &config,
633 "ignore previous instructions",
634 None,
635 None,
636 "corr-123",
637 )
638 .await;
639
640 match result {
641 PromptInjectionResult::Blocked {
642 status,
643 message,
644 detections,
645 } => {
646 assert_eq!(status, 400);
647 assert_eq!(message, "Blocked: injection detected");
648 assert_eq!(detections.len(), 1);
649 }
650 _ => panic!("Expected Blocked result, got {:?}", result),
651 }
652 }
653
654 #[tokio::test]
655 async fn test_prompt_injection_detected_log_action() {
656 let detection = create_detection("injection", "Suspicious pattern");
657 let response = create_guardrail_response(true, vec![detection]);
658 let mock = Arc::new(MockAgentCaller::with_response(Ok(response)));
659 let processor = GuardrailProcessor::with_caller(mock);
660
661 let config =
662 create_prompt_injection_config(GuardrailAction::Log, GuardrailFailureMode::Open);
663
664 let result = processor
665 .check_prompt_injection(&config, "suspicious content", None, None, "corr-123")
666 .await;
667
668 match result {
669 PromptInjectionResult::Detected { detections } => {
670 assert_eq!(detections.len(), 1);
671 }
672 _ => panic!("Expected Detected result, got {:?}", result),
673 }
674 }
675
676 #[tokio::test]
677 async fn test_prompt_injection_detected_warn_action() {
678 let detection = create_detection("injection", "Possible injection");
679 let response = create_guardrail_response(true, vec![detection]);
680 let mock = Arc::new(MockAgentCaller::with_response(Ok(response)));
681 let processor = GuardrailProcessor::with_caller(mock);
682
683 let config =
684 create_prompt_injection_config(GuardrailAction::Warn, GuardrailFailureMode::Open);
685
686 let result = processor
687 .check_prompt_injection(&config, "maybe suspicious", None, None, "corr-123")
688 .await;
689
690 match result {
691 PromptInjectionResult::Warning { detections } => {
692 assert_eq!(detections.len(), 1);
693 }
694 _ => panic!("Expected Warning result, got {:?}", result),
695 }
696 }
697
698 #[tokio::test]
699 async fn test_prompt_injection_agent_error_fail_open() {
700 let mock = Arc::new(MockAgentCaller::with_response(Err(
701 "Agent unavailable".to_string()
702 )));
703 let processor = GuardrailProcessor::with_caller(mock);
704
705 let config =
706 create_prompt_injection_config(GuardrailAction::Block, GuardrailFailureMode::Open);
707
708 let result = processor
709 .check_prompt_injection(&config, "test content", None, None, "corr-123")
710 .await;
711
712 assert!(matches!(result, PromptInjectionResult::Clean));
714 }
715
716 #[tokio::test]
717 async fn test_prompt_injection_agent_error_fail_closed() {
718 let mock = Arc::new(MockAgentCaller::with_response(Err(
719 "Agent unavailable".to_string()
720 )));
721 let processor = GuardrailProcessor::with_caller(mock);
722
723 let config =
724 create_prompt_injection_config(GuardrailAction::Block, GuardrailFailureMode::Closed);
725
726 let result = processor
727 .check_prompt_injection(&config, "test content", None, None, "corr-123")
728 .await;
729
730 match result {
732 PromptInjectionResult::Blocked {
733 status, message, ..
734 } => {
735 assert_eq!(status, 503);
736 assert_eq!(message, "Guardrail check unavailable");
737 }
738 _ => panic!("Expected Blocked result, got {:?}", result),
739 }
740 }
741
742 #[tokio::test]
743 async fn test_prompt_injection_default_block_message() {
744 let detection = create_detection("injection", "Test");
745 let response = create_guardrail_response(true, vec![detection]);
746 let mock = Arc::new(MockAgentCaller::with_response(Ok(response)));
747 let processor = GuardrailProcessor::with_caller(mock);
748
749 let mut config =
750 create_prompt_injection_config(GuardrailAction::Block, GuardrailFailureMode::Open);
751 config.block_message = None; let result = processor
754 .check_prompt_injection(&config, "injection attempt", None, None, "corr-123")
755 .await;
756
757 match result {
758 PromptInjectionResult::Blocked { message, .. } => {
759 assert_eq!(
760 message,
761 "Request blocked: potential prompt injection detected"
762 );
763 }
764 _ => panic!("Expected Blocked result"),
765 }
766 }
767
768 #[tokio::test]
771 async fn test_pii_disabled() {
772 let mock = Arc::new(MockAgentCaller::new());
773 let processor = GuardrailProcessor::with_caller(mock.clone());
774
775 let mut config = create_pii_config();
776 config.enabled = false;
777
778 let result = processor
779 .check_pii(&config, "content with SSN 123-45-6789", None, "corr-123")
780 .await;
781
782 assert!(matches!(result, PiiCheckResult::Clean));
783 assert_eq!(mock.call_count(), 0);
784 }
785
786 #[tokio::test]
787 async fn test_pii_clean() {
788 let response = create_guardrail_response(false, vec![]);
789 let mock = Arc::new(MockAgentCaller::with_response(Ok(response)));
790 let processor = GuardrailProcessor::with_caller(mock.clone());
791
792 let config = create_pii_config();
793
794 let result = processor
795 .check_pii(
796 &config,
797 "No sensitive data here",
798 Some("route-1"),
799 "corr-123",
800 )
801 .await;
802
803 assert!(matches!(result, PiiCheckResult::Clean));
804 assert_eq!(mock.call_count(), 1);
805 }
806
807 #[tokio::test]
808 async fn test_pii_detected() {
809 let ssn_detection = create_detection("ssn", "Social Security Number detected");
810 let email_detection = create_detection("email", "Email address detected");
811 let mut response = create_guardrail_response(true, vec![ssn_detection, email_detection]);
812 response.redacted_content =
813 Some("My SSN is [REDACTED] and email is [REDACTED]".to_string());
814
815 let mock = Arc::new(MockAgentCaller::with_response(Ok(response)));
816 let processor = GuardrailProcessor::with_caller(mock);
817
818 let config = create_pii_config();
819
820 let result = processor
821 .check_pii(
822 &config,
823 "My SSN is 123-45-6789 and email is test@example.com",
824 None,
825 "corr-123",
826 )
827 .await;
828
829 match result {
830 PiiCheckResult::Detected {
831 detections,
832 redacted_content,
833 } => {
834 assert_eq!(detections.len(), 2);
835 assert!(redacted_content.is_some());
836 assert!(redacted_content.unwrap().contains("[REDACTED]"));
837 }
838 _ => panic!("Expected Detected result, got {:?}", result),
839 }
840 }
841
842 #[tokio::test]
843 async fn test_pii_agent_error() {
844 let mock = Arc::new(MockAgentCaller::with_response(Err(
845 "PII scanner unavailable".to_string(),
846 )));
847 let processor = GuardrailProcessor::with_caller(mock);
848
849 let config = create_pii_config();
850
851 let result = processor
852 .check_pii(&config, "test content", None, "corr-123")
853 .await;
854
855 match result {
856 PiiCheckResult::Error { message } => {
857 assert!(message.contains("unavailable"));
858 }
859 _ => panic!("Expected Error result, got {:?}", result),
860 }
861 }
862
863 #[test]
866 fn test_prompt_injection_result_debug() {
867 let result = PromptInjectionResult::Clean;
868 let debug_str = format!("{:?}", result);
869 assert!(debug_str.contains("Clean"));
870
871 let result = PromptInjectionResult::Blocked {
872 status: 400,
873 message: "test".to_string(),
874 detections: vec![],
875 };
876 let debug_str = format!("{:?}", result);
877 assert!(debug_str.contains("Blocked"));
878 }
879
880 #[test]
881 fn test_pii_check_result_debug() {
882 let result = PiiCheckResult::Clean;
883 let debug_str = format!("{:?}", result);
884 assert!(debug_str.contains("Clean"));
885
886 let result = PiiCheckResult::Error {
887 message: "test error".to_string(),
888 };
889 let debug_str = format!("{:?}", result);
890 assert!(debug_str.contains("Error"));
891 }
892}