rs_utcp/
security.rs

1use anyhow::{anyhow, Result};
2use std::path::PathBuf;
3
4/// Security utilities for validating inputs and preventing common vulnerabilities.
5
6/// Validates that a file path is safe and doesn't allow directory traversal.
7///
8/// # Arguments
9/// * `path` - The path to validate
10/// * `allowed_base` - Optional base directory that the path must be within
11///
12/// # Returns
13/// Canonicalized path if valid, error otherwise
14pub fn validate_file_path(path: &str, allowed_base: Option<&str>) -> Result<PathBuf> {
15    let path_buf = PathBuf::from(path);
16
17    // Prevent absolute paths from escaping the base
18    if let Some(base) = allowed_base {
19        let canon_path = std::fs::canonicalize(&path_buf)
20            .map_err(|e| anyhow!("Failed to canonicalize path '{}': {}", path, e))?;
21
22        let canon_base = std::fs::canonicalize(base)
23            .map_err(|e| anyhow!("Failed to canonicalize base '{}': {}", base, e))?;
24
25        if !canon_path.starts_with(canon_base) {
26            return Err(anyhow!(
27                "Path '{}' is outside allowed directory '{}'",
28                path,
29                base
30            ));
31        }
32
33        Ok(canon_path)
34    } else {
35        // Just canonicalize without base restriction
36        std::fs::canonicalize(&path_buf).map_err(|e| anyhow!("Invalid path '{}': {}", path, e))
37    }
38}
39
40/// Validates a command name against an allowlist.
41/// This helps prevent command injection attacks.
42///
43/// # Arguments
44/// * `command` - The command to validate
45/// * `allowed_commands` - List of permitted command names or paths
46///
47/// # Returns
48/// Ok if command is in allowlist, error otherwise
49pub fn validate_command(command: &str, allowed_commands: &[&str]) -> Result<()> {
50    // Check for shell metacharacters that could enable injection
51    const DANGEROUS_CHARS: &[char] = &[
52        '|', '&', ';', '\n', '`', '$', '(', ')', '<', '>', '"', '\'', '\\',
53    ];
54
55    if command.chars().any(|c| DANGEROUS_CHARS.contains(&c)) {
56        return Err(anyhow!(
57            "Command contains dangerous characters: '{}'",
58            command
59        ));
60    }
61
62    // Extract just the command name (first component of path)
63    let path_buf = PathBuf::from(command);
64    let cmd_name = path_buf
65        .file_name()
66        .and_then(|s| s.to_str())
67        .unwrap_or(command);
68
69    // Check against allowlist
70    if !allowed_commands.is_empty() && !allowed_commands.contains(&cmd_name) {
71        return Err(anyhow!(
72            "Command '{}' is not in the allowed list. Permitted commands: {:?}",
73            cmd_name,
74            allowed_commands
75        ));
76    }
77
78    Ok(())
79}
80
81/// Validates command arguments for potentially dangerous content.
82///
83/// # Arguments
84/// * `args` - The arguments to validate
85///
86/// # Returns
87/// Ok if arguments appear safe, error otherwise
88pub fn validate_command_args(args: &[String]) -> Result<()> {
89    for arg in args {
90        // Check for shell injection patterns
91        if arg.contains("&&") || arg.contains("||") || arg.contains(";") || arg.contains("|") {
92            return Err(anyhow!(
93                "Argument contains dangerous shell operators: '{}'",
94                arg
95            ));
96        }
97
98        // Check for command substitution
99        if arg.contains("$(") || arg.contains("`") {
100            return Err(anyhow!("Argument contains command substitution: '{}'", arg));
101        }
102    }
103
104    Ok(())
105}
106
107/// Validates that a URL uses a secure protocol (https://, wss://, etc.)
108///
109/// # Arguments
110/// * `url` - The URL to validate
111/// * `require_tls` - Whether to enforce TLS/SSL
112///
113/// # Returns
114/// Ok if URL is valid and secure, error otherwise
115pub fn validate_url_security(url: &str, require_tls: bool) -> Result<()> {
116    let url_lower = url.to_lowercase();
117
118    if require_tls {
119        if !(url_lower.starts_with("https://")
120            || url_lower.starts_with("wss://")
121            || url_lower.starts_with("grpcs://"))
122        {
123            return Err(anyhow!(
124                "URL must use TLS/SSL (https://, wss://, grpcs://): '{}'",
125                url
126            ));
127        }
128    }
129
130    // Warn about localhost/127.0.0.1 in production (but allow it)
131    if url_lower.contains("localhost") || url_lower.contains("127.0.0.1") {
132        // This is just informational - don't fail
133        eprintln!("Warning: URL uses localhost/127.0.0.1: '{}'", url);
134    }
135
136    Ok(())
137}
138
139/// Validates the size of input data to prevent DoS attacks.
140///
141/// # Arguments
142/// * `data` - The data to check
143/// * `max_size` - Maximum allowed size in bytes
144///
145/// # Returns
146/// Ok if data is within limits, error otherwise
147pub fn validate_size_limit(data: &[u8], max_size: usize) -> Result<()> {
148    if data.len() > max_size {
149        return Err(anyhow!(
150            "Data size {} bytes exceeds maximum allowed size {} bytes",
151            data.len(),
152            max_size
153        ));
154    }
155
156    Ok(())
157}
158
159/// Validates a timeout value to ensure it's reasonable.
160///
161/// # Arguments
162/// * `timeout_ms` - Timeout in milliseconds
163/// * `max_timeout_ms` - Maximum allowed timeout
164///
165/// # Returns
166/// Ok if timeout is within limits, error otherwise
167pub fn validate_timeout(timeout_ms: u64, max_timeout_ms: u64) -> Result<()> {
168    if timeout_ms == 0 {
169        return Err(anyhow!("Timeout cannot be zero"));
170    }
171
172    if timeout_ms > max_timeout_ms {
173        return Err(anyhow!(
174            "Timeout {}ms exceeds maximum allowed {}ms",
175            timeout_ms,
176            max_timeout_ms
177        ));
178    }
179
180    Ok(())
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use std::fs;
187    use tempfile::TempDir;
188
189    #[test]
190    fn test_validate_command_rejects_dangerous_chars() {
191        let allowed = vec!["python3", "node"];
192
193        assert!(validate_command("python3", &allowed).is_ok());
194        assert!(validate_command("ls; rm -rf /", &[]).is_err());
195        assert!(validate_command("cat /etc/passwd | grep root", &[]).is_err());
196        assert!(validate_command("echo `whoami`", &[]).is_err());
197        assert!(validate_command("cmd && evil", &[]).is_err());
198    }
199
200    #[test]
201    fn test_validate_command_allowlist() {
202        let allowed = vec!["python3", "node", "npm"];
203
204        assert!(validate_command("python3", &allowed).is_ok());
205        assert!(validate_command("node", &allowed).is_ok());
206        assert!(validate_command("bash", &allowed).is_err());
207        assert!(validate_command("/usr/bin/python3", &allowed).is_ok()); // Path is ok if basename matches
208    }
209
210    #[test]
211    fn test_validate_command_args() {
212        assert!(validate_command_args(&["--help".to_string()]).is_ok());
213        assert!(validate_command_args(&["-v".to_string(), "file.txt".to_string()]).is_ok());
214
215        assert!(validate_command_args(&["arg && evil".to_string()]).is_err());
216        assert!(validate_command_args(&["$(whoami)".to_string()]).is_err());
217        assert!(validate_command_args(&["`id`".to_string()]).is_err());
218        assert!(validate_command_args(&["arg | grep".to_string()]).is_err());
219    }
220
221    #[test]
222    fn test_validate_url_security() {
223        assert!(validate_url_security("https://api.example.com", true).is_ok());
224        assert!(validate_url_security("wss://ws.example.com", true).is_ok());
225        assert!(validate_url_security("http://api.example.com", true).is_err());
226        assert!(validate_url_security("http://api.example.com", false).is_ok());
227    }
228
229    #[test]
230    fn test_validate_size_limit() {
231        let data = vec![0u8; 1000];
232        assert!(validate_size_limit(&data, 2000).is_ok());
233        assert!(validate_size_limit(&data, 500).is_err());
234    }
235
236    #[test]
237    fn test_validate_timeout() {
238        assert!(validate_timeout(1000, 60000).is_ok());
239        assert!(validate_timeout(0, 60000).is_err());
240        assert!(validate_timeout(100000, 60000).is_err());
241    }
242
243    #[test]
244    fn test_validate_file_path() {
245        let temp_dir = TempDir::new().unwrap();
246        let temp_path = temp_dir.path();
247
248        // Create a test file
249        let test_file = temp_path.join("test.txt");
250        fs::write(&test_file, b"test").unwrap();
251
252        // Valid path within base
253        let result = validate_file_path(
254            test_file.to_str().unwrap(),
255            Some(temp_path.to_str().unwrap()),
256        );
257        assert!(result.is_ok());
258
259        // Path outside base should fail
260        let outside_path = "/tmp/outside.txt";
261        let result = validate_file_path(outside_path, Some(temp_path.to_str().unwrap()));
262        // This will fail because /tmp/outside.txt doesn't exist or is outside temp_dir
263        assert!(result.is_err());
264    }
265}