echo_core 0.1.2

Core traits and types for the echo-agent framework
Documentation
//! Rule-based guard
//!
//! Supports regex matching, keyword blacklisting, and content length limiting.

use super::{Guard, GuardDirection, GuardResult};
use crate::error::Result;
use futures::future::BoxFuture;
use regex::Regex;

/// Rule guard
///
/// Synchronously filters content using regex patterns, keyword blacklists,
/// and length limits.
///
/// # Example
///
/// ```rust
/// use echo_core::guard::rule::RuleGuardBuilder;
///
/// let guard = RuleGuardBuilder::new("content-filter")
///     .blocked_keyword("password")
///     .blocked_keyword("token")
///     .blocked_pattern(r"\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b") // bank card number
///     .max_length(10000)
///     .build();
/// ```
pub struct RuleGuard {
    guard_name: String,
    blocked_patterns: Vec<Regex>,
    blocked_keywords: Vec<String>,
    max_length: Option<usize>,
    /// Directions to check (when empty, checks all directions)
    directions: Vec<GuardDirection>,
}

impl Guard for RuleGuard {
    fn name(&self) -> &str {
        &self.guard_name
    }

    fn check<'a>(
        &'a self,
        content: &'a str,
        direction: GuardDirection,
    ) -> BoxFuture<'a, Result<GuardResult>> {
        Box::pin(async move {
            if !self.directions.is_empty() && !self.directions.contains(&direction) {
                return Ok(GuardResult::Pass);
            }

            if let Some(max_len) = self.max_length
                && content.len() > max_len
            {
                return Ok(GuardResult::Block {
                    reason: format!("Content length {} exceeds limit {}", content.len(), max_len),
                });
            }

            let content_lower = content.to_lowercase();
            for keyword in &self.blocked_keywords {
                if content_lower.contains(&keyword.to_lowercase()) {
                    return Ok(GuardResult::Block {
                        reason: format!("Content contains blocked keyword: {keyword}"),
                    });
                }
            }

            for pattern in &self.blocked_patterns {
                if pattern.is_match(content) {
                    return Ok(GuardResult::Block {
                        reason: format!("Content matches blocked pattern: {}", pattern.as_str()),
                    });
                }
            }

            Ok(GuardResult::Pass)
        })
    }
}

/// RuleGuard builder
pub struct RuleGuardBuilder {
    name: String,
    blocked_patterns: Vec<Regex>,
    blocked_keywords: Vec<String>,
    max_length: Option<usize>,
    directions: Vec<GuardDirection>,
}

impl RuleGuardBuilder {
    /// Create a new rule guard builder
    pub fn new(name: impl Into<String>) -> Self {
        Self {
            name: name.into(),
            blocked_patterns: Vec::new(),
            blocked_keywords: Vec::new(),
            max_length: None,
            directions: Vec::new(),
        }
    }

    /// Add a regex match rule (block on match)
    pub fn blocked_pattern(mut self, pattern: &str) -> Self {
        if let Ok(regex) = Regex::new(pattern) {
            self.blocked_patterns.push(regex);
        } else {
            tracing::warn!(pattern = pattern, "Invalid regex pattern, ignored");
        }
        self
    }

    /// Add a keyword to the blacklist (case-insensitive, block on match)
    pub fn blocked_keyword(mut self, keyword: impl Into<String>) -> Self {
        self.blocked_keywords.push(keyword.into());
        self
    }

    /// Batch-add keywords to the blacklist
    pub fn blocked_keywords(mut self, keywords: Vec<String>) -> Self {
        self.blocked_keywords.extend(keywords);
        self
    }

    /// Set maximum content length
    pub fn max_length(mut self, max: usize) -> Self {
        self.max_length = Some(max);
        self
    }

    /// Restrict check direction (checks all directions if not set)
    pub fn direction(mut self, direction: GuardDirection) -> Self {
        self.directions.push(direction);
        self
    }

    /// Build the rule guard
    pub fn build(self) -> RuleGuard {
        RuleGuard {
            guard_name: self.name,
            blocked_patterns: self.blocked_patterns,
            blocked_keywords: self.blocked_keywords,
            max_length: self.max_length,
            directions: self.directions,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_keyword_block() {
        let guard = RuleGuardBuilder::new("test")
            .blocked_keyword("password")
            .build();

        let result = guard
            .check("Please tell me your password", GuardDirection::Input)
            .await
            .unwrap();
        assert!(result.is_blocked());
    }

    #[tokio::test]
    async fn test_pattern_block() {
        let guard = RuleGuardBuilder::new("test")
            .blocked_pattern(r"\d{4}-\d{4}-\d{4}-\d{4}")
            .build();

        let result = guard
            .check("Card number is 1234-5678-9012-3456", GuardDirection::Output)
            .await
            .unwrap();
        assert!(result.is_blocked());
    }

    #[tokio::test]
    async fn test_max_length() {
        let guard = RuleGuardBuilder::new("test").max_length(10).build();

        let result = guard
            .check(
                "This text exceeds the ten character limit",
                GuardDirection::Input,
            )
            .await
            .unwrap();
        assert!(result.is_blocked());
    }

    #[tokio::test]
    async fn test_pass() {
        let guard = RuleGuardBuilder::new("test")
            .blocked_keyword("forbidden")
            .max_length(1000)
            .build();

        let result = guard
            .check("Normal content", GuardDirection::Input)
            .await
            .unwrap();
        assert!(!result.is_blocked());
    }

    #[tokio::test]
    async fn test_direction_filter() {
        let guard = RuleGuardBuilder::new("test")
            .blocked_keyword("secret")
            .direction(GuardDirection::Output)
            .build();

        // Input 方向不检查
        let result = guard
            .check("secret content", GuardDirection::Input)
            .await
            .unwrap();
        assert!(!result.is_blocked());

        // Output 方向检查
        let result = guard
            .check("secret content", GuardDirection::Output)
            .await
            .unwrap();
        assert!(result.is_blocked());
    }
}