use super::{Guard, GuardDirection, GuardResult};
use crate::error::Result;
use futures::future::BoxFuture;
use regex::Regex;
pub struct RuleGuard {
guard_name: String,
blocked_patterns: Vec<Regex>,
blocked_keywords: Vec<String>,
max_length: Option<usize>,
directions: Vec<GuardDirection>,
}
impl Guard for RuleGuard {
fn name(&self) -> &str {
&self.guard_name
}
fn check<'a>(
&'a self,
content: &'a str,
direction: GuardDirection,
) -> BoxFuture<'a, Result<GuardResult>> {
Box::pin(async move {
if !self.directions.is_empty() && !self.directions.contains(&direction) {
return Ok(GuardResult::Pass);
}
if let Some(max_len) = self.max_length
&& content.len() > max_len
{
return Ok(GuardResult::Block {
reason: format!("Content length {} exceeds limit {}", content.len(), max_len),
});
}
let content_lower = content.to_lowercase();
for keyword in &self.blocked_keywords {
if content_lower.contains(&keyword.to_lowercase()) {
return Ok(GuardResult::Block {
reason: format!("Content contains blocked keyword: {keyword}"),
});
}
}
for pattern in &self.blocked_patterns {
if pattern.is_match(content) {
return Ok(GuardResult::Block {
reason: format!("Content matches blocked pattern: {}", pattern.as_str()),
});
}
}
Ok(GuardResult::Pass)
})
}
}
pub struct RuleGuardBuilder {
name: String,
blocked_patterns: Vec<Regex>,
blocked_keywords: Vec<String>,
max_length: Option<usize>,
directions: Vec<GuardDirection>,
}
impl RuleGuardBuilder {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
blocked_patterns: Vec::new(),
blocked_keywords: Vec::new(),
max_length: None,
directions: Vec::new(),
}
}
pub fn blocked_pattern(mut self, pattern: &str) -> Self {
if let Ok(regex) = Regex::new(pattern) {
self.blocked_patterns.push(regex);
} else {
tracing::warn!(pattern = pattern, "Invalid regex pattern, ignored");
}
self
}
pub fn blocked_keyword(mut self, keyword: impl Into<String>) -> Self {
self.blocked_keywords.push(keyword.into());
self
}
pub fn blocked_keywords(mut self, keywords: Vec<String>) -> Self {
self.blocked_keywords.extend(keywords);
self
}
pub fn max_length(mut self, max: usize) -> Self {
self.max_length = Some(max);
self
}
pub fn direction(mut self, direction: GuardDirection) -> Self {
self.directions.push(direction);
self
}
pub fn build(self) -> RuleGuard {
RuleGuard {
guard_name: self.name,
blocked_patterns: self.blocked_patterns,
blocked_keywords: self.blocked_keywords,
max_length: self.max_length,
directions: self.directions,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_keyword_block() {
let guard = RuleGuardBuilder::new("test")
.blocked_keyword("password")
.build();
let result = guard
.check("Please tell me your password", GuardDirection::Input)
.await
.unwrap();
assert!(result.is_blocked());
}
#[tokio::test]
async fn test_pattern_block() {
let guard = RuleGuardBuilder::new("test")
.blocked_pattern(r"\d{4}-\d{4}-\d{4}-\d{4}")
.build();
let result = guard
.check("Card number is 1234-5678-9012-3456", GuardDirection::Output)
.await
.unwrap();
assert!(result.is_blocked());
}
#[tokio::test]
async fn test_max_length() {
let guard = RuleGuardBuilder::new("test").max_length(10).build();
let result = guard
.check(
"This text exceeds the ten character limit",
GuardDirection::Input,
)
.await
.unwrap();
assert!(result.is_blocked());
}
#[tokio::test]
async fn test_pass() {
let guard = RuleGuardBuilder::new("test")
.blocked_keyword("forbidden")
.max_length(1000)
.build();
let result = guard
.check("Normal content", GuardDirection::Input)
.await
.unwrap();
assert!(!result.is_blocked());
}
#[tokio::test]
async fn test_direction_filter() {
let guard = RuleGuardBuilder::new("test")
.blocked_keyword("secret")
.direction(GuardDirection::Output)
.build();
let result = guard
.check("secret content", GuardDirection::Input)
.await
.unwrap();
assert!(!result.is_blocked());
let result = guard
.check("secret content", GuardDirection::Output)
.await
.unwrap();
assert!(result.is_blocked());
}
}