use super::input_processor::{InputProcessor, InputProcessorResult};
use super::PolicyContext;
use async_trait::async_trait;
#[cfg(feature = "guardrails")]
use enact_guardrails::{PiiClass, PiiDetector};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PiiInputMode {
Allow,
Warn,
BlockDirect,
BlockAll,
}
impl Default for PiiInputMode {
fn default() -> Self {
Self::Warn
}
}
pub struct PiiInputProcessor {
mode: PiiInputMode,
#[cfg(feature = "guardrails")]
detector: PiiDetector,
}
impl PiiInputProcessor {
#[cfg(feature = "guardrails")]
pub fn new() -> Self {
Self {
mode: PiiInputMode::default(),
detector: PiiDetector::new(),
}
}
#[cfg(not(feature = "guardrails"))]
pub fn new() -> Self {
Self {
mode: PiiInputMode::Allow,
}
}
pub fn with_mode(mut self, mode: PiiInputMode) -> Self {
self.mode = mode;
self
}
#[cfg(feature = "guardrails")]
fn check_pii(&self, input: &str) -> Option<(PiiClass, Vec<String>)> {
let matches = self.detector.detect(input);
if matches.is_empty() {
return None;
}
let highest = matches
.iter()
.fold(PiiClass::None, |acc, m| acc.max(m.class));
let patterns: Vec<String> = matches.iter().map(|m| m.pattern_name.clone()).collect();
Some((highest, patterns))
}
#[cfg(not(feature = "guardrails"))]
#[allow(dead_code)]
fn check_pii(&self, _input: &str) -> Option<((), Vec<String>)> {
None
}
}
impl Default for PiiInputProcessor {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl InputProcessor for PiiInputProcessor {
fn name(&self) -> &str {
"pii-input"
}
fn priority(&self) -> u32 {
50 }
#[cfg(feature = "guardrails")]
async fn process(
&self,
input: &str,
_ctx: &PolicyContext,
) -> anyhow::Result<InputProcessorResult> {
if self.mode == PiiInputMode::Allow {
return Ok(InputProcessorResult::Pass);
}
if let Some((class, patterns)) = self.check_pii(input) {
let pattern_list = patterns.join(", ");
match self.mode {
PiiInputMode::Allow => {
Ok(InputProcessorResult::Pass)
}
PiiInputMode::Warn => {
tracing::warn!(
pii_class = ?class,
patterns = %pattern_list,
"PII detected in input"
);
Ok(InputProcessorResult::Pass)
}
PiiInputMode::BlockDirect => {
if class == PiiClass::Direct {
Ok(InputProcessorResult::Block {
reason: format!("Direct PII detected in input: {}", pattern_list),
processor: self.name().to_string(),
})
} else {
tracing::warn!(
pii_class = ?class,
patterns = %pattern_list,
"Indirect/Sensitive PII detected in input (allowed)"
);
Ok(InputProcessorResult::Pass)
}
}
PiiInputMode::BlockAll => {
Ok(InputProcessorResult::Block {
reason: format!("PII detected in input ({:?}): {}", class, pattern_list),
processor: self.name().to_string(),
})
}
}
} else {
Ok(InputProcessorResult::Pass)
}
}
#[cfg(not(feature = "guardrails"))]
async fn process(
&self,
_input: &str,
_ctx: &PolicyContext,
) -> anyhow::Result<InputProcessorResult> {
Ok(InputProcessorResult::Pass)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::policy::PolicyAction;
use std::collections::HashMap;
fn test_context() -> PolicyContext {
PolicyContext {
tenant_id: None,
user_id: None,
action: PolicyAction::StartExecution { graph_id: None },
metadata: HashMap::new(),
}
}
#[tokio::test]
async fn test_pii_input_processor_name() {
let processor = PiiInputProcessor::new();
assert_eq!(processor.name(), "pii-input");
}
#[tokio::test]
async fn test_pii_input_processor_priority() {
let processor = PiiInputProcessor::new();
assert_eq!(processor.priority(), 50);
}
#[cfg(feature = "guardrails")]
#[tokio::test]
async fn test_pii_input_allow_mode() {
let processor = PiiInputProcessor::new().with_mode(PiiInputMode::Allow);
let ctx = test_context();
let result = processor
.process("Email: user@example.com", &ctx)
.await
.unwrap();
assert!(result.should_proceed());
}
#[cfg(feature = "guardrails")]
#[tokio::test]
async fn test_pii_input_warn_mode() {
let processor = PiiInputProcessor::new().with_mode(PiiInputMode::Warn);
let ctx = test_context();
let result = processor
.process("Email: user@example.com", &ctx)
.await
.unwrap();
assert!(result.should_proceed());
}
#[cfg(feature = "guardrails")]
#[tokio::test]
async fn test_pii_input_block_direct() {
let processor = PiiInputProcessor::new().with_mode(PiiInputMode::BlockDirect);
let ctx = test_context();
let result = processor
.process("Email: user@example.com", &ctx)
.await
.unwrap();
assert!(result.is_blocked());
let result = processor.process("Hello world", &ctx).await.unwrap();
assert!(result.should_proceed());
}
#[cfg(feature = "guardrails")]
#[tokio::test]
async fn test_pii_input_block_all() {
let processor = PiiInputProcessor::new().with_mode(PiiInputMode::BlockAll);
let ctx = test_context();
let result = processor.process("IP: 192.168.1.1", &ctx).await.unwrap();
assert!(result.is_blocked());
}
#[tokio::test]
async fn test_pii_input_no_pii() {
let processor = PiiInputProcessor::new();
let ctx = test_context();
let result = processor
.process("Hello, how can I help?", &ctx)
.await
.unwrap();
assert!(result.should_proceed());
}
}