use async_trait::async_trait;
use sha2::{Digest, Sha256};
use super::pii_detector::{PIIDetector, PIIType};
use super::{Middleware, MiddlewareContext, MiddlewareError};
use crate::prompt::PromptArgs;
use crate::schemas::agent::{AgentAction, AgentEvent};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PIIStrategy {
Redact,
Mask,
Hash,
Block,
}
impl PIIStrategy {
pub fn as_str(&self) -> &str {
match self {
PIIStrategy::Redact => "redact",
PIIStrategy::Mask => "mask",
PIIStrategy::Hash => "hash",
PIIStrategy::Block => "block",
}
}
}
pub struct PIIMiddleware {
pii_type: PIIType,
strategy: PIIStrategy,
detector: PIIDetector,
apply_to_input: bool,
apply_to_output: bool,
apply_to_tool_results: bool,
custom_detector: Option<String>,
}
impl PIIMiddleware {
pub fn new(pii_type: PIIType, strategy: PIIStrategy) -> Self {
let detector = PIIDetector::new(pii_type.clone());
Self {
pii_type,
strategy,
detector,
apply_to_input: true,
apply_to_output: false,
apply_to_tool_results: false,
custom_detector: None,
}
}
pub fn with_custom_pattern(
pii_type: PIIType,
strategy: PIIStrategy,
pattern: &str,
) -> Result<Self, regex::Error> {
let detector = PIIDetector::with_custom_pattern(pii_type.clone(), pattern)?;
Ok(Self {
pii_type,
strategy,
detector,
apply_to_input: true,
apply_to_output: false,
apply_to_tool_results: false,
custom_detector: Some(pattern.to_string()),
})
}
pub fn with_apply_to_input(mut self, apply: bool) -> Self {
self.apply_to_input = apply;
self
}
pub fn with_apply_to_output(mut self, apply: bool) -> Self {
self.apply_to_output = apply;
self
}
pub fn with_apply_to_tool_results(mut self, apply: bool) -> Self {
self.apply_to_tool_results = apply;
self
}
fn process_text(&self, text: &str) -> Result<String, MiddlewareError> {
let matches = self.detector.detect(text);
if matches.is_empty() {
return Ok(text.to_string());
}
if matches!(self.strategy, PIIStrategy::Block) {
return Err(MiddlewareError::Aborted(format!(
"PII detected: {} instances of {} found",
matches.len(),
self.pii_type.as_str()
)));
}
let mut result = text.to_string();
for mat in matches.iter().rev() {
let replacement = match self.strategy {
PIIStrategy::Redact => {
format!("[REDACTED_{}]", self.pii_type.as_str())
}
PIIStrategy::Mask => self.mask_pii(&mat.matched_text),
PIIStrategy::Hash => self.hash_pii(&mat.matched_text),
PIIStrategy::Block => unreachable!(), };
result.replace_range(mat.start..mat.end, &replacement);
}
Ok(result)
}
fn mask_pii(&self, pii: &str) -> String {
match self.pii_type {
PIIType::CreditCard => {
let digits: String = pii.chars().filter(|c| c.is_ascii_digit()).collect();
if digits.len() >= 4 {
let last_four = &digits[digits.len() - 4..];
format!("****-****-****-{}", last_four)
} else {
"****-****-****-****".to_string()
}
}
PIIType::Email => {
if let Some(at_pos) = pii.find('@') {
if at_pos > 0 {
let first_char = &pii[..1];
format!("{}***{}", first_char, &pii[at_pos..])
} else {
format!("***{}", &pii[at_pos..])
}
} else {
"***@***".to_string()
}
}
PIIType::IPAddress => {
if let Some(last_dot) = pii.rfind('.') {
if let Some(second_last_dot) = pii[..last_dot].rfind('.') {
format!("{}.***.***", &pii[..=second_last_dot])
} else {
format!("{}.***", &pii[..last_dot])
}
} else {
"***.***.***.***".to_string()
}
}
_ => {
if pii.len() > 2 {
format!("{}***{}", &pii[..1], &pii[pii.len() - 1..])
} else {
"***".to_string()
}
}
}
}
fn hash_pii(&self, pii: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(pii.as_bytes());
let hash = hasher.finalize();
format!("{:x}", hash)
}
fn extract_text_from_input(&self, input: &PromptArgs) -> String {
let mut texts = Vec::new();
if let Some(input_val) = input.get("input") {
if let Some(s) = input_val.as_str() {
texts.push(s.to_string());
}
}
if let Some(messages_val) = input.get("messages") {
if let Some(messages_array) = messages_val.as_array() {
for msg in messages_array {
if let Some(content) = msg.get("content").and_then(|v| v.as_str()) {
texts.push(content.to_string());
}
}
}
}
texts.join(" ")
}
}
#[async_trait]
impl Middleware for PIIMiddleware {
async fn before_agent_plan(
&self,
input: &PromptArgs,
_steps: &[(AgentAction, String)],
_context: &mut MiddlewareContext,
) -> Result<Option<PromptArgs>, MiddlewareError> {
if !self.apply_to_input {
return Ok(None);
}
let text = self.extract_text_from_input(input);
let processed = self.process_text(&text)?;
if processed != text {
let mut modified_input = input.clone();
if modified_input.contains_key("input") {
if let Some(input_val) = modified_input.get("input") {
if let Some(original) = input_val.as_str() {
if text.contains(original) {
let new_input = processed
.replace(original, &processed)
.split_whitespace()
.next()
.unwrap_or(&processed)
.to_string();
modified_input
.insert("input".to_string(), serde_json::json!(new_input));
}
}
}
}
if let Some(messages_val) = modified_input.get_mut("messages") {
if let Some(messages_array) = messages_val.as_array_mut() {
for msg in messages_array {
if let Some(content) = msg.get_mut("content") {
if let Some(content_str) = content.as_str() {
let processed_content = self.process_text(content_str)?;
*content = serde_json::json!(processed_content);
}
}
}
}
}
return Ok(Some(modified_input));
}
Ok(None)
}
async fn after_agent_plan(
&self,
_input: &PromptArgs,
event: &AgentEvent,
_context: &mut MiddlewareContext,
) -> Result<Option<AgentEvent>, MiddlewareError> {
if !self.apply_to_output {
return Ok(None);
}
match event {
AgentEvent::Finish(finish) => {
let processed = self.process_text(&finish.output)?;
if processed != finish.output {
let mut modified_finish = finish.clone();
modified_finish.output = processed;
return Ok(Some(AgentEvent::Finish(modified_finish)));
}
}
_ => {}
}
Ok(None)
}
async fn after_tool_call(
&self,
_action: &AgentAction,
observation: &str,
_context: &mut MiddlewareContext,
) -> Result<Option<String>, MiddlewareError> {
if !self.apply_to_tool_results {
return Ok(None);
}
let processed = self.process_text(observation)?;
if processed != observation {
return Ok(Some(processed));
}
Ok(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pii_middleware_creation() {
let middleware = PIIMiddleware::new(PIIType::Email, PIIStrategy::Redact);
assert!(middleware.apply_to_input);
assert!(!middleware.apply_to_output);
}
#[test]
fn test_redact_strategy() {
let middleware = PIIMiddleware::new(PIIType::Email, PIIStrategy::Redact);
let text = "Contact test@example.com";
let result = middleware.process_text(text).unwrap();
assert!(result.contains("[REDACTED_EMAIL]"));
}
#[test]
fn test_mask_strategy() {
let middleware = PIIMiddleware::new(PIIType::Email, PIIStrategy::Mask);
let text = "Contact test@example.com";
let result = middleware.process_text(text).unwrap();
assert!(result.contains("***"));
assert!(result.contains("@example.com"));
}
#[test]
fn test_hash_strategy() {
let middleware = PIIMiddleware::new(PIIType::Email, PIIStrategy::Hash);
let text = "Contact test@example.com";
let result = middleware.process_text(text).unwrap();
assert!(result.len() == 64 || result.contains("Contact"));
}
#[test]
fn test_block_strategy() {
let middleware = PIIMiddleware::new(PIIType::Email, PIIStrategy::Block);
let text = "Contact test@example.com";
let result = middleware.process_text(text);
assert!(result.is_err());
}
}