use super::detector::{PiiDetector, PiiMatch, PiiType};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RedactionStyle {
Placeholder,
Asterisks,
Xs,
PartialMask,
Custom,
}
#[derive(Debug, Clone)]
pub struct PiiRedactor {
detector: PiiDetector,
style: RedactionStyle,
custom_placeholders: std::collections::HashMap<PiiType, String>,
}
impl Default for PiiRedactor {
fn default() -> Self {
Self::new()
}
}
impl PiiRedactor {
#[must_use]
pub fn new() -> Self {
Self {
detector: PiiDetector::new(),
style: RedactionStyle::Placeholder,
custom_placeholders: std::collections::HashMap::new(),
}
}
#[must_use]
pub fn with_detector(detector: PiiDetector) -> Self {
Self {
detector,
..Self::new()
}
}
#[must_use]
pub const fn style(mut self, style: RedactionStyle) -> Self {
self.style = style;
self
}
#[must_use]
pub fn custom_placeholder(mut self, pii_type: PiiType, placeholder: impl Into<String>) -> Self {
self.custom_placeholders
.insert(pii_type, placeholder.into());
self
}
#[must_use]
pub fn redact(&self, text: &str) -> String {
let matches = self.detector.detect(text);
if matches.is_empty() {
return text.to_string();
}
let mut result = String::with_capacity(text.len());
let mut last_end = 0;
for m in &matches {
result.push_str(&text[last_end..m.start]);
result.push_str(&self.get_replacement(m));
last_end = m.end;
}
result.push_str(&text[last_end..]);
result
}
#[must_use]
pub fn redact_bytes(&self, data: &[u8]) -> Vec<u8> {
let text = String::from_utf8_lossy(data);
self.redact(&text).into_bytes()
}
fn get_replacement(&self, m: &PiiMatch) -> String {
if !m.is_custom()
&& let Some(custom) = self.custom_placeholders.get(&m.pii_type)
{
return custom.clone();
}
match self.style {
RedactionStyle::Placeholder => m.placeholder().to_string(),
RedactionStyle::Asterisks => "*".repeat(m.len()),
RedactionStyle::Xs => "X".repeat(m.len()),
RedactionStyle::PartialMask => self.partial_mask(&m.text),
RedactionStyle::Custom => m.placeholder().to_string(),
}
}
#[allow(clippy::unused_self)]
fn partial_mask(&self, text: &str) -> String {
let chars: Vec<char> = text.chars().collect();
if chars.len() <= 4 {
return "*".repeat(chars.len());
}
let visible = 2;
let hidden = chars.len() - (visible * 2);
format!(
"{}{}{}",
chars[..visible].iter().collect::<String>(),
"*".repeat(hidden),
chars[chars.len() - visible..].iter().collect::<String>()
)
}
#[must_use]
pub const fn detector(&self) -> &PiiDetector {
&self.detector
}
}
pub struct StreamingRedactor {
redactor: PiiRedactor,
buffer: String,
max_buffer: usize,
}
impl StreamingRedactor {
#[must_use]
pub const fn new(redactor: PiiRedactor) -> Self {
Self {
redactor,
buffer: String::new(),
max_buffer: 4096,
}
}
#[must_use]
pub const fn max_buffer(mut self, size: usize) -> Self {
self.max_buffer = size;
self
}
pub fn process(&mut self, data: &str) -> String {
self.buffer.push_str(data);
let safe_point = self.find_safe_point();
if safe_point > 0 {
let to_process = self.buffer[..safe_point].to_string();
self.buffer = self.buffer[safe_point..].to_string();
self.redactor.redact(&to_process)
} else {
String::new()
}
}
pub fn flush(&mut self) -> String {
let remaining = std::mem::take(&mut self.buffer);
self.redactor.redact(&remaining)
}
fn find_safe_point(&self) -> usize {
if self.buffer.len() >= self.max_buffer {
return self.max_buffer;
}
if let Some(pos) = self.buffer.rfind('\n') {
return pos + 1;
}
if self.buffer.len() > 100
&& let Some(pos) = self.buffer[..100].rfind(' ')
{
return pos + 1;
}
0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn redact_ssn() {
let redactor = PiiRedactor::new();
let result = redactor.redact("My SSN is 123-45-6789");
assert!(result.contains("[SSN REDACTED]"));
assert!(!result.contains("123-45-6789"));
}
#[test]
fn redact_email() {
let redactor = PiiRedactor::new();
let result = redactor.redact("Email: user@example.com");
assert!(result.contains("[EMAIL REDACTED]"));
}
#[test]
fn redact_asterisks() {
let redactor = PiiRedactor::new().style(RedactionStyle::Asterisks);
let result = redactor.redact("SSN: 123-45-6789");
assert!(result.contains("***********"));
}
#[test]
fn partial_mask() {
let redactor = PiiRedactor::new().style(RedactionStyle::PartialMask);
let result = redactor.redact("Email: user@example.com");
assert!(!result.contains("user@example.com"));
}
#[test]
fn custom_placeholder() {
let redactor = PiiRedactor::new().custom_placeholder(PiiType::Email, "***EMAIL***");
let result = redactor.redact("Contact: test@test.com");
assert!(result.contains("***EMAIL***"));
}
#[test]
fn streaming_redactor() {
let redactor = PiiRedactor::new();
let mut streaming = StreamingRedactor::new(redactor);
let out1 = streaming.process("Email: user@");
let out2 = streaming.process("example.com\n");
let out3 = streaming.flush();
let combined = format!("{out1}{out2}{out3}");
assert!(!combined.contains("user@example.com"));
}
#[test]
fn redact_custom_pattern() {
let detector = PiiDetector::new().add_pattern(
"employee_id",
r"EMP-\d{6}",
"[EMPLOYEE ID REDACTED]",
0.9,
);
let redactor = PiiRedactor::with_detector(detector);
let result = redactor.redact("Contact EMP-123456 for assistance");
assert!(result.contains("[EMPLOYEE ID REDACTED]"));
assert!(!result.contains("EMP-123456"));
}
#[test]
fn redact_custom_with_builtin() {
let detector =
PiiDetector::new().add_pattern("project", r"PROJ-[A-Z]{4}", "[PROJECT]", 0.9);
let redactor = PiiRedactor::with_detector(detector);
let result = redactor.redact("PROJ-ABCD owner: user@example.com");
assert!(result.contains("[PROJECT]"));
assert!(result.contains("[EMAIL REDACTED]"));
assert!(!result.contains("PROJ-ABCD"));
assert!(!result.contains("user@example.com"));
}
#[test]
fn redact_custom_asterisks() {
let detector = PiiDetector::custom_only().add_pattern("code", r"CODE-\d{4}", "[CODE]", 0.9);
let redactor = PiiRedactor::with_detector(detector).style(RedactionStyle::Asterisks);
let result = redactor.redact("Use CODE-1234 to access");
assert!(result.contains("*********")); assert!(!result.contains("CODE-1234"));
}
}