enact-core 0.0.1

Core agent runtime for Enact - Graph-Native AI agents
Documentation
//! PII Input Processor
//!
//! Detects PII in user input before execution.
//! Can block, warn, or allow based on configuration.

use super::input_processor::{InputProcessor, InputProcessorResult};
use super::PolicyContext;
use async_trait::async_trait;

#[cfg(feature = "guardrails")]
use enact_guardrails::{PiiClass, PiiDetector};

/// PII detection mode
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PiiInputMode {
    /// Allow all input (no PII checking)
    Allow,
    /// Warn if PII detected but allow
    Warn,
    /// Block if Direct PII detected (email, SSN, etc.)
    BlockDirect,
    /// Block if any PII detected
    BlockAll,
}

impl Default for PiiInputMode {
    fn default() -> Self {
        Self::Warn
    }
}

/// PII Input Processor
///
/// Detects PII in user input before sending to LLM.
/// Behavior is configurable via `PiiInputMode`.
pub struct PiiInputProcessor {
    mode: PiiInputMode,
    #[cfg(feature = "guardrails")]
    detector: PiiDetector,
}

impl PiiInputProcessor {
    /// Create a new PII input processor with default mode (Warn)
    #[cfg(feature = "guardrails")]
    pub fn new() -> Self {
        Self {
            mode: PiiInputMode::default(),
            detector: PiiDetector::new(),
        }
    }

    /// Create a new PII input processor (no-op when guardrails disabled)
    #[cfg(not(feature = "guardrails"))]
    pub fn new() -> Self {
        Self {
            mode: PiiInputMode::Allow,
        }
    }

    /// Set the detection mode
    pub fn with_mode(mut self, mode: PiiInputMode) -> Self {
        self.mode = mode;
        self
    }

    /// Check input for PII
    #[cfg(feature = "guardrails")]
    fn check_pii(&self, input: &str) -> Option<(PiiClass, Vec<String>)> {
        let matches = self.detector.detect(input);
        if matches.is_empty() {
            return None;
        }

        // Get highest classification from matches
        let highest = matches
            .iter()
            .fold(PiiClass::None, |acc, m| acc.max(m.class));

        // Collect pattern names for reporting
        let patterns: Vec<String> = matches.iter().map(|m| m.pattern_name.clone()).collect();

        Some((highest, patterns))
    }

    /// Check input for PII (no-op when guardrails disabled)
    #[cfg(not(feature = "guardrails"))]
    #[allow(dead_code)]
    fn check_pii(&self, _input: &str) -> Option<((), Vec<String>)> {
        None
    }
}

impl Default for PiiInputProcessor {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl InputProcessor for PiiInputProcessor {
    fn name(&self) -> &str {
        "pii-input"
    }

    fn priority(&self) -> u32 {
        50 // Run early in the pipeline
    }

    #[cfg(feature = "guardrails")]
    async fn process(
        &self,
        input: &str,
        _ctx: &PolicyContext,
    ) -> anyhow::Result<InputProcessorResult> {
        // Skip if mode is Allow
        if self.mode == PiiInputMode::Allow {
            return Ok(InputProcessorResult::Pass);
        }

        // Check for PII
        if let Some((class, patterns)) = self.check_pii(input) {
            let pattern_list = patterns.join(", ");

            match self.mode {
                PiiInputMode::Allow => {
                    // Already handled above
                    Ok(InputProcessorResult::Pass)
                }
                PiiInputMode::Warn => {
                    // Log warning but allow
                    tracing::warn!(
                        pii_class = ?class,
                        patterns = %pattern_list,
                        "PII detected in input"
                    );
                    Ok(InputProcessorResult::Pass)
                }
                PiiInputMode::BlockDirect => {
                    // Block only Direct PII
                    if class == PiiClass::Direct {
                        Ok(InputProcessorResult::Block {
                            reason: format!("Direct PII detected in input: {}", pattern_list),
                            processor: self.name().to_string(),
                        })
                    } else {
                        tracing::warn!(
                            pii_class = ?class,
                            patterns = %pattern_list,
                            "Indirect/Sensitive PII detected in input (allowed)"
                        );
                        Ok(InputProcessorResult::Pass)
                    }
                }
                PiiInputMode::BlockAll => {
                    // Block any PII
                    Ok(InputProcessorResult::Block {
                        reason: format!("PII detected in input ({:?}): {}", class, pattern_list),
                        processor: self.name().to_string(),
                    })
                }
            }
        } else {
            Ok(InputProcessorResult::Pass)
        }
    }

    #[cfg(not(feature = "guardrails"))]
    async fn process(
        &self,
        _input: &str,
        _ctx: &PolicyContext,
    ) -> anyhow::Result<InputProcessorResult> {
        // No guardrails feature, always pass
        Ok(InputProcessorResult::Pass)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::policy::PolicyAction;
    use std::collections::HashMap;

    fn test_context() -> PolicyContext {
        PolicyContext {
            tenant_id: None,
            user_id: None,
            action: PolicyAction::StartExecution { graph_id: None },
            metadata: HashMap::new(),
        }
    }

    #[tokio::test]
    async fn test_pii_input_processor_name() {
        let processor = PiiInputProcessor::new();
        assert_eq!(processor.name(), "pii-input");
    }

    #[tokio::test]
    async fn test_pii_input_processor_priority() {
        let processor = PiiInputProcessor::new();
        assert_eq!(processor.priority(), 50);
    }

    #[cfg(feature = "guardrails")]
    #[tokio::test]
    async fn test_pii_input_allow_mode() {
        let processor = PiiInputProcessor::new().with_mode(PiiInputMode::Allow);
        let ctx = test_context();

        // Even with PII, should pass
        let result = processor
            .process("Email: user@example.com", &ctx)
            .await
            .unwrap();
        assert!(result.should_proceed());
    }

    #[cfg(feature = "guardrails")]
    #[tokio::test]
    async fn test_pii_input_warn_mode() {
        let processor = PiiInputProcessor::new().with_mode(PiiInputMode::Warn);
        let ctx = test_context();

        // With PII, should warn but pass
        let result = processor
            .process("Email: user@example.com", &ctx)
            .await
            .unwrap();
        assert!(result.should_proceed());
    }

    #[cfg(feature = "guardrails")]
    #[tokio::test]
    async fn test_pii_input_block_direct() {
        let processor = PiiInputProcessor::new().with_mode(PiiInputMode::BlockDirect);
        let ctx = test_context();

        // Direct PII (email) should block
        let result = processor
            .process("Email: user@example.com", &ctx)
            .await
            .unwrap();
        assert!(result.is_blocked());

        // No PII should pass
        let result = processor.process("Hello world", &ctx).await.unwrap();
        assert!(result.should_proceed());
    }

    #[cfg(feature = "guardrails")]
    #[tokio::test]
    async fn test_pii_input_block_all() {
        let processor = PiiInputProcessor::new().with_mode(PiiInputMode::BlockAll);
        let ctx = test_context();

        // Any PII should block
        let result = processor.process("IP: 192.168.1.1", &ctx).await.unwrap();
        assert!(result.is_blocked());
    }

    #[tokio::test]
    async fn test_pii_input_no_pii() {
        let processor = PiiInputProcessor::new();
        let ctx = test_context();

        // No PII should always pass
        let result = processor
            .process("Hello, how can I help?", &ctx)
            .await
            .unwrap();
        assert!(result.should_proceed());
    }
}