1use anyhow::{Context, Result, bail};
9use std::path::{Component, Path, PathBuf};
10
11pub fn sanitize_filename(name: &str) -> Result<String> {
16 let cleaned: String = name
18 .chars()
19 .filter(|c| !c.is_control() || *c == ' ')
20 .collect();
21
22 let path = Path::new(&cleaned);
23
24 let filename = path
26 .components()
27 .filter_map(|c| match c {
28 Component::Normal(s) => s.to_str(),
29 _ => None,
30 })
31 .next_back();
32
33 match filename {
34 Some(f) if !f.is_empty() => Ok(f.to_string()),
35 _ => bail!("Filename is empty or contains only traversal components"),
36 }
37}
38
39pub fn validate_path_within(target: &Path, base_dir: &Path) -> Result<PathBuf> {
41 let canon_base = base_dir
42 .canonicalize()
43 .with_context(|| format!("Cannot canonicalize base dir: {}", base_dir.display()))?;
44 let canon_target = target
45 .canonicalize()
46 .with_context(|| format!("Cannot canonicalize target: {}", target.display()))?;
47
48 if canon_target.starts_with(&canon_base) {
49 Ok(canon_target)
50 } else {
51 bail!(
52 "Path '{}' escapes base directory '{}'",
53 target.display(),
54 base_dir.display()
55 )
56 }
57}
58
59pub fn safe_join(base_dir: &Path, untrusted_name: &str) -> Result<PathBuf> {
61 let safe_name = sanitize_filename(untrusted_name)?;
62 let joined = base_dir.join(&safe_name);
63
64 if base_dir.exists() {
67 if !joined.exists() {
69 if let Some(parent) = joined.parent() {
73 let canon_parent = parent
74 .canonicalize()
75 .with_context(|| format!("Cannot canonicalize parent: {}", parent.display()))?;
76 let canon_base = base_dir.canonicalize()?;
77 if !canon_parent.starts_with(&canon_base) {
78 bail!(
79 "Path '{}' escapes base directory '{}'",
80 joined.display(),
81 base_dir.display()
82 );
83 }
84 }
85 } else {
86 validate_path_within(&joined, base_dir)?;
87 }
88 }
89
90 Ok(joined)
91}
92
93pub fn create_dir_restricted(path: &Path) -> std::io::Result<()> {
98 #[cfg(unix)]
99 {
100 use std::os::unix::fs::DirBuilderExt;
101 std::fs::DirBuilder::new()
102 .recursive(true)
103 .mode(0o700)
104 .create(path)?;
105 }
106
107 #[cfg(not(unix))]
108 {
109 std::fs::create_dir_all(path)?;
110 }
111
112 Ok(())
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use std::fs;
119
120 #[test]
121 fn test_normal_filename_unchanged() {
122 assert_eq!(sanitize_filename("report.pdf").unwrap(), "report.pdf");
123 }
124
125 #[test]
126 fn test_traversal_etc_passwd() {
127 assert_eq!(sanitize_filename("../../etc/passwd").unwrap(), "passwd");
128 }
129
130 #[test]
131 fn test_absolute_path_stripped() {
132 assert_eq!(sanitize_filename("/etc/shadow").unwrap(), "shadow");
133 }
134
135 #[test]
136 fn test_windows_traversal() {
137 assert_eq!(
138 sanitize_filename("..\\..\\windows\\system32\\config").unwrap(),
139 if cfg!(windows) {
143 "config".to_string()
144 } else {
145 "..\\..\\windows\\system32\\config".to_string()
146 }
147 );
148 }
149
150 #[test]
151 fn test_empty_string_errors() {
152 assert!(sanitize_filename("").is_err());
153 }
154
155 #[test]
156 fn test_dotdot_alone_errors() {
157 assert!(sanitize_filename("..").is_err());
158 }
159
160 #[test]
161 fn test_nul_bytes_stripped() {
162 assert_eq!(sanitize_filename("file\0name.txt").unwrap(), "filename.txt");
163 }
164
165 #[test]
166 fn test_unicode_preserved() {
167 assert_eq!(
168 sanitize_filename("日本語ファイル.txt").unwrap(),
169 "日本語ファイル.txt"
170 );
171 }
172
173 #[test]
174 fn test_validate_path_within_rejects_escape() {
175 let tmp = tempfile::tempdir().unwrap();
176 let base = tmp.path();
177
178 let inside = base.join("safe.txt");
180 fs::write(&inside, "ok").unwrap();
181
182 assert!(validate_path_within(&inside, base).is_ok());
184
185 let outside = Path::new("/tmp");
187 assert!(validate_path_within(outside, base).is_err());
188 }
189
190 #[test]
191 fn test_safe_join_combines_correctly() {
192 let tmp = tempfile::tempdir().unwrap();
193 let base = tmp.path();
194
195 let result = safe_join(base, "report.pdf").unwrap();
196 assert_eq!(result, base.join("report.pdf"));
197 }
198
199 #[test]
200 fn test_safe_join_strips_traversal() {
201 let tmp = tempfile::tempdir().unwrap();
202 let base = tmp.path();
203
204 let result = safe_join(base, "../../etc/passwd").unwrap();
205 assert_eq!(result, base.join("passwd"));
206 }
207
208 #[cfg(unix)]
209 #[test]
210 fn test_create_dir_restricted_permissions() {
211 use std::os::unix::fs::PermissionsExt;
212
213 let tmp = tempfile::tempdir().unwrap();
214 let dir = tmp.path().join("restricted");
215
216 create_dir_restricted(&dir).unwrap();
217
218 let perms = fs::metadata(&dir).unwrap().permissions();
219 assert_eq!(perms.mode() & 0o777, 0o700);
220 }
221}