use std::collections::HashMap;
use async_trait::async_trait;
use serde_json::Value;
use cognis_core::error::Result;
use cognis_core::messages::base::MessageContent;
use cognis_core::messages::{Message, MessageType};
use super::redaction::{
apply_rules, builtin_detectors, get_builtin_detector, Detector, PIIMatch, RedactionStrategy,
ResolvedRedactionRule,
};
use super::types::{AgentMiddleware, AgentState};
pub struct PIIMiddleware {
pub pii_type: Option<String>,
pub strategy: RedactionStrategy,
pub apply_to_input: bool,
pub apply_to_output: bool,
pub apply_to_tool_results: bool,
pub detector: Option<Detector>,
}
impl PIIMiddleware {
pub fn new(strategy: RedactionStrategy) -> Self {
Self {
pii_type: None,
strategy,
apply_to_input: true,
apply_to_output: true,
apply_to_tool_results: false,
detector: None,
}
}
pub fn for_type(pii_type: impl Into<String>, strategy: RedactionStrategy) -> Self {
Self {
pii_type: Some(pii_type.into()),
strategy,
apply_to_input: true,
apply_to_output: true,
apply_to_tool_results: false,
detector: None,
}
}
pub fn with_detector(mut self, detector: Detector) -> Self {
self.detector = Some(detector);
self
}
fn detect(&self, text: &str) -> Vec<PIIMatch> {
if let Some(ref detector) = self.detector {
return detector(text);
}
if let Some(ref pii_type) = self.pii_type {
if let Some(detector) = get_builtin_detector(pii_type) {
return detector(text);
}
return Vec::new();
}
let mut all = Vec::new();
for (_name, detector) in builtin_detectors() {
all.extend(detector(text));
}
all.sort_by_key(|m| m.start);
all
}
fn redact_text(&self, text: &str) -> String {
let matches = self.detect(text);
if matches.is_empty() {
return text.to_string();
}
if self.strategy == RedactionStrategy::Block {
return String::new();
}
let rules: Vec<ResolvedRedactionRule> = matches
.into_iter()
.map(|m| ResolvedRedactionRule {
pii_match: m,
strategy: self.strategy.clone(),
})
.collect();
apply_rules(text, rules).unwrap_or_default()
}
fn redact_message(&self, msg: &Message) -> (Message, bool) {
let text = msg.content().text();
let redacted = self.redact_text(&text);
if redacted == text {
return (msg.clone(), false);
}
let mut new_msg = msg.clone();
let new_content = MessageContent::Text(redacted);
match &mut new_msg {
Message::Human(m) => m.base.content = new_content,
Message::Ai(m) => m.base.content = new_content,
Message::System(m) => m.base.content = new_content,
Message::Tool(m) => m.base.content = new_content,
Message::Function(m) => m.base.content = new_content,
Message::Chat(m) => m.base.content = new_content,
Message::HumanChunk(m) => m.base.content = new_content,
Message::AiChunk(m) => m.base.content = new_content,
Message::SystemChunk(m) => m.base.content = new_content,
Message::ToolChunk(m) => m.base.content = new_content,
Message::FunctionChunk(m) => m.base.content = new_content,
Message::ChatChunk(m) => m.base.content = new_content,
Message::Remove(_) => {} }
(new_msg, true)
}
fn redact_input_messages(&self, messages: &[Message]) -> Option<Vec<Message>> {
if messages.is_empty() {
return None;
}
let mut new_messages = messages.to_vec();
let mut changed = false;
if let Some(last_human_idx) = new_messages
.iter()
.rposition(|m| m.message_type() == MessageType::Human)
{
let (redacted, did_change) = self.redact_message(&new_messages[last_human_idx]);
if did_change {
new_messages[last_human_idx] = redacted;
changed = true;
}
}
if self.apply_to_tool_results {
let last_ai_idx = new_messages
.iter()
.rposition(|m| m.message_type() == MessageType::Ai);
let start_idx = last_ai_idx.map(|i| i + 1).unwrap_or(0);
for msg in new_messages.iter_mut().skip(start_idx) {
if msg.message_type() == MessageType::Tool {
let (redacted, did_change) = self.redact_message(msg);
if did_change {
*msg = redacted;
changed = true;
}
}
}
}
if changed {
Some(new_messages)
} else {
None
}
}
fn redact_output_messages(&self, messages: &[Message]) -> Option<Vec<Message>> {
if messages.is_empty() {
return None;
}
let mut new_messages = messages.to_vec();
if let Some(last_ai_idx) = new_messages
.iter()
.rposition(|m| m.message_type() == MessageType::Ai)
{
let (redacted, did_change) = self.redact_message(&new_messages[last_ai_idx]);
if did_change {
new_messages[last_ai_idx] = redacted;
return Some(new_messages);
}
}
None
}
}
#[async_trait]
impl AgentMiddleware for PIIMiddleware {
fn name(&self) -> &str {
"PIIMiddleware"
}
async fn before_model(&self, state: &AgentState) -> Result<Option<HashMap<String, Value>>> {
if !self.apply_to_input {
return Ok(None);
}
if let Some(redacted) = self.redact_input_messages(&state.messages) {
let mut updates = HashMap::new();
updates.insert("messages".into(), serde_json::to_value(&redacted)?);
Ok(Some(updates))
} else {
Ok(None)
}
}
async fn after_model(&self, state: &AgentState) -> Result<Option<HashMap<String, Value>>> {
if !self.apply_to_output {
return Ok(None);
}
if let Some(redacted) = self.redact_output_messages(&state.messages) {
let mut updates = HashMap::new();
updates.insert("messages".into(), serde_json::to_value(&redacted)?);
Ok(Some(updates))
} else {
Ok(None)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn test_pii_middleware_detect_email() {
let mw = PIIMiddleware::for_type("email", RedactionStrategy::Redact);
let matches = mw.detect("Send to user@example.com");
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].pii_type, "email");
}
#[test]
fn test_pii_middleware_redact_text() {
let mw = PIIMiddleware::for_type("email", RedactionStrategy::Redact);
let result = mw.redact_text("Contact user@example.com please");
assert!(result.contains("[REDACTED_EMAIL]"));
assert!(!result.contains("user@example.com"));
}
#[test]
fn test_pii_middleware_all_types() {
let mw = PIIMiddleware::new(RedactionStrategy::Redact);
let text = "Email: a@b.com, IP: 10.0.0.1";
let matches = mw.detect(text);
assert!(matches.len() >= 2);
}
#[test]
fn test_pii_middleware_block_strategy() {
let mw = PIIMiddleware::for_type("email", RedactionStrategy::Block);
let result = mw.redact_text("Email: a@b.com");
assert_eq!(result, "");
}
#[test]
fn test_pii_middleware_no_match() {
let mw = PIIMiddleware::for_type("email", RedactionStrategy::Redact);
let result = mw.redact_text("No PII here");
assert_eq!(result, "No PII here");
}
#[test]
fn test_pii_middleware_redact_input_messages_only_last_human() {
let mw = PIIMiddleware::for_type("email", RedactionStrategy::Redact);
let msgs = vec![
Message::human("My email is first@example.com"),
Message::ai("Got it"),
Message::human("My email is test@example.com"),
];
let redacted = mw.redact_input_messages(&msgs);
assert!(redacted.is_some());
let redacted = redacted.unwrap();
let first_text = redacted[0].content().text();
assert!(first_text.contains("first@example.com"));
let last_text = redacted[2].content().text();
assert!(last_text.contains("[REDACTED_EMAIL]"));
}
#[test]
fn test_pii_middleware_redact_output_messages_only_last_ai() {
let mw = PIIMiddleware::for_type("email", RedactionStrategy::Redact);
let msgs = vec![
Message::human("Hi"),
Message::ai("First reply with first@example.com"),
Message::human("Another"),
Message::ai("Second reply with second@example.com"),
];
let redacted = mw.redact_output_messages(&msgs);
assert!(redacted.is_some());
let redacted = redacted.unwrap();
let first_ai = redacted[1].content().text();
assert!(first_ai.contains("first@example.com"));
let last_ai = redacted[3].content().text();
assert!(last_ai.contains("[REDACTED_EMAIL]"));
}
#[test]
fn test_pii_middleware_preserves_metadata() {
let mw = PIIMiddleware::for_type("email", RedactionStrategy::Redact);
let mut ai_msg = cognis_core::messages::AIMessage::new("Email is ai@test.com");
ai_msg.base.id = Some("msg-123".to_string());
ai_msg.base.name = Some("assistant".to_string());
ai_msg.tool_calls = vec![cognis_core::messages::tool_types::ToolCall {
name: "search".into(),
args: Default::default(),
id: Some("tc-1".into()),
}];
let msgs = vec![Message::Ai(ai_msg)];
let redacted = mw.redact_output_messages(&msgs);
assert!(redacted.is_some());
let redacted = redacted.unwrap();
if let Message::Ai(ref m) = redacted[0] {
assert_eq!(m.base.id, Some("msg-123".to_string()));
assert_eq!(m.base.name, Some("assistant".to_string()));
assert_eq!(m.tool_calls.len(), 1);
assert_eq!(m.tool_calls[0].name, "search");
} else {
panic!("Expected AI message");
}
}
#[test]
fn test_pii_middleware_custom_detector() {
let custom_detector: Detector = Arc::new(|text: &str| {
let mut matches = Vec::new();
if let Some(start) = text.find("SECRET") {
matches.push(PIIMatch {
pii_type: "secret".into(),
value: "SECRET".into(),
start,
end: start + 6,
});
}
matches
});
let mw = PIIMiddleware::new(RedactionStrategy::Redact).with_detector(custom_detector);
let result = mw.redact_text("My SECRET code");
assert_eq!(result, "My [REDACTED_SECRET] code");
}
#[test]
fn test_pii_middleware_name() {
let mw = PIIMiddleware::new(RedactionStrategy::Mask);
assert_eq!(mw.name(), "PIIMiddleware");
}
#[tokio::test]
async fn test_pii_middleware_before_model_redacts() {
let mw = PIIMiddleware::for_type("email", RedactionStrategy::Redact);
let state = AgentState::new(vec![Message::human("Send to foo@bar.com")]);
let result = mw.before_model(&state).await.unwrap();
assert!(result.is_some());
}
#[tokio::test]
async fn test_pii_middleware_before_model_no_pii() {
let mw = PIIMiddleware::for_type("email", RedactionStrategy::Redact);
let state = AgentState::new(vec![Message::human("Hello world")]);
let result = mw.before_model(&state).await.unwrap();
assert!(result.is_none());
}
}