Skip to main content

ai_agent/utils/
query_guard.rs

1//! Query guard utilities for protecting against malicious queries.
2
3use std::collections::HashSet;
4
5/// A guard that validates and sanitizes queries
6pub struct QueryGuard {
7    blocked_patterns: HashSet<String>,
8    max_length: usize,
9}
10
11impl QueryGuard {
12    pub fn new() -> Self {
13        let blocked_patterns = vec![
14            "rm -rf /".to_string(),
15            "format c:".to_string(),
16            "del /f /s /q".to_string(),
17        ]
18        .into_iter()
19        .collect();
20
21        Self {
22            blocked_patterns,
23            max_length: 10000,
24        }
25    }
26
27    /// Validate a query
28    pub fn validate(&self, query: &str) -> Result<(), QueryGuardError> {
29        // Check length
30        if query.len() > self.max_length {
31            return Err(QueryGuardError::TooLong(query.len()));
32        }
33
34        // Check for blocked patterns
35        for pattern in &self.blocked_patterns {
36            if query.contains(pattern) {
37                return Err(QueryGuardError::BlockedPattern(pattern.clone()));
38            }
39        }
40
41        Ok(())
42    }
43
44    /// Sanitize a query
45    pub fn sanitize(&self, query: &str) -> String {
46        // Remove null bytes
47        let sanitized = query.replace('\0', "");
48
49        // Trim whitespace
50        sanitized.trim().to_string()
51    }
52}
53
54impl Default for QueryGuard {
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60/// Query guard errors
61#[derive(Debug, Clone)]
62pub enum QueryGuardError {
63    TooLong(usize),
64    BlockedPattern(String),
65}
66
67impl std::fmt::Display for QueryGuardError {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        match self {
70            QueryGuardError::TooLong(len) => write!(f, "Query too long: {} characters", len),
71            QueryGuardError::BlockedPattern(pattern) => {
72                write!(f, "Query contains blocked pattern: {}", pattern)
73            }
74        }
75    }
76}