Skip to main content

grapsus_proxy/inference/
guardrails.rs

1//! Semantic guardrails for inference routes.
2//!
3//! Provides content inspection via external agents:
4//! - Prompt injection detection on requests
5//! - PII detection on responses
6
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use async_trait::async_trait;
12use pingora_timeout::timeout;
13use tracing::{debug, trace, warn};
14use grapsus_agent_protocol::{
15    Decision, GuardrailDetection, GuardrailInspectEvent, GuardrailInspectionType, GuardrailResponse,
16};
17use grapsus_config::{
18    GuardrailAction, GuardrailFailureMode, PiiDetectionConfig, PromptInjectionConfig,
19};
20
21use crate::agents::AgentManager;
22
23/// Result of a prompt injection check
24#[derive(Debug)]
25pub enum PromptInjectionResult {
26    /// Content is clean (no injection detected)
27    Clean,
28    /// Injection detected, request should be blocked
29    Blocked {
30        status: u16,
31        message: String,
32        detections: Vec<GuardrailDetection>,
33    },
34    /// Injection detected but allowed (logged only)
35    Detected { detections: Vec<GuardrailDetection> },
36    /// Injection detected, add warning header
37    Warning { detections: Vec<GuardrailDetection> },
38    /// Agent error (behavior depends on failure mode)
39    Error { message: String },
40}
41
42/// Result of a PII detection check
43#[derive(Debug)]
44pub enum PiiCheckResult {
45    /// Content is clean (no PII detected)
46    Clean,
47    /// PII detected
48    Detected {
49        detections: Vec<GuardrailDetection>,
50        redacted_content: Option<String>,
51    },
52    /// Agent error
53    Error { message: String },
54}
55
56/// Trait for calling guardrail agents.
57///
58/// This trait allows for mocking agent calls in tests.
59#[async_trait]
60pub trait GuardrailAgentCaller: Send + Sync {
61    /// Call a guardrail agent with an inspection event.
62    async fn call_guardrail_agent(
63        &self,
64        agent_name: &str,
65        event: GuardrailInspectEvent,
66    ) -> Result<GuardrailResponse, String>;
67}
68
69/// Default implementation using the agent manager.
70pub struct AgentManagerCaller {
71    agent_manager: Arc<AgentManager>,
72}
73
74impl AgentManagerCaller {
75    /// Create a new agent manager caller.
76    pub fn new(agent_manager: Arc<AgentManager>) -> Self {
77        Self { agent_manager }
78    }
79}
80
81#[async_trait]
82impl GuardrailAgentCaller for AgentManagerCaller {
83    async fn call_guardrail_agent(
84        &self,
85        agent_name: &str,
86        event: GuardrailInspectEvent,
87    ) -> Result<GuardrailResponse, String> {
88        trace!(
89            agent = agent_name,
90            inspection_type = ?event.inspection_type,
91            "Calling guardrail agent via agent manager"
92        );
93
94        let response = self
95            .agent_manager
96            .call_guardrail_agent(agent_name, event)
97            .await
98            .map_err(|e| format!("Guardrail agent '{}' call failed: {}", agent_name, e))?;
99
100        // Map AgentResponse → GuardrailResponse.
101        //
102        // If the agent puts a serialized GuardrailResponse in
103        // routing_metadata["guardrail_response"], use that directly.
104        // Otherwise, infer from the Decision field.
105        if let Some(raw) = response.routing_metadata.get("guardrail_response") {
106            serde_json::from_str::<GuardrailResponse>(raw).map_err(|e| {
107                format!(
108                    "Failed to parse guardrail response from agent '{}': {}",
109                    agent_name, e
110                )
111            })
112        } else {
113            // Infer from decision
114            match &response.decision {
115                Decision::Allow => Ok(GuardrailResponse::default()),
116                Decision::Block { body, .. } => Ok(GuardrailResponse {
117                    detected: true,
118                    confidence: 1.0,
119                    detections: vec![GuardrailDetection {
120                        category: "agent_block".to_string(),
121                        description: body
122                            .clone()
123                            .unwrap_or_else(|| "Blocked by guardrail agent".to_string()),
124                        severity: grapsus_agent_protocol::DetectionSeverity::High,
125                        confidence: Some(1.0),
126                        span: None,
127                    }],
128                    redacted_content: None,
129                }),
130                _ => Ok(GuardrailResponse::default()),
131            }
132        }
133    }
134}
135
136/// Guardrail processor for semantic content analysis.
137///
138/// Uses external agents to inspect content for security issues
139/// like prompt injection and PII leakage.
140pub struct GuardrailProcessor {
141    agent_caller: Arc<dyn GuardrailAgentCaller>,
142}
143
144impl GuardrailProcessor {
145    /// Create a new guardrail processor with the default agent manager caller.
146    pub fn new(agent_manager: Arc<AgentManager>) -> Self {
147        Self {
148            agent_caller: Arc::new(AgentManagerCaller::new(agent_manager)),
149        }
150    }
151
152    /// Create a new guardrail processor with a custom agent caller.
153    ///
154    /// This is useful for testing with mock implementations.
155    pub fn with_caller(agent_caller: Arc<dyn GuardrailAgentCaller>) -> Self {
156        Self { agent_caller }
157    }
158
159    /// Check request content for prompt injection.
160    ///
161    /// # Arguments
162    /// * `config` - Prompt injection detection configuration
163    /// * `content` - Request body content to inspect
164    /// * `model` - Model name if available
165    /// * `route_id` - Route ID for context
166    /// * `correlation_id` - Request correlation ID
167    pub async fn check_prompt_injection(
168        &self,
169        config: &PromptInjectionConfig,
170        content: &str,
171        model: Option<&str>,
172        route_id: Option<&str>,
173        correlation_id: &str,
174    ) -> PromptInjectionResult {
175        if !config.enabled {
176            return PromptInjectionResult::Clean;
177        }
178
179        trace!(
180            correlation_id = correlation_id,
181            agent = %config.agent,
182            content_len = content.len(),
183            "Checking content for prompt injection"
184        );
185
186        let event = GuardrailInspectEvent {
187            correlation_id: correlation_id.to_string(),
188            inspection_type: GuardrailInspectionType::PromptInjection,
189            content: content.to_string(),
190            model: model.map(String::from),
191            categories: vec![],
192            route_id: route_id.map(String::from),
193            metadata: HashMap::new(),
194        };
195
196        let start = Instant::now();
197        let timeout_duration = Duration::from_millis(config.timeout_ms);
198
199        // Call the agent
200        match timeout(
201            timeout_duration,
202            self.agent_caller.call_guardrail_agent(&config.agent, event),
203        )
204        .await
205        {
206            Ok(Ok(response)) => {
207                let duration = start.elapsed();
208                debug!(
209                    correlation_id = correlation_id,
210                    agent = %config.agent,
211                    detected = response.detected,
212                    confidence = response.confidence,
213                    detection_count = response.detections.len(),
214                    duration_ms = duration.as_millis(),
215                    "Prompt injection check completed"
216                );
217
218                if response.detected {
219                    match config.action {
220                        GuardrailAction::Block => PromptInjectionResult::Blocked {
221                            status: config.block_status,
222                            message: config.block_message.clone().unwrap_or_else(|| {
223                                "Request blocked: potential prompt injection detected".to_string()
224                            }),
225                            detections: response.detections,
226                        },
227                        GuardrailAction::Log => PromptInjectionResult::Detected {
228                            detections: response.detections,
229                        },
230                        GuardrailAction::Warn => PromptInjectionResult::Warning {
231                            detections: response.detections,
232                        },
233                    }
234                } else {
235                    PromptInjectionResult::Clean
236                }
237            }
238            Ok(Err(e)) => {
239                warn!(
240                    correlation_id = correlation_id,
241                    agent = %config.agent,
242                    error = %e,
243                    failure_mode = ?config.failure_mode,
244                    "Prompt injection agent call failed"
245                );
246
247                match config.failure_mode {
248                    GuardrailFailureMode::Open => PromptInjectionResult::Clean,
249                    GuardrailFailureMode::Closed => PromptInjectionResult::Blocked {
250                        status: 503,
251                        message: "Guardrail check unavailable".to_string(),
252                        detections: vec![],
253                    },
254                }
255            }
256            Err(_) => {
257                warn!(
258                    correlation_id = correlation_id,
259                    agent = %config.agent,
260                    timeout_ms = config.timeout_ms,
261                    failure_mode = ?config.failure_mode,
262                    "Prompt injection agent call timed out"
263                );
264
265                match config.failure_mode {
266                    GuardrailFailureMode::Open => PromptInjectionResult::Clean,
267                    GuardrailFailureMode::Closed => PromptInjectionResult::Blocked {
268                        status: 504,
269                        message: "Guardrail check timed out".to_string(),
270                        detections: vec![],
271                    },
272                }
273            }
274        }
275    }
276
277    /// Check response content for PII.
278    ///
279    /// # Arguments
280    /// * `config` - PII detection configuration
281    /// * `content` - Response content to inspect
282    /// * `route_id` - Route ID for context
283    /// * `correlation_id` - Request correlation ID
284    pub async fn check_pii(
285        &self,
286        config: &PiiDetectionConfig,
287        content: &str,
288        route_id: Option<&str>,
289        correlation_id: &str,
290    ) -> PiiCheckResult {
291        if !config.enabled {
292            return PiiCheckResult::Clean;
293        }
294
295        trace!(
296            correlation_id = correlation_id,
297            agent = %config.agent,
298            content_len = content.len(),
299            categories = ?config.categories,
300            "Checking response for PII"
301        );
302
303        let event = GuardrailInspectEvent {
304            correlation_id: correlation_id.to_string(),
305            inspection_type: GuardrailInspectionType::PiiDetection,
306            content: content.to_string(),
307            model: None,
308            categories: config.categories.clone(),
309            route_id: route_id.map(String::from),
310            metadata: HashMap::new(),
311        };
312
313        let start = Instant::now();
314        let timeout_duration = Duration::from_millis(config.timeout_ms);
315
316        match timeout(
317            timeout_duration,
318            self.agent_caller.call_guardrail_agent(&config.agent, event),
319        )
320        .await
321        {
322            Ok(Ok(response)) => {
323                let duration = start.elapsed();
324                debug!(
325                    correlation_id = correlation_id,
326                    agent = %config.agent,
327                    detected = response.detected,
328                    detection_count = response.detections.len(),
329                    duration_ms = duration.as_millis(),
330                    "PII check completed"
331                );
332
333                if response.detected {
334                    PiiCheckResult::Detected {
335                        detections: response.detections,
336                        redacted_content: response.redacted_content,
337                    }
338                } else {
339                    PiiCheckResult::Clean
340                }
341            }
342            Ok(Err(e)) => {
343                warn!(
344                    correlation_id = correlation_id,
345                    agent = %config.agent,
346                    error = %e,
347                    "PII detection agent call failed"
348                );
349
350                PiiCheckResult::Error {
351                    message: e.to_string(),
352                }
353            }
354            Err(_) => {
355                warn!(
356                    correlation_id = correlation_id,
357                    agent = %config.agent,
358                    timeout_ms = config.timeout_ms,
359                    "PII detection agent call timed out"
360                );
361
362                PiiCheckResult::Error {
363                    message: "Agent timeout".to_string(),
364                }
365            }
366        }
367    }
368}
369
370/// Extract message content from an inference request body.
371///
372/// Attempts to parse the body as JSON and extract message content
373/// from common inference API formats (OpenAI, Anthropic, etc.)
374pub fn extract_inference_content(body: &[u8]) -> Option<String> {
375    let json: serde_json::Value = serde_json::from_slice(body).ok()?;
376
377    // OpenAI format: {"messages": [{"content": "..."}]}
378    if let Some(messages) = json.get("messages").and_then(|m| m.as_array()) {
379        let content: Vec<String> = messages
380            .iter()
381            .filter_map(|msg| msg.get("content").and_then(|c| c.as_str()))
382            .map(String::from)
383            .collect();
384        if !content.is_empty() {
385            return Some(content.join("\n"));
386        }
387    }
388
389    // Anthropic format: {"prompt": "..."}
390    if let Some(prompt) = json.get("prompt").and_then(|p| p.as_str()) {
391        return Some(prompt.to_string());
392    }
393
394    // Generic: look for common content fields
395    for field in &["input", "text", "query", "question"] {
396        if let Some(value) = json.get(*field).and_then(|v| v.as_str()) {
397            return Some(value.to_string());
398        }
399    }
400
401    None
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407    use std::sync::atomic::{AtomicUsize, Ordering};
408    use tokio::sync::Mutex;
409    use grapsus_agent_protocol::{DetectionSeverity, TextSpan};
410
411    // ==================== Mock Agent Caller ====================
412
413    /// Mock agent caller for testing guardrail processor
414    struct MockAgentCaller {
415        response: Mutex<Option<Result<GuardrailResponse, String>>>,
416        call_count: AtomicUsize,
417    }
418
419    impl MockAgentCaller {
420        fn new() -> Self {
421            Self {
422                response: Mutex::new(None),
423                call_count: AtomicUsize::new(0),
424            }
425        }
426
427        fn with_response(response: Result<GuardrailResponse, String>) -> Self {
428            Self {
429                response: Mutex::new(Some(response)),
430                call_count: AtomicUsize::new(0),
431            }
432        }
433
434        fn call_count(&self) -> usize {
435            self.call_count.load(Ordering::SeqCst)
436        }
437    }
438
439    #[async_trait]
440    impl GuardrailAgentCaller for MockAgentCaller {
441        async fn call_guardrail_agent(
442            &self,
443            _agent_name: &str,
444            _event: GuardrailInspectEvent,
445        ) -> Result<GuardrailResponse, String> {
446            self.call_count.fetch_add(1, Ordering::SeqCst);
447
448            let guard = self.response.lock().await;
449            match &*guard {
450                Some(response) => response.clone(),
451                None => Err("No mock response configured".to_string()),
452            }
453        }
454    }
455
456    // ==================== Test Helpers ====================
457
458    fn create_prompt_injection_config(
459        action: GuardrailAction,
460        failure_mode: GuardrailFailureMode,
461    ) -> PromptInjectionConfig {
462        PromptInjectionConfig {
463            enabled: true,
464            agent: "test-agent".to_string(),
465            action,
466            block_status: 400,
467            block_message: Some("Blocked: injection detected".to_string()),
468            timeout_ms: 5000,
469            failure_mode,
470        }
471    }
472
473    fn create_pii_config() -> PiiDetectionConfig {
474        PiiDetectionConfig {
475            enabled: true,
476            agent: "pii-scanner".to_string(),
477            action: grapsus_config::PiiAction::Log,
478            categories: vec!["ssn".to_string(), "email".to_string()],
479            timeout_ms: 5000,
480            failure_mode: GuardrailFailureMode::Open,
481        }
482    }
483
484    fn create_detection(category: &str, description: &str) -> GuardrailDetection {
485        GuardrailDetection {
486            category: category.to_string(),
487            description: description.to_string(),
488            severity: DetectionSeverity::High,
489            confidence: Some(0.95),
490            span: Some(TextSpan { start: 0, end: 10 }),
491        }
492    }
493
494    fn create_guardrail_response(
495        detected: bool,
496        detections: Vec<GuardrailDetection>,
497    ) -> GuardrailResponse {
498        GuardrailResponse {
499            detected,
500            confidence: if detected { 0.95 } else { 0.0 },
501            detections,
502            redacted_content: None,
503        }
504    }
505
506    // ==================== extract_inference_content Tests ====================
507
508    #[test]
509    fn test_extract_openai_content() {
510        let body = br#"{"messages": [{"role": "user", "content": "Hello world"}]}"#;
511        let content = extract_inference_content(body);
512        assert_eq!(content, Some("Hello world".to_string()));
513    }
514
515    #[test]
516    fn test_extract_openai_multi_message() {
517        let body = br#"{
518            "messages": [
519                {"role": "system", "content": "You are helpful"},
520                {"role": "user", "content": "Hello"}
521            ]
522        }"#;
523        let content = extract_inference_content(body);
524        assert_eq!(content, Some("You are helpful\nHello".to_string()));
525    }
526
527    #[test]
528    fn test_extract_anthropic_content() {
529        let body = br#"{"prompt": "Human: Hello\n\nAssistant:"}"#;
530        let content = extract_inference_content(body);
531        assert_eq!(content, Some("Human: Hello\n\nAssistant:".to_string()));
532    }
533
534    #[test]
535    fn test_extract_generic_input() {
536        let body = br#"{"input": "Test query"}"#;
537        let content = extract_inference_content(body);
538        assert_eq!(content, Some("Test query".to_string()));
539    }
540
541    #[test]
542    fn test_extract_generic_text() {
543        let body = br#"{"text": "Some text content"}"#;
544        let content = extract_inference_content(body);
545        assert_eq!(content, Some("Some text content".to_string()));
546    }
547
548    #[test]
549    fn test_extract_generic_query() {
550        let body = br#"{"query": "What is the weather?"}"#;
551        let content = extract_inference_content(body);
552        assert_eq!(content, Some("What is the weather?".to_string()));
553    }
554
555    #[test]
556    fn test_extract_generic_question() {
557        let body = br#"{"question": "How does this work?"}"#;
558        let content = extract_inference_content(body);
559        assert_eq!(content, Some("How does this work?".to_string()));
560    }
561
562    #[test]
563    fn test_extract_invalid_json() {
564        let body = b"not json";
565        let content = extract_inference_content(body);
566        assert_eq!(content, None);
567    }
568
569    #[test]
570    fn test_extract_empty_messages() {
571        let body = br#"{"messages": []}"#;
572        let content = extract_inference_content(body);
573        assert_eq!(content, None);
574    }
575
576    #[test]
577    fn test_extract_messages_without_content() {
578        let body = br#"{"messages": [{"role": "user"}]}"#;
579        let content = extract_inference_content(body);
580        assert_eq!(content, None);
581    }
582
583    #[test]
584    fn test_extract_empty_object() {
585        let body = br#"{}"#;
586        let content = extract_inference_content(body);
587        assert_eq!(content, None);
588    }
589
590    #[test]
591    fn test_extract_nested_content() {
592        // Messages with mixed content types (some with content, some without)
593        let body = br#"{
594            "messages": [
595                {"role": "system"},
596                {"role": "user", "content": "Valid content"},
597                {"role": "assistant"}
598            ]
599        }"#;
600        let content = extract_inference_content(body);
601        assert_eq!(content, Some("Valid content".to_string()));
602    }
603
604    // ==================== Prompt Injection Tests ====================
605
606    #[tokio::test]
607    async fn test_prompt_injection_disabled() {
608        let mock = Arc::new(MockAgentCaller::new());
609        let processor = GuardrailProcessor::with_caller(mock.clone());
610
611        let mut config =
612            create_prompt_injection_config(GuardrailAction::Block, GuardrailFailureMode::Open);
613        config.enabled = false;
614
615        let result = processor
616            .check_prompt_injection(&config, "test content", None, None, "corr-123")
617            .await;
618
619        assert!(matches!(result, PromptInjectionResult::Clean));
620        assert_eq!(mock.call_count(), 0); // Agent should not be called
621    }
622
623    #[tokio::test]
624    async fn test_prompt_injection_clean() {
625        let response = create_guardrail_response(false, vec![]);
626        let mock = Arc::new(MockAgentCaller::with_response(Ok(response)));
627        let processor = GuardrailProcessor::with_caller(mock.clone());
628
629        let config =
630            create_prompt_injection_config(GuardrailAction::Block, GuardrailFailureMode::Open);
631
632        let result = processor
633            .check_prompt_injection(
634                &config,
635                "normal content",
636                Some("gpt-4"),
637                Some("route-1"),
638                "corr-123",
639            )
640            .await;
641
642        assert!(matches!(result, PromptInjectionResult::Clean));
643        assert_eq!(mock.call_count(), 1);
644    }
645
646    #[tokio::test]
647    async fn test_prompt_injection_detected_block_action() {
648        let detection = create_detection("injection", "Attempt to override instructions");
649        let response = create_guardrail_response(true, vec![detection]);
650        let mock = Arc::new(MockAgentCaller::with_response(Ok(response)));
651        let processor = GuardrailProcessor::with_caller(mock);
652
653        let config =
654            create_prompt_injection_config(GuardrailAction::Block, GuardrailFailureMode::Open);
655
656        let result = processor
657            .check_prompt_injection(
658                &config,
659                "ignore previous instructions",
660                None,
661                None,
662                "corr-123",
663            )
664            .await;
665
666        match result {
667            PromptInjectionResult::Blocked {
668                status,
669                message,
670                detections,
671            } => {
672                assert_eq!(status, 400);
673                assert_eq!(message, "Blocked: injection detected");
674                assert_eq!(detections.len(), 1);
675            }
676            _ => panic!("Expected Blocked result, got {:?}", result),
677        }
678    }
679
680    #[tokio::test]
681    async fn test_prompt_injection_detected_log_action() {
682        let detection = create_detection("injection", "Suspicious pattern");
683        let response = create_guardrail_response(true, vec![detection]);
684        let mock = Arc::new(MockAgentCaller::with_response(Ok(response)));
685        let processor = GuardrailProcessor::with_caller(mock);
686
687        let config =
688            create_prompt_injection_config(GuardrailAction::Log, GuardrailFailureMode::Open);
689
690        let result = processor
691            .check_prompt_injection(&config, "suspicious content", None, None, "corr-123")
692            .await;
693
694        match result {
695            PromptInjectionResult::Detected { detections } => {
696                assert_eq!(detections.len(), 1);
697            }
698            _ => panic!("Expected Detected result, got {:?}", result),
699        }
700    }
701
702    #[tokio::test]
703    async fn test_prompt_injection_detected_warn_action() {
704        let detection = create_detection("injection", "Possible injection");
705        let response = create_guardrail_response(true, vec![detection]);
706        let mock = Arc::new(MockAgentCaller::with_response(Ok(response)));
707        let processor = GuardrailProcessor::with_caller(mock);
708
709        let config =
710            create_prompt_injection_config(GuardrailAction::Warn, GuardrailFailureMode::Open);
711
712        let result = processor
713            .check_prompt_injection(&config, "maybe suspicious", None, None, "corr-123")
714            .await;
715
716        match result {
717            PromptInjectionResult::Warning { detections } => {
718                assert_eq!(detections.len(), 1);
719            }
720            _ => panic!("Expected Warning result, got {:?}", result),
721        }
722    }
723
724    #[tokio::test]
725    async fn test_prompt_injection_agent_error_fail_open() {
726        let mock = Arc::new(MockAgentCaller::with_response(Err(
727            "Agent unavailable".to_string()
728        )));
729        let processor = GuardrailProcessor::with_caller(mock);
730
731        let config =
732            create_prompt_injection_config(GuardrailAction::Block, GuardrailFailureMode::Open);
733
734        let result = processor
735            .check_prompt_injection(&config, "test content", None, None, "corr-123")
736            .await;
737
738        // Fail-open: allow the request despite agent error
739        assert!(matches!(result, PromptInjectionResult::Clean));
740    }
741
742    #[tokio::test]
743    async fn test_prompt_injection_agent_error_fail_closed() {
744        let mock = Arc::new(MockAgentCaller::with_response(Err(
745            "Agent unavailable".to_string()
746        )));
747        let processor = GuardrailProcessor::with_caller(mock);
748
749        let config =
750            create_prompt_injection_config(GuardrailAction::Block, GuardrailFailureMode::Closed);
751
752        let result = processor
753            .check_prompt_injection(&config, "test content", None, None, "corr-123")
754            .await;
755
756        // Fail-closed: block the request on agent error
757        match result {
758            PromptInjectionResult::Blocked {
759                status, message, ..
760            } => {
761                assert_eq!(status, 503);
762                assert_eq!(message, "Guardrail check unavailable");
763            }
764            _ => panic!("Expected Blocked result, got {:?}", result),
765        }
766    }
767
768    #[tokio::test]
769    async fn test_prompt_injection_default_block_message() {
770        let detection = create_detection("injection", "Test");
771        let response = create_guardrail_response(true, vec![detection]);
772        let mock = Arc::new(MockAgentCaller::with_response(Ok(response)));
773        let processor = GuardrailProcessor::with_caller(mock);
774
775        let mut config =
776            create_prompt_injection_config(GuardrailAction::Block, GuardrailFailureMode::Open);
777        config.block_message = None; // Use default message
778
779        let result = processor
780            .check_prompt_injection(&config, "injection attempt", None, None, "corr-123")
781            .await;
782
783        match result {
784            PromptInjectionResult::Blocked { message, .. } => {
785                assert_eq!(
786                    message,
787                    "Request blocked: potential prompt injection detected"
788                );
789            }
790            _ => panic!("Expected Blocked result"),
791        }
792    }
793
794    // ==================== PII Detection Tests ====================
795
796    #[tokio::test]
797    async fn test_pii_disabled() {
798        let mock = Arc::new(MockAgentCaller::new());
799        let processor = GuardrailProcessor::with_caller(mock.clone());
800
801        let mut config = create_pii_config();
802        config.enabled = false;
803
804        let result = processor
805            .check_pii(&config, "content with SSN 123-45-6789", None, "corr-123")
806            .await;
807
808        assert!(matches!(result, PiiCheckResult::Clean));
809        assert_eq!(mock.call_count(), 0);
810    }
811
812    #[tokio::test]
813    async fn test_pii_clean() {
814        let response = create_guardrail_response(false, vec![]);
815        let mock = Arc::new(MockAgentCaller::with_response(Ok(response)));
816        let processor = GuardrailProcessor::with_caller(mock.clone());
817
818        let config = create_pii_config();
819
820        let result = processor
821            .check_pii(
822                &config,
823                "No sensitive data here",
824                Some("route-1"),
825                "corr-123",
826            )
827            .await;
828
829        assert!(matches!(result, PiiCheckResult::Clean));
830        assert_eq!(mock.call_count(), 1);
831    }
832
833    #[tokio::test]
834    async fn test_pii_detected() {
835        let ssn_detection = create_detection("ssn", "Social Security Number detected");
836        let email_detection = create_detection("email", "Email address detected");
837        let mut response = create_guardrail_response(true, vec![ssn_detection, email_detection]);
838        response.redacted_content =
839            Some("My SSN is [REDACTED] and email is [REDACTED]".to_string());
840
841        let mock = Arc::new(MockAgentCaller::with_response(Ok(response)));
842        let processor = GuardrailProcessor::with_caller(mock);
843
844        let config = create_pii_config();
845
846        let result = processor
847            .check_pii(
848                &config,
849                "My SSN is 123-45-6789 and email is test@example.com",
850                None,
851                "corr-123",
852            )
853            .await;
854
855        match result {
856            PiiCheckResult::Detected {
857                detections,
858                redacted_content,
859            } => {
860                assert_eq!(detections.len(), 2);
861                assert!(redacted_content.is_some());
862                assert!(redacted_content.unwrap().contains("[REDACTED]"));
863            }
864            _ => panic!("Expected Detected result, got {:?}", result),
865        }
866    }
867
868    #[tokio::test]
869    async fn test_pii_agent_error() {
870        let mock = Arc::new(MockAgentCaller::with_response(Err(
871            "PII scanner unavailable".to_string(),
872        )));
873        let processor = GuardrailProcessor::with_caller(mock);
874
875        let config = create_pii_config();
876
877        let result = processor
878            .check_pii(&config, "test content", None, "corr-123")
879            .await;
880
881        match result {
882            PiiCheckResult::Error { message } => {
883                assert!(message.contains("unavailable"));
884            }
885            _ => panic!("Expected Error result, got {:?}", result),
886        }
887    }
888
889    // ==================== Result Type Tests ====================
890
891    #[test]
892    fn test_prompt_injection_result_debug() {
893        let result = PromptInjectionResult::Clean;
894        let debug_str = format!("{:?}", result);
895        assert!(debug_str.contains("Clean"));
896
897        let result = PromptInjectionResult::Blocked {
898            status: 400,
899            message: "test".to_string(),
900            detections: vec![],
901        };
902        let debug_str = format!("{:?}", result);
903        assert!(debug_str.contains("Blocked"));
904    }
905
906    #[test]
907    fn test_pii_check_result_debug() {
908        let result = PiiCheckResult::Clean;
909        let debug_str = format!("{:?}", result);
910        assert!(debug_str.contains("Clean"));
911
912        let result = PiiCheckResult::Error {
913            message: "test error".to_string(),
914        };
915        let debug_str = format!("{:?}", result);
916        assert!(debug_str.contains("Error"));
917    }
918}