Skip to main content

graphify_security/
path_validator.rs

1//! Path traversal prevention and graph file validation.
2
3use std::path::{Path, PathBuf};
4
5use crate::SecurityError;
6
7/// Ensure a path stays within an allowed directory (no `../` traversal).
8///
9/// Uses `canonicalize` to resolve symlinks and relative components.
10/// Returns `PathNotFound` for non-existent paths (distinguishable from
11/// `PathTraversal`) and `PathTraversal` for actual escape attempts.
12pub fn safe_path(path: &Path, allowed_root: &Path) -> Result<PathBuf, SecurityError> {
13    let canonical = path.canonicalize().map_err(|e| {
14        if e.kind() == std::io::ErrorKind::NotFound {
15            SecurityError::PathNotFound(path.to_string_lossy().to_string())
16        } else {
17            SecurityError::PathTraversal(path.to_string_lossy().to_string())
18        }
19    })?;
20    let root = allowed_root
21        .canonicalize()
22        .map_err(|_| SecurityError::PathTraversal(allowed_root.to_string_lossy().to_string()))?;
23
24    if canonical.starts_with(&root) {
25        Ok(canonical)
26    } else {
27        Err(SecurityError::PathTraversal(
28            path.to_string_lossy().to_string(),
29        ))
30    }
31}
32
33/// Validate a graph file path: must have a `.json` extension.
34pub fn validate_graph_path(path: &str) -> Result<PathBuf, SecurityError> {
35    let p = PathBuf::from(path);
36    if p.extension().and_then(|e| e.to_str()) != Some("json") {
37        return Err(SecurityError::InvalidPath(
38            "graph file must be .json".into(),
39        ));
40    }
41    Ok(p)
42}
43
44#[cfg(test)]
45mod tests {
46    use super::*;
47    use std::fs;
48
49    #[test]
50    fn test_safe_path_within_root() {
51        let dir = std::env::temp_dir().join("graphify_security_test_safe");
52        let _ = fs::create_dir_all(&dir);
53        let file = dir.join("test.json");
54        fs::write(&file, "{}").unwrap();
55
56        let result = safe_path(&file, &dir);
57        assert!(result.is_ok());
58
59        let _ = fs::remove_file(&file);
60        let _ = fs::remove_dir(&dir);
61    }
62
63    #[test]
64    fn test_safe_path_traversal_blocked() {
65        let dir = std::env::temp_dir().join("graphify_security_test_traversal");
66        let sub = dir.join("sub");
67        let _ = fs::create_dir_all(&sub);
68        let file = dir.join("secret.txt");
69        fs::write(&file, "secret").unwrap();
70
71        let traversal = sub.join("../secret.txt");
72        let result = safe_path(&traversal, &sub);
73        assert!(matches!(result, Err(SecurityError::PathTraversal(_))));
74
75        let _ = fs::remove_file(&file);
76        let _ = fs::remove_dir(&sub);
77        let _ = fs::remove_dir(&dir);
78    }
79
80    #[test]
81    fn test_safe_path_nonexistent_file() {
82        let result = safe_path(Path::new("/nonexistent/path/file.txt"), Path::new("/tmp"));
83        assert!(matches!(result, Err(SecurityError::PathNotFound(_))));
84    }
85
86    #[test]
87    fn test_validate_graph_path_json() {
88        let result = validate_graph_path("output/graph.json");
89        assert!(result.is_ok());
90        assert_eq!(result.unwrap(), PathBuf::from("output/graph.json"));
91    }
92
93    #[test]
94    fn test_validate_graph_path_non_json() {
95        let result = validate_graph_path("output/graph.xml");
96        assert!(matches!(result, Err(SecurityError::InvalidPath(_))));
97    }
98
99    #[test]
100    fn test_validate_graph_path_no_extension() {
101        let result = validate_graph_path("output/graph");
102        assert!(matches!(result, Err(SecurityError::InvalidPath(_))));
103    }
104
105    #[test]
106    fn test_validate_graph_path_dot_json_in_middle() {
107        let result = validate_graph_path("foo.json.bak");
108        assert!(matches!(result, Err(SecurityError::InvalidPath(_))));
109    }
110}