Skip to main content

rs_agent/
helpers.rs

1//! Helper utilities for agent operation
2//!
3//! This module provides common utility functions used throughout the agent system,
4//! matching the structure from go-agent's helpers.go.
5
6/// Sanitizes user input to prevent prompt injection
7pub fn sanitize_input(s: &str) -> String {
8    let mut result = s.trim().to_string();
9    result = result.replace("\nUser:", "\nUser (quoted):");
10    result = result.replace("\nSystem:", "\nSystem (quoted):");
11    result = result.replace("\nConversation memory", "\nConversation memory (quoted)");
12    result
13}
14
15/// Escapes content for safe inclusion in prompts
16pub fn escape_prompt_content(s: &str) -> String {
17    let mut result = s.replace('`', "'");
18    result = result.replace("\nUser:", "\nUser (quoted):");
19    result = result.replace("\nSystem:", "\nSystem (quoted):");
20    result
21}
22
23/// Extracts JSON from a string that may contain additional text
24pub fn extract_json(s: &str) -> Option<String> {
25    let trimmed = s.trim();
26    
27    // Try to find JSON object or array
28    if let Some(start) = trimmed.find('{') {
29        if let Some(end) = trimmed.rfind('}') {
30            if end > start {
31                return Some(trimmed[start..=end].to_string());
32            }
33        }
34    }
35    
36    if let Some(start) = trimmed.find('[') {
37        if let Some(end) = trimmed.rfind(']') {
38            if end > start {
39                return Some(trimmed[start..=end].to_string());
40            }
41        }
42    }
43    
44    // If the whole string looks like JSON, return it
45    if (trimmed.starts_with('{') && trimmed.ends_with('}'))
46        || (trimmed.starts_with('[') && trimmed.ends_with(']'))
47    {
48        return Some(trimmed.to_string());
49    }
50    
51    None
52}
53
54/// Validates that a code snippet is safe to execute
55pub fn is_valid_snippet(code: &str) -> bool {
56    let code = code.trim();
57    
58    if code.is_empty() {
59        return false;
60    }
61    
62    // Reject obviously dangerous patterns
63    let dangerous = [
64        "rm -rf",
65        "format c:",
66        "del /f",
67        "DROP DATABASE",
68        "DROP TABLE",
69    ];
70    
71    let lower = code.to_lowercase();
72    for pattern in dangerous {
73        if lower.contains(&pattern.to_lowercase()) {
74            return false;
75        }
76    }
77    
78    true
79}
80
81/// Splits a command string into name and arguments
82pub fn split_command(s: &str) -> (&str, &str) {
83    let trimmed = s.trim();
84    if let Some(pos) = trimmed.find(char::is_whitespace) {
85        let (name, rest) = trimmed.split_at(pos);
86        (name.trim(), rest.trim())
87    } else {
88        (trimmed, "")
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95
96    #[test]
97    fn test_sanitize_input() {
98        let input = "Hello\nUser: inject\nSystem: bad";
99        let sanitized = sanitize_input(input);
100        assert!(!sanitized.contains("\nUser:"));
101        assert!(!sanitized.contains("\nSystem:"));
102        assert!(sanitized.contains("User (quoted)"));
103    }
104
105    #[test]
106    fn test_escape_prompt_content() {
107        let content = "`code`\nUser: test";
108        let escaped = escape_prompt_content(content);
109        assert!(!escaped.contains('`'));
110        assert!(escaped.contains("User (quoted)"));
111    }
112
113    #[test]
114    fn test_extract_json() {
115        assert_eq!(
116            extract_json("Some text {\"key\": \"value\"} more text"),
117            Some("{\"key\": \"value\"}".to_string())
118        );
119        assert_eq!(
120            extract_json("[1, 2, 3]"),
121            Some("[1, 2, 3]".to_string())
122        );
123        assert_eq!(extract_json("no json here"), None);
124    }
125
126    #[test]
127    fn test_is_valid_snippet() {
128        assert!(is_valid_snippet("let x = 5;"));
129        assert!(!is_valid_snippet("rm -rf /"));
130        assert!(!is_valid_snippet("DROP DATABASE users;"));
131        assert!(!is_valid_snippet(""));
132    }
133
134    #[test]
135    fn test_split_command() {
136        assert_eq!(split_command("echo hello world"), ("echo", "hello world"));
137        assert_eq!(split_command("tool"), ("tool", ""));
138        assert_eq!(split_command("  cmd  args  "), ("cmd", "args"));
139    }
140}