ai_agent/utils/
query_guard.rs1use std::collections::HashSet;
4
5pub struct QueryGuard {
7 blocked_patterns: HashSet<String>,
8 max_length: usize,
9}
10
11impl QueryGuard {
12 pub fn new() -> Self {
13 let blocked_patterns = vec![
14 "rm -rf /".to_string(),
15 "format c:".to_string(),
16 "del /f /s /q".to_string(),
17 ]
18 .into_iter()
19 .collect();
20
21 Self {
22 blocked_patterns,
23 max_length: 10000,
24 }
25 }
26
27 pub fn validate(&self, query: &str) -> Result<(), QueryGuardError> {
29 if query.len() > self.max_length {
31 return Err(QueryGuardError::TooLong(query.len()));
32 }
33
34 for pattern in &self.blocked_patterns {
36 if query.contains(pattern) {
37 return Err(QueryGuardError::BlockedPattern(pattern.clone()));
38 }
39 }
40
41 Ok(())
42 }
43
44 pub fn sanitize(&self, query: &str) -> String {
46 let sanitized = query.replace('\0', "");
48
49 sanitized.trim().to_string()
51 }
52}
53
54impl Default for QueryGuard {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60#[derive(Debug, Clone)]
62pub enum QueryGuardError {
63 TooLong(usize),
64 BlockedPattern(String),
65}
66
67impl std::fmt::Display for QueryGuardError {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 match self {
70 QueryGuardError::TooLong(len) => write!(f, "Query too long: {} characters", len),
71 QueryGuardError::BlockedPattern(pattern) => {
72 write!(f, "Query contains blocked pattern: {}", pattern)
73 }
74 }
75 }
76}