use crate::core::models::openai::*;
use crate::utils::error::Result;
use regex::Regex;
use std::collections::HashMap;
use tracing::warn;
pub struct ContentFilter {
pii_patterns: Vec<PIIPattern>,
moderation_rules: Vec<ModerationRule>,
profanity_filter: ProfanityFilter,
custom_filters: Vec<CustomFilter>,
}
#[derive(Debug, Clone)]
pub struct PIIPattern {
pub name: String,
pub pattern: Regex,
pub replacement: PIIReplacement,
pub confidence: f64,
}
#[derive(Debug, Clone)]
pub enum PIIReplacement {
Redact,
Placeholder(String),
Hash,
Remove,
PartialMask {
keep_start: usize,
keep_end: usize,
},
}
#[derive(Debug, Clone)]
pub struct ModerationRule {
pub name: String,
pub rule_type: ModerationType,
pub action: ModerationAction,
pub severity: ModerationSeverity,
}
#[derive(Debug, Clone)]
pub enum ModerationType {
HateSpeech,
Violence,
Sexual,
SelfHarm,
Harassment,
IllegalActivity,
Custom(String),
}
#[derive(Debug, Clone)]
pub enum ModerationAction {
Block,
Warn,
Log,
Modify,
HumanReview,
}
#[derive(Debug, Clone, PartialEq, PartialOrd)]
pub enum ModerationSeverity {
Low,
Medium,
High,
Critical,
}
#[derive(Debug, Clone)]
pub struct ProfanityFilter {
blocked_words: Vec<String>,
replacement_char: char,
fuzzy_matching: bool,
}
#[derive(Debug, Clone)]
pub struct CustomFilter {
pub name: String,
pub pattern: Regex,
pub action: ModerationAction,
}
#[derive(Debug, Clone)]
pub struct FilterResult {
pub blocked: bool,
pub issues: Vec<ContentIssue>,
pub modified_content: Option<String>,
pub confidence: f64,
}
#[derive(Debug, Clone)]
pub struct ContentIssue {
pub issue_type: String,
pub description: String,
pub severity: ModerationSeverity,
pub location: Option<(usize, usize)>,
pub confidence: f64,
}
pub struct GDPRCompliance {
retention_policies: HashMap<String, RetentionPolicy>,
consent_manager: ConsentManager,
export_tools: DataExportTools,
}
#[derive(Debug, Clone)]
pub struct RetentionPolicy {
pub data_type: String,
pub retention_days: u32,
pub auto_delete: bool,
pub anonymization: Option<AnonymizationRule>,
}
#[derive(Debug, Clone)]
pub struct ConsentManager {
consents: HashMap<String, UserConsent>,
}
#[derive(Debug, Clone)]
pub struct UserConsent {
pub user_id: String,
pub consented: bool,
pub timestamp: chrono::DateTime<chrono::Utc>,
pub version: String,
pub permissions: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct DataExportTools {
formats: Vec<ExportFormat>,
}
#[derive(Debug, Clone)]
pub enum ExportFormat {
JSON,
CSV,
XML,
PDF,
}
#[derive(Debug, Clone)]
pub struct AnonymizationRule {
pub fields: Vec<String>,
pub method: AnonymizationMethod,
}
#[derive(Debug, Clone)]
pub enum AnonymizationMethod {
Randomize,
Hash,
Remove,
Generalize,
}
impl ContentFilter {
pub fn new() -> Self {
Self {
pii_patterns: Self::default_pii_patterns(),
moderation_rules: Self::default_moderation_rules(),
profanity_filter: ProfanityFilter::new(),
custom_filters: Vec::new(),
}
}
pub async fn filter_chat_request(
&self,
request: &mut ChatCompletionRequest,
) -> Result<FilterResult> {
let mut issues = Vec::new();
let mut blocked = false;
let mut modified = false;
for message in &mut request.messages {
if let Some(MessageContent::Text(text)) = &mut message.content {
let result = self.filter_text(text).await?;
if result.blocked {
blocked = true;
}
issues.extend(result.issues);
if let Some(modified_text) = result.modified_content {
*text = modified_text;
modified = true;
}
}
}
let confidence = if issues.is_empty() {
1.0
} else {
issues.iter().map(|i| i.confidence).sum::<f64>() / issues.len() as f64
};
Ok(FilterResult {
blocked,
issues,
modified_content: if modified {
Some("Messages modified".to_string())
} else {
None
},
confidence,
})
}
pub async fn filter_text(&self, text: &str) -> Result<FilterResult> {
let mut issues = Vec::new();
let mut modified_text = text.to_string();
let mut blocked = false;
for pattern in &self.pii_patterns {
if let Some(captures) = pattern.pattern.captures(text) {
let issue = ContentIssue {
issue_type: format!("PII_{}", pattern.name),
description: format!("Detected {} in content", pattern.name),
severity: ModerationSeverity::High,
location: captures.get(0).map(|m| (m.start(), m.end())),
confidence: pattern.confidence,
};
issues.push(issue);
modified_text = self.apply_pii_replacement(&modified_text, pattern)?;
}
}
for rule in &self.moderation_rules {
if self.check_moderation_rule(&modified_text, rule).await? {
let issue = ContentIssue {
issue_type: format!("MODERATION_{:?}", rule.rule_type),
description: format!("Content flagged for {:?}", rule.rule_type),
severity: rule.severity.clone(),
location: None,
confidence: 0.8, };
issues.push(issue);
match rule.action {
ModerationAction::Block => blocked = true,
ModerationAction::Warn => warn!("Content warning: {:?}", rule.rule_type),
_ => {}
}
}
}
if self.profanity_filter.contains_profanity(&modified_text) {
modified_text = self.profanity_filter.filter(&modified_text);
issues.push(ContentIssue {
issue_type: "PROFANITY".to_string(),
description: "Profanity detected and filtered".to_string(),
severity: ModerationSeverity::Medium,
location: None,
confidence: 0.9,
});
}
let confidence = if issues.is_empty() {
1.0
} else {
issues.iter().map(|i| i.confidence).sum::<f64>() / issues.len() as f64
};
Ok(FilterResult {
blocked,
issues,
modified_content: if modified_text != text {
Some(modified_text)
} else {
None
},
confidence,
})
}
fn default_pii_patterns() -> Vec<PIIPattern> {
vec![
PIIPattern {
name: "SSN".to_string(),
pattern: Regex::new(r"\b\d{3}-\d{2}-\d{4}\b").unwrap(),
replacement: PIIReplacement::Placeholder("XXX-XX-XXXX".to_string()),
confidence: 0.95,
},
PIIPattern {
name: "Email".to_string(),
pattern: Regex::new(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b")
.unwrap(),
replacement: PIIReplacement::PartialMask {
keep_start: 2,
keep_end: 0,
},
confidence: 0.9,
},
PIIPattern {
name: "Phone".to_string(),
pattern: Regex::new(r"\b\d{3}-\d{3}-\d{4}\b").unwrap(),
replacement: PIIReplacement::Placeholder("XXX-XXX-XXXX".to_string()),
confidence: 0.85,
},
PIIPattern {
name: "CreditCard".to_string(),
pattern: Regex::new(r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b").unwrap(),
replacement: PIIReplacement::Placeholder("XXXX-XXXX-XXXX-XXXX".to_string()),
confidence: 0.9,
},
]
}
fn default_moderation_rules() -> Vec<ModerationRule> {
vec![
ModerationRule {
name: "Hate Speech".to_string(),
rule_type: ModerationType::HateSpeech,
action: ModerationAction::Block,
severity: ModerationSeverity::High,
},
ModerationRule {
name: "Violence".to_string(),
rule_type: ModerationType::Violence,
action: ModerationAction::Warn,
severity: ModerationSeverity::Medium,
},
]
}
fn apply_pii_replacement(&self, text: &str, pattern: &PIIPattern) -> Result<String> {
let result = match &pattern.replacement {
PIIReplacement::Redact => pattern.pattern.replace_all(text, "***").to_string(),
PIIReplacement::Placeholder(placeholder) => pattern
.pattern
.replace_all(text, placeholder.as_str())
.to_string(),
PIIReplacement::Hash => {
pattern.pattern.replace_all(text, "[HASHED]").to_string()
}
PIIReplacement::Remove => pattern.pattern.replace_all(text, "").to_string(),
PIIReplacement::PartialMask {
keep_start,
keep_end,
} => {
pattern
.pattern
.replace_all(text, |caps: ®ex::Captures| {
let matched = caps.get(0).unwrap().as_str();
let len = matched.len();
if len <= keep_start + keep_end {
"*".repeat(len)
} else {
let start = &matched[..*keep_start];
let end = if *keep_end > 0 {
&matched[len - keep_end..]
} else {
""
};
let middle = "*".repeat(len - keep_start - keep_end);
format!("{}{}{}", start, middle, end)
}
})
.to_string()
}
};
Ok(result)
}
async fn check_moderation_rule(&self, text: &str, rule: &ModerationRule) -> Result<bool> {
match rule.rule_type {
ModerationType::HateSpeech => {
Ok(text.to_lowercase().contains("hate") || text.to_lowercase().contains("racist"))
}
ModerationType::Violence => Ok(
text.to_lowercase().contains("violence") || text.to_lowercase().contains("kill")
),
_ => Ok(false),
}
}
}
impl ProfanityFilter {
pub fn new() -> Self {
Self {
blocked_words: vec![
"badword1".to_string(),
"badword2".to_string(),
],
replacement_char: '*',
fuzzy_matching: true,
}
}
pub fn contains_profanity(&self, text: &str) -> bool {
let lower_text = text.to_lowercase();
self.blocked_words
.iter()
.any(|word| lower_text.contains(word))
}
pub fn filter(&self, text: &str) -> String {
let mut result = text.to_string();
for word in &self.blocked_words {
let replacement = self.replacement_char.to_string().repeat(word.len());
result = result.replace(word, &replacement);
}
result
}
}
impl Default for ContentFilter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_pii_detection() {
let filter = ContentFilter::new();
let text = "My SSN is 123-45-6789 and email is test@example.com";
let result = filter.filter_text(text).await.unwrap();
assert!(!result.issues.is_empty());
assert!(result.modified_content.is_some());
}
#[tokio::test]
async fn test_profanity_filter() {
let filter = ProfanityFilter::new();
assert!(filter.contains_profanity("This contains badword1"));
let filtered = filter.filter("This contains badword1");
assert!(!filtered.contains("badword1"));
}
}