1#![allow(dead_code)]
12
13use std::io::Read;
14use std::path::{Component, Path, PathBuf};
15
16use path_clean::PathClean;
17use thiserror::Error;
18
19#[derive(Debug, Error)]
21pub enum Error {
22 #[error("tarball path traversal attempt: {entry_path}")]
25 Escape { entry_path: PathBuf },
26
27 #[error("unsupported tarball entry type '{kind}': {entry_path}")]
30 UnsupportedEntry {
31 entry_path: PathBuf,
32 kind: &'static str,
33 },
34
35 #[error("pre-existing symlink in target tree: {path}")]
38 PrePlantedSymlink { path: PathBuf },
39
40 #[error(transparent)]
42 Io(#[from] std::io::Error),
43}
44
45pub fn extract_safe(source: impl Read, target: &Path) -> Result<(), Error> {
66 let canon_target = target.canonicalize()?;
68
69 check_ancestors_not_symlinks(&canon_target)?;
71
72 check_children_not_symlinks(&canon_target)?;
74
75 let gz = flate2::read::GzDecoder::new(source);
77 let mut archive = tar::Archive::new(gz);
78
79 for entry_result in archive.entries()? {
80 let mut entry = entry_result?;
81 let header = entry.header();
82
83 let entry_type = header.entry_type();
85 if !entry_type.is_file() && !entry_type.is_dir() {
86 let kind = classify_entry_type(entry_type);
87 let entry_path = entry.path()?.into_owned();
88 return Err(Error::UnsupportedEntry { entry_path, kind });
89 }
90
91 let raw_path = entry.path()?.into_owned();
93
94 if raw_path.is_absolute() {
96 return Err(Error::Escape {
97 entry_path: raw_path,
98 });
99 }
100
101 for component in raw_path.components() {
103 match component {
104 Component::ParentDir | Component::RootDir | Component::Prefix(_) => {
105 return Err(Error::Escape {
106 entry_path: raw_path,
107 });
108 }
109 _ => {}
110 }
111 }
112
113 let resolved = canon_target.join(&raw_path).clean();
115 if !starts_with_normalized(&resolved, &canon_target) {
116 return Err(Error::Escape {
117 entry_path: raw_path,
118 });
119 }
120
121 entry.unpack_in(&canon_target)?;
123 }
124
125 Ok(())
126}
127
128fn check_ancestors_not_symlinks(path: &Path) -> Result<(), Error> {
130 let mut current = path.to_path_buf();
131 while let Some(parent) = current.parent() {
132 if parent.as_os_str().is_empty() {
133 break;
134 }
135 if parent.symlink_metadata()?.file_type().is_symlink() {
136 return Err(Error::PrePlantedSymlink {
137 path: parent.to_path_buf(),
138 });
139 }
140 current = parent.to_path_buf();
141 }
142 Ok(())
143}
144
145fn check_children_not_symlinks(path: &Path) -> Result<(), Error> {
147 if !path.is_dir() {
148 return Ok(());
149 }
150 for entry in std::fs::read_dir(path)? {
151 let entry = entry?;
152 if entry.metadata()?.file_type().is_symlink() {
153 return Err(Error::PrePlantedSymlink { path: entry.path() });
154 }
155 }
156 Ok(())
157}
158
159fn starts_with_normalized(path: &Path, prefix: &Path) -> bool {
162 #[cfg(windows)]
163 {
164 let path_lower = path.to_string_lossy().to_lowercase();
165 let prefix_lower = prefix.to_string_lossy().to_lowercase();
166 path_lower.starts_with(&prefix_lower)
167 }
168 #[cfg(not(windows))]
169 {
170 path.starts_with(prefix)
171 }
172}
173
174fn classify_entry_type(t: tar::EntryType) -> &'static str {
176 match t {
177 tar::EntryType::Symlink => "symlink",
178 tar::EntryType::Link => "hardlink",
179 tar::EntryType::Char => "char device",
180 tar::EntryType::Block => "block device",
181 tar::EntryType::Fifo => "fifo",
182 tar::EntryType::GNUSparse => "sparse",
183 _ => "unknown",
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use std::io::Cursor;
191 use tempfile::TempDir;
192
193 fn build_tar_gz(entries: &[(&str, &[u8], bool)]) -> Vec<u8> {
196 let mut builder = tar::Builder::new(Vec::new());
197 for &(path, content, is_dir) in entries {
198 if is_dir {
199 let mut header = tar::Header::new_gnu();
200 header.set_entry_type(tar::EntryType::Directory);
201 header.set_size(0);
202 header.set_mode(0o755);
203 header.set_cksum();
204 builder
205 .append_data(&mut header, path, &[] as &[u8])
206 .unwrap();
207 } else {
208 let mut header = tar::Header::new_gnu();
209 header.set_entry_type(tar::EntryType::Regular);
210 header.set_size(content.len() as u64);
211 header.set_mode(0o644);
212 header.set_cksum();
213 builder.append_data(&mut header, path, content).unwrap();
214 }
215 }
216 let tar_bytes = builder.into_inner().unwrap();
217
218 use flate2::write::GzEncoder;
220 use flate2::Compression;
221 use std::io::Write;
222 let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
223 encoder.write_all(&tar_bytes).unwrap();
224 encoder.finish().unwrap()
225 }
226
227 fn build_tar_gz_with_symlink(link_name: &str, target: &str) -> Vec<u8> {
229 let mut builder = tar::Builder::new(Vec::new());
230 let mut header = tar::Header::new_gnu();
231 header.set_entry_type(tar::EntryType::Symlink);
232 header.set_size(0);
233 header.set_mode(0o777);
234 header.set_cksum();
235 builder.append_link(&mut header, link_name, target).unwrap();
236 let tar_bytes = builder.into_inner().unwrap();
237
238 use flate2::write::GzEncoder;
239 use flate2::Compression;
240 use std::io::Write;
241 let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
242 encoder.write_all(&tar_bytes).unwrap();
243 encoder.finish().unwrap()
244 }
245
246 fn build_tar_gz_with_hardlink(link_name: &str, target: &str) -> Vec<u8> {
248 let mut builder = tar::Builder::new(Vec::new());
249 let mut header = tar::Header::new_gnu();
250 header.set_entry_type(tar::EntryType::Link);
251 header.set_size(0);
252 header.set_mode(0o644);
253 header.set_cksum();
254 builder.append_link(&mut header, link_name, target).unwrap();
255 let tar_bytes = builder.into_inner().unwrap();
256
257 use flate2::write::GzEncoder;
258 use flate2::Compression;
259 use std::io::Write;
260 let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
261 encoder.write_all(&tar_bytes).unwrap();
262 encoder.finish().unwrap()
263 }
264
265 fn build_tar_gz_with_forged_path(malicious_path: &str) -> Vec<u8> {
273 let placeholder = "PLACEHOLDER_PATH_FOR_FORGING";
275 assert!(
276 malicious_path.len() <= placeholder.len(),
277 "malicious path too long for placeholder"
278 );
279
280 let content = b"evil";
281 let mut builder = tar::Builder::new(Vec::new());
282 let mut header = tar::Header::new_gnu();
283 header.set_entry_type(tar::EntryType::Regular);
284 header.set_size(content.len() as u64);
285 header.set_mode(0o644);
286 header.set_cksum();
287 builder
288 .append_data(&mut header, placeholder, &content[..])
289 .unwrap();
290 let mut tar_bytes = builder.into_inner().unwrap();
291
292 if let Some(pos) = tar_bytes
295 .windows(placeholder.len())
296 .position(|w| w == placeholder.as_bytes())
297 {
298 for b in &mut tar_bytes[pos..pos + placeholder.len()] {
300 *b = 0;
301 }
302 tar_bytes[pos..pos + malicious_path.len()].copy_from_slice(malicious_path.as_bytes());
304
305 let header_start = pos - (pos % 512); let mut sum: u64 = 0;
310 for (i, &b) in tar_bytes[header_start..header_start + 512]
311 .iter()
312 .enumerate()
313 {
314 if (148..156).contains(&i) {
315 sum += 0x20u64; } else {
317 sum += b as u64;
318 }
319 }
320 let cksum = format!("{sum:06o}\0 ");
321 tar_bytes[header_start + 148..header_start + 156].copy_from_slice(cksum.as_bytes());
322 }
323
324 use flate2::write::GzEncoder;
326 use flate2::Compression;
327 use std::io::Write;
328 let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
329 encoder.write_all(&tar_bytes).unwrap();
330 encoder.finish().unwrap()
331 }
332
333 #[test]
336 fn extracts_regular_files_to_target() {
337 let tmp = TempDir::new().unwrap();
338 let target = tmp.path().join("dest");
339 std::fs::create_dir_all(&target).unwrap();
340
341 let gz = build_tar_gz(&[
342 ("core/", b"", true),
343 ("core/manifest.json", b"{\"version\":\"0.1\"}", false),
344 ]);
345
346 extract_safe(Cursor::new(gz), &target).unwrap();
347
348 let manifest = target.join("core/manifest.json");
349 assert!(manifest.exists(), "manifest.json should exist");
350 assert_eq!(
351 std::fs::read_to_string(manifest).unwrap(),
352 "{\"version\":\"0.1\"}"
353 );
354 }
355
356 #[test]
357 fn extracts_nested_directories() {
358 let tmp = TempDir::new().unwrap();
359 let target = tmp.path().join("dest");
360 std::fs::create_dir_all(&target).unwrap();
361
362 let gz = build_tar_gz(&[
363 ("core/", b"", true),
364 ("core/skills/", b"", true),
365 ("core/skills/query-installation.md", b"# Query", false),
366 ]);
367
368 extract_safe(Cursor::new(gz), &target).unwrap();
369
370 let file = target.join("core/skills/query-installation.md");
371 assert!(file.exists());
372 assert_eq!(std::fs::read_to_string(file).unwrap(), "# Query");
373 }
374
375 #[test]
376 fn empty_tarball_succeeds() {
377 let tmp = TempDir::new().unwrap();
378 let target = tmp.path().join("dest");
379 std::fs::create_dir_all(&target).unwrap();
380
381 let gz = build_tar_gz(&[]);
382 extract_safe(Cursor::new(gz), &target).unwrap();
383 }
385
386 #[test]
389 fn rejects_dotdot_path_traversal() {
390 let tmp = TempDir::new().unwrap();
391 let target = tmp.path().join("dest");
392 std::fs::create_dir_all(&target).unwrap();
393
394 let gz = build_tar_gz_with_forged_path("../../etc/passwd");
395 let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
396 assert!(
397 matches!(err, Error::Escape { .. }),
398 "Expected Escape, got: {err:?}"
399 );
400 }
401
402 #[test]
403 fn rejects_absolute_path() {
404 let tmp = TempDir::new().unwrap();
405 let target = tmp.path().join("dest");
406 std::fs::create_dir_all(&target).unwrap();
407
408 let gz = build_tar_gz_with_forged_path("/tmp/evil");
409 let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
410 assert!(
411 matches!(err, Error::Escape { .. }),
412 "Expected Escape, got: {err:?}"
413 );
414 }
415
416 #[test]
419 fn rejects_symlink_entry() {
420 let tmp = TempDir::new().unwrap();
421 let target = tmp.path().join("dest");
422 std::fs::create_dir_all(&target).unwrap();
423
424 let gz = build_tar_gz_with_symlink("evil-link", "/etc/passwd");
425 let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
426 match err {
427 Error::UnsupportedEntry { kind, .. } => assert_eq!(kind, "symlink"),
428 other => panic!("Expected UnsupportedEntry(symlink), got: {other:?}"),
429 }
430 }
431
432 #[test]
433 fn rejects_hardlink_entry() {
434 let tmp = TempDir::new().unwrap();
435 let target = tmp.path().join("dest");
436 std::fs::create_dir_all(&target).unwrap();
437
438 let gz = build_tar_gz_with_hardlink("evil-link", "core/manifest.json");
439 let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
440 match err {
441 Error::UnsupportedEntry { kind, .. } => assert_eq!(kind, "hardlink"),
442 other => panic!("Expected UnsupportedEntry(hardlink), got: {other:?}"),
443 }
444 }
445
446 #[test]
453 fn rejects_pre_planted_symlink_child() {
454 let tmp = TempDir::new().unwrap();
455 let target = tmp.path().join("dest");
456 std::fs::create_dir_all(&target).unwrap();
457
458 let decoy = tmp.path().join("decoy");
459 std::fs::create_dir_all(&decoy).unwrap();
460
461 let symlink_path = target.join("evil");
463 #[cfg(unix)]
464 std::os::unix::fs::symlink(&decoy, &symlink_path).unwrap();
465 #[cfg(windows)]
466 {
467 if std::os::windows::fs::symlink_dir(&decoy, &symlink_path).is_err() {
468 eprintln!("Skipping: symlink creation requires Developer Mode");
471 return;
472 }
473 }
474
475 let gz = build_tar_gz(&[("core/manifest.json", b"data", false)]);
476 let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
477 assert!(
478 matches!(err, Error::PrePlantedSymlink { .. }),
479 "Expected PrePlantedSymlink, got: {err:?}"
480 );
481 }
482
483 #[test]
486 fn escape_error_names_path() {
487 let err = Error::Escape {
488 entry_path: PathBuf::from("../../etc/passwd"),
489 };
490 let display = format!("{err}");
491 assert!(display.contains("../../etc/passwd"));
492 }
493
494 #[test]
495 fn unsupported_entry_names_kind_and_path() {
496 let err = Error::UnsupportedEntry {
497 entry_path: PathBuf::from("evil-link"),
498 kind: "symlink",
499 };
500 let display = format!("{err}");
501 assert!(display.contains("symlink"));
502 assert!(display.contains("evil-link"));
503 }
504
505 #[test]
506 fn pre_planted_symlink_error_names_path() {
507 let err = Error::PrePlantedSymlink {
508 path: PathBuf::from("/some/path"),
509 };
510 let display = format!("{err}");
511 assert!(display.contains("/some/path"));
512 }
513}