Skip to main content

agent_diva_core/security/
path.rs

1//! Path validation utilities for security policy
2
3use std::path::{Component, Path, PathBuf};
4
5/// Validates paths against security threats
6pub struct PathValidator;
7
8impl PathValidator {
9    /// Layer 1: Check for null bytes
10    pub fn contains_null_bytes(path: &str) -> bool {
11        path.contains('\0')
12    }
13
14    /// Layer 2: Check for path traversal components (../)
15    pub fn contains_path_traversal(path: &str) -> bool {
16        Path::new(path)
17            .components()
18            .any(|c| matches!(c, Component::ParentDir))
19    }
20
21    /// Layer 3: Check for URL-encoded traversal
22    pub fn contains_url_encoded_traversal(path: &str) -> bool {
23        let lower = path.to_lowercase();
24        lower.contains("..%2f")
25            || lower.contains("%2f..")
26            || lower.contains("..%5c")
27            || lower.contains("%5c..")
28    }
29
30    /// Layer 4: Check for tilde expansion (~user)
31    pub fn starts_with_tilde(path: &str) -> bool {
32        path.starts_with('~')
33    }
34
35    /// Layer 5: Check if path is absolute
36    pub fn is_absolute(path: &str) -> bool {
37        Path::new(path).is_absolute()
38    }
39
40    /// Layer 6: Check against forbidden path prefixes
41    pub fn matches_forbidden_prefix(path: &str, forbidden: &[String]) -> Option<String> {
42        let normalized = path.to_lowercase().replace('\\', "/");
43        for prefix in forbidden {
44            let norm_prefix = prefix.to_lowercase().replace('\\', "/");
45            if normalized.starts_with(&norm_prefix)
46                || normalized.contains(&format!("/{}", norm_prefix))
47            {
48                return Some(prefix.clone());
49            }
50        }
51        None
52    }
53
54    /// Normalize a path for comparison
55    pub fn normalize_path(path: &str) -> String {
56        path.replace('\\', "/")
57            .to_lowercase()
58            .trim_start_matches('/')
59            .to_string()
60    }
61
62    /// Check if a resolved path is within allowed roots
63    pub fn is_within_allowed_roots(resolved: &Path, allowed_roots: &[PathBuf]) -> bool {
64        // Try to canonicalize for comparison
65        let resolved_canonical = if let Ok(c) = resolved.canonicalize() {
66            c
67        } else {
68            resolved.to_path_buf()
69        };
70
71        for root in allowed_roots {
72            let root_canonical = if let Ok(c) = root.canonicalize() {
73                c
74            } else {
75                root.clone()
76            };
77
78            if resolved_canonical.starts_with(&root_canonical) {
79                return true;
80            }
81        }
82
83        false
84    }
85
86    /// Validate that a path doesn't escape the workspace via symlinks
87    pub async fn validate_no_symlink_escape(path: &Path, workspace: &Path) -> Result<(), String> {
88        // Check if the path itself is a symlink
89        if let Ok(meta) = tokio::fs::symlink_metadata(path).await {
90            if meta.file_type().is_symlink() {
91                return Err(format!("Path is a symbolic link: {}", path.display()));
92            }
93        }
94
95        // Check all parent directories
96        let mut current = path.parent();
97        while let Some(parent) = current {
98            if parent.as_os_str().is_empty() || parent == Path::new("/") {
99                break;
100            }
101
102            if let Ok(meta) = tokio::fs::symlink_metadata(parent).await {
103                if meta.file_type().is_symlink() {
104                    // Resolve the symlink and check if it's within workspace
105                    let resolved = tokio::fs::canonicalize(parent).await.map_err(|e| {
106                        format!("Failed to resolve symlink {}: {}", parent.display(), e)
107                    })?;
108
109                    if !resolved.starts_with(workspace) {
110                        return Err(format!(
111                            "Symlink {} escapes workspace (resolves to {})",
112                            parent.display(),
113                            resolved.display()
114                        ));
115                    }
116                }
117            }
118
119            current = parent.parent();
120        }
121
122        Ok(())
123    }
124
125    /// Get the file extension from a path
126    pub fn get_extension(path: &str) -> Option<String> {
127        Path::new(path)
128            .extension()
129            .and_then(|e| e.to_str())
130            .map(|s| s.to_lowercase())
131    }
132
133    /// Check if an extension is in the forbidden list
134    pub fn is_extension_forbidden(ext: &str, forbidden: &[String]) -> bool {
135        let ext_lower = ext.to_lowercase().trim_start_matches('.').to_string();
136        forbidden
137            .iter()
138            .any(|f| f.to_lowercase().trim_start_matches('.') == ext_lower)
139    }
140
141    /// Sanitize a path component for safe use
142    pub fn sanitize_component(component: &str) -> String {
143        component
144            .replace(['/', '\\'], "_")
145            .replace('\0', "")
146            .replace("..", "_")
147            .trim()
148            .to_string()
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_null_bytes() {
158        assert!(PathValidator::contains_null_bytes("/path\0to/file"));
159        assert!(!PathValidator::contains_null_bytes("/path/to/file"));
160    }
161
162    #[test]
163    fn test_path_traversal() {
164        assert!(PathValidator::contains_path_traversal("../etc/passwd"));
165        assert!(PathValidator::contains_path_traversal("/path/../file"));
166        assert!(!PathValidator::contains_path_traversal("/path/to/file"));
167    }
168
169    #[test]
170    fn test_url_encoded_traversal() {
171        assert!(PathValidator::contains_url_encoded_traversal(
172            "..%2fetc/passwd"
173        ));
174        assert!(PathValidator::contains_url_encoded_traversal(
175            "%2f..%5cwindows"
176        ));
177        assert!(!PathValidator::contains_url_encoded_traversal(
178            "/path/to/file"
179        ));
180    }
181
182    #[test]
183    fn test_tilde_expansion() {
184        assert!(PathValidator::starts_with_tilde("~/.ssh/id_rsa"));
185        assert!(PathValidator::starts_with_tilde("~user/file"));
186        assert!(!PathValidator::starts_with_tilde("/home/user/file"));
187    }
188
189    #[test]
190    fn test_forbidden_prefix() {
191        let forbidden = vec!["/etc".to_string(), "/root".to_string()];
192        assert!(PathValidator::matches_forbidden_prefix("/etc/passwd", &forbidden).is_some());
193        assert!(PathValidator::matches_forbidden_prefix("/root/.bashrc", &forbidden).is_some());
194        assert!(PathValidator::matches_forbidden_prefix("/home/user/file", &forbidden).is_none());
195    }
196
197    #[test]
198    fn test_extension_validation() {
199        let forbidden = vec![".exe".to_string(), ".dll".to_string()];
200        assert!(PathValidator::is_extension_forbidden("exe", &forbidden));
201        assert!(PathValidator::is_extension_forbidden(".exe", &forbidden));
202        assert!(!PathValidator::is_extension_forbidden("txt", &forbidden));
203
204        assert_eq!(
205            PathValidator::get_extension("/path/to/file.EXE"),
206            Some("exe".to_string())
207        );
208    }
209}